Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions configs/examples/value_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@
"dataset_mixture": {
"datasets": [
{
"repo_id": "physical-intelligence/libero",
"episodes": [0,1,2,3,4,5,6,7,8,9]
"repo_id": "physical-intelligence/libero"
},
{
"grounding": "clevr"
}
],
"weights": [
1.0,
1.0
],
"action_freq": 30.0,
Expand All @@ -22,7 +25,8 @@
"VALUE": "MEAN_STD"
},
"max_state_dim": 32,
"tokenizer_max_length": 52,
"prompt_max_length": 256,
"response_max_length": 52,
"reward_config": {
"number_of_bins": 201,
"C_neg": -1000.0,
Expand All @@ -41,10 +45,11 @@
"gradient_accumulation_steps": 1,
"dataloader_batch_size": 2,
"prefetch_factor": 2,
"steps": 100,
"log_freq": 1,
"steps": 10000,
"log_freq": 100,
"val_freq": 500,
"save_checkpoint": true,
"save_freq": 100,
"save_freq": 1000,
"use_policy_training_preset": false,
"trace_nans": true,
"optimizer": {
Expand Down
3 changes: 2 additions & 1 deletion src/opentau/policies/value/configuration_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ class ValueConfig(PreTrainedConfig):
empty_cameras: int = 0

# Tokenizer
tokenizer_max_length: int = 48
prompt_max_length: int = 48
response_max_length: int = 52

# Reward config
reward_config: RewardConfig = field(default_factory=RewardConfig)
Expand Down
202 changes: 164 additions & 38 deletions src/opentau/policies/value/modeling_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
Uses SIGLIP for vision encoding and Gemma 3 270M for language processing.
"""

import logging
import math

import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
from einops import rearrange
Expand Down Expand Up @@ -228,9 +232,8 @@ def predict_value(self, batch: dict[str, Tensor]) -> Tensor:

images, img_masks = self.prepare_images(batch)
lang_tokens, lang_masks = self.prepare_language(batch)
state = batch.get("state")

logits = self.model.forward(images, img_masks, lang_tokens, lang_masks, state)
logits = self.model.get_value(images, img_masks, lang_tokens, lang_masks)
return self.calculate_value(logits)

def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor] | None]:
Expand All @@ -246,22 +249,49 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor] |

images, img_masks = self.prepare_images(batch)
lang_tokens, lang_masks = self.prepare_language(batch)
state = batch.get("state")
response_tokens, response_masks = self.prepare_response(batch)

logits = self.model.forward(images, img_masks, lang_tokens, lang_masks, state)
values = self.calculate_value(logits)
ce_logits, response_logits = self.model.forward(
images, img_masks, lang_tokens, lang_masks, response_tokens, response_masks
)
values = self.calculate_value(ce_logits)
# Compute Cross-Entropy loss
logits = logits.to(dtype=torch.float32) # upcast to float32 for loss calculation
ce_logits = ce_logits.to(dtype=torch.float32) # upcast to float32 for loss calculation
batch["return_bin_idx"] = batch["return_bin_idx"].to(dtype=torch.long)
loss = F.cross_entropy(logits, batch["return_bin_idx"])
ce_loss = F.cross_entropy(ce_logits, batch["return_bin_idx"], reduction="none")

action_is_pad = batch.get("action_is_pad")
# Mask CE loss if all action_is_pad are true. This is used for VQA dataset where we don't have actions tokens.
ce_loss = ce_loss * (~action_is_pad.all(dim=1)).float()

ce_loss = ce_loss.mean()

l1_loss = F.l1_loss(values, batch["return_continuous"])

accuracy = (logits.argmax(dim=-1) == batch["return_bin_idx"]).float().mean()
accuracy = (ce_logits.argmax(dim=-1) == batch["return_bin_idx"]).float().mean()

batch_size, seq_len = response_logits.shape[0], response_logits.shape[1]
response_slice = slice(1, None)
response_logits = response_logits.to(dtype=torch.float32) # upcast to float32 for loss calculation
response_logits = rearrange(response_logits, "b s d -> (b s) d")
response_labels = rearrange(response_tokens[:, response_slice], "b s -> (b s)")
response_ce_loss = F.cross_entropy(response_logits, response_labels, reduction="none")

response_ce_loss = rearrange(response_ce_loss, "(b s) -> b s", b=batch_size, s=seq_len)

# remove pad tokens
response_is_pad = ~response_masks # convert into format where value for pad is True
# Mask response loss if response is padded
response_ce_loss = response_ce_loss * ~response_is_pad[:, response_slice]
# Mask response loss if all action_is_pad are true. This is used for Robotic dataset where we have at least one actions tokens.
response_ce_loss = response_ce_loss * rearrange((action_is_pad.all(dim=1)).float(), "b -> b 1")

# compute mean
response_ce_loss = response_ce_loss.mean()

return {
"MSE": torch.zeros_like(loss, requires_grad=False),
"CE": loss,
"MSE": torch.zeros_like(ce_loss, requires_grad=False),
"CE": ce_loss + response_ce_loss,
"L1": l1_loss,
"Accuracy": accuracy,
}
Expand Down Expand Up @@ -321,6 +351,35 @@ def prepare_images(self, batch):

return images, img_masks

def prepare_discrete_state(self, batch: dict[str, Tensor]) -> list[str]:
"""Discretizes the state into bins and converts it to a string representation.

Each dimension of the state vector is discretized into 256 bins.
The values of each dimension of the state are expected to be in the range [-1, 1].
The discretization bins are linearly spaced between -1 and 1.
The index of the bin for each dimension is then concatenated into a space-separated string.

Args:
batch: Batch of data containing the "state" tensor.

Returns:
A list of strings, where each string is a space-separated list of discretized state values.

Raises:
ValueError: If the state values are not normalized between -1 and 1.
"""
state = batch["state"]
state_np = state.to(device="cpu", dtype=torch.float32).numpy()
if np.any(state_np < -1.0) or np.any(state_np > 1.0):
logging.warning(
f"State values are not normalized between -1 and 1. Min: {state_np.min()}, Max: {state_np.max()}"
)
state_np = np.clip(state_np, -1.0, 1.0)
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
return [
" ".join(map(str, row)) for row in discretized_states
] # TODO: return a tensor instead of a list of strings?

def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
"""Tokenizes the text input for the model.

Expand All @@ -333,21 +392,54 @@ def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
device = batch.get("state", list(batch.values())[0]).device
tasks = batch["prompt"]

# PaliGemma prompt has to end with a new line
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
state = self.prepare_discrete_state(batch)
# using <eos> to separate each modality
prompt = [f"Task: {task}<eos>State: {state}<eos>" for task, state in zip(tasks, state, strict=False)]

tokenized_prompt = self.language_tokenizer.__call__(
tasks,
prompt,
padding="max_length",
padding_side="right",
max_length=self.config.tokenizer_max_length,
max_length=self.config.prompt_max_length,
return_tensors="pt",
truncation=True,
)
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)

return lang_tokens, lang_masks

def prepare_response(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
"""Tokenize the response input.

Args:
batch: Batch of data containing the key "response".

Returns:
A tuple containing:
- response_tokens: Tensor of response language tokens.
- response_masks: Tensor of response language attention masks.
"""

device = batch["state"].device
responses = batch["response"]

# if '' is found in response then response is not for loss calculation (used for robotic dataset with no subtask), so add pad token to the response.
response_prompt = [f"{response}" for response in responses]

tokenized_response = self.language_tokenizer.__call__(
response_prompt,
padding="max_length",
padding_side="right",
max_length=self.config.response_max_length,
return_tensors="pt",
truncation=True,
)
response_tokens = tokenized_response["input_ids"].to(device=device)
response_masks = tokenized_response["attention_mask"].to(device=device, dtype=torch.bool)

return response_tokens, response_masks


class ValueModel(nn.Module):
"""
Expand Down Expand Up @@ -376,8 +468,6 @@ class ValueModel(nn.Module):
└──────────────────────────────┘
"""

CLASSIFICATION_TOKEN_ID = 6 # unused token id in Gemma 3 270M that we repurpose for classification

def __init__(self, config):
"""Initializes the ValueModel.

Expand All @@ -388,7 +478,8 @@ def __init__(self, config):
self.config = config

siglip_gemma_value_config = SiglipGemmaValueConfig(
num_value_bins=self.config.reward_config.number_of_bins
num_value_bins=self.config.reward_config.number_of_bins,
response_max_length=self.config.response_max_length,
)
self.siglip_gemma_value = SiglipGemmaValueModel(siglip_gemma_value_config)

Expand All @@ -399,7 +490,13 @@ def __init__(self, config):
self.c_neg = config.reward_config.C_neg

def embed_sequence(
self, images, img_masks, lang_tokens, lang_masks, state
self,
images,
img_masks,
lang_tokens,
lang_masks,
response_tokens: torch.Tensor | None = None,
response_masks: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Embeds sequence of images and language tokens.

Expand Down Expand Up @@ -451,25 +548,19 @@ def embed_sequence(
num_lang_embs = lang_emb.shape[1]
att_masks += [0] * num_lang_embs

# embed state
state_emb = self.state_proj(state)
state_emb = state_emb.to(dtype=torch.bfloat16)
embs.append(state_emb[:, None, :])
if response_tokens is not None:
response_emb = self.siglip_gemma_value.embed_language_tokens(response_tokens)

state_mask = torch.ones(state_emb.shape[0], 1, dtype=torch.bool, device=state_emb.device)
pad_masks.append(state_mask)
# Normalize response language embeddings
response_emb_dim = response_emb.shape[-1]
response_emb = response_emb * math.sqrt(response_emb_dim)

# full attention between state and image and language inputs
att_masks += [0]
embs.append(response_emb)
pad_masks.append(response_masks)

# add classification token
cls_token = torch.full(
(bsize, 1), self.CLASSIFICATION_TOKEN_ID, device=state_emb.device, dtype=torch.long
)
cls_token_emb = self.siglip_gemma_value.gemma.embed_tokens(cls_token)
embs.append(cls_token_emb)
pad_masks.append(torch.ones(bsize, 1, dtype=torch.bool, device=state_emb.device))
att_masks += [0]
# full attention between image, language and response inputs
num_response_embs = response_emb.shape[1]
att_masks += [1] * num_response_embs

embs = torch.cat(embs, dim=1)
pad_masks = torch.cat(pad_masks, dim=1)
Expand All @@ -484,7 +575,42 @@ def forward(
img_masks: list[torch.Tensor],
lang_tokens: torch.Tensor,
lang_masks: torch.Tensor,
state: torch.Tensor | None = None,
response_tokens: torch.Tensor | None = None,
response_masks: torch.Tensor | None = None,
) -> torch.Tensor:
"""Predict value estimates given observations.

Args:
images: List of image tensors
img_masks: List of image masks
lang_tokens: Language token IDs
lang_masks: Language attention masks
state: Optional state tensor

Returns:
Tensor of shape [batch_size, 1] containing value estimates
"""
embs, pad_masks, att_masks = self.embed_sequence(
images, img_masks, lang_tokens, lang_masks, response_tokens, response_masks
)

att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
position_ids = torch.cumsum(pad_masks, dim=1) - 1

ce_logits, response_logits = self.siglip_gemma_value.forward(
inputs_embeds=embs,
attention_mask=att_2d_masks,
position_ids=position_ids,
)

return ce_logits, response_logits

def get_value(
self,
images: list[torch.Tensor],
img_masks: list[torch.Tensor],
lang_tokens: torch.Tensor,
lang_masks: torch.Tensor,
) -> torch.Tensor:
"""Predict value estimates given observations.

Expand All @@ -498,15 +624,15 @@ def forward(
Returns:
Tensor of shape [batch_size, 1] containing value estimates
"""
embs, pad_masks, att_masks = self.embed_sequence(images, img_masks, lang_tokens, lang_masks, state)
embs, pad_masks, att_masks = self.embed_sequence(images, img_masks, lang_tokens, lang_masks)

att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
position_ids = torch.cumsum(pad_masks, dim=1) - 1

logits = self.siglip_gemma_value.forward(
value_logits = self.siglip_gemma_value.get_value(
inputs_embeds=embs,
attention_mask=att_2d_masks,
position_ids=position_ids,
)

return logits
return value_logits
Loading
Loading