perf(flce): fuse grad_weight accumulation via addmm_ (cuBLAS beta=1) + memory cleanups#1270
Open
justinhh4 wants to merge 1 commit into
Open
perf(flce): fuse grad_weight accumulation via addmm_ (cuBLAS beta=1) + memory cleanups#1270justinhh4 wants to merge 1 commit into
justinhh4 wants to merge 1 commit into
Conversation
…+ memory cleanups
kolehma8
approved these changes
Jun 26, 2026
Tcc0403
reviewed
Jun 27, 2026
Comment on lines
+224
to
+227
| if grad_weight.dtype == grad_logits_chunk.dtype == _input_chunk.dtype: | ||
| grad_weight.addmm_(grad_logits_chunk.t(), _input_chunk) | ||
| else: | ||
| grad_weight += torch.mm(grad_logits_chunk.t(), _input_chunk).float() |
Collaborator
There was a problem hiding this comment.
#1239 is solving same issue by passing out_dtype and out args
torch.addmm(..., out_dtype=torch.float32, out=grad_weight)note that out_dtype is only supported in torch>2.8.0
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
perf(flce): fuse grad_weight accumulation via
addmm_(cuBLAS β=1) — B200 + H100Measured against the frozen upstream baseline (
v0_base, == HEAD). Correctness: the full FLCE suite (test_fused_linear_cross_entropy.py, 137 cases) plus a torch-reference gate (42/42 — 3 shapes incl. a multi-chunk BT=16384 shape × 7 feature toggles × bf16+fp32, so the cross-chunk grad_weight accumulation is covered) pass.What & why
In the per-chunk backward, the weight gradient
grad_weight += grad_logits_chunkᵀ @ _input_chunkwas computed asgrad_weight += torch.mm(...).float()— a GEMM, then a V×H fp32 temporary (the.float()copy), then a separate elementwise add, every chunk. Replace it withgrad_weight.addmm_(grad_logits_chunkᵀ, _input_chunk): cuBLAS folds the accumulate into the GEMM epilogue (β=1) — one GEMM, no fp32 temp, no separate add kernel..float()materialized a V×H fp32 buffer (2.1 GB at V=128256, H=4096) and the+=was a second full-tensor memory pass — by nsys ~28% of the forward (direct_copy12% +add16%). Removing both cuts latency and peak memory.grad_weight == grad_logits_chunk == _input_chunk). Autocast (bf16 matmul + fp32 input) and the fp32-accumulator path (accum_dtype=fp32) keep the originalmm + upcast. bf16 tensor cores already accumulate the MMA in fp32 internally, so the fused path is numerically equivalent to the old one (same bf16 running accumulator).Also includes two small memory cleanups in the same loop:
grad_input:torch.zeros_like→torch.empty_like— it is fully overwritten per chunk, so the zero-fill was wasted work.logits_chunk += biasin place when dtypes match (out-of-place fallback under autocast) — avoids a secondchunk×Vtemporary for models with an lm_head bias; no-op otherwise.Headline shape (BT=8192, V=128256, full fwd+bwd)
Across context length (V=128256, BT 1024–65536)
Figures
Latency and peak memory vs context length and vs vocab, baseline vs this PR, on H100 and B200.