[NPU] refact fused_neighborhood_attention#1265
Conversation
Fused Neighborhood Attention (Ascend) Refactoring Analysis1. Change OverviewTo resolve the UB overflow and OOM issues occurring with the fused_neighborhood_attention operator on the Ascend NPU, the following refactoring has been implemented. This refactoring migrates the Ascend backend's Neighborhood Attention from a dense S×S matrix implementation to a sparse neighborhood band implementation, introducing memory-aware tiling, dual-path backward scheduling, and Ascend hardware constraint adaptation in forward/backward passes.
2. Architecture Comparison2.1 Old Architecture (Dense)Backward: 2.2 New Architecture (Sparse Band + Dual-Path Backward)Backward Routing: 3. Core Changes Breakdown3.1 Removal of
|
| Kernel | Purpose |
|---|---|
_sparse_neighborhood_grad_attn_softmax_bwd_kernel |
Fused grad_attn + softmax backward → grad_qk (sparse) |
_sparse_neighborhood_grad_q_kernel |
grad_Q |
_sparse_neighborhood_grad_k_kernel |
grad_K |
_sparse_neighborhood_grad_qkv_fused_kernel |
Fused grad_Q/K/V (3D grid) |
_fused_neighborhood_attention_grad_v_kernel |
grad_V (sparse gather) |
Rationale: At large seq_len, S² intermediate tensors are infeasible; sparse kernels operate entirely on [S, K].
Dense Backward Path (_dense_neighborhood_attention_backward)
New kernel family:
| Kernel | Purpose |
|---|---|
_sparse_attn_expand_to_dense_kernel |
Sparse band → dense [S,S] (only needed for backward) |
_dense_neighborhood_grad_attn_kernel |
grad_attn = grad_output @ V^T |
_dense_neighborhood_grad_qk_fused_kernel |
Fused grad_Q + grad_K |
_dense_neighborhood_grad_qk_fused_persistent_kernel |
Persistent variant |
_dense_neighborhood_grad_v_kernel |
grad_V = attn^T @ grad_output |
Fast path conditions (_sparse_band_dense_fast_path_available):
seq_len <= 2048- workspace
B×H×S×S×3×elem_sizedoes not exceed 18% of available memory - grid does not exceed 65535
Rationale: For small sequences, tl.dot is more efficient on Ascend; after one expansion, existing dense matmul kernels outperform handwritten sparse gather.
3.5 Memory and Tiling Infrastructure
New helper layer (largely absent in old version):
| Function | Purpose |
|---|---|
_scores_chunk_byte_limit() |
Forward attn tiling: per-chunk limit ≤ 8% of available memory |
_attention_batch_chunk_size() |
Split forward along batch dimension |
_get_attention_block_sizes() |
UB-aware tile sizes (compute_default_tiling_strategy) |
_dense_backward_byte_limit() |
Backward workspace upper bound (18% of available memory) |
_dense_backward_batch_head_chunk() |
Chunk backward along batch×head dimension |
_get_dense_block_m() / _dense_backward_grid_fits() |
Adapt to Ascend grid 65535 upper bound |
_get_sparse_backward_block_sizes() |
Adjust block_m/block_n by seq_len tiers |
Rationale:
- Neighborhood attention is sparse, but backward fast path still requires S² workspace
- Ascend NPU has UB (Unified Buffer) and grid constraints requiring dynamic tiling
- Large batch + long sequence combinations would OOM without chunking
3.6 Forward API Changes
| Item | Old | New |
|---|---|---|
| Return value | (output, attn_weights, softmax_params) |
(output, attn_for_backward) |
| softmax_params | Saved for backward use | No longer needed (softmax fused) |
| dtype handling | Input cast to FP32, output cast back | Intermediate FP32 accumulation, attn cast to input dtype |
| Batch tiling | None | Controlled by _attention_batch_chunk_size |
3.7 Autograd Layer Enhancements
Old LigerFusedNeighborhoodAttentionFunction.backward:
- 200+ lines of inline backward logic
- Re-executed on every backward call
New:
- Delegates to
fused_neighborhood_attention_backward - Added backward cache: identical
grad_outputtensor (same storage/version/shape/stride) reuses grad_Q/K/V - Cache capacity of 4 entries, LRU eviction
Rationale: Separation of concerns; cache avoids redundant computation on double backward or repeated grad calls.
4. Why These Changes — Motivation Summary
4.1 Memory Bottleneck (Primary)
Effective connectivity in Neighborhood Attention is O(S·K), but the old implementation allocates O(S²) tensors for qk_scores, attn_weights, and grad_*. At S=4096, H=32, B=8:
- Old: Single
[B,H,S,S]FP32 tensor ≈ 8×32×4096²×4 ≈ 17 GB - New: Single
[B,H,S,K]FP32 tensor (K=7) ≈ 30 MB scale
This is the primary driver for this refactoring.
4.2 Computational Efficiency
The old QK stage runs tl.dot over the entire S×S matrix and then masks out invalid positions — computational cost equivalent to full attention. The new implementation computes dot products only for K neighbors, reducing computation to O(S·K·D).
4.3 Ascend Hardware Adaptation
- Introduces
compute_default_tiling_strategyfor UB-aware tiling - Explicitly handles grid program upper bound of 65535
- Adjusts block sizes by seq_len tiers (512 / 2048 / 4096)
- Uses
get_total_gpu_memory()for dynamic chunk threshold calculation
4.4 Correctness and Maintainability
- Fixes old QK kernel bug where scale was repeatedly applied inside the K dimension loop
- Backward logic modularized; sparse/dense paths can be independently tested and optimized
- Retains backward compatibility with legacy dense
attn_weightsshapes
|
@Tcc0403 This pr is ready for review. Thanks! |
4225067 to
bba28ce
Compare
Tcc0403
left a comment
There was a problem hiding this comment.
Left my comments under one kernel but other kernels have similar concerns, once they are resolved, it should be applicable to other kernels as well.
| partial = tl.sum(q_chunk * k_chunk, axis=1) | ||
| col_select = (tl.arange(0, kernel_size) == neighbor_idx)[None, :] | ||
| dot_acc += tl.where(col_select, partial[:, None], tl.zeros_like(dot_acc)) | ||
| neighbor_valid = tl.maximum(neighbor_valid, tl.where(col_select & key_valid[:, None], 1.0, 0.0)) |
There was a problem hiding this comment.
This part is quite difficult to understand. Let me know if I'm reading it wrong.
In for neighbor_idx in range(kernel_size): loop, we iteratively compute QK tile column by column with tl.sum instead of computing the whole tile once with tl.dot. What's the rationale behind it? Why couldn't we simply leverage tl.dot? Is kernel size being one of mma tile dimension not supported?
To handle partial tile access that isn't support in triton rerquires couple of tl.where as workaround, which is suboptimal for performance and readability. If we can make tile level instructions work with some modification (tile size, padding, etc), that would be our first choice. However, it's still acceptable if such workaround can achieve better result.
Some suggestions if partial pattern is a must:
tl.whereandtl.zeros_like: I suppose thattl.zeros_likeis for matchingtl.wherex/y shapes, but register pressure and frequent allocations could be performance issues. Setting out-of-bound values and proper masking when loading k_chunk would be a better approach. https://triton-lang.org/main/python-api/generated/triton.language.where.html- neighbor_valid: It seems to be a flag tensor which is filled with 1 iteratively in the loop. Instead of max + where, it should be equivalent to logical OR with
col_select & key_valid[:, None], so we can avoid max and where.
Note: above suggestions are not verified applicable and guaranteed for perf gain.
There was a problem hiding this comment.
The neighborhood QK here is a per-row gather pattern, not a contiguous matmul tile:
For each query row i and neighbor j, we need dot(Q[i], K[key_col(i,j)]), where key_col(i,j) = i + (j - half_kernel) * dilation.
K indices differ per row, so K cannot be laid out as a contiguous [N, D] tile for a single tl.dot(Q, K^T).
kernel_size is small (typically 7–15), so gather + element-wise multiply + tl.sum is the natural fit.
Agreed on the tl.where / neighbor_valid readability concerns — addressed in the latest push:
Replaced tl.zeros_like scatter with in-place tl.where(..., dot_acc).
Precomputed key_valid_all and removed the iterative neighbor_valid flag tensor.
We explored tl.dot but it would require materializing a [BLOCK_M, kernel_size, BLOCK_K] K tensor via gather first, which adds UB pressure without a clear win on Ascend for small K.
| for neighbor_idx in range(kernel_size): | ||
| offset = neighbor_idx - half_kernel | ||
| key_cols = row_offsets + offset * dilation | ||
| key_valid = row_valid & (key_cols >= 0) & (key_cols < seq_len) | ||
|
|
||
| col_select = (tl.arange(0, kernel_size) == neighbor_idx)[None, :] | ||
| attn_val = tl.sum(attn_local * col_select, axis=1) | ||
| attn_ptrs = ( | ||
| Attn_ptr | ||
| + batch_id * attn_batch_stride | ||
| + head_id * attn_head_stride | ||
| + safe_row * attn_seq_stride | ||
| + neighbor_idx * attn_neighbor_stride | ||
| ) | ||
| tl.store(attn_ptrs, attn_val, mask=key_valid) |
There was a problem hiding this comment.
ditto, is it possible to store a tile instead of column by column?
There was a problem hiding this comment.
Done — attn weights are now stored as a full [BLOCK_M, kernel_size] tile in one tl.store, using a precomputed key_valid_all mask. Applied the same pattern to grad_qk store in the sparse softmax-backward kernel.
05a263f to
532b4bc
Compare
|
Kernels with identical issues have all been fixed accordingly. A summary is provided below: Fixed Kernels (5 in total)
Kernels Not Eligible for the Same Optimization (3 in total)These kernels adopt query-side index gathering (
|
Co-authored-by: Cursor <cursoragent@cursor.com>
532b4bc to
183cc88
Compare
Summary
Refactor Ascend fused neighborhood attention from a dense S×S matrix implementation to a sparse band
[S, kernel_size]representation, dramatically reducing memory footprint and eliminating redundant compute on masked-out positions. To resolve the UB overflow and OOM issues.Details
Key changes:
_neighborhood_mask_kernel+ dense QK + external softmax with fused_sparse_neighborhood_qk_softmax_kernelthat computes only neighborhood dot products and applies row softmax in-register.tl.dotover seq_len._sparse_neighborhood_attention_backwardfor large sequences and_dense_neighborhood_attention_backwardwithtl.dot-based fast path for small problems (seq_len ≤ 2048, workspace fits in memory).get_total_gpu_memory().compute_default_tiling_strategy, grid program cap at 65535, seq_len-adaptive block sizes.fused_neighborhood_attention_backward; add backward result cache for repeated grad calls.Testing Done
ut:

forward:

backward:

full:

memory:

make testto ensure correctnessmake checkstyleto ensure code stylemake test-convergenceto ensure convergence