Skip to content

Commit 9c57623

Browse files
author
s00652993
committed
Add DPO training example and fix minor bugs
1 parent f9656f2 commit 9c57623

File tree

6 files changed

+130
-19
lines changed

6 files changed

+130
-19
lines changed

examples/hh/dpo_hh.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import json
2+
import sys
3+
from collections import defaultdict
4+
5+
import tqdm
6+
from datasets import Dataset, load_dataset
7+
8+
import trlx
9+
from trlx.data.default_configs import (
10+
DPOConfig,
11+
ModelConfig,
12+
OptimizerConfig,
13+
SchedulerConfig,
14+
TokenizerConfig,
15+
TrainConfig,
16+
TRLConfig,
17+
)
18+
19+
default_config = TRLConfig(
20+
train=TrainConfig(
21+
seq_length=1024,
22+
epochs=100,
23+
total_steps=1000,
24+
batch_size=1,
25+
checkpoint_interval=10000,
26+
eval_interval=100,
27+
pipeline="PromptPipeline",
28+
trainer="AccelerateDPOTrainer",
29+
checkpoint_dir="checkpoints/dpo_hh",
30+
),
31+
model=ModelConfig(model_path="gpt2", num_layers_unfrozen=-1),
32+
tokenizer=TokenizerConfig(tokenizer_path="gpt2", truncation_side="right"),
33+
optimizer=OptimizerConfig(name="adamw", kwargs=dict(lr=1e-6, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6)),
34+
scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=1e12, eta_min=1.0e-4)), # train.total_steps
35+
method=DPOConfig(
36+
name="DPOConfig", gen_kwargs=dict(max_new_tokens=40, top_k=20, top_p=1.0, do_sample=True), beta=0.1
37+
),
38+
)
39+
40+
41+
def get_hh(split: str, sanity_check=False, silent=False):
42+
dataset = load_dataset("Anthropic/hh-rlhf", split=split)
43+
if sanity_check:
44+
dataset = dataset.select(range(min(len(dataset), 1000)))
45+
46+
def extract_anthropic_prompt(prompt_and_response):
47+
"""Extract the anthropic prompt from a prompt and response pair."""
48+
search_term = "\n\nAssistant:"
49+
search_term_idx = prompt_and_response.rfind(search_term)
50+
assert search_term_idx != -1, f"Prompt and response does not contain '{search_term}'"
51+
return prompt_and_response[: search_term_idx + len(search_term)]
52+
53+
def split_prompt_and_responses(ex):
54+
prompt = extract_anthropic_prompt(ex["chosen"])
55+
chosen_response = ex["chosen"][len(prompt) :]
56+
rejected_response = ex["rejected"][len(prompt) :]
57+
return prompt, chosen_response, rejected_response
58+
59+
data = defaultdict(lambda: defaultdict(list))
60+
for row in tqdm.tqdm(dataset, desc="Processing HH", disable=silent):
61+
prompt, chosen, rejected = split_prompt_and_responses(row)
62+
responses = [chosen, rejected]
63+
n_responses = len(data[prompt]["responses"])
64+
data[prompt]["pairs"].append((n_responses, n_responses + 1))
65+
data[prompt]["responses"].extend(responses)
66+
data[prompt]["sft_target"] = chosen
67+
68+
def gen():
69+
for prompt, values in data.items():
70+
yield {
71+
"prompt": prompt,
72+
"responses": values["responses"],
73+
"pairs": values["pairs"],
74+
}
75+
76+
return Dataset.from_generator(gen)
77+
78+
79+
def preprocess(sample):
80+
pass
81+
82+
83+
def main(hparams={}):
84+
config = TRLConfig.update(default_config, hparams)
85+
86+
dataset = load_dataset("Dahoas/full-hh-rlhf").map(preprocess)
87+
88+
trlx.train(
89+
config=config,
90+
samples=dataset["train"],
91+
eval_prompts=dataset["test"]["prompt"][:280],
92+
# metric_fn=lambda **kwargs: {"reward": reward_fn(**kwargs)},
93+
stop_sequences=["Human:", "human:", "Assistant:", "assistant:"],
94+
)
95+
96+
97+
if __name__ == "__main__":
98+
hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1])
99+
main(hparams)

trlx/pipeline/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,8 @@ def __next__(self): # noqa: C901
166166
minibatch = BatchEncoding(sliced_data)
167167
elif is_dataclass(batch):
168168
minibatch = batch.__class__(**sliced_data)
169-
# else:
170-
# minibatch = sliced_data
169+
else:
170+
minibatch = sliced_data
171171

172172
minibatches.append(minibatch)
173173

trlx/pipeline/offline_pipeline.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -288,19 +288,27 @@ class DPOPreferences:
288288

289289
class DPOStore(BaseRolloutStore):
290290
# Adapted from TRL
291-
def __init__(self, preferences: List[DPOPreferences], tokenizer: PreTrainedTokenizer):
291+
def __init__(
292+
self,
293+
preferences: List[DPOPreferences],
294+
tokenizer: PreTrainedTokenizer,
295+
label_pad_token_id: int,
296+
padding_value: int,
297+
):
292298
super().__init__()
293299
self.tokenizer = tokenizer
300+
self.label_pad_token_id = label_pad_token_id
301+
self.padding_value = padding_value
294302

295303
self.history = [
296304
self._build_batch_from_preference_tokens(preference_element) for preference_element in preferences
297305
]
298306

299307
@staticmethod
300308
def tokenize_preferences(samples, tokenizer, max_length=2048):
301-
chosen_tokens = tokenizer(samples[0], add_special_tokens=False)
302-
rejected_tokens = tokenizer(samples[1], add_special_tokens=False)
303-
prompt_tokens = tokenizer(samples[2], add_special_tokens=False)
309+
chosen_tokens = tokenizer(samples["chosen"], add_special_tokens=False)
310+
rejected_tokens = tokenizer(samples["rejected"], add_special_tokens=False)
311+
prompt_tokens = tokenizer(samples["prompt"], add_special_tokens=False)
304312

305313
chosen_tokens["input_ids"].append(tokenizer.eos_token_id)
306314
chosen_tokens["attention_mask"].append(1)
@@ -313,14 +321,14 @@ def tokenize_preferences(samples, tokenizer, max_length=2048):
313321
# if combined sequence is too long, truncate the prompt only
314322
if len(prompt_tokens["input_ids"]) + longer_response_length > max_length:
315323
if tokenizer.truncation_side == "right":
316-
prompt_tokens = {k: v[: self.max_prompt_length] for k, v in prompt_tokens.items()}
324+
prompt_tokens = {k: v[:max_length] for k, v in prompt_tokens.items()}
317325
elif tokenizer.truncation_side == "left":
318-
prompt_tokens = {k: v[-self.max_prompt_length :] for k, v in prompt_tokens.items()}
326+
prompt_tokens = {k: v[-max_length:] for k, v in prompt_tokens.items()}
319327

320328
# if that's still too long, truncate the response
321329
if len(prompt_tokens["input_ids"]) + longer_response_length > max_length:
322-
chosen_tokens = {k: v[: max_length - max_prompt_length] for k, v in chosen_tokens.items()}
323-
rejected_tokens = {k: v[: max_length - max_prompt_length] for k, v in rejected_tokens.items()}
330+
chosen_tokens = {k: v[: max_length - max_length] for k, v in chosen_tokens.items()}
331+
rejected_tokens = {k: v[: max_length - max_length] for k, v in rejected_tokens.items()}
324332

325333
return DPOPreferences(prompt_tokens=prompt_tokens, chosen_tokens=chosen_tokens, rejected_tokens=rejected_tokens)
326334

trlx/trainer/accelerate_dpo_trainer.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,20 @@ class DPOConfig(MethodConfig):
2525
"""
2626

2727
gen_kwargs: dict
28-
beta: float = 0.1
28+
beta: float = 0.1 # Beta value for DPO loss calculation
29+
label_pad_token_id: int = -100 # -100 is ignore token for CELoss
30+
padding_value: int = 0
2931

3032

3133
@register_trainer
3234
class AccelerateDPOTrainer(AccelerateRLTrainer):
3335
def __init__(self, config: TRLConfig, **kwargs):
3436
super().__init__(config, **kwargs)
3537

36-
# Set up a reference model when hydra heads are not used
37-
if not hasattr(self.model, "frozen_head") and not self.model.peft_type:
38-
self.ref_model = self.get_arch(self.config)
39-
self.ref_model.to(self.accelerator.device)
40-
self.ref_model.eval()
38+
# TODO: Avoid setting up a reference model when hydra heads are used
39+
self.ref_model = self.get_arch(self.config)
40+
self.ref_model.to(self.accelerator.device)
41+
self.ref_model.eval()
4142

4243
self.generate_kwargs = dict(
4344
config.method.gen_kwargs,
@@ -47,6 +48,8 @@ def __init__(self, config: TRLConfig, **kwargs):
4748

4849
# `beta` corresponding to the DPO hyperparameter
4950
self.beta = config.method.beta
51+
self.label_pad_token_id = config.method.label_pad_token_id
52+
self.padding_value = config.method.padding_value
5053

5154
def get_arch(self, config):
5255
from_fn = AutoModelForCausalLM.from_pretrained
@@ -250,4 +253,4 @@ def prepare_learning(self):
250253

251254
def make_experience(self, samples, seq_length):
252255
preferences = [DPOStore.tokenize_preferences(sample, self.tokenizer, seq_length) for sample in samples]
253-
self.store = DPOStore(preferences, self.tokenizer)
256+
self.store = DPOStore(preferences, self.tokenizer, self.label_pad_token_id, self.padding_value)

trlx/trlx.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def train( # noqa: C901
6464
config = default_ppo_config()
6565
elif rewards:
6666
config = default_ilql_config()
67-
else:
67+
else: # Alternatively, could be DPO. But, ignoring since passing `config` implicitly is deprecated
6868
config = default_sft_config()
6969

7070
set_seed(config.train.seed)
@@ -102,7 +102,7 @@ def train( # noqa: C901
102102
if eval_prompts is None:
103103
eval_prompts = prompts[:batch_size]
104104

105-
# Offline training from the collected samples (e.g. SFT, ILQL)
105+
# Offline training from the collected samples (e.g. SFT, ILQL, DPO)
106106
elif samples:
107107
if rewards is not None:
108108
if len(samples) != len(rewards):

trlx/utils/loading.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
# Register load trainers via module import
88
from trlx.trainer import _TRAINERS, register_trainer
9+
from trlx.trainer.accelerate_dpo_trainer import AccelerateDPOTrainer
910
from trlx.trainer.accelerate_ilql_trainer import AccelerateILQLTrainer
1011
from trlx.trainer.accelerate_ppo_trainer import AcceleratePPOTrainer
1112
from trlx.trainer.accelerate_sft_trainer import AccelerateSFTTrainer

0 commit comments

Comments
 (0)