[PyTorch] Add workaround for cuteDSL stride requirement for zero-token expert#2947
Conversation
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Greptile SummaryThis PR adds a workaround in Confidence Score: 5/5Safe to merge — the change is a narrowly scoped, no-op workaround that only activates when total_tokens == 0 and leaves the hot path untouched. No P0 or P1 issues found. The zero-token guard correctly creates 0-element tensors with cuteDSL-compliant strides. The non-zero path is unchanged. The b_tensor stride of (in_features, 1) is consistent with the non-zero path and is always a multiple of the required alignment given MXFP8's block-size constraints. Previously flagged concerns (missing TODO, unexplained constant 16) are already tracked in existing review threads. No files require special attention beyond the previously noted style concerns. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["_cudnn_compute_wgrad called"] --> B{"total_tokens == 0?"}
B -- Yes --> C["Create dummy empty_strided tensors\n(a, b, sfa, sfb)\nwith cuteDSL-compliant strides"]
B -- No --> D["Slice real data from\ngrouped_dy & grouped_x\n(unchanged path)"]
C --> E["wgrad_kernel_fn\n(no-op: 0 tokens → no\ngradient contribution)"]
D --> E
E --> F{"single_grouped_weight?"}
F -- Yes --> G["dense mode:\nwgrad_tensor written"]
F -- No --> H["discrete mode:\nper-expert wgrad_ptrs written"]
Reviews (2): Last reviewed commit: "Merge branch 'main' into cutedsl_zero_to..." | Re-trigger Greptile |
| if total_tokens == 0: | ||
| # A workaround for the case with zero-token experts. | ||
| # Even for this case, cuteDSL still requires the same | ||
| # stride requirements for the input and scale tensors. |
There was a problem hiding this comment.
Missing TODO for temporary workaround
The PR description states this workaround will be removed once the upstream fix lands in cutedsl, but the in-code comment has no corresponding TODO or issue-tracker reference. Without one, there's no actionable reminder to clean this up once the upstream fix is released.
| if total_tokens == 0: | |
| # A workaround for the case with zero-token experts. | |
| # Even for this case, cuteDSL still requires the same | |
| # stride requirements for the input and scale tensors. | |
| if total_tokens == 0: | |
| # TODO: Remove this workaround once cuteDSL relaxes stride | |
| # divisibility requirements for zero-element tensors (tracked in | |
| # <upstream issue link>). | |
| # A workaround for the case with zero-token experts. | |
| # Even for this case, cuteDSL still requires the same | |
| # stride requirements for the input and scale tensors. |
| a_tensor = torch.empty_strided((out_features, 0), (16, 1), dtype=fp8_dtype, device=device) | ||
| b_tensor = torch.empty_strided( | ||
| (0, in_features), (in_features, 1), dtype=fp8_dtype, device=device | ||
| ) | ||
| sfa_tensor = torch.empty_strided( | ||
| (sfa_leading_dim, 0), | ||
| (16, 1), | ||
| dtype=torch.float8_e8m0fnu, | ||
| device=device, | ||
| ) | ||
| sfb_tensor = torch.empty_strided( | ||
| (sfb_leading_dim, 0), | ||
| (16, 1), | ||
| dtype=torch.float8_e8m0fnu, | ||
| device=device, | ||
| ) |
There was a problem hiding this comment.
Hardcoded stride
16 undocumented
The value 16 is used as the leading stride for both a_tensor and the scale tensors (sfa_tensor, sfb_tensor) in the zero-token path, but there is no comment explaining why 16 specifically satisfies cuteDSL's divisibility requirement. In the non-zero path the leading stride of a_tensor is 1 (column-major after transpose), so this value is not derived from the tensor layout. If the cuteDSL requirement ever changes (e.g. requires 32 or 128 alignment), this silent constant will be wrong without any indication of why it was chosen. A brief comment citing the minimum stride constraint would make future maintenance safer.
|
/te-ci pytorch |
Description
cudnn-frontendandcutedsldo not relax their stride divisibility requirements for input tensors with 0 elements for wgrad. This is a workaround that would be removed after the proper fix is made incutedsl.Type of change
Changes
Checklist: