Skip to content

[ROCm] optimizations to reduce temp size prior to training starts#3922

Open
cj401-amd wants to merge 2 commits into
AI-Hypercomputer:mainfrom
ROCm:cj/tmem-fixes-clean
Open

[ROCm] optimizations to reduce temp size prior to training starts#3922
cj401-amd wants to merge 2 commits into
AI-Hypercomputer:mainfrom
ROCm:cj/tmem-fixes-clean

Conversation

@cj401-amd
Copy link
Copy Markdown

@cj401-amd cj401-amd commented May 15, 2026

Description

The purpose of this PR is to reduce temp (compile) memory consumption. The optimizations have been tested on ROCm (MI355) and CUDA (H100), from which temp size was reduced. However, we would like to see your help for testing them on TPU to check if they are applicable or not. The PR might be a bit big, however, we can split into several PRs, if that helps to review them on your side.

Configuration PP EP Temp (main) Temp (cj/tmem-fixes-clean) Reduction
ds-proxy-se2-e256-h4096 8 1 36.5 GB 20.4 GB −16.1 GB (−44%)
ds-proxy-N1-ep2-pp4 4 2 44.4 GB 22.9 GB −21.5 GB (−48%)

We also detail the files changed along with the reason behind them.

1. src/maxtext/configs/base.yml -- Default float32_weight_sum to false

What changed: float32_weight_sum: true --> false

Why: When MoE models sum expert weights, using fp32 precision requires promoting all expert weight tensors to float32, then summing, then casting back. This materializes ~2 GB of temporary f32 tensors per device. By defaulting to false, we keep the summation in the model's native dtype (bf16), which avoids the temporary buffers entirely.

Trade-off: Slightly reduced numerical precision in expert weight summation. For bf16 training this is unlikely to matter in practice, and the parameter remains configurable for users who need strict fp32 summation.


2. src/maxtext/configs/types.py -- Pydantic field default alignment

What changed: float32_weight_sum field default True --> False, description updated.

Why: Mirrors the base.yml change so the Pydantic config schema and the YAML config file agree on the default.


3. src/maxtext/kernels/gather_reduce_sc.py -- JIT decorator restructuring

What changed: @jax.jit(static_argnames=[...]) --> @functools.partial(jax.jit, static_argnames=[...])

Why: The direct @jax.jit(...) decorator form can interact poorly with certain JAX versions when the decorated function is further wrapped or inspected. functools.partial preserves function metadata (name, docstring) and ensures the decorator is properly composed. This is a robustness fix that prevents subtle compilation issues.


4. src/maxtext/utils/sharding.py -- skip_trivial_specs + remove_size_one_mesh_axis

What changed: Two additions:

(a) skip_trivial_specs parameter on maybe_shard_with_logical():
When skip_trivial_specs=True, the function checks if all axes in the sharding spec resolve to None or (). If so, it returns the input tensor unchanged -- no with_sharding_constraint call is emitted. This is the foundational optimization used by attentions.py, pipeline.py, and mixtral.py.

Why: jax.lax.with_sharding_constraint is not free even when the spec is trivial. XLA can materialize copy buffers to enforce a constraint that effectively does nothing. Skipping it when the spec is all-None avoids those temporaries.

(b) remove_size_one_mesh_axis() function:
Filters out mesh axes from a PartitionSpec where the corresponding mesh dimension has size 1. For example, on a single-node setup where stage=1, any sharding over "stage" is meaningless and can be dropped.

Why: Size-1 mesh axes create sharding overhead (XLA inserts identity reshards) without providing any actual parallelism. Removing them simplifies the HLO and eliminates associated temporary buffers. The existing logical_to_mesh_axes() was refactored to call this new function automatically.


5. src/maxtext/layers/normalizations.py -- Replace einsum with multiply in RMSNorm

What changed:

  • jnp.einsum("...k,k->...k", y, effective_scale, out_sharding=...) --> y * effective_scale + separate jax.lax.with_sharding_constraint
  • Added conditional: effective_scale = scale + self.scale_offset if self.scale_offset != 0.0 else scale

Why: The einsum("...k,k->...k", ...) operation is semantically just element-wise multiplication (broadcasting the 1D scale across the leading dims). However, XLA's einsum lowering can generate larger intermediate buffers than a direct multiply, especially inside lax.scan (pipeline parallelism). Inside a scan body, these intermediates get hoisted into the scan carry by XLA's loop_broadcast_fusion pass, permanently increasing the carry size.

By using direct multiplication, XLA generates a simpler broadcast + multiply HLO sequence with no intermediate allocation. The sharding constraint is decoupled and applied only when needed (out_sharding is not None).

The scale_offset != 0.0 check avoids a pointless addition when offset is zero (the standard RMSNorm case, vs. RMSNorm with learnable offset).


