Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ topk_routing_group: -1 # number of top groups to route inputs. For EP,
# all-to-all communication with compute. Currently only implemented with DeepSeek sparse layers.
use_batch_split_schedule: False # a flag if splitting batch into micro-batches to hide communications that yields performance benefits.
batch_split_factor: 1 # the factor by which to split the batch. Only used if use_batch_split_schedule is True.
num_hash_layers: 3 # Number of initial MoE layers to apply static Hash Routing.

# For complex architectures like llama4 there are repeated sets of
# inhomogeneous layers. E.g. maverick uses [dense+rope, moe+rope, dense+rope, moe+nope]
Expand Down Expand Up @@ -1227,6 +1228,7 @@ force_q_layout: false
mhc_expansion_rate: 1
# The number of iterations for the Sinkhorn-Knopp algorithm.
sinkhorn_iterations: 20
hc_eps: 1.0e-6

################################## DeepSeek Engram ##################################
# Indices of transformer layers where Engram are integrated; leave empty [] to disable.
Expand Down
78 changes: 78 additions & 0 deletions src/maxtext/configs/models/deepseek_v4-flash.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Default model configs for DeepSeek-V4-Flash (43 Layers)

base_config: base.yml
model_name: deepseek_v4-flash

base_emb_dim: 4096
base_num_query_heads: 64
base_num_kv_heads: 1
head_dim: 512
base_mlp_dim: 2048
base_moe_mlp_dim: 2048
base_num_decoder_layers: 43
first_num_dense_layers: 0
mlp_activations: ["silu"]
vocab_size: 129280
tokenizer_type: "huggingface"
tokenizer_path: "deepseek-ai/DeepSeek-V3"
enable_dropout: False
logits_via_embedding: False
normalization_layer_epsilon: 1.0e-6
num_experts: 256
num_experts_per_tok: 6
shared_experts: 1
routed_scaling_factor: 1.5
routed_score_func: "sqrtsoftplus"
routed_bias: True
norm_topk_prob: True
decoder_block: "deepseek_v4"
pure_nnx_decoder: True
enable_nnx: True
use_tokamax_splash: True
# TODO: Dynamic Flash Attention masking requires use_indexer: True to route the custom compressor mask indexer_mask into Tokamax's make_dynamic_splash_mha dynamic Pallas kernel path. Refactor this MLA configuration parameter out of deep coordinator layers in subsequent milestones.
use_indexer: True
remat_policy: minimal_offloaded

# Manifold-Constrained Hyper-Connection configurations
mhc_expansion_rate: 4
sinkhorn_iterations: 20
compress_rope_theta: 160000.0
index_head_dim: 128
index_n_heads: 64
index_topk: 512
o_groups: 8
o_lora_rank: 1024
sliding_window: 128
num_hash_layers: 3
mlp_activations_limit: 10.0
compress_ratios: [0, 0, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4]

# Compressed Sparse Attention
q_lora_rank: 1024
kv_lora_rank: 512
qk_nope_head_dim: 128
qk_rope_head_dim: 64
v_head_dim: 128
mscale: 1.0

# RoPE
rope_type: "default"
rope_max_timescale: 10_000
max_position_embeddings: 1048576
original_max_position_embeddings: 65536
rope_factor: 16
beta_fast: 32
74 changes: 74 additions & 0 deletions src/maxtext/configs/models/deepseek_v4-tiny.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Tiny version of DeepSeek-V4 (4 Layers) for local sharding and compilation checks.

base_config: base.yml
model_name: deepseek_v4-tiny

base_emb_dim: 128
base_num_query_heads: 16
base_num_kv_heads: 1
head_dim: 32
base_mlp_dim: 128
base_moe_mlp_dim: 128
base_num_decoder_layers: 6
first_num_dense_layers: 0
mlp_activations: ["silu"]
vocab_size: 129280
tokenizer_type: "huggingface"
tokenizer_path: "deepseek-ai/DeepSeek-V3"
enable_dropout: False
logits_via_embedding: False
normalization_layer_epsilon: 1.0e-6
num_experts: 8
num_experts_per_tok: 4
shared_experts: 1
routed_scaling_factor: 1.5
routed_score_func: "sqrtsoftplus"
routed_bias: True
norm_topk_prob: True
decoder_block: "deepseek_v4"
pure_nnx_decoder: True
enable_nnx: True

# Manifold-Constrained Hyper-Connection configurations
mhc_expansion_rate: 4
sinkhorn_iterations: 20
compress_rope_theta: 160000.0
index_head_dim: 32
index_n_heads: 16
index_topk: 64
o_groups: 2
o_lora_rank: 64
sliding_window: 32
num_hash_layers: 3
mlp_activations_limit: 10.0
compress_ratios: [0, 4, 128, 4, 128, 0]

# Compressed Attention
q_lora_rank: 64
kv_lora_rank: 32
qk_nope_head_dim: 32
qk_rope_head_dim: 16
v_head_dim: 128
mscale: 1.0

# RoPE
rope_type: "default"
rope_max_timescale: 10_000
max_position_embeddings: 163840
original_max_position_embeddings: 4096
rope_factor: 40
beta_fast: 32
7 changes: 7 additions & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,8 @@ class ProfilerType(str, Enum):
"deepseek3-test",
"deepseek3-tiny",
"deepseek3.2-671b",
"deepseek_v4-tiny",
"deepseek_v4-flash",
"deepseek-custom",
"kimi-k2-1t",
"gemma-7b",
Expand Down Expand Up @@ -831,6 +833,10 @@ class DeepSeekMoE(BaseModel):
1,
description="Factor by which to split the batch into micro-batches. Only used if use_batch_split_schedule is True.",
)
num_hash_layers: int = Field(
3,
description="Number of initial MoE layers to apply static Hash Routing.",
)


class Qwen3Next(BaseModel):
Expand Down Expand Up @@ -1381,6 +1387,7 @@ class ManifoldConstrainedHyperConnections(BaseModel):

mhc_expansion_rate: PositiveInt = Field(1, description="The number of parallel streams in Hyper Connection.")
sinkhorn_iterations: PositiveInt = Field(20, description="The number of iterations for the Sinkhorn-Knopp algorithm.")
hc_eps: float = Field(1e-6, description="The epsilon fallback value for numerical stability in mHC.")


class DilocoParams(BaseModel):
Expand Down
114 changes: 113 additions & 1 deletion src/maxtext/layers/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

from flax import linen as nn
from flax import nnx
from maxtext.layers import nnx_wrappers
from maxtext.models.deepseek_v4 import DeepSeekV4HyperHead
from flax.linen.partitioning import ScanIn
import jax
from jax.ad_checkpoint import checkpoint_name
Expand All @@ -43,6 +45,7 @@
deepseek,
deepseek_batchsplit,
deepseek_batchsplit_fp8,
deepseek_v4,
gemma,
gemma2,
gemma3,
Expand Down Expand Up @@ -460,6 +463,12 @@ def get_decoder_layers(self):
deepseek.DeepSeekDenseLayerToLinen,
deepseek.DeepSeekMoELayerToLinen,
]
case DecoderBlockType.DEEPSEEK_V4:
return (
[deepseek_v4.DeepSeekV4ScannableBlockToLinen]
if self.config.scan_layers
else [deepseek_v4.DeepSeekV4DecoderLayerToLinen]
)
case DecoderBlockType.GEMMA:
return [gemma.GemmaDecoderLayerToLinen]
case DecoderBlockType.GEMMA2:
Expand Down Expand Up @@ -983,6 +992,20 @@ def __call__(
page_state,
slot,
)
elif cfg.decoder_block == DecoderBlockType.DEEPSEEK_V4:
bidirectional_mask_value = multimodal_input.bidirectional_mask if multimodal_input is not None else None
y = self._apply_deepseek_v4_scanned_blocks(
y,
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
previous_chunk,
page_state,
slot,
bidirectional_mask=bidirectional_mask_value,
decoder_input_tokens=decoder_input_tokens,
)
else:
RemattedBlockLayer = RemattedBlockLayers[0]
scan_length = int(cfg.num_decoder_layers / cfg.inhomogeneous_layer_cycle_interval)
Expand Down Expand Up @@ -1089,6 +1112,16 @@ def __call__(
RemattedBlockLayer = RemattedBlockLayers[0]
layer_kwargs = {}
layer_call_kwargs = {}
if cfg.decoder_block == DecoderBlockType.DEEPSEEK_V4:
# Retrieve layer-specific compression ratio from configuration to support sliding window attention
# at boundary layers and alternating compressed sparse/heavily compressed attention.
compress_ratio = self.config.compress_ratios[lyr]
bidirectional_mask_value = multimodal_input.bidirectional_mask if multimodal_input is not None else None
layer_kwargs = {"compress_ratio": compress_ratio, "layer_idx": lyr}
layer_call_kwargs = {
"decoder_input_tokens": decoder_input_tokens,
"bidirectional_mask": bidirectional_mask_value,
}
if cfg.decoder_block == DecoderBlockType.GEMMA3:
# Gemma3 uses both global and sliding window attention depending on the layer index.
bidirectional_mask_value = multimodal_input.bidirectional_mask if multimodal_input is not None else None
Expand Down Expand Up @@ -1151,7 +1184,11 @@ def __call__(
assert isinstance(y, jax.Array)

# After the final transformer layer, `y` holds the raw, un-normalized hidden state.
if cfg.mhc_expansion_rate > 1:
if cfg.decoder_block == DecoderBlockType.DEEPSEEK_V4:
# Collapse final streams using learnable collapse weights [B, S, k, D] -> [B, S, D]
hc_head = nnx_wrappers.to_linen_class(DeepSeekV4HyperHead, name="hc_head")(config=cfg)
hidden_state = hc_head(y)
elif cfg.mhc_expansion_rate > 1:
# (batch, length, mhc_expansion_rate, emb_dim) --> (batch, length, emb_dim)
hidden_state = mhc_reduce(y)
else:
Expand Down Expand Up @@ -1335,6 +1372,81 @@ def _apply_gemma4_scanned_blocks(

return y

def _apply_deepseek_v4_scanned_blocks(
self,
y,
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
previous_chunk,
page_state,
slot,
bidirectional_mask=None,
decoder_input_tokens=None,
):
"""Applies DeepSeek-V4 scanned decoder blocks under Flax Linen, handling main scan and remainders."""
cfg = self.config
mesh = self.mesh

# Define the repeating pattern length (2 for cyclical DeepSeek-V4 layers)
scan_length = cfg.num_decoder_layers // 2

policy = self.get_remat_policy()
RemattedDSV4Block = self.set_remat_policy([deepseek_v4.DeepSeekV4ScannableBlockToLinen], policy)[0]

layer_call_kwargs = {
"decoder_input_tokens": decoder_input_tokens,
"bidirectional_mask": bidirectional_mask,
}
layer_kwargs = {"num_of_layers": 2}

# Apply the main scan over the full blocks
if scan_length > 0:
broadcast_args = (
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
)
# inputs: y shape [B, S, k, D] -> [B, S, k, D]
y, _ = self.scan_decoder_layers(
cfg,
RemattedDSV4Block,
scan_length,
"layers",
mesh,
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
model_mode=self.model_mode,
**layer_kwargs,
)(y, *broadcast_args, **layer_call_kwargs)

# To allow efficient JAX/Flax compilation of the cyclic scanned layers in pairs,
# any prefix layers that have heterogeneous routing topologies (e.g., num_hash_layers
# static Hash Routing MoE layers) must be unrolled statelessly prior to the scan loop.
# Consequently, the remaining layers to be scanned and remainder-calculated are reduced by
# the number of prefix unrolled layers (cfg.num_hash_layers).
# Subtracting cfg.num_hash_layers ensures the correct remainder count of MoE layers
# evaluated statelessly, preventing shape mismatch and compilation drift.
num_remaining_layers = (cfg.num_decoder_layers - cfg.num_hash_layers) % 2
if num_remaining_layers > 0:
rem_layer_kwargs = {"num_of_layers": num_remaining_layers}
layer = RemattedDSV4Block(
config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode, name="layers_remainder", **rem_layer_kwargs
)
y, _ = layer(
y,
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
previous_chunk=previous_chunk,
page_state=page_state,
slot=slot,
**layer_call_kwargs,
)
return y

# TODO(b/490118813): Relocate the following functions to their designated directories
# once the plug-in strategy is implemented: _find_next_boundary(), _apply_single_engram_layer()
# _apply_scanned_chunk() and _apply_interleaved_scanned_layers().
Expand Down
9 changes: 9 additions & 0 deletions src/maxtext/layers/engram.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,15 @@ class StaticWrapper:
def __init__(self, val):
self.val = val

def __getitem__(self, key):
return self.val[key]

def __setitem__(self, key, value):
if key is Ellipsis:
self.val = value
else:
self.val = self.val.at[key].set(value)


class MultiHeadEmbedding(nnx.Module):
"""
Expand Down
Loading
Loading