Skip to content

Respect PyTorch precision when loading checkpoints#940

Open
taivu1998 wants to merge 1 commit into
Physical-Intelligence:mainfrom
taivu1998:tdv/issue-788-pytorch-dtype
Open

Respect PyTorch precision when loading checkpoints#940
taivu1998 wants to merge 1 commit into
Physical-Intelligence:mainfrom
taivu1998:tdv/issue-788-pytorch-dtype

Conversation

@taivu1998
Copy link
Copy Markdown

Summary

Fixes #788.

This PR makes PyTorch checkpoint loading respect TrainConfig.pytorch_training_precision instead of silently falling back to the nested model config's default dtype.

Root Cause

scripts/train_pytorch.py already updates the model config dtype before constructing PI0Pytorch, but BaseModelConfig.load_pytorch did not do the same. That meant a config requesting float32 PyTorch precision could still instantiate and load a PyTorch checkpoint through a bfloat16-configured module. The policy-loading path also unconditionally re-applied "bfloat16" after loading, which would keep float32 policy loading broken even if the loader was fixed.

Changes

  • Derive a non-mutating PyTorch model config in load_pytorch with dtype=train_config.pytorch_training_precision before loading safetensors.
  • Add an explicit early error for model types that are not supported by PI0Pytorch.
  • Make PyTorch policy loading re-apply the configured precision instead of hard-coding "bfloat16".
  • Add focused monkeypatched tests for both bfloat16 and float32 loader precision propagation, unsupported model-type rejection, and policy post-load precision forwarding.

Validation

  • .venv/bin/python -m pytest src/openpi/models/model_test.py -k pytorch
  • .venv/bin/python -m pytest src/openpi/policies/policy_test.py -k pytorch
  • uvx ruff check src/openpi/models/model.py src/openpi/policies/policy_config.py src/openpi/models/model_test.py src/openpi/policies/policy_test.py
  • uvx ruff format --check src/openpi/models/model.py src/openpi/policies/policy_config.py src/openpi/models/model_test.py src/openpi/policies/policy_test.py
  • git diff --check

Note: on macOS arm64, the normal uv run path attempts to install Linux-only CUDA JAX wheels from the project lockfile. I validated using a temporary environment synced with the CUDA-only packages skipped, then removed the generated environment and caches.

@taivu1998 taivu1998 marked this pull request as ready for review May 11, 2026 03:37
@jimmyt857 jimmyt857 removed their request for review May 11, 2026 04:08
Copy link
Copy Markdown
Contributor

@wadeKeith wadeKeith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clean fix for #788. Respects pytorch_training_precision during checkpoint loading instead of silently falling back to default dtype. Good test coverage. LGTM! Reviewed by Hermes Agent.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Misc bug in dtype setting in load_pytorch

2 participants