Feat: support ONNX exportation and loading#103
Conversation
There was a problem hiding this comment.
Pull request overview
This PR adds ONNX export and inference capabilities for the PI05 policy model. The implementation separates traceable neural network operations (exported to ONNX) from non-traceable operations like tokenization and state discretization (handled externally in Python).
Changes:
- Adds
onnx_inference.pyscript for running inference with exported ONNX models - Refactors
export_to_onnx.pyto support PI05 using the Dynamo ONNX exporter with pre-tokenized inputs - Updates dtype handling in PI05 model files to use float32 during ONNX export instead of bfloat16
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| src/opentau/scripts/onnx_inference.py | New inference script for ONNX models with pre-tokenization, image preprocessing, and action sampling |
| src/opentau/scripts/export_to_onnx.py | Refactored to export PI05 models with pre-tokenized inputs using Dynamo exporter |
| src/opentau/policies/pi05/paligemma_with_expert.py | Fixed Cache import from transformers; added dtype selection for ONNX export; guards bfloat16 casting during compilation |
| src/opentau/policies/pi05/modeling_pi05.py | Added dtype selection function for ONNX export compatibility |
| src/opentau/policies/normalize.py | Skips validation assertions during ONNX export and torch compilation |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| Returns: | ||
| Action tensor of shape (batch, n_action_steps, action_dim). | ||
| """ | ||
| print("Starting forward pass of the wrapper...") |
There was a problem hiding this comment.
Debug print statements should be removed or replaced with proper logging. These print statements will be executed during ONNX export tracing, which may clutter the output. Consider using logging.debug() instead.
| # `policy.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue | ||
| # effectively has shape (n_action_steps, batch_size, *), hence the transpose. | ||
| actions = actions.transpose(0, 1) | ||
| print("Finished forward pass of the wrapper") |
There was a problem hiding this comment.
Debug print statement should be removed or replaced with proper logging. Consider using logging.debug() instead.
| # `policy.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue | ||
| # effectively has shape (n_action_steps, batch_size, *), hence the transpose. | ||
| actions = actions.transpose(0, 1) |
There was a problem hiding this comment.
The comment states "policy.model.forward returns a (batch_size, n_action_steps, action_dim) tensor" but this is inside the forward method of the wrapper, and the code is calling policy.model.sample_actions, not policy.model.forward. The comment appears to be outdated or incorrect. Additionally, the transpose operation converts from (batch, n_action_steps, action_dim) to (n_action_steps, batch, action_dim), which contradicts the ONNX export that declares the output as "actions" without specifying this transposition in the documentation.
| providers = provider or ( | ||
| ["CUDAExecutionProvider", "CPUExecutionProvider"] | ||
| if ort.get_device() == "GPU" | ||
| else ["CPUExecutionProvider"] | ||
| ) |
There was a problem hiding this comment.
The provider parameter is a string but is assigned directly to providers which expects a list. When a user passes a string (e.g., "CUDAExecutionProvider"), it will be used as-is instead of being wrapped in a list. This should be: providers = [provider] if provider else (...)
| *images: Variable number of image tensors, each of shape (batch, 3, H, W). | ||
|
|
||
| Returns: | ||
| Action tensor of shape (batch, n_action_steps, action_dim). |
There was a problem hiding this comment.
The docstring states the return shape is (batch, n_action_steps, action_dim), but the actual return shape after the transpose on line 118 is (n_action_steps, batch, action_dim). This inconsistency will confuse users of the exported ONNX model. The docstring should be updated to match the actual output shape.
| Action tensor of shape (batch, n_action_steps, action_dim). | |
| Action tensor of shape (n_action_steps, batch, action_dim). |
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
What this does
Explain what this PR does. Feel free to tag your PR with the appropriate label(s).
Examples:
How it was tested
Explain/show how you tested your changes.
Examples:
test_somethingintests/test_stuff.py.new_featureand checked that training converges with policy X on dataset/environment Y.some_function, it now runs X times faster than previously.How to checkout & try? (for the reviewer)
Provide a simple way for the reviewer to try out your changes.
Examples:
Checklist
Note: Before submitting this PR, please read the contributor guideline.