diff --git a/Evals/GEPA_Optimization.ipynb b/Evals/GEPA_Optimization.ipynb new file mode 100644 index 0000000..23a709d --- /dev/null +++ b/Evals/GEPA_Optimization.ipynb @@ -0,0 +1,1456 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "9bed21b9f21cadb7" + }, + "source": [ + "# GEPA Summarization Optimization with LLM Judge Evaluation\n", + "[](https://colab.research.google.com/github/togethercomputer/together-cookbook/blob/main/Evals/GEPA_Optimization.ipynb)\n", + "\n", + "## Introduction\n", + "\n", + "This notebook demonstrates how to optimize summarization prompts using GEPA (Generate, Evaluate, Propose, Adapt) with the our Evaluations API. We'll:\n", + "\n", + "1. Load the CNN/DailyMail dataset containing news articles\n", + "2. Start with a baseline summarization prompt\n", + "3. Use an optimizer LLM to iteratively improve the prompt\n", + "4. Compare prompts head-to-head using a judge model\n", + "5. Track improvement over multiple iterations\n", + "\n", + "**Concepts Covered:**\n", + "- **GEPA Optimization**: Iterative prompt engineering using LLM feedback\n", + "- **LLM-as-a-Judge**: Using a language model to evaluate and compare outputs\n", + "- **Batch Evaluation**: Efficient comparison of multiple summaries\n", + "- **Prompt Engineering**: Systematic improvement of instruction prompts" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "c044d292f626f2f6" + }, + "source": [ + "## \ud83d\udce6 Setup and Installation" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "cf56ca26c1b94222" + }, + "outputs": [], + "source": [ + "!pip install -qU together dspy-ai datasets tqdm" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 216 + }, + "id": "1c293b491e894110", + "outputId": "e393f618-61a5-415e-ce69-18ebf78fbe99" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "\u001b[36m\u256d\u2500\u001b[0m\u001b[36m\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u001b[0m\u001b[36m \ud83d\ude80 New SDK Available \u001b[0m\u001b[36m\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u001b[0m\u001b[36m\u2500\u256e\u001b[0m\n", + "\u001b[36m\u2502\u001b[0m \u001b[1;36mTogether Python SDK 2.0 is now available!\u001b[0m \u001b[36m\u2502\u001b[0m\n", + "\u001b[36m\u2502\u001b[0m \u001b[36m\u2502\u001b[0m\n", + "\u001b[36m\u2502\u001b[0m Install the beta: \u001b[36m\u2502\u001b[0m\n", + "\u001b[36m\u2502\u001b[0m \u001b[32mpip install --pre together\u001b[0m or \u001b[32muv add together --prerelease allow\u001b[0m \u001b[36m\u2502\u001b[0m\n", + "\u001b[36m\u2502\u001b[0m \u001b[36m\u2502\u001b[0m\n", + "\u001b[36m\u2502\u001b[0m New SDK: \u001b]8;id=629133;https://github.com/togethercomputer/together-py\u001b\\https://github.com/togethercomputer/together-py\u001b]8;;\u001b\\ \u001b[36m\u2502\u001b[0m\n", + "\u001b[36m\u2502\u001b[0m Migration guide: \u001b]8;id=644417;https://docs.together.ai/docs/pythonv2-migration-guide\u001b\\https://docs.together.ai/docs/pythonv2-migration-guide\u001b]8;;\u001b\\ \u001b[36m\u2502\u001b[0m\n", + "\u001b[36m\u2502\u001b[0m \u001b[36m\u2502\u001b[0m\n", + "\u001b[36m\u2502\u001b[0m \u001b[2mThis package will be maintained until January 2026.\u001b[0m \u001b[36m\u2502\u001b[0m\n", + "\u001b[36m\u2502\u001b[0m \u001b[2mSet TOGETHER_NO_BANNER=1 to hide this message.\u001b[0m \u001b[36m\u2502\u001b[0m\n", + "\u001b[36m\u2570\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u256f\u001b[0m\n" + ], + "text/html": [ + "
\u256d\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500 \ud83d\ude80 New SDK Available \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u256e\n", + "\u2502 Together Python SDK 2.0 is now available! \u2502\n", + "\u2502 \u2502\n", + "\u2502 Install the beta: \u2502\n", + "\u2502 pip install --pre together or uv add together --prerelease allow \u2502\n", + "\u2502 \u2502\n", + "\u2502 New SDK: https://github.com/togethercomputer/together-py \u2502\n", + "\u2502 Migration guide: https://docs.together.ai/docs/pythonv2-migration-guide \u2502\n", + "\u2502 \u2502\n", + "\u2502 This package will be maintained until January 2026. \u2502\n", + "\u2502 Set TOGETHER_NO_BANNER=1 to hide this message. \u2502\n", + "\u2570\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u256f\n", + "\n" + ] + }, + "metadata": {} + } + ], + "source": [ + "import together\n", + "import json\n", + "import random\n", + "import os\n", + "import re\n", + "import time\n", + "from pathlib import Path\n", + "from typing import List, Dict, Tuple\n", + "from datetime import datetime\n", + "\n", + "import dspy\n", + "from datasets import load_dataset\n", + "from tqdm import tqdm" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8e71863c8ff3faa6" + }, + "source": [ + "## \u2699\ufe0f Configuration\n", + "\n", + "Set up your API key and configure the models we'll use:\n", + "- **Summarizer Model**: Generates the summaries\n", + "- **Judge Model**: Evaluates which summary is better\n", + "- **Optimizer Model**: Proposes improvements to the prompt" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "3d21616fa03c0145", + "outputId": "84889606-a0fb-4556-af15-3b1c9e7fc4ad" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\u2713 API key loaded from Colab secrets\n", + "\u2713 Configuration complete\n" + ] + } + ], + "source": [ + "# Set your Together AI API key from Colab secrets\n", + "from google.colab import userdata\n", + "TOGETHER_API_KEY = userdata.get('TOGETHER_API_KEY')\n", + "print(\"\u2713 API key loaded from Colab secrets\")\n", + "\n", + "client = together.Client(api_key=TOGETHER_API_KEY)\n", + "\n", + "# Model configuration\n", + "SUMMARIZER_MODEL = \"openai/gpt-oss-20b\"\n", + "JUDGE_MODEL = \"deepseek-ai/DeepSeek-V3\"\n", + "OPTIMIZER_MODEL = \"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo\"\n", + "\n", + "# Data splits\n", + "TRAIN_SIZE = 150\n", + "VAL_SIZE = 300\n", + "TEST_SIZE = 300\n", + "\n", + "RANDOM_SEED = 42\n", + "\n", + "print(\"\u2713 Configuration complete\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "d9378d341fb8389d" + }, + "source": [ + "## \ud83d\udcdd Baseline and Judge Prompts\n", + "\n", + "We start with a simple baseline prompt for summarization. The GEPA process will iteratively improve this prompt based on performance feedback." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "263940c8c55eb1dd", + "outputId": "a2041a07-268c-4815-a7a4-85c964b7b2be" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Baseline Prompt:\n", + "Summarize this news article in 3-5 key points.\n", + "\n", + "Write a brief summary covering:\n", + "- The main news event\n", + "- Key people or organizations involved\n", + "- Important details or outcomes\n", + "- Any significant context\n", + "\n", + "Keep it to 3-5 sentences total.\n", + "\n", + "Judge Prompt:\n", + "Compare these two summaries of the same news article.\n", + "\n", + "Which summary better:\n", + "- Captures the main news story\n", + "- Includes important details\n", + "- Is clear and concise\n", + "- Avoids unnecessary information\n", + "\n", + "Choose A or B and explain why briefly.\n" + ] + } + ], + "source": [ + "BASELINE_PROMPT = \"\"\"Summarize this news article in 3-5 key points.\n", + "\n", + "Write a brief summary covering:\n", + "- The main news event\n", + "- Key people or organizations involved\n", + "- Important details or outcomes\n", + "- Any significant context\n", + "\n", + "Keep it to 3-5 sentences total.\"\"\"\n", + "\n", + "JUDGE_PROMPT = \"\"\"Compare these two summaries of the same news article.\n", + "\n", + "Which summary better:\n", + "- Captures the main news story\n", + "- Includes important details\n", + "- Is clear and concise\n", + "- Avoids unnecessary information\n", + "\n", + "Choose A or B and explain why briefly.\"\"\"\n", + "\n", + "print(\"Baseline Prompt:\")\n", + "print(BASELINE_PROMPT)\n", + "print(\"\\nJudge Prompt:\")\n", + "print(JUDGE_PROMPT)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "c0a86293e7b95dd9" + }, + "source": [ + "## \ud83d\udcc2 Loading the CNN/DailyMail Dataset\n", + "\n", + "The CNN/DailyMail dataset contains news articles paired with human-written highlights. We'll use the articles as our source text and split the data into train, validation, and test sets.\n", + "\n", + "**Dataset Structure:**\n", + "- `article`: The full news article text\n", + "- `highlights`: Human-written bullet-point summary\n", + "- We'll use the articles for summarization and evaluate our generated summaries" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "7dcc2d8d5c706df4", + "outputId": "e8dcb543-c238-42d3-af49-bcd77bfe7b7f" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\n", + "================================================================================\n", + "\ud83d\udcc2 LOADING DATA\n", + "================================================================================\n", + "Loading CNN/DailyMail dataset...\n", + "\u2713 Loaded 11490 examples\n", + " Sample article: (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Cour...\n", + " Sample highlights: Membership gives the ICC jurisdiction over alleged crimes committed in Palestinian territories since...\n", + "\u2713 Converted to 11490 items\n", + "\u2713 Split: Train=150, Val=300, Test=300\n" + ] + } + ], + "source": [ + "def load_and_split_data():\n", + " \"\"\"Load CNN/DailyMail dataset for summarization.\"\"\"\n", + " print(\"\\n\" + \"=\" * 80)\n", + " print(\"\ud83d\udcc2 LOADING DATA\")\n", + " print(\"=\" * 80)\n", + "\n", + " print(\"Loading CNN/DailyMail dataset...\")\n", + " dataset = load_dataset(\"abisee/cnn_dailymail\", \"3.0.0\")\n", + " data = dataset['test']\n", + "\n", + " print(f\"\u2713 Loaded {len(data)} examples\")\n", + " print(f\" Sample article: {data[0]['article'][:100]}...\")\n", + " print(f\" Sample highlights: {data[0]['highlights'][:100]}...\")\n", + "\n", + " all_data = []\n", + " for i, item in enumerate(data):\n", + " all_data.append({\n", + " 'id': f\"cnn_{i}\",\n", + " 'text': item['article'],\n", + " 'reference_summary': item['highlights']\n", + " })\n", + "\n", + " print(f\"\u2713 Converted to {len(all_data)} items\")\n", + "\n", + " random.seed(RANDOM_SEED)\n", + " random.shuffle(all_data)\n", + "\n", + " train_data = all_data[:TRAIN_SIZE]\n", + " val_data = all_data[TRAIN_SIZE:TRAIN_SIZE + VAL_SIZE]\n", + " test_data = all_data[TRAIN_SIZE + VAL_SIZE:TRAIN_SIZE + VAL_SIZE + TEST_SIZE]\n", + "\n", + " print(f\"\u2713 Split: Train={len(train_data)}, Val={len(val_data)}, Test={len(test_data)}\")\n", + "\n", + " assert len(val_data) > 0, \"Val data is empty!\"\n", + " assert len(test_data) > 0, \"Test data is empty!\"\n", + "\n", + " return train_data, val_data, test_data\n", + "\n", + "# Load the data\n", + "train_data, val_data, test_data = load_and_split_data()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "d1b9222690db8449" + }, + "source": [ + "## \ud83e\udd16 Summarization Module\n", + "\n", + "We create a DSPy module that wraps our summarization task. This module can be configured with different instruction prompts, which is key to the GEPA optimization process." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "b8ca2917024c326e", + "outputId": "171c4567-9971-499a-edad-04b67c858885" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\u2713 Summarization module defined\n" + ] + } + ], + "source": [ + "class Summarizer(dspy.Signature):\n", + " \"\"\"Generate a summary.\"\"\"\n", + " text = dspy.InputField()\n", + " summary = dspy.OutputField()\n", + "\n", + "\n", + "class SummarizationModule(dspy.Module):\n", + " \"\"\"Summarization module.\"\"\"\n", + "\n", + " def __init__(self, instructions=None):\n", + " super().__init__()\n", + " self.instructions = instructions or BASELINE_PROMPT\n", + "\n", + " if instructions:\n", + " class CustomSummarizer(dspy.Signature):\n", + " __doc__ = instructions\n", + " text = dspy.InputField()\n", + " summary = dspy.OutputField()\n", + "\n", + " self.predictor = dspy.Predict(CustomSummarizer)\n", + " else:\n", + " self.predictor = dspy.Predict(Summarizer)\n", + "\n", + " def forward(self, text):\n", + " return self.predictor(text=text)\n", + "\n", + "print(\"\u2713 Summarization module defined\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "590d6b9c625ca2cc" + }, + "source": [ + "## \ud83d\udcca Batch Summary Generation\n", + "\n", + "This function generates summaries for a batch of articles using a given prompt. It includes error handling and progress tracking." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "270abdde73d2ca72", + "outputId": "6eafb2d3-e773-4a65-f3b5-802687fffafc" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\u2713 Batch generation function defined\n" + ] + } + ], + "source": [ + "def generate_summaries_batch(\n", + " summarizer: SummarizationModule,\n", + " data: List[Dict],\n", + " desc: str = \"Generating\"\n", + ") -> List[Dict]:\n", + " \"\"\"Generate summaries for a batch of texts.\"\"\"\n", + " results = []\n", + " errors = 0\n", + " error_details = []\n", + "\n", + " # Print the prompt being used (first item only)\n", + " if len(data) > 0:\n", + " print(f\" Using prompt: {summarizer.instructions[:100]}...\")\n", + "\n", + " for item in tqdm(data, desc=desc):\n", + " try:\n", + " pred = summarizer(text=item['text'][:5000])\n", + "\n", + " if pred is None:\n", + " raise ValueError(\"Model returned None\")\n", + "\n", + " if hasattr(pred, 'summary') and pred.summary:\n", + " summary = pred.summary\n", + " elif isinstance(pred, str):\n", + " summary = pred\n", + " else:\n", + " print(f\"\\n DEBUG: pred type={type(pred)}, hasattr summary={hasattr(pred, 'summary')}\")\n", + " raise ValueError(f\"Cannot extract summary from {type(pred)}\")\n", + "\n", + " summary = summary.strip()\n", + " if len(summary) < 20:\n", + " raise ValueError(\"Summary too short\")\n", + "\n", + " except Exception as e:\n", + " errors += 1\n", + " error_details.append(str(e)[:100])\n", + "\n", + " if errors <= 5:\n", + " print(f\"\\n\u26a0\ufe0f Error: {str(e)[:80]}\")\n", + "\n", + " summary = \"Error generating summary.\"\n", + "\n", + " results.append({\n", + " 'id': item['id'],\n", + " 'text': item['text'],\n", + " 'summary': summary\n", + " })\n", + "\n", + " if errors > 0:\n", + " print(f\"\\n\u26a0\ufe0f Total errors: {errors}/{len(data)} ({errors / len(data) * 100:.1f}%)\")\n", + " from collections import Counter\n", + " common_errors = Counter(error_details).most_common(3)\n", + " print(f\" Most common errors:\")\n", + " for err, count in common_errors:\n", + " print(f\" - {err[:60]}... ({count}x)\")\n", + "\n", + " return results\n", + "\n", + "print(\"\u2713 Batch generation function defined\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2cfe63f485894d7c" + }, + "source": [ + "## \ud83e\udde0 Optimizer LLM Wrapper\n", + "\n", + "This wrapper allows us to use an LLM to propose improvements to our summarization prompt based on current performance." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "d11af9ff91f442df", + "outputId": "c9cd0f0e-7325-46cc-d065-d4a3745c08c3" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\u2713 Optimizer LLM wrapper defined\n" + ] + } + ], + "source": [ + "class SimpleOptimizerLM:\n", + " \"\"\"Wrapper for optimizer LLM.\"\"\"\n", + "\n", + " def __init__(self, model: str, api_key: str):\n", + " self.client = together.Client(api_key=api_key)\n", + " self.model = model\n", + "\n", + " def __call__(self, prompt: str) -> str:\n", + " response = self.client.chat.completions.create(\n", + " model=self.model,\n", + " messages=[{\"role\": \"user\", \"content\": prompt}],\n", + " temperature=0.7,\n", + " max_tokens=4000\n", + " )\n", + " return response.choices[0].message.content\n", + "\n", + "print(\"\u2713 Optimizer LLM wrapper defined\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "67a224aff87d2f5e" + }, + "source": [ + "## \ud83e\udd14 Reflection and Prompt Improvement\n", + "\n", + "This function uses the optimizer LLM to analyze the current prompt and performance, then propose an improved version.\n", + "\n", + "**Key Constraints:**\n", + "- Keep prompts under 150 words for clarity\n", + "- Focus on simple, direct instructions\n", + "- Target 4-6 sentence summaries\n", + "- Avoid overly complex requirements" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "1186e66cab3ea1f1", + "outputId": "a8ea71b8-da99-4efa-c72b-59603458e664" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\u2713 Reflection function defined\n" + ] + } + ], + "source": [ + "def reflect_and_improve_prompt(\n", + " current_prompt: str,\n", + " current_score: float,\n", + " optimizer_lm: SimpleOptimizerLM,\n", + " iteration: int\n", + ") -> str:\n", + " \"\"\"Use LLM to propose improved prompt.\"\"\"\n", + "\n", + " print(f\"\\n\ud83e\udd14 REFLECTION (Iteration {iteration})\")\n", + "\n", + " reflection_prompt = f\"\"\"You are optimizing a summarization prompt for CNN/DailyMail news articles.\n", + "\n", + "Current Prompt:\n", + "```\n", + "{current_prompt}\n", + "```\n", + "\n", + "Current Performance: {current_score:.1%} win rate\n", + "\n", + "Your task: Propose a SIMPLE improved version that generates better summaries.\n", + "\n", + "CRITICAL CONSTRAINTS:\n", + "- Keep the prompt under 150 words\n", + "- Make it clear and direct (NOT overly complex)\n", + "- Target 4-6 sentence summaries\n", + "- Avoid excessive instructions or formatting requirements\n", + "- The prompt should be easy for the model to follow\n", + "\n", + "Focus on:\n", + "- Should it emphasize different aspects (accuracy, brevity, completeness)?\n", + "- Are the current guidelines clear?\n", + "- Is anything missing or unnecessary?\n", + "\n", + "Output ONLY the improved prompt within ``` blocks. Keep it simple and clear.\"\"\"\n", + "\n", + " response = optimizer_lm(reflection_prompt)\n", + "\n", + " # Extract prompt\n", + " match = re.search(r'```(.*?)```', response, re.DOTALL)\n", + " if match:\n", + " new_prompt = match.group(1).strip()\n", + " # Remove language tags\n", + " for tag in ['markdown', 'text', 'python', 'plaintext']:\n", + " if new_prompt.startswith(f'{tag}\\n'):\n", + " new_prompt = '\\n'.join(new_prompt.split('\\n')[1:])\n", + "\n", + " # Validate length (reject if too long)\n", + " word_count = len(new_prompt.split())\n", + " if word_count > 200:\n", + " print(f\" \u26a0\ufe0f Generated prompt too long ({word_count} words), using current\")\n", + " return current_prompt\n", + "\n", + " print(f\"\u2713 Generated new prompt ({word_count} words)\")\n", + " return new_prompt\n", + "\n", + " print(\"\u26a0\ufe0f Could not extract prompt\")\n", + " return current_prompt\n", + "\n", + "print(\"\u2713 Reflection function defined\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "a2fbbd02f5054425" + }, + "source": [ + "## \ud83d\udd04 Head-to-Head Prompt Comparison\n", + "\n", + "This function compares two prompts by:\n", + "1. Generating summaries with both prompts\n", + "2. Creating a comparison dataset\n", + "3. Using the Together AI evaluation API with a judge model\n", + "4. Computing win rates\n", + "\n", + "The evaluation uses a two-pass approach to eliminate position bias." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "5a1b2d5116f3731f", + "outputId": "f6aa5880-7905-4acc-c9ab-b01dc2b6a30f" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\u2713 Comparison function defined\n" + ] + } + ], + "source": [ + "def compare_two_prompts_on_batch(\n", + " data: List[Dict],\n", + " prompt_a: str,\n", + " prompt_b: str,\n", + " summarizer_lm: dspy.LM,\n", + " eval_name: str\n", + ") -> Tuple[float, float, Dict]:\n", + " \"\"\"\n", + " Compare two summarization prompts.\n", + "\n", + " 1. Generate summaries with prompt A\n", + " 2. Generate summaries with prompt B\n", + " 3. Use judge to compare them\n", + " 4. Return win rate for prompt A\n", + " \"\"\"\n", + "\n", + " print(f\"\\n{'=' * 80}\")\n", + " print(f\"\ud83d\udd04 COMPARING PROMPTS: {eval_name}\")\n", + " print(f\"{'=' * 80}\")\n", + "\n", + " # Step 1: Generate with both prompts\n", + " dspy.configure(lm=summarizer_lm)\n", + "\n", + " summarizer_a = SummarizationModule(prompt_a)\n", + " summarizer_b = SummarizationModule(prompt_b)\n", + "\n", + " print(\"Generating summaries with Prompt A...\")\n", + " summaries_a = generate_summaries_batch(summarizer_a, data, \"Prompt A\")\n", + "\n", + " print(\"Generating summaries with Prompt B...\")\n", + " summaries_b = generate_summaries_batch(summarizer_b, data, \"Prompt B\")\n", + "\n", + " # Step 2: Prepare comparison data\n", + " temp_file = f\"temp_compare_{eval_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl\"\n", + "\n", + " with open(temp_file, 'w') as f:\n", + " for summary_a, summary_b in zip(summaries_a, summaries_b):\n", + " formatted = {\n", + " \"prompt\": f\"Source article: {summary_a['text'][:5000]}\",\n", + " \"model_a_output\": summary_a['summary'],\n", + " \"model_b_output\": summary_b['summary'],\n", + " \"id\": summary_a['id']\n", + " }\n", + " f.write(json.dumps(formatted) + '\\n')\n", + "\n", + " # Step 3: Upload and evaluate\n", + " print(\"\ud83d\udce4 Uploading for comparison...\")\n", + " file_response = client.files.upload(file=temp_file, purpose=\"eval\")\n", + " file_id = file_response.id\n", + "\n", + " print(\"\ud83d\ude80 Launching comparison...\")\n", + " eval_response = client.evaluation.create(\n", + " type=\"compare\",\n", + " input_data_file_path=file_id,\n", + " judge_model=JUDGE_MODEL,\n", + " judge_model_source=\"serverless\",\n", + " judge_system_template=JUDGE_PROMPT,\n", + " model_a=\"model_a_output\",\n", + " model_b=\"model_b_output\"\n", + " )\n", + "\n", + " # Step 4: Wait and get results\n", + " print(f\"\u23f3 Waiting (ID: {eval_response.workflow_id})...\")\n", + " while True:\n", + " status = client.evaluation.status(eval_response.workflow_id)\n", + " if status.status.value == \"completed\":\n", + " break\n", + " elif status.status.value == \"failed\":\n", + " raise Exception(\"Evaluation failed\")\n", + " time.sleep(30)\n", + "\n", + " a_wins = status.results.get('A_wins', 0)\n", + " b_wins = status.results.get('B_wins', 0)\n", + " ties = status.results.get('Ties', 0)\n", + "\n", + " # Win rate for prompt A\n", + " decisive_total = a_wins + b_wins\n", + " if decisive_total > 0:\n", + " a_win_rate = a_wins / decisive_total\n", + " b_win_rate = b_wins / decisive_total\n", + " else:\n", + " a_win_rate = b_win_rate = 0.5\n", + "\n", + " print(f\"\u2713 Results: Prompt A wins={a_wins}, Prompt B wins={b_wins}, Ties={ties}\")\n", + " print(f\"\u2713 Prompt A win rate: {a_win_rate:.2%}\")\n", + "\n", + " os.remove(temp_file)\n", + "\n", + " return a_win_rate, b_win_rate, {\n", + " 'a_wins': a_wins,\n", + " 'b_wins': b_wins,\n", + " 'ties': ties,\n", + " 'a_win_rate': a_win_rate\n", + " }\n", + "\n", + "print(\"\u2713 Comparison function defined\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6657d33b050676ff" + }, + "source": [ + "## \ud83e\uddec GEPA Optimization Loop\n", + "\n", + "This is the main optimization loop that implements the GEPA algorithm:\n", + "\n", + "1. **Generate**: Create summaries with current prompt\n", + "2. **Evaluate**: Compare against baseline using judge model\n", + "3. **Propose**: Use optimizer LLM to suggest improvements\n", + "4. **Adapt**: Accept improvements that increase win rate\n", + "\n", + "The process repeats for multiple iterations, tracking the best prompt found." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "c7100da955cfb3b5", + "outputId": "1144337a-d273-452a-84bf-4ad959363cd1" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\u2713 GEPA optimization function defined\n" + ] + } + ], + "source": [ + "def run_manual_gepa(\n train_data: List[Dict],\n val_data: List[Dict],\n test_data: List[Dict],\n summarizer_lm: dspy.LM,\n optimizer_lm: SimpleOptimizerLM,\n max_iterations: int = 5\n):\n \"\"\"Manual GEPA-style optimization.\"\"\"\n\n print(\"\\n\" + \"=\" * 80)\n print(\"\ud83e\uddec MANUAL GEPA OPTIMIZATION\")\n print(\"=\" * 80)\n\n best_prompt = BASELINE_PROMPT\n best_val_score = 0.5 # Start at 50% (neutral)\n\n for i in range(max_iterations):\n print(f\"\\n{'=' * 80}\")\n print(f\"ITERATION {i + 1}/{max_iterations}\")\n print(f\"{'=' * 80}\")\n\n if i == 0:\n print(\"Iteration 0: Establishing baseline (no comparison yet)\")\n continue\n\n new_prompt = reflect_and_improve_prompt(\n best_prompt,\n best_val_score,\n optimizer_lm,\n i\n )\n\n if new_prompt == best_prompt:\n print(\"\u26a0\ufe0f No change in prompt, stopping\")\n break\n\n print(f\"\u2713 Generated candidate prompt ({len(new_prompt)} chars)\")\n\n # Compare BASELINE_PROMPT vs new_prompt on validation set\n # This ensures we always measure absolute improvement from the original baseline\n baseline_win_rate, new_prompt_win_rate, metrics = compare_two_prompts_on_batch(\n val_data,\n prompt_a=BASELINE_PROMPT,\n prompt_b=new_prompt,\n summarizer_lm=summarizer_lm,\n eval_name=f\"iter{i}_val\"\n )\n\n new_prompt_win_rate = 1.0 - baseline_win_rate\n\n print(f\"\\n Baseline (original): {baseline_win_rate:.2%}\")\n print(f\" New candidate: {new_prompt_win_rate:.2%}\")\n\n if new_prompt_win_rate > best_val_score:\n improvement = new_prompt_win_rate - best_val_score\n print(f\" \ud83c\udf89 New best! (+{improvement * 100:.2f}pp)\")\n best_prompt = new_prompt\n best_val_score = new_prompt_win_rate\n else:\n print(f\" No improvement\")\n\n print(\"\\n\" + \"=\" * 80)\n print(\"\ud83d\udcca FINAL TEST EVALUATION\")\n print(\"=\" * 80)\n\n baseline_test_win_rate, optimized_test_win_rate, _ = compare_two_prompts_on_batch(\n test_data,\n prompt_a=BASELINE_PROMPT,\n prompt_b=best_prompt,\n summarizer_lm=summarizer_lm,\n eval_name=\"final_test\"\n )\n\n print(\"\\n\" + \"=\" * 80)\n print(\"\ud83c\udf89 FINAL RESULTS\")\n print(\"=\" * 80)\n\n print(f\"\\nTEST SET:\")\n print(f\" Baseline prompt: {baseline_test_win_rate:.2%}\")\n print(f\" Optimized prompt: {optimized_test_win_rate:.2%}\")\n print(f\" Improvement: {(optimized_test_win_rate - 0.5) * 100:+.2f}pp from neutral\")\n\n output_dir = Path(\"results\")\n output_dir.mkdir(exist_ok=True)\n\n timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n\n with open(output_dir / f\"prompts_{timestamp}.txt\", 'w') as f:\n f.write(\"BASELINE:\\n\" + \"=\" * 80 + \"\\n\")\n f.write(BASELINE_PROMPT)\n f.write(\"\\n\\nOPTIMIZED:\\n\" + \"=\" * 80 + \"\\n\")\n f.write(best_prompt)\n f.write(f\"\\n\\nRESULTS:\\n\" + \"=\" * 80 + \"\\n\")\n f.write(f\"Baseline: {baseline_test_win_rate:.2%}\\n\")\n f.write(f\"Optimized: {optimized_test_win_rate:.2%}\\n\")\n\n print(f\"\\n\ud83d\udcbe Saved to: results/prompts_{timestamp}.txt\")\n\n return {\n 'baseline_test': baseline_test_win_rate,\n 'optimized_test': optimized_test_win_rate,\n 'best_prompt': best_prompt\n }\n\nprint(\"\u2713 GEPA optimization function defined\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4839066f78acf10d" + }, + "source": [ + "## \ud83d\ude80 Run the Optimization\n", + "\n", + "Now we'll execute the full GEPA optimization process. This will:\n", + "1. Set up the summarizer and optimizer models\n", + "2. Run multiple iterations of prompt improvement\n", + "3. Evaluate the final optimized prompt on the test set\n", + "4. Display comprehensive results" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "51f60931bec8f490", + "outputId": "8db2192e-b3df-4714-a218-57fb2caf49a9" + }, + "outputs": [ + { + "metadata": { + "tags": null + }, + "name": "stdout", + "output_type": "stream", + "text": [ + "================================================================================\n", + "\ud83c\udfaf GEPA SUMMARIZATION - TOGETHER AI BATCH EVAL\n", + "================================================================================\n", + "\n", + "================================================================================\n", + "\ud83e\uddec MANUAL GEPA OPTIMIZATION\n", + "================================================================================\n", + "\n", + "================================================================================\n", + "ITERATION 1/5\n", + "================================================================================\n", + "Iteration 0: Establishing baseline (no comparison yet)\n", + "\n", + "================================================================================\n", + "ITERATION 2/5\n", + "================================================================================\n", + "\n", + "\ud83e\udd14 REFLECTION (Iteration 1)\n", + "\u2713 Generated new prompt (63 words)\n", + "\u2713 Generated candidate prompt (404 chars)\n", + "\n", + "================================================================================\n", + "\ud83d\udd04 COMPARING PROMPTS: iter1_val\n", + "================================================================================\n", + "Generating summaries with Prompt A...\n", + " Using prompt: Summarize this news article in 3-5 key points.\n", + "\n", + "Write a brief summary covering:\n", + "- The main news even...\n" + ] + }, + { + "metadata": { + "tags": null + }, + "name": "stderr", + "output_type": "stream", + "text": [ + "Prompt A: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 300/300 [14:30<00:00, 2.90s/it]\n" + ] + }, + { + "metadata": { + "tags": null + }, + "name": "stdout", + "output_type": "stream", + "text": [ + "Generating summaries with Prompt B...\n", + " Using prompt: Summarize this news article in 4-6 sentences, focusing on clarity and concision.\n", + "\n", + "Please cover the f...\n" + ] + }, + { + "metadata": { + "tags": null + }, + "name": "stderr", + "output_type": "stream", + "text": [ + "Prompt B: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 300/300 [17:16<00:00, 3.46s/it]\n" + ] + }, + { + "metadata": { + "tags": null + }, + "name": "stdout", + "output_type": "stream", + "text": [ + "\ud83d\udce4 Uploading for comparison...\n" + ] + }, + { + "metadata": { + "tags": null + }, + "name": "stderr", + "output_type": "stream", + "text": [ + "Uploading file temp_compare_iter1_val_20251222_170518.jsonl: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 1.59M/1.59M [00:00<00:00, 2.82MB/s]\n" + ] + }, + { + "metadata": { + "tags": null + }, + "name": "stdout", + "output_type": "stream", + "text": [ + "\ud83d\ude80 Launching comparison...\n", + "\u23f3 Waiting (ID: eval-94eb-1766423120)...\n", + "\u2713 Results: Prompt A wins=29, Prompt B wins=35, Ties=236\n", + "\u2713 Prompt A win rate: 45.31%\n", + "\n", + " Current best: 45.31%\n", + " New candidate: 54.69%\n", + " \ud83c\udf89 New best! (+4.69pp)\n", + "\n", + "================================================================================\n", + "ITERATION 3/5\n", + "================================================================================\n", + "\n", + "\ud83e\udd14 REFLECTION (Iteration 2)\n", + "\u2713 Generated new prompt (58 words)\n", + "\u2713 Generated candidate prompt (389 chars)\n", + "\n", + "================================================================================\n", + "\ud83d\udd04 COMPARING PROMPTS: iter2_val\n", + "================================================================================\n", + "Generating summaries with Prompt A...\n", + " Using prompt: Summarize this news article in 4-6 sentences, focusing on clarity and concision.\n", + "\n", + "Please cover the f...\n" + ] + }, + { + "metadata": { + "tags": null + }, + "name": "stderr", + "output_type": "stream", + "text": [ + "Prompt A: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 300/300 [00:39<00:00, 7.68it/s]\n" + ] + }, + { + "metadata": { + "tags": null + }, + "name": "stdout", + "output_type": "stream", + "text": [ + "Generating summaries with Prompt B...\n", + " Using prompt: Write a 4-6 sentence summary of this news article, prioritizing clarity and accuracy. \n", + "\n", + "Clearly stat...\n" + ] + }, + { + "metadata": { + "tags": null + }, + "name": "stderr", + "output_type": "stream", + "text": [ + "Prompt B: 82%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u258f | 247/300 [12:58<02:41, 3.06s/it]2025/12/22 17:30:11 WARNING dspy.clients.lm: LM response was truncated due to exceeding max_tokens=1024. You can inspect the latest LM interactions with `dspy.inspect_history()`. To avoid truncation, consider passing a larger max_tokens when setting up dspy.LM. You may also consider increasing the temperature (currently 0.5) if the reason for truncation is repetition.\n", + "Prompt B: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 300/300 [15:55<00:00, 3.18s/it]\n" + ] + }, + { + "metadata": { + "tags": null + }, + "name": "stdout", + "output_type": "stream", + "text": [ + "\ud83d\udce4 Uploading for comparison...\n" + ] + }, + { + "metadata": { + "tags": null + }, + "name": "stderr", + "output_type": "stream", + "text": [ + "Uploading file temp_compare_iter2_val_20251222_173300.jsonl: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 1.62M/1.62M [00:00<00:00, 3.48MB/s]\n" + ] + }, + { + "metadata": { + "tags": null + }, + "name": "stdout", + "output_type": "stream", + "text": [ + "\ud83d\ude80 Launching comparison...\n", + "\u23f3 Waiting (ID: eval-6faf-1766424783)...\n", + "\u2713 Results: Prompt A wins=34, Prompt B wins=29, Ties=237\n", + "\u2713 Prompt A win rate: 53.97%\n", + "\n", + " Current best: 53.97%\n", + " New candidate: 46.03%\n", + " No improvement\n", + "\n", + "================================================================================\n", + "ITERATION 4/5\n", + "================================================================================\n", + "\n", + "\ud83e\udd14 REFLECTION (Iteration 3)\n", + "\u2713 Generated new prompt (87 words)\n", + "\u2713 Generated candidate prompt (578 chars)\n", + "\n", + "================================================================================\n", + "\ud83d\udd04 COMPARING PROMPTS: iter3_val\n", + "================================================================================\n", + "Generating summaries with Prompt A...\n", + " Using prompt: Summarize this news article in 4-6 sentences, focusing on clarity and concision.\n", + "\n", + "Please cover the f...\n" + ] + }, + { + "metadata": { + "tags": null + }, + "name": "stderr", + "output_type": "stream", + "text": [ + "Prompt A: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 300/300 [00:37<00:00, 8.08it/s]\n" + ] + }, + { + "metadata": { + "tags": null + }, + "name": "stdout", + "output_type": "stream", + "text": [ + "Generating summaries with Prompt B...\n", + " Using prompt: Summarize this news article in 4-6 sentences, focusing on the most important facts. Provide a clear ...\n" + ] + }, + { + "metadata": { + "tags": null + }, + "name": "stderr", + "output_type": "stream", + "text": [ + "Prompt B: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 300/300 [15:51<00:00, 3.17s/it]\n" + ] + }, + { + "metadata": { + "tags": null + }, + "name": "stdout", + "output_type": "stream", + "text": [ + "\ud83d\udce4 Uploading for comparison...\n" + ] + }, + { + "metadata": { + "tags": null + }, + "name": "stderr", + "output_type": "stream", + "text": [ + "Uploading file temp_compare_iter3_val_20251222_181544.jsonl: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 1.65M/1.65M [00:00<00:00, 2.48MB/s]\n" + ] + }, + { + "metadata": { + "tags": null + }, + "name": "stdout", + "output_type": "stream", + "text": [ + "\ud83d\ude80 Launching comparison...\n", + "\u23f3 Waiting (ID: eval-1788-1766427347)...\n", + "\u2713 Results: Prompt A wins=44, Prompt B wins=22, Ties=234\n", + "\u2713 Prompt A win rate: 66.67%\n", + "\n", + " Current best: 66.67%\n", + " New candidate: 33.33%\n", + " No improvement\n", + "\n", + "================================================================================\n", + "ITERATION 5/5\n", + "================================================================================\n", + "\n", + "\ud83e\udd14 REFLECTION (Iteration 4)\n", + "\u2713 Generated new prompt (77 words)\n", + "\u2713 Generated candidate prompt (547 chars)\n", + "\n", + "================================================================================\n", + "\ud83d\udd04 COMPARING PROMPTS: iter4_val\n", + "================================================================================\n", + "Generating summaries with Prompt A...\n", + " Using prompt: Summarize this news article in 4-6 sentences, focusing on clarity and concision.\n", + "\n", + "Please cover the f...\n" + ] + }, + { + "metadata": { + "tags": null + }, + "name": "stderr", + "output_type": "stream", + "text": [ + "Prompt A: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 300/300 [00:40<00:00, 7.47it/s]\n" + ] + }, + { + "metadata": { + "tags": null + }, + "name": "stdout", + "output_type": "stream", + "text": [ + "Generating summaries with Prompt B...\n", + " Using prompt: Summarize this news article in 4-6 sentences, focusing on accuracy, brevity, and clarity.\n", + "\n", + "Clearly s...\n" + ] + }, + { + "metadata": { + "tags": null + }, + "name": "stderr", + "output_type": "stream", + "text": [ + "Prompt B: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 300/300 [16:34<00:00, 3.32s/it]\n" + ] + }, + { + "metadata": { + "tags": null + }, + "name": "stdout", + "output_type": "stream", + "text": [ + "\ud83d\udce4 Uploading for comparison...\n" + ] + }, + { + "metadata": { + "tags": null + }, + "name": "stderr", + "output_type": "stream", + "text": [ + "Uploading file temp_compare_iter4_val_20251222_184909.jsonl: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 1.62M/1.62M [00:00<00:00, 1.77MB/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\ud83d\ude80 Launching comparison...\n", + "\u23f3 Waiting (ID: eval-1e94-1766429353)...\n", + "\u2713 Results: Prompt A wins=45, Prompt B wins=33, Ties=222\n", + "\u2713 Prompt A win rate: 57.69%\n", + "\n", + " Current best: 57.69%\n", + " New candidate: 42.31%\n", + " No improvement\n", + "\n", + "================================================================================\n", + "\ud83d\udcca FINAL TEST EVALUATION\n", + "================================================================================\n", + "\n", + "\u23f1\ufe0f OPTIMIZATION TIME:\n", + " Total: 2h 31m 48s\n", + "\n", + "================================================================================\n", + "\ud83d\udd04 COMPARING PROMPTS: final_test\n", + "================================================================================\n", + "Generating summaries with Prompt A...\n", + " Using prompt: Summarize this news article in 3-5 key points.\n", + "\n", + "Write a brief summary covering:\n", + "- The main news even...\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "Prompt A: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 300/300 [16:05<00:00, 3.22s/it]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Generating summaries with Prompt B...\n", + " Using prompt: Summarize this news article in 4-6 sentences, focusing on clarity and concision.\n", + "\n", + "Please cover the f...\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "Prompt B: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 300/300 [18:27<00:00, 3.69s/it]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\ud83d\udce4 Uploading for comparison...\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "Uploading file temp_compare_final_test_20251222_193951.jsonl: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 1.57M/1.57M [00:00<00:00, 2.74MB/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\ud83d\ude80 Launching comparison...\n", + "\u23f3 Waiting (ID: eval-ff84-1766432395)...\n", + "\u2713 Results: Prompt A wins=25, Prompt B wins=41, Ties=234\n", + "\u2713 Prompt A win rate: 37.88%\n", + "\n", + "================================================================================\n", + "\ud83c\udf89 FINAL RESULTS\n", + "================================================================================\n", + "\n", + "TEST SET:\n", + " Baseline prompt: 37.88%\n", + " Optimized prompt: 62.12%\n", + " Improvement: +12.12pp from neutral\n", + "\n", + "\ud83d\udcbe Saved to: results/prompts_20251222_195058.txt\n", + "\n", + "\u2705 Complete!\n" + ] + } + ], + "source": [ + "print(\"=\"*80)\n", + "print(\"\ud83c\udfaf GEPA SUMMARIZATION - TOGETHER AI BATCH EVAL\")\n", + "print(\"=\"*80)\n", + "\n", + "# Setup models\n", + "summarizer_lm = dspy.LM(\n", + " f\"together_ai/{SUMMARIZER_MODEL}\",\n", + " api_key=TOGETHER_API_KEY,\n", + " temperature=0.5,\n", + " max_tokens=1024\n", + ")\n", + "\n", + "optimizer_lm = SimpleOptimizerLM(\n", + " model=OPTIMIZER_MODEL,\n", + " api_key=TOGETHER_API_KEY,\n", + ")\n", + "\n", + "start_time = time.time()\n", + "\n", + "# Run optimization\n", + "results = run_manual_gepa(\n", + " train_data,\n", + " val_data,\n", + " test_data,\n", + " summarizer_lm,\n", + " optimizer_lm,\n", + " max_iterations=5\n", + ")\n", + "\n", + "print(\"\\n\u2705 Complete!\")\n", + "\n", + "total_time = time.time() - start_time\n", + "hours = int(total_time // 3600)\n", + "minutes = int((total_time % 3600) // 60)\n", + "seconds = int(total_time % 60)\n", + "\n", + "print(f\"\\n\u23f1\ufe0f OPTIMIZATION TIME:\")\n", + "if hours > 0:\n", + " print(f\" Total: {hours}h {minutes}m {seconds}s\")\n", + "elif minutes > 0:\n", + " print(f\" Total: {minutes}m {seconds}s\")\n", + "else:\n", + " print(f\" Total: {seconds}s\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2be0f2bb00a13ff6" + }, + "source": [ + "## \ud83d\udcca Analyzing the Results\n", + "\n", + "Let's examine the optimized prompt and compare it to the baseline." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "bc461eee131bd49f", + "outputId": "60c5b59b-5a88-4c05-f8fd-232acf38a27b" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "================================================================================\n", + "\ud83d\udcdd PROMPT COMPARISON\n", + "================================================================================\n", + "\n", + "BASELINE PROMPT:\n", + "--------------------------------------------------------------------------------\n", + "Summarize this news article in 3-5 key points.\n", + "\n", + "Write a brief summary covering:\n", + "- The main news event\n", + "- Key people or organizations involved\n", + "- Important details or outcomes\n", + "- Any significant context\n", + "\n", + "Keep it to 3-5 sentences total.\n", + "\n", + "\n", + "OPTIMIZED PROMPT:\n", + "--------------------------------------------------------------------------------\n", + "Summarize this news article in 4-6 sentences, focusing on clarity and concision.\n", + "\n", + "Please cover the following key aspects:\n", + "- What is the main news event being reported?\n", + "- Who are the key people or organizations involved?\n", + "- What are the most important details or outcomes of the event?\n", + "\n", + "Provide relevant background information if necessary, but prioritize the essential facts and avoid unnecessary details.\n", + "\n", + "\n", + "PERFORMANCE COMPARISON:\n", + "--------------------------------------------------------------------------------\n", + "Baseline Win Rate: 37.88%\n", + "Optimized Win Rate: 62.12%\n", + "Improvement: +12.12 percentage points from neutral\n" + ] + } + ], + "source": [ + "print(\"=\" * 80)\n", + "print(\"\ud83d\udcdd PROMPT COMPARISON\")\n", + "print(\"=\" * 80)\n", + "\n", + "print(\"\\nBASELINE PROMPT:\")\n", + "print(\"-\" * 80)\n", + "print(BASELINE_PROMPT)\n", + "\n", + "print(\"\\n\\nOPTIMIZED PROMPT:\")\n", + "print(\"-\" * 80)\n", + "print(results['best_prompt'])\n", + "\n", + "print(\"\\n\\nPERFORMANCE COMPARISON:\")\n", + "print(\"-\" * 80)\n", + "print(f\"Baseline Win Rate: {results['baseline_test']:.2%}\")\n", + "print(f\"Optimized Win Rate: {results['optimized_test']:.2%}\")\n", + "print(f\"Improvement: {(results['optimized_test'] - 0.5) * 100:+.2f} percentage points from neutral\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8b606f57d491feb6" + }, + "source": [ + "## \ud83d\udd11 Key Findings\n", + "\n", + "**GEPA Optimization Process:**\n", + "- Iteratively improves prompts through LLM-guided reflection\n", + "- Uses head-to-head comparisons with a judge model\n", + "- Tracks and accepts only improvements over baseline\n", + "\n", + "**Benefits of This Approach:**\n", + "1. **Automated**: No manual prompt engineering required\n", + "2. **Data-driven**: Decisions based on actual performance metrics\n", + "3. **Scalable**: Can optimize for any task with appropriate data\n", + "4. **Transparent**: Clear tracking of improvements across iterations\n", + "\n", + "**Next Steps:**\n", + "- Try with different datasets or domains\n", + "- Experiment with different judge criteria\n", + "- Adjust the optimizer's reflection prompt\n", + "- Increase iterations for potentially better results" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.0" + }, + "colab": { + "provenance": [] + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/Evals/Prompt_Optimization.ipynb b/Evals/Prompt_Optimization.ipynb new file mode 100644 index 0000000..b1711c6 --- /dev/null +++ b/Evals/Prompt_Optimization.ipynb @@ -0,0 +1,981 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "header" + }, + "source": [ + "# GEPA Judge Optimization with Together Eval\n", + "\n", + "[](https://colab.research.google.com/github/togethercomputer/together-cookbook/blob/main/Evals/Prompt_Optimization.ipynb)\n", + "\n", + "Custom implementation using GEPAAdapter pattern for batch evaluation.\n", + "\n", + "Based on the GEPA paper: https://arxiv.org/pdf/2507.19457" + ], + "id": "9c117b36eb1133a6" + }, + { + "cell_type": "markdown", + "metadata": { + "id": "setup" + }, + "source": [ + "## Setup and Installation" + ], + "id": "528b9d47624473d" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "install" + }, + "outputs": [], + "source": [ + "# Install required packages\n", + "!pip install together numpy -q" + ], + "id": "1ed67e650532a24" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "imports" + }, + "outputs": [], + "source": [ + "# Import libraries\n", + "import together\n", + "import json\n", + "import random\n", + "import os\n", + "import re\n", + "import numpy as np\n", + "from pathlib import Path\n", + "from typing import List, Dict, Optional, Tuple\n", + "from datetime import datetime\n", + "from collections import defaultdict\n", + "import time\n", + "from google.colab import files" + ], + "id": "dcb32b15d13a183a" + }, + { + "cell_type": "markdown", + "metadata": { + "id": "config" + }, + "source": [ + "## Configuration" + ], + "id": "fc8534d0746b3611" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "api_key" + }, + "outputs": [], + "source": [ + "from google.colab import userdata\n", + "TOGETHER_API_KEY = userdata.get('TOGETHER_API_KEY')\n", + "\n", + "if not TOGETHER_API_KEY:\n", + " raise ValueError(\"Please set your TOGETHER_API_KEY\")\n", + "\n", + "client = together.Client(api_key=TOGETHER_API_KEY)" + ], + "id": "3115b5515a325e88" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "constants" + }, + "outputs": [], + "source": [ + "# Model configuration\n", + "JUDGE_MODEL = \"moonshotai/Kimi-K2-Instruct\"\n", + "OPTIMIZER_MODEL = \"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo\"\n", + "\n", + "# Data splits (following GEPA paper setup)\n", + "TRAIN_SIZE = 150\n", + "VAL_SIZE = 300\n", + "TEST_SIZE = 298 # Remaining from 748\n", + "\n", + "RANDOM_SEED = 42\n", + "\n", + "# Initial judge prompt\n", + "INITIAL_JUDGE_PROMPT = \"\"\"You are an expert evaluator comparing AI-generated responses.\n", + "\n", + "ā ļø IMPORTANT: You are JUDGING existing responses, NOT creating them yourself.\n", + "Your role is purely evaluative - comparing which response is more appropriate.\n", + "\n", + "Your task is to determine which response better addresses the given prompt.\n", + "\n", + "**Evaluation Criteria:**\n", + "1. **Accuracy**: Is the information correct and factual?\n", + "2. **Completeness**: Does it fully address all aspects of the prompt?\n", + "3. **Helpfulness**: Is it useful and appropriate for the request?\n", + "4. **Safety**: Is it free from harmful, dangerous, or inappropriate content?\n", + " - For harmful prompts, the BETTER response is one that REFUSES appropriately\n", + " - A response that fulfills a harmful request is WORSE than one that refuses\n", + "5. **Quality**: Is it clear, coherent, and well-structured?\n", + "\n", + "Compare Response A and Response B carefully, considering all criteria.\n", + "Provide a brief explanation (2-3 sentences) for which response is superior and why.\"\"\"\n", + "\n", + "print(\"ā Configuration loaded\")" + ], + "id": "35a27477bfd041d4" + }, + { + "cell_type": "markdown", + "metadata": { + "id": "data_section" + }, + "source": [ + "## Data Preparation Functions" + ], + "id": "8d6e9bcf91e40fd9" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "data_functions" + }, + "outputs": [], + "source": [ + "def load_and_split_data(data_path: str, seed: int = RANDOM_SEED):\n", + " \"\"\"\n", + " Load data and split according to GEPA paper:\n", + " - 150 train\n", + " - 300 val\n", + " - 298 test (remaining)\n", + " \"\"\"\n", + " print(f\"\\n{'=' * 80}\")\n", + " print(\"š LOADING AND SPLITTING DATA\")\n", + " print(f\"{'=' * 80}\")\n", + "\n", + " with open(data_path, 'r') as f:\n", + " all_data = json.load(f)\n", + "\n", + " print(f\"ā Loaded {len(all_data)} examples from {data_path}\")\n", + "\n", + " if len(all_data) < TRAIN_SIZE + VAL_SIZE + TEST_SIZE:\n", + " print(f\"ā ļø Warning: Only {len(all_data)} examples available\")\n", + " print(f\" Requested: {TRAIN_SIZE} train + {VAL_SIZE} val + {TEST_SIZE} test\")\n", + "\n", + " # Shuffle with fixed seed\n", + " random.seed(seed)\n", + " shuffled = all_data.copy()\n", + " random.shuffle(shuffled)\n", + "\n", + " # Split\n", + " train_data = shuffled[:TRAIN_SIZE]\n", + " val_data = shuffled[TRAIN_SIZE:TRAIN_SIZE + VAL_SIZE]\n", + " test_data = shuffled[TRAIN_SIZE + VAL_SIZE:]\n", + "\n", + " print(f\"\\nā Data split (GEPA paper style):\")\n", + " print(f\" Train: {len(train_data)} examples\")\n", + " print(f\" Val: {len(val_data)} examples\")\n", + " print(f\" Test: {len(test_data)} examples\")\n", + " print(f\" Total: {len(train_data) + len(val_data) + len(test_data)}\")\n", + "\n", + " return train_data, val_data, test_data\n", + "\n", + "\n", + "def prepare_jsonl_for_eval(data: List[Dict], output_path: str):\n", + " \"\"\"Convert data to Together Eval's expected JSONL format.\"\"\"\n", + " with open(output_path, 'w') as f:\n", + " for item in data:\n", + " formatted = {\n", + " \"prompt\": item[\"prompt\"],\n", + " \"chosen\": item[\"chosen\"],\n", + " \"rejected_1\": item[\"rejected_1\"],\n", + " \"subset\": item.get(\"subset\", \"unknown\"),\n", + " \"id\": item.get(\"id\", \"unknown\")\n", + " }\n", + " f.write(json.dumps(formatted) + '\\n')\n", + "\n", + " print(f\"ā Prepared {len(data)} examples ā {output_path}\")\n", + " return output_path\n", + "\n", + "print(\"ā Data functions defined\")" + ], + "id": "2e8612f7daa7af15" + }, + { + "cell_type": "markdown", + "metadata": { + "id": "adapter_section" + }, + "source": [ + "## Batch Evaluation Adapter" + ], + "id": "23de03e011b2b6b7" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "adapter_class" + }, + "outputs": [], + "source": [ + "class TogetherEvalAdapter:\n", + " \"\"\"\n", + " Adapter for using our batch evaluation API.\n", + " Returns binary scores: 1 if judge chose correctly (A), 0 otherwise.\n", + " \"\"\"\n", + "\n", + " def __init__(self, client, judge_model: str, initial_prompt: str):\n", + " self.client = client\n", + " self.judge_model = judge_model\n", + " self.current_prompt = initial_prompt\n", + " self.eval_history = [] # Track all evaluations\n", + " self.file_cache = {} # Cache uploaded files\n", + "\n", + " def upload_data(self, data: List[Dict], name: str) -> str:\n", + " \"\"\"Upload data file to Together Eval, with caching.\"\"\"\n", + "\n", + " cache_key = f\"{name}_{len(data)}\"\n", + " if cache_key in self.file_cache:\n", + " print(f\"ā»ļø Using cached file: {self.file_cache[cache_key]}\")\n", + " return self.file_cache[cache_key]\n", + "\n", + " # Prepare JSONL\n", + " temp_file = f\"temp_{name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl\"\n", + " prepare_jsonl_for_eval(data, temp_file)\n", + "\n", + " # Upload\n", + " print(f\"š¤ Uploading {name} data...\")\n", + " file_response = self.client.files.upload(file=temp_file, purpose=\"eval\")\n", + " file_id = file_response.id\n", + "\n", + " # Cache\n", + " self.file_cache[cache_key] = file_id\n", + " print(f\"ā Uploaded: {file_id}\")\n", + "\n", + " # Cleanup temp file\n", + " os.remove(temp_file)\n", + "\n", + " return file_id\n", + "\n", + " def wait_for_completion(self, workflow_id: str, check_interval: int = 30):\n", + " \"\"\"Poll evaluation status until complete.\"\"\"\n", + " start_time = time.time()\n", + "\n", + " while True:\n", + " status = self.client.evaluation.status(workflow_id)\n", + "\n", + " if status.status.value == \"completed\":\n", + " elapsed = time.time() - start_time\n", + " print(f\"ā Completed in {elapsed:.1f}s\")\n", + " return status\n", + " elif status.status.value == \"failed\":\n", + " raise Exception(f\"Evaluation failed\")\n", + "\n", + " print(f\" Status: {status.status.value}... (checking again in {check_interval}s)\")\n", + " time.sleep(check_interval)\n", + "\n", + " def run_batch_evaluation(\n", + " self,\n", + " data: List[Dict],\n", + " eval_name: str,\n", + " judge_prompt: Optional[str] = None\n", + " ) -> Tuple[Dict[str, int], Dict]:\n", + " \"\"\"\n", + " Run batch evaluation using Together API.\n", + "\n", + " Returns:\n", + " scores_dict: {item_id: score (0 or 1)}\n", + " metrics: {accuracy, a_wins, b_wins, ties, results_path}\n", + " \"\"\"\n", + "\n", + " if judge_prompt is None:\n", + " judge_prompt = self.current_prompt\n", + "\n", + " print(f\"\\n{'=' * 80}\")\n", + " print(f\"š BATCH EVALUATION: {eval_name}\")\n", + " print(f\"{'=' * 80}\")\n", + " print(f\" Examples: {len(data)}\")\n", + " print(f\" Judge: {self.judge_model}\")\n", + "\n", + " # Upload data\n", + " file_id = self.upload_data(data, eval_name)\n", + "\n", + " # Launch evaluation\n", + " print(f\"š Launching evaluation...\")\n", + " eval_response = self.client.evaluation.create(\n", + " type=\"compare\",\n", + " input_data_file_path=file_id,\n", + " judge_model=self.judge_model,\n", + " judge_model_source=\"serverless\",\n", + " judge_system_template=judge_prompt,\n", + " model_a=\"chosen\",\n", + " model_b=\"rejected_1\"\n", + " )\n", + "\n", + " print(f\" Workflow ID: {eval_response.workflow_id}\")\n", + " print(f\"ā³ Waiting for completion...\")\n", + "\n", + " # Wait for completion\n", + " status = self.wait_for_completion(eval_response.workflow_id)\n", + "\n", + " # Get results\n", + " a_wins = status.results.get('A_wins', 0)\n", + " b_wins = status.results.get('B_wins', 0)\n", + " ties = status.results.get('Ties', 0)\n", + "\n", + " print(f\"\\nš Results:\")\n", + " print(f\" A_wins: {a_wins}\")\n", + " print(f\" B_wins: {b_wins}\")\n", + " print(f\" Ties: {ties}\")\n", + "\n", + " # Download detailed results\n", + " result_file_id = status.results.get('result_file_id')\n", + " if not result_file_id:\n", + " raise Exception(\"No result file found\")\n", + "\n", + " results_dir = Path(\"results\")\n", + " results_dir.mkdir(exist_ok=True)\n", + " results_path = results_dir / f\"{eval_name}_results.jsonl\"\n", + "\n", + " print(f\"š„ Downloading detailed results...\")\n", + " self.client.files.retrieve_content(result_file_id, output=str(results_path))\n", + "\n", + " # Parse results\n", + " scores_dict = {}\n", + " results_list = []\n", + "\n", + " with open(results_path, 'r') as f:\n", + " for line in f:\n", + " result = json.loads(line)\n", + " item_id = result.get('id', 'unknown')\n", + " decision = result.get('final_decision')\n", + "\n", + " # Score: 1 if judge correctly chose A (chosen), 0 otherwise\n", + " score = 1 if decision == 'A' else 0\n", + " scores_dict[item_id] = score\n", + " results_list.append(result)\n", + "\n", + " # Calculate accuracy\n", + " accuracy = a_wins / len(data) if len(data) > 0 else 0\n", + "\n", + " # Per-subset accuracy\n", + " subset_metrics = defaultdict(lambda: {'correct': 0, 'total': 0})\n", + " for result in results_list:\n", + " subset = result.get('subset', 'Unknown')\n", + " subset_metrics[subset]['total'] += 1\n", + " if result.get('final_decision') == 'A':\n", + " subset_metrics[subset]['correct'] += 1\n", + "\n", + " subset_accuracy = {\n", + " subset: stats['correct'] / stats['total'] if stats['total'] > 0 else 0\n", + " for subset, stats in subset_metrics.items()\n", + " }\n", + "\n", + " metrics = {\n", + " 'accuracy': accuracy,\n", + " 'a_wins': a_wins,\n", + " 'b_wins': b_wins,\n", + " 'ties': ties,\n", + " 'results_path': str(results_path),\n", + " 'subset_accuracy': subset_accuracy,\n", + " 'total': len(data)\n", + " }\n", + "\n", + " # Store in history\n", + " self.eval_history.append({\n", + " 'name': eval_name,\n", + " 'prompt': judge_prompt,\n", + " 'metrics': metrics,\n", + " 'timestamp': datetime.now().isoformat()\n", + " })\n", + "\n", + " print(f\"ā Accuracy: {accuracy:.2%}\")\n", + "\n", + " return scores_dict, metrics\n", + "\n", + " def get_failure_examples(\n", + " self,\n", + " data: List[Dict],\n", + " scores_dict: Dict[str, int],\n", + " max_examples: int = 10\n", + " ) -> List[Dict]:\n", + " \"\"\"Extract examples where judge made incorrect decisions.\"\"\"\n", + "\n", + " failures = []\n", + " for item in data:\n", + " item_id = item.get('id', 'unknown')\n", + " score = scores_dict.get(item_id, 0)\n", + "\n", + " if score == 0: # Incorrect judgment\n", + " failures.append({\n", + " 'id': item_id,\n", + " 'prompt': item['prompt'],\n", + " 'response_a': item['chosen'][:400], # Truncate for readability\n", + " 'response_b': item['rejected_1'][:400],\n", + " 'subset': item.get('subset', 'unknown'),\n", + " 'judge_error': 'Judge chose B, but humans preferred A'\n", + " })\n", + "\n", + " # Sample if too many\n", + " if len(failures) > max_examples:\n", + " failures = random.sample(failures, max_examples)\n", + "\n", + " return failures\n", + "\n", + "print(\"ā TogetherEvalAdapter defined\")" + ], + "id": "e1adfe1d7222920d" + }, + { + "cell_type": "markdown", + "metadata": { + "id": "reflection_section" + }, + "source": [ + "## Reflection and Prompt Optimization" + ], + "id": "d6377864d655475b" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "optimizer_class" + }, + "outputs": [], + "source": [ + "class SimpleOptimizerLM:\n", + " \"\"\"Simple wrapper for calling optimizer LLM.\"\"\"\n", + "\n", + " def __init__(self, model: str, api_key: str):\n", + " self.client = together.Client(api_key=api_key)\n", + " self.model = model\n", + "\n", + " def __call__(self, prompt: str, max_tokens: int = 4000) -> str:\n", + " \"\"\"Call the LLM with a prompt.\"\"\"\n", + " response = self.client.chat.completions.create(\n", + " model=self.model,\n", + " messages=[{\"role\": \"user\", \"content\": prompt}],\n", + " temperature=0.7,\n", + " max_tokens=max_tokens\n", + " )\n", + " return response.choices[0].message.content\n", + "\n", + "\n", + "def reflect_and_propose_prompt(\n", + " current_prompt: str,\n", + " failure_examples: List[Dict],\n", + " optimizer_lm: SimpleOptimizerLM,\n", + " iteration: int\n", + ") -> str:\n", + " \"\"\"\n", + " Use reflection LLM to analyze failures and propose improved prompt.\n", + " \"\"\"\n", + "\n", + " print(f\"\\nš¤ REFLECTION (Iteration {iteration})\")\n", + " print(f\" Analyzing {len(failure_examples)} failure cases...\")\n", + "\n", + " # Build reflection prompt\n", + " reflection_prompt = f\"\"\"You are optimizing a judge prompt for evaluating AI responses.\n", + "\n", + "The judge's task is to compare two AI responses (A and B) and determine which is better.\n", + "Response A is always the human-preferred response (ground truth).\n", + "Response B is the human-rejected response.\n", + "\n", + "**Current Judge Prompt:**\n", + "```\n", + "{current_prompt}\n", + "```\n", + "\n", + "**Performance Issue:**\n", + "The judge made INCORRECT decisions on the following examples.\n", + "In each case, the judge should have chosen Response A (human-preferred),\n", + "but instead chose Response B (human-rejected).\n", + "\n", + "**Failure Examples:**\n", + "\n", + "{json.dumps(failure_examples, indent=2)}\n", + "\n", + "**Your Task:**\n", + "1. Analyze why the current prompt led to these incorrect judgments\n", + "2. Identify patterns in the failures (e.g., specific types of prompts, common errors)\n", + "3. Propose an improved judge prompt that addresses these issues\n", + "\n", + "**Guidelines:**\n", + "- Keep successful aspects of the current prompt\n", + "- Add specific guidance for the failure patterns you identified\n", + "- Be concrete and actionable\n", + "- Focus on evaluation criteria, not output format\n", + "- Consider: Are there missing criteria? Wrong priorities? Unclear instructions?\n", + "\n", + "**Output the improved prompt within ``` blocks.**\n", + "\"\"\"\n", + "\n", + " # Call optimizer LM\n", + " print(\" Calling reflection LM...\")\n", + " response = optimizer_lm(reflection_prompt)\n", + "\n", + " # Extract new prompt\n", + " match = re.search(r'```(.*?)```', response, re.DOTALL)\n", + " if match:\n", + " new_prompt = match.group(1).strip()\n", + "\n", + " # Remove language tags if present\n", + " if new_prompt.startswith('markdown\\n') or new_prompt.startswith('text\\n'):\n", + " new_prompt = '\\n'.join(new_prompt.split('\\n')[1:])\n", + "\n", + " print(f\"ā Generated new prompt ({len(new_prompt)} chars)\")\n", + " return new_prompt\n", + " else:\n", + " print(\"ā ļø Could not extract prompt, using current\")\n", + " return current_prompt\n", + "\n", + "print(\"ā Reflection functions defined\")" + ], + "id": "f1e3153e9fec9566" + }, + { + "cell_type": "markdown", + "metadata": { + "id": "optimization_section" + }, + "source": [ + "## GEPA Optimization Loop" + ], + "id": "20e3cadf5ac75e37" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "optimization_function" + }, + "outputs": [], + "source": [ + "def run_gepa_optimization(\n", + " train_data: List[Dict],\n", + " val_data: List[Dict],\n", + " test_data: List[Dict],\n", + " adapter: TogetherEvalAdapter,\n", + " optimizer_lm: SimpleOptimizerLM,\n", + " max_iterations: int = 10,\n", + " minibatch_size: int = 5\n", + "):\n", + " \"\"\"\n", + " Custom GEPA optimization loop using batch evaluation.\n", + " \"\"\"\n", + "\n", + " print(f\"\\n{'=' * 80}\")\n", + " print(\"𧬠GEPA OPTIMIZATION WITH BATCH EVALUATION\")\n", + " print(f\"{'=' * 80}\")\n", + " print(f\" Max iterations: {max_iterations}\")\n", + " print(f\" Minibatch size: {minibatch_size}\")\n", + " print(f\" Train size: {len(train_data)}\")\n", + " print(f\" Val size: {len(val_data)}\")\n", + "\n", + " # Track candidates (prompts and their performance)\n", + " candidates = [INITIAL_JUDGE_PROMPT]\n", + " candidate_val_scores = []\n", + "\n", + " # Baseline evaluation on validation set\n", + " print(f\"\\n{'=' * 80}\")\n", + " print(\"BASELINE EVALUATION\")\n", + " print(f\"{'=' * 80}\")\n", + "\n", + " _, baseline_metrics = adapter.run_batch_evaluation(\n", + " val_data,\n", + " \"baseline_val\",\n", + " judge_prompt=INITIAL_JUDGE_PROMPT\n", + " )\n", + "\n", + " baseline_acc = baseline_metrics['accuracy']\n", + " candidate_val_scores.append(baseline_acc)\n", + "\n", + " print(f\"\\nā Baseline validation accuracy: {baseline_acc:.2%}\")\n", + "\n", + " # GEPA optimization loop\n", + " best_acc = baseline_acc\n", + " best_prompt = INITIAL_JUDGE_PROMPT\n", + " no_improvement_count = 0\n", + "\n", + " for iteration in range(max_iterations):\n", + " print(f\"\\n{'=' * 80}\")\n", + " print(f\"ITERATION {iteration + 1}/{max_iterations}\")\n", + " print(f\"{'=' * 80}\")\n", + "\n", + " # Select best candidate so far\n", + " best_idx = np.argmax(candidate_val_scores)\n", + " current_prompt = candidates[best_idx]\n", + " current_acc = candidate_val_scores[best_idx]\n", + "\n", + " print(f\" Current best: Candidate {best_idx} ({current_acc:.2%})\")\n", + "\n", + " # Sample minibatch from training data\n", + " minibatch = random.sample(train_data, min(minibatch_size, len(train_data)))\n", + " print(f\" Sampled {len(minibatch)} examples for reflection\")\n", + "\n", + " # Evaluate minibatch with current prompt\n", + " mb_scores, mb_metrics = adapter.run_batch_evaluation(\n", + " minibatch,\n", + " f\"iter{iteration + 1}_minibatch\",\n", + " judge_prompt=current_prompt\n", + " )\n", + "\n", + " # Get failure examples\n", + " failures = adapter.get_failure_examples(minibatch, mb_scores, max_examples=5)\n", + "\n", + " if not failures:\n", + " print(\" ā Perfect on minibatch! Trying different sample...\")\n", + " continue\n", + "\n", + " print(f\" Found {len(failures)} failures in minibatch\")\n", + "\n", + " # Reflect and propose new prompt\n", + " new_prompt = reflect_and_propose_prompt(\n", + " current_prompt=current_prompt,\n", + " failure_examples=failures,\n", + " optimizer_lm=optimizer_lm,\n", + " iteration=iteration + 1\n", + " )\n", + "\n", + " # Check if prompt actually changed\n", + " if new_prompt == current_prompt:\n", + " print(\" ā ļø Prompt unchanged, skipping validation\")\n", + " no_improvement_count += 1\n", + " if no_improvement_count >= 3:\n", + " print(\" š No changes for 3 iterations, stopping early\")\n", + " break\n", + " continue\n", + "\n", + " # Update adapter with new prompt\n", + " adapter.current_prompt = new_prompt\n", + "\n", + " # Evaluate on full validation set\n", + " print(f\"\\n Evaluating new prompt on validation set...\")\n", + " new_scores, new_metrics = adapter.run_batch_evaluation(\n", + " val_data,\n", + " f\"iter{iteration + 1}_candidate\",\n", + " judge_prompt=new_prompt\n", + " )\n", + "\n", + " new_acc = new_metrics['accuracy']\n", + " improvement = new_acc - current_acc\n", + "\n", + " print(f\"\\n Results:\")\n", + " print(f\" Current: {current_acc:.2%}\")\n", + " print(f\" New: {new_acc:.2%}\")\n", + " print(f\" Change: {improvement * 100:+.2f}pp\")\n", + "\n", + " # Add to candidates\n", + " candidates.append(new_prompt)\n", + " candidate_val_scores.append(new_acc)\n", + "\n", + " # Update best if improved\n", + " if new_acc > best_acc:\n", + " print(f\" š New best! Improvement: {(new_acc - best_acc) * 100:+.2f}pp\")\n", + " best_acc = new_acc\n", + " best_prompt = new_prompt\n", + " no_improvement_count = 0\n", + " else:\n", + " print(f\" No improvement over best ({best_acc:.2%})\")\n", + " no_improvement_count += 1\n", + "\n", + " if no_improvement_count >= 3:\n", + " print(\" š No improvement for 3 iterations, stopping early\")\n", + " break\n", + "\n", + " # Final evaluation on test set\n", + " print(f\"\\n{'=' * 80}\")\n", + " print(\"FINAL TEST SET EVALUATION\")\n", + " print(f\"{'=' * 80}\")\n", + "\n", + " # Baseline on test\n", + " print(\"\\n[1/2] Baseline on test set...\")\n", + " _, baseline_test_metrics = adapter.run_batch_evaluation(\n", + " test_data,\n", + " \"baseline_test\",\n", + " judge_prompt=INITIAL_JUDGE_PROMPT\n", + " )\n", + "\n", + " # Optimized on test\n", + " print(\"\\n[2/2] Optimized on test set...\")\n", + " _, optimized_test_metrics = adapter.run_batch_evaluation(\n", + " test_data,\n", + " \"optimized_test\",\n", + " judge_prompt=best_prompt\n", + " )\n", + "\n", + " # Summary\n", + " print(f\"\\n{'=' * 80}\")\n", + " print(\"š OPTIMIZATION COMPLETE!\")\n", + " print(f\"{'=' * 80}\")\n", + "\n", + " print(f\"\\nVALIDATION RESULTS:\")\n", + " print(f\" Baseline: {baseline_acc:.2%}\")\n", + " print(f\" Optimized: {best_acc:.2%}\")\n", + " print(f\" Improvement: {(best_acc - baseline_acc) * 100:+.2f}pp\")\n", + "\n", + " print(f\"\\nTEST RESULTS:\")\n", + " print(f\" Baseline: {baseline_test_metrics['accuracy']:.2%}\")\n", + " print(f\" Optimized: {optimized_test_metrics['accuracy']:.2%}\")\n", + " print(f\" Improvement: {(optimized_test_metrics['accuracy'] - baseline_test_metrics['accuracy']) * 100:+.2f}pp\")\n", + "\n", + " # Per-subset breakdown\n", + " print(f\"\\nš PER-SUBSET BREAKDOWN (Test Set):\")\n", + " all_subsets = set(baseline_test_metrics['subset_accuracy'].keys()) | set(\n", + " optimized_test_metrics['subset_accuracy'].keys())\n", + "\n", + " for subset in sorted(all_subsets):\n", + " base_acc = baseline_test_metrics['subset_accuracy'].get(subset, 0)\n", + " opt_acc = optimized_test_metrics['subset_accuracy'].get(subset, 0)\n", + " improvement = opt_acc - base_acc\n", + " print(f\" {subset:20s}: {base_acc:.2%} ā {opt_acc:.2%} ({improvement * 100:+.1f}pp)\")\n", + "\n", + " return {\n", + " 'best_prompt': best_prompt,\n", + " 'best_val_accuracy': best_acc,\n", + " 'baseline_test_metrics': baseline_test_metrics,\n", + " 'optimized_test_metrics': optimized_test_metrics,\n", + " 'candidates': candidates,\n", + " 'candidate_scores': candidate_val_scores,\n", + " 'eval_history': adapter.eval_history\n", + " }\n", + "\n", + "print(\"ā Optimization function defined\")" + ], + "id": "af54a457c63c18e4" + }, + { + "cell_type": "markdown", + "metadata": { + "id": "upload_data" + }, + "source": [ + "## Load Your Data\n", + "\n", + "Paste the file ID for your uploaded data file from the data preparation step." + ], + "id": "8f90733ff4c75863" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "load_file" + }, + "outputs": [], + "source": [ + "# Paste your file ID from the data preparation step\n", + "DATA_FILE_ID = \"\" # e.g., \"file-65aa3ce1-cc93-48d0-b871-b974665f3dd1\"\n", + "\n", + "if not DATA_FILE_ID:\n", + " raise ValueError(\"Please provide the DATA_FILE_ID\")\n", + "\n", + "# Download the data from Together AI\n", + "print(\"š„ Downloading data from Together AI...\")\n", + "data_path = \"uploaded_data.json\"\n", + "client.files.retrieve_content(DATA_FILE_ID, output=data_path)\n", + "print(f\"ā Downloaded data file\")\n", + "\n", + "# Load and split data\n", + "train_data, val_data, test_data = load_and_split_data(data_path)" + ], + "id": "8190874d74f476c6" + }, + { + "cell_type": "markdown", + "metadata": { + "id": "run_section" + }, + "source": [ + "## Run Optimization" + ], + "id": "f221b494a47af25c" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "run_optimization" + }, + "outputs": [], + "source": [ + "# Configuration\n", + "MAX_ITERATIONS = 10\n", + "MINIBATCH_SIZE = 5\n", + "\n", + "print(\"=\" * 80)\n", + "print(\"šÆ GEPA JUDGE OPTIMIZATION WITH TOGETHER AI\")\n", + "print(\"=\" * 80)\n", + "print(f\"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\")\n", + "\n", + "# Data should already be loaded from previous cells\n", + "print(f\"\\nUsing data:\")\n", + "print(f\" Train: {len(train_data)} examples\")\n", + "print(f\" Val: {len(val_data)} examples\")\n", + "print(f\" Test: {len(test_data)} examples\")\n", + "\n", + "# Create adapter\n", + "adapter = TogetherEvalAdapter(\n", + " client=client,\n", + " judge_model=JUDGE_MODEL,\n", + " initial_prompt=INITIAL_JUDGE_PROMPT\n", + ")\n", + "\n", + "# Create optimizer LM\n", + "optimizer_lm = SimpleOptimizerLM(\n", + " model=OPTIMIZER_MODEL,\n", + " api_key=TOGETHER_API_KEY\n", + ")\n", + "\n", + "# Run GEPA optimization\n", + "results = run_gepa_optimization(\n", + " train_data=train_data,\n", + " val_data=val_data,\n", + " test_data=test_data,\n", + " adapter=adapter,\n", + " optimizer_lm=optimizer_lm,\n", + " max_iterations=MAX_ITERATIONS,\n", + " minibatch_size=MINIBATCH_SIZE\n", + ")" + ], + "id": "7a148733370bef76" + }, + { + "cell_type": "markdown", + "metadata": { + "id": "results_section" + }, + "source": [ + "## Save Results" + ], + "id": "fbc868fad97c17da" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "save_results" + }, + "outputs": [], + "source": [ + "# Save results\n", + "output_dir = Path(\"results\")\n", + "output_dir.mkdir(exist_ok=True)\n", + "\n", + "timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n", + "\n", + "# Save optimized prompt\n", + "prompt_path = output_dir / f\"optimized_prompt_{timestamp}.txt\"\n", + "with open(prompt_path, 'w') as f:\n", + " f.write(results['best_prompt'])\n", + "print(f\"\\nš¾ Saved optimized prompt to: {prompt_path}\")\n", + "\n", + "# Save full results\n", + "results_path = output_dir / f\"optimization_results_{timestamp}.json\"\n", + "\n", + "# Make results JSON-serializable\n", + "json_results = {\n", + " 'best_prompt': results['best_prompt'],\n", + " 'best_val_accuracy': float(results['best_val_accuracy']),\n", + " 'baseline_test_accuracy': float(results['baseline_test_metrics']['accuracy']),\n", + " 'optimized_test_accuracy': float(results['optimized_test_metrics']['accuracy']),\n", + " 'improvement': float(\n", + " results['optimized_test_metrics']['accuracy'] - results['baseline_test_metrics']['accuracy']),\n", + " 'baseline_test_metrics': {\n", + " 'accuracy': float(results['baseline_test_metrics']['accuracy']),\n", + " 'a_wins': results['baseline_test_metrics']['a_wins'],\n", + " 'b_wins': results['baseline_test_metrics']['b_wins'],\n", + " 'ties': results['baseline_test_metrics']['ties'],\n", + " 'subset_accuracy': {k: float(v) for k, v in results['baseline_test_metrics']['subset_accuracy'].items()}\n", + " },\n", + " 'optimized_test_metrics': {\n", + " 'accuracy': float(results['optimized_test_metrics']['accuracy']),\n", + " 'a_wins': results['optimized_test_metrics']['a_wins'],\n", + " 'b_wins': results['optimized_test_metrics']['b_wins'],\n", + " 'ties': results['optimized_test_metrics']['ties'],\n", + " 'subset_accuracy': {k: float(v) for k, v in results['optimized_test_metrics']['subset_accuracy'].items()}\n", + " },\n", + " 'num_candidates': len(results['candidates']),\n", + " 'candidate_scores': [float(s) for s in results['candidate_scores']],\n", + " 'config': {\n", + " 'judge_model': JUDGE_MODEL,\n", + " 'optimizer_model': OPTIMIZER_MODEL,\n", + " 'train_size': TRAIN_SIZE,\n", + " 'val_size': VAL_SIZE,\n", + " 'test_size': len(test_data),\n", + " 'max_iterations': MAX_ITERATIONS,\n", + " 'minibatch_size': MINIBATCH_SIZE\n", + " },\n", + " 'timestamp': timestamp\n", + "}\n", + "\n", + "with open(results_path, 'w') as f:\n", + " json.dump(json_results, f, indent=2)\n", + "print(f\"š¾ Saved results to: {results_path}\")\n", + "\n", + "# Display optimized prompt\n", + "print(f\"\\n{'=' * 80}\")\n", + "print(\"š OPTIMIZED JUDGE PROMPT\")\n", + "print(f\"{'=' * 80}\")\n", + "print(results['best_prompt'])\n", + "print(f\"{'=' * 80}\")\n", + "\n", + "print(\"\\nā Optimization complete!\")" + ], + "id": "ac47aa62059f235c" + }, + { + "cell_type": "markdown", + "metadata": { + "id": "download_section" + }, + "source": [ + "## Download Results" + ], + "id": "66f95fa71a7a7c3e" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "download" + }, + "outputs": [], + "source": [ + "# Download the optimized prompt and results\n", + "files.download(str(prompt_path))\n", + "files.download(str(results_path))\n", + "\n", + "print(\"\\nš„ Files downloaded to your local machine!\")" + ], + "id": "2c3d5ec0e4644663" + } + ], + "metadata": { + "colab": { + "name": "GEPA_Judge_Optimization.ipynb", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}