[ROCm] optimizations to reduce temp size prior to training starts#3922
Open
cj401-amd wants to merge 2 commits into
Open
[ROCm] optimizations to reduce temp size prior to training starts#3922cj401-amd wants to merge 2 commits into
cj401-amd wants to merge 2 commits into
Conversation
NuojCheng
reviewed
May 15, 2026
|
|
||
| named_sharding = create_sharding(mesh, logical_axes, rules=rules) | ||
|
|
||
| if skip_trivial_specs and all(ax is None or ax == () for ax in named_sharding.spec): |
Collaborator
There was a problem hiding this comment.
why it improves performance? Shardy should ignore sharding hint spec==None or empty
Author
There was a problem hiding this comment.
It seems the CustomCall @sharding node it emits still exists in HLO, and XLA's loop_broadcast_fusion can hoist it into the scan carry, materializing large TMEM buffers. skip_trivial_specs=True prevents the node from being inserted at all, eliminating that hoisting opportunity.
Collaborator
There was a problem hiding this comment.
very interesting to know. Thank you!
Collaborator
There was a problem hiding this comment.
I assume you used shard_mode=auto, right?
Resolves one conflict in train.py: upstream added `nnx` to the flax
import line while the TMEM branch had added the flax_always_shard_variable
config update just before it. Both changes are preserved:
import flax
try:
flax.config.update("flax_always_shard_variable", False)
except LookupError:
pass
from flax import linen as nn, nnx
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
The purpose of this PR is to reduce temp (compile) memory consumption. The optimizations have been tested on ROCm (MI355) and CUDA (H100), from which temp size was reduced. However, we would like to see your help for testing them on TPU to check if they are applicable or not. The PR might be a bit big, however, we can split into several PRs, if that helps to review them on your side.
We also detail the files changed along with the reason behind them.
1.
src/maxtext/configs/base.yml-- Defaultfloat32_weight_sumto falseWhat changed:
float32_weight_sum: true-->falseWhy: When MoE models sum expert weights, using fp32 precision requires promoting all expert weight tensors to float32, then summing, then casting back. This materializes ~2 GB of temporary f32 tensors per device. By defaulting to false, we keep the summation in the model's native dtype (bf16), which avoids the temporary buffers entirely.
Trade-off: Slightly reduced numerical precision in expert weight summation. For bf16 training this is unlikely to matter in practice, and the parameter remains configurable for users who need strict fp32 summation.
2.
src/maxtext/configs/types.py-- Pydantic field default alignmentWhat changed:
float32_weight_sumfield defaultTrue-->False, description updated.Why: Mirrors the base.yml change so the Pydantic config schema and the YAML config file agree on the default.
3.
src/maxtext/kernels/gather_reduce_sc.py-- JIT decorator restructuringWhat changed:
@jax.jit(static_argnames=[...])-->@functools.partial(jax.jit, static_argnames=[...])Why: The direct
@jax.jit(...)decorator form can interact poorly with certain JAX versions when the decorated function is further wrapped or inspected.functools.partialpreserves function metadata (name, docstring) and ensures the decorator is properly composed. This is a robustness fix that prevents subtle compilation issues.4.
src/maxtext/utils/sharding.py--skip_trivial_specs+remove_size_one_mesh_axisWhat changed: Two additions:
(a)
skip_trivial_specsparameter onmaybe_shard_with_logical():When
skip_trivial_specs=True, the function checks if all axes in the sharding spec resolve toNoneor(). If so, it returns the input tensor unchanged -- nowith_sharding_constraintcall is emitted. This is the foundational optimization used by attentions.py, pipeline.py, and mixtral.py.Why:
jax.lax.with_sharding_constraintis not free even when the spec is trivial. XLA can materialize copy buffers to enforce a constraint that effectively does nothing. Skipping it when the spec is all-None avoids those temporaries.(b)
remove_size_one_mesh_axis()function:Filters out mesh axes from a
PartitionSpecwhere the corresponding mesh dimension has size 1. For example, on a single-node setup wherestage=1, any sharding over"stage"is meaningless and can be dropped.Why: Size-1 mesh axes create sharding overhead (XLA inserts identity reshards) without providing any actual parallelism. Removing them simplifies the HLO and eliminates associated temporary buffers. The existing
logical_to_mesh_axes()was refactored to call this new function automatically.5.
src/maxtext/layers/normalizations.py-- Replace einsum with multiply in RMSNormWhat changed:
jnp.einsum("...k,k->...k", y, effective_scale, out_sharding=...)-->y * effective_scale+ separatejax.lax.with_sharding_constrainteffective_scale = scale + self.scale_offset if self.scale_offset != 0.0 else scaleWhy: The
einsum("...k,k->...k", ...)operation is semantically just element-wise multiplication (broadcasting the 1D scale across the leading dims). However, XLA's einsum lowering can generate larger intermediate buffers than a direct multiply, especially insidelax.scan(pipeline parallelism). Inside a scan body, these intermediates get hoisted into the scan carry by XLA'sloop_broadcast_fusionpass, permanently increasing the carry size.By using direct multiplication, XLA generates a simpler
broadcast + multiplyHLO sequence with no intermediate allocation. The sharding constraint is decoupled and applied only when needed (out_sharding is not None).The
scale_offset != 0.0check avoids a pointless addition when offset is zero (the standard RMSNorm case, vs. RMSNorm with learnable offset).6.
src/maxtext/layers/attentions.py--skip_trivial_specs=TrueWhat changed: Added
skip_trivial_specs=Truetomaybe_shard_with_logical()call in the Attention class.Why: On mesh configurations where the attention output sharding resolves to all-None (e.g., no tensor parallelism), this avoids materializing a copy buffer. This is a minor but consistent optimization applied across the codebase.
7.
src/maxtext/layers/embeddings.py-- 2 GB iota_embed guard + sharding constraintsWhat changed:
batch * seq * vocab * dtype_size. If >2 GB, fall back to gather-based embedding instead of iota (one-hot) embedding.out_pspec/out_shardingwhenshard_mode == ShardMode.EXPLICIT.nn.with_logical_constraint(output, ...)after the embedding lookup.Why: Iota embedding works by creating a one-hot matrix of shape
[batch, seq, vocab]and multiplying by the embedding table. For large vocabularies (e.g., vocab=102400), this one-hot tensor can be enormous:8 * 4096 * 102400 * 2 bytes = 6.7 GBin bf16. XLA materializes this entire tensor even though only one element per row is non-zero. The gather-based path (jnp.take) uses constant memory regardless of vocabulary size.The 2 GB threshold is conservative -- it covers most common model configs while still allowing small-vocab models to use the faster iota path.
The sharding spec was being computed unconditionally even when
shard_mode != EXPLICIT. Moving it under a conditional avoids unnecessary mesh lookups.8.
src/maxtext/layers/attention_op.py-- Causal mask for synthetic dataWhat changed: For
dataset_type == "synthetic", skip materializing the attention mask tensors entirely. Usemask_type="causal"directly instead of computing mask arrays.Why: With synthetic data, segment IDs are always all-ones (a single segment per sequence). The segment mask is therefore all-True, so the combined attention mask reduces to pure causal masking. By using
mask_type="causal"natively, we avoid:f32/s32[batch, 1, seq, seq]mask tensors (~5 GiB for seq=4096, batch=8)loop_broadcast_fusionpassAlso removed two dead parameters from the
DotProductAttentioncall:scale_factor=1.0(no-op: default is already 1.0)context_parallel_strategy=...(not used in this non-CP code path)9.
src/maxtext/layers/moe.py-- Tile truncation + remove redundant astypeWhat changed:
(a) Tile size tuples: Reduced from 9-tuples to 3-tuples (forward-only).
(b) Removed
.astype(self.dtype)fromreturn intermediate_layer.Why (a): The 9-tuple
wi_tile_sizeandwo_tile_sizecontain tiling parameters for forward (fwd), dlhs (backward LHS), and drhs (backward RHS) passes of the Group Matrix Multiply (GMM). Whenmegablox=False(the JAXragged_dotpath), only the forward pass tile values are used -- the backward-pass tile values are allocated in the tuple but never read. By truncating to 3-tuples, we avoid allocating and passing unused tile parameters.Why (b): The
intermediate_layeris already computed in the correct dtype from the matmul operations. The explicit.astype(self.dtype)creates a redundant copy that XLA cannot always elide (especially inside scan bodies). Removing it avoids one tensor-sized temporary.10.
src/maxtext/layers/pipeline.py-- Replace ppermute with slice/concat, reduce rematWhat changed: Multiple optimizations in the
Pipelineclass (notCircularPipeline):(a) Replace
shard_map+ppermutewithlax.slice/jnp.concatenate/jnp.pad:The old
_rotate_rightused@jax.shard_mapwithjax.lax.ppermutefor circular shifting across pipeline stages. The new version useslax.slice_in_dim+jnp.concatenate:The old
_shift_rightusedppermute+ zero-fill. The new version usesjnp.pad+lax.slice:Similarly,
_update_state_iowas converted fromshard_map+ppermute-based left-shift topad+slice_in_dim-based left-shift.Why: XLA's
loop_broadcast_fusionpass hoists ppermute's collective communication buffers into the scan carry. Forpp=4, this added 10-15 GB of temporary memory. Slice-based operations don't use collective scratch space -- they operate on local arrays. The operations are mathematically identical (they perform the same circular/linear shift) but have very different memory footprints inside scan loops.(b) Remove unnecessary sharding constraints:
shift,first_stage_in,microbatches_processed-- these are small tensors where sharding adds overhead without benefit.out_shardingfrombroadcasted_iotacall.skip_trivial_specs=Trueto the base sharding helper.(c) Remove
decoder_layer_inputfrom remat save list:Changed from
save_only_these_names("iteration_input", "decoder_layer_input")tosave_only_these_names("iteration_input"). Theiteration_inputcheckpoint already contains the data needed for backward pass. Savingdecoder_layer_inputtoo created two redundant stacked checkpoint buffers (~3.3 GB).(d) Remove
meta.remove_axiscall:The axis removal logic for vmap weights was unnecessary and created intermediate metadata objects.
(e) Replace indexed gather with
dynamic_slice_in_dim:x.at[idx].get(out_sharding=...)creates intermediate index tuples.jax.lax.dynamic_slice_in_dim(x, i, 1, dim)is a single primitive with no intermediate allocation.11.
src/maxtext/models/deepseek.py-- Remove reshard callsWhat changed:
from maxtext.utils.sharding import remove_size_one_mesh_axisself.config.dense_init_scaleto1.0jax.reshard()in/out pair withremove_size_one_mesh_axis()-based PartitionSpec"fsdp_transpose"and"context"axesWhy: The original code captured the input sharding, resharded inputs to a specific activation PartitionSpec, ran the MoE block, then resharded outputs back to the original input sharding. Each
jax.reshard()creates a full-activation-sized temporary buffer. By usingremove_size_one_mesh_axis()to simplify the PartitionSpec (dropping axes that have size 1 on the current mesh), we get the same effective sharding without the explicit reshard calls. This eliminates two activation-sized temporaries.The PartitionSpec was also expanded to include
fsdp_transposeandcontextaxes for correctness on meshes that use those dimensions.12.
src/maxtext/models/mixtral.py-- NNX to Linen conversionWhat changed: Complete rewrite from
nnx.Moduletonn.Module(Linen).from flax import nnxandnnx_wrappersimportsattention_as_linen,rms_norm,maybe_shard_with_logicalimportsnnx.Module(with__init__) tonn.Module(with@nn.compact__call__)__init__into__call__shard()helper withskip_trivial_specs=TrueMixtralDecoderLayerToLinen = nnx_wrappers.to_linen_class(...)wrapper -- the class IS Linen nowWhy: The NNX-to-Linen wrapper (
nnx_wrappers.to_linen_class) creates an NNX module graph and then converts it to Linen form. This process materializes intermediate module objects and parameter trees that consume memory duringeval_shapeandjittracing. By writing the layer directly in Linen, we bypass this overhead entirely.The Linen
@nn.compactstyle also defers all parameter creation to call time, avoiding eager allocation of layer parameters during module construction. This is particularly impactful for MoE models where the number of parameters is large.Weight names are preserved through explicit
name=arguments to ensure checkpoint compatibility.13.
src/maxtext/trainers/pre_train/train.py-- Flax config + grad dtype guardWhat changed:
(a) Added
import flax+flax.config.update("flax_always_shard_variable", False)at module level, guarded bytry/except LookupError.(b) Wrapped the gradient dtype tree_map in
if config.grad_dtype != jnp.float32:.Why (a): When
flax_always_shard_variableis True (the default in some Flax versions), Flax insertswith_sharding_constrainton every variable during creation. This generates intermediate sharding copies even for variables that don't need explicit sharding. Setting it to False means sharding is only applied where the model code explicitly requests it.Why (b): The tree_map that casts gradients from fp32 to
config.grad_dtypeis a no-op whengrad_dtypeis already fp32. Buttree_mapstill walks the entire gradient tree and calls the lambda on every leaf, which can trigger unnecessary tensor copies if XLA doesn't fully elide the identity cast. Skipping the tree_map entirely when it's a no-op avoids this overhead.