Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions src/opentau/policies/pi05/modeling_pi05.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,13 +662,15 @@ def prepare_discrete_state(self, batch: dict[str, Tensor]) -> list[str]:
ValueError: If the state values are not normalized between -1 and 1.
"""
state = batch["state"]
state_np = state.to(device="cpu", dtype=torch.float32).numpy()
if np.any(state_np < -1.0) or np.any(state_np > 1.0):
state_cpu = state.to(device="cpu", dtype=torch.float32)
if torch.any(state_cpu < -1.0) or torch.any(state_cpu > 1.0):
logging.warning(
f"State values are not normalized between -1 and 1. Min: {state_np.min()}, Max: {state_np.max()}"
f"State values are not normalized between -1 and 1. Min: {state_cpu.min().item()}, Max: {state_cpu.max().item()}"
)
state_np = np.clip(state_np, -1.0, 1.0)
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
state_clipped = torch.clamp(state_cpu, -1.0, 1.0)
# replicate np.digitize with torch for torch.compile compatibility
bin_indices = ((state_clipped + 1.0) * 128.0).long().clamp(0, 255)
discretized_states = bin_indices.cpu().tolist()
return [
" ".join(map(str, row)) for row in discretized_states
] # TODO: return a tensor instead of a list of strings?
Expand Down
28 changes: 21 additions & 7 deletions src/opentau/scripts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

#!/usr/bin/env python

import logging
import time
from dataclasses import asdict
from pprint import pformat

Expand Down Expand Up @@ -47,7 +46,7 @@ def inference_main(cfg: TrainPipelineConfig):
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=cfg.policy)
policy.to(device=device, dtype=torch.bfloat16)
policy.eval()
policy = attempt_torch_compile(policy, device_hint=device)
policy.sample_actions = attempt_torch_compile(policy.sample_actions, device_hint=device)

# Always reset policy before episode to clear out action cache.
policy.reset()
Expand All @@ -57,10 +56,25 @@ def inference_main(cfg: TrainPipelineConfig):
print(observation.keys())

with torch.inference_mode():
for _ in range(1000):
action = policy.select_action(observation)
action = action.to("cpu", torch.float32).numpy()
print(f"Output shape: {action.shape}")
# One warmup call right after compiling
_ = policy.sample_actions(observation)

# Run 10 times and record inference times
n_runs = 10
times_ms = []
for _ in range(n_runs):
t0 = time.perf_counter()
actions = policy.sample_actions(observation)
t1 = time.perf_counter()
times_ms.append((t1 - t0) * 1000.0)

actions = actions.to("cpu", torch.float32).numpy()
print(f"Output shape: {actions.shape}")

times_ms = torch.tensor(times_ms)
print(
f"Inference time (ms) over {n_runs} runs: min={times_ms.min().item():.2f}, max={times_ms.max().item():.2f}, avg={times_ms.mean().item():.2f}, std={times_ms.std().item():.2f}"
)

logging.info("End of inference")

Expand Down
Loading