[DeepSeek-V4] Implement MoE routing primitives (HashRouter, TopKRouter, RoutedMoE)#3871
[DeepSeek-V4] Implement MoE routing primitives (HashRouter, TopKRouter, RoutedMoE)#3871parambole wants to merge 1 commit into
Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
37ee811 to
31329c5
Compare
1ab79e5 to
c025463
Compare
31329c5 to
22a57ff
Compare
c025463 to
68c44a6
Compare
22a57ff to
32869e5
Compare
72a92a7 to
e81f52d
Compare
…outer, RoutedMoE) Implement Mixture of Experts routing gates and execution layers for DeepSeek-V4 integration into MaxText: - HashRouter: Token routing mechanism utilizing MD5 hash projections for deterministic expert assignment. - TopKRouter: Gated top-k router implementing sigmoid scaling and score normalization. - RoutedMoE & RoutedAndSharedMoE: Execution layers supporting layer_idx routing and FP32 expert summation parity. - Parity verification: Extended unit test suite (deepseek_v4_vs_reference_test.py) validating routing parity against PyTorch reference implementations at atol=1e-5, rtol=1e-5.
32869e5 to
c92f2e0
Compare
|
🤖 Hi @parambole, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
RissyRan
left a comment
There was a problem hiding this comment.
Thanks for the change! Have a few comments.
| return jnp.sqrt(jax.nn.softplus(x)) | ||
|
|
||
|
|
||
| class DeepSeekV4TopKRouter(nnx.Module): |
There was a problem hiding this comment.
I see the logic is similar for GateLogic + topK
Gate logic:
maxtext/src/maxtext/layers/moe.py
Line 174 in b5e5330
Topk:
maxtext/src/maxtext/layers/moe.py
Line 599 in b5e5330
Shall we leverage _sqrtsoftplus, and combine them to avoid some duplicate?
| ) | ||
|
|
||
|
|
||
| class DeepSeekV4HashRouter(nnx.Module): |
There was a problem hiding this comment.
Wondering if we should name it HashRouter directly? Do you know if it's specific for DS v4?
| with jax.named_scope("ffn_act"): | ||
| if self.config.decoder_block == ctypes.DecoderBlockType.GPT_OSS: | ||
| if self.config.decoder_block == ctypes.DecoderBlockType.DEEPSEEK_V4: | ||
| limit = getattr(self.config, "swiglu_limit", 1.0) |
There was a problem hiding this comment.
Shall we reuse self.config.mlp_activations_limit config as bellow?
| w0_pspec, w1_pspec, wo_pspec = maybe_aqt_partition(w0_kernel, w0_pspec, w1_kernel, w1_pspec, wo_kernel, wo_pspec) | ||
|
|
||
| if gate_weights is not None: | ||
| gate_weights_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", None)) |
There was a problem hiding this comment.
Nit: could you help add shape info here? same comment for gate_indices
| routing_inputs = inputs if gate_inputs is None else gate_inputs.astype(gate_dtype) | ||
| gate_logits, pre_bias_logits = self.gate(routing_inputs) | ||
|
|
||
| if self.config.decoder_block == ctypes.DecoderBlockType.DEEPSEEK_V4: |
There was a problem hiding this comment.
Could you help make some comments for the conditions? Thanks!
|
🤖 Hi @RissyRan, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
The Pull Request successfully implements the MoE routing primitives required for DeepSeek-V4, including HashRouter, TopKRouter, and updates to the execution layers. The code is well-structured and includes comprehensive unit tests validating parity with PyTorch reference implementations.
🔍 General Feedback
- Integration Gaps: While the routing primitives are correct, the DeepSeek model definition (
src/maxtext/models/deepseek.py) needs corresponding updates to passlayer_idxandinput_idsto these new layers. Without these, the model will not behave correctly in DeepSeek-V4 mode. - Configuration Maintainability: The use of
getattrwith hardcoded defaults forswiglu_limitandnum_hash_layersshould be replaced with formal parameters inMaxTextConfigto ensure better discoverability and type safety. - Precision Parity: The explicit use of FP32 for expert summation and routing calculations is a positive highlight as it ensures numerical parity with reference implementations.
| matmul_precision=self.config.matmul_precision, | ||
| shard_mode=config.shard_mode, | ||
| rngs=self.rngs, | ||
| self.is_hash = self.config.decoder_block == ctypes.DecoderBlockType.DEEPSEEK_V4 and 0 <= layer_idx < getattr( |
There was a problem hiding this comment.
🔴 The integration of layer_idx is crucial for DeepSeek-V4 to distinguish between Hash and Top-K layers. However, the model definition in src/maxtext/models/deepseek.py does not yet pass this parameter during instantiation of RoutedAndSharedMoE. This will cause all layers to default to layer_idx=0, resulting in all layers using the HashRouter if num_hash_layers >= 1.
| shard_mode=config.shard_mode, | ||
| rngs=self.rngs, | ||
| self.is_hash = self.config.decoder_block == ctypes.DecoderBlockType.DEEPSEEK_V4 and 0 <= layer_idx < getattr( | ||
| config, "num_hash_layers", 3 |
There was a problem hiding this comment.
🟡 num_hash_layers should ideally be a formal configuration parameter in MaxTextConfig (in src/maxtext/configs/types.py) rather than being retrieved via getattr with a hardcoded default of 3.
| gate_inputs: jax.Array | None = None, | ||
| out_sharding: NamedSharding | None = None, | ||
| input_ids: jax.Array | None = None, | ||
| gate_weights: jax.Array | None = None, |
There was a problem hiding this comment.
🟠 DeepSeekV4HashRouter requires input_ids for expert assignment. While the interface now correctly supports this, ensure that src/maxtext/models/deepseek.py is updated to pass decoder_input_tokens as input_ids during the MoE block call, as it currently lacks this linkage.
| with jax.named_scope("ffn_act"): | ||
| if self.config.decoder_block == ctypes.DecoderBlockType.GPT_OSS: | ||
| if self.config.decoder_block == ctypes.DecoderBlockType.DEEPSEEK_V4: | ||
| limit = getattr(self.config, "swiglu_limit", 1.0) |
There was a problem hiding this comment.
🟡 Instead of using getattr with a hardcoded default of 1.0, consider adding swiglu_limit to the MaxTextConfig class. This improves discoverability, documentation, and type safety for the configuration.
| Computes logits, static routing weights based on token IDs, and expert indices. | ||
| """ | ||
|
|
||
| def __init__( |
There was a problem hiding this comment.
🟢 The docstring mentions "static routing weights", which might be confusing as the weights themselves are learned based on logits. Only the expert assignment (indices) is static based on token IDs.
| def __init__( | |
| """Hash Router for DeepSeek-V4 MoE routing. | |
| Computes learned routing weights for a static expert assignment determined by token IDs. | |
| """ |
Description
Implement Mixture of Experts (MoE) routing gates and execution layers required for DeepSeek-V4 integration into MaxText:
HashRouter: Token routing mechanism utilizing MD5 hash projections for deterministic expert assignment without auxiliary loss.TopKRouter: Gated top-k router implementing sigmoid scaling and score normalization across selected experts.RoutedMoE&RoutedAndSharedMoE: Execution layers supportinglayer_idxrouting, gate clamping, and FP32 expert summation parity.tests/unit/deepseek_v4_vs_reference_test.py) validating MoE routing parity against PyTorch reference implementations atatol=1e-5, rtol=1e-5.Tests
Tested on CPU:
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.