Skip to content

feat: support sm90_fp8_fp4 kernel#332

Open
zhangxiaolei123456 wants to merge 41 commits into
deepseek-ai:mainfrom
zhangxiaolei123456:main_hopper_fp8_fp4
Open

feat: support sm90_fp8_fp4 kernel#332
zhangxiaolei123456 wants to merge 41 commits into
deepseek-ai:mainfrom
zhangxiaolei123456:main_hopper_fp8_fp4

Conversation

@zhangxiaolei123456

@zhangxiaolei123456 zhangxiaolei123456 commented May 11, 2026

Copy link
Copy Markdown

For contiguous kernel
direct FP32 B scale case: b.second shape = [groups, N, K/128]

groups m/group n k W4 us W4 GB/s W4 diff FP8 us FP8 GB/s FP8 diff Speedup
8 256 4096 7168 529 296 0.0000 512 536 0.0335 0.97x
8 512 4096 7168 1036 182 0.0000 925 331 0.0338 0.89x
8 1024 4096 7168 1978 128 0.0000 1884 196 0.0338 0.95x
8 2048 4096 7168 3875 98 0.0000 3643 137 0.0336 0.94x
16 256 4096 7168 1041 301 0.0000 927 592 0.0339 0.89x
16 512 4096 7168 1988 190 0.0000 1890 324 0.0335 0.95x
16 1024 4096 7168 3872 130 0.0000 3657 202 0.0339 0.94x
16 2048 4096 7168 7689 99 0.0000 7174 139 0.0337 0.93x
24 256 4096 7168 1489 316 0.0000 1408 584 0.0336 0.95x
24 512 4096 7168 2947 192 0.0000 2781 330 0.0337 0.94x
24 1024 4096 7168 5762 131 0.0000 5426 205 0.0340 0.94x
24 2048 4096 7168 11379 100 0.0000 10608 141 0.0337 0.93x
32 256 4096 7168 1979 317 0.0000 1882 583 0.0339 0.95x
32 512 4096 7168 3852 196 0.0000 3637 337 0.0338 0.94x
32 1024 4096 7168 7630 132 0.0000 7126 208 0.0339 0.93x
32 2048 4096 7168 15137 100 0.0000 14094 141 0.0338 0.93x

For mask kernel
direct FP32 B scale case: b.second shape = [groups, N, K/128]

groups m/group n k W4 us W4 GB/s W4 diff FP8 us FP8 GB/s FP8 diff Speedup
8 1 4096 7168 97 1289 0.0000 134 1803 0.0344 1.39x
8 4 4096 7168 97 1287 0.0000 134 1817 0.0350 1.37x
8 8 4096 7168 97 1301 0.0000 133 1824 0.0344 1.38x
8 16 4096 7168 101 1251 0.0000 134 1829 0.0337 1.32x
8 32 4096 7168 113 1142 0.0000 133 1846 0.0344 1.18x
8 1 7168 2048 52 1195 0.0000 79 1540 0.0363 1.50x
8 4 7168 2048 53 1196 0.0000 79 1548 0.0345 1.49x
8 8 7168 2048 52 1219 0.0000 79 1555 0.0340 1.51x
8 16 7168 2048 56 1161 0.0000 79 1567 0.0348 1.42x
8 32 7168 2048 60 1110 0.0000 79 1596 0.0347 1.31x

direct E8M0 B scale case: b.second shape = [groups, N, K/32]

groups m/group n k W4 us W4 GB/s W4 diff FP8 us FP8 GB/s FP8 diff Speedup
8 1 4096 7168 130 1127 0.0000 132 1843 0.0344 1.01x
8 4 4096 7168 131 1122 0.0000 131 1849 0.0351 1.00x
8 8 4096 7168 131 1125 0.0000 132 1849 0.0344 1.00x
8 16 4096 7168 161 925 0.0000 131 1858 0.0337 0.82x
8 1 7168 2048 68 1086 0.0000 77 1566 0.0349 1.14x
8 4 7168 2048 68 1085 0.0000 77 1573 0.0356 1.13x
8 8 7168 2048 68 1097 0.0000 77 1584 0.0346 1.14x
8 16 7168 2048 80 947 0.0000 78 1589 0.0352 0.97x

@JoyFuture

Copy link
Copy Markdown

Hi, thanks for the great work on the SM90 FP8xFP4 kernels.

I have a question about the contiguous grouped GEMM prefill path. Some MXFP4 MoE models, such as DeepSeek-V4 and MiMoV2.5, use FP4 weight scales with K-group size 32, while the current SM90 FP8xFP4 contiguous grouped GEMM seems to mainly target gran_k_b=128.

Is there any plan to support gran_k_b=32 for m_grouped_fp8_fp4_gemm_nt_contiguous_sm90_fused_wgmma?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants