From 11f28487977277b9dab411f6d574894c7a33b31c Mon Sep 17 00:00:00 2001 From: x54-729 Date: Tue, 16 Jun 2026 23:21:28 +0800 Subject: [PATCH 01/14] add multi mtp config; add mtp mask type v1 --- xtuner/v1/engine/train_engine.py | 5 +- xtuner/v1/loss/mtp_loss.py | 40 +++- xtuner/v1/model/base.py | 13 +- xtuner/v1/model/moe/moe.py | 314 ++++++++++++++++++++----------- xtuner/v1/module/mtp/config.py | 2 + 5 files changed, 259 insertions(+), 115 deletions(-) diff --git a/xtuner/v1/engine/train_engine.py b/xtuner/v1/engine/train_engine.py index 74a56d4643..6021ff7794 100644 --- a/xtuner/v1/engine/train_engine.py +++ b/xtuner/v1/engine/train_engine.py @@ -569,6 +569,9 @@ def _get_total_loss(self, model_outputs: ModelOutputs) -> torch.Tensor: loss = torch.tensor(0.0, device=DEVICE) for key in model_outputs.model_fields: value = getattr(model_outputs, key) - if "loss" in key and isinstance(value, torch.Tensor): + if key == "mtp_loss" and isinstance(value, dict): + for mtp_loss_name, mtp_loss in value.items(): + loss += mtp_loss + elif "loss" in key and isinstance(value, torch.Tensor): loss += value return loss diff --git a/xtuner/v1/loss/mtp_loss.py b/xtuner/v1/loss/mtp_loss.py index a5aebc3b19..2cbeaab132 100644 --- a/xtuner/v1/loss/mtp_loss.py +++ b/xtuner/v1/loss/mtp_loss.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any +from typing import Any, Optional import torch import torch.nn.functional as F @@ -60,6 +60,7 @@ class MTPLossConfig(CELossConfig): mtp_depth: int detach_mtp_lm_head_weight: bool = False + mask_type: Optional[str] = None @property def loss_ctx_cls(self) -> type["MTPLossContext"]: @@ -167,6 +168,12 @@ def forward( head_weight = head_weight.detach() head_bias = head_bias.detach() if head_bias is not None else None # Dispatch to eager_mode/chunk_mode via base class, which calls loss_fn per chunk + + mask_type = self.loss_cfg.mask_type + if mask_type == "v1": + self.process_loss_weight_v1() + elif mask_type is not None: + raise NotImplementedError(f"Unknown MTP Loss Mask Type: {mask_type}") return super().forward(hidden_states, head_weight, head_bias) def loss_fn( @@ -214,3 +221,34 @@ def _kl_loss_fn( ) return kl_loss, (None, {}) + + def process_loss_weight_v1(self): + layer_idx = self.loss_cfg.mtp_depth - 1 + shifted_labels = self.loss_kwargs.shifted_labels + loss_weight = self.loss_kwargs.loss_weight + sum_loss_weight = loss_weight.sum() + + easy_to_use = torch.cat( + [ + shifted_labels, + torch.zeros((shifted_labels.size(0), 1), dtype=shifted_labels.dtype, device=shifted_labels.device), + ], + dim=-1, + ) + + # TODO: digit and dot token config + is_digit = torch.where(easy_to_use < 25, easy_to_use > 14, 0) + is_dot = torch.where(easy_to_use == 13, 1, 0) + is_digit_or_dot = is_digit | is_dot + + mask = is_digit_or_dot.clone() + for i in range(layer_idx + 1): + mask |= torch.roll(is_digit_or_dot, shifts=i + 1, dims=-1) + + mtp_mask = mask.bool()[:, :-1] + + loss_weight[mtp_mask == 0.0] = 0.0 + if loss_weight.sum().item() != 0: + loss_weight = loss_weight * sum_loss_weight / loss_weight.sum() + + self.loss_kwargs.loss_weight = loss_weight diff --git a/xtuner/v1/model/base.py b/xtuner/v1/model/base.py index c02c4dc9dd..32aef6f355 100644 --- a/xtuner/v1/model/base.py +++ b/xtuner/v1/model/base.py @@ -1313,7 +1313,18 @@ def post_micro_batch_forward(self, batch_outputs: Sequence[ModelOutputs]) -> Bat output_copy = output.model_copy() for name in output_copy.model_fields: obj = getattr(output_copy, name) - if "loss" in name and isinstance(obj, torch.Tensor): + if name == "mtp_loss" and isinstance(obj, dict): + for key, value in obj.items(): + loss_item = value.item() + local_total_loss += loss_item + reduced_name = f"{key}_reduced_mtp_loss" + + if reduced_name not in reduced_other_losses: + reduced_other_losses[reduced_name] = loss_item + else: + reduced_other_losses[reduced_name] += loss_item + + elif "loss" in name and isinstance(obj, torch.Tensor): loss_item = obj.item() local_total_loss += loss_item reduced_name = f"reduced_{name}" diff --git a/xtuner/v1/model/moe/moe.py b/xtuner/v1/model/moe/moe.py index 061e2f6e29..799b19887b 100644 --- a/xtuner/v1/model/moe/moe.py +++ b/xtuner/v1/model/moe/moe.py @@ -2,7 +2,7 @@ import os import types from pathlib import Path -from typing import TYPE_CHECKING, Annotated, Literal, Self, Sequence, TypedDict, cast +from typing import TYPE_CHECKING, Annotated, List, Literal, Self, Sequence, TypedDict, cast import torch import torch.distributed as dist @@ -102,7 +102,7 @@ class MoEModelOutputs(ModelOutputs): balancing_loss: torch.Tensor | None = None z_loss: torch.Tensor | None = None tokens_per_expert_global: torch.Tensor - mtp_loss: torch.Tensor | None = None + mtp_loss: dict[str, torch.Tensor] | None = None def free_nongrad_feature(self): """Release large intermediate tensors not needed for backward or @@ -129,7 +129,7 @@ class MoELossContextDict(TypedDict): lm: BaseLossContext balancing: BalancingLossContext | None z_loss: ZLossContext | None - mtp: list[BaseLossContext] | None + mtp: dict[str, list[BaseLossContext]] | None class MoEConfig(TransformerConfig): @@ -151,7 +151,7 @@ class MoEConfig(TransformerConfig): router_compute_dtype: Literal["float32", "native"] = "float32" moe_bias: bool = False moe_act_fn_cfg: MoEActFnConfig = MoEActFnConfig() - mtp_config: MTPConfig | None = None + mtp_config: List[MTPConfig] | None = None freeze_routers: bool = False router_async_offload: bool = False aux_loss_cfg: AuxLossConfig = AuxLossConfig() @@ -202,7 +202,7 @@ def __init__(self, config: MoEConfig): self.layers = self.build_layers(config) self.rotary_emb = self.build_rotary_embedding(config) self.embed_tokens = self.build_embeddings(config) - self.mtp_block = self.build_mtp_block(config) if config.mtp_config is not None else None + self.mtp_block = self.build_mtp_block_dict(config) if config.mtp_config is not None else None self.fp32_layers = [self.rotary_emb] @@ -338,23 +338,30 @@ def build_loss_ctx_batch( # type: ignore[override] # Add MTP loss contexts if MTP is enabled if self.config.mtp_config is not None: - for mtp_idx in range(self.config.mtp_config.num_layers): - mtp_loss_cfg = MTPLossConfig( - **self.config.lm_loss_cfg.model_dump(), - mtp_depth=mtp_idx + 1, - detach_mtp_lm_head_weight=self.config.mtp_config.detach_mtp_lm_head_weight, - ) - mtp_loss_ctx_list = self._build_loss_ctx(mtp_loss_cfg, _data_batch, sp_mesh) - if mtp_loss_ctx_list is not None: - mtp_loss_ctx_list = MTPLossContext.build_batches( # type: ignore[assignment] - cast(list[MTPLossContext], mtp_loss_ctx_list), # type: ignore[arg-type] - cu_seq_lens_list=cu_seq_lens_list, - sp_mesh=sp_mesh, + # Build MTP loss contexts using the same approach as LM loss + # Each MTP depth needs its own loss context + for mtp_config in self.config.mtp_config: + for mtp_idx in range(mtp_config.num_layers): + mtp_loss_cfg = MTPLossConfig( + **self.config.lm_loss_cfg.model_dump(), + mtp_depth=mtp_idx + 1, + detach_mtp_lm_head_weight=mtp_config.detach_mtp_lm_head_weight, + mask_type=mtp_config.mask_type, ) - for i, mtp_loss_ctx in enumerate(mtp_loss_ctx_list): - if "mtp" not in res[i]: - res[i]["mtp"] = [] - res[i]["mtp"].append(mtp_loss_ctx) # type: ignore[union-attr] + # MTP needs to shift labels multiple times. Since rebuild the `shifted_labels` in data_batch + mtp_loss_ctx_list = self._build_loss_ctx(mtp_loss_cfg, _data_batch, sp_mesh) + if mtp_loss_ctx_list is not None: + mtp_loss_ctx_list = MTPLossContext.build_batches( # type: ignore[assignment] + cast(list[MTPLossContext], mtp_loss_ctx_list), # type: ignore[arg-type] + cu_seq_lens_list=cu_seq_lens_list, + sp_mesh=sp_mesh, + ) + for i, mtp_loss_ctx in enumerate(mtp_loss_ctx_list): + if "mtp" not in res[i]: + res[i]["mtp"] = {} + if mtp_config.name not in res[i]["mtp"]: + res[i]["mtp"][mtp_config.name] = [] + res[i]["mtp"][mtp_config.name].append(mtp_loss_ctx) # type: ignore[union-attr] # Ensure all microbatches have mtp key for loss_ctx_dict in res: @@ -571,34 +578,52 @@ def _micro_batch_forward( ) ) - mtp_outputs_per_mb = self.mtp_block( - *hidden_states_list, - embed_tokens_fn=self.embed_tokens, - position_embeddings=position_embeddings_list, - seq_ctx=mtp_seq_ctx_list, - ) + # Initialize mtp_losses dict to store losses for each mtp_config + mtp_losses_dict: dict[str, torch.Tensor] = {} - mtp_losses = torch.tensor(0.0, device=DEVICE) - has_mtp_loss = False - for micro_batch_idx, (loss_ctx_dict, mtp_outputs) in enumerate(zip(loss_ctx_list, mtp_outputs_per_mb)): - mtp_loss_ctx_list = loss_ctx_dict.get("mtp") - if mtp_loss_ctx_list is None: - continue + # Loop through each mtp_config + for mtp_config in self.config.mtp_config: + name = mtp_config.name - micro_batch_mtp_losses = torch.tensor(0.0, device=DEVICE) - for mtp_idx, (mtp_hidden, mtp_ctx) in enumerate(zip(mtp_outputs, mtp_loss_ctx_list)): - mtp_hidden_states, mtp_router_results, _ = mtp_hidden - mtp_loss, _ = self.lm_head(mtp_hidden_states, cast(MTPLossContext, mtp_ctx)) - micro_batch_mtp_losses += mtp_loss + # Get the MTP block for this config by name + mtp_outputs_per_mb = self.mtp_block[name]( + *hidden_states_list, + embed_tokens_fn=self.embed_tokens, + position_embeddings=position_embeddings_list, + seq_ctx=mtp_seq_ctx_list, + ) - if keep_router: - router_logits_list[micro_batch_idx][f"mtp_layer{mtp_idx}"] = mtp_router_results + mtp_losses = torch.tensor(0.0, device=DEVICE) + has_mtp_loss = False + for micro_batch_idx, (loss_ctx_dict, mtp_outputs) in enumerate(zip(loss_ctx_list, mtp_outputs_per_mb)): + # Get the mtp loss context dict + mtp_loss_ctx_dict = loss_ctx_dict.get("mtp") + if mtp_loss_ctx_dict is None or name not in mtp_loss_ctx_dict: + continue - mtp_losses += micro_batch_mtp_losses / len(mtp_loss_ctx_list) - has_mtp_loss = True + # Get the loss context list for this mtp_config name + mtp_loss_ctx_list = mtp_loss_ctx_dict[name] - if has_mtp_loss: - output["mtp_loss"] = mtp_losses * self.config.mtp_config.loss_scaling_factor + micro_batch_mtp_losses = torch.tensor(0.0, device=DEVICE) + for mtp_idx, (mtp_hidden, mtp_ctx) in enumerate(zip(mtp_outputs, mtp_loss_ctx_list)): + mtp_hidden_states, mtp_router_results, _ = mtp_hidden + mtp_loss, _ = self.lm_head(mtp_hidden_states, cast(MTPLossContext, mtp_ctx)) + micro_batch_mtp_losses += mtp_loss + + if keep_router: + # Add name prefix to router logits key + router_logits_list[micro_batch_idx][f"{name}_mtp_layer{mtp_idx}"] = mtp_router_results + + mtp_losses += micro_batch_mtp_losses / len(mtp_loss_ctx_list) + has_mtp_loss = True + + if has_mtp_loss: + # Use the loss_scaling_factor from current mtp_config + mtp_losses_dict[name] = mtp_losses * mtp_config.loss_scaling_factor + + # Store mtp losses as dict + if mtp_losses_dict: + output["mtp_loss"] = mtp_losses_dict # Apply final norm to all micro-batches cat_hidden_states = torch.cat(hidden_states_list, dim=1) @@ -645,6 +670,70 @@ def _micro_batch_forward( return MoEModelOutputs(**output, logits=logits) + def _mtp_forward( + self, + mtp_config: MTPConfig, + output, + layer_hidden_states, + position_embeddings, + seq_ctx, + balancing_ctx, + z_ctx, + mtp_seq_ctx, + mtp_loss_ctx_dict, + keep_router: bool, + ): + # MTP uses its own mask; main mask's non-pad indices do not apply. + name = mtp_config.name + mtp_nonpad_indices = torch.nonzero(mtp_seq_ctx.mask, as_tuple=True)[1] + mtp_non_pad_token = mtp_nonpad_indices.numel() + mtp_num_tokens_global, mtp_z_world_size = self._z_loss_dist_token_count( + z_ctx, mtp_non_pad_token, mtp_seq_ctx.mask.device + ) + + # Forward through MTP block + mtp_outputs = self.mtp_block[name]( + layer_hidden_states, + embed_tokens_fn=self.embed_tokens, + position_embeddings=position_embeddings, + seq_ctx=mtp_seq_ctx, + ) + + # Compute MTP losses for each depth + mtp_losses = torch.tensor(0.0, device=DEVICE) + mtp_loss_ctx_list = mtp_loss_ctx_dict[name] + for idx, (mtp_hidden, mtp_ctx) in enumerate(zip(mtp_outputs, mtp_loss_ctx_list)): + mtp_hidden_states, mtp_router_results, mtp_router_weights = mtp_hidden + + if keep_router: + output["router_logits"][f"{name}_mtp_layer{idx}"] = mtp_router_results + output["router_weights"][f"{name}_mtp_layer{idx}"] = mtp_router_weights + # Inject this MTP layer's z-loss before lm_head so backward through mtp_loss + # traverses the AuxLossScaler node and releases this layer's logsumexp activations. + mtp_hidden_states = self.aux_loss.accumulate( + selected_router_weights=mtp_router_weights.index_select(0, mtp_nonpad_indices) + .contiguous() + .float(), + selected_router_logits=mtp_router_results.index_select(0, mtp_nonpad_indices).contiguous().float(), + hidden_states=mtp_hidden_states, + balancing_ctx=balancing_ctx, + z_ctx=z_ctx, + num_tokens_local=mtp_non_pad_token, + num_tokens_global=mtp_num_tokens_global, + world_size=mtp_z_world_size, + ) + mtp_loss, _ = self.lm_head(mtp_hidden_states, cast(MTPLossContext, mtp_ctx)) + mtp_losses += mtp_loss + + # Average MTP losses across depths and scale + mtp_losses = mtp_losses / len(mtp_loss_ctx_list) + scaled_mtp_loss = mtp_losses * mtp_config.loss_scaling_factor # type: ignore + + # Add to total loss + output["mtp_loss"][name] = scaled_mtp_loss + + return scaled_mtp_loss + def _forward( self, seq_ctx: SequenceContext, # todo(@yehaochen): support intra layer micro-batch @@ -751,59 +840,28 @@ def _forward( if ( self.mtp_block is not None and loss_ctx is not None - and (mtp_loss_ctx_list := loss_ctx.get("mtp")) is not None + and (mtp_loss_ctx_dict := loss_ctx.get("mtp")) is not None ): + output["mtp_loss"] = {} mtp_seq_ctx = seq_ctx.copy( input_ids=input_ids.clone() if input_ids is not None else None, position_ids=position_ids.clone(), inputs_embeds=seq_ctx.inputs_embeds.clone() if seq_ctx.inputs_embeds is not None else None, ) - # MTP uses its own mask; main mask's non-pad indices do not apply. - mtp_nonpad_indices = torch.nonzero(mtp_seq_ctx.mask, as_tuple=True)[1] - mtp_non_pad_token = mtp_nonpad_indices.numel() - mtp_num_tokens_global, mtp_z_world_size = self._z_loss_dist_token_count( - z_ctx, mtp_non_pad_token, mtp_seq_ctx.mask.device - ) - - # Forward through MTP block - mtp_outputs = self.mtp_block( - layer_hidden_states, - embed_tokens_fn=self.embed_tokens, - position_embeddings=position_embeddings, - seq_ctx=mtp_seq_ctx, - ) - - # Compute MTP losses for each depth - mtp_losses = torch.tensor(0.0, device=DEVICE) - for idx, (mtp_hidden, mtp_ctx) in enumerate(zip(mtp_outputs, mtp_loss_ctx_list)): - mtp_hidden_states, mtp_router_results, mtp_router_weights = mtp_hidden - if keep_router: - output["router_logits"][f"mtp_layer{idx}"] = mtp_router_results - output["router_weights"][f"mtp_layer{idx}"] = mtp_router_weights - # Inject this MTP layer's z-loss before lm_head so backward through mtp_loss - # traverses the AuxLossScaler node and releases this layer's logsumexp activations. - mtp_hidden_states = self.aux_loss.accumulate( - selected_router_weights=mtp_router_weights.index_select(0, mtp_nonpad_indices) - .contiguous() - .float(), - selected_router_logits=mtp_router_results.index_select(0, mtp_nonpad_indices).contiguous().float(), - hidden_states=mtp_hidden_states, + for mtp_config in self.config.mtp_config: + self._mtp_forward( + mtp_config=mtp_config, + output=output, + layer_hidden_states=layer_hidden_states, + position_embeddings=position_embeddings, + seq_ctx=seq_ctx, balancing_ctx=balancing_ctx, z_ctx=z_ctx, - num_tokens_local=mtp_non_pad_token, - num_tokens_global=mtp_num_tokens_global, - world_size=mtp_z_world_size, + mtp_seq_ctx=mtp_seq_ctx, + mtp_loss_ctx_dict=mtp_loss_ctx_dict, + keep_router=keep_router, ) - mtp_loss, _ = self.lm_head(mtp_hidden_states, cast(MTPLossContext, mtp_ctx)) - mtp_losses += mtp_loss - - # Average MTP losses across depths and scale - mtp_losses = mtp_losses / len(mtp_loss_ctx_list) - scaled_mtp_loss = mtp_losses * self.config.mtp_config.loss_scaling_factor # type: ignore - - # Add to total loss - output["mtp_loss"] = scaled_mtp_loss split_aux_output = self.aux_loss.finalize( balancing_ctx=balancing_ctx, @@ -896,16 +954,36 @@ def build_layers(self, config: MoEConfig) -> nn.ModuleDict: layers.__class__.__repr__ = module_dict_repr # type: ignore[method-assign] return layers - def build_mtp_block(self, config: MoEConfig) -> MTPBlock: + def build_mtp_block_dict(self, config): + mtp_block_dict = nn.ModuleDict() + layer_idx_offset = 0 # Cumulative offset for layer indices across all mtp_configs + + for mtp_config in config.mtp_config: + if mtp_config.name not in ("normal", "sci"): + raise ValueError(f"Expected mtp keys to be either `normal` or `sci`, but got `{mtp_config.name}`") + if mtp_config.name in mtp_block_dict.keys(): + raise ValueError(f"Duplicate mtp name: `{mtp_config.name}`") + + # Build the MTP block with the current offset + mtp_block_dict[mtp_config.name] = self.build_mtp_block(config, mtp_config, layer_idx_offset) + + # Update offset: number of physical layers for this mtp_config + num_physical_layer = 1 if mtp_config.share_weights else mtp_config.num_layers + layer_idx_offset += num_physical_layer + + return mtp_block_dict + + def build_mtp_block(self, config: MoEConfig, mtp_config: MTPConfig, layer_idx_offset: int) -> MTPBlock: """Build MTP block with MoE decoder layers. Args: config (MoEConfig): Model configuration. + mtp_config (MTPConfig): MTP configuration for this specific block. + layer_idx_offset (int): Offset for layer indices to ensure uniqueness across multiple mtp_configs. Returns: MTPBlock: Constructed MTP block. """ - mtp_config = config.mtp_config assert mtp_config is not None, "mtp_config must be provided" mtp_layers = [] @@ -949,7 +1027,7 @@ def build_mtp_block(self, config: MoEConfig) -> MTPBlock: router_compute_dtype=config.router_compute_dtype, moe_act_fn_cfg=config.moe_act_fn_cfg, float8_cfg=config.float8_cfg, - layer_idx=config.num_hidden_layers + i, + layer_idx=config.num_hidden_layers + layer_idx_offset + i, dispatcher=config.dispatcher, ep_mesh=self.ep_mesh, ) @@ -1084,28 +1162,38 @@ def fully_shard( # Shard MTP block if it exists if self.mtp_block is not None: - for mtp_idx, mtp_layer in enumerate(self.mtp_block.layers): - if self._should_recompute(None, mtp_idx=mtp_idx) or ( - self.config.mtp_config is not None and self.config.mtp_config.share_weights - ): # share mtp head must recompute - mtp_layer = checkpoint_wrapper(mtp_layer, checkpoint_impl=CheckpointImpl.REENTRANT) - self.mtp_block.layers[mtp_idx] = mtp_layer - - reshard_after_forward = mtp_idx != len(self.mtp_block.layers) - 1 - self._fully_shard( - mesh=self.fsdp_mesh if self.hsdp_mesh is None else self.hsdp_mesh, - mp_policy=mp_policy, - reshard_after_forward=reshard_after_forward, - offload_policy=CPUOffloadPolicy() if self.fsdp_config.cpu_offload else None, - module=mtp_layer, - ) - if mtp_idx == 0: - layer_next.set_modules_to_forward_prefetch([mtp_layer]) # type: ignore - - if self.config.mtp_config is not None and self.config.mtp_config.num_layers > 0: + global_mtp_idx = 0 # Track global MTP layer index across all mtp_configs + for mtp_name in self.mtp_block.keys(): + mtp_block = self.mtp_block[mtp_name] + mtp_config = next((cfg for cfg in self.config.mtp_config if cfg.name == mtp_name), None) # type: ignore + for local_mtp_idx, mtp_layer in enumerate(mtp_block.layers): + if self._should_recompute(None, mtp_idx=global_mtp_idx) or ( + mtp_config is not None and mtp_config.share_weights + ): # share mtp head must recompute + mtp_layer = checkpoint_wrapper(mtp_layer, checkpoint_impl=CheckpointImpl.REENTRANT) + mtp_block.layers[local_mtp_idx] = mtp_layer + + reshard_after_forward = local_mtp_idx != len(mtp_block.layers) - 1 + self._fully_shard( + mesh=self.fsdp_mesh if self.hsdp_mesh is None else self.hsdp_mesh, + mp_policy=mp_policy, + reshard_after_forward=reshard_after_forward, + offload_policy=CPUOffloadPolicy() if self.fsdp_config.cpu_offload else None, + module=mtp_layer, + ) + # Only set prefetch for the first MTP layer across all mtp_configs + if global_mtp_idx == 0: + layer_next.set_modules_to_forward_prefetch([mtp_layer]) # type: ignore + global_mtp_idx += 1 + + # Set up prefetch chains across all MTP blocks + if self.config.mtp_config is not None: + mtp_block_layers = [] + for mtp_config in self.config.mtp_config: + mtp_block_layers.extend(list(self.mtp_block[mtp_config.name].layers)) for prev_mtp_layer, next_mtp_layer in zip( - list(self.mtp_block.layers)[:-1], - list(self.mtp_block.layers)[1:], + mtp_block_layers[:-1], + mtp_block_layers[1:], ): prev_mtp_layer.set_modules_to_forward_prefetch([next_mtp_layer]) # type: ignore @@ -1329,7 +1417,9 @@ def _should_recompute( """ num_layers = self.config.num_hidden_layers if self.config.mtp_config is not None: - mtp_layers = 1 if self.config.mtp_config.share_weights else self.config.mtp_config.num_layers + mtp_layers = sum( + [1 if mtp_config.share_weights else mtp_config.num_layers for mtp_config in self.config.mtp_config] + ) else: mtp_layers = 0 recompute_ratio = self.fsdp_config.recompute_ratio if self.fsdp_config is not None else 0.0 diff --git a/xtuner/v1/module/mtp/config.py b/xtuner/v1/module/mtp/config.py index d6bcf7a9b5..032a4cb10b 100644 --- a/xtuner/v1/module/mtp/config.py +++ b/xtuner/v1/module/mtp/config.py @@ -45,8 +45,10 @@ class MTPConfig(BaseModel): model_config = ConfigDict(extra="forbid") + name: Annotated[str, Parameter(group="model")] num_layers: Annotated[int, Parameter(group="model")] share_weights: Annotated[bool, Parameter(group="model")] = False detach_mtp_lm_head_weight: Annotated[bool, Parameter(group="model")] = False detach_mtp_inputs: Annotated[bool, Parameter(group="model")] = False loss_scaling_factor: Annotated[float, Parameter(group="model")] = 0.1 + mask_type: Annotated[str | None, Parameter(group="model")] From b427a7f3c537863788a950495679734553c67405 Mon Sep 17 00:00:00 2001 From: x54-729 Date: Wed, 17 Jun 2026 16:22:55 +0800 Subject: [PATCH 02/14] Add SciMTPLossContext SciMTPConfig SciMTPLossConfig; Compatible with single MTP --- xtuner/v1/loss/mtp_loss.py | 126 ++++++++++++++++++++++++++++--- xtuner/v1/model/moe/moe.py | 49 ++++++++---- xtuner/v1/module/mtp/__init__.py | 4 +- xtuner/v1/module/mtp/config.py | 43 ++++++++++- 4 files changed, 192 insertions(+), 30 deletions(-) diff --git a/xtuner/v1/loss/mtp_loss.py b/xtuner/v1/loss/mtp_loss.py index 2cbeaab132..57e86fa462 100644 --- a/xtuner/v1/loss/mtp_loss.py +++ b/xtuner/v1/loss/mtp_loss.py @@ -60,7 +60,6 @@ class MTPLossConfig(CELossConfig): mtp_depth: int detach_mtp_lm_head_weight: bool = False - mask_type: Optional[str] = None @property def loss_ctx_cls(self) -> type["MTPLossContext"]: @@ -69,9 +68,9 @@ def loss_ctx_cls(self) -> type["MTPLossContext"]: @property def _loss_kwargs_cls(self) -> type["MTPLossKwargs"]: return MTPLossKwargs - - def build(self, data: dict, sp_mesh: DeviceMesh | None = None) -> "MTPLossContext | None": - """Build MTPLossContext from data dict. + + def _process_loss_kwargs(self, data: dict, sp_mesh: DeviceMesh | None = None) -> torch.Tensor: + """Process loss_kwargs for MTP Loss. Rolls ``shifted_labels`` (and optionally ``logprobs``) by ``-mtp_depth`` positions (per-sequence, respecting packed-sequence @@ -86,8 +85,7 @@ def build(self, data: dict, sp_mesh: DeviceMesh | None = None) -> "MTPLossContex sp_mesh (DeviceMesh | None): Sequence parallel mesh. Returns: - MTPLossContext | None: Built loss context, or ``None`` if - ``shifted_labels`` is not present in ``data``. + torch.Tensor: loss_kwargs """ # TODO: Should move the common utils function to public package to avoid from circular import. from xtuner.v1.module.mtp.utils import roll_packed_tensor @@ -136,9 +134,85 @@ def build(self, data: dict, sp_mesh: DeviceMesh | None = None) -> "MTPLossContex if sp_mesh is not None and sp_mesh.size() > 1: loss_kwargs = loss_kwargs.sp_split(sp_mesh) + return loss_kwargs + + def build(self, data: dict, sp_mesh: DeviceMesh | None = None) -> "MTPLossContext | None": + """Build MTPLossContext from data dict. + + Rolls ``shifted_labels`` (and optionally ``logprobs``) by + ``-mtp_depth`` positions (per-sequence, respecting packed-sequence + boundaries) before constructing the loss context. The roll is performed + on the full sequence prior to any sequence-parallel split so that + boundary positions and ``cu_seq_lens`` are always consistent. + + Args: + data (dict): Data dict containing loss-related fields. + Required keys: ``shifted_labels``, ``seq_ctx``. + Optional keys: ``logprobs``. + sp_mesh (DeviceMesh | None): Sequence parallel mesh. + + Returns: + MTPLossContext | None: Built loss context, or ``None`` if + ``shifted_labels`` is not present in ``data``. + """ + + if "shifted_labels" not in data: + return None + + loss_kwargs = self._process_loss_kwargs(data, sp_mesh) return MTPLossContext(self, loss_kwargs) +class SciMTPLossConfig(MTPLossConfig): + """Loss configuration for Multi-Token Prediction (MTP). + + Extends ``MTPLossConfig`` with a ``mask_type`` field that controls how to mask + ``loss_kwargs`` when calculating loss. + + Args: + mtp_depth (int): 1-indexed MTP layer depth. The first MTP layer uses + ``mtp_depth=1`` (shift=-1 on top of the existing label shift). + detach_mtp_lm_head_weight (bool): Whether to detach the LM head weight. + This is used in RL training. Default is False. + """ + + mask_type: Optional[str] = None + + @property + def loss_ctx_cls(self) -> type["SciMTPLossContext"]: + return SciMTPLossContext + + @property + def _loss_kwargs_cls(self) -> type["MTPLossKwargs"]: + return MTPLossKwargs + + def build(self, data: dict, sp_mesh: DeviceMesh | None = None) -> "SciMTPLossContext | None": + """Build SciMTPLossContext from data dict. + + Rolls ``shifted_labels`` (and optionally ``logprobs``) by + ``-mtp_depth`` positions (per-sequence, respecting packed-sequence + boundaries) before constructing the loss context. The roll is performed + on the full sequence prior to any sequence-parallel split so that + boundary positions and ``cu_seq_lens`` are always consistent. + + Args: + data (dict): Data dict containing loss-related fields. + Required keys: ``shifted_labels``, ``seq_ctx``. + Optional keys: ``logprobs``. + sp_mesh (DeviceMesh | None): Sequence parallel mesh. + + Returns: + SciMTPLossContext | None: Built loss context, or ``None`` if + ``shifted_labels`` is not present in ``data``. + """ + + if "shifted_labels" not in data: + return None + + loss_kwargs = self._process_loss_kwargs(data, sp_mesh) + return SciMTPLossContext(self, loss_kwargs) + + class MTPLossContext(LMHeadLossContext): """Loss context for Multi-Token Prediction (MTP). @@ -168,12 +242,6 @@ def forward( head_weight = head_weight.detach() head_bias = head_bias.detach() if head_bias is not None else None # Dispatch to eager_mode/chunk_mode via base class, which calls loss_fn per chunk - - mask_type = self.loss_cfg.mask_type - if mask_type == "v1": - self.process_loss_weight_v1() - elif mask_type is not None: - raise NotImplementedError(f"Unknown MTP Loss Mask Type: {mask_type}") return super().forward(hidden_states, head_weight, head_bias) def loss_fn( @@ -222,6 +290,40 @@ def _kl_loss_fn( return kl_loss, (None, {}) + +class SciMTPLossContext(MTPLossContext): + """Loss context for Science Multi-Token Prediction (MTP). + + Supports two modes: + - **CE mode** (default): Standard cross-entropy loss on rolled labels. + Used during SFT/pretraining. + - **KL mode**: When ``logprobs`` is available (RL training), + computes KL divergence between MTP's log-probabilities and the + rolled rollout log-probabilities. + + Both modes support chunk mode for memory-efficient computation via the + base class's ``forward() → eager_mode()/chunk_mode() → loss_fn()`` dispatch. + + Args: + loss_cfg (MTPLossConfig): The MTP loss configuration. + loss_kwargs (MTPLossKwargs): Pre-rolled keyword arguments for loss + computation. + """ + + def forward( + self, + hidden_states: torch.Tensor, + head_weight: torch.Tensor, + head_bias: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, tuple[torch.Tensor | None, dict[str, Any]]]: + mask_type = self.loss_cfg.mask_type + if mask_type == "v1": + self.process_loss_weight_v1() + elif mask_type is not None: + raise NotImplementedError(f"Unknown MTP Loss Mask Type: {mask_type}") + + return super().forward(hidden_states, head_weight, head_bias) + def process_loss_weight_v1(self): layer_idx = self.loss_cfg.mtp_depth - 1 shifted_labels = self.loss_kwargs.shifted_labels diff --git a/xtuner/v1/model/moe/moe.py b/xtuner/v1/model/moe/moe.py index 799b19887b..367e41f6f1 100644 --- a/xtuner/v1/model/moe/moe.py +++ b/xtuner/v1/model/moe/moe.py @@ -2,7 +2,7 @@ import os import types from pathlib import Path -from typing import TYPE_CHECKING, Annotated, List, Literal, Self, Sequence, TypedDict, cast +from typing import TYPE_CHECKING, Annotated, Literal, Self, Sequence, TypedDict, cast import torch import torch.distributed as dist @@ -36,7 +36,7 @@ ZLossConfig, ZLossContext, ) -from xtuner.v1.loss.mtp_loss import MTPLossConfig +from xtuner.v1.loss.mtp_loss import MTPLossConfig, SciMTPLossConfig from xtuner.v1.model.base import ( DEFAULT_FLOAT8_CFG, BaseModel, @@ -58,7 +58,7 @@ ) from xtuner.v1.module.decoder_layer.dense_decoder_layer import DenseDecoderLayer from xtuner.v1.module.decoder_layer.moe_decoder_layer import MoEActFnConfig, MoEBlock, MoEDecoderLayer -from xtuner.v1.module.mtp import MTPBlock, MTPConfig, MTPLayer +from xtuner.v1.module.mtp import MTPBlock, MTPConfig, MTPLayer, SciMTPConfig from xtuner.v1.utils import ( get_device, get_logger, @@ -151,7 +151,7 @@ class MoEConfig(TransformerConfig): router_compute_dtype: Literal["float32", "native"] = "float32" moe_bias: bool = False moe_act_fn_cfg: MoEActFnConfig = MoEActFnConfig() - mtp_config: List[MTPConfig] | None = None + mtp_config: list[MTPConfig] | MTPConfig | None = None freeze_routers: bool = False router_async_offload: bool = False aux_loss_cfg: AuxLossConfig = AuxLossConfig() @@ -186,6 +186,11 @@ class MoE(BaseModel): def __init__(self, config: MoEConfig): super().__init__(config) + + # Normalize mtp_config to always be a list or None for consistent handling + if config.mtp_config is not None and not isinstance(config.mtp_config, list): + config.mtp_config = [config.mtp_config] + if config.ep_size is not None and config.ep_size > 1: world_size = dist.get_world_size() self.ep_mesh = init_device_mesh( @@ -342,17 +347,25 @@ def build_loss_ctx_batch( # type: ignore[override] # Each MTP depth needs its own loss context for mtp_config in self.config.mtp_config: for mtp_idx in range(mtp_config.num_layers): - mtp_loss_cfg = MTPLossConfig( - **self.config.lm_loss_cfg.model_dump(), - mtp_depth=mtp_idx + 1, - detach_mtp_lm_head_weight=mtp_config.detach_mtp_lm_head_weight, - mask_type=mtp_config.mask_type, - ) + # Create the appropriate loss config based on mtp_config type + if isinstance(mtp_config, SciMTPConfig): + mtp_loss_cfg = SciMTPLossConfig( + **self.config.lm_loss_cfg.model_dump(), + mtp_depth=mtp_idx + 1, + detach_mtp_lm_head_weight=mtp_config.detach_mtp_lm_head_weight, + mask_type=mtp_config.mask_type, + ) + else: + mtp_loss_cfg = MTPLossConfig( + **self.config.lm_loss_cfg.model_dump(), + mtp_depth=mtp_idx + 1, + detach_mtp_lm_head_weight=mtp_config.detach_mtp_lm_head_weight, + ) # MTP needs to shift labels multiple times. Since rebuild the `shifted_labels` in data_batch mtp_loss_ctx_list = self._build_loss_ctx(mtp_loss_cfg, _data_batch, sp_mesh) if mtp_loss_ctx_list is not None: - mtp_loss_ctx_list = MTPLossContext.build_batches( # type: ignore[assignment] - cast(list[MTPLossContext], mtp_loss_ctx_list), # type: ignore[arg-type] + mtp_loss_ctx_list = type(mtp_loss_ctx_list[0]).build_batches( # type: ignore[assignment] + mtp_loss_ctx_list, # type: ignore[arg-type] cu_seq_lens_list=cu_seq_lens_list, sp_mesh=sp_mesh, ) @@ -625,6 +638,10 @@ def _micro_batch_forward( if mtp_losses_dict: output["mtp_loss"] = mtp_losses_dict + mtp_loss = 0 + for mtp_loss_name, mtp_loss in output["mtp_loss"].items(): + mtp_loss += mtp_loss + # Apply final norm to all micro-batches cat_hidden_states = torch.cat(hidden_states_list, dim=1) cat_hidden_states = self.norm(cat_hidden_states) @@ -636,7 +653,7 @@ def _micro_batch_forward( loss, (logits, extra_info) = self.lm_head(cat_hidden_states, cast(LMHeadLossContext, cat_loss_ctx)) # Aggregate losses (mean across micro-batches) - output["loss"] = loss.sum() + output["loss"] = loss.sum() + mtp_loss moe_extra_info = ModelForwardExtraLogInfo() if extra_info: moe_extra_info.append(extra_info) @@ -676,7 +693,6 @@ def _mtp_forward( output, layer_hidden_states, position_embeddings, - seq_ctx, balancing_ctx, z_ctx, mtp_seq_ctx, @@ -855,7 +871,6 @@ def _forward( output=output, layer_hidden_states=layer_hidden_states, position_embeddings=position_embeddings, - seq_ctx=seq_ctx, balancing_ctx=balancing_ctx, z_ctx=z_ctx, mtp_seq_ctx=mtp_seq_ctx, @@ -863,6 +878,10 @@ def _forward( keep_router=keep_router, ) + # add mtp_loss to loss + for mtp_loss_name, mtp_loss in output["mtp_loss"].items(): + output["loss"] += mtp_loss + split_aux_output = self.aux_loss.finalize( balancing_ctx=balancing_ctx, z_ctx=z_ctx, diff --git a/xtuner/v1/module/mtp/__init__.py b/xtuner/v1/module/mtp/__init__.py index 8ced4cbaae..6a4bf0f6d6 100644 --- a/xtuner/v1/module/mtp/__init__.py +++ b/xtuner/v1/module/mtp/__init__.py @@ -1,7 +1,7 @@ -from .config import MTPConfig +from .config import MTPConfig, SciMTPConfig from .mtp_block import MTPBlock from .mtp_layer import MTPLayer from .utils import roll_packed_tensor, roll_sequence_context -__all__ = ["MTPConfig", "MTPBlock", "MTPLayer", "roll_packed_tensor", "roll_sequence_context"] +__all__ = ["MTPConfig", "SciMTPConfig", "MTPBlock", "MTPLayer", "roll_packed_tensor", "roll_sequence_context"] diff --git a/xtuner/v1/module/mtp/config.py b/xtuner/v1/module/mtp/config.py index 032a4cb10b..1e8334d234 100644 --- a/xtuner/v1/module/mtp/config.py +++ b/xtuner/v1/module/mtp/config.py @@ -18,6 +18,7 @@ class MTPConfig(BaseModel): decoder layers. Args: + name (str): Name of mtp module. num_layers (int): Number of MTP layers (prediction depths). Each layer predicts tokens at increasing future positions (i+1, i+2, ..., i+D). share_weights (bool): Whether to share the weights of the MTP layers. @@ -51,4 +52,44 @@ class MTPConfig(BaseModel): detach_mtp_lm_head_weight: Annotated[bool, Parameter(group="model")] = False detach_mtp_inputs: Annotated[bool, Parameter(group="model")] = False loss_scaling_factor: Annotated[float, Parameter(group="model")] = 0.1 - mask_type: Annotated[str | None, Parameter(group="model")] + +class SciMTPConfig(MTPConfig): + """Configuration for Multi-Token Prediction (MTP). + + MTP extends the prediction scope to multiple future tokens at each position, + creating denser training signals and potentially improving data efficiency. + + This config only contains training-related hyperparameters. The actual + construction of MTP layers (including choosing Dense vs MoE decoder layers) + is handled by the model (Dense/MoE) which knows how to create the appropriate + decoder layers. + + Args: + name (str): Name of mtp module. + num_layers (int): Number of MTP layers (prediction depths). Each layer + predicts tokens at increasing future positions (i+1, i+2, ..., i+D). + share_weights (bool): Whether to share the weights of the MTP layers. + If True, the weights of the MTP layers are shared across all layers. + Default: False. + detach_mtp_lm_head_weight (bool): Whether to detach the LM head weight. + This is used in RL training. Default is False. + detach_mtp_inputs (bool): Whether to detach the input embeddings and hidden states. + This is used in RL training. Default is False. + loss_scaling_factor (float): Scaling factor for MTP loss. The total MTP loss + is computed as the average of losses across all depths, multiplied by + this factor. Default: 0.1. + mask_type (str | None): How to mask loss_kwargs when calculating loss. + + Example: + >>> # In model config + >>> config = TransformerConfig( + ... ..., + ... mtp_config=SciMTPConfig( + ... num_layers=2, + ... share_weights=True, + ... loss_scaling_factor=0.1, + ... ), + ... ) + """ + + mask_type: Annotated[str | None, Parameter(group="model")] = None From 7e80e95e9743727f999fa50633ed7addcafa2d20 Mon Sep 17 00:00:00 2001 From: x54-729 Date: Wed, 17 Jun 2026 20:06:00 +0800 Subject: [PATCH 03/14] remove SciMTPConfig; bind layer_idx before build mtp loss ctx --- xtuner/v1/loss/__init__.py | 5 +++- xtuner/v1/loss/mtp_loss.py | 12 ++++++-- xtuner/v1/model/moe/moe.py | 20 ++++++------- xtuner/v1/module/mtp/__init__.py | 4 +-- xtuner/v1/module/mtp/config.py | 51 ++++++-------------------------- 5 files changed, 35 insertions(+), 57 deletions(-) diff --git a/xtuner/v1/loss/__init__.py b/xtuner/v1/loss/__init__.py index d2f20b3a16..099d735640 100644 --- a/xtuner/v1/loss/__init__.py +++ b/xtuner/v1/loss/__init__.py @@ -10,7 +10,7 @@ ZLossContext, ZLossKwargs, ) -from .mtp_loss import MTPLossContext +from .mtp_loss import MTPLossContext, SciMTPLossContext, MTPLossConfig, SciMTPLossConfig from .rl_loss import LogProbConfig, LogProbContext @@ -31,6 +31,9 @@ "BaseLossKwargs", "LMHeadLossContext", "MTPLossContext", + "MTPLossConfig", + "SciMTPLossContext", + "SciMTPLossConfig", "LogProbConfig", "LogProbContext", ] diff --git a/xtuner/v1/loss/mtp_loss.py b/xtuner/v1/loss/mtp_loss.py index 57e86fa462..4895be2c69 100644 --- a/xtuner/v1/loss/mtp_loss.py +++ b/xtuner/v1/loss/mtp_loss.py @@ -53,14 +53,22 @@ class MTPLossConfig(CELossConfig): Args: mtp_depth (int): 1-indexed MTP layer depth. The first MTP layer uses - ``mtp_depth=1`` (shift=-1 on top of the existing label shift). + ``mtp_depth=1`` (shift=-1 on top of the existing label shift). Default: 1. detach_mtp_lm_head_weight (bool): Whether to detach the LM head weight. This is used in RL training. Default is False. """ - mtp_depth: int + mtp_depth: int = 1 detach_mtp_lm_head_weight: bool = False + def bind_mtp_depth(self, idx: int) -> None: + """Bind MTP depth to the given index. + + Args: + idx (int): 1-indexed MTP layer depth to bind. + """ + self.mtp_depth = idx + @property def loss_ctx_cls(self) -> type["MTPLossContext"]: return MTPLossContext diff --git a/xtuner/v1/model/moe/moe.py b/xtuner/v1/model/moe/moe.py index 367e41f6f1..0c55010a51 100644 --- a/xtuner/v1/model/moe/moe.py +++ b/xtuner/v1/model/moe/moe.py @@ -1,4 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +from __future__ import annotations + import os import types from pathlib import Path @@ -58,7 +60,7 @@ ) from xtuner.v1.module.decoder_layer.dense_decoder_layer import DenseDecoderLayer from xtuner.v1.module.decoder_layer.moe_decoder_layer import MoEActFnConfig, MoEBlock, MoEDecoderLayer -from xtuner.v1.module.mtp import MTPBlock, MTPConfig, MTPLayer, SciMTPConfig +from xtuner.v1.module.mtp import MTPBlock, MTPConfig, MTPLayer from xtuner.v1.utils import ( get_device, get_logger, @@ -347,20 +349,18 @@ def build_loss_ctx_batch( # type: ignore[override] # Each MTP depth needs its own loss context for mtp_config in self.config.mtp_config: for mtp_idx in range(mtp_config.num_layers): - # Create the appropriate loss config based on mtp_config type - if isinstance(mtp_config, SciMTPConfig): - mtp_loss_cfg = SciMTPLossConfig( - **self.config.lm_loss_cfg.model_dump(), - mtp_depth=mtp_idx + 1, - detach_mtp_lm_head_weight=mtp_config.detach_mtp_lm_head_weight, - mask_type=mtp_config.mask_type, - ) + # Get loss_cfg from mtp_config, or create a default one if not provided + if mtp_config.loss_cfg is not None: + # Create a copy to avoid modifying the original config + mtp_loss_cfg = mtp_config.loss_cfg.model_copy() else: + # Create default MTPLossConfig from model's lm_loss_cfg mtp_loss_cfg = MTPLossConfig( **self.config.lm_loss_cfg.model_dump(), - mtp_depth=mtp_idx + 1, detach_mtp_lm_head_weight=mtp_config.detach_mtp_lm_head_weight, ) + # Bind mtp_depth to current layer index + mtp_loss_cfg.bind_mtp_depth(mtp_idx + 1) # MTP needs to shift labels multiple times. Since rebuild the `shifted_labels` in data_batch mtp_loss_ctx_list = self._build_loss_ctx(mtp_loss_cfg, _data_batch, sp_mesh) if mtp_loss_ctx_list is not None: diff --git a/xtuner/v1/module/mtp/__init__.py b/xtuner/v1/module/mtp/__init__.py index 6a4bf0f6d6..8ced4cbaae 100644 --- a/xtuner/v1/module/mtp/__init__.py +++ b/xtuner/v1/module/mtp/__init__.py @@ -1,7 +1,7 @@ -from .config import MTPConfig, SciMTPConfig +from .config import MTPConfig from .mtp_block import MTPBlock from .mtp_layer import MTPLayer from .utils import roll_packed_tensor, roll_sequence_context -__all__ = ["MTPConfig", "SciMTPConfig", "MTPBlock", "MTPLayer", "roll_packed_tensor", "roll_sequence_context"] +__all__ = ["MTPConfig", "MTPBlock", "MTPLayer", "roll_packed_tensor", "roll_sequence_context"] diff --git a/xtuner/v1/module/mtp/config.py b/xtuner/v1/module/mtp/config.py index 1e8334d234..6926358125 100644 --- a/xtuner/v1/module/mtp/config.py +++ b/xtuner/v1/module/mtp/config.py @@ -1,10 +1,14 @@ """Configuration for Multi-Token Prediction (MTP).""" -from typing import Annotated +from __future__ import annotations + +from typing import TYPE_CHECKING, Annotated from cyclopts import Parameter from pydantic import BaseModel, ConfigDict +from xtuner.v1.loss.mtp_loss import MTPLossConfig + class MTPConfig(BaseModel): """Configuration for Multi-Token Prediction (MTP). @@ -31,6 +35,8 @@ class MTPConfig(BaseModel): loss_scaling_factor (float): Scaling factor for MTP loss. The total MTP loss is computed as the average of losses across all depths, multiplied by this factor. Default: 0.1. + loss_cfg (MTPLossConfig | None): Loss configuration for MTP. + If None, loss config will be constructed from MTPLossConfig(). Default: None. Example: >>> # In model config @@ -40,6 +46,7 @@ class MTPConfig(BaseModel): ... num_layers=2, ... share_weights=True, ... loss_scaling_factor=0.1, + ... loss_cfg=MTPLossConfig() ... ), ... ) """ @@ -52,44 +59,4 @@ class MTPConfig(BaseModel): detach_mtp_lm_head_weight: Annotated[bool, Parameter(group="model")] = False detach_mtp_inputs: Annotated[bool, Parameter(group="model")] = False loss_scaling_factor: Annotated[float, Parameter(group="model")] = 0.1 - -class SciMTPConfig(MTPConfig): - """Configuration for Multi-Token Prediction (MTP). - - MTP extends the prediction scope to multiple future tokens at each position, - creating denser training signals and potentially improving data efficiency. - - This config only contains training-related hyperparameters. The actual - construction of MTP layers (including choosing Dense vs MoE decoder layers) - is handled by the model (Dense/MoE) which knows how to create the appropriate - decoder layers. - - Args: - name (str): Name of mtp module. - num_layers (int): Number of MTP layers (prediction depths). Each layer - predicts tokens at increasing future positions (i+1, i+2, ..., i+D). - share_weights (bool): Whether to share the weights of the MTP layers. - If True, the weights of the MTP layers are shared across all layers. - Default: False. - detach_mtp_lm_head_weight (bool): Whether to detach the LM head weight. - This is used in RL training. Default is False. - detach_mtp_inputs (bool): Whether to detach the input embeddings and hidden states. - This is used in RL training. Default is False. - loss_scaling_factor (float): Scaling factor for MTP loss. The total MTP loss - is computed as the average of losses across all depths, multiplied by - this factor. Default: 0.1. - mask_type (str | None): How to mask loss_kwargs when calculating loss. - - Example: - >>> # In model config - >>> config = TransformerConfig( - ... ..., - ... mtp_config=SciMTPConfig( - ... num_layers=2, - ... share_weights=True, - ... loss_scaling_factor=0.1, - ... ), - ... ) - """ - - mask_type: Annotated[str | None, Parameter(group="model")] = None + loss_cfg: MTPLossConfig | None = None From 5045ea914f7438ab5d2472821107c052ec1dfe49 Mon Sep 17 00:00:00 2001 From: x54-729 Date: Thu, 18 Jun 2026 20:48:03 +0800 Subject: [PATCH 04/14] move bind_mtp_Depth to MTPLossContext --- xtuner/v1/loss/mtp_loss.py | 99 ++++++++++---------------------------- xtuner/v1/model/moe/moe.py | 10 ++-- 2 files changed, 33 insertions(+), 76 deletions(-) diff --git a/xtuner/v1/loss/mtp_loss.py b/xtuner/v1/loss/mtp_loss.py index 4895be2c69..d9f4433b85 100644 --- a/xtuner/v1/loss/mtp_loss.py +++ b/xtuner/v1/loss/mtp_loss.py @@ -58,17 +58,8 @@ class MTPLossConfig(CELossConfig): This is used in RL training. Default is False. """ - mtp_depth: int = 1 detach_mtp_lm_head_weight: bool = False - def bind_mtp_depth(self, idx: int) -> None: - """Bind MTP depth to the given index. - - Args: - idx (int): 1-indexed MTP layer depth to bind. - """ - self.mtp_depth = idx - @property def loss_ctx_cls(self) -> type["MTPLossContext"]: return MTPLossContext @@ -76,9 +67,9 @@ def loss_ctx_cls(self) -> type["MTPLossContext"]: @property def _loss_kwargs_cls(self) -> type["MTPLossKwargs"]: return MTPLossKwargs - - def _process_loss_kwargs(self, data: dict, sp_mesh: DeviceMesh | None = None) -> torch.Tensor: - """Process loss_kwargs for MTP Loss. + + def build(self, data: dict, sp_mesh: DeviceMesh | None = None) -> "MTPLossContext | None": + """Build MTPLossContext from data dict. Rolls ``shifted_labels`` (and optionally ``logprobs``) by ``-mtp_depth`` positions (per-sequence, respecting packed-sequence @@ -93,8 +84,10 @@ def _process_loss_kwargs(self, data: dict, sp_mesh: DeviceMesh | None = None) -> sp_mesh (DeviceMesh | None): Sequence parallel mesh. Returns: - torch.Tensor: loss_kwargs + MTPLossContext | None: Built loss context, or ``None`` if + ``shifted_labels`` is not present in ``data``. """ + # TODO: Should move the common utils function to public package to avoid from circular import. from xtuner.v1.module.mtp.utils import roll_packed_tensor @@ -103,6 +96,7 @@ def _process_loss_kwargs(self, data: dict, sp_mesh: DeviceMesh | None = None) -> shifted_labels = data["shifted_labels"] cu_seq_lens = data["seq_ctx"].cu_seq_lens_k + mtp_depth = data["mtp_depth"] # cu_seq_lens[-1] may be larger than shifted_labels.shape[-1] when seq_ctx # was split for sequence parallelism (padding is added to make the sequence @@ -119,7 +113,7 @@ def _process_loss_kwargs(self, data: dict, sp_mesh: DeviceMesh | None = None) -> ) shifted_labels = torch.cat([shifted_labels, pad], dim=-1) - rolled = roll_packed_tensor(shifted_labels, cu_seq_lens, shifts=-self.mtp_depth, dim=-1, fill_value=-100) + rolled = roll_packed_tensor(shifted_labels, cu_seq_lens, shifts=-mtp_depth, dim=-1, fill_value=-100) # Roll logprobs by the same amount as shifted_labels logprobs = data.get("logprobs", None) @@ -133,7 +127,7 @@ def _process_loss_kwargs(self, data: dict, sp_mesh: DeviceMesh | None = None) -> device=logprobs.device, ) logprobs = torch.cat([logprobs, rp_pad], dim=-1) - rolled_logprobs = roll_packed_tensor(logprobs, cu_seq_lens, shifts=-self.mtp_depth, dim=-1, fill_value=0) + rolled_logprobs = roll_packed_tensor(logprobs, cu_seq_lens, shifts=-mtp_depth, dim=-1, fill_value=0) loss_kwargs = MTPLossKwargs( shifted_labels=rolled, @@ -142,33 +136,9 @@ def _process_loss_kwargs(self, data: dict, sp_mesh: DeviceMesh | None = None) -> if sp_mesh is not None and sp_mesh.size() > 1: loss_kwargs = loss_kwargs.sp_split(sp_mesh) - return loss_kwargs - - def build(self, data: dict, sp_mesh: DeviceMesh | None = None) -> "MTPLossContext | None": - """Build MTPLossContext from data dict. - - Rolls ``shifted_labels`` (and optionally ``logprobs``) by - ``-mtp_depth`` positions (per-sequence, respecting packed-sequence - boundaries) before constructing the loss context. The roll is performed - on the full sequence prior to any sequence-parallel split so that - boundary positions and ``cu_seq_lens`` are always consistent. - - Args: - data (dict): Data dict containing loss-related fields. - Required keys: ``shifted_labels``, ``seq_ctx``. - Optional keys: ``logprobs``. - sp_mesh (DeviceMesh | None): Sequence parallel mesh. - - Returns: - MTPLossContext | None: Built loss context, or ``None`` if - ``shifted_labels`` is not present in ``data``. - """ - - if "shifted_labels" not in data: - return None - - loss_kwargs = self._process_loss_kwargs(data, sp_mesh) - return MTPLossContext(self, loss_kwargs) + loss_context = self.loss_ctx_cls(self, loss_kwargs) + loss_context.bind_mtp_depth(mtp_depth) + return loss_context class SciMTPLossConfig(MTPLossConfig): @@ -190,36 +160,6 @@ class SciMTPLossConfig(MTPLossConfig): def loss_ctx_cls(self) -> type["SciMTPLossContext"]: return SciMTPLossContext - @property - def _loss_kwargs_cls(self) -> type["MTPLossKwargs"]: - return MTPLossKwargs - - def build(self, data: dict, sp_mesh: DeviceMesh | None = None) -> "SciMTPLossContext | None": - """Build SciMTPLossContext from data dict. - - Rolls ``shifted_labels`` (and optionally ``logprobs``) by - ``-mtp_depth`` positions (per-sequence, respecting packed-sequence - boundaries) before constructing the loss context. The roll is performed - on the full sequence prior to any sequence-parallel split so that - boundary positions and ``cu_seq_lens`` are always consistent. - - Args: - data (dict): Data dict containing loss-related fields. - Required keys: ``shifted_labels``, ``seq_ctx``. - Optional keys: ``logprobs``. - sp_mesh (DeviceMesh | None): Sequence parallel mesh. - - Returns: - SciMTPLossContext | None: Built loss context, or ``None`` if - ``shifted_labels`` is not present in ``data``. - """ - - if "shifted_labels" not in data: - return None - - loss_kwargs = self._process_loss_kwargs(data, sp_mesh) - return SciMTPLossContext(self, loss_kwargs) - class MTPLossContext(LMHeadLossContext): """Loss context for Multi-Token Prediction (MTP). @@ -239,6 +179,10 @@ class MTPLossContext(LMHeadLossContext): loss_kwargs (MTPLossKwargs): Pre-rolled keyword arguments for loss computation. """ + def __init__(self, loss_cfg: MTPLossConfig, loss_kwargs: MTPLossKwargs): + super().__init__(loss_cfg, loss_kwargs) + + self.mtp_depth = None def forward( self, @@ -246,6 +190,7 @@ def forward( head_weight: torch.Tensor, head_bias: torch.Tensor | None = None, ) -> tuple[torch.Tensor, tuple[torch.Tensor | None, dict[str, Any]]]: + assert self.mtp_depth is not None, "Please bind mtp depth for MTPLossContext!" if self.loss_cfg.detach_mtp_lm_head_weight: head_weight = head_weight.detach() head_bias = head_bias.detach() if head_bias is not None else None @@ -298,6 +243,14 @@ def _kl_loss_fn( return kl_loss, (None, {}) + def bind_mtp_depth(self, idx: int, sp_mesh: DeviceMesh | None = None) -> None: + """Bind MTP depth to the given index. + + Args: + idx (int): 1-indexed MTP layer depth to bind. + """ + self.mtp_depth = idx + class SciMTPLossContext(MTPLossContext): """Loss context for Science Multi-Token Prediction (MTP). @@ -333,7 +286,7 @@ def forward( return super().forward(hidden_states, head_weight, head_bias) def process_loss_weight_v1(self): - layer_idx = self.loss_cfg.mtp_depth - 1 + layer_idx = self.mtp_depth - 1 shifted_labels = self.loss_kwargs.shifted_labels loss_weight = self.loss_kwargs.loss_weight sum_loss_weight = loss_weight.sum() diff --git a/xtuner/v1/model/moe/moe.py b/xtuner/v1/model/moe/moe.py index 0c55010a51..5c3ec5605e 100644 --- a/xtuner/v1/model/moe/moe.py +++ b/xtuner/v1/model/moe/moe.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from __future__ import annotations +import copy import os import types from pathlib import Path @@ -359,10 +360,13 @@ def build_loss_ctx_batch( # type: ignore[override] **self.config.lm_loss_cfg.model_dump(), detach_mtp_lm_head_weight=mtp_config.detach_mtp_lm_head_weight, ) - # Bind mtp_depth to current layer index - mtp_loss_cfg.bind_mtp_depth(mtp_idx + 1) + + # copy data_batch to insert mtp_depth + _new_data_batch = copy.copy(_data_batch) + for _data in _new_data_batch: + _data["mtp_depth"] = mtp_idx + 1 # MTP needs to shift labels multiple times. Since rebuild the `shifted_labels` in data_batch - mtp_loss_ctx_list = self._build_loss_ctx(mtp_loss_cfg, _data_batch, sp_mesh) + mtp_loss_ctx_list = self._build_loss_ctx(mtp_loss_cfg, _new_data_batch, sp_mesh) if mtp_loss_ctx_list is not None: mtp_loss_ctx_list = type(mtp_loss_ctx_list[0]).build_batches( # type: ignore[assignment] mtp_loss_ctx_list, # type: ignore[arg-type] From 92f4ff967a83911fda78120b1e348625f9fa852e Mon Sep 17 00:00:00 2001 From: x54-729 Date: Mon, 22 Jun 2026 15:08:12 +0800 Subject: [PATCH 05/14] record mtp_name when save qwen3_5 to hf --- xtuner/v1/model/moe/qwen3_5_text.py | 39 ++++++++++++++++++++--------- 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/xtuner/v1/model/moe/qwen3_5_text.py b/xtuner/v1/model/moe/qwen3_5_text.py index e4cf2e2fc3..067c7b9961 100644 --- a/xtuner/v1/model/moe/qwen3_5_text.py +++ b/xtuner/v1/model/moe/qwen3_5_text.py @@ -45,19 +45,28 @@ class Qwen3_5_VLTextMoE(Qwen3VLTextMoE): def to_hf_key_list(self, key: str) -> list[str]: # Handle MTP parameters if key.startswith("mtp_block."): - # Remove "mtp_block." prefix - key = key.replace("mtp_block.", "", 1) + + # Extract MTP name from mtp_block.{mtp_name}.{rest} + match = re.match(r"mtp_block\.(normal|sci)\.(.*)", key) + if not match: + raise ValueError( + f"Invalid mtp_block key format: {key}. " + f"Expected 'mtp_block.normal.*' or 'mtp_block.sci.*'" + ) + + mtp_name = match.group(1) + key = match.group(2) # Handle MTP layer-specific parameters - # xtuner: mtp_block.layers.{idx}.decoder_layer.{param} - # HF: mtp.layers.{idx}.{param} + # xtuner: mtp_block.{mtp_name}.layers.{idx}.decoder_layer.{param} + # HF normal: mtp.layers.{idx}.{param} + # HF sci: mtp.sci.layers.{idx}.{param} key = re.sub(r"layers\.(\d+)\.decoder_layer\.", r"layers.\1.", key) # Handle MTP normalization layers - # xtuner: mtp_block.layers.{idx}.enorm -> HF: mtp.pre_fc_norm_embedding - # xtuner: mtp_block.layers.{idx}.hnorm -> HF: mtp.pre_fc_norm_hidden - # xtuner: mtp_block.layers.{idx}.final_layernorm -> HF: mtp.norm - # Note: Currently assuming single MTP layer (idx=0), may need adjustment for multiple layers + # xtuner: mtp_block.{mtp_name}.layers.{idx}.enorm -> HF: mtp[.sci].pre_fc_norm_embedding + # xtuner: mtp_block.{mtp_name}.layers.{idx}.hnorm -> HF: mtp[.sci].pre_fc_norm_hidden + # xtuner: mtp_block.{mtp_name}.layers.{idx}.final_layernorm -> HF: mtp[.sci].norm if ".enorm." in key: key = re.sub(r"layers\.\d+\.enorm\.", "pre_fc_norm_embedding.", key) elif ".hnorm." in key: @@ -66,7 +75,7 @@ def to_hf_key_list(self, key: str) -> list[str]: key = re.sub(r"layers\.\d+\.final_layernorm\.", "norm.", key) # Handle MTP projection layer - # xtuner: mtp_block.layers.{idx}.eh_proj -> HF: mtp.fc + # xtuner: mtp_block.{mtp_name}.layers.{idx}.eh_proj -> HF: mtp.{mtp_name}.fc if ".eh_proj." in key: key = re.sub(r"layers\.\d+\.eh_proj\.", "fc.", key) @@ -74,6 +83,12 @@ def to_hf_key_list(self, key: str) -> list[str]: key = re.sub(r"layers\.(\d+)\.(experts|gate|shared_experts|shared_expert_gate)", r"layers.\1.mlp.\2", key) key = key.replace("shared_experts", "shared_expert") + # Determine HF prefix based on mtp_name + # Normal MTP (mtp_block.normal.*): mtp.{key} + # Science MTP (mtp_block.sci.*): mtp.sci.{key} + # TODO: normal mtp prefix + hf_prefix = "mtp." if mtp_name == "normal" else f"mtp.{mtp_name}." + # Handle fused weights n_routed_experts = self.config.n_routed_experts if "fused_w1w3.weight" in key: @@ -83,15 +98,15 @@ def to_hf_key_list(self, key: str) -> list[str]: w1w3_keys.append(key.replace("fused_w1w3.weight", f"{i}.gate_proj.weight")) w1w3_keys.append(key.replace("fused_w1w3.weight", f"{i}.up_proj.weight")) - return [f"mtp.{key}" for key in w1w3_keys] + return [f"{hf_prefix}{key}" for key in w1w3_keys] elif "fused_w2.weight" in key: w2_keys: list[str] = [] for i in range(n_routed_experts): w2_keys.append(key.replace("fused_w2.weight", f"{i}.down_proj.weight")) - return [f"mtp.{key}" for key in w2_keys] + return [f"{hf_prefix}{key}" for key in w2_keys] else: - return ["mtp." + key] + return [hf_prefix + key] # Handle main model parameters if "layers" in key or "embed_tokens" in key: From 077eb54bce1cb78a8bf2c2336b5b675582d6180b Mon Sep 17 00:00:00 2001 From: x54-729 Date: Mon, 22 Jun 2026 15:10:12 +0800 Subject: [PATCH 06/14] fix reshard_after_forward judge using global_mtp_idx --- xtuner/v1/model/moe/moe.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xtuner/v1/model/moe/moe.py b/xtuner/v1/model/moe/moe.py index 5c3ec5605e..24fef27603 100644 --- a/xtuner/v1/model/moe/moe.py +++ b/xtuner/v1/model/moe/moe.py @@ -1185,6 +1185,7 @@ def fully_shard( # Shard MTP block if it exists if self.mtp_block is not None: + total_mtp_layers = sum([len(mtp_block.layers) for mtp_name, mtp_block in self.mtp_block.items()]) global_mtp_idx = 0 # Track global MTP layer index across all mtp_configs for mtp_name in self.mtp_block.keys(): mtp_block = self.mtp_block[mtp_name] @@ -1196,7 +1197,7 @@ def fully_shard( mtp_layer = checkpoint_wrapper(mtp_layer, checkpoint_impl=CheckpointImpl.REENTRANT) mtp_block.layers[local_mtp_idx] = mtp_layer - reshard_after_forward = local_mtp_idx != len(mtp_block.layers) - 1 + reshard_after_forward = global_mtp_idx != total_mtp_layers - 1 self._fully_shard( mesh=self.fsdp_mesh if self.hsdp_mesh is None else self.hsdp_mesh, mp_policy=mp_policy, From 19cbccaa175d702531594230bff7fa5951275f21 Mon Sep 17 00:00:00 2001 From: x54-729 Date: Mon, 22 Jun 2026 15:21:33 +0800 Subject: [PATCH 07/14] remove mtp_loss if comment since total loss is calculated in forward --- xtuner/v1/engine/train_engine.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/xtuner/v1/engine/train_engine.py b/xtuner/v1/engine/train_engine.py index 6021ff7794..74a56d4643 100644 --- a/xtuner/v1/engine/train_engine.py +++ b/xtuner/v1/engine/train_engine.py @@ -569,9 +569,6 @@ def _get_total_loss(self, model_outputs: ModelOutputs) -> torch.Tensor: loss = torch.tensor(0.0, device=DEVICE) for key in model_outputs.model_fields: value = getattr(model_outputs, key) - if key == "mtp_loss" and isinstance(value, dict): - for mtp_loss_name, mtp_loss in value.items(): - loss += mtp_loss - elif "loss" in key and isinstance(value, torch.Tensor): + if "loss" in key and isinstance(value, torch.Tensor): loss += value return loss From be15d566a58235983028fb90657697560864a117 Mon Sep 17 00:00:00 2001 From: x54-729 Date: Mon, 22 Jun 2026 15:22:02 +0800 Subject: [PATCH 08/14] rename param name from idx to depth in bind_mtp_depth --- xtuner/v1/loss/mtp_loss.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/xtuner/v1/loss/mtp_loss.py b/xtuner/v1/loss/mtp_loss.py index d9f4433b85..5795d54910 100644 --- a/xtuner/v1/loss/mtp_loss.py +++ b/xtuner/v1/loss/mtp_loss.py @@ -148,10 +148,9 @@ class SciMTPLossConfig(MTPLossConfig): ``loss_kwargs`` when calculating loss. Args: - mtp_depth (int): 1-indexed MTP layer depth. The first MTP layer uses - ``mtp_depth=1`` (shift=-1 on top of the existing label shift). detach_mtp_lm_head_weight (bool): Whether to detach the LM head weight. This is used in RL training. Default is False. + mask_type (str | None): Mask method when calculating Science MTP. """ mask_type: Optional[str] = None @@ -243,13 +242,13 @@ def _kl_loss_fn( return kl_loss, (None, {}) - def bind_mtp_depth(self, idx: int, sp_mesh: DeviceMesh | None = None) -> None: + def bind_mtp_depth(self, depth: int) -> None: """Bind MTP depth to the given index. Args: idx (int): 1-indexed MTP layer depth to bind. """ - self.mtp_depth = idx + self.mtp_depth = depth class SciMTPLossContext(MTPLossContext): From da6293ae152b13112d69fd828d3ddf1f3d14e950 Mon Sep 17 00:00:00 2001 From: x54-729 Date: Mon, 22 Jun 2026 15:34:16 +0800 Subject: [PATCH 09/14] remove _mtp_forward in moe --- xtuner/v1/model/moe/moe.py | 117 ++++++++++++++----------------------- 1 file changed, 45 insertions(+), 72 deletions(-) diff --git a/xtuner/v1/model/moe/moe.py b/xtuner/v1/model/moe/moe.py index 24fef27603..1896b357f7 100644 --- a/xtuner/v1/model/moe/moe.py +++ b/xtuner/v1/model/moe/moe.py @@ -691,69 +691,6 @@ def _micro_batch_forward( return MoEModelOutputs(**output, logits=logits) - def _mtp_forward( - self, - mtp_config: MTPConfig, - output, - layer_hidden_states, - position_embeddings, - balancing_ctx, - z_ctx, - mtp_seq_ctx, - mtp_loss_ctx_dict, - keep_router: bool, - ): - # MTP uses its own mask; main mask's non-pad indices do not apply. - name = mtp_config.name - mtp_nonpad_indices = torch.nonzero(mtp_seq_ctx.mask, as_tuple=True)[1] - mtp_non_pad_token = mtp_nonpad_indices.numel() - mtp_num_tokens_global, mtp_z_world_size = self._z_loss_dist_token_count( - z_ctx, mtp_non_pad_token, mtp_seq_ctx.mask.device - ) - - # Forward through MTP block - mtp_outputs = self.mtp_block[name]( - layer_hidden_states, - embed_tokens_fn=self.embed_tokens, - position_embeddings=position_embeddings, - seq_ctx=mtp_seq_ctx, - ) - - # Compute MTP losses for each depth - mtp_losses = torch.tensor(0.0, device=DEVICE) - mtp_loss_ctx_list = mtp_loss_ctx_dict[name] - for idx, (mtp_hidden, mtp_ctx) in enumerate(zip(mtp_outputs, mtp_loss_ctx_list)): - mtp_hidden_states, mtp_router_results, mtp_router_weights = mtp_hidden - - if keep_router: - output["router_logits"][f"{name}_mtp_layer{idx}"] = mtp_router_results - output["router_weights"][f"{name}_mtp_layer{idx}"] = mtp_router_weights - # Inject this MTP layer's z-loss before lm_head so backward through mtp_loss - # traverses the AuxLossScaler node and releases this layer's logsumexp activations. - mtp_hidden_states = self.aux_loss.accumulate( - selected_router_weights=mtp_router_weights.index_select(0, mtp_nonpad_indices) - .contiguous() - .float(), - selected_router_logits=mtp_router_results.index_select(0, mtp_nonpad_indices).contiguous().float(), - hidden_states=mtp_hidden_states, - balancing_ctx=balancing_ctx, - z_ctx=z_ctx, - num_tokens_local=mtp_non_pad_token, - num_tokens_global=mtp_num_tokens_global, - world_size=mtp_z_world_size, - ) - mtp_loss, _ = self.lm_head(mtp_hidden_states, cast(MTPLossContext, mtp_ctx)) - mtp_losses += mtp_loss - - # Average MTP losses across depths and scale - mtp_losses = mtp_losses / len(mtp_loss_ctx_list) - scaled_mtp_loss = mtp_losses * mtp_config.loss_scaling_factor # type: ignore - - # Add to total loss - output["mtp_loss"][name] = scaled_mtp_loss - - return scaled_mtp_loss - def _forward( self, seq_ctx: SequenceContext, # todo(@yehaochen): support intra layer micro-batch @@ -870,18 +807,54 @@ def _forward( ) for mtp_config in self.config.mtp_config: - self._mtp_forward( - mtp_config=mtp_config, - output=output, - layer_hidden_states=layer_hidden_states, + name = mtp_config.name + mtp_nonpad_indices = torch.nonzero(mtp_seq_ctx.mask, as_tuple=True)[1] + mtp_non_pad_token = mtp_nonpad_indices.numel() + mtp_num_tokens_global, mtp_z_world_size = self._z_loss_dist_token_count( + z_ctx, mtp_non_pad_token, mtp_seq_ctx.mask.device + ) + + # Forward through MTP block + mtp_outputs = self.mtp_block[name]( + layer_hidden_states, + embed_tokens_fn=self.embed_tokens, position_embeddings=position_embeddings, - balancing_ctx=balancing_ctx, - z_ctx=z_ctx, - mtp_seq_ctx=mtp_seq_ctx, - mtp_loss_ctx_dict=mtp_loss_ctx_dict, - keep_router=keep_router, + seq_ctx=mtp_seq_ctx, ) + # Compute MTP losses for each depth + mtp_losses = torch.tensor(0.0, device=DEVICE) + mtp_loss_ctx_list = mtp_loss_ctx_dict[name] + for idx, (mtp_hidden, mtp_ctx) in enumerate(zip(mtp_outputs, mtp_loss_ctx_list)): + mtp_hidden_states, mtp_router_results, mtp_router_weights = mtp_hidden + + if keep_router: + output["router_logits"][f"{name}_mtp_layer{idx}"] = mtp_router_results + output["router_weights"][f"{name}_mtp_layer{idx}"] = mtp_router_weights + # Inject this MTP layer's z-loss before lm_head so backward through mtp_loss + # traverses the AuxLossScaler node and releases this layer's logsumexp activations. + mtp_hidden_states = self.aux_loss.accumulate( + selected_router_weights=mtp_router_weights.index_select(0, mtp_nonpad_indices) + .contiguous() + .float(), + selected_router_logits=mtp_router_results.index_select(0, mtp_nonpad_indices).contiguous().float(), + hidden_states=mtp_hidden_states, + balancing_ctx=balancing_ctx, + z_ctx=z_ctx, + num_tokens_local=mtp_non_pad_token, + num_tokens_global=mtp_num_tokens_global, + world_size=mtp_z_world_size, + ) + mtp_loss, _ = self.lm_head(mtp_hidden_states, cast(MTPLossContext, mtp_ctx)) + mtp_losses += mtp_loss + + # Average MTP losses across depths and scale + mtp_losses = mtp_losses / len(mtp_loss_ctx_list) + scaled_mtp_loss = mtp_losses * mtp_config.loss_scaling_factor # type: ignore + + # Add to total loss + output["mtp_loss"][name] = scaled_mtp_loss + # add mtp_loss to loss for mtp_loss_name, mtp_loss in output["mtp_loss"].items(): output["loss"] += mtp_loss From 5c5bba49026d51ecaaedbeac07443b69dce605cb Mon Sep 17 00:00:00 2001 From: x54-729 Date: Mon, 22 Jun 2026 15:36:29 +0800 Subject: [PATCH 10/14] small fix in moe.py --- xtuner/v1/model/moe/moe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xtuner/v1/model/moe/moe.py b/xtuner/v1/model/moe/moe.py index 1896b357f7..e54469dfbb 100644 --- a/xtuner/v1/model/moe/moe.py +++ b/xtuner/v1/model/moe/moe.py @@ -1162,13 +1162,13 @@ def fully_shard( global_mtp_idx = 0 # Track global MTP layer index across all mtp_configs for mtp_name in self.mtp_block.keys(): mtp_block = self.mtp_block[mtp_name] - mtp_config = next((cfg for cfg in self.config.mtp_config if cfg.name == mtp_name), None) # type: ignore + mtp_config = mtp_block.mtp_config for local_mtp_idx, mtp_layer in enumerate(mtp_block.layers): if self._should_recompute(None, mtp_idx=global_mtp_idx) or ( mtp_config is not None and mtp_config.share_weights ): # share mtp head must recompute mtp_layer = checkpoint_wrapper(mtp_layer, checkpoint_impl=CheckpointImpl.REENTRANT) - mtp_block.layers[local_mtp_idx] = mtp_layer + mtp_block.layers[local_mtp_idx] = mtp_layer reshard_after_forward = global_mtp_idx != total_mtp_layers - 1 self._fully_shard( From be6b91cf13f2b5a5d28840a83b16dbe3d6a1e9f8 Mon Sep 17 00:00:00 2001 From: x54-729 Date: Mon, 22 Jun 2026 16:28:28 +0800 Subject: [PATCH 11/14] fix bind_layer_idx doc --- xtuner/v1/loss/mtp_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xtuner/v1/loss/mtp_loss.py b/xtuner/v1/loss/mtp_loss.py index 5795d54910..6c6e2bf2d8 100644 --- a/xtuner/v1/loss/mtp_loss.py +++ b/xtuner/v1/loss/mtp_loss.py @@ -246,7 +246,7 @@ def bind_mtp_depth(self, depth: int) -> None: """Bind MTP depth to the given index. Args: - idx (int): 1-indexed MTP layer depth to bind. + depth (int): 1-indexed MTP layer depth to bind. """ self.mtp_depth = depth From 6e079d3995f4525085c602a28b43f6af46b148af Mon Sep 17 00:00:00 2001 From: x54-729 Date: Mon, 22 Jun 2026 19:22:08 +0800 Subject: [PATCH 12/14] indet mtp_loss --- xtuner/v1/model/moe/moe.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/xtuner/v1/model/moe/moe.py b/xtuner/v1/model/moe/moe.py index e54469dfbb..c9a18be752 100644 --- a/xtuner/v1/model/moe/moe.py +++ b/xtuner/v1/model/moe/moe.py @@ -577,6 +577,7 @@ def _micro_batch_forward( assert hidden_states_list, "XTuner Internal Error, found empty hidden states for domino EP" + total_mtp_loss = 0 if self.mtp_block is not None: assert self.config.mtp_config is not None @@ -642,9 +643,8 @@ def _micro_batch_forward( if mtp_losses_dict: output["mtp_loss"] = mtp_losses_dict - mtp_loss = 0 - for mtp_loss_name, mtp_loss in output["mtp_loss"].items(): - mtp_loss += mtp_loss + for mtp_loss_name, mtp_loss in output["mtp_loss"].items(): + total_mtp_loss += mtp_loss # Apply final norm to all micro-batches cat_hidden_states = torch.cat(hidden_states_list, dim=1) @@ -657,7 +657,7 @@ def _micro_batch_forward( loss, (logits, extra_info) = self.lm_head(cat_hidden_states, cast(LMHeadLossContext, cat_loss_ctx)) # Aggregate losses (mean across micro-batches) - output["loss"] = loss.sum() + mtp_loss + output["loss"] = loss.sum() + total_mtp_loss moe_extra_info = ModelForwardExtraLogInfo() if extra_info: moe_extra_info.append(extra_info) @@ -855,9 +855,9 @@ def _forward( # Add to total loss output["mtp_loss"][name] = scaled_mtp_loss - # add mtp_loss to loss - for mtp_loss_name, mtp_loss in output["mtp_loss"].items(): - output["loss"] += mtp_loss + # add mtp_loss to loss + for mtp_loss_name, mtp_loss in output["mtp_loss"].items(): + output["loss"] += mtp_loss split_aux_output = self.aux_loss.finalize( balancing_ctx=balancing_ctx, From 86971dd5c45c7bc0dc83e364e90afcd80c47d84b Mon Sep 17 00:00:00 2001 From: x54-729 Date: Mon, 22 Jun 2026 20:26:22 +0800 Subject: [PATCH 13/14] Change mtp_block from ModuleDict to ModuleList --- xtuner/v1/model/moe/moe.py | 47 +++++++++++++++-------------- xtuner/v1/model/moe/qwen3_5_text.py | 7 +++-- 2 files changed, 28 insertions(+), 26 deletions(-) diff --git a/xtuner/v1/model/moe/moe.py b/xtuner/v1/model/moe/moe.py index c9a18be752..8fd4442cce 100644 --- a/xtuner/v1/model/moe/moe.py +++ b/xtuner/v1/model/moe/moe.py @@ -210,7 +210,7 @@ def __init__(self, config: MoEConfig): self.layers = self.build_layers(config) self.rotary_emb = self.build_rotary_embedding(config) self.embed_tokens = self.build_embeddings(config) - self.mtp_block = self.build_mtp_block_dict(config) if config.mtp_config is not None else None + self.mtp_block = self.build_mtp_block_list(config) if config.mtp_config is not None else None self.fp32_layers = [self.rotary_emb] @@ -600,11 +600,12 @@ def _micro_batch_forward( mtp_losses_dict: dict[str, torch.Tensor] = {} # Loop through each mtp_config - for mtp_config in self.config.mtp_config: + for mtp_block in self.mtp_block: + mtp_config = mtp_block.mtp_config name = mtp_config.name # Get the MTP block for this config by name - mtp_outputs_per_mb = self.mtp_block[name]( + mtp_outputs_per_mb = mtp_block( *hidden_states_list, embed_tokens_fn=self.embed_tokens, position_embeddings=position_embeddings_list, @@ -806,7 +807,8 @@ def _forward( inputs_embeds=seq_ctx.inputs_embeds.clone() if seq_ctx.inputs_embeds is not None else None, ) - for mtp_config in self.config.mtp_config: + for mtp_block in self.mtp_block: + mtp_config = mtp_block.mtp_config name = mtp_config.name mtp_nonpad_indices = torch.nonzero(mtp_seq_ctx.mask, as_tuple=True)[1] mtp_non_pad_token = mtp_nonpad_indices.numel() @@ -815,7 +817,7 @@ def _forward( ) # Forward through MTP block - mtp_outputs = self.mtp_block[name]( + mtp_outputs = mtp_block( layer_hidden_states, embed_tokens_fn=self.embed_tokens, position_embeddings=position_embeddings, @@ -950,24 +952,26 @@ def build_layers(self, config: MoEConfig) -> nn.ModuleDict: layers.__class__.__repr__ = module_dict_repr # type: ignore[method-assign] return layers - def build_mtp_block_dict(self, config): - mtp_block_dict = nn.ModuleDict() + def build_mtp_block_list(self, config): + mtp_block_list = [] layer_idx_offset = 0 # Cumulative offset for layer indices across all mtp_configs + mtp_name_list = [] for mtp_config in config.mtp_config: if mtp_config.name not in ("normal", "sci"): raise ValueError(f"Expected mtp keys to be either `normal` or `sci`, but got `{mtp_config.name}`") - if mtp_config.name in mtp_block_dict.keys(): + if mtp_config.name in mtp_name_list: raise ValueError(f"Duplicate mtp name: `{mtp_config.name}`") + mtp_name_list.append(mtp_config.name) # Build the MTP block with the current offset - mtp_block_dict[mtp_config.name] = self.build_mtp_block(config, mtp_config, layer_idx_offset) + mtp_block_list.append(self.build_mtp_block(config, mtp_config, layer_idx_offset)) # Update offset: number of physical layers for this mtp_config num_physical_layer = 1 if mtp_config.share_weights else mtp_config.num_layers layer_idx_offset += num_physical_layer - return mtp_block_dict + return nn.ModuleList(mtp_block_list) def build_mtp_block(self, config: MoEConfig, mtp_config: MTPConfig, layer_idx_offset: int) -> MTPBlock: """Build MTP block with MoE decoder layers. @@ -1158,10 +1162,10 @@ def fully_shard( # Shard MTP block if it exists if self.mtp_block is not None: - total_mtp_layers = sum([len(mtp_block.layers) for mtp_name, mtp_block in self.mtp_block.items()]) + total_mtp_layers = sum([len(mtp_block.layers) for mtp_block in self.mtp_block]) global_mtp_idx = 0 # Track global MTP layer index across all mtp_configs - for mtp_name in self.mtp_block.keys(): - mtp_block = self.mtp_block[mtp_name] + mtp_block_layers = [] + for mtp_block in self.mtp_block: mtp_config = mtp_block.mtp_config for local_mtp_idx, mtp_layer in enumerate(mtp_block.layers): if self._should_recompute(None, mtp_idx=global_mtp_idx) or ( @@ -1183,16 +1187,13 @@ def fully_shard( layer_next.set_modules_to_forward_prefetch([mtp_layer]) # type: ignore global_mtp_idx += 1 - # Set up prefetch chains across all MTP blocks - if self.config.mtp_config is not None: - mtp_block_layers = [] - for mtp_config in self.config.mtp_config: - mtp_block_layers.extend(list(self.mtp_block[mtp_config.name].layers)) - for prev_mtp_layer, next_mtp_layer in zip( - mtp_block_layers[:-1], - mtp_block_layers[1:], - ): - prev_mtp_layer.set_modules_to_forward_prefetch([next_mtp_layer]) # type: ignore + mtp_block_layers.extend(list(mtp_block.layers)) + + for prev_mtp_layer, next_mtp_layer in zip( + mtp_block_layers[:-1], + mtp_block_layers[1:], + ): + prev_mtp_layer.set_modules_to_forward_prefetch([next_mtp_layer]) # type: ignore self._fully_shard( mesh=self.fsdp_mesh if self.hsdp_mesh is None else self.hsdp_mesh, diff --git a/xtuner/v1/model/moe/qwen3_5_text.py b/xtuner/v1/model/moe/qwen3_5_text.py index 067c7b9961..94fb488051 100644 --- a/xtuner/v1/model/moe/qwen3_5_text.py +++ b/xtuner/v1/model/moe/qwen3_5_text.py @@ -47,14 +47,15 @@ def to_hf_key_list(self, key: str) -> list[str]: if key.startswith("mtp_block."): # Extract MTP name from mtp_block.{mtp_name}.{rest} - match = re.match(r"mtp_block\.(normal|sci)\.(.*)", key) + match = re.match(r"mtp_block\.(\d+)\.(.*)", key) if not match: raise ValueError( f"Invalid mtp_block key format: {key}. " - f"Expected 'mtp_block.normal.*' or 'mtp_block.sci.*'" + f"Expected 'mtp_block.{{idx}}.*" ) - mtp_name = match.group(1) + mtp_idx = int(match.group(1)) + mtp_name = self.config.mtp_config[mtp_idx].name key = match.group(2) # Handle MTP layer-specific parameters From 0bcbacef1b99127a607c4ef2817664645399fa4c Mon Sep 17 00:00:00 2001 From: x54-729 Date: Mon, 22 Jun 2026 21:27:09 +0800 Subject: [PATCH 14/14] remove mtp_loss sum in moe.py; change total_loss sum in train_engine.py --- xtuner/v1/engine/train_engine.py | 7 +++++-- xtuner/v1/model/moe/moe.py | 10 +--------- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/xtuner/v1/engine/train_engine.py b/xtuner/v1/engine/train_engine.py index 74a56d4643..bb0d3b5441 100644 --- a/xtuner/v1/engine/train_engine.py +++ b/xtuner/v1/engine/train_engine.py @@ -569,6 +569,9 @@ def _get_total_loss(self, model_outputs: ModelOutputs) -> torch.Tensor: loss = torch.tensor(0.0, device=DEVICE) for key in model_outputs.model_fields: value = getattr(model_outputs, key) - if "loss" in key and isinstance(value, torch.Tensor): - loss += value + if "loss" in key: + loss_values = list(value.values()) if isinstance(value, dict) else [value] + loss_values = [i for i in loss_values if isinstance(i, torch.Tensor)] + for value in loss_values: + loss += value return loss diff --git a/xtuner/v1/model/moe/moe.py b/xtuner/v1/model/moe/moe.py index 8fd4442cce..50969f8d64 100644 --- a/xtuner/v1/model/moe/moe.py +++ b/xtuner/v1/model/moe/moe.py @@ -577,7 +577,6 @@ def _micro_batch_forward( assert hidden_states_list, "XTuner Internal Error, found empty hidden states for domino EP" - total_mtp_loss = 0 if self.mtp_block is not None: assert self.config.mtp_config is not None @@ -644,9 +643,6 @@ def _micro_batch_forward( if mtp_losses_dict: output["mtp_loss"] = mtp_losses_dict - for mtp_loss_name, mtp_loss in output["mtp_loss"].items(): - total_mtp_loss += mtp_loss - # Apply final norm to all micro-batches cat_hidden_states = torch.cat(hidden_states_list, dim=1) cat_hidden_states = self.norm(cat_hidden_states) @@ -658,7 +654,7 @@ def _micro_batch_forward( loss, (logits, extra_info) = self.lm_head(cat_hidden_states, cast(LMHeadLossContext, cat_loss_ctx)) # Aggregate losses (mean across micro-batches) - output["loss"] = loss.sum() + total_mtp_loss + output["loss"] = loss.sum() moe_extra_info = ModelForwardExtraLogInfo() if extra_info: moe_extra_info.append(extra_info) @@ -857,10 +853,6 @@ def _forward( # Add to total loss output["mtp_loss"][name] = scaled_mtp_loss - # add mtp_loss to loss - for mtp_loss_name, mtp_loss in output["mtp_loss"].items(): - output["loss"] += mtp_loss - split_aux_output = self.aux_loss.finalize( balancing_ctx=balancing_ctx, z_ctx=z_ctx,