Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions Dockerfile.mlflow
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
FROM python:3.11-slim

# Install MLflow and downgrade protobuf to a compatible version
RUN pip install mlflow==2.19.0 protobuf==3.20.1

# Expose the port for MLflow UI
EXPOSE 5000

# Command to run the MLflow server
CMD ["mlflow", "server", "--host", "0.0.0.0", "--port", "5000"]
33 changes: 32 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,24 @@ Create a `.env` file in the root directory with the following content:
```env
GEMINI_API_KEY=your_gemini_api_key
```
The simplest way to run the application is to use `uv` and `docker` commands:
```shell
uv run download_public_data data
sudo docker compose up
```
Then open your browser and go to:
http://localhost:7860/ (ReAgentAI)
http://localhost:5000/ (MLflow)

#### setup with uv
```sh
uv run download_public_data data
uv run run.py
```
Optionally you can set the `MLFLOW_TRACKING_URI` environment variable to point to your MLflow server:
```sh
MLFLOW_TRACKING_URI=http://localhost:5000 uv run run.py
```
Note: You need a trained model and a stock collection. You can download a publicly available model based on USPTO and a stock
collection from ZINC database using the following command `download_public_data data`.
#### setup with pip
Expand All @@ -35,12 +47,31 @@ Build the Docker image:
```sh
sudo docker build -t reagentai .
```
Optionally you can set the `MLFLOW_TRACKING_URI` environment variable to point to your MLflow server.
Run the Docker container:
```sh
sudo docker run -p 7860:7860 --env-file .env reagentai
```
Access the application in your browser at: http://127.0.0.1:7860/

### MLflow
```shell
uv run mlflow server
```
### MLflow in Docker container
```shell
sudo docker network create mlflow-network
```
```shell
sudo docker build -f Dockerfile.mlflow -t mlflow-server .
```
To run just mlflow-serwer:
```shell
sudo docker run --rm -p 5000:5000 mlflow-server
```
To run both MLflow serwer and ReAgentAI:
```shell
sudo docker compose up
```


### Troubleshooting
Expand Down
24 changes: 24 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
version: '3'
services:
mlflow:
build:
context: .
dockerfile: Dockerfile.mlflow
ports:
- "5000:5000"
networks:
- mlflow-network
reagentai:
build:
context: .
dockerfile: Dockerfile
ports:
- "7860:7860"
env_file:
- .env
environment:
- MLFLOW_TRACKING_URI=http://mlflow:5000
networks:
- mlflow-network
networks:
mlflow-network:
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ dependencies = [
"pydantic-ai>=0.2.4",
"python-dotenv>=1.1.0",
"gradio>=5.29.1",
"mlflow>=2.22.1",
]

[tool.black]
Expand Down
5 changes: 5 additions & 0 deletions src/reagentai/agents/main/main_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pydantic_ai import Tool

from src.reagentai.common.client import LLMClient
from src.reagentai.common.mlflow_tracking import MLflowTracker
from src.reagentai.constants import AIZYNTHFINDER_CONFIG_PATH
from src.reagentai.tools.retrosynthesis import (
initialize_aizynthfinder,
Expand All @@ -14,6 +15,9 @@
MAIN_AGENT_INSTRUCTIONS_PATH: str = "src/reagentai/agents/main/instructions.txt"
MAIN_AGENT_MODEL: str = "google-gla:gemini-2.0-flash"

# Initialize MLflow tracker for the main agent
agent_mlflow_tracker = MLflowTracker(experiment_name="main_agent_interactions")


@dataclass
class MainAgentDependencyTypes:
Expand Down Expand Up @@ -53,6 +57,7 @@ def create_main_agent() -> LLMClient:
instructions=instructions,
dependency_types=MainAgentDependencyTypes,
dependencies=MainAgentDependencyTypes(aizynth_finder=aizynth_finder),
mlflow_tracker=agent_mlflow_tracker,
)

return llm_client
96 changes: 86 additions & 10 deletions src/reagentai/common/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pydantic_ai import Agent, Tool
from pydantic_ai.tools import AgentDepsT

from src.reagentai.common.mlflow_tracking import MLflowTracker
from src.reagentai.models.llm_output import MultipleOutputs

logger = logging.getLogger(__name__)
Expand All @@ -19,6 +20,7 @@ def __init__(
instructions: str | None = None,
dependency_types: AgentDepsT | None = None,
dependencies: Any | None = None,
mlflow_tracker: MLflowTracker | None = None,
):
"""
Initializes the LLMClient with the specified model, tools, instructions, and dependencies.
Expand All @@ -29,12 +31,14 @@ def __init__(
instructions (str | None): Instructions for the agent, if any.
dependency_types (AgentDepsT | None): Types of dependencies required by the agent.
dependencies (Any | None): Actual dependencies to be used by the agent.
mlflow_tracker (MLflowTracker | None): MLflow tracker for experiment tracking.
"""

self.model_name = model_name
self.instructions = instructions
self.tools = tools
self.dependencies = dependencies
self.mlflow_tracker = mlflow_tracker or MLflowTracker()

self.agent = Agent(
model_name,
Expand All @@ -47,6 +51,17 @@ def __init__(
self.result_history = None
logger.info(f"LLMClient initialized with model: {model_name}")

# Log agent initialization
self.mlflow_tracker.start_run("agent_initialization")
self.mlflow_tracker.log_params(
{
"model_name": model_name,
"num_tools": len(tools),
"instructions_length": len(instructions) if instructions else 0,
}
)
self.mlflow_tracker.end_run()

def set_model(self, model_name: str):
"""
Sets the model for the LLMClient.
Expand All @@ -62,6 +77,11 @@ def set_model(self, model_name: str):
)
logger.info(f"LLMClient model set to: {model_name}")

# Log model change
self.mlflow_tracker.start_run("model_change")
self.mlflow_tracker.log_params({"new_model": model_name})
self.mlflow_tracker.end_run()

def get_token_usage(self) -> int:
"""
Returns the token usage of the current agent.
Expand Down Expand Up @@ -95,18 +115,74 @@ def respond(self, user_query: str, **kwargs) -> list[dict]:
list[dict]: A list of messages generated by the agent in response to the user query.
"""

# Start an MLflow run for this interaction
# self.mlflow_tracker.start_run(f"interaction_{hash(user_query)[:8] if hash(user_query) else 'new'}")
# Convert the hash to a string before slicing
self.mlflow_tracker.start_run(
f"interaction_{str(hash(user_query))[:8] if user_query else 'new'}"
)
self.mlflow_tracker.log_params(
{
"query_length": len(user_query),
}
)

if self.result_history is not None:
message_history = self.result_history.all_messages()
else:
message_history = None

result = self.agent.run_sync(
user_query,
message_history=message_history,
deps=self.dependencies,
**kwargs,
)
self.result_history = result
logger.info(f"LLMClient response: {result.output}")
bot_message = result.output.to_message()
return bot_message
try:
# Track token usage before the response
tokens_before = self.get_token_usage()

import time

start_time = time.time()

result = self.agent.run_sync(
user_query,
message_history=message_history,
deps=self.dependencies,
**kwargs,
)

response_time = time.time() - start_time

self.result_history = result
logger.info(f"LLMClient response: {result.output}")
bot_message = result.output.to_message()

# Calculate tokens used in this interaction
tokens_after = self.get_token_usage()
tokens_used = tokens_after - tokens_before

# Log metrics to MLflow
self.mlflow_tracker.log_metrics(
{
"response_time_seconds": response_time,
"tokens_used": tokens_used,
"total_tokens": tokens_after,
}
)

# Log which tools were used, if any
tool_calls = []
if hasattr(result, "tool_calls") and result.tool_calls:
for tool_call in result.tool_calls:
tool_calls.append(tool_call.name)

self.mlflow_tracker.log_params(
{
"tools_used": ", ".join(tool_calls),
"num_tool_calls": len(tool_calls),
}
)

self.mlflow_tracker.end_run()
return bot_message

except Exception as e:
self.mlflow_tracker.log_params({"error": str(e), "error_type": type(e).__name__})
self.mlflow_tracker.end_run()
raise
Loading