From c62db16236b8255ed07131fde10231b1bff1805f Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Fri, 19 Dec 2025 13:33:04 -0800 Subject: [PATCH 1/4] Replace Optional with | None --- Fuser/auto_agent.py | 18 +++++++++--------- Fuser/compose_end_to_end.py | 4 ++-- Fuser/config.py | 5 ++--- Fuser/dedup.py | 4 ++-- Fuser/event_adapter.py | 14 +++++++------- Fuser/pipeline.py | 7 +++---- Fuser/prompting.py | 10 +++++----- Fuser/runner.py | 3 +-- Fuser/subgraph_extractor.py | 6 +++--- Fuser/worker.py | 8 ++++---- scripts/fuser_ui.py | 18 +++++++++--------- scripts/pipeline_ui.py | 16 ++++++++-------- scripts/triton_ui.py | 8 ++++---- triton_kernel_agent/agent.py | 20 ++++++++++---------- triton_kernel_agent/manager.py | 12 ++++++------ triton_kernel_agent/prompt_manager.py | 12 ++++++------ triton_kernel_agent/worker.py | 8 ++++---- utils/providers/base.py | 8 ++++---- utils/providers/env_config.py | 6 +++--- utils/providers/models.py | 10 +++++----- utils/providers/openai_base.py | 4 ++-- 21 files changed, 99 insertions(+), 102 deletions(-) diff --git a/Fuser/auto_agent.py b/Fuser/auto_agent.py index f44eb6f..fc9822f 100644 --- a/Fuser/auto_agent.py +++ b/Fuser/auto_agent.py @@ -50,7 +50,7 @@ import sys from dataclasses import dataclass from pathlib import Path -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Tuple from dotenv import load_dotenv from Fuser.pipeline import run_pipeline @@ -299,18 +299,18 @@ class RouteResult: route: str # "kernelagent" or "fuser" success: bool details: Dict[str, Any] - kernel_code: Optional[str] = None + kernel_code: str | None = None class AutoKernelRouter: def __init__( self, - ka_model: Optional[str] = None, + ka_model: str | None = None, ka_num_workers: int = 4, ka_max_rounds: int = 10, ka_high_reasoning: bool = True, # Router LLM - router_model: Optional[str] = "gpt-5", + router_model: str | None = "gpt-5", router_high_reasoning: bool = True, router_temperature: float = 0.2, router_max_tokens: int = 700, @@ -431,7 +431,7 @@ def _solve_with_fuser(self, problem_path: Path) -> RouteResult: comp = res.get("composition", {}) or {} ok = bool(comp.get("verify_passed", not self.verify)) - kernel_code: Optional[str] = None + kernel_code: str | None = None try: composed_path = comp.get("composed_path") if composed_path and Path(composed_path).is_file(): @@ -457,8 +457,8 @@ def solve(self, problem_path: Path) -> RouteResult: cache = _load_router_cache() cached = cache.get(code_hash) - strategy: Optional[str] = None - route_conf: Optional[float] = None + strategy: str | None = None + route_conf: float | None = None route_cfg: Dict[str, Any] = {} if isinstance(cached, dict): @@ -545,7 +545,7 @@ def solve(self, problem_path: Path) -> RouteResult: # -------- LLM decision helper -------- def _llm_decide_route( self, problem_path: Path, code: str, cx: Complexity - ) -> Tuple[Optional[str], Optional[float], Dict[str, Any]]: + ) -> Tuple[str | None, float | None, Dict[str, Any]]: """Ask an LLM to choose a routing STRATEGY and optional budgets. The LLM must return JSON with keys: @@ -668,7 +668,7 @@ def _llm_decide_route( # ------------------------ -def main(argv: Optional[list[str]] = None) -> int: +def main(argv: list[str | None] = None) -> int: p = argparse.ArgumentParser( description="Auto-router for KernelBench problems (KernelAgent vs Fuser)" ) diff --git a/Fuser/compose_end_to_end.py b/Fuser/compose_end_to_end.py index 17ae07c..02e293a 100644 --- a/Fuser/compose_end_to_end.py +++ b/Fuser/compose_end_to_end.py @@ -45,7 +45,7 @@ import textwrap from dataclasses import dataclass from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Tuple from dotenv import load_dotenv @@ -404,7 +404,7 @@ def compose( return result -def main(argv: Optional[List[str]] = None) -> int: +def main(argv: List[str | None] = None) -> int: load_dotenv() p = argparse.ArgumentParser( description="Compose end-to-end Triton kernel from subgraphs + generated kernels" diff --git a/Fuser/config.py b/Fuser/config.py index f44d9ba..ec8a64f 100644 --- a/Fuser/config.py +++ b/Fuser/config.py @@ -14,7 +14,6 @@ from __future__ import annotations from dataclasses import dataclass, asdict from pathlib import Path -from typing import Optional import json import time import uuid @@ -68,8 +67,8 @@ class WorkerConfig: @dataclass class ResultSummary: run_id: str - winner_worker_id: Optional[str] - artifact_path: Optional[str] + winner_worker_id: str | None + artifact_path: str | None reason: str diff --git a/Fuser/dedup.py b/Fuser/dedup.py index 70a1f0c..abf1f7a 100644 --- a/Fuser/dedup.py +++ b/Fuser/dedup.py @@ -15,12 +15,12 @@ import json import time from pathlib import Path -from typing import Tuple, Optional +from typing import Tuple def register_digest( shared_digests_dir: Path, sha256: str, worker_id: str, iter_index: int -) -> Tuple[str, Optional[str]]: +) -> Tuple[str, str | None]: """ Atomically register a digest in shared_digests_dir. Returns (status, owner_worker_id or None), where status is one of: diff --git a/Fuser/event_adapter.py b/Fuser/event_adapter.py index 0a70092..289b8f4 100644 --- a/Fuser/event_adapter.py +++ b/Fuser/event_adapter.py @@ -17,7 +17,7 @@ import time from dataclasses import dataclass from pathlib import Path -from typing import Callable, Optional, Any +from typing import Callable, Any try: # OpenAI official SDK (Responses API) @@ -48,9 +48,9 @@ def __init__( store_responses: bool, timeout_s: int, jsonl_path: Path, - stop_event: Optional[threading.Event] = None, - on_delta: Optional[Callable[[str], None]] = None, - client: Optional[Any] = None, + stop_event: threading.Event | None = None, + on_delta: Callable[[str | None, None]] = None, + client: Any | None = None, ) -> None: self.model = model self.store_responses = store_responses @@ -117,7 +117,7 @@ def stream( self, system_prompt: str, user_prompt: str, - extras: Optional[dict[str, Any]] = None, + extras: dict[str, Any | None] = None, ) -> dict[str, Any]: """ Start streaming and persist events. Returns a dict with summary fields: @@ -125,8 +125,8 @@ def stream( """ client = self._ensure_client() output_text_parts: list[str] = [] - response_id: Optional[str] = None - error_msg: Optional[str] = None + response_id: str | None = None + error_msg: str | None = None # Start background flusher done_flag = threading.Event() diff --git a/Fuser/pipeline.py b/Fuser/pipeline.py index 16633a3..386d3d9 100644 --- a/Fuser/pipeline.py +++ b/Fuser/pipeline.py @@ -36,7 +36,6 @@ import argparse import json from pathlib import Path -from typing import Optional from .subgraph_extractor import extract_subgraphs_to_json from .dispatch_kernel_agent import run as dispatch_run @@ -46,14 +45,14 @@ def run_pipeline( problem_path: Path, extract_model: str, - dispatch_model: Optional[str], + dispatch_model: str | None, compose_model: str, dispatch_jobs: int | str, workers: int, max_iters: int, llm_timeout_s: int, run_timeout_s: int, - out_root: Optional[Path] = None, + out_root: Path | None = None, verify: bool = True, compose_max_iters: int = 5, ) -> dict: @@ -130,7 +129,7 @@ def run_pipeline( } -def main(argv: Optional[list[str]] = None) -> int: +def main(argv: list[str | None] = None) -> int: # Load .env if present for OPENAI_API_KEY, proxies, etc. try: from dotenv import load_dotenv # type: ignore diff --git a/Fuser/prompting.py b/Fuser/prompting.py index f5de060..fd53225 100644 --- a/Fuser/prompting.py +++ b/Fuser/prompting.py @@ -14,7 +14,7 @@ from __future__ import annotations from dataclasses import dataclass from pathlib import Path -from typing import Optional, Any +from typing import Any # Prompt rendering for the Fuse orchestrator # - Deterministic, stateless per iteration @@ -66,7 +66,7 @@ def _variant_line(idx: int) -> str: def build_user_prompt( attempt_index: int, problem_file_content: str, - error_context: Optional[str], + error_context: str | None, variant_index: int, ) -> str: parts: list[str] = [] @@ -89,10 +89,10 @@ def render_prompt( problem_path: Path, variant_index: int, attempt_index: int, - error_context: Optional[str], + error_context: str | None, enable_reasoning_extras: bool, - seed: Optional[int] = None, - model_name: Optional[str] = None, + seed: int | None = None, + model_name: str | None = None, ) -> RenderedPrompt: """Render system+user prompts and extras for the Responses API (deterministic).""" content = problem_path.read_text(encoding="utf-8") diff --git a/Fuser/runner.py b/Fuser/runner.py index 683eeaf..41454fb 100644 --- a/Fuser/runner.py +++ b/Fuser/runner.py @@ -24,7 +24,6 @@ import threading from dataclasses import dataclass from pathlib import Path -from typing import Optional from Fuser.runner_util import _run_candidate_multiprocess @@ -203,7 +202,7 @@ def run_candidate( timeout_s: int, isolated: bool, deny_network: bool, - cancel_event: Optional["threading.Event"] = None, + cancel_event: "threading.Event" | None = None, ) -> RunResult: """ Execute a candidate program in a fresh run directory under run_root. diff --git a/Fuser/subgraph_extractor.py b/Fuser/subgraph_extractor.py index afbb5d7..2a7a98b 100644 --- a/Fuser/subgraph_extractor.py +++ b/Fuser/subgraph_extractor.py @@ -38,7 +38,7 @@ import sys import tarfile from pathlib import Path -from typing import Any, Optional, Tuple, Dict +from typing import Any, Tuple, Dict from .cli import _load_dotenv_if_present # reuse env loader from .config import OrchestratorConfig, new_run_id @@ -72,7 +72,7 @@ def _load_code_from_tar(artifact_path: Path) -> str: def _extract_json_block(text: str) -> str: """Extract the last fenced JSON block or fallback to best-effort slice.""" matches = list(_JSON_BLOCK_RE.finditer(text)) - chosen: Optional[re.Match[str]] = None + chosen: re.Match[str | None] = None for m in reversed(matches): lang = (m.group(1) or "").strip().lower() if lang == "json": @@ -367,7 +367,7 @@ def sort_w(obj: Any) -> Dict[str, Any]: return dirs["run_dir"], out_path -def main(argv: Optional[list[str]] = None) -> int: +def main(argv: list[str | None] = None) -> int: _load_dotenv_if_present() p = argparse.ArgumentParser( description="Extract unique subgraphs with shapes (JSON)" diff --git a/Fuser/worker.py b/Fuser/worker.py index 468662b..4d58519 100644 --- a/Fuser/worker.py +++ b/Fuser/worker.py @@ -16,7 +16,7 @@ import queue from dataclasses import dataclass, asdict from pathlib import Path -from typing import Optional, Any, Callable +from typing import Any, Callable from .config import WorkerConfig from .event_adapter import EventAdapter @@ -33,8 +33,8 @@ class WorkerState: worker_id: str iter_index: int - last_response_id: Optional[str] - last_error: Optional[str] + last_response_id: str | None + last_error: str | None passed: bool @@ -77,7 +77,7 @@ def __init__( problem_path: Path, winner_queue: Any, cancel_event: Any, - on_delta: Optional[Callable[[str], None]] = None, + on_delta: Callable[[str | None, None]] = None, ) -> None: self.cfg = cfg self.problem_path = problem_path diff --git a/scripts/fuser_ui.py b/scripts/fuser_ui.py index 6977cba..cd55e83 100644 --- a/scripts/fuser_ui.py +++ b/scripts/fuser_ui.py @@ -25,7 +25,7 @@ import zipfile from dataclasses import dataclass from pathlib import Path -from typing import List, Optional, Tuple +from typing import List, Tuple import gradio as gr from dotenv import load_dotenv @@ -48,7 +48,7 @@ class RunArtifacts: summary_md: str code_text: str run_info_md: str - zip_path: Optional[Path] + zip_path: Path | None def _list_kernelbench_problems(base: Path) -> List[Tuple[str, str]]: @@ -81,7 +81,7 @@ def base_name(node: ast.expr) -> str: return node.id if isinstance(node, ast.Attribute): parts: List[str] = [] - cur: Optional[ast.AST] = node + cur: ast.AST | None = node while isinstance(cur, ast.Attribute): parts.append(cur.attr) cur = cur.value @@ -134,7 +134,7 @@ def _load_code_from_tar(artifact_path: Path) -> str: return extracted.read().decode("utf-8") -def _create_zip_from_tar(artifact_path: Path, zip_path: Path) -> Optional[Path]: +def _create_zip_from_tar(artifact_path: Path, zip_path: Path) -> Path | None: if not artifact_path.is_file(): return None with ( @@ -153,7 +153,7 @@ def _create_zip_from_tar(artifact_path: Path, zip_path: Path) -> Optional[Path]: def _compose_run_info( - run_dir: Path, summary_reason: str, elapsed: float, winner: Optional[str] + run_dir: Path, summary_reason: str, elapsed: float, winner: str | None ) -> str: lines = ["## 📁 Run Information"] lines.append(f"- Run directory: `{run_dir}`") @@ -174,7 +174,7 @@ def run_fuser_problem( llm_timeout: int, run_timeout: int, enable_reasoning: bool, - user_api_key: Optional[str] = None, + user_api_key: str | None = None, ) -> RunArtifacts: """Execute the Fuser orchestrator and collect artifacts.""" if not problem_path: @@ -664,8 +664,8 @@ def run( llm_timeout: int, run_timeout: int, enable_reasoning: bool, - user_api_key: Optional[str], - ) -> Tuple[str, str, str, str, Optional[str]]: + user_api_key: str | None, + ) -> Tuple[str, str, str, str, str | None]: problem_path = custom_problem.strip() or selected_problem artifacts = run_fuser_problem( problem_path=problem_path, @@ -802,7 +802,7 @@ def generate( run_timeout: int, reasoning: bool, strict_compile: bool, - api_key: Optional[str], + api_key: str | None, ): selected_path = problem_mapping.get(selected_label, default_problem) status, summary, code_text, run_info, zip_path = ui.run( diff --git a/scripts/pipeline_ui.py b/scripts/pipeline_ui.py index 5dba6b3..d9d446b 100644 --- a/scripts/pipeline_ui.py +++ b/scripts/pipeline_ui.py @@ -24,7 +24,7 @@ import zipfile from dataclasses import dataclass from pathlib import Path -from typing import List, Optional, Tuple +from typing import List, Tuple import gradio as gr from dotenv import load_dotenv @@ -51,7 +51,7 @@ def _list_kernelbench_problems(base: Path) -> List[Tuple[str, str]]: return problems -def _zip_dir(src_dir: Path, zip_path: Path) -> Optional[Path]: +def _zip_dir(src_dir: Path, zip_path: Path) -> Path | None: try: with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf: for root, _dirs, files in os.walk(src_dir): @@ -81,7 +81,7 @@ class PipelineArtifacts: details_md: str code_text: str run_info_md: str - zip_path: Optional[Path] + zip_path: Path | None def _write_temp_problem(code: str) -> Path: @@ -154,9 +154,9 @@ def run_pipeline_ui( compose_max_iters: int, verify: bool, auto_route: bool = False, - router_model: Optional[str] = None, + router_model: str | None = None, router_high_reasoning: bool = True, - user_api_key: Optional[str] = None, + user_api_key: str | None = None, ) -> PipelineArtifacts: from Fuser.auto_agent import AutoKernelRouter from Fuser.pipeline import run_pipeline @@ -448,8 +448,8 @@ def run( run_timeout: int, compose_max_iters: int, verify: bool, - user_api_key: Optional[str], - ) -> Tuple[str, str, str, str, Optional[str]]: + user_api_key: str | None, + ) -> Tuple[str, str, str, str, str | None]: problem_mapping = {label: path for label, path in self.problem_choices} selected_path = problem_mapping.get(selected_problem_label, "") # Use description override if present; otherwise selected path @@ -642,7 +642,7 @@ def on_run( run_timeout: int, compose_max_iters: int, verify: bool, - api_key: Optional[str], + api_key: str | None, ): return ui.run( selected_problem_label=selected_label, diff --git a/scripts/triton_ui.py b/scripts/triton_ui.py index 74884e5..7eb09af 100644 --- a/scripts/triton_ui.py +++ b/scripts/triton_ui.py @@ -20,7 +20,7 @@ import time import traceback from pathlib import Path -from typing import Any, Dict, Optional, Tuple +from typing import Tuple import gradio as gr from dotenv import load_dotenv @@ -99,11 +99,11 @@ def _provider_env_var(self, class_name: str) -> str: def generate_kernel( self, problem_description: str, - test_code: Optional[str] = None, + test_code: str | None = None, model_name: str = "o3-2025-04-16", provider_class_name: str = "", high_reasoning_effort: bool = True, - user_api_key: Optional[str] = None, + user_api_key: str | None = None, ) -> Tuple[str, str, str, str, str, str]: """ Generate a Triton kernel based on the problem description @@ -546,7 +546,7 @@ def _create_app() -> gr.Blocks: # Event handlers - def update_problem_descriptor(selection: Optional[str]): + def update_problem_descriptor(selection: str | None): if not selection: return gr.update() diff --git a/triton_kernel_agent/agent.py b/triton_kernel_agent/agent.py index 753d095..9cc4e55 100644 --- a/triton_kernel_agent/agent.py +++ b/triton_kernel_agent/agent.py @@ -18,7 +18,7 @@ import json import re from pathlib import Path -from typing import Optional, List, Dict, Any +from typing import List, Dict, Any from datetime import datetime import logging from dotenv import load_dotenv @@ -33,12 +33,12 @@ class TritonKernelAgent: def __init__( self, - num_workers: Optional[int] = None, - max_rounds: Optional[int] = None, - log_dir: Optional[str] = None, - model_name: Optional[str] = None, + num_workers: int | None = None, + max_rounds: int | None = None, + log_dir: str | None = None, + model_name: str | None = None, high_reasoning_effort: bool = True, - preferred_provider: Optional[BaseProvider] = None, + preferred_provider: BaseProvider | None = None, ): """ Initialize the Triton Kernel Agent. @@ -111,7 +111,7 @@ def _setup_logging(self): def _extract_code_from_response( self, response_text: str, language: str = "python" - ) -> Optional[str]: + ) -> str | None: """ Extract code from LLM response text. @@ -182,7 +182,7 @@ def _call_llm(self, messages: List[Dict[str, str]], **kwargs) -> str: return response.content def _generate_test( - self, problem_description: str, provided_test_code: Optional[str] = None + self, problem_description: str, provided_test_code: str | None = None ) -> str: """ Generate test code for the problem using OpenAI API. @@ -305,7 +305,7 @@ def test_kernel(): return test_code def _generate_kernel_seeds( - self, problem_description: str, test_code: str, num_seeds: Optional[int] = None + self, problem_description: str, test_code: str, num_seeds: int | None = None ) -> List[str]: """ Generate initial kernel implementations using OpenAI API. @@ -422,7 +422,7 @@ def kernel_function(*args, **kwargs): return kernels def generate_kernel( - self, problem_description: str, test_code: Optional[str] = None + self, problem_description: str, test_code: str | None = None ) -> Dict[str, Any]: """ Generate an optimized Triton kernel for the given problem. diff --git a/triton_kernel_agent/manager.py b/triton_kernel_agent/manager.py index 5161b57..5cd6898 100644 --- a/triton_kernel_agent/manager.py +++ b/triton_kernel_agent/manager.py @@ -19,7 +19,7 @@ import multiprocessing as mp import queue from pathlib import Path -from typing import List, Dict, Any, Optional +from typing import List, Dict, Any from datetime import datetime import logging from contextlib import contextmanager @@ -33,8 +33,8 @@ def __init__( num_workers: int = 4, max_rounds: int = 10, history_size: int = 8, - log_dir: Optional[str] = None, - openai_api_key: Optional[str] = None, + log_dir: str | None = None, + openai_api_key: str | None = None, openai_model: str = "gpt-5", high_reasoning_effort: bool = True, ): @@ -110,8 +110,8 @@ def run_verification( kernel_seeds: List[str], test_code: str, problem_description: str, - session_log_dir: Optional[Path] = None, - ) -> Optional[Dict[str, Any]]: + session_log_dir: Path | None = None, + ) -> Dict[str, Any | None]: """ Run parallel verification on multiple kernel seeds. @@ -220,7 +220,7 @@ def worker_process( history_size: int, success_event: mp.Event, result_queue: mp.Queue, - openai_api_key: Optional[str], + openai_api_key: str | None, openai_model: str, high_reasoning_effort: bool, ): diff --git a/triton_kernel_agent/prompt_manager.py b/triton_kernel_agent/prompt_manager.py index 174d5dc..8c6a368 100644 --- a/triton_kernel_agent/prompt_manager.py +++ b/triton_kernel_agent/prompt_manager.py @@ -15,7 +15,7 @@ """Prompt Manager for handling Jinja2 templates.""" from pathlib import Path -from typing import Dict, Optional +from typing import Dict try: from jinja2 import Environment, FileSystemLoader, Template @@ -36,7 +36,7 @@ class PromptManager: for test generation, kernel generation, and kernel refinement. """ - def __init__(self, templates_dir: Optional[str] = None): + def __init__(self, templates_dir: str | None = None): """ Initialize the prompt manager. @@ -91,7 +91,7 @@ def _load_templates(self): raise FileNotFoundError(f"Template file not found: {template_path}") def render_test_generation_prompt( - self, problem_description: str, provided_test_code: Optional[str] = None + self, problem_description: str, provided_test_code: str | None = None ) -> str: """ Render the test generation prompt. @@ -113,7 +113,7 @@ def render_kernel_generation_prompt( self, problem_description: str, test_code: str, - triton_guidelines: Optional[str] = None, + triton_guidelines: str | None = None, ) -> str: """ Render the kernel generation prompt. @@ -144,8 +144,8 @@ def render_kernel_refinement_prompt( test_code: str, kernel_code: str, error_info: Dict[str, str], - history_context: Optional[str] = None, - triton_guidelines: Optional[str] = None, + history_context: str | None = None, + triton_guidelines: str | None = None, ) -> str: """ Render the kernel refinement prompt. diff --git a/triton_kernel_agent/worker.py b/triton_kernel_agent/worker.py index eee6f13..6128197 100644 --- a/triton_kernel_agent/worker.py +++ b/triton_kernel_agent/worker.py @@ -24,7 +24,7 @@ from collections import deque from datetime import datetime from pathlib import Path -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Tuple from .prompt_manager import PromptManager from utils.providers import get_model_provider @@ -127,7 +127,7 @@ def __init__( log_dir: Path, max_rounds: int = 10, history_size: int = 8, - openai_api_key: Optional[str] = None, + openai_api_key: str | None = None, openai_model: str = "gpt-5", high_reasoning_effort: bool = True, ): @@ -187,7 +187,7 @@ def _setup_logging(self): def _extract_code_from_response( self, response_text: str, language: str = "python" - ) -> Optional[str]: + ) -> str | None: """ Extract code from LLM response text. @@ -258,7 +258,7 @@ def _strip_comments_and_strings(self, code: str) -> str: pattern = re.compile(r'("""[\s\S]*?"""|\'\'\'[\s\S]*?\'\'\'|#.*)') return re.sub(pattern, "", code) - def _detect_pytorch_compute(self, kernel_code: str) -> Optional[str]: + def _detect_pytorch_compute(self, kernel_code: str) -> str | None: """Detect disallowed PyTorch usage inside the kernel wrapper.""" sanitized = self._strip_comments_and_strings(kernel_code) for pattern, message in DISALLOWED_TORCH_PATTERNS: diff --git a/utils/providers/base.py b/utils/providers/base.py index ddbc655..38f1ae1 100644 --- a/utils/providers/base.py +++ b/utils/providers/base.py @@ -16,7 +16,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import List, Dict, Any, Optional +from typing import List, Dict, Any import os @@ -27,8 +27,8 @@ class LLMResponse: content: str model: str provider: str - usage: Optional[Dict[str, Any]] = None - response_id: Optional[str] = None + usage: Dict[str, Any | None] = None + response_id: str | None = None class BaseProvider(ABC): @@ -97,7 +97,7 @@ def get_max_tokens_limit(self, model_name: str) -> int: """Get the maximum tokens limit for a model.""" return 8192 # Default limit - def _get_api_key(self, env_var: str) -> Optional[str]: + def _get_api_key(self, env_var: str) -> str | None: """Helper to get API key from environment.""" api_key = os.getenv(env_var) if api_key and api_key != "your-api-key-here": diff --git a/utils/providers/env_config.py b/utils/providers/env_config.py index 4dc8ca3..f342f6e 100644 --- a/utils/providers/env_config.py +++ b/utils/providers/env_config.py @@ -17,10 +17,10 @@ import os import subprocess import logging -from typing import Dict, Optional +from typing import Dict -def _get_meta_proxy_config() -> Optional[Dict[str, str]]: +def _get_meta_proxy_config() -> Dict[str, str | None]: """ Get Meta's proxy configuration if available. @@ -59,7 +59,7 @@ def _get_meta_proxy_config() -> Optional[Dict[str, str]]: return None -def configure_proxy_environment() -> Optional[Dict[str, Optional[str]]]: +def configure_proxy_environment() -> Dict[str[str | None]]: """ Configure proxy environment variables for Meta environment. This is the centralized proxy configuration logic used by all providers. diff --git a/utils/providers/models.py b/utils/providers/models.py index a71a33c..f6ccdcf 100644 --- a/utils/providers/models.py +++ b/utils/providers/models.py @@ -14,13 +14,13 @@ """Model registry and configuration for KernelAgent.""" -from typing import Dict, List, Optional, Type +from typing import Dict, List, Type from .base import BaseProvider from .model_config import ModelConfig # Cached model lookup dictionary (lazily initialized) -_model_name_to_config: Optional[Dict[str, ModelConfig]] = None +_model_name_to_config: Dict[str, ModelConfig | None] = None # Provider instances cache _provider_instances: Dict[Type[BaseProvider], BaseProvider] = {} @@ -50,7 +50,7 @@ def _get_model_name_to_config() -> Dict[str, ModelConfig]: def get_model_provider( - model_name: str, preferred_provider: Optional[Type[BaseProvider]] = None + model_name: str, preferred_provider: Type[BaseProvider | None] = None ) -> BaseProvider: """ Get the first available provider instance for a given model. If a preferred @@ -58,7 +58,7 @@ def get_model_provider( Args: model_name: Name of the model - preferred_provider: Optional preffered provider class + preferred_provider: preffered provider class Returns: Provider instance @@ -105,7 +105,7 @@ def get_model_provider( def is_model_available( - model_name: str, preferred_provider: Optional[Type[BaseProvider]] = None + model_name: str, preferred_provider: Type[BaseProvider | None] = None ) -> bool: """Check if a model is available and its provider is ready. If a preferred provider is specified, only it will be checked diff --git a/utils/providers/openai_base.py b/utils/providers/openai_base.py index 43e07b4..f4ee527 100644 --- a/utils/providers/openai_base.py +++ b/utils/providers/openai_base.py @@ -14,7 +14,7 @@ """Base provider for OpenAI-compatible APIs.""" -from typing import List, Dict, Any, Optional +from typing import List, Dict, Any import logging from .base import BaseProvider, LLMResponse from .env_config import configure_proxy_environment @@ -31,7 +31,7 @@ class OpenAICompatibleProvider(BaseProvider): """Base provider for OpenAI-compatible APIs.""" - def __init__(self, api_key_env: str, base_url: Optional[str] = None): + def __init__(self, api_key_env: str, base_url: str | None = None): self.api_key_env = api_key_env self.base_url = base_url self._original_proxy_env = None From e04acb9ff1d1b47da738f7dc0cc9607c654af2be Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Fri, 19 Dec 2025 13:58:44 -0800 Subject: [PATCH 2/4] Fix optional typos --- Fuser/auto_agent.py | 2 +- Fuser/compose_end_to_end.py | 2 +- Fuser/event_adapter.py | 4 ++-- Fuser/pipeline.py | 2 +- Fuser/subgraph_extractor.py | 4 ++-- Fuser/worker.py | 2 +- utils/providers/base.py | 2 +- utils/providers/env_config.py | 4 ++-- utils/providers/models.py | 6 +++--- 9 files changed, 14 insertions(+), 14 deletions(-) diff --git a/Fuser/auto_agent.py b/Fuser/auto_agent.py index fc9822f..7a8b8d1 100644 --- a/Fuser/auto_agent.py +++ b/Fuser/auto_agent.py @@ -668,7 +668,7 @@ def _llm_decide_route( # ------------------------ -def main(argv: list[str | None] = None) -> int: +def main(argv: list[str] | None = None) -> int: p = argparse.ArgumentParser( description="Auto-router for KernelBench problems (KernelAgent vs Fuser)" ) diff --git a/Fuser/compose_end_to_end.py b/Fuser/compose_end_to_end.py index 02e293a..94976f8 100644 --- a/Fuser/compose_end_to_end.py +++ b/Fuser/compose_end_to_end.py @@ -404,7 +404,7 @@ def compose( return result -def main(argv: List[str | None] = None) -> int: +def main(argv: List[str] | None = None) -> int: load_dotenv() p = argparse.ArgumentParser( description="Compose end-to-end Triton kernel from subgraphs + generated kernels" diff --git a/Fuser/event_adapter.py b/Fuser/event_adapter.py index 289b8f4..d9c609d 100644 --- a/Fuser/event_adapter.py +++ b/Fuser/event_adapter.py @@ -49,7 +49,7 @@ def __init__( timeout_s: int, jsonl_path: Path, stop_event: threading.Event | None = None, - on_delta: Callable[[str | None, None]] = None, + on_delta: Callable[[str, None]] | None = None, client: Any | None = None, ) -> None: self.model = model @@ -117,7 +117,7 @@ def stream( self, system_prompt: str, user_prompt: str, - extras: dict[str, Any | None] = None, + extras: dict[str, Any] | None = None, ) -> dict[str, Any]: """ Start streaming and persist events. Returns a dict with summary fields: diff --git a/Fuser/pipeline.py b/Fuser/pipeline.py index 386d3d9..f96444e 100644 --- a/Fuser/pipeline.py +++ b/Fuser/pipeline.py @@ -129,7 +129,7 @@ def run_pipeline( } -def main(argv: list[str | None] = None) -> int: +def main(argv: list[str] | None = None) -> int: # Load .env if present for OPENAI_API_KEY, proxies, etc. try: from dotenv import load_dotenv # type: ignore diff --git a/Fuser/subgraph_extractor.py b/Fuser/subgraph_extractor.py index 2a7a98b..b28f5cf 100644 --- a/Fuser/subgraph_extractor.py +++ b/Fuser/subgraph_extractor.py @@ -72,7 +72,7 @@ def _load_code_from_tar(artifact_path: Path) -> str: def _extract_json_block(text: str) -> str: """Extract the last fenced JSON block or fallback to best-effort slice.""" matches = list(_JSON_BLOCK_RE.finditer(text)) - chosen: re.Match[str | None] = None + chosen: re.Match[str] | None = None for m in reversed(matches): lang = (m.group(1) or "").strip().lower() if lang == "json": @@ -367,7 +367,7 @@ def sort_w(obj: Any) -> Dict[str, Any]: return dirs["run_dir"], out_path -def main(argv: list[str | None] = None) -> int: +def main(argv: list[str] | None = None) -> int: _load_dotenv_if_present() p = argparse.ArgumentParser( description="Extract unique subgraphs with shapes (JSON)" diff --git a/Fuser/worker.py b/Fuser/worker.py index 4d58519..8481193 100644 --- a/Fuser/worker.py +++ b/Fuser/worker.py @@ -77,7 +77,7 @@ def __init__( problem_path: Path, winner_queue: Any, cancel_event: Any, - on_delta: Callable[[str | None, None]] = None, + on_delta: Callable[[str, None]] | None = None, ) -> None: self.cfg = cfg self.problem_path = problem_path diff --git a/utils/providers/base.py b/utils/providers/base.py index 38f1ae1..6ba3fe1 100644 --- a/utils/providers/base.py +++ b/utils/providers/base.py @@ -27,7 +27,7 @@ class LLMResponse: content: str model: str provider: str - usage: Dict[str, Any | None] = None + usage: Dict[str, Any] | None = None response_id: str | None = None diff --git a/utils/providers/env_config.py b/utils/providers/env_config.py index f342f6e..72b0e2e 100644 --- a/utils/providers/env_config.py +++ b/utils/providers/env_config.py @@ -20,7 +20,7 @@ from typing import Dict -def _get_meta_proxy_config() -> Dict[str, str | None]: +def _get_meta_proxy_config() -> Dict[str, str] | None: """ Get Meta's proxy configuration if available. @@ -59,7 +59,7 @@ def _get_meta_proxy_config() -> Dict[str, str | None]: return None -def configure_proxy_environment() -> Dict[str[str | None]]: +def configure_proxy_environment() -> Dict[str, str] | None: """ Configure proxy environment variables for Meta environment. This is the centralized proxy configuration logic used by all providers. diff --git a/utils/providers/models.py b/utils/providers/models.py index f6ccdcf..ee73238 100644 --- a/utils/providers/models.py +++ b/utils/providers/models.py @@ -20,7 +20,7 @@ from .model_config import ModelConfig # Cached model lookup dictionary (lazily initialized) -_model_name_to_config: Dict[str, ModelConfig | None] = None +_model_name_to_config: Dict[str, ModelConfig] | None = None # Provider instances cache _provider_instances: Dict[Type[BaseProvider], BaseProvider] = {} @@ -50,7 +50,7 @@ def _get_model_name_to_config() -> Dict[str, ModelConfig]: def get_model_provider( - model_name: str, preferred_provider: Type[BaseProvider | None] = None + model_name: str, preferred_provider: Type[BaseProvider] | None = None ) -> BaseProvider: """ Get the first available provider instance for a given model. If a preferred @@ -105,7 +105,7 @@ def get_model_provider( def is_model_available( - model_name: str, preferred_provider: Type[BaseProvider | None] = None + model_name: str, preferred_provider: Type[BaseProvider] | None = None ) -> bool: """Check if a model is available and its provider is ready. If a preferred provider is specified, only it will be checked From e803fffb2ac5b3d6556472dce5c693da41ea7002 Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Fri, 19 Dec 2025 14:13:15 -0800 Subject: [PATCH 3/4] Update typing for dict/list/tuple --- Fuser/auto_agent.py | 22 ++++++++-------- Fuser/compose_end_to_end.py | 36 ++++++++++++------------- Fuser/dedup.py | 4 +-- Fuser/dispatch_kernel_agent.py | 38 +++++++++++++-------------- Fuser/orchestrator.py | 4 +-- Fuser/subgraph_extractor.py | 12 ++++----- scripts/fuser_ui.py | 16 +++++------ scripts/pipeline_ui.py | 8 +++--- scripts/triton_ui.py | 25 +++++++++--------- triton_kernel_agent/agent.py | 8 +++--- triton_kernel_agent/manager.py | 10 +++---- triton_kernel_agent/prompt_manager.py | 4 +-- triton_kernel_agent/worker.py | 8 +++--- triton_kernel_agent/worker_util.py | 4 +-- utils/providers/anthropic_provider.py | 8 +++--- utils/providers/base.py | 10 +++---- utils/providers/env_config.py | 6 ++--- utils/providers/model_config.py | 4 +-- utils/providers/models.py | 10 +++---- utils/providers/openai_base.py | 12 ++++----- utils/providers/relay_provider.py | 8 +++--- 21 files changed, 129 insertions(+), 128 deletions(-) diff --git a/Fuser/auto_agent.py b/Fuser/auto_agent.py index 7a8b8d1..583fe60 100644 --- a/Fuser/auto_agent.py +++ b/Fuser/auto_agent.py @@ -50,7 +50,7 @@ import sys from dataclasses import dataclass from pathlib import Path -from typing import Any, Dict, Tuple +from typing import Any from dotenv import load_dotenv from Fuser.pipeline import run_pipeline @@ -105,7 +105,7 @@ def _file_sha256_text(txt: str) -> str: return hashlib.sha256(txt.encode("utf-8")).hexdigest() -def _load_router_cache() -> Dict[str, Any]: +def _load_router_cache() -> dict[str, Any]: try: if _ROUTER_CACHE_PATH.is_file(): return json.loads(_ROUTER_CACHE_PATH.read_text(encoding="utf-8")) @@ -114,7 +114,7 @@ def _load_router_cache() -> Dict[str, Any]: return {} -def _save_router_cache(cache: Dict[str, Any]) -> None: +def _save_router_cache(cache: dict[str, Any]) -> None: try: _ensure_dir(_ROUTER_CACHE_PATH) _ROUTER_CACHE_PATH.write_text(json.dumps(cache, indent=2), encoding="utf-8") @@ -144,7 +144,7 @@ class Complexity: pool_ops: int act_ops: int chain_len_estimate: int - raw_op_names: Dict[str, int] + raw_op_names: dict[str, int] def route_to_fuser(self) -> bool: # Primary triggers @@ -213,7 +213,7 @@ def analyze_problem_code(code: str) -> Complexity: # AST path: inspect Model.forward for ops and control flow has_control_flow = False - raw_op_counts: Dict[str, int] = {} + raw_op_counts: dict[str, int] = {} has_attention_like = False has_conv_transpose = False has_group_norm = False @@ -298,7 +298,7 @@ def visit_Assign(self, node: ast.Assign) -> Any: class RouteResult: route: str # "kernelagent" or "fuser" success: bool - details: Dict[str, Any] + details: dict[str, Any] kernel_code: str | None = None @@ -459,7 +459,7 @@ def solve(self, problem_path: Path) -> RouteResult: strategy: str | None = None route_conf: float | None = None - route_cfg: Dict[str, Any] = {} + route_cfg: dict[str, Any] = {} if isinstance(cached, dict): strategy = ( @@ -545,7 +545,7 @@ def solve(self, problem_path: Path) -> RouteResult: # -------- LLM decision helper -------- def _llm_decide_route( self, problem_path: Path, code: str, cx: Complexity - ) -> Tuple[str | None, float | None, Dict[str, Any]]: + ) -> tuple[str | None, float | None, dict[str, Any]]: """Ask an LLM to choose a routing STRATEGY and optional budgets. The LLM must return JSON with keys: @@ -621,7 +621,7 @@ def _llm_decide_route( f"Features:\n```json\n{json.dumps(feats, indent=2)}\n```\n\n" "Problem code:\n```python\n" + code + "\n```\n" ) - kwargs: Dict[str, Any] = { + kwargs: dict[str, Any] = { "max_tokens": self.router_max_tokens, "temperature": self.router_temperature, } @@ -636,7 +636,7 @@ def _llm_decide_route( # Best-effort JSON parse route = None conf = None - raw_info: Dict[str, Any] = {"raw": txt} + raw_info: dict[str, Any] = {"raw": txt} try: # If model returned extra text, try to locate JSON object first = txt.find("{") @@ -748,7 +748,7 @@ def main(argv: list[str] | None = None) -> int: ) return 1 - out: Dict[str, Any] = { + out: dict[str, Any] = { "route": res.route, "success": res.success, "details": res.details, diff --git a/Fuser/compose_end_to_end.py b/Fuser/compose_end_to_end.py index 94976f8..5c7c142 100644 --- a/Fuser/compose_end_to_end.py +++ b/Fuser/compose_end_to_end.py @@ -45,7 +45,7 @@ import textwrap from dataclasses import dataclass from pathlib import Path -from typing import Any, Dict, List, Tuple +from typing import Any from dotenv import load_dotenv @@ -73,11 +73,11 @@ def _read_text(path: Path) -> str: return path.read_text(encoding="utf-8") -def _load_kernels_from_summary(summary_path: Path) -> List[KernelItem]: +def _load_kernels_from_summary(summary_path: Path) -> list[KernelItem]: data = json.loads(_read_text(summary_path)) if not isinstance(data, list): raise SystemExit("kernels summary must be a JSON array (from dispatch step)") - items: List[KernelItem] = [] + items: list[KernelItem] = [] for it in data: if not isinstance(it, dict): continue @@ -98,8 +98,8 @@ def _load_kernels_from_summary(summary_path: Path) -> List[KernelItem]: return items -def _summarize_subgraphs_for_prompt(subgraphs: List[Dict[str, Any]]) -> str: - lines: List[str] = [] +def _summarize_subgraphs_for_prompt(subgraphs: list[dict[str, Any]]) -> str: + lines: list[str] = [] for it in subgraphs: sid = str(it.get("id", "unknown")) typ = str(it.get("type", "")) @@ -126,8 +126,8 @@ def _summarize_subgraphs_for_prompt(subgraphs: List[Dict[str, Any]]) -> str: def _build_composition_prompt( problem_code: str, - subgraphs: List[Dict[str, Any]], - kernel_items: List[KernelItem], + subgraphs: list[dict[str, Any]], + kernel_items: list[KernelItem], ) -> str: """Create a single user message to instruct composition by the LLM.""" # Provide a succinct summary of subgraphs up front @@ -135,7 +135,7 @@ def _build_composition_prompt( # Include only essential snippets from each kernel to keep token usage sane # We include full files for now; callers can trim by model limits. - kernels_section_parts: List[str] = [] + kernels_section_parts: list[str] = [] for ki in kernel_items: kernels_section_parts.append( f"### Subgraph {ki.subgraph_id}\n```python\n" + ki.code + "\n```\n" @@ -190,7 +190,7 @@ def _build_composition_prompt( """ ).strip() - user_lines: List[str] = [] + user_lines: list[str] = [] user_lines.append(guidance) user_lines.append("") user_lines.append("SUBGRAPHS (summary):") @@ -212,10 +212,10 @@ def _build_composition_prompt( def _build_refinement_prompt( problem_code: str, - subgraphs: List[Dict[str, Any]], - kernel_items: List[KernelItem], + subgraphs: list[dict[str, Any]], + kernel_items: list[KernelItem], previous_code: str, - error_info: Dict[str, str], + error_info: dict[str, str], ) -> str: """Prompt the LLM to refine the previously produced code based on errors.""" err_tail = error_info.get("stderr_tail", "") @@ -239,7 +239,7 @@ def _build_refinement_prompt( """ ).strip() - lines: List[str] = [] + lines: list[str] = [] lines.append(guidance) lines.append("") lines.append("ERROR_CONTEXT (stderr tail):\n```\n" + err_tail + "\n```") @@ -259,7 +259,7 @@ def _build_refinement_prompt( return "\n".join(lines) -def _auto_patch_common_triton_issues(code: str) -> Tuple[str, bool]: +def _auto_patch_common_triton_issues(code: str) -> tuple[str, bool]: """Apply tiny safe textual patches for known Triton pitfalls. - Replace tl.broadcast(0.0, ...) or tl.broadcast(1.0, ...) with scalar constants. @@ -289,7 +289,7 @@ def compose( model_name: str, verify: bool = False, max_iters: int = 5, -) -> Dict[str, Any]: +) -> dict[str, Any]: if get_model_provider is None: raise SystemExit( "KernelAgent providers unavailable; ensure package import and dependencies" @@ -310,7 +310,7 @@ def compose( last_usage = None last_code = None - verify_info: Dict[str, Any] = {} + verify_info: dict[str, Any] = {} for i in range(1, max_iters + 1): if i == 1 or last_code is None: @@ -388,7 +388,7 @@ def compose( composed_path = out_dir / "composed_kernel.py" composed_path.write_text(last_code or "", encoding="utf-8") - result: Dict[str, Any] = { + result: dict[str, Any] = { "success": bool(verify_info.get("verify_passed", not verify)), "composed_path": str(composed_path.resolve()), "model": model_name, @@ -404,7 +404,7 @@ def compose( return result -def main(argv: List[str] | None = None) -> int: +def main(argv: list[str] | None = None) -> int: load_dotenv() p = argparse.ArgumentParser( description="Compose end-to-end Triton kernel from subgraphs + generated kernels" diff --git a/Fuser/dedup.py b/Fuser/dedup.py index abf1f7a..8e34c8d 100644 --- a/Fuser/dedup.py +++ b/Fuser/dedup.py @@ -15,12 +15,12 @@ import json import time from pathlib import Path -from typing import Tuple + def register_digest( shared_digests_dir: Path, sha256: str, worker_id: str, iter_index: int -) -> Tuple[str, str | None]: +) -> tuple[str, str | None]: """ Atomically register a digest in shared_digests_dir. Returns (status, owner_worker_id or None), where status is one of: diff --git a/Fuser/dispatch_kernel_agent.py b/Fuser/dispatch_kernel_agent.py index 89307f5..b32807e 100644 --- a/Fuser/dispatch_kernel_agent.py +++ b/Fuser/dispatch_kernel_agent.py @@ -39,7 +39,7 @@ import os import textwrap from pathlib import Path -from typing import Any, Dict, List, Tuple +from typing import Any import concurrent.futures as _futures from dotenv import load_dotenv @@ -50,7 +50,7 @@ TritonKernelAgent = None # type: ignore -def _shape_list(shape: Any) -> List[str]: +def _shape_list(shape: Any) -> list[str]: if isinstance(shape, list): return [str(x) for x in shape] return [str(shape)] if shape is not None else [] @@ -69,7 +69,7 @@ def _py_tuple(arr: Any) -> str: return f"({', '.join(vals)})" -def _pick_weights(item: Dict[str, Any], keys: List[str]) -> Dict[str, Any]: +def _pick_weights(item: dict[str, Any], keys: list[str]) -> dict[str, Any]: # Prefer explicit fused/original dicts; fallback to generic 'weights' ws = ( item.get("weights_fused") @@ -79,7 +79,7 @@ def _pick_weights(item: Dict[str, Any], keys: List[str]) -> Dict[str, Any]: ) if not isinstance(ws, dict): return {} - out: Dict[str, Any] = {} + out: dict[str, Any] = {} for k in keys: if k in ws: out[k] = ws[k] @@ -89,20 +89,20 @@ def _pick_weights(item: Dict[str, Any], keys: List[str]) -> Dict[str, Any]: return out -def _build_reference_code(item: Dict[str, Any]) -> Tuple[str, List[str]]: +def _build_reference_code(item: dict[str, Any]) -> tuple[str, list[str]]: """Return (reference_code_str, param_names) implementing the subgraph. param_names are additional parameters to reference() beyond the first input(s). """ - ops: List[Dict[str, Any]] = [ + ops: list[dict[str, Any]] = [ op for op in (item.get("ops") or []) if isinstance(op, dict) ] - lines: List[str] = ["import torch", "import torch.nn.functional as F", ""] - params: List[str] = [] + lines: list[str] = ["import torch", "import torch.nn.functional as F", ""] + params: list[str] = [] # Determine if multi-input inputs_multi = item.get("inputs") - input_names: List[str] + input_names: list[str] if isinstance(inputs_multi, list) and inputs_multi: input_names = [f"x{i}" for i in range(len(inputs_multi))] header = f"def reference({', '.join(input_names)}" # weights appended later @@ -110,7 +110,7 @@ def _build_reference_code(item: Dict[str, Any]) -> Tuple[str, List[str]]: input_names = ["x"] header = "def reference(x" - body: List[str] = [] + body: list[str] = [] cur = input_names[0] if input_names else "x" for op in ops: @@ -123,7 +123,7 @@ def _build_reference_code(item: Dict[str, Any]) -> Tuple[str, List[str]]: else ("weight" if "weight" in wmap else "conv_weight") ) b = "bias" if "bias" in wmap else None - args: List[str] = [cur, w] + args: list[str] = [cur, w] if b: args.append(b) stride = _py_tuple(op.get("stride", (1, 1))) @@ -142,7 +142,7 @@ def _build_reference_code(item: Dict[str, Any]) -> Tuple[str, List[str]]: else ("weight" if "weight" in wmap else "conv_transpose_weight") ) b = "bias" if "bias" in wmap else None - args: List[str] = [cur, w] + args: list[str] = [cur, w] if b: args.append(b) stride = _py_tuple(op.get("stride", (1, 1))) @@ -252,7 +252,7 @@ def _build_reference_code(item: Dict[str, Any]) -> Tuple[str, List[str]]: return "\n".join(lines) + "\n", params -def _synthesize_problem_description(item: Dict[str, Any]) -> str: +def _synthesize_problem_description(item: dict[str, Any]) -> str: id_ = str(item.get("id", "unknown")) type_ = str(item.get("type", "")) layout = item.get("data_layout") or "NCHW" @@ -327,7 +327,7 @@ def run( ) with subgraphs_path.open("r", encoding="utf-8") as f: - items: List[Dict[str, Any]] = json.load(f) + items: list[dict[str, Any]] = json.load(f) if not isinstance(items, list): raise SystemExit("subgraphs.json must be a JSON array") @@ -335,7 +335,7 @@ def run( # Worker function: create a dedicated agent instance per subgraph to avoid # cross-thread state interactions inside the agent/manager. - def _handle_one(idx_item: Tuple[int, Dict[str, Any]]) -> Tuple[int, Dict[str, Any]]: + def _handle_one(idx_item: tuple[int, dict[str, Any]]) -> tuple[int, dict[str, Any]]: idx, item = idx_item sid = str(item.get("id", f"subgraph_{idx}")) pdesc = _synthesize_problem_description(item) @@ -384,8 +384,8 @@ def _handle_one(idx_item: Tuple[int, Dict[str, Any]]) -> Tuple[int, Dict[str, An # Submit tasks with bounded concurrency jobs = max(1, int(jobs or 1)) - ordered_inputs: List[Tuple[int, Dict[str, Any]]] = list(enumerate(items, start=1)) - results: Dict[int, Dict[str, Any]] = {} + ordered_inputs: list[tuple[int, dict[str, Any]]] = list(enumerate(items, start=1)) + results: dict[int, dict[str, Any]] = {} if jobs == 1: for pair in ordered_inputs: i, res = _handle_one(pair) @@ -400,13 +400,13 @@ def _handle_one(idx_item: Tuple[int, Dict[str, Any]]) -> Tuple[int, Dict[str, An results[i] = res # Preserve input order in summary output - summary: List[Dict[str, Any]] = [results[i] for i in sorted(results.keys())] + summary: list[dict[str, Any]] = [results[i] for i in sorted(results.keys())] out_summary = out_dir / "summary.json" out_summary.write_text(json.dumps(summary, indent=2), encoding="utf-8") return out_summary -def main(argv: List[str] | None = None) -> int: +def main(argv: list[str] | None = None) -> int: load_dotenv() p = argparse.ArgumentParser( description="Generate Triton kernels for subgraphs via KernelAgent" diff --git a/Fuser/orchestrator.py b/Fuser/orchestrator.py index 402706c..f47a489 100644 --- a/Fuser/orchestrator.py +++ b/Fuser/orchestrator.py @@ -23,7 +23,7 @@ from dataclasses import asdict from pathlib import Path from queue import Empty -from typing import Any, Dict +from typing import Any from dotenv import load_dotenv @@ -35,7 +35,7 @@ # Spawn-safe worker entrypoint (top-level function; pass plain payload) def _worker_process_main( - cfg_payload: Dict[str, Any], + cfg_payload: dict[str, Any], problem_path: str, winner_queue: Any, cancel_event: Any, diff --git a/Fuser/subgraph_extractor.py b/Fuser/subgraph_extractor.py index b28f5cf..80851fa 100644 --- a/Fuser/subgraph_extractor.py +++ b/Fuser/subgraph_extractor.py @@ -38,7 +38,7 @@ import sys import tarfile from pathlib import Path -from typing import Any, Tuple, Dict +from typing import Any from .cli import _load_dotenv_if_present # reuse env loader from .config import OrchestratorConfig, new_run_id @@ -139,7 +139,7 @@ def norm_shapes(arr: Any) -> Any: return out -def _build_llm_prompt_for_shapes(fused_code: str, problem_code: str) -> Tuple[str, str]: +def _build_llm_prompt_for_shapes(fused_code: str, problem_code: str) -> tuple[str, str]: system = "Return a single JSON array only." user_lines: list[str] = [] user_lines.append( @@ -212,7 +212,7 @@ def extract_subgraphs_to_json( max_iters: int, llm_timeout_s: int, run_timeout_s: int, -) -> Tuple[Path, Path]: +) -> tuple[Path, Path]: """Run Fuser to produce fused code, then use LLM to emit subgraphs JSON. Returns (run_dir, json_path). @@ -300,9 +300,9 @@ def extract_subgraphs_to_json( raise SystemExit("LLM output JSON is not a list") # Merge duplicates by signature and sum counts - grouped: Dict[str, Dict[str, Any]] = {} + grouped: dict[str, dict[str, Any]] = {} - def sig_of(it: Dict[str, Any]) -> str: + def sig_of(it: dict[str, Any]) -> str: # Build a robust signature from ops + shapes + weights ops = it.get("ops") or [] # normalize ops by sorting keys of each op dict @@ -320,7 +320,7 @@ def sig_of(it: Dict[str, Any]) -> str: weights_original = it.get("weights_original") or {} # sort weight dicts by key for stability - def sort_w(obj: Any) -> Dict[str, Any]: + def sort_w(obj: Any) -> dict[str, Any]: if isinstance(obj, dict): return {k: obj[k] for k in sorted(obj.keys())} return {} diff --git a/scripts/fuser_ui.py b/scripts/fuser_ui.py index cd55e83..97e1a8b 100644 --- a/scripts/fuser_ui.py +++ b/scripts/fuser_ui.py @@ -25,7 +25,7 @@ import zipfile from dataclasses import dataclass from pathlib import Path -from typing import List, Tuple + import gradio as gr from dotenv import load_dotenv @@ -51,9 +51,9 @@ class RunArtifacts: zip_path: Path | None -def _list_kernelbench_problems(base: Path) -> List[Tuple[str, str]]: +def _list_kernelbench_problems(base: Path) -> list[tuple[str, str]]: """Return list of (label, absolute_path) pairs for KernelBench problems.""" - problems: List[Tuple[str, str]] = [] + problems: list[tuple[str, str]] = [] if not base.exists(): return problems for level_dir in sorted(base.glob("level*")): @@ -80,7 +80,7 @@ def base_name(node: ast.expr) -> str: if isinstance(node, ast.Name): return node.id if isinstance(node, ast.Attribute): - parts: List[str] = [] + parts: list[str] = [] cur: ast.AST | None = node while isinstance(cur, ast.Attribute): parts.append(cur.attr) @@ -91,9 +91,9 @@ def base_name(node: ast.expr) -> str: return ".".join(parts) return ast.dump(node, include_attributes=False) - class_lines: List[str] = ["## 🧩 Fusion Module Summary"] - classes: List[ast.ClassDef] = [n for n in tree.body if isinstance(n, ast.ClassDef)] - functions: List[ast.FunctionDef] = [ + class_lines: list[str] = ["## 🧩 Fusion Module Summary"] + classes: list[ast.ClassDef] = [n for n in tree.body if isinstance(n, ast.ClassDef)] + functions: list[ast.FunctionDef] = [ n for n in tree.body if isinstance(n, ast.FunctionDef) ] @@ -665,7 +665,7 @@ def run( run_timeout: int, enable_reasoning: bool, user_api_key: str | None, - ) -> Tuple[str, str, str, str, str | None]: + ) -> tuple[str, str, str, str, str | None]: problem_path = custom_problem.strip() or selected_problem artifacts = run_fuser_problem( problem_path=problem_path, diff --git a/scripts/pipeline_ui.py b/scripts/pipeline_ui.py index d9d446b..fa1a583 100644 --- a/scripts/pipeline_ui.py +++ b/scripts/pipeline_ui.py @@ -24,7 +24,7 @@ import zipfile from dataclasses import dataclass from pathlib import Path -from typing import List, Tuple + import gradio as gr from dotenv import load_dotenv @@ -35,9 +35,9 @@ sys.path.insert(0, str(_PROJECT_ROOT)) -def _list_kernelbench_problems(base: Path) -> List[Tuple[str, str]]: +def _list_kernelbench_problems(base: Path) -> list[tuple[str, str]]: """Return list of (label, absolute_path) pairs for KernelBench problems.""" - problems: List[Tuple[str, str]] = [] + problems: list[tuple[str, str]] = [] if not base.exists(): return problems for level_dir in sorted(base.glob("level*")): @@ -449,7 +449,7 @@ def run( compose_max_iters: int, verify: bool, user_api_key: str | None, - ) -> Tuple[str, str, str, str, str | None]: + ) -> tuple[str, str, str, str, str | None]: problem_mapping = {label: path for label, path in self.problem_choices} selected_path = problem_mapping.get(selected_problem_label, "") # Use description override if present; otherwise selected path diff --git a/scripts/triton_ui.py b/scripts/triton_ui.py index 7eb09af..d3c0908 100644 --- a/scripts/triton_ui.py +++ b/scripts/triton_ui.py @@ -20,7 +20,8 @@ import time import traceback from pathlib import Path -from typing import Tuple +from typing import Any + import gradio as gr from dotenv import load_dotenv @@ -43,9 +44,9 @@ def load_kernelbench_problem_map( - levels: Tuple[str, ...] = ("level1", "level2"), -) -> Dict[str, Path]: - problem_map: Dict[str, Path] = {} + levels: tuple[str, ...] = ("level1", "level2"), +) -> dict[str, Path]: + problem_map: dict[str, Path] = {} for level in levels: level_dir = KERNELBENCH_BASE_PATH / level if not level_dir.is_dir(): @@ -104,7 +105,7 @@ def generate_kernel( provider_class_name: str = "", high_reasoning_effort: bool = True, user_api_key: str | None = None, - ) -> Tuple[str, str, str, str, str, str]: + ) -> tuple[str, str, str, str, str, str]: """ Generate a Triton kernel based on the problem description @@ -235,7 +236,7 @@ def generate_kernel( if key_env_var in os.environ: del os.environ[key_env_var] - def _format_logs(self, result: Dict[str, Any], generation_time: float) -> str: + def _format_logs(self, result: dict[str, Any], generation_time: float) -> str: """Format generation logs for display""" logs = f"""## Generation Summary @@ -260,7 +261,7 @@ def _format_logs(self, result: Dict[str, Any], generation_time: float) -> str: """ return logs - def _format_error_logs(self, result: Dict[str, Any], generation_time: float) -> str: + def _format_error_logs(self, result: dict[str, Any], generation_time: float) -> str: """Format error logs for display""" logs = f"""## Generation Failed @@ -282,7 +283,7 @@ def _format_error_logs(self, result: Dict[str, Any], generation_time: float) -> """ return logs - def _format_session_info(self, result: Dict[str, Any]) -> str: + def _format_session_info(self, result: dict[str, Any]) -> str: """Format session information""" session_path = result["session_dir"] session_name = os.path.basename(session_path) @@ -313,7 +314,7 @@ def _format_session_info(self, result: Dict[str, Any]) -> str: """ return info - def _create_download_info(self, result: Dict[str, Any]) -> str: + def _create_download_info(self, result: dict[str, Any]) -> str: """Create download information""" if not result["success"]: return "" @@ -352,7 +353,7 @@ def _create_app() -> gr.Blocks: kernelbench_problem_map = load_kernelbench_problem_map() # Add external problems (manual entries) - extra_problem_map: Dict[str, Path] = {} + extra_problem_map: dict[str, Path] = {} try: external_cf = Path( "/home/leyuan/workplace/kernel_fuser/external/control_flow.py" @@ -371,13 +372,13 @@ def _create_app() -> gr.Blocks: pass # Combine: external first, then KernelBench - combined_problem_map: Dict[str, Path] = { + combined_problem_map: dict[str, Path] = { **extra_problem_map, **kernelbench_problem_map, } problem_choices = list(combined_problem_map.keys()) default_problem_choice = problem_choices[0] if problem_choices else None - problem_cache: Dict[str, str] = {} + problem_cache: dict[str, str] = {} if default_problem_choice: try: diff --git a/triton_kernel_agent/agent.py b/triton_kernel_agent/agent.py index 9cc4e55..8c79cbb 100644 --- a/triton_kernel_agent/agent.py +++ b/triton_kernel_agent/agent.py @@ -18,7 +18,7 @@ import json import re from pathlib import Path -from typing import List, Dict, Any +from typing import Any from datetime import datetime import logging from dotenv import load_dotenv @@ -160,7 +160,7 @@ def _extract_code_from_response( self.logger.warning("No code block found in LLM response") return None - def _call_llm(self, messages: List[Dict[str, str]], **kwargs) -> str: + def _call_llm(self, messages: list[dict[str, str]], **kwargs) -> str: """ Call the LLM provider for the configured model. @@ -306,7 +306,7 @@ def test_kernel(): def _generate_kernel_seeds( self, problem_description: str, test_code: str, num_seeds: int | None = None - ) -> List[str]: + ) -> list[str]: """ Generate initial kernel implementations using OpenAI API. @@ -423,7 +423,7 @@ def kernel_function(*args, **kwargs): def generate_kernel( self, problem_description: str, test_code: str | None = None - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Generate an optimized Triton kernel for the given problem. diff --git a/triton_kernel_agent/manager.py b/triton_kernel_agent/manager.py index 5cd6898..15d2401 100644 --- a/triton_kernel_agent/manager.py +++ b/triton_kernel_agent/manager.py @@ -19,7 +19,7 @@ import multiprocessing as mp import queue from pathlib import Path -from typing import List, Dict, Any +from typing import Any from datetime import datetime import logging from contextlib import contextmanager @@ -71,7 +71,7 @@ def __init__( # Setup multiprocessing self.success_event = mp.Event() # Shared event to signal success self.result_queue = mp.Queue() # Queue for collecting results - self.workers: List[mp.Process] = [] + self.workers: list[mp.Process] = [] # Setup logger self._setup_logging() @@ -89,7 +89,7 @@ def _setup_logging(self): self.logger = logging.getLogger(__name__) @contextmanager - def temp_workdirs(self) -> List[Path]: + def temp_workdirs(self) -> list[Path]: """Create temporary working directories for workers.""" workdirs = [] try: @@ -107,11 +107,11 @@ def temp_workdirs(self) -> List[Path]: def run_verification( self, - kernel_seeds: List[str], + kernel_seeds: list[str], test_code: str, problem_description: str, session_log_dir: Path | None = None, - ) -> Dict[str, Any | None]: + ) -> dict[str, Any | None]: """ Run parallel verification on multiple kernel seeds. diff --git a/triton_kernel_agent/prompt_manager.py b/triton_kernel_agent/prompt_manager.py index 8c6a368..5fa76cf 100644 --- a/triton_kernel_agent/prompt_manager.py +++ b/triton_kernel_agent/prompt_manager.py @@ -15,7 +15,7 @@ """Prompt Manager for handling Jinja2 templates.""" from pathlib import Path -from typing import Dict + try: from jinja2 import Environment, FileSystemLoader, Template @@ -143,7 +143,7 @@ def render_kernel_refinement_prompt( problem_description: str, test_code: str, kernel_code: str, - error_info: Dict[str, str], + error_info: dict[str, str], history_context: str | None = None, triton_guidelines: str | None = None, ) -> str: diff --git a/triton_kernel_agent/worker.py b/triton_kernel_agent/worker.py index 6128197..f33b5fb 100644 --- a/triton_kernel_agent/worker.py +++ b/triton_kernel_agent/worker.py @@ -24,7 +24,7 @@ from collections import deque from datetime import datetime from pathlib import Path -from typing import Any, Dict, Tuple +from typing import Any from .prompt_manager import PromptManager from utils.providers import get_model_provider @@ -266,7 +266,7 @@ def _detect_pytorch_compute(self, kernel_code: str) -> str | None: return message return None - def _run_test(self) -> Tuple[bool, str, str]: + def _run_test(self) -> tuple[bool, str, str]: """ Run the test script and capture results. @@ -327,7 +327,7 @@ def _call_llm(self, messages: list, **kwargs) -> str: def _refine_kernel( self, kernel_code: str, - error_info: Dict[str, str], + error_info: dict[str, str], problem_description: str, test_code: str, ) -> str: @@ -419,7 +419,7 @@ def run( test_code: str, problem_description: str, success_event: mp.Event, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Run verification and refinement loop. diff --git a/triton_kernel_agent/worker_util.py b/triton_kernel_agent/worker_util.py index 6b72387..1a8bc22 100644 --- a/triton_kernel_agent/worker_util.py +++ b/triton_kernel_agent/worker_util.py @@ -17,7 +17,7 @@ import multiprocessing as mp import os from pathlib import Path -from typing import Tuple + from logging import Logger @@ -102,7 +102,7 @@ def _run_test_process(test_file: Path, workdir: Path, result_queue: mp.Queue) -> def _run_test_multiprocess( logger: Logger, workdir: Path, test_file: Path -) -> Tuple[bool, str, str]: +) -> tuple[bool, str, str]: """ Run the test script and capture results using multiprocessing. diff --git a/utils/providers/anthropic_provider.py b/utils/providers/anthropic_provider.py index b80c5e6..ef4bc9d 100644 --- a/utils/providers/anthropic_provider.py +++ b/utils/providers/anthropic_provider.py @@ -14,7 +14,7 @@ """Anthropic provider implementation.""" -from typing import List, Dict + from .base import BaseProvider, LLMResponse from .env_config import configure_proxy_environment @@ -44,7 +44,7 @@ def _initialize_client(self) -> None: self.client = Anthropic(api_key=api_key) def get_response( - self, model_name: str, messages: List[Dict[str, str]], **kwargs + self, model_name: str, messages: list[dict[str, str]], **kwargs ) -> LLMResponse: if not self.is_available(): raise RuntimeError("Anthropic client not available") @@ -64,8 +64,8 @@ def get_response( ) def get_multiple_responses( - self, model_name: str, messages: List[Dict[str, str]], n: int = 1, **kwargs - ) -> List[LLMResponse]: + self, model_name: str, messages: list[dict[str, str]], n: int = 1, **kwargs + ) -> list[LLMResponse]: return [ self.get_response( model_name, diff --git a/utils/providers/base.py b/utils/providers/base.py index 6ba3fe1..a77bfc7 100644 --- a/utils/providers/base.py +++ b/utils/providers/base.py @@ -16,7 +16,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import List, Dict, Any +from typing import Any import os @@ -27,7 +27,7 @@ class LLMResponse: content: str model: str provider: str - usage: Dict[str, Any] | None = None + usage: dict[str, Any] | None = None response_id: str | None = None @@ -45,7 +45,7 @@ def _initialize_client(self) -> None: @abstractmethod def get_response( - self, model_name: str, messages: List[Dict[str, str]], **kwargs + self, model_name: str, messages: list[dict[str, str]], **kwargs ) -> LLMResponse: """ Get response from the LLM provider. @@ -62,8 +62,8 @@ def get_response( @abstractmethod def get_multiple_responses( - self, model_name: str, messages: List[Dict[str, str]], n: int = 1, **kwargs - ) -> List[LLMResponse]: + self, model_name: str, messages: list[dict[str, str]], n: int = 1, **kwargs + ) -> list[LLMResponse]: """ Get multiple responses from the LLM provider. diff --git a/utils/providers/env_config.py b/utils/providers/env_config.py index 72b0e2e..1adb36f 100644 --- a/utils/providers/env_config.py +++ b/utils/providers/env_config.py @@ -17,10 +17,10 @@ import os import subprocess import logging -from typing import Dict -def _get_meta_proxy_config() -> Dict[str, str] | None: + +def _get_meta_proxy_config() -> dict[str, str] | None: """ Get Meta's proxy configuration if available. @@ -59,7 +59,7 @@ def _get_meta_proxy_config() -> Dict[str, str] | None: return None -def configure_proxy_environment() -> Dict[str, str] | None: +def configure_proxy_environment() -> dict[str, str] | None: """ Configure proxy environment variables for Meta environment. This is the centralized proxy configuration logic used by all providers. diff --git a/utils/providers/model_config.py b/utils/providers/model_config.py index c31d6ba..83ece71 100644 --- a/utils/providers/model_config.py +++ b/utils/providers/model_config.py @@ -15,7 +15,7 @@ """Model configuration dataclass for KernelAgent.""" from dataclasses import dataclass -from typing import List, Type +from typing import Type from .base import BaseProvider @@ -25,5 +25,5 @@ class ModelConfig: """Configuration for a specific model.""" name: str - provider_classes: List[Type[BaseProvider]] + provider_classes: list[Type[BaseProvider]] description: str = "" diff --git a/utils/providers/models.py b/utils/providers/models.py index ee73238..447e198 100644 --- a/utils/providers/models.py +++ b/utils/providers/models.py @@ -14,16 +14,16 @@ """Model registry and configuration for KernelAgent.""" -from typing import Dict, List, Type +from typing import Type from .base import BaseProvider from .model_config import ModelConfig # Cached model lookup dictionary (lazily initialized) -_model_name_to_config: Dict[str, ModelConfig] | None = None +_model_name_to_config: dict[str, ModelConfig] | None = None # Provider instances cache -_provider_instances: Dict[Type[BaseProvider], BaseProvider] = {} +_provider_instances: dict[Type[BaseProvider], BaseProvider] = {} def _get_or_create_provider( @@ -35,13 +35,13 @@ def _get_or_create_provider( return _provider_instances[provider_class] -def get_available_models() -> List[ModelConfig]: +def get_available_models() -> list[ModelConfig]: from .available_models import AVAILABLE_MODELS return AVAILABLE_MODELS -def _get_model_name_to_config() -> Dict[str, ModelConfig]: +def _get_model_name_to_config() -> dict[str, ModelConfig]: """Get the model name to config lookup dictionary (lazily initialized).""" global _model_name_to_config if _model_name_to_config is None: diff --git a/utils/providers/openai_base.py b/utils/providers/openai_base.py index f4ee527..3351ace 100644 --- a/utils/providers/openai_base.py +++ b/utils/providers/openai_base.py @@ -14,7 +14,7 @@ """Base provider for OpenAI-compatible APIs.""" -from typing import List, Dict, Any +from typing import Any import logging from .base import BaseProvider, LLMResponse from .env_config import configure_proxy_environment @@ -54,7 +54,7 @@ def _initialize_client(self) -> None: self.client = OpenAI(api_key=api_key) def get_response( - self, model_name: str, messages: List[Dict[str, str]], **kwargs + self, model_name: str, messages: list[dict[str, str]], **kwargs ) -> LLMResponse: """Get single response.""" if not self.is_available(): @@ -77,8 +77,8 @@ def get_response( ) def get_multiple_responses( - self, model_name: str, messages: List[Dict[str, str]], n: int = 1, **kwargs - ) -> List[LLMResponse]: + self, model_name: str, messages: list[dict[str, str]], n: int = 1, **kwargs + ) -> list[LLMResponse]: """Get multiple responses using n parameter.""" if not self.is_available(): raise RuntimeError(f"{self.name} client not available") @@ -103,8 +103,8 @@ def get_multiple_responses( ] def _build_api_params( - self, model_name: str, messages: List[Dict[str, str]], **kwargs - ) -> Dict[str, Any]: + self, model_name: str, messages: list[dict[str, str]], **kwargs + ) -> dict[str, Any]: """Build API parameters for OpenAI-compatible call.""" params = { "model": model_name, diff --git a/utils/providers/relay_provider.py b/utils/providers/relay_provider.py index 43e1606..6d8bb21 100644 --- a/utils/providers/relay_provider.py +++ b/utils/providers/relay_provider.py @@ -14,7 +14,7 @@ """Relay provider implementation.""" -from typing import Dict, List + import requests import logging @@ -48,7 +48,7 @@ def _initialize_client(self) -> None: self.is_available_flag = False def get_response( - self, model_name: str, messages: List[Dict[str, str]], **kwargs + self, model_name: str, messages: list[dict[str, str]], **kwargs ) -> LLMResponse: """ Supported kwargs: @@ -116,10 +116,10 @@ def get_response( def get_multiple_responses( self, model_name: str, - messages: List[Dict[str, str]], + messages: list[dict[str, str]], n: int = 1, **kwargs, - ) -> List[LLMResponse]: + ) -> list[LLMResponse]: return [ self.get_response( model_name, From 409a05799fe220b783fb172019ad5672c2190216 Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Fri, 19 Dec 2025 14:36:05 -0800 Subject: [PATCH 4/4] Whitespace --- Fuser/dedup.py | 1 - utils/providers/anthropic_provider.py | 1 - utils/providers/env_config.py | 1 - utils/providers/relay_provider.py | 2 -- 4 files changed, 5 deletions(-) diff --git a/Fuser/dedup.py b/Fuser/dedup.py index 8e34c8d..b30f06e 100644 --- a/Fuser/dedup.py +++ b/Fuser/dedup.py @@ -17,7 +17,6 @@ from pathlib import Path - def register_digest( shared_digests_dir: Path, sha256: str, worker_id: str, iter_index: int ) -> tuple[str, str | None]: diff --git a/utils/providers/anthropic_provider.py b/utils/providers/anthropic_provider.py index ef4bc9d..7fca461 100644 --- a/utils/providers/anthropic_provider.py +++ b/utils/providers/anthropic_provider.py @@ -14,7 +14,6 @@ """Anthropic provider implementation.""" - from .base import BaseProvider, LLMResponse from .env_config import configure_proxy_environment diff --git a/utils/providers/env_config.py b/utils/providers/env_config.py index 1adb36f..3238580 100644 --- a/utils/providers/env_config.py +++ b/utils/providers/env_config.py @@ -19,7 +19,6 @@ import logging - def _get_meta_proxy_config() -> dict[str, str] | None: """ Get Meta's proxy configuration if available. diff --git a/utils/providers/relay_provider.py b/utils/providers/relay_provider.py index 6d8bb21..025f3c7 100644 --- a/utils/providers/relay_provider.py +++ b/utils/providers/relay_provider.py @@ -14,8 +14,6 @@ """Relay provider implementation.""" - - import requests import logging import os