Skip to content

[FEATURE] Add block-causal attention for dLLM example#2499

Open
perkyfever wants to merge 1 commit into
tile-ai:mainfrom
perkyfever:block-causal-dllm-attention-example
Open

[FEATURE] Add block-causal attention for dLLM example#2499
perkyfever wants to merge 1 commit into
tile-ai:mainfrom
perkyfever:block-causal-dllm-attention-example

Conversation

@perkyfever

@perkyfever perkyfever commented Jun 30, 2026

Copy link
Copy Markdown

Added a new block-causal attention example for dLLM built on TileLang, including:

  • A full forward/backward TileLang implementation with masking logic, softmax/LSE handling, and autograd integration
  • A PyTorch reference path for validation
  • Kernel JIT caching and dtype validation helpers
  • A CUDA-gated test entrypoint and regression/performance harness for the example

@github-actions

Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai

coderabbitai Bot commented Jun 30, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

📝 Walkthrough

Walkthrough

Adds a new examples/block_causal_attention/ directory containing a full block-causal attention implementation: TileLang forward and backward kernel templates with a block-causal/noisy-clean masking scheme, JIT compilation caches, a PyTorch autograd Function wrapper, a reference implementation, correctness tests, and a regression benchmarking harness.

Changes

Block-causal attention example

Layer / File(s) Summary
Mask helpers and compilation config
examples/block_causal_attention/block_causal_attention.py
Environment-variable-driven fast-math flags, TileLang pass config dicts, and helper functions computing block-causal/noisy-clean tile constraints for forward and backward.
Forward TileLang kernel
examples/block_causal_attention/block_causal_attention.py
Kernel template with tiled GEMMs, per-tile conditional masking (noisy vs. clean diagonal), masked softmax reduction, output, and LSE writes.
Backward TileLang kernels (Delta, dQ, dK/dV)
examples/block_causal_attention/block_causal_attention.py
Three backward kernels: Delta preprocessing via fragment dot-products, dQ kernel reconstructing masked softmax scaling, and a pipelined dK/dV kernel accumulating gradients over key tiles.
JIT caching, autograd Function, and public API
examples/block_causal_attention/block_causal_attention.py
get_fwd_kernel/get_bwd_kernels with parameter-keyed caches, _BlockCausalAttentionTL autograd Function wiring all kernels, and block_causal_attention user-facing wrapper.
Reference implementation, tests, and regression
examples/block_causal_attention/block_causal_attention.py, examples/block_causal_attention/test_block_causal_attention.py, examples/block_causal_attention/regression_block_causal_attention.py
block_causal_attention_ref with explicit mask construction, CUDA correctness test comparing TileLang vs. reference, main()/run_regression_perf(), and pytest/regression entry-point modules.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

🐇 Hoppity-hop through each masked tile,
The noisy and clean split done in style,
dQ and dK/dV in a loop,
LSE saved in every group—
Block-causal kernels land with a smile! 🎉

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly matches the main change: adding a block-causal attention example for dLLM.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@examples/block_causal_attention/block_causal_attention.py`:
- Around line 67-70: The public `mask_block_size` path in the block causal
attention flow currently accepts unsupported values, which can later produce
incorrect masking or delayed JIT failures. Add an early validation guard in the
wrapper and the forward schedule entry points (including the `forward`/mask
setup near the diagonal tile logic) to reject any `mask_block_size` that is not
supported by the 64-token tile layout and the backward kernels’ fixed
`mask_block == 32` specialization. Keep the check close to the existing
shape/tile assertions so invalid inputs fail fast before any kernel launch.
- Around line 179-190: The kernel in block_causal_attention.py uses the
ambiguous tensor name O, which Ruff E741 flags and can break linting. Rename O
to a clearer identifier throughout the affected function and update every
corresponding use in the same kernel body, including the matching argument and
any related references near dO and the T.copy call, so the naming stays
consistent and unambiguous.
- Around line 175-191: The Delta preprocessing in `prep` still assumes every `k`
block is a full 64-wide tile, so the `T.copy` slices can run past the last valid
columns when `dim` is not a multiple of `block`. Update the `prep` loop to guard
the tail case in `T.ceildiv(dim, block)` by restricting the copied range to the
remaining valid width, or fail fast up front if non-64-wide tail dimensions are
not supported. Keep the fix localized around `prep`, `Delta`, and the two
`T.copy` calls so the block dimension handling stays consistent.
- Around line 488-505: The forward launch in
block_causal_attention/_BlockCausalAttentionTL is still passing through
non-contiguous query, key, and value tensors, which TileLang’s tensor proxy may
reject at the host-kernel boundary. Normalize query/key/value in
block_causal_attention before calling _BlockCausalAttentionTL.apply by checking
is_contiguous() and materializing contiguous tensors when needed, and keep the
backward helper’s contig logic aligned with the same rule instead of relying
only on stride(-1) == 1.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 725afb7a-3d0b-4843-a6af-dbef47cd83ac

📥 Commits

Reviewing files that changed from the base of the PR and between d861ff4 and 0e10494.

📒 Files selected for processing (3)
  • examples/block_causal_attention/block_causal_attention.py
  • examples/block_causal_attention/regression_block_causal_attention.py
  • examples/block_causal_attention/test_block_causal_attention.py

Comment on lines +67 to +70
assert seq_len % 2 == 0, "seq_len must be noisy|clean halves"
half_len = seq_len // 2
assert half_len % block_M == 0, "half_len must be divisible by block_M"
assert block_M == block_N, "forward uses square tiles"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎯 Functional Correctness | 🟠 Major | ⚡ Quick win

Fail fast on unsupported mask_block_size values.

The forward schedule only masks diagonal/noisy tiles, which assumes mask blocks align with the 64-token tile schedule, while both backward kernels are explicitly specialized to mask_block == 32. The public wrapper currently accepts any value, so unsupported inputs can either produce wrong forward masks or fail only during backward JIT.

Proposed guard
 def block_causal_attention(query, key, value, mask_block_size: int, softmax_scale=None):
+    if mask_block_size != 32:
+        raise ValueError("block_causal_attention currently supports mask_block_size=32 only")
     if softmax_scale is None:
         softmax_scale = query.shape[-1] ** -0.5

Also applies to: 109-135, 212-213, 303-304, 502-505

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/block_causal_attention/block_causal_attention.py` around lines 67 -
70, The public `mask_block_size` path in the block causal attention flow
currently accepts unsupported values, which can later produce incorrect masking
or delayed JIT failures. Add an early validation guard in the wrapper and the
forward schedule entry points (including the `forward`/mask setup near the
diagonal tile logic) to reject any `mask_block_size` that is not supported by
the 64-token tile layout and the backward kernels’ fixed `mask_block == 32`
specialization. Keep the check close to the existing shape/tile assertions so
invalid inputs fail fast before any kernel launch.

Comment on lines +175 to +191
block = 64

@T.prim_func
def prep(
O: T.Tensor(shape, dtype),
dO: T.Tensor(shape, dtype),
Delta: T.Tensor([batch, heads, seq_len], accum_dtype),
):
with T.Kernel(heads, T.ceildiv(seq_len, block), batch) as (bx, by, bz):
o = T.alloc_fragment([block, block], dtype)
do = T.alloc_fragment([block, block], dtype)
acc = T.alloc_fragment([block, block], accum_dtype)
delta = T.alloc_fragment([block], accum_dtype)
T.clear(acc)
for k in range(T.ceildiv(dim, block)):
T.copy(O[bz, by * block : (by + 1) * block, bx, k * block : (k + 1) * block], o)
T.copy(dO[bz, by * block : (by + 1) * block, bx, k * block : (k + 1) * block], do)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🩺 Stability & Availability | 🟠 Major | ⚡ Quick win

Guard Delta preprocessing against non-64-wide tail dimensions.

T.ceildiv(dim, block) creates a tail iteration, but each T.copy(... k * block : (k + 1) * block) still copies 64 columns. For dim=32 or dim=96, the last slice extends past the declared tensor dimension.

Fail-fast option
     shape = [batch, seq_len, heads, dim]
     block = 64
+    assert dim % block == 0, "Delta preprocessing requires dim to be divisible by 64"
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
block = 64
@T.prim_func
def prep(
O: T.Tensor(shape, dtype),
dO: T.Tensor(shape, dtype),
Delta: T.Tensor([batch, heads, seq_len], accum_dtype),
):
with T.Kernel(heads, T.ceildiv(seq_len, block), batch) as (bx, by, bz):
o = T.alloc_fragment([block, block], dtype)
do = T.alloc_fragment([block, block], dtype)
acc = T.alloc_fragment([block, block], accum_dtype)
delta = T.alloc_fragment([block], accum_dtype)
T.clear(acc)
for k in range(T.ceildiv(dim, block)):
T.copy(O[bz, by * block : (by + 1) * block, bx, k * block : (k + 1) * block], o)
T.copy(dO[bz, by * block : (by + 1) * block, bx, k * block : (k + 1) * block], do)
block = 64
assert dim % block == 0, "Delta preprocessing requires dim to be divisible by 64"
`@T.prim_func`
def prep(
O: T.Tensor(shape, dtype),
dO: T.Tensor(shape, dtype),
Delta: T.Tensor([batch, heads, seq_len], accum_dtype),
):
with T.Kernel(heads, T.ceildiv(seq_len, block), batch) as (bx, by, bz):
o = T.alloc_fragment([block, block], dtype)
do = T.alloc_fragment([block, block], dtype)
acc = T.alloc_fragment([block, block], accum_dtype)
delta = T.alloc_fragment([block], accum_dtype)
T.clear(acc)
for k in range(T.ceildiv(dim, block)):
T.copy(O[bz, by * block : (by + 1) * block, bx, k * block : (k + 1) * block], o)
T.copy(dO[bz, by * block : (by + 1) * block, bx, k * block : (k + 1) * block], do)
🧰 Tools
🪛 Ruff (0.15.20)

