Skip to content

fix(attention): correct causal mask for multi-tile decoder attention (seq_len > mlen)#58

Merged
booth-algo merged 1 commit into
mainfrom
fix/causal-mask-multitile
Jun 3, 2026
Merged

fix(attention): correct causal mask for multi-tile decoder attention (seq_len > mlen)#58
booth-algo merged 1 commit into
mainfrom
fix/causal-mask-multitile

Conversation

@booth-algo

Copy link
Copy Markdown
Collaborator

The native per-head flash-attention kernel (_flash_attention_mha) added the same static (mlen, mlen) triangular causal mask to every QK score tile. That is correct only when seq_len <= mlen, where the seq x seq score matrix is a single tile. When seq_len > mlen the score matrix tiles into a grid and the mask must depend on tile position: strictly-future key tiles (k_idx > q_idx) must be fully masked but the triangular mask left their lower triangle open (leaking future keys), and strictly-past key tiles (k_idx < q_idx) need no mask but the triangular mask wrongly masked their upper triangle (dropping valid past keys). Only the diagonal tile was correct, producing a partial numerical divergence (~58% allclose) on decoder runs at seq_len > mlen.

Fix

The kernel now derives each tile's global causal geometry: strictly-future key tiles contribute exp(-inf)=0 and are skipped entirely, the triangular mask is applied only on the diagonal tile, and strictly-past tiles are left unmasked. A NotImplementedError guards the unsupported kv_seq_len != seq_len (KV-cache query offset) case rather than silently miscomputing. seq_len <= mlen remains the single-tile special case and is byte-identical; non-causal callers (causal_mask=None, e.g. the vision encoder) are unaffected.

Why it was never caught

This regime was simply never exercised. The sub-64 decoder fixes (#54/#55) validated seq_len=4 < mlen (single tile); the default mlen=256 keeps seq_len=64 < mlen (single tile); vision is non-causal. The first run at seq_len > mlen for a causal decoder was the recent sub-64 blen sweep, which surfaced the divergence.

Verification

Isolation run, identical config native_64x64x16, only seq differs — seq=64 (1 tile) PASS vs seq=128 (2 tiles) FAIL→PASS — pins the cause to multi-tile causal masking, ruling out the #56 RoPE/norm rolling (same rolled code passes at 1 tile). The full SmolVLM2 1-layer decoder sweep at mlen=16 and mlen=32 across all blen goes from allclose FAIL to PASS:

mlen blen sim_lat allclose
16 4 19.74ms PASS (97.8%)
16 8 6.86ms PASS
16 16 3.63ms PASS
32 4 16.72ms PASS (98.5%)
32 8 4.76ms PASS
32 16 1.76ms PASS
32 32 1.01ms PASS

The existing flash_attention_mha golden test still passes at 100%, and a new codegen regression test (test_mha_causal_skips_future_tiles_and_masks_only_diagonal) asserts that strictly-future tiles are skipped and the triangular mask lands on diagonal tiles only — it fails on the pre-fix kernel.

…(seq_len > mlen)

The native per-head flash-attention kernel (_flash_attention_mha) added the same static (mlen, mlen) triangular causal mask to every QK score tile. That is correct only when seq_len <= mlen, where the score matrix is a single tile. When seq_len > mlen the seq x seq score matrix tiles into a grid and the mask must depend on tile position: strictly-future key tiles (k_idx > q_idx) need to be fully masked but the triangular mask left their lower triangle open (leaking future keys), and strictly-past key tiles (k_idx < q_idx) need no mask but the triangular mask wrongly masked their upper triangle (dropping valid past keys). Only the diagonal tile was correct, producing a partial numerical divergence (~58% allclose) on decoder runs at seq_len > mlen.

The kernel now derives each tile's global causal geometry: strictly-future key tiles contribute exp(-inf)=0 and are skipped entirely, the triangular mask is applied only on the diagonal tile, and strictly-past tiles are left unmasked. A NotImplementedError guards the unsupported kv_seq_len != seq_len (KV-cache query offset) case instead of silently miscomputing. seq_len <= mlen remains the single-tile special case and is byte-identical; non-causal callers (causal_mask=None, e.g. vision) are unaffected.

This regime was never exercised before: the sub-64 decoder fixes (#54/#55) validated seq_len=4 < mlen (single tile), the default mlen=256 keeps seq_len=64 < mlen (single tile), and vision is non-causal. Verified by isolation (native_64x64x16 seq=64 PASS vs seq=128 was-FAIL-now-PASS, identical config) and by the full SmolVLM2 1L decoder sweep at mlen=16/32 (all blen) going from allclose FAIL to PASS. Adds test_mha_causal_skips_future_tiles_and_masks_only_diagonal asserting future tiles are skipped and the mask lands on diagonal tiles only.
@booth-algo booth-algo merged commit ebdba9e into main Jun 3, 2026
3 checks passed
@booth-algo booth-algo deleted the fix/causal-mask-multitile branch June 3, 2026 15:22
@booth-algo booth-algo restored the fix/causal-mask-multitile branch June 3, 2026 15:22
@booth-algo booth-algo deleted the fix/causal-mask-multitile branch June 3, 2026 15:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant