diff --git a/src/opentau/policies/pi05/configuration_pi05.py b/src/opentau/policies/pi05/configuration_pi05.py index 97a4327..e3bb7ed 100644 --- a/src/opentau/policies/pi05/configuration_pi05.py +++ b/src/opentau/policies/pi05/configuration_pi05.py @@ -116,6 +116,10 @@ class PI05Config(PreTrainedConfig): # Decoding num_steps: int = 10 + # Real Time Inference + # maximum number of frozen actions + max_delay: int = 0 + # Initialization strategy init_strategy: Literal["no_init", "full_he_init", "expert_only_he_init"] = "full_he_init" diff --git a/src/opentau/policies/pi05/modeling_pi05.py b/src/opentau/policies/pi05/modeling_pi05.py index 169e7e4..5dc5b85 100644 --- a/src/opentau/policies/pi05/modeling_pi05.py +++ b/src/opentau/policies/pi05/modeling_pi05.py @@ -29,7 +29,7 @@ import numpy as np import torch import torch.nn.functional as F # noqa: N812 -from einops import rearrange +from einops import rearrange, repeat from torch import Tensor, nn from transformers import AutoProcessor, AutoTokenizer @@ -52,23 +52,23 @@ def create_sinusoidal_pos_embedding( """Computes sine-cosine positional embedding vectors for scalar positions. Args: - time: A 1-D tensor of shape (batch_size,). + time: A 2-D tensor of shape (batch_size, action_chunk_length). dimension: The dimension of the embedding vectors. Must be divisible by 2. min_period: The minimum period of the sinusoidal functions. max_period: The maximum period of the sinusoidal functions. device: The device to create the tensors on. Defaults to "cpu". Returns: - A tensor of shape (batch_size, dimension) containing the positional embeddings. + A tensor of shape (batch_size, action_chunk_length, dimension) containing the positional embeddings. Raises: - ValueError: If dimension is not divisible by 2 or if time tensor is not 1-D. + ValueError: If dimension is not divisible by 2 or if time tensor is not 2-D with shape (batch_size, action_chunk_length). """ if dimension % 2 != 0: raise ValueError(f"dimension ({dimension}) must be divisible by 2") - if time.ndim != 1: - raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.") + if time.ndim != 2: + raise ValueError("The time tensor is expected to be of shape `(batch_size, action_chunk_length)`.") dtype = ( get_safe_dtype(torch.float64, device.type) @@ -80,8 +80,8 @@ def create_sinusoidal_pos_embedding( # Compute the outer product scaling_factor = 1.0 / period * 2 * math.pi - sin_input = scaling_factor[None, :] * time[:, None] - pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1) + sin_input = rearrange(scaling_factor, "d -> 1 1 d") * rearrange(time, "b c -> b c 1") + pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=2) return pos_emb @@ -271,7 +271,7 @@ def __init__( self.normalize_targets = Normalize( config.output_features, config.normalization_mapping, dataset_stats ) - self.normalize_actions = Normalize( + self.normalize_discrete_actions = Normalize( config.output_features, {"ACTION": NormalizationMode.MIN_MAX}, dataset_stats ) self.unnormalize_outputs = Unnormalize( @@ -525,9 +525,12 @@ def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: """Select a single action given environment observations. - This method wraps `select_actions` in order to return one action at a time for execution in the - environment. It works by managing the actions in a queue and only calling `select_actions` when the - queue is empty. + This method uses an action queue that is replenished when it has config.max_delay or fewer actions (or is empty). + When replenishing, the current queue contents are used as action_prefix for sample_actions, + then the queue is refilled with the new chunk. + + Note: This method should only be called when running a policy in simulation. For real world inference, + this method should be written in the ROS client node. Args: batch: Batch of data containing environment observations. @@ -538,24 +541,58 @@ def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) - """ self.eval() - # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by - # querying the policy. - if len(self._action_queue) == 0: - actions = self.sample_actions(batch, noise=noise) - self._action_queue.extend(actions) - return self._action_queue.popleft() + if len(self._action_queue) == 0 or len(self._action_queue) <= self.config.max_delay: + # Use current queue as action prefix to replenish + action_prefix = None + delay = 0 + if len(self._action_queue) > 0: + prefix_actions = list(self._action_queue) + delay = min(len(prefix_actions), self.config.max_delay) + assert delay == self.config.max_delay, f"Delay must be equal to {self.config.max_delay}" + prefix_actions = prefix_actions[-delay:] + action_prefix = torch.stack(prefix_actions, dim=1) + action_prefix = self.normalize_targets({"actions": action_prefix})["actions"] + original_action_dim = self.config.action_feature.shape[0] + if original_action_dim < self.config.max_action_dim: + action_prefix = F.pad( + action_prefix, + (0, self.config.max_action_dim - original_action_dim), + ) + if delay < self.config.chunk_size: + action_prefix = F.pad( + action_prefix, + (0, 0, 0, self.config.chunk_size - delay), + ) + actions = self.sample_actions(batch, noise=noise, action_prefix=action_prefix, delay=delay) + actions = rearrange(actions, "b c d -> c b d") + self._action_queue.extend(actions[delay:]) + assert len(self._action_queue) == self.config.n_action_steps, ( + f"Action queue must have {self.config.n_action_steps} actions" + ) + + action = self._action_queue.popleft() + return action @torch.no_grad() - def sample_actions(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: + def sample_actions( + self, + batch: dict[str, Tensor], + noise: Tensor | None = None, + action_prefix: Tensor | None = None, + delay: int = 0, + ) -> Tensor: """Sample actions from the policy given environment observations. Args: batch: Batch of data containing environment observations. noise: Optional noise tensor. - + action_prefix: Optional action prefix tensor of shape (batch_size, action_chunk_length, action_dim). + delay: number of frozen delay actions from action_prefix. Returns: - The sampled actions tensor of shape (batch_size, action_dim). + The sampled actions tensor of shape (batch_size, action_chunk_length, action_dim). """ + assert 0 <= delay <= self.config.max_delay, f"Delay must be between 0 and {self.config.max_delay}" + batch = self.normalize_inputs(batch) images, img_masks = self.prepare_images(batch) @@ -566,6 +603,8 @@ def sample_actions(self, batch: dict[str, Tensor], noise: Tensor | None = None) img_masks, lang_tokens, lang_masks, + action_prefix=action_prefix, + delay=delay, noise=noise, ) @@ -575,9 +614,6 @@ def sample_actions(self, batch: dict[str, Tensor], noise: Tensor | None = None) actions = self.unnormalize_outputs({"actions": actions})["actions"] - # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue - # effectively has shape (n_action_steps, batch_size, *), hence the transpose. - actions = actions.transpose(0, 1) return actions def forward( @@ -594,7 +630,7 @@ def forward( A dictionary containing the loss components ("MSE" and "CE"). """ batch = self.normalize_inputs(batch) - batch["discrete_actions"] = self.normalize_actions(dict(batch))["actions"] + batch["discrete_actions"] = self.normalize_discrete_actions(dict(batch))["actions"] batch = self.normalize_targets(batch) images, img_masks = self.prepare_images( @@ -622,6 +658,7 @@ def forward( lang_tokens, lang_masks, actions, + actions_is_pad, response_tokens, response_masks, noise, @@ -632,17 +669,8 @@ def forward( mse_loss = losses["MSE"] ce_loss = losses["CE"] - if actions_is_pad is not None: - in_episode_bound = ~actions_is_pad - mse_loss = mse_loss * in_episode_bound.unsqueeze(-1) - - # Remove padding - mse_loss = mse_loss[:, :, : self.config.max_action_dim] - - # For backward pass - loss = mse_loss.mean() - return {"MSE": loss, "CE": ce_loss} + return {"MSE": mse_loss, "CE": ce_loss} def prepare_discrete_state(self, batch: dict[str, Tensor]) -> list[str]: """Discretizes the state into bins and converts it to a string representation. @@ -1048,7 +1076,7 @@ def embed_suffix(self, noisy_actions: Tensor, timestep: Tensor) -> tuple[Tensor, Args: noisy_actions: Tensor containing noisy actions. - timestep: Tensor containing timesteps. + timestep: Tensor containing timesteps of shape (batch_size, action_chunk_length). Returns: A tuple containing: @@ -1107,6 +1135,7 @@ def forward( lang_tokens: Tensor, lang_masks: Tensor, actions: Tensor, + actions_is_pad: Tensor | None = None, response_tokens: Tensor | None = None, response_masks: Tensor | None = None, noise: Tensor | None = None, @@ -1124,6 +1153,7 @@ def forward( response_tokens: Response language token tensor. response_masks: Response language mask tensor. actions: Action tensor. + actions_is_pad: Optional action is padded mask tensor. noise: Optional noise tensor. time: Optional time tensor. discrete_actions: Optional discrete action tensor. @@ -1161,13 +1191,24 @@ def forward( ) # Now run action expert + batch_size = actions.shape[0] if noise is None: noise = self.sample_noise(actions.shape, actions.device) if time is None: - time = self.sample_time(actions.shape[0], actions.device) + time = self.sample_time(batch_size, actions.device) - time_expanded = time[:, None, None] + # handle real time inference delay + delay = torch.randint(0, self.config.max_delay + 1, (batch_size,)) + prefix_mask = rearrange(torch.arange(self.config.chunk_size), "c -> 1 c") < rearrange( + delay, "b -> b 1" + ) + prefix_mask = prefix_mask.to(device=actions.device) + time = torch.where( + prefix_mask, 0, rearrange(time, "b -> b 1") + ) # using diffusion time 0 instead of flow matching time 1 + + time_expanded = rearrange(time, "b c -> b c 1") x_t = time_expanded * noise + (1 - time_expanded) * actions u_t = noise - actions @@ -1206,7 +1247,28 @@ def forward( v_t = self.action_out_proj(suffix_out) v_t = v_t.to(dtype=torch.float32) - losses = F.mse_loss(u_t, v_t, reduction="none") + mse_loss = F.mse_loss(u_t, v_t, reduction="none") + + # mask out frozen actions and padded actions + postfix_mask = rearrange( + torch.logical_not(prefix_mask), "b c -> b c 1" + ) # 0 for frozen actions, 1 for non-frozen actions + + if actions_is_pad is not None: + in_episode_bound = ~actions_is_pad + in_episode_bound = rearrange( + in_episode_bound, "b c -> b c 1" + ) # 0 for padded actions, 1 for non-padded actions + postfix_mask = torch.logical_and(postfix_mask, in_episode_bound) + + mse_loss = mse_loss * postfix_mask + + # Remove padding + mse_loss = mse_loss[:, :, : self.config.max_action_dim] + + # Do not include frozen actions and padded actions in the mean loss calculation + postfix_mask_expanded = repeat(postfix_mask, "b c 1 -> b c d", d=mse_loss.shape[-1]) + mse_loss = mse_loss.sum() / (postfix_mask_expanded.sum() + 1e-8) # compute cross entropy loss for discrete actions batch_size, seq_len = discrete_actions.shape @@ -1263,9 +1325,9 @@ def forward( # compute mean response_ce_loss = response_ce_loss.mean() else: - response_ce_loss = torch.tensor(0.0, device=losses.device) + response_ce_loss = torch.tensor(0.0, device=mse_loss.device) - return {"MSE": losses, "CE": discrete_action_ce_loss + response_ce_loss} + return {"MSE": mse_loss, "CE": discrete_action_ce_loss + response_ce_loss} def sample_actions( self, @@ -1273,6 +1335,8 @@ def sample_actions( img_masks: list[Tensor], lang_tokens: Tensor, lang_masks: Tensor, + action_prefix: Tensor | None = None, + delay: int = 0, noise: Tensor | None = None, ) -> Tensor: """Do a full inference forward and compute the action. @@ -1283,7 +1347,8 @@ def sample_actions( lang_tokens: Language token tensor. lang_masks: Language mask tensor. noise: Optional noise tensor. - + action_prefix: Optional action prefix tensor. + delay: number of delay actions. Returns: The sampled action tensor. """ @@ -1347,18 +1412,26 @@ def sample_actions( x_t = noise time = torch.tensor(1.0, dtype=torch.float32, device=device) + prefix_mask = rearrange(torch.arange(self.config.chunk_size, device=device), "c -> 1 c") < delay while time >= -dt / 2: - expanded_time = time.expand(bsize) + # if delay is greater than 0, then freeze the action prefix at the beginning of action chunk + if delay > 0: + x_t = torch.where(rearrange(prefix_mask, "b c -> b c 1"), action_prefix, x_t) + masked_time = torch.where(prefix_mask, 0, time) v_t = self.denoise_step( prefix_pad_masks, past_key_values, x_t, - expanded_time, + masked_time, ) # Euler step x_t += dt * v_t time += dt + + # we need to ensure the frozen actions are not modified before returning the denoised actions + if delay > 0: + x_t = torch.where(rearrange(prefix_mask, "b c -> b c 1"), action_prefix, x_t) return x_t def denoise_step( @@ -1366,7 +1439,7 @@ def denoise_step( prefix_pad_masks: Tensor, past_key_values: list[dict[str, Tensor]], x_t: Tensor, - timestep: Tensor, + time: Tensor, ) -> Tensor: """Apply one denoising step of the noise `x_t` at a given timestep. @@ -1374,12 +1447,11 @@ def denoise_step( prefix_pad_masks: Prefix padding masks. past_key_values: Past key values from the VLM. x_t: Current noise tensor. - timestep: Current timestep. - + time: Time tensor of shape (batch_size, action_chunk_length). Returns: The predicted velocity tensor (v_t). """ - suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, timestep) + suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time) num_cross_att_tokens = prefix_pad_masks.shape[1] action_expert_2d_attention_mask = make_att_2d_masks( diff --git a/src/opentau/scripts/grpc/server.py b/src/opentau/scripts/grpc/server.py index c1e4832..3637254 100644 --- a/src/opentau/scripts/grpc/server.py +++ b/src/opentau/scripts/grpc/server.py @@ -189,9 +189,9 @@ def GetActionChunk( # Run inference with torch.inference_mode(): action_chunk = self.policy.sample_actions(batch) - # action_chunk shape: (n_action_steps, batch_size=1, action_dim) + # action_chunk shape: (batch_size=1, n_action_steps, action_dim) # Remove batch dimension and convert to numpy - action_chunk = action_chunk.squeeze(1).to("cpu", torch.float32).numpy() + action_chunk = action_chunk.squeeze(0).to("cpu", torch.float32).numpy() # Populate 2D action chunk structure for action_vector in action_chunk: diff --git a/src/opentau/utils/transformers_patch.py b/src/opentau/utils/transformers_patch.py index 8913ecc..08ccbf4 100644 --- a/src/opentau/utils/transformers_patch.py +++ b/src/opentau/utils/transformers_patch.py @@ -155,9 +155,6 @@ def forward( raise ValueError(f"Expected cond dimension {self.cond_dim}, got {cond.shape[-1]}") modulation = self.dense(cond) - # Reshape modulation to broadcast properly: [batch, 1, features] for [batch, seq, features] - if len(x.shape) == 3: # [batch, seq, features] - modulation = modulation.unsqueeze(1) scale, shift, gate = torch.chunk(modulation, 3, dim=-1)