-
Notifications
You must be signed in to change notification settings - Fork 1
[feat] MLflow Integration #14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
57db7aa
fc1b8e5
e6be060
6ef24b6
b4cad46
9f98284
74d8bbd
566f718
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| FROM python:3.11-slim | ||
|
|
||
| # Install MLflow and downgrade protobuf to a compatible version | ||
| RUN pip install mlflow==2.22.1 protobuf==3.20.1 | ||
|
|
||
| # Expose the port for MLflow UI | ||
| EXPOSE 5000 | ||
|
|
||
| # Command to run the MLflow server | ||
| CMD ["mlflow", "server", "--host", "0.0.0.0", "--port", "5000"] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,24 @@ | ||
| version: '3' | ||
| services: | ||
| mlflow: | ||
| build: | ||
| context: . | ||
| dockerfile: Dockerfile.mlflow | ||
| ports: | ||
| - "5000:5000" | ||
| networks: | ||
| - mlflow-network | ||
| reagentai: | ||
| build: | ||
| context: . | ||
| dockerfile: Dockerfile | ||
| ports: | ||
| - "7860:7860" | ||
| env_file: | ||
| - .env | ||
| environment: | ||
| - MLFLOW_TRACKING_URI=http://mlflow:5000 | ||
| networks: | ||
| - mlflow-network | ||
| networks: | ||
| mlflow-network: | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,8 +9,18 @@ dependencies = [ | |
| "pydantic-ai>=0.2.4", | ||
| "python-dotenv>=1.1.0", | ||
| "gradio>=5.29.1", | ||
| "mlflow>=2.22.1", | ||
| ] | ||
|
|
||
| [dependency-groups] | ||
| dev = [ | ||
| "pytest>=8.3.5", | ||
| ] | ||
|
|
||
| [tool.pytest.ini_options] | ||
| testpaths = ["tests"] | ||
| python_files = "test_*.py" | ||
|
|
||
|
Comment on lines
+20
to
+23
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe add this? Because it doesn't work without that for me (the reagent module is not found when running tests). pythonpath = [ |
||
| [tool.black] | ||
| target-version = ["py310", "py311"] | ||
| line-length = 99 | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,118 @@ | ||
| import logging | ||
| import os | ||
| from typing import Any | ||
|
|
||
| import mlflow | ||
| from mlflow.tracking import MlflowClient | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class MLflowTracker: | ||
| """ | ||
| Handles MLflow experiment tracking for ReagentAI. | ||
| """ | ||
|
|
||
| def __init__(self, experiment_name: str = "reagentai_experiments"): | ||
| """ | ||
| Initialize MLflow tracker. | ||
|
|
||
| Args: | ||
| experiment_name: The name of the MLflow experiment to use | ||
| """ | ||
| self.experiment_name = experiment_name | ||
| self.tracking_uri = os.environ.get("MLFLOW_TRACKING_URI", "") | ||
| self.active_run = None | ||
| self.mlflow_enabled = bool(self.tracking_uri) | ||
|
|
||
| if not self.mlflow_enabled: | ||
| logger.info("MLflow tracking is disabled - MLFLOW_TRACKING_URI is not set.") | ||
| return | ||
|
|
||
| try: | ||
| mlflow.set_tracking_uri(self.tracking_uri) | ||
|
|
||
| # Create or get the experiment | ||
| self.experiment = mlflow.get_experiment_by_name(experiment_name) | ||
| if not self.experiment: | ||
| self.experiment_id = mlflow.create_experiment(experiment_name) | ||
| else: | ||
| self.experiment_id = self.experiment.experiment_id | ||
|
|
||
| self.client = MlflowClient() | ||
| except Exception as e: | ||
| logger.warning(f"MLflow experiment setup failed: {e}") | ||
| self.mlflow_enabled = False | ||
|
|
||
| def start_run(self, run_name: str | None = None) -> str | None: | ||
| """ | ||
| Start a new MLflow run. | ||
|
|
||
| Args: | ||
| run_name: Optional name for the run | ||
|
|
||
| Returns: | ||
| The run ID of the created run or None if MLflow is disabled | ||
| """ | ||
| if not self.mlflow_enabled: | ||
| logger.debug("MLflow tracking disabled. Not starting run.") | ||
| return None | ||
|
|
||
| try: | ||
| self.active_run = mlflow.start_run(experiment_id=self.experiment_id, run_name=run_name) | ||
| return self.active_run.info.run_id | ||
| except Exception as e: | ||
| logger.warning(f"Failed to start MLflow run: {e}") | ||
| self.mlflow_enabled = False | ||
| return None | ||
|
|
||
| def end_run(self) -> None: | ||
| """End the current MLflow run.""" | ||
| if not self.mlflow_enabled or not self.active_run: | ||
| return | ||
|
|
||
| try: | ||
| mlflow.end_run() | ||
| self.active_run = None | ||
| except Exception as e: | ||
| logger.warning(f"Error ending MLflow run: {e}") | ||
|
|
||
| def log_params(self, params: dict[str, Any]) -> None: | ||
| """Log parameters to the current run.""" | ||
| if not self.mlflow_enabled or not self.active_run: | ||
| return | ||
|
|
||
| try: | ||
| mlflow.log_params(params) | ||
| except Exception as e: | ||
| logger.warning(f"Failed to log params to MLflow: {e}") | ||
|
|
||
| def log_metrics(self, metrics: dict[str, float | int], step: int | None = None) -> None: | ||
| """Log metrics to the current run.""" | ||
| if not self.mlflow_enabled or not self.active_run: | ||
| return | ||
|
|
||
| try: | ||
| mlflow.log_metrics(metrics, step=step) | ||
| except Exception as e: | ||
| logger.warning(f"Failed to log metrics to MLflow: {e}") | ||
|
|
||
| def log_artifact(self, local_path: str) -> None: | ||
| """Log an artifact to the current run.""" | ||
| if not self.mlflow_enabled or not self.active_run: | ||
| return | ||
|
|
||
| try: | ||
| mlflow.log_artifact(local_path) | ||
| except Exception as e: | ||
| logger.warning(f"Failed to log artifact to MLflow: {e}") | ||
|
|
||
| def set_tags(self, tags: dict[str, str]) -> None: | ||
| """Set tags on the current run.""" | ||
| if not self.mlflow_enabled or not self.active_run: | ||
| return | ||
|
|
||
| try: | ||
| mlflow.set_tags(tags) | ||
| except Exception as e: | ||
| logger.warning(f"Failed to set tags in MLflow: {e}") |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,8 +1,10 @@ | ||
| import logging | ||
| import os | ||
|
|
||
| from dotenv import load_dotenv | ||
|
|
||
| from src.reagentai.agents.main.main_agent import create_main_agent | ||
| from src.reagentai.common.mlflow_tracking import MLflowTracker | ||
| from src.reagentai.logging import setup_logging | ||
| from src.reagentai.ui.app import create_gradio_app | ||
|
|
||
|
|
@@ -13,7 +15,45 @@ def start_agent(): | |
| setup_logging() | ||
| load_dotenv() | ||
|
|
||
| # Initialize MLflow tracking | ||
| tracker = MLflowTracker(experiment_name="reagentai_experiments") | ||
|
|
||
| # Start a new run for this application session | ||
| run_id = tracker.start_run(run_name="reagentai_session") | ||
| logger.info(f"MLflow tracking {'enabled' if run_id else 'disabled'}") | ||
|
|
||
| # Log system information and configuration parameters | ||
| if tracker.mlflow_enabled: | ||
| import platform | ||
| import sys | ||
|
|
||
| # Log system info as tags | ||
| tracker.set_tags( | ||
| { | ||
| "python_version": sys.version, | ||
| "platform": platform.platform(), | ||
| "application": "ReagentAI", | ||
| } | ||
| ) | ||
|
|
||
| # Log configuration parameters | ||
| tracker.log_params( | ||
| { | ||
| "log_to_file": os.environ.get("LOG_TO_FILE", "True"), | ||
| "app_version": "0.1.0", # Could be pulled from a version file | ||
| } | ||
| ) | ||
|
|
||
|
Comment on lines
+18
to
+46
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe add a function like setup_mlflow for this? |
||
| main_agent = create_main_agent() | ||
| app = create_gradio_app(main_agent) | ||
|
|
||
| app.launch(server_name="0.0.0.0") | ||
| # Pass the MLflow tracker to the Gradio app | ||
| app = create_gradio_app(main_agent, mlflow_tracker=tracker) | ||
|
|
||
| # Launch the application | ||
| try: | ||
| app.launch(server_name="0.0.0.0") | ||
| finally: | ||
| # End the MLflow run when the application exits | ||
| if tracker.mlflow_enabled: | ||
| tracker.end_run() | ||
| logger.info("MLflow tracking session ended") | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -74,7 +74,7 @@ def add_user_message_to_history( | |
|
|
||
|
|
||
| def handle_bot_response( | ||
| chat_history: ChatHistory, llm_client: MainAgent | ||
| chat_history: ChatHistory, llm_client: MainAgent, mlflow_tracker=None | ||
| ) -> tuple[ChatHistory, int]: | ||
| """ | ||
| Gets LLM response, updates chat history and token usage. | ||
|
|
@@ -83,6 +83,20 @@ def handle_bot_response( | |
| response: list[ChatMessage] = llm_client.respond(user_query) | ||
| chat_history.extend(response) | ||
| token_used: int = llm_client.get_token_usage() | ||
|
|
||
| # Log metrics to MLflow | ||
| if mlflow_tracker and mlflow_tracker.mlflow_enabled: | ||
| mlflow_tracker.log_metrics( | ||
| {"token_usage": token_used, "conversation_length": len(chat_history)} | ||
| ) | ||
|
|
||
| # Log user query as param for tracking purposes | ||
| mlflow_tracker.log_params( | ||
| { | ||
| f"query_{len(chat_history)}": user_query[:100] # Truncate long queries | ||
| } | ||
| ) | ||
|
|
||
|
Comment on lines
+86
to
+99
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wouldn't pass the mlflow instance to the Gradio frontend, I don't think we should keep logic like that there. Maybe we should pass the tracker to the Agent class itself. |
||
| return chat_history, token_used | ||
|
|
||
|
|
||
|
|
@@ -96,12 +110,17 @@ def handle_clear_chat(llm_client: MainAgent) -> tuple[list[None], Literal[0]]: | |
| return [], 0 | ||
|
|
||
|
|
||
| def handle_model_change(model_name: str, llm_client: MainAgent) -> None: | ||
| def handle_model_change(model_name: str, llm_client: MainAgent, mlflow_tracker=None) -> None: | ||
| """ | ||
| Sets the new LLM model in the client. | ||
| """ | ||
| llm_client.set_model(model_name) | ||
|
|
||
| # Log model change to MLflow | ||
| if mlflow_tracker and mlflow_tracker.mlflow_enabled: | ||
| mlflow_tracker.log_params({"llm_model": model_name}) | ||
| mlflow_tracker.set_tags({"model_changed": "true"}) | ||
|
|
||
|
|
||
| def re_enable_chat_input() -> gr.MultimodalTextbox: | ||
| """ | ||
|
|
@@ -111,7 +130,7 @@ def re_enable_chat_input() -> gr.MultimodalTextbox: | |
|
|
||
|
|
||
| # Main App Creation Function | ||
| def create_gradio_app(llm_client: MainAgent) -> gr.Blocks: | ||
| def create_gradio_app(llm_client: MainAgent, mlflow_tracker=None) -> gr.Blocks: | ||
| with gr.Blocks( | ||
| theme=gr.themes.Origin(), | ||
| ) as demo: | ||
|
|
@@ -131,7 +150,9 @@ def create_gradio_app(llm_client: MainAgent) -> gr.Blocks: | |
| inputs=[chatbot_display, chat_input], | ||
| outputs=[chatbot_display, chat_input], | ||
| ).then( | ||
| fn=functools.partial(handle_bot_response, llm_client=llm_client), | ||
| fn=functools.partial( | ||
| handle_bot_response, llm_client=llm_client, mlflow_tracker=mlflow_tracker | ||
| ), | ||
| inputs=chatbot_display, | ||
| outputs=[chatbot_display, token_usage_display], | ||
| api_name="bot_response", | ||
|
|
@@ -145,7 +166,9 @@ def create_gradio_app(llm_client: MainAgent) -> gr.Blocks: | |
| ) | ||
|
|
||
| llm_model_dropdown.change( | ||
| fn=functools.partial(handle_model_change, llm_client=llm_client), | ||
| fn=functools.partial( | ||
| handle_model_change, llm_client=llm_client, mlflow_tracker=mlflow_tracker | ||
| ), | ||
| inputs=llm_model_dropdown, | ||
| outputs=[], | ||
| ) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nitpick: missing new line at the end