Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions xtuner/v1/engine/train_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,9 @@ def _get_total_loss(self, model_outputs: ModelOutputs) -> torch.Tensor:
loss = torch.tensor(0.0, device=DEVICE)
for key in model_outputs.model_fields:
value = getattr(model_outputs, key)
if "loss" in key and isinstance(value, torch.Tensor):
loss += value
if "loss" in key:
loss_values = list(value.values()) if isinstance(value, dict) else [value]
loss_values = [i for i in loss_values if isinstance(i, torch.Tensor)]
for value in loss_values:
loss += value
return loss
5 changes: 4 additions & 1 deletion xtuner/v1/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
ZLossContext,
ZLossKwargs,
)
from .mtp_loss import MTPLossContext
from .mtp_loss import MTPLossContext, SciMTPLossContext, MTPLossConfig, SciMTPLossConfig
from .rl_loss import LogProbConfig, LogProbContext


Expand All @@ -31,6 +31,9 @@
"BaseLossKwargs",
"LMHeadLossContext",
"MTPLossContext",
"MTPLossConfig",
"SciMTPLossContext",
"SciMTPLossConfig",
"LogProbConfig",
"LogProbContext",
]
Expand Down
112 changes: 106 additions & 6 deletions xtuner/v1/loss/mtp_loss.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any
from typing import Any, Optional

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -53,12 +53,11 @@ class MTPLossConfig(CELossConfig):

Args:
mtp_depth (int): 1-indexed MTP layer depth. The first MTP layer uses
``mtp_depth=1`` (shift=-1 on top of the existing label shift).
``mtp_depth=1`` (shift=-1 on top of the existing label shift). Default: 1.
detach_mtp_lm_head_weight (bool): Whether to detach the LM head weight.
This is used in RL training. Default is False.
"""

mtp_depth: int
detach_mtp_lm_head_weight: bool = False

@property
Expand Down Expand Up @@ -88,6 +87,7 @@ def build(self, data: dict, sp_mesh: DeviceMesh | None = None) -> "MTPLossContex
MTPLossContext | None: Built loss context, or ``None`` if
``shifted_labels`` is not present in ``data``.
"""

# TODO: Should move the common utils function to public package to avoid from circular import.
from xtuner.v1.module.mtp.utils import roll_packed_tensor

Expand All @@ -96,6 +96,7 @@ def build(self, data: dict, sp_mesh: DeviceMesh | None = None) -> "MTPLossContex

shifted_labels = data["shifted_labels"]
cu_seq_lens = data["seq_ctx"].cu_seq_lens_k
mtp_depth = data["mtp_depth"]

# cu_seq_lens[-1] may be larger than shifted_labels.shape[-1] when seq_ctx
# was split for sequence parallelism (padding is added to make the sequence
Expand All @@ -112,7 +113,7 @@ def build(self, data: dict, sp_mesh: DeviceMesh | None = None) -> "MTPLossContex
)
shifted_labels = torch.cat([shifted_labels, pad], dim=-1)

rolled = roll_packed_tensor(shifted_labels, cu_seq_lens, shifts=-self.mtp_depth, dim=-1, fill_value=-100)
rolled = roll_packed_tensor(shifted_labels, cu_seq_lens, shifts=-mtp_depth, dim=-1, fill_value=-100)

# Roll logprobs by the same amount as shifted_labels
logprobs = data.get("logprobs", None)
Expand All @@ -126,7 +127,7 @@ def build(self, data: dict, sp_mesh: DeviceMesh | None = None) -> "MTPLossContex
device=logprobs.device,
)
logprobs = torch.cat([logprobs, rp_pad], dim=-1)
rolled_logprobs = roll_packed_tensor(logprobs, cu_seq_lens, shifts=-self.mtp_depth, dim=-1, fill_value=0)
rolled_logprobs = roll_packed_tensor(logprobs, cu_seq_lens, shifts=-mtp_depth, dim=-1, fill_value=0)

loss_kwargs = MTPLossKwargs(
shifted_labels=rolled,
Expand All @@ -135,7 +136,28 @@ def build(self, data: dict, sp_mesh: DeviceMesh | None = None) -> "MTPLossContex
if sp_mesh is not None and sp_mesh.size() > 1:
loss_kwargs = loss_kwargs.sp_split(sp_mesh)

return MTPLossContext(self, loss_kwargs)
loss_context = self.loss_ctx_cls(self, loss_kwargs)
loss_context.bind_mtp_depth(mtp_depth)
return loss_context


class SciMTPLossConfig(MTPLossConfig):
"""Loss configuration for Multi-Token Prediction (MTP).

Extends ``MTPLossConfig`` with a ``mask_type`` field that controls how to mask
``loss_kwargs`` when calculating loss.

Args:
detach_mtp_lm_head_weight (bool): Whether to detach the LM head weight.
This is used in RL training. Default is False.
mask_type (str | None): Mask method when calculating Science MTP.
"""

mask_type: Optional[str] = None

@property
def loss_ctx_cls(self) -> type["SciMTPLossContext"]:
return SciMTPLossContext


class MTPLossContext(LMHeadLossContext):
Expand All @@ -156,13 +178,18 @@ class MTPLossContext(LMHeadLossContext):
loss_kwargs (MTPLossKwargs): Pre-rolled keyword arguments for loss
computation.
"""
def __init__(self, loss_cfg: MTPLossConfig, loss_kwargs: MTPLossKwargs):
super().__init__(loss_cfg, loss_kwargs)

self.mtp_depth = None

def forward(
self,
hidden_states: torch.Tensor,
head_weight: torch.Tensor,
head_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, tuple[torch.Tensor | None, dict[str, Any]]]:
assert self.mtp_depth is not None, "Please bind mtp depth for MTPLossContext!"
if self.loss_cfg.detach_mtp_lm_head_weight:
head_weight = head_weight.detach()
head_bias = head_bias.detach() if head_bias is not None else None
Expand Down Expand Up @@ -214,3 +241,76 @@ def _kl_loss_fn(
)

return kl_loss, (None, {})

def bind_mtp_depth(self, depth: int) -> None:
"""Bind MTP depth to the given index.

Args:
depth (int): 1-indexed MTP layer depth to bind.
"""
self.mtp_depth = depth


class SciMTPLossContext(MTPLossContext):
"""Loss context for Science Multi-Token Prediction (MTP).

Supports two modes:
- **CE mode** (default): Standard cross-entropy loss on rolled labels.
Used during SFT/pretraining.
- **KL mode**: When ``logprobs`` is available (RL training),
computes KL divergence between MTP's log-probabilities and the
rolled rollout log-probabilities.

Both modes support chunk mode for memory-efficient computation via the
base class's ``forward() → eager_mode()/chunk_mode() → loss_fn()`` dispatch.

Args:
loss_cfg (MTPLossConfig): The MTP loss configuration.
loss_kwargs (MTPLossKwargs): Pre-rolled keyword arguments for loss
computation.
"""

def forward(
self,
hidden_states: torch.Tensor,
head_weight: torch.Tensor,
head_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, tuple[torch.Tensor | None, dict[str, Any]]]:
mask_type = self.loss_cfg.mask_type
if mask_type == "v1":
self.process_loss_weight_v1()
elif mask_type is not None:
raise NotImplementedError(f"Unknown MTP Loss Mask Type: {mask_type}")

return super().forward(hidden_states, head_weight, head_bias)

def process_loss_weight_v1(self):
Comment thread
HAOCHENYE marked this conversation as resolved.
layer_idx = self.mtp_depth - 1
shifted_labels = self.loss_kwargs.shifted_labels
loss_weight = self.loss_kwargs.loss_weight
sum_loss_weight = loss_weight.sum()

easy_to_use = torch.cat(
[
shifted_labels,
torch.zeros((shifted_labels.size(0), 1), dtype=shifted_labels.dtype, device=shifted_labels.device),
],
dim=-1,
)

# TODO: digit and dot token config
is_digit = torch.where(easy_to_use < 25, easy_to_use > 14, 0)
is_dot = torch.where(easy_to_use == 13, 1, 0)
is_digit_or_dot = is_digit | is_dot

mask = is_digit_or_dot.clone()
for i in range(layer_idx + 1):
mask |= torch.roll(is_digit_or_dot, shifts=i + 1, dims=-1)

mtp_mask = mask.bool()[:, :-1]

loss_weight[mtp_mask == 0.0] = 0.0
if loss_weight.sum().item() != 0:
loss_weight = loss_weight * sum_loss_weight / loss_weight.sum()

self.loss_kwargs.loss_weight = loss_weight
13 changes: 12 additions & 1 deletion xtuner/v1/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1313,7 +1313,18 @@ def post_micro_batch_forward(self, batch_outputs: Sequence[ModelOutputs]) -> Bat
output_copy = output.model_copy()
for name in output_copy.model_fields:
obj = getattr(output_copy, name)
if "loss" in name and isinstance(obj, torch.Tensor):
if name == "mtp_loss" and isinstance(obj, dict):
for key, value in obj.items():
loss_item = value.item()
local_total_loss += loss_item
reduced_name = f"{key}_reduced_mtp_loss"

if reduced_name not in reduced_other_losses:
reduced_other_losses[reduced_name] = loss_item
else:
reduced_other_losses[reduced_name] += loss_item

elif "loss" in name and isinstance(obj, torch.Tensor):
loss_item = obj.item()
local_total_loss += loss_item
reduced_name = f"reduced_{name}"
Expand Down
Loading
Loading