Skip to content

perf(flce): fuse grad_weight accumulation via addmm_ (cuBLAS beta=1) + memory cleanups#1270

Open
justinhh4 wants to merge 1 commit into
linkedin:mainfrom
justinhh4:flce-opt-addmm
Open

perf(flce): fuse grad_weight accumulation via addmm_ (cuBLAS beta=1) + memory cleanups#1270
justinhh4 wants to merge 1 commit into
linkedin:mainfrom
justinhh4:flce-opt-addmm

Conversation

@justinhh4

@justinhh4 justinhh4 commented Jun 26, 2026

Copy link
Copy Markdown
Contributor

perf(flce): fuse grad_weight accumulation via addmm_ (cuBLAS β=1) — B200 + H100

Measured against the frozen upstream baseline (v0_base, == HEAD). Correctness: the full FLCE suite (test_fused_linear_cross_entropy.py, 137 cases) plus a torch-reference gate (42/42 — 3 shapes incl. a multi-chunk BT=16384 shape × 7 feature toggles × bf16+fp32, so the cross-chunk grad_weight accumulation is covered) pass.

What & why

In the per-chunk backward, the weight gradient grad_weight += grad_logits_chunkᵀ @ _input_chunk was computed as grad_weight += torch.mm(...).float() — a GEMM, then a V×H fp32 temporary (the .float() copy), then a separate elementwise add, every chunk. Replace it with grad_weight.addmm_(grad_logits_chunkᵀ, _input_chunk): cuBLAS folds the accumulate into the GEMM epilogue (β=1) — one GEMM, no fp32 temp, no separate add kernel.

  • Why: the .float() materialized a V×H fp32 buffer (2.1 GB at V=128256, H=4096) and the += was a second full-tensor memory pass — by nsys ~28% of the forward (direct_copy 12% + add 16%). Removing both cuts latency and peak memory.
  • Correctness: gated to the pure same-dtype case (grad_weight == grad_logits_chunk == _input_chunk). Autocast (bf16 matmul + fp32 input) and the fp32-accumulator path (accum_dtype=fp32) keep the original mm + upcast. bf16 tensor cores already accumulate the MMA in fp32 internally, so the fused path is numerically equivalent to the old one (same bf16 running accumulator).

Also includes two small memory cleanups in the same loop:

  • grad_input: torch.zeros_liketorch.empty_like — it is fully overwritten per chunk, so the zero-fill was wasted work.
  • bias add: logits_chunk += bias in place when dtypes match (out-of-place fallback under autocast) — avoids a second chunk×V temporary for models with an lm_head bias; no-op otherwise.

Headline shape (BT=8192, V=128256, full fwd+bwd)

GPU dtype baseline this PR latency throughput peak mem (base → PR)
H100 bf16 160.882 ms 63.204 ms -60.7% +154.5% 6.64 → 3.55 GB (-46%)
H100 fp32 596.720 ms 563.251 ms -5.6% +5.9% 9.01 → 7.04 GB (-22%)
B200 bf16 124.027 ms 31.484 ms -74.6% +293.9% 6.64 → 3.55 GB (-46%)
B200 fp32 472.132 ms 476.846 ms +1.0% -1.0% 9.01 → 7.04 GB (-22%)

Across context length (V=128256, BT 1024–65536)

GPU dtype latency min/avg/max peak-mem min/avg/max
H100 bf16 -25% / -52% / -62% -31% / -44% / -49%
H100 fp32 -1% / -9% / -24% -8% / -19% / -24%
B200 bf16 -36% / -66% / -79% -31% / -44% / -49%
B200 fp32 +3% / -4% / -16% -8% / -19% / -24%

Figures

Latency and peak memory vs context length and vs vocab, baseline vs this PR, on H100 and B200.

flce_addmm_bf16_bt flce_addmm_bf16_vocab flce_addmm_fp32_bt flce_addmm_fp32_vocab flce_addmm_mem_bf16_bt flce_addmm_mem_bf16_vocab flce_addmm_mem_fp32_bt flce_addmm_mem_fp32_vocab

Comment on lines +224 to +227
if grad_weight.dtype == grad_logits_chunk.dtype == _input_chunk.dtype:
grad_weight.addmm_(grad_logits_chunk.t(), _input_chunk)
else:
grad_weight += torch.mm(grad_logits_chunk.t(), _input_chunk).float()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#1239 is solving same issue by passing out_dtype and out args

torch.addmm(..., out_dtype=torch.float32, out=grad_weight)

note that out_dtype is only supported in torch>2.8.0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants