[Common, PyTorch] Add Triton MLA attention kernels for SM80#2950
[Common, PyTorch] Add Triton MLA attention kernels for SM80#2950bzantium wants to merge 5 commits intoNVIDIA:mainfrom
Conversation
Greptile SummaryThis PR adds three Triton kernel families for MLA-shaped attention on SM80 — prefill forward, FA-2 analytical backward (no atomics), and a compressed-KV decode forward — along with PyTorch autograd wrappers and an opt-in Confidence Score: 4/5Safe to merge with minor polish; all new findings are P2 style/coverage observations. No new P0 or P1 issues found. The kernel math (FA-2 online softmax, three-pass backward, absorbed-projection decode) is correct. The transformer_engine/pytorch/triton/init.py (eager triton import), transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py (head_dim whitelist), tests/pytorch/test_mla_triton.py (backward layout coverage) Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["DotProductAttention.forward()"] -->|NVTE_MLA_TRITON=1 + preconditions| B["MLATritonAttention.forward()"]
A -->|fallthrough| C["FA / FusedAttn / Unfused cascade"]
B --> D["mla_attention()"]
D --> E["MLAttentionFn.apply()"]
E -->|forward| F["_launch_mla_fwd()\n_mla_attn_fwd Triton kernel\n-> O, LSE"]
E -->|backward| G["_launch_mla_bwd()\n1. preprocess -> Delta\n2. dq -> dQ\n3. dkv -> dK,dV"]
H["mla_decode_attention()"] -->|forward-only, no autograd| I["_launch_mla_decode_fwd()\n_mla_decode_attn_fwd Triton kernel\n-> O_inter (no LSE)"]
subgraph kernels["common/triton/mla.py"]
F
G
I
end
subgraph wrappers["pytorch/triton/mla.py"]
D
E
H
end
Reviews (3): Last reviewed commit: "[Common, PyTorch] Drop unused LSE plumbi..." | Re-trigger Greptile |
FlashAttention-2 style Triton kernels for MLA-shaped attention
(head_dim_qk != head_dim_v, e.g. DeepSeek-V2 192/128) targeted at
SM80 (A100), where FlashMLA / FA4-MLA SM80 paths are not available.
Three kernel families in transformer_engine/common/triton/mla.py:
- Prefill / training forward: standard FA-2 online softmax adapted
for non-square head dims, right-aligned causal, autotuned tile
sizes. Saves an fp32 LSE for backward.
- Analytical backward: canonical FA-2 three-pass structure
(preprocess for Delta = rowsum(O * dO), then dQ over Q-tile
programs and dK/dV over K/V-tile programs). No atomics — each
program owns a distinct output slice.
- Decode forward over compressed KV cache: c_kv [B, S_kv, R] and
k_rope [B, S_kv, R_rope] with absorbed up-projection (Q's nope
side pre-multiplied by W_uk^T). c_kv plays both K (nope side)
and V; per-head K/V are never materialized. Returns
O_inter [B, H, S_q, R]; the caller applies W_uv.
PyTorch wrapper (transformer_engine/pytorch/triton/mla.py):
- mla_attention(q, k, v, *, softmax_scale, is_causal, qkv_format)
via torch.autograd.Function (Triton fwd + Triton bwd). Layouts:
bshd, bhsd, sbhd. Pure-PyTorch mla_attention_ref kept as the test
reference.
- mla_decode_attention(q_nope_abs, q_rope, c_kv, k_rope, *,
softmax_scale, is_causal). softmax_scale is required (R is not
the original head_dim_qk so no sane default).
Optional DotProductAttention hookup behind NVTE_MLA_TRITON=1
(default off). MLATritonAttention added to backends.py;
DotProductAttention.forward gains a strict-precondition early-out
that falls through to the regular FA / Fused / Unfused cascade
unless every supported feature flag matches (no FP8, no dropout,
no context parallel, no alibi/bias, no padding/sliding-window mask,
no inference cache, bshd/sbhd, MLA-shaped, SM80+). The existing
get_attention_backend() return signature is preserved, so existing
dispatch is untouched when the env var is unset.
Tests at tests/pytorch/test_mla_triton.py exercise:
- Prefill forward across {bf16, fp16} x {causal, non-causal} x
{bshd, bhsd, sbhd} for shapes including DeepSeek-V2 prefill,
cross-attention (S_q != S_kv), and non-multiple-of-block seqlens.
- Backward dQ/dK/dV vs fp32 PyTorch reference within bf16/fp16
tolerances.
- Decode forward across DeepSeek-V2 dims (R=512, R_rope=64) and
smoke shapes; plus dim-mismatch and dtype rejection.
- DPA dispatch: equality with direct mla_attention call when
NVTE_MLA_TRITON=1, and fall-through preservation when unset.
Signed-off-by: Minho Ryu <ryumin93@gmail.com>
30c8c16 to
f561681
Compare
for more information, see https://pre-commit.ci
- Map ``window_size=(-1, 0)`` (FlashAttention causal-via-window convention) to ``is_causal=True`` in the MLA Triton dispatch. Previously a caller setting ``attn_mask_type="no_mask"`` together with ``window_size=(-1, 0)`` would land in the kernel with ``is_causal=False`` and silently attend to future tokens. - Use string equality (``== "1"``) instead of ``int(os.getenv(...))`` for ``NVTE_MLA_TRITON``, so non-integer values like ``"true"`` no longer raise ``ValueError`` at dispatch time. - ``mla_decode_attention`` now raises ``NotImplementedError`` when any input has ``requires_grad=True``. The decode kernel is launched directly (no autograd.Function wrapper) and would otherwise drop gradients silently at the kernel boundary; v1 is forward-only. Test: tests/pytorch/test_mla_triton.py adds ``test_mla_decode_rejects_requires_grad``. Signed-off-by: Minho Ryu <ryumin93@gmail.com>
for more information, see https://pre-commit.ci
The decode forward kernel was allocating and writing an fp32 ``LSE`` buffer that was immediately discarded by the wrapper (v1 ships no analytical decode backward, and the prefill backward kernels only consume the prefill kernel's LSE). Remove the buffer, the kernel parameter, and the trailing ``m_i + tl.log(l_i)`` store. A short comment marks the spot for re-introduction if a future change adds decode backward. No semantic change. Decode correctness re-verified across DSv2 dims (R=512, R_rope=64) and smaller smoke configs against the fp32-internal PyTorch reference. Signed-off-by: Minho Ryu <ryumin93@gmail.com>
Summary
Adds FlashAttention-2 style Triton kernels for MLA-shaped attention (
head_dim_qk != head_dim_v, e.g. DeepSeek-V2 192/128) targeted at SM80 (A100). Three kernel families:Delta = rowsum(O * dO), thendQover Q-tile programs anddK/dVover K/V-tile programs). No atomics — each program owns a distinct output slice.c_kv [B, S_kv, R]andk_rope [B, S_kv, R_rope]with absorbed up-projection (Q's nope side pre-multiplied byW_uk^T).c_kvplays both K (nope side) and V; per-head K/V are never materialized. ReturnsO_inter [B, H, S_q, R]; the caller appliesW_uv.PyTorch wrapper (
transformer_engine/pytorch/triton/mla.py):mla_attention(q, k, v, *, softmax_scale, is_causal, qkv_format)viatorch.autograd.Function. Layouts:bshd,bhsd,sbhd.mla_decode_attention(q_nope_abs, q_rope, c_kv, k_rope, *, softmax_scale, is_causal).Optional
DotProductAttentionhookup behindNVTE_MLA_TRITON=1(default off). The existingget_attention_backend()return signature is preserved; the early-out lives only inDotProductAttention.forwardand falls through to the regular FA / Fused / Unfused cascade unless every supported feature flag matches.Motivation
On SM80, MLA-shaped attention currently has these options:
FusedAttentionvia cuDNN — works but couples MLA users to cuDNNUnfusedDotProductAttention— slow PyTorch fallbackThis PR adds a Triton-based path that's hackable / open-coded for users experimenting with MLA on A100, and includes a compressed-KV decode kernel (FlashMLA is SM90+ only).
Benchmarks (A100, bf16, autotuned, NVTE_DISABLE_TRITON_AUTOTUNING=0)
Prefill (D_qk=192, D_v=128, causal, BHSD)
Forward is competitive with xFormers' memory-efficient attention (the only PyTorch SDPA backend on this version that accepts D_qk=192). Backward is consistently 1.4–2.8x faster.
Decode KV-cache footprint (B=1, H=128, S_kv=4096, R=512, R_rope=64 — DSv2 inference)
No SDPA-accessible alternative on SM80 implements the absorbed-projection path.
Caveats
head_dim > 128on this version, so cuDNN-via-SDPA is not a valid baseline. cuDNN through TE'sFusedAttention(which uses the cuDNN frontend directly with different head_dim limits) is not benchmarked here; follow-up.(S_q, S_kv, D_qk, D_v, IS_CAUSAL)tuple is slower (~5s) due to autotune compilation.