Skip to content

Commit 30fb578

Browse files
authored
Merge pull request #157 from gabe-l-hart/LLMBlockConcurrency-135
LLMBlock concurrency
2 parents e4765b9 + a27a1b8 commit 30fb578

File tree

10 files changed

+351
-45
lines changed

10 files changed

+351
-45
lines changed

src/instructlab/sdg/filterblock.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,9 +178,13 @@ def generate(self, samples) -> Dataset:
178178
samples,
179179
self.column_name,
180180
self.dtype,
181-
self.ctx.num_procs,
181+
self.ctx.dataset_num_procs,
182182
)
183183

184184
return _filter_by_values(
185-
samples, self.column_name, self.operation, self.value, self.ctx.num_procs
185+
samples,
186+
self.column_name,
187+
self.operation,
188+
self.value,
189+
self.ctx.dataset_num_procs,
186190
)

src/instructlab/sdg/generate_data.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,15 @@ def _check_pipeline_dir(pipeline):
173173
)
174174

175175

176-
def _sdg_init(pipeline, client, model_family, model_id, num_instructions_to_generate):
176+
def _sdg_init(
177+
pipeline: Pipeline,
178+
client: openai.OpenAI,
179+
model_family: str,
180+
model_id: str,
181+
num_instructions_to_generate: int,
182+
batch_num_workers: Optional[int],
183+
batch_size: Optional[int],
184+
):
177185
pipeline_pkg = None
178186

179187
# Search for the pipeline in User and Site data directories
@@ -200,7 +208,18 @@ def _sdg_init(pipeline, client, model_family, model_id, num_instructions_to_gene
200208
)
201209
_check_pipeline_dir(pipeline)
202210

203-
ctx = PipelineContext(client, model_family, model_id, num_instructions_to_generate)
211+
extra_kwargs = {}
212+
if batch_size is not None:
213+
extra_kwargs["batch_size"] = batch_size
214+
extra_kwargs["batch_num_workers"] = batch_num_workers
215+
216+
ctx = PipelineContext(
217+
client=client,
218+
model_family=model_family,
219+
model_id=model_id,
220+
num_instructions_to_generate=num_instructions_to_generate,
221+
**extra_kwargs,
222+
)
204223

205224
def load_pipeline(yaml_basename):
206225
if pipeline_pkg:
@@ -227,8 +246,6 @@ def generate_data(
227246
api_key: Optional[str] = None,
228247
model_family: Optional[str] = None,
229248
model_name: Optional[str] = None,
230-
# TODO - not used -- when batching is enabled, this is relevant.
231-
# Right now the code hard codes 8 cpus for batching
232249
num_cpus: Optional[int] = None,
233250
num_instructions_to_generate: Optional[int] = 30,
234251
taxonomy: Optional[str] = None,
@@ -247,6 +264,7 @@ def generate_data(
247264
tls_client_key: Optional[str] = None,
248265
tls_client_passwd: Optional[str] = None,
249266
pipeline: Optional[str] = "simple",
267+
batch_size: Optional[int] = None,
250268
) -> None:
251269
"""Generate data for training and testing a model.
252270
@@ -264,6 +282,10 @@ def generate_data(
264282
"""
265283
generate_start = time.time()
266284

285+
# FIXME: remove this when ilab knows to pass batch_size=0 with llama.cpp
286+
if batch_size is None:
287+
batch_size = 0
288+
267289
if not os.path.exists(output_dir):
268290
os.mkdir(output_dir)
269291

@@ -302,15 +324,14 @@ def generate_data(
302324
else:
303325
model_family = MODEL_FAMILY_MERLINITE
304326

305-
# TODO -- llama-cpp doesn't support batching, we need to get a hint from the CLI
306-
# about whether we can turn this on (whether vllm is used or not)
307-
308327
sdg_knowledge, sdg_freeform_skill, sdg_grounded_skill = _sdg_init(
309328
pipeline,
310329
client,
311330
model_family,
312331
model_name,
313332
num_instructions_to_generate,
333+
batch_size=batch_size,
334+
batch_num_workers=num_cpus,
314335
)
315336

316337
if console_output:

src/instructlab/sdg/pipeline.py

Lines changed: 115 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# Standard
3+
from concurrent.futures import ThreadPoolExecutor
4+
from dataclasses import dataclass
35
from importlib import resources
4-
from typing import Optional
6+
from typing import Iterable, Optional
7+
import math
58
import os.path
69

710
# Third Party
8-
from datasets import Dataset
11+
from datasets import Dataset, concatenate_datasets
12+
from openai import OpenAI
913
import yaml
1014

1115
# Local
@@ -22,16 +26,47 @@ class EmptyDatasetError(Exception):
2226

2327

2428
# This is part of the public API.
25-
class PipelineContext:
26-
def __init__(
27-
self, client, model_family, model_id, num_instructions_to_generate
28-
) -> None:
29-
self.client = client
30-
self.model_family = model_family
31-
self.model_id = model_id
32-
self.num_instructions_to_generate = num_instructions_to_generate
33-
# FIXME: base this on the available number of CPUs
34-
self.num_procs = 8
29+
@dataclass
30+
class PipelineContext: # pylint: disable=too-many-instance-attributes
31+
"""
32+
A PipelineContext holds the common attributes needed between blocks in a
33+
pipeline
34+
35+
client: The OpenAI client handle.
36+
model_id: The ID of the teacher model to be used for client calls.
37+
model_family: The family identifier for the model being updated.
38+
num_instructions_to_generate: The total number of instructions the user
39+
wants to generate during this run.
40+
batch_size: The size of the dataset batches for parallel generation. Set to
41+
0 to disable batching.
42+
batch_num_workers: The number of worker threads/processes to maintain in the
43+
central executor pool.
44+
dataset_num_procs: The number of processes to use when performing parallel
45+
map operations on individual datasets.
46+
"""
47+
48+
# The default batch size of 8 has been determined as a good default for
49+
# standard instructlab workloads when running with vllm batching.
50+
DEFAULT_BATCH_SIZE = 8
51+
52+
# The default number of processes to use when performing parallel operations
53+
# on individual datasets
54+
DEFAULT_DATASET_NUM_PROCS = 8
55+
56+
client: OpenAI
57+
model_family: str
58+
model_id: str
59+
num_instructions_to_generate: int
60+
dataset_num_procs: Optional[int] = DEFAULT_DATASET_NUM_PROCS
61+
batch_size: int = DEFAULT_BATCH_SIZE
62+
batch_num_workers: Optional[int] = None
63+
64+
@property
65+
def batching_enabled(self) -> bool:
66+
"""Batching is enabled IFF the batch size is specified and the number of
67+
workers is not set explicitly to 1
68+
"""
69+
return self.batch_size > 0 and self.batch_num_workers != 1
3570

3671

3772
# This is part of the public API.
@@ -63,7 +98,12 @@ def exception_message(self) -> str:
6398

6499
# This is part of the public API.
65100
class Pipeline:
66-
def __init__(self, ctx, config_path, chained_blocks: list) -> None:
101+
def __init__(
102+
self,
103+
ctx: PipelineContext,
104+
config_path: str,
105+
chained_blocks: list[dict],
106+
) -> None:
67107
"""
68108
Initialize the Pipeline class with a configuration dictionary.
69109
config_dict: the run config py or yaml loaded into a dictionary
@@ -81,20 +121,40 @@ def from_file(cls, ctx, pipeline_yaml):
81121
pipeline_yaml = os.path.join(resources.files(__package__), pipeline_yaml)
82122
return cls(ctx, pipeline_yaml, _parse_pipeline_config_file(pipeline_yaml))
83123

84-
def _drop_duplicates(self, dataset, cols):
85-
"""
86-
Drop duplicates from the dataset based on the columns provided.
87-
"""
88-
df = dataset.to_pandas()
89-
df = df.drop_duplicates(subset=cols).reset_index(drop=True)
90-
ds = Dataset.from_pandas(df)
91-
return ds
92-
93124
def generate(self, dataset) -> Dataset:
94125
"""
95126
Generate the dataset by running the pipeline steps.
96127
dataset: the input dataset
97128
"""
129+
# If not batching, simply delegate to _generate_single
130+
if not self.ctx.batching_enabled:
131+
logger.info("Running pipeline single-threaded")
132+
return self._generate_single(dataset)
133+
134+
# Otherwise, split the dataset into batches and run each batch as a
135+
# future in the thread pool
136+
logger.info(
137+
"Running pipeline with multi-threaded batching. Using %s workers for batches of size %s",
138+
self.ctx.batch_num_workers,
139+
self.ctx.batch_size,
140+
)
141+
input_splits = self._split_dataset(dataset)
142+
with ThreadPoolExecutor(max_workers=self.ctx.batch_num_workers) as executor:
143+
futures = [
144+
executor.submit(self._generate_single, input_split)
145+
for input_split in input_splits
146+
]
147+
148+
# Collect the results of each batch as they finish. This needs to
149+
# wait for them all, so the order of waiting doesn't matter
150+
output_splits = [future.result() for future in futures]
151+
152+
return concatenate_datasets(output_splits)
153+
154+
## Implementation Details ##
155+
156+
def _generate_single(self, dataset) -> Dataset:
157+
"""Generate a single dataset by running the pipeline steps."""
98158
for block_prop in self.chained_blocks:
99159
# Initialize arguments for error handling to None
100160
block, block_name, block_type = None, None, None
@@ -134,6 +194,39 @@ def generate(self, dataset) -> Dataset:
134194

135195
return dataset
136196

197+
def _drop_duplicates(self, dataset, cols):
198+
"""
199+
Drop duplicates from the dataset based on the columns provided.
200+
"""
201+
df = dataset.to_pandas()
202+
df = df.drop_duplicates(subset=cols).reset_index(drop=True)
203+
ds = Dataset.from_pandas(df)
204+
return ds
205+
206+
def _split_dataset(self, dataset: Dataset) -> list[Dataset]:
207+
"""Split the dataset into smaller batches."""
208+
assert (
209+
self.ctx.batch_size is not None
210+
), "Programming Error: Should not call _split_dataset if batching disabled"
211+
total_size = len(dataset)
212+
num_batches = math.ceil(total_size / self.ctx.batch_size)
213+
batches = [
214+
dataset.select(self._get_batch_indices(i, total_size))
215+
for i in range(num_batches)
216+
]
217+
return batches
218+
219+
def _get_batch_indices(self, batch_index: int, total_size: int) -> Iterable[int]:
220+
assert (
221+
self.ctx.batch_size is not None
222+
), "Programming Error: Should not call _get_batch_indices if batching disabled"
223+
return range(
224+
# Start index offset by the batch size
225+
batch_index * self.ctx.batch_size,
226+
# End index is the next batch offset or the end of the dataset
227+
min((batch_index + 1) * self.ctx.batch_size, total_size),
228+
)
229+
137230

138231
_block_types = {
139232
"CombineColumnsBlock": utilblocks.CombineColumnsBlock,

src/instructlab/sdg/utilblocks.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def populate(sample):
3535

3636
def generate(self, samples) -> Dataset:
3737
return self._map_populate(
38-
samples, self.configs, self.column_name, self.ctx.num_procs
38+
samples, self.configs, self.column_name, self.ctx.dataset_num_procs
3939
)
4040

4141

@@ -64,7 +64,7 @@ def generate(self, samples: Dataset) -> Dataset:
6464
self.choice_map,
6565
self.choice_col,
6666
self.output_col,
67-
self.ctx.num_procs,
67+
self.ctx.dataset_num_procs,
6868
)
6969

7070

@@ -89,5 +89,9 @@ def combine(sample):
8989

9090
def generate(self, samples: Dataset) -> Dataset:
9191
return self._map_combine(
92-
samples, self.columns, self.output_col, self.separator, self.ctx.num_procs
92+
samples,
93+
self.columns,
94+
self.output_col,
95+
self.separator,
96+
self.ctx.dataset_num_procs,
9397
)

tests/conftest.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
"""
2+
Common fixtures and testing utilities
3+
"""
4+
5+
# Standard
6+
from unittest import mock
7+
8+
# Third Party
9+
from datasets import Dataset
10+
import pytest
11+
12+
# First Party
13+
from instructlab.sdg.pipeline import PipelineContext
14+
15+
16+
def get_ctx(**kwargs) -> PipelineContext:
17+
kwargs.setdefault("client", mock.MagicMock())
18+
kwargs.setdefault("model_family", "test")
19+
kwargs.setdefault("model_id", "test-model")
20+
kwargs.setdefault("num_instructions_to_generate", 10)
21+
kwargs.setdefault("dataset_num_procs", 1)
22+
return PipelineContext(**kwargs)
23+
24+
25+
def get_single_threaded_ctx(**kwargs) -> PipelineContext:
26+
kwargs["batch_size"] = 0
27+
return get_ctx(**kwargs)
28+
29+
30+
def get_threaded_ctx(**kwargs) -> PipelineContext:
31+
kwargs["batch_size"] = 6
32+
kwargs["batch_num_workers"] = 2
33+
return get_ctx(**kwargs)
34+
35+
36+
@pytest.fixture
37+
def single_threaded_ctx() -> PipelineContext:
38+
return get_single_threaded_ctx()
39+
40+
41+
@pytest.fixture
42+
def threaded_ctx() -> PipelineContext:
43+
return get_threaded_ctx()
44+
45+
46+
@pytest.fixture
47+
def sample_dataset():
48+
return Dataset.from_list([{"foo": i} for i in range(10)])

tests/test_filterblock.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
class TestFilterByValueBlock(unittest.TestCase):
1515
def setUp(self):
1616
self.ctx = MagicMock()
17-
self.ctx.num_procs = 1
17+
self.ctx.dataset_num_procs = 1
1818
self.pipe = MagicMock()
1919
self.block = FilterByValueBlock(
2020
self.ctx,

0 commit comments

Comments
 (0)