fix(attention): correct causal mask for multi-tile decoder attention (seq_len > mlen)#58
Merged
Merged
Conversation
…(seq_len > mlen) The native per-head flash-attention kernel (_flash_attention_mha) added the same static (mlen, mlen) triangular causal mask to every QK score tile. That is correct only when seq_len <= mlen, where the score matrix is a single tile. When seq_len > mlen the seq x seq score matrix tiles into a grid and the mask must depend on tile position: strictly-future key tiles (k_idx > q_idx) need to be fully masked but the triangular mask left their lower triangle open (leaking future keys), and strictly-past key tiles (k_idx < q_idx) need no mask but the triangular mask wrongly masked their upper triangle (dropping valid past keys). Only the diagonal tile was correct, producing a partial numerical divergence (~58% allclose) on decoder runs at seq_len > mlen. The kernel now derives each tile's global causal geometry: strictly-future key tiles contribute exp(-inf)=0 and are skipped entirely, the triangular mask is applied only on the diagonal tile, and strictly-past tiles are left unmasked. A NotImplementedError guards the unsupported kv_seq_len != seq_len (KV-cache query offset) case instead of silently miscomputing. seq_len <= mlen remains the single-tile special case and is byte-identical; non-causal callers (causal_mask=None, e.g. vision) are unaffected. This regime was never exercised before: the sub-64 decoder fixes (#54/#55) validated seq_len=4 < mlen (single tile), the default mlen=256 keeps seq_len=64 < mlen (single tile), and vision is non-causal. Verified by isolation (native_64x64x16 seq=64 PASS vs seq=128 was-FAIL-now-PASS, identical config) and by the full SmolVLM2 1L decoder sweep at mlen=16/32 (all blen) going from allclose FAIL to PASS. Adds test_mha_causal_skips_future_tiles_and_masks_only_diagonal asserting future tiles are skipped and the mask lands on diagonal tiles only.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
The native per-head flash-attention kernel (
_flash_attention_mha) added the same static(mlen, mlen)triangular causal mask to every QK score tile. That is correct only whenseq_len <= mlen, where theseq x seqscore matrix is a single tile. Whenseq_len > mlenthe score matrix tiles into a grid and the mask must depend on tile position: strictly-future key tiles (k_idx > q_idx) must be fully masked but the triangular mask left their lower triangle open (leaking future keys), and strictly-past key tiles (k_idx < q_idx) need no mask but the triangular mask wrongly masked their upper triangle (dropping valid past keys). Only the diagonal tile was correct, producing a partial numerical divergence (~58% allclose) on decoder runs atseq_len > mlen.Fix
The kernel now derives each tile's global causal geometry: strictly-future key tiles contribute
exp(-inf)=0and are skipped entirely, the triangular mask is applied only on the diagonal tile, and strictly-past tiles are left unmasked. ANotImplementedErrorguards the unsupportedkv_seq_len != seq_len(KV-cache query offset) case rather than silently miscomputing.seq_len <= mlenremains the single-tile special case and is byte-identical; non-causal callers (causal_mask=None, e.g. the vision encoder) are unaffected.Why it was never caught
This regime was simply never exercised. The sub-64 decoder fixes (#54/#55) validated
seq_len=4 < mlen(single tile); the defaultmlen=256keepsseq_len=64 < mlen(single tile); vision is non-causal. The first run atseq_len > mlenfor a causal decoder was the recent sub-64 blen sweep, which surfaced the divergence.Verification
Isolation run, identical config
native_64x64x16, only seq differs —seq=64(1 tile) PASS vsseq=128(2 tiles) FAIL→PASS — pins the cause to multi-tile causal masking, ruling out the #56 RoPE/norm rolling (same rolled code passes at 1 tile). The full SmolVLM2 1-layer decoder sweep atmlen=16andmlen=32across allblengoes from allclose FAIL to PASS:The existing
flash_attention_mhagolden test still passes at 100%, and a new codegen regression test (test_mha_causal_skips_future_tiles_and_masks_only_diagonal) asserts that strictly-future tiles are skipped and the triangular mask lands on diagonal tiles only — it fails on the pre-fix kernel.