Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions Fuser/auto_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Optional, Tuple
from typing import Any

from dotenv import load_dotenv
from Fuser.pipeline import run_pipeline
Expand Down Expand Up @@ -105,7 +105,7 @@ def _file_sha256_text(txt: str) -> str:
return hashlib.sha256(txt.encode("utf-8")).hexdigest()


def _load_router_cache() -> Dict[str, Any]:
def _load_router_cache() -> dict[str, Any]:
try:
if _ROUTER_CACHE_PATH.is_file():
return json.loads(_ROUTER_CACHE_PATH.read_text(encoding="utf-8"))
Expand All @@ -114,7 +114,7 @@ def _load_router_cache() -> Dict[str, Any]:
return {}


def _save_router_cache(cache: Dict[str, Any]) -> None:
def _save_router_cache(cache: dict[str, Any]) -> None:
try:
_ensure_dir(_ROUTER_CACHE_PATH)
_ROUTER_CACHE_PATH.write_text(json.dumps(cache, indent=2), encoding="utf-8")
Expand Down Expand Up @@ -144,7 +144,7 @@ class Complexity:
pool_ops: int
act_ops: int
chain_len_estimate: int
raw_op_names: Dict[str, int]
raw_op_names: dict[str, int]

def route_to_fuser(self) -> bool:
# Primary triggers
Expand Down Expand Up @@ -213,7 +213,7 @@ def analyze_problem_code(code: str) -> Complexity:

# AST path: inspect Model.forward for ops and control flow
has_control_flow = False
raw_op_counts: Dict[str, int] = {}
raw_op_counts: dict[str, int] = {}
has_attention_like = False
has_conv_transpose = False
has_group_norm = False
Expand Down Expand Up @@ -298,19 +298,19 @@ def visit_Assign(self, node: ast.Assign) -> Any:
class RouteResult:
route: str # "kernelagent" or "fuser"
success: bool
details: Dict[str, Any]
kernel_code: Optional[str] = None
details: dict[str, Any]
kernel_code: str | None = None


class AutoKernelRouter:
def __init__(
self,
ka_model: Optional[str] = None,
ka_model: str | None = None,
ka_num_workers: int = 4,
ka_max_rounds: int = 10,
ka_high_reasoning: bool = True,
# Router LLM
router_model: Optional[str] = "gpt-5",
router_model: str | None = "gpt-5",
router_high_reasoning: bool = True,
router_temperature: float = 0.2,
router_max_tokens: int = 700,
Expand Down Expand Up @@ -431,7 +431,7 @@ def _solve_with_fuser(self, problem_path: Path) -> RouteResult:

comp = res.get("composition", {}) or {}
ok = bool(comp.get("verify_passed", not self.verify))
kernel_code: Optional[str] = None
kernel_code: str | None = None
try:
composed_path = comp.get("composed_path")
if composed_path and Path(composed_path).is_file():
Expand All @@ -457,9 +457,9 @@ def solve(self, problem_path: Path) -> RouteResult:
cache = _load_router_cache()
cached = cache.get(code_hash)

strategy: Optional[str] = None
route_conf: Optional[float] = None
route_cfg: Dict[str, Any] = {}
strategy: str | None = None
route_conf: float | None = None
route_cfg: dict[str, Any] = {}

if isinstance(cached, dict):
strategy = (
Expand Down Expand Up @@ -545,7 +545,7 @@ def solve(self, problem_path: Path) -> RouteResult:
# -------- LLM decision helper --------
def _llm_decide_route(
self, problem_path: Path, code: str, cx: Complexity
) -> Tuple[Optional[str], Optional[float], Dict[str, Any]]:
) -> tuple[str | None, float | None, dict[str, Any]]:
"""Ask an LLM to choose a routing STRATEGY and optional budgets.

The LLM must return JSON with keys:
Expand Down Expand Up @@ -621,7 +621,7 @@ def _llm_decide_route(
f"Features:\n```json\n{json.dumps(feats, indent=2)}\n```\n\n"
"Problem code:\n```python\n" + code + "\n```\n"
)
kwargs: Dict[str, Any] = {
kwargs: dict[str, Any] = {
"max_tokens": self.router_max_tokens,
"temperature": self.router_temperature,
}
Expand All @@ -636,7 +636,7 @@ def _llm_decide_route(
# Best-effort JSON parse
route = None
conf = None
raw_info: Dict[str, Any] = {"raw": txt}
raw_info: dict[str, Any] = {"raw": txt}
try:
# If model returned extra text, try to locate JSON object
first = txt.find("{")
Expand Down Expand Up @@ -668,7 +668,7 @@ def _llm_decide_route(
# ------------------------


def main(argv: Optional[list[str]] = None) -> int:
def main(argv: list[str] | None = None) -> int:
p = argparse.ArgumentParser(
description="Auto-router for KernelBench problems (KernelAgent vs Fuser)"
)
Expand Down Expand Up @@ -748,7 +748,7 @@ def main(argv: Optional[list[str]] = None) -> int:
)
return 1

out: Dict[str, Any] = {
out: dict[str, Any] = {
"route": res.route,
"success": res.success,
"details": res.details,
Expand Down
36 changes: 18 additions & 18 deletions Fuser/compose_end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
import textwrap
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from typing import Any

from dotenv import load_dotenv

Expand Down Expand Up @@ -73,11 +73,11 @@ def _read_text(path: Path) -> str:
return path.read_text(encoding="utf-8")


def _load_kernels_from_summary(summary_path: Path) -> List[KernelItem]:
def _load_kernels_from_summary(summary_path: Path) -> list[KernelItem]:
data = json.loads(_read_text(summary_path))
if not isinstance(data, list):
raise SystemExit("kernels summary must be a JSON array (from dispatch step)")
items: List[KernelItem] = []
items: list[KernelItem] = []
for it in data:
if not isinstance(it, dict):
continue
Expand All @@ -98,8 +98,8 @@ def _load_kernels_from_summary(summary_path: Path) -> List[KernelItem]:
return items


def _summarize_subgraphs_for_prompt(subgraphs: List[Dict[str, Any]]) -> str:
lines: List[str] = []
def _summarize_subgraphs_for_prompt(subgraphs: list[dict[str, Any]]) -> str:
lines: list[str] = []
for it in subgraphs:
sid = str(it.get("id", "unknown"))
typ = str(it.get("type", ""))
Expand All @@ -126,16 +126,16 @@ def _summarize_subgraphs_for_prompt(subgraphs: List[Dict[str, Any]]) -> str:

def _build_composition_prompt(
problem_code: str,
subgraphs: List[Dict[str, Any]],
kernel_items: List[KernelItem],
subgraphs: list[dict[str, Any]],
kernel_items: list[KernelItem],
) -> str:
"""Create a single user message to instruct composition by the LLM."""
# Provide a succinct summary of subgraphs up front
sg_summary = _summarize_subgraphs_for_prompt(subgraphs)

# Include only essential snippets from each kernel to keep token usage sane
# We include full files for now; callers can trim by model limits.
kernels_section_parts: List[str] = []
kernels_section_parts: list[str] = []
for ki in kernel_items:
kernels_section_parts.append(
f"### Subgraph {ki.subgraph_id}\n```python\n" + ki.code + "\n```\n"
Expand Down Expand Up @@ -190,7 +190,7 @@ def _build_composition_prompt(
"""
).strip()

user_lines: List[str] = []
user_lines: list[str] = []
user_lines.append(guidance)
user_lines.append("")
user_lines.append("SUBGRAPHS (summary):")
Expand All @@ -212,10 +212,10 @@ def _build_composition_prompt(

def _build_refinement_prompt(
problem_code: str,
subgraphs: List[Dict[str, Any]],
kernel_items: List[KernelItem],
subgraphs: list[dict[str, Any]],
kernel_items: list[KernelItem],
previous_code: str,
error_info: Dict[str, str],
error_info: dict[str, str],
) -> str:
"""Prompt the LLM to refine the previously produced code based on errors."""
err_tail = error_info.get("stderr_tail", "")
Expand All @@ -239,7 +239,7 @@ def _build_refinement_prompt(
"""
).strip()

lines: List[str] = []
lines: list[str] = []
lines.append(guidance)
lines.append("")
lines.append("ERROR_CONTEXT (stderr tail):\n```\n" + err_tail + "\n```")
Expand All @@ -259,7 +259,7 @@ def _build_refinement_prompt(
return "\n".join(lines)


def _auto_patch_common_triton_issues(code: str) -> Tuple[str, bool]:
def _auto_patch_common_triton_issues(code: str) -> tuple[str, bool]:
"""Apply tiny safe textual patches for known Triton pitfalls.

- Replace tl.broadcast(0.0, ...) or tl.broadcast(1.0, ...) with scalar constants.
Expand Down Expand Up @@ -289,7 +289,7 @@ def compose(
model_name: str,
verify: bool = False,
max_iters: int = 5,
) -> Dict[str, Any]:
) -> dict[str, Any]:
if get_model_provider is None:
raise SystemExit(
"KernelAgent providers unavailable; ensure package import and dependencies"
Expand All @@ -310,7 +310,7 @@ def compose(

last_usage = None
last_code = None
verify_info: Dict[str, Any] = {}
verify_info: dict[str, Any] = {}

for i in range(1, max_iters + 1):
if i == 1 or last_code is None:
Expand Down Expand Up @@ -388,7 +388,7 @@ def compose(
composed_path = out_dir / "composed_kernel.py"
composed_path.write_text(last_code or "", encoding="utf-8")

result: Dict[str, Any] = {
result: dict[str, Any] = {
"success": bool(verify_info.get("verify_passed", not verify)),
"composed_path": str(composed_path.resolve()),
"model": model_name,
Expand All @@ -404,7 +404,7 @@ def compose(
return result


def main(argv: Optional[List[str]] = None) -> int:
def main(argv: list[str] | None = None) -> int:
load_dotenv()
p = argparse.ArgumentParser(
description="Compose end-to-end Triton kernel from subgraphs + generated kernels"
Expand Down
5 changes: 2 additions & 3 deletions Fuser/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from __future__ import annotations
from dataclasses import dataclass, asdict
from pathlib import Path
from typing import Optional
import json
import time
import uuid
Expand Down Expand Up @@ -68,8 +67,8 @@ class WorkerConfig:
@dataclass
class ResultSummary:
run_id: str
winner_worker_id: Optional[str]
artifact_path: Optional[str]
winner_worker_id: str | None
artifact_path: str | None
reason: str


Expand Down
3 changes: 1 addition & 2 deletions Fuser/dedup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@
import json
import time
from pathlib import Path
from typing import Tuple, Optional


def register_digest(
shared_digests_dir: Path, sha256: str, worker_id: str, iter_index: int
) -> Tuple[str, Optional[str]]:
) -> tuple[str, str | None]:
"""
Atomically register a digest in shared_digests_dir.
Returns (status, owner_worker_id or None), where status is one of:
Expand Down
Loading