Skip to content

Commit fd58c49

Browse files
authored
fix(accelerate_ppo_trainer): no resizing when using peft reference
1 parent d2c4b3b commit fd58c49

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

trlx/trainer/accelerate_ppo_trainer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,15 @@ def __init__(self, config: TRLConfig, **kwargs):
6969

7070
# Set up a reference model when hydra heads are not used
7171
if not hasattr(self.model, "frozen_head") and not self.model.peft_type:
72+
# Full Reference Copy
7273
self.ref_model = self.get_arch(self.config)
7374
self.ref_model.base_model.resize_token_embeddings(len(self.tokenizer))
7475
self.ref_model.to(self.accelerator.device)
7576
self.ref_model.eval()
76-
else:
77-
# resize hydra heads
77+
elif hasattr(self.model, "frozen_head"):
78+
# Hydra Reference: Use the frozen base layers and head as the reference model, resize hydra heads
7879
self.model.frozen_head.resize_token_embeddings(len(self.tokenizer))
80+
# TODO: else PEFT Reference, do something?
7981

8082
# Set up the KL controller
8183
# This helps prevent large divergences in the controller (policy)

0 commit comments

Comments
 (0)