diff --git a/.github/scripts/check_accumulate_grad_sync.py b/.github/scripts/check_accumulate_grad_sync.py new file mode 100644 index 0000000..5dc318e --- /dev/null +++ b/.github/scripts/check_accumulate_grad_sync.py @@ -0,0 +1,40 @@ +# Copyright 2026 Tensor Auto Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +from utils import grep_file + +from opentau.configs.parser import wrap + + +@dataclass +class Arg: + log_path: str + expected_length: int + re_pattern: str = r"accelerator\.sync_gradients=(True|False)" + gradient_accumulation_steps: int = 2 + + +@wrap() +def main(arg: Arg) -> None: + sync_grads = grep_file(arg.log_path, arg.re_pattern, processor=bool) + assert len(sync_grads) == arg.expected_length, ( + f"Expected {arg.expected_length} sync_gradients, found {len(sync_grads)} in {arg.log_path}." + ) + assert all(sg == ((i + 1) % arg.gradient_accumulation_steps == 0) for i, sg in enumerate(sync_grads)), ( + f"Sync gradients should be set according to " + f"gradient_accumulation_steps={arg.gradient_accumulation_steps}, " + f"got {sync_grads}." + ) diff --git a/.github/scripts/check_loss_drop.py b/.github/scripts/check_loss_drop.py new file mode 100644 index 0000000..8b4d3a1 --- /dev/null +++ b/.github/scripts/check_loss_drop.py @@ -0,0 +1,99 @@ +# Copyright 2026 Tensor Auto Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass + +import numpy as np +from utils import grep_file + +from opentau.configs.parser import wrap + + +@dataclass +class Arg: + log_path: str + expected_length: int + re_pattern: str = r"mse_loss:([0-9.eE+-]+)" + gauss_sigma: float = 4.0 + gauss_truncate: float = 4.0 + pad_mode: str = "reflect" + resume_log_path: str | None = None + resume_expected_length: int | None = None + + +def gaussian_smooth( + series: list[float], sigma: float, *, truncate: float = 4.0, mode: str = "reflect" +) -> list[float]: + if sigma <= 0: + raise ValueError("sigma must be positive") + + x = np.asarray(series, dtype=np.float64) + + radius = int(math.ceil(truncate * sigma)) + k = np.arange(-radius, radius + 1, dtype=np.float64) + kernel = np.exp(-(k**2) / (2 * sigma**2)) + kernel /= kernel.sum() # normalize + + pad_width = (radius, radius) + x_padded = np.pad(x, pad_width, mode=mode) + smoothed = np.convolve(x_padded, kernel, mode="valid") + + return smoothed.tolist() + + +def check_smooth_loss(losses: list[float], expected_length: int, arg: Arg, prefix: str) -> list[float]: + print(f"{prefix} raw losses:", losses) + assert len(losses) == expected_length, ( + f"Expected {expected_length} losses, found {len(losses)} in {arg.log_path}." + ) + smoothed = gaussian_smooth(losses, arg.gauss_sigma, truncate=arg.gauss_truncate, mode=arg.pad_mode) + print(f"{prefix} smoothed losses:", smoothed) + assert smoothed[0] >= smoothed[-1], "Losses should drop over time when smoothed." + return smoothed + + +@wrap() +def main(arg: Arg): + losses = grep_file(arg.log_path, arg.re_pattern, processor=float) + smoothed = check_smooth_loss(losses, arg.expected_length, arg, "Training") + + if arg.resume_expected_length is None and arg.resume_log_path is None: + return + + if arg.resume_expected_length is None or arg.resume_log_path is None: + raise ValueError( + "Both resume_log_path and resume_expected_length must be provided if one is given. " + f"Got resume_log_path: {arg.resume_log_path}, " + f"Got resume_expected_length: {arg.resume_expected_length}, " + ) + + resume_losses = grep_file(arg.resume_log_path, arg.re_pattern, processor=float) + resume_smoothed = check_smooth_loss(resume_losses, arg.resume_expected_length, arg, "Resume") + + # resuming start should be closer to the end of the training than the start + resume_start = resume_smoothed[0] + training_start = smoothed[0] + training_end = smoothed[-1] + print( + f"{resume_start=}, {training_start=}, {training_end=}, " + f"{abs(resume_start - training_end)=}, {abs(resume_start - training_start)=}." + ) + assert abs(resume_start - training_end) <= abs(resume_start - training_start), ( + "Resuming start loss should be closer to the end of the training than the start." + ) + + +if __name__ == "__main__": + main() diff --git a/.github/scripts/check_nonzero_grad_norm.py b/.github/scripts/check_nonzero_grad_norm.py new file mode 100644 index 0000000..36a0f2d --- /dev/null +++ b/.github/scripts/check_nonzero_grad_norm.py @@ -0,0 +1,35 @@ +# Copyright 2026 Tensor Auto Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +from utils import grep_file + +from opentau.configs.parser import wrap + + +@dataclass +class Arg: + log_path: str + expected_length: int + re_pattern: str = r"grad_norm:([0-9.eE+-]+)" + + +@wrap() +def main(arg: Arg) -> None: + grad_norm = grep_file(arg.log_path, arg.re_pattern, processor=float) + assert len(grad_norm) == arg.expected_length, ( + f"Expected {arg.expected_length} grad_norms, found {len(grad_norm)} in {arg.log_path}." + ) + assert all(g > 0 for g in grad_norm), f"All grad_norms should be greater than zero, got {grad_norm}." diff --git a/.github/scripts/check_state_keys.py b/.github/scripts/check_state_keys.py new file mode 100644 index 0000000..9b16df4 --- /dev/null +++ b/.github/scripts/check_state_keys.py @@ -0,0 +1,128 @@ +# Copyright 2026 Tensor Auto Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +from opentau.configs.parser import wrap + +MISSING_KEYS = { + "hf": { + "normalize_inputs.buffer_state.max", + "normalize_inputs.buffer_state.min", + "normalize_targets.buffer_actions.mean", + "normalize_targets.buffer_actions.std", + "normalize_actions.buffer_actions.max", + "normalize_actions.buffer_actions.min", + "unnormalize_outputs.buffer_actions.mean", + "unnormalize_outputs.buffer_actions.std", + "model.paligemma_with_expert.discrete_action_embedding.weight", + "model.paligemma_with_expert.da_head.weight", + "model.paligemma_with_expert.da_head.bias", + }, + "local": None, +} + + +@dataclass +class Arg: + log_path: str + source: str + + def __post_init__(self): + if self.source not in MISSING_KEYS: + raise ValueError(f"--source must be one of {MISSING_KEYS.keys()}. Got {self.source}") + + +def parse_missing_keys(log_path: str) -> list[set[str]]: + """Parse missing keys from log file. + + The log format is: + Missing keys when loading state dict: N keys + - key1 + - key2 + ... + """ + all_key_sets = [] + current_keys = None + + with open(log_path) as f: + for line in f: + if "Missing keys when loading state dict:" in line: + # Start collecting keys for a new occurrence + if current_keys is not None: + all_key_sets.append(current_keys) + current_keys = set() + elif current_keys is not None: + # Check if line is a key entry (starts with " - ") + stripped = line.strip() + if stripped.startswith("- "): + key = stripped[2:].strip() + current_keys.add(key) + elif stripped and not stripped.startswith("-"): + # Non-empty line that's not a key entry means section ended + all_key_sets.append(current_keys) + current_keys = None + + # Don't forget the last set if file ended while collecting + if current_keys is not None: + all_key_sets.append(current_keys) + + return all_key_sets + + +def check_no_unexpected_keys(log_path: str): + """Check that 'Unexpected keys when loading state dict:' does not appear in the log.""" + print("Checking for unexpected keys") + with open(log_path) as f: + for line in f: + if "Unexpected keys when loading state dict:" in line: + raise ValueError(f"Found unexpected keys in log: {line.strip()}") + print("Passed - no unexpected keys found") + + +def check_missing_keys(key_sets: list[set[str]], source: str): + """Check that all missing key sets match the expected keys.""" + print("Checking missing keys") + expected_keys = MISSING_KEYS[source] + + if expected_keys is None: + if key_sets: + raise ValueError(f"Found missing keys but expecting none: {key_sets}") + elif not key_sets: + raise ValueError(f"No missing keys found, should be {expected_keys}") + else: + for i, keys in enumerate(key_sets): + if keys != expected_keys: + missing_from_expected = expected_keys - keys + extra_in_found = keys - expected_keys + raise ValueError( + f"Missing keys mismatch at occurrence {i + 1}:\n" + f" Expected but not found: {missing_from_expected}\n" + f" Found but not expected: {extra_in_found}" + ) + print("Passed") + + +@wrap() +def main(arg: Arg) -> None: + # Check that no unexpected keys appear + check_no_unexpected_keys(arg.log_path) + + # Parse and check missing keys + key_sets = parse_missing_keys(arg.log_path) + check_missing_keys(key_sets, arg.source) + + +if __name__ == "__main__": + main() diff --git a/.github/scripts/utils.py b/.github/scripts/utils.py new file mode 100644 index 0000000..b2b88fb --- /dev/null +++ b/.github/scripts/utils.py @@ -0,0 +1,27 @@ +# Copyright 2026 Tensor Auto Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + + +def grep_file(file: str, pattern: str, processor=None) -> list: + processor = processor or (lambda x: x) + values = [] + with open(file) as f: + for line in f: + match = re.search(pattern, line) + if not match: + continue + values.append(processor(match.group(1))) + return values diff --git a/.github/workflows/gpu_test.yml b/.github/workflows/gpu_test.yml index bcfa7bf..768a875 100644 --- a/.github/workflows/gpu_test.yml +++ b/.github/workflows/gpu_test.yml @@ -42,14 +42,14 @@ jobs: - name: Start Instance run: | - aws autoscaling set-desired-capacity --auto-scaling-group-name github-runner-asg --desired-capacity 1 + aws autoscaling set-desired-capacity --auto-scaling-group-name github-runner-asg-g6-2xlarge --desired-capacity 1 echo "Waiting for instance to be ready..." gpu-test: name: Run Pytest on GPU needs: start-runner runs-on: [g6.2xlarge] - timeout-minutes: 60 + timeout-minutes: 30 container: image: nvidia/cuda:12.2.0-devel-ubuntu22.04 @@ -110,4 +110,4 @@ jobs: - name: Stop Instance run: | - aws autoscaling set-desired-capacity --auto-scaling-group-name github-runner-asg --desired-capacity 0 + aws autoscaling set-desired-capacity --auto-scaling-group-name github-runner-asg-g6-2xlarge --desired-capacity 0 diff --git a/.github/workflows/regression_test.yml b/.github/workflows/regression_test.yml new file mode 100644 index 0000000..3f1bb9c --- /dev/null +++ b/.github/workflows/regression_test.yml @@ -0,0 +1,233 @@ +# Copyright 2026 Tensor Auto Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: Nightly Regression Tests + +on: + schedule: + # Run at 2:00 AM PST every day (10:00 AM UTC) + - cron: '0 10 * * *' + workflow_dispatch: + +permissions: + contents: read + +env: + PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:True" + MUJOCO_GL: "egl" + PYOPENGL_PLATFORM: "egl" + +jobs: + start-runner: + name: Start GPU Runner + runs-on: ubuntu-latest + permissions: + id-token: write + contents: read + steps: + - name: Configure AWS Credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: ${{ secrets.AWS_ROLE_ARN }} + aws-region: us-west-2 + + - name: Start Instance + run: | + aws autoscaling set-desired-capacity --auto-scaling-group-name github-runner-asg-g6-12xlarge --desired-capacity 1 + echo "Waiting for instance to be ready..." + + train-regression: + name: Train with Model Parallelism + needs: start-runner + runs-on: [g6.12xlarge] + timeout-minutes: 30 + + container: + image: nvidia/cuda:12.2.0-devel-ubuntu22.04 + options: --gpus all --ipc=host + + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + persist-credentials: false + + - name: Install system dependencies + run: | + apt-get update && apt-get install -y python3 python3-pip git ffmpeg libegl1 libegl-mesa0 libegl-dev libgl1 libglx-mesa0 libgles2 mesa-utils curl cmake build-essential + + - name: Install uv + uses: astral-sh/setup-uv@v5 + with: + version: "latest" + + - name: Install dependencies + run: | + uv sync --extra dev --extra libero --extra openai + + - name: Check GPU + run: nvidia-smi + + - name: Set up HuggingFace authentication + shell: bash + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + run: | + source .venv/bin/activate + huggingface-cli login --token $HF_TOKEN + + - name: Set up Wandb authentication + shell: bash + env: + WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }} + run: | + source .venv/bin/activate + wandb login $WANDB_API_KEY + wandb offline + + - name: Set up Libero Configs + shell: bash + run: | + source .venv/bin/activate + mkdir -p /tmp/libero-assets/libero/libero + export LIBERO_CONFIG_PATH="$(pwd)/.github/assets/libero" + + - name: Run Training + shell: bash + run: | + source .venv/bin/activate + opentau-train --accelerate-config=configs/examples/accelerate_deepspeed_config.yaml --config_path=configs/dev/ci_config.json --output_dir=outputs/train/ci/ 2>&1 | tee /tmp/train.log + + - name: Check Loss Drop + id: check-loss-drop + continue-on-error: true + shell: bash + run: | + source .venv/bin/activate + python3 .github/scripts/check_loss_drop.py --log_path=/tmp/train.log --expected_length=25 + echo "Loss drop confirmed" + + - name: Check Non-Zero Grad Norm + id: check-grad-norm + continue-on-error: true + shell: bash + run: | + source .venv/bin/activate + python3 .github/scripts/check_nonzero_grad_norm.py --log_path=/tmp/train.log --expected_length=25 + echo "Non-zero grad norm confirmed" + + - name: Check Accumulate Grad Sync + id: check-grad-sync + continue-on-error: true + shell: bash + run: | + source .venv/bin/activate + python3 .github/scripts/check_accumulate_grad_sync.py --log_path=/tmp/train.log --expected_length=50 + echo "Accumulate grad sync confirmed" + + - name: Check State Keys + id: check-state-keys + continue-on-error: true + shell: bash + run: | + source .venv/bin/activate + python3 .github/scripts/check_state_keys.py --log_path=/tmp/train.log --source=hf + echo "Checks for state keys passed" + + - name: Convert Checkpoint + shell: bash + run: | + source .venv/bin/activate + ./src/opentau/scripts/convert_checkpoint.sh outputs/train/ci/checkpoints/000025 + + # - name: Resume Training + # shell: bash + # run: | + # source .venv/bin/activate + # opentau-train --accelerate-config=configs/examples/accelerate_deepspeed_config.yaml --config_path=outputs/train/ci/checkpoints/000025/train_config.json --resume=true --steps=50 2>&1 | tee /tmp/resume.log + + # - name: Check Loss Drop (after Resume) + # continue-on-error: true + # shell: bash + # run: | + # source .venv/bin/activate + # python3 .github/scripts/check_loss_drop.py --log_path=/tmp/train.log --expected_length=25 --resume_log_path=/tmp/resume.log --resume_expected_length=25 + # echo "Loss drop confirmed" + + # - name: Check Non-Zero Grad Norm (after Resume) + # continue-on-error: true + # shell: bash + # run: | + # source .venv/bin/activate + # python3 .github/scripts/check_nonzero_grad_norm.py --log_path=/tmp/resume.log --expected_length=25 + # echo "Non-zero grad norm confirmed" + + # - name: Check Accumulate Grad Sync (after Resume) + # continue-on-error: true + # shell: bash + # run: | + # source .venv/bin/activate + # python3 .github/scripts/check_accumulate_grad_sync.py --log_path=/tmp/resume.log --expected_length=50 + # echo "Accumulate grad sync confirmed" + + # - name: Check State Keys (after Resume) + # continue-on-error: true + # shell: bash + # run: | + # source .venv/bin/activate + # python3 .github/scripts/check_state_keys.py --log_path=/tmp/resume.log --source=local + # echo "Checks for state keys passed" + + - name: Run Inference + shell: bash + run: | + source .venv/bin/activate + python src/opentau/scripts/inference.py --config_path=outputs/train/ci/checkpoints/000025/train_config.json + + - name: Fail if checks failed + if: always() + env: + LOSS_DROP: ${{ steps.check-loss-drop.outcome }} + GRAD_NORM: ${{ steps.check-grad-norm.outcome }} + GRAD_SYNC: ${{ steps.check-grad-sync.outcome }} + STATE_KEYS: ${{ steps.check-state-keys.outcome }} + run: | + failed="" + [ "$LOSS_DROP" == "failure" ] && failed="$failed check-loss-drop" + [ "$GRAD_NORM" == "failure" ] && failed="$failed check-grad-norm" + [ "$GRAD_SYNC" == "failure" ] && failed="$failed check-grad-sync" + [ "$STATE_KEYS" == "failure" ] && failed="$failed check-state-keys" + if [ -n "$failed" ]; then + echo "The following checks failed:$failed" + exit 1 + fi + + stop-runner: + name: Stop GPU Runner + needs: [start-runner, train-regression] + if: always() + runs-on: ubuntu-latest + permissions: + id-token: write + contents: read + steps: + - name: Configure AWS Credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: ${{ secrets.AWS_ROLE_ARN }} + aws-region: us-west-2 + + - name: Stop Instance + run: | + aws autoscaling set-desired-capacity --auto-scaling-group-name github-runner-asg-g6-12xlarge --desired-capacity 0 diff --git a/configs/dev/ci_config.json b/configs/dev/ci_config.json index f285b79..145a186 100644 --- a/configs/dev/ci_config.json +++ b/configs/dev/ci_config.json @@ -17,70 +17,42 @@ "image_resample_strategy": "nearest", "vector_resample_strategy": "nearest" }, - "env": { - "type": "libero", - "task": "libero_spatial", - "task_ids": [0, 2] - }, - "eval": { - "n_episodes": 8, - "batch_size": 8 - }, - "eval_freq": 25, "policy": { - "type": "tau0", - "pretrained_path": "lerobot/pi0", + "type": "pi05", + "pretrained_path": "TensorAuto/pi05_base", "n_obs_steps": 1, "normalization_mapping": { "VISUAL": "IDENTITY", - "STATE": "MEAN_STD", + "STATE": "MIN_MAX", "ACTION": "MEAN_STD" }, - "chunk_size": 50, - "n_action_steps": 50, + "chunk_size": 10, + "predict_response": true, + "n_action_steps": 10, "max_state_dim": 32, "max_action_dim": 32, - "cloud_vlm_latency_mean": 0.16, - "cloud_vlm_latency_std": 0.05, - "cloud_vlm_latency_lower": 0.10, - "cloud_vlm_latency_upper": 0.25, - "action_decoder_latency_mean": 0.032, - "action_decoder_latency_std": 0.010, - "action_decoder_latency_lower": 0.020, - "action_decoder_latency_upper": 0.050, - "tokenizer_max_length": 52, - "response_max_tokens": 52, - "n_cross_att_tokens": 10, "proj_width": 1024, "num_steps": 10, + "init_strategy": "expert_only_he_init", "attention_implementation": "eager", "freeze_vision_encoder": true, "train_expert_only": true, - "train_state_proj": true, - "optimizer_lr": 1e-4, - "optimizer_betas": [ - 0.9, - 0.95 - ], - "optimizer_eps": 1e-08, - "optimizer_weight_decay": 0, - "scheduler_warmup_steps": 0, - "scheduler_decay_steps": 30000, - "scheduler_decay_lr": 0 + "prompt_max_length": 52, + "response_max_length": 5, + "discrete_action_max_length": 10 }, "resume": false, "seed": 1000, "resolution": [224, 224], "num_cams": 2, - "action_expert_num_cams": 1, "max_state_dim": 32, "max_action_dim": 32, - "action_chunk": 50, + "action_chunk": 10, "loss_weighting": {"MSE": 1, "CE": 1}, - "num_workers": 4, - "batch_size": 16, + "num_workers": 12, + "batch_size": 4, "gradient_accumulation_steps": 2, - "dataloader_batch_size": 8, + "dataloader_batch_size": 2, "prefetch_factor": 8, "steps": 25, "log_freq": 1, @@ -109,10 +81,10 @@ "wandb": { "enable": true, "entity": "wyautox-autox", - "project": "tau0-ci", + "project": "github-ci", "run_id": null, "name": null, - "notes": "CI Config", + "notes": "GitHub CI Config", "tags": [], "group": null, "job_type": null, diff --git a/configs/examples/accelerate_deepspeed_config.yaml b/configs/examples/accelerate_deepspeed_config.yaml index 5c444fd..7196648 100644 --- a/configs/examples/accelerate_deepspeed_config.yaml +++ b/configs/examples/accelerate_deepspeed_config.yaml @@ -14,7 +14,7 @@ machine_rank: 0 main_training_function: main mixed_precision: 'no' num_machines: 1 -num_processes: 2 +num_processes: 4 rdzv_backend: static same_network: true tpu_env: [] diff --git a/docs/source/tutorials/inference.rst b/docs/source/tutorials/inference.rst index 88a6365..5a2c6d3 100644 --- a/docs/source/tutorials/inference.rst +++ b/docs/source/tutorials/inference.rst @@ -14,7 +14,7 @@ To run inference, run the following command: .. code-block:: bash - python lerobot/scripts/inference.py --config_path=outputs/train/pi05/checkpoints/000040/train_config.json + python src/opentau/scripts/inference.py --config_path=outputs/train/pi05/checkpoints/000040/train_config.json Running inference with autoregressive response prediction diff --git a/src/opentau/policies/pi05/modeling_pi05.py b/src/opentau/policies/pi05/modeling_pi05.py index 6a1af1c..8a02b3a 100644 --- a/src/opentau/policies/pi05/modeling_pi05.py +++ b/src/opentau/policies/pi05/modeling_pi05.py @@ -42,6 +42,7 @@ PaliGemmaWithExpertModel, ) from opentau.policies.pretrained import PreTrainedPolicy, T +from opentau.utils.accelerate_utils import get_proc_accelerator from opentau.utils.utils import get_safe_dtype @@ -351,9 +352,12 @@ def from_pretrained( model = cls(config, **kwargs) # Now manually load and remap the state dict + acc = get_proc_accelerator() + is_main_process = acc.is_main_process if acc else True try: # Try to load the pytorch_model.bin or model.safetensors file - print(f"Loading model from: {pretrained_name_or_path}") + if is_main_process: + print(f"Loading model from: {pretrained_name_or_path}") try: from transformers.utils import cached_file @@ -372,10 +376,12 @@ def from_pretrained( from safetensors.torch import load_file original_state_dict = load_file(resolved_file) - print("✓ Loaded state dict from model.safetensors") + if is_main_process: + print("✓ Loaded state dict from model.safetensors") except Exception as e: - print(f"Could not load state dict from remote files: {e}") - print("Returning model without loading pretrained weights") + if is_main_process: + print(f"Could not load state dict from remote files: {e}") + print("Returning model without loading pretrained weights") return model # First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys` @@ -390,18 +396,18 @@ def from_pretrained( new_key = f"model.{key}" remapped_state_dict[new_key] = value remap_count += 1 - if remap_count <= 10: # Only print first 10 to avoid spam + if remap_count <= 10 and is_main_process: # Only print first 10 to avoid spam print(f"Remapped: {key} -> {new_key}") else: remapped_state_dict[key] = value - if remap_count > 0: + if remap_count > 0 and is_main_process: print(f"Remapped {remap_count} state dict keys") # Load the remapped state dict into the model missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=False) - if missing_keys: + if missing_keys and is_main_process: print(f"Missing keys when loading state dict: {len(missing_keys)} keys") if len(missing_keys) <= 20: for key in missing_keys: @@ -411,7 +417,7 @@ def from_pretrained( print(f" - {key}") print(f" ... and {len(missing_keys) - 20} more") - if unexpected_keys: + if unexpected_keys and is_main_process: print(f"Unexpected keys when loading state dict: {len(unexpected_keys)} keys") if len(unexpected_keys) <= 20: for key in unexpected_keys: @@ -421,11 +427,12 @@ def from_pretrained( print(f" - {key}") print(f" ... and {len(unexpected_keys) - 20} more") - if not missing_keys and not unexpected_keys: + if not missing_keys and not unexpected_keys and is_main_process: print("All keys loaded successfully!") except Exception as e: - print(f"Warning: Could not remap state dict keys: {e}") + if is_main_process: + print(f"Warning: Could not remap state dict keys: {e}") return model