[Graph] Support member ndarrays for qd.checkpoint and qd.graph_do_while#760
[Graph] Support member ndarrays for qd.checkpoint and qd.graph_do_while#760hughperkins wants to merge 5 commits into
Conversation
…h_do_while() arguments `qd.checkpoint(yield_on=...)` and `qd.graph_do_while(...)` previously required the argument to be a bare kernel parameter (`ast.Name`). With this change they also accept attribute chains -- both `@qd.data_oriented` member ndarrays (`self.flag`, `self.counter`) and `@dataclasses.dataclass` parameter members (`params.flag`, `params.counter`) -- resolved to a flat C++ arg-id at AST-build time via a new shared `ASTTransformer._resolve_ndarray_kernel_arg_id` helper that builds the expression and reads the resolved `ExternalTensorExpression.arg_id` via a new `get_external_tensor_arg_id` accessor on `export_lang.cpp`. Any attribute chain that flattens to a kernel ndarray argument works the same way as a bare parameter name, so users no longer have to forward flag / counter members as top-level kernel parameters. The launch path now forwards `Kernel.checkpoint_yield_on_cpp_arg_ids` and `GraphDoWhileLevel.cond_cpp_arg_id` directly to the launch context, removing the per-launch name-matching step. The fast-cache schema bumps to v3 to round-trip the AST-resolved arg-ids alongside the existing graph_do_while level table and the checkpoint yield_on / user-label tables.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: f785888193
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| if cache_value.checkpoint_yield_on_args: | ||
| self.checkpoint_yield_on_args = list(cache_value.checkpoint_yield_on_args) | ||
| self.checkpoint_yield_on_cpp_arg_ids = list(cache_value.checkpoint_yield_on_cpp_arg_ids) | ||
| self.checkpoint_user_labels_by_cp_id = list(cache_value.checkpoint_user_labels_by_cp_id) |
There was a problem hiding this comment.
Clear checkpoint metadata on empty fast-cache hits
When a cached specialization has no checkpoint entries, this guard skips restoring the cached tables, leaving whatever checkpoint_yield_on_args and label metadata were populated by a previously loaded specialization of the same Kernel object. Because fast-cache hits run with only_parse_function_def and do not walk/reset the body, a later specialization with no checkpoints can still forward stale yield_on arg ids, return a GraphStatus, or accept resume labels for checkpoints that are not in the cached kernel. Restore these lists unconditionally, including the empty-list case.
Useful? React with 👍 / 👎.
…ema-v3 round-trip Follows up on the member-ndarray support commit with a wider test surface: - Behavioural yield/resume for `yield_on=params.flag` (dataclass mirror of the existing data_oriented test). - Error paths for `yield_on=self.nonexistent_attr`, `qd.graph_do_while(self.nonexistent_attr)`, and `yield_on=self.scalar_attr` (non-ndarray) -- pins the user-facing diagnostic for the attribute forms. - Nested `qd.graph_do_while(self.outer)` containing `qd.graph_do_while(self.inner)` -- exercises the level-table machinery with `@qd.data_oriented` member ndarrays end-to-end. - Direct ``CacheValue`` round-trip unit test for schema v3 (`cachevalue-v3-ast-resolved-ids`): covers the new 3-tuple `graph_do_while_levels` + `checkpoint_yield_on_args` / `checkpoint_yield_on_cpp_arg_ids` / `checkpoint_user_labels_by_cp_id` fields the loader/storer now plumb through. - Cross-process fast-cache restore test for a `@qd.kernel(graph=True, checkpoints=True, fastcache=True)` kernel with `yield_on=self.flag` -- without the schema-v3 restore the launch path's `forward_yield_on_table_to_ctx` would be a no-op and yield/resume would silently break on fast-cached checkpoint kernels. Also fixes an existing `test_src_hasher_store_validate` assertion that indexed `src_hasher.load(...)` by position; the loader now returns a `CacheValue` (or `None`) so the test is updated to use attribute access plus asserts the new default-empty fields are present on the round-tripped object.
9bf106a to
d3de129
Compare
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: d3de1296a7
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| # translate `from_checkpoint=` labels without re-running the AST transformer. | ||
| checkpoint_yield_on_args: list[str | None] = [] | ||
| checkpoint_yield_on_cpp_arg_ids: list[int] = [] | ||
| checkpoint_user_labels_by_cp_id: list[int | None] = [] |
There was a problem hiding this comment.
Preserve IntEnum labels when restoring fast-cache checkpoints
When a checkpoint uses an IntEnum cp_id and the kernel is restored from src_ll_cache, this new cache field is serialized through JSON as plain integers and _try_load_fastcache restores checkpoint_user_labels_by_cp_id as [1] rather than [Stage.LOAD]. maybe_build_graph_status() then returns the raw int for status.checkpoint on cache hits, breaking the documented/API contract that qd.checkpoint(Stage.X, ...) round-trips the enum value rather than the underlying int. Persist enough enum metadata/expression information to reconstruct the label, or avoid lossy restoration for enum labels.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Good catch -- confirmed and fixed in 58bba23.
Pydantic coerces IntEnum to int at CacheValue.__init__ time (the field is typed list[int | None]), so just persisting the int column was lossy even before JSON. Schema bumped to cachevalue-v4-intenum-qualnames, which adds a parallel checkpoint_user_label_enum_qualnames column. src_hasher.store derives the per-slot module.ClassQualName.MEMBER string from the live label list (still holding the original IntEnum instances at store time, before pydantic strips identity), and _resolve_intenum_member re-imports the enum class via importlib on load. Mismatch / failed import (enum moved or renamed since the cache was written) falls back to the persisted int rather than raising, so stale caches degrade gracefully.
Tests:
test_checkpoint_fastcache_preserves_intenum_label_identity-- subprocess cache miss + hit, assertsisinstance(label, _FastcacheStage)after restore (not just int equality).test_src_hasher_intenum_qualname_round_trip-- directCacheValueunit test for mixed IntEnum / None / plain-int slots, qualname derivation, and the resolver fallback.
Both pass on x64 and CUDA on the cluster. Older v3 caches just invalidate via the version bump, no migration needed.
…ache restore Bumps fast-cache schema to v4 (`cachevalue-v4-intenum-qualnames`) to fix a P2 regression flagged on PR #760 where ``qd.checkpoint(Stage.X, ...)`` round-tripped through fast-cache as the raw ``int`` rather than the original ``IntEnum`` member. Pydantic coerces ``IntEnum`` to ``int`` at ``CacheValue`` construction time (the field is typed ``list[int | None]``), so persisting only the int column was lossy. The fix stores a parallel ``checkpoint_user_label_enum_qualnames`` column with the original member's ``module.ClassQualName.MEMBER`` string and rebuilds the enum on load via ``_resolve_intenum_member`` (importlib + attribute walk; falls back to the persisted int if the enum was moved/renamed since the cache was written, so stale caches degrade gracefully rather than crashing). The store-side helper ``_intenum_member_qualname`` returns ``None`` for plain-int labels so non-enum users pay nothing. Tests: - ``test_checkpoint_fastcache_preserves_intenum_label_identity`` -- end-to-end subprocess cache miss + hit, asserts ``isinstance(label, _FastcacheStage)`` after fast-cache restore (not just int equality). - ``test_src_hasher_intenum_qualname_round_trip`` -- direct ``CacheValue`` unit test covering mixed IntEnum/None/int label slots, the parallel qualname derivation, and the fallback when a qualname no longer resolves. Older v3 caches drop into the same raw-int fallback path the loader uses for plain-int labels (the missing column defaults to an empty list, padded to None per slot), so no migration is required -- the version bump just invalidates them via ``create_cache_key``.
…ndarray_arg_resolver.py CI's `Check feature factorization` flagged the ~40-line `ASTTransformer._resolve_ndarray_kernel_arg_id` static method as carving a new feature into the central 1705-line `ast_transformer.py`. The bot's suggested fix matches the existing pattern adjacent to `_is_checkpoint_call` (thin forwarding wrapper on `ASTTransformer`, real logic in a sibling `ast_transformers/*_transformer.py` file). Moved the resolver to `python/quadrants/lang/ast/ast_transformers/ndarray_arg_resolver.py` as a free function `resolve_ndarray_kernel_arg_id`. `ASTTransformer._resolve_ndarray_kernel_arg_id` is now a one-line forwarding wrapper (kept so existing call sites in `build_While` and any third-party callers continue to work), and `CheckpointTransformer.build_checkpoint_with` imports the free function directly instead of going through the wrapper. The local-import dance dodges the `ast_transformers -> ast_transformer` cycle the same way the existing `_is_checkpoint_call` / `CheckpointTransformer` split does.
qd.checkpoint(yield_on=...)andqd.graph_do_while(...)previously required the argument to be a bare kernel parameter (ast.Name). With this change they also accept attribute chains -- both@qd.data_orientedmember ndarrays (self.flag,self.counter) and@dataclasses.dataclassparameter members (params.flag,params.counter) -- resolved to a flat C++ arg-id at AST-build time via a new sharedASTTransformer._resolve_ndarray_kernel_arg_idhelper that builds the expression and reads the resolvedExternalTensorExpression.arg_idvia a newget_external_tensor_arg_idaccessor onexport_lang.cpp. Any attribute chain that flattens to a kernel ndarray argument works the same way as a bare parameter name, so users no longer have to forward flag / counter members as top-level kernel parameters.The launch path now forwards
Kernel.checkpoint_yield_on_cpp_arg_idsandGraphDoWhileLevel.cond_cpp_arg_iddirectly to the launch context, removing the per-launch name-matching step. The fast-cache schema bumps to v3 to round-trip the AST-resolved arg-ids alongside the existing graph_do_while level table and the checkpoint yield_on / user-label tables.Issue: #
Brief Summary
copilot:summary
Walkthrough
copilot:walkthrough