Skip to content

[sm100] Packed FP4×FP4 mega-MoE kernel (W4A4) with per-band BLOCK_K#44

Open
Romaosir wants to merge 22 commits into
sgl-project:devfrom
Romaosir:w4a4-packed-fp4-mega-moe-dev
Open

[sm100] Packed FP4×FP4 mega-MoE kernel (W4A4) with per-band BLOCK_K#44
Romaosir wants to merge 22 commits into
sgl-project:devfrom
Romaosir:w4a4-packed-fp4-mega-moe-dev

Conversation

@Romaosir

@Romaosir Romaosir commented Jun 14, 2026

Copy link
Copy Markdown

Adds a dedicated W4A4 (FP4 weights × FP4 activations) fused mega-MoE kernel for SM100 — sm100_fp4_mega_moe — a packed-E2M1 specialization that issues tcgen05.mma.kind::mxf4, plus a per-band BLOCK_K selection that exploits packed-FP4's lower SMEM footprint for a measured kernel-level speedup over the existing FP8-acts and 0.1.0 W4A4 paths.

Paired with sgl-project/sglang#28210 (sgl-project/sglang#28210) — version bump + one-line env propagation; the existing W4A4 env path transparently routes to this kernel, no other sglang code changes.

What's new

  • sm100_fp4_mega_moe_impl: packed float_e2m1_t A and B (no multi-mode branches), hand-built packed-byte SMEM descriptors, dedicated packed_fp4_2sm_tma_load_2d (byte-typed tensor maps, per-CTA cta_group::2 issues with leader-centralized 2× tx accounting), SwiGLU epilogue cast straight to packed FP4. Half the activation bytes + SMEM of the FP8-acts path.
  • Per-band BLOCK_K=256 + 128B swizzle for the block_m ∈ {32, 64, 192} bands ({96, 128} stay 128). Packed FP4's 256-deep K stage costs ~31 KB SMEM (vs ~57 KB for FP8 acts), so it sustains 6–10 pipeline stages where the FP8-acts retrofit cannot afford the knob. Larger, more-contiguous TMA transactions raise DRAM efficiency. Overrides: DG_BLOCK_K (128/256), DG_NUM_STAGES (clamp).
  • mega_moe_pre_dispatch: packs activations to E2M1 + UE8M0 SF into the symmetric buffer.
  • csrc/apis/mega.hpp: env-selected dispatch (DG_MEGA_MOE_FP4 / DG_USE_FP4_ACTS); symm-buffer halves the activation regions under FP4 acts.
  • Tests: FP4-acts coverage + --shape-ntoks/--bench-reps sweep with in-process interleaved A/B (DG_BLOCKK_AB/DG_STAGES_AB/DG_WAVE_AB).

How it differs from the existing W4A4 path (sm100_fp8_fp4_mega_moe<kUseFp4Acts, kUseMxf4Kind>)

Same architecture (warp roles, scheduler, comm/barriers, packed-E2M1 symm buffer) and the same final tcgen05.mma.kind::mxf4 instruction — the difference is a dedicated specialization (enabling the per-band BLOCK_K win) instead of a template-flag retrofit of the FP8 kernel:

aspect existing W4A4 (retrofit) this kernel (specialized)
code shape 1 kernel + 4 template bools; conditional dtypes, swizzle/SMEM ternaries, per-phase MMA-desc dtype flips unconditional packed float_e2m1_t A and B; no multi-mode branches
BLOCK_K fixed 128 per-band 256/128 (the measured speedup)
SMEM descriptors generic make_umma_desc via _unpacksmem/_ALIGN16B hand-built packed-byte strides (generic helper miscounts K for packed operands)
TMA loads generic tma::copy dedicated packed_fp4_2sm_tma_load_2d (byte tensor maps, 64B/128B swizzle atoms, leader-centralized 2× tx)
MMA wrapper generic instr-desc + dtype-field flip dedicated ptx::SM100_MMA_MXFP4_2x1SM_SS
standalone validation sm100_fp4_fp4_gemm_1d1d single-GPU GEMM clone

Performance (4× B200, DSV4-Flash: h=4096 / i=2048 / E=256 / topk=6; bit-exact numerics vs the BLOCK_K=128 path, torch.equal)

Per-band BLOCK_K vs all-128 — in-process 4-rank interleaved A/B, 20 reps (the precision claim):
32-band ~+10%, 64-band +2–6%, 192-band ~+4%. Mechanism (NCU): DRAM efficiency from larger contiguous transactions (64% → 85% busy at ntok=512).

vs W4A4-old — the existing 0.1.0-release W4A4 kernel (sgl-project/DeepGEMM#27, sm100_fp8_fp4_mega_moe<kUseFp4Acts, kUseMxf4Kind>) — both kernels measured through the same binding (DG_FORCE_FP8FP4 toggle) to isolate kernel from packaging:

  • 1-rank NCU 537.6 µs vs 712.5 µs (+24.6%), DRAM 85.2% vs 64.3% — the block_k=256 data path (the 0.1.0 retrofit is fixed at block_k=128).
  • Fairness: the W4A4-old arm is true W4A4 — L2 contracts FP4-width (L2_SHAPE_K identical, compile-time); its DRAM 64.3% = the expected block_k=128 efficiency, no traffic confound.
  • 1-rank over-weights the memory regime; end-to-end the kernel win dilutes through Amdahl (mega-MoE is a fraction of the workload — the decode path is ~82% sm100_tf32_hc_prenorm_gemm, shared with both baselines).

End-to-end 5-concurrency sweep (TP4+DP4, sequential, same GPU group, 1 rep/cell after per-concurrency warmup, num_prompts raised for steady state). Ours and the 0.1.0 W4A4 ("old") run via the same fork pybind; W4A8 (FP8 acts) is the 0.1.0-release reference on its own tvm-ffi binding (quant-scheme baseline, not a same-binding comparison):

Prefill (in=1024, out=1) — input throughput (tok/s) · median TTFT (ms):

conc Ours tput old tput W4A8 tput Δ tput Ours/old Ours TTFT old TTFT Δ TTFT Ours/old
1 5847 5886 5683 −0.7% 156.8 154.8 +1.4%
4 14802 14478 10095 +2.2% 289.5 292.2 −0.9%
32 54114 54755 52762 −1.2% 417.6 402.2 +3.8%
128 126557 117838 123033 +7.4% 1110.6 1145.8 −3.1%
512 133116 130299 124708 +2.2% 4575.7 4672.6 −2.1%

Mixed (in=1024, out=256) — total throughput (tok/s) · median TPOT (ms):

conc Ours tput old tput W4A8 tput Δ tput Ours/old Ours TPOT old TPOT Δ TPOT Ours/old
1 342 338 340 +1.1% 14.02 14.17 −1.1%
4 1268 1291 1274 −1.8% 14.55 14.82 −1.8%
32 8212 8055 8149 +2.0% 17.82 18.11 −1.6%
128 24982 24544 24204 +1.8% 22.40 22.63 −1.0%
512 51295 50201 50363 +2.2% 33.50 37.81 −11.4%

Reading: the kernel win surfaces at production concurrency (c ≥ 128) where mega-MoE's share is largest — prefill input-throughput +2–7%, mixed throughput +2%, TPOT down to −11% at c=512. Low concurrency (c ≤ 4) is within the e2e noise band, and the c=1/4 mixed cells (num_prompts = 12) are under-sampled.

NCU Speed-of-Light (1-rank isolation, binding-controlled, ntok=512)

Profiling the fused 4-rank kernel with NCU is invalid (persistent kernel dominated by cross-rank NVLink-barrier spinning → SOL collapses). Single-rank removes the barrier and gives true SOL. All three arms measured at the identical shape (hidden=4096 / inter=2048 / E=256 / topk=6, ntok=512) through the same fork binding:

metric W4A4-Ours (packed, bk256) W4A4-old (fp8_fp4, bk128) W4A8 (fp8 acts)
Duration (µs) 537.6 712.5 753.2
Compute (SM) % 33.6 28.6 42.5
Memory % 85.2 64.3 61.4
DRAM throughput (TB/s) 6.54 4.93 4.71
L2 % 56.1 42.4 40.3
Achieved occupancy % 22.5 22.4 22.5
Registers/thread 128 128 128

Ours +24.6% vs W4A4-old, +28.6% vs W4A8 (1-rank). Mechanism: same occupancy/registers (identical launch config) → the win is the data path. Ours' block_k=256 sustains 85% memory SOL / 6.54 TB/s DRAM vs W4A4-old's block_k=128 at 64% / 4.93 TB/s — the larger-contiguous-transaction DRAM-efficiency win. W4A8 carries the heaviest compute (42.5% SM, FP8 acts = 1 byte/act) and the least DRAM headroom. 1-rank over-weights the memory regime (255 experts local, ~2 tokens/expert); use it for the mechanism, the e2e sweep above for the realistic magnitude.

Accuracy (eval_v2, DSV4-Flash)

eval W4A4-Ours W4A4-old W4A8
GSM8K 0.966 0.962 0.962
HumanEval (pass@1) 0.898 0.888 0.887

W4A4-Ours is at or above both the W4A4-old and W4A8 baselines on every eval run — the packed FP4×FP4 path + per-band BLOCK_K preserve accuracy.

Notes

Fridge003 and others added 22 commits April 25, 2026 11:47
Co-Authored-By: rainj-me <rain-jiang@outlook.com>
…el (sgl-project#27)

Two related additions for the DeepSeek-V4-Pro mega-MoE path:

1. **FP4 (E2M1) activations + `kind::mxf4` mainloop opt-in** for `fp8_fp4_mega_moe`.
   - `DG_USE_FP4_ACTS=1` halves the symm-buffer x-slot footprint (E2M1 nibbles
     vs E4M3 bytes); SF slot unchanged (still `hidden/32` UE8M0 bytes under
     gran_k=32).
   - `use_mxf4_kind=true` switches the L1+L2 mainloops to `cta_group::2 kind::mxf4`
     (2-CTA cluster) with dense FP4 smem layout (`_ALIGN8B`, 2 nibbles/byte).
     Per-stage A/B byte footprint halves → num_stages doubles for the same
     smem budget.
   - Threads `cumulative_local_expert_recv_stats` through the public mega-MoE
     API for per-rank expert counters used by sglang's expert-distribution
     recorder.
   - Block-m heuristic: under `use_mxf4_kind`, bumps `block_m=16 → 32` for the
     smallest-tokens-per-expert bucket so `load_block_m * block_k / 2` meets
     the 1024-byte smem alignment.
   - Multi-block_m support via `kCandidateBlockM` array + LCM-aligned pool
     padding; replaces the static `block_m=192` heuristic with token-density
     dispatch (8/16/32/64/96/128/192).

2. **`mega_moe_pre_dispatch` kernel**: BF16 → quant + topk-copy + pad-fill in
   one launch, gated on `kUseFp4Acts` + `kUsePDL`. Templated on
   `(kGroupSize, kUseFp4Acts, kUsePDL)`. Uses bucketize-style E2M1 encoder for
   byte-exact match against the `per_token_cast_to_fp4` host helper.
   - New: `deep_gemm.mega_moe_pre_dispatch(x, topk_idx, topk_weights, buf_x,
     buf_x_sf, buf_topk_idx, buf_topk_weights, num_tokens, group_size, use_fp4_acts)`
   - Test: `tests/test_mega_moe_pre_dispatch.py` — single-GPU bytewise check
     against host `per_token_cast_to_fp{8,4}` + pad-fill assertion.

Validated end-to-end on 8× B300 with DeepSeek-V4-Pro at 8K input bench:
- FP4 acts + MXF4 kind path produces matching tokens vs the FP8 baseline
  (rel-RMSE ≤ 0.5 sentinel; GSM8K accuracy parity within run-to-run variance).

PR also includes existing FP4-mega-MoE supporting changes that are required
by the kernel:
- `cluster_sync_with_relaxed_arrive` helper (used twice in `sm100_fp8_fp4_mega_moe.cuh`).
- `cvt_pack_f32_to_e2m1x2` / `cvt_pack_f32x4_to_e2m1x4` PTX wrappers.
- `SM100_MMA_MXF4_2x1SM_SS` 2-CTA cluster MMA wrapper.
- Generalized `red_add(int*, int)` for the `cumulative_local_expert_recv_stats`
  counter.
- `st.L1::no_allocate.relaxed.sys.global.u64` (correctness fix: previous
  generic-address variant could miss the global state space).

Co-authored-by: pranjalssh <adkz.photos@gmail.com>
(cherry picked from commit bca278e)
…bine path) (sgl-project#28)

* Add DG_USE_FP8_COMBINE: FP8 + per-row UE8M0 SF on the second a2a (combine path)

The mega-MoE second all-to-all (combine) currently ships BF16 over NVLink:
each token, each topk slot = kHidden * 2 bytes. This commit adds an env-
gated FP8 path that ships FP8 E4M3 + a per-(token, N=128) UE8M0 SF byte —
kHidden + kHidden/128 bytes per token per slot, half the NVLink bytes.

Wiring:
- New `kUseFp8Combine` template flag (default false → keeps BF16 path
  byte-identical when off).
- New `combine_sf_buffer` symm-buffer slot, sized kHidden/128 bytes per
  (token, slot) when on, zero when off.
- Host: `DG_USE_FP8_COMBINE=1` env flag in `mega.hpp`. Independent of
  `DG_USE_FP4_ACTS` / `DG_USE_MXF4_KIND` (those control the dispatch a2a +
  mainloops; this controls the combine a2a only).

Producer side (L2 epilogue write-back, sm100_fp8_fp4_mega_moe.cuh):
- Read 8 BF16 from smem (existing STSM target).
- Compute per-row amax via `__shfl_xor_sync` reduction over the 16 lanes
  that share each row tile. Use a 16-lane mask (NOT 0xffffffff) — the
  outer `if (m_idx_in_block >= valid_m) break` may cause the OTHER half-
  warp to exit on padding rows, and a full-warp shfl would deadlock.
- Compute UE8M0 SF (E4M3 finfo_max=448, mirrors `get_e4m3_sf_and_sf_inv`).
- Cast 8 BF16 → 8 FP8 via `__nv_fp8x4_e4m3(float4)` ×2; pack into uint64.
- Write 8 FP8 bytes to remote (vs 16 BF16 bytes). Lane 0 of the 16-lane
  group writes the SF byte to `combine_sf_buffer`.

Consumer side (combine reduce):
- Per-slot SF base ptr cached at slot start.
- TMA-load FP8 chunk (kNumChunkBytes / 2 bytes when kUseFp8Combine).
- Per uint4 (16 FP8): __ldg the SF byte for the segment; FP8 → FP16x2
  via `cvt.rn.f16x2.e4m3x2`, FP16 → FP32 via `cvt.f32.f16`, then
  `__fmaf_rn(val, sf, acc)` for the accumulate-with-dequant.
- BF16 store-buffer layout for FP8 path: 2 BF16 uint4 per input uint4
  (16 elements → 2 × 8 BF16 stripes), at indices (j*32+lane)*2 + {0,1}.
  Total store uint4/lane same as BF16 path (kNumChunkUint4Bf16 / 32).

Validation:
- Microbench (`ptx/d_combine_reduce_v{1,2}_*`):
  - v1 BF16 baseline: 6,895 cycles/token, max_abs=0 (perfect).
  - v2 FP8 + UE8M0 SF: correctness PASS (max_abs=0 vs host reference
    that uses the same FP8 quant), 50% NVLink bytes savings.
- Single-GPU iso bench (8x B300, fp4_mxf4 vs fp4_mxf4+combine):
  - b=128:  364 us → 359 us (+1.5%)
  - b=512:  377 us → 386 us (-2.2%)
  - b=2048: 710 us → 739 us (-3.9%)
  Single-GPU is compute-bound (no NVLink saving); production is the
  point of the change.
- E2E DeepSeek-V4-Pro on 8x B300 (b=8192 input, 1024 output):
  - b=512:  91.92 s (FP8) → 78.37 s (FP4+MXF4+FP8combine) — +17.3%
  - b=2048: 259.4 s (FP8) → 238.2 s — +8.9%
  - b=4096: 489.5 s (FP8) → 444.2 s — +10.2%
  Sentinel test (FP4 acts vs FP8 acts): rel-RMSE <= 0.5 still passes.

Numerical: rel-RMSE on synthetic random init = 0.027 (combine FP8 vs
BF16 baseline, w/o SwiGLU clamping → tail outliers). Real activations
post-SwiGLU + topk-weighting are bounded; production accuracy parity
preserved (same GSM8K results as FP4 baseline).

* Combine reduce: HFMA path (FP16 accumulator + fma.f16x2)

Switch the FP8 combine reduce inner loop from FP32 accumulator + scalar
fma to FP16x2 accumulator + hfma.f16x2. Halves the per-element op count
and halves the accumulator register pressure (94 regs vs 138 regs).

Inner loop, before:
  cvt.rn.f16x2.e4m3x2 (FP8x2 → FP16x2)
  cvt.f32.f16  ×2     (FP16 → FP32)
  fma.rn.f32   ×2     (acc += sf_f32 * f32_val)
  = 5 ops per FP8x2 (= 2 elements)

After:
  cvt.rn.f16x2.e4m3x2 (FP8x2 → FP16x2)
  fma.rn.f16x2        (acc_fp16x2 += sf_pair * f16x2)
  = 2 ops per FP8x2

SF in FP16: UE8M0 byte → 1.0 * 2^(byte-127), packed as FP16 with bias 15.
Out-of-range SFs (byte < 112 or > 142) clamp to 0 / FP16-max — production
activations post-SwiGLU + topk-weighting fit comfortably in FP16 range.

End cast: FP16x2 → __half22float2 → __float22bfloat162_rn for the gmem
write-back (BF16 output unchanged).

Microbench (`ptx/d_combine_reduce_v3_fp8_hfma`):
  v1 BF16 baseline: 6,895 cycles/token
  v2 FP8 + FP32 acc: 10,797 cycles/token (+57% vs v1)
  v3 FP8 + FP16 HFMA: **5,799 cycles/token (-16% vs v1, -46% vs v2)**

E2E DeepSeek-V4-Pro 8x B300, 8K input + 1024 output:
  | batch | FP4+MXF4 | combine FP32 | combine HFMA |
  |------:|---------:|-------------:|-------------:|
  | 512   | —        | 7,526        | 7,350        |
  | 2048  | 9,814    | 9,903        | **9,992**    |
  | 4096  | 10,418   | 10,622       | **10,699**   |

HFMA wins at 2048/4096; ~tie at 512. Worth keeping as the default.

Numerical: v3 microbench correctness max_abs=0.0625, rel_rmse=3.8e-4
vs the FP32 reference. Production activations: still within sentinel
tolerance (rel-RMSE ≤ 0.5 vs FP8 baseline).

* Revert "Combine reduce: HFMA path (FP16 accumulator + fma.f16x2)"

This reverts commit 48e8101.

---------

Co-authored-by: pranjalssh <adkz.photos@gmail.com>
(cherry picked from commit 8fc78b4)
Co-authored-by: b8zhong <b8zhong@users.noreply.github.com>
Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
Co-authored-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
Co-authored-by: Brayden Zhong <brayden@radixark.ai>
Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…with setuptools>=77

Co-authored-by: Brayden Zhong <brayden@radixark.ai>
Adds sm100_fp4_mega_moe_impl, a packed-FP4-activation x packed-FP4-weight
specialization of the FP8/FP4 mega-MoE kernel using tcgen05.mma.kind::mxf4
with packed operands (2 elements/byte in smem and the L1/L2 activation pools).
Both activations and weights are packed E2M1; the L1 epilogue casts SwiGLU
output to E2M1 + UE8M0 SF. Includes the standalone sm100_fp4_fp4_gemm_1d1d
kernel that validated the packed-FP4 mainloop.

The kernel carries a one-code-path BLOCK_K generalization: BLOCK_K=128 uses
64B swizzle / one SF column / one UMMA sub-chunk; BLOCK_K=256 makes a packed
row 128B, enabling 128B swizzle and halving the K-iteration (and barrier)
count via a two-level umma_k loop with per-column UTCCP. The per-band BLOCK_K
selection (256 for the 32/64/192 token-per-expert bands, 128 otherwise) is
kernel-level +4-10% over the FP8/FP4 kUseFp4Acts (0.1.0 W4A4) mode at ntok
256-8192, with bit-exact numerics; vs the FP8-acts (W4A8) path it is +8-12%.
Validated on 4x B200 (DSV4-Flash shape, TP4/DP4).

The runtime_utils make_tma_sf_desc gains an sf_box_outer_dim param (default 1,
all existing callers unchanged) to deliver multiple SF columns per TMA at
BLOCK_K=256. The MXFP4 MMA wrapper is aliased onto the canonical mxf4 PTX
structs already present in tcgen05.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
…affolds

Wires the packed FP4xFP4 kernel into the host runtime and dispatch:
  - sm100_fp4_mega_moe.hpp builds the packed-FP4 TMA descriptors (acts, weights,
    L1/L2 pools all packed E2M1) and JIT-launches the kernel. Carries the
    per-band BLOCK_K default table (block_m in {32,64,192} -> BLOCK_K=256) and
    off-by-default profiling knobs DG_BLOCK_K (force BLOCK_K) and DG_NUM_STAGES
    (clamp the auto stage count), used for the in-process A/B measurements; both
    unset reproduce the band-default configs.
  - apis/mega.hpp routes to the packed kernel when DG_MEGA_MOE_FP4=1 and sizes
    the L2 intermediate activation pool at FP4 width (half) accordingly. Default
    off -> dev's existing FP8/FP4 kUseFp4Acts dispatch and symm-buffer sizing are
    byte-identical. Reached through the existing fp8_fp4_mega_moe entry point, so
    no new python/tvm-ffi binding is needed.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Extends test_mega_moe.py with the packed FP4xFP4 path: per-token FP4 weight/act
casts, a --shape-ntoks bench-sweep mode with per-shape clock ramp-up, and the
in-process A/B harness (DG_BLOCKK_AB / DG_STAGES_AB) that interleaves the
BLOCK_K and stage-count overrides per rep and tags result lines with bk=/st=.
The fused path is reached via deep_gemm.fp8_fp4_mega_moe with DG_MEGA_MOE_FP4=1;
all APIs used already exist on dev.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants