diff --git a/aten/plena/isa_attention.py b/aten/plena/isa_attention.py index d6e9baa..f6f1340 100644 --- a/aten/plena/isa_attention.py +++ b/aten/plena/isa_attention.py @@ -219,6 +219,7 @@ def _pv_multiply_asm( v_hbm_offset: int, pv_address: int, rows: int | None = None, + pv_physical_rows: int | None = None, ) -> str: """ Compute PV = P @ V via M_MM. @@ -232,6 +233,14 @@ def _pv_multiply_asm( column blocks; the outer loop iterates blocks, middle loop iterates blen-wide V columns within a block, inner loop iterates blen-wide P rows. """ + # PV is stored column-block-major with `pv_physical_rows` rows (= min(mlen, + # seq_len) for the decoder), so each head_dim col-block spans + # pv_physical_rows*mlen, not mlen*mlen. They coincide at seq>=mlen, but for + # seq str: """Legacy Python-unrolled P @ V emission, kept for A/B comparisons.""" + if pv_physical_rows is None: + pv_physical_rows = mlen gp_regs = self.register_allocator.allocate_gp(5) gp_p = gp_regs[0] gp_v = gp_regs[1] @@ -348,7 +361,7 @@ def _pv_multiply_asm_unrolled( lines.extend(load_large_int(gp_p, p_row_addr)) lines.append(f"M_MM 0, gp{gp_v}, gp{gp_p}") - pv_offset = v_col_block * mlen * mlen + p_row * blen * mlen + v_col * blen + pv_offset = v_col_block * pv_physical_rows * mlen + p_row * blen * mlen + v_col * blen lines.extend(load_large_int(gp_pv, pv_address + pv_offset)) lines.append(f"M_MM_WO gp{gp_pv}, gp{gp_stride}, 0") @@ -797,6 +810,7 @@ def compute_pv( v_hbm_offset=v_hbm_offset, pv_address=pv_address, rows=rows, + pv_physical_rows=pv_info.physical_shape[0], ) self.register_allocator.free_gp(gp_regs) diff --git a/aten/plena_frontend.py b/aten/plena_frontend.py index fb4e137..da97d91 100644 --- a/aten/plena_frontend.py +++ b/aten/plena_frontend.py @@ -1094,6 +1094,14 @@ def _emit_attention_block( O_full = prog.alloc(f"O_full_{layer_idx}", seq_len, total_q_dim, strict=False) o_full_addr = prog.get_vram_addr(O_full.name) + prog.vram_fill_zero(O_full) + # Each head occupies head_dim//mlen col-blocks of O_full, so the per-head span + # is physical_rows*head_dim. seq_len*mlen (= physical_rows*mlen) only spans one + # col-block, overlapping the heads so O_proj reads only the first ~2 correctly + # (the rest fall past the written range). Mirrors the vision o_head_stride; + # identical when head_dim==mlen, correct when head_dim>mlen. + physical_rows = O_full.physical_shape[0] + head_stride = physical_rows * head_dim kv_stored = _emit_kv_stores( prog, @@ -1112,7 +1120,7 @@ def _emit_attention_block( kv_h = h // ratio K_stored, V_stored = kv_stored[kv_h] - q_h_addr = q_full_addr + h * seq_len * prog.mlen + q_h_addr = q_full_addr + h * head_stride Q_h = prog.alloc_at(f"Q_h{h}_{layer_idx}", seq_len, head_dim, q_h_addr) _apply_rope_projection( prog, @@ -1130,9 +1138,12 @@ def _emit_attention_block( V_stored, scale, causal_mask=causal_mask, + batch_size=1, + seq_len=active_seq_len, + kv_seq_len=active_seq_len, ) - o_h_dest_addr = o_full_addr + h * seq_len * prog.mlen + o_h_dest_addr = o_full_addr + h * head_stride _copy_into_vram_view( prog, O_h, @@ -2866,9 +2877,19 @@ def _verbose(message: str = ""): # (BLEN), not a full MLEN block. Columns/K dimensions still use MLEN # until the vector, norm, RoPE, and FFN templates learn true tail masks. # - # Keep an opt-in compatibility path for reproducing older tile-scaling - # reports that padded sequence rows to MLEN. - seq_padding_multiple = mlen if os.environ.get("PLENA_PAD_SEQ_TO_MLEN") == "1" else blen + # Exception: when seq_len < MLEN the flash-attention kernel still requires + # Q/K/V to occupy a full MLEN tile physically (it asserts physical rows per + # batch >= max(MLEN, seq_len)). A BLEN-padded sub-MLEN sequence leaves the + # per-head Q/O slices with only `padded_seq_len` physical rows, so pad the + # sequence up to MLEN in that case. The extra padding query rows are never + # compared (comparison uses the real seq_len), and the causal mask keeps the + # real query rows from attending to the padding KV rows, so this is just the + # proven seq_len==MLEN tile path with fewer logically-active rows. + # + # Also keep an opt-in compatibility path for reproducing older tile-scaling + # reports that padded every sequence (even seq_len >= MLEN) up to MLEN. + pad_seq_to_mlen = os.environ.get("PLENA_PAD_SEQ_TO_MLEN") == "1" or seq_len < mlen + seq_padding_multiple = mlen if pad_seq_to_mlen else blen padded_seq_len = _ceil_to_multiple(seq_len, seq_padding_multiple) rows_per_batch = ( padded_seq_len