Skip to content

[NPU] refact fused_neighborhood_attention#1265

Open
sunyi0505 wants to merge 1 commit into
linkedin:mainfrom
sunyi0505:fused_neighborhood_attention
Open

[NPU] refact fused_neighborhood_attention#1265
sunyi0505 wants to merge 1 commit into
linkedin:mainfrom
sunyi0505:fused_neighborhood_attention

Conversation

@sunyi0505

@sunyi0505 sunyi0505 commented Jun 22, 2026

Copy link
Copy Markdown
Contributor

Summary

Refactor Ascend fused neighborhood attention from a dense S×S matrix implementation to a sparse band [S, kernel_size] representation, dramatically reducing memory footprint and eliminating redundant compute on masked-out positions. To resolve the UB overflow and OOM issues.

Details

Key changes:

  • Sparse forward: Replace _neighborhood_mask_kernel + dense QK + external softmax with fused _sparse_neighborhood_qk_softmax_kernel that computes only neighborhood dot products and applies row softmax in-register.
  • Sparse AV: Rewrite attention-value product to gather neighbor V positions and accumulate via element-wise multiply instead of dense tl.dot over seq_len.
  • Dual backward paths: Add _sparse_neighborhood_attention_backward for large sequences and _dense_neighborhood_attention_backward with tl.dot-based fast path for small problems (seq_len ≤ 2048, workspace fits in memory).
  • Memory-aware chunking: Introduce batch/batch-head chunking with device memory limits (8% forward, 18% backward) via get_total_gpu_memory().
  • Ascend hardware constraints: UB-aware tiling via compute_default_tiling_strategy, grid program cap at 65535, seq_len-adaptive block sizes.
  • Autograd improvements: Extract backward from inline autograd into fused_neighborhood_attention_backward; add backward result cache for repeated grad calls.

Testing Done

ut:
image

forward:
image

backward:
image

full:
image

memory:
image

  • Hardware Type: Tested with Atlas 800T A3(X86)(triton-ascend=3.2.1)
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@sunyi0505

Copy link
Copy Markdown
Contributor Author

Fused Neighborhood Attention (Ascend) Refactoring Analysis

1. Change Overview

To resolve the UB overflow and OOM issues occurring with the fused_neighborhood_attention operator on the Ascend NPU, the following refactoring has been implemented. This refactoring migrates the Ascend backend's Neighborhood Attention from a dense S×S matrix implementation to a sparse neighborhood band implementation, introducing memory-aware tiling, dual-path backward scheduling, and Ascend hardware constraint adaptation in forward/backward passes.

Dimension Old Implementation New Implementation
Attention weight shape [B, H, S, S] dense [B, H, S, kernel_size] sparse band
Neighborhood mask Separate kernel pre-generates [S, S] Implicitly computed within kernel via neighbor index
QK + Softmax Separate: QK kernel → _softmax_forward Fused: _sparse_neighborhood_qk_softmax_kernel
Forward compute dtype Forced FP32 Preserve input dtype, intermediate accumulation in FP32
Backward implementation Fully inline in autograd.Function.backward Standalone fused_neighborhood_attention_backward + sparse/dense dual paths
Memory strategy No tiling, full S² allocation Batch / batch-head tiling + memory upper bound
Grid constraints Not handled _MAX_GRID_PROGRAMS = 65535 enforced

2. Architecture Comparison

2.1 Old Architecture (Dense)

Q, K ──► _neighborhood_mask_kernel ──► mask [S, S]
                │
Q, K, mask ──► _fused_neighborhood_attention_qk_kernel ──► qk_scores [B,H,S,S]
                │
                └──► _softmax_forward ──► attn_weights [B,H,S,S]
                              │
attn_weights, V ──► _fused_neighborhood_attention_av_kernel ──► output

Backward:

grad_output ──► grad_attn ──► _softmax_backward ──► grad_qk
grad_qk ──► grad_Q / grad_K
attn, grad_output ──► grad_V

2.2 New Architecture (Sparse Band + Dual-Path Backward)

Q, K ──► _sparse_neighborhood_qk_softmax_kernel ──► attn [B,H,S,K]
                              │
attn, V ──► _fused_neighborhood_attention_av_kernel (sparse) ──► output

Backward Routing:

fused_neighborhood_attention_backward
    │
    ├─ attn.shape[-1] == kernel_size AND fast-path conditions met
    │       └─► _dense_neighborhood_attention_backward
    │             (expand sparse→dense, tl.dot dense backward)
    │
    ├─ attn.shape[-1] == kernel_size AND fast-path unavailable
    │       └─► _sparse_neighborhood_attention_backward
    │             (full sparse band kernel path)
    │
    └─ attn.shape[-1] == seq_len (legacy dense checkpoint)
            └─► _dense_neighborhood_attention_backward

3. Core Changes Breakdown

3.1 Removal of _neighborhood_mask_kernel

Old approach: Allocate [S, S] mask first; QK kernel reads mask per tile and sets non-neighborhood positions to -inf.

New approach: Within kernel, directly compute neighbor key column indices via for neighbor_idx in range(kernel_size):

offset = neighbor_idx - half_kernel
key_cols = row_offsets + offset * dilation

Rationale:

  • Eliminates [S, S] mask memory and one extra kernel launch
  • Neighborhood structure parameterized by (kernel_size, dilation), no mask materialization needed
  • Naturally aligned with sparse band storage

3.2 QK + Softmax Fusion: _sparse_neighborhood_qk_softmax_kernel

Old _fused_neighborhood_attention_qk_kernel:

  • Outputs full [S, S] QK scores
  • Uses tl.dot for dense matrix multiplication
  • Depends on external _softmax_forward
  • Bug: acc = acc * scale inside head_dim loop, causing repeated scale application

New kernel:

  • Computes dot products only for kernel_size neighbor positions (tl.sum(q * k, axis=1))
  • Performs row-wise softmax entirely in registers (max → exp → normalize)
  • Directly scatters to [S, kernel_size] Attn_ptr
  • Scale applied once uniformly before softmax

Rationale:

  • Memory reduced from O(S²) to O(S·K), with K typically 7–15
  • Fewer kernel launches and reduced global memory reads/writes
  • Dense QK computations are wasteful for neighborhood attention, with most work spent on masked-out zeros

3.3 AV Forward Changed to Sparse Band Reading

Old: attn [B,H,S,S] × V [B,H,S,D] using tl.dot along seq_len dimension

New: For each neighbor_idx, gather V at corresponding key positions and perform element-wise weighted summation:

acc += attn_vals[:, None] * v_chunk

Rationale: Input is already in sparse band format; continuing with dense S² matrix multiplication no longer makes sense. Gather + multiply matches the data structure.


3.4 Backward: From Inline to Modular Dual-Path

Sparse Backward Path (_sparse_neighborhood_attention_backward)

New kernel family:

Kernel Purpose
_sparse_neighborhood_grad_attn_softmax_bwd_kernel Fused grad_attn + softmax backward → grad_qk (sparse)
_sparse_neighborhood_grad_q_kernel grad_Q
_sparse_neighborhood_grad_k_kernel grad_K
_sparse_neighborhood_grad_qkv_fused_kernel Fused grad_Q/K/V (3D grid)
_fused_neighborhood_attention_grad_v_kernel grad_V (sparse gather)

Rationale: At large seq_len, S² intermediate tensors are infeasible; sparse kernels operate entirely on [S, K].

Dense Backward Path (_dense_neighborhood_attention_backward)

New kernel family:

Kernel Purpose
_sparse_attn_expand_to_dense_kernel Sparse band → dense [S,S] (only needed for backward)
_dense_neighborhood_grad_attn_kernel grad_attn = grad_output @ V^T
_dense_neighborhood_grad_qk_fused_kernel Fused grad_Q + grad_K
_dense_neighborhood_grad_qk_fused_persistent_kernel Persistent variant
_dense_neighborhood_grad_v_kernel grad_V = attn^T @ grad_output

Fast path conditions (_sparse_band_dense_fast_path_available):

  • seq_len <= 2048
  • workspace B×H×S×S×3×elem_size does not exceed 18% of available memory
  • grid does not exceed 65535

Rationale: For small sequences, tl.dot is more efficient on Ascend; after one expansion, existing dense matmul kernels outperform handwritten sparse gather.


3.5 Memory and Tiling Infrastructure

New helper layer (largely absent in old version):

Function Purpose
_scores_chunk_byte_limit() Forward attn tiling: per-chunk limit ≤ 8% of available memory
_attention_batch_chunk_size() Split forward along batch dimension
_get_attention_block_sizes() UB-aware tile sizes (compute_default_tiling_strategy)
_dense_backward_byte_limit() Backward workspace upper bound (18% of available memory)
_dense_backward_batch_head_chunk() Chunk backward along batch×head dimension
_get_dense_block_m() / _dense_backward_grid_fits() Adapt to Ascend grid 65535 upper bound
_get_sparse_backward_block_sizes() Adjust block_m/block_n by seq_len tiers

Rationale:

  • Neighborhood attention is sparse, but backward fast path still requires S² workspace
  • Ascend NPU has UB (Unified Buffer) and grid constraints requiring dynamic tiling
  • Large batch + long sequence combinations would OOM without chunking

3.6 Forward API Changes

Item Old New
Return value (output, attn_weights, softmax_params) (output, attn_for_backward)
softmax_params Saved for backward use No longer needed (softmax fused)
dtype handling Input cast to FP32, output cast back Intermediate FP32 accumulation, attn cast to input dtype
Batch tiling None Controlled by _attention_batch_chunk_size

3.7 Autograd Layer Enhancements

Old LigerFusedNeighborhoodAttentionFunction.backward:

  • 200+ lines of inline backward logic
  • Re-executed on every backward call

New:

  • Delegates to fused_neighborhood_attention_backward
  • Added backward cache: identical grad_output tensor (same storage/version/shape/stride) reuses grad_Q/K/V
  • Cache capacity of 4 entries, LRU eviction

Rationale: Separation of concerns; cache avoids redundant computation on double backward or repeated grad calls.


4. Why These Changes — Motivation Summary

4.1 Memory Bottleneck (Primary)

Effective connectivity in Neighborhood Attention is O(S·K), but the old implementation allocates O(S²) tensors for qk_scores, attn_weights, and grad_*. At S=4096, H=32, B=8:

  • Old: Single [B,H,S,S] FP32 tensor ≈ 8×32×4096²×4 ≈ 17 GB
  • New: Single [B,H,S,K] FP32 tensor (K=7) ≈ 30 MB scale

This is the primary driver for this refactoring.

4.2 Computational Efficiency

The old QK stage runs tl.dot over the entire S×S matrix and then masks out invalid positions — computational cost equivalent to full attention. The new implementation computes dot products only for K neighbors, reducing computation to O(S·K·D).

4.3 Ascend Hardware Adaptation

  • Introduces compute_default_tiling_strategy for UB-aware tiling
  • Explicitly handles grid program upper bound of 65535
  • Adjusts block sizes by seq_len tiers (512 / 2048 / 4096)
  • Uses get_total_gpu_memory() for dynamic chunk threshold calculation

4.4 Correctness and Maintainability

  • Fixes old QK kernel bug where scale was repeatedly applied inside the K dimension loop
  • Backward logic modularized; sparse/dense paths can be independently tested and optimized
  • Retains backward compatibility with legacy dense attn_weights shapes

@sunyi0505

Copy link
Copy Markdown
Contributor Author

@Tcc0403 This pr is ready for review. Thanks!

@sunyi0505 sunyi0505 force-pushed the fused_neighborhood_attention branch from 4225067 to bba28ce Compare June 27, 2026 08:00

@Tcc0403 Tcc0403 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left my comments under one kernel but other kernels have similar concerns, once they are resolved, it should be applicable to other kernels as well.

Comment on lines +103 to +106
partial = tl.sum(q_chunk * k_chunk, axis=1)
col_select = (tl.arange(0, kernel_size) == neighbor_idx)[None, :]
dot_acc += tl.where(col_select, partial[:, None], tl.zeros_like(dot_acc))
neighbor_valid = tl.maximum(neighbor_valid, tl.where(col_select & key_valid[:, None], 1.0, 0.0))

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is quite difficult to understand. Let me know if I'm reading it wrong.

In for neighbor_idx in range(kernel_size): loop, we iteratively compute QK tile column by column with tl.sum instead of computing the whole tile once with tl.dot. What's the rationale behind it? Why couldn't we simply leverage tl.dot? Is kernel size being one of mma tile dimension not supported?

To handle partial tile access that isn't support in triton rerquires couple of tl.where as workaround, which is suboptimal for performance and readability. If we can make tile level instructions work with some modification (tile size, padding, etc), that would be our first choice. However, it's still acceptable if such workaround can achieve better result.

Some suggestions if partial pattern is a must:

  1. tl.where and tl.zeros_like: I suppose that tl.zeros_like is for matching tl.where x/y shapes, but register pressure and frequent allocations could be performance issues. Setting out-of-bound values and proper masking when loading k_chunk would be a better approach. https://triton-lang.org/main/python-api/generated/triton.language.where.html
  2. neighbor_valid: It seems to be a flag tensor which is filled with 1 iteratively in the loop. Instead of max + where, it should be equivalent to logical OR with col_select & key_valid[:, None], so we can avoid max and where.

Note: above suggestions are not verified applicable and guaranteed for perf gain.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The neighborhood QK here is a per-row gather pattern, not a contiguous matmul tile:

For each query row i and neighbor j, we need dot(Q[i], K[key_col(i,j)]), where key_col(i,j) = i + (j - half_kernel) * dilation.
K indices differ per row, so K cannot be laid out as a contiguous [N, D] tile for a single tl.dot(Q, K^T).
kernel_size is small (typically 7–15), so gather + element-wise multiply + tl.sum is the natural fit.
Agreed on the tl.where / neighbor_valid readability concerns — addressed in the latest push:

Replaced tl.zeros_like scatter with in-place tl.where(..., dot_acc).
Precomputed key_valid_all and removed the iterative neighbor_valid flag tensor.
We explored tl.dot but it would require materializing a [BLOCK_M, kernel_size, BLOCK_K] K tensor via gather first, which adds UB pressure without a clear win on Ascend for small K.

Comment on lines +114 to +128
for neighbor_idx in range(kernel_size):
offset = neighbor_idx - half_kernel
key_cols = row_offsets + offset * dilation
key_valid = row_valid & (key_cols >= 0) & (key_cols < seq_len)

col_select = (tl.arange(0, kernel_size) == neighbor_idx)[None, :]
attn_val = tl.sum(attn_local * col_select, axis=1)
attn_ptrs = (
Attn_ptr
+ batch_id * attn_batch_stride
+ head_id * attn_head_stride
+ safe_row * attn_seq_stride
+ neighbor_idx * attn_neighbor_stride
)
tl.store(attn_ptrs, attn_val, mask=key_valid)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto, is it possible to store a tile instead of column by column?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — attn weights are now stored as a full [BLOCK_M, kernel_size] tile in one tl.store, using a precomputed key_valid_all mask. Applied the same pattern to grad_qk store in the sparse softmax-backward kernel.

@sunyi0505 sunyi0505 force-pushed the fused_neighborhood_attention branch 2 times, most recently from 05a263f to 532b4bc Compare June 29, 2026 02:47
@sunyi0505

Copy link
Copy Markdown
Contributor Author

Kernels with identical issues have all been fixed accordingly. A summary is provided below:

Fixed Kernels (5 in total)

Kernel Name Optimization Details
_sparse_neighborhood_qk_softmax_kernel Removed iterative computation of neighbor_valid; implemented tile-based storage for attention weights
_sparse_neighborhood_grad_attn_softmax_bwd_kernel Apply tile load to y_cols; adopt tile-based storage for grad_qk
_sparse_neighborhood_grad_q_kernel Load grad_qk via tile load; only execute key gathering within the loop
_sparse_neighborhood_grad_qkv_fused_kernel Use tile load for gq in the Q branch (query gathering is still used for gk and attention weights)
_sparse_attn_expand_to_dense_kernel Adopt tile load for sparse attention; dense scattering remains column-wise due to memory layout constraints

Kernels Not Eligible for the Same Optimization (3 in total)

These kernels adopt query-side index gathering (safe_query = row - offset * dilation). Their sparse bands are non-contiguous tiles along the sequence dimension, making single-pass tile loading infeasible:

  1. _fused_neighborhood_attention_grad_v_kernel — Attention weights are gathered by query row
  2. _sparse_neighborhood_grad_k_kernelgrad_qk is gathered by query row
  3. _dense_neighborhood_* series — These already leverage standard dense matrix multiplication via tl.dot and are free of the above limitations

_fused_neighborhood_attention_av_kernel cannot be modified using the aforementioned optimization scheme. Preloading the entire attention tile would require multiple buffers of shape M × kernel_size to reside in memory simultaneously, including attn_tile, key_cols_all, key_valid_all and col_select. With multi-buffer enabled on the Ascend platform, the total UB memory requirement reaches approximately 782 KB, exceeding the 192 KB hardware limit.

@sunyi0505 sunyi0505 requested a review from Tcc0403 June 29, 2026 02:58
Co-authored-by: Cursor <cursoragent@cursor.com>
@sunyi0505 sunyi0505 force-pushed the fused_neighborhood_attention branch from 532b4bc to 183cc88 Compare June 30, 2026 01:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants