diff --git a/llm/dashboard.json b/llm/dashboard.json new file mode 100644 index 0000000..4e3adb5 --- /dev/null +++ b/llm/dashboard.json @@ -0,0 +1,249 @@ +{ + "annotations": { + "list": [ + { + "builtIn": 1, + "datasource": { + "type": "grafana", + "uid": "-- Grafana --" + }, + "enable": true, + "hide": true, + "iconColor": "rgba(0, 211, 255, 1)", + "name": "Annotations & Alerts", + "type": "dashboard" + } + ] + }, + "editable": true, + "fiscalYearStartMonth": 0, + "graphTooltip": 0, + "id": 1, + "links": [], + "panels": [ + { + "datasource": { + "type": "grafana-clickhouse-datasource", + "uid": "aesy917veeltsd" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "fillOpacity": 80, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "lineWidth": 1, + "scaleDistribution": { + "type": "linear" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": 0 + }, + { + "color": "red", + "value": 80 + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 0 + }, + "id": 2, + "options": { + "barRadius": 0, + "barWidth": 0.97, + "fullHighlight": false, + "groupWidth": 0.7, + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "orientation": "auto", + "showValue": "always", + "stacking": "none", + "tooltip": { + "hideZeros": false, + "mode": "single", + "sort": "none" + }, + "xField": "total_tokens", + "xTickLabelRotation": 0, + "xTickLabelSpacing": 0 + }, + "pluginVersion": "12.1.0", + "targets": [ + { + "editorType": "sql", + "format": 1, + "meta": { + "builderOptions": { + "columns": [], + "database": "", + "limit": 1000, + "mode": "list", + "queryType": "table", + "table": "" + } + }, + "pluginVersion": "4.10.1", + "queryType": "table", + "rawSql": "SELECT parent_run_id, total_tokens, total_cost, status\r\nFROM \"guardian\".\"langchain_metrics\"\r\nORDER BY start_time", + "refId": "A" + } + ], + "title": "Token Metrics", + "type": "barchart" + }, + { + "datasource": { + "type": "grafana-clickhouse-datasource", + "uid": "aesy917veeltsd" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "continuous-BlYlRd" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "fillOpacity": 80, + "gradientMode": "hue", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "lineWidth": 1, + "scaleDistribution": { + "type": "linear" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": 0 + }, + { + "color": "red", + "value": 80 + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 8 + }, + "id": 1, + "options": { + "barRadius": 0, + "barWidth": 0.97, + "colorByField": "duration", + "fullHighlight": false, + "groupWidth": 0.7, + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "orientation": "auto", + "showValue": "auto", + "stacking": "none", + "tooltip": { + "hideZeros": false, + "mode": "single", + "sort": "none" + }, + "xTickLabelRotation": 0, + "xTickLabelSpacing": 0 + }, + "pluginVersion": "12.1.0", + "targets": [ + { + "datasource": { + "type": "grafana-clickhouse-datasource", + "uid": "aesy917veeltsd" + }, + "editorType": "sql", + "format": 1, + "meta": { + "builderOptions": { + "columns": [], + "database": "", + "limit": 1000, + "mode": "list", + "queryType": "table", + "table": "" + } + }, + "pluginVersion": "4.10.1", + "queryType": "table", + "rawSql": "SELECT parent_run_id, duration, status, outputs\r\nFROM \"guardian\".\"langchain_metrics\"\r\nWHERE parent_run_id != 'None'\r\nORDER BY start_time ASC\r\nLIMIT 1000", + "refId": "A" + } + ], + "title": "Duration Metrics", + "transparent": true, + "type": "barchart" + } + ], + "preload": false, + "schemaVersion": 41, + "tags": [], + "templating": { + "list": [] + }, + "time": { + "from": "2025-07-27T10:59:29.712Z", + "to": "2025-07-28T10:59:29.712Z" + }, + "timepicker": {}, + "timezone": "browser", + "title": "DB Metrics Visualization", + "uid": "90ced2bd-5ea8-42c5-b87b-be9e1a8cdb4c", + "version": 11 +} \ No newline at end of file diff --git a/llm/docker-compose.yml b/llm/docker-compose.yml new file mode 100644 index 0000000..a70a8b3 --- /dev/null +++ b/llm/docker-compose.yml @@ -0,0 +1,13 @@ +version: "3.8" + +services: + grafana: + image: grafana/grafana-oss + container_name: grafana + ports: + - "3000:3000" + environment: + - GF_SECURITY_ALLOW_EMBEDDING=true + - GF_AUTH_ANONYMOUS_ENABLED=true + - GF_AUTH_ANONYMOUS_ORG_ROLE=Viewer + restart: unless-stopped \ No newline at end of file diff --git a/llm/langchain_gui.py b/llm/langchain_gui.py index 96c4d7f..527d21c 100644 --- a/llm/langchain_gui.py +++ b/llm/langchain_gui.py @@ -14,7 +14,8 @@ LOADING_URL = "https://cdn.pixabay.com/animation/2025/04/08/09/08/09-08-31-655_512.gif" DURATION_METRICS_URL = "http://localhost:3000/d-solo/90ced2bd-5ea8-42c5-b87b-be9e1a8cdb4c/db-metrics-visualization?orgId=1&from=1753395832378&to=1753417432378&timezone=browser&panelId=1&__feature.dashboardSceneSolo=true" TOKEN_METRICS_URL = "http://localhost:3000/d-solo/90ced2bd-5ea8-42c5-b87b-be9e1a8cdb4c/db-metrics-visualization?orgId=1&from=1753645594350&to=1753667194350&timezone=browser&panelId=2&__feature.dashboardSceneSolo=true" - +MOST_RECENT_RUNS_URL = "http://localhost:3000/d-solo/90ced2bd-5ea8-42c5-b87b-be9e1a8cdb4c/db-metrics-visualization?orgId=1&from=1753613969712&to=1753700369712&timezone=browser&panelId=4&__feature.dashboardSceneSolo=true" +CONTEXT_METRICS_URL = "http://localhost:3000/d-solo/90ced2bd-5ea8-42c5-b87b-be9e1a8cdb4c/db-metrics-visualization?orgId=1&from=1753645594350&to=1753667194350&timezone=browser&panelId=3&__feature.dashboardSceneSolo=true" # Database options - customize these based on your available databases DATABASE_OPTIONS = [ db.value[0] for db in Database @@ -668,14 +669,24 @@ def log_user_interaction(interaction_type, details): with tab4: - tab4a, tab4b = st.tabs(["Duration Metrics", "Token Metrics"]) + tab4a, tab4b, tab4c, tab4d = st.tabs(["Most Recent Run", "Duration Metrics", "Token Metrics", "Context Metrics"]) with tab4a: + st.components.v1.html( + f'', + height=800) + + with tab4b: st.components.v1.html( f'', height=600) - with tab4b: + with tab4c: st.components.v1.html( f'', + height=600) + + with tab4d: + st.components.v1.html( + f'', height=600) \ No newline at end of file diff --git a/llm/langchain_metrics.py b/llm/langchain_metrics.py new file mode 100644 index 0000000..c071761 --- /dev/null +++ b/llm/langchain_metrics.py @@ -0,0 +1,211 @@ +from langsmith import Client +import os +import json +from datetime import datetime, timedelta +import csv +import clickhouse_connect +from dotenv import load_dotenv +from typing import List, Any +from pytz import timezone, utc +load_dotenv() + + +# Check if API key is set from environment variable +print("Current system time:", datetime.now()) + +class LangchainMetrics: + def __init__(self): + self.API_KEY = os.getenv("LANGSMITH_API_KEY") + # Initialize client (uses LANGSMITH_API_KEY env var) + self.client = Client(api_key=self.API_KEY) + + + def connect_clickhouse(self): + """Connect to ClickHouse database""" + try: + self.clickhouse_client = clickhouse_connect.get_client( + host='10.0.100.92', + port=8123, + username='user', + password='default', + database='guardian' + ) + print("Connected to ClickHouse successfully") + return True + except Exception as e: + print(f"Failed to connect to ClickHouse: {e}") + self.clickhouse_client = None + return False + + def get_runs_by_id(self, run_ids: List[str]) -> List[Any]: + # Convert UUID objects to strings if needed + run_ids = [str(run_id) for run_id in run_ids] + runs = list(self.client.list_runs(run_ids=run_ids)) + return runs[0] if runs else None + + def get_runs_by_id_safe(self, run_ids: List[str]) -> Any: + """Safe version that returns None instead of raising errors""" + try: + return self.get_runs_by_id(run_ids) + except Exception as e: + print(f"Error getting run by ID: {e}") + return None + + def find_root_run_id(self, run_id): + current_run = self.get_runs_by_id([run_id]) + while current_run and current_run.parent_run_id: + current_run = self.get_runs_by_id([current_run.parent_run_id]) + return current_run.id if current_run else None + + def save_to_clickhouse(self, run): + # Upload data to ClickHouse table (assumes connection already established) + table_name = "langchain_metrics" + + query = f""" + CREATE TABLE IF NOT EXISTS {table_name} ( + id String, + name String, + status String, + total_tokens Int32, + total_cost Float64, + start_time DateTime, + end_time DateTime, + duration Float64, + inputs String, + outputs String, + error String, + tags String, + metadata String, + parent_run_id String, + child_runs String, + ) ENGINE = MergeTree() + ORDER BY id + """ + # Create table if not exists (simple schema) + self.clickhouse_client.command(query) + + # Prepare row for insertion with safe datetime handling + start_time = run.start_time if run.start_time else datetime.now() + # end_time might be None if run is still in progress + end_time = run.end_time if run.end_time else datetime.now() + duration = (end_time - start_time).total_seconds() if end_time and start_time else 0.0 + + # make sure start and end time are in the same timezone + start_time = start_time.astimezone(timezone('Australia/Sydney')) + end_time = end_time.astimezone(timezone('Australia/Sydney')) + + row = [ + str(run.id), + str(run.name) if run.name else "", + str(run.status) if run.status else "", + int(run.total_tokens) if run.total_tokens is not None else 0, + float(run.total_cost) if run.total_cost is not None else 0.0, + start_time, + end_time, + float(duration), + json.dumps(run.inputs) if run.inputs else "{}", + json.dumps(run.outputs) if run.outputs else "{}", + str(run.error) if run.error else "", + json.dumps(run.tags) if run.tags else "[]", + json.dumps(run.metadata) if run.metadata else "{}", + str(run.parent_run_id) if run.parent_run_id else "", + json.dumps(run.child_runs) if run.child_runs else "[]", + ] + + # Only insert if we have valid data + if start_time: + self.clickhouse_client.insert( + table_name, + [row], # a list of rows + column_names=[ + "id", "name", "status", "total_tokens", "total_cost", "start_time", "end_time", + "duration", "inputs", "outputs", "error", "tags", "metadata", "parent_run_id", "child_runs" + ] + ) + print(f"Run {run.id} saved to ClickHouse") + else: + print(f"Run {run.id} not saved - missing start_time") + + def get_runs(self, num_runs: int = 1, run_ids: List[str] = None, run_name: str = None): + # Get runs from the last hour + print("Getting recent runs...") + try: + recent_runs = list(self.client.list_runs( + project_name="default", # Your LangSmith project + limit = num_runs, + run_ids = run_ids, + name = run_name, # Filter by run name + )) + print(f"Found {len(recent_runs)} runs") + + # Get metrics from recent runs and convert to a json + for run in recent_runs: + print(f"Run ID: {run.id}") + print(f"Trace ID: {run.trace_id}") + print(f"Name: {run.name}") + print(f"Status: {run.status}") + print(f"Tokens: {run.total_tokens}") + print(f"Cost: ${run.total_cost}") + print(f"Start time: {run.start_time}") + print(f"End time: {run.end_time}") + print(f"Duration: {run.end_time - run.start_time}") + print(f"Input: {run.inputs}") + print(f"Output: {run.outputs}") + print(f"Error: {run.error}") + print(f"Tags: {run.tags}") + print(f"Metadata: {run.metadata}") + print(f"Parent Run ID: {run.parent_run_id}") + print(f"Child Runs: {run.child_runs}") + print("---") + + + + # csv_file = "recent_runs.csv" + # fieldnames = [ + # "id", "name", "status", "total_tokens", "total_cost", "start_time", "end_time", + # "duration", "inputs", "outputs", "error", "tags", "metadata", "parent_run_id", "child_runs" + # ] + # # Check if file exists to write header only once + # write_header = not os.path.exists(csv_file) + # with open(csv_file, mode="a", newline="", encoding="utf-8") as f: + # writer = csv.DictWriter(f, fieldnames=fieldnames) + # if write_header: + # writer.writeheader() + # writer.writerow({ + # "id": run.id, + # "name": run.name, + # "status": run.status, + # "total_tokens": run.total_tokens, + # "total_cost": run.total_cost, + # "start_time": run.start_time, + # "end_time": run.end_time, + # "duration": (run.end_time - run.start_time) if run.end_time and run.start_time else None, + # "inputs": json.dumps(run.inputs), + # "outputs": json.dumps(run.outputs), + # "error": run.error, + # "tags": json.dumps(run.tags), + # "metadata": json.dumps(run.metadata), + # "parent_run_id": run.parent_run_id, + # "child_runs": json.dumps(run.child_runs), + # }) + return recent_runs + + + + except Exception as e: + print(f"Error: {e}") + return [] + +if __name__ == "__main__": + langchain_metrics = LangchainMetrics() + langchain_metrics.connect_clickhouse() + runs = langchain_metrics.get_runs() + for run in runs: + langchain_metrics.save_to_clickhouse(run=run) + + + + + + + diff --git a/llm/langchain_metrics2.py b/llm/langchain_metrics2.py new file mode 100644 index 0000000..1417bd2 --- /dev/null +++ b/llm/langchain_metrics2.py @@ -0,0 +1,147 @@ +from langsmith import Client +import os +import csv +from io import StringIO +from datetime import datetime, timedelta +from typing import List, Optional, Dict, Any, Union +from dotenv import load_dotenv +load_dotenv() + + +class LangChainMetrics: + def __init__(self, project_name: str = "default"): + self.api_key = os.getenv("LANGSMITH_API_KEY") + self.project_name = project_name + self.client = Client(api_key=self.api_key) + + def get_latest_runs(self, limit: int = 50, hours_back: int = 24) -> List[Any]: + try: + start_time = datetime.now() - timedelta(hours=hours_back) + return list(self.client.list_runs( + project_name=self.project_name, + start_time=start_time, + limit=limit + )) + except Exception as e: + print(f"Error getting latest runs: {e}") + return [] + + def get_runs_by_id(self, run_ids: List[str]) -> List[Any]: + runs = [] + for run_id in run_ids: + try: + runs.append(self.client.read_run(run_id)) + except Exception as e: + print(f"Error retrieving run {run_id}: {e}") + return runs + + def get_runs_by_tags(self, tags: List[str], match_all: bool = True, limit: int = 100) -> List[Any]: + try: + operator = " and " if match_all else " or " + filter_str = operator.join([f'has(tags, "{tag}")' for tag in tags]) + + return list(self.client.list_runs( + project_name=self.project_name, + filter=filter_str, + limit=limit + )) + except Exception as e: + print(f"Error searching runs by tags {tags}: {e}") + return [] + + def get_runs_by_start_time( + self, + start_time_gte: datetime, + start_time_lte: Optional[datetime] = None, + limit: int = 100 + ) -> List[Any]: + try: + filters = [f'start_time >= "{start_time_gte.isoformat()}"'] + if start_time_lte: + filters.append(f'start_time <= "{start_time_lte.isoformat()}"') + + return list(self.client.list_runs( + project_name=self.project_name, + filter=" and ".join(filters), + limit=limit + )) + except Exception as e: + print(f"Error searching runs by start time: {e}") + return [] + + def pretty_print_metrics(self, runs: List[Any], group_by: str = "parent_id") -> None: + # Group runs + grouped_runs = {} + if group_by == "parent_id": + for run in runs: + parent_id = str(run.parent_run_id) if run.parent_run_id else str(run.id) + if parent_id not in grouped_runs: + grouped_runs[parent_id] = [] + grouped_runs[parent_id].append(run) + elif group_by == "tag": + for run in runs: + tags = run.tags if run.tags else ["no_tag"] + for tag in tags: + if tag not in grouped_runs: + grouped_runs[tag] = [] + grouped_runs[tag].append(run) + + # CSV format output + csv_buffer = StringIO() + + # Summary CSV + print("=== SUMMARY CSV ===") + summary_writer = csv.writer(csv_buffer) + summary_writer.writerow([ + "group_key", "group_type", "total_runs", "total_cost", "total_tokens", + "success_count", "error_count", "success_rate" + ]) + + for group_key, group_runs in grouped_runs.items(): + success_count = sum(1 for run in group_runs if run.status == "success") + summary_writer.writerow([ + group_key, + group_by, + len(group_runs), + sum(float(run.total_cost or 0) for run in group_runs), + sum(int(run.total_tokens or 0) for run in group_runs), + success_count, + len(group_runs) - success_count, + f"{(success_count / len(group_runs) * 100):.1f}%" if group_runs else "0%" + ]) + + #print(csv_buffer.getvalue()) + csv_buffer.seek(0) + csv_buffer.truncate(0) + + # Detailed runs CSV + print("\n=== DETAILED RUNS CSV ===") + detail_writer = csv.writer(csv_buffer) + detail_writer.writerow([ + "group_key", "run_id", "name", "status", "total_tokens", "total_cost", + "start_time", "end_time", "duration_seconds", "tags", "parent_run_id", "error" + ]) + + for group_key, group_runs in grouped_runs.items(): + for run in group_runs: + detail_writer.writerow([ + group_key, + str(run.id) if run.id else "", + str(run.name) if run.name else "", + str(run.status) if run.status else "", + int(run.total_tokens) if run.total_tokens else 0, + float(run.total_cost) if run.total_cost else 0.0, + run.start_time.isoformat() if run.start_time else "", + run.end_time.isoformat() if run.end_time else "", + (run.end_time - run.start_time).total_seconds() if run.start_time and run.end_time else "", + "|".join(run.tags) if run.tags else "", + str(run.parent_run_id) if run.parent_run_id else "", + str(run.error) if run.error else "" + ]) + + #print(csv_buffer.getvalue()) + +if __name__ == "__main__": + metrics = LangChainMetrics() + runs = metrics.get_latest_runs(limit=10) + metrics.pretty_print_metrics(runs, group_by="parent_id") \ No newline at end of file diff --git a/llm/llm_utils/langchain_pipeline.py b/llm/llm_utils/langchain_pipeline.py index 300fc88..9c68ed6 100644 --- a/llm/llm_utils/langchain_pipeline.py +++ b/llm/llm_utils/langchain_pipeline.py @@ -13,6 +13,11 @@ # === LangGraph imports === from langgraph.graph import StateGraph, START from typing_extensions import TypedDict +import sys +import os +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from langchain_metrics import LangchainMetrics +from langchain_core.callbacks.base import BaseCallbackHandler load_dotenv() @@ -29,8 +34,64 @@ class State(TypedDict): question: str context: List[Document] answer: str - port: int +class RunIdCollector(BaseCallbackHandler): + def __init__(self): + self.run_ids = [] + self.last_run_id = None # Add this line + self.chain_run_ids = [] + + def on_llm_start(self, serialized, prompts, *, run_id, parent_run_id=None, **kwargs): + self.run_ids.append(run_id) + print(f"LLM run started with run_id: {run_id}") + # Store run_id for later processing in on_llm_end + self.current_run_id = run_id + + def on_llm_end(self, response, *, run_id, parent_run_id=None, **kwargs): + print(f"LLM run ended with run_id: {run_id}") + # Process the run after it's completed + import time + time.sleep(1) # Give LangSmith time to save the run + langchain_metrics = LangchainMetrics() + langchain_metrics.connect_clickhouse() # Connect to ClickHouse + # run = langchain_metrics.get_runs_by_id_safe([run_id]) + # if run: + # # Process the run here + # langchain_metrics.save_to_clickhouse(run) + # else: + # print(f"No run found for run_id: {run_id}") + + run_name = "retrieve" + runs = langchain_metrics.get_runs(num_runs=1, run_ids=None, run_name=run_name) + print(f"Runs: {runs}") + + if runs and len(runs) > 0: + # Get the first run from the list + run = runs[0] + print(f"Found run with name: {run.name}") + #save run to clickhouse + langchain_metrics.save_to_clickhouse(run) + else: + print(f"No run found for run_name: {run_name}") + + + # def on_chain_start(self, inputs, *, run_id, parent_run_id=None, **kwargs): + # print(f"Chain run started with run_id: {run_id}") + # self.chain_run_ids.append(run_id) + + # def on_chain_end(self, outputs, *, run_id, parent_run_id=None, **kwargs): + # print(f"Chain run ended with run_id: {run_id}") + # # Process the run here + # langchain_metrics = LangchainMetrics() + # langchain_metrics.connect_clickhouse() # Connect to ClickHouse + # run = langchain_metrics.get_runs_by_id_safe([run_id]) + # if run: + # print(f"Found run: {run.trace_id}") + + + + + # 2. Step 1: retrieve relevant articles def retrieve(state: State) -> Dict[str, Any]: @@ -100,13 +161,16 @@ def __init__(self, name: str, max_articles: int = 5): api_key = os.getenv("ANTHROPIC_API_KEY") if not api_key: raise ValueError("ANTHROPIC_API_KEY is required") + + collector = RunIdCollector() self.llm = ChatAnthropic( model_name="claude-3-5-sonnet-latest", api_key=SecretStr(api_key), temperature=0.1, timeout=60, - stop=[] + stop=[], + callbacks=[collector] ) self.rag_prompt = PromptTemplate( input_variables=["question", "context"], @@ -121,9 +185,7 @@ def __init__(self, name: str, max_articles: int = 5): Please provide a comprehensive answer based on the context above. If the context doesn't contain enough information to answer the question, say so. Use the Guardian articles as your primary source of information. Answer:""" - ) - - # 4. Build the LangGraph orchestration + ) # 4. Build the LangGraph orchestration builder = StateGraph(State).add_sequence([ post, retrieve, @@ -161,6 +223,8 @@ def answer_question(self, question: str, database: str) -> Dict[str, Any]: for d in docs ] } + + except Exception as e: logging.error(f"RAG pipeline failed: {e}") return { @@ -175,10 +239,6 @@ def answer_question(self, question: str, database: str) -> Dict[str, Any]: if __name__ == "__main__": state_app = RAGApplication(max_articles=5) for q in [ - "Give me the latest on news corp columnist Lucy Zelić.", + "Give me the latest on Trump.", ]: res = state_app.answer_question(q) - print(f"\nQuestion: {res['question']}") - print(f"Answer: {res['answer']}") - print(f"Articles used: {res['articles_used']}") - print("-" * 60) diff --git a/llm/llm_utils/langchain_pipeline_2.py b/llm/llm_utils/langchain_pipeline_2.py new file mode 100644 index 0000000..10c1c25 --- /dev/null +++ b/llm/llm_utils/langchain_pipeline_2.py @@ -0,0 +1,211 @@ +import os +import logging +from dotenv import load_dotenv +from typing import List, Dict, Any +from anthropic import Anthropic +from langchain_anthropic import ChatAnthropic +from langchain.schema import Document +from langchain.prompts import PromptTemplate +from pydantic import SecretStr +import requests +from langchain_metrics import LangchainMetrics +from langchain_core.callbacks.base import BaseCallbackHandler +from langchain_core.runnables import RunnableSequence +from typing_extensions import TypedDict + +load_dotenv() + + +# 1. Define the RAG chain +class RAGChain: + """Custom RAG chain that combines retrieval and generation.""" + + def __init__(self, llm: ChatAnthropic, rag_prompt: PromptTemplate): + self.llm = llm + self.rag_prompt = rag_prompt + + def _retrieve_documents(self, question: str) -> List[Document]: + """Retrieve relevant documents.""" + if os.getenv("DATABASE_TYPE", "").lower() == "clickhouse": + port = 8000 + elif os.getenv("DATABASE_TYPE", "").lower() == "postgres": + port = 8001 + else: + raise ValueError("DATABASE_TYPE must be either clickhouse or postgres") + + docs = requests.get(f"http://localhost:{port}/related-articles?query={question}").json() + return [ + Document( + page_content=body, + metadata={ + "url": url, + "title": title, + "publication_date": pub_date, + "similarity_score": score + } + ) + for url, title, body, pub_date, score in docs + ] + + def invoke(self, inputs: Dict[str, Any], callbacks=None) -> Dict[str, Any]: + """Invoke the RAG chain with proper callback handling.""" + question = inputs["question"] + + # Step 1: Retrieve documents + docs = self._retrieve_documents(question) + + # Step 2: Generate answer using LLMChain for proper callback support + ctx = "\n\n".join( + f"Title: {doc.metadata['title']}\n" + f"Date: {doc.metadata['publication_date']}\n" + f"Content: {doc.page_content}" + for doc in docs + ) + + # Create a temporary runnable for the generation step + runnable = self.rag_prompt | self.llm + response = runnable.invoke( + {"question": question, "context": ctx}, + config={"callbacks": callbacks} + ) + + return { + "answer": response.content, + "context": [ + { + "title": d.metadata["title"], + "url": d.metadata["url"], + "publication_date": d.metadata["publication_date"], + "similarity_score": d.metadata["similarity_score"], + "snippet": (d.page_content[:200] + "...") if len(d.page_content) > 200 else d.page_content + } + for d in docs + ], + "articles_used": len(docs) + } + +class RunIdCollector(BaseCallbackHandler): + def __init__(self): + self.run_ids = [] + self.last_run_id = None # Add this line + self.chain_run_ids = [] + + def on_llm_start(self, serialized, prompts, *, run_id, parent_run_id=None, **kwargs): + self.run_ids.append(run_id) + print(f"LLM run started with run_id: {run_id}") + # Store run_id for later processing in on_llm_end + self.current_run_id = run_id + + def on_llm_end(self, response, *, run_id, parent_run_id=None, **kwargs): + print(f"LLM run ended with run_id: {run_id}") + # Process the run after it's completed + import time + time.sleep(1) # Give LangSmith time to save the run + langchain_metrics = LangchainMetrics() + langchain_metrics.connect_clickhouse() # Connect to ClickHouse + run = langchain_metrics.get_runs_by_id_safe([run_id]) + if run: + print(f"Found run: {run.trace_id}") + # Process the run here + langchain_metrics.save_to_clickhouse(run) + else: + print(f"No run found for run_id: {run_id}") + + def on_chain_start(self, serialized, inputs, *, run_id, parent_run_id=None, **kwargs): + print(f"Chain run started with run_id: {run_id}") + print(f"Inputs: {inputs}") + self.chain_run_ids.append(run_id) + + def on_chain_end(self, outputs, *, run_id, parent_run_id=None, **kwargs): + print(f"Chain run ended with run_id: {run_id}") + print(f"Outputs: {outputs}") + # Process the run here + import time + time.sleep(1) # Give LangSmith time to save the run + langchain_metrics = LangchainMetrics() + langchain_metrics.connect_clickhouse() # Connect to ClickHouse + run = langchain_metrics.get_runs_by_id_safe([run_id]) + if run: + print(f"Found run: {run.trace_id}") + langchain_metrics.save_to_clickhouse(run) + else: + print(f"No run found for run_id: {run_id}") + + def on_llm_start(self, serialized, prompts, *, run_id, parent_run_id=None, **kwargs): + print(f"LLM run started with run_id: {run_id}") + self.run_ids.append(run_id) + + def on_llm_end(self, response, *, run_id, parent_run_id=None, **kwargs): + print(f"LLM run ended with run_id: {run_id}") + self.last_run_id = run_id + + + + +# These functions are now integrated into the RAGChain class + + +class RAGApplication: + def __init__(self, max_articles: int = 5): + self.max_articles = max_articles + self.anthropic = Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) + api_key = os.getenv("ANTHROPIC_API_KEY") + if not api_key: + raise ValueError("ANTHROPIC_API_KEY is required") + + self.llm = ChatAnthropic( + model_name="claude-3-5-sonnet-latest", + api_key=SecretStr(api_key), + temperature=0.1, + timeout=60, + stop=[] + ) + self.rag_prompt = PromptTemplate( + input_variables=["question", "context"], + template=""" + You are a helpful AI assistant that answers questions based on the provided context from Guardian articles. + + Context from Guardian articles: + {context} + + Question: {question} + + Please provide a comprehensive answer based on the context above. If the context doesn't contain enough information to answer the question, say so. Use the Guardian articles as your primary source of information. + + Answer:""" + ) + + # Create the RAG chain + self.rag_chain = RAGChain(llm=self.llm, rag_prompt=self.rag_prompt) + + def answer_question(self, question: str) -> Dict[str, Any]: + """Invoke the RAG chain with callbacks.""" + try: + collector = RunIdCollector() + # Invoke the chain with callbacks + result = self.rag_chain.invoke({"question": question}, callbacks=[collector]) + + return { + "question": question, + "answer": result["answer"], + "articles_used": result["articles_used"], + "context": result["context"] + } + + except Exception as e: + logging.error(f"RAG pipeline failed: {e}") + return { + "question": question, + "answer": f"Error: {e}", + "context": [], + "articles_used": 0 + } + + +# === Example usage === +if __name__ == "__main__": + state_app = RAGApplication(max_articles=5) + for q in [ + "Give me the latest on Epstein.", + ]: + res = state_app.answer_question(q) diff --git a/requirements.txt b/requirements.txt index 0bbe95e..4d1fc2f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,7 +9,7 @@ #uvicorn==0.35.0 # LangChain dependencies -#clickhouse_connect==0.8.18 +clickhouse_connect==0.8.18 langchain~=0.3.26 langchain-community langchain-core diff --git a/services/cassandra/Dockerfile b/services/cassandra/Dockerfile index d821782..4b9f0f4 100644 --- a/services/cassandra/Dockerfile +++ b/services/cassandra/Dockerfile @@ -21,10 +21,42 @@ RUN apt-get update && \ RUN ls -la # Install dependencies +<<<<<<< HEAD +<<<<<<< HEAD +<<<<<<< HEAD +<<<<<<< HEAD +COPY requirements.txt . +======= COPY services/cassandra/requirements.txt . +>>>>>>> 9634c8a0ef870fa079b7b6de8463d67f3425a819 +======= +COPY services/cassandra/requirements.txt . +>>>>>>> 9634c8a0ef870fa079b7b6de8463d67f3425a819 +======= +COPY services/cassandra/requirements.txt . +>>>>>>> 9634c8a0ef870fa079b7b6de8463d67f3425a819 +======= +COPY services/cassandra/requirements.txt . +>>>>>>> 9634c8a0ef870fa079b7b6de8463d67f3425a819 RUN pip install --upgrade pip RUN pip install --no-cache-dir -r requirements.txt RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu # Default command if not overridden by docker-compose +<<<<<<< HEAD +<<<<<<< HEAD +<<<<<<< HEAD +<<<<<<< HEAD +CMD ["python", "-m", "uvicorn", "services.cassandra.cassandra_controller:app", "--host", "0.0.0.0", "--port", "8002"] +======= +CMD ["python", "-m", "uvicorn", "services.cassandra.cassandra_controller:app", "--host", "0.0.0.0", "--port", "8003"] +>>>>>>> 9634c8a0ef870fa079b7b6de8463d67f3425a819 +======= +CMD ["python", "-m", "uvicorn", "services.cassandra.cassandra_controller:app", "--host", "0.0.0.0", "--port", "8003"] +>>>>>>> 9634c8a0ef870fa079b7b6de8463d67f3425a819 +======= +CMD ["python", "-m", "uvicorn", "services.cassandra.cassandra_controller:app", "--host", "0.0.0.0", "--port", "8003"] +>>>>>>> 9634c8a0ef870fa079b7b6de8463d67f3425a819 +======= CMD ["python", "-m", "uvicorn", "services.cassandra.cassandra_controller:app", "--host", "0.0.0.0", "--port", "8003"] +>>>>>>> 9634c8a0ef870fa079b7b6de8463d67f3425a819 diff --git a/services/clickhouse/docker-compose.yaml b/services/clickhouse/docker-compose.yaml index 61ed390..b5ae685 100644 --- a/services/clickhouse/docker-compose.yaml +++ b/services/clickhouse/docker-compose.yaml @@ -4,7 +4,7 @@ services: container_name: clickhouse-server hostname: clickhouse ports: - - "8124:8123" # HTTP interface + - "8123:8123" # HTTP interface - "9001:9000" # Native TCP interface volumes: - ./user_directories/config.xml:/etc/clickhouse-server/user_directories/config.xml diff --git a/services/streamlit/docker-compose.yaml b/services/streamlit/docker-compose.yaml index 9f3dd50..595d761 100644 --- a/services/streamlit/docker-compose.yaml +++ b/services/streamlit/docker-compose.yaml @@ -57,4 +57,6 @@ services: - GRAFANA_USER=admin - GRAFANA_PASS=admin restart: "no" + networks: + - default env_file: "../../.env" diff --git a/services/streamlit/provisioner/dashboard.json b/services/streamlit/provisioner/dashboard.json index 96fb5ae..88d3d44 100644 --- a/services/streamlit/provisioner/dashboard.json +++ b/services/streamlit/provisioner/dashboard.json @@ -24,7 +24,7 @@ { "datasource": { "type": "grafana-clickhouse-datasource", - "uid": "__datasource__" + "uid": "ClickHouse" }, "fieldConfig": { "defaults": { @@ -126,7 +126,7 @@ { "datasource": { "type": "grafana-clickhouse-datasource", - "uid": "__datasource__" + "uid": "ClickHouse" }, "fieldConfig": { "defaults": { @@ -206,7 +206,7 @@ { "datasource": { "type": "grafana-clickhouse-datasource", - "uid": "__datasource__" + "uid": "ClickHouse" }, "editorType": "sql", "format": 1, diff --git a/services/streamlit/provisioner/provision_grafana.py b/services/streamlit/provisioner/provision_grafana.py index 2ac8ec9..a87e34c 100644 --- a/services/streamlit/provisioner/provision_grafana.py +++ b/services/streamlit/provisioner/provision_grafana.py @@ -8,6 +8,13 @@ PASSWORD = os.getenv("GRAFANA_PASS", "admin") DASHBOARD_PATH = "dashboard.json" +# ClickHouse connection details +CLICKHOUSE_HOST = os.getenv("CLICKHOUSE_HOST", "10.0.100.92") +CLICKHOUSE_PORT = os.getenv("CLICKHOUSE_PORT", "8123") +CLICKHOUSE_DATABASE = os.getenv("CLICKHOUSE_DATABASE", "guardian") +CLICKHOUSE_USER = os.getenv("CLICKHOUSE_USER", "user") +CLICKHOUSE_PASSWORD = os.getenv("CLICKHOUSE_PASSWORD", "default") + MAX_RETRIES = 10 RETRY_DELAY = 3 # seconds @@ -37,18 +44,19 @@ def create_datasource(session): "type": "grafana-clickhouse-datasource", "access": "proxy", "url": f"http://{host}:{port}", # adjust in .env + "url": f"http://{CLICKHOUSE_HOST}:{CLICKHOUSE_PORT}", "basicAuth": False, "jsonData": { - "defaultDatabase": "guardian", - "port": 8123, - "username": "user", - "server": host, # this needs to be here + "defaultDatabase": CLICKHOUSE_DATABASE, + "port": int(CLICKHOUSE_PORT), + "username": CLICKHOUSE_USER, + "server": CLICKHOUSE_HOST, "secure": False, "protocol": "http", "skip-tls-verify": True }, "secureJsonData": { - "password": "default" + "password": CLICKHOUSE_PASSWORD }, "isDefault": True } diff --git a/services/streamlit/requirements.txt b/services/streamlit/requirements.txt index 84682d5..3ae4a9e 100644 --- a/services/streamlit/requirements.txt +++ b/services/streamlit/requirements.txt @@ -1,5 +1,6 @@ # Streamlit service requirements anthropic==0.59.0 +clickhouse_connect==0.8.18 fastapi==0.116.1 langchain==0.3.27 langchain_anthropic==0.3.17