Skip to content

[PyTorch] Add workaround for cuteDSL stride requirement for zero-token expert#2947

Merged
ksivaman merged 2 commits intoNVIDIA:mainfrom
ksivaman:cutedsl_zero_token_stride_war
May 1, 2026
Merged

[PyTorch] Add workaround for cuteDSL stride requirement for zero-token expert#2947
ksivaman merged 2 commits intoNVIDIA:mainfrom
ksivaman:cutedsl_zero_token_stride_war

Conversation

@ksivaman
Copy link
Copy Markdown
Member

Description

cudnn-frontend and cutedsl do not relax their stride divisibility requirements for input tensors with 0 elements for wgrad. This is a workaround that would be removed after the proper fix is made in cutedsl.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Create dummy empty tensors to pass to cuDNN for the case where we have zero tokens.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@ksivaman ksivaman requested a review from timmoon10 April 30, 2026 21:22
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 30, 2026

Greptile Summary

This PR adds a workaround in _cudnn_compute_wgrad to handle the zero-token expert case, where cuteDSL/cudnn-frontend enforces stride divisibility requirements even on 0-element tensors. The fix creates dummy empty tensors with compliant strides for a_tensor, b_tensor, sfa_tensor, and sfb_tensor when total_tokens == 0, bypassing the kernel's validation while producing a correct (no-op) result.

Confidence Score: 5/5

Safe to merge — the change is a narrowly scoped, no-op workaround that only activates when total_tokens == 0 and leaves the hot path untouched.

No P0 or P1 issues found. The zero-token guard correctly creates 0-element tensors with cuteDSL-compliant strides. The non-zero path is unchanged. The b_tensor stride of (in_features, 1) is consistent with the non-zero path and is always a multiple of the required alignment given MXFP8's block-size constraints. Previously flagged concerns (missing TODO, unexplained constant 16) are already tracked in existing review threads.

No files require special attention beyond the previously noted style concerns.

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py Adds zero-token guard in _cudnn_compute_wgrad; dummy empty tensors with cuteDSL-compliant strides are created and forwarded to the wgrad kernel, which is a no-op for 0 tokens. Refactors sfa_leading_dim/sfb_leading_dim to be computed before the branch.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["_cudnn_compute_wgrad called"] --> B{"total_tokens == 0?"}
    B -- Yes --> C["Create dummy empty_strided tensors\n(a, b, sfa, sfb)\nwith cuteDSL-compliant strides"]
    B -- No --> D["Slice real data from\ngrouped_dy & grouped_x\n(unchanged path)"]
    C --> E["wgrad_kernel_fn\n(no-op: 0 tokens → no\ngradient contribution)"]
    D --> E
    E --> F{"single_grouped_weight?"}
    F -- Yes --> G["dense mode:\nwgrad_tensor written"]
    F -- No --> H["discrete mode:\nper-expert wgrad_ptrs written"]
Loading

Reviews (2): Last reviewed commit: "Merge branch 'main' into cutedsl_zero_to..." | Re-trigger Greptile

Comment on lines +64 to +67
if total_tokens == 0:
# A workaround for the case with zero-token experts.
# Even for this case, cuteDSL still requires the same
# stride requirements for the input and scale tensors.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Missing TODO for temporary workaround

The PR description states this workaround will be removed once the upstream fix lands in cutedsl, but the in-code comment has no corresponding TODO or issue-tracker reference. Without one, there's no actionable reminder to clean this up once the upstream fix is released.

Suggested change
if total_tokens == 0:
# A workaround for the case with zero-token experts.
# Even for this case, cuteDSL still requires the same
# stride requirements for the input and scale tensors.
if total_tokens == 0:
# TODO: Remove this workaround once cuteDSL relaxes stride
# divisibility requirements for zero-element tensors (tracked in
# <upstream issue link>).
# A workaround for the case with zero-token experts.
# Even for this case, cuteDSL still requires the same
# stride requirements for the input and scale tensors.

Comment on lines +69 to +84
a_tensor = torch.empty_strided((out_features, 0), (16, 1), dtype=fp8_dtype, device=device)
b_tensor = torch.empty_strided(
(0, in_features), (in_features, 1), dtype=fp8_dtype, device=device
)
sfa_tensor = torch.empty_strided(
(sfa_leading_dim, 0),
(16, 1),
dtype=torch.float8_e8m0fnu,
device=device,
)
sfb_tensor = torch.empty_strided(
(sfb_leading_dim, 0),
(16, 1),
dtype=torch.float8_e8m0fnu,
device=device,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Hardcoded stride 16 undocumented

The value 16 is used as the leading stride for both a_tensor and the scale tensors (sfa_tensor, sfb_tensor) in the zero-token path, but there is no comment explaining why 16 specifically satisfies cuteDSL's divisibility requirement. In the non-zero path the leading stride of a_tensor is 1 (column-major after transpose), so this value is not derived from the tensor layout. If the cuteDSL requirement ever changes (e.g. requires 32 or 128 alignment), this silent constant will be wrong without any indication of why it was chosen. A brief comment citing the minimum stride constraint would make future maintenance safer.

@ksivaman
Copy link
Copy Markdown
Member Author

/te-ci pytorch

Copy link
Copy Markdown
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ksivaman ksivaman merged commit 36fc336 into NVIDIA:main May 1, 2026
21 of 24 checks passed
@ksivaman ksivaman deleted the cutedsl_zero_token_stride_war branch May 1, 2026 06:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants