[FEATURE] Add block-causal attention for dLLM example#2499
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughAdds a new ChangesBlock-causal attention example
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 4
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@examples/block_causal_attention/block_causal_attention.py`:
- Around line 67-70: The public `mask_block_size` path in the block causal
attention flow currently accepts unsupported values, which can later produce
incorrect masking or delayed JIT failures. Add an early validation guard in the
wrapper and the forward schedule entry points (including the `forward`/mask
setup near the diagonal tile logic) to reject any `mask_block_size` that is not
supported by the 64-token tile layout and the backward kernels’ fixed
`mask_block == 32` specialization. Keep the check close to the existing
shape/tile assertions so invalid inputs fail fast before any kernel launch.
- Around line 179-190: The kernel in block_causal_attention.py uses the
ambiguous tensor name O, which Ruff E741 flags and can break linting. Rename O
to a clearer identifier throughout the affected function and update every
corresponding use in the same kernel body, including the matching argument and
any related references near dO and the T.copy call, so the naming stays
consistent and unambiguous.
- Around line 175-191: The Delta preprocessing in `prep` still assumes every `k`
block is a full 64-wide tile, so the `T.copy` slices can run past the last valid
columns when `dim` is not a multiple of `block`. Update the `prep` loop to guard
the tail case in `T.ceildiv(dim, block)` by restricting the copied range to the
remaining valid width, or fail fast up front if non-64-wide tail dimensions are
not supported. Keep the fix localized around `prep`, `Delta`, and the two
`T.copy` calls so the block dimension handling stays consistent.
- Around line 488-505: The forward launch in
block_causal_attention/_BlockCausalAttentionTL is still passing through
non-contiguous query, key, and value tensors, which TileLang’s tensor proxy may
reject at the host-kernel boundary. Normalize query/key/value in
block_causal_attention before calling _BlockCausalAttentionTL.apply by checking
is_contiguous() and materializing contiguous tensors when needed, and keep the
backward helper’s contig logic aligned with the same rule instead of relying
only on stride(-1) == 1.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 725afb7a-3d0b-4843-a6af-dbef47cd83ac
📒 Files selected for processing (3)
examples/block_causal_attention/block_causal_attention.pyexamples/block_causal_attention/regression_block_causal_attention.pyexamples/block_causal_attention/test_block_causal_attention.py
| assert seq_len % 2 == 0, "seq_len must be noisy|clean halves" | ||
| half_len = seq_len // 2 | ||
| assert half_len % block_M == 0, "half_len must be divisible by block_M" | ||
| assert block_M == block_N, "forward uses square tiles" |
There was a problem hiding this comment.
🎯 Functional Correctness | 🟠 Major | ⚡ Quick win
Fail fast on unsupported mask_block_size values.
The forward schedule only masks diagonal/noisy tiles, which assumes mask blocks align with the 64-token tile schedule, while both backward kernels are explicitly specialized to mask_block == 32. The public wrapper currently accepts any value, so unsupported inputs can either produce wrong forward masks or fail only during backward JIT.
Proposed guard
def block_causal_attention(query, key, value, mask_block_size: int, softmax_scale=None):
+ if mask_block_size != 32:
+ raise ValueError("block_causal_attention currently supports mask_block_size=32 only")
if softmax_scale is None:
softmax_scale = query.shape[-1] ** -0.5Also applies to: 109-135, 212-213, 303-304, 502-505
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@examples/block_causal_attention/block_causal_attention.py` around lines 67 -
70, The public `mask_block_size` path in the block causal attention flow
currently accepts unsupported values, which can later produce incorrect masking
or delayed JIT failures. Add an early validation guard in the wrapper and the
forward schedule entry points (including the `forward`/mask setup near the
diagonal tile logic) to reject any `mask_block_size` that is not supported by
the 64-token tile layout and the backward kernels’ fixed `mask_block == 32`
specialization. Keep the check close to the existing shape/tile assertions so
invalid inputs fail fast before any kernel launch.
| block = 64 | ||
|
|
||
| @T.prim_func | ||
| def prep( | ||
| O: T.Tensor(shape, dtype), | ||
| dO: T.Tensor(shape, dtype), | ||
| Delta: T.Tensor([batch, heads, seq_len], accum_dtype), | ||
| ): | ||
| with T.Kernel(heads, T.ceildiv(seq_len, block), batch) as (bx, by, bz): | ||
| o = T.alloc_fragment([block, block], dtype) | ||
| do = T.alloc_fragment([block, block], dtype) | ||
| acc = T.alloc_fragment([block, block], accum_dtype) | ||
| delta = T.alloc_fragment([block], accum_dtype) | ||
| T.clear(acc) | ||
| for k in range(T.ceildiv(dim, block)): | ||
| T.copy(O[bz, by * block : (by + 1) * block, bx, k * block : (k + 1) * block], o) | ||
| T.copy(dO[bz, by * block : (by + 1) * block, bx, k * block : (k + 1) * block], do) |
There was a problem hiding this comment.
🩺 Stability & Availability | 🟠 Major | ⚡ Quick win
Guard Delta preprocessing against non-64-wide tail dimensions.
T.ceildiv(dim, block) creates a tail iteration, but each T.copy(... k * block : (k + 1) * block) still copies 64 columns. For dim=32 or dim=96, the last slice extends past the declared tensor dimension.
Fail-fast option
shape = [batch, seq_len, heads, dim]
block = 64
+ assert dim % block == 0, "Delta preprocessing requires dim to be divisible by 64"📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| block = 64 | |
| @T.prim_func | |
| def prep( | |
| O: T.Tensor(shape, dtype), | |
| dO: T.Tensor(shape, dtype), | |
| Delta: T.Tensor([batch, heads, seq_len], accum_dtype), | |
| ): | |
| with T.Kernel(heads, T.ceildiv(seq_len, block), batch) as (bx, by, bz): | |
| o = T.alloc_fragment([block, block], dtype) | |
| do = T.alloc_fragment([block, block], dtype) | |
| acc = T.alloc_fragment([block, block], accum_dtype) | |
| delta = T.alloc_fragment([block], accum_dtype) | |
| T.clear(acc) | |
| for k in range(T.ceildiv(dim, block)): | |
| T.copy(O[bz, by * block : (by + 1) * block, bx, k * block : (k + 1) * block], o) | |
| T.copy(dO[bz, by * block : (by + 1) * block, bx, k * block : (k + 1) * block], do) | |
| block = 64 | |
| assert dim % block == 0, "Delta preprocessing requires dim to be divisible by 64" | |
| `@T.prim_func` | |
| def prep( | |
| O: T.Tensor(shape, dtype), | |
| dO: T.Tensor(shape, dtype), | |
| Delta: T.Tensor([batch, heads, seq_len], accum_dtype), | |
| ): | |
| with T.Kernel(heads, T.ceildiv(seq_len, block), batch) as (bx, by, bz): | |
| o = T.alloc_fragment([block, block], dtype) | |
| do = T.alloc_fragment([block, block], dtype) | |
| acc = T.alloc_fragment([block, block], accum_dtype) | |
| delta = T.alloc_fragment([block], accum_dtype) | |
| T.clear(acc) | |
| for k in range(T.ceildiv(dim, block)): | |
| T.copy(O[bz, by * block : (by + 1) * block, bx, k * block : (k + 1) * block], o) | |
| T.copy(dO[bz, by * block : (by + 1) * block, bx, k * block : (k + 1) * block], do) |
🧰 Tools
🪛 Ruff (0.15.20)
[error] 179-179: Ambiguous variable name: O
(E741)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@examples/block_causal_attention/block_causal_attention.py` around lines 175 -
191, The Delta preprocessing in `prep` still assumes every `k` block is a full
64-wide tile, so the `T.copy` slices can run past the last valid columns when
`dim` is not a multiple of `block`. Update the `prep` loop to guard the tail
case in `T.ceildiv(dim, block)` by restricting the copied range to the remaining
valid width, or fail fast up front if non-64-wide tail dimensions are not
supported. Keep the fix localized around `prep`, `Delta`, and the two `T.copy`
calls so the block dimension handling stays consistent.
| O: T.Tensor(shape, dtype), | ||
| dO: T.Tensor(shape, dtype), | ||
| Delta: T.Tensor([batch, heads, seq_len], accum_dtype), | ||
| ): | ||
| with T.Kernel(heads, T.ceildiv(seq_len, block), batch) as (bx, by, bz): | ||
| o = T.alloc_fragment([block, block], dtype) | ||
| do = T.alloc_fragment([block, block], dtype) | ||
| acc = T.alloc_fragment([block, block], accum_dtype) | ||
| delta = T.alloc_fragment([block], accum_dtype) | ||
| T.clear(acc) | ||
| for k in range(T.ceildiv(dim, block)): | ||
| T.copy(O[bz, by * block : (by + 1) * block, bx, k * block : (k + 1) * block], o) |
There was a problem hiding this comment.
📐 Maintainability & Code Quality | 🟡 Minor | ⚡ Quick win
Rename O to satisfy Ruff E741.
Ruff flags O as an ambiguous variable name; this can fail lint checks.
Proposed rename
- O: T.Tensor(shape, dtype),
+ Output: T.Tensor(shape, dtype),
dO: T.Tensor(shape, dtype),
Delta: T.Tensor([batch, heads, seq_len], accum_dtype),
@@
- T.copy(O[bz, by * block : (by + 1) * block, bx, k * block : (k + 1) * block], o)
+ T.copy(Output[bz, by * block : (by + 1) * block, bx, k * block : (k + 1) * block], o)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| O: T.Tensor(shape, dtype), | |
| dO: T.Tensor(shape, dtype), | |
| Delta: T.Tensor([batch, heads, seq_len], accum_dtype), | |
| ): | |
| with T.Kernel(heads, T.ceildiv(seq_len, block), batch) as (bx, by, bz): | |
| o = T.alloc_fragment([block, block], dtype) | |
| do = T.alloc_fragment([block, block], dtype) | |
| acc = T.alloc_fragment([block, block], accum_dtype) | |
| delta = T.alloc_fragment([block], accum_dtype) | |
| T.clear(acc) | |
| for k in range(T.ceildiv(dim, block)): | |
| T.copy(O[bz, by * block : (by + 1) * block, bx, k * block : (k + 1) * block], o) | |
| Output: T.Tensor(shape, dtype), | |
| dO: T.Tensor(shape, dtype), | |
| Delta: T.Tensor([batch, heads, seq_len], accum_dtype), | |
| ): | |
| with T.Kernel(heads, T.ceildiv(seq_len, block), batch) as (bx, by, bz): | |
| o = T.alloc_fragment([block, block], dtype) | |
| do = T.alloc_fragment([block, block], dtype) | |
| acc = T.alloc_fragment([block, block], accum_dtype) | |
| delta = T.alloc_fragment([block], accum_dtype) | |
| T.clear(acc) | |
| for k in range(T.ceildiv(dim, block)): | |
| T.copy(Output[bz, by * block : (by + 1) * block, bx, k * block : (k + 1) * block], o) |
🧰 Tools
🪛 Ruff (0.15.20)
[error] 179-179: Ambiguous variable name: O
(E741)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@examples/block_causal_attention/block_causal_attention.py` around lines 179 -
190, The kernel in block_causal_attention.py uses the ambiguous tensor name O,
which Ruff E741 flags and can break linting. Rename O to a clearer identifier
throughout the affected function and update every corresponding use in the same
kernel body, including the matching argument and any related references near dO
and the T.copy call, so the naming stays consistent and unambiguous.
Source: Linters/SAST tools
| def contig(tensor): | ||
| return tensor if tensor.stride(-1) == 1 else tensor.contiguous() | ||
|
|
||
| do, q, k, v, o = (contig(tensor) for tensor in (do, q, k, v, o)) | ||
| prep, dq_kernel, dkv_kernel = get_bwd_kernels(batch, heads, seq_len, dim, ctx.mask_block, ctx.softmax_scale, dtype) | ||
| delta = prep(o, do) | ||
| dq = torch.empty_like(q) | ||
| dk = torch.empty_like(k) | ||
| dv = torch.empty_like(v) | ||
| dq_kernel(q, k, v, do, lse, delta, dq) | ||
| dkv_kernel(q, k, v, do, lse, delta, dk, dv) | ||
| return dq, dk, dv, None, None | ||
|
|
||
|
|
||
| def block_causal_attention(query, key, value, mask_block_size: int, softmax_scale=None): | ||
| if softmax_scale is None: | ||
| softmax_scale = query.shape[-1] ** -0.5 | ||
| return _BlockCausalAttentionTL.apply(query, key, value, mask_block_size, float(softmax_scale)) |
There was a problem hiding this comment.
🎯 Functional Correctness | 🟠 Major | ⚡ Quick win
🧩 Analysis chain
🏁 Script executed:
set -e
python3 - <<'PY'
from pathlib import Path
p = Path('examples/block_causal_attention/block_causal_attention.py')
lines = p.read_text().splitlines()
for start, end in [(1,120),(120,260),(260,420),(420,560)]:
print(f"\n--- {p} lines {start}-{end} ---")
for i in range(start, min(end, len(lines))+1):
print(f"{i:4d}: {lines[i-1]}")
PY
rg -n "is_contiguous|contiguous\(|stride\(-1\)|mask_block|get_bwd_kernels|get_fwd_kernels|TileLang|T\.Tensor" examples/block_causal_attention/block_causal_attention.pyRepository: tile-ai/tilelang
Length of output: 30771
🌐 Web query:
TileLang T.Tensor input contiguity strided tensor layout semantics custom kernel contiguous requirement
💡 Result:
In TileLang, T.Tensor input contiguity is governed by the compiler's auto-validation system, which enforces memory layout requirements at the host-kernel boundary [1]. Key aspects of input contiguity and strided tensor semantics include: 1. Contiguity Requirement: By default, TileLang's TensorProxy assumes that global scope tensors are contiguous [2][3]. The host-side auto-validation system automatically performs checks on arguments passed to kernels (e.g., torch.Tensor). If a tensor's actual strides do not match the expected contiguous layout, the validation system triggers an error [1]. 2. Handling Non-Contiguous Tensors: If a kernel receives a non-contiguous tensor (e.g., a transposed or sliced tensor) that violates the expected stride constraints, the compiler will raise a "Strides check failed" error [1]. To resolve this, users are advised to either call.contiguous on the tensor before passing it to the kernel or adjust the kernel's layout expectations [1]. 3. Strided Tensor Support: While default proxies assume contiguity, TileLang supports explicit strided tensor layouts through the StridedTensorProxy class [2][3]. Furthermore, recent advancements have introduced support for operations like T.gemm to handle tensors with explicit memory strides [4]. This allows the compiler to propagate stride information—such as leading dimensions—into backend CUDA kernels, enabling correct memory access patterns for non-contiguous inputs [4]. 4. Layout Semantics: TileLang manages memory layouts using a composable algebraic approach based on IterVars [5][6]. Layout information is maintained in a LayoutMap, and the compiler uses this to infer access patterns and optimize memory operations (such as swizzling or padding) [5]. For specialized custom kernels, developers can use T.annotate_layout to define explicit, non-standard layouts for buffers, overriding default contiguity assumptions [7].
Citations:
- 1: https://tilelang.com/compiler_internals/tensor_checks.html
- 2: https://www.tilelang.com/autoapi/tilelang/language/proxy/index.html
- 3: https://github.com/tile-ai/tilelang/blob/17c4b384/tilelang/language/proxy.py
- 4: [Feat] Support gemm with stride #701
- 5: https://arxiv.org/pdf/2504.17577
- 6: https://openreview.net/pdf/e851ad3eaca44bc18d67d753b03b3664ca2b9e9c.pdf
- 7: https://www.tilelang.com/programming_guides/instructions.html
🏁 Script executed:
set -e
rg -n "stride\(-1\) == 1|is_contiguous\(\)|contiguous\(\)" -g '*.py' .Repository: tile-ai/tilelang
Length of output: 18785
Normalize query/key/value before the forward launch. TileLang’s default tensor proxy expects contiguous layouts and will reject non-contiguous inputs at the host-kernel boundary; stride(-1) == 1 in backward() still lets strided views through. Use is_contiguous() (or .contiguous()) for both the forward inputs and the backward helper.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@examples/block_causal_attention/block_causal_attention.py` around lines 488 -
505, The forward launch in block_causal_attention/_BlockCausalAttentionTL is
still passing through non-contiguous query, key, and value tensors, which
TileLang’s tensor proxy may reject at the host-kernel boundary. Normalize
query/key/value in block_causal_attention before calling
_BlockCausalAttentionTL.apply by checking is_contiguous() and materializing
contiguous tensors when needed, and keep the backward helper’s contig logic
aligned with the same rule instead of relying only on stride(-1) == 1.
|
please fix the lint :) |
Added a new block-causal attention example for dLLM built on TileLang, including: