Skip to content

Add CuTe DSL implementation of the swiglu liger kernel#1277

Open
Celaena24 wants to merge 3 commits into
linkedin:mainfrom
Celaena24:swiglu-cutedsl
Open

Add CuTe DSL implementation of the swiglu liger kernel#1277
Celaena24 wants to merge 3 commits into
linkedin:mainfrom
Celaena24:swiglu-cutedsl

Conversation

@Celaena24

@Celaena24 Celaena24 commented Jun 29, 2026

Copy link
Copy Markdown

Summary

Adds a new NVIDIA CUTLASS Python-DSL (CuteDSL) implementation of the fused SwiGLU / SiLU-Mul activation, alongside the existing Triton kernel. The kernel computes c = silu(a · gate_mult) · b with an in-place backward (da→a, db→b), exposed via a drop-in autograd Function (LigerSiLUMulCuteDSLFunction).

Result on B200 — three headline outcomes:

  • Isolated kernel speedup (the win): backward ~1.71× Triton. Forward is a smaller win, ranging from parity up to 1.29× depending on the model width.
  • Full-MLP parity: the activation is a small slice of the GEMM-dominated MLP, so at the model level cutedsl is ≈1.0× Triton — no regression.
  • Memory: peak memory bit-identical to Triton everywhere.
File Change
src/liger_kernel/ops/swiglu_cutedsl.py New — CuteDSL fwd/bwd kernels + autograd Function
test/transformers/test_swiglu_cutedsl.py New — 49 correctness tests

Details

Two key optimizations drive the kernel:

  • Packed-f32x2 SFU ops (mul_packed_f32x2 / add_packed_f32x2 / exp_packed_f32x2 / rcp_approx): Processes two fp32 lanes per instruction on Blackwell, roughly halving the math instruction count so bf16/fp16 forward is not instruction-issue bound.
  • TVM-FFI fast path: The kernel is compiled once per dtype against an abstract fake tensor with --enable-tvm-ffi, so PyTorch tensors are passed directly without from_dlpack/memref-construction overhead (~26 µs → ~4 µs per launch). This path is used whenever numel is a multiple of the CTA tile (i.e. every realistic workload); a predicated scalar-math fallback handles odd shapes.

Isolated SiLU·Mul kernel — where the optimization lives

Timed in isolation (activation only, GEMMs excluded), with Triton and CuteDSL measured back-to-back in one process (same clocks/thermal state). Achieved HBM bandwidth on B200, bf16:

Sweep pass CuteDSL vs Triton
T = 4096 → 65536 (llama-8B width 14336) backward 1.70–1.72×, flat across the whole range
T = 4096 → 65536 (llama-8B width 14336) forward 1.00–1.04× (parity)
7 real model widths, T = 2048 forward 1.01× → 1.57× (grows with intermediate width — deepseek_v3, qwen2.5-7B)
7 real model widths, T = 2048 backward 1.03× → 1.22×
bf16 / fp16 (H=11008, T=8192) backward 1.31× (fp32 1.08×)

Headline: at the shapes training actually runs (large T, wide intermediate) the backward kernel is ~1.71× faster and holds that steadily from T=4096 through T=65536; the forward kernel reaches 1.57× at the widest models.

Full MLP — parity, no regression

The activation is a small slice of the GEMM-dominated MLP, so the kernel win amortizes to parity at the model level. All three providers measured in a single run at a stable window (rep=200) for a fair comparison; peak memory is bit-identical between liger and cutedsl:

Model HuggingFace liger (Triton) liger_cutedsl cutedsl vs liger
deepseek_v2_lite 0.805 0.765 0.740 1.03×
llama_2_7b 1.447 1.418 1.398 1.01×
llama_3_8b 1.869 1.818 1.877 0.97×
qwen2.5_7b 2.149 2.090 2.063 1.01×
qwen2.5_14b 2.217 2.168 2.149 1.01×
qwen2.5_72b 7.027 6.891 6.841 1.01×
deepseek_v3 4.037 4.046 3.932 1.03×

Full-MLP fwd+bwd latency, bf16, T=2048, B200 (ms, lower = better). The point is no regression: switching to CuteDSL costs nothing at the model level while delivering the kernel-level wins above. (All three providers are timed in a single run on the same GPU so the comparison is fair.)

Full breakdown: optimization/swiglu_cutedsl/report.md + profile.md.

fp32 full MLP (B200, T=2048, fwd+bwd ms) holds the same parity — the isolated fp32 win is amortized by the GEMMs, identical memory:

Model liger (Triton) liger_cutedsl cutedsl vs liger
llama_2_7b 28.17 28.17 1.000×
llama_3_8b 35.47 35.46 1.000×
qwen2.5_7b 40.92 40.92 1.000×

Figures

Isolated SiLU·Mul kernel

(activation only, GEMMs excluded; Triton vs CuteDSL, B200) — the optimization's real win:

Speedup vs T (fwd & bwd) swiglu_iso_speedup_vs_T_b200
Speedup across 7 model widths swiglu_iso_model_widths_b200
Kernel GB/s, standard sweep range swiglu_iso_std_range_b200
Speedup by dtype swiglu_iso_dtype_b200
Isolated forward GB/s, high-T swiglu_highT_kernel_gbps_b200
Isolated backward GB/s, high-T swiglu_highT_kernel_bwd_gbps_b200

MLP unit

Model-config sweep — 7 real model widths (deepseek/llama/qwen), bf16, T=2048, three providers (huggingface / liger / liger_cutedsl):

Forward speed — MLP forward pass; cutedsl ≈ liger swiglu_speed_forward_model_config
Backward speed — MLP backward pass; cutedsl ≈ liger swiglu_speed_backward_model_config
Full MLP speed — gate+up+down GEMMs + activation; all three at parity swiglu_speed_full_model_config
Full MLP memory — cutedsl == liger to the MB, both ~10% below HF swiglu_memory_full_model_config

Token-length sweep — llama-3-8B width, bf16, T = 1024→8192:

Forward speed vs T swiglu_speed_forward_token_length
Backward speed vs T swiglu_speed_backward_token_length
Full MLP speed vs T — parity (cutedsl ≈ liger), both ≈ HF swiglu_speed_full_token_length
Full MLP memory vs T — exact parity swiglu_memory_full_token_length

Cross-GPU + high-T (B200):

Full MLP speed, B200 vs H100 — parity within each GPU; B200 ~2× faster than H100 (hardware) swiglu_full_speed_b200_vs_h100
Full MLP memory, B200 vs H100 — exact parity swiglu_full_memory_b200_vs_h100
Full MLP ms, high-T — all providers within noise, both Liger backends ≈ HF swiglu_highT_full_ms_b200

Testing Done

  • test/transformers/test_swiglu_cutedsl.py — 49/49 pass: fwd+bwd vs Triton and a pure-torch reference; shapes (4096,11008) / (2,256,512) / non-tile-aligned (6,42,431) / odd-width (3,1023) / tiny (1,1,7) + 1-D/3-D/4-D; fp32 (tol 1e-4), bf16/fp16 (tol 1e-2); gate/down multipliers.

  • Hardware Type: NVIDIA B200 (sm_100)

  • run make test to ensure correctness

  • run make checkstyle to ensure code style

  • run make test-convergence to ensure convergence

@thad0ctor

Copy link
Copy Markdown

perhaps consider adding routing logic as not to regress the older architures? route cuteDSL where it wins, legacy where it loses?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants