[sm100] Packed FP4×FP4 mega-MoE kernel (W4A4) with per-band BLOCK_K#44
Open
Romaosir wants to merge 22 commits into
Open
[sm100] Packed FP4×FP4 mega-MoE kernel (W4A4) with per-band BLOCK_K#44Romaosir wants to merge 22 commits into
Romaosir wants to merge 22 commits into
Conversation
Co-Authored-By: rainj-me <rain-jiang@outlook.com>
…el (sgl-project#27) Two related additions for the DeepSeek-V4-Pro mega-MoE path: 1. **FP4 (E2M1) activations + `kind::mxf4` mainloop opt-in** for `fp8_fp4_mega_moe`. - `DG_USE_FP4_ACTS=1` halves the symm-buffer x-slot footprint (E2M1 nibbles vs E4M3 bytes); SF slot unchanged (still `hidden/32` UE8M0 bytes under gran_k=32). - `use_mxf4_kind=true` switches the L1+L2 mainloops to `cta_group::2 kind::mxf4` (2-CTA cluster) with dense FP4 smem layout (`_ALIGN8B`, 2 nibbles/byte). Per-stage A/B byte footprint halves → num_stages doubles for the same smem budget. - Threads `cumulative_local_expert_recv_stats` through the public mega-MoE API for per-rank expert counters used by sglang's expert-distribution recorder. - Block-m heuristic: under `use_mxf4_kind`, bumps `block_m=16 → 32` for the smallest-tokens-per-expert bucket so `load_block_m * block_k / 2` meets the 1024-byte smem alignment. - Multi-block_m support via `kCandidateBlockM` array + LCM-aligned pool padding; replaces the static `block_m=192` heuristic with token-density dispatch (8/16/32/64/96/128/192). 2. **`mega_moe_pre_dispatch` kernel**: BF16 → quant + topk-copy + pad-fill in one launch, gated on `kUseFp4Acts` + `kUsePDL`. Templated on `(kGroupSize, kUseFp4Acts, kUsePDL)`. Uses bucketize-style E2M1 encoder for byte-exact match against the `per_token_cast_to_fp4` host helper. - New: `deep_gemm.mega_moe_pre_dispatch(x, topk_idx, topk_weights, buf_x, buf_x_sf, buf_topk_idx, buf_topk_weights, num_tokens, group_size, use_fp4_acts)` - Test: `tests/test_mega_moe_pre_dispatch.py` — single-GPU bytewise check against host `per_token_cast_to_fp{8,4}` + pad-fill assertion. Validated end-to-end on 8× B300 with DeepSeek-V4-Pro at 8K input bench: - FP4 acts + MXF4 kind path produces matching tokens vs the FP8 baseline (rel-RMSE ≤ 0.5 sentinel; GSM8K accuracy parity within run-to-run variance). PR also includes existing FP4-mega-MoE supporting changes that are required by the kernel: - `cluster_sync_with_relaxed_arrive` helper (used twice in `sm100_fp8_fp4_mega_moe.cuh`). - `cvt_pack_f32_to_e2m1x2` / `cvt_pack_f32x4_to_e2m1x4` PTX wrappers. - `SM100_MMA_MXF4_2x1SM_SS` 2-CTA cluster MMA wrapper. - Generalized `red_add(int*, int)` for the `cumulative_local_expert_recv_stats` counter. - `st.L1::no_allocate.relaxed.sys.global.u64` (correctness fix: previous generic-address variant could miss the global state space). Co-authored-by: pranjalssh <adkz.photos@gmail.com> (cherry picked from commit bca278e)
…bine path) (sgl-project#28) * Add DG_USE_FP8_COMBINE: FP8 + per-row UE8M0 SF on the second a2a (combine path) The mega-MoE second all-to-all (combine) currently ships BF16 over NVLink: each token, each topk slot = kHidden * 2 bytes. This commit adds an env- gated FP8 path that ships FP8 E4M3 + a per-(token, N=128) UE8M0 SF byte — kHidden + kHidden/128 bytes per token per slot, half the NVLink bytes. Wiring: - New `kUseFp8Combine` template flag (default false → keeps BF16 path byte-identical when off). - New `combine_sf_buffer` symm-buffer slot, sized kHidden/128 bytes per (token, slot) when on, zero when off. - Host: `DG_USE_FP8_COMBINE=1` env flag in `mega.hpp`. Independent of `DG_USE_FP4_ACTS` / `DG_USE_MXF4_KIND` (those control the dispatch a2a + mainloops; this controls the combine a2a only). Producer side (L2 epilogue write-back, sm100_fp8_fp4_mega_moe.cuh): - Read 8 BF16 from smem (existing STSM target). - Compute per-row amax via `__shfl_xor_sync` reduction over the 16 lanes that share each row tile. Use a 16-lane mask (NOT 0xffffffff) — the outer `if (m_idx_in_block >= valid_m) break` may cause the OTHER half- warp to exit on padding rows, and a full-warp shfl would deadlock. - Compute UE8M0 SF (E4M3 finfo_max=448, mirrors `get_e4m3_sf_and_sf_inv`). - Cast 8 BF16 → 8 FP8 via `__nv_fp8x4_e4m3(float4)` ×2; pack into uint64. - Write 8 FP8 bytes to remote (vs 16 BF16 bytes). Lane 0 of the 16-lane group writes the SF byte to `combine_sf_buffer`. Consumer side (combine reduce): - Per-slot SF base ptr cached at slot start. - TMA-load FP8 chunk (kNumChunkBytes / 2 bytes when kUseFp8Combine). - Per uint4 (16 FP8): __ldg the SF byte for the segment; FP8 → FP16x2 via `cvt.rn.f16x2.e4m3x2`, FP16 → FP32 via `cvt.f32.f16`, then `__fmaf_rn(val, sf, acc)` for the accumulate-with-dequant. - BF16 store-buffer layout for FP8 path: 2 BF16 uint4 per input uint4 (16 elements → 2 × 8 BF16 stripes), at indices (j*32+lane)*2 + {0,1}. Total store uint4/lane same as BF16 path (kNumChunkUint4Bf16 / 32). Validation: - Microbench (`ptx/d_combine_reduce_v{1,2}_*`): - v1 BF16 baseline: 6,895 cycles/token, max_abs=0 (perfect). - v2 FP8 + UE8M0 SF: correctness PASS (max_abs=0 vs host reference that uses the same FP8 quant), 50% NVLink bytes savings. - Single-GPU iso bench (8x B300, fp4_mxf4 vs fp4_mxf4+combine): - b=128: 364 us → 359 us (+1.5%) - b=512: 377 us → 386 us (-2.2%) - b=2048: 710 us → 739 us (-3.9%) Single-GPU is compute-bound (no NVLink saving); production is the point of the change. - E2E DeepSeek-V4-Pro on 8x B300 (b=8192 input, 1024 output): - b=512: 91.92 s (FP8) → 78.37 s (FP4+MXF4+FP8combine) — +17.3% - b=2048: 259.4 s (FP8) → 238.2 s — +8.9% - b=4096: 489.5 s (FP8) → 444.2 s — +10.2% Sentinel test (FP4 acts vs FP8 acts): rel-RMSE <= 0.5 still passes. Numerical: rel-RMSE on synthetic random init = 0.027 (combine FP8 vs BF16 baseline, w/o SwiGLU clamping → tail outliers). Real activations post-SwiGLU + topk-weighting are bounded; production accuracy parity preserved (same GSM8K results as FP4 baseline). * Combine reduce: HFMA path (FP16 accumulator + fma.f16x2) Switch the FP8 combine reduce inner loop from FP32 accumulator + scalar fma to FP16x2 accumulator + hfma.f16x2. Halves the per-element op count and halves the accumulator register pressure (94 regs vs 138 regs). Inner loop, before: cvt.rn.f16x2.e4m3x2 (FP8x2 → FP16x2) cvt.f32.f16 ×2 (FP16 → FP32) fma.rn.f32 ×2 (acc += sf_f32 * f32_val) = 5 ops per FP8x2 (= 2 elements) After: cvt.rn.f16x2.e4m3x2 (FP8x2 → FP16x2) fma.rn.f16x2 (acc_fp16x2 += sf_pair * f16x2) = 2 ops per FP8x2 SF in FP16: UE8M0 byte → 1.0 * 2^(byte-127), packed as FP16 with bias 15. Out-of-range SFs (byte < 112 or > 142) clamp to 0 / FP16-max — production activations post-SwiGLU + topk-weighting fit comfortably in FP16 range. End cast: FP16x2 → __half22float2 → __float22bfloat162_rn for the gmem write-back (BF16 output unchanged). Microbench (`ptx/d_combine_reduce_v3_fp8_hfma`): v1 BF16 baseline: 6,895 cycles/token v2 FP8 + FP32 acc: 10,797 cycles/token (+57% vs v1) v3 FP8 + FP16 HFMA: **5,799 cycles/token (-16% vs v1, -46% vs v2)** E2E DeepSeek-V4-Pro 8x B300, 8K input + 1024 output: | batch | FP4+MXF4 | combine FP32 | combine HFMA | |------:|---------:|-------------:|-------------:| | 512 | — | 7,526 | 7,350 | | 2048 | 9,814 | 9,903 | **9,992** | | 4096 | 10,418 | 10,622 | **10,699** | HFMA wins at 2048/4096; ~tie at 512. Worth keeping as the default. Numerical: v3 microbench correctness max_abs=0.0625, rel_rmse=3.8e-4 vs the FP32 reference. Production activations: still within sentinel tolerance (rel-RMSE ≤ 0.5 vs FP8 baseline). * Revert "Combine reduce: HFMA path (FP16 accumulator + fma.f16x2)" This reverts commit 48e8101. --------- Co-authored-by: pranjalssh <adkz.photos@gmail.com> (cherry picked from commit 8fc78b4)
Co-authored-by: b8zhong <b8zhong@users.noreply.github.com>
Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com> Co-authored-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
Co-authored-by: Brayden Zhong <brayden@radixark.ai> Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…with setuptools>=77 Co-authored-by: Brayden Zhong <brayden@radixark.ai>
Adds sm100_fp4_mega_moe_impl, a packed-FP4-activation x packed-FP4-weight specialization of the FP8/FP4 mega-MoE kernel using tcgen05.mma.kind::mxf4 with packed operands (2 elements/byte in smem and the L1/L2 activation pools). Both activations and weights are packed E2M1; the L1 epilogue casts SwiGLU output to E2M1 + UE8M0 SF. Includes the standalone sm100_fp4_fp4_gemm_1d1d kernel that validated the packed-FP4 mainloop. The kernel carries a one-code-path BLOCK_K generalization: BLOCK_K=128 uses 64B swizzle / one SF column / one UMMA sub-chunk; BLOCK_K=256 makes a packed row 128B, enabling 128B swizzle and halving the K-iteration (and barrier) count via a two-level umma_k loop with per-column UTCCP. The per-band BLOCK_K selection (256 for the 32/64/192 token-per-expert bands, 128 otherwise) is kernel-level +4-10% over the FP8/FP4 kUseFp4Acts (0.1.0 W4A4) mode at ntok 256-8192, with bit-exact numerics; vs the FP8-acts (W4A8) path it is +8-12%. Validated on 4x B200 (DSV4-Flash shape, TP4/DP4). The runtime_utils make_tma_sf_desc gains an sf_box_outer_dim param (default 1, all existing callers unchanged) to deliver multiple SF columns per TMA at BLOCK_K=256. The MXFP4 MMA wrapper is aliased onto the canonical mxf4 PTX structs already present in tcgen05. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
…affolds
Wires the packed FP4xFP4 kernel into the host runtime and dispatch:
- sm100_fp4_mega_moe.hpp builds the packed-FP4 TMA descriptors (acts, weights,
L1/L2 pools all packed E2M1) and JIT-launches the kernel. Carries the
per-band BLOCK_K default table (block_m in {32,64,192} -> BLOCK_K=256) and
off-by-default profiling knobs DG_BLOCK_K (force BLOCK_K) and DG_NUM_STAGES
(clamp the auto stage count), used for the in-process A/B measurements; both
unset reproduce the band-default configs.
- apis/mega.hpp routes to the packed kernel when DG_MEGA_MOE_FP4=1 and sizes
the L2 intermediate activation pool at FP4 width (half) accordingly. Default
off -> dev's existing FP8/FP4 kUseFp4Acts dispatch and symm-buffer sizing are
byte-identical. Reached through the existing fp8_fp4_mega_moe entry point, so
no new python/tvm-ffi binding is needed.
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Extends test_mega_moe.py with the packed FP4xFP4 path: per-token FP4 weight/act casts, a --shape-ntoks bench-sweep mode with per-shape clock ramp-up, and the in-process A/B harness (DG_BLOCKK_AB / DG_STAGES_AB) that interleaves the BLOCK_K and stage-count overrides per rep and tags result lines with bk=/st=. The fused path is reached via deep_gemm.fp8_fp4_mega_moe with DG_MEGA_MOE_FP4=1; all APIs used already exist on dev. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Adds a dedicated W4A4 (FP4 weights × FP4 activations) fused mega-MoE kernel for SM100 —
sm100_fp4_mega_moe— a packed-E2M1 specialization that issuestcgen05.mma.kind::mxf4, plus a per-band BLOCK_K selection that exploits packed-FP4's lower SMEM footprint for a measured kernel-level speedup over the existing FP8-acts and 0.1.0 W4A4 paths.Paired with sgl-project/sglang#28210 (sgl-project/sglang#28210) — version bump + one-line env propagation; the existing W4A4 env path transparently routes to this kernel, no other sglang code changes.
What's new
sm100_fp4_mega_moe_impl: packedfloat_e2m1_tA and B (no multi-mode branches), hand-built packed-byte SMEM descriptors, dedicatedpacked_fp4_2sm_tma_load_2d(byte-typed tensor maps, per-CTActa_group::2issues with leader-centralized 2× tx accounting), SwiGLU epilogue cast straight to packed FP4. Half the activation bytes + SMEM of the FP8-acts path.DG_BLOCK_K(128/256),DG_NUM_STAGES(clamp).mega_moe_pre_dispatch: packs activations to E2M1 + UE8M0 SF into the symmetric buffer.csrc/apis/mega.hpp: env-selected dispatch (DG_MEGA_MOE_FP4/DG_USE_FP4_ACTS); symm-buffer halves the activation regions under FP4 acts.--shape-ntoks/--bench-repssweep with in-process interleaved A/B (DG_BLOCKK_AB/DG_STAGES_AB/DG_WAVE_AB).How it differs from the existing W4A4 path (
sm100_fp8_fp4_mega_moe<kUseFp4Acts, kUseMxf4Kind>)Same architecture (warp roles, scheduler, comm/barriers, packed-E2M1 symm buffer) and the same final
tcgen05.mma.kind::mxf4instruction — the difference is a dedicated specialization (enabling the per-band BLOCK_K win) instead of a template-flag retrofit of the FP8 kernel:float_e2m1_tA and B; no multi-mode branchesmake_umma_descvia_unpacksmem/_ALIGN16Btma::copypacked_fp4_2sm_tma_load_2d(byte tensor maps, 64B/128B swizzle atoms, leader-centralized 2× tx)ptx::SM100_MMA_MXFP4_2x1SM_SSsm100_fp4_fp4_gemm_1d1dsingle-GPU GEMM clonePerformance (4× B200, DSV4-Flash: h=4096 / i=2048 / E=256 / topk=6; bit-exact numerics vs the BLOCK_K=128 path, torch.equal)
Per-band BLOCK_K vs all-128 — in-process 4-rank interleaved A/B, 20 reps (the precision claim):
32-band ~+10%, 64-band +2–6%, 192-band ~+4%. Mechanism (NCU): DRAM efficiency from larger contiguous transactions (64% → 85% busy at ntok=512).
vs W4A4-old — the existing 0.1.0-release W4A4 kernel (sgl-project/DeepGEMM#27,
sm100_fp8_fp4_mega_moe<kUseFp4Acts, kUseMxf4Kind>) — both kernels measured through the same binding (DG_FORCE_FP8FP4toggle) to isolate kernel from packaging:sm100_tf32_hc_prenorm_gemm, shared with both baselines).End-to-end 5-concurrency sweep (TP4+DP4, sequential, same GPU group, 1 rep/cell after per-concurrency warmup, num_prompts raised for steady state). Ours and the 0.1.0 W4A4 ("old") run via the same fork pybind; W4A8 (FP8 acts) is the 0.1.0-release reference on its own tvm-ffi binding (quant-scheme baseline, not a same-binding comparison):
Prefill (in=1024, out=1) — input throughput (tok/s) · median TTFT (ms):
Mixed (in=1024, out=256) — total throughput (tok/s) · median TPOT (ms):
Reading: the kernel win surfaces at production concurrency (c ≥ 128) where mega-MoE's share is largest — prefill input-throughput +2–7%, mixed throughput +2%, TPOT down to −11% at c=512. Low concurrency (c ≤ 4) is within the e2e noise band, and the c=1/4 mixed cells (num_prompts = 12) are under-sampled.
NCU Speed-of-Light (1-rank isolation, binding-controlled, ntok=512)
Profiling the fused 4-rank kernel with NCU is invalid (persistent kernel dominated by cross-rank NVLink-barrier spinning → SOL collapses). Single-rank removes the barrier and gives true SOL. All three arms measured at the identical shape (hidden=4096 / inter=2048 / E=256 / topk=6, ntok=512) through the same fork binding:
Ours +24.6% vs W4A4-old, +28.6% vs W4A8 (1-rank). Mechanism: same occupancy/registers (identical launch config) → the win is the data path. Ours' block_k=256 sustains 85% memory SOL / 6.54 TB/s DRAM vs W4A4-old's block_k=128 at 64% / 4.93 TB/s — the larger-contiguous-transaction DRAM-efficiency win. W4A8 carries the heaviest compute (42.5% SM, FP8 acts = 1 byte/act) and the least DRAM headroom. 1-rank over-weights the memory regime (255 experts local, ~2 tokens/expert); use it for the mechanism, the e2e sweep above for the realistic magnitude.
Accuracy (eval_v2, DSV4-Flash)
W4A4-Ours is at or above both the W4A4-old and W4A8 baselines on every eval run — the packed FP4×FP4 path + per-band BLOCK_K preserve accuracy.
Notes
torch.equalvs deep_ep + grouped GEMM) needs an internal deep_ep build withElasticBuffer; rerun in CI before merge.