From 04c9645c752b73501ede157162811fb609e24324 Mon Sep 17 00:00:00 2001 From: William Yue Date: Thu, 5 Feb 2026 14:19:22 -0800 Subject: [PATCH] fixed torch compile --- src/opentau/policies/pi05/modeling_pi05.py | 12 ++++++---- src/opentau/scripts/inference.py | 28 ++++++++++++++++------ 2 files changed, 28 insertions(+), 12 deletions(-) diff --git a/src/opentau/policies/pi05/modeling_pi05.py b/src/opentau/policies/pi05/modeling_pi05.py index 8a02b3a..71cbdb7 100644 --- a/src/opentau/policies/pi05/modeling_pi05.py +++ b/src/opentau/policies/pi05/modeling_pi05.py @@ -662,13 +662,15 @@ def prepare_discrete_state(self, batch: dict[str, Tensor]) -> list[str]: ValueError: If the state values are not normalized between -1 and 1. """ state = batch["state"] - state_np = state.to(device="cpu", dtype=torch.float32).numpy() - if np.any(state_np < -1.0) or np.any(state_np > 1.0): + state_cpu = state.to(device="cpu", dtype=torch.float32) + if torch.any(state_cpu < -1.0) or torch.any(state_cpu > 1.0): logging.warning( - f"State values are not normalized between -1 and 1. Min: {state_np.min()}, Max: {state_np.max()}" + f"State values are not normalized between -1 and 1. Min: {state_cpu.min().item()}, Max: {state_cpu.max().item()}" ) - state_np = np.clip(state_np, -1.0, 1.0) - discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1 + state_clipped = torch.clamp(state_cpu, -1.0, 1.0) + # replicate np.digitize with torch for torch.compile compatibility + bin_indices = ((state_clipped + 1.0) * 128.0).long().clamp(0, 255) + discretized_states = bin_indices.cpu().tolist() return [ " ".join(map(str, row)) for row in discretized_states ] # TODO: return a tensor instead of a list of strings? diff --git a/src/opentau/scripts/inference.py b/src/opentau/scripts/inference.py index a8b0d1f..6d9b7a0 100644 --- a/src/opentau/scripts/inference.py +++ b/src/opentau/scripts/inference.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -#!/usr/bin/env python - import logging +import time from dataclasses import asdict from pprint import pformat @@ -47,7 +46,7 @@ def inference_main(cfg: TrainPipelineConfig): policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=cfg.policy) policy.to(device=device, dtype=torch.bfloat16) policy.eval() - policy = attempt_torch_compile(policy, device_hint=device) + policy.sample_actions = attempt_torch_compile(policy.sample_actions, device_hint=device) # Always reset policy before episode to clear out action cache. policy.reset() @@ -57,10 +56,25 @@ def inference_main(cfg: TrainPipelineConfig): print(observation.keys()) with torch.inference_mode(): - for _ in range(1000): - action = policy.select_action(observation) - action = action.to("cpu", torch.float32).numpy() - print(f"Output shape: {action.shape}") + # One warmup call right after compiling + _ = policy.sample_actions(observation) + + # Run 10 times and record inference times + n_runs = 10 + times_ms = [] + for _ in range(n_runs): + t0 = time.perf_counter() + actions = policy.sample_actions(observation) + t1 = time.perf_counter() + times_ms.append((t1 - t0) * 1000.0) + + actions = actions.to("cpu", torch.float32).numpy() + print(f"Output shape: {actions.shape}") + + times_ms = torch.tensor(times_ms) + print( + f"Inference time (ms) over {n_runs} runs: min={times_ms.min().item():.2f}, max={times_ms.max().item():.2f}, avg={times_ms.mean().item():.2f}, std={times_ms.std().item():.2f}" + ) logging.info("End of inference")