Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion deadend_cli/deadend_agent/src/deadend_agent/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,10 @@ def sandbox_setup() -> SandboxManager:
sandbox_manager = SandboxManager()
return sandbox_manager

def setup_model_registry(config: Config) -> ModelRegistry:
async def setup_model_registry(config: Config) -> ModelRegistry:
"""Setup Model registry"""
model_registry = ModelRegistry(config=config)
await model_registry.initialize()
return model_registry

def _file_matches_sha256(path: Path, expected_hash: str) -> bool:
Expand Down
137 changes: 89 additions & 48 deletions deadend_cli/deadend_agent/src/deadend_agent/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
objects that are consumed by the CoreAgent and other components.
"""

from typing import Dict
from typing import Dict, Optional
import aiohttp

from deadend_agent.config.settings import Config, ModelSpec, EmbeddingSpec, ProvidersList
Expand All @@ -19,23 +19,25 @@

class EmbedderClient:
"""Client for generating embeddings using various embedding API providers.

This class provides a unified interface for embedding generation across
different providers (OpenAI, OpenRouter, etc.) by abstracting the API
communication and response parsing.

Attributes:
model: Name of the embedding model to use.
api_key: API key for authenticating with the embedding service.
base_url: Base URL for the embedding API endpoint.
_session: Shared aiohttp ClientSession for connection reuse.
"""
model: str
api_key: str
base_url: str
_session: Optional[aiohttp.ClientSession]

def __init__(self, model_name: str, api_key: str, base_url: str) -> None:
"""Initialize the EmbedderClient with provider configuration.

Args:
model_name: Name of the embedding model to use (e.g., "text-embedding-3-small").
api_key: API key for authenticating with the embedding service.
Expand All @@ -44,65 +46,88 @@ def __init__(self, model_name: str, api_key: str, base_url: str) -> None:
self.model = model_name
self.api_key = api_key
self.base_url = base_url
self._session = None

async def initialize(self) -> None:
"""Initialize the shared ClientSession for HTTP requests.

Creates a persistent aiohttp ClientSession that will be reused
across all embedding requests to avoid resource exhaustion from
creating too many concurrent connections.
"""
if self._session is None:
self._session = aiohttp.ClientSession()

async def close(self) -> None:
"""Close the shared ClientSession and cleanup resources.

Should be called when the EmbedderClient is no longer needed
to properly release HTTP connection resources.
"""
if self._session is not None:
await self._session.close()
self._session = None

async def batch_embed(self, input: list) -> list:
"""Generate embeddings for a batch of input texts.

Sends a batch embedding request to the configured API endpoint and
handles various response formats. Supports OpenAI-compatible APIs
and other providers with different response structures.

Args:
input: List of text strings to embed. Each string will be
embedded into a vector representation.

Returns:
List of embedding dictionaries. Each dictionary contains an
'embedding' key with the vector representation. Returns empty
list if no embeddings were generated.

Raises:
ValueError: If the API returns a non-200 status code, an error
response, or an unexpected response structure.
RuntimeError: If the session has not been initialized.
"""
async with aiohttp.ClientSession() as session:
response = await session.post(
url=self.base_url,
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
json={
"model": self.model,
"input": input
}
)

# Check HTTP status code
if response.status != 200:
error_text = await response.text()
raise ValueError(f"Embedding API returned status {response.status}: {error_text}")

data = await response.json()

# Handle different response structures
# OpenAI format: {"data": [{"embedding": [...]}, ...]}
# Some APIs might return the data directly or in a different structure
if isinstance(data, dict) and 'data' in data:
embeddings = data['data']
elif isinstance(data, list):
# Response is already a list of embeddings
embeddings = data
elif isinstance(data, dict) and 'error' in data:
# API returned an error
error_info = data.get('error', {})
error_msg = error_info.get('message', str(error_info)) if isinstance(error_info, dict) else str(error_info)
raise ValueError(f"Embedding API error: {error_msg}")
else:
# Try to find embeddings in the response
error_msg = f"Unexpected response structure: \
{list(data.keys()) if isinstance(data, dict) else type(data)}"
raise ValueError(error_msg)
if self._session is None:
raise RuntimeError("EmbedderClient session not initialized. Call initialize() first.")

response = await self._session.post(
url=self.base_url,
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
json={
"model": self.model,
"input": input
}
)

# Check HTTP status code
if response.status != 200:
error_text = await response.text()
raise ValueError(f"Embedding API returned status {response.status}: {error_text}")

data = await response.json()

# Handle different response structures
# OpenAI format: {"data": [{"embedding": [...]}, ...]}
# Some APIs might return the data directly or in a different structure
if isinstance(data, dict) and 'data' in data:
embeddings = data['data']
elif isinstance(data, list):
# Response is already a list of embeddings
embeddings = data
elif isinstance(data, dict) and 'error' in data:
# API returned an error
error_info = data.get('error', {})
error_msg = error_info.get('message', str(error_info)) if isinstance(error_info, dict) else str(error_info)
raise ValueError(f"Embedding API error: {error_msg}")
else:
# Try to find embeddings in the response
error_msg = f"Unexpected response structure: {list(data.keys()) if isinstance(data, dict) else type(data)}"
raise ValueError(error_msg)

return embeddings if embeddings else []

Expand All @@ -113,16 +138,18 @@ class ModelInfo(BaseModel):

class ModelRegistry:
"""Registry for managing model specifications from multiple providers.

This class initializes and manages access to language model specifications
from various providers (OpenAI, Anthropic, Google/Gemini, OpenRouter, Local)
based on configuration settings. It also manages the embedding client for
generating vector embeddings via HTTP.

Attributes:
embedder_model: Embedding client instance, or None if not initialized.
_initialized: Flag indicating whether async initialization is complete.
"""
embedder_model: EmbedderClient | None
_initialized: bool

def __init__(self, config: Config):
"""Initialize the ModelRegistry with configuration.
Expand All @@ -131,6 +158,9 @@ def __init__(self, config: Config):
model instances for all configured providers. Also sets up the
embedding client based on the first available provider configuration.

Note: After creating ModelRegistry, you must call initialize() before
using the embedder client.

Args:
config: Configuration object containing API keys and model settings
for various providers.
Expand All @@ -139,8 +169,19 @@ def __init__(self, config: Config):
self._models: Dict[str, list[ModelSpec]] = {}
# Keep a reference to config for runtime spec creation
self._config = config
self._initialized = False
self._initialize_models(config=config)

async def initialize(self) -> None:
"""Initialize async resources like the embedder ClientSession.

Must be called after __init__ and before using the embedder client.
This is a separate method because __init__ cannot be async.
"""
if not self._initialized and self.embedder_model is not None:
await self.embedder_model.initialize()
self._initialized = True

def _initialize_models(self, config: Config):
"""Initialize model specifications and embedding client.

Expand Down
11 changes: 10 additions & 1 deletion deadend_cli/deadend_agent/src/deadend_agent/rag/db_cruds.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from datetime import datetime
from typing import List, Optional, Dict, Any, AsyncGenerator
from contextlib import asynccontextmanager
from urllib.parse import urlparse
# import numpy as np
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from sqlalchemy import text, select
Expand All @@ -28,12 +29,20 @@ def __init__(self, database_url: str, pool_size: int = 20, max_overflow: int = 3
if database_url.startswith("postgresql://"):
database_url = database_url.replace("postgresql://", "postgresql+asyncpg://", 1)

# Disable SSL for localhost connections to fix macOS asyncpg issues
# asyncpg requires ssl=False instead of sslmode URL parameter
parsed = urlparse(database_url)
connect_args = {}
if parsed.hostname in ('localhost', '127.0.0.1', '::1'):
connect_args['ssl'] = False

self.engine = create_async_engine(
database_url,
pool_size=pool_size,
max_overflow=max_overflow,
pool_pre_ping=True,
echo=False # Set to True for SQL debugging
echo=False, # Set to True for SQL debugging
connect_args=connect_args
)

self.async_session = async_sessionmaker(
Expand Down
1 change: 1 addition & 0 deletions deadend_cli/src/deadend_cli/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ async def chat_interface(
):
"""Chat Interface for the CLI"""
model_registry = ModelRegistry(config=config)
await model_registry.initialize()
if not model_registry.has_any_model():
raise RuntimeError(f"No LM model configured. You can run `deadend init` to \
initialize the required Model configuration for {llm_provider}")
Expand Down
1 change: 1 addition & 0 deletions deadend_cli/src/deadend_cli/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ async def eval_interface(
eval_metadata = EvalMetadata(**data)

model_registry = ModelRegistry(config=config)
await model_registry.initialize()
if not model_registry.has_any_model():
raise RuntimeError(f"No LM model configured. You can run `deadend init` to \
initialize the required Model configuration for {providers[0]}")
Expand Down
1 change: 1 addition & 0 deletions deadend_cli/src/deadend_cli/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import time
from pathlib import Path
import sys

import docker
import toml
import typer
Expand Down