Skip to content

[DeepSeek-V4] Implement Compressed Attention Layers#3866

Open
parambole wants to merge 1 commit into
dsv4-moe-routing-primitivesfrom
deepseek_v4_compressed_attention
Open

[DeepSeek-V4] Implement Compressed Attention Layers#3866
parambole wants to merge 1 commit into
dsv4-moe-routing-primitivesfrom
deepseek_v4_compressed_attention

Conversation

@parambole
Copy link
Copy Markdown
Collaborator

@parambole parambole commented May 11, 2026

Description

Implement compressed attention mechanisms and indexer modules required for DeepSeek-V4 integration into MaxText:

  • CSACompressor & HCACompressor: Long-range attention compressors supporting causal block bias and YaRN frequency scaling decoupling.
  • LightningIndexer: Memory-efficient indexer module implementing sentinel masking and dynamic RoPE scaling.
  • Configuration: Register attention compression hyperparameters (compress_ratios, index_head_dim, sliding_window) in types.py and base.yml.
  • Unit test suite (tests/unit/deepseek_v4_vs_reference_test.py) validating attention compression parity against PyTorch reference implementations at atol=1e-5, rtol=1e-5.

Tests

Tested on CPU

pytest  tests/unit/deepseek_v4_vs_reference_test.py

======================= 10 passed, 10 warnings in 20.42s =======================
tests/unit/deepseek_v4_vs_reference_test.py ..........                   [100%]

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link
Copy Markdown

codecov Bot commented May 11, 2026

Codecov Report

❌ Patch coverage is 7.63052% with 230 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/layers/attention_compressed.py 7.63% 230 Missing ⚠️

📢 Thoughts on this report? Let us know!

@parambole parambole force-pushed the deepseek_v4_compressed_attention branch from 5f54827 to 07eb3e2 Compare May 11, 2026 19:39
@parambole parambole changed the base branch from deepseek_v4_core_primitives to dsv4-moe-routing-primitives May 11, 2026 20:29
@parambole parambole force-pushed the dsv4-moe-routing-primitives branch from 37ee811 to 31329c5 Compare May 11, 2026 20:38
@parambole parambole force-pushed the deepseek_v4_compressed_attention branch from 07eb3e2 to 4520166 Compare May 11, 2026 20:43
@parambole parambole force-pushed the dsv4-moe-routing-primitives branch from 31329c5 to 22a57ff Compare May 12, 2026 17:23
@parambole parambole force-pushed the deepseek_v4_compressed_attention branch from 4520166 to 10ca4f6 Compare May 12, 2026 17:23
@parambole parambole force-pushed the dsv4-moe-routing-primitives branch from 22a57ff to 32869e5 Compare May 12, 2026 21:12
@parambole parambole force-pushed the deepseek_v4_compressed_attention branch from 10ca4f6 to 31a5932 Compare May 12, 2026 21:13
@parambole parambole force-pushed the dsv4-moe-routing-primitives branch from 32869e5 to c92f2e0 Compare May 14, 2026 17:51
…ghtningIndexer)

Implement compressed attention mechanisms and indexer modules for DeepSeek-V4 integration into MaxText:

- CSACompressor & HCACompressor: Long-range attention compressors supporting causal block bias and YaRN frequency scaling decoupling.
- LightningIndexer: Memory-efficient indexer module implementing sentinel masking and dynamic RoPE scaling.
- Configuration: Register attention compression hyperparameters (compress_ratios, index_head_dim, sliding_window) in types.py and base.yml.
- Parity verification: Extended unit test suite (deepseek_v4_vs_reference_test.py) validating attention compression parity against PyTorch reference implementations at atol=1e-5, rtol=1e-5.
@parambole parambole force-pushed the deepseek_v4_compressed_attention branch from 31a5932 to c98a34e Compare May 14, 2026 17:53
@parambole parambole changed the title Implement DeepSeek-V4 Compressed Attention Layers [DeepSeek-V4] Implement Compressed Attention Layers May 14, 2026
@github-actions
Copy link
Copy Markdown

🤖 Hi @parambole, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions
Copy link
Copy Markdown

🤖 Hi @parambole, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This PR implements the core compressed attention layers for DeepSeek-V4 integration, including HCA, CSA, and the Lightning Indexer. The implementation is technically sound, follows established patterns in MaxText, and includes comprehensive parity tests against PyTorch.

🔍 General Feedback

  • Efficiency: The main coordinator block uses jnp.repeat for broadcasting MQA keys/values, which is memory-intensive. Switching to jnp.einsum broadcasting is recommended.
  • Typo: A minor typo swaped was found in the indexer module.
  • Config: Ensure compress_ratios is properly documented as a required list when using these attention variants to avoid runtime IndexError.


# Compute attention logits
# logits: [B, H, S, S_kv]
logits = jnp.einsum("bhsd, bhkd -> bhsk", q, k, precision=self.config.matmul_precision) * self.scaling
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Memory/Performance: `jnp.repeat` creates large intermediate tensors that are unnecessary given the shared MQA structure. You can use broadcasting in `jnp.einsum` to achieve the same result more efficiently.
Suggested change
logits = jnp.einsum("bhsd, bhkd -> bhsk", q, k, precision=self.config.matmul_precision) * self.scaling
# Broadcast key/value configurations to all heads using broadcasting in einsum
# k and v remain [B, 1, S_kv, D_head]
k = kv
v = kv
# Compute attention logits with head broadcasting: [B, H, S, S_kv]
logits = jnp.einsum("bhsd, b1kd -> bhsk", q, k, precision=self.config.matmul_precision) * self.scaling


# Project attention weights onto values
# attn_output: [B, H, S, D_head]
attn_output = jnp.einsum("bhsk, bhkd -> bhsd", attn_weights, v, precision=self.config.matmul_precision)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Memory/Performance: Use broadcasting in `jnp.einsum` here as well to avoid the `jnp.repeat` from earlier.
Suggested change
attn_output = jnp.einsum("bhsk, bhkd -> bhsd", attn_weights, v, precision=self.config.matmul_precision)
# Project attention weights onto values with head broadcasting: [B, H, S, D_head]
attn_output = jnp.einsum("bhsk, b1kd -> bhsd", attn_weights, v, precision=self.config.matmul_precision)

swaped_kv = jnp.swapaxes(compressed_kv, -1, -2)
swaped_kv = jnp.expand_dims(swaped_kv, axis=1)
# scores: [B, S, H, W]
scores = jnp.matmul(q, swaped_kv)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 Typo: "swaped" should be "swapped".
Suggested change
scores = jnp.matmul(q, swaped_kv)
# swapped_kv: [B, 1, D_idx, W]
swapped_kv = jnp.swapaxes(compressed_kv, -1, -2)
swapped_kv = jnp.expand_dims(swapped_kv, axis=1)
# scores: [B, S, H, W]
scores = jnp.matmul(q, swapped_kv)

"""Configuration specific to DeepSeek-V4 stateless compressed attention layers."""

compress_rope_theta: float = Field(160000.0, description="Theta base frequency for long-range compressor layers.")
compress_ratios: list[int] = Field(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Configuration: The default `compress_ratios` is an empty list. Since `DeepSeekV4Attention` relies on this list having at least `layer_idx + 1` elements when using compressed layer types, this will cause an `IndexError` at runtime unless the user provides a full list. It might be better to provide a default or add a check with a clear error message.

Copy link
Copy Markdown
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't see sharding annotations. It will be good we start to add some of them in this PR? i.e. starting with those weights.

return compressed_kv, block_bias


class DeepSeekV4Indexer(nnx.Module):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you considered to re-use v3.2 Indexer?

Ref: doc

# See the License for the specific language governing permissions and
# limitations under the License.

"""Compressed Attention layers and long-range compressors."""
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For file name, may be compressed_attention.py as you mentioned in the comment here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants