Add SM90 MegaMoE by pingpong and cooperative#360
Open
mpdfdfl wants to merge 18 commits into
Open
Conversation
Two SM90 (Hopper) FP8 MegaMoE kernels with their own scheduler and
heuristics, kept isolated from the shared SM100 mega_moe scheduler/heuristics
so the SM100 path is untouched:
* pingpong (BLOCK_M=64): one math warpgroup per tile, the two warpgroups
overlap one's MMA with the other's epilogue via an OrderedSequenceBarrier.
Tuned for small/medium M.
* cooperative (BLOCK_M=128): the two math warpgroups cooperatively M-split a
single tile and share one B-tile load, halving weight HBM traffic; a
256-thread cross-warpgroup barrier closes the L2 epilogue. Tuned for large M.
New files: impls/sm90_fp8_mega_moe_{pingpong,cooperative}.cuh, the SM90 host
runtimes, scheduler/sm90_mega_moe.cuh (adds kClusterSize / kL2NMajorSchedule),
and heuristics/sm90_mega_moe.hpp.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Wire the two SM90 kernels into the host API and Python:
* csrc/apis/mega.hpp: SM90-aware get_symm_buffer_size_for_mega_moe (per-arch
SF dtype/granularity, guarded by is_sm90 so the SM100 path is unchanged);
fp8_mega_moe routes by token count (DG_SM90_MOE_COOPERATIVE_THRESHOLD,
default 256) between pingpong (<threshold) and cooperative (>=threshold);
plus fp8_mega_moe_{pingpong,cooperative} forced entry points for A/B.
* deep_gemm/mega/__init__.py + deep_gemm/__init__.py: expose fp8_mega_moe,
fp8_mega_moe_{pingpong,cooperative} and transform_weights_for_mega_moe_sm90.
* comm/barrier.cuh: guard the NVLink-timeout printf behind DG_NO_DEVICE_PRINTF
(defined by the SM90 kernels) to avoid the ptxas C7510 WGMMA-pipeline
serialization a function-call boundary would cause; upstream 300s timeout
is preserved.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
* tests/test_mega_moe_sm90.py: layered correctness suite (L1 smoke .. L5
stress) against a PyTorch reference via calc_diff. DG_SM90_MOE_KERNEL
selects auto / pingpong / cooperative.
* tests/bench_mega_moe_sm90.py: per-config TFLOPS / HBM / NVLink timing with
an optional DeepEP (V1 contiguous + V2 ElasticBuffer) baseline comparison.
* tests/_deepep_v1_baseline.py: the DeepEP V1 contiguous baseline used by the
bench (Triton SwiGLU + FP8 quant + grouped GEMM pipeline).
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Two SM90 (Hopper) FP8 MegaMoE kernels with their own scheduler and
heuristics, kept isolated from the shared SM100 mega_moe scheduler/heuristics
so the SM100 path is untouched:
* pingpong (BLOCK_M=64): one math warpgroup per tile, the two warpgroups
overlap one's MMA with the other's epilogue via an OrderedSequenceBarrier.
Tuned for small/medium M.
* cooperative (BLOCK_M=128): the two math warpgroups cooperatively M-split a
single tile and share one B-tile load, halving weight HBM traffic; a
256-thread cross-warpgroup barrier closes the L2 epilogue. Tuned for large M.
New files: impls/sm90_fp8_mega_moe_{pingpong,cooperative}.cuh, the SM90 host
runtimes, scheduler/sm90_mega_moe.cuh (adds kClusterSize / kL2NMajorSchedule),
and heuristics/sm90_mega_moe.hpp.
Wire the two SM90 kernels into the host API and Python:
* csrc/apis/mega.hpp: SM90-aware get_symm_buffer_size_for_mega_moe (per-arch
SF dtype/granularity, guarded by is_sm90 so the SM100 path is unchanged);
fp8_mega_moe routes by token count (DG_SM90_MOE_COOPERATIVE_THRESHOLD,
default 256) between pingpong (<threshold) and cooperative (>=threshold);
plus fp8_mega_moe_{pingpong,cooperative} forced entry points for A/B.
* deep_gemm/mega/__init__.py + deep_gemm/__init__.py: expose fp8_mega_moe,
fp8_mega_moe_{pingpong,cooperative} and transform_weights_for_mega_moe_sm90.
* comm/barrier.cuh: guard the NVLink-timeout printf behind DG_NO_DEVICE_PRINTF
(defined by the SM90 kernels) to avoid the ptxas C7510 WGMMA-pipeline
serialization a function-call boundary would cause; upstream 300s timeout
is preserved.
* tests/test_mega_moe_sm90.py: layered correctness suite (L1 smoke .. L5
stress) against a PyTorch reference via calc_diff. DG_SM90_MOE_KERNEL
selects auto / pingpong / cooperative.
* tests/bench_mega_moe_sm90.py: per-config TFLOPS / HBM / NVLink timing with
an optional DeepEP (V1 contiguous + V2 ElasticBuffer) baseline comparison.
* tests/_deepep_v1_baseline.py: the DeepEP V1 contiguous baseline used by the
bench (Triton SwiGLU + FP8 quant + grouped GEMM pipeline).
1380676 to
ec757bd
Compare
|
hi, which machine are you working on, h20, h100 or h200? |
Author
we working on H200 |
added 3 commits
June 17, 2026 22:31
Replace mega MoE implementation with the per-128 N-split cooperative kernel; keep remote .gitignore.
Author
|
We discovered an accuracy issue during end-to-end testing. We suspect the cause is that the L1 output of the MoE layer is normally quantized with per-128, whereas my previous implementation used per-64. To fix this, we chose to implement this kernel using only the cooperative approach, where the first warp group handles the left 128 columns and the other handles the right 128 columns, thereby achieving per-128 quantization. At the same time, we were pleasantly surprised to find that, with this layout, the cooperative implementation also delivers better performance at small shapes. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Add SM90 MegaMoE by pingpong and cooperative
First, thanks to #352 and #323 — this work builds on the ideas from those PRs.
Our implementation differs slightly: we provide the SM90 fused MegaMoE through
two kernels, picked automatically by per-rank token count:
BLOCK_M=64, small/medium M): each math warpgroup owns a wholetile and the two warpgroups run half a cycle out of phase — one does WGMMA
while the other runs its epilogue — handing off via an
OrderedSequenceBarrier.BLOCK_M=128, large M): the two math warpgroups cooperativelyM-split one tile and share a single weight load
fp8_mega_moeroutes< 256tokens → pingpong,>= 256→ cooperative (threshold viaDG_SM90_MOE_COOPERATIVE_THRESHOLD);fp8_mega_moe_pingpong/fp8_mega_moe_cooperativeforce a single variant.The SM90 kernel / scheduler / heuristics are kept separate from the shared SM100
files, so SM100 behavior is unchanged.
Correctness test
tests/test_mega_moe_sm90.pychecks the fused kernel against a PyTorch reference(
calc_diff < 0.01) over layered scenarios (smoke → heuristic bands → shapes →edge cases → random stress).
Benchmark
tests/bench_mega_moe_sm90.pytimes the fused kernel end-to-end against theDeepEP + DeepGEMM unfused baselines (V1 contiguous, V1 low-latency, V2
ElasticBuffer), same CUDA-event timing for all. Full results in
docs/sm90_megamoe_bench.md.DeepSeek-V4 Flash (hidden=4096, intermediate=2048, experts=256, topk=6)
Faster than all baselines at every token count.
DeepSeek-V4 Pro (hidden=7168, intermediate=3072, experts=384, topk=6)
For this (largest) shape the fused kernel is faster than DeepEP V2 at small batch
(token-per-rank ≤ 512), but at token-per-rank ≥ 1024 it falls behind V2
(0.96x → 0.88x as the batch grows) because large-M becomes DRAM-bandwidth bound
on the weight re-reads. It stays well ahead of the V1 contiguous baseline
everywhere.
MiMo-V2.5 (hidden=4096, intermediate=2048, experts=256, topk=8)
Faster than all baselines at every token count.
MiMo-V2.5-Pro (hidden=6144, intermediate=2048, experts=384, topk=8)
Faster than all baselines at every token count.
Summary
Across DeepSeek-V4 Flash, MiMo-V2.5 and MiMo-V2.5-Pro the fused kernel beats all
three baselines at every measured token count. The only regression is
DeepSeek-V4 Pro at token-per-rank ≥ 1024, where it trails DeepEP V2 (down to
0.88x at 8192) because that shape's large-M GEMMs become DRAM-bandwidth bound; it
remains a clear win at smaller batches and over the V1 baselines throughout.
TODO / future work
This is just a first version, and there is still a lot of room to optimize
inside the kernels. We would really love your help — let's optimize this
kernel together! Planned directions:
tuning the block-tile / expert-wave scheduling strategy should reduce HBM
traffic and help the large-M DRAM-bandwidth-bound cases.
MoE into focused kernels (dispatch+L1+SwiGLU / L2+combine / combine-reduce)
running concurrently on disjoint SM partitions via CUDA green contexts — to
overlap stages the fused megakernel serializes internally.