[DeepSeek-V4] Implement Compressed Attention Layers#3866
Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
5f54827 to
07eb3e2
Compare
37ee811 to
31329c5
Compare
07eb3e2 to
4520166
Compare
31329c5 to
22a57ff
Compare
4520166 to
10ca4f6
Compare
22a57ff to
32869e5
Compare
10ca4f6 to
31a5932
Compare
32869e5 to
c92f2e0
Compare
…ghtningIndexer) Implement compressed attention mechanisms and indexer modules for DeepSeek-V4 integration into MaxText: - CSACompressor & HCACompressor: Long-range attention compressors supporting causal block bias and YaRN frequency scaling decoupling. - LightningIndexer: Memory-efficient indexer module implementing sentinel masking and dynamic RoPE scaling. - Configuration: Register attention compression hyperparameters (compress_ratios, index_head_dim, sliding_window) in types.py and base.yml. - Parity verification: Extended unit test suite (deepseek_v4_vs_reference_test.py) validating attention compression parity against PyTorch reference implementations at atol=1e-5, rtol=1e-5.
31a5932 to
c98a34e
Compare
|
🤖 Hi @parambole, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
|
🤖 Hi @parambole, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This PR implements the core compressed attention layers for DeepSeek-V4 integration, including HCA, CSA, and the Lightning Indexer. The implementation is technically sound, follows established patterns in MaxText, and includes comprehensive parity tests against PyTorch.
🔍 General Feedback
- Efficiency: The main coordinator block uses
jnp.repeatfor broadcasting MQA keys/values, which is memory-intensive. Switching tojnp.einsumbroadcasting is recommended. - Typo: A minor typo
swapedwas found in the indexer module. - Config: Ensure
compress_ratiosis properly documented as a required list when using these attention variants to avoid runtimeIndexError.
|
|
||
| # Compute attention logits | ||
| # logits: [B, H, S, S_kv] | ||
| logits = jnp.einsum("bhsd, bhkd -> bhsk", q, k, precision=self.config.matmul_precision) * self.scaling |
There was a problem hiding this comment.
| logits = jnp.einsum("bhsd, bhkd -> bhsk", q, k, precision=self.config.matmul_precision) * self.scaling | |
| # Broadcast key/value configurations to all heads using broadcasting in einsum | |
| # k and v remain [B, 1, S_kv, D_head] | |
| k = kv | |
| v = kv | |
| # Compute attention logits with head broadcasting: [B, H, S, S_kv] | |
| logits = jnp.einsum("bhsd, b1kd -> bhsk", q, k, precision=self.config.matmul_precision) * self.scaling |
|
|
||
| # Project attention weights onto values | ||
| # attn_output: [B, H, S, D_head] | ||
| attn_output = jnp.einsum("bhsk, bhkd -> bhsd", attn_weights, v, precision=self.config.matmul_precision) |
There was a problem hiding this comment.
| attn_output = jnp.einsum("bhsk, bhkd -> bhsd", attn_weights, v, precision=self.config.matmul_precision) | |
| # Project attention weights onto values with head broadcasting: [B, H, S, D_head] | |
| attn_output = jnp.einsum("bhsk, b1kd -> bhsd", attn_weights, v, precision=self.config.matmul_precision) |
| swaped_kv = jnp.swapaxes(compressed_kv, -1, -2) | ||
| swaped_kv = jnp.expand_dims(swaped_kv, axis=1) | ||
| # scores: [B, S, H, W] | ||
| scores = jnp.matmul(q, swaped_kv) |
There was a problem hiding this comment.
| scores = jnp.matmul(q, swaped_kv) | |
| # swapped_kv: [B, 1, D_idx, W] | |
| swapped_kv = jnp.swapaxes(compressed_kv, -1, -2) | |
| swapped_kv = jnp.expand_dims(swapped_kv, axis=1) | |
| # scores: [B, S, H, W] | |
| scores = jnp.matmul(q, swapped_kv) |
| """Configuration specific to DeepSeek-V4 stateless compressed attention layers.""" | ||
|
|
||
| compress_rope_theta: float = Field(160000.0, description="Theta base frequency for long-range compressor layers.") | ||
| compress_ratios: list[int] = Field( |
There was a problem hiding this comment.
RissyRan
left a comment
There was a problem hiding this comment.
I didn't see sharding annotations. It will be good we start to add some of them in this PR? i.e. starting with those weights.
| return compressed_kv, block_bias | ||
|
|
||
|
|
||
| class DeepSeekV4Indexer(nnx.Module): |
There was a problem hiding this comment.
Have you considered to re-use v3.2 Indexer?
Ref: doc
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Compressed Attention layers and long-range compressors.""" |
There was a problem hiding this comment.
For file name, may be compressed_attention.py as you mentioned in the comment here?
Description
Implement compressed attention mechanisms and indexer modules required for DeepSeek-V4 integration into MaxText:
CSACompressor&HCACompressor: Long-range attention compressors supporting causal block bias and YaRN frequency scaling decoupling.LightningIndexer: Memory-efficient indexer module implementing sentinel masking and dynamic RoPE scaling.compress_ratios,index_head_dim,sliding_window) intypes.pyandbase.yml.tests/unit/deepseek_v4_vs_reference_test.py) validating attention compression parity against PyTorch reference implementations atatol=1e-5, rtol=1e-5.Tests
Tested on CPU
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.