Skip to content

perf: optimize grad_weight accumulation with addmm#1239

Open
maskyuanzh wants to merge 5 commits into
linkedin:mainfrom
maskyuanzh:fix-fused-linear-ce-addmm
Open

perf: optimize grad_weight accumulation with addmm#1239
maskyuanzh wants to merge 5 commits into
linkedin:mainfrom
maskyuanzh:fix-fused-linear-ce-addmm

Conversation

@maskyuanzh

@maskyuanzh maskyuanzh commented May 26, 2026

Copy link
Copy Markdown

Summary

This PR optimizes grad_weight accumulation in fused linear cross entropy by replacing:

grad_weight += torch.mm(...).float()

with an in-place torch.addmm(..., out=grad_weight)-based accumulation.

For PyTorch >= 2.8 on CUDA, when out_dtype is supported and accumulating fp16/bf16 operands into an fp32 grad_weight, this uses:

torch.addmm(..., out_dtype=torch.float32, out=grad_weight)

This avoids materializing the full [V, H] intermediate from torch.mm(...).float().

Fixes #1232.

Memory Benchmark

I benchmarked a 128k-vocab case with V=131072, H=4096, chunk_size=2048, bf16 inputs, and fp32 grad_weight on an NVIDIA GeForce RTX 4090.

PyTorch 2.1.2+cu121:
  old mm(...).float():                     extra peak 3072 MiB

PyTorch 2.12.0+cu126:
  old mm(...).float():                     extra peak 3072 MiB
  addmm(out_dtype=torch.float32, out=...): extra peak 0 MiB

So on PyTorch >= 2.8, the out_dtype path removes the large [V, H] peak allocation in this configuration. On earlier PyTorch versions, the existing implementation is preserved.

Testing Done

  • Hardware Type: NVIDIA GeForce RTX 4090
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

Additional targeted testing:

pytest -q test/transformers/test_fused_linear_cross_entropy.py

Passed. This covers the fused linear cross entropy paths affected by this change.

@maskyuanzh maskyuanzh changed the title Optimize grad_weight accumulation with addmm perf: optimize grad_weight accumulation with addmm May 26, 2026
if ce_weight.stride(-1) != 1:
ce_weight = ce_weight.contiguous()

IS_TORCH2P12 = Version(torch.__version__.split("+")[0]) >= Version("2.12.0")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could probably be located globally, so it doesn't need to run every forward pass.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out. I’ll move the PyTorch version check to the module-level constants so it is only evaluated once instead of on every forward pass.

grad_weight,
grad_logits_chunk.t(),
_input_chunk.to(dtype=grad_logits_chunk.t().dtype),
out_dtype=torch.float32,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think technically out_dtype is available earlier than 2.12, I just didnt do the work to track down which version it was introduced in. I think I remember it existing in 2.10 as well.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out. The exact version wasn’t carefully verified here. After checking the PyTorch docs and source tags, I found that out_dtype was added to torch.addmm in PyTorch 2.8.0 for fp16/bf16 CUDA inputs with fp32 output accumulation. I’ll lower the version guard from 2.12.0 to 2.8.0.

torch.addmm(
grad_weight,
grad_logits_chunk.t(),
_input_chunk.to(dtype=grad_logits_chunk.t().dtype),

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this cast now necessary?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for asking. The cast is needed because out_dtype only controls the output/accumulation dtype; torch.addmm still requires mat1 and mat2 to have the same input dtype.

I tested this on PyTorch 2.12.0 + CUDA. With out_dtype=torch.float32, addmm still fails for fp16 x fp32 and bf16 x fp32 inputs:

Half and Float     -> RuntimeError: mat1 and mat2 must have the same dtype
BFloat16 and Float -> RuntimeError: mat1 and mat2 must have the same dtype

After casting mat2 to match mat1, both fp16 and bf16 paths succeed and write into a fp32 output buffer. In the AMP path here, grad_logits_chunk is the low-precision operand while _input_chunk can remain fp32, so this cast aligns _input_chunk with grad_logits_chunk and keeps the multiply in fp16/bf16 while accumulating into fp32.

@Tcc0403 Tcc0403 Jun 23, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now I looked at it, I don't think we ever casted _input_chunk, _input_chunk should always have same dtype as grad_logits_chunk. .to is a no-op when both are same dtype, but I prefer keeping it clean rather than guarding defensively if it could never happen.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, that makes sense. I removed the redundant _input_chunk.to(...) and now pass _input_chunk directly to addmm.

Comment on lines +230 to +231
grad_logits_chunk.t().to(grad_weight.dtype),
_input_chunk.to(grad_weight.dtype),

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the input, chunk, and weight all need the same dtype on this path? The desired behavior is typically to multiply in bf16 and then to accumulate in fp32. Doing the multiply in fp32 as well would be pretty slow, so I dont think this is advisable.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out. I tested the fp32-operand fallback and found that your concern was correct: although it reduces peak memory compared with the old mm(...).float() path, it is slower because the matmul itself runs in fp32.

old: mm(lowp, lowp).float()               23.5 ms, peak 5688 MiB
old fallback: addmm(fp32, fp32, out=fp32) 27.4 ms, peak 3640 MiB
fast: addmm(lowp, lowp, out_dtype=fp32)   13.4 ms, peak 2632 MiB

I updated the logic so this path no longer promotes both operands to fp32. The addmm(..., out_dtype=torch.float32, out=grad_weight) path is now only used when out_dtype is supported, grad_weight is fp32, and grad_logits_chunk is fp16/bf16. Since addmm(..., out=...) does not autocast the operands, and out_dtype only controls the output dtype, _input_chunk is explicitly cast to grad_logits_chunk’s dtype. This keeps the matmul in fp16/bf16 while writing directly into the fp32 accumulation buffer. For unsupported cases, the code now falls back to the original mm(...).float() behavior.

@maskyuanzh

Copy link
Copy Markdown
Author

@Tcc0403 This pr is ready for review. Thanks!

Tcc0403
Tcc0403 previously approved these changes Jun 9, 2026

@Tcc0403 Tcc0403 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, cc @Mecoli1219 and @vaibhavjindal for double check

@maskyuanzh

Copy link
Copy Markdown
Author

@Tcc0403 Gentle ping on this PR.

Thanks again for the approval. As far as I can tell, the review comments have been addressed, and the PR is ready from my side. Since @Mecoli1219 and @vaibhavjindal were cc’ed for double check, I just wanted to ask whether there is any remaining concern or blocker before this can be merged.

Thanks!

# The optimal maximum block size depends on your hardware, your kernel, and your dtype
MAX_FUSED_SIZE = 2048 if infer_device() == "npu" else 65536 // 2
_TORCH_VERSION = Version(torch.__version__.split("+")[0])
_SUPPORTS_ADDM_MIXED_PRECISION_OUT_DTYPE = _TORCH_VERSION >= Version("2.8.0")

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a small typo, it should be 'ADDMM' instead of 'ADDM'.

Also, consider renaming this to _ADDMM_SUPPORTS_OUT_DTYPE to make it less verbose.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion! I renamed it to _ADDMM_SUPPORTS_OUT_DTYPE as suggested.

@vaibhavjindal

Copy link
Copy Markdown
Collaborator

@maskyuanzh thanks for the PR. Two minor things:

  1. Added a comment to rename the constant.
  2. Please update the PR body: The PR body describes a fallback that "explicitly aligns operand dtypes with grad_weight.dtype before calling addmm(out=grad_weight)" and benchmarks it at 1056 MiB on torch 2.1.2. But the actual else branch is the unchanged mm().float() line — which the same benchmark lists at 3072 MiB. So on torch < 2.8 this PR gives zero memory benefit, contrary to the description.

@maskyuanzh

Copy link
Copy Markdown
Author

@vaibhavjindal Thanks for the review and the helpful suggestions!
I renamed the constant to _ADDMM_SUPPORTS_OUT_DTYPE and updated the PR description to match the current implementation.

@maskyuanzh maskyuanzh requested a review from vaibhavjindal July 1, 2026 02:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Mem and compute inefficiency in fused_linear_cross_entropy_foward

4 participants