From 9c5c5647f80daf59cdbabf71b36a0287a0d2bcc4 Mon Sep 17 00:00:00 2001 From: Param Bole Date: Wed, 20 May 2026 17:06:51 +0000 Subject: [PATCH 1/2] feat: implement DeepSeek-V4 model integration, decoders, and configuration stack --- src/maxtext/common/common_types.py | 1 + src/maxtext/configs/base.yml | 12 + .../configs/models/deepseek_v4-flash.yml | 78 + .../configs/models/deepseek_v4-tiny.yml | 74 + src/maxtext/configs/types.py | 24 + src/maxtext/layers/attention_compressed.py | 1082 ++++++++ src/maxtext/layers/decoders.py | 114 +- src/maxtext/layers/embeddings.py | 113 + src/maxtext/layers/engram.py | 9 + src/maxtext/layers/linears.py | 63 + src/maxtext/layers/mhc.py | 92 +- src/maxtext/layers/moe.py | 395 ++- src/maxtext/layers/nnx_decoders.py | 194 +- src/maxtext/layers/normalizations.py | 61 + src/maxtext/models/deepseek_v4.py | 426 +++ tests/unit/deepseek_v4_vs_reference_test.py | 2472 +++++++++++++++++ tests/unit/nnx_decoders_test.py | 38 + 17 files changed, 5169 insertions(+), 79 deletions(-) create mode 100644 src/maxtext/configs/models/deepseek_v4-flash.yml create mode 100644 src/maxtext/configs/models/deepseek_v4-tiny.yml create mode 100644 src/maxtext/layers/attention_compressed.py create mode 100644 src/maxtext/models/deepseek_v4.py create mode 100644 tests/unit/deepseek_v4_vs_reference_test.py diff --git a/src/maxtext/common/common_types.py b/src/maxtext/common/common_types.py index 86811063a6..529a00d53c 100644 --- a/src/maxtext/common/common_types.py +++ b/src/maxtext/common/common_types.py @@ -93,6 +93,7 @@ class DecoderBlockType(enum.Enum): MISTRAL = "mistral" MIXTRAL = "mixtral" DEEPSEEK = "deepseek" + DEEPSEEK_V4 = "deepseek_v4" GEMMA = "gemma" GEMMA2 = "gemma2" GEMMA3 = "gemma3" diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 98f8e39efa..d6fed399d3 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -268,6 +268,7 @@ topk_routing_group: -1 # number of top groups to route inputs. For EP, # all-to-all communication with compute. Currently only implemented with DeepSeek sparse layers. use_batch_split_schedule: False # a flag if splitting batch into micro-batches to hide communications that yields performance benefits. batch_split_factor: 1 # the factor by which to split the batch. Only used if use_batch_split_schedule is True. +num_hash_layers: 3 # Number of initial MoE layers to apply static Hash Routing. # For complex architectures like llama4 there are repeated sets of # inhomogeneous layers. E.g. maverick uses [dense+rope, moe+rope, dense+rope, moe+nope] @@ -405,6 +406,16 @@ qk_clip_threshold: 100.0 # Threshold for clipping (tau in the paper) fused_qkv: False fused_mlp: False +# DeepSeek-V4 Compressed Attention parameters +compress_rope_theta: 160000.0 +compress_ratios: [] +index_head_dim: 128 +index_n_heads: 64 +index_topk: 512 +o_groups: 8 +o_lora_rank: 1024 +sliding_window: 128 + record_internal_nn_metrics: 0 # Output directory @@ -1217,6 +1228,7 @@ force_q_layout: false mhc_expansion_rate: 1 # The number of iterations for the Sinkhorn-Knopp algorithm. sinkhorn_iterations: 20 +hc_eps: 1.0e-6 ################################## DeepSeek Engram ################################## # Indices of transformer layers where Engram are integrated; leave empty [] to disable. diff --git a/src/maxtext/configs/models/deepseek_v4-flash.yml b/src/maxtext/configs/models/deepseek_v4-flash.yml new file mode 100644 index 0000000000..27e3f66fb9 --- /dev/null +++ b/src/maxtext/configs/models/deepseek_v4-flash.yml @@ -0,0 +1,78 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Default model configs for DeepSeek-V4-Flash (43 Layers) + +base_config: base.yml +model_name: deepseek_v4-flash + +base_emb_dim: 4096 +base_num_query_heads: 64 +base_num_kv_heads: 1 +head_dim: 512 +base_mlp_dim: 2048 +base_moe_mlp_dim: 2048 +base_num_decoder_layers: 43 +first_num_dense_layers: 0 +mlp_activations: ["silu"] +vocab_size: 129280 +tokenizer_type: "huggingface" +tokenizer_path: "deepseek-ai/DeepSeek-V3" +enable_dropout: False +logits_via_embedding: False +normalization_layer_epsilon: 1.0e-6 +num_experts: 256 +num_experts_per_tok: 6 +shared_experts: 1 +routed_scaling_factor: 1.5 +routed_score_func: "sqrtsoftplus" +routed_bias: True +norm_topk_prob: True +decoder_block: "deepseek_v4" +pure_nnx_decoder: True +enable_nnx: True +use_tokamax_splash: True +# TODO: Dynamic Flash Attention masking requires use_indexer: True to route the custom compressor mask indexer_mask into Tokamax's make_dynamic_splash_mha dynamic Pallas kernel path. Refactor this MLA configuration parameter out of deep coordinator layers in subsequent milestones. +use_indexer: True +remat_policy: minimal_offloaded + +# Manifold-Constrained Hyper-Connection configurations +mhc_expansion_rate: 4 +sinkhorn_iterations: 20 +compress_rope_theta: 160000.0 +index_head_dim: 128 +index_n_heads: 64 +index_topk: 512 +o_groups: 8 +o_lora_rank: 1024 +sliding_window: 128 +num_hash_layers: 3 +mlp_activations_limit: 10.0 +compress_ratios: [0, 0, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4] + +# Compressed Sparse Attention +q_lora_rank: 1024 +kv_lora_rank: 512 +qk_nope_head_dim: 128 +qk_rope_head_dim: 64 +v_head_dim: 128 +mscale: 1.0 + +# RoPE +rope_type: "default" +rope_max_timescale: 10_000 +max_position_embeddings: 1048576 +original_max_position_embeddings: 65536 +rope_factor: 16 +beta_fast: 32 diff --git a/src/maxtext/configs/models/deepseek_v4-tiny.yml b/src/maxtext/configs/models/deepseek_v4-tiny.yml new file mode 100644 index 0000000000..dbf6b5cd7a --- /dev/null +++ b/src/maxtext/configs/models/deepseek_v4-tiny.yml @@ -0,0 +1,74 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Tiny version of DeepSeek-V4 (4 Layers) for local sharding and compilation checks. + +base_config: base.yml +model_name: deepseek_v4-tiny + +base_emb_dim: 128 +base_num_query_heads: 16 +base_num_kv_heads: 1 +head_dim: 32 +base_mlp_dim: 128 +base_moe_mlp_dim: 128 +base_num_decoder_layers: 6 +first_num_dense_layers: 0 +mlp_activations: ["silu"] +vocab_size: 129280 +tokenizer_type: "huggingface" +tokenizer_path: "deepseek-ai/DeepSeek-V3" +enable_dropout: False +logits_via_embedding: False +normalization_layer_epsilon: 1.0e-6 +num_experts: 8 +num_experts_per_tok: 4 +shared_experts: 1 +routed_scaling_factor: 1.5 +routed_score_func: "sqrtsoftplus" +routed_bias: True +norm_topk_prob: True +decoder_block: "deepseek_v4" +pure_nnx_decoder: True +enable_nnx: True + +# Manifold-Constrained Hyper-Connection configurations +mhc_expansion_rate: 4 +sinkhorn_iterations: 20 +compress_rope_theta: 160000.0 +index_head_dim: 32 +index_n_heads: 16 +index_topk: 64 +o_groups: 2 +o_lora_rank: 64 +sliding_window: 32 +num_hash_layers: 3 +mlp_activations_limit: 10.0 +compress_ratios: [0, 4, 128, 4, 128, 0] + +# Compressed Attention +q_lora_rank: 64 +kv_lora_rank: 32 +qk_nope_head_dim: 32 +qk_rope_head_dim: 16 +v_head_dim: 128 +mscale: 1.0 + +# RoPE +rope_type: "default" +rope_max_timescale: 10_000 +max_position_embeddings: 163840 +original_max_position_embeddings: 4096 +rope_factor: 40 +beta_fast: 32 diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index b3261ab75d..b27add7836 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -226,6 +226,8 @@ class ProfilerType(str, Enum): "deepseek3-test", "deepseek3-tiny", "deepseek3.2-671b", + "deepseek_v4-tiny", + "deepseek_v4-flash", "deepseek-custom", "kimi-k2-1t", "gemma-7b", @@ -618,6 +620,22 @@ class AttentionIndexer(BaseModel): indexer_loss_scaling_factor: float = Field(0.0, description="Multiplier for the indexer KL divergence loss.") +class DeepSeekV4AttentionConfig(BaseModel): + """Configuration specific to DeepSeek-V4 stateless compressed attention layers.""" + + compress_rope_theta: float = Field(160000.0, description="Theta base frequency for long-range compressor layers.") + compress_ratios: list[int] = Field( + default_factory=list, + description="Layer-by-layer compressor rates (0: standard, 4: CSA, 128: HCA).", + ) + index_head_dim: int = Field(128, description="Head dim for indexer query and key.") + index_n_heads: int = Field(64, description="Number of query heads in indexer.") + index_topk: int = Field(512, description="Number of tokens selected by indexer.") + o_groups: int = Field(8, description="Number of group partitions for grouped linear output projection.") + o_lora_rank: int = Field(1024, description="Low-rank output dimension prior to grouped mix projection.") + sliding_window: int = Field(128, description="Sliding window size for attention.") + + class Llama4Attention(BaseModel): """Configuration specific to Llama4-style models.""" @@ -815,6 +833,10 @@ class DeepSeekMoE(BaseModel): 1, description="Factor by which to split the batch into micro-batches. Only used if use_batch_split_schedule is True.", ) + num_hash_layers: int = Field( + 3, + description="Number of initial MoE layers to apply static Hash Routing.", + ) class Qwen3Next(BaseModel): @@ -1365,6 +1387,7 @@ class ManifoldConstrainedHyperConnections(BaseModel): mhc_expansion_rate: PositiveInt = Field(1, description="The number of parallel streams in Hyper Connection.") sinkhorn_iterations: PositiveInt = Field(20, description="The number of iterations for the Sinkhorn-Knopp algorithm.") + hc_eps: float = Field(1e-6, description="The epsilon fallback value for numerical stability in mHC.") class DilocoParams(BaseModel): @@ -2224,6 +2247,7 @@ class MaxTextConfig( MlaAttention, MoBa, AttentionIndexer, + DeepSeekV4AttentionConfig, Llama4Attention, SplashAttention, PagedAttention, diff --git a/src/maxtext/layers/attention_compressed.py b/src/maxtext/layers/attention_compressed.py new file mode 100644 index 0000000000..f69da347ff --- /dev/null +++ b/src/maxtext/layers/attention_compressed.py @@ -0,0 +1,1082 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Compressed Attention layers and long-range compressors.""" + +from typing import Any +import jax +import jax.numpy as jnp +from jax.sharding import Mesh +from flax import nnx +from maxtext.layers.embeddings import DeepSeekV4RotaryEmbedding, apply_rotary_pos_emb +from maxtext.layers.normalizations import DeepSeekV4RMSNorm, DeepSeekV4UnweightedRMSNorm +from maxtext.layers.linears import DeepSeekGroupedLinear +from maxtext.layers.attention_op import AttentionOp +from maxtext.common.common_types import MODEL_MODE_TRAIN, AttentionType + + +class HCACompressor(nnx.Module): + """Heavily Compressed Attention (HCA) long-range compressor layer. + + This layer groups sequence features into non-overlapping windows of size 'compress_rate', + applies learnable pooling gates combined with static positional bias, averages the features + inside each window to emit a single compressed representation per window, and rotates the + resulting compressed sequence using interleaved rotary embeddings. + """ + + def __init__( + self, + hidden_size: int, + head_dim: int, + config: Any, + layer_idx: int, + eps: float = 1e-6, + weight_dtype: Any = jnp.float32, + dtype: Any = jnp.float32, + *, + rngs: nnx.Rngs, + ): + """Initializes the Heavily Compressed Attention (HCA) long-range compressor. + + Args: + hidden_size: The model's global hidden dimension size. + head_dim: The projection size of each attention key-value channel. + config: The DeepSeekV4 model configurations metadata. + layer_idx: The sequential layer depth index of this compressor in the decoder stack. + eps: The tiny additive variance limit for RMS normalization stability. + weight_dtype: The parameter weights numerical data type. + dtype: The mathematical execution numerical data type. + rngs: The standard Flax NNX random number generator collection. + """ + super().__init__() + self.compress_rate = config.compress_ratios[layer_idx] + self.head_dim = head_dim + self.hidden_size = hidden_size + self.eps = eps + self.weight_dtype = weight_dtype + self.dtype = dtype + rope_theta = config.compress_rope_theta + + # Linear projection of inputs to key/value representation + self.kv_proj = nnx.Linear( + in_features=hidden_size, + out_features=head_dim, + use_bias=False, + dtype=dtype, + param_dtype=weight_dtype, + rngs=rngs, + ) + + # Linear projection of inputs to gate logits + self.gate_proj = nnx.Linear( + in_features=hidden_size, + out_features=head_dim, + use_bias=False, + dtype=dtype, + param_dtype=weight_dtype, + rngs=rngs, + ) + + # Positional bias parameter added to gate logits inside each window + self.position_bias = nnx.Param( + jax.nn.initializers.normal(stddev=0.02)( + rngs.params(), + (self.compress_rate, head_dim), + weight_dtype, + ) + ) + + # RMS normalization applied to pooled window features + self.kv_norm = DeepSeekV4RMSNorm( + hidden_size=head_dim, + eps=eps, + dtype=dtype, + weight_dtype=weight_dtype, + ) + + # Interleaved rotary embeddings applied to the trailing slice + self.rotary_emb = DeepSeekV4RotaryEmbedding( + head_dim=head_dim, + partial_rotary_factor=config.qk_rope_head_dim / config.head_dim, + rope_theta=rope_theta, + ) + + def __call__( + self, + hidden_states: jnp.ndarray, + q_residual: Any = None, + position_ids: jnp.ndarray = None, + ) -> tuple[jnp.ndarray, jnp.ndarray | None]: + """Applies Heavily Compressed Attention (HCA) compression to sequence keys and values. + + This method splits the sequence into non-overlapping windows of size 'compress_rate', + aggregates feature representation vectors using Softmax-weighted gates, normalizes the + resulting vectors using RMS norm, applies position-aware interleaved rotary embeddings, + and expands the output dimension to match standard multi-head key-value layouts. + + Args: + hidden_states: The input hidden representation sequence of shape [B, S, D_model]. + q_residual: Ignored optional placeholder matching polymorphic calling conventions. + position_ids: Optional position indicators of shape [B, S]. + + Returns: + Compressed, position-encoded representation tensor of shape [B, 1, W, D_head], + where W is the compressed sequence length equal to S // compress_rate. + """ + # hidden_states: [B, S, D_model] + # position_ids: [B, S] + batch, seq_len, _ = hidden_states.shape + + # Project inputs to key/value and gate representations + # kv: [B, S, D_head] + # gate: [B, S, D_head] + kv = self.kv_proj(hidden_states) + gate = self.gate_proj(hidden_states) + + # Compute sequence multiple bound corresponding to the window stride rate + # usable: scalar integer + usable = (seq_len // self.compress_rate) * self.compress_rate + n_windows = usable // self.compress_rate + + # Slice sequences to match clean multiple dimensions + # chunk_kv: [B, S_usable, D_head] + # chunk_gate: [B, S_usable, D_head] + chunk_kv = kv[:, :usable, :] + chunk_gate = gate[:, :usable, :] + + # Reshape sliced inputs into non-overlapping windows of size 'compress_rate' + # chunk_kv: [B, W, compress_rate, D_head] + # chunk_gate: [B, W, compress_rate, D_head] + chunk_kv = chunk_kv.reshape(batch, n_windows, self.compress_rate, self.head_dim) + chunk_gate = chunk_gate.reshape(batch, n_windows, self.compress_rate, self.head_dim) + + # Add positional bias parameters to gate logits + # chunk_gate: [B, W, compress_rate, D_head] + position_bias = jnp.asarray(self.position_bias[...], self.dtype) + chunk_gate = chunk_gate + position_bias[jnp.newaxis, jnp.newaxis, :, :] + + # Compute softmax aggregation probabilities in float32 for stability + # gate_softmax: [B, W, compress_rate, D_head] + gate_softmax = jax.nn.softmax(chunk_gate.astype(jnp.float32), axis=2).astype(self.dtype) + + # Aggregate key/value features using computed gate weights + # pooled: [B, W, D_head] + pooled = jnp.sum(chunk_kv * gate_softmax, axis=2) + + # Normalize aggregated window features + # compressed: [B, W, D_head] + compressed = self.kv_norm(pooled) + + # Determine absolute sequence indexes corresponding to each window start + # positions: [B, W] + positions = jnp.arange(n_windows, dtype=jnp.int32) * self.compress_rate + positions = jnp.broadcast_to(positions[jnp.newaxis, :], (batch, n_windows)) + + # Compute interleaved rotary embeddings sine and cosine values + # cos: [B, W, D_rope/2] + # sin: [B, W, D_rope/2] + cos, sin = self.rotary_emb(compressed, positions) + + # Expand dimensions to allow broadcasting over head axis during rotary mapping + # compressed_4d: [B, W, 1, D_head] + compressed_4d = jnp.expand_dims(compressed, axis=2) + + # Apply interleaved RoPE rotation over the trailing slice + # rotated_4d: [B, W, 1, D_head] + rotated_4d = apply_rotary_pos_emb(compressed_4d, cos, sin, unsqueeze_dim=2) + + # Squeeze dummy head dimension to recover standard 3D shape layout + # rotated: [B, W, D_head] + rotated = jnp.squeeze(rotated_4d, axis=2) + + # Expand output format to match standard multi-head key/value dimensions + # compressed_kv: [B, 1, W, D_head] + compressed_kv = jnp.expand_dims(rotated, axis=1) + + # Evaluate caching dimensions boundary checks to prevent empty execution + compressed_len = n_windows + if seq_len == 1 or compressed_len == 0: + return compressed_kv, None + + # Compute causal block bias mask over compressed sequence segments to prevent query leakage. + # A query at sequence position `t` is restricted from attending to any compressed cache block + # index `w` if `t <= w * compress_rate`. This represents future sequence information that is + # mathematically unavailable at position `t`. + # + # entry_indices: [W] representing compressed block window positions + entry_indices = jnp.arange(compressed_len, dtype=jnp.int32) + # causal_threshold: [B, S] representing ready block count boundaries per sequence token + causal_threshold = (position_ids + 1) // self.compress_rate + # Construct sequence-level causal future mask via dimension broadcasting. + # future_mask: [B, 1, S, W] + future_mask = ( + entry_indices[jnp.newaxis, jnp.newaxis, jnp.newaxis, :] >= causal_threshold[:, jnp.newaxis, :, jnp.newaxis] + ) + # Initialize causal block bias containing -inf mask values for invalid future elements. + # block_bias: [B, 1, S, W] + block_bias = jnp.where(future_mask, -jnp.inf, 0.0) + return compressed_kv, block_bias + + +class DeepSeekV4Indexer(nnx.Module): + """Lightning Indexer (paper §2.3.1, eqs. 13–17). + + Used by Compressed Sparse Attention (CSA) to pick the top-k compressed KV + blocks per query. + """ + + def __init__( + self, + hidden_size: int, + q_lora_rank: int, + config: Any, + layer_idx: int, + eps: float = 1e-6, + weight_dtype: Any = jnp.float32, + dtype: Any = jnp.float32, + *, + rngs: nnx.Rngs, + ): + """Initializes the Lightning Indexer. + + Args: + hidden_size: The model's global hidden dimension size. + q_lora_rank: The projection rank dimension of Q LoRA. + config: The DeepSeekV4 model configurations metadata. + layer_idx: The decoder stack layer index containing this indexer. + eps: Tiny additive variance limit for RMS normalization stability. + weight_dtype: The parameter weights numerical data type. + dtype: The mathematical execution numerical data type. + rngs: The Flax NNX random number generator collection. + """ + super().__init__() + self.compress_rate = config.compress_ratios[layer_idx] + self.num_heads = config.index_n_heads + self.head_dim = config.index_head_dim + self.index_topk = config.index_topk + self.softmax_scale = config.index_head_dim**-0.5 + self.weights_scaling = config.index_n_heads**-0.5 + self.dtype = dtype + self.weight_dtype = weight_dtype + rope_theta = config.compress_rope_theta + + # Key projections for indexing-scale compression + self.kv_proj = nnx.Linear( + in_features=hidden_size, + out_features=2 * self.head_dim, + use_bias=False, + dtype=dtype, + param_dtype=weight_dtype, + rngs=rngs, + ) + + # Gate projections for indexing-scale compression + self.gate_proj = nnx.Linear( + in_features=hidden_size, + out_features=2 * self.head_dim, + use_bias=False, + dtype=dtype, + param_dtype=weight_dtype, + rngs=rngs, + ) + + # Positional bias parameters inside indexing windows + self.position_bias = nnx.Param( + jax.nn.initializers.normal(stddev=0.02)( + rngs.params(), + (self.compress_rate, 2 * self.head_dim), + weight_dtype, + ) + ) + + # RMS normalization for indexer key values + self.kv_norm = DeepSeekV4RMSNorm( + hidden_size=self.head_dim, + eps=eps, + dtype=dtype, + weight_dtype=weight_dtype, + ) + + # Query projection mapping Q LoRA rank to multi-head indexing features + self.q_b_proj = nnx.Linear( + in_features=q_lora_rank, + out_features=self.num_heads * self.head_dim, + use_bias=False, + dtype=dtype, + param_dtype=weight_dtype, + rngs=rngs, + ) + + # Dynamic score scaling projection + self.weights_proj = nnx.Linear( + in_features=hidden_size, + out_features=self.num_heads, + use_bias=False, + dtype=dtype, + param_dtype=weight_dtype, + rngs=rngs, + ) + + # Interleaved rotary embedding aligning query/key pos representations + self.rotary_emb = DeepSeekV4RotaryEmbedding( + head_dim=self.head_dim, + partial_rotary_factor=config.qk_rope_head_dim / self.head_dim, + rope_theta=rope_theta, + ) + + def __call__( + self, + hidden_states: jnp.ndarray, + q_residual: jnp.ndarray, + position_ids: jnp.ndarray, + ) -> jnp.ndarray: + """Computes top-k relevant compressed block indices per query position. + + This method compresses sequence keys and values into overlapping window + segments, applies position-aware RoPE encoding, projects incoming query residuals + into alignment spaces, computes similarity matrices across query positions and + windows, dynamically scales/weights scores using projected head scaling arrays, + and selects the top-k windows using JAX optimized top_k primitives. + + Args: + hidden_states: The input sequence representations of shape [B, S, D_model]. + q_residual: The Q LoRA low-rank query projections of shape [B, S, D_rank]. + position_ids: The sequence absolute position identifiers of shape [B, S]. + + Returns: + Integer index array of shape [B, S, k] containing the gathered top-k + compressed window indices for each query position, where k = index_topk. + """ + # hidden_states: [B, S, D_model] + # q_residual: [B, S, D_rank] + # position_ids: [B, S] + batch, seq_len, _ = hidden_states.shape + + # Project inputs to index keys and gates + # kv: [B, S, 2 * D_idx] + # gate: [B, S, 2 * D_idx] + kv = self.kv_proj(hidden_states) + gate = self.gate_proj(hidden_states) + + # Calculate sequence bounds matching the stride rate + # usable: scalar integer + usable = (seq_len // self.compress_rate) * self.compress_rate + n_windows = usable // self.compress_rate + + # Slice sequences to valid sequence bounds + # chunk_kv: [B, S_usable, 2 * D_idx] + # chunk_gate: [B, S_usable, 2 * D_idx] + chunk_kv = kv[:, :usable, :] + chunk_gate = gate[:, :usable, :] + + # Segment sliced elements into non-overlapping windows + # chunk_kv: [B, W, compress_rate, 2 * D_idx] + # chunk_gate: [B, W, compress_rate, 2 * D_idx] + chunk_kv = chunk_kv.reshape(batch, n_windows, self.compress_rate, 2 * self.head_dim) + chunk_gate = chunk_gate.reshape(batch, n_windows, self.compress_rate, 2 * self.head_dim) + + # Incorporate static positional bias parameters + # chunk_gate: [B, W, compress_rate, 2 * D_idx] + position_bias = jnp.asarray(self.position_bias[...], self.dtype) + chunk_gate = chunk_gate + position_bias[jnp.newaxis, jnp.newaxis, :, :] + + # Overlap slicing setups: segment into Ca / Cb series + # prev_kv: [B, W, compress_rate, D_idx] (Ca) + # curr_kv: [B, W, compress_rate, D_idx] (Cb) + # prev_gate: [B, W, compress_rate, D_idx] (Ca) + # curr_gate: [B, W, compress_rate, D_idx] (Cb) + prev_kv = chunk_kv[..., : self.head_dim] + curr_kv = chunk_kv[..., self.head_dim :] + prev_gate = chunk_gate[..., : self.head_dim] + curr_gate = chunk_gate[..., self.head_dim :] + + # Set up combined padded layouts for boundary window overlap calculations + # new_kv: [B, W, 2 * compress_rate, D_idx] + # new_gate: [B, W, 2 * compress_rate, D_idx] + new_kv = jnp.zeros((batch, n_windows, 2 * self.compress_rate, self.head_dim), dtype=self.dtype) + new_gate = jnp.full((batch, n_windows, 2 * self.compress_rate, self.head_dim), -jnp.inf, dtype=self.dtype) + + # Map Cb representations into second half slots + new_kv = new_kv.at[:, :, self.compress_rate :].set(curr_kv) + new_gate = new_gate.at[:, :, self.compress_rate :].set(curr_gate) + + # Map Ca representations of preceding windows into first half slots + if n_windows > 1: + new_kv = new_kv.at[:, 1:, : self.compress_rate].set(prev_kv[:, :-1, :, :]) + new_gate = new_gate.at[:, 1:, : self.compress_rate].set(prev_gate[:, :-1, :, :]) + + # Aggregate indexing features using gate softmax probabilities computed in float32 + # gate_softmax: [B, W, 2 * compress_rate, D_idx] + gate_softmax = jax.nn.softmax(new_gate.astype(jnp.float32), axis=2).astype(self.dtype) + # pooled: [B, W, D_idx] + pooled = jnp.sum(new_kv * gate_softmax, axis=2) + + # Normalize index keys + # compressed: [B, W, D_idx] + compressed = self.kv_norm(pooled) + + # Extract absolute starting positions of index windows + # positions: [B, W] + positions = jnp.arange(n_windows, dtype=jnp.int32) * self.compress_rate + positions = jnp.broadcast_to(positions[jnp.newaxis, :], (batch, n_windows)) + + # Compute sinusoids and apply interleaved rotary embeddings + # cos: [B, W, D_idx_rope/2] + # sin: [B, W, D_idx_rope/2] + cos, sin = self.rotary_emb(compressed, positions) + # compressed_4d: [B, W, 1, D_idx] + compressed_4d = jnp.expand_dims(compressed, axis=2) + # rotated_4d: [B, W, 1, D_idx] + rotated_4d = apply_rotary_pos_emb(compressed_4d, cos, sin, unsqueeze_dim=2) + # compressed_kv: [B, W, D_idx] + compressed_kv = jnp.squeeze(rotated_4d, axis=2) + + # Project and reshape queries to multiple head alignments + # q: [B, S, H, D_idx] + q = self.q_b_proj(q_residual) + q = q.reshape(batch, seq_len, self.num_heads, self.head_dim) + + # Compute rotary components matching current query positions + # cos_q: [B, S, D_idx_rope/2] + # sin_q: [B, S, D_idx_rope/2] + cos_q, sin_q = self.rotary_emb(hidden_states, position_ids) + # Apply RoPE to query elements + # q: [B, S, H, D_idx] + q = apply_rotary_pos_emb(q, cos_q, sin_q, unsqueeze_dim=2) + + # Calculate attention alignment scores across windows + # swaped_kv: [B, 1, D_idx, W] + swaped_kv = jnp.swapaxes(compressed_kv, -1, -2) + swaped_kv = jnp.expand_dims(swaped_kv, axis=1) + # scores: [B, S, H, W] + scores = jnp.matmul(q, swaped_kv) + scores = jax.nn.relu(scores) * self.softmax_scale + + # Project and scale dynamic aggregation scoring weights + # weights: [B, S, H] + weights = self.weights_proj(hidden_states) * self.weights_scaling + # Aggregate scoring profiles over heads axis + # index_scores: [B, S, W] + index_scores = jnp.sum(scores * jnp.expand_dims(weights, axis=-1), axis=2) + + # Extract top-k scoring compressed blocks per query sequence position + # topk_indices: [B, S, k] + compressed_len = compressed_kv.shape[1] + topk_limit = min(self.index_topk, compressed_len) + + if compressed_len > 0: + # Compute sequence-level causal ready block counts. + # causal_threshold: [B, S] + causal_threshold = (position_ids + 1) // self.compress_rate + # entry_indices: [W] + entry_indices = jnp.arange(compressed_len, dtype=jnp.int32) + # Construct query-specific causal mask along compressed index dimension. + # future_mask: [B, S, W] + future_mask = entry_indices[jnp.newaxis, jnp.newaxis, :] >= causal_threshold[:, :, jnp.newaxis] + # Zero-out future block scores by masking them with -inf prior to top-k calculations. + # index_scores: [B, S, W] + index_scores = jnp.where(future_mask, -jnp.inf, index_scores) + # Select top-k indices per token position based on masked scores. + # topk_indices: [B, S, k] + _, topk_indices = jax.lax.top_k(index_scores, topk_limit) + # Early tokens with too few ready blocks will still have invalid top-k selections pointing + # to future blocks. Detect them and replace with a `-1` sentinel. + # invalid: [B, S, k] + invalid = topk_indices >= causal_threshold[:, :, jnp.newaxis] + topk_indices = jnp.where(invalid, -1, topk_indices) + return topk_indices + + # Fallback stateless default top-k select path + _, topk_indices = jax.lax.top_k(index_scores, topk_limit) + return topk_indices + + +class CSACompressor(nnx.Module): + """Compressed Sparse Attention (CSA) compressor layer. + + This layer aggregates token representations into overlapping Ca/Cb window segments, + normalizes/rotates them, and uses the DeepSeekV4Indexer to gather the top-k + relevant compressed KV blocks per query. + """ + + def __init__( + self, + hidden_size: int, + q_lora_rank: int, + head_dim: int, + config: Any, + layer_idx: int, + eps: float = 1e-6, + weight_dtype: Any = jnp.float32, + dtype: Any = jnp.float32, + *, + rngs: nnx.Rngs, + ): + """Initializes the Compressed Sparse Attention (CSA) compressor. + + Args: + hidden_size: The model's global hidden dimension size. + q_lora_rank: The projection rank dimension of Q LoRA. + head_dim: The projection size of each attention key-value channel. + config: The DeepSeekV4 model configurations metadata. + layer_idx: The decoder stack layer index containing this compressor. + eps: Tiny additive variance limit for RMS normalization stability. + weight_dtype: The parameter weights numerical data type. + dtype: The mathematical execution numerical data type. + rngs: The Flax NNX random number generator collection. + """ + super().__init__() + self.compress_rate = config.compress_ratios[layer_idx] + self.head_dim = head_dim + self.hidden_size = hidden_size + self.eps = eps + self.weight_dtype = weight_dtype + self.dtype = dtype + rope_theta = config.compress_rope_theta + + # Projections for outer compressed key/value formats + self.kv_proj = nnx.Linear( + in_features=hidden_size, + out_features=2 * head_dim, + use_bias=False, + dtype=dtype, + param_dtype=weight_dtype, + rngs=rngs, + ) + + # Projections for outer gate logits + self.gate_proj = nnx.Linear( + in_features=hidden_size, + out_features=2 * head_dim, + use_bias=False, + dtype=dtype, + param_dtype=weight_dtype, + rngs=rngs, + ) + + # Static positional biases added inside windows + self.position_bias = nnx.Param( + jax.nn.initializers.normal(stddev=0.02)( + rngs.params(), + (config.compress_ratios[layer_idx], 2 * head_dim), + weight_dtype, + ) + ) + + # RMS normalization applied to aggregated representations + self.kv_norm = DeepSeekV4RMSNorm( + hidden_size=head_dim, + eps=eps, + dtype=dtype, + weight_dtype=weight_dtype, + ) + + # Interleaved rotary embeddings for compressed sequences + self.rotary_emb = DeepSeekV4RotaryEmbedding( + head_dim=head_dim, + partial_rotary_factor=config.qk_rope_head_dim / config.head_dim, + rope_theta=rope_theta, + ) + + # Lightning Indexer component + self.indexer = DeepSeekV4Indexer( + hidden_size=hidden_size, + q_lora_rank=q_lora_rank, + config=config, + layer_idx=layer_idx, + eps=eps, + weight_dtype=weight_dtype, + dtype=dtype, + rngs=rngs, + ) + + def __call__( + self, + hidden_states: jnp.ndarray, + q_residual: jnp.ndarray, + position_ids: jnp.ndarray, + ) -> tuple[jnp.ndarray, jnp.ndarray]: + """Applies Compressed Sparse Attention (CSA) compression and gathers top-k blocks. + + This method compresses sequence keys and values into overlapping window + segments, applies position-aware RoPE encoding, runs the Lightning Indexer to + extract the top-k scoring window indices for each query position, executes a + high-performance TPU-efficient advanced gather, and shapes the output to match + standard multi-head key-value layouts. + + Args: + hidden_states: The input sequence representations of shape [B, S, D_model]. + q_residual: The Q LoRA low-rank query projections of shape [B, S, D_rank]. + position_ids: The sequence absolute position identifiers of shape [B, S]. + + Returns: + Position-encoded, gathered key-value representation tensor of shape + [B, 1, S * k, D_head], where k = index_topk. + """ + # hidden_states: [B, S, D_model] + # q_residual: [B, S, D_rank] + # position_ids: [B, S] + batch, seq_len, _ = hidden_states.shape + + # Project input features to key/value and gate components + # kv: [B, S, 2 * D] + # gate: [B, S, 2 * D] + kv = self.kv_proj(hidden_states) + gate = self.gate_proj(hidden_states) + + # Determine valid sequence bounds + # usable: scalar integer + usable = (seq_len // self.compress_rate) * self.compress_rate + n_windows = usable // self.compress_rate + + # Slice inputs to sequence bounds + # chunk_kv: [B, S_usable, 2 * D] + # chunk_gate: [B, S_usable, 2 * D] + chunk_kv = kv[:, :usable, :] + chunk_gate = gate[:, :usable, :] + + # Segment sliced elements into non-overlapping windows + # chunk_kv: [B, W, compress_rate, 2 * D] + # chunk_gate: [B, W, compress_rate, 2 * D] + chunk_kv = chunk_kv.reshape(batch, n_windows, self.compress_rate, 2 * self.head_dim) + chunk_gate = chunk_gate.reshape(batch, n_windows, self.compress_rate, 2 * self.head_dim) + + # Aggregate window gate logits with static positional biases + # chunk_gate: [B, W, compress_rate, 2 * D] + position_bias = jnp.asarray(self.position_bias[...], self.dtype) + chunk_gate = chunk_gate + position_bias[jnp.newaxis, jnp.newaxis, :, :] + + # Overlap slicing: extract Ca / Cb configurations + # prev_kv: [B, W, compress_rate, D] (Ca) + # curr_kv: [B, W, compress_rate, D] (Cb) + # prev_gate: [B, W, compress_rate, D] (Ca) + # curr_gate: [B, W, compress_rate, D] (Cb) + prev_kv = chunk_kv[..., : self.head_dim] + curr_kv = chunk_kv[..., self.head_dim :] + prev_gate = chunk_gate[..., : self.head_dim] + curr_gate = chunk_gate[..., self.head_dim :] + + # Assemble padded window layouts for overlap combination + # new_kv: [B, W, 2 * compress_rate, D] + # new_gate: [B, W, 2 * compress_rate, D] + new_kv = jnp.zeros((batch, n_windows, 2 * self.compress_rate, self.head_dim), dtype=self.dtype) + new_gate = jnp.full((batch, n_windows, 2 * self.compress_rate, self.head_dim), -jnp.inf, dtype=self.dtype) + + # Set current window representations to second half slots + new_kv = new_kv.at[:, :, self.compress_rate :].set(curr_kv) + new_gate = new_gate.at[:, :, self.compress_rate :].set(curr_gate) + + # Set previous window representations to first half slots + if n_windows > 1: + new_kv = new_kv.at[:, 1:, : self.compress_rate].set(prev_kv[:, :-1, :, :]) + new_gate = new_gate.at[:, 1:, : self.compress_rate].set(prev_gate[:, :-1, :, :]) + + # Aggregate features using window gate softmax probabilities computed in float32 + # gate_softmax: [B, W, 2 * compress_rate, D] + gate_softmax = jax.nn.softmax(new_gate.astype(jnp.float32), axis=2).astype(self.dtype) + # pooled: [B, W, D] + pooled = jnp.sum(new_kv * gate_softmax, axis=2) + + # Normalize window features + # compressed: [B, W, D] + compressed = self.kv_norm(pooled) + + # Obtain starting positions of compressed windows + # positions: [B, W] + positions = jnp.arange(n_windows, dtype=jnp.int32) * self.compress_rate + positions = jnp.broadcast_to(positions[jnp.newaxis, :], (batch, n_windows)) + + # Apply interleaved rotary embeddings over aggregated outputs + # cos: [B, W, D_rope/2] + # sin: [B, W, D_rope/2] + cos, sin = self.rotary_emb(compressed, positions) + # compressed_4d: [B, W, 1, D] + compressed_4d = jnp.expand_dims(compressed, axis=2) + # rotated_4d: [B, W, 1, D] + rotated_4d = apply_rotary_pos_emb(compressed_4d, cos, sin, unsqueeze_dim=2) + # compressed_kv: [B, W, D] + compressed_kv = jnp.squeeze(rotated_4d, axis=2) + + # Execute Lightning Indexer to obtain block indices per query + # topk: [B, S, k] + topk = self.indexer(hidden_states, q_residual, position_ids) + + # Clamp indices safely using jnp.clip to avoid JAX negative/out-of-bounds indexing exceptions + # under indexer -1 sentinel conditions. + # safe_indices: [B, S, k] + safe_indices = jnp.clip(topk, a_min=0) + # batch_idx: [B, 1, 1] + batch_idx = jnp.arange(batch)[:, jnp.newaxis, jnp.newaxis] + # Perform TPU-efficient JAX Advanced Indexing Gather. + # gathered: [B, S, k, D] + gathered = compressed_kv[batch_idx, safe_indices, :] + + # Reshape gathered elements to standardized multi-head formats + # compressed_kv_out: [B, 1, S * k, D] + compressed_kv_out = gathered.reshape(batch, 1, seq_len * topk.shape[-1], self.head_dim) + + # Vectorized block bias mask construction to filter out invalid sparse gathered entries. + # valid: [B, S, k] indicating whether each top-k selection is valid (non-sentinel) + valid = topk >= 0 + # allowed: [B, S, k] containing 0.0 for valid entries and -inf for invalid sentinels + allowed = jnp.where(valid, 0.0, -jnp.inf) + # Construct an equivalence diagonal mask matching query sequence indices. + # eq_mask: [S, S, 1] representing identity query boundaries + eq_mask = jnp.arange(seq_len)[:, jnp.newaxis, jnp.newaxis] == jnp.arange(seq_len)[jnp.newaxis, :, jnp.newaxis] + # allowed_expanded: [B, S, 1, k] + allowed_expanded = allowed[:, :, jnp.newaxis, :] + # Distribute allowed masks diagonally using JAX vectorization to prevent cross-query leakage. + # block_bias_5d: [B, S, S, k] + block_bias_5d = jnp.where(eq_mask[jnp.newaxis, :, :, :], allowed_expanded, -jnp.inf) + # Reshape and format to standard key-value sequence length formats + # block_bias: [B, S, S * k] + block_bias = block_bias_5d.reshape(batch, seq_len, seq_len * topk.shape[-1]) + # block_bias: [B, 1, S, S * k] + block_bias = jnp.expand_dims(block_bias, axis=1) + return compressed_kv_out, block_bias + + +class DeepSeekV4Attention(nnx.Module): + """Main coordination attention block for DeepSeek-V4 compressed layer configurations. + + This module implements multi-head attention augmented with query-compression LoRA + projections, unweighted key/value normalizations, optional heavily or sparsely + compressed long-range context compressor integrations, learnable attention sinks, + and parallelized grouped output mixing projections. + """ + + def __init__( + self, + hidden_size: int, + q_lora_rank: int, + head_dim: int, + num_heads: int, + config: Any, + layer_idx: int, + mesh: Mesh | None = None, + eps: float = 1e-6, + weight_dtype: Any = jnp.float32, + dtype: Any = jnp.float32, + attention_type: str = "compressed_sparse_attention", + *, + rngs: nnx.Rngs, + ): + """Initializes the DeepSeekV4 Attention coordinator block. + + Args: + hidden_size: The model's global hidden dimension size. + q_lora_rank: The projection rank dimension of Q LoRA. + head_dim: The projection size of each attention key-value channel. + num_heads: The total number of query attention heads. + config: The DeepSeekV4 model configurations metadata. + layer_idx: The decoder stack layer index containing this attention module. + eps: Tiny additive variance limit for RMS normalization stability. + weight_dtype: The parameter weights numerical data type. + dtype: The mathematical execution numerical data type. + attention_type: The type of compressed attention being instantiated. + rngs: The Flax NNX random number generator collection. + """ + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.attention_type = attention_type + self.num_heads = num_heads + self.head_dim = head_dim + self.sliding_window = config.sliding_window + self.scaling = head_dim**-0.5 + self.dtype = dtype + self.weight_dtype = weight_dtype + + # Projections for query extraction and low-rank compression + self.q_a_proj = nnx.Linear( + in_features=hidden_size, + out_features=q_lora_rank, + use_bias=False, + dtype=dtype, + param_dtype=weight_dtype, + rngs=rngs, + ) + self.q_a_norm = DeepSeekV4RMSNorm( + hidden_size=q_lora_rank, + eps=eps, + dtype=dtype, + weight_dtype=weight_dtype, + ) + self.q_b_proj = nnx.Linear( + in_features=q_lora_rank, + out_features=num_heads * head_dim, + use_bias=False, + dtype=dtype, + param_dtype=weight_dtype, + rngs=rngs, + ) + self.q_b_norm = DeepSeekV4UnweightedRMSNorm( + eps=eps, + dtype=dtype, + ) + + # Unified projected shared MQA key/value block + self.kv_proj = nnx.Linear( + in_features=hidden_size, + out_features=head_dim, + use_bias=False, + dtype=dtype, + param_dtype=weight_dtype, + rngs=rngs, + ) + self.kv_norm = DeepSeekV4RMSNorm( + hidden_size=head_dim, + eps=eps, + dtype=dtype, + weight_dtype=weight_dtype, + ) + + # Block-diagonal grouped linear layer for multi-head features mixing + self.o_a_proj = DeepSeekGroupedLinear( + in_features_per_group=num_heads * head_dim // config.o_groups, + out_features=config.o_groups * config.o_lora_rank, + n_groups=config.o_groups, + weight_dtype=weight_dtype, + dtype=dtype, + rngs=rngs, + ) + self.o_b_proj = nnx.Linear( + in_features=config.o_groups * config.o_lora_rank, + out_features=hidden_size, + use_bias=False, + dtype=dtype, + param_dtype=weight_dtype, + rngs=rngs, + ) + + # Attention Sink Parameter + self.sinks = nnx.Param(jax.nn.initializers.zeros(rngs.params(), (num_heads,), weight_dtype)) + + # Layer specific compressor allocation + if self.attention_type == "heavily_compressed_attention": + self.compressor = HCACompressor( + hidden_size=hidden_size, + head_dim=head_dim, + config=config, + layer_idx=layer_idx, + eps=eps, + weight_dtype=weight_dtype, + dtype=dtype, + rngs=rngs, + ) + elif self.attention_type == "compressed_sparse_attention": + self.compressor = CSACompressor( + hidden_size=hidden_size, + q_lora_rank=q_lora_rank, + head_dim=head_dim, + config=config, + layer_idx=layer_idx, + eps=eps, + weight_dtype=weight_dtype, + dtype=dtype, + rngs=rngs, + ) + else: + self.compressor = None + + # Compute partial rotary factor dynamically from config to prevent dimension mismatches. + # DeepSeek-V4 pairs consecutive channels to apply partial RoPE on qk_rope_head_dim channels, + # requiring dynamic scaling: partial_rotary_factor = qk_rope_head_dim / head_dim. + self.partial_rotary_factor = self.config.qk_rope_head_dim / self.config.head_dim + + self.rope_theta = ( + self.config.rope_max_timescale if self.attention_type == "sliding_attention" else self.config.compress_rope_theta + ) + + # Local rotary embedding block matching standard MaxText (Gemma/Llama2) paradigms. + self.rotary_embedding = DeepSeekV4RotaryEmbedding( + head_dim=self.head_dim, + partial_rotary_factor=self.partial_rotary_factor, + rope_theta=self.rope_theta, + ) + + # Scaling factor applied to query representations to match standard MaxText attention scaling. + # MaxText's AttentionOp core expects queries to be pre-scaled by 1 / sqrt(head_dim). + self.scaling = self.head_dim**-0.5 + + self.attention_op = AttentionOp( + config=self.config, + mesh=mesh, + attention_kernel=self.config.attention, + max_target_length=self.config.max_target_length, + num_query_heads=self.num_heads, + num_kv_heads=1, + dtype=self.dtype, + compute_axis_order=(0, 1, 2, 3), + attention_type=AttentionType.FULL, + rngs=rngs, + ) + + def __call__( + self, + hidden_states: jnp.ndarray | None = None, + position_ids: jnp.ndarray | None = None, + attention_mask: jnp.ndarray | None = None, + inputs_q: jnp.ndarray | None = None, + inputs_kv: jnp.ndarray | None = None, + **kwargs, + ) -> tuple[jnp.ndarray, jnp.ndarray]: + """Executes the main coordination attention pass over sequence inputs. + + This method coordinates multi-head attention augmented with query-compression LoRA + projections, unweighted key/value normalizations, long-range context compressor + integrations, learnable attention sinks, and parallelized grouped output mixing. + + Args: + hidden_states: Input sequence representations of shape [B, S, D_model]. + position_ids: Sequence absolute position identifiers of shape [B, S]. + attention_mask: Optional attention mask preventing invalid token attendance. + inputs_q: Optional query input override for decoupled execution. + inputs_kv: Optional key/value input override for decoupled execution. + **kwargs: Additional runtime execution configurations (e.g., decoder_segment_ids). + + Returns: + Tuple containing the projected output representations of shape [B, S, D_model] + and an empty caching intermediate placeholder. + """ + # Resolve input representations from standard hidden states or override inputs. + # hidden_states: [B, S, D_model] + if hidden_states is None: + hidden_states = inputs_q + batch, seq_len, _ = hidden_states.shape + + # Generate absolute position identifiers if not provided at runtime. + # position_ids: [B, S] + if position_ids is None: + position_ids = jnp.broadcast_to(jnp.arange(seq_len, dtype=jnp.int32)[None], (batch, seq_len)) + + # Resolve rotary position embedding sinusoids from runtime keyword arguments or compute local sinusoids. + # Utilizing pre-computed sinusoids avoids redundant computation across decoder layers during forward passes. + # cos: [B, S, D_rope/2] + # sin: [B, S, D_rope/2] + cos = kwargs.get("cos", None) + sin = kwargs.get("sin", None) + if cos is None or sin is None: + cos, sin = self.rotary_embedding(hidden_states, position_ids) + + # Project input features to low-rank query residuals and apply RMS normalization. + # # [B, S, D_model] -> [B, S, D_rank] + q_residual = self.q_a_norm(self.q_a_proj(hidden_states)) + + # Project low-rank residuals to multi-head query dimensions and reshape. + # # [B, S, D_rank] -> [B, S, H, D_head] + q = self.q_b_proj(q_residual).reshape(batch, seq_len, self.num_heads, self.head_dim) + + # Apply scale-free unweighted RMS normalization across multi-head queries and scale by attention scaling factor. + # Unweighted normalization stabilizes query variance without introducing learnable scaling parameters. + # MaxText's AttentionOp core assumes pre-scaled query tensors, requiring explicit scaling here. + # # [B, S, H, D_head] -> [B, S, H, D_head] + q = self.q_b_norm(q) * self.scaling + + # Apply Rotary Position Embedding (RoPE) to query representations. + # # [B, S, H, D_head] -> [B, S, H, D_head] + q = apply_rotary_pos_emb(q, cos, sin, unsqueeze_dim=2) + + # Project input representations to shared key/value features and apply RMS normalization. + # # [B, S, D_model] -> [B, S, 1, D_head] + kv = self.kv_norm(self.kv_proj(hidden_states)).reshape(batch, seq_len, 1, self.head_dim) + + # Apply Rotary Position Embedding (RoPE) to shared key/value representations. + # # [B, S, 1, D_head] -> [B, S, 1, D_head] + kv = apply_rotary_pos_emb(kv, cos, sin, unsqueeze_dim=2) + + # Integrate long-range context compressor representations if configured. + block_bias = None + if self.compressor is not None: + # Execute compressor pass to extract compressed key/value blocks and structural block bias masks. + # compressed_kv: [B, 1, W, D_head] or [B, W, 1, D_head] + # block_bias: [B, 1, S, W] or [B, S, W] + compressed_kv, block_bias = self.compressor(hidden_states, q_residual, position_ids) + + # Standardize compressed key/value layout to match multi-head sequence formats. + # # [B, 1, W, D_head] -> [B, W, 1, D_head] + if compressed_kv.shape[1] == 1: + compressed_kv = compressed_kv.transpose(0, 2, 1, 3) + + # Concatenate local sequence keys with compressed long-range cache blocks along sequence dimension. + # # [B, S, 1, D_head] + [B, W, 1, D_head] -> [B, S + W, 1, D_head] + kv = jnp.concatenate([kv, compressed_kv], axis=1) + + # Reconcile structural block bias masks with runtime attention masks. + if attention_mask is not None: + if block_bias is not None: + # Concatenate block bias mask to attention mask along trailing sequence dimension. + # # [B, 1, S, S] + [B, 1, S, W] -> [B, 1, S, S + W] + attention_mask = jnp.concatenate([attention_mask, block_bias.astype(attention_mask.dtype)], axis=-1) + elif kv.shape[1] > attention_mask.shape[-1]: + # Pad attention mask with zero-value allowed elements to match extended key/value sequence length. + # # [B, 1, S, S] -> [B, 1, S, S + W] + pad_width = kv.shape[1] - attention_mask.shape[-1] + attention_mask = jnp.pad(attention_mask, ((0, 0), (0, 0), (0, 0), (0, pad_width)), constant_values=0.0) + # Ensure key/value sequence length is perfectly divisible by the Splash attention block size (sa_block_kv). + # Hardware Matrix Multiply Units (MXUs) and XLA Pallas kernels enforce strict memory layout alignment grids. + # When Splash Flash Attention is active, the runtime key/value sequence dimension must perfectly divide by sa_block_kv. + # Because long-range cache compressors append dynamic auxiliary tokens (e.g. +32 tokens), the resulting combined length + # may break hardware divisibility constraints (e.g. 4128 % 512 != 0). + # This dynamic padding forces exact MXU grid alignment. + # # [B, S + W, 1, D_head] -> [B, align(S + W, sa_block_kv), 1, D_head] + if self.config.sa_block_kv > 0 and kv.shape[1] % self.config.sa_block_kv != 0: + pad_len = self.config.sa_block_kv - (kv.shape[1] % self.config.sa_block_kv) + kv = jnp.pad(kv, ((0, 0), (0, pad_len), (0, 0), (0, 0)), constant_values=0.0) + if attention_mask is not None: + # Pad 4D attention mask along trailing key/value sequence axis. + # # [B, 1, Q, S + W] -> [B, 1, Q, align(S + W, sa_block_kv)] + attention_mask = jnp.pad(attention_mask, ((0, 0), (0, 0), (0, 0), (0, pad_len)), constant_values=0.0) + + # Squeeze redundant head dimension from 4D attention masks to ensure compatibility with AttentionOp core. + # # [B, 1, S, S + W] -> [B, S, S + W] + unified_mask = ( + jnp.squeeze(attention_mask, axis=1) if attention_mask is not None and attention_mask.ndim == 4 else attention_mask + ) + + # Execute core attention operator pass over query and concatenated key/value sequences. + # # q: [B, S, H, D_head], kv: [B, S + W, 1, D_head] -> [B, S, H, D_head] + attn_output = self.attention_op( + query=q, + key=kv, + value=kv, + decoder_segment_ids=kwargs.get("decoder_segment_ids", None), + inputs_positions=position_ids, + model_mode=kwargs.get("model_mode", MODEL_MODE_TRAIN), + indexer_mask=unified_mask, + sinks=self.sinks, + ) + + # Apply conjugate RoPE rotation (-sin) to attention outputs to un-rotate representations. + # Un-rotating aligns output feature spaces prior to multi-head mixing projections. + # # [B, S, H, D_head] -> [B, S, H, D_head] + attn_output = apply_rotary_pos_emb(attn_output, cos, -sin, unsqueeze_dim=2) + + # Reshape attention outputs into block-diagonal output groups. + # # [B, S, H, D_head] -> [B, S, o_groups, H * D_head / o_groups] + grouped = attn_output.reshape(batch, seq_len, self.config.o_groups, -1) + + # Apply block-diagonal grouped linear projections to mix intra-group features. + # # [B, S, o_groups, H * D_head / o_groups] -> [B, S, o_groups, o_lora_rank] + grouped = self.o_a_proj(grouped) + + # Flatten grouped representations into a unified feature vector per sequence position. + # # [B, S, o_groups, o_lora_rank] -> [B, S, o_groups * o_lora_rank] + grouped_flat = grouped.reshape(batch, seq_len, -1) + + # Project mixed representations back to global model hidden dimension. + # # [B, S, o_groups * o_lora_rank] -> [B, S, D_model] + output = self.o_b_proj(grouped_flat) + + return output, None diff --git a/src/maxtext/layers/decoders.py b/src/maxtext/layers/decoders.py index a2d52dd033..b0498ddcbc 100644 --- a/src/maxtext/layers/decoders.py +++ b/src/maxtext/layers/decoders.py @@ -22,6 +22,8 @@ from flax import linen as nn from flax import nnx +from maxtext.layers import nnx_wrappers +from maxtext.models.deepseek_v4 import DeepSeekV4HyperHead from flax.linen.partitioning import ScanIn import jax from jax.ad_checkpoint import checkpoint_name @@ -43,6 +45,7 @@ deepseek, deepseek_batchsplit, deepseek_batchsplit_fp8, + deepseek_v4, gemma, gemma2, gemma3, @@ -460,6 +463,12 @@ def get_decoder_layers(self): deepseek.DeepSeekDenseLayerToLinen, deepseek.DeepSeekMoELayerToLinen, ] + case DecoderBlockType.DEEPSEEK_V4: + return ( + [deepseek_v4.DeepSeekV4ScannableBlockToLinen] + if self.config.scan_layers + else [deepseek_v4.DeepSeekV4DecoderLayerToLinen] + ) case DecoderBlockType.GEMMA: return [gemma.GemmaDecoderLayerToLinen] case DecoderBlockType.GEMMA2: @@ -983,6 +992,20 @@ def __call__( page_state, slot, ) + elif cfg.decoder_block == DecoderBlockType.DEEPSEEK_V4: + bidirectional_mask_value = multimodal_input.bidirectional_mask if multimodal_input is not None else None + y = self._apply_deepseek_v4_scanned_blocks( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk, + page_state, + slot, + bidirectional_mask=bidirectional_mask_value, + decoder_input_tokens=decoder_input_tokens, + ) else: RemattedBlockLayer = RemattedBlockLayers[0] scan_length = int(cfg.num_decoder_layers / cfg.inhomogeneous_layer_cycle_interval) @@ -1089,6 +1112,16 @@ def __call__( RemattedBlockLayer = RemattedBlockLayers[0] layer_kwargs = {} layer_call_kwargs = {} + if cfg.decoder_block == DecoderBlockType.DEEPSEEK_V4: + # Retrieve layer-specific compression ratio from configuration to support sliding window attention + # at boundary layers and alternating compressed sparse/heavily compressed attention. + compress_ratio = self.config.compress_ratios[lyr] + bidirectional_mask_value = multimodal_input.bidirectional_mask if multimodal_input is not None else None + layer_kwargs = {"compress_ratio": compress_ratio, "layer_idx": lyr} + layer_call_kwargs = { + "decoder_input_tokens": decoder_input_tokens, + "bidirectional_mask": bidirectional_mask_value, + } if cfg.decoder_block == DecoderBlockType.GEMMA3: # Gemma3 uses both global and sliding window attention depending on the layer index. bidirectional_mask_value = multimodal_input.bidirectional_mask if multimodal_input is not None else None @@ -1151,7 +1184,11 @@ def __call__( assert isinstance(y, jax.Array) # After the final transformer layer, `y` holds the raw, un-normalized hidden state. - if cfg.mhc_expansion_rate > 1: + if cfg.decoder_block == DecoderBlockType.DEEPSEEK_V4: + # Collapse final streams using learnable collapse weights [B, S, k, D] -> [B, S, D] + hc_head = nnx_wrappers.to_linen_class(DeepSeekV4HyperHead, name="hc_head")(config=cfg) + hidden_state = hc_head(y) + elif cfg.mhc_expansion_rate > 1: # (batch, length, mhc_expansion_rate, emb_dim) --> (batch, length, emb_dim) hidden_state = mhc_reduce(y) else: @@ -1335,6 +1372,81 @@ def _apply_gemma4_scanned_blocks( return y + def _apply_deepseek_v4_scanned_blocks( + self, + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk, + page_state, + slot, + bidirectional_mask=None, + decoder_input_tokens=None, + ): + """Applies DeepSeek-V4 scanned decoder blocks under Flax Linen, handling main scan and remainders.""" + cfg = self.config + mesh = self.mesh + + # Define the repeating pattern length (2 for cyclical DeepSeek-V4 layers) + scan_length = cfg.num_decoder_layers // 2 + + policy = self.get_remat_policy() + RemattedDSV4Block = self.set_remat_policy([deepseek_v4.DeepSeekV4ScannableBlockToLinen], policy)[0] + + layer_call_kwargs = { + "decoder_input_tokens": decoder_input_tokens, + "bidirectional_mask": bidirectional_mask, + } + layer_kwargs = {"num_of_layers": 2} + + # Apply the main scan over the full blocks + if scan_length > 0: + broadcast_args = ( + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + ) + # inputs: y shape [B, S, k, D] -> [B, S, k, D] + y, _ = self.scan_decoder_layers( + cfg, + RemattedDSV4Block, + scan_length, + "layers", + mesh, + in_axes_tuple=(nn.broadcast,) * len(broadcast_args), + model_mode=self.model_mode, + **layer_kwargs, + )(y, *broadcast_args, **layer_call_kwargs) + + # To allow efficient JAX/Flax compilation of the cyclic scanned layers in pairs, + # any prefix layers that have heterogeneous routing topologies (e.g., num_hash_layers + # static Hash Routing MoE layers) must be unrolled statelessly prior to the scan loop. + # Consequently, the remaining layers to be scanned and remainder-calculated are reduced by + # the number of prefix unrolled layers (cfg.num_hash_layers). + # Subtracting cfg.num_hash_layers ensures the correct remainder count of MoE layers + # evaluated statelessly, preventing shape mismatch and compilation drift. + num_remaining_layers = (cfg.num_decoder_layers - cfg.num_hash_layers) % 2 + if num_remaining_layers > 0: + rem_layer_kwargs = {"num_of_layers": num_remaining_layers} + layer = RemattedDSV4Block( + config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode, name="layers_remainder", **rem_layer_kwargs + ) + y, _ = layer( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk=previous_chunk, + page_state=page_state, + slot=slot, + **layer_call_kwargs, + ) + return y + # TODO(b/490118813): Relocate the following functions to their designated directories # once the plug-in strategy is implemented: _find_next_boundary(), _apply_single_engram_layer() # _apply_scanned_chunk() and _apply_interleaved_scanned_layers(). diff --git a/src/maxtext/layers/embeddings.py b/src/maxtext/layers/embeddings.py index 525fff1ed5..0035fe63dd 100644 --- a/src/maxtext/layers/embeddings.py +++ b/src/maxtext/layers/embeddings.py @@ -16,6 +16,7 @@ import dataclasses import math +from typing import Any import jax from jax import lax @@ -1800,3 +1801,115 @@ def qwen3_omni_mrope_embedding_as_linen( metadata_fn=variable_to_logically_partitioned, name=name, ) + + +class DeepSeekV4RotaryEmbedding(nnx.Module): + """DeepSeek-V4 partial rotary embedding with interleaved frequencies. + + DeepSeek-V4 uses an interleaved positional encoding where consecutive channels + are paired together. Unlike standard rotary models that split dimensions globally + into first and second halves, this implementation pairs each even channel 2i + with the corresponding odd channel 2i + 1. + + This results in two specific mathematical properties: + 1. Inverse frequencies are computed for (dim // 2) unique theta angles. + 2. Sinusoidal components are expanded consecutively (e.g., [f0, f0, f1, f1]) + prior to application. + """ + + def __init__( + self, + head_dim: int, + partial_rotary_factor: float = 64.0 / 512.0, + rope_theta: float = 10000.0, + dtype: Any = jnp.float32, + ): + self.head_dim = head_dim + self.partial_rotary_factor = partial_rotary_factor + self.rope_theta = rope_theta + self.dtype = dtype + + # Compute the partial rotary dimension (rope_head_dim) + # e.g., 512 * (64 / 512) = 64 channels + self.dim = int(head_dim * partial_rotary_factor) + + # Compute base inverse frequencies for half of self.dim (dim // 2 unique theta angles). + # Adjacent channels share the same base frequency, matching the reference sequence. + half_dim = self.dim // 2 + fraction = 2 * jnp.arange(0, half_dim, dtype=jnp.float32) / self.dim + self.inv_freq = 1.0 / (self.rope_theta**fraction) + + def __call__(self, x: jnp.ndarray, position_ids: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]: + # position_ids: [B, S] + # Expand inverse frequencies for broadcasting: [1, 1, dim/2] + inv_freq_expanded = self.inv_freq[jnp.newaxis, jnp.newaxis, :] + + # Expand position IDs: [B, S, 1] + position_ids_expanded = position_ids[:, :, jnp.newaxis].astype(jnp.float32) + + # Compute outer product of positions and frequencies: [B, S, dim/2] + freqs = position_ids_expanded * inv_freq_expanded + + cos = jnp.cos(freqs).astype(x.dtype) # [B, S, dim/2] + sin = jnp.sin(freqs).astype(x.dtype) # [B, S, dim/2] + + return cos, sin + + +def _rotate_half(x: jax.Array) -> jax.Array: + """Performs consecutive half-rotation to match DeepSeek-V4 interleaved layout. + + Pairs adjacent elements: [x0, x1, x2, x3] -> [-x1, x0, -x3, x2]. + + Operations: + 1. Slice even indices: x1 = x[..., 0::2] + 2. Slice odd indices: x2 = x[..., 1::2] + 3. Stack (-x2, x1) along a new trailing dimension: [..., D/2, 2] + 4. Reshape back to the original dimension: [..., D] + """ + x1 = x[..., 0::2] # [B, S, H, D_rope/2] + x2 = x[..., 1::2] # [B, S, H, D_rope/2] + + # Interleave consecutive components: [-x2_0, x1_0, -x2_1, x1_1, ...] + stacked = jnp.stack((-x2, x1), axis=-1) # [B, S, H, D_rope/2, 2] + return stacked.reshape(x.shape) # [B, S, H, D_rope] + + +def apply_rotary_pos_emb( + x: jax.Array, + cos: jax.Array, + sin: jax.Array, + unsqueeze_dim: int = 2, +) -> jax.Array: + """Applies DeepSeek-V4 interleaved RoPE to the trailing rotary slice of x. + + 1. Duplicates inverse frequencies consecutively using jnp.repeat along the + last dimension to match the full rotary dimension size. + 2. Extracts the trailing 'rope_dim' channels of x to apply rotation, leaving + the leading 'nope' channels unmodified. + 3. Computes the rotation using float32 precision for numerical stability, + casting the final rotated tensor back to the input data type. + """ + # cos/sin shape: [B, S, D_rope/2] + # Duplicate frequencies consecutively to build full D_rope dimension + cos = jnp.repeat(cos, 2, axis=-1) # [B, S, D_rope] + sin = jnp.repeat(sin, 2, axis=-1) # [B, S, D_rope] + + # Expand dimensions for head broadcasting: [B, S, 1, D_rope] + cos = jnp.expand_dims(cos, axis=unsqueeze_dim) + sin = jnp.expand_dims(sin, axis=unsqueeze_dim) + + rope_dim = cos.shape[-1] + + # Separate features into unrotated (nope) and rotated (rope) slices + # x: [B, S, H, D] where D is the head dimension + nope = x[..., :-rope_dim] # [B, S, H, D - D_rope] + rope = x[..., -rope_dim:] # [B, S, H, D_rope] + + # Cast to float32, compute rotation, and cast back to original data type + rope_f32 = rope.astype(jnp.float32) + rotated = (rope_f32 * cos) + (_rotate_half(rope_f32) * sin) + rotated = rotated.astype(x.dtype) + + # Concatenate unrotated and rotated channels + return jnp.concatenate([nope, rotated], axis=-1) # [B, S, H, D] diff --git a/src/maxtext/layers/engram.py b/src/maxtext/layers/engram.py index 3b2eb4e2b5..f457097d72 100644 --- a/src/maxtext/layers/engram.py +++ b/src/maxtext/layers/engram.py @@ -335,6 +335,15 @@ class StaticWrapper: def __init__(self, val): self.val = val + def __getitem__(self, key): + return self.val[key] + + def __setitem__(self, key, value): + if key is Ellipsis: + self.val = value + else: + self.val = self.val.at[key].set(value) + class MultiHeadEmbedding(nnx.Module): """ diff --git a/src/maxtext/layers/linears.py b/src/maxtext/layers/linears.py index 8d9f094c98..519f35a5bc 100644 --- a/src/maxtext/layers/linears.py +++ b/src/maxtext/layers/linears.py @@ -567,3 +567,66 @@ def mlp_block( abstract_init=False, ) return module + + +class DeepSeekGroupedLinear(nnx.Module): + """Block-diagonal grouped linear projection layer. + + This layer segments the trailing dimension of the input tensor into a specified + number of groups, and projects each group independently using a distinct weight + matrix block. It minimizes parameter counts and compute overhead in the + attention output projection. + """ + + def __init__( + self, + in_features_per_group: int, + out_features: int, + n_groups: int, + weight_dtype: DType = jnp.float32, + dtype: DType = jnp.float32, + kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "truncated_normal"), + *, + rngs: nnx.Rngs, + ): + self.in_features_per_group = in_features_per_group + self.out_features = out_features + self.n_groups = n_groups + self.weight_dtype = weight_dtype + self.dtype = dtype + + # Validate divisibility of target output features by group count + if out_features % n_groups != 0: + raise ValueError(f"Output features ({out_features}) must be divisible by n_groups ({n_groups}).") + self.out_features_per_group = out_features // n_groups + + # Grouped block-diagonal projection kernel parameters + # Kernels are stored as a 3D tensor: [n_groups, in_features_per_group, out_features_per_group] + kernel_shape = (n_groups, in_features_per_group, self.out_features_per_group) + self.kernel = nnx.Param( + kernel_init( + rngs.params(), + kernel_shape, + self.weight_dtype, + in_axis=1, + out_axis=2, + ) + ) + + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + """Projects segmented groups from the input tensor using block weight matrices. + + Args: + x: Input tensor of shape [..., n_groups, in_features_per_group] + + Returns: + Projected tensor of shape [..., n_groups, out_features_per_group] + """ + x = jnp.asarray(x, self.dtype) + kernel = jnp.asarray(self.kernel[...], self.dtype) + + # Execute parallel group projection via optimized einsum broadcasting. + # x: [..., g, i] + # kernel: [g, i, o] + # output: [..., g, o] + return jnp.einsum("...gi,gio->...go", x, kernel) diff --git a/src/maxtext/layers/mhc.py b/src/maxtext/layers/mhc.py index ce700aafcd..c479ebd177 100644 --- a/src/maxtext/layers/mhc.py +++ b/src/maxtext/layers/mhc.py @@ -19,10 +19,10 @@ import jax import jax.numpy as jnp from jax.sharding import Mesh -from maxtext.common.common_types import Array, Config +from maxtext.common.common_types import Array, Config, DecoderBlockType from maxtext.common.common_types import HyperConnectionType from maxtext.layers.initializers import default_bias_init, default_scalar_init, nd_dense_init -from maxtext.layers.normalizations import RMSNorm +from maxtext.layers.normalizations import DeepSeekV4UnweightedRMSNorm, RMSNorm def get_functions(expansion_rate: int): @@ -42,26 +42,26 @@ def reduce(x: Array): return expand, reduce -def sinkhorn(t, iters=20): +def sinkhorn(t, iters=20, eps=1e-12): """Computes the Sinkhorn normalization of a matrix (rows and columns sum to 1).""" - # Use float32 precision for numerical stability during normalization + # Use float32 precision for numerical stability during alternating L1 row/column normalizations. + # val: [B, S, H, H] initial_dtype = t.dtype t = t.astype(jnp.float32) - # Column-wise normalization (axis=-2) - positive and sum up to 1 across columns - # Equivalent to t = exp(t) / jnp.sum(jnp.exp(t), axis=-2) - t = jax.nn.softmax(t, axis=-2) + # Column normalization first (sum along axis -2) matching Xie et al. Equation 8 initialization. + t = t / (jnp.sum(t, axis=-2, keepdims=True) + eps) def body_fun(i, val): - # L1 Normalization: val / sum(val) with clipping of denominator + # L1 Normalization: val / (sum(val) + eps) matching the exact denominator addition. # Normalize rows (axis -1) - val = val / jnp.clip(jnp.sum(val, axis=-1, keepdims=True), min=1e-12) + val = val / (jnp.sum(val, axis=-1, keepdims=True) + eps) # Normalize columns (axis -2) - val = val / jnp.clip(jnp.sum(val, axis=-2, keepdims=True), min=1e-12) + val = val / (jnp.sum(val, axis=-2, keepdims=True) + eps) return val - # Use lax.fori_loop for an efficient, JIT-friendly loop - t = jax.lax.fori_loop(0, iters, body_fun, t) + # Use lax.fori_loop for an efficient, JIT-friendly loop over exactly iters - 1 steps. + t = jax.lax.fori_loop(0, iters - 1, body_fun, t) return t.astype(initial_dtype) @@ -95,14 +95,20 @@ def __init__( self.matmul_precision = jax.lax.Precision(self.config.matmul_precision) # Norm layer - self.mhc_norm = RMSNorm( - num_features=self.k * self.dim, - dtype=self.config.dtype, - weight_dtype=self.weight_dtype, - kernel_axes=("norm",), - epsilon=self.config.normalization_layer_epsilon, - rngs=self.rngs, - ) + if getattr(self.config, "decoder_block", None) == DecoderBlockType.DEEPSEEK_V4: + self.mhc_norm = DeepSeekV4UnweightedRMSNorm( + eps=self.config.normalization_layer_epsilon, + dtype=self.config.dtype, + ) + else: + self.mhc_norm = RMSNorm( + num_features=self.k * self.dim, + dtype=self.config.dtype, + weight_dtype=self.weight_dtype, + kernel_axes=("norm",), + epsilon=self.config.normalization_layer_epsilon, + rngs=self.rngs, + ) # Scalars self.res_alpha_scale = nnx.Param( @@ -170,28 +176,33 @@ def __init__( def res_mapping(self, x: Array): """Helper function for residual mapping.""" - # In MaxText, we match weight precision to activations before Matmul + # In MaxText, we match weight precision to activations before Matmul. + # x: [B, S, H * D] representing sequence token features. + # res_alpha: [H * D, H * H] + # res_beta: [H, H] res_alpha = jnp.asarray(self.res_alpha[...], self.dtype) res_beta = jnp.asarray(self.res_beta[...], self.dtype) res_alpha_scale = jnp.asarray(self.res_alpha_scale[...], self.dtype) - # Apply projection: (b, s, k*d) @ (k*d, k*k) -> (b, s, k*k) h_res = jnp.einsum("bsm,mn -> bsn", x, res_alpha, precision=self.matmul_precision) b, s, _ = h_res.shape h_res = jnp.reshape(h_res, (b, s, self.k, self.k)) intermediate = res_alpha_scale * h_res + res_beta[None, None, :, :] - output = sinkhorn(intermediate, self.sinkhorn_iterations) + # Apply softmax pre-normalization along the trailing axis matching the exact initialization. + # intermediate: [B, S, H, H] + intermediate = jax.nn.softmax(intermediate, axis=-1) + self.config.hc_eps + output = sinkhorn(intermediate, self.sinkhorn_iterations, eps=self.config.hc_eps) return output - def mapping(self, x: Array, alpha_scale: Array, alpha: Array, beta: Array, scale: int): + def mapping(self, x: Array, alpha_scale: Array, alpha: Array, beta: Array, scale: float, eps: float = 0.0): """Helper function for both pre and post mappings.""" # In MaxText, we match weight precision to activations before Matmul alpha = jnp.asarray(alpha, self.dtype) beta = jnp.asarray(beta, self.dtype) alpha_scale = jnp.asarray(alpha_scale, self.dtype) - # Apply projection: (b, s, k*d) @ (k*d, k) -> (b, s, k) - h = jnp.einsum("bsm,mk -> bsk", x, alpha, precision=self.matmul_precision) + # Apply projection: (b, s, e*d) @ (e*d, e) -> (b, s, e) + h = jnp.einsum("bsm,me -> bse", x, alpha, precision=self.matmul_precision) intermediate = alpha_scale * h + beta[None, None, :] - output = scale * jax.nn.sigmoid(intermediate) + output = scale * jax.nn.sigmoid(intermediate) + eps return output def __call__( @@ -227,8 +238,12 @@ def __call__( self.pre_alpha[...], self.pre_beta[...], 1.0, + self.config.hc_eps, ) - layer_input = jnp.einsum("bskd,bsk -> bsd", x, pre_mapping, precision=self.matmul_precision) + layer_input = jnp.einsum( + "bsed,bse -> bsd", x.astype(jnp.float32), pre_mapping.astype(jnp.float32), precision=self.matmul_precision + ) + layer_input = layer_input.astype(self.dtype) # 3. Pre-norm layer_input = norm_fn(layer_input) @@ -246,22 +261,31 @@ def __call__( else: raise ValueError(f"Unsupported type: {mhc_type}") - # 5. Post mapping + # 5. Post mapping (multiplied by 2.0 matching post_scale) post_mapping = self.mapping( norm_x, self.post_alpha_scale[...], self.post_alpha[...], self.post_beta[...], 2.0, + 0.0, ) post_out = jnp.einsum( - "bsd,bsk -> bskd", - layer_out, - post_mapping, + "bsd,bse -> bsed", + layer_out.astype(jnp.float32), + post_mapping.astype(jnp.float32), precision=self.matmul_precision, ) # 6. Residual mapping, res_out shape as [batch, seq, expansion_rate, emb] res_mapping = self.res_mapping(norm_x) - res_out = jnp.einsum("bskd,bskm -> bsmd", x, res_mapping, precision=self.matmul_precision) - return res_out + post_out, metadata + # Transposed residual mixing (bsme @ bsmd -> bsed) matching Xie et al. Equation 8 + # representing the projection index: comb.T @ residual stream values. + # res_mapping: [B, S, H_src, H_dest] + # x: [B, S, H_src, D] + # res_out: [B, S, H_dest, D] + res_out = jnp.einsum( + "bsme,bsmd -> bsed", res_mapping.astype(jnp.float32), x.astype(jnp.float32), precision=self.matmul_precision + ) + output = res_out + post_out + return output.astype(self.dtype), metadata diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 975e8fe9a2..6a8b5134d2 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -23,6 +23,8 @@ from aqt.jax.v2 import aqt_tensor as aqt from flax import nnx + +nnx.BatchStats = nnx.BatchStat import jax from jax import ad_checkpoint as adc from jax.experimental import xla_metadata @@ -306,6 +308,202 @@ def __call__(self, inputs: jax.Array, _initializing: bool = False) -> Tuple[jax. return output, pre_bias_logits +def _sqrtsoftplus(x: jax.Array) -> jax.Array: + """Computes sqrtsoftplus activation: sqrt(softplus(x)).""" + # [Any] -> [Any] + return jnp.sqrt(jax.nn.softplus(x)) + + +class DeepSeekV4TopKRouter(nnx.Module): + """Top-K Router for DeepSeek-V4 MoE routing. + + Computes logits, normalized routing weights, and expert indices. + """ + + def __init__( + self, + config: ctypes.Config, + mesh: jax.sharding.Mesh, + rngs: nnx.Rngs, + kernel_axes: Tuple[Optional[str], ...] = (), + ): + super().__init__() + self.config = config + self.top_k = config.num_experts_per_tok + self.num_experts = config.num_experts + self.hidden_dim = config.emb_dim if config.moe_expert_input_dim <= 0 else config.moe_expert_input_dim + self.routed_scaling_factor = config.routed_scaling_factor + + # Initialize gate weight matrix. + # Shape: [hidden_dim, num_experts] + kernel_init = nd_dense_init(1.0, "fan_in", "truncated_normal") + kernel_shape = (self.hidden_dim, self.num_experts) + kernel_in_axis = np.arange(1) + kernel_out_axis = np.arange(1, 2) + + self.kernel = nnx.Param( + kernel_init( + rngs.params(), + kernel_shape, + config.weight_dtype, + kernel_in_axis, + kernel_out_axis, + ), + out_sharding=kernel_axes, + ) + + # Load-balancing expert score correction bias. + # Shape: [num_experts] + self.e_score_correction_bias = nnx.Param( + jnp.zeros((self.num_experts,), dtype=jnp.float32), + out_sharding=(kernel_axes[-1] if kernel_axes else None,), + ) + + def __call__(self, hidden_states: jax.Array) -> Tuple[jax.Array, jax.Array, jax.Array]: + # input hidden_states shape: [batch, seq_len, hidden_dim] or [tokens, hidden_dim] + inputs = jnp.asarray(hidden_states, dtype=self.config.dtype) + # [batch, seq_len, hidden_dim] -> [tokens, hidden_dim] + flat = inputs.reshape(-1, self.hidden_dim) + + # Compute raw logits in float32. + # [tokens, hidden_dim] x [hidden_dim, num_experts] -> [tokens, num_experts] + kernel_f32 = jnp.asarray(self.kernel[...], dtype=jnp.float32) + logits = jnp.matmul(flat.astype(jnp.float32), kernel_f32) + + # Apply routed scoring function from configuration. + # [tokens, num_experts] -> [tokens, num_experts] + score_fn = ( + _sqrtsoftplus + if self.config.routed_score_func == "sqrtsoftplus" + else linears._convert_to_activation_function(self.config.routed_score_func) + ) + scores = score_fn(logits) + + # Add expert score correction bias and select top-k indices. + # [tokens, num_experts] + [num_experts] -> [tokens, num_experts] + scores_biased = scores + jnp.asarray(self.e_score_correction_bias[...], dtype=jnp.float32) + # [tokens, num_experts] -> [tokens, top_k] + _, indices = jax.lax.top_k(scores_biased, self.top_k) + + # Gather corresponding scores for the selected top-k indices. + # [tokens, num_experts] gathered with [tokens, top_k] -> [tokens, top_k] + weights = jnp.take_along_axis(scores, indices, axis=-1) + + # Normalize weights to sum to 1.0 per token. + # [tokens, top_k] -> [tokens, top_k] + weights = weights / (weights.sum(axis=-1, keepdims=True) + 1e-20) + + # Scale weights by routed scaling factor. + # [tokens, top_k] -> [tokens, top_k] + scaled_weights = weights * self.routed_scaling_factor + + return ( + logits.astype(self.config.dtype), + scaled_weights.astype(self.config.dtype), + indices, + ) + + +class non_trainable(nnx.Variable): + pass + + +class DeepSeekV4HashRouter(nnx.Module): + """Hash Router for DeepSeek-V4 MoE routing. + + Computes logits, static routing weights based on token IDs, and expert indices. + """ + + def __init__( + self, + config: ctypes.Config, + mesh: jax.sharding.Mesh, + rngs: nnx.Rngs, + kernel_axes: Tuple[Optional[str], ...] = (), + ): + super().__init__() + self.config = config + self.top_k = config.num_experts_per_tok + self.num_experts = config.num_experts + self.hidden_dim = config.emb_dim if config.moe_expert_input_dim <= 0 else config.moe_expert_input_dim + self.routed_scaling_factor = config.routed_scaling_factor + + # Initialize gate weight matrix. + # Shape: [hidden_dim, num_experts] + kernel_init = nd_dense_init(1.0, "fan_in", "truncated_normal") + kernel_shape = (self.hidden_dim, self.num_experts) + kernel_in_axis = np.arange(1) + kernel_out_axis = np.arange(1, 2) + + self.kernel = nnx.Param( + kernel_init( + rngs.params(), + kernel_shape, + config.weight_dtype, + kernel_in_axis, + kernel_out_axis, + ), + out_sharding=kernel_axes, + ) + + # Static token-to-expert mapping table. + # Shape: [vocab_size, top_k] + # Register self.tid2eid using nnx.Param with trainable=False and dtype=jnp.float32. + # Initializing with floating-point parameters completely resolves JAX gradient compilation + # passes by bypassing real-valued autograd exceptions, while setting trainable=False + # ensures Optax optimization rules correctly bypass weight decay updates. In the forward pass, + # dynamic casting to int32 freezes mathematical gradient tracking at exactly 0.0. + self.tid2eid = nnx.Param( + jnp.zeros((config.vocab_size, self.top_k), dtype=jnp.float32), + trainable=False, + ) + + def __call__(self, hidden_states: jax.Array, input_ids: jax.Array) -> Tuple[jax.Array, jax.Array, jax.Array]: + # input hidden_states shape: [batch, seq_len, hidden_dim] or [tokens, hidden_dim] + inputs = jnp.asarray(hidden_states, dtype=self.config.dtype) + # [batch, seq_len, hidden_dim] -> [tokens, hidden_dim] + flat = inputs.reshape(-1, self.hidden_dim) + + # Compute raw logits in float32. + # [tokens, hidden_dim] x [hidden_dim, num_experts] -> [tokens, num_experts] + kernel_f32 = jnp.asarray(self.kernel[...], dtype=jnp.float32) + logits = jnp.matmul(flat.astype(jnp.float32), kernel_f32) + + # Apply routed scoring function from configuration. + # [tokens, num_experts] -> [tokens, num_experts] + score_fn = ( + _sqrtsoftplus + if self.config.routed_score_func == "sqrtsoftplus" + else linears._convert_to_activation_function(self.config.routed_score_func) + ) + scores = score_fn(logits) + + # Look up frozen expert routing indices from input_ids. + # [batch, seq_len] -> [tokens] + flat_input_ids = input_ids.reshape(-1) + # Look up from nnx.Param to retrieve frozen lookup indices. + # [vocab_size, top_k] sliced at [tokens] -> [tokens, top_k] + indices = self.tid2eid.value[flat_input_ids].astype(jnp.int32) + + # Gather corresponding scores for the statically selected expert indices. + # [tokens, num_experts] gathered with [tokens, top_k] -> [tokens, top_k] + weights = jnp.take_along_axis(scores, indices, axis=-1) + + # Normalize weights to sum to 1.0 per token. + # [tokens, top_k] -> [tokens, top_k] + weights = weights / (weights.sum(axis=-1, keepdims=True) + 1e-20) + + # Scale weights by routed scaling factor. + # [tokens, top_k] -> [tokens, top_k] + scaled_weights = weights * self.routed_scaling_factor + + return ( + logits.astype(self.config.dtype), + scaled_weights.astype(self.config.dtype), + indices, + ) + + class RoutedMoE(nnx.Module): """Implements a routed MoE block.""" @@ -322,6 +520,7 @@ def __init__( weight_dtype: ctypes.DType = jnp.float32, dtype: ctypes.DType = jnp.float32, quant: Optional[quantizations.AqtQuantization] = None, + layer_idx: int = 0, ): """Initializes the RoutedMoE module. @@ -349,6 +548,7 @@ def __init__( self.dtype = dtype self.quant = quant self.rngs = rngs + self.layer_idx = layer_idx self.moe_expert_input_dim = ( self.config.emb_dim if self.config.moe_expert_input_dim <= 0 else self.config.moe_expert_input_dim @@ -381,25 +581,33 @@ def __init__( else: self._expert_parallelism_name = "expert" - self.gate = GateLogit( - in_features_shape=self.moe_expert_input_dim, - out_features_shape=self.num_experts, - mesh=self.mesh, - model_name=self.config.model_name, - dtype=jnp.float32 if self.config.float32_gate_logits else self.dtype, - weight_dtype=self.weight_dtype, - quant=self.quant, - kernel_init=self.kernel_init, - kernel_axes=self.kernel_axes, - use_bias=self.config.routed_bias, - # tpu-inference applies the score function in the fused_moe_gmm kernel, - # so we don't apply it here to avoid redundant computation. - # See https://github.com/vllm-project/tpu-inference/blob/main/tpu_inference/layers/common/fused_moe_gmm.py#L58. - score_func="" if self.config.attention == "vllm_rpa" else self.config.routed_score_func, - matmul_precision=self.config.matmul_precision, - shard_mode=config.shard_mode, - rngs=self.rngs, + self.is_hash = ( + self.config.decoder_block == ctypes.DecoderBlockType.DEEPSEEK_V4 and 0 <= layer_idx < config.num_hash_layers ) + if self.is_hash: + self.gate = DeepSeekV4HashRouter(config=config, mesh=mesh, rngs=rngs, kernel_axes=self.kernel_axes) + elif self.config.decoder_block == ctypes.DecoderBlockType.DEEPSEEK_V4: + self.gate = DeepSeekV4TopKRouter(config=config, mesh=mesh, rngs=rngs, kernel_axes=self.kernel_axes) + else: + self.gate = GateLogit( + in_features_shape=self.moe_expert_input_dim, + out_features_shape=self.num_experts, + mesh=self.mesh, + model_name=self.config.model_name, + dtype=jnp.float32 if self.config.float32_gate_logits else self.dtype, + weight_dtype=self.weight_dtype, + quant=self.quant, + kernel_init=self.kernel_init, + kernel_axes=self.kernel_axes, + use_bias=self.config.routed_bias, + # tpu-inference applies the score function in the fused_moe_gmm kernel, + # so we don't apply it here to avoid redundant computation. + # See https://github.com/vllm-project/tpu-inference/blob/main/tpu_inference/layers/common/fused_moe_gmm.py#L58. + score_func="" if self.config.attention == "vllm_rpa" else self.config.routed_score_func, + matmul_precision=self.config.matmul_precision, + shard_mode=config.shard_mode, + rngs=self.rngs, + ) rule = qpl.get_current_rule("gmm") sparsity_rule = None if rule is not None: @@ -704,7 +912,13 @@ def deepseek_routing(self, gate_logits: jax.Array, pre_bias_logits: jax.Array) - def apply_ffn_activation(self, layer_w0, layer_w1): """Applies FFN activation function.""" with jax.named_scope("ffn_act"): - if self.config.decoder_block == ctypes.DecoderBlockType.GPT_OSS: + if self.config.decoder_block == ctypes.DecoderBlockType.DEEPSEEK_V4: + limit = getattr(self.config, "swiglu_limit", 1.0) + layer_w0 = jnp.clip(layer_w0, max=limit) + layer_w1 = jnp.clip(layer_w1, min=-limit, max=limit) + layer_act = self.activation_fn(layer_w0) + intermediate_layer = jnp.multiply(layer_act, layer_w1) + elif self.config.decoder_block == ctypes.DecoderBlockType.GPT_OSS: layer_w0 = jnp.clip(layer_w0, min=None, max=self.config.mlp_activations_limit) layer_w1 = jnp.clip(layer_w1, min=-self.config.mlp_activations_limit, max=self.config.mlp_activations_limit) layer_act = self.activation_fn(layer_w0 * 1.702) @@ -715,13 +929,26 @@ def apply_ffn_activation(self, layer_w0, layer_w1): intermediate_layer = jnp.multiply(layer_act, layer_w1) return intermediate_layer.astype(self.dtype) - def permute(self, inputs, gate_logits, pre_bias_logits, use_custom_sort_vjp=True, rngs=None, roll_to_expert_id=None): + def permute( + self, + inputs, + gate_logits, + pre_bias_logits, + use_custom_sort_vjp=True, + rngs=None, + roll_to_expert_id=None, + gate_weights=None, + gate_indices=None, + ): """Permute tokens to group by expert to fit gmm call.""" # reshape inputs (batch, sequence, emb) to (batch * sequence, emb) inputs_shape = inputs.shape bsz_times_seq_len = inputs_shape[0] * inputs_shape[1] inputs_2d = jnp.reshape(inputs, (bsz_times_seq_len, inputs_shape[2])) - weights, selected_experts = self.get_topk(gate_logits, pre_bias_logits, rngs) + if gate_weights is not None and gate_indices is not None: + weights, selected_experts = gate_weights, gate_indices + else: + weights, selected_experts = self.get_topk(gate_logits, pre_bias_logits, rngs) lb_loss = None if self.config.load_balance_loss_weight > 0.0: @@ -794,7 +1021,7 @@ def unpermute( if self.config.decoder_block == ctypes.DecoderBlockType.LLAMA4: # For Llama4, combine using weights of 1 for selected experts reshaped_weights = jnp.ones_like(reshaped_weights) - if self.config.float32_weight_sum: + if self.config.float32_weight_sum or self.config.decoder_block == ctypes.DecoderBlockType.DEEPSEEK_V4: reshaped_intermediate = reshaped_intermediate.astype(jnp.float32) reshaped_weights = reshaped_weights.astype(jnp.float32) output = jnp.einsum( @@ -1033,6 +1260,8 @@ def sparse_matmul( w0_bias, w1_bias, wo_bias, + gate_weights=None, + gate_indices=None, ): """Perform sparse matrix multiplication of inputs and Experts.""" @@ -1197,7 +1426,10 @@ def get_routed_moe_shardings(is_batch_sharded_by_expert): wo_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_embed")) gate_logits_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", None)) - if self.config.model_name.startswith("deepseek3"): + if ( + self.config.model_name.startswith("deepseek3") + or self.config.decoder_block == ctypes.DecoderBlockType.DEEPSEEK_V4 + ): pre_bias_logits_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", None)) else: # pre_bias_logits is None for non-DeepSeek v3 models @@ -1255,6 +1487,16 @@ def get_routed_moe_shardings(is_batch_sharded_by_expert): ) = get_routed_moe_shardings(is_batch_sharded_by_expert) w0_pspec, w1_pspec, wo_pspec = maybe_aqt_partition(w0_kernel, w0_pspec, w1_kernel, w1_pspec, wo_kernel, wo_pspec) + if gate_weights is not None: + gate_weights_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", None)) + else: + gate_weights_pspec = None + + if gate_indices is not None: + gate_indices_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", None)) + else: + gate_indices_pspec = None + @functools.partial( jax.shard_map, mesh=self.mesh, @@ -1269,6 +1511,8 @@ def get_routed_moe_shardings(is_batch_sharded_by_expert): w1_bias_pspec, wo_bias_pspec, P(), # Replicate the input key + gate_weights_pspec, + gate_indices_pspec, ), out_specs=( self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", "activation_embed")), @@ -1277,7 +1521,7 @@ def get_routed_moe_shardings(is_batch_sharded_by_expert): ), check_vma=False, ) - def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, rngs): + def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, rngs, g_weights, g_indices): batch_size, sequence_length, _ = x.shape num_expert_parallelism = self.get_expert_parallelism_size() if num_expert_parallelism > 1: @@ -1304,6 +1548,8 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r self.config.use_custom_sort_vjp, roll_to_expert_id=num_experts_per_shard * expert_shard_id, rngs=rngs, + gate_weights=g_weights, + gate_indices=g_indices, ) # Filter down to the group sizes that apply to only the experts in the @@ -1313,7 +1559,13 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r x = jnp.where(mask[:, None], x, 0) else: x, sorted_selected_experts, weights, group_sizes, selected_experts, lb_loss, bias_updates = self.permute( - x, logits, pre_bias_logits, self.config.use_custom_sort_vjp, rngs + x, + logits, + pre_bias_logits, + self.config.use_custom_sort_vjp, + rngs, + gate_weights=g_weights, + gate_indices=g_indices, ) if num_expert_parallelism > 1: @@ -1571,7 +1823,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): input_axes = (batch_logical_axis, "activation_norm_length", None) gate_logits_axes = (batch_logical_axis, "activation_norm_length", None) - if self.config.model_name.startswith("deepseek3"): + if self.config.model_name.startswith("deepseek3") or self.config.decoder_block == ctypes.DecoderBlockType.DEEPSEEK_V4: pre_bias_logits_axes = (batch_logical_axis, "activation_norm_length", None) else: pre_bias_logits_axes = None @@ -1590,8 +1842,24 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): if wo_bias is not None: wo_bias = self._maybe_shard_with_pspec(wo_bias, wo_bias_pspec) + if gate_weights is not None: + gate_weights = self._maybe_shard_with_logical(gate_weights, gate_logits_axes) + if gate_indices is not None: + gate_indices = self._maybe_shard_with_logical(gate_indices, gate_logits_axes) + return wrapper( - inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias, self.rngs + inputs, + gate_logits, + pre_bias_logits, + w0_kernel, + w1_kernel, + wo_kernel, + w0_bias, + w1_bias, + wo_bias, + self.rngs, + gate_weights, + gate_indices, ) def reshape_and_update_weights(self, weights, indices): @@ -1851,6 +2119,8 @@ def dense_matmul( w0_bias, w1_bias, wo_bias, + gate_weights=None, + gate_indices=None, ) -> tuple[jax.Array, Optional[jax.Array], Optional[jax.Array]]: """Dense matrix multiplication.""" # gate_logits: batch, length, expert @@ -1860,7 +2130,10 @@ def dense_matmul( pre_bias_logits = self._maybe_shard_with_logical( pre_bias_logits, ("activation_batch_moe", "activation_length_moe", None) ) - top_k_weights, top_k_indices = self.get_topk(gate_logits, pre_bias_logits, self.rngs) + if gate_weights is not None and gate_indices is not None: + top_k_weights, top_k_indices = gate_weights, gate_indices + else: + top_k_weights, top_k_indices = self.get_topk(gate_logits, pre_bias_logits, self.rngs) is_llama4_decoder_layer = self.config.decoder_block == ctypes.DecoderBlockType.LLAMA4 if is_llama4_decoder_layer: router_scores = jax.nn.sigmoid(top_k_weights.astype(jnp.float32)).astype(self.dtype) @@ -2231,13 +2504,33 @@ def retrieve_quantized_weight( return w0_kernel, w1_kernel, wo_kernel def __call__( - self, inputs: jax.Array, gate_inputs: jax.Array | None = None, out_sharding: NamedSharding | None = None + self, + inputs: jax.Array, + gate_inputs: jax.Array | None = None, + out_sharding: NamedSharding | None = None, + input_ids: jax.Array | None = None, + gate_weights: jax.Array | None = None, + gate_indices: jax.Array | None = None, ) -> tuple[jax.Array, Optional[jax.Array], Optional[jax.Array]]: cfg = self.config inputs = inputs.astype(cfg.dtype) gate_dtype = jnp.float32 if cfg.float32_gate_logits else cfg.dtype routing_inputs = inputs if gate_inputs is None else gate_inputs.astype(gate_dtype) - gate_logits, pre_bias_logits = self.gate(routing_inputs) + + if self.config.decoder_block == ctypes.DecoderBlockType.DEEPSEEK_V4: + batch_size, seq_len = inputs.shape[0], inputs.shape[1] + if self.is_hash: + if input_ids is None: + raise ValueError("input_ids must be provided when using DeepSeekV4HashRouter.") + gate_logits, gate_weights_val, gate_indices_val = self.gate(routing_inputs, input_ids) + else: + gate_logits, gate_weights_val, gate_indices_val = self.gate(routing_inputs) + gate_logits = gate_logits.reshape(batch_size, seq_len, -1) + gate_weights = gate_weights_val.reshape(batch_size, seq_len, -1) + gate_indices = gate_indices_val.reshape(batch_size, seq_len, -1) + pre_bias_logits = gate_logits + else: + gate_logits, pre_bias_logits = self.gate(routing_inputs) wo_kernel = jnp.asarray(self.wo[...], self.dtype) @@ -2284,11 +2577,31 @@ def __call__( wo_bias, ) output, lb_loss, bias_updates = self.sparse_matmul( - inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias + inputs, + gate_logits, + pre_bias_logits, + w0_kernel, + w1_kernel, + wo_kernel, + w0_bias, + w1_bias, + wo_bias, + gate_weights=gate_weights, + gate_indices=gate_indices, ) else: output, lb_loss, bias_updates = self.dense_matmul( - inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias + inputs, + gate_logits, + pre_bias_logits, + w0_kernel, + w1_kernel, + wo_kernel, + w0_bias, + w1_bias, + wo_bias, + gate_weights=gate_weights, + gate_indices=gate_indices, ) return output, lb_loss, bias_updates @@ -2306,6 +2619,7 @@ def __init__( weight_dtype: ctypes.DType = jnp.float32, dtype: ctypes.DType = jnp.float32, quant: Optional[quantizations.AqtQuantization] = None, + layer_idx: int = 0, ): """Initializes the RoutedAndSharedMoE module. @@ -2345,6 +2659,7 @@ def __init__( weight_dtype=self.config.weight_dtype, quant=self.quant, rngs=self.rngs, + layer_idx=layer_idx, ) shared_expert_mlp_dim = ( @@ -2375,12 +2690,24 @@ def __call__( gate_inputs: jax.Array | None = None, intermediate_sharding: NamedSharding | None = None, out_sharding: NamedSharding | None = None, + input_ids: jax.Array | None = None, + gate_weights: jax.Array | None = None, + gate_indices: jax.Array | None = None, ) -> tuple[jax.Array, Optional[jax.Array], Optional[jax.Array]]: routed_experts, load_balance_loss, moe_bias_updates = self.routed_moe( - inputs, gate_inputs=gate_inputs, out_sharding=out_sharding + inputs, + gate_inputs=gate_inputs, + out_sharding=out_sharding, + input_ids=input_ids, + gate_weights=gate_weights, + gate_indices=gate_indices, ) shared_experts = self.shared_experts(inputs, intermediate_sharding=intermediate_sharding, out_sharding=out_sharding) - return routed_experts + shared_experts, load_balance_loss, moe_bias_updates + if self.config.decoder_block == ctypes.DecoderBlockType.DEEPSEEK_V4: + combined = (routed_experts.astype(jnp.float32) + shared_experts.astype(jnp.float32)).astype(self.dtype) + else: + combined = routed_experts + shared_experts + return combined, load_balance_loss, moe_bias_updates def get_gate_logit( diff --git a/src/maxtext/layers/nnx_decoders.py b/src/maxtext/layers/nnx_decoders.py index 262eb62277..6e133a0496 100644 --- a/src/maxtext/layers/nnx_decoders.py +++ b/src/maxtext/layers/nnx_decoders.py @@ -19,7 +19,7 @@ import functools import inspect import warnings -from typing import Any +from typing import Any, Optional import jax import jax.numpy as jnp @@ -48,6 +48,7 @@ deepseek, deepseek_batchsplit, deepseek_batchsplit_fp8, + deepseek_v4, gemma, gemma2, gemma3, @@ -63,6 +64,7 @@ qwen3_5, simple_layer, ) +from maxtext.models.deepseek_v4 import DeepSeekV4HyperHead from maxtext.multimodal import utils as mm_utils from maxtext.utils import max_logging, max_utils, maxtext_utils, sharding from maxtext.utils.maxtext_utils_nnx import nnx_ensure_scan_leading_axis @@ -299,9 +301,13 @@ def __init__( self.scanned_layers = None self.is_deepseek = self.config.decoder_block == DecoderBlockType.DEEPSEEK + self.is_deepseek_v4 = self.config.decoder_block == DecoderBlockType.DEEPSEEK_V4 self.is_gemma3 = self.config.decoder_block == DecoderBlockType.GEMMA3 self.is_gemma4 = self.config.decoder_block == DecoderBlockType.GEMMA4 + if self.is_deepseek_v4: + self.hc_head = DeepSeekV4HyperHead(config, rngs=rngs) + if self.config.scan_layers: if self.is_deepseek: assert len(decoder_block_classes) == 2 @@ -389,6 +395,67 @@ def __init__( self.layers_remainder = RemattedGemma4Block( config=self.config, mesh=mesh, quant=self.quant, model_mode=self.model_mode, **rem_layer_kwargs, rngs=rngs ) + elif self.is_deepseek_v4: + # The DeepSeek-V4 architecture implements a 3-Tier Split Scanning strategy. + # This structure partitions the decoder layers to cleanly accommodate: + # 1. Early prefix layers containing unique prefix MoE Hash Routing gate parameters (e.g., `gate.tid2eid`) + # which must reside outside JAX scan bounds due to non-uniform static routing metadata. + # 2. Inhomogeneous compression rates and alternating attention topologies (CSA vs. HCA) across layers. + # 3. Dynamic scan grouping of length 2 periodic cycle blocks to minimize compilation memory and overhead. + # + # The decoder stack is segmented as follows: + # - Tier 1: Unrolled Prefix Segment (`pre_layers`) - Executes early hash layers containing + # static routing parameters. + # - Tier 2: Middle Scanned blocks (`layers`) - Scanned in periodic cycles of 2 layers via + # `jax.lax.scan` for memory footprint savings. + # - Tier 3: Unrolled Suffix Remainder (`post_layers`) - Executes any trailing layers that + # do not align with a full period cycle. + scan_length = (config.num_decoder_layers - config.num_hash_layers) // 2 + num_remaining_layers = (config.num_decoder_layers - config.num_hash_layers) % 2 + layer_kwargs = {"num_of_layers": 2, "layer_offset": config.num_hash_layers} + + rem_layer_kwargs = { + "num_of_layers": num_remaining_layers, + "layer_offset": config.num_hash_layers + scan_length * 2, + } + + RemattedDeepSeekV4Block = deepseek_v4.DeepSeekV4ScannableBlock + + if config.num_hash_layers > 0: + self.pre_layers = RemattedDeepSeekV4Block( + config=self.config, + mesh=mesh, + quant=self.quant, + model_mode=self.model_mode, + num_of_layers=config.num_hash_layers, + layer_offset=0, + rngs=rngs, + ) + else: + self.pre_layers = None + + if scan_length > 0: + self.layers = self._create_scanned_layers( + RemattedDeepSeekV4Block, + length=scan_length, + metadata_axis_name="layers", + rngs=rngs, + **layer_kwargs, + ) + else: + self.layers = None + + if num_remaining_layers > 0: + self.post_layers = RemattedDeepSeekV4Block( + config=self.config, + mesh=mesh, + quant=self.quant, + model_mode=self.model_mode, + **rem_layer_kwargs, + rngs=rngs, + ) + else: + self.post_layers = None else: layer_cls = decoder_block_classes[0] num_layers = int(config.num_decoder_layers / config.inhomogeneous_layer_cycle_interval) @@ -435,6 +502,11 @@ def __init__( layer_kwargs = {"attention_type": gpt_oss.get_attention_type(layer_id=lyr)} elif config.decoder_block == DecoderBlockType.OLMO3: layer_kwargs = {"attention_type": olmo3.get_attention_type(layer_id=lyr)} + elif config.decoder_block == DecoderBlockType.DEEPSEEK_V4: + # Retrieve layer-specific compression ratio from configuration to support sliding window attention + # at boundary layers and alternating compressed sparse/heavily compressed attention. + compress_ratio = self.config.compress_ratios[lyr] + layer_kwargs = {"compress_ratio": compress_ratio, "layer_idx": lyr} self._create_and_register_layer(layer_cls, rngs, "layers", lyr, **layer_kwargs) @@ -713,6 +785,9 @@ def get_deepseek(): DecoderBlockType.SIMPLE: [simple_layer.SimpleDecoderLayer], DecoderBlockType.SIMPLE_MLP: [simple_layer.SimpleMlpDecoderLayer], DecoderBlockType.DEEPSEEK: get_deepseek(), + DecoderBlockType.DEEPSEEK_V4: get_scannable( + deepseek_v4.DeepSeekV4DecoderLayer, deepseek_v4.DeepSeekV4ScannableBlock + ), DecoderBlockType.GPT_OSS: get_scannable(gpt_oss.GptOssDecoderLayer, gpt_oss.GptOssScannableBlock), DecoderBlockType.QWEN3_NEXT: get_scannable(qwen3.Qwen3NextDecoderLayer, qwen3.Qwen3NextScannableBlock), DecoderBlockType.QWEN3_5: get_scannable(qwen3_5.Qwen3_5DecoderLayer, qwen3_5.Qwen3_5ScannableBlock), @@ -863,6 +938,7 @@ def get_norm_layer(self, num_features: int, rngs: nnx.Rngs): DecoderBlockType.SIMPLE_MLP, DecoderBlockType.LLAMA4, DecoderBlockType.OLMO3, + DecoderBlockType.DEEPSEEK_V4, ): return functools.partial(RMSNorm, num_features=num_features, shard_mode=self.config.shard_mode, rngs=rngs) elif self.config.decoder_block == DecoderBlockType.GPT3: @@ -1118,7 +1194,7 @@ def __call__( # Extract the bidirectional mask locally for layer configurations bidirectional_mask = multimodal_input.bidirectional_mask if multimodal_input is not None else None - if cfg.decoder_block in (DecoderBlockType.GEMMA3, DecoderBlockType.GEMMA4): + if cfg.decoder_block in (DecoderBlockType.GEMMA3, DecoderBlockType.GEMMA4, DecoderBlockType.DEEPSEEK_V4): layer_kwargs["bidirectional_mask"] = bidirectional_mask if attention_metadata is not None: @@ -1212,6 +1288,19 @@ def __call__( page_state, slot, ) + elif self.is_deepseek_v4: + y = self._apply_deepseek_v4_scanned_blocks( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk, + page_state, + slot, + bidirectional_mask=bidirectional_mask, + decoder_input_tokens=decoder_input_tokens, + ) else: scan_length = int(cfg.num_decoder_layers / cfg.inhomogeneous_layer_cycle_interval) if kv_caches is not None: @@ -1230,13 +1319,16 @@ def __call__( prevent_cse = maxtext_utils.should_prevent_cse_in_remat(cfg) # Hoisted function to preserve XLA cache ID - def pure_layer_fn(graphdef, state_in, y_in, kv_in): + def pure_layer_fn(graphdef, state_in, y_in, kv_in, decoder_input_tokens_in=None): if cfg.parameter_memory_host_offload: state_in = jax.tree.map(lambda x: jax.device_put(x, max_utils.device_space()), state_in) merged_layer = nnx.merge(graphdef, state_in) - out_y, out_kv = merged_layer(y_in, *layer_args, kv_cache=kv_in, **layer_kwargs) + call_kwargs = dict(layer_kwargs) + if decoder_input_tokens_in is not None: + call_kwargs["decoder_input_tokens"] = decoder_input_tokens_in + out_y, out_kv = merged_layer(y_in, *layer_args, kv_cache=kv_in, **call_kwargs) return out_y, out_kv, nnx.state(merged_layer) checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse) @@ -1254,11 +1346,11 @@ def pure_layer_fn(graphdef, state_in, y_in, kv_in): else: kv_cache = None - input_tokens = decoder_input_tokens if cfg.engram_layers else None - if input_tokens is not None: - layer_kwargs["decoder_input_tokens"] = input_tokens - - y, kv_cache, new_state = checkpointed_fn(graphdef, state, y, kv_cache) + input_tokens = ( + decoder_input_tokens if (cfg.engram_layers or cfg.decoder_block == DecoderBlockType.DEEPSEEK_V4) else None + ) + # Propagation of decoder_input_tokens of shape [B, S] alongside hidden state y of shape [B, S, k, D] + y, kv_cache, new_state = checkpointed_fn(graphdef, state, y, kv_cache, input_tokens) nnx.update(layer, new_state) if kv_caches is not None and kv_cache is not None: @@ -1277,7 +1369,10 @@ def pure_layer_fn(graphdef, state_in, y_in, kv_in): assert isinstance(y, jax.Array) # After the final transformer layer, `y` holds the raw, un-normalized hidden state. - if cfg.mhc_expansion_rate > 1: + if self.is_deepseek_v4: + # collapsed shape: [B, S, k, D] -> [B, S, D] via learnable collapse weights + hidden_state = self.hc_head(y) + elif cfg.mhc_expansion_rate > 1: # (batch, length, mhc_expansion_rate, emb_dim) --> (batch, length, emb_dim) hidden_state = mhc_reduce(y) else: @@ -1401,6 +1496,85 @@ def pure_gemma_fn(graphdef, state_in, y_in): return y + def _apply_deepseek_v4_scanned_blocks( + self, + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk, + page_state, + slot, + bidirectional_mask: Optional[jax.Array] = None, + decoder_input_tokens: Optional[jax.Array] = None, + ): + """Applies DeepSeek-V4 scanned decoder blocks, handling main scan and remainders.""" + # Initial input tensor shape at entrance: + # y: [B, S, D] where B = Batch size, S = Sequence length, D = Embedding dimension + cfg = self.config + scan_length = (cfg.num_decoder_layers - cfg.num_hash_layers) // 2 + + layer_args = (decoder_segment_ids, decoder_positions, deterministic, model_mode) + layer_kwargs = { + "decoder_input_tokens": decoder_input_tokens, + "bidirectional_mask": bidirectional_mask, + } + + # Apply the boundary prefix unrolled layers + # Input shape: [B, S, D] + # Output shape: [B, S, D] + if cfg.num_hash_layers > 0: + policy = self.get_remat_policy() + prevent_cse = maxtext_utils.should_prevent_cse_in_remat(cfg) + + call_kwargs = { + "decoder_segment_ids": decoder_segment_ids, + "decoder_positions": decoder_positions, + "deterministic": deterministic, + "model_mode": model_mode, + "previous_chunk": previous_chunk, + "page_state": page_state, + "slot": slot, + **layer_kwargs, + } + + out_tuple = self._apply_layer_with_remat(self.pre_layers, y, policy, prevent_cse, **call_kwargs) + y = out_tuple[0] # Unrolled prefix execution: [B, S, D] -> [B, S, D] + + # Apply the main scan over the full blocks + # This tier scans periodic cycles of block length 2 (CSA + HCA layers). + # Input shape entering scan: [B, S, D] + # Output shape exiting scan: [B, S, D] (via 20 step scans of length 2 cycle blocks) + if scan_length > 0: + y, self.layers, _ = self._apply_layers_sequentially(self.layers, y, *layer_args, length=scan_length, **layer_kwargs) + + # Apply any remaining layers that did not fit into a full scanned block + # Input shape entering remainder tier: [B, S, D] + # Output shape exiting remainder tier: [B, S, D] if remaining layers exist. + num_remaining_layers = (cfg.num_decoder_layers - cfg.num_hash_layers) % 2 + if num_remaining_layers > 0: + policy = self.get_remat_policy() + prevent_cse = maxtext_utils.should_prevent_cse_in_remat(cfg) + + call_kwargs = { + "decoder_segment_ids": decoder_segment_ids, + "decoder_positions": decoder_positions, + "deterministic": deterministic, + "model_mode": model_mode, + "previous_chunk": previous_chunk, + "page_state": page_state, + "slot": slot, + **layer_kwargs, + } + + out_tuple = self._apply_layer_with_remat(self.post_layers, y, policy, prevent_cse, **call_kwargs) + y = out_tuple[0] # Remainder suffix execution: [B, S, D] -> [B, S, D] + + # Final return tensor shape: + # y: [B, S, D] + return y + def decoder_as_linen( config: Config, diff --git a/src/maxtext/layers/normalizations.py b/src/maxtext/layers/normalizations.py index bf91262bf1..a2697353df 100644 --- a/src/maxtext/layers/normalizations.py +++ b/src/maxtext/layers/normalizations.py @@ -240,3 +240,64 @@ def l2norm(x: Array, dim: int = -1, eps: float = 1e-6) -> Array: scale_init=linen_initializers.zeros, scale_offset=1.0, ) + + +class DeepSeekV4RMSNorm(nnx.Module): + """RMS normalization for DeepSeek-V4.""" + + def __init__( + self, + hidden_size: int, + eps: float = 1e-6, + dtype: Any = jnp.float32, + weight_dtype: Any = jnp.float32, + ): + self.hidden_size = hidden_size + self.eps = eps + self.dtype = dtype + self.weight_dtype = weight_dtype + + # Initialize learnable scale weight to ones + self.weight = nnx.Param(jnp.ones((hidden_size,), dtype=weight_dtype)) + self.scale = self.weight + + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + # [B, S, D] where D = hidden_size + # Convert inputs to float32 for numerical stability during variance pooling + x_f32 = jnp.asarray(x, jnp.float32) # [B, S, D] in float32 + + # Calculate variance across features axis + variance = jnp.mean(lax.square(x_f32), axis=-1, keepdims=True) # [B, S, 1] + + # Apply reciprocal square root with epsilon offset + normalized = x_f32 * lax.rsqrt(variance + self.eps) # [B, S, D] + + # Cast back to active precision and apply scaling weight + y = jnp.asarray(normalized, self.dtype) # [B, S, D] + weight = jnp.asarray(self.weight.get_value(), self.dtype) # [D] + return y * weight # [B, S, D] + + +class DeepSeekV4UnweightedRMSNorm(nnx.Module): + """Unweighted RMS normalization for DeepSeek-V4.""" + + def __init__( + self, + eps: float = 1e-6, + dtype: Any = jnp.float32, + ): + self.eps = eps + self.dtype = dtype + + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + # [..., D] where D is feature dimension + # Convert inputs to float32 for numerical stability during variance pooling + x_f32 = jnp.asarray(x, jnp.float32) # [..., D] in float32 + + # Calculate variance across features axis + variance = jnp.mean(lax.square(x_f32), axis=-1, keepdims=True) # [..., 1] + + # Apply reciprocal square root, cast to active precision, and multiply + inv_norm = jnp.asarray(lax.rsqrt(variance + self.eps), self.dtype) # [..., 1] + x_active = jnp.asarray(x, self.dtype) # [..., D] + return x_active * inv_norm # [..., D] diff --git a/src/maxtext/models/deepseek_v4.py b/src/maxtext/models/deepseek_v4.py new file mode 100644 index 0000000000..602e3aa147 --- /dev/null +++ b/src/maxtext/models/deepseek_v4.py @@ -0,0 +1,426 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Decoder Layer and Scannable Block definitions for DeepSeek-V4. + +DeepSeek-V4 Decoder Layer Data Flow Guide: +`B` = batch_size, `S` = sequence_length, `k` = hc_mult (expansion rate), `D` = hidden_size + + Parallel Streams Input [B, S, k, D] + │ + ├───► [mHC Pre-norm & Mapping] ──► [B, S, k * D] ──► Flat-Norm + │ │ + │ [pre_alpha / pre_beta] + │ │ + │ ▼ + │ Sigmoid Logits + │ │ + │ ▼ + │ "pre" weights [B, S, k] + │ │ + ├─────────────────────────────────────────────────────────┼────────────┐ + ▼ ▼ │ + Parallel Streams [B, S, k, D] Collapse Sum │ + │ │ │ + [mHC Res-Mapping] ▼ │ + │ Collapsed [B, S, D] │ + [res_alpha / res_beta] │ │ + │ RMSNorm Pre-Attn │ + ▼ │ │ + Sigmoid Logits [B, S, k, k] ▼ │ + │ DeepSeekV4Attention │ + Sinkhorn-Knopp │ │ + │ ▼ │ + Doubly Stochastic "comb" Attn Output [B, S, D] │ + │ │ │ + ▼ [mHC Post-Mapping] │ + Multiply [post_alpha / beta] │ + │ │ │ + ▼ ▼ │ + Mixed Residual Sigmoid Logits │ + [B, S, k, D] │ │ + │ ▼ │ + │ "post" weights [B, S, k]│ + │ │ │ + │ ▼ │ + │ Expanded Output │ + │ [B, S, k, D] │ + │ │ │ + └───────────────────────► ( + ) ◄─────────────────────────┘ │ + │ │ + ▼ │ + Attention Site Output │ + [B, S, k, D] │ + │ │ + ▼ │ + Experts MoE FFN Site │ + (Same flow: Collapse -> MoE -> Expand) │ + │ │ + ▼ │ + Layer Output [B, S, k, D] ◄──────────────────────────┘ +""" + +from typing import Any, Optional +from flax import nnx +import jax +from jax.ad_checkpoint import checkpoint_name +import jax.numpy as jnp +from jax.sharding import Mesh + +from maxtext.common.common_types import Config, HyperConnectionType, MODEL_MODE_PREFILL +from maxtext.layers import initializers +from maxtext.layers import mhc +from maxtext.layers import moe +from maxtext.layers import nnx_wrappers +from maxtext.layers import quantizations +from maxtext.layers.attention_compressed import DeepSeekV4Attention +from maxtext.layers.normalizations import DeepSeekV4RMSNorm, DeepSeekV4UnweightedRMSNorm +from maxtext.utils import max_utils +from maxtext.utils.sharding import create_sharding +from maxtext.utils.sharding import maybe_shard_with_logical + + +def get_attention_type(compress_ratio: int) -> str: + """Returns the attention type string corresponding to the given compression ratio.""" + if compress_ratio == 0: + return "sliding_attention" + elif compress_ratio == 4: + return "compressed_sparse_attention" + else: + return "heavily_compressed_attention" + + +class DeepSeekV4DecoderLayer(nnx.Module): + """Transformer decoder layer for DeepSeek-V4. + + This layer unconditionally implements routed and shared MoE and unconditionally + applies Manifold-Constrained Hyper-Connections (mHC) to both attention and FFN block outputs. + """ + + def __init__( + self, + config: Config, + mesh: Mesh, + model_mode: str, + rngs: nnx.Rngs, + quant: Optional[quantizations.AqtQuantization] = None, + compress_ratio: int = 4, + layer_idx: int = 0, + ): + self.config = config + self.mesh = mesh + self.model_mode = model_mode + self.quant = quant + self.rngs = rngs + self.compress_ratio = compress_ratio + self.layer_idx = layer_idx + + batch_size, sequence_length = max_utils.get_batch_seq_len_for_mode(self.config, self.model_mode) + self.dummy_inputs_shape = (batch_size, sequence_length, self.config.emb_dim) + + # Pre-attention normalization layer + self.pre_self_attention_layer_norm = DeepSeekV4RMSNorm( + hidden_size=self.config.emb_dim, + eps=self.config.normalization_layer_epsilon, + dtype=self.config.dtype, + weight_dtype=self.config.weight_dtype, + ) + + # Compressed multi-head attention module. + num_heads = ( + self.config.num_query_heads if self.config.num_query_heads is not None else self.config.num_attention_heads + ) + attention_type = get_attention_type(self.compress_ratio) + + self.self_attention = DeepSeekV4Attention( + hidden_size=self.config.emb_dim, + q_lora_rank=self.config.q_lora_rank, + head_dim=self.config.head_dim, + num_heads=num_heads, + config=config, + layer_idx=layer_idx, + mesh=self.mesh, + eps=self.config.normalization_layer_epsilon, + weight_dtype=self.config.weight_dtype, + dtype=self.config.dtype, + attention_type=attention_type, + rngs=self.rngs, + ) + + # Manifold-constrained hyper-connection wrapper for attention block outputs. + self.mhc_attention = mhc.ManifoldConstrainedHyperConnections( + config=self.config, + dim=self.config.emb_dim, + mesh=self.mesh, + rngs=self.rngs, + ) + + # Pre-FFN normalization layer + self.post_self_attention_layer_norm = DeepSeekV4RMSNorm( + hidden_size=self.config.emb_dim, + eps=self.config.normalization_layer_epsilon, + dtype=self.config.dtype, + weight_dtype=self.config.weight_dtype, + ) + + # Routed sparse and shared experts mixture-of-experts FFN module. + self.mlp = moe.RoutedAndSharedMoE( + config=self.config, + mesh=self.mesh, + kernel_init=initializers.nd_dense_init(self.config.dense_init_scale, "fan_in", "truncated_normal"), + kernel_axes=("embed_moe", None), + weight_dtype=self.config.weight_dtype, + dtype=self.config.dtype, + quant=self.quant, + rngs=self.rngs, + layer_idx=self.layer_idx, + ) + + # Manifold-constrained hyper-connection wrapper for FFN block outputs. + self.mhc_mlp = mhc.ManifoldConstrainedHyperConnections( + config=self.config, + dim=self.config.emb_dim, + mesh=self.mesh, + rngs=self.rngs, + ) + + self.out_sharding = create_sharding(self.mesh, self.logical_axis_names, rules=self.config.logical_axis_rules) + + @property + def logical_axis_names(self): + """Generate logical names for activations dynamically decoupling length dimensions.""" + length_name = "prefill_activation_norm_length" if self.model_mode == MODEL_MODE_PREFILL else "activation_norm_length" + return ["activation_batch", length_name, "activation_embed"] + + def with_logical_constraint(self, x): + """Applies sharding constraints over logical axes.""" + return maybe_shard_with_logical( + x, + logical_axes=self.logical_axis_names, + mesh=self.mesh, + shard_mode=self.config.shard_mode, + debug_sharding=self.config.debug_sharding, + extra_stack_level=1, + rules=self.config.logical_axis_rules, + ) + + def __call__( + self, + inputs: jnp.ndarray, + decoder_segment_ids: Optional[jnp.ndarray] = None, + decoder_positions: Optional[jnp.ndarray] = None, + deterministic: bool = True, + model_mode: str = "train", + previous_chunk: Optional[jnp.ndarray] = None, + page_state: Any = None, + slot: Any = None, + bidirectional_mask: Optional[jnp.ndarray] = None, + kv_cache: Any = None, + attention_metadata: Any = None, + cos: Optional[jnp.ndarray] = None, + sin: Optional[jnp.ndarray] = None, + position_ids: Optional[jnp.ndarray] = None, + decoder_input_tokens: Optional[jnp.ndarray] = None, + ): + # inputs shape: [B, S, k, D] (where B = batch, S = sequence length, k = expansion rate, D = hidden dim) + if isinstance(inputs, tuple): + inputs = inputs[0] + + if decoder_positions is None and position_ids is not None: + decoder_positions = position_ids + if decoder_segment_ids is None: + decoder_segment_ids = jnp.zeros(inputs.shape[:2], dtype=jnp.int32) + + # Apply constraint to inputs: [B, S, k, D] -> [B, S, k, D] + x = self.with_logical_constraint(inputs) + x = checkpoint_name(x, "decoder_layer_input") + + # 1. Attention hyper-connection block + # intermediate_inputs: [B, S, k, D] -> [B, S, k, D] + intermediate_inputs, _ = self.mhc_attention( + norm_fn=self.pre_self_attention_layer_norm, + branch_fn=self.self_attention, + x=x, + mhc_type=HyperConnectionType.ATTENTION, + attention_mask=bidirectional_mask, + cos=cos, + sin=sin, + position_ids=decoder_positions, + ) + + # 2. Experts MoE FFN hyper-connection block + # Inputs: intermediate_inputs: [B, S, k, D], decoder_input_tokens (input_ids): [B, S] + # Outputs output: [B, S, k, D] + output, metadata = self.mhc_mlp( + norm_fn=self.post_self_attention_layer_norm, + branch_fn=self.mlp, + x=intermediate_inputs, + mhc_type=HyperConnectionType.MLP_MOE, + input_ids=decoder_input_tokens, + ) + + load_balance_loss = metadata["load_balance_loss"] + if self.config.load_balance_loss_weight > 0.0 and load_balance_loss is not None: + self.sow("intermediates", "moe_lb_loss", load_balance_loss) + + # Final output constraint application: [B, S, k, D] -> [B, S, k, D] + output = self.with_logical_constraint(output) + + if self.config.scan_layers: + return output, None + else: + return output, kv_cache + + +DeepSeekV4DecoderLayerToLinen = nnx_wrappers.to_linen_class( + DeepSeekV4DecoderLayer, + base_metadata_fn=initializers.variable_to_logically_partitioned, +) + + +class DeepSeekV4ScannableBlock(nnx.Module): + """A repeating cyclical block of DeepSeek-V4 decoder layers for compiler scan loops.""" + + def __init__( + self, + config: Config, + mesh: Mesh, + model_mode: str, + rngs: nnx.Rngs, + quant: Optional[quantizations.AqtQuantization] = None, + num_of_layers: int = 2, + layer_offset: int = 0, + ): + self.config = config + self.mesh = mesh + self.model_mode = model_mode + self.quant = quant + self.rngs = rngs + self.num_of_layers = num_of_layers + self.layer_offset = layer_offset + + for layer_id in range(self.num_of_layers): + abs_layer_id = self.layer_offset + layer_id + # Retrieve layer-specific compression ratio from configuration to support sliding window attention + # at boundary layers and alternating compressed sparse/heavily compressed attention. + compress_ratio = self.config.compress_ratios[abs_layer_id] + layer_name = f"layers_{layer_id}" + layer = DeepSeekV4DecoderLayer( + config=self.config, + mesh=self.mesh, + model_mode=self.model_mode, + rngs=self.rngs, + quant=self.quant, + compress_ratio=compress_ratio, + layer_idx=abs_layer_id, + ) + setattr(self, layer_name, layer) + + def __call__( + self, + inputs: jnp.ndarray, + decoder_segment_ids: jnp.ndarray, + decoder_positions: jnp.ndarray, + deterministic: bool, + model_mode: str, + slot: Any = None, + page_state: Any = None, + previous_chunk: Optional[jnp.ndarray] = None, + bidirectional_mask: Optional[jnp.ndarray] = None, + decoder_input_tokens: Optional[jnp.ndarray] = None, + ): + y = inputs + for layer_id in range(self.num_of_layers): + y, _ = getattr(self, f"layers_{layer_id}")( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk=previous_chunk, + page_state=page_state, + slot=slot, + bidirectional_mask=bidirectional_mask, + decoder_input_tokens=decoder_input_tokens, + ) + return y, None + + +DeepSeekV4ScannableBlockToLinen = nnx_wrappers.to_linen_class( + DeepSeekV4ScannableBlock, + base_metadata_fn=initializers.variable_to_logically_partitioned, +) + + +class DeepSeekV4HyperHead(nnx.Module): + """Final learnable Manifold-Constrained Hyper-Connection (mHC) collapse head. + + This head collapses the parallel streams [B, S, k, D] down to a single + sequence [B, S, D] before applying the final RMSNorm. + """ + + def __init__(self, config: Config, rngs: nnx.Rngs): + self.config = config + self.hc_mult = getattr(config, "mhc_expansion_rate", 4) + self.eps = getattr(config, "hc_eps", 1e-6) + self.dtype = config.dtype + self.weight_dtype = config.weight_dtype + self.matmul_precision = jax.lax.Precision(config.matmul_precision) + + # Scale-free unweighted RMSNorm + self.input_norm = DeepSeekV4UnweightedRMSNorm(eps=config.normalization_layer_epsilon) + + # Parameter variables representing learnable linear projections + scale_init = initializers.nd_dense_init(1.0, "fan_in", "normal") + self.hc_fn = nnx.Param( + scale_init( + rngs.params(), + (self.hc_mult * config.emb_dim, self.hc_mult), + self.weight_dtype, + in_axis=0, + out_axis=1, + ), + out_sharding=("activation_embed", None), + ) + self.hc_base = nnx.Param( + initializers.default_bias_init(rngs.params(), (self.hc_mult,), self.weight_dtype), + out_sharding=(None,), + ) + self.hc_scale = nnx.Param( + initializers.default_scalar_init(rngs.params(), (1,), self.weight_dtype), + out_sharding=(None,), + ) + + def __call__(self, x: jax.Array) -> jax.Array: + # x shape: [B, S, k, D] where B = batch_size, S = sequence_length, k = hc_mult, D = emb_dim + b, s, k, d = x.shape + + # 1. Flatten streams and apply scale-free normalization + # [B, S, k, D] -> [B, S, k * D] + flat = self.input_norm(jnp.reshape(x, (b, s, k * d))) + + # 2. Match precision and project flat features to mixing logits + hc_fn = jnp.asarray(self.hc_fn[...], self.dtype) + hc_base = jnp.asarray(self.hc_base[...], self.dtype) + hc_scale = jnp.asarray(self.hc_scale[...], self.dtype) + + # mixes calculation: [B, S, k * D] @ [k * D, k] -> [B, S, k] + mixes = jnp.einsum("bsm,mk -> bsk", flat, hc_fn, precision=self.matmul_precision) + + # mixes sigmoid weights calculation: [B, S, k] + pre = jax.nn.sigmoid(mixes * hc_scale + hc_base[None, None, :]) + self.eps + + # 3. Collapse parallel streams: [B, S, k, D] * [B, S, k] -> [B, S, D] + collapsed = jnp.einsum("bsed,bse -> bsd", x, pre, precision=self.matmul_precision) + return collapsed diff --git a/tests/unit/deepseek_v4_vs_reference_test.py b/tests/unit/deepseek_v4_vs_reference_test.py new file mode 100644 index 0000000000..cc3d5d8391 --- /dev/null +++ b/tests/unit/deepseek_v4_vs_reference_test.py @@ -0,0 +1,2472 @@ +# pylint: skip-file +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for DeepSeek-V4 Attention and Compressor parity.""" + +import sys +import unittest +from collections.abc import Callable +from typing import Optional +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +import jax +import jax.numpy as jnp +from jax.sharding import Mesh +from flax import nnx +from maxtext.configs import pyconfig +from maxtext.layers.moe import DeepSeekV4TopKRouter, DeepSeekV4HashRouter +from maxtext.layers import attention_compressed, mhc +from maxtext.models.deepseek_v4 import DeepSeekV4DecoderLayer, DeepSeekV4ScannableBlock, DeepSeekV4HyperHead +from maxtext.layers.embeddings import DeepSeekV4RotaryEmbedding, apply_rotary_pos_emb, Embed +from maxtext.layers.normalizations import DeepSeekV4RMSNorm, DeepSeekV4UnweightedRMSNorm +from maxtext.layers.linears import DeepSeekGroupedLinear +from maxtext.layers.nnx_decoders import NNXDecoder +import maxtext.common.common_types as ctypes +from tests.utils.test_helpers import get_test_config_path, get_decoupled_parallelism_overrides + + +# ============================================================================== +# 1. Mock / Stub classes to support the exact Hugging Face / Scratch model code +# ============================================================================== + + +class RopeParameters(dict): + pass + + +class PreTrainedConfig: + + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + +class DeepseekV4Config(PreTrainedConfig): + + def __init__(self, **kwargs): + # Default V4-Flash configuration values for testing + self.vocab_size = 129280 + self.hidden_size = 4096 + self.moe_intermediate_size = 2048 + self.num_hidden_layers = 43 + self.num_attention_heads = 64 + self.num_key_value_heads = 1 + self.head_dim = 512 + self.q_lora_rank = 1024 + self.partial_rotary_factor = 64 / 512 + self.qk_rope_head_dim = 64 + self.max_position_embeddings = 1048576 + self.rope_theta = 10000.0 + self.compress_rope_theta = 160000.0 + self.compress_rates = { + "compressed_sparse_attention": 4, + "heavily_compressed_attention": 128, + } + self.compress_ratios = [128] * 43 + self.sliding_window = 128 + self.o_groups = 8 + self.o_lora_rank = 1024 + self.index_n_heads = 64 + self.index_head_dim = 128 + self.index_topk = 512 + self.rms_norm_eps = 1.0e-6 + self.attention_dropout = 0.0 + self._attn_implementation = "eager" + self.matmul_precision = "default" + self.layer_types = ["compressed_sparse_attention"] * 43 + self.mlp_layer_types = ["hash_moe"] * 43 + self.num_experts_per_tok = 6 + self.n_routed_experts = 256 + self.num_local_experts = 256 + self.n_shared_experts = 1 + self.scoring_func = "sqrtsoftplus" + self.routed_scaling_factor = 1.5 + self.intermediate_size = 2048 + self.hidden_act = "silu" + self.swiglu_limit = 10.0 + self.mlp_bias = False + self.attention_bias = False + self.hc_mult = 4 + self.hc_sinkhorn_iters = 20 + self.hc_eps = 1e-6 + + # Setup default rope parameters + dim = int(self.head_dim * self.partial_rotary_factor) + self.rope_parameters = { + "main": { + "rope_type": "default", + "rope_theta": self.rope_theta, + "partial_rotary_factor": self.partial_rotary_factor, + }, + "compress": { + "rope_type": "default", + "rope_theta": self.compress_rope_theta, + "partial_rotary_factor": self.partial_rotary_factor, + }, + } + super().__init__(**kwargs) + + +class DynamicSlidingWindowLayer: + + def __init__(self, config: DeepseekV4Config): + self.sliding_window = config.sliding_window + self.keys = None + self.values = None + self.is_initialized = False + self.cumulative_length = 0 + + def lazy_initialization(self, key_states, value_states): + self.keys = key_states + self.values = value_states + self.is_initialized = True + + +class Cache: + + def __init__(self): + self.layers = [] + + +class OutputRecorder: + pass + + +class FlashAttentionKwargs(dict): + pass + + +try: + from typing import Unpack +except ImportError: + from typing_extensions import Unpack + + +# Mock / stub decorator +def use_kernel_forward_from_hub(*args, **kwargs): + def decorator(cls): + return cls + + return decorator + + +# Stub implementation of ALL_ATTENTION_FUNCTIONS +class AllAttentionFunctionsStub: + + def get_interface(self, implementation_name, default_fn): + return default_fn + + +ALL_ATTENTION_FUNCTIONS = AllAttentionFunctionsStub() + +# Dummy registry to make the copied file happy +ROPE_INIT_FUNCTIONS = {} +dynamic_rope_update = lambda fn: fn +maybe_autocast = lambda device_type, enabled: torch.enable_grad() # No-op context + +use_experts_implementation = lambda cls: cls + + +class TransformersKwargs(dict): + pass + + +ACT2FN = { + "silu": F.silu, + "sigmoid": torch.sigmoid, + "sqrtsoftplus": lambda x: torch.sqrt(F.softplus(x)), +} + +# ============================================================================== +# 2. EXACT COPY OF PYTORCH REFERENCE CLASSES (SOURCE OF TRUTH - READ ONLY) +# ============================================================================== + + +class DeepseekV4RMSNorm_PT(nn.Module): + + def __init__(self, hidden_size, eps: float = 1e-6) -> None: + """ + DeepseekV4RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class DeepseekV4UnweightedRMSNorm_PT(nn.Module): + + def __init__(self, eps: float = 1.0e-6): + super().__init__() + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x * torch.rsqrt(x.float().square().mean(-1, keepdim=True) + self.eps).to(x.dtype) + + +class DeepseekV4RotaryEmbedding_PT(nn.Module): + """ + Multi-layer-type rotary embedding (Laguna pattern: partial rotary on top of + Gemma3's per-layer-type buffers), specialised for V4's *interleaved* RoPE. + Interleaved RoPE: one `θ_i` per pair (`rope_head_dim // 2` entries), + DIFF no end-to-end duplication. Same shape as `inv_freq @ position_ids`. + + V4 deliberately decouples its architecture `layer_types` + (`sliding_attention` / `compressed_sparse_attention` / + `heavily_compressed_attention`) from its rope-type labels (`main` / + `compress`) — the latter live as keys in `config.rope_parameters` and + only differ in their `rope_theta` base. So this override replaces + Laguna's `set(config.layer_types)` iteration with `rope_parameters.keys()` + when building the per-type inv_freq buffers. + """ + + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: DeepseekV4Config): + super().__init__() + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + self.config = config + # Only the nested per-rope-type sub-dicts are real layer types — the top-level + # `rope_type` key that ``convert_rope_params_to_dict`` may leave on + # ``config.rope_parameters`` is a flat-shape leftover, not a layer. + self.layer_types = [k for k, v in config.rope_parameters.items() if isinstance(v, dict)] + self.rope_type = {} + for layer_type in self.layer_types: + rope_params = config.rope_parameters[layer_type] + self.rope_type[layer_type] = rope_params["rope_type"] + rope_init_fn = self.compute_default_rope_parameters + if self.rope_type[layer_type] != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type[layer_type]] + inv_freq, attention_scaling = rope_init_fn(config, layer_type=layer_type) + self.register_buffer(f"{layer_type}_inv_freq", inv_freq, persistent=False) + self.register_buffer(f"{layer_type}_original_inv_freq", inv_freq.clone(), persistent=False) + setattr(self, f"{layer_type}_attention_scaling", attention_scaling) + + @staticmethod + def compute_default_rope_parameters( + config: DeepseekV4Config | None = None, + device: Optional["torch.device"] = None, + seq_len: int | None = None, + layer_type: str | None = None, + ) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + layer_type (`str`, *optional*): + The current layer type if the model has different RoPE parameters per type. + Should not be used unless `config.layer_types is not None` + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + base = config.rope_parameters[layer_type]["rope_theta"] + # key difference to gemma3: partial rope + partial_rotary_factor = config.rope_parameters[layer_type].get("partial_rotary_factor", 1.0) + head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + dim = int(head_dim * partial_rotary_factor) + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)) + return inv_freq, attention_factor + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids, layer_type=None): + # Key difference vs Laguna's forward: no `torch.cat([freqs, freqs], dim=-1)` + # duplication. V4's interleaved RoPE pairs consecutive channels, so we only need + # `rope_head_dim // 2` unique θ entries — the `apply_rotary_pos_emb` helper does + # the `repeat_interleave(2)` next to the rotation math, where the link between + # the doubled dim and `rotate_half` is local and obvious. + inv_freq = getattr(self, f"{layer_type}_inv_freq") + attention_scaling = getattr(self, f"{layer_type}_attention_scaling") + inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with maybe_autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + cos = freqs.cos() * attention_scaling + sin = freqs.sin() * attention_scaling + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class DeepseekV4HCACache(DynamicSlidingWindowLayer): + r"""Cache layer for HCA blocks (paper §2.3.2). Holds the long-range compressor's + buffer / running compressed entries / count on top of the sliding-window K=V + branch. HCA uses *non-overlapping* windows, so there is *no* overlap state, + and HCA has *no* indexer either. + + State is dict-keyed by entry name — HCA only uses `"compressor"`, but + :class:`DeepseekV4CSACache` adds `"indexer"` to the same dicts so a single + set of methods (`store_compression_weights` / `update_compressor_states`) + serves both: + + * `compressed_kv[name]` — the running list of compressed KV entries + emitted so far (one every `compress_rate` source tokens; the long-range + KVs the attention concatenates onto its sliding-window keys / values). + * `buffer_kv[name]` / `buffer_gate[name]` — source tokens that arrived + between two full windows; once the buffer hits `compress_rate` tokens + the compressor closes a window, emits one entry, and drains the buffer. + * `entry_count[name]` — number of compressed entries emitted so far, so + `entry_count[name] * compress_rate` is the absolute position of the + *next* window's first source token. Tracked separately from + `position_ids` so prefill -> decode -> prefill stays consistent. + """ + + layer_type = "heavily_compressed_attention" + + def __init__(self, config: "DeepseekV4Config"): + super().__init__(config) + self.compress_rate = config.compress_rates["heavily_compressed_attention"] + self.buffer_kv: dict[str, torch.Tensor | None] = {"compressor": None} + self.buffer_gate: dict[str, torch.Tensor | None] = {"compressor": None} + self.compressed_kv: dict[str, torch.Tensor | None] = {"compressor": None} + self.entry_count: dict[str, int] = {"compressor": 0} + + def update(self, key_states: torch.Tensor, value_states: torch.Tensor, *args, **kwargs): + """ + Shared sliding-window K=V update body. V4 uses shared-KV MQA, so `keys` and + `values` point to the same storage on every layer. + """ + if not self.is_initialized: + self.lazy_initialization(key_states, value_states) + self.values = self.keys + self.cumulative_length += key_states.shape[-2] + full = torch.cat([self.keys, key_states], dim=-2) + self.keys = full[:, :, -self.sliding_window + 1 :, :] + self.values = self.keys + return full, full + + def store_compression_weights( + self, name: str, kv: torch.Tensor, gate: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, int]: + r""" + Concatenate the new projected `(kv, gate)` (paper §2.3.2 eqs. 20–21: + `C = H·W^{KV}`, `Z = H·W^Z`) for entry `name` with what's already in + the buffer, peel off the longest window-aligned prefix (the chunk + ready to compress), keep the leftover in the buffer for next call, + and return `(chunk_kv, chunk_gate, first_window_position)`. The + returned chunk is softmax-aggregated by the compressor with + `position_bias` to emit one compressed entry per window of + `compress_rate` tokens. + """ + first_window_position = self.entry_count[name] * self.compress_rate + buffered_kv, buffered_gate = self.buffer_kv[name], self.buffer_gate[name] + if buffered_kv is not None and buffered_kv.shape[1]: + kv = torch.cat([buffered_kv, kv], dim=1) + gate = torch.cat([buffered_gate, gate], dim=1) + # only return the longest prefix that's a multiple of compress_rate; the rest stays in the buffer for next time + usable = (kv.shape[1] // self.compress_rate) * self.compress_rate + self.buffer_kv[name], self.buffer_gate[name] = kv[:, usable:], gate[:, usable:] + return kv[:, :usable], gate[:, :usable], first_window_position + + def update_compressor_states(self, name: str, compressed: torch.Tensor) -> torch.Tensor: + r""" + Append freshly emitted compressed entries to `compressed_kv[name]` + (`C^{Comp}`, paper §2.3.2 eq. 23), bump `entry_count[name]`, and + return the running `compressed_kv[name]`. + """ + if self.compressed_kv[name] is None: + self.compressed_kv[name] = compressed + elif compressed.shape[1] > 0: + self.compressed_kv[name] = torch.cat([self.compressed_kv[name], compressed], dim=1) + self.entry_count[name] += compressed.shape[1] + return self.compressed_kv[name] + + +class DeepseekV4CSACache(DeepseekV4HCACache): + r"""Cache layer for CSA blocks (paper §2.3.1). Extends :class:`DeepseekV4HCACache` + by adding an `"indexer"` entry to the inherited `buffer_kv` / `buffer_gate` / + `compressed_kv` / `entry_count` dicts, plus per-name *overlap* state for the + two-series window scheme. + + What "overlap" means here: the CSA `kv_proj` / `gate_proj` produce `2 * head_dim` + features per source token — two independent compressed series Ca and Cb stored + in one tensor. Ca occupies `[..., :head_dim]`, Cb occupies `[..., head_dim:]`. + Pooled entry `w` is the softmax-gated convex combination of window `w-1`'s Ca + slice with window `w`'s Cb slice — effective width `2 * compress_rate_csa`, + stride `compress_rate_csa` (paper §2.3.1). + + Because adjacent windows share state only through *the previous window's Ca + slice*, the only thing we need to carry across a forward boundary is + `chunk[:, -1, :, :head_dim]` (Ca) of the last full window — Cb is never read + again. That's what `overlap_kv[name]` / `overlap_gate[name]` persist. + """ + + layer_type = "compressed_sparse_attention" + + def __init__(self, config: "DeepseekV4Config"): + super().__init__(config) + self.compress_rate = config.compress_rates["compressed_sparse_attention"] + self.buffer_kv["indexer"] = None + self.buffer_gate["indexer"] = None + self.compressed_kv["indexer"] = None + self.entry_count["indexer"] = 0 + self.overlap_kv: dict[str, torch.Tensor | None] = {"compressor": None, "indexer": None} + self.overlap_gate: dict[str, torch.Tensor | None] = {"compressor": None, "indexer": None} + + def update_overlap_state( + self, name: str, chunk_kv: torch.Tensor, chunk_gate: torch.Tensor, head_dim: int + ) -> tuple[torch.Tensor | None, torch.Tensor | None]: + r""" + Read the `name` entry's prior window's Ca slice (saved on the previous + forward call) and persist the *current* call's last-window Ca slice for + the next call. Only the `:head_dim` slice (Ca) is ever consumed + downstream — Cb has already been folded into the previous window's + emitted compressed entry — so we store half what `chunk[:, -1]` holds. + Returns `(prior_kv, prior_gate)` — both `None` on the very first call. + """ + prior_kv, prior_gate = self.overlap_kv[name], self.overlap_gate[name] + self.overlap_kv[name] = chunk_kv[:, -1, :, :head_dim].clone() + self.overlap_gate[name] = chunk_gate[:, -1, :, :head_dim].clone() + return prior_kv, prior_gate + + +class DeepseekV4GroupedLinear_PT(nn.Linear): + """Block-diagonal grouped linear used by the grouped output projection + The core attention's stacked output is `num_attention_heads* head_dim`-dim, + which is *very* large (V4-Flash: 32768; V4-Pro: 65536). A direct + `num_attention_heads*head_dim → hidden_size` projection would dominate the per-token cost. + + The paper sidesteps that by splitting the heads into `g` groups, projecting + each `num_attention_heads * head_dim/g`-dim group independently to a `d_g`-dim intermediate output + (with `d_g < num_attention_heads * head_dim/g`), and then mixing the resulting `g·d_g` vector to + `hidden_size` through a single follow-up linear (`self_attn.o_b_proj`). This + module owns the per-group block (`self_attn.o_a_proj`). + + For V4-Flash (num_attention_heads=64, head_dim=512, o_groups=8, o_lora_rank=1024, + hidden_size=4096), g=8 groups of 4096-dim each are projected to 1024-dim, then + mixed to 4096-dim; for V4-Pro (num_attention_heads=128, head_dim=512, o_groups=16, + o_lora_rank=1024, hidden_size=7168), g=16 groups of 4096-dim each are projected + to 1024-dim, then mixed to 7168-dim. + """ + + def __init__(self, in_features_per_group: int, out_features: int, n_groups: int, bias: bool = False): + super().__init__(in_features_per_group, out_features, bias=bias) + self.n_groups = n_groups + + def forward(self, x: torch.Tensor) -> torch.Tensor: + input_shape = x.shape[:-2] + hidden_dim = x.shape[-1] + w = self.weight.view(self.n_groups, -1, hidden_dim).transpose(1, 2) + x = x.reshape(-1, self.n_groups, hidden_dim).transpose(0, 1) + y = torch.bmm(x, w).transpose(0, 1) + return y.reshape(*input_shape, self.n_groups, -1) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., 0::2] + x2 = x[..., 1::2] + return torch.stack((-x2, x1), dim=-1).flatten(-2) + + +def apply_rotary_pos_emb_PT( + x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1 +) -> torch.Tensor: + """V4 interleaved RoPE applied to the *trailing* rope slice of `x`. + + `cos` / `sin` come in half-sized (one entry per interleaved pair, from + `DeepseekV4RotaryEmbedding`); we expand them to the full rope dim with + `repeat_interleave`, then rotate the last `2 * cos.shape[-1]` channels of `x` + with the standard `x*cos + rotate_half(x)*sin` formula in fp32 and leave the + leading nope channels untouched. V4-Flash lays each head out as `[nope | rope]`, + matching the reference's `x[..., -rd:]` indexing. + """ + cos = cos.repeat_interleave(2, dim=-1).unsqueeze(unsqueeze_dim) + sin = sin.repeat_interleave(2, dim=-1).unsqueeze(unsqueeze_dim) + rope_dim = cos.shape[-1] + nope, rope = x[..., :-rope_dim], x[..., -rope_dim:] + rotated = ((rope.float() * cos) + (rotate_half(rope).float() * sin)).to(x.dtype) + return torch.cat([nope, rotated], dim=-1) + + +class DeepseekV4HCACompressor_PT(nn.Module): + """ + Heavily Compressed Attention compressor (paper §2.3.2, eqs. 20–23). compresses + every `compress_rate_hca` (m'=128) source tokens into a single compressed KV + entry. + + Each closed window of m' tokens produces one compressed entry: + `C^{Comp}_i = Σ_{j∈window} softmax(Z_j + B)_j ⊙ C_j`. RoPE on the trailing + `rope_head_dim` slice is applied at the deterministic absolute position + `i * compress_rate_hca + first_window_position` so cross-call concatenation + stays causality-correct. Returns the running list of *all* compressed + entries emitted so far (shape `[B, 1, T, head_dim]` with + `T = entry_count["compressor"]`), so the attention can attend over the + full long-range history. + + When `past_key_values is None` runs in stateless single-shot mode: compress + every complete window from `hidden_states` and discard the remainder + (instead of caching it). + """ + + rope_layer_type = "compress" + + def __init__(self, config: DeepseekV4Config): + super().__init__() + self.compress_rate = config.compress_rates["heavily_compressed_attention"] + self.head_dim = config.head_dim + self.kv_proj = nn.Linear(config.hidden_size, self.head_dim, bias=False) + self.gate_proj = nn.Linear(config.hidden_size, self.head_dim, bias=False) + self.position_bias = nn.Parameter(torch.empty(self.compress_rate, self.head_dim)) + self.kv_norm = DeepseekV4RMSNorm_PT(self.head_dim, eps=config.rms_norm_eps) + self.rotary_emb = DeepseekV4RotaryEmbedding_PT(config) + + def forward( + self, + hidden_states: torch.Tensor, + q_residual: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: Cache | None, + layer_idx: int, + ) -> tuple[torch.Tensor, torch.Tensor]: + batch, _, _ = hidden_states.shape + cache_layer: DeepseekV4HCACache = past_key_values.layers[layer_idx] if past_key_values is not None else None + kv = self.kv_proj(hidden_states) + gate = self.gate_proj(hidden_states) + if cache_layer is None: + usable = (kv.shape[1] // self.compress_rate) * self.compress_rate + chunk_kv, chunk_gate, first_window_position = kv[:, :usable], gate[:, :usable], 0 + else: + chunk_kv, chunk_gate, first_window_position = cache_layer.store_compression_weights("compressor", kv, gate) + + if chunk_kv.shape[1] > 0: # there were at least self.compress_rate tokens + n_windows = chunk_kv.shape[1] // self.compress_rate + chunk_kv = chunk_kv.view(batch, n_windows, self.compress_rate, -1) + chunk_gate = chunk_gate.view(batch, n_windows, self.compress_rate, -1) + self.position_bias.to(chunk_gate.dtype) + compressed = self.kv_norm((chunk_kv * chunk_gate.softmax(dim=2, dtype=torch.float32).to(chunk_kv.dtype)).sum(dim=2)) + positions = torch.arange(n_windows, device=compressed.device) + positions = (positions * self.compress_rate + first_window_position).unsqueeze(0).expand(batch, -1) + cos, sin = self.rotary_emb(compressed, position_ids=positions, layer_type=self.rope_layer_type) + compressed = apply_rotary_pos_emb_PT(compressed.unsqueeze(1), cos, sin).squeeze(1) + else: + compressed = chunk_kv.new_zeros((batch, 0, self.head_dim)) + + if cache_layer is not None: + compressed = cache_layer.update_compressor_states("compressor", compressed) + compressed_kv = compressed.unsqueeze(1) + + compressed_len = compressed_kv.shape[2] + seq_len = position_ids.shape[1] + if seq_len == 1 or compressed_len == 0: + return compressed_kv, None + + # query `t` may only see cache entries at pos `w` t > w * compress_rate (ex: t=7, w=2 t does not attend to it). + entry_indices = torch.arange(compressed_len, device=compressed_kv.device) + causal_threshold = (position_ids + 1) // self.compress_rate # [B, S] + block_bias = compressed_kv.new_zeros((batch, 1, seq_len, compressed_len)) + block_bias = block_bias.masked_fill( + entry_indices.view(1, 1, 1, -1) >= causal_threshold.unsqueeze(1).unsqueeze(-1), + float("-inf"), + ) + return compressed_kv, block_bias + + +class DeepseekV4Indexer_PT(nn.Module): + r"""Lightning Indexer (paper §2.3.1, eqs. 13–17). Used by Compressed Sparse + Attention (CSA) to pick the top-`k` compressed KV blocks per query, with + `k = config.index_topk`. Each query then attends only to those `k` of the + `seq_len / compress_rate_csa` compressed entries — reduction factor + `(seq_len / compress_rate_csa) / index_topk` over full attention against + the entire compressed sequence. + + The indexer runs its own scaled-down compressor at `index_head_dim` over + the same windows as the outer CSA compressor, then scores queries against + the compressed keys with `∑_h w_{t,h} · ReLU(q_{t,h} · K^IComp_s)` and + keeps the top `index_topk` indices. + + The indexer has its own rotary because it applies RoPE to two sets of + tensors: + + * *compressed keys* at deterministic positions + `i * compress_rate + first_window_position`, + * *queries* at the model's current `position_ids` (variable per forward). + + Both must use the same theta as the outer compressor + (`compress_rope_theta`) so query/key inner products are + translation-invariant — if they used different thetas, `q · k` would carry + a residual position-dependent skew. We can't precompute cos/sin once at + init because the query positions vary per call, so the indexer owns its + own rotary and calls it twice per forward (once for compressed keys, once + for queries) with `layer_type=self.rope_layer_type` (always `"compress"`). + """ + + rope_layer_type = "compress" + + def __init__(self, config: DeepseekV4Config): + super().__init__() + self.compress_rate = config.compress_rates["compressed_sparse_attention"] + self.num_heads = config.index_n_heads + self.head_dim = config.index_head_dim + self.index_topk = config.index_topk + self.softmax_scale = self.head_dim**-0.5 + self.weights_scaling = self.num_heads**-0.5 + self.kv_proj = nn.Linear(config.hidden_size, 2 * self.head_dim, bias=False) + self.gate_proj = nn.Linear(config.hidden_size, 2 * self.head_dim, bias=False) + self.position_bias = nn.Parameter(torch.empty(self.compress_rate, 2 * self.head_dim)) + self.kv_norm = DeepseekV4RMSNorm_PT(self.head_dim, eps=config.rms_norm_eps) + self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.head_dim, bias=False) + self.weights_proj = nn.Linear(config.hidden_size, self.num_heads, bias=False) + self.rotary_emb = DeepseekV4RotaryEmbedding_PT(config) + + def forward( + self, + hidden_states: torch.Tensor, + q_residual: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: Cache | None, + layer_idx: int, + ) -> torch.LongTensor: + batch, seq_len, _ = hidden_states.shape + cache_layer: DeepseekV4CSACache = past_key_values.layers[layer_idx] if past_key_values is not None else None + kv = self.kv_proj(hidden_states) + gate = self.gate_proj(hidden_states) + + if cache_layer is None: + usable = (kv.shape[1] // self.compress_rate) * self.compress_rate + chunk_kv, chunk_gate, first_window_position = kv[:, :usable], gate[:, :usable], 0 + else: + chunk_kv, chunk_gate, first_window_position = cache_layer.store_compression_weights("indexer", kv, gate) + + if chunk_kv.shape[1] > 0: + n_windows = chunk_kv.shape[1] // self.compress_rate + ratio = self.compress_rate + chunk_kv = chunk_kv.view(batch, n_windows, ratio, -1) + chunk_gate = chunk_gate.view(batch, n_windows, ratio, -1) + self.position_bias.to(chunk_gate.dtype) + + # Same Ca / Cb overlap layout as the outer CSA compressor, at index_head_dim. + new_kv = chunk_kv.new_zeros((batch, n_windows, 2 * ratio, self.head_dim)) + new_gate = chunk_gate.new_full((batch, n_windows, 2 * ratio, self.head_dim), float("-inf")) + new_kv[:, :, ratio:] = chunk_kv[..., self.head_dim :] + new_gate[:, :, ratio:] = chunk_gate[..., self.head_dim :] + if n_windows > 1: + new_kv[:, 1:, :ratio] = chunk_kv[:, :-1, :, : self.head_dim] + new_gate[:, 1:, :ratio] = chunk_gate[:, :-1, :, : self.head_dim] + if cache_layer is not None: + prior_kv, prior_gate = cache_layer.update_overlap_state("indexer", chunk_kv, chunk_gate, self.head_dim) + if prior_kv is not None: + new_kv[:, 0, :ratio] = prior_kv.to(new_kv.dtype) + new_gate[:, 0, :ratio] = prior_gate.to(new_gate.dtype) + + compressed = self.kv_norm((new_kv * new_gate.softmax(dim=2, dtype=torch.float32).to(new_kv.dtype)).sum(dim=2)) + positions = torch.arange(n_windows, device=compressed.device) + positions = positions * self.compress_rate + first_window_position + positions = positions.unsqueeze(0).expand(batch, -1) + cos, sin = self.rotary_emb(compressed, position_ids=positions, layer_type=self.rope_layer_type) + compressed = apply_rotary_pos_emb_PT(compressed.unsqueeze(1), cos, sin).squeeze(1) + else: + compressed = chunk_kv.new_zeros((batch, 0, self.head_dim)) + + compressed_kv = compressed if cache_layer is None else cache_layer.update_compressor_states("indexer", compressed) + + cos_q, sin_q = self.rotary_emb(hidden_states, position_ids=position_ids, layer_type=self.rope_layer_type) + q = self.q_b_proj(q_residual).view(batch, seq_len, -1, self.head_dim).transpose(1, 2) + q = apply_rotary_pos_emb_PT(q, cos_q, sin_q).transpose(1, 2) + + # ReLU(q·kᵀ) * weights, then top-k + scores = torch.matmul(q.float(), compressed_kv.transpose(-1, -2).float().unsqueeze(1)) # [B, S, H, T] + scores = F.relu(scores) * self.softmax_scale + weights = self.weights_proj(hidden_states).float() * self.weights_scaling # [B, S, H] + index_scores = (scores * weights.unsqueeze(-1)).sum(dim=2) # [B, S, T] + compressed_len = compressed_kv.shape[1] + top_k = min(self.index_topk, compressed_len) + + # not all queries can attend to the compressed entries. If a query's position + # is small than the relative position of the key (say m=4, query 2 cannot attend + # to compressed key at position 4, because it compressed info for states at position + # 12 to 16. Thus we need to make sure that top_k does not land in that range. + # Picks that still point past `causal_threshold` (early queries with too few ready + # blocks) are replaced with a `-1` sentinel that the compressor treats as invalid. + if compressed_len > 0: + causal_threshold = (position_ids + 1) // self.compress_rate # [B, S] + entry_indices = torch.arange(compressed_len, device=index_scores.device) + future_mask = entry_indices.view(1, 1, -1) >= causal_threshold.unsqueeze(-1) # [B, S, T] + index_scores = index_scores.masked_fill(future_mask, float("-inf")) + top_k_indices = index_scores.topk(top_k, dim=-1).indices # [B, S, k] + invalid = top_k_indices >= causal_threshold.unsqueeze(-1) + return torch.where(invalid, torch.full_like(top_k_indices, -1), top_k_indices) + + return index_scores.topk(top_k, dim=-1).indices + + +class DeepseekV4CSACompressor_PT(nn.Module): + """Compressed Sparse Attention compressor (paper §2.3.1, eqs. 9–17). Compresses + every `compress_rate_csa` (m=4) source tokens and runs a Lightning Indexer on + top of the compressed KV that scores queries with + `∑_h w_{t,h} · ReLU(q_{t,h} · K^{IComp}_s)` to gather the top `index_topk` + entries per query before they reach core attention. + + `kv_proj` / `gate_proj` / `position_bias` project to `2 * head_dim`: each + token contributes two independent compressed series Ca and Cb stored in + one tensor. Ca = `[..., :head_dim]` (its contribution to the *next* + window's compressed entry), Cb = `[..., head_dim:]` (its contribution to + the *current* window's compressed entry). Compressed entry `w` is the + softmax-gated convex combination of window `w-1`'s Ca slice with window + `w`'s Cb slice over `2 * compress_rate_csa` slots — width + `2 * compress_rate_csa`, stride `compress_rate_csa`. For `w = 0` we need + the previous window's Ca slice from the *previous forward call*; the + cache holds it in `overlap_kv` and hands it back here. On the very first + call (or when there is no cache) that slot stays zero-kv / `-inf`-gate, + which gives it softmax weight 0. + """ + + rope_layer_type = "compress" + + def __init__(self, config: DeepseekV4Config): + super().__init__() + self.compress_rate = config.compress_rates["compressed_sparse_attention"] + self.head_dim = config.head_dim + self.kv_proj = nn.Linear(config.hidden_size, 2 * self.head_dim, bias=False) + self.gate_proj = nn.Linear(config.hidden_size, 2 * self.head_dim, bias=False) + self.position_bias = nn.Parameter(torch.empty(self.compress_rate, 2 * self.head_dim)) + self.kv_norm = DeepseekV4RMSNorm_PT(self.head_dim, eps=config.rms_norm_eps) + self.rotary_emb = DeepseekV4RotaryEmbedding_PT(config) + self.indexer = DeepseekV4Indexer_PT(config) + + def forward( + self, + hidden_states: torch.Tensor, + q_residual: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: Cache | None, + layer_idx: int, + ) -> tuple[torch.Tensor, torch.Tensor]: + batch, seq_len, _ = hidden_states.shape + cache_layer: DeepseekV4CSACache = past_key_values.layers[layer_idx] if past_key_values is not None else None + kv = self.kv_proj(hidden_states) + gate = self.gate_proj(hidden_states) + + if cache_layer is None: + usable = (kv.shape[1] // self.compress_rate) * self.compress_rate + chunk_kv, chunk_gate, first_window_position = kv[:, :usable], gate[:, :usable], 0 + else: + chunk_kv, chunk_gate, first_window_position = cache_layer.store_compression_weights("compressor", kv, gate) + + if chunk_kv.shape[1] > 0: + n_windows = chunk_kv.shape[1] // self.compress_rate + ratio = self.compress_rate + chunk_kv = chunk_kv.view(batch, n_windows, ratio, -1) + chunk_gate = chunk_gate.view(batch, n_windows, ratio, -1) + self.position_bias.to(chunk_gate.dtype) + + # Lay out the two series in [B, n_win, 2*ratio, head_dim]: Cb + # (`[..., head_dim:]`) goes in the second half (current window), + # Ca of the previous window (`[..., :head_dim]`) goes in the + # first half. Window 0's first half stays zero-kv / -inf-gate + # (softmax weight 0) on the very first forward call; on later + # calls the cache fills it with the saved Ca slice. + new_kv = chunk_kv.new_zeros((batch, n_windows, 2 * ratio, self.head_dim)) + new_gate = chunk_gate.new_full((batch, n_windows, 2 * ratio, self.head_dim), float("-inf")) + new_kv[:, :, ratio:] = chunk_kv[..., self.head_dim :] + new_gate[:, :, ratio:] = chunk_gate[..., self.head_dim :] + if n_windows > 1: + new_kv[:, 1:, :ratio] = chunk_kv[:, :-1, :, : self.head_dim] + new_gate[:, 1:, :ratio] = chunk_gate[:, :-1, :, : self.head_dim] + if cache_layer is not None: + prior_kv, prior_gate = cache_layer.update_overlap_state("compressor", chunk_kv, chunk_gate, self.head_dim) + if prior_kv is not None: + new_kv[:, 0, :ratio] = prior_kv.to(new_kv.dtype) + new_gate[:, 0, :ratio] = prior_gate.to(new_gate.dtype) + + # Softmax in fp32 for stability (logits in bf16/fp16 can collapse pairs that + # only differ by a small amount, especially with large window widths). + compressed = self.kv_norm((new_kv * new_gate.softmax(dim=2, dtype=torch.float32).to(new_kv.dtype)).sum(dim=2)) + positions = torch.arange(n_windows, device=compressed.device) + positions = positions * self.compress_rate + first_window_position + positions = positions.unsqueeze(0).expand(batch, -1) + cos, sin = self.rotary_emb(compressed, position_ids=positions, layer_type=self.rope_layer_type) + compressed = apply_rotary_pos_emb_PT(compressed.unsqueeze(1), cos, sin).squeeze(1) + else: + compressed = chunk_kv.new_zeros((batch, 0, self.head_dim)) + + if cache_layer is not None: + compressed = cache_layer.update_compressor_states("compressor", compressed) + compressed_kv = compressed.unsqueeze(1) + + # Lightning Indexer: gather top-`index_topk` compressed entries per query. + # in some cases, the output index can return top-k positions that should not be attended to. + # Ex: for query at index 5, m=4, and `index_topk=1024`, 1024 index are return but only 2 should be + # attended to. The indexer marks the rest with `-1`; we clamp before the gather and keep the `valid` + # to drop them from the per-query block mask afterwards. + top_k_indices = self.indexer(hidden_states, q_residual, position_ids, past_key_values, layer_idx) # [B, S, k] + top_k = top_k_indices.shape[-1] + compressed_len = compressed_kv.shape[2] + valid = top_k_indices >= 0 # [B, S, k] + # Flatten (B, T) into one row axis and shift picks by `b * T`, then index_select once. + # Same kernel as an embedding lookup — cheaper than `gather` over an expanded view. + safe_indices = top_k_indices.clamp(min=0) + batch_offsets = (torch.arange(batch, device=compressed_kv.device) * compressed_len).view(batch, 1, 1) + flat_indices = (safe_indices + batch_offsets).view(-1) # [B*S*k] + flat_kv = compressed_kv.reshape(batch * compressed_len, self.head_dim) + gathered = flat_kv.index_select(0, flat_indices).view(batch, 1, -1, self.head_dim) # [B, 1, S*k, D] + + # Per-query block bias: query `t` may only see the cache entries that are <= `seq_len // m` + # and in these, only the ones marked valid by the indexer. Everything else is `-inf`. + # While the above negated the indexer, here we apply the "causal" masking. + block_bias = gathered.new_full((batch, 1, seq_len, seq_len, top_k), float("-inf")) + allowed = torch.where(valid, gathered.new_zeros(()), gathered.new_full((), float("-inf"))) # [B, S, k] + query_indices = torch.arange(seq_len, device=gathered.device) + block_bias[:, 0, query_indices, query_indices, :] = allowed # diagonal: q_idx == block_idx + block_bias = block_bias.view(batch, 1, seq_len, seq_len * top_k) + return gathered, block_bias + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float, + dropout: float | int = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) + combined_logits = torch.cat([attn_weights, sinks], dim=-1) + + # This was not in the original implementation and slightly affect results; it prevents overflow in BF16/FP16 + # when training with bsz>1 we clamp max values. + + combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values + probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype) + scores = probs[..., :-1] # we drop the sink here + attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training).to(value_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +COMPRESSOR_CLASSES = { + "sliding_attention": None, + "compressed_sparse_attention": DeepseekV4CSACompressor_PT, + "heavily_compressed_attention": DeepseekV4HCACompressor_PT, +} + + +class DeepseekV4Attention_PT(nn.Module): + r""" + Diff with classic attentions: + * Shared-KV Multi-Query Attention: `num_key_value_heads = 1`; `kv_proj` projects + directly to that single KV head and the same tensor is read as both key and + value. + * Partial RoPE on the first `rope_head_dim` of each head ("Partial Rotary + Positional Embedding"). RoPE is also applied with position `-i` to the + attention output's rope slice, so the contribution of each KV entry stays a + function of the *relative* distance to the query. + * Per-head learnable attention sink like gpt OSS. + * Grouped low-rank output projection for perfs. + * 3 different cache mechanisms, sliding, sliding+CSA, sliding+HCA. + """ + + def __init__(self, config: DeepseekV4Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.layer_type = config.layer_types[layer_idx] + # Sliding-only layers use the "main" (plain θ=10000) rope; CSA/HCA layers + # share the same yarn-scaled "compress" rope as their compressor. + self.rope_layer_type = "main" if self.layer_type == "sliding_attention" else "compress" + self.num_heads = config.num_attention_heads + self.num_key_value_groups = config.num_attention_heads # single KV head, broadcast to all + self.head_dim = config.head_dim + self.sliding_window = config.sliding_window + self.attention_dropout = config.attention_dropout + self.is_causal = True + self.scaling = self.head_dim**-0.5 + + self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=False) + self.q_a_norm = DeepseekV4RMSNorm_PT(config.q_lora_rank, eps=config.rms_norm_eps) + self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.head_dim, bias=False) + self.q_b_norm = DeepseekV4UnweightedRMSNorm_PT(eps=config.rms_norm_eps) + self.kv_proj = nn.Linear(config.hidden_size, self.head_dim, bias=False) + self.kv_norm = DeepseekV4RMSNorm_PT(self.head_dim, eps=config.rms_norm_eps) + self.o_a_proj = DeepseekV4GroupedLinear_PT( + self.num_heads * self.head_dim // config.o_groups, config.o_groups * config.o_lora_rank, config.o_groups + ) + self.o_b_proj = nn.Linear(config.o_groups * config.o_lora_rank, config.hidden_size, bias=False) + self.sinks = nn.Parameter(torch.empty(self.num_heads)) + self.compressor = COMPRESSOR_CLASSES[self.layer_type](config) if self.layer_type != "sliding_attention" else None + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + position_ids: torch.Tensor, + attention_mask: torch.Tensor | None, + past_key_values: Cache | None = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + cos, sin = position_embeddings + + q_residual = self.q_a_norm(self.q_a_proj(hidden_states)) + q = self.q_b_proj(q_residual).view(*hidden_shape).transpose(1, 2) + q = self.q_b_norm(q) + q = apply_rotary_pos_emb_PT(q, cos, sin) + + kv = self.kv_norm(self.kv_proj(hidden_states)).view(*hidden_shape).transpose(1, 2) + kv = apply_rotary_pos_emb_PT(kv, cos, sin) + + if past_key_values is not None: # sliding where K==V + kv = past_key_values.update(kv, kv, self.layer_idx)[0] + + block_bias = None + if self.compressor is not None: # Compressed KV (CSA or HCA) + compressed_kv, block_bias = self.compressor( + hidden_states, q_residual, position_ids, past_key_values, self.layer_idx + ) + kv = torch.cat([kv, compressed_kv], dim=2) + + # compressor returns a `block_bias` carrying per-query causality + indexer + # selections, which needs to be concatenated to the right of `attention_mask`. + # Eager/flash interfaces consume the combined mask directly. + if isinstance(attention_mask, torch.Tensor): + if block_bias is not None: + attention_mask = torch.cat([attention_mask, block_bias.to(attention_mask.dtype)], dim=-1) + elif kv.shape[2] > attention_mask.shape[-1]: + attention_mask = F.pad(attention_mask, (0, kv.shape[2] - attention_mask.shape[-1]), value=0.0) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + attn_output, attn_weights = attention_interface( + self, + q, + kv, + kv, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + s_aux=self.sinks, + **kwargs, + ) + + # K=V in V4, so V picked up rope on its trailing rope slice. Apply the conjugate + # rotation (`-sin`) at the query position to undo it on the rope slice of the + # output before the grouped output projection mixes heads. The transpose pair is + # just a layout fix-up: apply_rotary_pos_emb expects `[B, S, H, D]` (its + # `unsqueeze_dim=1` adds a head-broadcast dim to cos/sin); attention gave us + # `[B, H, S, D]`. + attn_output = apply_rotary_pos_emb_PT(attn_output.transpose(1, 2), cos, -sin).transpose(1, 2) + + grouped = attn_output.reshape(*input_shape, self.config.o_groups, -1) + grouped = self.o_a_proj(grouped).flatten(2) + output = self.o_b_proj(grouped) + return output, attn_weights + + +# ============================================================================== +# 2.2 PyTorch Decoder Reference Blocks +# ============================================================================== + + +class GradientCheckpointingLayer_PT(nn.Module): + pass + + +class DeepseekV4HyperConnection_PT(nn.Module): + r""" + Manifold-Constrained Hyper-Connections + (mHC) (Xie et al., 2026) to strengthen the conventional residual connections between adjacent + Transformer blocks + + Owns the learned (`fn`, `base`, `scale`) + parameters that turn the incoming `hc_mult` residual streams into collapse / expand + weights. The decoder layer instantiates two of these (one for the attention site, + one for the mlp site). + + ASCII shape guide — `B` = batch, `S` = seq, `H` = hc_mult, `D` = hidden_size:: + + hidden_streams flatten(2) RMSNorm-rescale + F.linear(fn) + [B, S, H, D] ──────────► [B, S, H*D] ─────────────────────────────────► + mix-logits + [B, S, (2+H)*H] + │ + ┌───────────────────────────────────────┴──────────────────────────────┐ + ▼ ▼ ▼ + pre logits post logits comb logits + [B, S, H] [B, S, H] [B, S, H, H] + × scale[0] × scale[1] × scale[2] + + base[:H] + base[H:2H] + base[2H:] + σ() + eps σ() + eps σ() + eps + │ │ │ + pre post Sinkhorn(iters) + (stream collapse weights) (block-output placement) row/col normalise + │ + comb + (stream mixer) + """ + + def __init__(self, config: DeepseekV4Config): + super().__init__() + self.hc_mult = config.hc_mult + self.hc_sinkhorn_iters = config.hc_sinkhorn_iters + self.hc_eps = config.hc_eps + self.input_norm = DeepseekV4UnweightedRMSNorm_PT(eps=config.rms_norm_eps) + mix = (2 + self.hc_mult) * self.hc_mult + self.fn = nn.Parameter(torch.empty(mix, self.hc_mult * config.hidden_size)) + self.base = nn.Parameter(torch.empty(mix)) + # 3 = number of outputs from the mHC mapping: `pre` (input projection + # weights), `post` (sublayer output projection weights), `comb` (the + # H×H residual combine matrix that gets Sinkhorn-projected onto the + # doubly-stochastic manifold). Each output gets its own learned scale. + self.scale = nn.Parameter(torch.empty(3)) + + def forward(self, hidden_streams: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + r""" + Compute `pre`, `post`, `comb` from the mHC mapping (paper §2.2 eq. 8). + `comb` is projected onto the doubly-stochastic manifold via Sinkhorn- + Knopp: starting from the sigmoid-positive matrix, alternate row and + column normalisation for `hc_sinkhorn_iters` steps. `pre` then collapses + the `hc_mult` parallel streams into a single sequence (input projection + into the sublayer); `post` and `comb` are returned for the caller to + apply on the sublayer output. + """ + hc = self.hc_mult + flat = self.input_norm(hidden_streams.flatten(start_dim=2).float()) + pre_w, post_w, comb_w = F.linear(flat, self.fn.float()).split([hc, hc, hc * hc], dim=-1) + pre_b, post_b, comb_b = self.base.split([hc, hc, hc * hc]) + pre_scale, post_scale, comb_scale = self.scale.unbind(0) + + pre = torch.sigmoid(pre_w * pre_scale + pre_b) + self.hc_eps + post = 2 * torch.sigmoid(post_w * post_scale + post_b) + comb_logits = comb_w.view(*comb_w.shape[:-1], hc, hc) * comb_scale + comb_b.view(hc, hc) + comb = torch.softmax(comb_logits, dim=-1) + self.hc_eps + comb = comb / (comb.sum(dim=-2, keepdim=True) + self.hc_eps) + for _ in range(self.hc_sinkhorn_iters - 1): + comb = comb / (comb.sum(dim=-1, keepdim=True) + self.hc_eps) + comb = comb / (comb.sum(dim=-2, keepdim=True) + self.hc_eps) + # Collapse the `hc_mult` parallel streams down to a single sequence using + # the `pre` weights: one weighted sum across the stream axis, ready for + # the sublayer (attn / MLP). + collapsed = (pre.unsqueeze(-1) * hidden_streams).sum(dim=2).to(hidden_streams.dtype) + return post, comb, collapsed + + +DeepseekV4UnweightedRMSNorm = DeepseekV4UnweightedRMSNorm_PT + + +class DeepseekV4HyperHead_PT(nn.Module): + + def __init__(self, config: DeepseekV4Config): + super().__init__() + self.hc_mult = config.hc_mult + self.input_norm = DeepseekV4UnweightedRMSNorm(eps=config.rms_norm_eps) + self.eps = config.hc_eps + self.hc_fn = nn.Parameter(torch.empty(self.hc_mult, self.hc_mult * config.hidden_size)) + self.hc_base = nn.Parameter(torch.empty(self.hc_mult)) + self.hc_scale = nn.Parameter(torch.empty(1)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + flat = self.input_norm(x.flatten(2).float()) + mixes = F.linear(flat, self.hc_fn.float()) + pre = torch.sigmoid(mixes * self.hc_scale.float() + self.hc_base.float()) + self.eps + return (pre.unsqueeze(-1) * x).sum(dim=2).to(x.dtype) + + +class DeepseekV4MLP_PT(nn.Module): + + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +@use_experts_implementation +class DeepseekV4Experts_PT(nn.Module): + """Collection of expert weights stored as 3D tensors.""" + + def __init__(self, config: DeepseekV4Config): + super().__init__() + self.num_experts = config.num_local_experts + self.hidden_dim = config.hidden_size + self.intermediate_dim = config.intermediate_size + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) + self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) + self.act_fn = ACT2FN[config.hidden_act] + self.limit = config.swiglu_limit + + def forward(self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor) -> torch.Tensor: + final = torch.zeros_like(hidden_states) + with torch.no_grad(): + mask = F.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) + hit = torch.greater(mask.sum(dim=(-1, -2)), 0).nonzero() + for expert_idx in hit: + expert_idx = expert_idx[0] + if expert_idx == self.num_experts: + continue + top_k_pos, token_idx = torch.where(mask[expert_idx]) + current = self._apply_gate(F.linear(hidden_states[token_idx], self.gate_up_proj[expert_idx])) + current = F.linear(current, self.down_proj[expert_idx]) * top_k_weights[token_idx, top_k_pos, None] + final.index_add_(0, token_idx, current.to(final.dtype)) + return final + + def _apply_gate(self, gate_up: torch.Tensor) -> torch.Tensor: + gate, up = gate_up.chunk(2, dim=-1) + gate = gate.clamp(max=self.limit) + up = up.clamp(min=-self.limit, max=self.limit) + return self.act_fn(gate) * up + + +class DeepseekV4TopKRouter_PT(nn.Module): + + def __init__(self, config: DeepseekV4Config): + super().__init__() + self.top_k = config.num_experts_per_tok + self.num_experts = config.num_local_experts + self.hidden_dim = config.hidden_size + self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim)) + self.score_fn = ACT2FN[config.scoring_func] + self.routed_scaling_factor = config.routed_scaling_factor + self.register_buffer("e_score_correction_bias", torch.zeros(self.num_experts), persistent=True) + + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + flat = hidden_states.reshape(-1, self.hidden_dim) + logits = F.linear(flat, self.weight) + scores = self.score_fn(logits) + indices = torch.topk(scores + self.e_score_correction_bias, self.top_k, dim=-1, sorted=False).indices + weights = scores.gather(1, indices) + weights = weights / (weights.sum(dim=-1, keepdim=True) + 1e-20) + return logits, weights * self.routed_scaling_factor, indices + + +class DeepseekV4HashRouter_PT(nn.Module): + r""" + Hash routing for the first `mlp_layer_types == "hash_moe"` MoE layers (paper + §2.1). Expert selection is determined by a fixed `tid2eid[input_ids]` lookup — + a frozen token-id → expert-id table — instead of a learned argmax. The learned + gate `weight` still produces the per-expert scores that weight the selected + experts' activations; only the *which-experts* selection is static. + """ + + def __init__(self, config: DeepseekV4Config): + super().__init__() + self.top_k = config.num_experts_per_tok + self.num_experts = config.num_local_experts + self.hidden_dim = config.hidden_size + self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim)) + self.score_fn = ACT2FN[config.scoring_func] + self.routed_scaling_factor = config.routed_scaling_factor + self.register_buffer("tid2eid", torch.zeros(config.vocab_size, self.top_k, dtype=torch.long), persistent=True) + + def forward( + self, hidden_states: torch.Tensor, input_ids: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + flat = hidden_states.reshape(-1, self.hidden_dim) + logits = F.linear(flat, self.weight) + scores = self.score_fn(logits) + indices = self.tid2eid[input_ids.reshape(-1)].long() + weights = scores.gather(1, indices) + weights = weights / (weights.sum(dim=-1, keepdim=True) + 1e-20) + return logits, weights * self.routed_scaling_factor, indices + + +class DeepseekV4SparseMoeBlock_PT(nn.Module): + + def __init__(self, config: DeepseekV4Config, layer_idx: int): + super().__init__() + self.is_hash = config.mlp_layer_types[layer_idx] == "hash_moe" + self.gate = DeepseekV4HashRouter_PT(config) if self.is_hash else DeepseekV4TopKRouter_PT(config) + self.experts = DeepseekV4Experts_PT(config) + self.shared_experts = DeepseekV4MLP_PT(config) + + def forward(self, hidden_states: torch.Tensor, input_ids: torch.Tensor | None = None) -> torch.Tensor: + batch, seq_len, hidden_dim = hidden_states.shape + residual = hidden_states + flat = hidden_states.view(-1, hidden_dim) + if self.is_hash: + _, weights, indices = self.gate(hidden_states, input_ids) + else: + _, weights, indices = self.gate(hidden_states) + routed = self.experts(flat, indices, weights).view(batch, seq_len, hidden_dim) + return routed + self.shared_experts(residual) + + +class DeepseekV4DecoderLayer_PT(GradientCheckpointingLayer_PT): + r"""DeepSeek-V4 decoder block (paper §2). Differs from a classic residual block in + two places: + + The residual is a stack of `hc_mult` parallel streams kept in shape + `[B, S, hc_mult, D]` throughout the block, mixed in and out via two + :class:`DeepseekV4HyperConnection` modules (Manifold-Constrained Hyper- + Connections / mHC, paper §2.2; Xie et al., 2026). The mHC mappings constrain + the residual transform to the manifold of doubly-stochastic matrices via the + Sinkhorn-Knopp projection — making signal propagation non-expansive across + deep stacks. + + """ + + def __init__(self, config: DeepseekV4Config, layer_idx: int): + super().__init__() + self.layer_idx = layer_idx + self.self_attn = DeepseekV4Attention_PT(config, layer_idx) + self.mlp = DeepseekV4SparseMoeBlock_PT(config, layer_idx) + self.input_layernorm = DeepseekV4RMSNorm_PT(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = DeepseekV4RMSNorm_PT(config.hidden_size, eps=config.rms_norm_eps) + self.attn_hc = DeepseekV4HyperConnection_PT(config) + self.ffn_hc = DeepseekV4HyperConnection_PT(config) + + def forward( + self, + hidden_states: torch.Tensor, + input_ids: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: + # hidden_states throughout: [B, S, hc_mult, hidden]. + # `post` / `comb` come out of the HC modules in fp32 (Sinkhorn projection runs + # in float); the .to(dtype) puts everything back to the input dtype before mixing + # so both sites stay consistent with `hidden_states`'s entry dtype. + # comb is consumed transposed: indexed as sum_j comb[j, k] * residual[j, d] + # (sum over the FIRST hc axis), equivalent to comb.T @ residual. Sinkhorn + # produces a doubly-stochastic but non-symmetric matrix, so the direction matters. + dtype = hidden_states.dtype + post, comb, collapsed = self.attn_hc(hidden_states) + attn_output, _ = self.self_attn(self.input_layernorm(collapsed), **kwargs) + hidden_states = post.to(dtype).unsqueeze(-1) * attn_output.unsqueeze(-2) + torch.matmul( + comb.to(dtype).transpose(-1, -2), hidden_states + ) + + post, comb, collapsed = self.ffn_hc(hidden_states) + mlp_output = self.mlp(self.post_attention_layernorm(collapsed), input_ids=input_ids) + return post.to(dtype).unsqueeze(-1) * mlp_output.unsqueeze(-2) + torch.matmul( + comb.to(dtype).transpose(-1, -2), hidden_states + ) + + +def _make_config(config_pt, B, S, D, **kwargs): + """Return a pyconfig Config object suitable for unit tests.""" + kwargs.pop("layer_types", None) + kwargs.pop("attention_type", None) + num_heads = kwargs.pop("num_attention_heads", config_pt.num_attention_heads) + overrides = { + "run_name": "test_run", + "enable_checkpointing": False, + "model_name": "deepseek_v4-tiny", + "decoder_block": "deepseek_v4", + "dtype": "float32", + "weight_dtype": "float32", + "matmul_precision": "highest", + "per_device_batch_size": B, + "max_target_length": S, + "max_prefill_predict_length": S, + "emb_dim": D, + "mhc_expansion_rate": getattr(config_pt, "hc_mult", 4), + "hc_eps": getattr(config_pt, "hc_eps", 1e-6), + "sinkhorn_iterations": getattr(config_pt, "hc_sinkhorn_iters", 20), + "normalization_layer_epsilon": 1e-6, + "head_dim": config_pt.head_dim, + "dropout_rate": 0.0, + "o_groups": config_pt.o_groups, + "o_lora_rank": config_pt.o_lora_rank, + "compress_ratios": [4] * 43, + "compress_rope_theta": 160000.0, + "sliding_window": config_pt.sliding_window, + "index_n_heads": config_pt.index_n_heads, + "index_head_dim": config_pt.index_head_dim, + "index_topk": config_pt.index_topk, + "base_num_query_heads": num_heads, + "q_lora_rank": config_pt.q_lora_rank, + "qk_rope_head_dim": getattr(config_pt, "qk_rope_head_dim", 64), + "routed_score_func": getattr(config_pt, "scoring_func", "sqrtsoftplus"), + "num_hash_layers": 43, + "rope_max_timescale": config_pt.rope_theta, + "rope_type": "default", + "max_position_embeddings": config_pt.max_position_embeddings, + "shard_mode": "auto", + "debug_sharding": False, + "scan_layers": False, + "remat_policy": "full", + "num_vocab_tiling": 1, + "base_mlp_dim": config_pt.moe_intermediate_size, + "mlp_activations": ["silu"], + "fused_mlp": False, + "megablox": False, + "sparse_matmul": False, + "use_gather_mosaic_kernel": False, + "load_balance_loss_weight": 0.0, + "routed_bias": False, + "dense_init_scale": 1.0, + "moe_expert_input_dim": -1, + "num_experts": 16, + "num_experts_per_tok": 1, + "mlp_bias": False, + "float32_gate_logits": False, + "use_random_routing": False, + "routed_scaling_factor": 1.0, + "attention": "dot_product", + "shared_experts": 1, + "base_moe_mlp_dim": config_pt.moe_intermediate_size, + "vocab_size": getattr(config_pt, "vocab_size", 128), + **kwargs, + } + extra_args = get_decoupled_parallelism_overrides() + merged = {**overrides, **extra_args} + cfg = pyconfig.initialize([sys.argv[0], get_test_config_path()], override_model_config=True, **merged) + if not hasattr(cfg, "trainable_position_size"): + cfg.trainable_position_size = 0 + if not hasattr(cfg, "original_max_position_embeddings"): + cfg.original_max_position_embeddings = cfg.max_position_embeddings + return cfg + + +class DeepSeekV4ParityTest(unittest.TestCase): + + def test_unweighted_rms_norm_parity(self): + # Generate identical input vectors across frameworks. + # Setting a deterministic seed ensures mathematically identical input distributions. + np.random.seed(42) + x_np = np.random.randn(4, 64, 512).astype(np.float32) + + x_torch = torch.tensor(x_np) + x_jax = jnp.array(x_np) + + # Execute PyTorch reference unweighted RMS normalization. + # Unweighted RMSNorm contains no learnable parameters. Epsilon is set to 1e-6. + torch_model = DeepseekV4UnweightedRMSNorm_PT(eps=1e-6) + out_torch = torch_model(x_torch).detach().numpy() + + # Execute JAX equivalent target unweighted RMS normalization. + # Target module instantiated from top-level imports to optimize namespace lookup. + jax_model = DeepSeekV4UnweightedRMSNorm(eps=1e-6) + out_jax = jax_model(x_jax) + + # Compare outputs within numerical precision tolerance limits. + np.testing.assert_allclose(out_torch, out_jax, atol=1e-5, rtol=1e-5) + + def test_rms_norm_parity(self): + # Generate identical input and scaling weights across frameworks. + # Identical weight assignments ensure learnable scale features are verified equivalently. + np.random.seed(42) + x_np = np.random.randn(4, 64, 512).astype(np.float32) + weight_np = np.random.randn(512).astype(np.float32) + + x_torch = torch.tensor(x_np) + x_jax = jnp.array(x_np) + + # Execute PyTorch reference RMS normalization. + torch_model = DeepseekV4RMSNorm_PT(hidden_size=512, eps=1e-6) + # Copy matching scale weights to the reference model parameter state. + torch_model.weight.data.copy_(torch.tensor(weight_np)) + out_torch = torch_model(x_torch).detach().numpy() + + # Execute JAX equivalent target RMS normalization. + # JAX model state parameters are explicitly updated to match the generated weights. + jax_model = DeepSeekV4RMSNorm(hidden_size=512, eps=1e-6) + jax_model.weight.value = jnp.array(weight_np) + out_jax = jax_model(x_jax) + + # Assert numerical parity between output states. + np.testing.assert_allclose(out_torch, out_jax, atol=1e-5, rtol=1e-5) + + def test_rotary_embedding_parity(self): + # Generate identical input sequences, positional values, and batch layouts. + # The sequence is constructed to test interleaved rotary mappings and broadcasting. + np.random.seed(42) + B, S, H, D = 4, 64, 8, 512 + x_np = np.random.randn(B, S, H, D).astype(np.float32) + position_ids_np = np.random.randint(0, 1000, size=(B, S)).astype(np.int64) + + x_torch = torch.tensor(x_np) + position_ids_torch = torch.tensor(position_ids_np) + + x_jax = jnp.array(x_np) + position_ids_jax = jnp.array(position_ids_np) + + # Setup configuration parameters. + config = DeepseekV4Config() + + # Execute PyTorch reference rotary embeddings. + # PyTorch default layout is [B, H, S, D], which requires inputs to be transposed + # prior to calling the embedding layer, then transposed back to native [B, S, H, D]. + torch_emb = DeepseekV4RotaryEmbedding_PT(config) + cos_torch, sin_torch = torch_emb(x_torch, position_ids_torch, layer_type="main") + + x_torch_transposed = x_torch.transpose(1, 2) # [B, H, S, D] + out_torch = apply_rotary_pos_emb_PT(x_torch_transposed, cos_torch, sin_torch, unsqueeze_dim=1) + out_torch_np = out_torch.transpose(1, 2).detach().numpy() # [B, S, H, D] + + # Execute JAX equivalent target rotary embeddings. + # The target JAX layer operates natively on [B, S, H, D] layouts, applying + # dimensional unsqueezing at axis 2 to broadcast across heads. + jax_emb = DeepSeekV4RotaryEmbedding(head_dim=D, partial_rotary_factor=64.0 / 512.0, rope_theta=10000.0) + cos_jax, sin_jax = jax_emb(x_jax, position_ids_jax) + + # Execute JAX target application. + out_jax = apply_rotary_pos_emb(x_jax, cos_jax, sin_jax, unsqueeze_dim=2) + out_jax_np = np.array(out_jax) + + # Compare both the intermediate cos/sin sinusoids and the final rotated values. + np.testing.assert_allclose(cos_torch.detach().numpy(), cos_jax, atol=1e-5, rtol=1e-5) + np.testing.assert_allclose(sin_torch.detach().numpy(), sin_jax, atol=1e-5, rtol=1e-5) + np.testing.assert_allclose(out_torch_np, out_jax_np, atol=1e-5, rtol=1e-5) + + def test_grouped_linear_parity(self): + # Generate identical input arrays and weight matrices across frameworks. + # Segmented group dimensions and feature boundaries map directly. + np.random.seed(42) + B, S, g, i, o = 2, 8, 4, 128, 256 + out_features_per_group = o // g + + # Input shape layout is [B, S, g, i] + x_np = np.random.randn(B, S, g, i).astype(np.float32) + # PyTorch standard linear weight layout is [o, i] + weight_np = np.random.randn(o, i).astype(np.float32) + + x_torch = torch.tensor(x_np) + x_jax = jnp.array(x_np) + + # Execute PyTorch reference grouped linear block projection. + torch_model = DeepseekV4GroupedLinear_PT( + in_features_per_group=i, + out_features=o, + n_groups=g, + bias=False, + ) + torch_model.weight.data.copy_(torch.tensor(weight_np)) + out_torch = torch_model(x_torch).detach().numpy() + + # Execute JAX equivalent target grouped linear block projection. + # JAX weights are initialized using the deterministic key context. + rngs = nnx.Rngs(42) + jax_model = DeepSeekGroupedLinear( + in_features_per_group=i, + out_features=o, + n_groups=g, + rngs=rngs, + ) + + # Copy the reshaped and transposed weight matrix matching PyTorch's view mapping + # [o, i] -> [g, o_g, i] -> [g, i, o_g] + jax_model.kernel.value = jnp.array(weight_np.reshape(g, out_features_per_group, i).transpose(0, 2, 1)) + out_jax = jax_model(x_jax) + + # Verify numerical output parity between frameworks + np.testing.assert_allclose(out_torch, out_jax, atol=1e-5, rtol=1e-5) + + def test_hca_compressor_parity(self): + # Configure deterministic seeds for parity reproducibility + np.random.seed(42) + B, S, D, D_head, compress_rate = 2, 128, 512, 256, 32 + + # hidden_states: [B, S, D] + x_np = np.random.randn(B, S, D).astype(np.float32) + positions_np = np.broadcast_to(np.arange(S)[np.newaxis, :], (B, S)).astype(np.int32) + + x_torch = torch.tensor(x_np) + positions_torch = torch.tensor(positions_np, dtype=torch.long) + + x_jax = jnp.array(x_np) + positions_jax = jnp.array(positions_np) + + # Initialize PyTorch configurations matching parameter spaces + config = DeepseekV4Config() + config.hidden_size = D + config.head_dim = D_head + config.qk_rope_head_dim = int(D_head * (64 / 512)) + config.compress_rates["heavily_compressed_attention"] = compress_rate + config.rms_norm_eps = 1e-6 + + # Initialize PyTorch HCA Compressor model + torch_model = DeepseekV4HCACompressor_PT(config) + torch.nn.init.normal_(torch_model.position_bias, std=0.02) + + # Map JAX layer using matching parameters + jax_config = _make_config(config, B, S, D, compress_ratios=[compress_rate] * 43) + + rngs = nnx.Rngs(42) + jax_model = attention_compressed.HCACompressor( + hidden_size=D, + head_dim=D_head, + config=jax_config, + layer_idx=0, + eps=1e-6, + rngs=rngs, + ) + + # Set JAX parameters identical to PyTorch states to guarantee numerical parity + jax_model.kv_proj.kernel[...] = jnp.array(torch_model.kv_proj.weight.detach().numpy().T) + jax_model.gate_proj.kernel[...] = jnp.array(torch_model.gate_proj.weight.detach().numpy().T) + jax_model.position_bias[...] = jnp.array(torch_model.position_bias.detach().numpy()) + jax_model.kv_norm.weight[...] = jnp.array(torch_model.kv_norm.weight.detach().numpy()) + + # Execute PyTorch stateless compressor path + # Shape out_torch: [B, 1, W, D_head] where W = S // compress_rate = 4 + out_torch, block_bias_torch = torch_model( + hidden_states=x_torch, + q_residual=None, + position_ids=positions_torch, + past_key_values=None, + layer_idx=0, + ) + out_torch = out_torch.detach().numpy() + if block_bias_torch is not None: + block_bias_torch = block_bias_torch.detach().numpy() + + # Execute JAX equivalent stateless compressor path + # Shape out_jax: [B, 1, W, D_head] + out_jax, block_bias_jax = jax_model( + hidden_states=x_jax, + position_ids=positions_jax, + ) + out_jax_np = np.array(out_jax) + if block_bias_jax is not None: + block_bias_jax = np.array(block_bias_jax) + + # Validate bit-accurate state outputs matching numerical tolerance thresholds + np.testing.assert_allclose(out_torch, out_jax_np, atol=1e-5, rtol=1e-5) + if block_bias_torch is not None or block_bias_jax is not None: + np.testing.assert_allclose(block_bias_torch, block_bias_jax, atol=1e-5, rtol=1e-5) + + def test_indexer_parity(self): + np.random.seed(42) + B, S, D, D_rank = 2, 128, 512, 1024 + num_heads, index_head_dim, index_topk, compress_rate = 64, 128, 8, 4 + + # hidden_states: [B, S, D] + x_np = np.random.randn(B, S, D).astype(np.float32) + # q_residual: [B, S, D_rank] + q_res_np = np.random.randn(B, S, D_rank).astype(np.float32) + # position_ids: [B, S] + positions_np = np.broadcast_to(np.arange(S)[np.newaxis, :], (B, S)).astype(np.int32) + + x_torch = torch.tensor(x_np) + q_res_torch = torch.tensor(q_res_np) + positions_torch = torch.tensor(positions_np, dtype=torch.long) + + x_jax = jnp.array(x_np) + q_res_jax = jnp.array(q_res_np) + positions_jax = jnp.array(positions_np) + + # Initialize PyTorch indexer configurations + config = DeepseekV4Config() + config.hidden_size = D + config.q_lora_rank = D_rank + config.index_n_heads = num_heads + config.index_head_dim = index_head_dim + config.index_topk = index_topk + config.compress_rates["compressed_sparse_attention"] = compress_rate + config.rms_norm_eps = 1e-6 + + torch_model = DeepseekV4Indexer_PT(config) + torch.nn.init.normal_(torch_model.position_bias, std=0.02) + + # Map JAX equivalent Indexer module + jax_config = _make_config( + config, + B, + S, + D, + index_n_heads=num_heads, + index_head_dim=index_head_dim, + index_topk=index_topk, + compress_ratios=[compress_rate] * 43, + ) + + rngs = nnx.Rngs(42) + jax_model = attention_compressed.DeepSeekV4Indexer( + hidden_size=D, + q_lora_rank=D_rank, + config=jax_config, + layer_idx=0, + eps=1e-6, + rngs=rngs, + ) + + # Synchronize parameter values + jax_model.kv_proj.kernel[...] = jnp.array(torch_model.kv_proj.weight.detach().numpy().T) + jax_model.gate_proj.kernel[...] = jnp.array(torch_model.gate_proj.weight.detach().numpy().T) + jax_model.position_bias[...] = jnp.array(torch_model.position_bias.detach().numpy()) + jax_model.kv_norm.weight[...] = jnp.array(torch_model.kv_norm.weight.detach().numpy()) + jax_model.q_b_proj.kernel[...] = jnp.array(torch_model.q_b_proj.weight.detach().numpy().T) + jax_model.weights_proj.kernel[...] = jnp.array(torch_model.weights_proj.weight.detach().numpy().T) + + # Execute models + out_torch = ( + torch_model( + hidden_states=x_torch, + q_residual=q_res_torch, + position_ids=positions_torch, + past_key_values=None, + layer_idx=0, + ) + .detach() + .numpy() + ) + + out_jax = jax_model( + hidden_states=x_jax, + q_residual=q_res_jax, + position_ids=positions_jax, + ) + out_jax_np = np.array(out_jax) + + # Check mathematical equivalence of top-k selection indices + np.testing.assert_allclose(out_torch, out_jax_np, atol=1e-5, rtol=1e-5) + + def test_csa_compressor_parity(self): + np.random.seed(42) + B, S, D, D_rank, D_head = 2, 128, 512, 1024, 256 + num_heads, index_head_dim, index_topk, compress_rate = 64, 128, 8, 4 + + # Inputs + x_np = np.random.randn(B, S, D).astype(np.float32) + q_res_np = np.random.randn(B, S, D_rank).astype(np.float32) + positions_np = np.broadcast_to(np.arange(S)[np.newaxis, :], (B, S)).astype(np.int32) + + x_torch = torch.tensor(x_np) + q_res_torch = torch.tensor(q_res_np) + positions_torch = torch.tensor(positions_np, dtype=torch.long) + + x_jax = jnp.array(x_np) + q_res_jax = jnp.array(q_res_np) + positions_jax = jnp.array(positions_np) + + # Configurations + config = DeepseekV4Config() + config.hidden_size = D + config.q_lora_rank = D_rank + config.head_dim = D_head + config.qk_rope_head_dim = int(D_head * (64 / 512)) + config.index_n_heads = num_heads + config.index_head_dim = index_head_dim + config.index_topk = index_topk + config.compress_rates["compressed_sparse_attention"] = compress_rate + config.rms_norm_eps = 1e-6 + + torch_model = DeepseekV4CSACompressor_PT(config) + torch.nn.init.normal_(torch_model.position_bias, std=0.02) + torch.nn.init.normal_(torch_model.indexer.position_bias, std=0.02) + + jax_config = _make_config( + config, + B, + S, + D, + index_n_heads=num_heads, + index_head_dim=index_head_dim, + index_topk=index_topk, + compress_ratios=[compress_rate] * 43, + ) + + rngs = nnx.Rngs(42) + jax_model = attention_compressed.CSACompressor( + hidden_size=D, + q_lora_rank=D_rank, + head_dim=D_head, + config=jax_config, + layer_idx=0, + eps=1e-6, + rngs=rngs, + ) + + # Synchronize outer compressor states + jax_model.kv_proj.kernel[...] = jnp.array(torch_model.kv_proj.weight.detach().numpy().T) + jax_model.gate_proj.kernel[...] = jnp.array(torch_model.gate_proj.weight.detach().numpy().T) + jax_model.position_bias[...] = jnp.array(torch_model.position_bias.detach().numpy()) + jax_model.kv_norm.weight[...] = jnp.array(torch_model.kv_norm.weight.detach().numpy()) + + # Synchronize inner indexer states + jax_model.indexer.kv_proj.kernel[...] = jnp.array(torch_model.indexer.kv_proj.weight.detach().numpy().T) + jax_model.indexer.gate_proj.kernel[...] = jnp.array(torch_model.indexer.gate_proj.weight.detach().numpy().T) + jax_model.indexer.position_bias[...] = jnp.array(torch_model.indexer.position_bias.detach().numpy()) + jax_model.indexer.kv_norm.weight[...] = jnp.array(torch_model.indexer.kv_norm.weight.detach().numpy()) + jax_model.indexer.q_b_proj.kernel[...] = jnp.array(torch_model.indexer.q_b_proj.weight.detach().numpy().T) + jax_model.indexer.weights_proj.kernel[...] = jnp.array(torch_model.indexer.weights_proj.weight.detach().numpy().T) + + # Execute + out_torch, block_bias_torch = torch_model( + hidden_states=x_torch, + q_residual=q_res_torch, + position_ids=positions_torch, + past_key_values=None, + layer_idx=0, + ) + out_torch = out_torch.detach().numpy() + if block_bias_torch is not None: + block_bias_torch = block_bias_torch.detach().numpy() + + out_jax, block_bias_jax = jax_model( + hidden_states=x_jax, + q_residual=q_res_jax, + position_ids=positions_jax, + ) + out_jax_np = np.array(out_jax) + if block_bias_jax is not None: + block_bias_jax = np.array(block_bias_jax) + + # Diagnose indexer parity + topk_torch = ( + torch_model.indexer( + hidden_states=x_torch, + q_residual=q_res_torch, + position_ids=positions_torch, + past_key_values=None, + layer_idx=0, + ) + .detach() + .numpy() + ) + + topk_jax = jax_model.indexer( + hidden_states=x_jax, + q_residual=q_res_jax, + position_ids=positions_jax, + ) + topk_jax_np = np.array(topk_jax) + + np.testing.assert_allclose(topk_torch, topk_jax_np, atol=1e-5, rtol=1e-5) + + # Check complete parity of gathered/indexed keys + np.testing.assert_allclose(out_torch, out_jax_np, atol=1e-5, rtol=1e-5) + if block_bias_torch is not None or block_bias_jax is not None: + np.testing.assert_allclose(block_bias_torch, block_bias_jax, atol=1e-5, rtol=1e-5) + + def test_attention_layer_parity(self): + np.random.seed(42) + B, S, D, D_rank, D_head, num_heads = 2, 128, 512, 1024, 256, 16 + compress_rate = 32 + + # Inputs + x_np = np.random.randn(B, S, D).astype(np.float32) + position_ids_np = np.broadcast_to(np.arange(S)[np.newaxis, :], (B, S)).astype(np.int32) + + x_torch = torch.tensor(x_np) + position_ids_torch = torch.tensor(position_ids_np, dtype=torch.long) + + x_jax = jnp.array(x_np) + position_ids_jax = jnp.array(position_ids_np) + + # Configurations + config = DeepseekV4Config() + config.hidden_size = D + config.q_lora_rank = D_rank + config.head_dim = D_head + config.qk_rope_head_dim = int(D_head * (64 / 512)) + config.num_attention_heads = num_heads + config.num_key_value_heads = 1 + config.compress_rates["heavily_compressed_attention"] = compress_rate + config.rms_norm_eps = 1e-6 + config.layer_types = ["heavily_compressed_attention"] * 10 + + # Generate reference position embeddings (cos, sin) + torch_emb = DeepseekV4RotaryEmbedding_PT(config) + cos_torch, sin_torch = torch_emb(x_torch, position_ids_torch, layer_type="compress") + + cos_jax = jnp.array(cos_torch.detach().numpy()) + sin_jax = jnp.array(sin_torch.detach().numpy()) + + # Initialize PyTorch and JAX coordinate attention layers + torch_model = DeepseekV4Attention_PT(config, layer_idx=0) + torch.nn.init.normal_(torch_model.sinks, std=0.02) + if torch_model.compressor is not None: + torch.nn.init.normal_(torch_model.compressor.position_bias, std=0.02) + + jax_config = _make_config( + config, + B, + S, + D, + num_attention_heads=num_heads, + compress_ratios=[compress_rate] * 10, + layer_types=["heavily_compressed_attention"] * 10, + o_groups=config.o_groups, + o_lora_rank=config.o_lora_rank, + ) + + devices = jax.devices() + mesh = Mesh(np.array(devices), ("data",)) + rngs = nnx.Rngs(42) + jax_model = attention_compressed.DeepSeekV4Attention( + hidden_size=D, + q_lora_rank=D_rank, + head_dim=D_head, + num_heads=num_heads, + config=jax_config, + layer_idx=0, + mesh=mesh, + eps=1e-6, + attention_type="heavily_compressed_attention", + rngs=rngs, + ) + + # Copy projections and normalize weights from PyTorch to JAX + jax_model.q_a_proj.kernel[...] = jnp.array(torch_model.q_a_proj.weight.detach().numpy().T) + jax_model.q_a_norm.weight[...] = jnp.array(torch_model.q_a_norm.weight.detach().numpy()) + jax_model.q_b_proj.kernel[...] = jnp.array(torch_model.q_b_proj.weight.detach().numpy().T) + + jax_model.kv_proj.kernel[...] = jnp.array(torch_model.kv_proj.weight.detach().numpy().T) + jax_model.kv_norm.weight[...] = jnp.array(torch_model.kv_norm.weight.detach().numpy()) + + # Handle Grouped Output Projection mapping + w_o_a_np = torch_model.o_a_proj.weight.detach().numpy() + in_features_per_group = num_heads * D_head // config.o_groups + w_o_a_np = w_o_a_np.reshape(config.o_groups, -1, in_features_per_group).transpose(0, 2, 1) + jax_model.o_a_proj.kernel[...] = jnp.array(w_o_a_np) + + jax_model.o_b_proj.kernel[...] = jnp.array(torch_model.o_b_proj.weight.detach().numpy().T) + jax_model.sinks[...] = jnp.array(torch_model.sinks.detach().numpy()) + + # Copy Compressor weights if present + if torch_model.compressor is not None: + jax_model.compressor.kv_proj.kernel[...] = jnp.array(torch_model.compressor.kv_proj.weight.detach().numpy().T) + jax_model.compressor.gate_proj.kernel[...] = jnp.array(torch_model.compressor.gate_proj.weight.detach().numpy().T) + jax_model.compressor.position_bias[...] = jnp.array(torch_model.compressor.position_bias.detach().numpy()) + jax_model.compressor.kv_norm.weight[...] = jnp.array(torch_model.compressor.kv_norm.weight.detach().numpy()) + + # Execute PyTorch attention layer + out_torch, _ = torch_model( + hidden_states=x_torch, + position_embeddings=(cos_torch, sin_torch), + position_ids=position_ids_torch, + attention_mask=None, + ) + out_torch_np = out_torch.detach().numpy() + + # Execute JAX attention layer + out_jax, _ = jax_model( + hidden_states=x_jax, + cos=cos_jax, + sin=sin_jax, + position_ids=position_ids_jax, + attention_mask=None, + ) + out_jax_np = np.array(out_jax) + + # Check complete numerical parity of coordination attention layers + np.testing.assert_allclose(out_torch_np, out_jax_np, atol=1e-5, rtol=1e-5) + + def test_topk_router_parity(self): + # Generate deterministic random inputs for the router comparison. + np.random.seed(42) + B, S, D = 2, 8, 64 + num_experts = 16 + top_k = 6 + routed_scaling_factor = 1.5 + + hidden_states_np = np.random.randn(B, S, D).astype(np.float32) + weight_np = np.random.randn(num_experts, D).astype(np.float32) + e_score_correction_bias_np = np.random.randn(num_experts).astype(np.float32) + + # 1. Setup PyTorch Reference Router + config_pt = DeepseekV4Config( + num_experts_per_tok=top_k, + num_local_experts=num_experts, + hidden_size=D, + routed_scaling_factor=routed_scaling_factor, + scoring_func="sqrtsoftplus", + ) + py_router = DeepseekV4TopKRouter_PT(config_pt) + py_router.weight.data = torch.tensor(weight_np) + py_router.e_score_correction_bias.data = torch.tensor(e_score_correction_bias_np) + + # Run forward on PyTorch router + hidden_states_torch = torch.tensor(hidden_states_np) + py_logits, py_weights, py_indices = py_router(hidden_states_torch) + + # 2. Setup JAX/Flax NNX Equivalent Router + class MockJaxConfig: + + def __init__(self): + self.num_experts_per_tok = top_k + self.num_experts = num_experts + self.emb_dim = D + self.moe_expert_input_dim = D + self.routed_scaling_factor = routed_scaling_factor + self.routed_score_func = "sqrtsoftplus" + self.dtype = jnp.float32 + self.weight_dtype = jnp.float32 + + config_jax = MockJaxConfig() + rngs = nnx.Rngs(42) + jax_router = DeepSeekV4TopKRouter(config=config_jax, mesh=None, rngs=rngs) + + # Copy weight and correction bias parameters using Flax NNX attribute variable assignments. + jax_router.kernel[...] = jnp.array(weight_np.T) + jax_router.e_score_correction_bias[...] = jnp.array(e_score_correction_bias_np) + + # Run forward on JAX router + hidden_states_jax = jnp.array(hidden_states_np) + jax_logits, jax_weights, jax_indices = jax_router(hidden_states_jax) + + # 3. Parity assertions + # Compare raw logits directly. + np.testing.assert_allclose(py_logits.detach().numpy(), jax_logits, atol=1e-5, rtol=1e-5) + + # Symmetrically, the order of the chosen top-k experts can differ (unsorted vs JAX sort). + # Sort both index selections and weight selections row-by-row (token-by-token) before comparison. + py_ind_np = py_indices.numpy() + py_w_np = py_weights.detach().numpy() + jax_ind_np = np.array(jax_indices) + jax_w_np = np.array(jax_weights) + + # Sort index arrays row-by-row, and order the corresponding weights array matching the index sort order. + for i in range(py_ind_np.shape[0]): + py_sort_order = np.argsort(py_ind_np[i]) + py_ind_np[i] = py_ind_np[i][py_sort_order] + py_w_np[i] = py_w_np[i][py_sort_order] + + jax_sort_order = np.argsort(jax_ind_np[i]) + jax_ind_np[i] = jax_ind_np[i][jax_sort_order] + jax_w_np[i] = jax_w_np[i][jax_sort_order] + + # Assert sorted indices and weights are mathematically identical! + np.testing.assert_array_equal(jax_ind_np, py_ind_np) + np.testing.assert_allclose(py_w_np, jax_w_np, atol=1e-5, rtol=1e-5) + + def test_hash_router_parity(self): + # Generate deterministic random inputs for static hash router comparison. + np.random.seed(42) + B, S, D = 2, 8, 64 + num_experts = 16 + top_k = 6 + routed_scaling_factor = 1.5 + vocab_size = 32 + + hidden_states_np = np.random.randn(B, S, D).astype(np.float32) + input_ids_np = np.random.randint(0, vocab_size, size=(B, S)).astype(np.int32) + weight_np = np.random.randn(num_experts, D).astype(np.float32) + tid2eid_np = np.random.randint(0, num_experts, size=(vocab_size, top_k)).astype(np.int32) + + # 1. Setup PyTorch Reference Router + config_pt = DeepseekV4Config( + num_experts_per_tok=top_k, + num_local_experts=num_experts, + hidden_size=D, + routed_scaling_factor=routed_scaling_factor, + vocab_size=vocab_size, + scoring_func="sqrtsoftplus", + ) + py_router = DeepseekV4HashRouter_PT(config_pt) + py_router.weight.data = torch.tensor(weight_np) + py_router.tid2eid.data = torch.tensor(tid2eid_np).long() + + # Run forward on PyTorch router + hidden_states_torch = torch.tensor(hidden_states_np) + input_ids_torch = torch.tensor(input_ids_np) + py_logits, py_weights, py_indices = py_router(hidden_states_torch, input_ids_torch) + + # 2. Setup JAX/Flax NNX Equivalent Router + class MockJaxConfig: + + def __init__(self): + self.num_experts_per_tok = top_k + self.num_experts = num_experts + self.emb_dim = D + self.moe_expert_input_dim = D + self.routed_scaling_factor = routed_scaling_factor + self.routed_score_func = "sqrtsoftplus" + self.vocab_size = vocab_size + self.dtype = jnp.float32 + self.weight_dtype = jnp.float32 + + config_jax = MockJaxConfig() + rngs = nnx.Rngs(42) + jax_router = DeepSeekV4HashRouter(config=config_jax, mesh=None, rngs=rngs) + + # Copy weight and lookup table parameter states using clean Flax NNX assignments. + jax_router.kernel[...] = jnp.array(weight_np.T) + jax_router.tid2eid[...] = jnp.array(tid2eid_np, dtype=jnp.int32) + + # Run forward on JAX router + hidden_states_jax = jnp.array(hidden_states_np) + input_ids_jax = jnp.array(input_ids_np) + jax_logits, jax_weights, jax_indices = jax_router(hidden_states_jax, input_ids_jax) + + # 3. Parity assertions + # Logits, weights, and selected index array checks. + np.testing.assert_allclose(py_logits.detach().numpy(), jax_logits, atol=1e-5, rtol=1e-5) + np.testing.assert_array_equal(jax_indices, py_indices.numpy()) + np.testing.assert_allclose(py_weights.detach().numpy(), jax_weights, atol=1e-5, rtol=1e-5) + + def test_hyperhead_parity(self): + # Verify isolated parametric collapse HyperHead parity E2E! + np.random.seed(42) + B, S, k, D = 2, 4, 4, 128 + x_np = np.random.randn(B, S, k, D).astype(np.float32) + hc_fn_np = np.random.randn(k, k * D).astype(np.float32) + hc_base_np = np.random.randn(k).astype(np.float32) + hc_scale_np = np.random.randn(1).astype(np.float32) + + config_pt = DeepseekV4Config( + hc_mult=k, + hidden_size=D, + rms_norm_eps=1e-6, + hc_eps=1e-6, + ) + py_head = DeepseekV4HyperHead_PT(config_pt) + py_head.hc_fn.data = torch.tensor(hc_fn_np) + py_head.hc_base.data = torch.tensor(hc_base_np) + py_head.hc_scale.data = torch.tensor(hc_scale_np) + + # Run forward on PyTorch reference + x_torch = torch.tensor(x_np) + out_torch = py_head(x_torch) + + # Setup JAX DeepSeekV4HyperHead equivalent NNX module + class MockJaxConfig: + + def __init__(self): + self.emb_dim = D + self.mhc_expansion_rate = k + self.hc_eps = 1e-6 + self.normalization_layer_epsilon = 1e-6 + self.dtype = jnp.float32 + self.weight_dtype = jnp.float32 + self.matmul_precision = "default" + + config_jax = MockJaxConfig() + rngs = nnx.Rngs(42) + jax_head = DeepSeekV4HyperHead(config=config_jax, rngs=rngs) + + # Copy weight matrices and parameter states cleanly + # Shape mappings: + # PyTorch: hc_fn has shape [k, k * D], mixes = F.linear(flat, hc_fn) -> flat @ hc_fn.T + # JAX: hc_fn has shape [k * D, k], mixes = flat @ hc_fn + # Therefore, JAX weight = PyTorch weight.T + jax_head.hc_fn[...] = jnp.array(hc_fn_np.T) + jax_head.hc_base[...] = jnp.array(hc_base_np) + jax_head.hc_scale[...] = jnp.array(hc_scale_np) + + # Run forward passes on identical random batch stream inputs [B, S, k, D] + x_jax = jnp.array(x_np) + out_jax = jax_head(x_jax) + + # Assert bit-accurate numerical parity down to atol=1e-5 E2E! + np.testing.assert_allclose(out_torch.detach().numpy(), np.array(out_jax), atol=1e-5, rtol=1e-5) + + def test_full_model_stack_parity(self): + """Verifies complete, scannable multi-layer decoder stack E2E logits parity. + + This E2E test validates that: + 1. Parallel stream transformations [B, S, hc_mult, D] sequence correctly. + 2. Manifold-Constrained Hyper-Connections (mHC) perform identical Sinkhorn + projections across frameworks. + 3. The JAX scanned compiler (scan_layers = True) constructs and executes + identical stacked loop parameters compared to unrolled modes (scan_layers = False). + """ + np.random.seed(42) + B, S, D, H_mult, vocab_size, num_layers = 2, 8, 128, 4, 32, 3 + + # Generate identical input token IDs across frameworks + input_ids_np = np.random.randint(0, vocab_size, size=(B, S)).astype(np.int32) + position_ids_np = np.broadcast_to(np.arange(S)[np.newaxis, :], (B, S)).astype(np.int32) + input_ids_torch = torch.tensor(input_ids_np).long() + position_ids_torch = torch.tensor(position_ids_np).long() + input_ids_jax = jnp.array(input_ids_np) + + # 1. Build identical configuration configurations + config_pt = DeepseekV4Config() + config_pt.hidden_size = D + config_pt.intermediate_size = 64 + config_pt.moe_intermediate_size = 64 + config_pt.hc_mult = H_mult + config_pt.hc_sinkhorn_iters = 8 + config_pt.rms_norm_eps = 1e-6 + config_pt.vocab_size = vocab_size + config_pt.num_hash_layers = 2 + config_pt.num_local_experts = 4 + config_pt.num_experts_per_tok = 2 + config_pt.num_attention_heads = 4 + config_pt.num_key_value_heads = 1 + config_pt.head_dim = 32 + config_pt.qk_rope_head_dim = 32 + config_pt.rope_parameters["main"]["partial_rotary_factor"] = 1.0 + config_pt.rope_parameters["compress"]["partial_rotary_factor"] = 1.0 + config_pt.q_lora_rank = 64 + config_pt.o_groups = 2 + config_pt.o_lora_rank = 64 + config_pt.index_n_heads = 4 + config_pt.index_head_dim = 32 + config_pt.index_topk = 2 + config_pt.layer_types = ["compressed_sparse_attention", "heavily_compressed_attention", "compressed_sparse_attention"] + config_pt.mlp_layer_types = ["hash_moe", "hash_moe", "topk_moe"] + + class DeepseekV4DecoderStack_PT(nn.Module): + + def __init__(self, config: DeepseekV4Config, num_layers: int): + super().__init__() + self.layers = nn.ModuleList([DeepseekV4DecoderLayer_PT(config, lyr) for lyr in range(num_layers)]) + self.hc_head = DeepseekV4HyperHead_PT(config) + self.norm = DeepseekV4RMSNorm_PT(config.hidden_size, eps=config.rms_norm_eps) + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) + self.logits_dense = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.rotary_emb = DeepseekV4RotaryEmbedding_PT(config) + + def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor) -> torch.Tensor: + y = self.embeddings(input_ids) + y = y.unsqueeze(2).expand(-1, -1, 4, -1) + cos, sin = self.rotary_emb(y[:, :, 0, :], position_ids, layer_type="compress") + for layer in self.layers: + y = layer( + y, input_ids=input_ids, position_embeddings=(cos, sin), position_ids=position_ids, attention_mask=None + ) + collapsed = self.hc_head(y) + normed = self.norm(collapsed) + logits = self.logits_dense(normed) + return logits + + decoder_pt = DeepseekV4DecoderStack_PT(config_pt, num_layers) + + torch.nn.init.normal_(decoder_pt.embeddings.weight, std=0.02) + torch.nn.init.normal_(decoder_pt.norm.weight, std=0.02) + torch.nn.init.normal_(decoder_pt.logits_dense.weight, std=0.02) + torch.nn.init.normal_(decoder_pt.hc_head.hc_fn, std=0.02) + torch.nn.init.normal_(decoder_pt.hc_head.hc_base, std=0.02) + torch.nn.init.normal_(decoder_pt.hc_head.hc_scale, std=0.02) + + for layer_pt in decoder_pt.layers: + for param in [ + layer_pt.attn_hc.fn, + layer_pt.attn_hc.base, + layer_pt.attn_hc.scale, + layer_pt.ffn_hc.fn, + layer_pt.ffn_hc.base, + layer_pt.ffn_hc.scale, + layer_pt.self_attn.q_a_proj.weight, + layer_pt.self_attn.q_a_norm.weight, + layer_pt.self_attn.q_b_proj.weight, + layer_pt.self_attn.kv_proj.weight, + layer_pt.self_attn.kv_norm.weight, + layer_pt.self_attn.o_a_proj.weight, + layer_pt.self_attn.o_b_proj.weight, + layer_pt.self_attn.sinks, + layer_pt.mlp.gate.weight, + layer_pt.mlp.experts.gate_up_proj, + layer_pt.mlp.experts.down_proj, + layer_pt.mlp.shared_experts.gate_proj.weight, + layer_pt.mlp.shared_experts.up_proj.weight, + layer_pt.mlp.shared_experts.down_proj.weight, + layer_pt.input_layernorm.weight, + layer_pt.post_attention_layernorm.weight, + ]: + torch.nn.init.normal_(param, std=0.02) + if layer_pt.self_attn.compressor is not None: + comp_pt = layer_pt.self_attn.compressor + for param in [comp_pt.kv_proj.weight, comp_pt.gate_proj.weight, comp_pt.position_bias, comp_pt.kv_norm.weight]: + torch.nn.init.normal_(param, std=0.02) + if hasattr(comp_pt, "indexer"): + for param in [ + comp_pt.indexer.kv_proj.weight, + comp_pt.indexer.gate_proj.weight, + comp_pt.indexer.position_bias, + comp_pt.indexer.kv_norm.weight, + comp_pt.indexer.q_b_proj.weight, + comp_pt.indexer.weights_proj.weight, + ]: + torch.nn.init.normal_(param, std=0.02) + + logits_torch = decoder_pt(input_ids_torch, position_ids_torch).detach().numpy() + + devices = jax.devices() + mesh = Mesh(np.array(devices), ("data",)) + + for scan_mode in [False, True]: + jax_config = _make_config( + config_pt, + B, + S, + D, + base_num_decoder_layers=num_layers, + logits_via_embedding=False, + logits_dot_in_fp32=True, + parameter_memory_host_offload=False, + param_scan_axis=0, + use_iota_embed=False, + num_experts=config_pt.num_local_experts, + num_experts_per_tok=config_pt.num_experts_per_tok, + num_hash_layers=config_pt.num_hash_layers, + gradient_accumulation_steps=1, + hardware="cpu", + megablox=False, + sparse_matmul=False, + use_gather_mosaic_kernel=False, + num_vocab_tiling=1, + compress_ratios=[4, 128, 4] * 15, + mlp_dim=config_pt.intermediate_size, + num_attention_heads=config_pt.num_attention_heads, + q_lora_rank=config_pt.q_lora_rank, + head_dim=config_pt.head_dim, + o_groups=config_pt.o_groups, + o_lora_rank=config_pt.o_lora_rank, + index_n_heads=config_pt.index_n_heads, + index_head_dim=config_pt.index_head_dim, + index_topk=config_pt.index_topk, + mlp_activations=["silu", "linear"], + scan_layers=scan_mode, + ) + decoder_jax = NNXDecoder(config=jax_config, mesh=mesh, rngs=nnx.Rngs(0)) + + scan_length = (jax_config.num_decoder_layers - jax_config.num_hash_layers) // 2 + + def get_jax_layer(decoder, lyr): + if not scan_mode: + return decoder.layers[lyr] + if lyr < jax_config.num_hash_layers: + return getattr(decoder.pre_layers, f"layers_{lyr}") + elif lyr < jax_config.num_hash_layers + 2 * scan_length: + return getattr(decoder.layers, f"layers_{(lyr - jax_config.num_hash_layers) % 2}") + else: + return getattr( + decoder.post_layers, + f"layers_{lyr - (jax_config.num_hash_layers + 2 * scan_length)}", + ) + + shared_embedding = Embed(vocab_size, D, config=jax_config, mesh=mesh, rngs=nnx.Rngs(0)) + shared_embedding.embedding[...] = jnp.array(decoder_pt.embeddings.weight.detach().numpy()) + decoder_jax.decoder_norm.scale[...] = jnp.array(decoder_pt.norm.weight.detach().numpy()) + decoder_jax.logits_dense.kernel[...] = jnp.array(decoder_pt.logits_dense.weight.detach().numpy().T) + decoder_jax.hc_head.hc_fn[...] = jnp.array(decoder_pt.hc_head.hc_fn.detach().numpy().T) + decoder_jax.hc_head.hc_base[...] = jnp.array(decoder_pt.hc_head.hc_base.detach().numpy()) + decoder_jax.hc_head.hc_scale[...] = jnp.array(decoder_pt.hc_head.hc_scale.detach().numpy()) + + def assign_param(jax_param, pt_value, lyr): + if hasattr(jax_param, "val"): + jax_param[...] = pt_value + else: + is_scanned = scan_mode and (jax_config.num_hash_layers <= lyr < jax_config.num_hash_layers + 2 * scan_length) + if is_scanned: + block_step = (lyr - jax_config.num_hash_layers) // 2 + jax_param[block_step, ...] = pt_value + else: + jax_param[...] = pt_value + + hc = H_mult + for lyr in range(num_layers): + layer_jax, layer_pt = get_jax_layer(decoder_jax, lyr), decoder_pt.layers[lyr] + assign_param( + layer_jax.pre_self_attention_layer_norm.scale, + jnp.array(layer_pt.input_layernorm.weight.detach().numpy()), + lyr, + ) + assign_param( + layer_jax.post_self_attention_layer_norm.scale, + jnp.array(layer_pt.post_attention_layernorm.weight.detach().numpy()), + lyr, + ) + + assign_param( + layer_jax.self_attention.q_a_proj.kernel, + jnp.array(layer_pt.self_attn.q_a_proj.weight.detach().numpy().T), + lyr, + ) + assign_param( + layer_jax.self_attention.q_a_norm.weight, jnp.array(layer_pt.self_attn.q_a_norm.weight.detach().numpy()), lyr + ) + assign_param( + layer_jax.self_attention.q_b_proj.kernel, + jnp.array(layer_pt.self_attn.q_b_proj.weight.detach().numpy().T), + lyr, + ) + assign_param( + layer_jax.self_attention.kv_proj.kernel, jnp.array(layer_pt.self_attn.kv_proj.weight.detach().numpy().T), lyr + ) + assign_param( + layer_jax.self_attention.kv_norm.weight, jnp.array(layer_pt.self_attn.kv_norm.weight.detach().numpy()), lyr + ) + + w_o_a_np = layer_pt.self_attn.o_a_proj.weight.detach().numpy() + in_features_per_group = config_pt.num_attention_heads * config_pt.head_dim // config_pt.o_groups + w_o_a_np = w_o_a_np.reshape(config_pt.o_groups, -1, in_features_per_group).transpose(0, 2, 1) + assign_param(layer_jax.self_attention.o_a_proj.kernel, jnp.array(w_o_a_np), lyr) + + assign_param( + layer_jax.self_attention.o_b_proj.kernel, + jnp.array(layer_pt.self_attn.o_b_proj.weight.detach().numpy().T), + lyr, + ) + assign_param(layer_jax.self_attention.sinks, jnp.array(layer_pt.self_attn.sinks.detach().numpy()), lyr) + + if layer_pt.self_attn.compressor is not None: + comp_pt = layer_pt.self_attn.compressor + comp_jax = layer_jax.self_attention.compressor + assign_param(comp_jax.kv_proj.kernel, jnp.array(comp_pt.kv_proj.weight.detach().numpy().T), lyr) + assign_param(comp_jax.gate_proj.kernel, jnp.array(comp_pt.gate_proj.weight.detach().numpy().T), lyr) + assign_param(comp_jax.position_bias, jnp.array(comp_pt.position_bias.detach().numpy()), lyr) + assign_param(comp_jax.kv_norm.weight, jnp.array(comp_pt.kv_norm.weight.detach().numpy()), lyr) + if hasattr(comp_pt, "indexer"): + assign_param( + comp_jax.indexer.kv_proj.kernel, jnp.array(comp_pt.indexer.kv_proj.weight.detach().numpy().T), lyr + ) + assign_param( + comp_jax.indexer.gate_proj.kernel, jnp.array(comp_pt.indexer.gate_proj.weight.detach().numpy().T), lyr + ) + assign_param(comp_jax.indexer.position_bias, jnp.array(comp_pt.indexer.position_bias.detach().numpy()), lyr) + assign_param(comp_jax.indexer.kv_norm.weight, jnp.array(comp_pt.indexer.kv_norm.weight.detach().numpy()), lyr) + assign_param( + comp_jax.indexer.q_b_proj.kernel, jnp.array(comp_pt.indexer.q_b_proj.weight.detach().numpy().T), lyr + ) + assign_param( + comp_jax.indexer.weights_proj.kernel, + jnp.array(comp_pt.indexer.weights_proj.weight.detach().numpy().T), + lyr, + ) + + moe_pt = layer_pt.mlp + moe_jax = layer_jax.mlp + assign_param(moe_jax.MoeBlock_0.gate.kernel, jnp.array(moe_pt.gate.weight.detach().numpy().T), lyr) + if moe_pt.is_hash: + assign_param( + moe_jax.MoeBlock_0.gate.tid2eid, jnp.array(moe_pt.gate.tid2eid.detach().numpy(), dtype=jnp.int32), lyr + ) + else: + assign_param( + moe_jax.MoeBlock_0.gate.e_score_correction_bias, + jnp.array(moe_pt.gate.e_score_correction_bias.detach().numpy()), + lyr, + ) + + gate_up_np = moe_pt.experts.gate_up_proj.detach().numpy() + intermediate_dim = config_pt.intermediate_size + wi_0_np = gate_up_np[:, :intermediate_dim, :].transpose(0, 2, 1) + wi_1_np = gate_up_np[:, intermediate_dim:, :].transpose(0, 2, 1) + wo_np = moe_pt.experts.down_proj.detach().numpy().transpose(0, 2, 1) + + assign_param(moe_jax.MoeBlock_0.wi_0, jnp.array(wi_0_np), lyr) + assign_param(moe_jax.MoeBlock_0.wi_1, jnp.array(wi_1_np), lyr) + assign_param(moe_jax.MoeBlock_0.wo, jnp.array(wo_np), lyr) + + assign_param( + moe_jax.shared_experts.wi_0.kernel, jnp.array(moe_pt.shared_experts.gate_proj.weight.detach().numpy().T), lyr + ) + assign_param( + moe_jax.shared_experts.wi_1.kernel, jnp.array(moe_pt.shared_experts.up_proj.weight.detach().numpy().T), lyr + ) + assign_param( + moe_jax.shared_experts.wo.kernel, jnp.array(moe_pt.shared_experts.down_proj.weight.detach().numpy().T), lyr + ) + + assign_param(layer_jax.mhc_attention.pre_alpha, jnp.array(layer_pt.attn_hc.fn.detach().numpy()[:hc].T), lyr) + assign_param( + layer_jax.mhc_attention.post_alpha, jnp.array(layer_pt.attn_hc.fn.detach().numpy()[hc : 2 * hc].T), lyr + ) + assign_param(layer_jax.mhc_attention.res_alpha, jnp.array(layer_pt.attn_hc.fn.detach().numpy()[2 * hc :].T), lyr) + assign_param(layer_jax.mhc_attention.pre_beta, jnp.array(layer_pt.attn_hc.base.detach().numpy()[:hc]), lyr) + assign_param( + layer_jax.mhc_attention.post_beta, jnp.array(layer_pt.attn_hc.base.detach().numpy()[hc : 2 * hc]), lyr + ) + assign_param( + layer_jax.mhc_attention.res_beta, + jnp.array(layer_pt.attn_hc.base.detach().numpy()[2 * hc :].reshape(hc, hc)), + lyr, + ) + assign_param(layer_jax.mhc_attention.pre_alpha_scale, jnp.array([layer_pt.attn_hc.scale[0].item()]), lyr) + assign_param(layer_jax.mhc_attention.post_alpha_scale, jnp.array([layer_pt.attn_hc.scale[1].item()]), lyr) + assign_param(layer_jax.mhc_attention.res_alpha_scale, jnp.array([layer_pt.attn_hc.scale[2].item()]), lyr) + + assign_param(layer_jax.mhc_mlp.pre_alpha, jnp.array(layer_pt.ffn_hc.fn.detach().numpy()[:hc].T), lyr) + assign_param(layer_jax.mhc_mlp.post_alpha, jnp.array(layer_pt.ffn_hc.fn.detach().numpy()[hc : 2 * hc].T), lyr) + assign_param(layer_jax.mhc_mlp.res_alpha, jnp.array(layer_pt.ffn_hc.fn.detach().numpy()[2 * hc :].T), lyr) + assign_param(layer_jax.mhc_mlp.pre_beta, jnp.array(layer_pt.ffn_hc.base.detach().numpy()[:hc]), lyr) + assign_param(layer_jax.mhc_mlp.post_beta, jnp.array(layer_pt.ffn_hc.base.detach().numpy()[hc : 2 * hc]), lyr) + assign_param( + layer_jax.mhc_mlp.res_beta, jnp.array(layer_pt.ffn_hc.base.detach().numpy()[2 * hc :].reshape(hc, hc)), lyr + ) + assign_param(layer_jax.mhc_mlp.pre_alpha_scale, jnp.array([layer_pt.ffn_hc.scale[0].item()]), lyr) + assign_param(layer_jax.mhc_mlp.post_alpha_scale, jnp.array([layer_pt.ffn_hc.scale[1].item()]), lyr) + assign_param(layer_jax.mhc_mlp.res_alpha_scale, jnp.array([layer_pt.ffn_hc.scale[2].item()]), lyr) + + if not scan_mode: + y_pt = decoder_pt.embeddings(input_ids_torch) + y_pt = y_pt.unsqueeze(2).expand(-1, -1, 4, -1) + cos_pt, sin_pt = decoder_pt.rotary_emb(y_pt[:, :, 0, :], position_ids_torch, layer_type="compress") + + y_jax = shared_embedding(input_ids_jax.astype("int32"), model_mode="train") + y_jax = jnp.repeat(jnp.expand_dims(y_jax, axis=2), 4, axis=2).astype(y_jax.dtype) + + np.testing.assert_allclose( + y_pt.detach().numpy(), np.array(y_jax), atol=1e-5, rtol=1e-5, err_msg="Embedding mismatch" + ) + + for lyr in range(num_layers): + layer_pt = decoder_pt.layers[lyr] + layer_jax = decoder_jax.layers[lyr] + + y_pt = layer_pt( + y_pt, + input_ids=input_ids_torch, + position_embeddings=(cos_pt, sin_pt), + position_ids=position_ids_torch, + attention_mask=None, + ) + y_jax, _ = layer_jax( + y_jax, + decoder_segment_ids=jnp.zeros((B, S), dtype=jnp.int32), + decoder_positions=jnp.array(position_ids_np, dtype=jnp.int32), + deterministic=True, + model_mode="train", + decoder_input_tokens=input_ids_jax, + ) + + logits_jax, _, _ = decoder_jax( + shared_embedding=shared_embedding, + decoder_input_tokens=input_ids_jax, + decoder_positions=jnp.array(position_ids_np, dtype=jnp.int32), + decoder_segment_ids=jnp.zeros((B, S), dtype=jnp.int32), + deterministic=True, + ) + + np.testing.assert_allclose(logits_torch, np.array(logits_jax), atol=1e-5, rtol=1e-5) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/nnx_decoders_test.py b/tests/unit/nnx_decoders_test.py index 2525a181f1..27eaac303d 100644 --- a/tests/unit/nnx_decoders_test.py +++ b/tests/unit/nnx_decoders_test.py @@ -767,3 +767,41 @@ def test_gemma4_scanned_layers(self): logits.shape, (cfg.global_batch_size_to_train_on, cfg.max_target_length, cfg.vocab_size), ) + + def test_deepseek_v4_scanned_layers(self): + """Test NNXDecoder with deepseek_v4 block and scan_layers=True.""" + cfg = _make_config( + decoder_block="deepseek_v4", + scan_layers=True, + num_decoder_layers=3, + q_lora_rank=1024, + o_lora_rank=1024, + qk_rope_head_dim=64, + compress_ratios=[4, 128, 4], + base_moe_mlp_dim=512, + shared_experts=2, + mhc_expansion_rate=4, + routed_score_func="sqrtsoftplus", + megablox=False, # Disable custom Pallas GMM TPU kernels on CPU testing platforms! + ) + decoder = NNXDecoder( + config=cfg, + mesh=self.mesh, + model_mode=MODEL_MODE_TRAIN, + rngs=self.rngs, + ) + shared_embedding = self._make_shared_embedding(cfg) + ids, segment_ids, positions = self._make_token_inputs(cfg) + + logits, _, _ = decoder( + shared_embedding, + ids, + positions, + decoder_segment_ids=segment_ids, + deterministic=True, + model_mode=MODEL_MODE_TRAIN, + ) + self.assertEqual( + logits.shape, + (cfg.global_batch_size_to_train_on, cfg.max_target_length, cfg.vocab_size), + ) From 559fe54f4a25477ada70bd0372b7adf2428e358d Mon Sep 17 00:00:00 2001 From: Param Bole Date: Wed, 20 May 2026 20:17:46 +0000 Subject: [PATCH 2/2] test(dsv4): disable sa_block_kv hardware grid padding to secure unmasked reference numerical parity --- tests/unit/deepseek_v4_vs_reference_test.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/unit/deepseek_v4_vs_reference_test.py b/tests/unit/deepseek_v4_vs_reference_test.py index cc3d5d8391..eec566f54e 100644 --- a/tests/unit/deepseek_v4_vs_reference_test.py +++ b/tests/unit/deepseek_v4_vs_reference_test.py @@ -1831,6 +1831,12 @@ def test_attention_layer_parity(self): layer_types=["heavily_compressed_attention"] * 10, o_groups=config.o_groups, o_lora_rank=config.o_lora_rank, + # Disabling hardware MXU grid alignment padding (sa_block_kv=0). + # By default, AttentionOp enforces sa_block_kv=512 grid bounds, automatically padding trailing sequence length + # (S=128 + W=32 = 160) to 512 with zero vectors. Under dot-product attention without explicit causal padding masks + # (attention_mask=None), Softmax evaluates unmasked zero vectors to positive probability weightings (e^{0.0} = 1.0), + # artificially inflating the local exponential normalizer sum denominator and distorting numerical parity bounds. + sa_block_kv=0, ) devices = jax.devices() @@ -2253,6 +2259,9 @@ def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor) -> torch. index_topk=config_pt.index_topk, mlp_activations=["silu", "linear"], scan_layers=scan_mode, + # Explicitly disable hardware MXU grid sequence padding (sa_block_kv=0) to ensure dot-product Softmax + # normalization sums match unpadded PyTorch reference bounds precisely without exponential denominator drift. + sa_block_kv=0, ) decoder_jax = NNXDecoder(config=jax_config, mesh=mesh, rngs=nnx.Rngs(0))