Skip to content

Support batch_size>1 in native MHA decoder, incl. true sub-64 (head_dim>mlen)#55

Merged
booth-algo merged 2 commits into
mainfrom
feat/multibatch-decoder
Jun 2, 2026
Merged

Support batch_size>1 in native MHA decoder, incl. true sub-64 (head_dim>mlen)#55
booth-algo merged 2 commits into
mainfrom
feat/multibatch-decoder

Conversation

@booth-algo

Copy link
Copy Markdown
Collaborator

Summary

Adds batch_size>1 to the native non-packed MHA decoder, including true sub-64
(mlen 16/32 < head_dim 64) multi-batch — both previously raised NotImplementedError.

Decoder (_emit_attention_block)

Rather than re-implement a per-batch loop in the decoder (which NaN-ed under
hand-written column-block-major slab slicing), span the per-head Q/K/V/O across all
batches and delegate to _flash_attention_mha's own batch loop. Also make the
batch>1 output comparison truthful: the compact golden holds only the active rows
(batch*seq), which are rpb-strided in VRAM, so emit num_batches=active and carry
rows_per_batch/active_seq_per_batch for the verifier (paired with PLENA_Tools).

Kernel (_flash_attention_mha)

The per-batch Q/O views truncated physical_shape, valid only for a single head_dim
col-block. Tensors are column-block-major (col-block stride = R*mlen), so for
head_dim>mlen a batch's rows span head_dim/mlen col-blocks. Fix: preserve the full
physical_shape and set q_batch_stride = q_rows_per_batch*mlen (this coincides with
the old *physical_shape[1] at head_dim==mlen, so that path is byte-identical).
o_batch_stride = seq_len*mlen is kept (matches the decoder's O_h read). Remove the
head_dim>mlen guard.

Test plan (numerical, allclose PASS)

  • batch=2 @64/64/16 seq=4 (head_dim==mlen) = 97.37%, byte-identical ISA vs before
  • batch=2 @16/16/4 seq=4 (head_dim 64>mlen 16, 4 col-blocks) = 99.07%, no NaN
  • batch=2 @32/32/4 seq=4 (head_dim 64>mlen 32, 2 col-blocks) = 99.67%, no NaN
  • batch=1 @64/64/16 seq=4 = 99.96% (no regression)

The non-packed decoder attention previously raised NotImplementedError for
batch_size>1. Rather than re-implement a per-batch loop in the decoder (which
NaN-ed under hand-written column-block-major slab slicing), span the per-head
Q/K/V/O across all batches and let the flash-attention kernel's own batch loop
handle it. head_dim<=mlen works now; head_dim>mlen still raises gracefully via
the kernel's guard (sub-64 lands separately).

Also make the batch>1 output comparison truthful. The compact golden holds only
the active rows (batch_size*seq_len), which live rpb-strided in VRAM (batch b at
physical row b*rows_per_batch). Set num_batches to the active-row count and carry
rows_per_batch/active_seq_per_batch so the verifier extracts the strided active
rows instead of comparing all physical (padding) rows.

batch=2 @64/64/16 seq=4: 97.37% allclose PASS (previously a 29% false-negative);
batch=1 unchanged at 99.96%.
The non-packed flash-attention kernel guarded head_dim > mlen with batch_size > 1
because the per-batch Q/O views truncated physical_shape to one batch, which is
only valid for a single head_dim col-block. Tensors are column-block-major
(col-block cb at cb*R*mlen, row r at r*mlen), so for head_dim > mlen a batch's
rows are interleaved across head_dim/mlen col-blocks and the truncated views
mis-strided every col-block past the first.

Fix: preserve the full physical_shape in Q_batch/O_batch so the col-block stride
stays R*mlen, and set the per-batch Q base offset to q_rows_per_batch*mlen (the
flat distance to batch b's first row within col-block 0) instead of
q_rows_per_batch*physical_shape[1] (which over-skips by head_dim/mlen col-blocks).
The two expressions coincide at head_dim == mlen, so the existing head_dim <= mlen
batched path is unchanged. o_batch_stride (seq_len*mlen) already matched the
decoder's O_h read and is kept.

Verified: 16/16/4 batch=2 seq=4 (head_dim 64 > mlen 16, 4 col-blocks) = 99.07%
allclose PASS, no NaN -- previously raised NotImplementedError. 64/64/16 batch=2
regression = 97.37%, byte-identical to before (no-op at head_dim == mlen).
batch=1 unaffected (single-batch fast path).
@booth-algo booth-algo merged commit a4c80f8 into main Jun 2, 2026
3 checks passed
@booth-algo booth-algo deleted the feat/multibatch-decoder branch June 2, 2026 00:41
booth-algo added a commit that referenced this pull request Jun 3, 2026
…(seq_len > mlen) (#58)

The native per-head flash-attention kernel (_flash_attention_mha) added the same static (mlen, mlen) triangular causal mask to every QK score tile. That is correct only when seq_len <= mlen, where the score matrix is a single tile. When seq_len > mlen the seq x seq score matrix tiles into a grid and the mask must depend on tile position: strictly-future key tiles (k_idx > q_idx) need to be fully masked but the triangular mask left their lower triangle open (leaking future keys), and strictly-past key tiles (k_idx < q_idx) need no mask but the triangular mask wrongly masked their upper triangle (dropping valid past keys). Only the diagonal tile was correct, producing a partial numerical divergence (~58% allclose) on decoder runs at seq_len > mlen.

The kernel now derives each tile's global causal geometry: strictly-future key tiles contribute exp(-inf)=0 and are skipped entirely, the triangular mask is applied only on the diagonal tile, and strictly-past tiles are left unmasked. A NotImplementedError guards the unsupported kv_seq_len != seq_len (KV-cache query offset) case instead of silently miscomputing. seq_len <= mlen remains the single-tile special case and is byte-identical; non-causal callers (causal_mask=None, e.g. vision) are unaffected.

This regime was never exercised before: the sub-64 decoder fixes (#54/#55) validated seq_len=4 < mlen (single tile), the default mlen=256 keeps seq_len=64 < mlen (single tile), and vision is non-causal. Verified by isolation (native_64x64x16 seq=64 PASS vs seq=128 was-FAIL-now-PASS, identical config) and by the full SmolVLM2 1L decoder sweep at mlen=16/32 (all blen) going from allclose FAIL to PASS. Adds test_mha_causal_skips_future_tiles_and_masks_only_diagonal asserting future tiles are skipped and the mask lands on diagonal tiles only.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant