Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
GEMINI_API_KEY=<TOKEN>
REDIS_PASSWORD=<REDIS_PASSWORD>
File renamed without changes.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ For containerized deployment:

3. **Run the container**:
```sh
sudo docker run -p 7860:7860 --env-file .env reagentai
sudo docker compose up
```

4. **Access the application**:
Expand Down
43 changes: 43 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
services:
app:
container_name: reagent_app
build:
context: .
dockerfile: Dockerfile.reagent
ports:
- "7860:7860"
env_file:
- .env
environment:
- REDIS_HOST=redis
- REDIS_PORT=6379
- REDIS_PASSWORD=${REDIS_PASSWORD}
networks:
- app-network
restart: no
depends_on:
redis:
condition: service_started

redis:
container_name: reagent_redis
image: redis:7-alpine
volumes:
- redis_data:/data
networks:
- app-network
restart: no
command: redis-server --appendonly yes --maxmemory 256mb --maxmemory-policy allkeys-lru --requirepass ${REDIS_PASSWORD}
healthcheck:
test: ["CMD", "redis-cli", "-a", "${REDIS_PASSWORD}", "ping"]
interval: 60s
timeout: 5s
retries: 3

volumes:
redis_data:
driver: local

networks:
app-network:
driver: bridge
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dependencies = [
"gradio>=5.29.1",
"pydantic-ai-slim[duckduckgo]>=0.2.4",
"pubchempy>=1.0.4",
"redis[hiredis]>=6.2.0",
]

[tool.black]
Expand Down
27 changes: 15 additions & 12 deletions src/reagentai/agents/main/main_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(
self,
model_name: str,
instructions: str,
tools: list[Tool],
tools: list[Tool[MainAgentDependencyTypes]],
dependency_types: type[MainAgentDependencyTypes],
dependencies: MainAgentDependencyTypes,
output_type: type[str],
Expand Down Expand Up @@ -80,12 +80,13 @@ def _create_agent(self) -> Agent[MainAgentDependencyTypes, str]:
Agent[MainAgentDependencyTypes, str]: An instance of the Agent configured with the main agent's model and instructions.
"""

return Agent(
return Agent[MainAgentDependencyTypes, str](
self.model_name,
tools=self.tools,
instructions=self.instructions,
deps_type=self.dependency_types,
output_type=self.output_type,
retries=3,
)

def remove_last_messages(self, remove_user_prompt: bool = True):
Expand Down Expand Up @@ -122,7 +123,7 @@ def get_total_token_usage(self) -> int:
Returns:
int: The total number of tokens used by the agent.
"""
if self.usage:
if self.usage and self.usage.total_tokens:
return self.usage.total_tokens
else:
return 0
Expand All @@ -137,7 +138,9 @@ def clear_history(self):
self.usage = None

@asynccontextmanager
async def run_stream(self, user_query: str) -> AsyncIterator[result.StreamedRunResult]:
async def run_stream(
self, user_query: str
) -> AsyncIterator[result.StreamedRunResult[MainAgentDependencyTypes, str]]:
"""
Streams the response from the agent asynchronously.

Expand Down Expand Up @@ -199,14 +202,14 @@ def create_main_agent() -> MainAgent:
instructions = instructions_file.read()

tools = [
Tool(perform_retrosynthesis, takes_ctx=True),
Tool(is_valid_smiles),
Tool(smiles_to_image),
Tool(route_to_image),
Tool(find_similar_molecules),
Tool(get_smiles_from_name),
Tool(get_compound_info),
Tool(get_name_from_smiles),
Tool[MainAgentDependencyTypes](perform_retrosynthesis, takes_ctx=True),
Tool[MainAgentDependencyTypes](is_valid_smiles),
Tool[MainAgentDependencyTypes](smiles_to_image),
Tool[MainAgentDependencyTypes](route_to_image),
Tool[MainAgentDependencyTypes](find_similar_molecules),
Tool[MainAgentDependencyTypes](get_smiles_from_name),
Tool[MainAgentDependencyTypes](get_compound_info),
Tool[MainAgentDependencyTypes](get_name_from_smiles),
duckduckgo_search_tool(),
]

Expand Down
138 changes: 130 additions & 8 deletions src/reagentai/common/utils/cache.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,148 @@
import logging
import pickle

from aizynthfinder.context.config import Configuration

from src.reagentai.common.utils.redis import RedisManager
from src.reagentai.models.retrosynthesis import RouteCollection

logger = logging.getLogger(__name__)


class RetrosynthesisCache:
"""
A cache for storing retrosynthesis routes based on target SMILES strings.
This class provides methods to add, retrieve, and clear cached routes.
It also maintains a configuration for the AiZynthFinder instance used in retrosynthesis.
Supports both in-memory and Redis backends with automatic fallback.
"""

routes_cache: dict[str, RouteCollection] = {}
# Class-level cache for fast access
_memory_cache: dict[str, RouteCollection] = {}
finder_config: Configuration | None = None
_cache_prefix = "retrosynthesis"
_default_ttl = 86400 # 24 hours

@classmethod
def _serialize_data(cls, data: RouteCollection) -> bytes:
"""Serialize RouteCollection for Redis storage."""
try:
return pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)
except Exception as e:
logger.error(f"Failed to serialize data: {e}")
raise

@classmethod
def add(cls, target_smile: str, data: RouteCollection):
cls.routes_cache[target_smile] = data
def _deserialize_data(cls, data: bytes) -> RouteCollection:
"""Deserialize RouteCollection from Redis storage."""
try:
return pickle.loads(data)
except Exception as e:
logger.error(f"Failed to deserialize data: {e}")
raise

@classmethod
def _get_cache_key(cls, target_smile: str) -> str:
"""Generate standardized cache key."""
normalized_smile = target_smile.strip().lower()
return f"{cls._cache_prefix}:{normalized_smile}"

@classmethod
def add(cls, target_smile: str, data: RouteCollection, ttl: int | None = None) -> bool:
"""Add route collection to cache."""
if not target_smile or not data:
logger.warning("Invalid input for cache add operation")
return False

# Always store in memory cache
cls._memory_cache[target_smile] = data

# Attempt Redis storage
ttl = ttl or cls._default_ttl
cache_key = cls._get_cache_key(target_smile)

with RedisManager.get_client() as redis_client:
if redis_client:
try:
serialized_data = cls._serialize_data(data)
result = redis_client.setex(cache_key, ttl, serialized_data)
if result:
logger.debug(f"Cached to Redis: {cache_key}")
return True
except Exception as e:
logger.warning(f"Failed to cache to Redis: {e}")

logger.debug(f"Cached to memory only: {target_smile}")
return True

@classmethod
def get(cls, target_smile: str) -> RouteCollection | None:
return cls.routes_cache.get(target_smile)
"""Retrieve route collection from cache."""
if not target_smile:
return None

# Check memory cache first
if target_smile in cls._memory_cache:
logger.debug(f"Cache hit (memory): {target_smile}")
return cls._memory_cache[target_smile]

# Check Redis cache
cache_key = cls._get_cache_key(target_smile)

with RedisManager.get_client() as redis_client:
if redis_client:
try:
cached_data = redis_client.get(cache_key)
if cached_data and isinstance(cached_data, bytes):
data = cls._deserialize_data(cached_data)
cls._memory_cache[target_smile] = data
logger.debug(f"Cache hit (Redis): {target_smile}")
return data
except Exception as e:
logger.warning(f"Failed to retrieve from Redis: {e}")

logger.debug(f"Cache miss: {target_smile}")
return None

@classmethod
def delete(cls, target_smile: str) -> bool:
"""Delete specific entry from cache."""
if not target_smile:
return False

cls._memory_cache.pop(target_smile, None)
cache_key = cls._get_cache_key(target_smile)

with RedisManager.get_client() as redis_client:
if redis_client:
try:
result = redis_client.delete(cache_key)
logger.debug(f"Deleted from cache: {target_smile}")
return bool(result)
except Exception as e:
logger.warning(f"Failed to delete from Redis: {e}")

return True

@classmethod
def clear(cls) -> bool:
"""Clear all cached routes."""
cls._memory_cache.clear()

with RedisManager.get_client() as redis_client:
if redis_client:
try:
pipeline = redis_client.pipeline()
for key in redis_client.scan_iter(match=f"{cls._cache_prefix}:*", count=100):
pipeline.delete(key)
pipeline.execute()
logger.info("Cleared Redis cache")
return True
except Exception as e:
logger.warning(f"Failed to clear Redis cache: {e}")

logger.info("Cleared memory cache")
return True

@classmethod
def clear(cls):
cls.routes_cache.clear()
def close(cls):
"""Close Redis connections and cleanup resources."""
RedisManager.close()
80 changes: 80 additions & 0 deletions src/reagentai/common/utils/redis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from collections.abc import Generator
from contextlib import contextmanager
import logging
import os

import redis
from redis.exceptions import ConnectionError, RedisError, TimeoutError

logger = logging.getLogger(__name__)


class RedisManager:
"""Centralized Redis connection management."""

_pool: redis.ConnectionPool | None = None

@classmethod
def get_pool(cls) -> redis.ConnectionPool | None:
"""Get or create Redis connection pool."""
if cls._pool is None:
try:
cls._pool = redis.ConnectionPool(
host=os.getenv("REDIS_HOST", "localhost"),
port=int(os.getenv("REDIS_PORT", "6379")),
password=os.getenv("REDIS_PASSWORD"),
decode_responses=False,
socket_timeout=5,
socket_connect_timeout=5,
retry_on_timeout=True,
max_connections=20,
health_check_interval=30,
)
# Test connection
with redis.Redis(connection_pool=cls._pool) as client:
client.ping()
logger.info("Redis connection pool initialized successfully")
except (ConnectionError, TimeoutError, RedisError) as e:
logger.warning(f"Redis connection failed: {e}")
cls._pool = None
except Exception as e:
logger.error(f"Unexpected error initializing Redis: {e}")
cls._pool = None
return cls._pool

@classmethod
@contextmanager
def get_client(cls) -> Generator[redis.Redis | None, None, None]:
"""Context manager for Redis client."""
pool = cls.get_pool()
if pool is None:
yield None
return

client = None
try:
client = redis.Redis(connection_pool=pool)
yield client
except (ConnectionError, TimeoutError, RedisError) as e:
logger.warning(f"Redis operation failed: {e}")
yield None
except Exception as e:
logger.error(f"Unexpected Redis error: {e}")
yield None
finally:
if client:
try:
client.close()
except Exception:
pass

@classmethod
def close(cls):
"""Close Redis connection pool."""
if cls._pool:
try:
cls._pool.disconnect()
cls._pool = None
logger.info("Redis connection pool closed")
except Exception as e:
logger.warning(f"Error closing Redis pool: {e}")
Loading