6. src/maxtext/layers/attentions.py -- skip_trivial_specs=True

What changed: Added skip_trivial_specs=True to maybe_shard_with_logical() call in the Attention class.

Why: On mesh configurations where the attention output sharding resolves to all-None (e.g., no tensor parallelism), this avoids materializing a copy buffer. This is a minor but consistent optimization applied across the codebase.


7. src/maxtext/layers/embeddings.py -- 2 GB iota_embed guard + sharding constraints

What changed:

  • Iota guard: Compute one-hot tensor size as batch * seq * vocab * dtype_size. If >2 GB, fall back to gather-based embedding instead of iota (one-hot) embedding.
  • Conditional sharding: Only create out_pspec/out_sharding when shard_mode == ShardMode.EXPLICIT.
  • Logical constraints: Added nn.with_logical_constraint(output, ...) after the embedding lookup.

Why: Iota embedding works by creating a one-hot matrix of shape [batch, seq, vocab] and multiplying by the embedding table. For large vocabularies (e.g., vocab=102400), this one-hot tensor can be enormous: 8 * 4096 * 102400 * 2 bytes = 6.7 GB in bf16. XLA materializes this entire tensor even though only one element per row is non-zero. The gather-based path (jnp.take) uses constant memory regardless of vocabulary size.

The 2 GB threshold is conservative -- it covers most common model configs while still allowing small-vocab models to use the faster iota path.

The sharding spec was being computed unconditionally even when shard_mode != EXPLICIT. Moving it under a conditional avoids unnecessary mesh lookups.


8. src/maxtext/layers/attention_op.py -- Causal mask for synthetic data

What changed: For dataset_type == "synthetic", skip materializing the attention mask tensors entirely. Use mask_type="causal" directly instead of computing mask arrays.

Why: With synthetic data, segment IDs are always all-ones (a single segment per sequence). The segment mask is therefore all-True, so the combined attention mask reduces to pure causal masking. By using mask_type="causal" natively, we avoid:

  1. Materializing f32/s32[batch, 1, seq, seq] mask tensors (~5 GiB for seq=4096, batch=8)
  2. These masks being hoisted into the scan carry by XLA's loop_broadcast_fusion pass

Also removed two dead parameters from the DotProductAttention call:

  • scale_factor=1.0 (no-op: default is already 1.0)
  • context_parallel_strategy=... (not used in this non-CP code path)

9. src/maxtext/layers/moe.py -- Tile truncation + remove redundant astype

What changed:
(a) Tile size tuples: Reduced from 9-tuples to 3-tuples (forward-only).

(b) Removed .astype(self.dtype) from return intermediate_layer.

Why (a): The 9-tuple wi_tile_size and wo_tile_size contain tiling parameters for forward (fwd), dlhs (backward LHS), and drhs (backward RHS) passes of the Group Matrix Multiply (GMM). When megablox=False (the JAX ragged_dot path), only the forward pass tile values are used -- the backward-pass tile values are allocated in the tuple but never read. By truncating to 3-tuples, we avoid allocating and passing unused tile parameters.

Why (b): The intermediate_layer is already computed in the correct dtype from the matmul operations. The explicit .astype(self.dtype) creates a redundant copy that XLA cannot always elide (especially inside scan bodies). Removing it avoids one tensor-sized temporary.


10. src/maxtext/layers/pipeline.py -- Replace ppermute with slice/concat, reduce remat

What changed: Multiple optimizations in the Pipeline class (not CircularPipeline):

(a) Replace shard_map+ppermute with lax.slice/jnp.concatenate/jnp.pad:

The old _rotate_right used @jax.shard_map with jax.lax.ppermute for circular shifting across pipeline stages. The new version uses lax.slice_in_dim + jnp.concatenate:

last = lax.slice_in_dim(arr, N-1, N, axis=0)
rest = lax.slice_in_dim(arr, 0, N-1, axis=0)
return jnp.concatenate([last, rest], axis=0)

The old _shift_right used ppermute + zero-fill. The new version uses jnp.pad + lax.slice:

padded = jnp.pad(arr, [[1,0]] + [[0,0]]*(ndim-1))
return lax.slice(padded, [0]*ndim, arr.shape)

Similarly, _update_state_io was converted from shard_map+ppermute-based left-shift to pad+slice_in_dim-based left-shift.

Why: XLA's loop_broadcast_fusion pass hoists ppermute's collective communication buffers into the scan carry. For pp=4, this added 10-15 GB of temporary memory. Slice-based operations don't use collective scratch space -- they operate on local arrays. The operations are mathematically identical (they perform the same circular/linear shift) but have very different memory footprints inside scan loops.

(b) Remove unnecessary sharding constraints:

  • Removed sharding on shift, first_stage_in, microbatches_processed -- these are small tensors where sharding adds overhead without benefit.
  • Removed out_sharding from broadcasted_iota call.
  • Added skip_trivial_specs=True to the base sharding helper.

