Add CuTe DSL implementation of the swiglu liger kernel#1277
Open
Celaena24 wants to merge 3 commits into
Open
Conversation
|
perhaps consider adding routing logic as not to regress the older architures? route cuteDSL where it wins, legacy where it loses? |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds a new NVIDIA CUTLASS Python-DSL (CuteDSL) implementation of the fused SwiGLU / SiLU-Mul activation, alongside the existing Triton kernel. The kernel computes
c = silu(a · gate_mult) · bwith an in-place backward (da→a,db→b), exposed via a drop-in autogradFunction(LigerSiLUMulCuteDSLFunction).Result on B200 — three headline outcomes:
src/liger_kernel/ops/swiglu_cutedsl.pyFunctiontest/transformers/test_swiglu_cutedsl.pyDetails
Two key optimizations drive the kernel:
mul_packed_f32x2/add_packed_f32x2/exp_packed_f32x2/rcp_approx): Processes two fp32 lanes per instruction on Blackwell, roughly halving the math instruction count so bf16/fp16 forward is not instruction-issue bound.--enable-tvm-ffi, so PyTorch tensors are passed directly withoutfrom_dlpack/memref-construction overhead (~26 µs → ~4 µs per launch). This path is used whenevernumelis a multiple of the CTA tile (i.e. every realistic workload); a predicated scalar-math fallback handles odd shapes.Isolated SiLU·Mul kernel — where the optimization lives
Timed in isolation (activation only, GEMMs excluded), with Triton and CuteDSL measured back-to-back in one process (same clocks/thermal state). Achieved HBM bandwidth on B200, bf16:
Headline: at the shapes training actually runs (large T, wide intermediate) the backward kernel is ~1.71× faster and holds that steadily from T=4096 through T=65536; the forward kernel reaches 1.57× at the widest models.
Full MLP — parity, no regression
The activation is a small slice of the GEMM-dominated MLP, so the kernel win amortizes to parity at the model level. All three providers measured in a single run at a stable window (rep=200) for a fair comparison; peak memory is bit-identical between liger and cutedsl:
Full-MLP fwd+bwd latency, bf16, T=2048, B200 (ms, lower = better). The point is no regression: switching to CuteDSL costs nothing at the model level while delivering the kernel-level wins above. (All three providers are timed in a single run on the same GPU so the comparison is fair.)
Full breakdown:
optimization/swiglu_cutedsl/report.md+profile.md.fp32 full MLP (B200, T=2048, fwd+bwd ms) holds the same parity — the isolated fp32 win is amortized by the GEMMs, identical memory:
Figures
Isolated SiLU·Mul kernel
(activation only, GEMMs excluded; Triton vs CuteDSL, B200) — the optimization's real win:
MLP unit
Model-config sweep — 7 real model widths (deepseek/llama/qwen), bf16, T=2048, three providers (huggingface / liger / liger_cutedsl):
Token-length sweep — llama-3-8B width, bf16, T = 1024→8192:
Cross-GPU + high-T (B200):
Testing Done
test/transformers/test_swiglu_cutedsl.py— 49/49 pass: fwd+bwd vs Triton and a pure-torch reference; shapes(4096,11008)/(2,256,512)/ non-tile-aligned(6,42,431)/ odd-width(3,1023)/ tiny(1,1,7)+ 1-D/3-D/4-D; fp32 (tol 1e-4), bf16/fp16 (tol 1e-2); gate/down multipliers.Hardware Type: NVIDIA B200 (sm_100)
run
make testto ensure correctnessrun
make checkstyleto ensure code stylerun
make test-convergenceto ensure convergence