Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ triton_kernel_logs/
session_*/
worker_*/
.fuse/
outputs/

# Generated kernels
kernel.py
Expand Down
103 changes: 47 additions & 56 deletions Fuser/auto_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,26 +33,41 @@

If the chosen path fails, the agent can optionally fall back to the other path.

CLI:
python -m Fuser.auto_agent --problem /abs/path/to/problem.py \
[--ka-model gpt-5] [--extract-model gpt-5] [--dispatch-model o4-mini] [--compose-model o4-mini] \
[--verify] [--no-fallback]

CLI (Hydra-based):
python -m Fuser.auto_agent problem=/abs/path/to/problem.py

# Override config values:
python -m Fuser.auto_agent problem=/abs/path/to/problem.py \
ka.model=gpt-5 \
router.model=gpt-5 \
fuser.extracter.model=gpt-5 \
fuser.dispatcher.model=o4-mini \
fuser.composer.model=o4-mini \
fuser.composer.verify=true \
routing.allow_fallback=false

# Or use a custom config:
python -m Fuser.auto_agent --config-name custom_auto_agent problem=/abs/path/to/problem.py

Config file: configs/pipeline/auto_agent.yaml
Returns a JSON summary to stdout and writes the generated kernel path (if available).
"""

from __future__ import annotations

import argparse
import ast
import hashlib
import json
import sys

from hydra import main as hydra_main

from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Optional, Tuple

from dotenv import load_dotenv
from omegaconf import DictConfig
from Fuser.pipeline import run_pipeline

# Local imports (available inside repo)
Expand Down Expand Up @@ -668,64 +683,40 @@ def _llm_decide_route(
# ------------------------


def main(argv: Optional[list[str]] = None) -> int:
p = argparse.ArgumentParser(
description="Auto-router for KernelBench problems (KernelAgent vs Fuser)"
)
p.add_argument("--problem", required=True, help="Absolute path to the problem file")
p.add_argument(
"--ka-model",
default=None,
help="Model for KernelAgent (optional; uses env default if omitted)",
)
p.add_argument("--ka-workers", type=int, default=4)
p.add_argument("--ka-rounds", type=int, default=10)
p.add_argument("--no-ka-high-reasoning", action="store_true")
p.add_argument("--router-model", default="gpt-5")
p.add_argument("--no-router-high-reasoning", action="store_true")
p.add_argument("--router-temp", type=float, default=0.2)
p.add_argument("--router-max-tokens", type=int, default=700)
p.add_argument("--extract-model", default="gpt-5")
p.add_argument("--dispatch-model", default="o4-mini")
p.add_argument("--compose-model", default="o4-mini")
p.add_argument("--workers", type=int, default=4)
p.add_argument("--max-iters", type=int, default=5)
p.add_argument("--llm-timeout-s", type=int, default=1200)
p.add_argument("--run-timeout-s", type=int, default=1200)
p.add_argument("--compose-max-iters", type=int, default=5)
p.add_argument("--verify", action="store_true")
p.add_argument("--dispatch-jobs", type=int, default=2)
p.add_argument("--no-fallback", action="store_true")
args = p.parse_args(argv)

@hydra_main(
version_base=None,
config_path=str(Path(__file__).resolve().parent.parent / "configs/pipeline"),
config_name="auto_agent",
)
def main(cfg: DictConfig) -> int:
# Load environment variables from .env file
load_dotenv()

problem_path = Path(args.problem).resolve()
problem_path = Path(cfg.problem).resolve()
if not problem_path.is_file():
print(f"problem not found: {problem_path}", file=sys.stderr)
return 2

router = AutoKernelRouter(
ka_model=args.ka_model,
ka_num_workers=args.ka_workers,
ka_max_rounds=args.ka_rounds,
ka_high_reasoning=(not args.no_ka_high_reasoning),
router_model=args.router_model,
router_high_reasoning=(not args.no_router_high_reasoning),
router_temperature=args.router_temp,
router_max_tokens=args.router_max_tokens,
extract_model=args.extract_model,
dispatch_model=args.dispatch_model,
compose_model=args.compose_model,
workers=args.workers,
max_iters=args.max_iters,
llm_timeout_s=args.llm_timeout_s,
run_timeout_s=args.run_timeout_s,
compose_max_iters=args.compose_max_iters,
verify=args.verify,
dispatch_jobs=args.dispatch_jobs,
allow_fallback=(not args.no_fallback),
ka_model=cfg.ka.model_name,
ka_num_workers=cfg.ka.num_workers,
ka_max_rounds=cfg.ka.max_rounds,
ka_high_reasoning=cfg.ka.high_reasoning,
router_model=cfg.router.model,
router_high_reasoning=cfg.router.high_reasoning,
router_temperature=cfg.router.temperature,
router_max_tokens=cfg.router.max_tokens,
extract_model=cfg.fuser.extractor.model,
dispatch_model=cfg.fuser.dispatcher.model,
compose_model=cfg.fuser.composer.model,
workers=cfg.fuser.extractor.workers,
max_iters=cfg.fuser.extractor.max_iters,
llm_timeout_s=cfg.fuser.extractor.llm_timeout_s,
run_timeout_s=cfg.fuser.extractor.run_timeout_s,
compose_max_iters=cfg.fuser.composer.max_iters,
verify=cfg.fuser.composer.verify,
dispatch_jobs=cfg.fuser.dispatcher.jobs,
allow_fallback=cfg.routing.allow_fallback,
)

try:
Expand Down
122 changes: 65 additions & 57 deletions Fuser/cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from __future__ import annotations

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -13,15 +12,50 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
"""
Orchestrates parallel LLM workers to generate and verify
fused code.

Runs multiple workers concurrently against a KernelBench problem file, each
worker attempting to generate a valid solution. The first worker to produce a
passing candidate wins.

CLI (Hydra-based):
python -m Fuser.cli problem=/abs/path/to/problem.py

# Override config values:
python -m Fuser.cli problem=/abs/path/to/problem.py \
model=gpt-5 \
workers=4 \
max_iters=10 \
stream=winner

# Or use a custom config:
python -m Fuser.cli --config-name custom_fuser \
problem=/abs/path/to/problem.py

Config file: configs/pipeline/orchestrator.yaml

Requirements:
- OPENAI_API_KEY (.env in CWD or environment)

Outputs:
- Run directory path printed to stdout
- Artifacts in .fuse/<run_id>/
"""

from __future__ import annotations
import json
import sys
import os
import multiprocessing as mp
from pathlib import Path

from hydra import main as hydra_main
from omegaconf import DictConfig

from .config import new_run_id, OrchestratorConfig
from .constants import ExitCode
from .config import OrchestratorConfig, new_run_id
from .paths import ensure_abs_regular_file, make_run_dirs, PathSafetyError
from .logging_utils import setup_file_logger
from .orchestrator import Orchestrator
Expand Down Expand Up @@ -53,56 +87,43 @@ def _load_dotenv_if_present() -> None:
pass


def cmd_run(argv: list[str]) -> int:
@hydra_main(
version_base=None,
config_path=str(Path(__file__).resolve().parent.parent / "configs/pipeline"),
config_name="orchestrator",
)
def main(cfg: DictConfig) -> int:
_load_dotenv_if_present()
p = argparse.ArgumentParser(
prog="fuse run", description="Fuse Orchestrator — first-wins runner"
)
p.add_argument(
"--problem", required=True, help="Absolute path to the Python problem file"
)
p.add_argument(
"--model",
default="gpt-5",
help="OpenAI model name (Responses API, default: gpt-5)",
)
p.add_argument("--workers", type=int, default=4)
p.add_argument("--max-iters", type=int, default=10)
p.add_argument("--llm-timeout-s", type=int, default=120)
p.add_argument("--run-timeout-s", type=int, default=180)
p.add_argument("--stream", choices=["all", "winner", "none"], default="all")
p.add_argument("--store-responses", action="store_true", default=False)
p.add_argument("--isolated", action="store_true", default=False)
p.add_argument("--deny-network", action="store_true", default=False)
p.add_argument("--enable-reasoning-extras", action="store_true", default=True)
args = p.parse_args(argv)

try:
problem_path = ensure_abs_regular_file(args.problem)
problem_path = ensure_abs_regular_file(cfg.problem)
except PathSafetyError as e:
print(str(e), file=sys.stderr)
return int(ExitCode.INVALID_ARGS)

cfg = OrchestratorConfig(
orch_cfg = OrchestratorConfig(
problem_path=problem_path,
model=args.model,
workers=args.workers,
max_iters=args.max_iters,
llm_timeout_s=args.llm_timeout_s,
run_timeout_s=args.run_timeout_s,
stream_mode=args.stream,
store_responses=args.store_responses,
isolated=args.isolated,
deny_network=args.deny_network,
enable_reasoning_extras=args.enable_reasoning_extras,
model=cfg.model,
workers=cfg.workers,
max_iters=cfg.max_iters,
llm_timeout_s=cfg.llm_timeout_s,
run_timeout_s=cfg.run_timeout_s,
stream_mode=cfg.stream,
store_responses=cfg.store_responses,
isolated=cfg.isolated,
deny_network=cfg.deny_network,
enable_reasoning_extras=cfg.enable_reasoning_extras,
)

run_id = new_run_id()
FUSE_BASE_DIR.mkdir(exist_ok=True)
try:
d = make_run_dirs(FUSE_BASE_DIR, run_id)
except FileExistsError:
print("Run directory already exists unexpectedly; retry.", file=sys.stderr)
print(
"Run directory already exists unexpectedly; retry.",
file=sys.stderr,
)
return int(ExitCode.GENERIC_FAILURE)

orch_dir = d["orchestrator"]
Expand All @@ -113,7 +134,7 @@ def cmd_run(argv: list[str]) -> int:
json.dumps(
{
"run_id": run_id,
"config": json.loads(cfg.to_json()),
"config": json.loads(orch_cfg.to_json()),
},
indent=2,
)
Expand All @@ -128,7 +149,10 @@ def cmd_run(argv: list[str]) -> int:
# Spawn orchestrator and execute first-wins
mp.set_start_method("spawn", force=True)
orch = Orchestrator(
cfg, run_dir=run_dir, workers_dir=d["workers"], orchestrator_dir=orch_dir
orch_cfg,
run_dir=run_dir,
workers_dir=d["workers"],
orchestrator_dir=orch_dir,
)
summary = orch.run()

Expand All @@ -142,21 +166,5 @@ def cmd_run(argv: list[str]) -> int:
return int(ExitCode.SUCCESS)


def main(argv: list[str] | None = None) -> int:
argv = list(sys.argv[1:] if argv is None else argv)
if not argv:
print(
"usage: fuse run --problem /abs/path.py [--model <name>] [flags]",
file=sys.stderr,
)
return int(ExitCode.INVALID_ARGS)
cmd = argv[0]
if cmd == "run":
return cmd_run(argv[1:])
else:
print(f"unknown subcommand: {cmd}", file=sys.stderr)
return int(ExitCode.INVALID_ARGS)


if __name__ == "__main__":
sys.exit(main())
Loading