Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ OpenTau ($\tau$) is a tool developed by *[Tensor][1]* to bridge this gap, and we
| Create Validation Splits During Training | ❌ | ❌ | ✅ |
| $\pi^{*}_{0.6}$ style Reinforcement Learning Pipeline | ❌ | ❌ | ✅ |
| Framework | Jax / PyTorch | PyTorch | PyTorch |
| Raw Robotic data to Structured LeRobot format conversion | ❌ | ❌ | ✅ |

## Quick Start
If you are familiar with LeRobot, getting started with OpenTau is very easy.
Expand Down
66 changes: 34 additions & 32 deletions src/opentau/policies/pi05/modeling_pi05.py
Original file line number Diff line number Diff line change
Expand Up @@ -1231,39 +1231,41 @@ def forward(
# compute mean
discrete_action_ce_loss = discrete_action_ce_loss.mean()

# compute cross entropy loss for response language
batch_size, seq_len = response_tokens.shape
response_token_start = -self.config.response_max_length - self.config.discrete_action_max_length
# The last token of language will predict <BOS> token of response, so no need to include for loss calculation. Hence slice starts from -self.config.discrete_action_max_length - self.config.response_max_length.
# The last token of response predicts first token of discrete actions, so no need to include for loss calculation. Hence slice ends at -self.config.discrete_action_max_length - 1.
response_token_end = -self.config.discrete_action_max_length - 1
response_slice_object = slice(response_token_start, response_token_end)
response_out = prefix_out[
:,
response_slice_object,
]
response_logits = self.paligemma_with_expert.paligemma.lm_head(response_out)
# response slice to exclude the <BOS> token from response while calculating loss.
response_slice = slice(1, None)
response_logits = response_logits.to(dtype=torch.float32) # upcast to float32 for loss calculation
response_logits = rearrange(response_logits, "b s d -> (b s) d")
response_labels = rearrange(response_tokens[:, response_slice], "b s -> (b s)")
response_ce_loss = F.cross_entropy(response_logits, response_labels, reduction="none")

response_ce_loss = rearrange(response_ce_loss, "(b s) -> b s", b=batch_size, s=seq_len - 1)

# remove pad tokens
response_is_pad = ~response_masks # convert into format where value for pad is True
# helps to control loss for response tokens in case of robotic data and VQA data
response_ce_loss = response_ce_loss * ~response_is_pad[:, response_slice]

# compute mean
response_ce_loss = response_ce_loss.mean()
# compute cross entropy loss for response language only when pedict_response is set to true
if self.config.predict_response:
batch_size, seq_len = response_tokens.shape
response_token_start = -self.config.response_max_length - self.config.discrete_action_max_length
# The last token of language will predict <BOS> token of response, so no need to include for loss calculation. Hence slice starts from -self.config.discrete_action_max_length - self.config.response_max_length.
# The last token of response predicts first token of discrete actions, so no need to include for loss calculation. Hence slice ends at -self.config.discrete_action_max_length - 1.
response_token_end = -self.config.discrete_action_max_length - 1
response_slice_object = slice(response_token_start, response_token_end)
response_out = prefix_out[
:,
response_slice_object,
]
response_logits = self.paligemma_with_expert.paligemma.lm_head(response_out)
# response slice to exclude the <BOS> token from response while calculating loss.
response_slice = slice(1, None)
response_logits = response_logits.to(
dtype=torch.float32
) # upcast to float32 for loss calculation
response_logits = rearrange(response_logits, "b s d -> (b s) d")
response_labels = rearrange(response_tokens[:, response_slice], "b s -> (b s)")
response_ce_loss = F.cross_entropy(response_logits, response_labels, reduction="none")

response_ce_loss = rearrange(response_ce_loss, "(b s) -> b s", b=batch_size, s=seq_len - 1)

# remove pad tokens
response_is_pad = ~response_masks # convert into format where value for pad is True
# helps to control loss for response tokens in case of robotic data and VQA data
response_ce_loss = response_ce_loss * ~response_is_pad[:, response_slice]

# compute mean
response_ce_loss = response_ce_loss.mean()
else:
response_ce_loss = torch.tensor(0.0, device=losses.device)

return {
"MSE": losses,
"CE": (discrete_action_ce_loss + response_ce_loss),
}
return {"MSE": losses, "CE": discrete_action_ce_loss + response_ce_loss}

def sample_actions(
self,
Expand Down
Loading