[error] 179-179: Ambiguous variable name: O

(E741)

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/block_causal_attention/block_causal_attention.py` around lines 175 -
191, The Delta preprocessing in `prep` still assumes every `k` block is a full
64-wide tile, so the `T.copy` slices can run past the last valid columns when
`dim` is not a multiple of `block`. Update the `prep` loop to guard the tail
case in `T.ceildiv(dim, block)` by restricting the copied range to the remaining
valid width, or fail fast up front if non-64-wide tail dimensions are not
supported. Keep the fix localized around `prep`, `Delta`, and the two `T.copy`
calls so the block dimension handling stays consistent.

Comment on lines +179 to +190
O: T.Tensor(shape, dtype),
dO: T.Tensor(shape, dtype),
Delta: T.Tensor([batch, heads, seq_len], accum_dtype),
):
with T.Kernel(heads, T.ceildiv(seq_len, block), batch) as (bx, by, bz):
o = T.alloc_fragment([block, block], dtype)
do = T.alloc_fragment([block, block], dtype)
acc = T.alloc_fragment([block, block], accum_dtype)
delta = T.alloc_fragment([block], accum_dtype)
T.clear(acc)
for k in range(T.ceildiv(dim, block)):
T.copy(O[bz, by * block : (by + 1) * block, bx, k * block : (k + 1) * block], o)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📐 Maintainability & Code Quality | 🟡 Minor | ⚡ Quick win

Rename O to satisfy Ruff E741.

Ruff flags O as an ambiguous variable name; this can fail lint checks.

Proposed rename
-        O: T.Tensor(shape, dtype),
+        Output: T.Tensor(shape, dtype),
         dO: T.Tensor(shape, dtype),
         Delta: T.Tensor([batch, heads, seq_len], accum_dtype),
@@
-                T.copy(O[bz, by * block : (by + 1) * block, bx, k * block : (k + 1) * block], o)
+                T.copy(Output[bz, by * block : (by + 1) * block, bx, k * block : (k + 1) * block], o)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
O: T.Tensor(shape, dtype),
dO: T.Tensor(shape, dtype),
Delta: T.Tensor([batch, heads, seq_len], accum_dtype),
):
with T.Kernel(heads, T.ceildiv(seq_len, block), batch) as (bx, by, bz):
o = T.alloc_fragment([block, block], dtype)
do = T.alloc_fragment([block, block], dtype)
acc = T.alloc_fragment([block, block], accum_dtype)
delta = T.alloc_fragment([block], accum_dtype)
T.clear(acc)
for k in range(T.ceildiv(dim, block)):
T.copy(O[bz, by * block : (by + 1) * block, bx, k * block : (k + 1) * block], o)
Output: T.Tensor(shape, dtype),
dO: T.Tensor(shape, dtype),
Delta: T.Tensor([batch, heads, seq_len], accum_dtype),
):
with T.Kernel(heads, T.ceildiv(seq_len, block), batch) as (bx, by, bz):
o = T.alloc_fragment([block, block], dtype)
do = T.alloc_fragment([block, block], dtype)
acc = T.alloc_fragment([block, block], accum_dtype)
delta = T.alloc_fragment([block], accum_dtype)
T.clear(acc)
for k in range(T.ceildiv(dim, block)):
T.copy(Output[bz, by * block : (by + 1) * block, bx, k * block : (k + 1) * block], o)
🧰 Tools
🪛 Ruff (0.15.20)

[error] 179-179: Ambiguous variable name: O

(E741)

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/block_causal_attention/block_causal_attention.py` around lines 179 -
190, The kernel in block_causal_attention.py uses the ambiguous tensor name O,
which Ruff E741 flags and can break linting. Rename O to a clearer identifier
throughout the affected function and update every corresponding use in the same
kernel body, including the matching argument and any related references near dO
and the T.copy call, so the naming stays consistent and unambiguous.

Source: Linters/SAST tools

Comment on lines +488 to +505
def contig(tensor):
return tensor if tensor.stride(-1) == 1 else tensor.contiguous()

do, q, k, v, o = (contig(tensor) for tensor in (do, q, k, v, o))
prep, dq_kernel, dkv_kernel = get_bwd_kernels(batch, heads, seq_len, dim, ctx.mask_block, ctx.softmax_scale, dtype)
delta = prep(o, do)
dq = torch.empty_like(q)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
dq_kernel(q, k, v, do, lse, delta, dq)
dkv_kernel(q, k, v, do, lse, delta, dk, dv)
return dq, dk, dv, None, None


def block_causal_attention(query, key, value, mask_block_size: int, softmax_scale=None):
if softmax_scale is None:
softmax_scale = query.shape[-1] ** -0.5
return _BlockCausalAttentionTL.apply(query, key, value, mask_block_size, float(softmax_scale))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎯 Functional Correctness | 🟠 Major | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

set -e
python3 - <<'PY'
from pathlib import Path
p = Path('examples/block_causal_attention/block_causal_attention.py')
lines = p.read_text().splitlines()
for start, end in [(1,120),(120,260),(260,420),(420,560)]:
    print(f"\n--- {p} lines {start}-{end} ---")
    for i in range(start, min(end, len(lines))+1):
        print(f"{i:4d}: {lines[i-1]}")
PY
rg -n "is_contiguous|contiguous\(|stride\(-1\)|mask_block|get_bwd_kernels|get_fwd_kernels|TileLang|T\.Tensor" examples/block_causal_attention/block_causal_attention.py

Repository: tile-ai/tilelang

Length of output: 30771


🌐 Web query:

TileLang T.Tensor input contiguity strided tensor layout semantics custom kernel contiguous requirement

💡 Result:

In TileLang, T.Tensor input contiguity is governed by the compiler's auto-validation system, which enforces memory layout requirements at the host-kernel boundary [1]. Key aspects of input contiguity and strided tensor semantics include: 1. Contiguity Requirement: By default, TileLang's TensorProxy assumes that global scope tensors are contiguous [2][3]. The host-side auto-validation system automatically performs checks on arguments passed to kernels (e.g., torch.Tensor). If a tensor's actual strides do not match the expected contiguous layout, the validation system triggers an error [1]. 2. Handling Non-Contiguous Tensors: If a kernel receives a non-contiguous tensor (e.g., a transposed or sliced tensor) that violates the expected stride constraints, the compiler will raise a "Strides check failed" error [1]. To resolve this, users are advised to either call.contiguous on the tensor before passing it to the kernel or adjust the kernel's layout expectations [1]. 3. Strided Tensor Support: While default proxies assume contiguity, TileLang supports explicit strided tensor layouts through the StridedTensorProxy class [2][3]. Furthermore, recent advancements have introduced support for operations like T.gemm to handle tensors with explicit memory strides [4]. This allows the compiler to propagate stride information—such as leading dimensions—into backend CUDA kernels, enabling correct memory access patterns for non-contiguous inputs [4]. 4. Layout Semantics: TileLang manages memory layouts using a composable algebraic approach based on IterVars [5][6]. Layout information is maintained in a LayoutMap, and the compiler uses this to infer access patterns and optimize memory operations (such as swizzling or padding) [5]. For specialized custom kernels, developers can use T.annotate_layout to define explicit, non-standard layouts for buffers, overriding default contiguity assumptions [7].

Citations:


🏁 Script executed:

set -e
rg -n "stride\(-1\) == 1|is_contiguous\(\)|contiguous\(\)" -g '*.py' .

Repository: tile-ai/tilelang

Length of output: 18785


Normalize query/key/value before the forward launch. TileLang’s default tensor proxy expects contiguous layouts and will reject non-contiguous inputs at the host-kernel boundary; stride(-1) == 1 in backward() still lets strided views through. Use is_contiguous() (or .contiguous()) for both the forward inputs and the backward helper.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/block_causal_attention/block_causal_attention.py` around lines 488 -
505, The forward launch in block_causal_attention/_BlockCausalAttentionTL is
still passing through non-contiguous query, key, and value tensors, which
TileLang’s tensor proxy may reject at the host-kernel boundary. Normalize
query/key/value in block_causal_attention before calling
_BlockCausalAttentionTL.apply by checking is_contiguous() and materializing
contiguous tensors when needed, and keep the backward helper’s contig logic
aligned with the same rule instead of relying only on stride(-1) == 1.

@LeiWang1999

Copy link
Copy Markdown
Member

please fix the lint :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants