File tree Expand file tree Collapse file tree 2 files changed +6
-3
lines changed
Expand file tree Collapse file tree 2 files changed +6
-3
lines changed Original file line number Diff line number Diff line change @@ -252,7 +252,7 @@ class CausalLMOutputWithValue(ModelOutput):
252252 value : Optional [torch .FloatTensor ] = None
253253
254254
255- def make_value_branch (base_model , num_value_layers_unfrozen , dtype ):
255+ def make_value_branch (base_model , num_value_layers_unfrozen , dtype = torch . float32 ):
256256 value_head = make_head (hf_get_hidden_size (base_model .config ), 1 , dtype )
257257 if num_value_layers_unfrozen == 0 :
258258 return value_head
@@ -1211,9 +1211,12 @@ def __init__(
12111211 ):
12121212 super ().__init__ (base_model , peft_config = peft_config )
12131213 # TODO: Support Seq2Seq value branching
1214+ parameter = next (hf_get_lm_head (self .base_model ).parameters ())
1215+ dtype = parameter .dtype
1216+ device = parameter .device
12141217 if num_value_layers_unfrozen > 0 :
12151218 raise NotImplementedError ("Value branches unsupported for Seq2Seq architecture" )
1216- self .v_head = make_head (hf_get_hidden_size (self .base_model .config ), 1 )
1219+ self .v_head = make_head (hf_get_hidden_size (self .base_model .config ), 1 , dtype ). to ( device )
12171220
12181221 def forward (
12191222 self ,
Original file line number Diff line number Diff line change 1010import transformers
1111
1212
13- def make_head (n_embd : int , out : int , dtype : type = torch .float32 ) -> nn .Sequential :
13+ def make_head (n_embd : int , out : int , dtype : torch . dtype = torch .float32 ) -> nn .Sequential :
1414 """Returns a generic sequential MLP head."""
1515 return nn .Sequential (
1616 nn .Linear (n_embd , n_embd * 2 , dtype = dtype ),
You can’t perform that action at this time.
0 commit comments