Skip to content
4 changes: 4 additions & 0 deletions src/opentau/policies/pi05/configuration_pi05.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ class PI05Config(PreTrainedConfig):
# Decoding
num_steps: int = 10

# Real Time Inference
# maximum number of frozen actions
max_delay: int = 0

# Initialization strategy
init_strategy: Literal["no_init", "full_he_init", "expert_only_he_init"] = "full_he_init"

Expand Down
168 changes: 120 additions & 48 deletions src/opentau/policies/pi05/modeling_pi05.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
from einops import rearrange
from einops import rearrange, repeat
from torch import Tensor, nn
from transformers import AutoProcessor, AutoTokenizer

Expand All @@ -52,23 +52,23 @@ def create_sinusoidal_pos_embedding(
"""Computes sine-cosine positional embedding vectors for scalar positions.

Args:
time: A 1-D tensor of shape (batch_size,).
time: A 2-D tensor of shape (batch_size, action_chunk_length).
dimension: The dimension of the embedding vectors. Must be divisible by 2.
min_period: The minimum period of the sinusoidal functions.
max_period: The maximum period of the sinusoidal functions.
device: The device to create the tensors on. Defaults to "cpu".

Returns:
A tensor of shape (batch_size, dimension) containing the positional embeddings.
A tensor of shape (batch_size, action_chunk_length, dimension) containing the positional embeddings.

Raises:
ValueError: If dimension is not divisible by 2 or if time tensor is not 1-D.
ValueError: If dimension is not divisible by 2 or if time tensor is not 2-D with shape (batch_size, action_chunk_length).
"""
if dimension % 2 != 0:
raise ValueError(f"dimension ({dimension}) must be divisible by 2")

if time.ndim != 1:
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
if time.ndim != 2:
raise ValueError("The time tensor is expected to be of shape `(batch_size, action_chunk_length)`.")

dtype = (
get_safe_dtype(torch.float64, device.type)
Expand All @@ -80,8 +80,8 @@ def create_sinusoidal_pos_embedding(

# Compute the outer product
scaling_factor = 1.0 / period * 2 * math.pi
sin_input = scaling_factor[None, :] * time[:, None]
pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
sin_input = rearrange(scaling_factor, "d -> 1 1 d") * rearrange(time, "b c -> b c 1")
pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=2)
return pos_emb


Expand Down Expand Up @@ -271,7 +271,7 @@ def __init__(
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.normalize_actions = Normalize(
self.normalize_discrete_actions = Normalize(
config.output_features, {"ACTION": NormalizationMode.MIN_MAX}, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
Expand Down Expand Up @@ -525,9 +525,12 @@ def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""Select a single action given environment observations.

This method wraps `select_actions` in order to return one action at a time for execution in the
environment. It works by managing the actions in a queue and only calling `select_actions` when the
queue is empty.
This method uses an action queue that is replenished when it has config.max_delay or fewer actions (or is empty).
When replenishing, the current queue contents are used as action_prefix for sample_actions,
then the queue is refilled with the new chunk.

Note: This method should only be called when running a policy in simulation. For real world inference,
this method should be written in the ROS client node.

Args:
batch: Batch of data containing environment observations.
Expand All @@ -538,24 +541,58 @@ def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -
"""
self.eval()

# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
# querying the policy.
if len(self._action_queue) == 0:
actions = self.sample_actions(batch, noise=noise)
self._action_queue.extend(actions)
return self._action_queue.popleft()
if len(self._action_queue) == 0 or len(self._action_queue) <= self.config.max_delay:
# Use current queue as action prefix to replenish
action_prefix = None
delay = 0
if len(self._action_queue) > 0:
prefix_actions = list(self._action_queue)
delay = min(len(prefix_actions), self.config.max_delay)
assert delay == self.config.max_delay, f"Delay must be equal to {self.config.max_delay}"
prefix_actions = prefix_actions[-delay:]
action_prefix = torch.stack(prefix_actions, dim=1)
action_prefix = self.normalize_targets({"actions": action_prefix})["actions"]
original_action_dim = self.config.action_feature.shape[0]
if original_action_dim < self.config.max_action_dim:
action_prefix = F.pad(
action_prefix,
(0, self.config.max_action_dim - original_action_dim),
)
if delay < self.config.chunk_size:
action_prefix = F.pad(
action_prefix,
(0, 0, 0, self.config.chunk_size - delay),
)
actions = self.sample_actions(batch, noise=noise, action_prefix=action_prefix, delay=delay)
actions = rearrange(actions, "b c d -> c b d")
self._action_queue.extend(actions[delay:])
assert len(self._action_queue) == self.config.n_action_steps, (
f"Action queue must have {self.config.n_action_steps} actions"
)

action = self._action_queue.popleft()
return action

@torch.no_grad()
def sample_actions(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
def sample_actions(
self,
batch: dict[str, Tensor],
noise: Tensor | None = None,
action_prefix: Tensor | None = None,
delay: int = 0,
) -> Tensor:
"""Sample actions from the policy given environment observations.

Args:
batch: Batch of data containing environment observations.
noise: Optional noise tensor.

action_prefix: Optional action prefix tensor of shape (batch_size, action_chunk_length, action_dim).
delay: number of frozen delay actions from action_prefix.
Returns:
The sampled actions tensor of shape (batch_size, action_dim).
The sampled actions tensor of shape (batch_size, action_chunk_length, action_dim).
"""
assert 0 <= delay <= self.config.max_delay, f"Delay must be between 0 and {self.config.max_delay}"

batch = self.normalize_inputs(batch)

images, img_masks = self.prepare_images(batch)
Expand All @@ -566,6 +603,8 @@ def sample_actions(self, batch: dict[str, Tensor], noise: Tensor | None = None)
img_masks,
lang_tokens,
lang_masks,
action_prefix=action_prefix,
delay=delay,
noise=noise,
)

Expand All @@ -575,9 +614,6 @@ def sample_actions(self, batch: dict[str, Tensor], noise: Tensor | None = None)

actions = self.unnormalize_outputs({"actions": actions})["actions"]

# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
actions = actions.transpose(0, 1)
return actions

def forward(
Expand All @@ -594,7 +630,7 @@ def forward(
A dictionary containing the loss components ("MSE" and "CE").
"""
batch = self.normalize_inputs(batch)
batch["discrete_actions"] = self.normalize_actions(dict(batch))["actions"]
batch["discrete_actions"] = self.normalize_discrete_actions(dict(batch))["actions"]
batch = self.normalize_targets(batch)

images, img_masks = self.prepare_images(
Expand Down Expand Up @@ -622,6 +658,7 @@ def forward(
lang_tokens,
lang_masks,
actions,
actions_is_pad,
response_tokens,
response_masks,
noise,
Expand All @@ -632,17 +669,8 @@ def forward(

mse_loss = losses["MSE"]
ce_loss = losses["CE"]
if actions_is_pad is not None:
in_episode_bound = ~actions_is_pad
mse_loss = mse_loss * in_episode_bound.unsqueeze(-1)

# Remove padding
mse_loss = mse_loss[:, :, : self.config.max_action_dim]

# For backward pass
loss = mse_loss.mean()

return {"MSE": loss, "CE": ce_loss}
return {"MSE": mse_loss, "CE": ce_loss}

def prepare_discrete_state(self, batch: dict[str, Tensor]) -> list[str]:
"""Discretizes the state into bins and converts it to a string representation.
Expand Down Expand Up @@ -1048,7 +1076,7 @@ def embed_suffix(self, noisy_actions: Tensor, timestep: Tensor) -> tuple[Tensor,

Args:
noisy_actions: Tensor containing noisy actions.
timestep: Tensor containing timesteps.
timestep: Tensor containing timesteps of shape (batch_size, action_chunk_length).

Returns:
A tuple containing:
Expand Down Expand Up @@ -1107,6 +1135,7 @@ def forward(
lang_tokens: Tensor,
lang_masks: Tensor,
actions: Tensor,
actions_is_pad: Tensor | None = None,
response_tokens: Tensor | None = None,
response_masks: Tensor | None = None,
noise: Tensor | None = None,
Expand All @@ -1124,6 +1153,7 @@ def forward(
response_tokens: Response language token tensor.
response_masks: Response language mask tensor.
actions: Action tensor.
actions_is_pad: Optional action is padded mask tensor.
noise: Optional noise tensor.
time: Optional time tensor.
discrete_actions: Optional discrete action tensor.
Expand Down Expand Up @@ -1161,13 +1191,24 @@ def forward(
)

# Now run action expert
batch_size = actions.shape[0]
if noise is None:
noise = self.sample_noise(actions.shape, actions.device)

if time is None:
time = self.sample_time(actions.shape[0], actions.device)
time = self.sample_time(batch_size, actions.device)

time_expanded = time[:, None, None]
# handle real time inference delay
delay = torch.randint(0, self.config.max_delay + 1, (batch_size,))
prefix_mask = rearrange(torch.arange(self.config.chunk_size), "c -> 1 c") < rearrange(
delay, "b -> b 1"
)
prefix_mask = prefix_mask.to(device=actions.device)
time = torch.where(
prefix_mask, 0, rearrange(time, "b -> b 1")
) # using diffusion time 0 instead of flow matching time 1

time_expanded = rearrange(time, "b c -> b c 1")
x_t = time_expanded * noise + (1 - time_expanded) * actions
u_t = noise - actions

Expand Down Expand Up @@ -1206,7 +1247,28 @@ def forward(
v_t = self.action_out_proj(suffix_out)
v_t = v_t.to(dtype=torch.float32)

losses = F.mse_loss(u_t, v_t, reduction="none")
mse_loss = F.mse_loss(u_t, v_t, reduction="none")

# mask out frozen actions and padded actions
postfix_mask = rearrange(
torch.logical_not(prefix_mask), "b c -> b c 1"
) # 0 for frozen actions, 1 for non-frozen actions

if actions_is_pad is not None:
in_episode_bound = ~actions_is_pad
in_episode_bound = rearrange(
in_episode_bound, "b c -> b c 1"
) # 0 for padded actions, 1 for non-padded actions
postfix_mask = torch.logical_and(postfix_mask, in_episode_bound)

mse_loss = mse_loss * postfix_mask

# Remove padding
mse_loss = mse_loss[:, :, : self.config.max_action_dim]

# Do not include frozen actions and padded actions in the mean loss calculation
postfix_mask_expanded = repeat(postfix_mask, "b c 1 -> b c d", d=mse_loss.shape[-1])
mse_loss = mse_loss.sum() / (postfix_mask_expanded.sum() + 1e-8)

# compute cross entropy loss for discrete actions
batch_size, seq_len = discrete_actions.shape
Expand Down Expand Up @@ -1263,16 +1325,18 @@ def forward(
# compute mean
response_ce_loss = response_ce_loss.mean()
else:
response_ce_loss = torch.tensor(0.0, device=losses.device)
response_ce_loss = torch.tensor(0.0, device=mse_loss.device)

return {"MSE": losses, "CE": discrete_action_ce_loss + response_ce_loss}
return {"MSE": mse_loss, "CE": discrete_action_ce_loss + response_ce_loss}

def sample_actions(
self,
images: list[Tensor],
img_masks: list[Tensor],
lang_tokens: Tensor,
lang_masks: Tensor,
action_prefix: Tensor | None = None,
delay: int = 0,
noise: Tensor | None = None,
) -> Tensor:
"""Do a full inference forward and compute the action.
Expand All @@ -1283,7 +1347,8 @@ def sample_actions(
lang_tokens: Language token tensor.
lang_masks: Language mask tensor.
noise: Optional noise tensor.

action_prefix: Optional action prefix tensor.
delay: number of delay actions.
Returns:
The sampled action tensor.
"""
Expand Down Expand Up @@ -1347,39 +1412,46 @@ def sample_actions(

x_t = noise
time = torch.tensor(1.0, dtype=torch.float32, device=device)
prefix_mask = rearrange(torch.arange(self.config.chunk_size, device=device), "c -> 1 c") < delay
while time >= -dt / 2:
expanded_time = time.expand(bsize)
# if delay is greater than 0, then freeze the action prefix at the beginning of action chunk
if delay > 0:
x_t = torch.where(rearrange(prefix_mask, "b c -> b c 1"), action_prefix, x_t)
masked_time = torch.where(prefix_mask, 0, time)
v_t = self.denoise_step(
prefix_pad_masks,
past_key_values,
x_t,
expanded_time,
masked_time,
)

# Euler step
x_t += dt * v_t
time += dt

# we need to ensure the frozen actions are not modified before returning the denoised actions
if delay > 0:
x_t = torch.where(rearrange(prefix_mask, "b c -> b c 1"), action_prefix, x_t)
return x_t

def denoise_step(
self,
prefix_pad_masks: Tensor,
past_key_values: list[dict[str, Tensor]],
x_t: Tensor,
timestep: Tensor,
time: Tensor,
) -> Tensor:
"""Apply one denoising step of the noise `x_t` at a given timestep.

Args:
prefix_pad_masks: Prefix padding masks.
past_key_values: Past key values from the VLM.
x_t: Current noise tensor.
timestep: Current timestep.

time: Time tensor of shape (batch_size, action_chunk_length).
Returns:
The predicted velocity tensor (v_t).
"""
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, timestep)
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time)

num_cross_att_tokens = prefix_pad_masks.shape[1]
action_expert_2d_attention_mask = make_att_2d_masks(
Expand Down
4 changes: 2 additions & 2 deletions src/opentau/scripts/grpc/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,9 @@ def GetActionChunk(
# Run inference
with torch.inference_mode():
action_chunk = self.policy.sample_actions(batch)
# action_chunk shape: (n_action_steps, batch_size=1, action_dim)
# action_chunk shape: (batch_size=1, n_action_steps, action_dim)
# Remove batch dimension and convert to numpy
action_chunk = action_chunk.squeeze(1).to("cpu", torch.float32).numpy()
action_chunk = action_chunk.squeeze(0).to("cpu", torch.float32).numpy()

# Populate 2D action chunk structure
for action_vector in action_chunk:
Expand Down
3 changes: 0 additions & 3 deletions src/opentau/utils/transformers_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,6 @@ def forward(
raise ValueError(f"Expected cond dimension {self.cond_dim}, got {cond.shape[-1]}")

modulation = self.dense(cond)
# Reshape modulation to broadcast properly: [batch, 1, features] for [batch, seq, features]
if len(x.shape) == 3: # [batch, seq, features]
modulation = modulation.unsqueeze(1)

scale, shift, gate = torch.chunk(modulation, 3, dim=-1)

Expand Down
Loading