Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@ libero = [
urdf = [
"rerun-sdk>=0.28.2",
]
trt = [
"tensorrt>=10.15.1.29",
"tensorrt>=10.9.0 ; (sys_platform == 'linux' and platform_machine == 'x86_64') or (sys_platform == 'win32' and (platform_machine == 'AMD64' or platform_machine == 'x86_64'))",
]

[tool.uv.sources]
libero = { git = "https://github.com/shuheng-liu/LIBERO" , branch = "master" } # the official libero repo is misconfigured for pip install with git
Expand Down
10 changes: 6 additions & 4 deletions src/opentau/policies/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,14 +300,16 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
if norm_mode is NormalizationMode.MEAN_STD:
mean = buffer["mean"]
std = buffer["std"]
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
assert not torch.isinf(std).any(), _no_stats_error_str("std")
if not (torch.compiler.is_compiling() or torch.onnx.is_in_onnx_export()):
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
assert not torch.isinf(std).any(), _no_stats_error_str("std")
batch[key] = batch[key] * (std + EPS) + mean
elif norm_mode is NormalizationMode.MIN_MAX:
min = buffer["min"]
max = buffer["max"]
assert not torch.isinf(min).any(), _no_stats_error_str("min")
assert not torch.isinf(max).any(), _no_stats_error_str("max")
if not (torch.compiler.is_compiling() or torch.onnx.is_in_onnx_export()):
assert not torch.isinf(min).any(), _no_stats_error_str("min")
assert not torch.isinf(max).any(), _no_stats_error_str("max")
batch[key] = (batch[key] + 1) / 2
batch[key] = batch[key] * (max - min + EPS) + min
else:
Expand Down
10 changes: 7 additions & 3 deletions src/opentau/policies/pi05/modeling_pi05.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@
from opentau.utils.utils import get_safe_dtype


def _preferred_dtype():
return torch.float32 if torch.onnx.is_in_onnx_export() else torch.bfloat16


def create_sinusoidal_pos_embedding(
time: Tensor, dimension: int, min_period: float, max_period: float, device: torch.device | str = "cpu"
) -> Tensor:
Expand Down Expand Up @@ -988,7 +992,7 @@ def embed_prefix(
img_mask,
) in zip(images, img_masks, strict=False):
img_emb = self.paligemma_with_expert.embed_image(img)
img_emb = img_emb.to(dtype=torch.bfloat16)
img_emb = img_emb.to(dtype=_preferred_dtype())

# image embeddings don't need to be unnormalized because `fix/lerobot_openpi` branch of huggingface
# already removed the normalization inside PaliGemma
Expand Down Expand Up @@ -1032,7 +1036,7 @@ def embed_prefix(

if discrete_actions is not None:
discrete_action_emb = self.paligemma_with_expert.embed_discrete_actions(discrete_actions)
embs.append(discrete_action_emb.to(dtype=torch.bfloat16))
embs.append(discrete_action_emb.to(dtype=_preferred_dtype()))
pad_masks.append(discrete_action_masks)
att_masks += [1] * discrete_action_emb.shape[1]

Expand Down Expand Up @@ -1062,7 +1066,7 @@ def embed_suffix(self, noisy_actions: Tensor, timestep: Tensor) -> tuple[Tensor,
att_masks = []

bsize = noisy_actions.shape[0]
dtype = torch.bfloat16
dtype = _preferred_dtype()
device = noisy_actions.device

# Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
Expand Down
14 changes: 9 additions & 5 deletions src/opentau/policies/pi05/paligemma_with_expert.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,10 @@
"""

import torch
import torch.version
from pytest import Cache
from torch import nn
from transformers import (
AutoConfig,
Cache,
GemmaForCausalLM,
PaliGemmaForConditionalGeneration,
PretrainedConfig,
Expand All @@ -37,6 +36,10 @@
from transformers.models.gemma import modeling_gemma


def _preferred_dtype():
return torch.float32 if torch.onnx.is_in_onnx_export() else torch.bfloat16


def apply_rope(x: torch.Tensor, positions: torch.Tensor, max_wavelength: int = 10_000) -> torch.Tensor:
"""Applies RoPE positions to the input tensor.

Expand Down Expand Up @@ -246,7 +249,8 @@ def __init__(self, config: PaliGemmaWithExpertConfig):

self.dropout = nn.Dropout(config.dropout)

self.to_bfloat16_like_physical_intelligence()
if not torch.compiler.is_compiling(): # Only cast to bfloat16 if not compiling
self.to_bfloat16_like_physical_intelligence()
self.set_requires_grad()

def set_requires_grad(self) -> None:
Expand Down Expand Up @@ -402,7 +406,7 @@ def forward(
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)

hidden_states = hidden_states.to(dtype=torch.bfloat16)
hidden_states = hidden_states.to(dtype=_preferred_dtype())
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
Expand Down Expand Up @@ -442,7 +446,7 @@ def forward(
att_output = attention_interface(
attention_mask, batch_size, head_dim, query_states, key_states, value_states
)
att_output = att_output.to(dtype=torch.bfloat16)
att_output = att_output.to(dtype=_preferred_dtype())

# first part of att_output is prefix (up to sequence length, [:, 0:prefix_seq_len])
outputs_embeds = []
Expand Down
Loading
Loading