Skip to content

[DeepSeek-V4] Implement MoE routing primitives (HashRouter, TopKRouter, RoutedMoE)#3871

Open
parambole wants to merge 1 commit into
deepseek_v4_core_primitivesfrom
dsv4-moe-routing-primitives
Open

[DeepSeek-V4] Implement MoE routing primitives (HashRouter, TopKRouter, RoutedMoE)#3871
parambole wants to merge 1 commit into
deepseek_v4_core_primitivesfrom
dsv4-moe-routing-primitives

Conversation

@parambole
Copy link
Copy Markdown
Collaborator

@parambole parambole commented May 11, 2026

Description

Implement Mixture of Experts (MoE) routing gates and execution layers required for DeepSeek-V4 integration into MaxText:

  • HashRouter: Token routing mechanism utilizing MD5 hash projections for deterministic expert assignment without auxiliary loss.
  • TopKRouter: Gated top-k router implementing sigmoid scaling and score normalization across selected experts.
  • RoutedMoE & RoutedAndSharedMoE: Execution layers supporting layer_idx routing, gate clamping, and FP32 expert summation parity.
  • Unit test suite (tests/unit/deepseek_v4_vs_reference_test.py) validating MoE routing parity against PyTorch reference implementations at atol=1e-5, rtol=1e-5.

Tests

Tested on CPU:

pytest  tests/unit/deepseek_v4_vs_reference_test.py

======================== 6 passed, 8 warnings in 3.99s =========================
tests/unit/deepseek_v4_vs_reference_test.py ......                       [100%]

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link
Copy Markdown

codecov Bot commented May 11, 2026

Codecov Report

❌ Patch coverage is 17.64706% with 84 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/layers/moe.py 16.83% 72 Missing and 12 partials ⚠️

📢 Thoughts on this report? Let us know!

@parambole parambole force-pushed the dsv4-moe-routing-primitives branch from 37ee811 to 31329c5 Compare May 11, 2026 20:38
@parambole parambole force-pushed the deepseek_v4_core_primitives branch from 1ab79e5 to c025463 Compare May 12, 2026 17:22
@parambole parambole force-pushed the dsv4-moe-routing-primitives branch from 31329c5 to 22a57ff Compare May 12, 2026 17:23
@parambole parambole force-pushed the deepseek_v4_core_primitives branch from c025463 to 68c44a6 Compare May 12, 2026 21:12
@parambole parambole force-pushed the dsv4-moe-routing-primitives branch from 22a57ff to 32869e5 Compare May 12, 2026 21:12
@parambole parambole force-pushed the deepseek_v4_core_primitives branch 2 times, most recently from 72a92a7 to e81f52d Compare May 14, 2026 17:45
…outer, RoutedMoE)

Implement Mixture of Experts routing gates and execution layers for DeepSeek-V4 integration into MaxText:

- HashRouter: Token routing mechanism utilizing MD5 hash projections for deterministic expert assignment.
- TopKRouter: Gated top-k router implementing sigmoid scaling and score normalization.
- RoutedMoE & RoutedAndSharedMoE: Execution layers supporting layer_idx routing and FP32 expert summation parity.
- Parity verification: Extended unit test suite (deepseek_v4_vs_reference_test.py) validating routing parity against PyTorch reference implementations at atol=1e-5, rtol=1e-5.
@parambole parambole force-pushed the dsv4-moe-routing-primitives branch from 32869e5 to c92f2e0 Compare May 14, 2026 17:51
@parambole parambole changed the title Implement custom MoE HashRouter, TopKRouter, and sqrtsoftplus Implement DeepSeek-V4 MoE routing primitives (HashRouter, TopKRouter, RoutedMoE) May 14, 2026
@parambole parambole changed the title Implement DeepSeek-V4 MoE routing primitives (HashRouter, TopKRouter, RoutedMoE) [DeepSeek-V4] Implement MoE routing primitives (HashRouter, TopKRouter, RoutedMoE) May 14, 2026
@github-actions
Copy link
Copy Markdown

🤖 Hi @parambole, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the change! Have a few comments.

Comment thread src/maxtext/layers/moe.py
return jnp.sqrt(jax.nn.softplus(x))


class DeepSeekV4TopKRouter(nnx.Module):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see the logic is similar for GateLogic + topK

Gate logic:

class GateLogit(nnx.Module):

Topk:

def get_topk(self, gate_logits, pre_bias_logits, rngs=None):

Shall we leverage _sqrtsoftplus, and combine them to avoid some duplicate?

Comment thread src/maxtext/layers/moe.py
)


class DeepSeekV4HashRouter(nnx.Module):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering if we should name it HashRouter directly? Do you know if it's specific for DS v4?

Comment thread src/maxtext/layers/moe.py
with jax.named_scope("ffn_act"):
if self.config.decoder_block == ctypes.DecoderBlockType.GPT_OSS:
if self.config.decoder_block == ctypes.DecoderBlockType.DEEPSEEK_V4:
limit = getattr(self.config, "swiglu_limit", 1.0)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we reuse self.config.mlp_activations_limit config as bellow?

Comment thread src/maxtext/layers/moe.py
w0_pspec, w1_pspec, wo_pspec = maybe_aqt_partition(w0_kernel, w0_pspec, w1_kernel, w1_pspec, wo_kernel, wo_pspec)

if gate_weights is not None:
gate_weights_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", None))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: could you help add shape info here? same comment for gate_indices

Comment thread src/maxtext/layers/moe.py
routing_inputs = inputs if gate_inputs is None else gate_inputs.astype(gate_dtype)
gate_logits, pre_bias_logits = self.gate(routing_inputs)

if self.config.decoder_block == ctypes.DecoderBlockType.DEEPSEEK_V4:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you help make some comments for the conditions? Thanks!

@github-actions
Copy link
Copy Markdown

🤖 Hi @RissyRan, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

The Pull Request successfully implements the MoE routing primitives required for DeepSeek-V4, including HashRouter, TopKRouter, and updates to the execution layers. The code is well-structured and includes comprehensive unit tests validating parity with PyTorch reference implementations.

🔍 General Feedback

  • Integration Gaps: While the routing primitives are correct, the DeepSeek model definition (src/maxtext/models/deepseek.py) needs corresponding updates to pass layer_idx and input_ids to these new layers. Without these, the model will not behave correctly in DeepSeek-V4 mode.
  • Configuration Maintainability: The use of getattr with hardcoded defaults for swiglu_limit and num_hash_layers should be replaced with formal parameters in MaxTextConfig to ensure better discoverability and type safety.
  • Precision Parity: The explicit use of FP32 for expert summation and routing calculations is a positive highlight as it ensures numerical parity with reference implementations.

Comment thread src/maxtext/layers/moe.py
matmul_precision=self.config.matmul_precision,
shard_mode=config.shard_mode,
rngs=self.rngs,
self.is_hash = self.config.decoder_block == ctypes.DecoderBlockType.DEEPSEEK_V4 and 0 <= layer_idx < getattr(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 The integration of layer_idx is crucial for DeepSeek-V4 to distinguish between Hash and Top-K layers. However, the model definition in src/maxtext/models/deepseek.py does not yet pass this parameter during instantiation of RoutedAndSharedMoE. This will cause all layers to default to layer_idx=0, resulting in all layers using the HashRouter if num_hash_layers >= 1.

Comment thread src/maxtext/layers/moe.py
shard_mode=config.shard_mode,
rngs=self.rngs,
self.is_hash = self.config.decoder_block == ctypes.DecoderBlockType.DEEPSEEK_V4 and 0 <= layer_idx < getattr(
config, "num_hash_layers", 3
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 num_hash_layers should ideally be a formal configuration parameter in MaxTextConfig (in src/maxtext/configs/types.py) rather than being retrieved via getattr with a hardcoded default of 3.

Comment thread src/maxtext/layers/moe.py
gate_inputs: jax.Array | None = None,
out_sharding: NamedSharding | None = None,
input_ids: jax.Array | None = None,
gate_weights: jax.Array | None = None,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟠 DeepSeekV4HashRouter requires input_ids for expert assignment. While the interface now correctly supports this, ensure that src/maxtext/models/deepseek.py is updated to pass decoder_input_tokens as input_ids during the MoE block call, as it currently lacks this linkage.

Comment thread src/maxtext/layers/moe.py
with jax.named_scope("ffn_act"):
if self.config.decoder_block == ctypes.DecoderBlockType.GPT_OSS:
if self.config.decoder_block == ctypes.DecoderBlockType.DEEPSEEK_V4:
limit = getattr(self.config, "swiglu_limit", 1.0)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Instead of using getattr with a hardcoded default of 1.0, consider adding swiglu_limit to the MaxTextConfig class. This improves discoverability, documentation, and type safety for the configuration.

Comment thread src/maxtext/layers/moe.py
Computes logits, static routing weights based on token IDs, and expert indices.
"""

def __init__(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 The docstring mentions "static routing weights", which might be confusing as the weights themselves are learned based on logits. Only the expert assignment (indices) is static based on token IDs.

Suggested change
def __init__(
"""Hash Router for DeepSeek-V4 MoE routing.
Computes learned routing weights for a static expert assignment determined by token IDs.
"""

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants