Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ include = ["*"]
[tool.setuptools.package-data]
"guidellm.data" = ["*.gz"]
"guidellm.benchmark.scenarios" = ["*.json", "**/*.json"]
"guidellm.benchmark.outputs.html_outputs" = ["*.html"]

[[tool.uv.index]]
name = "pytorch-cpu"
Expand Down Expand Up @@ -70,11 +71,12 @@ dependencies = [
"transformers",
"uvloop>=0.18",
"torch",
"more-itertools>=10.8.0",
]

[project.optional-dependencies]
# Meta Extras
all = ["guidellm[perf,tokenizers,audio,vision]"]
all = ["guidellm[perf,tokenizers,audio,vision,embeddings]"]
recommended = ["guidellm[perf,tokenizers]"]
# Feature Extras
perf = ["orjson", "msgpack", "msgspec", "uvloop"]
Expand All @@ -90,6 +92,12 @@ vision = [
"datasets[vision]",
"pillow",
]
embeddings = [
# Quality validation with baseline models
"sentence-transformers>=2.2.0",
# MTEB benchmark integration
"mteb>=1.0.0",
]
# Dev Tooling
dev = [
# Install all optional dependencies
Expand Down Expand Up @@ -179,7 +187,9 @@ module = [
"transformers.*",
"setuptools.*",
"setuptools_git_versioning.*",
"torchcodec.*"
"torchcodec.*",
"sentence_transformers.*",
"mteb.*"
]
ignore_missing_imports = true

Expand Down
221 changes: 221 additions & 0 deletions src/guidellm/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,5 +792,226 @@ def mock_server(
server.run()


@benchmark.command(
"embeddings",
help=(
"Run embeddings benchmark with optional quality validation. "
"Supports cosine similarity validation and MTEB benchmark evaluation."
),
context_settings={"auto_envvar_prefix": "GUIDELLM"},
)
@click.option(
"--target",
type=str,
required=True,
help="Target backend URL (e.g., http://localhost:8000).",
)
@click.option(
"--data",
type=str,
multiple=True,
required=True,
help=(
"HuggingFace dataset ID, path to dataset, path to data file "
"(csv/json/jsonl/txt), or synthetic data config."
),
)
@click.option(
"--profile",
default="sweep",
type=click.Choice(STRATEGY_PROFILE_CHOICES),
help=f"Benchmark profile type. Options: {', '.join(STRATEGY_PROFILE_CHOICES)}.",
)
@click.option(
"--rate",
callback=cli_tools.parse_list_floats,
multiple=True,
default=None,
help="Benchmark rate(s) to test. Meaning depends on profile.",
)
@click.option(
"--backend",
type=click.Choice(list(get_literal_vals(BackendType))),
default="openai_http",
help=f"Backend type. Options: {', '.join(get_literal_vals(BackendType))}.",
)
@click.option(
"--backend-kwargs",
callback=cli_tools.parse_json,
default=None,
help='JSON string of backend arguments. E.g., \'{"api_key": "key"}\'',
)
@click.option(
"--model",
default=None,
type=str,
help="Model ID to benchmark. If not provided, uses first available model.",
)
@click.option(
"--request-format",
default="embeddings",
help="Format to use for requests (default: embeddings).",
)
@click.option(
"--processor",
default=None,
type=str,
help="Processor or tokenizer for token counts. If not provided, loads from model.",
)
@click.option(
"--data-samples",
default=-1,
type=int,
help="Number of samples from dataset. -1 (default) uses all samples.",
)
@click.option(
"--outputs",
default=["json", "csv", "html"],
callback=cli_tools.parse_list,
help=(
"Comma-separated list of output formats: json,csv,html,console. "
"Default: json,csv,html"
),
)
@click.option(
"--output-dir",
type=click.Path(file_okay=False, dir_okay=True, path_type=Path),
default=Path.cwd(),
help="Directory to save output files. Default: current directory.",
)
@click.option(
"--max-requests",
default=None,
type=int,
help="Maximum number of requests to execute.",
)
@click.option(
"--max-errors",
default=None,
type=int,
help="Maximum number of errors before stopping benchmark.",
)
@click.option(
"--max-duration",
default=None,
type=float,
help="Maximum duration in seconds for benchmark execution.",
)
# Embeddings-specific quality validation options
@click.option(
"--enable-quality-validation",
is_flag=True,
default=False,
help="Enable quality validation using cosine similarity against baseline model.",
)
@click.option(
"--baseline-model",
default=None,
type=str,
help=(
"HuggingFace model for baseline comparison. "
"E.g., 'sentence-transformers/all-MiniLM-L6-v2'. "
"Defaults to target model if not specified."
),
)
@click.option(
"--quality-tolerance",
default=1e-2,
type=float,
help=(
"Cosine similarity tolerance threshold. "
"Default: 1e-2 (standard), use 5e-4 for MTEB-level validation."
),
)
@click.option(
"--enable-mteb",
is_flag=True,
default=False,
help="Enable MTEB benchmark evaluation for standardized quality scoring.",
)
@click.option(
"--mteb-tasks",
callback=cli_tools.parse_list,
default=None,
help=(
"Comma-separated list of MTEB tasks. "
"Default: STS12,STS13,STSBenchmark. E.g., 'STS12,STS13,STS14'"
),
)
@click.option(
"--encoding-format",
type=click.Choice(["float", "base64"]),
default="float",
help="Embedding encoding format. Options: float, base64. Default: float.",
)
@click.option(
"--disable-console",
is_flag=True,
default=False,
help="Disable all console output (including progress display).",
)
@click.option(
"--disable-console-interactive",
is_flag=True,
default=False,
help="Disable interactive console elements (progress bar, tables).",
)
@click.option(
"--random-seed",
default=42,
type=int,
help="Random seed for reproducibility. Default: 42.",
)
def embeddings(**kwargs):
"""Run embeddings benchmark with optional quality validation."""
from guidellm.benchmark.embeddings_entrypoints import benchmark_embeddings
from guidellm.benchmark.schemas.embeddings import BenchmarkEmbeddingsArgs

# Only set CLI args that differ from click defaults
kwargs = cli_tools.set_if_not_default(click.get_current_context(), **kwargs)

# Handle console options
disable_console = kwargs.pop("disable_console", False)
disable_console_interactive = (
kwargs.pop("disable_console_interactive", False) or disable_console
)
console = Console() if not disable_console else None

envs = cli_tools.list_set_env()
if console and envs:
console.print_update(
title=(
"Note: the following environment variables "
"are set and **may** affect configuration"
),
details=", ".join(envs),
status="warning",
)

try:
args = BenchmarkEmbeddingsArgs.create(scenario=None, **kwargs)
except ValidationError as err:
errs = err.errors(include_url=False, include_context=True, include_input=True)
param_name = "--" + str(errs[0]["loc"][0]).replace("_", "-")
raise click.BadParameter(
errs[0]["msg"], ctx=click.get_current_context(), param_hint=param_name
) from err

if uvloop is not None:
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

asyncio.run(
benchmark_embeddings(
args=args,
progress=(
GenerativeConsoleBenchmarkerProgress()
if not disable_console_interactive
else None
),
console=console,
)
)


if __name__ == "__main__":
cli()
16 changes: 14 additions & 2 deletions src/guidellm/backends/openai/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
"/v1/chat/completions": "v1/chat/completions",
"/v1/audio/transcriptions": "v1/audio/transcriptions",
"/v1/audio/translations": "v1/audio/translations",
"/v1/embeddings": "v1/embeddings",
"embeddings": "v1/embeddings", # Alias for convenience
}

DEFAULT_API = "/v1/chat/completions"
Expand All @@ -50,6 +52,9 @@
"audio_translations": "/v1/audio/translations",
}

# NOTE: This value is taken from httpx's default
FALLBACK_TIMEOUT = 5.0


@Backend.register("openai_http")
class OpenAIHTTPBackend(Backend):
Expand Down Expand Up @@ -83,7 +88,8 @@ def __init__(
api_key: str | None = None,
api_routes: dict[str, str] | None = None,
request_handlers: dict[str, Any] | None = None,
timeout: float = 60.0,
timeout: float | None = None,
timeout_connect: float | None = FALLBACK_TIMEOUT,
http2: bool = True,
follow_redirects: bool = True,
verify: bool = False,
Expand Down Expand Up @@ -133,6 +139,7 @@ def __init__(
self.api_routes = api_routes or DEFAULT_API_PATHS
self.request_handlers = request_handlers
self.timeout = timeout
self.timeout_connect = timeout_connect
self.http2 = http2
self.follow_redirects = follow_redirects
self.verify = verify
Expand Down Expand Up @@ -162,6 +169,7 @@ def info(self) -> dict[str, Any]:
"target": self.target,
"model": self.model,
"timeout": self.timeout,
"timeout_connect": self.timeout_connect,
"http2": self.http2,
"follow_redirects": self.follow_redirects,
"verify": self.verify,
Expand All @@ -182,7 +190,11 @@ async def process_startup(self):

self._async_client = httpx.AsyncClient(
http2=self.http2,
timeout=self.timeout,
timeout=httpx.Timeout(
FALLBACK_TIMEOUT,
read=self.timeout,
connect=self.timeout_connect,
),
follow_redirects=self.follow_redirects,
verify=self.verify,
# Allow unlimited connections
Expand Down
Loading
Loading