Skip to content

compat mtp megatron_core main branch#92

Open
Jintao-Huang wants to merge 1 commit into
modelscope:mainfrom
Jintao-Huang:compat_mtp_megatron_core_main_brancj
Open

compat mtp megatron_core main branch#92
Jintao-Huang wants to merge 1 commit into
modelscope:mainfrom
Jintao-Huang:compat_mtp_megatron_core_main_brancj

Conversation

@Jintao-Huang
Copy link
Copy Markdown
Collaborator

No description provided.

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces activation recomputation support for the MTP layer by implementing the _checkpointed_forward method, which handles both standard and quantized (FP8/FP4) execution paths. The review feedback focuses on optimizing memory efficiency and ensuring correctness for quantized training. Key suggestions include capturing non-gradient tensors within the custom_forward closure to avoid unnecessary stashing, adding missing inner quantization contexts for FP8/FP4, and providing a functional implementation for the 'block' recompute method. Additionally, the reviewer recommended importing missing FP4 utilities and relaxing a restrictive assertion regarding the number of recompute layers.

Comment on lines +166 to +190
def custom_forward(
hidden_states,
decoder_input,
attention_mask,
context,
context_mask,
rotary_pos_emb,
rotary_pos_cos,
rotary_pos_sin,
sequence_len_offset,
):
return self._proj_and_transformer_layer(
hidden_states=hidden_states,
decoder_input=decoder_input,
attention_mask=attention_mask,
context=context,
context_mask=context_mask,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
attention_bias=attention_bias,
inference_params=inference_params,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The custom_forward implementation is missing the inner quantization context logic for FP8 and FP4, which is necessary for correct activation recomputation. Additionally, the signature should be simplified to only take tensors that require gradients (like hidden_states and decoder_input), capturing the rest in the closure for better memory efficiency.

        def custom_forward(hidden_states, decoder_input):
            # Get appropriate inner quantization context
            if self.config.fp8:
                inner_quantization_context = get_fp8_context(self.config, self.layer_number - 1)
            elif self.config.fp4:
                inner_quantization_context = get_fp4_context(self.config, self.layer_number - 1)
            else:
                inner_quantization_context = nullcontext()

            with inner_quantization_context:
                return self._proj_and_transformer_layer(
                    hidden_states=hidden_states,
                    decoder_input=decoder_input,
                    attention_mask=attention_mask,
                    context=context,
                    context_mask=context_mask,
                    rotary_pos_emb=rotary_pos_emb,
                    rotary_pos_cos=rotary_pos_cos,
                    rotary_pos_sin=rotary_pos_sin,
                    attention_bias=attention_bias,
                    inference_params=inference_params,
                    packed_seq_params=packed_seq_params,
                    sequence_len_offset=sequence_len_offset,
                )

Comment on lines +7 to +8
from megatron.core.enums import Fp8Recipe
from megatron.core.fp8_utils import get_fp8_context
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To fully support FP4 quantization recomputation as referenced in the _checkpointed_forward logic, you should also import get_fp4_context.

Suggested change
from megatron.core.enums import Fp8Recipe
from megatron.core.fp8_utils import get_fp8_context
from megatron.core.enums import Fp8Recipe
from megatron.core.fp4_utils import get_fp4_context
from megatron.core.fp8_utils import get_fp8_context

Comment on lines +150 to +158
* Non-tensor objects (``attention_bias``, ``inference_params``,
``packed_seq_params``) are captured by the ``custom_forward``
closure; only tensor / ``None`` arguments flow positionally
through the underlying checkpoint primitive. This is required
by both backends: ``tensor_parallel.checkpoint`` because its
``save_for_backward`` only accepts tensors and ``None``, and
``te_checkpoint`` because its reentrant implementation only
tracks positional tensor inputs as checkpoint inputs (kwarg
tensors are not represented in the recompute backward path).
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring incorrectly identifies attention_bias as a non-tensor object. Furthermore, to improve memory efficiency, tensors that do not require gradients (like attention_mask, rotary_pos_emb, and sequence_len_offset) should be captured in the closure rather than passed positionally to the checkpoint function, which avoids unnecessary stashing. This aligns with the implementation in transformer_block.py.

Suggested change
* Non-tensor objects (``attention_bias``, ``inference_params``,
``packed_seq_params``) are captured by the ``custom_forward``
closure; only tensor / ``None`` arguments flow positionally
through the underlying checkpoint primitive. This is required
by both backends: ``tensor_parallel.checkpoint`` because its
``save_for_backward`` only accepts tensors and ``None``, and
``te_checkpoint`` because its reentrant implementation only
tracks positional tensor inputs as checkpoint inputs (kwarg
tensors are not represented in the recompute backward path).
* Non-gradient tensors and objects (``attention_mask``, ``rotary_pos_emb``,
``attention_bias``, ``inference_params``, ``packed_seq_params``,
``sequence_len_offset``) are captured by the ``custom_forward``
closure; only tensors requiring gradients (``hidden_states``,
``decoder_input``) flow positionally through the underlying
checkpoint primitive. This is required by both backends:
``tensor_parallel.checkpoint`` because its ``save_for_backward``
only accepts tensors and ``None``, and ``te_checkpoint`` because
its reentrant implementation only tracks positional tensor inputs
as checkpoint inputs.

Comment on lines +219 to +252
return te_checkpoint(
custom_forward,
self.config.distribute_saved_activations,
tensor_parallel.random.get_cuda_rng_tracker,
parallel_state.get_tensor_model_parallel_group(),
hidden_states,
decoder_input,
attention_mask,
context,
context_mask,
rotary_pos_emb,
rotary_pos_cos,
rotary_pos_sin,
sequence_len_offset,
)
else:
# tensor_parallel.checkpoint stashes args via autograd's
# ``save_for_backward``, which only accepts tensors and ``None``.
# Pass tensor / ``None`` args positionally and capture the
# non-tensor objects (``attention_bias``, ``inference_params``,
# ``packed_seq_params``) via the ``custom_forward`` closure.
return tensor_parallel.checkpoint(
custom_forward,
self.config.distribute_saved_activations,
hidden_states,
decoder_input,
attention_mask,
context,
context_mask,
rotary_pos_emb,
rotary_pos_cos,
rotary_pos_sin,
sequence_len_offset,
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Update the checkpoint calls to match the simplified custom_forward signature. Only hidden_states and decoder_input (and potentially context/context_mask if they were supported and required gradients) should be passed positionally.

                return te_checkpoint(
                    custom_forward,
                    self.config.distribute_saved_activations,
                    tensor_parallel.random.get_cuda_rng_tracker,
                    parallel_state.get_tensor_model_parallel_group(),
                    hidden_states,
                    decoder_input,
                )
            else:
                return tensor_parallel.checkpoint(
                    custom_forward,
                    self.config.distribute_saved_activations,
                    hidden_states,
                    decoder_input,
                )

# Uniformly divide the total number of Transformer layers and checkpoint
# the input activation of each divided chunk.
# A method to further reduce memory usage reducing checkpoints.
assert (self.config.recompute_num_layers == 1), 'recompute_num_layers must be 1 for MTP recompute'
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This assertion might be too restrictive if recompute_num_layers is set to a value greater than 1 globally. It is safer to check if it is at least 1, as MTP only contains a single layer to checkpoint.

Suggested change
assert (self.config.recompute_num_layers == 1), 'recompute_num_layers must be 1 for MTP recompute'
assert (self.config.recompute_num_layers >= 1), 'recompute_num_layers must be at least 1 for MTP recompute'

Comment on lines +261 to +278
elif self.config.recompute_method == 'block':
# TODO: implement block-based recompute for MTP
warnings.warn("recompute_method == 'block' is not supported for MTP yet."
' Skipping recompute.')
outputs = self._proj_and_transformer_layer(
hidden_states=hidden_states,
decoder_input=decoder_input,
attention_mask=attention_mask,
context=context,
context_mask=context_mask,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
attention_bias=attention_bias,
inference_params=inference_params,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The 'block' recompute method can be supported for MTP by simply checking if recompute_num_layers >= 1. Since MTP wraps a single transformer layer, this is equivalent to checkpointing the entire module.

Suggested change
elif self.config.recompute_method == 'block':
# TODO: implement block-based recompute for MTP
warnings.warn("recompute_method == 'block' is not supported for MTP yet."
' Skipping recompute.')
outputs = self._proj_and_transformer_layer(
hidden_states=hidden_states,
decoder_input=decoder_input,
attention_mask=attention_mask,
context=context,
context_mask=context_mask,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
attention_bias=attention_bias,
inference_params=inference_params,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
)
elif self.config.recompute_method == 'block':
if self.config.recompute_num_layers >= 1:
with outer_quantization_context:
outputs = checkpoint_handler()
else:
outputs = self._proj_and_transformer_layer(
hidden_states=hidden_states,
decoder_input=decoder_input,
attention_mask=attention_mask,
context=context,
context_mask=context_mask,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
attention_bias=attention_bias,
inference_params=inference_params,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant