diff --git a/deadend_cli/deadend_agent/src/deadend_agent/core.py b/deadend_cli/deadend_agent/src/deadend_agent/core.py index 031ebee..63e46d3 100644 --- a/deadend_cli/deadend_agent/src/deadend_agent/core.py +++ b/deadend_cli/deadend_agent/src/deadend_agent/core.py @@ -66,9 +66,10 @@ def sandbox_setup() -> SandboxManager: sandbox_manager = SandboxManager() return sandbox_manager -def setup_model_registry(config: Config) -> ModelRegistry: +async def setup_model_registry(config: Config) -> ModelRegistry: """Setup Model registry""" model_registry = ModelRegistry(config=config) + await model_registry.initialize() return model_registry def _file_matches_sha256(path: Path, expected_hash: str) -> bool: diff --git a/deadend_cli/deadend_agent/src/deadend_agent/models/registry.py b/deadend_cli/deadend_agent/src/deadend_agent/models/registry.py index 8dac622..0e5c032 100644 --- a/deadend_cli/deadend_agent/src/deadend_agent/models/registry.py +++ b/deadend_cli/deadend_agent/src/deadend_agent/models/registry.py @@ -10,7 +10,7 @@ objects that are consumed by the CoreAgent and other components. """ -from typing import Dict +from typing import Dict, Optional import aiohttp from deadend_agent.config.settings import Config, ModelSpec, EmbeddingSpec, ProvidersList @@ -19,23 +19,25 @@ class EmbedderClient: """Client for generating embeddings using various embedding API providers. - + This class provides a unified interface for embedding generation across different providers (OpenAI, OpenRouter, etc.) by abstracting the API communication and response parsing. - + Attributes: model: Name of the embedding model to use. api_key: API key for authenticating with the embedding service. base_url: Base URL for the embedding API endpoint. + _session: Shared aiohttp ClientSession for connection reuse. """ model: str api_key: str base_url: str + _session: Optional[aiohttp.ClientSession] def __init__(self, model_name: str, api_key: str, base_url: str) -> None: """Initialize the EmbedderClient with provider configuration. - + Args: model_name: Name of the embedding model to use (e.g., "text-embedding-3-small"). api_key: API key for authenticating with the embedding service. @@ -44,65 +46,88 @@ def __init__(self, model_name: str, api_key: str, base_url: str) -> None: self.model = model_name self.api_key = api_key self.base_url = base_url + self._session = None + + async def initialize(self) -> None: + """Initialize the shared ClientSession for HTTP requests. + + Creates a persistent aiohttp ClientSession that will be reused + across all embedding requests to avoid resource exhaustion from + creating too many concurrent connections. + """ + if self._session is None: + self._session = aiohttp.ClientSession() + + async def close(self) -> None: + """Close the shared ClientSession and cleanup resources. + + Should be called when the EmbedderClient is no longer needed + to properly release HTTP connection resources. + """ + if self._session is not None: + await self._session.close() + self._session = None async def batch_embed(self, input: list) -> list: """Generate embeddings for a batch of input texts. - + Sends a batch embedding request to the configured API endpoint and handles various response formats. Supports OpenAI-compatible APIs and other providers with different response structures. - + Args: input: List of text strings to embed. Each string will be embedded into a vector representation. - + Returns: List of embedding dictionaries. Each dictionary contains an 'embedding' key with the vector representation. Returns empty list if no embeddings were generated. - + Raises: ValueError: If the API returns a non-200 status code, an error response, or an unexpected response structure. + RuntimeError: If the session has not been initialized. """ - async with aiohttp.ClientSession() as session: - response = await session.post( - url=self.base_url, - headers={ - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json", - }, - json={ - "model": self.model, - "input": input - } - ) - - # Check HTTP status code - if response.status != 200: - error_text = await response.text() - raise ValueError(f"Embedding API returned status {response.status}: {error_text}") - - data = await response.json() - - # Handle different response structures - # OpenAI format: {"data": [{"embedding": [...]}, ...]} - # Some APIs might return the data directly or in a different structure - if isinstance(data, dict) and 'data' in data: - embeddings = data['data'] - elif isinstance(data, list): - # Response is already a list of embeddings - embeddings = data - elif isinstance(data, dict) and 'error' in data: - # API returned an error - error_info = data.get('error', {}) - error_msg = error_info.get('message', str(error_info)) if isinstance(error_info, dict) else str(error_info) - raise ValueError(f"Embedding API error: {error_msg}") - else: - # Try to find embeddings in the response - error_msg = f"Unexpected response structure: \ - {list(data.keys()) if isinstance(data, dict) else type(data)}" - raise ValueError(error_msg) + if self._session is None: + raise RuntimeError("EmbedderClient session not initialized. Call initialize() first.") + + response = await self._session.post( + url=self.base_url, + headers={ + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + }, + json={ + "model": self.model, + "input": input + } + ) + + # Check HTTP status code + if response.status != 200: + error_text = await response.text() + raise ValueError(f"Embedding API returned status {response.status}: {error_text}") + + data = await response.json() + + # Handle different response structures + # OpenAI format: {"data": [{"embedding": [...]}, ...]} + # Some APIs might return the data directly or in a different structure + if isinstance(data, dict) and 'data' in data: + embeddings = data['data'] + elif isinstance(data, list): + # Response is already a list of embeddings + embeddings = data + elif isinstance(data, dict) and 'error' in data: + # API returned an error + error_info = data.get('error', {}) + error_msg = error_info.get('message', str(error_info)) if isinstance(error_info, dict) else str(error_info) + raise ValueError(f"Embedding API error: {error_msg}") + else: + # Try to find embeddings in the response + error_msg = f"Unexpected response structure: {list(data.keys()) if isinstance(data, dict) else type(data)}" + raise ValueError(error_msg) return embeddings if embeddings else [] @@ -113,16 +138,18 @@ class ModelInfo(BaseModel): class ModelRegistry: """Registry for managing model specifications from multiple providers. - + This class initializes and manages access to language model specifications from various providers (OpenAI, Anthropic, Google/Gemini, OpenRouter, Local) based on configuration settings. It also manages the embedding client for generating vector embeddings via HTTP. - + Attributes: embedder_model: Embedding client instance, or None if not initialized. + _initialized: Flag indicating whether async initialization is complete. """ embedder_model: EmbedderClient | None + _initialized: bool def __init__(self, config: Config): """Initialize the ModelRegistry with configuration. @@ -131,6 +158,9 @@ def __init__(self, config: Config): model instances for all configured providers. Also sets up the embedding client based on the first available provider configuration. + Note: After creating ModelRegistry, you must call initialize() before + using the embedder client. + Args: config: Configuration object containing API keys and model settings for various providers. @@ -139,8 +169,19 @@ def __init__(self, config: Config): self._models: Dict[str, list[ModelSpec]] = {} # Keep a reference to config for runtime spec creation self._config = config + self._initialized = False self._initialize_models(config=config) + async def initialize(self) -> None: + """Initialize async resources like the embedder ClientSession. + + Must be called after __init__ and before using the embedder client. + This is a separate method because __init__ cannot be async. + """ + if not self._initialized and self.embedder_model is not None: + await self.embedder_model.initialize() + self._initialized = True + def _initialize_models(self, config: Config): """Initialize model specifications and embedding client. diff --git a/deadend_cli/deadend_agent/src/deadend_agent/rag/db_cruds.py b/deadend_cli/deadend_agent/src/deadend_agent/rag/db_cruds.py index ff13fa0..370e429 100644 --- a/deadend_cli/deadend_agent/src/deadend_agent/rag/db_cruds.py +++ b/deadend_cli/deadend_agent/src/deadend_agent/rag/db_cruds.py @@ -14,6 +14,7 @@ from datetime import datetime from typing import List, Optional, Dict, Any, AsyncGenerator from contextlib import asynccontextmanager +from urllib.parse import urlparse # import numpy as np from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker from sqlalchemy import text, select @@ -28,12 +29,20 @@ def __init__(self, database_url: str, pool_size: int = 20, max_overflow: int = 3 if database_url.startswith("postgresql://"): database_url = database_url.replace("postgresql://", "postgresql+asyncpg://", 1) + # Disable SSL for localhost connections to fix macOS asyncpg issues + # asyncpg requires ssl=False instead of sslmode URL parameter + parsed = urlparse(database_url) + connect_args = {} + if parsed.hostname in ('localhost', '127.0.0.1', '::1'): + connect_args['ssl'] = False + self.engine = create_async_engine( database_url, pool_size=pool_size, max_overflow=max_overflow, pool_pre_ping=True, - echo=False # Set to True for SQL debugging + echo=False, # Set to True for SQL debugging + connect_args=connect_args ) self.async_session = async_sessionmaker( diff --git a/deadend_cli/src/deadend_cli/chat.py b/deadend_cli/src/deadend_cli/chat.py index 17c6f77..b317662 100644 --- a/deadend_cli/src/deadend_cli/chat.py +++ b/deadend_cli/src/deadend_cli/chat.py @@ -361,6 +361,7 @@ async def chat_interface( ): """Chat Interface for the CLI""" model_registry = ModelRegistry(config=config) + await model_registry.initialize() if not model_registry.has_any_model(): raise RuntimeError(f"No LM model configured. You can run `deadend init` to \ initialize the required Model configuration for {llm_provider}") diff --git a/deadend_cli/src/deadend_cli/eval.py b/deadend_cli/src/deadend_cli/eval.py index ff321a1..61a4a83 100644 --- a/deadend_cli/src/deadend_cli/eval.py +++ b/deadend_cli/src/deadend_cli/eval.py @@ -67,6 +67,7 @@ async def eval_interface( eval_metadata = EvalMetadata(**data) model_registry = ModelRegistry(config=config) + await model_registry.initialize() if not model_registry.has_any_model(): raise RuntimeError(f"No LM model configured. You can run `deadend init` to \ initialize the required Model configuration for {providers[0]}") diff --git a/deadend_cli/src/deadend_cli/init.py b/deadend_cli/src/deadend_cli/init.py index bfdb9c4..7994783 100644 --- a/deadend_cli/src/deadend_cli/init.py +++ b/deadend_cli/src/deadend_cli/init.py @@ -12,6 +12,7 @@ import time from pathlib import Path import sys + import docker import toml import typer