Fix native decoder attention for seq_len < mlen (sub-64 tile configs)#54
Merged
Conversation
Three independent codegen bugs made the native decoder emit wrong-but-finite output when seq_len < mlen (e.g. SmolVLM2 at the 16/16/4 tile, seq=4: 23.5% allclose). All three used mlen or seq_len where the physical padded row count belongs, so all were invisible at seq_len == mlen. - compile_native_hf_decoder: pad sub-mlen sequences to MLEN. The flash-attn kernel asserts physical rows/batch >= max(mlen, seq_len); a BLEN-padded sub-mlen sequence left the per-head Q/O slices too short. No-op at seq>=mlen. - _emit_attention_block: the per-head Q-read and O-write stride must span the full physical per-head region (physical_rows * head_dim), not seq_len * mlen (= physical_rows * mlen, one col-block). The short stride overlapped the heads so O_proj read only the first ~2 correctly (rest fell past the written range -> zero). Also zero-fill O_full. Mirrors the vision o_head_stride. - _pv_multiply_asm / _pv_multiply_asm_unrolled: the PV head_dim col-block base must use the PV physical row count (min(mlen, seq_len)), not mlen*mlen. At seq < mlen the upper head_dim col-blocks were written out of bounds, leaving cols mlen..head_dim zero so the attention output shrank by mlen/head_dim. Verified on smolvlm2 @16/16/4, 1 layer: seq=4 now 100% allclose (was 23.5%); seq=16 (== mlen) unchanged at 100% (no regression).
booth-algo
added a commit
that referenced
this pull request
Jun 3, 2026
…(seq_len > mlen) (#58) The native per-head flash-attention kernel (_flash_attention_mha) added the same static (mlen, mlen) triangular causal mask to every QK score tile. That is correct only when seq_len <= mlen, where the score matrix is a single tile. When seq_len > mlen the seq x seq score matrix tiles into a grid and the mask must depend on tile position: strictly-future key tiles (k_idx > q_idx) need to be fully masked but the triangular mask left their lower triangle open (leaking future keys), and strictly-past key tiles (k_idx < q_idx) need no mask but the triangular mask wrongly masked their upper triangle (dropping valid past keys). Only the diagonal tile was correct, producing a partial numerical divergence (~58% allclose) on decoder runs at seq_len > mlen. The kernel now derives each tile's global causal geometry: strictly-future key tiles contribute exp(-inf)=0 and are skipped entirely, the triangular mask is applied only on the diagonal tile, and strictly-past tiles are left unmasked. A NotImplementedError guards the unsupported kv_seq_len != seq_len (KV-cache query offset) case instead of silently miscomputing. seq_len <= mlen remains the single-tile special case and is byte-identical; non-causal callers (causal_mask=None, e.g. vision) are unaffected. This regime was never exercised before: the sub-64 decoder fixes (#54/#55) validated seq_len=4 < mlen (single tile), the default mlen=256 keeps seq_len=64 < mlen (single tile), and vision is non-causal. Verified by isolation (native_64x64x16 seq=64 PASS vs seq=128 was-FAIL-now-PASS, identical config) and by the full SmolVLM2 1L decoder sweep at mlen=16/32 (all blen) going from allclose FAIL to PASS. Adds test_mha_causal_skips_future_tiles_and_masks_only_diagonal asserting future tiles are skipped and the mask lands on diagonal tiles only.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Three independent codegen bugs made the native decoder emit wrong-but-finite output when
seq_len < mlen(e.g. SmolVLM2 at the 16/16/4 tile,seq=4: 23.5% allclose). All three usedmlenorseq_lenwhere the physical padded row count belongs, so all were invisible atseq_len == mlen(the blind spot that previously caused a stride fix to be reverted).compile_native_hf_decoder— pad sub-mlensequences up toMLEN. The flash-attention kernel asserts physical rows/batch ≥max(mlen, seq_len); aBLEN-padded sub-mlensequence left the per-head Q/O slices too short. No-op atseq ≥ mlen._emit_attention_block— the per-head Q-read and O-write stride must span the full physical per-head region (physical_rows * head_dim), notseq_len * mlen(one col-block). The short stride overlapped the heads, soO_projread only the first ~2 correctly and the rest fell past the written range (→ zero). Also zero-fillsO_full. Mirrors the visiono_head_stride._pv_multiply_asm/_pv_multiply_asm_unrolled— the PVhead_dimcol-block base must use the PV physical row count (min(mlen, seq_len)), notmlen*mlen. Atseq < mlenthe upperhead_dimcol-blocks were written out of bounds, leaving colsmlen..head_dimzero, so the attention output shrank bymlen/head_dim.Verification
smolvlm2
native_16x16x4_b1, 1 layer (cycle-accurate emulator):seq=4(seq < mlen): 100% allclose (was 23.5%)seq=16(== mlen): 100% allclose — no regression