feat(shmem/sdma): implement address-based device putmem_nbi_signal on…#445
feat(shmem/sdma): implement address-based device putmem_nbi_signal on…#445zjing14 wants to merge 1 commit into
Conversation
… the SDMA transport The device-side ShmemPutMemNbiSignalBlockKernel<SDMA> address-based overloads were stubs (TODO), and DISPATCH_TRANSPORT_TYPE_WITH_BOOL asserted on SDMA, so a device putmem_nbi_signal targeting an SDMA peer hit assert(false). This wires SDMA through the bool dispatch and implements the address-based signal-put as a COPY_LINEAR(source -> peer dest) followed by an ATOMIC on the peer flag, enqueued on the SAME SDMA queue so the DMA engine executes the flag update strictly after the copy completes (in-order queue). Result: a CU-free per-tile push+signal usable to drive an in-kernel wait_until gate (e.g. a fused all-gather + GEMM consumer) without consuming compute units. Notes: - Thread scope resolves the heap object from globalGpuStates and computes peer offsets from heapBaseAddr, matching the existing address-based SDMA paths. - Uses atomic INCREMENT (monotonic-generation semantics); consumers wait `flag >= gen`. signalValue/signalOp are accepted for API parity but the SDMA path currently only wraps increment. - onlyOneSignal=false forwards to the =true path. Validated on 4x gfx950 (MI350) with a FlyDSL SDMA-signal probe and a fused AG+GEMM PoC.
jhchouuu
left a comment
There was a problem hiding this comment.
LGTM.
Merging is fine. Follow-ups for us to complete later:
- Implement the thread/warp scopes (only block-level is done now; the macro also expands those, so they currently hit a silent no-op stub).
- Support the full signal semantics (signalValue/signalOp are currently ignored, always += 1).
- Reduce the runtime dispatch overhead from the added SDMA branch (extra cost and code size).
| inline __device__ void ShmemPutMemNbiSignalBlockKernel<application::TransportType::SDMA, true>( | ||
| const void* dest, const void* source, size_t bytes, const void* signalDest, | ||
| uint64_t signalValue, core::atomicType signalOp, int pe, int qpId) { |
There was a problem hiding this comment.
Only the block-level version is implemented; the thread and warp versions aren't yet. The macro expansion probably expands the thread/warp paths too, so there might be a small issue here. But it's fine to merge for now, we'll add the complete functionality in a follow-up.
| } else if (transportType == application::TransportType::SDMA) { \ | ||
| func<application::TransportType::SDMA, boolParam>(__VA_ARGS__); \ |
There was a problem hiding this comment.
This adds a bit more device-side runtime overhead to this API. Previously the branch was just between P2P and IBGDA; now there's an extra SDMA branch, which also grows the code size. There's not much we can do about it for now, it's inherent to runtime dispatch. We'll look into addressing this in a follow-up.
| uint64_t off = 0; | ||
| uint64_t base = handle.ReserveQueueSpace(sizeof(SDMA_PKT_ATOMIC), off); | ||
| uint64_t wptr = base; | ||
| auto pkt = anvil::CreateAtomicIncPacket(sigPtr); |
There was a problem hiding this comment.
Only the increment semantics are implemented here, signalValue/signalOp are ignored (always atomic += 1). Fine for now; we'll complete the remaining ops in a follow-up.
|
And now CI system has some issues, is being addressed. |
… the SDMA transport
The device-side ShmemPutMemNbiSignalBlockKernel address-based overloads were stubs (TODO), and DISPATCH_TRANSPORT_TYPE_WITH_BOOL asserted on SDMA, so a device putmem_nbi_signal targeting an SDMA peer hit assert(false).
This wires SDMA through the bool dispatch and implements the address-based signal-put as a COPY_LINEAR(source -> peer dest) followed by an ATOMIC on the peer flag, enqueued on the SAME SDMA queue so the DMA engine executes the flag update strictly after the copy completes (in-order queue). Result: a CU-free per-tile push+signal usable to drive an in-kernel wait_until gate (e.g. a fused all-gather + GEMM consumer) without consuming compute units.
Notes:
flag >= gen. signalValue/signalOp are accepted for API parity but the SDMA path currently only wraps increment.Validated on 4x gfx950 (MI350) with a FlyDSL SDMA-signal probe and a fused AG+GEMM PoC.
Motivation
Technical Details
Test Plan
Test Result
Submission Checklist