11# SPDX-License-Identifier: Apache-2.0
22# Standard
3+ from concurrent .futures import ThreadPoolExecutor
4+ from dataclasses import dataclass
35from importlib import resources
4- from typing import Optional
6+ from typing import Iterable , Optional
7+ import math
58import os .path
69
710# Third Party
8- from datasets import Dataset
11+ from datasets import Dataset , concatenate_datasets
12+ from openai import OpenAI
913import yaml
1014
1115# Local
@@ -22,16 +26,47 @@ class EmptyDatasetError(Exception):
2226
2327
2428# This is part of the public API.
25- class PipelineContext :
26- def __init__ (
27- self , client , model_family , model_id , num_instructions_to_generate
28- ) -> None :
29- self .client = client
30- self .model_family = model_family
31- self .model_id = model_id
32- self .num_instructions_to_generate = num_instructions_to_generate
33- # FIXME: base this on the available number of CPUs
34- self .num_procs = 8
29+ @dataclass
30+ class PipelineContext : # pylint: disable=too-many-instance-attributes
31+ """
32+ A PipelineContext holds the common attributes needed between blocks in a
33+ pipeline
34+
35+ client: The OpenAI client handle.
36+ model_id: The ID of the teacher model to be used for client calls.
37+ model_family: The family identifier for the model being updated.
38+ num_instructions_to_generate: The total number of instructions the user
39+ wants to generate during this run.
40+ batch_size: The size of the dataset batches for parallel generation. Set to
41+ 0 to disable batching.
42+ batch_num_workers: The number of worker threads/processes to maintain in the
43+ central executor pool.
44+ dataset_num_procs: The number of processes to use when performing parallel
45+ map operations on individual datasets.
46+ """
47+
48+ # The default batch size of 8 has been determined as a good default for
49+ # standard instructlab workloads when running with vllm batching.
50+ DEFAULT_BATCH_SIZE = 8
51+
52+ # The default number of processes to use when performing parallel operations
53+ # on individual datasets
54+ DEFAULT_DATASET_NUM_PROCS = 8
55+
56+ client : OpenAI
57+ model_family : str
58+ model_id : str
59+ num_instructions_to_generate : int
60+ dataset_num_procs : Optional [int ] = DEFAULT_DATASET_NUM_PROCS
61+ batch_size : int = DEFAULT_BATCH_SIZE
62+ batch_num_workers : Optional [int ] = None
63+
64+ @property
65+ def batching_enabled (self ) -> bool :
66+ """Batching is enabled IFF the batch size is specified and the number of
67+ workers is not set explicitly to 1
68+ """
69+ return self .batch_size > 0 and self .batch_num_workers != 1
3570
3671
3772# This is part of the public API.
@@ -63,7 +98,12 @@ def exception_message(self) -> str:
6398
6499# This is part of the public API.
65100class Pipeline :
66- def __init__ (self , ctx , config_path , chained_blocks : list ) -> None :
101+ def __init__ (
102+ self ,
103+ ctx : PipelineContext ,
104+ config_path : str ,
105+ chained_blocks : list [dict ],
106+ ) -> None :
67107 """
68108 Initialize the Pipeline class with a configuration dictionary.
69109 config_dict: the run config py or yaml loaded into a dictionary
@@ -81,20 +121,40 @@ def from_file(cls, ctx, pipeline_yaml):
81121 pipeline_yaml = os .path .join (resources .files (__package__ ), pipeline_yaml )
82122 return cls (ctx , pipeline_yaml , _parse_pipeline_config_file (pipeline_yaml ))
83123
84- def _drop_duplicates (self , dataset , cols ):
85- """
86- Drop duplicates from the dataset based on the columns provided.
87- """
88- df = dataset .to_pandas ()
89- df = df .drop_duplicates (subset = cols ).reset_index (drop = True )
90- ds = Dataset .from_pandas (df )
91- return ds
92-
93124 def generate (self , dataset ) -> Dataset :
94125 """
95126 Generate the dataset by running the pipeline steps.
96127 dataset: the input dataset
97128 """
129+ # If not batching, simply delegate to _generate_single
130+ if not self .ctx .batching_enabled :
131+ logger .info ("Running pipeline single-threaded" )
132+ return self ._generate_single (dataset )
133+
134+ # Otherwise, split the dataset into batches and run each batch as a
135+ # future in the thread pool
136+ logger .info (
137+ "Running pipeline with multi-threaded batching. Using %s workers for batches of size %s" ,
138+ self .ctx .batch_num_workers ,
139+ self .ctx .batch_size ,
140+ )
141+ input_splits = self ._split_dataset (dataset )
142+ with ThreadPoolExecutor (max_workers = self .ctx .batch_num_workers ) as executor :
143+ futures = [
144+ executor .submit (self ._generate_single , input_split )
145+ for input_split in input_splits
146+ ]
147+
148+ # Collect the results of each batch as they finish. This needs to
149+ # wait for them all, so the order of waiting doesn't matter
150+ output_splits = [future .result () for future in futures ]
151+
152+ return concatenate_datasets (output_splits )
153+
154+ ## Implementation Details ##
155+
156+ def _generate_single (self , dataset ) -> Dataset :
157+ """Generate a single dataset by running the pipeline steps."""
98158 for block_prop in self .chained_blocks :
99159 # Initialize arguments for error handling to None
100160 block , block_name , block_type = None , None , None
@@ -134,6 +194,39 @@ def generate(self, dataset) -> Dataset:
134194
135195 return dataset
136196
197+ def _drop_duplicates (self , dataset , cols ):
198+ """
199+ Drop duplicates from the dataset based on the columns provided.
200+ """
201+ df = dataset .to_pandas ()
202+ df = df .drop_duplicates (subset = cols ).reset_index (drop = True )
203+ ds = Dataset .from_pandas (df )
204+ return ds
205+
206+ def _split_dataset (self , dataset : Dataset ) -> list [Dataset ]:
207+ """Split the dataset into smaller batches."""
208+ assert (
209+ self .ctx .batch_size is not None
210+ ), "Programming Error: Should not call _split_dataset if batching disabled"
211+ total_size = len (dataset )
212+ num_batches = math .ceil (total_size / self .ctx .batch_size )
213+ batches = [
214+ dataset .select (self ._get_batch_indices (i , total_size ))
215+ for i in range (num_batches )
216+ ]
217+ return batches
218+
219+ def _get_batch_indices (self , batch_index : int , total_size : int ) -> Iterable [int ]:
220+ assert (
221+ self .ctx .batch_size is not None
222+ ), "Programming Error: Should not call _get_batch_indices if batching disabled"
223+ return range (
224+ # Start index offset by the batch size
225+ batch_index * self .ctx .batch_size ,
226+ # End index is the next batch offset or the end of the dataset
227+ min ((batch_index + 1 ) * self .ctx .batch_size , total_size ),
228+ )
229+
137230
138231_block_types = {
139232 "CombineColumnsBlock" : utilblocks .CombineColumnsBlock ,
0 commit comments