Skip to content

Fix native decoder attention for seq_len < mlen (sub-64 tile configs)#54

Merged
booth-algo merged 1 commit into
mainfrom
fix/vision-smolvlm2
May 31, 2026
Merged

Fix native decoder attention for seq_len < mlen (sub-64 tile configs)#54
booth-algo merged 1 commit into
mainfrom
fix/vision-smolvlm2

Conversation

@booth-algo

Copy link
Copy Markdown
Collaborator

Summary

Three independent codegen bugs made the native decoder emit wrong-but-finite output when seq_len < mlen (e.g. SmolVLM2 at the 16/16/4 tile, seq=4: 23.5% allclose). All three used mlen or seq_len where the physical padded row count belongs, so all were invisible at seq_len == mlen (the blind spot that previously caused a stride fix to be reverted).

  • compile_native_hf_decoder — pad sub-mlen sequences up to MLEN. The flash-attention kernel asserts physical rows/batch ≥ max(mlen, seq_len); a BLEN-padded sub-mlen sequence left the per-head Q/O slices too short. No-op at seq ≥ mlen.
  • _emit_attention_block — the per-head Q-read and O-write stride must span the full physical per-head region (physical_rows * head_dim), not seq_len * mlen (one col-block). The short stride overlapped the heads, so O_proj read only the first ~2 correctly and the rest fell past the written range (→ zero). Also zero-fills O_full. Mirrors the vision o_head_stride.
  • _pv_multiply_asm / _pv_multiply_asm_unrolled — the PV head_dim col-block base must use the PV physical row count (min(mlen, seq_len)), not mlen*mlen. At seq < mlen the upper head_dim col-blocks were written out of bounds, leaving cols mlen..head_dim zero, so the attention output shrank by mlen/head_dim.

Verification

smolvlm2 native_16x16x4_b1, 1 layer (cycle-accurate emulator):

  • seq=4 (seq < mlen): 100% allclose (was 23.5%)
  • seq=16 (== mlen): 100% allclose — no regression

Three independent codegen bugs made the native decoder emit wrong-but-finite
output when seq_len < mlen (e.g. SmolVLM2 at the 16/16/4 tile, seq=4: 23.5%
allclose). All three used mlen or seq_len where the physical padded row count
belongs, so all were invisible at seq_len == mlen.

- compile_native_hf_decoder: pad sub-mlen sequences to MLEN. The flash-attn
  kernel asserts physical rows/batch >= max(mlen, seq_len); a BLEN-padded
  sub-mlen sequence left the per-head Q/O slices too short. No-op at seq>=mlen.
- _emit_attention_block: the per-head Q-read and O-write stride must span the
  full physical per-head region (physical_rows * head_dim), not seq_len * mlen
  (= physical_rows * mlen, one col-block). The short stride overlapped the
  heads so O_proj read only the first ~2 correctly (rest fell past the written
  range -> zero). Also zero-fill O_full. Mirrors the vision o_head_stride.
- _pv_multiply_asm / _pv_multiply_asm_unrolled: the PV head_dim col-block base
  must use the PV physical row count (min(mlen, seq_len)), not mlen*mlen. At
  seq < mlen the upper head_dim col-blocks were written out of bounds, leaving
  cols mlen..head_dim zero so the attention output shrank by mlen/head_dim.

Verified on smolvlm2 @16/16/4, 1 layer: seq=4 now 100% allclose (was 23.5%);
seq=16 (== mlen) unchanged at 100% (no regression).
@booth-algo booth-algo merged commit 7539bdd into main May 31, 2026
3 checks passed
@booth-algo booth-algo deleted the fix/vision-smolvlm2 branch May 31, 2026 22:36
booth-algo added a commit that referenced this pull request Jun 3, 2026
…(seq_len > mlen) (#58)

The native per-head flash-attention kernel (_flash_attention_mha) added the same static (mlen, mlen) triangular causal mask to every QK score tile. That is correct only when seq_len <= mlen, where the score matrix is a single tile. When seq_len > mlen the seq x seq score matrix tiles into a grid and the mask must depend on tile position: strictly-future key tiles (k_idx > q_idx) need to be fully masked but the triangular mask left their lower triangle open (leaking future keys), and strictly-past key tiles (k_idx < q_idx) need no mask but the triangular mask wrongly masked their upper triangle (dropping valid past keys). Only the diagonal tile was correct, producing a partial numerical divergence (~58% allclose) on decoder runs at seq_len > mlen.

The kernel now derives each tile's global causal geometry: strictly-future key tiles contribute exp(-inf)=0 and are skipped entirely, the triangular mask is applied only on the diagonal tile, and strictly-past tiles are left unmasked. A NotImplementedError guards the unsupported kv_seq_len != seq_len (KV-cache query offset) case instead of silently miscomputing. seq_len <= mlen remains the single-tile special case and is byte-identical; non-causal callers (causal_mask=None, e.g. vision) are unaffected.

This regime was never exercised before: the sub-64 decoder fixes (#54/#55) validated seq_len=4 < mlen (single tile), the default mlen=256 keeps seq_len=64 < mlen (single tile), and vision is non-causal. Verified by isolation (native_64x64x16 seq=64 PASS vs seq=128 was-FAIL-now-PASS, identical config) and by the full SmolVLM2 1L decoder sweep at mlen=16/32 (all blen) going from allclose FAIL to PASS. Adds test_mha_causal_skips_future_tiles_and_masks_only_diagonal asserting future tiles are skipped and the mask lands on diagonal tiles only.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant