Skip to content

[Graph] Support member ndarrays for qd.checkpoint and qd.graph_do_while#760

Open
hughperkins wants to merge 5 commits into
mainfrom
hp/member-ndarray-yield-condition
Open

[Graph] Support member ndarrays for qd.checkpoint and qd.graph_do_while#760
hughperkins wants to merge 5 commits into
mainfrom
hp/member-ndarray-yield-condition

Conversation

@hughperkins

Copy link
Copy Markdown
Collaborator

qd.checkpoint(yield_on=...) and qd.graph_do_while(...) previously required the argument to be a bare kernel parameter (ast.Name). With this change they also accept attribute chains -- both @qd.data_oriented member ndarrays (self.flag, self.counter) and @dataclasses.dataclass parameter members (params.flag, params.counter) -- resolved to a flat C++ arg-id at AST-build time via a new shared ASTTransformer._resolve_ndarray_kernel_arg_id helper that builds the expression and reads the resolved ExternalTensorExpression.arg_id via a new get_external_tensor_arg_id accessor on export_lang.cpp. Any attribute chain that flattens to a kernel ndarray argument works the same way as a bare parameter name, so users no longer have to forward flag / counter members as top-level kernel parameters.

The launch path now forwards Kernel.checkpoint_yield_on_cpp_arg_ids and GraphDoWhileLevel.cond_cpp_arg_id directly to the launch context, removing the per-launch name-matching step. The fast-cache schema bumps to v3 to round-trip the AST-resolved arg-ids alongside the existing graph_do_while level table and the checkpoint yield_on / user-label tables.

Issue: #

Brief Summary

copilot:summary

Walkthrough

copilot:walkthrough

…h_do_while() arguments

`qd.checkpoint(yield_on=...)` and `qd.graph_do_while(...)` previously required the argument to be a bare kernel
parameter (`ast.Name`). With this change they also accept attribute chains -- both `@qd.data_oriented` member
ndarrays (`self.flag`, `self.counter`) and `@dataclasses.dataclass` parameter members (`params.flag`,
`params.counter`) -- resolved to a flat C++ arg-id at AST-build time via a new shared
`ASTTransformer._resolve_ndarray_kernel_arg_id` helper that builds the expression and reads the resolved
`ExternalTensorExpression.arg_id` via a new `get_external_tensor_arg_id` accessor on `export_lang.cpp`. Any
attribute chain that flattens to a kernel ndarray argument works the same way as a bare parameter name, so users
no longer have to forward flag / counter members as top-level kernel parameters.

The launch path now forwards `Kernel.checkpoint_yield_on_cpp_arg_ids` and `GraphDoWhileLevel.cond_cpp_arg_id`
directly to the launch context, removing the per-launch name-matching step. The fast-cache schema bumps to v3 to
round-trip the AST-resolved arg-ids alongside the existing graph_do_while level table and the checkpoint
yield_on / user-label tables.
@hughperkins hughperkins marked this pull request as ready for review June 24, 2026 13:13

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: f785888193

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread python/quadrants/lang/kernel.py Outdated
Comment on lines +438 to +441
if cache_value.checkpoint_yield_on_args:
self.checkpoint_yield_on_args = list(cache_value.checkpoint_yield_on_args)
self.checkpoint_yield_on_cpp_arg_ids = list(cache_value.checkpoint_yield_on_cpp_arg_ids)
self.checkpoint_user_labels_by_cp_id = list(cache_value.checkpoint_user_labels_by_cp_id)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Clear checkpoint metadata on empty fast-cache hits

When a cached specialization has no checkpoint entries, this guard skips restoring the cached tables, leaving whatever checkpoint_yield_on_args and label metadata were populated by a previously loaded specialization of the same Kernel object. Because fast-cache hits run with only_parse_function_def and do not walk/reset the body, a later specialization with no checkpoints can still forward stale yield_on arg ids, return a GraphStatus, or accept resume labels for checkpoints that are not in the cached kernel. Restore these lists unconditionally, including the empty-list case.

Useful? React with 👍 / 👎.

@hughperkins hughperkins marked this pull request as draft June 24, 2026 13:19
…ema-v3 round-trip

Follows up on the member-ndarray support commit with a wider test surface:

- Behavioural yield/resume for `yield_on=params.flag` (dataclass mirror of the existing data_oriented test).
- Error paths for `yield_on=self.nonexistent_attr`, `qd.graph_do_while(self.nonexistent_attr)`, and
  `yield_on=self.scalar_attr` (non-ndarray) -- pins the user-facing diagnostic for the attribute forms.
- Nested `qd.graph_do_while(self.outer)` containing `qd.graph_do_while(self.inner)` -- exercises the level-table
  machinery with `@qd.data_oriented` member ndarrays end-to-end.
- Direct ``CacheValue`` round-trip unit test for schema v3 (`cachevalue-v3-ast-resolved-ids`): covers the new
  3-tuple `graph_do_while_levels` + `checkpoint_yield_on_args` / `checkpoint_yield_on_cpp_arg_ids` /
  `checkpoint_user_labels_by_cp_id` fields the loader/storer now plumb through.
- Cross-process fast-cache restore test for a `@qd.kernel(graph=True, checkpoints=True, fastcache=True)` kernel
  with `yield_on=self.flag` -- without the schema-v3 restore the launch path's
  `forward_yield_on_table_to_ctx` would be a no-op and yield/resume would silently break on fast-cached
  checkpoint kernels.

Also fixes an existing `test_src_hasher_store_validate` assertion that indexed `src_hasher.load(...)` by
position; the loader now returns a `CacheValue` (or `None`) so the test is updated to use attribute access plus
asserts the new default-empty fields are present on the round-tripped object.
@hughperkins hughperkins force-pushed the hp/member-ndarray-yield-condition branch from 9bf106a to d3de129 Compare June 24, 2026 13:29
@hughperkins hughperkins marked this pull request as ready for review June 24, 2026 13:30

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: d3de1296a7

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

# translate `from_checkpoint=` labels without re-running the AST transformer.
checkpoint_yield_on_args: list[str | None] = []
checkpoint_yield_on_cpp_arg_ids: list[int] = []
checkpoint_user_labels_by_cp_id: list[int | None] = []

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve IntEnum labels when restoring fast-cache checkpoints

When a checkpoint uses an IntEnum cp_id and the kernel is restored from src_ll_cache, this new cache field is serialized through JSON as plain integers and _try_load_fastcache restores checkpoint_user_labels_by_cp_id as [1] rather than [Stage.LOAD]. maybe_build_graph_status() then returns the raw int for status.checkpoint on cache hits, breaking the documented/API contract that qd.checkpoint(Stage.X, ...) round-trips the enum value rather than the underlying int. Persist enough enum metadata/expression information to reconstruct the label, or avoid lossy restoration for enum labels.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch -- confirmed and fixed in 58bba23.

Pydantic coerces IntEnum to int at CacheValue.__init__ time (the field is typed list[int | None]), so just persisting the int column was lossy even before JSON. Schema bumped to cachevalue-v4-intenum-qualnames, which adds a parallel checkpoint_user_label_enum_qualnames column. src_hasher.store derives the per-slot module.ClassQualName.MEMBER string from the live label list (still holding the original IntEnum instances at store time, before pydantic strips identity), and _resolve_intenum_member re-imports the enum class via importlib on load. Mismatch / failed import (enum moved or renamed since the cache was written) falls back to the persisted int rather than raising, so stale caches degrade gracefully.

Tests:

  • test_checkpoint_fastcache_preserves_intenum_label_identity -- subprocess cache miss + hit, asserts isinstance(label, _FastcacheStage) after restore (not just int equality).
  • test_src_hasher_intenum_qualname_round_trip -- direct CacheValue unit test for mixed IntEnum / None / plain-int slots, qualname derivation, and the resolver fallback.

Both pass on x64 and CUDA on the cluster. Older v3 caches just invalidate via the version bump, no migration needed.

@github-actions

Copy link
Copy Markdown

…ache restore

Bumps fast-cache schema to v4 (`cachevalue-v4-intenum-qualnames`) to fix a P2 regression flagged on PR #760
where ``qd.checkpoint(Stage.X, ...)`` round-tripped through fast-cache as the raw ``int`` rather than the
original ``IntEnum`` member. Pydantic coerces ``IntEnum`` to ``int`` at ``CacheValue`` construction time (the
field is typed ``list[int | None]``), so persisting only the int column was lossy.

The fix stores a parallel ``checkpoint_user_label_enum_qualnames`` column with the original member's
``module.ClassQualName.MEMBER`` string and rebuilds the enum on load via ``_resolve_intenum_member`` (importlib +
attribute walk; falls back to the persisted int if the enum was moved/renamed since the cache was written, so
stale caches degrade gracefully rather than crashing). The store-side helper ``_intenum_member_qualname`` returns
``None`` for plain-int labels so non-enum users pay nothing.

Tests:
- ``test_checkpoint_fastcache_preserves_intenum_label_identity`` -- end-to-end subprocess cache miss + hit, asserts
  ``isinstance(label, _FastcacheStage)`` after fast-cache restore (not just int equality).
- ``test_src_hasher_intenum_qualname_round_trip`` -- direct ``CacheValue`` unit test covering mixed
  IntEnum/None/int label slots, the parallel qualname derivation, and the fallback when a qualname no longer
  resolves.

Older v3 caches drop into the same raw-int fallback path the loader uses for plain-int labels (the missing
column defaults to an empty list, padded to None per slot), so no migration is required -- the version bump just
invalidates them via ``create_cache_key``.
@github-actions

Copy link
Copy Markdown

@github-actions

Copy link
Copy Markdown

@github-actions

Copy link
Copy Markdown

…ndarray_arg_resolver.py

CI's `Check feature factorization` flagged the ~40-line `ASTTransformer._resolve_ndarray_kernel_arg_id` static
method as carving a new feature into the central 1705-line `ast_transformer.py`. The bot's suggested fix
matches the existing pattern adjacent to `_is_checkpoint_call` (thin forwarding wrapper on `ASTTransformer`,
real logic in a sibling `ast_transformers/*_transformer.py` file).

Moved the resolver to `python/quadrants/lang/ast/ast_transformers/ndarray_arg_resolver.py` as a free function
`resolve_ndarray_kernel_arg_id`. `ASTTransformer._resolve_ndarray_kernel_arg_id` is now a one-line forwarding
wrapper (kept so existing call sites in `build_While` and any third-party callers continue to work), and
`CheckpointTransformer.build_checkpoint_with` imports the free function directly instead of going through the
wrapper. The local-import dance dodges the `ast_transformers -> ast_transformer` cycle the same way the
existing `_is_checkpoint_call` / `CheckpointTransformer` split does.
@github-actions

Copy link
Copy Markdown

@github-actions

Copy link
Copy Markdown

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant