feat: MegaMOE adaptation for SM90#323
Conversation
|
Do you have benchmark data? |
I’m testing the benefits of DeepSeek V4 Flash on H20, and I’ll share the data soon. |
|
看起来效果并不理想:
Performance:
"""
H200 (SM90 / Hopper) mega-MoE: fused kernel + 同管线 baseline 性能对比。
结构对齐 tests/test_mega_moe.py(B 系列 SM100 FP4 路径),但所有路径都换成 H200 FP8:
* fused:调用 `deep_gemm.fp8_mega_moe`(kernel symbol `sm90_fp8_mega_moe_impl`),
使用 `transform_weights_for_mega_moe_sm90` 处理过的权重 + SymmBuffer。
* baseline:DeepEP dispatch + 2 个 grouped FP8 GEMM + Triton SwiGLU + DeepEP combine,
使用未变换的权重。由于当前 SM90 grouped GEMM 只支持 L2 activation
per-128-K SFA,而 fused SM90 mega-MoE 的 L1 epilogue 为避免跨 CTA
同步使用 per-64-K SFA,所以该 baseline 是同管线 legacy 参照,
不是 bitwise apples-to-apples correctness oracle。
* 性能输出涵盖:TFLOPS / overlap TFLOPS / HBM GB/s / NVL GB/s / fused us /
reduction us / `t_baseline / t_fused` legacy 比。
"""
import deep_ep
import argparse
import math
import os
import random
import torch
import torch.distributed as dist
import triton
import triton.language as tl
from typing import Tuple
import deep_gemm
from deep_gemm.utils import per_token_cast_to_fp8
from deep_gemm.utils.dist import dist_print, init_dist, uneven_all_gather
from deep_gemm.testing import bench_kineto
# 与 deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh 中模板入口同名,
# bench_kineto 用它从 trace 里挑出 fused mega-MoE 的 GPU 段
SM90_KERNEL_NAME = "sm90_fp8_mega_moe_impl"
# FP8 e4m3fn 的最大可表示值,量化时用 amax / 448 作为 scale 基准
FP8_E4M3_MAX = 448.0
# 新版 Triton(>= 3.x)强制:jit 内核读到的 Python 全局必须是 tl.constexpr 实例,
# 否则编译期 NameError。宿主 Python 侧仍用上面的普通 float 做 torch 运算。
_FP8_E4M3_MAX_TL = tl.constexpr(448.0)
L1_ACT_SF_GRAN = 128
FUSED_L2_ACT_SF_GRAN = 64
BASELINE_L2_ACT_SF_GRAN = 128
WEIGHT_SF_GRAN_MN = 128
WEIGHT_SF_GRAN_K = 128
# ============================================================================
# 模块 1:Triton SwiGLU + FP8 量化内核
# ----------------------------------------------------------------------------
# baseline 的 L2 仍走 DeepGEMM SM90 grouped FP8 GEMM,所以 activation SFA 只能按
# per-128-K 输入;但 scale 数值采用 fused epilogue 同款 UE8M0/power-of-two 规则,
# 避免再额外引入 exact-FP32-scale 差异。
# 输入 x : (M, 2*H) bf16,内层是 [gate_part | up_part]
# 输入 topk_w : (M,) fp32,可选
# 输出 y : (M, H) fp8_e4m3fn
# 输出 y_sf : (M, H/BLOCK_K) fp32 行主序
# ============================================================================
@triton.jit
def _swiglu_apply_weight_to_fp8_kernel(
x_ptr,
topk_w_ptr,
y_ptr,
y_sf_ptr,
M,
H, # 运行时形状
stride_xm,
stride_xn, # x: (M, 2H) 的 stride
stride_ym,
stride_yn, # y: (M, H) 的 stride
stride_sfm,
stride_sfk, # y_sf: (M, H/BLOCK_K) 的 stride
clamp_value, # 当 HAS_CLAMP=False 时这个参数无意义
HAS_TOPK: tl.constexpr,
HAS_CLAMP: tl.constexpr,
USE_UE8M0_SCALE: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_K: tl.constexpr, # = num_per_channels
):
# 一个 program 处理 (BLOCK_M 个 token) × (第 pid_k 个 K-block 的 BLOCK_K 列)
pid_m = tl.program_id(0)
pid_k = tl.program_id(1)
# 行索引:本 program 负责 [pid_m*BLOCK_M, pid_m*BLOCK_M+BLOCK_M)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
# 当前 K-block 内的列索引(在 H 维度,不是 2H)
offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)
mask_m = offs_m < M
# ---- 1) 载入 gate(x 的前半段 [0, H))和 up(x 的后半段 [H, 2H))----
# 注意 stride_xn 是元素 stride(一般 == 1),但 H + offs_k 偏移是按"元素"算的
gate_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xn
up_ptrs = x_ptr + offs_m[:, None] * stride_xm + (H + offs_k[None, :]) * stride_xn
gate = tl.load(gate_ptrs, mask=mask_m[:, None], other=0.0).to(tl.float32)
up = tl.load(up_ptrs, mask=mask_m[:, None], other=0.0).to(tl.float32)
# ---- 2) 可选 clamp(参考 tilelang 实现:gate 单边 max,up 双边)----
if HAS_CLAMP:
gate = tl.minimum(gate, clamp_value)
up = tl.minimum(tl.maximum(up, -clamp_value), clamp_value)
# ---- 3) SwiGLU:silu(gate) * up = gate * sigmoid(gate) * up(全程 FP32 累计)----
y = gate * tl.sigmoid(gate) * up
# ---- 4) 可选 MoE 权重缩放(per-token 标量)----
if HAS_TOPK:
w = tl.load(topk_w_ptr + offs_m, mask=mask_m, other=1.0)
y = y * w[:, None]
# ---- 5) 当前 K-block 内每行 absmax → scale ----
amax = tl.max(tl.abs(y), axis=1) # (BLOCK_M,)
sf = tl.maximum(amax / _FP8_E4M3_MAX_TL, 1.0e-30)
if USE_UE8M0_SCALE:
# 对齐 deep_gemm/common/math.cuh::get_e4m3_sf_and_sf_inv:
# scale = 2 ** ceil(log2(amax / 448)).
sf = tl.exp2(tl.ceil(tl.log2(sf)))
# ---- 6) 量化为 FP8 e4m3fn ----
y_fp8 = (y / sf[:, None]).to(tl.float8e4nv)
# ---- 7) 写回 y 和 sf ----
y_ptrs = y_ptr + offs_m[:, None] * stride_ym + offs_k[None, :] * stride_yn
tl.store(y_ptrs, y_fp8, mask=mask_m[:, None])
sf_ptrs = y_sf_ptr + offs_m * stride_sfm + pid_k * stride_sfk
tl.store(sf_ptrs, sf, mask=mask_m)
def swiglu_apply_weight_to_fp8_triton(
x: torch.Tensor,
topk_weights: torch.Tensor | None,
clamp_value: float | None = None,
num_per_channels: int = BASELINE_L2_ACT_SF_GRAN,
use_ue8m0_scale: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""SwiGLU + FP8 量化。语义等价于 PyTorch reference:
gate, up = x[:, :H], x[:, H:]
y = silu(gate.clamp(max=c)) * up.clamp(-c, c) * topk_w
y_sf = y.view(M, H/np, np).abs().amax(-1) / 448
if use_ue8m0_scale: y_sf = ceil_to_power_of_2(y_sf)
y_fp8 = (y / y_sf.unsqueeze(-1)).to(fp8)
"""
assert x.is_cuda and x.dtype == torch.bfloat16
assert x.is_contiguous(), "当前实现假设 x 是 contiguous 的,避免 stride 计算错位"
M, two_H = x.shape
H = two_H // 2
assert H % num_per_channels == 0, f"H={H} 必须是 {num_per_channels} 的整数倍"
y = torch.empty((M, H), dtype=torch.float8_e4m3fn, device=x.device)
y_sf = torch.empty((M, H // num_per_channels), dtype=torch.float32, device=x.device)
# BLOCK_M 取 16:内核每个 program 处理 16 个 token × 128 列,寄存器压力小、容易调
BLOCK_M = 16
grid = (triton.cdiv(M, BLOCK_M), H // num_per_channels)
# HAS_TOPK=False 时仍要传一个有效指针(Triton 不允许 nullptr),用 x 占位
topk_ptr = topk_weights if topk_weights is not None else x
_swiglu_apply_weight_to_fp8_kernel[grid](
x,
topk_ptr,
y,
y_sf,
M,
H,
x.stride(0),
x.stride(1),
y.stride(0),
y.stride(1),
y_sf.stride(0),
y_sf.stride(1),
float(clamp_value) if clamp_value is not None else 0.0,
HAS_TOPK=topk_weights is not None,
HAS_CLAMP=clamp_value is not None,
USE_UE8M0_SCALE=use_ue8m0_scale,
BLOCK_M=BLOCK_M,
BLOCK_K=num_per_channels,
)
return y, y_sf
# ============================================================================
# 模块 2:grouped weight 的 (128, 128) FP8 块量化
# ----------------------------------------------------------------------------
# m_grouped_fp8_gemm_nt_contiguous 在 SM90 上对 weight 的输入约定:
# 每 (128, 128) 子块共享一个 FP32 SF,K 是 SF 的内层连续维(K-major)。
# 与 SM100 FP4 路径的差异:
# * 不需要 deep_gemm.transform_sf_into_required_layout
# * SF 是 FP32,不是 UE8M0 packed
# ============================================================================
def _quantize_grouped_fp8_block_128_128(
w: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""(G, N, K) bf16 → (G, N, K) fp8_e4m3fn + (G, N//128, K//128) fp32 SF。"""
g, n, k = w.shape
assert n % 128 == 0 and k % 128 == 0, f"weight 的 N={n}, K={k} 都必须是 128 的倍数"
# 把 (N, K) 切成 (N/128, 128, K/128, 128),最后一维和倒数第三维就是 128×128 子块内部
w_view = w.view(g, n // 128, 128, k // 128, 128).float()
# 子块内 absmax → scale = amax / 448,clamp(1e-4) 避免全 0 子块
amax = w_view.abs().amax(dim=(-1, -3)).clamp(1e-4) # (G, N/128, K/128)
sf = amax / FP8_E4M3_MAX
# 量化:每个元素除以所属子块的 sf 后转 FP8
# sf 形状 (G, N/128, K/128),需在 N-内 (axis -3) 和 K-内 (axis -1) 都补维度
w_fp8 = (w_view / sf.unsqueeze(-1).unsqueeze(-3)).to(torch.float8_e4m3fn)
return w_fp8.view(g, n, k).contiguous(), sf.contiguous()
# ============================================================================
# 模块 3:尝试导入 deep_ep(用于 dispatch / combine)
# ============================================================================
def _import_deep_ep():
try:
import deep_ep
return deep_ep
except Exception as ex:
dist_print(f"Failed to import deep_ep: {ex}", once_in_node=True)
return None
# ============================================================================
# 模块 4:CUDA event 中位数测时(避开对 tilelang.do_bench 的依赖)
# ============================================================================
def _bench_cuda_events(
fn, num_warmup: int = 5, num_repeat: int = 20, l2_flush_gb: float = 8.0
) -> float:
"""返回 fn 的中位数耗时(秒)。"""
for _ in range(num_warmup):
fn()
torch.cuda.synchronize()
times_ms = []
for _ in range(num_repeat):
# L2 flush,避免重复访问命中 cache 让测时偏低
if l2_flush_gb > 0:
free_bytes, _ = torch.cuda.mem_get_info()
flush_bytes = min(int(l2_flush_gb * 1e9), int(free_bytes * 0.5))
if flush_bytes >= 4:
torch.empty(flush_bytes // 4, dtype=torch.int, device="cuda").zero_()
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record()
fn()
e.record()
e.synchronize()
times_ms.append(s.elapsed_time(e))
times_ms.sort()
return times_ms[len(times_ms) // 2] / 1e3
# ============================================================================
# 模块 5:test() 主入口 — 在每个 rank 上跑一遍 baseline
# ============================================================================
def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
# 初始化分布式:rank_idx 是全局 rank,group 是默认 NCCL group
rank_idx, num_ranks, group = init_dist(local_rank, num_local_ranks)
torch.manual_seed(rank_idx)
random.seed(rank_idx)
# 形状参数(与 test_mega_moe.py 同名同义)
num_max_tokens_per_rank = args.num_max_tokens_per_rank
num_tokens = args.num_tokens if args.num_tokens > 0 else num_max_tokens_per_rank
hidden, intermediate_hidden = args.hidden, args.intermediate_hidden
num_experts, num_topk = args.num_experts, args.num_topk
num_experts_per_rank = num_experts // num_ranks
assert num_tokens <= num_max_tokens_per_rank
assert num_experts % num_ranks == 0, (
f"num_experts={num_experts} 必须能被 num_ranks={num_ranks} 整除"
)
# SM90 fused kernel 的形状约束(来自 csrc/apis/mega.hpp::fp8_mega_moe):
# * H、IH 必须是 128 的倍数(L1 input per-128-K SF + block-(128,128) weight SF)
# * IH/64 ≤ 64 → IH ≤ 4096(l2_arrival_mask 是 uint64,每 bit 对应 64 列)
assert hidden % 128 == 0
assert intermediate_hidden % 128 == 0
assert intermediate_hidden // 64 <= 64, (
f"SM90 fused kernel 要求 intermediate_hidden <= 4096, 当前 {intermediate_hidden}"
)
# ---- 创建 BF16 输入:token 与两层 weight ----
# x: 每 rank 本地 num_tokens 个 token,每个 token hidden 维
x_bf16 = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device="cuda")
# L1 weight: 每个 expert 把 hidden → 2*intermediate_hidden(gate 和 up 拼一起)
l1_weights_bf16 = torch.randn(
(num_experts_per_rank, intermediate_hidden * 2, hidden),
dtype=torch.bfloat16,
device="cuda",
)
# L2 weight: 每个 expert 把 intermediate_hidden → hidden
l2_weights_bf16 = torch.randn(
(num_experts_per_rank, hidden, intermediate_hidden),
dtype=torch.bfloat16,
device="cuda",
)
# 路由:scores → topk_idx (M, K) + topk_weights (M, K)
scores = torch.randn((num_tokens, num_experts), dtype=torch.float, device="cuda")
topk_weights, topk_idx = torch.topk(
scores, num_topk, dim=-1, largest=True, sorted=False
)
# 累计接收统计:fused 与 baseline 各持一份避免相互覆盖
cum_stats_fused = torch.zeros(
(num_experts_per_rank,), dtype=torch.int, device="cuda"
)
cum_stats_baseline = cum_stats_fused.clone()
# ---- BF16 → FP8 量化 ----
# x_fp8 是元组:(token_fp8 (M, hidden), token_sf (M, hidden//128) fp32 行主序)
# 注意 use_ue8m0=False, use_packed_ue8m0=False:SM90 不接受 UE8M0 packed SF
x_fp8 = per_token_cast_to_fp8(
x_bf16, use_ue8m0=False, gran_k=128, use_packed_ue8m0=False
)
# weight 量化:(G, N, K) bf16 → ((G, N, K) fp8 e4m3fn, (G, N//128, K//128) fp32 SF)
# baseline(DeepEP grouped GEMM)直接用这两个未变换的元组
l1_weights = _quantize_grouped_fp8_block_128_128(l1_weights_bf16)
l2_weights = _quantize_grouped_fp8_block_128_128(l2_weights_bf16)
# fused 路径:FP8 weight 上做 gate/up gran-8 N-轴 interleave;SF 不变
transformed_l1, transformed_l2 = deep_gemm.transform_weights_for_mega_moe_sm90(
l1_weights, l2_weights
)
# SwiGLU clamp:finite → 传给 fused/triton;inf → None(关闭 clamp,与 SM90 fused 一致)
clamp_arg = args.activation_clamp if math.isfinite(args.activation_clamp) else None
# ---- DeepGEMM grouped GEMM 的 M 维 alignment(baseline 走 DeepEP 时也用这个)----
alignment = deep_gemm.get_theoretical_mk_alignment_for_contiguous_layout()
deep_gemm.set_mk_alignment_for_contiguous_layout(alignment)
# ---- 分配 fused 的 SymmBuffer 与输出 buffer ----
sym_buffer = deep_gemm.get_symm_buffer_for_mega_moe(
group,
num_experts,
num_max_tokens_per_rank,
num_topk,
hidden,
intermediate_hidden,
)
y_fused = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device="cuda")
def run_fused():
# NOTE: 跟 SM100 test_mega_moe.py 的处理一致 —— DG_COMM_KERNEL_DEBUG=1 时
# kernel 出口会把 sym_buffer 整块清零,所以每次都要重新拷输入
sym_buffer.x[:num_tokens].copy_(x_fp8[0])
sym_buffer.x_sf[:num_tokens].copy_(x_fp8[1])
sym_buffer.topk_idx[:num_tokens].copy_(topk_idx)
sym_buffer.topk_weights[:num_tokens].copy_(topk_weights)
deep_gemm.fp8_mega_moe(
y_fused,
transformed_l1,
transformed_l2,
sym_buffer,
cumulative_local_expert_recv_stats=cum_stats_fused,
recipe=(128, 128, 128),
activation="swiglu",
activation_clamp=clamp_arg,
fast_math=bool(args.fast_math),
)
return y_fused
# ---- 分配 DeepEP buffer(baseline 用)----
deep_ep = _import_deep_ep()
ep_buffer = None
if deep_ep is not None:
ep_buffer = deep_ep.ElasticBuffer(
group,
num_max_tokens_per_rank=num_max_tokens_per_rank,
hidden=hidden,
num_topk=num_topk,
use_fp8_dispatch=True,
explicitly_destroy=True,
allow_multiple_reduction=False,
)
# ----------------------------------------------------------------
# baseline 主体:dispatch → L1 GEMM → SwiGLU+量化 → L2 GEMM → combine
# 与 fused 用同一份 (FP8 weight, FP32 block-(128,128) SF) —— 但是 **未变换**
# 的版本(baseline grouped GEMM 不需要 gate/up interleave)
# ----------------------------------------------------------------
def run_baseline():
recv_x, _, recv_topk_weights, handle, _ = ep_buffer.dispatch(
x_fp8,
topk_idx=topk_idx,
topk_weights=topk_weights,
cumulative_local_expert_recv_stats=cum_stats_baseline,
num_experts=num_experts,
expert_alignment=alignment,
do_cpu_sync=False,
do_handle_copy=False,
do_expand=True,
use_tma_aligned_col_major_sf=False, # SM90: row-major float SF
)
n = recv_x[0].size(0)
# L1 GEMM:FP8 token @ FP8 W1 → BF16 中间激活 (gate||up 拼接)
l1_y = torch.empty(
(n, intermediate_hidden * 2), dtype=torch.bfloat16, device="cuda"
)
deep_gemm.m_grouped_fp8_gemm_nt_contiguous(
recv_x,
l1_weights,
l1_y,
handle.psum_num_recv_tokens_per_expert,
use_psum_layout=True,
disable_ue8m0_cast=True,
)
# Triton SwiGLU + FP8 量化(含 topk 权重乘法)
# 注意:fused SM90 mega-MoE 的 L2 activation SFA 是 per-64-K;
# 当前 DeepGEMM SM90 grouped GEMM 只支持 per-128-K SFA,所以性能 baseline
# 只能用 per-128-K,但 scale 数值采用 fused 同款 UE8M0/power-of-two。
l1_y = swiglu_apply_weight_to_fp8_triton(
x=l1_y,
topk_weights=recv_topk_weights,
clamp_value=clamp_arg,
num_per_channels=BASELINE_L2_ACT_SF_GRAN,
use_ue8m0_scale=True,
)
# L2 GEMM:FP8 中间激活 @ FP8 W2 → BF16
l2_y = torch.empty((n, hidden), dtype=torch.bfloat16, device="cuda")
deep_gemm.m_grouped_fp8_gemm_nt_contiguous(
l1_y,
l2_weights,
l2_y,
handle.psum_num_recv_tokens_per_expert,
use_psum_layout=True,
disable_ue8m0_cast=True,
)
# DeepEP combine:把每个 token 在 topk 个 expert 上的输出汇聚回源 rank
return ep_buffer.combine(l2_y, handle=handle)[0]
# ---- 打印 config ----
dist_print("Config (H200 fused mega-MoE):", once_in_node=True)
dist_print(f" > Tokens: {num_tokens}/{num_max_tokens_per_rank}", once_in_node=True)
dist_print(
f" > Hidden: {hidden}, Intermediate: {intermediate_hidden}", once_in_node=True
)
dist_print(
f" > Experts: {num_topk}/{num_experts} (per-rank: {num_experts_per_rank})",
once_in_node=True,
)
dist_print(
f" > Activation SF: fused L2 per-{FUSED_L2_ACT_SF_GRAN} UE8M0, "
f"baseline L2 per-{BASELINE_L2_ACT_SF_GRAN} UE8M0 "
f"(SM90 grouped GEMM constraint)",
once_in_node=True,
)
dist_print(
f" > Buffer: {sym_buffer.buffer.nbytes / 2**30:.3f} GiB", once_in_node=True
)
dist_print(once_in_node=True)
# ---- 跑一次确保不报错(fused + 可选 baseline)----
y = run_fused()
assert y.shape == (num_tokens, hidden) and y.dtype == torch.bfloat16, (
f"fused 输出 shape/dtype 异常: shape={y.shape}, dtype={y.dtype}"
)
if ep_buffer is not None:
out_b = run_baseline()
assert out_b.shape == (num_tokens, hidden) and out_b.dtype == torch.bfloat16, (
f"baseline 输出 shape/dtype 异常: shape={out_b.shape}, dtype={out_b.dtype}"
)
if args.check_output_diff:
diff = (y.float() - out_b.float()).abs()
denom = out_b.float().abs().mean().clamp_min(1e-12)
dist_print(
"Output diff (fused vs legacy-per128 baseline):", once_in_node=True
)
dist_print(
f" > max_abs={diff.max().item():.6e}, "
f"mean_abs={diff.mean().item():.6e}, "
f"mean_abs/mean_ref={diff.mean().div(denom).item():.6e}",
once_in_node=True,
)
dist_print(once_in_node=True)
# ---- 统计本 rank 实际接收的 token 数与触达的 expert 数 ----
# 把所有 rank 的 topk_idx 收齐,再把不落在本 rank 持有 expert 范围内的条目
# 标成 -1;剩下的非 -1 条目数即"被路由进本 rank 的 (token, slot) 总数"。
gathered_topk_idx = uneven_all_gather(topk_idx, group=group)
gathered_topk_idx[
(gathered_topk_idx < rank_idx * num_experts_per_rank)
| (gathered_topk_idx >= (rank_idx + 1) * num_experts_per_rank)
] = -1
num_recv_tokens = int((gathered_topk_idx != -1).sum().item())
num_touched_experts = max(torch.unique(gathered_topk_idx.flatten()).numel() - 1, 0)
# ---- benchmark ----
# fused:bench_kineto 抓 sm90_fp8_mega_moe_impl 的 GPU 段(不含 host overhead)
t_fused = bench_kineto(
run_fused,
SM90_KERNEL_NAME,
num_tests=args.num_bench_tests,
barrier=lambda: ep_buffer.barrier(use_comm_stream=False)
if ep_buffer is not None
else dist.barrier(),
trace_path=(
f"{args.dump_profile_traces}/mega_moe_hopper_rank{rank_idx}.json"
if args.dump_profile_traces
else None
),
)
# baseline:cuda events 中位数(tilelang.do_bench 在 H200 不一定有,统一用 events)
t_baseline = (
_bench_cuda_events(
run_baseline,
num_warmup=args.num_warmup,
num_repeat=args.num_repeat,
l2_flush_gb=args.l2_flush_gb,
)
if ep_buffer is not None
else 0.0
)
def safe_div(a, b):
return float("nan") if b == 0 else a / b
# 端到端 TFLOPS:3 个 matmul(L1 gate、L1 up、L2),每个 2*M*N*K,M=num_recv_tokens
tflops = safe_div(
2 * num_recv_tokens * (hidden * intermediate_hidden * 3) / 1e12, t_fused
)
# HBM 字节估算(SM90: weight 是 FP8 = 1B/elem,与 SM100 FP4=0.5B 不同)
l1_weight_bytes = num_touched_experts * intermediate_hidden * 2 * hidden
l2_weight_bytes = num_touched_experts * hidden * intermediate_hidden
l1_weight_sf_bytes = (
num_touched_experts
* (intermediate_hidden * 2 // WEIGHT_SF_GRAN_MN)
* (hidden // WEIGHT_SF_GRAN_K)
* 4
)
l2_weight_sf_bytes = (
num_touched_experts
* (hidden // WEIGHT_SF_GRAN_MN)
* (intermediate_hidden // WEIGHT_SF_GRAN_K)
* 4
)
l1_input_sf_bytes = num_recv_tokens * (hidden // L1_ACT_SF_GRAN) * 4
l2_act_sf_bytes = (
num_recv_tokens * (intermediate_hidden // FUSED_L2_ACT_SF_GRAN) * 4
)
num_hbm_bytes = (
l1_weight_bytes
+ l2_weight_bytes # weights (FP8)
+ l1_weight_sf_bytes
+ l2_weight_sf_bytes # weight SF (FP32)
+ num_recv_tokens * hidden
+ l1_input_sf_bytes # L1 输入读 (FP8 + SF)
+ num_recv_tokens * intermediate_hidden
+ l2_act_sf_bytes # L1 输出写 (FP8 + SF)
+ num_recv_tokens * intermediate_hidden
+ l2_act_sf_bytes # L2 输入读 (FP8 + SF)
+ num_recv_tokens * hidden * 2 # L2 输出写 (BF16)
)
hbm_gbs = safe_div(num_hbm_bytes / 1e9, t_fused)
# NVLink 字节:dispatch 拉 token + input SF + topk weight,combine 写回 BF16
num_nvlink_bytes = num_recv_tokens * (hidden + hidden // 32 + 4 + hidden * 2)
nvlink_gbs = safe_div(num_nvlink_bytes / 1e9, t_fused)
# combine reduction 串行下界(解析估计;6.5e12 = HBM 串行 reduction 经验吞吐 B/s)
t_reduction = num_tokens * hidden * 2 * (1 + num_topk) / 6.5e12
# overlap 校正:扣掉 fused 中无法重叠的串行 reduction 段后估计稳态吞吐
approx_factor = t_fused / max(t_fused - t_reduction, 1e-12)
# baseline 用同一份 FLOPs / HBM 字节,时间换成 t_baseline
tflops_baseline = safe_div(
2 * num_recv_tokens * (hidden * intermediate_hidden * 3) / 1e12, t_baseline
)
hbm_gbs_baseline = safe_div(num_hbm_bytes / 1e9, t_baseline)
nvlink_gbs_baseline = safe_div(num_nvlink_bytes / 1e9, t_baseline)
dist_print("Performance:", once_in_node=True)
dist_print(
f" > [fused] EP {rank_idx:2}/{num_ranks} | "
f"{tflops:4.0f} TFLOPS | "
f"overlap: {tflops * approx_factor:4.0f} TFLOPS, "
f"HBM {hbm_gbs * approx_factor:4.0f} GB/s, "
f"NVL {nvlink_gbs * approx_factor:3.0f} GB/s | "
f"{t_fused * 1e6:6.0f} us, "
f"reduction: {t_reduction * 1e6:5.1f} us"
)
if ep_buffer is not None:
speedup = safe_div(t_baseline, t_fused)
dist_print(
f" > [baseline] EP {rank_idx:2}/{num_ranks} | "
f"{tflops_baseline:4.0f} TFLOPS | "
f" HBM {hbm_gbs_baseline:4.0f} GB/s, "
f"NVL {nvlink_gbs_baseline:3.0f} GB/s | "
f"{t_baseline * 1e6:6.0f} us | "
f"t_baseline/t_fused = {speedup:.2f}x "
f"({'fused 更快' if speedup > 1 else 'baseline 更快'})"
)
else:
dist_print(" > [baseline] (no baseline: deep_ep unavailable)", once_in_node=True)
# ---- 清理 ----
dist.barrier()
sym_buffer.destroy()
if ep_buffer is not None:
ep_buffer.destroy()
dist.destroy_process_group()
# ============================================================================
# 模块 6:argparse + spawn
# ============================================================================
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="H200 mega-MoE: fused (deep_gemm.fp8_mega_moe) vs DeepEP+grouped-FP8 baseline"
)
# 资源
parser.add_argument(
"--num-processes", type=int, default=8, help="spawn 出来的进程数(一卡一进程)"
)
# 模型形状
# 注:SM90 fused kernel 要求 intermediate_hidden ≤ 4096
parser.add_argument("--num-max-tokens-per-rank", type=int, default=8192)
parser.add_argument(
"--num-tokens",
type=int,
default=0,
help="per-rank 实际 token 数;0 表示用 num-max-tokens-per-rank",
)
parser.add_argument("--hidden", type=int, default=7168)
parser.add_argument(
"--intermediate-hidden",
type=int,
default=3072,
help="中间层维度(≤ 4096,受 SM90 l2_arrival_mask 约束)",
)
parser.add_argument(
"--activation-clamp",
type=float,
default=10.0,
help="SwiGLU 前对 gate/up 的 clamp 阈值;传 inf 表示关闭",
)
parser.add_argument("--num-experts", type=int, default=384)
parser.add_argument("--num-topk", type=int, default=6)
parser.add_argument(
"--fast-math",
type=int,
default=1,
help="fused 内 SwiGLU 是否启用 fast-math(0/1)",
)
# 测时
parser.add_argument(
"--num-bench-tests",
type=int,
default=30,
help="bench_kineto 抓 fused 时的迭代数",
)
parser.add_argument(
"--num-warmup", type=int, default=5, help="baseline cuda events warmup"
)
parser.add_argument(
"--num-repeat", type=int, default=20, help="baseline cuda events 测时迭代"
)
parser.add_argument(
"--l2-flush-gb",
type=float,
default=8.0,
help="baseline event 测时前用于 flush L2 的临时写入大小;0 表示关闭",
)
parser.add_argument(
"--check-output-diff",
type=int,
default=0,
help="非 0 时打印 fused 与 legacy-per128 baseline 的输出差异(预期非 bitwise)",
)
parser.add_argument(
"--dump-profile-traces",
type=str,
default="",
help="非空时把 fused 的 Chrome trace 写到该目录(每 rank 一份)",
)
args = parser.parse_args()
if args.dump_profile_traces:
os.makedirs(args.dump_profile_traces, exist_ok=True)
# 多进程启动:每个进程对应一个 GPU;test() 内部用 init_dist 建 NCCL group
torch.multiprocessing.spawn(
test, args=(args.num_processes, args), nprocs=args.num_processes
) |
deepseek-ai/DeepEP#629 没有RDMA 的8卡需要依赖这个PR 才能跑ElasticBuffer接口,另外EP4 是不是太小了?整体应该还是bound 在HBM 读取上面了,看不到megamoe 的收益。 |
|
你跑过B300的ncu报告吗?我跑出来的B300的报告SM和Memory利用率非常低,不知道是不是跑错了,感觉有点奇怪呢。#336 |
就是官方的Mega MoE的kernel的B300的报告,我在8卡上跑的 |
|
@Stone749990226 我只有H20环境 |
请问,这是用最新代码跑的吗?baseline是和sm100上一样,deepep v2+deepgemm完全没任何overlap的吗?我之前跑出来的结果如下:
|
test_mega_moe_sm90.py |
源码有吗? |
测试脚本是 |
|
按照最新的测试跑了一下,确实没有spill,不过相比H800
1900+TFLOPS的峰值tflops比较低的现象还是存在的(379 TFLOPS)。这里的C7510我记得是由于vprinf导致的,但是解了之后提升不大。 Config:
Tokens: 8192/8192
Hidden: 4096
Intermediate: 2048
Experts: 6/256
Buffer: 2.507 GiB
Warning: please use at least NVCC 12.9 for the best DeepGEMM performance
Warning: please use at least NVCC 12.9 for the best DeepGEMM performance
Warning: please use at least NVCC 12.9 for the best DeepGEMM performance
Warning: please use at least NVCC 12.9 for the best DeepGEMM performance
Running NVCC command: cd /root/.deep_gemm/tmp && /usr/local/cuda/bin/nvcc
/root/.deep_gemm/tmp/3743189-40fbd983-7dc7f586-d638089b/kernel.cu -cubin -o
/root/.deep_gemm/tmp/3743189-40fbd983-7dc7f586-d638089b/kernel.cubin
-std=c++20 --diag-suppress=39,161,174,177,186,940
--ptxas-options=--register-usage-level=10
--ptxas-options=--verbose,--warn-on-local-memory-usage
-I/workspace/qiushixiaoyu-deepgemm/DeepGEMM/deep_gemm/include
--gpu-architecture=sm_90a
--compiler-options=-fPIC,-O3,-fconcepts,-Wno-deprecated-declarations,-Wno-abi
-O3 --expt-relaxed-constexpr --expt-extended-lambda
Running NVCC command: cd /root/.deep_gemm/tmp && /usr/local/cuda/bin/nvcc
/root/.deep_gemm/tmp/3743187-8f593979-ce0bfdbd-e1ee5922/kernel.cu -cubin -o
/root/.deep_gemm/tmp/3743187-8f593979-ce0bfdbd-e1ee5922/kernel.cubin
-std=c++20 --diag-suppress=39,161,174,177,186,940
--ptxas-options=--register-usage-level=10
--ptxas-options=--verbose,--warn-on-local-memory-usage
-I/workspace/qiushixiaoyu-deepgemm/DeepGEMM/deep_gemm/include
--gpu-architecture=sm_90a
--compiler-options=-fPIC,-O3,-fconcepts,-Wno-deprecated-declarations,-Wno-abi
-O3 --expt-relaxed-constexpr --expt-extended-lambda
Running NVCC command: cd /root/.deep_gemm/tmp && /usr/local/cuda/bin/nvcc
/root/.deep_gemm/tmp/3743188-7615d764-76bb6375-b809a2fe/kernel.cu -cubin -o
/root/.deep_gemm/tmp/3743188-7615d764-76bb6375-b809a2fe/kernel.cubin
-std=c++20 --diag-suppress=39,161,174,177,186,940
--ptxas-options=--register-usage-level=10
--ptxas-options=--verbose,--warn-on-local-memory-usage
-I/workspace/qiushixiaoyu-deepgemm/DeepGEMM/deep_gemm/include
--gpu-architecture=sm_90a
--compiler-options=-fPIC,-O3,-fconcepts,-Wno-deprecated-declarations,-Wno-abi
-O3 --expt-relaxed-constexpr --expt-extended-lambda
Running NVCC command: cd /root/.deep_gemm/tmp && /usr/local/cuda/bin/nvcc
/root/.deep_gemm/tmp/3743186-397a9b0a-45b0e0be-7ec9317a/kernel.cu -cubin -o
/root/.deep_gemm/tmp/3743186-397a9b0a-45b0e0be-7ec9317a/kernel.cubin
-std=c++20 --diag-suppress=39,161,174,177,186,940
--ptxas-options=--register-usage-level=10
--ptxas-options=--verbose,--warn-on-local-memory-usage
-I/workspace/qiushixiaoyu-deepgemm/DeepGEMM/deep_gemm/include
--gpu-architecture=sm_90a
--compiler-options=-fPIC,-O3,-fconcepts,-Wno-deprecated-declarations,-Wno-abi
-O3 --expt-relaxed-constexpr --expt-extended-lambda
ptxas info : (C7510) Potential Performance Loss: wgmma.mma_async
instructions are serialized due to wgmma pipeline crossing function
boundary at a function call in the function
'_ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT15_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_'
ptxas warning : Local memory used for function
'_ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT15_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_',
size of stack frame: 56 bytes
ptxas info : 517 bytes gmem
ptxas info : Compiling entry function
'_ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT15_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_'
for 'sm_90a'
ptxas info : Function properties for
_ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT15_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_
56 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info : Used 168 registers, used 16 barriers, 56 bytes cumulative
stack size
ptxas info : Compile time = 315.840 ms
ptxas info : (C7510) Potential Performance Loss: wgmma.mma_async
instructions are serialized due to wgmma pipeline crossing function
boundary at a function call in the function
'_ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT15_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_'
ptxas warning : Local memory used for function
'_ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT15_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_',
size of stack frame: 56 bytes
ptxas info : 517 bytes gmem
ptxas info : Compiling entry function
'_ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT15_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_'
for 'sm_90a'
ptxas info : Function properties for
_ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT15_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_
56 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info : Used 168 registers, used 16 barriers, 56 bytes cumulative
stack size
ptxas info : Compile time = 307.963 ms
ptxas info : (C7510) Potential Performance Loss: wgmma.mma_async
instructions are serialized due to wgmma pipeline crossing function
boundary at a function call in the function
'_ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT15_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_'
ptxas warning : Local memory used for function
'_ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT15_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_',
size of stack frame: 56 bytes
ptxas info : 517 bytes gmem
ptxas info : Compiling entry function
'_ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT15_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_'
for 'sm_90a'
ptxas info : Function properties for
_ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT15_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_
56 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info : Used 168 registers, used 16 barriers, 56 bytes cumulative
stack size
ptxas info : Compile time = 308.712 ms
ptxas info : (C7510) Potential Performance Loss: wgmma.mma_async
instructions are serialized due to wgmma pipeline crossing function
boundary at a function call in the function
'_ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT15_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_'
ptxas warning : Local memory used for function
'_ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT15_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_',
size of stack frame: 56 bytes
ptxas info : 517 bytes gmem
ptxas info : Compiling entry function
'_ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT15_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_'
for 'sm_90a'
ptxas info : Function properties for
_ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT15_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_
56 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info : Used 168 registers, used 16 barriers, 56 bytes cumulative
stack size
ptxas info : Compile time = 325.611 ms
Performance:
EP: 0/4 | 375 TFLOPS | overlap: 379 TFLOPS, HBM 370 GB/s, NVL 92
GB/s | 6605 us, reduction: 72.3 us | 0.00x legacy
foobar2023xx ***@***.***> 于2026年5月25日周一 11:12写道:
… *foobar2023xx* left a comment (deepseek-ai/DeepGEMM#323)
<#323 (comment)>
在我们的测试里,nvcc编译出来这个版本的寄存器spill很多,对比deepepv2+deepgemm non-overlap: [image:
image]
<https://private-user-images.githubusercontent.com/199557527/595469978-c8551ec8-eac4-49c0-b41b-7ea9d7532f48.png?jwt=eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Nzk0Mzk4MjksIm5iZiI6MTc3OTQzOTUyOSwicGF0aCI6Ii8xOTk1NTc1MjcvNTk1NDY5OTc4LWM4NTUxZWM4LWVhYzQtNDljMC1iNDFiLTdlYTlkNzUzMmY0OC5wbmc_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBVkNPRFlMU0E1M1BRSzRaQSUyRjIwMjYwNTIyJTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDI2MDUyMlQwODQ1MjlaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT1jMzQ0ZDZkM2MyYWZmMDkzMmJjZTlhYzdhMWExZTgxZTAzZjM5MzlmYjM0ZDk2NzMyOWZiZWNhMDk5ZWY1ODRmJlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCZyZXNwb25zZS1jb250ZW50LXR5cGU9aW1hZ2UlMkZwbmcifQ.fuwLWBUSUuRpl04rM4JIfpv6YfaydLq7LG04NYmb_Jc>
会比baseline差。 我们尝试手动解决了一下,并对计算warpgroup做了修改,拿4卡H800测试的效果: --num-processes 4
--num-experts 128 --hidden 3072 on 4*H800 [image: image]
<https://private-user-images.githubusercontent.com/199557527/595470385-62183cd2-a858-4569-a73f-aca982a0a8bd.png?jwt=eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Nzk0Mzk4MjksIm5iZiI6MTc3OTQzOTUyOSwicGF0aCI6Ii8xOTk1NTc1MjcvNTk1NDcwMzg1LTYyMTgzY2QyLWE4NTgtNDU2OS1hNzNmLWFjYTk4MmEwYThiZC5wbmc_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBVkNPRFlMU0E1M1BRSzRaQSUyRjIwMjYwNTIyJTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDI2MDUyMlQwODQ1MjlaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT1kNDQzMmY4OTFkNzNkNjQ3NGYzY2Y3Y2E3ZGMxNWNiM2ZiMzQwMzUwNGU0NDk3NjQyMjI4MzU2NTZiY2I1ZGFjJlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCZyZXNwb25zZS1jb250ZW50LXR5cGU9aW1hZ2UlMkZwbmcifQ.emAemklJJR0mz482dNzDEAf_9Q_7B57qQ-cmv6gV03w>
看上去性能合理
我在H800上测试的时候,没有观察到明显的寄存器spill,请问你们是基于当前最新版本测试的吗?
Running NVCC command: cd /tmp/dg_sm90_spill_v1/tmp && /usr/local/cuda/bin/nvcc /tmp/dg_sm90_spill_v1/tmp/13989-a1ffc508-d2e5b710-f4640373/kernel.cu -cubin -o /tmp/dg_sm90_spill_v1/tmp/13989-a1ffc508-d2e5b710-f4640373/kernel.cubin -std=c++20 --diag-suppress=39,161,174,177,186,940 --ptxas-options=--register-usage-level=10 --ptxas-options=--verbose,--warn-on-local-memory-usage -I/workspace/DeepGEMM/deep_gemm/include --gpu-architecture=sm_90a --compiler-options=-fPIC,-O3,-fconcepts,-Wno-deprecated-declarations,-Wno-abi -O3 --expt-relaxed-constexpr --expt-extended-lambda
ptxas info : (C7510) Potential Performance Loss: wgmma.mma_async instructions are serialized due to wgmma pipeline crossing function boundary at a function call in the function '_ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj0ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj1ELj48ELj40ELj208ELj0ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT16_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_'
ptxas warning : Local memory used for function '_ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj0ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj1ELj48ELj40ELj208ELj0ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT16_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_', size of stack frame: 56 bytes
ptxas info : 474 bytes gmem
ptxas info : Compiling entry function '_ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj0ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj1ELj48ELj40ELj208ELj0ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT16_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_' for 'sm_90a'
ptxas info : Function properties for _ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj0ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj1ELj48ELj40ELj208ELj0ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT16_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_
56 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info : Used 168 registers, used 16 barriers, 56 bytes cumulative stack size
ptxas info : Compile time = 582.487 ms
你的测试脚本和参数是什么
测试脚本是
bench_mega_moe_sm90.py
<https://github.com/user-attachments/files/28204259/bench_mega_moe_sm90.py>
DG_JIT_PTXAS_VERBOSE=1 \
DG_JIT_PRINT_COMPILER_COMMAND=1 \
python tests/bench_mega_moe_sm90.py \
--num-processes 4 \
--num-max-tokens-per-rank 8192 \
--num-tokens 8192 \
--hidden 4096 \
--intermediate-hidden 2048 \
--num-experts 256 \
--num-topk 6
—
Reply to this email directly, view it on GitHub
<#323 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/BPSQDF2J5JJ7TX5SBDGY4DD44O2ZZAVCNFSM6AAAAACYLY4UHGVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHM2DKMZRGE4TIOJVGQ>
.
Triage notifications on the go with GitHub Mobile for iOS
<https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675>
or Android
<https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub>.
You are receiving this because you commented.Message ID:
***@***.***>
|
|
H200 跑下来看起来 收益是负面的。。。 |
我也是这个结果,只有Deepep+GEMM的2/3 |
|
https://github.com/usernamehaha2022/DeepGEMM/tree/add_hopper_mega DG_JIT_PTXAS_VERBOSE=1 DG_JIT_PRINT_COMPILER_COMMAND=1 python tests/bench_mega_moe_sm90.py --num-processes 4 --num-max-tokens-per-rank 8192 --num-tokens 8192 --hidden 7168 --intermediate-hidden 4096 --num-experts 256 --num-topk 6 基于lz的提交做了一些修改,大家可以测一下 |
This looks better than PR. |
但是MFU 也只有0.37的水准,比B卡的还低很多,还有哪些瓶颈? |
最大的瓶颈就是H卡没有TMEM。所以做一做低延迟还是可以,我理解比sbo会好一点 |
H卡缺寄存器是没办法的事,并且sfa和sfb也没有硬件加速,是CUDA core算的。如果 我砍掉sfa和sfb,直接fp8/fp8计算(总比mxfp8/mxfp4精度高)。是不是还能再腾出一点MFU。 |
|
DSV4参数 immediate_size = 3072 hidden_size = 7168 experts=384 tokens=448~640之间有一个性能低谷,tokens=648 相比 tokens=640 MFU直接暴涨 50%,感觉schedule有问题。 |
有可能,现在token少的时候确实性能一般。这个你有时间看吗?我目前还在看token多的时候tma不是瓶颈,L2 wait更多的问题。我的commit有一个简单的profiler: export DG_PHASE_PROFILING |
tokens_per_expert 分界线改成56 可以解决 PRO参数下 性能低谷问题。 感觉还是block_m 64还是128的调度问题。 |
|
cluster_size改为2利用TMA的Cluster多播特性一次搬运是否能加快速度呢?收益会很多吗?我看目前Cluster为1 |
40b508b to
23f46aa
Compare
Co-authored-by: rainj-me <rain-jiang@outlook.com>
Co-authored-by: b8zhong <b8zhong@users.noreply.github.com>
Co-authored-by: Brayden Zhong <b8zhong@uwaterloo.ca>
Co-authored-by: yinding <yinding@bytedance.com>
…el (deepseek-ai#27) Two related additions for the DeepSeek-V4-Pro mega-MoE path: 1. **FP4 (E2M1) activations + `kind::mxf4` mainloop opt-in** for `fp8_fp4_mega_moe`. - `DG_USE_FP4_ACTS=1` halves the symm-buffer x-slot footprint (E2M1 nibbles vs E4M3 bytes); SF slot unchanged (still `hidden/32` UE8M0 bytes under gran_k=32). - `use_mxf4_kind=true` switches the L1+L2 mainloops to `cta_group::2 kind::mxf4` (2-CTA cluster) with dense FP4 smem layout (`_ALIGN8B`, 2 nibbles/byte). Per-stage A/B byte footprint halves → num_stages doubles for the same smem budget. - Threads `cumulative_local_expert_recv_stats` through the public mega-MoE API for per-rank expert counters used by sglang's expert-distribution recorder. - Block-m heuristic: under `use_mxf4_kind`, bumps `block_m=16 → 32` for the smallest-tokens-per-expert bucket so `load_block_m * block_k / 2` meets the 1024-byte smem alignment. - Multi-block_m support via `kCandidateBlockM` array + LCM-aligned pool padding; replaces the static `block_m=192` heuristic with token-density dispatch (8/16/32/64/96/128/192). 2. **`mega_moe_pre_dispatch` kernel**: BF16 → quant + topk-copy + pad-fill in one launch, gated on `kUseFp4Acts` + `kUsePDL`. Templated on `(kGroupSize, kUseFp4Acts, kUsePDL)`. Uses bucketize-style E2M1 encoder for byte-exact match against the `per_token_cast_to_fp4` host helper. - New: `deep_gemm.mega_moe_pre_dispatch(x, topk_idx, topk_weights, buf_x, buf_x_sf, buf_topk_idx, buf_topk_weights, num_tokens, group_size, use_fp4_acts)` - Test: `tests/test_mega_moe_pre_dispatch.py` — single-GPU bytewise check against host `per_token_cast_to_fp{8,4}` + pad-fill assertion. Validated end-to-end on 8× B300 with DeepSeek-V4-Pro at 8K input bench: - FP4 acts + MXF4 kind path produces matching tokens vs the FP8 baseline (rel-RMSE ≤ 0.5 sentinel; GSM8K accuracy parity within run-to-run variance). PR also includes existing FP4-mega-MoE supporting changes that are required by the kernel: - `cluster_sync_with_relaxed_arrive` helper (used twice in `sm100_fp8_fp4_mega_moe.cuh`). - `cvt_pack_f32_to_e2m1x2` / `cvt_pack_f32x4_to_e2m1x4` PTX wrappers. - `SM100_MMA_MXF4_2x1SM_SS` 2-CTA cluster MMA wrapper. - Generalized `red_add(int*, int)` for the `cumulative_local_expert_recv_stats` counter. - `st.L1::no_allocate.relaxed.sys.global.u64` (correctness fix: previous generic-address variant could miss the global state space). Co-authored-by: pranjalssh <adkz.photos@gmail.com> (cherry picked from commit bca278e)
…bine path) (deepseek-ai#28) * Add DG_USE_FP8_COMBINE: FP8 + per-row UE8M0 SF on the second a2a (combine path) The mega-MoE second all-to-all (combine) currently ships BF16 over NVLink: each token, each topk slot = kHidden * 2 bytes. This commit adds an env- gated FP8 path that ships FP8 E4M3 + a per-(token, N=128) UE8M0 SF byte — kHidden + kHidden/128 bytes per token per slot, half the NVLink bytes. Wiring: - New `kUseFp8Combine` template flag (default false → keeps BF16 path byte-identical when off). - New `combine_sf_buffer` symm-buffer slot, sized kHidden/128 bytes per (token, slot) when on, zero when off. - Host: `DG_USE_FP8_COMBINE=1` env flag in `mega.hpp`. Independent of `DG_USE_FP4_ACTS` / `DG_USE_MXF4_KIND` (those control the dispatch a2a + mainloops; this controls the combine a2a only). Producer side (L2 epilogue write-back, sm100_fp8_fp4_mega_moe.cuh): - Read 8 BF16 from smem (existing STSM target). - Compute per-row amax via `__shfl_xor_sync` reduction over the 16 lanes that share each row tile. Use a 16-lane mask (NOT 0xffffffff) — the outer `if (m_idx_in_block >= valid_m) break` may cause the OTHER half- warp to exit on padding rows, and a full-warp shfl would deadlock. - Compute UE8M0 SF (E4M3 finfo_max=448, mirrors `get_e4m3_sf_and_sf_inv`). - Cast 8 BF16 → 8 FP8 via `__nv_fp8x4_e4m3(float4)` ×2; pack into uint64. - Write 8 FP8 bytes to remote (vs 16 BF16 bytes). Lane 0 of the 16-lane group writes the SF byte to `combine_sf_buffer`. Consumer side (combine reduce): - Per-slot SF base ptr cached at slot start. - TMA-load FP8 chunk (kNumChunkBytes / 2 bytes when kUseFp8Combine). - Per uint4 (16 FP8): __ldg the SF byte for the segment; FP8 → FP16x2 via `cvt.rn.f16x2.e4m3x2`, FP16 → FP32 via `cvt.f32.f16`, then `__fmaf_rn(val, sf, acc)` for the accumulate-with-dequant. - BF16 store-buffer layout for FP8 path: 2 BF16 uint4 per input uint4 (16 elements → 2 × 8 BF16 stripes), at indices (j*32+lane)*2 + {0,1}. Total store uint4/lane same as BF16 path (kNumChunkUint4Bf16 / 32). Validation: - Microbench (`ptx/d_combine_reduce_v{1,2}_*`): - v1 BF16 baseline: 6,895 cycles/token, max_abs=0 (perfect). - v2 FP8 + UE8M0 SF: correctness PASS (max_abs=0 vs host reference that uses the same FP8 quant), 50% NVLink bytes savings. - Single-GPU iso bench (8x B300, fp4_mxf4 vs fp4_mxf4+combine): - b=128: 364 us → 359 us (+1.5%) - b=512: 377 us → 386 us (-2.2%) - b=2048: 710 us → 739 us (-3.9%) Single-GPU is compute-bound (no NVLink saving); production is the point of the change. - E2E DeepSeek-V4-Pro on 8x B300 (b=8192 input, 1024 output): - b=512: 91.92 s (FP8) → 78.37 s (FP4+MXF4+FP8combine) — +17.3% - b=2048: 259.4 s (FP8) → 238.2 s — +8.9% - b=4096: 489.5 s (FP8) → 444.2 s — +10.2% Sentinel test (FP4 acts vs FP8 acts): rel-RMSE <= 0.5 still passes. Numerical: rel-RMSE on synthetic random init = 0.027 (combine FP8 vs BF16 baseline, w/o SwiGLU clamping → tail outliers). Real activations post-SwiGLU + topk-weighting are bounded; production accuracy parity preserved (same GSM8K results as FP4 baseline). * Combine reduce: HFMA path (FP16 accumulator + fma.f16x2) Switch the FP8 combine reduce inner loop from FP32 accumulator + scalar fma to FP16x2 accumulator + hfma.f16x2. Halves the per-element op count and halves the accumulator register pressure (94 regs vs 138 regs). Inner loop, before: cvt.rn.f16x2.e4m3x2 (FP8x2 → FP16x2) cvt.f32.f16 ×2 (FP16 → FP32) fma.rn.f32 ×2 (acc += sf_f32 * f32_val) = 5 ops per FP8x2 (= 2 elements) After: cvt.rn.f16x2.e4m3x2 (FP8x2 → FP16x2) fma.rn.f16x2 (acc_fp16x2 += sf_pair * f16x2) = 2 ops per FP8x2 SF in FP16: UE8M0 byte → 1.0 * 2^(byte-127), packed as FP16 with bias 15. Out-of-range SFs (byte < 112 or > 142) clamp to 0 / FP16-max — production activations post-SwiGLU + topk-weighting fit comfortably in FP16 range. End cast: FP16x2 → __half22float2 → __float22bfloat162_rn for the gmem write-back (BF16 output unchanged). Microbench (`ptx/d_combine_reduce_v3_fp8_hfma`): v1 BF16 baseline: 6,895 cycles/token v2 FP8 + FP32 acc: 10,797 cycles/token (+57% vs v1) v3 FP8 + FP16 HFMA: **5,799 cycles/token (-16% vs v1, -46% vs v2)** E2E DeepSeek-V4-Pro 8x B300, 8K input + 1024 output: | batch | FP4+MXF4 | combine FP32 | combine HFMA | |------:|---------:|-------------:|-------------:| | 512 | — | 7,526 | 7,350 | | 2048 | 9,814 | 9,903 | **9,992** | | 4096 | 10,418 | 10,622 | **10,699** | HFMA wins at 2048/4096; ~tie at 512. Worth keeping as the default. Numerical: v3 microbench correctness max_abs=0.0625, rel_rmse=3.8e-4 vs the FP32 reference. Production activations: still within sentinel tolerance (rel-RMSE ≤ 0.5 vs FP8 baseline). * Revert "Combine reduce: HFMA path (FP16 accumulator + fma.f16x2)" This reverts commit 48e8101. --------- Co-authored-by: pranjalssh <adkz.photos@gmail.com> (cherry picked from commit 8fc78b4)
Add swapAB code paths for small-batch SM90 FP8 MegaMoE and remove ptxas C7510 sources from hot device code. (cherry picked from commit 0074938)
(cherry picked from commit b230085)
(cherry picked from commit 34fe473)
(cherry picked from commit 15a6f42)
Add the SM90 FP8xFP4 MegaMoE runtime, kernel path, Python API, Hopper correctness and benchmark coverage, tuned runtime decode heuristics, swapAB support, synchronization/spill fixes, and the SM90 MegaMoE alignment export.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…plementation # Conflicts: # csrc/apis/sm90_mega.hpp # sgl_deep_gemm/run_tests.sh













Summary
Add SM90/Hopper FP8 MegaMoE decode support to DeepGEMM.
This change introduces a fused SM90 MegaMoE kernel path for FP8 weights and FP8 activations, including the C++/CUDA implementation, JIT entry point, Python API wrapper, Hopper-specific weight transform, scheduling heuristics, and a unified Hopper accuracy/performance test script.
Main Changes
Added the SM90 FP8 MegaMoE fused kernel implementation:
deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuhcsrc/jit_kernels/impls/sm90_fp8_mega_moe.hppExtended MegaMoE API bindings with
fp8_mega_moefor SM90.Added Hopper-specific Python entry points:
deep_gemm.fp8_mega_moedeep_gemm.transform_weights_for_mega_moe_sm90Added SM90 MegaMoE scheduling/config heuristics.
Updated MegaMoE symmetric buffer handling for SM90 FP32 scale-factor layouts.
Added
tests/test_mega_moe_hopper.py, covering:DeepSeekV4Flash(8 card H20)
DeepSeekV4Pro(8 card H20)
Benchmark:DeepSeekV4Flash CP8/EP8
SLO-Compliant Total Throughput
Max Throughput
@LyricZhao Could you help review this PR when you have time? If the patch is too large for one PR, I’m happy to split it into smaller parts following your preference.