Skip to content

[Common, PyTorch] Add Triton MLA attention kernels for SM80#2950

Open
bzantium wants to merge 5 commits intoNVIDIA:mainfrom
bzantium:mla-triton-sm80
Open

[Common, PyTorch] Add Triton MLA attention kernels for SM80#2950
bzantium wants to merge 5 commits intoNVIDIA:mainfrom
bzantium:mla-triton-sm80

Conversation

@bzantium
Copy link
Copy Markdown

@bzantium bzantium commented Apr 30, 2026

Summary

Adds FlashAttention-2 style Triton kernels for MLA-shaped attention (head_dim_qk != head_dim_v, e.g. DeepSeek-V2 192/128) targeted at SM80 (A100). Three kernel families:

  • Prefill / training forward: standard FA-2 online softmax adapted for non-square head dims, right-aligned causal, autotuned. Saves an fp32 LSE for backward.
  • Analytical backward: canonical FA-2 three-pass structure (preprocess for Delta = rowsum(O * dO), then dQ over Q-tile programs and dK/dV over K/V-tile programs). No atomics — each program owns a distinct output slice.
  • Decode forward over compressed KV cache: c_kv [B, S_kv, R] and k_rope [B, S_kv, R_rope] with absorbed up-projection (Q's nope side pre-multiplied by W_uk^T). c_kv plays both K (nope side) and V; per-head K/V are never materialized. Returns O_inter [B, H, S_q, R]; the caller applies W_uv.

PyTorch wrapper (transformer_engine/pytorch/triton/mla.py):

  • mla_attention(q, k, v, *, softmax_scale, is_causal, qkv_format) via torch.autograd.Function. Layouts: bshd, bhsd, sbhd.
  • mla_decode_attention(q_nope_abs, q_rope, c_kv, k_rope, *, softmax_scale, is_causal).

Optional DotProductAttention hookup behind NVTE_MLA_TRITON=1 (default off). The existing get_attention_backend() return signature is preserved; the early-out lives only in DotProductAttention.forward and falls through to the regular FA / Fused / Unfused cascade unless every supported feature flag matches.

Motivation

On SM80, MLA-shaped attention currently has these options:

  • FusedAttention via cuDNN — works but couples MLA users to cuDNN
  • FA4 SM80 MLA — limited
  • UnfusedDotProductAttention — slow PyTorch fallback

This PR adds a Triton-based path that's hackable / open-coded for users experimenting with MLA on A100, and includes a compressed-KV decode kernel (FlashMLA is SM90+ only).

Benchmarks (A100, bf16, autotuned, NVTE_DISABLE_TRITON_AUTOTUNING=0)

Prefill (D_qk=192, D_v=128, causal, BHSD)

Shape Triton fwd xFormers EFFICIENT Triton fwd+bwd xFormers fwd+bwd
B=2, H=8, S=512 0.33ms 0.34ms 0.96ms 1.36ms (1.42x slower)
B=1, H=16, S=1024 0.28ms 0.19ms (1.4x faster) 1.30ms 3.63ms (2.79x slower)
B=1, H=16, S=2048 0.88ms 1.08ms 3.55ms 9.10ms (2.56x slower)
B=1, H=16, S=4096 2.40ms 2.85ms

Forward is competitive with xFormers' memory-efficient attention (the only PyTorch SDPA backend on this version that accepts D_qk=192). Backward is consistently 1.4–2.8x faster.

Decode KV-cache footprint (B=1, H=128, S_kv=4096, R=512, R_rope=64 — DSv2 inference)

Compressed (this kernel) Full per-head K + V
KV cache size 4.5 MB 320 MB (71x larger)

No SDPA-accessible alternative on SM80 implements the absorbed-projection path.

Caveats

  • PyTorch SDPA's CUDNN_ATTENTION backend rejects head_dim > 128 on this version, so cuDNN-via-SDPA is not a valid baseline. cuDNN through TE's FusedAttention (which uses the cuDNN frontend directly with different head_dim limits) is not benchmarked here; follow-up.
  • All Triton numbers use autotune ON (warm). First call per unique (S_q, S_kv, D_qk, D_v, IS_CAUSAL) tuple is slower (~5s) due to autotune compilation.

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 30, 2026

Greptile Summary

This PR adds three Triton kernel families for MLA-shaped attention on SM80 — prefill forward, FA-2 analytical backward (no atomics), and a compressed-KV decode forward — along with PyTorch autograd wrappers and an opt-in DotProductAttention dispatch behind NVTE_MLA_TRITON=1. The implementation is well-structured: the FA-2 online-softmax logic is correctly adapted for non-square head dims, the decode path correctly guards against silent gradient loss, and the causal dispatch covers both attn_mask_type and window_size=(-1,0). All findings are P2 style/coverage observations; no new P0/P1 issues were identified beyond those already discussed in prior review threads.

Confidence Score: 4/5

Safe to merge with minor polish; all new findings are P2 style/coverage observations.

No new P0 or P1 issues found. The kernel math (FA-2 online softmax, three-pass backward, absorbed-projection decode) is correct. The l_i=1.0 initializer is unconventional but provably safe with the -1e6 sentinel. Remaining concerns are hardcoded stride-1 assumptions (safe today, fragile long-term), an overly conservative head-dim whitelist in the DPA dispatch, and missing backward test coverage for non-bshd layouts. Score is 4 due to the complexity of the Triton kernel suite and open items from prior review threads.

transformer_engine/pytorch/triton/init.py (eager triton import), transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py (head_dim whitelist), tests/pytorch/test_mla_triton.py (backward layout coverage)

Important Files Changed

Filename Overview
transformer_engine/common/triton/mla.py New file: Triton kernels for MLA prefill/training forward+backward and compressed-KV decode forward. FA-2 style, no atomics. One non-obvious initialisation (l_i=1.0 with -1e6 sentinel) that is correct but undocumented.
transformer_engine/pytorch/triton/mla.py New file: PyTorch wrappers (mla_attention via autograd.Function, mla_decode_attention forward-only). Hardcodes last-dim stride as 1 throughout launchers — safe only because all inputs are made contiguous first. Backward layout coverage in tests is incomplete.
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py Adds optional MLA Triton dispatch guarded by NVTE_MLA_TRITON=1. Causal detection covers both attn_mask_type and window_size=(-1,0). Head-dim whitelist (128, 192, 256) is overly conservative relative to what the kernel actually supports.
transformer_engine/pytorch/attention/dot_product_attention/backends.py Adds MLATritonAttention nn.Module with lazy import of mla_attention inside forward(). Clean wrapper; no issues.
transformer_engine/pytorch/triton/init.py Adds eager from transformer_engine.pytorch.triton import mla, making Triton a hard transitive dependency for all TE-PyTorch users.
tests/pytorch/test_mla_triton.py New test file: good forward/decode coverage across shapes, dtypes, layouts, and causal flags. Backward tests only exercise bshd, leaving bhsd/sbhd gradient paths untested.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["DotProductAttention.forward()"] -->|NVTE_MLA_TRITON=1 + preconditions| B["MLATritonAttention.forward()"]
    A -->|fallthrough| C["FA / FusedAttn / Unfused cascade"]
    B --> D["mla_attention()"]
    D --> E["MLAttentionFn.apply()"]
    E -->|forward| F["_launch_mla_fwd()\n_mla_attn_fwd Triton kernel\n-> O, LSE"]
    E -->|backward| G["_launch_mla_bwd()\n1. preprocess -> Delta\n2. dq -> dQ\n3. dkv -> dK,dV"]
    H["mla_decode_attention()"] -->|forward-only, no autograd| I["_launch_mla_decode_fwd()\n_mla_decode_attn_fwd Triton kernel\n-> O_inter (no LSE)"]
    subgraph kernels["common/triton/mla.py"]
        F
        G
        I
    end
    subgraph wrappers["pytorch/triton/mla.py"]
        D
        E
        H
    end
Loading

Reviews (3): Last reviewed commit: "[Common, PyTorch] Drop unused LSE plumbi..." | Re-trigger Greptile

Comment thread transformer_engine/pytorch/triton/mla.py Outdated
Comment thread transformer_engine/pytorch/triton/mla.py Outdated
FlashAttention-2 style Triton kernels for MLA-shaped attention
(head_dim_qk != head_dim_v, e.g. DeepSeek-V2 192/128) targeted at
SM80 (A100), where FlashMLA / FA4-MLA SM80 paths are not available.

Three kernel families in transformer_engine/common/triton/mla.py:

- Prefill / training forward: standard FA-2 online softmax adapted
  for non-square head dims, right-aligned causal, autotuned tile
  sizes. Saves an fp32 LSE for backward.
- Analytical backward: canonical FA-2 three-pass structure
  (preprocess for Delta = rowsum(O * dO), then dQ over Q-tile
  programs and dK/dV over K/V-tile programs). No atomics — each
  program owns a distinct output slice.
- Decode forward over compressed KV cache: c_kv [B, S_kv, R] and
  k_rope [B, S_kv, R_rope] with absorbed up-projection (Q's nope
  side pre-multiplied by W_uk^T). c_kv plays both K (nope side)
  and V; per-head K/V are never materialized. Returns
  O_inter [B, H, S_q, R]; the caller applies W_uv.

PyTorch wrapper (transformer_engine/pytorch/triton/mla.py):

- mla_attention(q, k, v, *, softmax_scale, is_causal, qkv_format)
  via torch.autograd.Function (Triton fwd + Triton bwd). Layouts:
  bshd, bhsd, sbhd. Pure-PyTorch mla_attention_ref kept as the test
  reference.
- mla_decode_attention(q_nope_abs, q_rope, c_kv, k_rope, *,
  softmax_scale, is_causal). softmax_scale is required (R is not
  the original head_dim_qk so no sane default).

Optional DotProductAttention hookup behind NVTE_MLA_TRITON=1
(default off). MLATritonAttention added to backends.py;
DotProductAttention.forward gains a strict-precondition early-out
that falls through to the regular FA / Fused / Unfused cascade
unless every supported feature flag matches (no FP8, no dropout,
no context parallel, no alibi/bias, no padding/sliding-window mask,
no inference cache, bshd/sbhd, MLA-shaped, SM80+). The existing
get_attention_backend() return signature is preserved, so existing
dispatch is untouched when the env var is unset.

Tests at tests/pytorch/test_mla_triton.py exercise:

- Prefill forward across {bf16, fp16} x {causal, non-causal} x
  {bshd, bhsd, sbhd} for shapes including DeepSeek-V2 prefill,
  cross-attention (S_q != S_kv), and non-multiple-of-block seqlens.
- Backward dQ/dK/dV vs fp32 PyTorch reference within bf16/fp16
  tolerances.
- Decode forward across DeepSeek-V2 dims (R=512, R_rope=64) and
  smoke shapes; plus dim-mismatch and dtype rejection.
- DPA dispatch: equality with direct mla_attention call when
  NVTE_MLA_TRITON=1, and fall-through preservation when unset.

Signed-off-by: Minho Ryu <ryumin93@gmail.com>
pre-commit-ci Bot and others added 4 commits April 30, 2026 23:23
- Map ``window_size=(-1, 0)`` (FlashAttention causal-via-window
  convention) to ``is_causal=True`` in the MLA Triton dispatch.
  Previously a caller setting ``attn_mask_type="no_mask"`` together
  with ``window_size=(-1, 0)`` would land in the kernel with
  ``is_causal=False`` and silently attend to future tokens.
- Use string equality (``== "1"``) instead of ``int(os.getenv(...))``
  for ``NVTE_MLA_TRITON``, so non-integer values like ``"true"`` no
  longer raise ``ValueError`` at dispatch time.
- ``mla_decode_attention`` now raises ``NotImplementedError`` when
  any input has ``requires_grad=True``. The decode kernel is launched
  directly (no autograd.Function wrapper) and would otherwise drop
  gradients silently at the kernel boundary; v1 is forward-only.

Test: tests/pytorch/test_mla_triton.py adds
``test_mla_decode_rejects_requires_grad``.

Signed-off-by: Minho Ryu <ryumin93@gmail.com>
The decode forward kernel was allocating and writing an fp32 ``LSE``
buffer that was immediately discarded by the wrapper (v1 ships no
analytical decode backward, and the prefill backward kernels only
consume the prefill kernel's LSE). Remove the buffer, the kernel
parameter, and the trailing ``m_i + tl.log(l_i)`` store. A short
comment marks the spot for re-introduction if a future change adds
decode backward.

No semantic change. Decode correctness re-verified across DSv2 dims
(R=512, R_rope=64) and smaller smoke configs against the
fp32-internal PyTorch reference.

Signed-off-by: Minho Ryu <ryumin93@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant