Skip to content

Add SM90 MegaMoE by pingpong and cooperative#360

Open
mpdfdfl wants to merge 18 commits into
deepseek-ai:mainfrom
mpdfdfl:sm90-mega-moe-pingpong-coop
Open

Add SM90 MegaMoE by pingpong and cooperative#360
mpdfdfl wants to merge 18 commits into
deepseek-ai:mainfrom
mpdfdfl:sm90-mega-moe-pingpong-coop

Conversation

@mpdfdfl

@mpdfdfl mpdfdfl commented Jun 13, 2026

Copy link
Copy Markdown

Add SM90 MegaMoE by pingpong and cooperative

First, thanks to #352 and #323 — this work builds on the ideas from those PRs.

Our implementation differs slightly: we provide the SM90 fused MegaMoE through
two kernels, picked automatically by per-rank token count:

  • pingpong (BLOCK_M=64, small/medium M): each math warpgroup owns a whole
    tile and the two warpgroups run half a cycle out of phase — one does WGMMA
    while the other runs its epilogue — handing off via an OrderedSequenceBarrier.
  • cooperative (BLOCK_M=128, large M): the two math warpgroups cooperatively
    M-split one tile and share a single weight load

fp8_mega_moe routes < 256 tokens → pingpong, >= 256 → cooperative (threshold via DG_SM90_MOE_COOPERATIVE_THRESHOLD);
fp8_mega_moe_pingpong / fp8_mega_moe_cooperative force a single variant.

The SM90 kernel / scheduler / heuristics are kept separate from the shared SM100
files, so SM100 behavior is unchanged.

Correctness test

tests/test_mega_moe_sm90.py checks the fused kernel against a PyTorch reference
(calc_diff < 0.01) over layered scenarios (smoke → heuristic bands → shapes →
edge cases → random stress).

python tests/test_mega_moe_sm90.py --num-processes 8

Benchmark

tests/bench_mega_moe_sm90.py times the fused kernel end-to-end against the
DeepEP + DeepGEMM unfused baselines (V1 contiguous, V1 low-latency, V2
ElasticBuffer), same CUDA-event timing for all. Full results in
docs/sm90_megamoe_bench.md.

NVSHMEM_IBGDA_ENABLE=0 NVSHMEM_DISABLE_IBGDA=1 EP_DISABLE_GIN=1 \
python tests/bench_mega_moe_sm90.py --num-processes 8 \
  --hidden 4096 --intermediate-hidden 2048 --num-experts 256 --num-topk 8 \
  --baseline --baseline-version both \
  --batches 16 64 256 512 1024 4096 8192

DeepSeek-V4 Flash (hidden=4096, intermediate=2048, experts=256, topk=6)

tokens fused (us) TFLOPS HBM (GB/s) v1-contig (us) v1-contig × v1-ll (us) v1-ll × v2 (us) v2 ×
16 240.8 19.4 3247 1529.9 6.35x 268.7 1.12x 332.7 1.38x
64 250.6 70.7 3237 1507.7 6.02x 296.2 1.18x 351.0 1.40x
256 326.8 238.1 2542 1525.0 4.67x 347.6 1.06x 399.3 1.22x
512 347.1 450.5 2467 1576.1 4.54x 473.7 1.36x 476.7 1.37x
1024 593.4 521.1 1527 2074.5 3.50x 721.3 1.22x 687.5 1.16x
4096 1817.8 677.6 664 4445.3 2.45x skipped 2141.2 1.18x
8192 3605.2 684.3 446 8022.2 2.23x skipped 4059.3 1.13x

Faster than all baselines at every token count.

DeepSeek-V4 Pro (hidden=7168, intermediate=3072, experts=384, topk=6)

tokens fused (us) TFLOPS HBM (GB/s) v1-contig (us) v1-contig × v1-ll (us) v1-ll × v2 (us) v2 ×
16 748.3 17.1 3888 2813.4 3.76x 813.0 1.09x 883.6 1.18x
64 821.3 59.2 3873 2969.5 3.62x 908.3 1.11x 979.6 1.19x
256 1039.5 192.9 3091 2513.1 2.42x 1017.8 0.98x 1065.2 1.02x
512 1104.4 372.6 2949 2587.5 2.34x 1187.1 1.07x 1170.8 1.06x
1024 1701.7 486.9 1965 3329.4 1.96x 1658.1 0.97x 1626.0 0.96x
4096 4988.6 647.3 771 7783.4 1.56x skipped 4628.4 0.93x
8192 10123.4 640.0 447 13612.2 1.34x skipped 8925.8 0.88x

For this (largest) shape the fused kernel is faster than DeepEP V2 at small batch
(token-per-rank ≤ 512), but at token-per-rank ≥ 1024 it falls behind V2
(0.96x → 0.88x as the batch grows) because large-M becomes DRAM-bandwidth bound
on the weight re-reads. It stays well ahead of the V1 contiguous baseline
everywhere.

MiMo-V2.5 (hidden=4096, intermediate=2048, experts=256, topk=8)

tokens fused (us) TFLOPS HBM (GB/s) v1-contig (us) v1-contig × v1-ll (us) v1-ll × v2 (us) v2 ×
16 244.9 25.3 3297 1588.1 6.49x 273.7 1.12x 340.2 1.39x
64 253.9 95.6 3203 1499.9 5.91x 301.4 1.19x 367.4 1.45x
256 332.4 317.8 2526 1607.7 4.84x 391.1 1.18x 429.9 1.29x
512 481.8 432.3 1812 2162.3 4.49x 554.7 1.15x 563.9 1.17x
1024 741.8 553.4 1266 2762.5 3.72x 899.3 1.21x 873.0 1.18x
4096 2323.3 708.6 577 5330.1 2.29x skipped 2765.5 1.19x
8192 4795.0 686.3 391 9624.3 2.01x skipped 5330.6 1.11x

Faster than all baselines at every token count.

MiMo-V2.5-Pro (hidden=6144, intermediate=2048, experts=384, topk=8)

tokens fused (us) TFLOPS HBM (GB/s) v1-contig (us) v1-contig × v1-ll (us) v1-ll × v2 (us) v2 ×
16 476.1 20.3 3336 1992.6 4.19x 525.7 1.10x 587.5 1.23x
64 509.2 72.9 3580 2083.2 4.09x 575.4 1.13x 635.4 1.25x
256 659.5 236.2 2818 2105.8 3.19x 688.5 1.04x 730.5 1.11x
512 700.2 438.1 2718 2216.8 3.17x 874.5 1.25x 864.3 1.23x
1024 1259.8 489.7 1584 3578.8 2.84x 1388.4 1.10x 1334.7 1.06x
4096 3714.8 671.1 688 7806.0 2.10x skipped 4112.6 1.11x
8192 7772.7 638.5 424 13098.7 1.69x skipped 7995.5 1.03x

Faster than all baselines at every token count.


Summary

Across DeepSeek-V4 Flash, MiMo-V2.5 and MiMo-V2.5-Pro the fused kernel beats all
three baselines at every measured token count. The only regression is
DeepSeek-V4 Pro at token-per-rank ≥ 1024, where it trails DeepEP V2 (down to
0.88x at 8192) because that shape's large-M GEMMs become DRAM-bandwidth bound; it
remains a clear win at smaller batches and over the V1 baselines throughout.

TODO / future work

This is just a first version, and there is still a lot of room to optimize
inside the kernels. We would really love your help — let's optimize this
kernel together!
Planned directions:

  • Block-tile scheduling. The L2 cache hit rate is currently low (~40%);
    tuning the block-tile / expert-wave scheduling strategy should reduce HBM
    traffic and help the large-M DRAM-bandwidth-bound cases.
  • Green-context split-kernel. Explore the approach from PR Draft: Add green-context split-kernel MegaMoE features #357 — split the
    MoE into focused kernels (dispatch+L1+SwiGLU / L2+combine / combine-reduce)
    running concurrently on disjoint SM partitions via CUDA green contexts — to
    overlap stages the fused megakernel serializes internally.

熊梦轩 and others added 15 commits June 12, 2026 19:26
Two SM90 (Hopper) FP8 MegaMoE kernels with their own scheduler and
heuristics, kept isolated from the shared SM100 mega_moe scheduler/heuristics
so the SM100 path is untouched:

  * pingpong (BLOCK_M=64): one math warpgroup per tile, the two warpgroups
    overlap one's MMA with the other's epilogue via an OrderedSequenceBarrier.
    Tuned for small/medium M.
  * cooperative (BLOCK_M=128): the two math warpgroups cooperatively M-split a
    single tile and share one B-tile load, halving weight HBM traffic; a
    256-thread cross-warpgroup barrier closes the L2 epilogue. Tuned for large M.

New files: impls/sm90_fp8_mega_moe_{pingpong,cooperative}.cuh, the SM90 host
runtimes, scheduler/sm90_mega_moe.cuh (adds kClusterSize / kL2NMajorSchedule),
and heuristics/sm90_mega_moe.hpp.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Wire the two SM90 kernels into the host API and Python:

  * csrc/apis/mega.hpp: SM90-aware get_symm_buffer_size_for_mega_moe (per-arch
    SF dtype/granularity, guarded by is_sm90 so the SM100 path is unchanged);
    fp8_mega_moe routes by token count (DG_SM90_MOE_COOPERATIVE_THRESHOLD,
    default 256) between pingpong (<threshold) and cooperative (>=threshold);
    plus fp8_mega_moe_{pingpong,cooperative} forced entry points for A/B.
  * deep_gemm/mega/__init__.py + deep_gemm/__init__.py: expose fp8_mega_moe,
    fp8_mega_moe_{pingpong,cooperative} and transform_weights_for_mega_moe_sm90.
  * comm/barrier.cuh: guard the NVLink-timeout printf behind DG_NO_DEVICE_PRINTF
    (defined by the SM90 kernels) to avoid the ptxas C7510 WGMMA-pipeline
    serialization a function-call boundary would cause; upstream 300s timeout
    is preserved.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
  * tests/test_mega_moe_sm90.py: layered correctness suite (L1 smoke .. L5
    stress) against a PyTorch reference via calc_diff. DG_SM90_MOE_KERNEL
    selects auto / pingpong / cooperative.
  * tests/bench_mega_moe_sm90.py: per-config TFLOPS / HBM / NVLink timing with
    an optional DeepEP (V1 contiguous + V2 ElasticBuffer) baseline comparison.
  * tests/_deepep_v1_baseline.py: the DeepEP V1 contiguous baseline used by the
    bench (Triton SwiGLU + FP8 quant + grouped GEMM pipeline).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Two SM90 (Hopper) FP8 MegaMoE kernels with their own scheduler and
heuristics, kept isolated from the shared SM100 mega_moe scheduler/heuristics
so the SM100 path is untouched:

  * pingpong (BLOCK_M=64): one math warpgroup per tile, the two warpgroups
    overlap one's MMA with the other's epilogue via an OrderedSequenceBarrier.
    Tuned for small/medium M.
  * cooperative (BLOCK_M=128): the two math warpgroups cooperatively M-split a
    single tile and share one B-tile load, halving weight HBM traffic; a
    256-thread cross-warpgroup barrier closes the L2 epilogue. Tuned for large M.

New files: impls/sm90_fp8_mega_moe_{pingpong,cooperative}.cuh, the SM90 host
runtimes, scheduler/sm90_mega_moe.cuh (adds kClusterSize / kL2NMajorSchedule),
and heuristics/sm90_mega_moe.hpp.
Wire the two SM90 kernels into the host API and Python:

  * csrc/apis/mega.hpp: SM90-aware get_symm_buffer_size_for_mega_moe (per-arch
    SF dtype/granularity, guarded by is_sm90 so the SM100 path is unchanged);
    fp8_mega_moe routes by token count (DG_SM90_MOE_COOPERATIVE_THRESHOLD,
    default 256) between pingpong (<threshold) and cooperative (>=threshold);
    plus fp8_mega_moe_{pingpong,cooperative} forced entry points for A/B.
  * deep_gemm/mega/__init__.py + deep_gemm/__init__.py: expose fp8_mega_moe,
    fp8_mega_moe_{pingpong,cooperative} and transform_weights_for_mega_moe_sm90.
  * comm/barrier.cuh: guard the NVLink-timeout printf behind DG_NO_DEVICE_PRINTF
    (defined by the SM90 kernels) to avoid the ptxas C7510 WGMMA-pipeline
    serialization a function-call boundary would cause; upstream 300s timeout
    is preserved.
  * tests/test_mega_moe_sm90.py: layered correctness suite (L1 smoke .. L5
    stress) against a PyTorch reference via calc_diff. DG_SM90_MOE_KERNEL
    selects auto / pingpong / cooperative.
  * tests/bench_mega_moe_sm90.py: per-config TFLOPS / HBM / NVLink timing with
    an optional DeepEP (V1 contiguous + V2 ElasticBuffer) baseline comparison.
  * tests/_deepep_v1_baseline.py: the DeepEP V1 contiguous baseline used by the
    bench (Triton SwiGLU + FP8 quant + grouped GEMM pipeline).
@mpdfdfl mpdfdfl force-pushed the sm90-mega-moe-pingpong-coop branch from 1380676 to ec757bd Compare June 17, 2026 08:54
@Rachmanino

Copy link
Copy Markdown

hi, which machine are you working on, h20, h100 or h200?

@mpdfdfl

mpdfdfl commented Jun 17, 2026

Copy link
Copy Markdown
Author

hi, which machine are you working on, h20, h100 or h200?

we working on H200

熊梦轩 added 3 commits June 17, 2026 22:31
Replace mega MoE implementation with the per-128 N-split cooperative
kernel; keep remote .gitignore.
@mpdfdfl

mpdfdfl commented Jun 26, 2026

Copy link
Copy Markdown
Author

We discovered an accuracy issue during end-to-end testing. We suspect the cause is that the L1 output of the MoE layer is normally quantized with per-128, whereas my previous implementation used per-64. To fix this, we chose to implement this kernel using only the cooperative approach, where the first warp group handles the left 128 columns and the other handles the right 128 columns, thereby achieving per-128 quantization. At the same time, we were pleasantly surprised to find that, with this layout, the cooperative implementation also delivers better performance at small shapes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants