Skip to content

Commit 4614877

Browse files
authored
Merge pull request #222 from derekhiggins/checkpointing
Add data checkpointing capability
2 parents 7581308 + f5d22d7 commit 4614877

File tree

5 files changed

+235
-3
lines changed

5 files changed

+235
-3
lines changed
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Standard
2+
import logging
3+
import uuid
4+
5+
# Third Party
6+
from datasets import Dataset, concatenate_datasets, load_dataset
7+
from datasets.data_files import EmptyDatasetError
8+
9+
# First Party
10+
from instructlab.sdg.utils import pandas
11+
12+
logger = logging.getLogger(__name__)
13+
14+
15+
class Checkpointer:
16+
def __init__(self, checkpoint_dir=None, save_freq=1):
17+
self._checkpoint_dir = checkpoint_dir
18+
19+
self._save_freq = save_freq
20+
self._cache = []
21+
22+
def checkpoint(self, dataset):
23+
self._cache.append(dataset)
24+
if len(self._cache) < self._save_freq:
25+
return
26+
self.save()
27+
self._cache.clear()
28+
29+
def done(self):
30+
if self._cache:
31+
self.save()
32+
self._cache.clear()
33+
34+
def save(self):
35+
if self._checkpoint_dir is None:
36+
return
37+
checkpoint_id = uuid.uuid4().hex
38+
checkpoint_file = (
39+
f"{self._checkpoint_dir}/data_checkpoint_{checkpoint_id}.jsonl"
40+
)
41+
logger.info(f"Saving checkpoint to {checkpoint_file}")
42+
# Saves all the current records to new file in the checkpoint dir
43+
concatenate_datasets(self._cache).to_json(
44+
checkpoint_file, orient="records", lines=True
45+
)
46+
47+
def load(self, dataset: Dataset) -> Dataset:
48+
if self._checkpoint_dir is None:
49+
return dataset, None
50+
51+
try:
52+
pre_generated_data = load_dataset(
53+
"json", data_dir=self._checkpoint_dir, split="train"
54+
)
55+
except EmptyDatasetError:
56+
logger.info(
57+
f"No existing checkpoints found in {self._checkpoint_dir}, generating from scratch"
58+
)
59+
return dataset, None
60+
61+
logger.info(
62+
f"Loading existing checkpoints from {self._checkpoint_dir}, with {pre_generated_data.num_rows} rows"
63+
)
64+
seed_data = self._get_missing_data(dataset, pre_generated_data)
65+
logger.info(f"Found {seed_data.num_rows} missing rows in the dataset")
66+
return seed_data, pre_generated_data
67+
68+
def _get_missing_data(self, seed_data, generated_data):
69+
# Get the common columns between the two datasets
70+
common_columns = list(
71+
set(seed_data.column_names) & set(generated_data.column_names)
72+
)
73+
74+
# Extract the relevant data based on common columns
75+
seed_data_common = seed_data.select_columns(common_columns)
76+
generated_data_common = generated_data.select_columns(common_columns)
77+
78+
# Convert to Pandas DataFrames for easier comparison
79+
seed_df = seed_data_common.to_pandas()
80+
generated_df = generated_data_common.to_pandas()
81+
82+
# Identify missing rows
83+
missing_rows = ~seed_df.apply(tuple, 1).isin(generated_df.apply(tuple, 1))
84+
85+
missing_df = seed_data.to_pandas()[missing_rows]
86+
return pandas.dataset_from_pandas_dataframe(missing_df)

src/instructlab/sdg/generate_data.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from importlib import resources
66
from pathlib import Path
77
from typing import Optional
8+
import dataclasses
89
import json
910
import os
1011
import time
@@ -181,6 +182,8 @@ def _context_init(
181182
model_family: str,
182183
model_id: str,
183184
num_instructions_to_generate: int,
185+
checkpoint_dir: str,
186+
save_freq: int,
184187
batch_num_workers: Optional[int],
185188
batch_size: Optional[int],
186189
):
@@ -194,6 +197,8 @@ def _context_init(
194197
model_family=model_family,
195198
model_id=model_id,
196199
num_instructions_to_generate=num_instructions_to_generate,
200+
checkpoint_dir=checkpoint_dir,
201+
save_freq=save_freq,
197202
**extra_kwargs,
198203
)
199204

@@ -284,6 +289,7 @@ def generate_data(
284289
client: Optional[openai.OpenAI] = None,
285290
pipeline: Optional[str] = "simple",
286291
batch_size: Optional[int] = None,
292+
checkpoint_dir: Optional[str] = None,
287293
) -> None:
288294
"""Generate data for training and testing a model.
289295
@@ -348,13 +354,17 @@ def generate_data(
348354
model_family,
349355
model_name,
350356
num_instructions_to_generate,
357+
checkpoint_dir,
358+
1, # save_freq
351359
batch_size=batch_size,
352360
batch_num_workers=num_cpus,
353361
)
354362

355363
sdg_knowledge, sdg_freeform_skill, sdg_grounded_skill = _sdg_init(ctx, pipeline)
356364

357-
mmlu_bench_pipe = mmlubench_pipe_init(ctx)
365+
# Make sure checkpointing is disabled (we don't want this pipeline to load checkpoints from the main pipeline)
366+
mmlu_ctx = dataclasses.replace(ctx, checkpoint_dir=None)
367+
mmlu_bench_pipe = mmlubench_pipe_init(mmlu_ctx)
358368

359369
mixer = _mixer_init(ctx, output_dir, date_suffix)
360370

src/instructlab/sdg/pipeline.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import yaml
1414

1515
# First Party
16+
from instructlab.sdg.checkpointing import Checkpointer
1617
from instructlab.sdg.utils import pandas
1718

1819
# Local
@@ -61,6 +62,8 @@ class PipelineContext: # pylint: disable=too-many-instance-attributes
6162
model_id: str
6263
num_instructions_to_generate: int
6364
dataset_num_procs: Optional[int] = DEFAULT_DATASET_NUM_PROCS
65+
checkpoint_dir: Optional[str] = None
66+
save_freq: Optional[int] = 1
6467
batch_size: int = DEFAULT_BATCH_SIZE
6568
batch_num_workers: Optional[int] = None
6669

@@ -129,6 +132,12 @@ def generate(self, dataset) -> Dataset:
129132
Generate the dataset by running the pipeline steps.
130133
dataset: the input dataset
131134
"""
135+
136+
# The checkpointer allows us to resume from where we left off
137+
# Saving the output of pipe instances along the way
138+
checkpointer = Checkpointer(self.ctx.checkpoint_dir, self.ctx.save_freq)
139+
dataset, pre_generated_data = checkpointer.load(dataset)
140+
132141
# If not batching, simply delegate to _generate_single
133142
if not self.ctx.batching_enabled:
134143
logger.info("Running pipeline single-threaded")
@@ -142,6 +151,7 @@ def generate(self, dataset) -> Dataset:
142151
self.ctx.batch_size,
143152
)
144153
input_splits = self._split_dataset(dataset)
154+
output_splits = []
145155
with ThreadPoolExecutor(max_workers=self.ctx.batch_num_workers) as executor:
146156
futures = [
147157
executor.submit(self._generate_single, input_split)
@@ -150,8 +160,13 @@ def generate(self, dataset) -> Dataset:
150160

151161
# Collect the results of each batch as they finish. This needs to
152162
# wait for them all, so the order of waiting doesn't matter
153-
output_splits = [future.result() for future in futures]
154-
163+
for future in futures:
164+
ds = future.result()
165+
output_splits.append(ds)
166+
checkpointer.checkpoint(ds)
167+
checkpointer.done()
168+
if pre_generated_data:
169+
output_splits.append(pre_generated_data)
155170
return concatenate_datasets(output_splits)
156171

157172
## Implementation Details ##

tests/test_checkpointing.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# Standard
2+
import json
3+
import os
4+
5+
# Third Party
6+
from datasets import Dataset
7+
import pytest
8+
9+
# First Party
10+
from instructlab.sdg.checkpointing import Checkpointer
11+
12+
13+
def _add_bar(sample, add_value=100):
14+
sample["bar"] = sample["foo"] + add_value
15+
return sample
16+
17+
18+
def _populate_checkpoints(tmpdir, dataset, checkpoints_count, remove_column):
19+
for i in range(0, checkpoints_count):
20+
checkpoint_dataset = dataset.select(range(i * 10, (i + 1) * 10))
21+
checkpoint_dataset = checkpoint_dataset.map(
22+
lambda x: _add_bar(x, add_value=100)
23+
)
24+
if remove_column:
25+
checkpoint_dataset = checkpoint_dataset.remove_columns("foo")
26+
checkpoint_dataset.to_json(
27+
os.path.join(tmpdir, f"data_checkpoint_abcde{i}.jsonl"),
28+
orient="records",
29+
lines=True,
30+
)
31+
32+
33+
def _validate_checkpoints(tmpdir, expected_files_count, expected_length, remove_column):
34+
saved_files = os.listdir(tmpdir)
35+
assert len(saved_files) == expected_files_count
36+
assert all(f.startswith("data_checkpoint_") for f in saved_files)
37+
assert all(f.endswith(".jsonl") for f in saved_files)
38+
39+
for f in saved_files:
40+
with open(os.path.join(tmpdir, f), "r") as f:
41+
l = list(f)
42+
if isinstance(expected_length, list):
43+
expected_length.remove(len(l))
44+
else:
45+
assert len(l) == expected_length
46+
for s in l:
47+
data = json.loads(s)
48+
if remove_column:
49+
assert "foo" not in data and "bar" in data
50+
else:
51+
assert "foo" in data and "bar" in data
52+
53+
54+
@pytest.mark.parametrize(
55+
"save_freq, remove_column, dataset_size, init_checkpoints, splits, final_checkpoints, checkpoint_length",
56+
[
57+
(1, False, 10, 0, 0, 1, 10),
58+
(1, True, 10, 0, 0, 1, 10),
59+
(1, False, 100, 1, 9, 10, 10),
60+
(1, True, 100, 1, 9, 10, 10),
61+
(1, False, 100, 2, 8, 10, 10),
62+
(3, False, 100, 2, 8, 5, [10, 10, 30, 30, 20]),
63+
],
64+
)
65+
def test_checkpointing(
66+
tmpdir,
67+
save_freq,
68+
remove_column,
69+
dataset_size,
70+
init_checkpoints,
71+
splits,
72+
final_checkpoints,
73+
checkpoint_length,
74+
):
75+
# Our initial dataset
76+
dataset = Dataset.from_list([{"idx": i, "foo": i} for i in range(dataset_size)])
77+
78+
# Generate and save some checkpoints to disk
79+
_populate_checkpoints(tmpdir, dataset, init_checkpoints, remove_column)
80+
81+
# Load checkpoints, giving us the remaining dataset to process and
82+
# the generated data loaded from the checkpoints
83+
checkpointer = Checkpointer(checkpoint_dir=tmpdir, save_freq=save_freq)
84+
dataset, pre_generated_data = checkpointer.load(dataset)
85+
86+
# Should be present, even if removed from the checkpoint (remove_column=True)
87+
assert "foo" in dataset.features
88+
89+
# When testing save_freq, we will have checkpoints of different lengths
90+
if isinstance(checkpoint_length, list):
91+
checkpoints_total = sum(checkpoint_length[:init_checkpoints])
92+
else:
93+
checkpoints_total = checkpoint_length * init_checkpoints
94+
95+
# Validate pre-generated data loaded from the checkpoints
96+
assert len(dataset) == (dataset_size - checkpoints_total)
97+
if init_checkpoints > 0:
98+
assert len(pre_generated_data) == checkpoints_total
99+
100+
# Apply pipeline to the remaining dataset and save checkpoints
101+
if splits:
102+
for i in range(0, splits):
103+
split = dataset.select(range(i * 10, (i + 1) * 10))
104+
split = split.map(lambda x: _add_bar(x, add_value=100))
105+
if remove_column:
106+
split = split.remove_columns("foo")
107+
checkpointer.checkpoint(split)
108+
else:
109+
dataset = dataset.map(lambda x: _add_bar(x, add_value=10))
110+
if remove_column:
111+
dataset = dataset.remove_columns("foo")
112+
checkpointer.checkpoint(dataset)
113+
114+
checkpointer.done()
115+
116+
# Validate that all checkpoints are now saved to disk
117+
_validate_checkpoints(tmpdir, final_checkpoints, checkpoint_length, remove_column)

tests/test_generate_data.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ def test_context_init_batch_size_optional():
1919
"mixtral",
2020
"foo.bar",
2121
1,
22+
"/checkpoint/dir",
23+
1,
2224
batch_size=None,
2325
batch_num_workers=None,
2426
)
@@ -32,6 +34,8 @@ def test_context_init_batch_size_optional():
3234
"mixtral",
3335
"foo.bar",
3436
1,
37+
"/checkpoint/dir",
38+
1,
3539
batch_size=20,
3640
batch_num_workers=32,
3741
)

0 commit comments

Comments
 (0)