diff --git a/docs/getting-started/benchmark.md b/docs/getting-started/benchmark.md index 11a04d19a..e99821736 100644 --- a/docs/getting-started/benchmark.md +++ b/docs/getting-started/benchmark.md @@ -62,6 +62,88 @@ GuideLLM supports several benchmark profiles and strategies: - `poisson`: Sends requests following a Poisson distribution - `sweep`: Automatically determines optimal performance points (default) +#### Sweep Profile Configuration + +The sweep profile includes advanced configuration options for optimizing benchmarks on CPU-based deployments. These parameters help manage saturation detection and prevent graph artifacts: + +**Available Parameters:** + +| Parameter | Description | Default | Environment Variable | +| ----------------------------- | ------------------------------------------------- | ------- | ------------------------------------- | +| `--exclude-throughput-target` | Stop constant-rate tests before throughput level | `false` | `GUIDELLM__EXCLUDE_THROUGHPUT_TARGET` | +| `--exclude-throughput-result` | Exclude throughput benchmark from saved results | `false` | `GUIDELLM__EXCLUDE_THROUGHPUT_RESULT` | +| `--saturation-threshold` | Efficiency threshold for stopping sweep (0.0-1.0) | `0.98` | `GUIDELLM__SATURATION_THRESHOLD` | + +**When to Use:** + +- **CPU based system under test**: Enable `exclude-throughput-target` and `exclude-throughput-result` to prevent anomalous data points in performance graphs (TTFT spikes, inter-token latency anomalies) +- **GPU based system under test**: Use default settings (all disabled) + +**Example for CPU-optimized benchmarking:** + +```bash +guidellm benchmark \ + --target "http://localhost:8000" \ + --profile sweep \ + --exclude-throughput-target true \ + --exclude-throughput-result true \ + --saturation-threshold 0.98 \ + --data "prompt_tokens=256,output_tokens=128" \ + --max-seconds 300 +``` + +**Using Environment Variables:** + +```bash +export GUIDELLM__EXCLUDE_THROUGHPUT_TARGET=true +export GUIDELLM__EXCLUDE_THROUGHPUT_RESULT=true +export GUIDELLM__SATURATION_THRESHOLD=0.98 + +guidellm benchmark \ + --target "http://localhost:8000" \ + --profile sweep \ + --data "prompt_tokens=256,output_tokens=128" +``` + +**How It Works:** + +The sweep profile runs tests in this order: + +1. **Synchronous test**: Measures baseline single-request performance +2. **Throughput test**: Discovers maximum server capacity with parallel requests +3. **Constant-rate tests**: Tests at interpolated rates between synchronous and throughput + +Each parameter optimizes a different aspect: + +- **`exclude-throughput-target`**: Prevents generating a constant-rate test at the throughput level itself + + - **Why**: The highest constant-rate test would target the same rate as the throughput test, creating redundant "elbow" artifacts in graphs + - **Effect**: Stops constant-rate tests just before reaching throughput rate + +- **`exclude-throughput-result`**: Removes the throughput benchmark from saved results + + - **Why**: Throughput tests measure burst capacity with severe queuing (e.g., 23+ second TTFT), creating extreme outliers in graphs + - **Effect**: Graphs only show sustainable performance metrics from constant-rate tests + +- **`saturation-threshold`**: Stops the sweep when efficiency drops below threshold + + - **Why**: Once saturation is detected (achieved rate < target rate × threshold), further tests provide diminishing returns + - **Effect**: Saves time by stopping early when the server can no longer meet target rates + +**Why use all three together?** + +For CPU based system under test, all three parameters work synergistically: + +- `saturation-threshold` stops the sweep efficiently when saturation is detected +- `exclude-throughput-target` prevents testing at the unsustainable throughput rate +- `exclude-throughput-result` removes the anomalous throughput spike from graphs + +This combination produces clean, efficient benchmarks that focus on sustainable performance ranges. + +**Important Note:** + +Do not set `--max-concurrency` or `GUIDELLM__MAX_CONCURRENCY` when running sweep tests. The sweep profile uses the throughput test to discover the server's true capacity, and artificially limiting concurrency will result in an underestimated throughput measurement. This causes the constant-rate tests to run at rates far below the actual server capacity, preventing proper saturation detection and producing misleading results where TTFT may decrease instead of increase. + ### Data Options For synthetic data, some key options include, among others: diff --git a/src/guidellm/benchmark/entrypoints.py b/src/guidellm/benchmark/entrypoints.py index 75c8c787b..0b53373ad 100644 --- a/src/guidellm/benchmark/entrypoints.py +++ b/src/guidellm/benchmark/entrypoints.py @@ -534,7 +534,14 @@ async def benchmark_generative_text( prefer_response_metrics=args.prefer_response_metrics, ): if benchmark: - report.benchmarks.append(benchmark) + # Check if we should exclude the throughput benchmark + should_exclude = ( + hasattr(profile, "exclude_throughput_result") + and profile.exclude_throughput_result + and benchmark.config.strategy.type_ == "throughput" + ) + if not should_exclude: + report.benchmarks.append(benchmark) output_format_results = {} for key, output in output_formats.items(): diff --git a/src/guidellm/benchmark/profiles.py b/src/guidellm/benchmark/profiles.py index 054356c10..68cfa0362 100644 --- a/src/guidellm/benchmark/profiles.py +++ b/src/guidellm/benchmark/profiles.py @@ -595,6 +595,35 @@ class SweepProfile(Profile): default=42, description="Random seed for Poisson distribution strategy", ) + exclude_throughput_target: bool = Field( + default=False, + description=( + "Exclude constant-rate test at throughput level. " + "When True, constant-rate tests stop before reaching throughput rate, " + "preventing 'elbow' artifacts in performance graphs. " + "Recommended for CPU-based deployments." + ), + ) + exclude_throughput_result: bool = Field( + default=False, + description=( + "Exclude throughput benchmark from saved results. " + "When True, the throughput benchmark is not saved to the report, " + "preventing anomalous data points in graphs. " + "Recommended for CPU based system under test when saturation is detected." + ), + ) + saturation_threshold: float = Field( + default=0.98, + ge=0.0, + le=1.0, + description=( + "Efficiency threshold for saturation detection (achieved/target rate). " + "Sweep stops when efficiency drops below this value. " + "Default 0.98 (98%) is recommended for CPU based system under test. " + "Use 0.95 (95%) for noisier systems, 0.99 (99%) for very stable systems." + ), + ) synchronous_rate: float = Field( default=-1.0, description="Measured rate from synchronous strategy execution", @@ -634,6 +663,24 @@ def resolve_args( kwargs["random_seed"] = random_seed if rate_type in ["constant", "poisson"]: kwargs["strategy_type"] = rate_type + + # Resolve sweep profile parameters from settings if not provided + if ( + "exclude_throughput_target" not in kwargs + or kwargs["exclude_throughput_target"] is None + ): + kwargs["exclude_throughput_target"] = settings.exclude_throughput_target + if ( + "exclude_throughput_result" not in kwargs + or kwargs["exclude_throughput_result"] is None + ): + kwargs["exclude_throughput_result"] = settings.exclude_throughput_result + if ( + "saturation_threshold" not in kwargs + or kwargs["saturation_threshold"] is None + ): + kwargs["saturation_threshold"] = settings.saturation_threshold + return kwargs @property @@ -645,7 +692,7 @@ def strategy_types(self) -> list[str]: types += [self.strategy_type] * (self.sweep_size - len(types)) return types - def next_strategy( + def next_strategy( # noqa: C901 self, prev_strategy: SchedulingStrategy | None, prev_benchmark: Benchmark | None, @@ -685,13 +732,57 @@ def next_strategy( "Invalid rates in sweep; aborting. " "Were there any successful requests?" ) - self.measured_rates = list( - np.linspace( - self.synchronous_rate, - self.throughput_rate, - self.sweep_size - 1, - ) - )[1:] # don't rerun synchronous + + # Generate interpolated rates between synchronous and throughput. + # The behavior depends on exclude_throughput_target setting: + # + # When exclude_throughput_target=False (default, GPU mode): + # - Generate (sweep_size - 1) points from sync to throughput + # - Remove sync (already tested), keep throughput-level test + # - Example: sweep_size=10 -> 9 points, remove 1 = 8 async tests + # - Last async test targets throughput_rate + # + # When exclude_throughput_target=True (CPU mode): + # - Generate (sweep_size) points from sync to throughput + # - Remove sync AND throughput-level test + # - Example: sweep_size=10 -> 10 points, remove 2 = 8 async tests + # - Last async test stops before throughput_rate + # - Prevents "elbow" artifact in graphs + if self.exclude_throughput_target: + # CPU mode: stop before throughput level + self.measured_rates = list( + np.linspace( + self.synchronous_rate, + self.throughput_rate, + self.sweep_size, + ) + )[1:-1] + else: + # GPU mode: include throughput level + self.measured_rates = list( + np.linspace( + self.synchronous_rate, + self.throughput_rate, + self.sweep_size - 1, + ) + )[1:] + + # Check for saturation: if the previous constant-rate test couldn't + # achieve its target rate, the system has saturated + if ( + prev_strategy + and prev_strategy.type_ in ["constant", "poisson"] + and prev_benchmark + and hasattr(prev_strategy, "rate") + and hasattr(prev_benchmark, "metrics") + ): + target_rate = prev_strategy.rate # type: ignore[attr-defined] + achieved_rate = prev_benchmark.metrics.requests_per_second.successful.mean # type: ignore[attr-defined] + + # If achieved rate is below threshold, system is saturated + if achieved_rate < (target_rate * self.saturation_threshold): + # System saturated - don't test higher rates + return None next_index = ( len(self.completed_strategies) - 1 - 1 diff --git a/src/guidellm/benchmark/schemas/generative/entrypoints.py b/src/guidellm/benchmark/schemas/generative/entrypoints.py index fff2bec37..adef61fe6 100644 --- a/src/guidellm/benchmark/schemas/generative/entrypoints.py +++ b/src/guidellm/benchmark/schemas/generative/entrypoints.py @@ -225,6 +225,33 @@ def get_default(cls: type[BenchmarkGenerativeTextArgs], field: str) -> Any: default=None, description="Additional dataloader configuration arguments" ) random_seed: int = Field(default=42, description="Random seed for reproducibility") + # Sweep profile configuration + exclude_throughput_target: bool | None = Field( + default=None, + description=( + "Exclude constant-rate test at throughput level. " + "When True, constant-rate tests stop before reaching throughput rate. " + "Recommended for CPU-based deployments." + ), + ) + exclude_throughput_result: bool | None = Field( + default=None, + description=( + "Exclude throughput benchmark from saved results. " + "When True, throughput benchmark is not saved to the report. " + "Recommended for CPU-based deployments when saturation is detected." + ), + ) + saturation_threshold: float | None = Field( + default=None, + ge=0.0, + le=1.0, + description=( + "Efficiency threshold for saturation detection (achieved/target rate). " + "Sweep stops when efficiency drops below this value. " + "Default 0.98 (98%) is recommended for CPU based system under test." + ), + ) # Output configuration outputs: list[str] | tuple[str] = Field( default_factory=lambda: ["json", "csv", "html"], diff --git a/src/guidellm/data/builders.py b/src/guidellm/data/builders.py index 7ff584b68..ff96a46a6 100644 --- a/src/guidellm/data/builders.py +++ b/src/guidellm/data/builders.py @@ -219,9 +219,7 @@ def process_dataset( Main method to process and save a dataset with sampled prompt/output token counts. """ _validate_output_suffix(output_path) - logger.info( - f"Starting dataset conversion | Input: {data} | Output: {output_path}" - ) + logger.info(f"Starting dataset conversion | Input: {data} | Output: {output_path}") # Parse config config_obj = parse_synthetic_config(config) @@ -320,9 +318,7 @@ def _extract_column_names( output_mappings = column_mapper.datasets_column_mappings.get( "output_tokens_count_column", [] ) - output_column = ( - output_mappings[0][1] if output_mappings else "output_tokens_count" - ) + output_column = output_mappings[0][1] if output_mappings else "output_tokens_count" return prompt_column, prefix_column, output_column @@ -436,9 +432,7 @@ def _process_single_row( if prefix_tokens_max is not None: prefix_tokens_list = tokenizer.encode(prefix_text) if len(prefix_tokens_list) > prefix_tokens_max: - prefix_text = tokenizer.decode( - prefix_tokens_list[:prefix_tokens_max] - ) + prefix_text = tokenizer.decode(prefix_tokens_list[:prefix_tokens_max]) # Count prefix tokens toward prompt if enabled if include_prefix_in_token_count: @@ -450,9 +444,11 @@ def _process_single_row( elif count_adjustment > 0: adjusted_prompt_len = target_prompt_len - count_adjustment if adjusted_prompt_len <= 0: - logger.warning("The prefix exceeds target output length with " - "--include-prefix-in-token-count enabled; Using prompt size" - "of 1; skipping row") + logger.warning( + "The prefix exceeds target output length with " + "--include-prefix-in-token-count enabled; Using prompt size" + "of 1; skipping row" + ) return None target_prompt_len = adjusted_prompt_len diff --git a/src/guidellm/data/config.py b/src/guidellm/data/config.py index 2b0b2133a..401b5db2f 100644 --- a/src/guidellm/data/config.py +++ b/src/guidellm/data/config.py @@ -48,9 +48,7 @@ def _load_config_file(data: Any, config_class: type[ConfigT]) -> ConfigT | None: if Path(data).is_file() and data_path.suffix.lower() == ".json": try: - return config_class.model_validate_json( - data_path.read_text() - ) + return config_class.model_validate_json(data_path.read_text()) except Exception as err: # noqa: BLE001 error = err @@ -60,9 +58,7 @@ def _load_config_file(data: Any, config_class: type[ConfigT]) -> ConfigT | None: ".config", }: try: - return config_class.model_validate( - yaml.safe_load(data_path.read_text()) - ) + return config_class.model_validate(yaml.safe_load(data_path.read_text())) except Exception as err: # noqa: BLE001 error = err @@ -101,9 +97,7 @@ def _load_config_str(data: str, config_class: type[ConfigT]) -> ConfigT | None: for item in items: key, value = item.split("=") config_dict[key.strip()] = ( - int(value.strip()) - if value.strip().isnumeric() - else value.strip() + int(value.strip()) if value.strip().isnumeric() else value.strip() ) return config_class.model_validate(config_dict) diff --git a/src/guidellm/data/entrypoints.py b/src/guidellm/data/entrypoints.py index 1d88f34f2..f39631187 100644 --- a/src/guidellm/data/entrypoints.py +++ b/src/guidellm/data/entrypoints.py @@ -46,7 +46,18 @@ def process_dataset( :raises ValueError: If the output path is invalid or pushing conditions unmet. """ builders.process_dataset( - data, output_path, processor, config, processor_args, data_args, - data_column_mapper, short_prompt_strategy, pad_char, concat_delimiter, - include_prefix_in_token_count, push_to_hub, hub_dataset_id, random_seed, + data, + output_path, + processor, + config, + processor_args, + data_args, + data_column_mapper, + short_prompt_strategy, + pad_char, + concat_delimiter, + include_prefix_in_token_count, + push_to_hub, + hub_dataset_id, + random_seed, ) diff --git a/src/guidellm/data/schemas.py b/src/guidellm/data/schemas.py index 16af56dff..763f18073 100644 --- a/src/guidellm/data/schemas.py +++ b/src/guidellm/data/schemas.py @@ -25,26 +25,28 @@ "audio_column", ] + class DataNotSupportedError(Exception): """ Exception raised when the data format is not supported by deserializer or config. """ + class DataConfig(StandardBaseModel): """ A generic parent class for various configs for the data package that can be passed in as key-value pairs or JSON. """ -class PreprocessDatasetConfig(DataConfig): +class PreprocessDatasetConfig(DataConfig): prompt_tokens: int = Field( description="The average number of text tokens retained or added to prompts.", gt=0, ) prompt_tokens_stdev: int | None = Field( description="The standard deviation of the number of tokens retained in or " - "added to prompts.", + "added to prompts.", gt=0, default=None, ) @@ -64,7 +66,7 @@ class PreprocessDatasetConfig(DataConfig): ) output_tokens_stdev: int | None = Field( description="The standard deviation of the number of tokens retained or " - "added to outputs.", + "added to outputs.", gt=0, default=None, ) @@ -84,6 +86,7 @@ class PreprocessDatasetConfig(DataConfig): default=None, ) + class SyntheticTextPrefixBucketConfig(StandardBaseModel): bucket_weight: int = Field( description="Weight of this bucket in the overall distribution.", @@ -151,7 +154,6 @@ class SyntheticTextDatasetConfig(DataConfig): default=None, ) - @model_validator(mode="after") def check_prefix_options(self) -> SyntheticTextDatasetConfig: if self.__pydantic_extra__ is not None: diff --git a/src/guidellm/settings.py b/src/guidellm/settings.py index 0e6e6c455..4a67a4842 100644 --- a/src/guidellm/settings.py +++ b/src/guidellm/settings.py @@ -132,6 +132,11 @@ class Settings(BaseSettings): constraint_error_window_size: float = 30 constraint_error_min_processed: float = 30 + # Sweep profile settings + exclude_throughput_target: bool = False + exclude_throughput_result: bool = False + saturation_threshold: float = 0.98 + # Data settings dataset: DatasetSettings = DatasetSettings() diff --git a/tests/unit/data/deserializers/test_synthetic.py b/tests/unit/data/deserializers/test_synthetic.py index eda02ef58..3664470be 100644 --- a/tests/unit/data/deserializers/test_synthetic.py +++ b/tests/unit/data/deserializers/test_synthetic.py @@ -413,7 +413,8 @@ def test_load_config_file_yaml(self): try: loaded_config = config_module._load_config_file( - yaml_path, SyntheticTextDatasetConfig, + yaml_path, + SyntheticTextDatasetConfig, ) assert loaded_config.prompt_tokens == 60 @@ -443,7 +444,8 @@ def test_load_config_file_config_extension(self): try: loaded_config = config_module._load_config_file( - config_path, SyntheticTextDatasetConfig, + config_path, + SyntheticTextDatasetConfig, ) assert loaded_config.prompt_tokens == 90 @@ -460,7 +462,8 @@ def test_load_config_str_json(self): """ json_str = '{"prompt_tokens": 50, "output_tokens": 25}' loaded_config = config_module._load_config_str( - json_str, SyntheticTextDatasetConfig, + json_str, + SyntheticTextDatasetConfig, ) assert loaded_config.prompt_tokens == 50 @@ -474,7 +477,8 @@ def test_load_config_str_key_value(self): """ kv_str = "prompt_tokens=50,output_tokens=25" loaded_config = config_module._load_config_str( - kv_str, SyntheticTextDatasetConfig, + kv_str, + SyntheticTextDatasetConfig, ) assert loaded_config.prompt_tokens == 50 @@ -488,7 +492,8 @@ def test_load_config_str_invalid_format(self): """ with pytest.raises(DataNotSupportedError, match="Unsupported string data"): config_module._load_config_str( - "invalid_format_string", SyntheticTextDatasetConfig, + "invalid_format_string", + SyntheticTextDatasetConfig, ) @pytest.mark.regression @@ -498,7 +503,8 @@ def test_load_config_file_non_existent(self): ### WRITTEN BY AI ### """ loaded_config = config_module._load_config_file( - "/non/existent/path.config", SyntheticTextDatasetConfig, + "/non/existent/path.config", + SyntheticTextDatasetConfig, ) assert loaded_config is None diff --git a/tests/unit/data/test_builders.py b/tests/unit/data/test_builders.py index 946b9cd1b..d0626a739 100644 --- a/tests/unit/data/test_builders.py +++ b/tests/unit/data/test_builders.py @@ -52,60 +52,72 @@ def decode_side_effect(tokens, skip_special_tokens=False): @pytest.fixture def sample_dataset_default_columns(): """Sample dataset with default column names.""" - return Dataset.from_dict({ - "prompt": [ - ( - "This is a very long prompt that should be sufficient for " - "testing purposes. " - ) * 10, - "Short.", - ( - "Another very long prompt for testing the dataset processing " - "functionality. " - ) * 10, - ], - }) + return Dataset.from_dict( + { + "prompt": [ + ( + "This is a very long prompt that should be sufficient for " + "testing purposes. " + ) + * 10, + "Short.", + ( + "Another very long prompt for testing the dataset processing " + "functionality. " + ) + * 10, + ], + } + ) @pytest.fixture def sample_dataset_custom_columns(): """Sample dataset with custom column names requiring mapping.""" - return Dataset.from_dict({ - "question": [ - ( - "What is the meaning of life? This is a longer question that " - "should work for testing. " - ) * 10, - ( - "How does this work? Let me explain in detail how this system " - "functions. " - ) * 10, - ( - "Tell me about machine learning. Machine learning is a " - "fascinating field. " - ) * 10, - ], - }) + return Dataset.from_dict( + { + "question": [ + ( + "What is the meaning of life? This is a longer question that " + "should work for testing. " + ) + * 10, + ( + "How does this work? Let me explain in detail how this system " + "functions. " + ) + * 10, + ( + "Tell me about machine learning. Machine learning is a " + "fascinating field. " + ) + * 10, + ], + } + ) @pytest.fixture def sample_dataset_with_prefix(): """Sample dataset with prefix column.""" - return Dataset.from_dict({ - "prompt": [ - ( - "This is a long prompt that should be sufficient for testing " - "purposes. " - ) * 10, - "Another long prompt here that will work for testing. " * 10, - "Yet another long prompt for testing purposes. " * 10, - ], - "system_prompt": [ - "You are a helpful assistant.", - "You are a helpful assistant.", - "You are a helpful assistant.", - ], - }) + return Dataset.from_dict( + { + "prompt": [ + ( + "This is a long prompt that should be sufficient for testing " + "purposes. " + ) + * 10, + "Another long prompt here that will work for testing. " * 10, + "Yet another long prompt for testing purposes. " * 10, + ], + "system_prompt": [ + "You are a helpful assistant.", + "You are a helpful assistant.", + "You are a helpful assistant.", + ], + } + ) @pytest.fixture @@ -192,30 +204,32 @@ def test_process_dataset_concatenate_strategy( # Create a dataset with short prompts that can be concatenated to reach target # Use a lower target (15 tokens) so concatenation is achievable short_config = '{"prompt_tokens": 15, "output_tokens": 10}' - short_prompts_dataset = Dataset.from_dict({ - "prompt": [ - "A", # 1 char = 1 token - "B", # 1 char = 1 token - "C", # 1 char = 1 token - "D", # 1 char = 1 token - "E", # 1 char = 1 token - "F", # 1 char = 1 token - "G", # 1 char = 1 token - "H", # 1 char = 1 token - "I", # 1 char = 1 token - "J", # 1 char = 1 token - "K", # 1 char = 1 token - "L", # 1 char = 1 token - "M", # 1 char = 1 token - "N", # 1 char = 1 token - "O", # 1 char = 1 token - "P", # 1 char = 1 token - "Q", # 1 char = 1 token - "R", # 1 char = 1 token - "S", # 1 char = 1 token - "T", # 1 char = 1 token - ], - }) + short_prompts_dataset = Dataset.from_dict( + { + "prompt": [ + "A", # 1 char = 1 token + "B", # 1 char = 1 token + "C", # 1 char = 1 token + "D", # 1 char = 1 token + "E", # 1 char = 1 token + "F", # 1 char = 1 token + "G", # 1 char = 1 token + "H", # 1 char = 1 token + "I", # 1 char = 1 token + "J", # 1 char = 1 token + "K", # 1 char = 1 token + "L", # 1 char = 1 token + "M", # 1 char = 1 token + "N", # 1 char = 1 token + "O", # 1 char = 1 token + "P", # 1 char = 1 token + "Q", # 1 char = 1 token + "R", # 1 char = 1 token + "S", # 1 char = 1 token + "T", # 1 char = 1 token + ], + } + ) # Setup mocks mock_check_processor.return_value = tokenizer_mock @@ -323,8 +337,9 @@ def test_process_dataset_pad_strategy( # Verify that prompts meet minimum token count requirements actual_tokens = len(tokenizer_mock.encode(row["prompt"])) - assert actual_tokens >= 50, \ + assert actual_tokens >= 50, ( f"Padded prompt should have at least 50 tokens, got {actual_tokens}" + ) assert row["prompt_tokens_count"] == actual_tokens # For the "Short." prompt (index 1), verify it was padded @@ -527,12 +542,14 @@ def test_process_dataset_with_instruction_column( """ # Create dataset with 'instruction' column (one of the default # text_column names) - dataset = Dataset.from_dict({ - "instruction": [ - "Follow these instructions carefully. " * 20, - "Complete the task as described. " * 20, - ], - }) + dataset = Dataset.from_dict( + { + "instruction": [ + "Follow these instructions carefully. " * 20, + "Complete the task as described. " * 20, + ], + } + ) # Setup mocks mock_check_processor.return_value = tokenizer_mock @@ -823,10 +840,12 @@ def test_process_dataset_empty_after_filtering( ## WRITTEN BY AI ## """ # Create dataset with only very short prompts that will be filtered out - dataset = Dataset.from_dict({ - # Very short prompts (1 char each, less than 50 tokens) - "prompt": ["A", "B", "C"], - }) + dataset = Dataset.from_dict( + { + # Very short prompts (1 char each, less than 50 tokens) + "prompt": ["A", "B", "C"], + } + ) # Setup mocks mock_check_processor.return_value = tokenizer_mock @@ -1462,8 +1481,9 @@ def test_prompt_trimming_accuracy( # Verify all prompts are trimmed to exactly 50 tokens for row in saved_dataset: actual_tokens = len(tokenizer_mock.encode(row["prompt"])) - assert actual_tokens == 50, \ + assert actual_tokens == 50, ( f"Prompt not trimmed correctly: expected 50 tokens, got {actual_tokens}" + ) @pytest.mark.sanity @patch("guidellm.data.builders.save_dataset_to_file") @@ -1515,8 +1535,9 @@ def test_prompt_padding_accuracy( for row in saved_dataset: prompt_text = row["prompt"] actual_tokens = len(tokenizer_mock.encode(prompt_text)) - assert actual_tokens == 100, \ + assert actual_tokens == 100, ( f"Prompt not padded correctly: expected 100 tokens, got {actual_tokens}" + ) assert row["prompt_tokens_count"] == 100 # Verify that pad_char "X" appears in the padded prompts @@ -1813,9 +1834,11 @@ def test_process_dataset_push_to_hub_called( ): """Test that push_to_hub is called when push_to_hub=True.""" # Create a dataset with prompts long enough to be processed - sample_dataset = Dataset.from_dict({ - "prompt": ["abc " * 50], # Long enough - }) + sample_dataset = Dataset.from_dict( + { + "prompt": ["abc " * 50], # Long enough + } + ) mock_check_processor.return_value = tokenizer_mock mock_deserializer_factory_class.deserialize.return_value = sample_dataset @@ -1854,9 +1877,11 @@ def test_process_dataset_push_to_hub_not_called( ): """Test that push_to_hub is not called when push_to_hub=False.""" # Create a dataset with prompts long enough to be processed - sample_dataset = Dataset.from_dict({ - "prompt": ["abc " * 50], # Long enough - }) + sample_dataset = Dataset.from_dict( + { + "prompt": ["abc " * 50], # Long enough + } + ) mock_check_processor.return_value = tokenizer_mock mock_deserializer_factory_class.deserialize.return_value = sample_dataset @@ -1918,15 +1943,18 @@ def test_strategy_handler_called( ): """Test that strategy handlers are called during dataset processing.""" from guidellm.data.builders import STRATEGY_HANDLERS + mock_handler = MagicMock(return_value="processed_prompt") with patch.dict(STRATEGY_HANDLERS, {ShortPromptStrategy.IGNORE: mock_handler}): # Create a dataset with prompts that need processing - sample_dataset = Dataset.from_dict({ - "prompt": [ - "abc" * 20, # Long enough to pass - "def" * 20, # Long enough to pass - ], - }) + sample_dataset = Dataset.from_dict( + { + "prompt": [ + "abc" * 20, # Long enough to pass + "def" * 20, # Long enough to pass + ], + } + ) mock_check_processor.return_value = tokenizer_mock mock_deserializer_factory_class.deserialize.return_value = sample_dataset