(c) Remove decoder_layer_input from remat save list:
Changed from save_only_these_names("iteration_input", "decoder_layer_input") to save_only_these_names("iteration_input"). The iteration_input checkpoint already contains the data needed for backward pass. Saving decoder_layer_input too created two redundant stacked checkpoint buffers (~3.3 GB).

(d) Remove meta.remove_axis call:
The axis removal logic for vmap weights was unnecessary and created intermediate metadata objects.

(e) Replace indexed gather with dynamic_slice_in_dim:
x.at[idx].get(out_sharding=...) creates intermediate index tuples. jax.lax.dynamic_slice_in_dim(x, i, 1, dim) is a single primitive with no intermediate allocation.


11. src/maxtext/models/deepseek.py -- Remove reshard calls

What changed:

  • Added import: from maxtext.utils.sharding import remove_size_one_mesh_axis
  • Changed MoE kernel init scale from self.config.dense_init_scale to 1.0
  • Replaced explicit jax.reshard() in/out pair with remove_size_one_mesh_axis()-based PartitionSpec
  • Expanded PartitionSpec to include "fsdp_transpose" and "context" axes

Why: The original code captured the input sharding, resharded inputs to a specific activation PartitionSpec, ran the MoE block, then resharded outputs back to the original input sharding. Each jax.reshard() creates a full-activation-sized temporary buffer. By using remove_size_one_mesh_axis() to simplify the PartitionSpec (dropping axes that have size 1 on the current mesh), we get the same effective sharding without the explicit reshard calls. This eliminates two activation-sized temporaries.

The PartitionSpec was also expanded to include fsdp_transpose and context axes for correctness on meshes that use those dimensions.


12. src/maxtext/models/mixtral.py -- NNX to Linen conversion

What changed: Complete rewrite from nnx.Module to nn.Module (Linen).

  • Removed from flax import nnx and nnx_wrappers imports
  • Added attention_as_linen, rms_norm, maybe_shard_with_logical imports
  • Changed class from nnx.Module (with __init__) to nn.Module (with @nn.compact __call__)
  • All layer instantiation moved from __init__ into __call__
  • Added local shard() helper with skip_trivial_specs=True
  • Removed MixtralDecoderLayerToLinen = nnx_wrappers.to_linen_class(...) wrapper -- the class IS Linen now

Why: The NNX-to-Linen wrapper (nnx_wrappers.to_linen_class) creates an NNX module graph and then converts it to Linen form. This process materializes intermediate module objects and parameter trees that consume memory during eval_shape and jit tracing. By writing the layer directly in Linen, we bypass this overhead entirely.

The Linen @nn.compact style also defers all parameter creation to call time, avoiding eager allocation of layer parameters during module construction. This is particularly impactful for MoE models where the number of parameters is large.

Weight names are preserved through explicit name= arguments to ensure checkpoint compatibility.


13. src/maxtext/trainers/pre_train/train.py -- Flax config + grad dtype guard

What changed:
(a) Added import flax + flax.config.update("flax_always_shard_variable", False) at module level, guarded by try/except LookupError.

(b) Wrapped the gradient dtype tree_map in if config.grad_dtype != jnp.float32:.

Why (a): When flax_always_shard_variable is True (the default in some Flax versions), Flax inserts with_sharding_constraint on every variable during creation. This generates intermediate sharding copies even for variables that don't need explicit sharding. Setting it to False means sharding is only applied where the model code explicitly requests it.

Why (b): The tree_map that casts gradients from fp32 to config.grad_dtype is a no-op when grad_dtype is already fp32. But tree_map still walks the entire gradient tree and calls the lambda on every leaf, which can trigger unnecessary tensor copies if XLA doesn't fully elide the identity cast. Skipping the tree_map entirely when it's a no-op avoids this overhead.


named_sharding = create_sharding(mesh, logical_axes, rules=rules)

if skip_trivial_specs and all(ax is None or ax == () for ax in named_sharding.spec):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why it improves performance? Shardy should ignore sharding hint spec==None or empty

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems the CustomCall @sharding node it emits still exists in HLO, and XLA's loop_broadcast_fusion can hoist it into the scan carry, materializing large TMEM buffers. skip_trivial_specs=True prevents the node from being inserted at all, eliminating that hoisting opportunity.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very interesting to know. Thank you!

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume you used shard_mode=auto, right?

Resolves one conflict in train.py: upstream added `nnx` to the flax
import line while the TMEM branch had added the flax_always_shard_variable
config update just before it. Both changes are preserved:

  import flax
  try:
    flax.config.update("flax_always_shard_variable", False)
  except LookupError:
    pass
  from flax import linen as nn, nnx
@codecov
Copy link
Copy Markdown

codecov Bot commented May 15, 2026

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants