Skip to content

Commit 9dae10c

Browse files
committed
Fixes
Seq2seq PPO + dtype default value
1 parent 4a6896b commit 9dae10c

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

trlx/models/modeling_ppo.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ class CausalLMOutputWithValue(ModelOutput):
252252
value: Optional[torch.FloatTensor] = None
253253

254254

255-
def make_value_branch(base_model, num_value_layers_unfrozen, dtype):
255+
def make_value_branch(base_model, num_value_layers_unfrozen, dtype=torch.float32):
256256
value_head = make_head(hf_get_hidden_size(base_model.config), 1, dtype)
257257
if num_value_layers_unfrozen == 0:
258258
return value_head
@@ -1211,9 +1211,12 @@ def __init__(
12111211
):
12121212
super().__init__(base_model, peft_config=peft_config)
12131213
# TODO: Support Seq2Seq value branching
1214+
parameter = next(hf_get_lm_head(self.base_model).parameters())
1215+
dtype = parameter.dtype
1216+
device = parameter.device
12141217
if num_value_layers_unfrozen > 0:
12151218
raise NotImplementedError("Value branches unsupported for Seq2Seq architecture")
1216-
self.v_head = make_head(hf_get_hidden_size(self.base_model.config), 1)
1219+
self.v_head = make_head(hf_get_hidden_size(self.base_model.config), 1, dtype).to(device)
12171220

12181221
def forward(
12191222
self,

trlx/utils/modeling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import transformers
1111

1212

13-
def make_head(n_embd: int, out: int, dtype: type = torch.float32) -> nn.Sequential:
13+
def make_head(n_embd: int, out: int, dtype: torch.dtype = torch.float32) -> nn.Sequential:
1414
"""Returns a generic sequential MLP head."""
1515
return nn.Sequential(
1616
nn.Linear(n_embd, n_embd * 2, dtype=dtype),

0 commit comments

Comments
 (0)