Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions aten/plena/isa_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ def _pv_multiply_asm(
v_hbm_offset: int,
pv_address: int,
rows: int | None = None,
pv_physical_rows: int | None = None,
) -> str:
"""
Compute PV = P @ V via M_MM.
Expand All @@ -232,6 +233,14 @@ def _pv_multiply_asm(
column blocks; the outer loop iterates blocks, middle loop iterates blen-wide
V columns within a block, inner loop iterates blen-wide P rows.
"""
# PV is stored column-block-major with `pv_physical_rows` rows (= min(mlen,
# seq_len) for the decoder), so each head_dim col-block spans
# pv_physical_rows*mlen, not mlen*mlen. They coincide at seq>=mlen, but for
# seq<mlen using mlen*mlen writes col-blocks 1.. out of bounds, leaving
# head_dim cols mlen.. zero (PV shrinks by mlen/head_dim).
if pv_physical_rows is None:
pv_physical_rows = mlen

if getattr(self, "unroll_attention", False):
return self._pv_multiply_asm_unrolled(
mlen=mlen,
Expand All @@ -241,6 +250,7 @@ def _pv_multiply_asm(
v_hbm_offset_reg=v_hbm_offset_reg,
v_hbm_offset=v_hbm_offset,
pv_address=pv_address,
pv_physical_rows=pv_physical_rows,
)

gp_regs = self.register_allocator.allocate_gp(8)
Expand Down Expand Up @@ -283,7 +293,7 @@ def _pv_multiply_asm(
lines.extend(load_large_int(gp_hbm, v_block_hbm_offset))
lines.append(f"H_PREFETCH_M gp{gp_v}, gp{gp_hbm}, a{v_hbm_offset_reg}, 1, 1")

pv_col_block_base = pv_address + v_col_block * mlen * mlen
pv_col_block_base = pv_address + v_col_block * pv_physical_rows * mlen
lines.extend(load_large_int(gp_pv_col_base, pv_col_block_base))
lines.append(f"C_LOOP_START gp{gp_v_loop}, {tiles_per_mlen}")
lines.extend(load_large_int(gp_p, p_address))
Expand All @@ -310,8 +320,11 @@ def _pv_multiply_asm_unrolled(
v_hbm_offset_reg: int,
v_hbm_offset: int,
pv_address: int,
pv_physical_rows: int | None = None,
) -> str:
"""Legacy Python-unrolled P @ V emission, kept for A/B comparisons."""
if pv_physical_rows is None:
pv_physical_rows = mlen
gp_regs = self.register_allocator.allocate_gp(5)
gp_p = gp_regs[0]
gp_v = gp_regs[1]
Expand Down Expand Up @@ -348,7 +361,7 @@ def _pv_multiply_asm_unrolled(
lines.extend(load_large_int(gp_p, p_row_addr))
lines.append(f"M_MM 0, gp{gp_v}, gp{gp_p}")

pv_offset = v_col_block * mlen * mlen + p_row * blen * mlen + v_col * blen
pv_offset = v_col_block * pv_physical_rows * mlen + p_row * blen * mlen + v_col * blen
lines.extend(load_large_int(gp_pv, pv_address + pv_offset))
lines.append(f"M_MM_WO gp{gp_pv}, gp{gp_stride}, 0")

Expand Down Expand Up @@ -797,6 +810,7 @@ def compute_pv(
v_hbm_offset=v_hbm_offset,
pv_address=pv_address,
rows=rows,
pv_physical_rows=pv_info.physical_shape[0],
)

self.register_allocator.free_gp(gp_regs)
Expand Down
31 changes: 26 additions & 5 deletions aten/plena_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1094,6 +1094,14 @@ def _emit_attention_block(

O_full = prog.alloc(f"O_full_{layer_idx}", seq_len, total_q_dim, strict=False)
o_full_addr = prog.get_vram_addr(O_full.name)
prog.vram_fill_zero(O_full)
# Each head occupies head_dim//mlen col-blocks of O_full, so the per-head span
# is physical_rows*head_dim. seq_len*mlen (= physical_rows*mlen) only spans one
# col-block, overlapping the heads so O_proj reads only the first ~2 correctly
# (the rest fall past the written range). Mirrors the vision o_head_stride;
# identical when head_dim==mlen, correct when head_dim>mlen.
physical_rows = O_full.physical_shape[0]
head_stride = physical_rows * head_dim

kv_stored = _emit_kv_stores(
prog,
Expand All @@ -1112,7 +1120,7 @@ def _emit_attention_block(
kv_h = h // ratio
K_stored, V_stored = kv_stored[kv_h]

q_h_addr = q_full_addr + h * seq_len * prog.mlen
q_h_addr = q_full_addr + h * head_stride
Q_h = prog.alloc_at(f"Q_h{h}_{layer_idx}", seq_len, head_dim, q_h_addr)
_apply_rope_projection(
prog,
Expand All @@ -1130,9 +1138,12 @@ def _emit_attention_block(
V_stored,
scale,
causal_mask=causal_mask,
batch_size=1,
seq_len=active_seq_len,
kv_seq_len=active_seq_len,
)

o_h_dest_addr = o_full_addr + h * seq_len * prog.mlen
o_h_dest_addr = o_full_addr + h * head_stride
_copy_into_vram_view(
prog,
O_h,
Expand Down Expand Up @@ -2866,9 +2877,19 @@ def _verbose(message: str = ""):
# (BLEN), not a full MLEN block. Columns/K dimensions still use MLEN
# until the vector, norm, RoPE, and FFN templates learn true tail masks.
#
# Keep an opt-in compatibility path for reproducing older tile-scaling
# reports that padded sequence rows to MLEN.
seq_padding_multiple = mlen if os.environ.get("PLENA_PAD_SEQ_TO_MLEN") == "1" else blen
# Exception: when seq_len < MLEN the flash-attention kernel still requires
# Q/K/V to occupy a full MLEN tile physically (it asserts physical rows per
# batch >= max(MLEN, seq_len)). A BLEN-padded sub-MLEN sequence leaves the
# per-head Q/O slices with only `padded_seq_len` physical rows, so pad the
# sequence up to MLEN in that case. The extra padding query rows are never
# compared (comparison uses the real seq_len), and the causal mask keeps the
# real query rows from attending to the padding KV rows, so this is just the
# proven seq_len==MLEN tile path with fewer logically-active rows.
#
# Also keep an opt-in compatibility path for reproducing older tile-scaling
# reports that padded every sequence (even seq_len >= MLEN) up to MLEN.
pad_seq_to_mlen = os.environ.get("PLENA_PAD_SEQ_TO_MLEN") == "1" or seq_len < mlen
seq_padding_multiple = mlen if pad_seq_to_mlen else blen
padded_seq_len = _ceil_to_multiple(seq_len, seq_padding_multiple)
rows_per_batch = (
padded_seq_len
Expand Down
Loading