Support batch_size>1 in native MHA decoder, incl. true sub-64 (head_dim>mlen)#55
Merged
Conversation
The non-packed decoder attention previously raised NotImplementedError for batch_size>1. Rather than re-implement a per-batch loop in the decoder (which NaN-ed under hand-written column-block-major slab slicing), span the per-head Q/K/V/O across all batches and let the flash-attention kernel's own batch loop handle it. head_dim<=mlen works now; head_dim>mlen still raises gracefully via the kernel's guard (sub-64 lands separately). Also make the batch>1 output comparison truthful. The compact golden holds only the active rows (batch_size*seq_len), which live rpb-strided in VRAM (batch b at physical row b*rows_per_batch). Set num_batches to the active-row count and carry rows_per_batch/active_seq_per_batch so the verifier extracts the strided active rows instead of comparing all physical (padding) rows. batch=2 @64/64/16 seq=4: 97.37% allclose PASS (previously a 29% false-negative); batch=1 unchanged at 99.96%.
The non-packed flash-attention kernel guarded head_dim > mlen with batch_size > 1 because the per-batch Q/O views truncated physical_shape to one batch, which is only valid for a single head_dim col-block. Tensors are column-block-major (col-block cb at cb*R*mlen, row r at r*mlen), so for head_dim > mlen a batch's rows are interleaved across head_dim/mlen col-blocks and the truncated views mis-strided every col-block past the first. Fix: preserve the full physical_shape in Q_batch/O_batch so the col-block stride stays R*mlen, and set the per-batch Q base offset to q_rows_per_batch*mlen (the flat distance to batch b's first row within col-block 0) instead of q_rows_per_batch*physical_shape[1] (which over-skips by head_dim/mlen col-blocks). The two expressions coincide at head_dim == mlen, so the existing head_dim <= mlen batched path is unchanged. o_batch_stride (seq_len*mlen) already matched the decoder's O_h read and is kept. Verified: 16/16/4 batch=2 seq=4 (head_dim 64 > mlen 16, 4 col-blocks) = 99.07% allclose PASS, no NaN -- previously raised NotImplementedError. 64/64/16 batch=2 regression = 97.37%, byte-identical to before (no-op at head_dim == mlen). batch=1 unaffected (single-batch fast path).
3 tasks
booth-algo
added a commit
that referenced
this pull request
Jun 3, 2026
…(seq_len > mlen) (#58) The native per-head flash-attention kernel (_flash_attention_mha) added the same static (mlen, mlen) triangular causal mask to every QK score tile. That is correct only when seq_len <= mlen, where the score matrix is a single tile. When seq_len > mlen the seq x seq score matrix tiles into a grid and the mask must depend on tile position: strictly-future key tiles (k_idx > q_idx) need to be fully masked but the triangular mask left their lower triangle open (leaking future keys), and strictly-past key tiles (k_idx < q_idx) need no mask but the triangular mask wrongly masked their upper triangle (dropping valid past keys). Only the diagonal tile was correct, producing a partial numerical divergence (~58% allclose) on decoder runs at seq_len > mlen. The kernel now derives each tile's global causal geometry: strictly-future key tiles contribute exp(-inf)=0 and are skipped entirely, the triangular mask is applied only on the diagonal tile, and strictly-past tiles are left unmasked. A NotImplementedError guards the unsupported kv_seq_len != seq_len (KV-cache query offset) case instead of silently miscomputing. seq_len <= mlen remains the single-tile special case and is byte-identical; non-causal callers (causal_mask=None, e.g. vision) are unaffected. This regime was never exercised before: the sub-64 decoder fixes (#54/#55) validated seq_len=4 < mlen (single tile), the default mlen=256 keeps seq_len=64 < mlen (single tile), and vision is non-causal. Verified by isolation (native_64x64x16 seq=64 PASS vs seq=128 was-FAIL-now-PASS, identical config) and by the full SmolVLM2 1L decoder sweep at mlen=16/32 (all blen) going from allclose FAIL to PASS. Adds test_mha_causal_skips_future_tiles_and_masks_only_diagonal asserting future tiles are skipped and the mask lands on diagonal tiles only.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds
batch_size>1to the native non-packed MHA decoder, including true sub-64(mlen 16/32 < head_dim 64) multi-batch — both previously raised
NotImplementedError.Decoder (
_emit_attention_block)Rather than re-implement a per-batch loop in the decoder (which NaN-ed under
hand-written column-block-major slab slicing), span the per-head Q/K/V/O across all
batches and delegate to
_flash_attention_mha's own batch loop. Also make thebatch>1 output comparison truthful: the compact golden holds only the active rows
(
batch*seq), which are rpb-strided in VRAM, so emitnum_batches=active and carryrows_per_batch/active_seq_per_batchfor the verifier (paired with PLENA_Tools).Kernel (
_flash_attention_mha)The per-batch Q/O views truncated
physical_shape, valid only for a single head_dimcol-block. Tensors are column-block-major (col-block stride =
R*mlen), so forhead_dim>mlen a batch's rows span
head_dim/mlencol-blocks. Fix: preserve the fullphysical_shapeand setq_batch_stride = q_rows_per_batch*mlen(this coincides withthe old
*physical_shape[1]at head_dim==mlen, so that path is byte-identical).o_batch_stride = seq_len*mlenis kept (matches the decoder'sO_hread). Remove thehead_dim>mlen guard.
Test plan (numerical, allclose PASS)