Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 32 additions & 8 deletions src/opentau/policies/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,22 @@
EPS = 1e-8 # Small epsilon value for numerical stability in normalization


def _is_tracing() -> bool:
"""Check if we're currently in a tracing context (e.g., ONNX export, torch.compile).

During tracing, data-dependent operations like assert with .any() cannot be evaluated,
so we skip these assertions to allow ONNX export to succeed.

Returns:
bool: True if in a tracing/compilation context, False otherwise.
"""
# Check for torch.compile/dynamo tracing
if torch.compiler.is_compiling():
return True
# Check for ONNX export context
return torch.onnx.is_in_onnx_export()


def warn_missing_keys(features: dict[str, PolicyFeature], batch: dict[str, Tensor], mode: str) -> None:
"""Warns if expected features are missing from the batch.

Expand Down Expand Up @@ -221,14 +237,18 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
if norm_mode is NormalizationMode.MEAN_STD:
mean = buffer["mean"]
std = buffer["std"]
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
assert not torch.isinf(std).any(), _no_stats_error_str("std")
# Skip data-dependent assertions during tracing (ONNX export)
if not _is_tracing():
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
assert not torch.isinf(std).any(), _no_stats_error_str("std")
batch[key] = (batch[key] - mean) / (std + EPS)
elif norm_mode is NormalizationMode.MIN_MAX:
min = buffer["min"]
max = buffer["max"]
assert not torch.isinf(min).any(), _no_stats_error_str("min")
assert not torch.isinf(max).any(), _no_stats_error_str("max")
# Skip data-dependent assertions during tracing (ONNX export)
if not _is_tracing():
assert not torch.isinf(min).any(), _no_stats_error_str("min")
assert not torch.isinf(max).any(), _no_stats_error_str("max")
batch[key] = (batch[key] - min) / (max - min + EPS)
# normalize to [-1, 1]
batch[key] = batch[key] * 2 - 1
Expand Down Expand Up @@ -300,14 +320,18 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
if norm_mode is NormalizationMode.MEAN_STD:
mean = buffer["mean"]
std = buffer["std"]
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
assert not torch.isinf(std).any(), _no_stats_error_str("std")
# Skip data-dependent assertions during tracing (ONNX export)
if not _is_tracing():
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
assert not torch.isinf(std).any(), _no_stats_error_str("std")
batch[key] = batch[key] * (std + EPS) + mean
elif norm_mode is NormalizationMode.MIN_MAX:
min = buffer["min"]
max = buffer["max"]
assert not torch.isinf(min).any(), _no_stats_error_str("min")
assert not torch.isinf(max).any(), _no_stats_error_str("max")
# Skip data-dependent assertions during tracing (ONNX export)
if not _is_tracing():
assert not torch.isinf(min).any(), _no_stats_error_str("min")
assert not torch.isinf(max).any(), _no_stats_error_str("max")
batch[key] = (batch[key] + 1) / 2
batch[key] = batch[key] * (max - min + EPS) + min
else:
Expand Down
Loading
Loading