From d4d46e62a1684cd7baef58b0abbf5a59ec6f5fa5 Mon Sep 17 00:00:00 2001 From: jli Date: Mon, 22 Dec 2025 23:14:51 +0800 Subject: [PATCH 1/6] add GEPA optimization notebook for summarization --- Evals/GEPA_Optimization.ipynb | 868 ++++++++++++++++++++++++++++++++++ 1 file changed, 868 insertions(+) create mode 100644 Evals/GEPA_Optimization.ipynb diff --git a/Evals/GEPA_Optimization.ipynb b/Evals/GEPA_Optimization.ipynb new file mode 100644 index 0000000..07ae1e9 --- /dev/null +++ b/Evals/GEPA_Optimization.ipynb @@ -0,0 +1,868 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# GEPA Summarization Optimization with LLM Judge Evaluation\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/)\n", + "\n", + "## Introduction\n", + "\n", + "This notebook demonstrates how to optimize summarization prompts using GEPA (Generate, Evaluate, Propose, Adapt) with the our Evaluations API. We'll:\n", + "\n", + "1. Load the CNN/DailyMail dataset containing news articles\n", + "2. Start with a baseline summarization prompt\n", + "3. Use an optimizer LLM to iteratively improve the prompt\n", + "4. Compare prompts head-to-head using a judge model\n", + "5. Track improvement over multiple iterations\n", + "\n", + "**Concepts Covered:**\n", + "- **GEPA Optimization**: Iterative prompt engineering using LLM feedback\n", + "- **LLM-as-a-Judge**: Using a language model to evaluate and compare outputs\n", + "- **Batch Evaluation**: Efficient comparison of multiple summaries\n", + "- **Prompt Engineering**: Systematic improvement of instruction prompts" + ], + "id": "9bed21b9f21cadb7" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 📦 Setup and Installation" + ], + "id": "c044d292f626f2f6" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -qU together dspy-ai datasets tqdm" + ], + "id": "cf56ca26c1b94222" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import together\n", + "import json\n", + "import random\n", + "import os\n", + "import re\n", + "import time\n", + "from pathlib import Path\n", + "from typing import List, Dict, Tuple\n", + "from datetime import datetime\n", + "\n", + "import dspy\n", + "from datasets import load_dataset\n", + "from tqdm import tqdm" + ], + "id": "1c293b491e894110" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## ⚙️ Configuration\n", + "\n", + "Set up your API key and configure the models we'll use:\n", + "- **Summarizer Model**: Generates the summaries\n", + "- **Judge Model**: Evaluates which summary is better\n", + "- **Optimizer Model**: Proposes improvements to the prompt" + ], + "id": "8e71863c8ff3faa6" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "client = together.Client()\n", + "\n", + "# Model configuration\n", + "SUMMARIZER_MODEL = \"openai/gpt-oss-20b\"\n", + "JUDGE_MODEL = \"deepseek-ai/DeepSeek-V3\"\n", + "OPTIMIZER_MODEL = \"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo\"\n", + "\n", + "# Data splits\n", + "TRAIN_SIZE = 150\n", + "VAL_SIZE = 300\n", + "TEST_SIZE = 300\n", + "\n", + "RANDOM_SEED = 42\n", + "\n", + "print(\"✓ Configuration complete\")" + ], + "id": "3d21616fa03c0145" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 📝 Baseline and Judge Prompts\n", + "\n", + "We start with a simple baseline prompt for summarization. The GEPA process will iteratively improve this prompt based on performance feedback." + ], + "id": "d9378d341fb8389d" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "BASELINE_PROMPT = \"\"\"Summarize this news article in 3-5 key points.\n", + "\n", + "Write a brief summary covering:\n", + "- The main news event\n", + "- Key people or organizations involved\n", + "- Important details or outcomes\n", + "- Any significant context\n", + "\n", + "Keep it to 3-5 sentences total.\"\"\"\n", + "\n", + "JUDGE_PROMPT = \"\"\"Compare these two summaries of the same news article.\n", + "\n", + "Which summary better:\n", + "- Captures the main news story\n", + "- Includes important details\n", + "- Is clear and concise\n", + "- Avoids unnecessary information\n", + "\n", + "Choose A or B and explain why briefly.\"\"\"\n", + "\n", + "print(\"Baseline Prompt:\")\n", + "print(BASELINE_PROMPT)\n", + "print(\"\\nJudge Prompt:\")\n", + "print(JUDGE_PROMPT)" + ], + "id": "263940c8c55eb1dd" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 📂 Loading the CNN/DailyMail Dataset\n", + "\n", + "The CNN/DailyMail dataset contains news articles paired with human-written highlights. We'll use the articles as our source text and split the data into train, validation, and test sets.\n", + "\n", + "**Dataset Structure:**\n", + "- `article`: The full news article text\n", + "- `highlights`: Human-written bullet-point summary\n", + "- We'll use the articles for summarization and evaluate our generated summaries" + ], + "id": "c0a86293e7b95dd9" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def load_and_split_data():\n", + " \"\"\"Load CNN/DailyMail dataset for summarization.\"\"\"\n", + " print(\"\\n\" + \"=\" * 80)\n", + " print(\"📂 LOADING DATA\")\n", + " print(\"=\" * 80)\n", + "\n", + " print(\"Loading CNN/DailyMail dataset...\")\n", + " dataset = load_dataset(\"abisee/cnn_dailymail\", \"3.0.0\", trust_remote_code=True)\n", + " data = dataset['test']\n", + "\n", + " print(f\"✓ Loaded {len(data)} examples\")\n", + " print(f\" Sample article: {data[0]['article'][:100]}...\")\n", + " print(f\" Sample highlights: {data[0]['highlights'][:100]}...\")\n", + "\n", + " all_data = []\n", + " for i, item in enumerate(data):\n", + " all_data.append({\n", + " 'id': f\"cnn_{i}\",\n", + " 'text': item['article'],\n", + " 'reference_summary': item['highlights']\n", + " })\n", + "\n", + " print(f\"✓ Converted to {len(all_data)} items\")\n", + "\n", + " # Shuffle and split\n", + " random.seed(RANDOM_SEED)\n", + " random.shuffle(all_data)\n", + "\n", + " train_data = all_data[:TRAIN_SIZE]\n", + " val_data = all_data[TRAIN_SIZE:TRAIN_SIZE + VAL_SIZE]\n", + " test_data = all_data[TRAIN_SIZE + VAL_SIZE:TRAIN_SIZE + VAL_SIZE + TEST_SIZE]\n", + "\n", + " print(f\"✓ Split: Train={len(train_data)}, Val={len(val_data)}, Test={len(test_data)}\")\n", + "\n", + " # Verify\n", + " assert len(val_data) > 0, \"Val data is empty!\"\n", + " assert len(test_data) > 0, \"Test data is empty!\"\n", + "\n", + " return train_data, val_data, test_data\n", + "\n", + "# Load the data\n", + "train_data, val_data, test_data = load_and_split_data()" + ], + "id": "7dcc2d8d5c706df4" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 🤖 Summarization Module\n", + "\n", + "We create a DSPy module that wraps our summarization task. This module can be configured with different instruction prompts, which is key to the GEPA optimization process." + ], + "id": "d1b9222690db8449" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class Summarizer(dspy.Signature):\n", + " \"\"\"Generate a summary.\"\"\"\n", + " text = dspy.InputField()\n", + " summary = dspy.OutputField()\n", + "\n", + "\n", + "class SummarizationModule(dspy.Module):\n", + " \"\"\"Summarization module.\"\"\"\n", + "\n", + " def __init__(self, instructions=None):\n", + " super().__init__()\n", + " self.instructions = instructions or BASELINE_PROMPT\n", + "\n", + " if instructions:\n", + " class CustomSummarizer(dspy.Signature):\n", + " __doc__ = instructions\n", + " text = dspy.InputField()\n", + " summary = dspy.OutputField()\n", + "\n", + " self.predictor = dspy.Predict(CustomSummarizer)\n", + " else:\n", + " self.predictor = dspy.Predict(Summarizer)\n", + "\n", + " def forward(self, text):\n", + " return self.predictor(text=text)\n", + "\n", + "print(\"✓ Summarization module defined\")" + ], + "id": "b8ca2917024c326e" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 📊 Batch Summary Generation\n", + "\n", + "This function generates summaries for a batch of articles using a given prompt. It includes error handling and progress tracking." + ], + "id": "590d6b9c625ca2cc" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def generate_summaries_batch(\n", + " summarizer: SummarizationModule,\n", + " data: List[Dict],\n", + " desc: str = \"Generating\"\n", + ") -> List[Dict]:\n", + " \"\"\"Generate summaries for a batch of texts.\"\"\"\n", + " results = []\n", + " errors = 0\n", + " error_details = []\n", + "\n", + " # Print the prompt being used (first item only)\n", + " if len(data) > 0:\n", + " print(f\" Using prompt: {summarizer.instructions[:100]}...\")\n", + "\n", + " for item in tqdm(data, desc=desc):\n", + " try:\n", + " pred = summarizer(text=item['text'][:5000])\n", + "\n", + " if pred is None:\n", + " raise ValueError(\"Model returned None\")\n", + "\n", + " if hasattr(pred, 'summary') and pred.summary:\n", + " summary = pred.summary\n", + " elif isinstance(pred, str):\n", + " summary = pred\n", + " else:\n", + " print(f\"\\n DEBUG: pred type={type(pred)}, hasattr summary={hasattr(pred, 'summary')}\")\n", + " raise ValueError(f\"Cannot extract summary from {type(pred)}\")\n", + "\n", + " summary = summary.strip()\n", + " if len(summary) < 20:\n", + " raise ValueError(\"Summary too short\")\n", + "\n", + " except Exception as e:\n", + " errors += 1\n", + " error_details.append(str(e)[:100])\n", + "\n", + " if errors <= 5:\n", + " print(f\"\\n⚠️ Error: {str(e)[:80]}\")\n", + "\n", + " summary = \"Error generating summary.\"\n", + "\n", + " results.append({\n", + " 'id': item['id'],\n", + " 'text': item['text'],\n", + " 'summary': summary\n", + " })\n", + "\n", + " if errors > 0:\n", + " print(f\"\\n⚠️ Total errors: {errors}/{len(data)} ({errors / len(data) * 100:.1f}%)\")\n", + " from collections import Counter\n", + " common_errors = Counter(error_details).most_common(3)\n", + " print(f\" Most common errors:\")\n", + " for err, count in common_errors:\n", + " print(f\" - {err[:60]}... ({count}x)\")\n", + "\n", + " return results\n", + "\n", + "print(\"✓ Batch generation function defined\")" + ], + "id": "270abdde73d2ca72" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 🧠 Optimizer LLM Wrapper\n", + "\n", + "This wrapper allows us to use an LLM to propose improvements to our summarization prompt based on current performance." + ], + "id": "2cfe63f485894d7c" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class SimpleOptimizerLM:\n", + " \"\"\"Wrapper for optimizer LLM.\"\"\"\n", + "\n", + " def __init__(self, model: str, api_key: str):\n", + " self.client = together.Client(api_key=api_key)\n", + " self.model = model\n", + "\n", + " def __call__(self, prompt: str) -> str:\n", + " response = self.client.chat.completions.create(\n", + " model=self.model,\n", + " messages=[{\"role\": \"user\", \"content\": prompt}],\n", + " temperature=0.7,\n", + " max_tokens=4000\n", + " )\n", + " return response.choices[0].message.content\n", + "\n", + "print(\"✓ Optimizer LLM wrapper defined\")" + ], + "id": "d11af9ff91f442df" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 🤔 Reflection and Prompt Improvement\n", + "\n", + "This function uses the optimizer LLM to analyze the current prompt and performance, then propose an improved version.\n", + "\n", + "**Key Constraints:**\n", + "- Keep prompts under 150 words for clarity\n", + "- Focus on simple, direct instructions\n", + "- Target 4-6 sentence summaries\n", + "- Avoid overly complex requirements" + ], + "id": "67a224aff87d2f5e" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def reflect_and_improve_prompt(\n", + " current_prompt: str,\n", + " current_score: float,\n", + " optimizer_lm: SimpleOptimizerLM,\n", + " iteration: int\n", + ") -> str:\n", + " \"\"\"Use LLM to propose improved prompt.\"\"\"\n", + "\n", + " print(f\"\\n🤔 REFLECTION (Iteration {iteration})\")\n", + "\n", + " reflection_prompt = f\"\"\"You are optimizing a summarization prompt for CNN/DailyMail news articles.\n", + "\n", + "Current Prompt:\n", + "```\n", + "{current_prompt}\n", + "```\n", + "\n", + "Current Performance: {current_score:.1%} win rate\n", + "\n", + "Your task: Propose a SIMPLE improved version that generates better summaries.\n", + "\n", + "CRITICAL CONSTRAINTS:\n", + "- Keep the prompt under 150 words\n", + "- Make it clear and direct (NOT overly complex)\n", + "- Target 4-6 sentence summaries\n", + "- Avoid excessive instructions or formatting requirements\n", + "- The prompt should be easy for the model to follow\n", + "\n", + "Focus on:\n", + "- Should it emphasize different aspects (accuracy, brevity, completeness)?\n", + "- Are the current guidelines clear?\n", + "- Is anything missing or unnecessary?\n", + "\n", + "Output ONLY the improved prompt within ``` blocks. Keep it simple and clear.\"\"\"\n", + "\n", + " response = optimizer_lm(reflection_prompt)\n", + "\n", + " # Extract prompt\n", + " match = re.search(r'```(.*?)```', response, re.DOTALL)\n", + " if match:\n", + " new_prompt = match.group(1).strip()\n", + " # Remove language tags\n", + " for tag in ['markdown', 'text', 'python', 'plaintext']:\n", + " if new_prompt.startswith(f'{tag}\\n'):\n", + " new_prompt = '\\n'.join(new_prompt.split('\\n')[1:])\n", + "\n", + " # Validate length (reject if too long)\n", + " word_count = len(new_prompt.split())\n", + " if word_count > 200:\n", + " print(f\" ⚠️ Generated prompt too long ({word_count} words), using current\")\n", + " return current_prompt\n", + "\n", + " print(f\"✓ Generated new prompt ({word_count} words)\")\n", + " return new_prompt\n", + "\n", + " print(\"⚠️ Could not extract prompt\")\n", + " return current_prompt\n", + "\n", + "print(\"✓ Reflection function defined\")" + ], + "id": "1186e66cab3ea1f1" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 🔄 Head-to-Head Prompt Comparison\n", + "\n", + "This function compares two prompts by:\n", + "1. Generating summaries with both prompts\n", + "2. Creating a comparison dataset\n", + "3. Using the Together AI evaluation API with a judge model\n", + "4. Computing win rates\n", + "\n", + "The evaluation uses a two-pass approach to eliminate position bias." + ], + "id": "a2fbbd02f5054425" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def compare_two_prompts_on_batch(\n", + " data: List[Dict],\n", + " prompt_a: str,\n", + " prompt_b: str,\n", + " summarizer_lm: dspy.LM,\n", + " eval_name: str\n", + ") -> Tuple[float, float, Dict]:\n", + " \"\"\"\n", + " Compare two summarization prompts.\n", + "\n", + " 1. Generate summaries with prompt A\n", + " 2. Generate summaries with prompt B\n", + " 3. Use judge to compare them\n", + " 4. Return win rate for prompt A\n", + " \"\"\"\n", + "\n", + " print(f\"\\n{'=' * 80}\")\n", + " print(f\"🔄 COMPARING PROMPTS: {eval_name}\")\n", + " print(f\"{'=' * 80}\")\n", + "\n", + " # Step 1: Generate with both prompts\n", + " dspy.configure(lm=summarizer_lm)\n", + "\n", + " summarizer_a = SummarizationModule(prompt_a)\n", + " summarizer_b = SummarizationModule(prompt_b)\n", + "\n", + " print(\"Generating summaries with Prompt A...\")\n", + " summaries_a = generate_summaries_batch(summarizer_a, data, \"Prompt A\")\n", + "\n", + " print(\"Generating summaries with Prompt B...\")\n", + " summaries_b = generate_summaries_batch(summarizer_b, data, \"Prompt B\")\n", + "\n", + " # Step 2: Prepare comparison data\n", + " temp_file = f\"temp_compare_{eval_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl\"\n", + "\n", + " with open(temp_file, 'w') as f:\n", + " for summary_a, summary_b in zip(summaries_a, summaries_b):\n", + " formatted = {\n", + " \"prompt\": f\"Source article: {summary_a['text'][:5000]}\",\n", + " \"model_a_output\": summary_a['summary'],\n", + " \"model_b_output\": summary_b['summary'],\n", + " \"id\": summary_a['id']\n", + " }\n", + " f.write(json.dumps(formatted) + '\\n')\n", + "\n", + " # Step 3: Upload and evaluate\n", + " print(\"📤 Uploading for comparison...\")\n", + " file_response = client.files.upload(file=temp_file, purpose=\"eval\")\n", + " file_id = file_response.id\n", + "\n", + " print(\"🚀 Launching comparison...\")\n", + " eval_response = client.evaluation.create(\n", + " type=\"compare\",\n", + " input_data_file_path=file_id,\n", + " judge_model=JUDGE_MODEL,\n", + " judge_model_source=\"serverless\",\n", + " judge_system_template=JUDGE_PROMPT,\n", + " model_a=\"model_a_output\",\n", + " model_b=\"model_b_output\"\n", + " )\n", + "\n", + " # Step 4: Wait and get results\n", + " print(f\"⏳ Waiting (ID: {eval_response.workflow_id})...\")\n", + " while True:\n", + " status = client.evaluation.status(eval_response.workflow_id)\n", + " if status.status.value == \"completed\":\n", + " break\n", + " elif status.status.value == \"failed\":\n", + " raise Exception(\"Evaluation failed\")\n", + " time.sleep(30)\n", + "\n", + " a_wins = status.results.get('A_wins', 0)\n", + " b_wins = status.results.get('B_wins', 0)\n", + " ties = status.results.get('Ties', 0)\n", + "\n", + " # Win rate for prompt A\n", + " decisive_total = a_wins + b_wins\n", + " if decisive_total > 0:\n", + " a_win_rate = a_wins / decisive_total\n", + " b_win_rate = b_wins / decisive_total\n", + " else:\n", + " a_win_rate = b_win_rate = 0.5\n", + "\n", + " print(f\"✓ Results: Prompt A wins={a_wins}, Prompt B wins={b_wins}, Ties={ties}\")\n", + " print(f\"✓ Prompt A win rate: {a_win_rate:.2%}\")\n", + "\n", + " os.remove(temp_file)\n", + "\n", + " return a_win_rate, b_win_rate, {\n", + " 'a_wins': a_wins,\n", + " 'b_wins': b_wins,\n", + " 'ties': ties,\n", + " 'a_win_rate': a_win_rate\n", + " }\n", + "\n", + "print(\"✓ Comparison function defined\")" + ], + "id": "5a1b2d5116f3731f" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 🧬 GEPA Optimization Loop\n", + "\n", + "This is the main optimization loop that implements the GEPA algorithm:\n", + "\n", + "1. **Generate**: Create summaries with current prompt\n", + "2. **Evaluate**: Compare against baseline using judge model\n", + "3. **Propose**: Use optimizer LLM to suggest improvements\n", + "4. **Adapt**: Accept improvements that increase win rate\n", + "\n", + "The process repeats for multiple iterations, tracking the best prompt found." + ], + "id": "6657d33b050676ff" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def run_manual_gepa(\n", + " train_data: List[Dict],\n", + " val_data: List[Dict],\n", + " test_data: List[Dict],\n", + " summarizer_lm: dspy.LM,\n", + " optimizer_lm: SimpleOptimizerLM,\n", + " max_iterations: int = 5\n", + "):\n", + " \"\"\"Manual GEPA-style optimization.\"\"\"\n", + "\n", + " start_time = time.time()\n", + "\n", + " print(\"\\n\" + \"=\" * 80)\n", + " print(\"🧬 MANUAL GEPA OPTIMIZATION\")\n", + " print(\"=\" * 80)\n", + "\n", + " # Track best prompt\n", + " best_prompt = BASELINE_PROMPT\n", + " best_val_score = 0.5 # Start at 50% (neutral)\n", + "\n", + " for i in range(max_iterations):\n", + " print(f\"\\n{'=' * 80}\")\n", + " print(f\"ITERATION {i + 1}/{max_iterations}\")\n", + " print(f\"{'=' * 80}\")\n", + "\n", + " if i == 0:\n", + " print(\"Iteration 0: Establishing baseline (no comparison yet)\")\n", + " continue\n", + "\n", + " # Generate new candidate prompt\n", + " new_prompt = reflect_and_improve_prompt(\n", + " best_prompt,\n", + " best_val_score,\n", + " optimizer_lm,\n", + " i\n", + " )\n", + "\n", + " if new_prompt == best_prompt:\n", + " print(\"⚠️ No change in prompt, stopping\")\n", + " break\n", + "\n", + " print(f\"✓ Generated candidate prompt ({len(new_prompt)} chars)\")\n", + "\n", + " # Compare best_prompt vs new_prompt on validation set\n", + " baseline_win_rate, new_prompt_win_rate, metrics = compare_two_prompts_on_batch(\n", + " val_data,\n", + " prompt_a=best_prompt,\n", + " prompt_b=new_prompt,\n", + " summarizer_lm=summarizer_lm,\n", + " eval_name=f\"iter{i}_val\"\n", + " )\n", + "\n", + " new_prompt_win_rate = 1.0 - baseline_win_rate\n", + "\n", + " print(f\"\\n Current best: {baseline_win_rate:.2%}\")\n", + " print(f\" New candidate: {new_prompt_win_rate:.2%}\")\n", + "\n", + " if new_prompt_win_rate > best_val_score:\n", + " improvement = new_prompt_win_rate - best_val_score\n", + " print(f\" 🎉 New best! (+{improvement * 100:.2f}pp)\")\n", + " best_prompt = new_prompt\n", + " best_val_score = new_prompt_win_rate\n", + " else:\n", + " print(f\" No improvement\")\n", + "\n", + " # Calculate total time\n", + " total_time = time.time() - start_time\n", + " hours = int(total_time // 3600)\n", + " minutes = int((total_time % 3600) // 60)\n", + " seconds = int(total_time % 60)\n", + "\n", + " # Final test evaluation\n", + " print(\"\\n\" + \"=\" * 80)\n", + " print(\"📊 FINAL TEST EVALUATION\")\n", + " print(\"=\" * 80)\n", + "\n", + " print(f\"\\n⏱️ OPTIMIZATION TIME:\")\n", + " if hours > 0:\n", + " print(f\" Total: {hours}h {minutes}m {seconds}s\")\n", + " elif minutes > 0:\n", + " print(f\" Total: {minutes}m {seconds}s\")\n", + " else:\n", + " print(f\" Total: {seconds}s\")\n", + "\n", + " baseline_test_win_rate, optimized_test_win_rate, _ = compare_two_prompts_on_batch(\n", + " test_data,\n", + " prompt_a=BASELINE_PROMPT,\n", + " prompt_b=best_prompt,\n", + " summarizer_lm=summarizer_lm,\n", + " eval_name=\"final_test\"\n", + " )\n", + "\n", + " # Display results\n", + " print(\"\\n\" + \"=\" * 80)\n", + " print(\"🎉 FINAL RESULTS\")\n", + " print(\"=\" * 80)\n", + "\n", + " print(f\"\\nTEST SET:\")\n", + " print(f\" Baseline prompt: {baseline_test_win_rate:.2%}\")\n", + " print(f\" Optimized prompt: {optimized_test_win_rate:.2%}\")\n", + " print(f\" Improvement: {(optimized_test_win_rate - 0.5) * 100:+.2f}pp from neutral\")\n", + "\n", + " # Save results\n", + " output_dir = Path(\"results\")\n", + " output_dir.mkdir(exist_ok=True)\n", + "\n", + " timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n", + "\n", + " with open(output_dir / f\"prompts_{timestamp}.txt\", 'w') as f:\n", + " f.write(\"BASELINE:\\n\" + \"=\" * 80 + \"\\n\")\n", + " f.write(BASELINE_PROMPT)\n", + " f.write(\"\\n\\nOPTIMIZED:\\n\" + \"=\" * 80 + \"\\n\")\n", + " f.write(best_prompt)\n", + " f.write(f\"\\n\\nRESULTS:\\n\" + \"=\" * 80 + \"\\n\")\n", + " f.write(f\"Baseline: {baseline_test_win_rate:.2%}\\n\")\n", + " f.write(f\"Optimized: {optimized_test_win_rate:.2%}\\n\")\n", + "\n", + " print(f\"\\n💾 Saved to: results/prompts_{timestamp}.txt\")\n", + "\n", + " return {\n", + " 'baseline_test': baseline_test_win_rate,\n", + " 'optimized_test': optimized_test_win_rate,\n", + " 'best_prompt': best_prompt\n", + " }\n", + "\n", + "print(\"✓ GEPA optimization function defined\")" + ], + "id": "c7100da955cfb3b5" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 🚀 Run the Optimization\n", + "\n", + "Now we'll execute the full GEPA optimization process. This will:\n", + "1. Set up the summarizer and optimizer models\n", + "2. Run multiple iterations of prompt improvement\n", + "3. Evaluate the final optimized prompt on the test set\n", + "4. Display comprehensive results" + ], + "id": "4839066f78acf10d" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"=\"*80)\n", + "print(\"🎯 GEPA SUMMARIZATION - TOGETHER AI BATCH EVAL\")\n", + "print(\"=\"*80)\n", + "\n", + "if not TOGETHER_API_KEY or TOGETHER_API_KEY == 'your_api_key_here':\n", + " print(\"❌ Set TOGETHER_API_KEY\")\n", + "else:\n", + " # Setup models\n", + " summarizer_lm = dspy.LM(\n", + " f\"together_ai/{SUMMARIZER_MODEL}\",\n", + " api_key=TOGETHER_API_KEY,\n", + " temperature=0.5,\n", + " max_tokens=1024\n", + " )\n", + "\n", + " optimizer_lm = SimpleOptimizerLM(\n", + " model=OPTIMIZER_MODEL,\n", + " api_key=TOGETHER_API_KEY\n", + " )\n", + "\n", + " # Run optimization\n", + " results = run_manual_gepa(\n", + " train_data,\n", + " val_data,\n", + " test_data,\n", + " summarizer_lm,\n", + " optimizer_lm,\n", + " max_iterations=5\n", + " )\n", + "\n", + " print(\"\\n✅ Complete!\")" + ], + "id": "51f60931bec8f490" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 📊 Analyzing the Results\n", + "\n", + "Let's examine the optimized prompt and compare it to the baseline." + ], + "id": "2be0f2bb00a13ff6" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"=\" * 80)\n", + "print(\"📝 PROMPT COMPARISON\")\n", + "print(\"=\" * 80)\n", + "\n", + "print(\"\\nBASELINE PROMPT:\")\n", + "print(\"-\" * 80)\n", + "print(BASELINE_PROMPT)\n", + "\n", + "print(\"\\n\\nOPTIMIZED PROMPT:\")\n", + "print(\"-\" * 80)\n", + "print(results['best_prompt'])\n", + "\n", + "print(\"\\n\\nPERFORMANCE COMPARISON:\")\n", + "print(\"-\" * 80)\n", + "print(f\"Baseline Win Rate: {results['baseline_test']:.2%}\")\n", + "print(f\"Optimized Win Rate: {results['optimized_test']:.2%}\")\n", + "print(f\"Improvement: {(results['optimized_test'] - 0.5) * 100:+.2f} percentage points from neutral\")" + ], + "id": "bc461eee131bd49f" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 🔑 Key Findings\n", + "\n", + "**GEPA Optimization Process:**\n", + "- Iteratively improves prompts through LLM-guided reflection\n", + "- Uses head-to-head comparisons with a judge model\n", + "- Tracks and accepts only improvements over baseline\n", + "\n", + "**Benefits of This Approach:**\n", + "1. **Automated**: No manual prompt engineering required\n", + "2. **Data-driven**: Decisions based on actual performance metrics\n", + "3. **Scalable**: Can optimize for any task with appropriate data\n", + "4. **Transparent**: Clear tracking of improvements across iterations\n", + "\n", + "**Next Steps:**\n", + "- Try with different datasets or domains\n", + "- Experiment with different judge criteria\n", + "- Adjust the optimizer's reflection prompt\n", + "- Increase iterations for potentially better results" + ], + "id": "8b606f57d491feb6" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} From 170d85bdb7d3a33cfef5c5e119fdef60fd53f6ed Mon Sep 17 00:00:00 2001 From: jli Date: Tue, 23 Dec 2025 01:28:19 +0800 Subject: [PATCH 2/6] update comments --- Evals/GEPA_Optimization.ipynb | 2067 +++++++++++++++++++-------------- 1 file changed, 1201 insertions(+), 866 deletions(-) diff --git a/Evals/GEPA_Optimization.ipynb b/Evals/GEPA_Optimization.ipynb index 07ae1e9..254eb56 100644 --- a/Evals/GEPA_Optimization.ipynb +++ b/Evals/GEPA_Optimization.ipynb @@ -1,868 +1,1203 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# GEPA Summarization Optimization with LLM Judge Evaluation\n", - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/)\n", - "\n", - "## Introduction\n", - "\n", - "This notebook demonstrates how to optimize summarization prompts using GEPA (Generate, Evaluate, Propose, Adapt) with the our Evaluations API. We'll:\n", - "\n", - "1. Load the CNN/DailyMail dataset containing news articles\n", - "2. Start with a baseline summarization prompt\n", - "3. Use an optimizer LLM to iteratively improve the prompt\n", - "4. Compare prompts head-to-head using a judge model\n", - "5. Track improvement over multiple iterations\n", - "\n", - "**Concepts Covered:**\n", - "- **GEPA Optimization**: Iterative prompt engineering using LLM feedback\n", - "- **LLM-as-a-Judge**: Using a language model to evaluate and compare outputs\n", - "- **Batch Evaluation**: Efficient comparison of multiple summaries\n", - "- **Prompt Engineering**: Systematic improvement of instruction prompts" - ], - "id": "9bed21b9f21cadb7" + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "9bed21b9f21cadb7" + }, + "source": [ + "# GEPA Summarization Optimization with LLM Judge Evaluation\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/)\n", + "\n", + "## Introduction\n", + "\n", + "This notebook demonstrates how to optimize summarization prompts using GEPA (Generate, Evaluate, Propose, Adapt) with the our Evaluations API. We'll:\n", + "\n", + "1. Load the CNN/DailyMail dataset containing news articles\n", + "2. Start with a baseline summarization prompt\n", + "3. Use an optimizer LLM to iteratively improve the prompt\n", + "4. Compare prompts head-to-head using a judge model\n", + "5. Track improvement over multiple iterations\n", + "\n", + "**Concepts Covered:**\n", + "- **GEPA Optimization**: Iterative prompt engineering using LLM feedback\n", + "- **LLM-as-a-Judge**: Using a language model to evaluate and compare outputs\n", + "- **Batch Evaluation**: Efficient comparison of multiple summaries\n", + "- **Prompt Engineering**: Systematic improvement of instruction prompts" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "c044d292f626f2f6" + }, + "source": [ + "## 📦 Setup and Installation" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "cf56ca26c1b94222" + }, + "outputs": [], + "source": [ + "!pip install -qU together dspy-ai datasets tqdm" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 216 + }, + "id": "1c293b491e894110", + "outputId": "e393f618-61a5-415e-ce69-18ebf78fbe99" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "\u001b[36m╭─\u001b[0m\u001b[36m────────────────────────────────────────────\u001b[0m\u001b[36m 🚀 New SDK Available \u001b[0m\u001b[36m─────────────────────────────────────────────\u001b[0m\u001b[36m─╮\u001b[0m\n", + "\u001b[36m│\u001b[0m \u001b[1;36mTogether Python SDK 2.0 is now available!\u001b[0m \u001b[36m│\u001b[0m\n", + "\u001b[36m│\u001b[0m \u001b[36m│\u001b[0m\n", + "\u001b[36m│\u001b[0m Install the beta: \u001b[36m│\u001b[0m\n", + "\u001b[36m│\u001b[0m \u001b[32mpip install --pre together\u001b[0m or \u001b[32muv add together --prerelease allow\u001b[0m \u001b[36m│\u001b[0m\n", + "\u001b[36m│\u001b[0m \u001b[36m│\u001b[0m\n", + "\u001b[36m│\u001b[0m New SDK: \u001b]8;id=629133;https://github.com/togethercomputer/together-py\u001b\\https://github.com/togethercomputer/together-py\u001b]8;;\u001b\\ \u001b[36m│\u001b[0m\n", + "\u001b[36m│\u001b[0m Migration guide: \u001b]8;id=644417;https://docs.together.ai/docs/pythonv2-migration-guide\u001b\\https://docs.together.ai/docs/pythonv2-migration-guide\u001b]8;;\u001b\\ \u001b[36m│\u001b[0m\n", + "\u001b[36m│\u001b[0m \u001b[36m│\u001b[0m\n", + "\u001b[36m│\u001b[0m \u001b[2mThis package will be maintained until January 2026.\u001b[0m \u001b[36m│\u001b[0m\n", + "\u001b[36m│\u001b[0m \u001b[2mSet TOGETHER_NO_BANNER=1 to hide this message.\u001b[0m \u001b[36m│\u001b[0m\n", + "\u001b[36m╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n" + ], + "text/html": [ + "
╭───────────────────────────────────────────── 🚀 New SDK Available ──────────────────────────────────────────────╮\n",
+              " Together Python SDK 2.0 is now available!                                                                       \n",
+              "                                                                                                                 \n",
+              " Install the beta:                                                                                               \n",
+              " pip install --pre together  or  uv add together --prerelease allow                                              \n",
+              "                                                                                                                 \n",
+              " New SDK: https://github.com/togethercomputer/together-py                                                        \n",
+              " Migration guide: https://docs.together.ai/docs/pythonv2-migration-guide                                         \n",
+              "                                                                                                                 \n",
+              " This package will be maintained until January 2026.                                                             \n",
+              " Set TOGETHER_NO_BANNER=1 to hide this message.                                                                  \n",
+              "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n",
+              "
\n" + ] + }, + "metadata": {} + } + ], + "source": [ + "import together\n", + "import json\n", + "import random\n", + "import os\n", + "import re\n", + "import time\n", + "from pathlib import Path\n", + "from typing import List, Dict, Tuple\n", + "from datetime import datetime\n", + "\n", + "import dspy\n", + "from datasets import load_dataset\n", + "from tqdm import tqdm" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8e71863c8ff3faa6" + }, + "source": [ + "## ⚙️ Configuration\n", + "\n", + "Set up your API key and configure the models we'll use:\n", + "- **Summarizer Model**: Generates the summaries\n", + "- **Judge Model**: Evaluates which summary is better\n", + "- **Optimizer Model**: Proposes improvements to the prompt" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "3d21616fa03c0145", + "outputId": "84889606-a0fb-4556-af15-3b1c9e7fc4ad" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "✓ API key loaded from Colab secrets\n", + "✓ Configuration complete\n" + ] + } + ], + "source": [ + "# Set your Together AI API key from Colab secrets\n", + "from google.colab import userdata\n", + "TOGETHER_API_KEY = userdata.get('TOGETHER_API_KEY')\n", + "print(\"✓ API key loaded from Colab secrets\")\n", + "\n", + "client = together.Client(api_key=TOGETHER_API_KEY)\n", + "\n", + "# Model configuration\n", + "SUMMARIZER_MODEL = \"openai/gpt-oss-20b\"\n", + "JUDGE_MODEL = \"deepseek-ai/DeepSeek-V3\"\n", + "OPTIMIZER_MODEL = \"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo\"\n", + "\n", + "# Data splits\n", + "TRAIN_SIZE = 150\n", + "VAL_SIZE = 300\n", + "TEST_SIZE = 300\n", + "\n", + "RANDOM_SEED = 42\n", + "\n", + "print(\"✓ Configuration complete\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "d9378d341fb8389d" + }, + "source": [ + "## 📝 Baseline and Judge Prompts\n", + "\n", + "We start with a simple baseline prompt for summarization. The GEPA process will iteratively improve this prompt based on performance feedback." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "263940c8c55eb1dd", + "outputId": "a2041a07-268c-4815-a7a4-85c964b7b2be" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Baseline Prompt:\n", + "Summarize this news article in 3-5 key points.\n", + "\n", + "Write a brief summary covering:\n", + "- The main news event\n", + "- Key people or organizations involved\n", + "- Important details or outcomes\n", + "- Any significant context\n", + "\n", + "Keep it to 3-5 sentences total.\n", + "\n", + "Judge Prompt:\n", + "Compare these two summaries of the same news article.\n", + "\n", + "Which summary better:\n", + "- Captures the main news story\n", + "- Includes important details\n", + "- Is clear and concise\n", + "- Avoids unnecessary information\n", + "\n", + "Choose A or B and explain why briefly.\n" + ] + } + ], + "source": [ + "BASELINE_PROMPT = \"\"\"Summarize this news article in 3-5 key points.\n", + "\n", + "Write a brief summary covering:\n", + "- The main news event\n", + "- Key people or organizations involved\n", + "- Important details or outcomes\n", + "- Any significant context\n", + "\n", + "Keep it to 3-5 sentences total.\"\"\"\n", + "\n", + "JUDGE_PROMPT = \"\"\"Compare these two summaries of the same news article.\n", + "\n", + "Which summary better:\n", + "- Captures the main news story\n", + "- Includes important details\n", + "- Is clear and concise\n", + "- Avoids unnecessary information\n", + "\n", + "Choose A or B and explain why briefly.\"\"\"\n", + "\n", + "print(\"Baseline Prompt:\")\n", + "print(BASELINE_PROMPT)\n", + "print(\"\\nJudge Prompt:\")\n", + "print(JUDGE_PROMPT)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "c0a86293e7b95dd9" + }, + "source": [ + "## 📂 Loading the CNN/DailyMail Dataset\n", + "\n", + "The CNN/DailyMail dataset contains news articles paired with human-written highlights. We'll use the articles as our source text and split the data into train, validation, and test sets.\n", + "\n", + "**Dataset Structure:**\n", + "- `article`: The full news article text\n", + "- `highlights`: Human-written bullet-point summary\n", + "- We'll use the articles for summarization and evaluate our generated summaries" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "7dcc2d8d5c706df4", + "outputId": "e8dcb543-c238-42d3-af49-bcd77bfe7b7f" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\n", + "================================================================================\n", + "📂 LOADING DATA\n", + "================================================================================\n", + "Loading CNN/DailyMail dataset...\n", + "✓ Loaded 11490 examples\n", + " Sample article: (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Cour...\n", + " Sample highlights: Membership gives the ICC jurisdiction over alleged crimes committed in Palestinian territories since...\n", + "✓ Converted to 11490 items\n", + "✓ Split: Train=150, Val=300, Test=300\n" + ] + } + ], + "source": [ + "def load_and_split_data():\n", + " \"\"\"Load CNN/DailyMail dataset for summarization.\"\"\"\n", + " print(\"\\n\" + \"=\" * 80)\n", + " print(\"📂 LOADING DATA\")\n", + " print(\"=\" * 80)\n", + "\n", + " print(\"Loading CNN/DailyMail dataset...\")\n", + " dataset = load_dataset(\"abisee/cnn_dailymail\", \"3.0.0\")\n", + " data = dataset['test']\n", + "\n", + " print(f\"✓ Loaded {len(data)} examples\")\n", + " print(f\" Sample article: {data[0]['article'][:100]}...\")\n", + " print(f\" Sample highlights: {data[0]['highlights'][:100]}...\")\n", + "\n", + " all_data = []\n", + " for i, item in enumerate(data):\n", + " all_data.append({\n", + " 'id': f\"cnn_{i}\",\n", + " 'text': item['article'],\n", + " 'reference_summary': item['highlights']\n", + " })\n", + "\n", + " print(f\"✓ Converted to {len(all_data)} items\")\n", + "\n", + " random.seed(RANDOM_SEED)\n", + " random.shuffle(all_data)\n", + "\n", + " train_data = all_data[:TRAIN_SIZE]\n", + " val_data = all_data[TRAIN_SIZE:TRAIN_SIZE + VAL_SIZE]\n", + " test_data = all_data[TRAIN_SIZE + VAL_SIZE:TRAIN_SIZE + VAL_SIZE + TEST_SIZE]\n", + "\n", + " print(f\"✓ Split: Train={len(train_data)}, Val={len(val_data)}, Test={len(test_data)}\")\n", + "\n", + " assert len(val_data) > 0, \"Val data is empty!\"\n", + " assert len(test_data) > 0, \"Test data is empty!\"\n", + "\n", + " return train_data, val_data, test_data\n", + "\n", + "# Load the data\n", + "train_data, val_data, test_data = load_and_split_data()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "d1b9222690db8449" + }, + "source": [ + "## 🤖 Summarization Module\n", + "\n", + "We create a DSPy module that wraps our summarization task. This module can be configured with different instruction prompts, which is key to the GEPA optimization process." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "b8ca2917024c326e", + "outputId": "171c4567-9971-499a-edad-04b67c858885" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "✓ Summarization module defined\n" + ] + } + ], + "source": [ + "class Summarizer(dspy.Signature):\n", + " \"\"\"Generate a summary.\"\"\"\n", + " text = dspy.InputField()\n", + " summary = dspy.OutputField()\n", + "\n", + "\n", + "class SummarizationModule(dspy.Module):\n", + " \"\"\"Summarization module.\"\"\"\n", + "\n", + " def __init__(self, instructions=None):\n", + " super().__init__()\n", + " self.instructions = instructions or BASELINE_PROMPT\n", + "\n", + " if instructions:\n", + " class CustomSummarizer(dspy.Signature):\n", + " __doc__ = instructions\n", + " text = dspy.InputField()\n", + " summary = dspy.OutputField()\n", + "\n", + " self.predictor = dspy.Predict(CustomSummarizer)\n", + " else:\n", + " self.predictor = dspy.Predict(Summarizer)\n", + "\n", + " def forward(self, text):\n", + " return self.predictor(text=text)\n", + "\n", + "print(\"✓ Summarization module defined\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "590d6b9c625ca2cc" + }, + "source": [ + "## 📊 Batch Summary Generation\n", + "\n", + "This function generates summaries for a batch of articles using a given prompt. It includes error handling and progress tracking." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "270abdde73d2ca72", + "outputId": "6eafb2d3-e773-4a65-f3b5-802687fffafc" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "✓ Batch generation function defined\n" + ] + } + ], + "source": [ + "def generate_summaries_batch(\n", + " summarizer: SummarizationModule,\n", + " data: List[Dict],\n", + " desc: str = \"Generating\"\n", + ") -> List[Dict]:\n", + " \"\"\"Generate summaries for a batch of texts.\"\"\"\n", + " results = []\n", + " errors = 0\n", + " error_details = []\n", + "\n", + " # Print the prompt being used (first item only)\n", + " if len(data) > 0:\n", + " print(f\" Using prompt: {summarizer.instructions[:100]}...\")\n", + "\n", + " for item in tqdm(data, desc=desc):\n", + " try:\n", + " pred = summarizer(text=item['text'][:5000])\n", + "\n", + " if pred is None:\n", + " raise ValueError(\"Model returned None\")\n", + "\n", + " if hasattr(pred, 'summary') and pred.summary:\n", + " summary = pred.summary\n", + " elif isinstance(pred, str):\n", + " summary = pred\n", + " else:\n", + " print(f\"\\n DEBUG: pred type={type(pred)}, hasattr summary={hasattr(pred, 'summary')}\")\n", + " raise ValueError(f\"Cannot extract summary from {type(pred)}\")\n", + "\n", + " summary = summary.strip()\n", + " if len(summary) < 20:\n", + " raise ValueError(\"Summary too short\")\n", + "\n", + " except Exception as e:\n", + " errors += 1\n", + " error_details.append(str(e)[:100])\n", + "\n", + " if errors <= 5:\n", + " print(f\"\\n⚠️ Error: {str(e)[:80]}\")\n", + "\n", + " summary = \"Error generating summary.\"\n", + "\n", + " results.append({\n", + " 'id': item['id'],\n", + " 'text': item['text'],\n", + " 'summary': summary\n", + " })\n", + "\n", + " if errors > 0:\n", + " print(f\"\\n⚠️ Total errors: {errors}/{len(data)} ({errors / len(data) * 100:.1f}%)\")\n", + " from collections import Counter\n", + " common_errors = Counter(error_details).most_common(3)\n", + " print(f\" Most common errors:\")\n", + " for err, count in common_errors:\n", + " print(f\" - {err[:60]}... ({count}x)\")\n", + "\n", + " return results\n", + "\n", + "print(\"✓ Batch generation function defined\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2cfe63f485894d7c" + }, + "source": [ + "## 🧠 Optimizer LLM Wrapper\n", + "\n", + "This wrapper allows us to use an LLM to propose improvements to our summarization prompt based on current performance." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "d11af9ff91f442df", + "outputId": "c9cd0f0e-7325-46cc-d065-d4a3745c08c3" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "✓ Optimizer LLM wrapper defined\n" + ] + } + ], + "source": [ + "class SimpleOptimizerLM:\n", + " \"\"\"Wrapper for optimizer LLM.\"\"\"\n", + "\n", + " def __init__(self, model: str, api_key: str):\n", + " self.client = together.Client(api_key=api_key)\n", + " self.model = model\n", + "\n", + " def __call__(self, prompt: str) -> str:\n", + " response = self.client.chat.completions.create(\n", + " model=self.model,\n", + " messages=[{\"role\": \"user\", \"content\": prompt}],\n", + " temperature=0.7,\n", + " max_tokens=4000\n", + " )\n", + " return response.choices[0].message.content\n", + "\n", + "print(\"✓ Optimizer LLM wrapper defined\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "67a224aff87d2f5e" + }, + "source": [ + "## 🤔 Reflection and Prompt Improvement\n", + "\n", + "This function uses the optimizer LLM to analyze the current prompt and performance, then propose an improved version.\n", + "\n", + "**Key Constraints:**\n", + "- Keep prompts under 150 words for clarity\n", + "- Focus on simple, direct instructions\n", + "- Target 4-6 sentence summaries\n", + "- Avoid overly complex requirements" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "1186e66cab3ea1f1", + "outputId": "a8ea71b8-da99-4efa-c72b-59603458e664" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "✓ Reflection function defined\n" + ] + } + ], + "source": [ + "def reflect_and_improve_prompt(\n", + " current_prompt: str,\n", + " current_score: float,\n", + " optimizer_lm: SimpleOptimizerLM,\n", + " iteration: int\n", + ") -> str:\n", + " \"\"\"Use LLM to propose improved prompt.\"\"\"\n", + "\n", + " print(f\"\\n🤔 REFLECTION (Iteration {iteration})\")\n", + "\n", + " reflection_prompt = f\"\"\"You are optimizing a summarization prompt for CNN/DailyMail news articles.\n", + "\n", + "Current Prompt:\n", + "```\n", + "{current_prompt}\n", + "```\n", + "\n", + "Current Performance: {current_score:.1%} win rate\n", + "\n", + "Your task: Propose a SIMPLE improved version that generates better summaries.\n", + "\n", + "CRITICAL CONSTRAINTS:\n", + "- Keep the prompt under 150 words\n", + "- Make it clear and direct (NOT overly complex)\n", + "- Target 4-6 sentence summaries\n", + "- Avoid excessive instructions or formatting requirements\n", + "- The prompt should be easy for the model to follow\n", + "\n", + "Focus on:\n", + "- Should it emphasize different aspects (accuracy, brevity, completeness)?\n", + "- Are the current guidelines clear?\n", + "- Is anything missing or unnecessary?\n", + "\n", + "Output ONLY the improved prompt within ``` blocks. Keep it simple and clear.\"\"\"\n", + "\n", + " response = optimizer_lm(reflection_prompt)\n", + "\n", + " # Extract prompt\n", + " match = re.search(r'```(.*?)```', response, re.DOTALL)\n", + " if match:\n", + " new_prompt = match.group(1).strip()\n", + " # Remove language tags\n", + " for tag in ['markdown', 'text', 'python', 'plaintext']:\n", + " if new_prompt.startswith(f'{tag}\\n'):\n", + " new_prompt = '\\n'.join(new_prompt.split('\\n')[1:])\n", + "\n", + " # Validate length (reject if too long)\n", + " word_count = len(new_prompt.split())\n", + " if word_count > 200:\n", + " print(f\" ⚠️ Generated prompt too long ({word_count} words), using current\")\n", + " return current_prompt\n", + "\n", + " print(f\"✓ Generated new prompt ({word_count} words)\")\n", + " return new_prompt\n", + "\n", + " print(\"⚠️ Could not extract prompt\")\n", + " return current_prompt\n", + "\n", + "print(\"✓ Reflection function defined\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "a2fbbd02f5054425" + }, + "source": [ + "## 🔄 Head-to-Head Prompt Comparison\n", + "\n", + "This function compares two prompts by:\n", + "1. Generating summaries with both prompts\n", + "2. Creating a comparison dataset\n", + "3. Using the Together AI evaluation API with a judge model\n", + "4. Computing win rates\n", + "\n", + "The evaluation uses a two-pass approach to eliminate position bias." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "5a1b2d5116f3731f", + "outputId": "f6aa5880-7905-4acc-c9ab-b01dc2b6a30f" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "✓ Comparison function defined\n" + ] + } + ], + "source": [ + "def compare_two_prompts_on_batch(\n", + " data: List[Dict],\n", + " prompt_a: str,\n", + " prompt_b: str,\n", + " summarizer_lm: dspy.LM,\n", + " eval_name: str\n", + ") -> Tuple[float, float, Dict]:\n", + " \"\"\"\n", + " Compare two summarization prompts.\n", + "\n", + " 1. Generate summaries with prompt A\n", + " 2. Generate summaries with prompt B\n", + " 3. Use judge to compare them\n", + " 4. Return win rate for prompt A\n", + " \"\"\"\n", + "\n", + " print(f\"\\n{'=' * 80}\")\n", + " print(f\"🔄 COMPARING PROMPTS: {eval_name}\")\n", + " print(f\"{'=' * 80}\")\n", + "\n", + " # Step 1: Generate with both prompts\n", + " dspy.configure(lm=summarizer_lm)\n", + "\n", + " summarizer_a = SummarizationModule(prompt_a)\n", + " summarizer_b = SummarizationModule(prompt_b)\n", + "\n", + " print(\"Generating summaries with Prompt A...\")\n", + " summaries_a = generate_summaries_batch(summarizer_a, data, \"Prompt A\")\n", + "\n", + " print(\"Generating summaries with Prompt B...\")\n", + " summaries_b = generate_summaries_batch(summarizer_b, data, \"Prompt B\")\n", + "\n", + " # Step 2: Prepare comparison data\n", + " temp_file = f\"temp_compare_{eval_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl\"\n", + "\n", + " with open(temp_file, 'w') as f:\n", + " for summary_a, summary_b in zip(summaries_a, summaries_b):\n", + " formatted = {\n", + " \"prompt\": f\"Source article: {summary_a['text'][:5000]}\",\n", + " \"model_a_output\": summary_a['summary'],\n", + " \"model_b_output\": summary_b['summary'],\n", + " \"id\": summary_a['id']\n", + " }\n", + " f.write(json.dumps(formatted) + '\\n')\n", + "\n", + " # Step 3: Upload and evaluate\n", + " print(\"📤 Uploading for comparison...\")\n", + " file_response = client.files.upload(file=temp_file, purpose=\"eval\")\n", + " file_id = file_response.id\n", + "\n", + " print(\"🚀 Launching comparison...\")\n", + " eval_response = client.evaluation.create(\n", + " type=\"compare\",\n", + " input_data_file_path=file_id,\n", + " judge_model=JUDGE_MODEL,\n", + " judge_model_source=\"serverless\",\n", + " judge_system_template=JUDGE_PROMPT,\n", + " model_a=\"model_a_output\",\n", + " model_b=\"model_b_output\"\n", + " )\n", + "\n", + " # Step 4: Wait and get results\n", + " print(f\"⏳ Waiting (ID: {eval_response.workflow_id})...\")\n", + " while True:\n", + " status = client.evaluation.status(eval_response.workflow_id)\n", + " if status.status.value == \"completed\":\n", + " break\n", + " elif status.status.value == \"failed\":\n", + " raise Exception(\"Evaluation failed\")\n", + " time.sleep(30)\n", + "\n", + " a_wins = status.results.get('A_wins', 0)\n", + " b_wins = status.results.get('B_wins', 0)\n", + " ties = status.results.get('Ties', 0)\n", + "\n", + " # Win rate for prompt A\n", + " decisive_total = a_wins + b_wins\n", + " if decisive_total > 0:\n", + " a_win_rate = a_wins / decisive_total\n", + " b_win_rate = b_wins / decisive_total\n", + " else:\n", + " a_win_rate = b_win_rate = 0.5\n", + "\n", + " print(f\"✓ Results: Prompt A wins={a_wins}, Prompt B wins={b_wins}, Ties={ties}\")\n", + " print(f\"✓ Prompt A win rate: {a_win_rate:.2%}\")\n", + "\n", + " os.remove(temp_file)\n", + "\n", + " return a_win_rate, b_win_rate, {\n", + " 'a_wins': a_wins,\n", + " 'b_wins': b_wins,\n", + " 'ties': ties,\n", + " 'a_win_rate': a_win_rate\n", + " }\n", + "\n", + "print(\"✓ Comparison function defined\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6657d33b050676ff" + }, + "source": [ + "## 🧬 GEPA Optimization Loop\n", + "\n", + "This is the main optimization loop that implements the GEPA algorithm:\n", + "\n", + "1. **Generate**: Create summaries with current prompt\n", + "2. **Evaluate**: Compare against baseline using judge model\n", + "3. **Propose**: Use optimizer LLM to suggest improvements\n", + "4. **Adapt**: Accept improvements that increase win rate\n", + "\n", + "The process repeats for multiple iterations, tracking the best prompt found." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "c7100da955cfb3b5", + "outputId": "1144337a-d273-452a-84bf-4ad959363cd1" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "✓ GEPA optimization function defined\n" + ] + } + ], + "source": [ + "def run_manual_gepa(\n", + " train_data: List[Dict],\n", + " val_data: List[Dict],\n", + " test_data: List[Dict],\n", + " summarizer_lm: dspy.LM,\n", + " optimizer_lm: SimpleOptimizerLM,\n", + " max_iterations: int = 5\n", + "):\n", + " \"\"\"Manual GEPA-style optimization.\"\"\"\n", + "\n", + " print(\"\\n\" + \"=\" * 80)\n", + " print(\"🧬 MANUAL GEPA OPTIMIZATION\")\n", + " print(\"=\" * 80)\n", + "\n", + " best_prompt = BASELINE_PROMPT\n", + " best_val_score = 0.5 # Start at 50% (neutral)\n", + "\n", + " for i in range(max_iterations):\n", + " print(f\"\\n{'=' * 80}\")\n", + " print(f\"ITERATION {i + 1}/{max_iterations}\")\n", + " print(f\"{'=' * 80}\")\n", + "\n", + " if i == 0:\n", + " print(\"Iteration 0: Establishing baseline (no comparison yet)\")\n", + " continue\n", + "\n", + " new_prompt = reflect_and_improve_prompt(\n", + " best_prompt,\n", + " best_val_score,\n", + " optimizer_lm,\n", + " i\n", + " )\n", + "\n", + " if new_prompt == best_prompt:\n", + " print(\"⚠️ No change in prompt, stopping\")\n", + " break\n", + "\n", + " print(f\"✓ Generated candidate prompt ({len(new_prompt)} chars)\")\n", + "\n", + " # Compare best_prompt vs new_prompt on validation set\n", + " baseline_win_rate, new_prompt_win_rate, metrics = compare_two_prompts_on_batch(\n", + " val_data,\n", + " prompt_a=best_prompt,\n", + " prompt_b=new_prompt,\n", + " summarizer_lm=summarizer_lm,\n", + " eval_name=f\"iter{i}_val\"\n", + " )\n", + "\n", + " new_prompt_win_rate = 1.0 - baseline_win_rate\n", + "\n", + " print(f\"\\n Current best: {baseline_win_rate:.2%}\")\n", + " print(f\" New candidate: {new_prompt_win_rate:.2%}\")\n", + "\n", + " if new_prompt_win_rate > best_val_score:\n", + " improvement = new_prompt_win_rate - best_val_score\n", + " print(f\" 🎉 New best! (+{improvement * 100:.2f}pp)\")\n", + " best_prompt = new_prompt\n", + " best_val_score = new_prompt_win_rate\n", + " else:\n", + " print(f\" No improvement\")\n", + "\n", + " print(\"\\n\" + \"=\" * 80)\n", + " print(\"📊 FINAL TEST EVALUATION\")\n", + " print(\"=\" * 80)\n", + "\n", + " baseline_test_win_rate, optimized_test_win_rate, _ = compare_two_prompts_on_batch(\n", + " test_data,\n", + " prompt_a=BASELINE_PROMPT,\n", + " prompt_b=best_prompt,\n", + " summarizer_lm=summarizer_lm,\n", + " eval_name=\"final_test\"\n", + " )\n", + "\n", + " print(\"\\n\" + \"=\" * 80)\n", + " print(\"🎉 FINAL RESULTS\")\n", + " print(\"=\" * 80)\n", + "\n", + " print(f\"\\nTEST SET:\")\n", + " print(f\" Baseline prompt: {baseline_test_win_rate:.2%}\")\n", + " print(f\" Optimized prompt: {optimized_test_win_rate:.2%}\")\n", + " print(f\" Improvement: {(optimized_test_win_rate - 0.5) * 100:+.2f}pp from neutral\")\n", + "\n", + " output_dir = Path(\"results\")\n", + " output_dir.mkdir(exist_ok=True)\n", + "\n", + " timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n", + "\n", + " with open(output_dir / f\"prompts_{timestamp}.txt\", 'w') as f:\n", + " f.write(\"BASELINE:\\n\" + \"=\" * 80 + \"\\n\")\n", + " f.write(BASELINE_PROMPT)\n", + " f.write(\"\\n\\nOPTIMIZED:\\n\" + \"=\" * 80 + \"\\n\")\n", + " f.write(best_prompt)\n", + " f.write(f\"\\n\\nRESULTS:\\n\" + \"=\" * 80 + \"\\n\")\n", + " f.write(f\"Baseline: {baseline_test_win_rate:.2%}\\n\")\n", + " f.write(f\"Optimized: {optimized_test_win_rate:.2%}\\n\")\n", + "\n", + " print(f\"\\n💾 Saved to: results/prompts_{timestamp}.txt\")\n", + "\n", + " return {\n", + " 'baseline_test': baseline_test_win_rate,\n", + " 'optimized_test': optimized_test_win_rate,\n", + " 'best_prompt': best_prompt\n", + " }\n", + "\n", + "print(\"✓ GEPA optimization function defined\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4839066f78acf10d" + }, + "source": [ + "## 🚀 Run the Optimization\n", + "\n", + "Now we'll execute the full GEPA optimization process. This will:\n", + "1. Set up the summarizer and optimizer models\n", + "2. Run multiple iterations of prompt improvement\n", + "3. Evaluate the final optimized prompt on the test set\n", + "4. Display comprehensive results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "51f60931bec8f490", + "outputId": "1b34ac6f-0d40-46c9-d9df-6ac6c699cb66" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "================================================================================\n", + "🎯 GEPA SUMMARIZATION - TOGETHER AI BATCH EVAL\n", + "================================================================================\n", + "\n", + "================================================================================\n", + "🧬 MANUAL GEPA OPTIMIZATION\n", + "================================================================================\n", + "\n", + "================================================================================\n", + "ITERATION 1/5\n", + "================================================================================\n", + "Iteration 0: Establishing baseline (no comparison yet)\n", + "\n", + "================================================================================\n", + "ITERATION 2/5\n", + "================================================================================\n", + "\n", + "🤔 REFLECTION (Iteration 1)\n", + "✓ Generated new prompt (63 words)\n", + "✓ Generated candidate prompt (404 chars)\n", + "\n", + "================================================================================\n", + "🔄 COMPARING PROMPTS: iter1_val\n", + "================================================================================\n", + "Generating summaries with Prompt A...\n", + " Using prompt: Summarize this news article in 3-5 key points.\n", + "\n", + "Write a brief summary covering:\n", + "- The main news even...\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "Prompt A: 100%|██████████| 300/300 [14:30<00:00, 2.90s/it]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Generating summaries with Prompt B...\n", + " Using prompt: Summarize this news article in 4-6 sentences, focusing on clarity and concision.\n", + "\n", + "Please cover the f...\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "Prompt B: 100%|██████████| 300/300 [17:16<00:00, 3.46s/it]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "📤 Uploading for comparison...\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "Uploading file temp_compare_iter1_val_20251222_170518.jsonl: 100%|██████████| 1.59M/1.59M [00:00<00:00, 2.82MB/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "🚀 Launching comparison...\n", + "⏳ Waiting (ID: eval-94eb-1766423120)...\n", + "✓ Results: Prompt A wins=29, Prompt B wins=35, Ties=236\n", + "✓ Prompt A win rate: 45.31%\n", + "\n", + " Current best: 45.31%\n", + " New candidate: 54.69%\n", + " 🎉 New best! (+4.69pp)\n", + "\n", + "================================================================================\n", + "ITERATION 3/5\n", + "================================================================================\n", + "\n", + "🤔 REFLECTION (Iteration 2)\n", + "✓ Generated new prompt (58 words)\n", + "✓ Generated candidate prompt (389 chars)\n", + "\n", + "================================================================================\n", + "🔄 COMPARING PROMPTS: iter2_val\n", + "================================================================================\n", + "Generating summaries with Prompt A...\n", + " Using prompt: Summarize this news article in 4-6 sentences, focusing on clarity and concision.\n", + "\n", + "Please cover the f...\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "Prompt A: 100%|██████████| 300/300 [00:39<00:00, 7.68it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Generating summaries with Prompt B...\n", + " Using prompt: Write a 4-6 sentence summary of this news article, prioritizing clarity and accuracy. \n", + "\n", + "Clearly stat...\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "Prompt B: 38%|███▊ | 113/300 [06:12<09:41, 3.11s/it]" + ] + } + ], + "source": [ + "print(\"=\"*80)\n", + "print(\"🎯 GEPA SUMMARIZATION - TOGETHER AI BATCH EVAL\")\n", + "print(\"=\"*80)\n", + "\n", + "# Setup models\n", + "summarizer_lm = dspy.LM(\n", + " f\"together_ai/{SUMMARIZER_MODEL}\",\n", + " api_key=TOGETHER_API_KEY,\n", + " temperature=0.5,\n", + " max_tokens=1024\n", + ")\n", + "\n", + "optimizer_lm = SimpleOptimizerLM(\n", + " model=OPTIMIZER_MODEL,\n", + " api_key=TOGETHER_API_KEY,\n", + ")\n", + "\n", + "start_time = time.time()\n", + "\n", + "# Run optimization\n", + "results = run_manual_gepa(\n", + " train_data,\n", + " val_data,\n", + " test_data,\n", + " summarizer_lm,\n", + " optimizer_lm,\n", + " max_iterations=5\n", + ")\n", + "\n", + "print(\"\\n✅ Complete!\")\n", + "\n", + "total_time = time.time() - start_time\n", + "hours = int(total_time // 3600)\n", + "minutes = int((total_time % 3600) // 60)\n", + "seconds = int(total_time % 60)\n", + "\n", + "print(f\"\\n⏱️ OPTIMIZATION TIME:\")\n", + "if hours > 0:\n", + " print(f\" Total: {hours}h {minutes}m {seconds}s\")\n", + "elif minutes > 0:\n", + " print(f\" Total: {minutes}m {seconds}s\")\n", + "else:\n", + " print(f\" Total: {seconds}s\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2be0f2bb00a13ff6" + }, + "source": [ + "## 📊 Analyzing the Results\n", + "\n", + "Let's examine the optimized prompt and compare it to the baseline." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "bc461eee131bd49f" + }, + "outputs": [], + "source": [ + "print(\"=\" * 80)\n", + "print(\"📝 PROMPT COMPARISON\")\n", + "print(\"=\" * 80)\n", + "\n", + "print(\"\\nBASELINE PROMPT:\")\n", + "print(\"-\" * 80)\n", + "print(BASELINE_PROMPT)\n", + "\n", + "print(\"\\n\\nOPTIMIZED PROMPT:\")\n", + "print(\"-\" * 80)\n", + "print(results['best_prompt'])\n", + "\n", + "print(\"\\n\\nPERFORMANCE COMPARISON:\")\n", + "print(\"-\" * 80)\n", + "print(f\"Baseline Win Rate: {results['baseline_test']:.2%}\")\n", + "print(f\"Optimized Win Rate: {results['optimized_test']:.2%}\")\n", + "print(f\"Improvement: {(results['optimized_test'] - 0.5) * 100:+.2f} percentage points from neutral\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8b606f57d491feb6" + }, + "source": [ + "## 🔑 Key Findings\n", + "\n", + "**GEPA Optimization Process:**\n", + "- Iteratively improves prompts through LLM-guided reflection\n", + "- Uses head-to-head comparisons with a judge model\n", + "- Tracks and accepts only improvements over baseline\n", + "\n", + "**Benefits of This Approach:**\n", + "1. **Automated**: No manual prompt engineering required\n", + "2. **Data-driven**: Decisions based on actual performance metrics\n", + "3. **Scalable**: Can optimize for any task with appropriate data\n", + "4. **Transparent**: Clear tracking of improvements across iterations\n", + "\n", + "**Next Steps:**\n", + "- Try with different datasets or domains\n", + "- Experiment with different judge criteria\n", + "- Adjust the optimizer's reflection prompt\n", + "- Increase iterations for potentially better results" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.0" + }, + "colab": { + "provenance": [] + } }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 📦 Setup and Installation" - ], - "id": "c044d292f626f2f6" - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "!pip install -qU together dspy-ai datasets tqdm" - ], - "id": "cf56ca26c1b94222" - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import together\n", - "import json\n", - "import random\n", - "import os\n", - "import re\n", - "import time\n", - "from pathlib import Path\n", - "from typing import List, Dict, Tuple\n", - "from datetime import datetime\n", - "\n", - "import dspy\n", - "from datasets import load_dataset\n", - "from tqdm import tqdm" - ], - "id": "1c293b491e894110" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## ⚙️ Configuration\n", - "\n", - "Set up your API key and configure the models we'll use:\n", - "- **Summarizer Model**: Generates the summaries\n", - "- **Judge Model**: Evaluates which summary is better\n", - "- **Optimizer Model**: Proposes improvements to the prompt" - ], - "id": "8e71863c8ff3faa6" - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "client = together.Client()\n", - "\n", - "# Model configuration\n", - "SUMMARIZER_MODEL = \"openai/gpt-oss-20b\"\n", - "JUDGE_MODEL = \"deepseek-ai/DeepSeek-V3\"\n", - "OPTIMIZER_MODEL = \"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo\"\n", - "\n", - "# Data splits\n", - "TRAIN_SIZE = 150\n", - "VAL_SIZE = 300\n", - "TEST_SIZE = 300\n", - "\n", - "RANDOM_SEED = 42\n", - "\n", - "print(\"✓ Configuration complete\")" - ], - "id": "3d21616fa03c0145" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 📝 Baseline and Judge Prompts\n", - "\n", - "We start with a simple baseline prompt for summarization. The GEPA process will iteratively improve this prompt based on performance feedback." - ], - "id": "d9378d341fb8389d" - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "BASELINE_PROMPT = \"\"\"Summarize this news article in 3-5 key points.\n", - "\n", - "Write a brief summary covering:\n", - "- The main news event\n", - "- Key people or organizations involved\n", - "- Important details or outcomes\n", - "- Any significant context\n", - "\n", - "Keep it to 3-5 sentences total.\"\"\"\n", - "\n", - "JUDGE_PROMPT = \"\"\"Compare these two summaries of the same news article.\n", - "\n", - "Which summary better:\n", - "- Captures the main news story\n", - "- Includes important details\n", - "- Is clear and concise\n", - "- Avoids unnecessary information\n", - "\n", - "Choose A or B and explain why briefly.\"\"\"\n", - "\n", - "print(\"Baseline Prompt:\")\n", - "print(BASELINE_PROMPT)\n", - "print(\"\\nJudge Prompt:\")\n", - "print(JUDGE_PROMPT)" - ], - "id": "263940c8c55eb1dd" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 📂 Loading the CNN/DailyMail Dataset\n", - "\n", - "The CNN/DailyMail dataset contains news articles paired with human-written highlights. We'll use the articles as our source text and split the data into train, validation, and test sets.\n", - "\n", - "**Dataset Structure:**\n", - "- `article`: The full news article text\n", - "- `highlights`: Human-written bullet-point summary\n", - "- We'll use the articles for summarization and evaluate our generated summaries" - ], - "id": "c0a86293e7b95dd9" - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def load_and_split_data():\n", - " \"\"\"Load CNN/DailyMail dataset for summarization.\"\"\"\n", - " print(\"\\n\" + \"=\" * 80)\n", - " print(\"📂 LOADING DATA\")\n", - " print(\"=\" * 80)\n", - "\n", - " print(\"Loading CNN/DailyMail dataset...\")\n", - " dataset = load_dataset(\"abisee/cnn_dailymail\", \"3.0.0\", trust_remote_code=True)\n", - " data = dataset['test']\n", - "\n", - " print(f\"✓ Loaded {len(data)} examples\")\n", - " print(f\" Sample article: {data[0]['article'][:100]}...\")\n", - " print(f\" Sample highlights: {data[0]['highlights'][:100]}...\")\n", - "\n", - " all_data = []\n", - " for i, item in enumerate(data):\n", - " all_data.append({\n", - " 'id': f\"cnn_{i}\",\n", - " 'text': item['article'],\n", - " 'reference_summary': item['highlights']\n", - " })\n", - "\n", - " print(f\"✓ Converted to {len(all_data)} items\")\n", - "\n", - " # Shuffle and split\n", - " random.seed(RANDOM_SEED)\n", - " random.shuffle(all_data)\n", - "\n", - " train_data = all_data[:TRAIN_SIZE]\n", - " val_data = all_data[TRAIN_SIZE:TRAIN_SIZE + VAL_SIZE]\n", - " test_data = all_data[TRAIN_SIZE + VAL_SIZE:TRAIN_SIZE + VAL_SIZE + TEST_SIZE]\n", - "\n", - " print(f\"✓ Split: Train={len(train_data)}, Val={len(val_data)}, Test={len(test_data)}\")\n", - "\n", - " # Verify\n", - " assert len(val_data) > 0, \"Val data is empty!\"\n", - " assert len(test_data) > 0, \"Test data is empty!\"\n", - "\n", - " return train_data, val_data, test_data\n", - "\n", - "# Load the data\n", - "train_data, val_data, test_data = load_and_split_data()" - ], - "id": "7dcc2d8d5c706df4" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 🤖 Summarization Module\n", - "\n", - "We create a DSPy module that wraps our summarization task. This module can be configured with different instruction prompts, which is key to the GEPA optimization process." - ], - "id": "d1b9222690db8449" - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class Summarizer(dspy.Signature):\n", - " \"\"\"Generate a summary.\"\"\"\n", - " text = dspy.InputField()\n", - " summary = dspy.OutputField()\n", - "\n", - "\n", - "class SummarizationModule(dspy.Module):\n", - " \"\"\"Summarization module.\"\"\"\n", - "\n", - " def __init__(self, instructions=None):\n", - " super().__init__()\n", - " self.instructions = instructions or BASELINE_PROMPT\n", - "\n", - " if instructions:\n", - " class CustomSummarizer(dspy.Signature):\n", - " __doc__ = instructions\n", - " text = dspy.InputField()\n", - " summary = dspy.OutputField()\n", - "\n", - " self.predictor = dspy.Predict(CustomSummarizer)\n", - " else:\n", - " self.predictor = dspy.Predict(Summarizer)\n", - "\n", - " def forward(self, text):\n", - " return self.predictor(text=text)\n", - "\n", - "print(\"✓ Summarization module defined\")" - ], - "id": "b8ca2917024c326e" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 📊 Batch Summary Generation\n", - "\n", - "This function generates summaries for a batch of articles using a given prompt. It includes error handling and progress tracking." - ], - "id": "590d6b9c625ca2cc" - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def generate_summaries_batch(\n", - " summarizer: SummarizationModule,\n", - " data: List[Dict],\n", - " desc: str = \"Generating\"\n", - ") -> List[Dict]:\n", - " \"\"\"Generate summaries for a batch of texts.\"\"\"\n", - " results = []\n", - " errors = 0\n", - " error_details = []\n", - "\n", - " # Print the prompt being used (first item only)\n", - " if len(data) > 0:\n", - " print(f\" Using prompt: {summarizer.instructions[:100]}...\")\n", - "\n", - " for item in tqdm(data, desc=desc):\n", - " try:\n", - " pred = summarizer(text=item['text'][:5000])\n", - "\n", - " if pred is None:\n", - " raise ValueError(\"Model returned None\")\n", - "\n", - " if hasattr(pred, 'summary') and pred.summary:\n", - " summary = pred.summary\n", - " elif isinstance(pred, str):\n", - " summary = pred\n", - " else:\n", - " print(f\"\\n DEBUG: pred type={type(pred)}, hasattr summary={hasattr(pred, 'summary')}\")\n", - " raise ValueError(f\"Cannot extract summary from {type(pred)}\")\n", - "\n", - " summary = summary.strip()\n", - " if len(summary) < 20:\n", - " raise ValueError(\"Summary too short\")\n", - "\n", - " except Exception as e:\n", - " errors += 1\n", - " error_details.append(str(e)[:100])\n", - "\n", - " if errors <= 5:\n", - " print(f\"\\n⚠️ Error: {str(e)[:80]}\")\n", - "\n", - " summary = \"Error generating summary.\"\n", - "\n", - " results.append({\n", - " 'id': item['id'],\n", - " 'text': item['text'],\n", - " 'summary': summary\n", - " })\n", - "\n", - " if errors > 0:\n", - " print(f\"\\n⚠️ Total errors: {errors}/{len(data)} ({errors / len(data) * 100:.1f}%)\")\n", - " from collections import Counter\n", - " common_errors = Counter(error_details).most_common(3)\n", - " print(f\" Most common errors:\")\n", - " for err, count in common_errors:\n", - " print(f\" - {err[:60]}... ({count}x)\")\n", - "\n", - " return results\n", - "\n", - "print(\"✓ Batch generation function defined\")" - ], - "id": "270abdde73d2ca72" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 🧠 Optimizer LLM Wrapper\n", - "\n", - "This wrapper allows us to use an LLM to propose improvements to our summarization prompt based on current performance." - ], - "id": "2cfe63f485894d7c" - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class SimpleOptimizerLM:\n", - " \"\"\"Wrapper for optimizer LLM.\"\"\"\n", - "\n", - " def __init__(self, model: str, api_key: str):\n", - " self.client = together.Client(api_key=api_key)\n", - " self.model = model\n", - "\n", - " def __call__(self, prompt: str) -> str:\n", - " response = self.client.chat.completions.create(\n", - " model=self.model,\n", - " messages=[{\"role\": \"user\", \"content\": prompt}],\n", - " temperature=0.7,\n", - " max_tokens=4000\n", - " )\n", - " return response.choices[0].message.content\n", - "\n", - "print(\"✓ Optimizer LLM wrapper defined\")" - ], - "id": "d11af9ff91f442df" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 🤔 Reflection and Prompt Improvement\n", - "\n", - "This function uses the optimizer LLM to analyze the current prompt and performance, then propose an improved version.\n", - "\n", - "**Key Constraints:**\n", - "- Keep prompts under 150 words for clarity\n", - "- Focus on simple, direct instructions\n", - "- Target 4-6 sentence summaries\n", - "- Avoid overly complex requirements" - ], - "id": "67a224aff87d2f5e" - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def reflect_and_improve_prompt(\n", - " current_prompt: str,\n", - " current_score: float,\n", - " optimizer_lm: SimpleOptimizerLM,\n", - " iteration: int\n", - ") -> str:\n", - " \"\"\"Use LLM to propose improved prompt.\"\"\"\n", - "\n", - " print(f\"\\n🤔 REFLECTION (Iteration {iteration})\")\n", - "\n", - " reflection_prompt = f\"\"\"You are optimizing a summarization prompt for CNN/DailyMail news articles.\n", - "\n", - "Current Prompt:\n", - "```\n", - "{current_prompt}\n", - "```\n", - "\n", - "Current Performance: {current_score:.1%} win rate\n", - "\n", - "Your task: Propose a SIMPLE improved version that generates better summaries.\n", - "\n", - "CRITICAL CONSTRAINTS:\n", - "- Keep the prompt under 150 words\n", - "- Make it clear and direct (NOT overly complex)\n", - "- Target 4-6 sentence summaries\n", - "- Avoid excessive instructions or formatting requirements\n", - "- The prompt should be easy for the model to follow\n", - "\n", - "Focus on:\n", - "- Should it emphasize different aspects (accuracy, brevity, completeness)?\n", - "- Are the current guidelines clear?\n", - "- Is anything missing or unnecessary?\n", - "\n", - "Output ONLY the improved prompt within ``` blocks. Keep it simple and clear.\"\"\"\n", - "\n", - " response = optimizer_lm(reflection_prompt)\n", - "\n", - " # Extract prompt\n", - " match = re.search(r'```(.*?)```', response, re.DOTALL)\n", - " if match:\n", - " new_prompt = match.group(1).strip()\n", - " # Remove language tags\n", - " for tag in ['markdown', 'text', 'python', 'plaintext']:\n", - " if new_prompt.startswith(f'{tag}\\n'):\n", - " new_prompt = '\\n'.join(new_prompt.split('\\n')[1:])\n", - "\n", - " # Validate length (reject if too long)\n", - " word_count = len(new_prompt.split())\n", - " if word_count > 200:\n", - " print(f\" ⚠️ Generated prompt too long ({word_count} words), using current\")\n", - " return current_prompt\n", - "\n", - " print(f\"✓ Generated new prompt ({word_count} words)\")\n", - " return new_prompt\n", - "\n", - " print(\"⚠️ Could not extract prompt\")\n", - " return current_prompt\n", - "\n", - "print(\"✓ Reflection function defined\")" - ], - "id": "1186e66cab3ea1f1" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 🔄 Head-to-Head Prompt Comparison\n", - "\n", - "This function compares two prompts by:\n", - "1. Generating summaries with both prompts\n", - "2. Creating a comparison dataset\n", - "3. Using the Together AI evaluation API with a judge model\n", - "4. Computing win rates\n", - "\n", - "The evaluation uses a two-pass approach to eliminate position bias." - ], - "id": "a2fbbd02f5054425" - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def compare_two_prompts_on_batch(\n", - " data: List[Dict],\n", - " prompt_a: str,\n", - " prompt_b: str,\n", - " summarizer_lm: dspy.LM,\n", - " eval_name: str\n", - ") -> Tuple[float, float, Dict]:\n", - " \"\"\"\n", - " Compare two summarization prompts.\n", - "\n", - " 1. Generate summaries with prompt A\n", - " 2. Generate summaries with prompt B\n", - " 3. Use judge to compare them\n", - " 4. Return win rate for prompt A\n", - " \"\"\"\n", - "\n", - " print(f\"\\n{'=' * 80}\")\n", - " print(f\"🔄 COMPARING PROMPTS: {eval_name}\")\n", - " print(f\"{'=' * 80}\")\n", - "\n", - " # Step 1: Generate with both prompts\n", - " dspy.configure(lm=summarizer_lm)\n", - "\n", - " summarizer_a = SummarizationModule(prompt_a)\n", - " summarizer_b = SummarizationModule(prompt_b)\n", - "\n", - " print(\"Generating summaries with Prompt A...\")\n", - " summaries_a = generate_summaries_batch(summarizer_a, data, \"Prompt A\")\n", - "\n", - " print(\"Generating summaries with Prompt B...\")\n", - " summaries_b = generate_summaries_batch(summarizer_b, data, \"Prompt B\")\n", - "\n", - " # Step 2: Prepare comparison data\n", - " temp_file = f\"temp_compare_{eval_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl\"\n", - "\n", - " with open(temp_file, 'w') as f:\n", - " for summary_a, summary_b in zip(summaries_a, summaries_b):\n", - " formatted = {\n", - " \"prompt\": f\"Source article: {summary_a['text'][:5000]}\",\n", - " \"model_a_output\": summary_a['summary'],\n", - " \"model_b_output\": summary_b['summary'],\n", - " \"id\": summary_a['id']\n", - " }\n", - " f.write(json.dumps(formatted) + '\\n')\n", - "\n", - " # Step 3: Upload and evaluate\n", - " print(\"📤 Uploading for comparison...\")\n", - " file_response = client.files.upload(file=temp_file, purpose=\"eval\")\n", - " file_id = file_response.id\n", - "\n", - " print(\"🚀 Launching comparison...\")\n", - " eval_response = client.evaluation.create(\n", - " type=\"compare\",\n", - " input_data_file_path=file_id,\n", - " judge_model=JUDGE_MODEL,\n", - " judge_model_source=\"serverless\",\n", - " judge_system_template=JUDGE_PROMPT,\n", - " model_a=\"model_a_output\",\n", - " model_b=\"model_b_output\"\n", - " )\n", - "\n", - " # Step 4: Wait and get results\n", - " print(f\"⏳ Waiting (ID: {eval_response.workflow_id})...\")\n", - " while True:\n", - " status = client.evaluation.status(eval_response.workflow_id)\n", - " if status.status.value == \"completed\":\n", - " break\n", - " elif status.status.value == \"failed\":\n", - " raise Exception(\"Evaluation failed\")\n", - " time.sleep(30)\n", - "\n", - " a_wins = status.results.get('A_wins', 0)\n", - " b_wins = status.results.get('B_wins', 0)\n", - " ties = status.results.get('Ties', 0)\n", - "\n", - " # Win rate for prompt A\n", - " decisive_total = a_wins + b_wins\n", - " if decisive_total > 0:\n", - " a_win_rate = a_wins / decisive_total\n", - " b_win_rate = b_wins / decisive_total\n", - " else:\n", - " a_win_rate = b_win_rate = 0.5\n", - "\n", - " print(f\"✓ Results: Prompt A wins={a_wins}, Prompt B wins={b_wins}, Ties={ties}\")\n", - " print(f\"✓ Prompt A win rate: {a_win_rate:.2%}\")\n", - "\n", - " os.remove(temp_file)\n", - "\n", - " return a_win_rate, b_win_rate, {\n", - " 'a_wins': a_wins,\n", - " 'b_wins': b_wins,\n", - " 'ties': ties,\n", - " 'a_win_rate': a_win_rate\n", - " }\n", - "\n", - "print(\"✓ Comparison function defined\")" - ], - "id": "5a1b2d5116f3731f" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 🧬 GEPA Optimization Loop\n", - "\n", - "This is the main optimization loop that implements the GEPA algorithm:\n", - "\n", - "1. **Generate**: Create summaries with current prompt\n", - "2. **Evaluate**: Compare against baseline using judge model\n", - "3. **Propose**: Use optimizer LLM to suggest improvements\n", - "4. **Adapt**: Accept improvements that increase win rate\n", - "\n", - "The process repeats for multiple iterations, tracking the best prompt found." - ], - "id": "6657d33b050676ff" - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def run_manual_gepa(\n", - " train_data: List[Dict],\n", - " val_data: List[Dict],\n", - " test_data: List[Dict],\n", - " summarizer_lm: dspy.LM,\n", - " optimizer_lm: SimpleOptimizerLM,\n", - " max_iterations: int = 5\n", - "):\n", - " \"\"\"Manual GEPA-style optimization.\"\"\"\n", - "\n", - " start_time = time.time()\n", - "\n", - " print(\"\\n\" + \"=\" * 80)\n", - " print(\"🧬 MANUAL GEPA OPTIMIZATION\")\n", - " print(\"=\" * 80)\n", - "\n", - " # Track best prompt\n", - " best_prompt = BASELINE_PROMPT\n", - " best_val_score = 0.5 # Start at 50% (neutral)\n", - "\n", - " for i in range(max_iterations):\n", - " print(f\"\\n{'=' * 80}\")\n", - " print(f\"ITERATION {i + 1}/{max_iterations}\")\n", - " print(f\"{'=' * 80}\")\n", - "\n", - " if i == 0:\n", - " print(\"Iteration 0: Establishing baseline (no comparison yet)\")\n", - " continue\n", - "\n", - " # Generate new candidate prompt\n", - " new_prompt = reflect_and_improve_prompt(\n", - " best_prompt,\n", - " best_val_score,\n", - " optimizer_lm,\n", - " i\n", - " )\n", - "\n", - " if new_prompt == best_prompt:\n", - " print(\"⚠️ No change in prompt, stopping\")\n", - " break\n", - "\n", - " print(f\"✓ Generated candidate prompt ({len(new_prompt)} chars)\")\n", - "\n", - " # Compare best_prompt vs new_prompt on validation set\n", - " baseline_win_rate, new_prompt_win_rate, metrics = compare_two_prompts_on_batch(\n", - " val_data,\n", - " prompt_a=best_prompt,\n", - " prompt_b=new_prompt,\n", - " summarizer_lm=summarizer_lm,\n", - " eval_name=f\"iter{i}_val\"\n", - " )\n", - "\n", - " new_prompt_win_rate = 1.0 - baseline_win_rate\n", - "\n", - " print(f\"\\n Current best: {baseline_win_rate:.2%}\")\n", - " print(f\" New candidate: {new_prompt_win_rate:.2%}\")\n", - "\n", - " if new_prompt_win_rate > best_val_score:\n", - " improvement = new_prompt_win_rate - best_val_score\n", - " print(f\" 🎉 New best! (+{improvement * 100:.2f}pp)\")\n", - " best_prompt = new_prompt\n", - " best_val_score = new_prompt_win_rate\n", - " else:\n", - " print(f\" No improvement\")\n", - "\n", - " # Calculate total time\n", - " total_time = time.time() - start_time\n", - " hours = int(total_time // 3600)\n", - " minutes = int((total_time % 3600) // 60)\n", - " seconds = int(total_time % 60)\n", - "\n", - " # Final test evaluation\n", - " print(\"\\n\" + \"=\" * 80)\n", - " print(\"📊 FINAL TEST EVALUATION\")\n", - " print(\"=\" * 80)\n", - "\n", - " print(f\"\\n⏱️ OPTIMIZATION TIME:\")\n", - " if hours > 0:\n", - " print(f\" Total: {hours}h {minutes}m {seconds}s\")\n", - " elif minutes > 0:\n", - " print(f\" Total: {minutes}m {seconds}s\")\n", - " else:\n", - " print(f\" Total: {seconds}s\")\n", - "\n", - " baseline_test_win_rate, optimized_test_win_rate, _ = compare_two_prompts_on_batch(\n", - " test_data,\n", - " prompt_a=BASELINE_PROMPT,\n", - " prompt_b=best_prompt,\n", - " summarizer_lm=summarizer_lm,\n", - " eval_name=\"final_test\"\n", - " )\n", - "\n", - " # Display results\n", - " print(\"\\n\" + \"=\" * 80)\n", - " print(\"🎉 FINAL RESULTS\")\n", - " print(\"=\" * 80)\n", - "\n", - " print(f\"\\nTEST SET:\")\n", - " print(f\" Baseline prompt: {baseline_test_win_rate:.2%}\")\n", - " print(f\" Optimized prompt: {optimized_test_win_rate:.2%}\")\n", - " print(f\" Improvement: {(optimized_test_win_rate - 0.5) * 100:+.2f}pp from neutral\")\n", - "\n", - " # Save results\n", - " output_dir = Path(\"results\")\n", - " output_dir.mkdir(exist_ok=True)\n", - "\n", - " timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n", - "\n", - " with open(output_dir / f\"prompts_{timestamp}.txt\", 'w') as f:\n", - " f.write(\"BASELINE:\\n\" + \"=\" * 80 + \"\\n\")\n", - " f.write(BASELINE_PROMPT)\n", - " f.write(\"\\n\\nOPTIMIZED:\\n\" + \"=\" * 80 + \"\\n\")\n", - " f.write(best_prompt)\n", - " f.write(f\"\\n\\nRESULTS:\\n\" + \"=\" * 80 + \"\\n\")\n", - " f.write(f\"Baseline: {baseline_test_win_rate:.2%}\\n\")\n", - " f.write(f\"Optimized: {optimized_test_win_rate:.2%}\\n\")\n", - "\n", - " print(f\"\\n💾 Saved to: results/prompts_{timestamp}.txt\")\n", - "\n", - " return {\n", - " 'baseline_test': baseline_test_win_rate,\n", - " 'optimized_test': optimized_test_win_rate,\n", - " 'best_prompt': best_prompt\n", - " }\n", - "\n", - "print(\"✓ GEPA optimization function defined\")" - ], - "id": "c7100da955cfb3b5" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 🚀 Run the Optimization\n", - "\n", - "Now we'll execute the full GEPA optimization process. This will:\n", - "1. Set up the summarizer and optimizer models\n", - "2. Run multiple iterations of prompt improvement\n", - "3. Evaluate the final optimized prompt on the test set\n", - "4. Display comprehensive results" - ], - "id": "4839066f78acf10d" - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(\"=\"*80)\n", - "print(\"🎯 GEPA SUMMARIZATION - TOGETHER AI BATCH EVAL\")\n", - "print(\"=\"*80)\n", - "\n", - "if not TOGETHER_API_KEY or TOGETHER_API_KEY == 'your_api_key_here':\n", - " print(\"❌ Set TOGETHER_API_KEY\")\n", - "else:\n", - " # Setup models\n", - " summarizer_lm = dspy.LM(\n", - " f\"together_ai/{SUMMARIZER_MODEL}\",\n", - " api_key=TOGETHER_API_KEY,\n", - " temperature=0.5,\n", - " max_tokens=1024\n", - " )\n", - "\n", - " optimizer_lm = SimpleOptimizerLM(\n", - " model=OPTIMIZER_MODEL,\n", - " api_key=TOGETHER_API_KEY\n", - " )\n", - "\n", - " # Run optimization\n", - " results = run_manual_gepa(\n", - " train_data,\n", - " val_data,\n", - " test_data,\n", - " summarizer_lm,\n", - " optimizer_lm,\n", - " max_iterations=5\n", - " )\n", - "\n", - " print(\"\\n✅ Complete!\")" - ], - "id": "51f60931bec8f490" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 📊 Analyzing the Results\n", - "\n", - "Let's examine the optimized prompt and compare it to the baseline." - ], - "id": "2be0f2bb00a13ff6" - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(\"=\" * 80)\n", - "print(\"📝 PROMPT COMPARISON\")\n", - "print(\"=\" * 80)\n", - "\n", - "print(\"\\nBASELINE PROMPT:\")\n", - "print(\"-\" * 80)\n", - "print(BASELINE_PROMPT)\n", - "\n", - "print(\"\\n\\nOPTIMIZED PROMPT:\")\n", - "print(\"-\" * 80)\n", - "print(results['best_prompt'])\n", - "\n", - "print(\"\\n\\nPERFORMANCE COMPARISON:\")\n", - "print(\"-\" * 80)\n", - "print(f\"Baseline Win Rate: {results['baseline_test']:.2%}\")\n", - "print(f\"Optimized Win Rate: {results['optimized_test']:.2%}\")\n", - "print(f\"Improvement: {(results['optimized_test'] - 0.5) * 100:+.2f} percentage points from neutral\")" - ], - "id": "bc461eee131bd49f" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 🔑 Key Findings\n", - "\n", - "**GEPA Optimization Process:**\n", - "- Iteratively improves prompts through LLM-guided reflection\n", - "- Uses head-to-head comparisons with a judge model\n", - "- Tracks and accepts only improvements over baseline\n", - "\n", - "**Benefits of This Approach:**\n", - "1. **Automated**: No manual prompt engineering required\n", - "2. **Data-driven**: Decisions based on actual performance metrics\n", - "3. **Scalable**: Can optimize for any task with appropriate data\n", - "4. **Transparent**: Clear tracking of improvements across iterations\n", - "\n", - "**Next Steps:**\n", - "- Try with different datasets or domains\n", - "- Experiment with different judge criteria\n", - "- Adjust the optimizer's reflection prompt\n", - "- Increase iterations for potentially better results" - ], - "id": "8b606f57d491feb6" - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.0" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file From 414b4d23fcfb48e9f7edb25e474af05c0eafa4db Mon Sep 17 00:00:00 2001 From: jli Date: Tue, 23 Dec 2025 10:44:23 +0800 Subject: [PATCH 3/6] update notebook link --- Evals/GEPA_Optimization.ipynb | 2316 +++++++++++++++---------------- Evals/Prompt_Optimization.ipynb | 1063 ++++++++++++++ 2 files changed, 2221 insertions(+), 1158 deletions(-) create mode 100644 Evals/Prompt_Optimization.ipynb diff --git a/Evals/GEPA_Optimization.ipynb b/Evals/GEPA_Optimization.ipynb index 254eb56..11127d4 100644 --- a/Evals/GEPA_Optimization.ipynb +++ b/Evals/GEPA_Optimization.ipynb @@ -1,1203 +1,1203 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "9bed21b9f21cadb7" - }, - "source": [ - "# GEPA Summarization Optimization with LLM Judge Evaluation\n", - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/)\n", - "\n", - "## Introduction\n", - "\n", - "This notebook demonstrates how to optimize summarization prompts using GEPA (Generate, Evaluate, Propose, Adapt) with the our Evaluations API. We'll:\n", - "\n", - "1. Load the CNN/DailyMail dataset containing news articles\n", - "2. Start with a baseline summarization prompt\n", - "3. Use an optimizer LLM to iteratively improve the prompt\n", - "4. Compare prompts head-to-head using a judge model\n", - "5. Track improvement over multiple iterations\n", - "\n", - "**Concepts Covered:**\n", - "- **GEPA Optimization**: Iterative prompt engineering using LLM feedback\n", - "- **LLM-as-a-Judge**: Using a language model to evaluate and compare outputs\n", - "- **Batch Evaluation**: Efficient comparison of multiple summaries\n", - "- **Prompt Engineering**: Systematic improvement of instruction prompts" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "c044d292f626f2f6" - }, - "source": [ - "## 📦 Setup and Installation" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "id": "cf56ca26c1b94222" - }, - "outputs": [], - "source": [ - "!pip install -qU together dspy-ai datasets tqdm" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 216 - }, - "id": "1c293b491e894110", - "outputId": "e393f618-61a5-415e-ce69-18ebf78fbe99" - }, - "outputs": [ - { - "output_type": "display_data", - "data": { - "text/plain": [ - "\u001b[36m╭─\u001b[0m\u001b[36m────────────────────────────────────────────\u001b[0m\u001b[36m 🚀 New SDK Available \u001b[0m\u001b[36m─────────────────────────────────────────────\u001b[0m\u001b[36m─╮\u001b[0m\n", - "\u001b[36m│\u001b[0m \u001b[1;36mTogether Python SDK 2.0 is now available!\u001b[0m \u001b[36m│\u001b[0m\n", - "\u001b[36m│\u001b[0m \u001b[36m│\u001b[0m\n", - "\u001b[36m│\u001b[0m Install the beta: \u001b[36m│\u001b[0m\n", - "\u001b[36m│\u001b[0m \u001b[32mpip install --pre together\u001b[0m or \u001b[32muv add together --prerelease allow\u001b[0m \u001b[36m│\u001b[0m\n", - "\u001b[36m│\u001b[0m \u001b[36m│\u001b[0m\n", - "\u001b[36m│\u001b[0m New SDK: \u001b]8;id=629133;https://github.com/togethercomputer/together-py\u001b\\https://github.com/togethercomputer/together-py\u001b]8;;\u001b\\ \u001b[36m│\u001b[0m\n", - "\u001b[36m│\u001b[0m Migration guide: \u001b]8;id=644417;https://docs.together.ai/docs/pythonv2-migration-guide\u001b\\https://docs.together.ai/docs/pythonv2-migration-guide\u001b]8;;\u001b\\ \u001b[36m│\u001b[0m\n", - "\u001b[36m│\u001b[0m \u001b[36m│\u001b[0m\n", - "\u001b[36m│\u001b[0m \u001b[2mThis package will be maintained until January 2026.\u001b[0m \u001b[36m│\u001b[0m\n", - "\u001b[36m│\u001b[0m \u001b[2mSet TOGETHER_NO_BANNER=1 to hide this message.\u001b[0m \u001b[36m│\u001b[0m\n", - "\u001b[36m╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n" - ], - "text/html": [ - "
╭───────────────────────────────────────────── 🚀 New SDK Available ──────────────────────────────────────────────╮\n",
-              " Together Python SDK 2.0 is now available!                                                                       \n",
-              "                                                                                                                 \n",
-              " Install the beta:                                                                                               \n",
-              " pip install --pre together  or  uv add together --prerelease allow                                              \n",
-              "                                                                                                                 \n",
-              " New SDK: https://github.com/togethercomputer/together-py                                                        \n",
-              " Migration guide: https://docs.together.ai/docs/pythonv2-migration-guide                                         \n",
-              "                                                                                                                 \n",
-              " This package will be maintained until January 2026.                                                             \n",
-              " Set TOGETHER_NO_BANNER=1 to hide this message.                                                                  \n",
-              "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n",
-              "
\n" - ] - }, - "metadata": {} - } - ], - "source": [ - "import together\n", - "import json\n", - "import random\n", - "import os\n", - "import re\n", - "import time\n", - "from pathlib import Path\n", - "from typing import List, Dict, Tuple\n", - "from datetime import datetime\n", - "\n", - "import dspy\n", - "from datasets import load_dataset\n", - "from tqdm import tqdm" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8e71863c8ff3faa6" - }, - "source": [ - "## ⚙️ Configuration\n", - "\n", - "Set up your API key and configure the models we'll use:\n", - "- **Summarizer Model**: Generates the summaries\n", - "- **Judge Model**: Evaluates which summary is better\n", - "- **Optimizer Model**: Proposes improvements to the prompt" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "3d21616fa03c0145", - "outputId": "84889606-a0fb-4556-af15-3b1c9e7fc4ad" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "✓ API key loaded from Colab secrets\n", - "✓ Configuration complete\n" - ] - } - ], - "source": [ - "# Set your Together AI API key from Colab secrets\n", - "from google.colab import userdata\n", - "TOGETHER_API_KEY = userdata.get('TOGETHER_API_KEY')\n", - "print(\"✓ API key loaded from Colab secrets\")\n", - "\n", - "client = together.Client(api_key=TOGETHER_API_KEY)\n", - "\n", - "# Model configuration\n", - "SUMMARIZER_MODEL = \"openai/gpt-oss-20b\"\n", - "JUDGE_MODEL = \"deepseek-ai/DeepSeek-V3\"\n", - "OPTIMIZER_MODEL = \"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo\"\n", - "\n", - "# Data splits\n", - "TRAIN_SIZE = 150\n", - "VAL_SIZE = 300\n", - "TEST_SIZE = 300\n", - "\n", - "RANDOM_SEED = 42\n", - "\n", - "print(\"✓ Configuration complete\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "d9378d341fb8389d" - }, - "source": [ - "## 📝 Baseline and Judge Prompts\n", - "\n", - "We start with a simple baseline prompt for summarization. The GEPA process will iteratively improve this prompt based on performance feedback." - ] + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "9bed21b9f21cadb7" + }, + "source": [ + "# GEPA Summarization Optimization with LLM Judge Evaluation\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/togethercomputer/together-cookbook/blob/main/Evals/GEPA_Optimization.ipynb)\n", + "\n", + "## Introduction\n", + "\n", + "This notebook demonstrates how to optimize summarization prompts using GEPA (Generate, Evaluate, Propose, Adapt) with the our Evaluations API. We'll:\n", + "\n", + "1. Load the CNN/DailyMail dataset containing news articles\n", + "2. Start with a baseline summarization prompt\n", + "3. Use an optimizer LLM to iteratively improve the prompt\n", + "4. Compare prompts head-to-head using a judge model\n", + "5. Track improvement over multiple iterations\n", + "\n", + "**Concepts Covered:**\n", + "- **GEPA Optimization**: Iterative prompt engineering using LLM feedback\n", + "- **LLM-as-a-Judge**: Using a language model to evaluate and compare outputs\n", + "- **Batch Evaluation**: Efficient comparison of multiple summaries\n", + "- **Prompt Engineering**: Systematic improvement of instruction prompts" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "c044d292f626f2f6" + }, + "source": [ + "## 📦 Setup and Installation" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "cf56ca26c1b94222" + }, + "outputs": [], + "source": [ + "!pip install -qU together dspy-ai datasets tqdm" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 216 }, + "id": "1c293b491e894110", + "outputId": "e393f618-61a5-415e-ce69-18ebf78fbe99" + }, + "outputs": [ { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "263940c8c55eb1dd", - "outputId": "a2041a07-268c-4815-a7a4-85c964b7b2be" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Baseline Prompt:\n", - "Summarize this news article in 3-5 key points.\n", - "\n", - "Write a brief summary covering:\n", - "- The main news event\n", - "- Key people or organizations involved\n", - "- Important details or outcomes\n", - "- Any significant context\n", - "\n", - "Keep it to 3-5 sentences total.\n", - "\n", - "Judge Prompt:\n", - "Compare these two summaries of the same news article.\n", - "\n", - "Which summary better:\n", - "- Captures the main news story\n", - "- Includes important details\n", - "- Is clear and concise\n", - "- Avoids unnecessary information\n", - "\n", - "Choose A or B and explain why briefly.\n" - ] - } + "output_type": "display_data", + "data": { + "text/plain": [ + "\u001B[36m╭─\u001B[0m\u001B[36m────────────────────────────────────────────\u001B[0m\u001B[36m 🚀 New SDK Available \u001B[0m\u001B[36m─────────────────────────────────────────────\u001B[0m\u001B[36m─╮\u001B[0m\n", + "\u001B[36m│\u001B[0m \u001B[1;36mTogether Python SDK 2.0 is now available!\u001B[0m \u001B[36m│\u001B[0m\n", + "\u001B[36m│\u001B[0m \u001B[36m│\u001B[0m\n", + "\u001B[36m│\u001B[0m Install the beta: \u001B[36m│\u001B[0m\n", + "\u001B[36m│\u001B[0m \u001B[32mpip install --pre together\u001B[0m or \u001B[32muv add together --prerelease allow\u001B[0m \u001B[36m│\u001B[0m\n", + "\u001B[36m│\u001B[0m \u001B[36m│\u001B[0m\n", + "\u001B[36m│\u001B[0m New SDK: \u001B]8;id=629133;https://github.com/togethercomputer/together-py\u001B\\https://github.com/togethercomputer/together-py\u001B]8;;\u001B\\ \u001B[36m│\u001B[0m\n", + "\u001B[36m│\u001B[0m Migration guide: \u001B]8;id=644417;https://docs.together.ai/docs/pythonv2-migration-guide\u001B\\https://docs.together.ai/docs/pythonv2-migration-guide\u001B]8;;\u001B\\ \u001B[36m│\u001B[0m\n", + "\u001B[36m│\u001B[0m \u001B[36m│\u001B[0m\n", + "\u001B[36m│\u001B[0m \u001B[2mThis package will be maintained until January 2026.\u001B[0m \u001B[36m│\u001B[0m\n", + "\u001B[36m│\u001B[0m \u001B[2mSet TOGETHER_NO_BANNER=1 to hide this message.\u001B[0m \u001B[36m│\u001B[0m\n", + "\u001B[36m╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\u001B[0m\n" ], - "source": [ - "BASELINE_PROMPT = \"\"\"Summarize this news article in 3-5 key points.\n", - "\n", - "Write a brief summary covering:\n", - "- The main news event\n", - "- Key people or organizations involved\n", - "- Important details or outcomes\n", - "- Any significant context\n", - "\n", - "Keep it to 3-5 sentences total.\"\"\"\n", - "\n", - "JUDGE_PROMPT = \"\"\"Compare these two summaries of the same news article.\n", - "\n", - "Which summary better:\n", - "- Captures the main news story\n", - "- Includes important details\n", - "- Is clear and concise\n", - "- Avoids unnecessary information\n", - "\n", - "Choose A or B and explain why briefly.\"\"\"\n", - "\n", - "print(\"Baseline Prompt:\")\n", - "print(BASELINE_PROMPT)\n", - "print(\"\\nJudge Prompt:\")\n", - "print(JUDGE_PROMPT)" + "text/html": [ + "
╭───────────────────────────────────────────── 🚀 New SDK Available ──────────────────────────────────────────────╮\n",
+       " Together Python SDK 2.0 is now available!                                                                       \n",
+       "                                                                                                                 \n",
+       " Install the beta:                                                                                               \n",
+       " pip install --pre together  or  uv add together --prerelease allow                                              \n",
+       "                                                                                                                 \n",
+       " New SDK: https://github.com/togethercomputer/together-py                                                        \n",
+       " Migration guide: https://docs.together.ai/docs/pythonv2-migration-guide                                         \n",
+       "                                                                                                                 \n",
+       " This package will be maintained until January 2026.                                                             \n",
+       " Set TOGETHER_NO_BANNER=1 to hide this message.                                                                  \n",
+       "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n",
+       "
\n" ] + }, + "metadata": {} + } + ], + "source": [ + "import together\n", + "import json\n", + "import random\n", + "import os\n", + "import re\n", + "import time\n", + "from pathlib import Path\n", + "from typing import List, Dict, Tuple\n", + "from datetime import datetime\n", + "\n", + "import dspy\n", + "from datasets import load_dataset\n", + "from tqdm import tqdm" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8e71863c8ff3faa6" + }, + "source": [ + "## ⚙️ Configuration\n", + "\n", + "Set up your API key and configure the models we'll use:\n", + "- **Summarizer Model**: Generates the summaries\n", + "- **Judge Model**: Evaluates which summary is better\n", + "- **Optimizer Model**: Proposes improvements to the prompt" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "3d21616fa03c0145", + "outputId": "84889606-a0fb-4556-af15-3b1c9e7fc4ad" + }, + "outputs": [ { - "cell_type": "markdown", - "metadata": { - "id": "c0a86293e7b95dd9" - }, - "source": [ - "## 📂 Loading the CNN/DailyMail Dataset\n", - "\n", - "The CNN/DailyMail dataset contains news articles paired with human-written highlights. We'll use the articles as our source text and split the data into train, validation, and test sets.\n", - "\n", - "**Dataset Structure:**\n", - "- `article`: The full news article text\n", - "- `highlights`: Human-written bullet-point summary\n", - "- We'll use the articles for summarization and evaluate our generated summaries" - ] + "output_type": "stream", + "name": "stdout", + "text": [ + "✓ API key loaded from Colab secrets\n", + "✓ Configuration complete\n" + ] + } + ], + "source": [ + "# Set your Together AI API key from Colab secrets\n", + "from google.colab import userdata\n", + "TOGETHER_API_KEY = userdata.get('TOGETHER_API_KEY')\n", + "print(\"✓ API key loaded from Colab secrets\")\n", + "\n", + "client = together.Client(api_key=TOGETHER_API_KEY)\n", + "\n", + "# Model configuration\n", + "SUMMARIZER_MODEL = \"openai/gpt-oss-20b\"\n", + "JUDGE_MODEL = \"deepseek-ai/DeepSeek-V3\"\n", + "OPTIMIZER_MODEL = \"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo\"\n", + "\n", + "# Data splits\n", + "TRAIN_SIZE = 150\n", + "VAL_SIZE = 300\n", + "TEST_SIZE = 300\n", + "\n", + "RANDOM_SEED = 42\n", + "\n", + "print(\"✓ Configuration complete\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "d9378d341fb8389d" + }, + "source": [ + "## 📝 Baseline and Judge Prompts\n", + "\n", + "We start with a simple baseline prompt for summarization. The GEPA process will iteratively improve this prompt based on performance feedback." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "263940c8c55eb1dd", + "outputId": "a2041a07-268c-4815-a7a4-85c964b7b2be" + }, + "outputs": [ { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "7dcc2d8d5c706df4", - "outputId": "e8dcb543-c238-42d3-af49-bcd77bfe7b7f" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "\n", - "================================================================================\n", - "📂 LOADING DATA\n", - "================================================================================\n", - "Loading CNN/DailyMail dataset...\n", - "✓ Loaded 11490 examples\n", - " Sample article: (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Cour...\n", - " Sample highlights: Membership gives the ICC jurisdiction over alleged crimes committed in Palestinian territories since...\n", - "✓ Converted to 11490 items\n", - "✓ Split: Train=150, Val=300, Test=300\n" - ] - } - ], - "source": [ - "def load_and_split_data():\n", - " \"\"\"Load CNN/DailyMail dataset for summarization.\"\"\"\n", - " print(\"\\n\" + \"=\" * 80)\n", - " print(\"📂 LOADING DATA\")\n", - " print(\"=\" * 80)\n", - "\n", - " print(\"Loading CNN/DailyMail dataset...\")\n", - " dataset = load_dataset(\"abisee/cnn_dailymail\", \"3.0.0\")\n", - " data = dataset['test']\n", - "\n", - " print(f\"✓ Loaded {len(data)} examples\")\n", - " print(f\" Sample article: {data[0]['article'][:100]}...\")\n", - " print(f\" Sample highlights: {data[0]['highlights'][:100]}...\")\n", - "\n", - " all_data = []\n", - " for i, item in enumerate(data):\n", - " all_data.append({\n", - " 'id': f\"cnn_{i}\",\n", - " 'text': item['article'],\n", - " 'reference_summary': item['highlights']\n", - " })\n", - "\n", - " print(f\"✓ Converted to {len(all_data)} items\")\n", - "\n", - " random.seed(RANDOM_SEED)\n", - " random.shuffle(all_data)\n", - "\n", - " train_data = all_data[:TRAIN_SIZE]\n", - " val_data = all_data[TRAIN_SIZE:TRAIN_SIZE + VAL_SIZE]\n", - " test_data = all_data[TRAIN_SIZE + VAL_SIZE:TRAIN_SIZE + VAL_SIZE + TEST_SIZE]\n", - "\n", - " print(f\"✓ Split: Train={len(train_data)}, Val={len(val_data)}, Test={len(test_data)}\")\n", - "\n", - " assert len(val_data) > 0, \"Val data is empty!\"\n", - " assert len(test_data) > 0, \"Test data is empty!\"\n", - "\n", - " return train_data, val_data, test_data\n", - "\n", - "# Load the data\n", - "train_data, val_data, test_data = load_and_split_data()" - ] + "output_type": "stream", + "name": "stdout", + "text": [ + "Baseline Prompt:\n", + "Summarize this news article in 3-5 key points.\n", + "\n", + "Write a brief summary covering:\n", + "- The main news event\n", + "- Key people or organizations involved\n", + "- Important details or outcomes\n", + "- Any significant context\n", + "\n", + "Keep it to 3-5 sentences total.\n", + "\n", + "Judge Prompt:\n", + "Compare these two summaries of the same news article.\n", + "\n", + "Which summary better:\n", + "- Captures the main news story\n", + "- Includes important details\n", + "- Is clear and concise\n", + "- Avoids unnecessary information\n", + "\n", + "Choose A or B and explain why briefly.\n" + ] + } + ], + "source": [ + "BASELINE_PROMPT = \"\"\"Summarize this news article in 3-5 key points.\n", + "\n", + "Write a brief summary covering:\n", + "- The main news event\n", + "- Key people or organizations involved\n", + "- Important details or outcomes\n", + "- Any significant context\n", + "\n", + "Keep it to 3-5 sentences total.\"\"\"\n", + "\n", + "JUDGE_PROMPT = \"\"\"Compare these two summaries of the same news article.\n", + "\n", + "Which summary better:\n", + "- Captures the main news story\n", + "- Includes important details\n", + "- Is clear and concise\n", + "- Avoids unnecessary information\n", + "\n", + "Choose A or B and explain why briefly.\"\"\"\n", + "\n", + "print(\"Baseline Prompt:\")\n", + "print(BASELINE_PROMPT)\n", + "print(\"\\nJudge Prompt:\")\n", + "print(JUDGE_PROMPT)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "c0a86293e7b95dd9" + }, + "source": [ + "## 📂 Loading the CNN/DailyMail Dataset\n", + "\n", + "The CNN/DailyMail dataset contains news articles paired with human-written highlights. We'll use the articles as our source text and split the data into train, validation, and test sets.\n", + "\n", + "**Dataset Structure:**\n", + "- `article`: The full news article text\n", + "- `highlights`: Human-written bullet-point summary\n", + "- We'll use the articles for summarization and evaluate our generated summaries" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "7dcc2d8d5c706df4", + "outputId": "e8dcb543-c238-42d3-af49-bcd77bfe7b7f" + }, + "outputs": [ { - "cell_type": "markdown", - "metadata": { - "id": "d1b9222690db8449" - }, - "source": [ - "## 🤖 Summarization Module\n", - "\n", - "We create a DSPy module that wraps our summarization task. This module can be configured with different instruction prompts, which is key to the GEPA optimization process." - ] + "output_type": "stream", + "name": "stdout", + "text": [ + "\n", + "================================================================================\n", + "📂 LOADING DATA\n", + "================================================================================\n", + "Loading CNN/DailyMail dataset...\n", + "✓ Loaded 11490 examples\n", + " Sample article: (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Cour...\n", + " Sample highlights: Membership gives the ICC jurisdiction over alleged crimes committed in Palestinian territories since...\n", + "✓ Converted to 11490 items\n", + "✓ Split: Train=150, Val=300, Test=300\n" + ] + } + ], + "source": [ + "def load_and_split_data():\n", + " \"\"\"Load CNN/DailyMail dataset for summarization.\"\"\"\n", + " print(\"\\n\" + \"=\" * 80)\n", + " print(\"📂 LOADING DATA\")\n", + " print(\"=\" * 80)\n", + "\n", + " print(\"Loading CNN/DailyMail dataset...\")\n", + " dataset = load_dataset(\"abisee/cnn_dailymail\", \"3.0.0\")\n", + " data = dataset['test']\n", + "\n", + " print(f\"✓ Loaded {len(data)} examples\")\n", + " print(f\" Sample article: {data[0]['article'][:100]}...\")\n", + " print(f\" Sample highlights: {data[0]['highlights'][:100]}...\")\n", + "\n", + " all_data = []\n", + " for i, item in enumerate(data):\n", + " all_data.append({\n", + " 'id': f\"cnn_{i}\",\n", + " 'text': item['article'],\n", + " 'reference_summary': item['highlights']\n", + " })\n", + "\n", + " print(f\"✓ Converted to {len(all_data)} items\")\n", + "\n", + " random.seed(RANDOM_SEED)\n", + " random.shuffle(all_data)\n", + "\n", + " train_data = all_data[:TRAIN_SIZE]\n", + " val_data = all_data[TRAIN_SIZE:TRAIN_SIZE + VAL_SIZE]\n", + " test_data = all_data[TRAIN_SIZE + VAL_SIZE:TRAIN_SIZE + VAL_SIZE + TEST_SIZE]\n", + "\n", + " print(f\"✓ Split: Train={len(train_data)}, Val={len(val_data)}, Test={len(test_data)}\")\n", + "\n", + " assert len(val_data) > 0, \"Val data is empty!\"\n", + " assert len(test_data) > 0, \"Test data is empty!\"\n", + "\n", + " return train_data, val_data, test_data\n", + "\n", + "# Load the data\n", + "train_data, val_data, test_data = load_and_split_data()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "d1b9222690db8449" + }, + "source": [ + "## 🤖 Summarization Module\n", + "\n", + "We create a DSPy module that wraps our summarization task. This module can be configured with different instruction prompts, which is key to the GEPA optimization process." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "b8ca2917024c326e", + "outputId": "171c4567-9971-499a-edad-04b67c858885" + }, + "outputs": [ { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "b8ca2917024c326e", - "outputId": "171c4567-9971-499a-edad-04b67c858885" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "✓ Summarization module defined\n" - ] - } - ], - "source": [ - "class Summarizer(dspy.Signature):\n", - " \"\"\"Generate a summary.\"\"\"\n", - " text = dspy.InputField()\n", - " summary = dspy.OutputField()\n", - "\n", - "\n", - "class SummarizationModule(dspy.Module):\n", - " \"\"\"Summarization module.\"\"\"\n", - "\n", - " def __init__(self, instructions=None):\n", - " super().__init__()\n", - " self.instructions = instructions or BASELINE_PROMPT\n", - "\n", - " if instructions:\n", - " class CustomSummarizer(dspy.Signature):\n", - " __doc__ = instructions\n", - " text = dspy.InputField()\n", - " summary = dspy.OutputField()\n", - "\n", - " self.predictor = dspy.Predict(CustomSummarizer)\n", - " else:\n", - " self.predictor = dspy.Predict(Summarizer)\n", - "\n", - " def forward(self, text):\n", - " return self.predictor(text=text)\n", - "\n", - "print(\"✓ Summarization module defined\")" - ] + "output_type": "stream", + "name": "stdout", + "text": [ + "✓ Summarization module defined\n" + ] + } + ], + "source": [ + "class Summarizer(dspy.Signature):\n", + " \"\"\"Generate a summary.\"\"\"\n", + " text = dspy.InputField()\n", + " summary = dspy.OutputField()\n", + "\n", + "\n", + "class SummarizationModule(dspy.Module):\n", + " \"\"\"Summarization module.\"\"\"\n", + "\n", + " def __init__(self, instructions=None):\n", + " super().__init__()\n", + " self.instructions = instructions or BASELINE_PROMPT\n", + "\n", + " if instructions:\n", + " class CustomSummarizer(dspy.Signature):\n", + " __doc__ = instructions\n", + " text = dspy.InputField()\n", + " summary = dspy.OutputField()\n", + "\n", + " self.predictor = dspy.Predict(CustomSummarizer)\n", + " else:\n", + " self.predictor = dspy.Predict(Summarizer)\n", + "\n", + " def forward(self, text):\n", + " return self.predictor(text=text)\n", + "\n", + "print(\"✓ Summarization module defined\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "590d6b9c625ca2cc" + }, + "source": [ + "## 📊 Batch Summary Generation\n", + "\n", + "This function generates summaries for a batch of articles using a given prompt. It includes error handling and progress tracking." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "270abdde73d2ca72", + "outputId": "6eafb2d3-e773-4a65-f3b5-802687fffafc" + }, + "outputs": [ { - "cell_type": "markdown", - "metadata": { - "id": "590d6b9c625ca2cc" - }, - "source": [ - "## 📊 Batch Summary Generation\n", - "\n", - "This function generates summaries for a batch of articles using a given prompt. It includes error handling and progress tracking." - ] + "output_type": "stream", + "name": "stdout", + "text": [ + "✓ Batch generation function defined\n" + ] + } + ], + "source": [ + "def generate_summaries_batch(\n", + " summarizer: SummarizationModule,\n", + " data: List[Dict],\n", + " desc: str = \"Generating\"\n", + ") -> List[Dict]:\n", + " \"\"\"Generate summaries for a batch of texts.\"\"\"\n", + " results = []\n", + " errors = 0\n", + " error_details = []\n", + "\n", + " # Print the prompt being used (first item only)\n", + " if len(data) > 0:\n", + " print(f\" Using prompt: {summarizer.instructions[:100]}...\")\n", + "\n", + " for item in tqdm(data, desc=desc):\n", + " try:\n", + " pred = summarizer(text=item['text'][:5000])\n", + "\n", + " if pred is None:\n", + " raise ValueError(\"Model returned None\")\n", + "\n", + " if hasattr(pred, 'summary') and pred.summary:\n", + " summary = pred.summary\n", + " elif isinstance(pred, str):\n", + " summary = pred\n", + " else:\n", + " print(f\"\\n DEBUG: pred type={type(pred)}, hasattr summary={hasattr(pred, 'summary')}\")\n", + " raise ValueError(f\"Cannot extract summary from {type(pred)}\")\n", + "\n", + " summary = summary.strip()\n", + " if len(summary) < 20:\n", + " raise ValueError(\"Summary too short\")\n", + "\n", + " except Exception as e:\n", + " errors += 1\n", + " error_details.append(str(e)[:100])\n", + "\n", + " if errors <= 5:\n", + " print(f\"\\n⚠️ Error: {str(e)[:80]}\")\n", + "\n", + " summary = \"Error generating summary.\"\n", + "\n", + " results.append({\n", + " 'id': item['id'],\n", + " 'text': item['text'],\n", + " 'summary': summary\n", + " })\n", + "\n", + " if errors > 0:\n", + " print(f\"\\n⚠️ Total errors: {errors}/{len(data)} ({errors / len(data) * 100:.1f}%)\")\n", + " from collections import Counter\n", + " common_errors = Counter(error_details).most_common(3)\n", + " print(f\" Most common errors:\")\n", + " for err, count in common_errors:\n", + " print(f\" - {err[:60]}... ({count}x)\")\n", + "\n", + " return results\n", + "\n", + "print(\"✓ Batch generation function defined\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2cfe63f485894d7c" + }, + "source": [ + "## 🧠 Optimizer LLM Wrapper\n", + "\n", + "This wrapper allows us to use an LLM to propose improvements to our summarization prompt based on current performance." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "d11af9ff91f442df", + "outputId": "c9cd0f0e-7325-46cc-d065-d4a3745c08c3" + }, + "outputs": [ { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "270abdde73d2ca72", - "outputId": "6eafb2d3-e773-4a65-f3b5-802687fffafc" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "✓ Batch generation function defined\n" - ] - } - ], - "source": [ - "def generate_summaries_batch(\n", - " summarizer: SummarizationModule,\n", - " data: List[Dict],\n", - " desc: str = \"Generating\"\n", - ") -> List[Dict]:\n", - " \"\"\"Generate summaries for a batch of texts.\"\"\"\n", - " results = []\n", - " errors = 0\n", - " error_details = []\n", - "\n", - " # Print the prompt being used (first item only)\n", - " if len(data) > 0:\n", - " print(f\" Using prompt: {summarizer.instructions[:100]}...\")\n", - "\n", - " for item in tqdm(data, desc=desc):\n", - " try:\n", - " pred = summarizer(text=item['text'][:5000])\n", - "\n", - " if pred is None:\n", - " raise ValueError(\"Model returned None\")\n", - "\n", - " if hasattr(pred, 'summary') and pred.summary:\n", - " summary = pred.summary\n", - " elif isinstance(pred, str):\n", - " summary = pred\n", - " else:\n", - " print(f\"\\n DEBUG: pred type={type(pred)}, hasattr summary={hasattr(pred, 'summary')}\")\n", - " raise ValueError(f\"Cannot extract summary from {type(pred)}\")\n", - "\n", - " summary = summary.strip()\n", - " if len(summary) < 20:\n", - " raise ValueError(\"Summary too short\")\n", - "\n", - " except Exception as e:\n", - " errors += 1\n", - " error_details.append(str(e)[:100])\n", - "\n", - " if errors <= 5:\n", - " print(f\"\\n⚠️ Error: {str(e)[:80]}\")\n", - "\n", - " summary = \"Error generating summary.\"\n", - "\n", - " results.append({\n", - " 'id': item['id'],\n", - " 'text': item['text'],\n", - " 'summary': summary\n", - " })\n", - "\n", - " if errors > 0:\n", - " print(f\"\\n⚠️ Total errors: {errors}/{len(data)} ({errors / len(data) * 100:.1f}%)\")\n", - " from collections import Counter\n", - " common_errors = Counter(error_details).most_common(3)\n", - " print(f\" Most common errors:\")\n", - " for err, count in common_errors:\n", - " print(f\" - {err[:60]}... ({count}x)\")\n", - "\n", - " return results\n", - "\n", - "print(\"✓ Batch generation function defined\")" - ] + "output_type": "stream", + "name": "stdout", + "text": [ + "✓ Optimizer LLM wrapper defined\n" + ] + } + ], + "source": [ + "class SimpleOptimizerLM:\n", + " \"\"\"Wrapper for optimizer LLM.\"\"\"\n", + "\n", + " def __init__(self, model: str, api_key: str):\n", + " self.client = together.Client(api_key=api_key)\n", + " self.model = model\n", + "\n", + " def __call__(self, prompt: str) -> str:\n", + " response = self.client.chat.completions.create(\n", + " model=self.model,\n", + " messages=[{\"role\": \"user\", \"content\": prompt}],\n", + " temperature=0.7,\n", + " max_tokens=4000\n", + " )\n", + " return response.choices[0].message.content\n", + "\n", + "print(\"✓ Optimizer LLM wrapper defined\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "67a224aff87d2f5e" + }, + "source": [ + "## 🤔 Reflection and Prompt Improvement\n", + "\n", + "This function uses the optimizer LLM to analyze the current prompt and performance, then propose an improved version.\n", + "\n", + "**Key Constraints:**\n", + "- Keep prompts under 150 words for clarity\n", + "- Focus on simple, direct instructions\n", + "- Target 4-6 sentence summaries\n", + "- Avoid overly complex requirements" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "1186e66cab3ea1f1", + "outputId": "a8ea71b8-da99-4efa-c72b-59603458e664" + }, + "outputs": [ { - "cell_type": "markdown", - "metadata": { - "id": "2cfe63f485894d7c" - }, - "source": [ - "## 🧠 Optimizer LLM Wrapper\n", - "\n", - "This wrapper allows us to use an LLM to propose improvements to our summarization prompt based on current performance." - ] + "output_type": "stream", + "name": "stdout", + "text": [ + "✓ Reflection function defined\n" + ] + } + ], + "source": [ + "def reflect_and_improve_prompt(\n", + " current_prompt: str,\n", + " current_score: float,\n", + " optimizer_lm: SimpleOptimizerLM,\n", + " iteration: int\n", + ") -> str:\n", + " \"\"\"Use LLM to propose improved prompt.\"\"\"\n", + "\n", + " print(f\"\\n🤔 REFLECTION (Iteration {iteration})\")\n", + "\n", + " reflection_prompt = f\"\"\"You are optimizing a summarization prompt for CNN/DailyMail news articles.\n", + "\n", + "Current Prompt:\n", + "```\n", + "{current_prompt}\n", + "```\n", + "\n", + "Current Performance: {current_score:.1%} win rate\n", + "\n", + "Your task: Propose a SIMPLE improved version that generates better summaries.\n", + "\n", + "CRITICAL CONSTRAINTS:\n", + "- Keep the prompt under 150 words\n", + "- Make it clear and direct (NOT overly complex)\n", + "- Target 4-6 sentence summaries\n", + "- Avoid excessive instructions or formatting requirements\n", + "- The prompt should be easy for the model to follow\n", + "\n", + "Focus on:\n", + "- Should it emphasize different aspects (accuracy, brevity, completeness)?\n", + "- Are the current guidelines clear?\n", + "- Is anything missing or unnecessary?\n", + "\n", + "Output ONLY the improved prompt within ``` blocks. Keep it simple and clear.\"\"\"\n", + "\n", + " response = optimizer_lm(reflection_prompt)\n", + "\n", + " # Extract prompt\n", + " match = re.search(r'```(.*?)```', response, re.DOTALL)\n", + " if match:\n", + " new_prompt = match.group(1).strip()\n", + " # Remove language tags\n", + " for tag in ['markdown', 'text', 'python', 'plaintext']:\n", + " if new_prompt.startswith(f'{tag}\\n'):\n", + " new_prompt = '\\n'.join(new_prompt.split('\\n')[1:])\n", + "\n", + " # Validate length (reject if too long)\n", + " word_count = len(new_prompt.split())\n", + " if word_count > 200:\n", + " print(f\" ⚠️ Generated prompt too long ({word_count} words), using current\")\n", + " return current_prompt\n", + "\n", + " print(f\"✓ Generated new prompt ({word_count} words)\")\n", + " return new_prompt\n", + "\n", + " print(\"⚠️ Could not extract prompt\")\n", + " return current_prompt\n", + "\n", + "print(\"✓ Reflection function defined\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "a2fbbd02f5054425" + }, + "source": [ + "## 🔄 Head-to-Head Prompt Comparison\n", + "\n", + "This function compares two prompts by:\n", + "1. Generating summaries with both prompts\n", + "2. Creating a comparison dataset\n", + "3. Using the Together AI evaluation API with a judge model\n", + "4. Computing win rates\n", + "\n", + "The evaluation uses a two-pass approach to eliminate position bias." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "5a1b2d5116f3731f", + "outputId": "f6aa5880-7905-4acc-c9ab-b01dc2b6a30f" + }, + "outputs": [ { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "d11af9ff91f442df", - "outputId": "c9cd0f0e-7325-46cc-d065-d4a3745c08c3" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "✓ Optimizer LLM wrapper defined\n" - ] - } - ], - "source": [ - "class SimpleOptimizerLM:\n", - " \"\"\"Wrapper for optimizer LLM.\"\"\"\n", - "\n", - " def __init__(self, model: str, api_key: str):\n", - " self.client = together.Client(api_key=api_key)\n", - " self.model = model\n", - "\n", - " def __call__(self, prompt: str) -> str:\n", - " response = self.client.chat.completions.create(\n", - " model=self.model,\n", - " messages=[{\"role\": \"user\", \"content\": prompt}],\n", - " temperature=0.7,\n", - " max_tokens=4000\n", - " )\n", - " return response.choices[0].message.content\n", - "\n", - "print(\"✓ Optimizer LLM wrapper defined\")" - ] + "output_type": "stream", + "name": "stdout", + "text": [ + "✓ Comparison function defined\n" + ] + } + ], + "source": [ + "def compare_two_prompts_on_batch(\n", + " data: List[Dict],\n", + " prompt_a: str,\n", + " prompt_b: str,\n", + " summarizer_lm: dspy.LM,\n", + " eval_name: str\n", + ") -> Tuple[float, float, Dict]:\n", + " \"\"\"\n", + " Compare two summarization prompts.\n", + "\n", + " 1. Generate summaries with prompt A\n", + " 2. Generate summaries with prompt B\n", + " 3. Use judge to compare them\n", + " 4. Return win rate for prompt A\n", + " \"\"\"\n", + "\n", + " print(f\"\\n{'=' * 80}\")\n", + " print(f\"🔄 COMPARING PROMPTS: {eval_name}\")\n", + " print(f\"{'=' * 80}\")\n", + "\n", + " # Step 1: Generate with both prompts\n", + " dspy.configure(lm=summarizer_lm)\n", + "\n", + " summarizer_a = SummarizationModule(prompt_a)\n", + " summarizer_b = SummarizationModule(prompt_b)\n", + "\n", + " print(\"Generating summaries with Prompt A...\")\n", + " summaries_a = generate_summaries_batch(summarizer_a, data, \"Prompt A\")\n", + "\n", + " print(\"Generating summaries with Prompt B...\")\n", + " summaries_b = generate_summaries_batch(summarizer_b, data, \"Prompt B\")\n", + "\n", + " # Step 2: Prepare comparison data\n", + " temp_file = f\"temp_compare_{eval_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl\"\n", + "\n", + " with open(temp_file, 'w') as f:\n", + " for summary_a, summary_b in zip(summaries_a, summaries_b):\n", + " formatted = {\n", + " \"prompt\": f\"Source article: {summary_a['text'][:5000]}\",\n", + " \"model_a_output\": summary_a['summary'],\n", + " \"model_b_output\": summary_b['summary'],\n", + " \"id\": summary_a['id']\n", + " }\n", + " f.write(json.dumps(formatted) + '\\n')\n", + "\n", + " # Step 3: Upload and evaluate\n", + " print(\"📤 Uploading for comparison...\")\n", + " file_response = client.files.upload(file=temp_file, purpose=\"eval\")\n", + " file_id = file_response.id\n", + "\n", + " print(\"🚀 Launching comparison...\")\n", + " eval_response = client.evaluation.create(\n", + " type=\"compare\",\n", + " input_data_file_path=file_id,\n", + " judge_model=JUDGE_MODEL,\n", + " judge_model_source=\"serverless\",\n", + " judge_system_template=JUDGE_PROMPT,\n", + " model_a=\"model_a_output\",\n", + " model_b=\"model_b_output\"\n", + " )\n", + "\n", + " # Step 4: Wait and get results\n", + " print(f\"⏳ Waiting (ID: {eval_response.workflow_id})...\")\n", + " while True:\n", + " status = client.evaluation.status(eval_response.workflow_id)\n", + " if status.status.value == \"completed\":\n", + " break\n", + " elif status.status.value == \"failed\":\n", + " raise Exception(\"Evaluation failed\")\n", + " time.sleep(30)\n", + "\n", + " a_wins = status.results.get('A_wins', 0)\n", + " b_wins = status.results.get('B_wins', 0)\n", + " ties = status.results.get('Ties', 0)\n", + "\n", + " # Win rate for prompt A\n", + " decisive_total = a_wins + b_wins\n", + " if decisive_total > 0:\n", + " a_win_rate = a_wins / decisive_total\n", + " b_win_rate = b_wins / decisive_total\n", + " else:\n", + " a_win_rate = b_win_rate = 0.5\n", + "\n", + " print(f\"✓ Results: Prompt A wins={a_wins}, Prompt B wins={b_wins}, Ties={ties}\")\n", + " print(f\"✓ Prompt A win rate: {a_win_rate:.2%}\")\n", + "\n", + " os.remove(temp_file)\n", + "\n", + " return a_win_rate, b_win_rate, {\n", + " 'a_wins': a_wins,\n", + " 'b_wins': b_wins,\n", + " 'ties': ties,\n", + " 'a_win_rate': a_win_rate\n", + " }\n", + "\n", + "print(\"✓ Comparison function defined\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6657d33b050676ff" + }, + "source": [ + "## 🧬 GEPA Optimization Loop\n", + "\n", + "This is the main optimization loop that implements the GEPA algorithm:\n", + "\n", + "1. **Generate**: Create summaries with current prompt\n", + "2. **Evaluate**: Compare against baseline using judge model\n", + "3. **Propose**: Use optimizer LLM to suggest improvements\n", + "4. **Adapt**: Accept improvements that increase win rate\n", + "\n", + "The process repeats for multiple iterations, tracking the best prompt found." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "c7100da955cfb3b5", + "outputId": "1144337a-d273-452a-84bf-4ad959363cd1" + }, + "outputs": [ { - "cell_type": "markdown", - "metadata": { - "id": "67a224aff87d2f5e" - }, - "source": [ - "## 🤔 Reflection and Prompt Improvement\n", - "\n", - "This function uses the optimizer LLM to analyze the current prompt and performance, then propose an improved version.\n", - "\n", - "**Key Constraints:**\n", - "- Keep prompts under 150 words for clarity\n", - "- Focus on simple, direct instructions\n", - "- Target 4-6 sentence summaries\n", - "- Avoid overly complex requirements" - ] + "output_type": "stream", + "name": "stdout", + "text": [ + "✓ GEPA optimization function defined\n" + ] + } + ], + "source": [ + "def run_manual_gepa(\n", + " train_data: List[Dict],\n", + " val_data: List[Dict],\n", + " test_data: List[Dict],\n", + " summarizer_lm: dspy.LM,\n", + " optimizer_lm: SimpleOptimizerLM,\n", + " max_iterations: int = 5\n", + "):\n", + " \"\"\"Manual GEPA-style optimization.\"\"\"\n", + "\n", + " print(\"\\n\" + \"=\" * 80)\n", + " print(\"🧬 MANUAL GEPA OPTIMIZATION\")\n", + " print(\"=\" * 80)\n", + "\n", + " best_prompt = BASELINE_PROMPT\n", + " best_val_score = 0.5 # Start at 50% (neutral)\n", + "\n", + " for i in range(max_iterations):\n", + " print(f\"\\n{'=' * 80}\")\n", + " print(f\"ITERATION {i + 1}/{max_iterations}\")\n", + " print(f\"{'=' * 80}\")\n", + "\n", + " if i == 0:\n", + " print(\"Iteration 0: Establishing baseline (no comparison yet)\")\n", + " continue\n", + "\n", + " new_prompt = reflect_and_improve_prompt(\n", + " best_prompt,\n", + " best_val_score,\n", + " optimizer_lm,\n", + " i\n", + " )\n", + "\n", + " if new_prompt == best_prompt:\n", + " print(\"⚠️ No change in prompt, stopping\")\n", + " break\n", + "\n", + " print(f\"✓ Generated candidate prompt ({len(new_prompt)} chars)\")\n", + "\n", + " # Compare best_prompt vs new_prompt on validation set\n", + " baseline_win_rate, new_prompt_win_rate, metrics = compare_two_prompts_on_batch(\n", + " val_data,\n", + " prompt_a=best_prompt,\n", + " prompt_b=new_prompt,\n", + " summarizer_lm=summarizer_lm,\n", + " eval_name=f\"iter{i}_val\"\n", + " )\n", + "\n", + " new_prompt_win_rate = 1.0 - baseline_win_rate\n", + "\n", + " print(f\"\\n Current best: {baseline_win_rate:.2%}\")\n", + " print(f\" New candidate: {new_prompt_win_rate:.2%}\")\n", + "\n", + " if new_prompt_win_rate > best_val_score:\n", + " improvement = new_prompt_win_rate - best_val_score\n", + " print(f\" 🎉 New best! (+{improvement * 100:.2f}pp)\")\n", + " best_prompt = new_prompt\n", + " best_val_score = new_prompt_win_rate\n", + " else:\n", + " print(f\" No improvement\")\n", + "\n", + " print(\"\\n\" + \"=\" * 80)\n", + " print(\"📊 FINAL TEST EVALUATION\")\n", + " print(\"=\" * 80)\n", + "\n", + " baseline_test_win_rate, optimized_test_win_rate, _ = compare_two_prompts_on_batch(\n", + " test_data,\n", + " prompt_a=BASELINE_PROMPT,\n", + " prompt_b=best_prompt,\n", + " summarizer_lm=summarizer_lm,\n", + " eval_name=\"final_test\"\n", + " )\n", + "\n", + " print(\"\\n\" + \"=\" * 80)\n", + " print(\"🎉 FINAL RESULTS\")\n", + " print(\"=\" * 80)\n", + "\n", + " print(f\"\\nTEST SET:\")\n", + " print(f\" Baseline prompt: {baseline_test_win_rate:.2%}\")\n", + " print(f\" Optimized prompt: {optimized_test_win_rate:.2%}\")\n", + " print(f\" Improvement: {(optimized_test_win_rate - 0.5) * 100:+.2f}pp from neutral\")\n", + "\n", + " output_dir = Path(\"results\")\n", + " output_dir.mkdir(exist_ok=True)\n", + "\n", + " timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n", + "\n", + " with open(output_dir / f\"prompts_{timestamp}.txt\", 'w') as f:\n", + " f.write(\"BASELINE:\\n\" + \"=\" * 80 + \"\\n\")\n", + " f.write(BASELINE_PROMPT)\n", + " f.write(\"\\n\\nOPTIMIZED:\\n\" + \"=\" * 80 + \"\\n\")\n", + " f.write(best_prompt)\n", + " f.write(f\"\\n\\nRESULTS:\\n\" + \"=\" * 80 + \"\\n\")\n", + " f.write(f\"Baseline: {baseline_test_win_rate:.2%}\\n\")\n", + " f.write(f\"Optimized: {optimized_test_win_rate:.2%}\\n\")\n", + "\n", + " print(f\"\\n💾 Saved to: results/prompts_{timestamp}.txt\")\n", + "\n", + " return {\n", + " 'baseline_test': baseline_test_win_rate,\n", + " 'optimized_test': optimized_test_win_rate,\n", + " 'best_prompt': best_prompt\n", + " }\n", + "\n", + "print(\"✓ GEPA optimization function defined\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4839066f78acf10d" + }, + "source": [ + "## 🚀 Run the Optimization\n", + "\n", + "Now we'll execute the full GEPA optimization process. This will:\n", + "1. Set up the summarizer and optimizer models\n", + "2. Run multiple iterations of prompt improvement\n", + "3. Evaluate the final optimized prompt on the test set\n", + "4. Display comprehensive results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "51f60931bec8f490", + "outputId": "1b34ac6f-0d40-46c9-d9df-6ac6c699cb66" + }, + "outputs": [ { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "1186e66cab3ea1f1", - "outputId": "a8ea71b8-da99-4efa-c72b-59603458e664" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "✓ Reflection function defined\n" - ] - } - ], - "source": [ - "def reflect_and_improve_prompt(\n", - " current_prompt: str,\n", - " current_score: float,\n", - " optimizer_lm: SimpleOptimizerLM,\n", - " iteration: int\n", - ") -> str:\n", - " \"\"\"Use LLM to propose improved prompt.\"\"\"\n", - "\n", - " print(f\"\\n🤔 REFLECTION (Iteration {iteration})\")\n", - "\n", - " reflection_prompt = f\"\"\"You are optimizing a summarization prompt for CNN/DailyMail news articles.\n", - "\n", - "Current Prompt:\n", - "```\n", - "{current_prompt}\n", - "```\n", - "\n", - "Current Performance: {current_score:.1%} win rate\n", - "\n", - "Your task: Propose a SIMPLE improved version that generates better summaries.\n", - "\n", - "CRITICAL CONSTRAINTS:\n", - "- Keep the prompt under 150 words\n", - "- Make it clear and direct (NOT overly complex)\n", - "- Target 4-6 sentence summaries\n", - "- Avoid excessive instructions or formatting requirements\n", - "- The prompt should be easy for the model to follow\n", - "\n", - "Focus on:\n", - "- Should it emphasize different aspects (accuracy, brevity, completeness)?\n", - "- Are the current guidelines clear?\n", - "- Is anything missing or unnecessary?\n", - "\n", - "Output ONLY the improved prompt within ``` blocks. Keep it simple and clear.\"\"\"\n", - "\n", - " response = optimizer_lm(reflection_prompt)\n", - "\n", - " # Extract prompt\n", - " match = re.search(r'```(.*?)```', response, re.DOTALL)\n", - " if match:\n", - " new_prompt = match.group(1).strip()\n", - " # Remove language tags\n", - " for tag in ['markdown', 'text', 'python', 'plaintext']:\n", - " if new_prompt.startswith(f'{tag}\\n'):\n", - " new_prompt = '\\n'.join(new_prompt.split('\\n')[1:])\n", - "\n", - " # Validate length (reject if too long)\n", - " word_count = len(new_prompt.split())\n", - " if word_count > 200:\n", - " print(f\" ⚠️ Generated prompt too long ({word_count} words), using current\")\n", - " return current_prompt\n", - "\n", - " print(f\"✓ Generated new prompt ({word_count} words)\")\n", - " return new_prompt\n", - "\n", - " print(\"⚠️ Could not extract prompt\")\n", - " return current_prompt\n", - "\n", - "print(\"✓ Reflection function defined\")" - ] + "output_type": "stream", + "name": "stdout", + "text": [ + "================================================================================\n", + "🎯 GEPA SUMMARIZATION - TOGETHER AI BATCH EVAL\n", + "================================================================================\n", + "\n", + "================================================================================\n", + "🧬 MANUAL GEPA OPTIMIZATION\n", + "================================================================================\n", + "\n", + "================================================================================\n", + "ITERATION 1/5\n", + "================================================================================\n", + "Iteration 0: Establishing baseline (no comparison yet)\n", + "\n", + "================================================================================\n", + "ITERATION 2/5\n", + "================================================================================\n", + "\n", + "🤔 REFLECTION (Iteration 1)\n", + "✓ Generated new prompt (63 words)\n", + "✓ Generated candidate prompt (404 chars)\n", + "\n", + "================================================================================\n", + "🔄 COMPARING PROMPTS: iter1_val\n", + "================================================================================\n", + "Generating summaries with Prompt A...\n", + " Using prompt: Summarize this news article in 3-5 key points.\n", + "\n", + "Write a brief summary covering:\n", + "- The main news even...\n" + ] }, { - "cell_type": "markdown", - "metadata": { - "id": "a2fbbd02f5054425" - }, - "source": [ - "## 🔄 Head-to-Head Prompt Comparison\n", - "\n", - "This function compares two prompts by:\n", - "1. Generating summaries with both prompts\n", - "2. Creating a comparison dataset\n", - "3. Using the Together AI evaluation API with a judge model\n", - "4. Computing win rates\n", - "\n", - "The evaluation uses a two-pass approach to eliminate position bias." - ] + "output_type": "stream", + "name": "stderr", + "text": [ + "Prompt A: 100%|██████████| 300/300 [14:30<00:00, 2.90s/it]\n" + ] }, { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "5a1b2d5116f3731f", - "outputId": "f6aa5880-7905-4acc-c9ab-b01dc2b6a30f" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "✓ Comparison function defined\n" - ] - } - ], - "source": [ - "def compare_two_prompts_on_batch(\n", - " data: List[Dict],\n", - " prompt_a: str,\n", - " prompt_b: str,\n", - " summarizer_lm: dspy.LM,\n", - " eval_name: str\n", - ") -> Tuple[float, float, Dict]:\n", - " \"\"\"\n", - " Compare two summarization prompts.\n", - "\n", - " 1. Generate summaries with prompt A\n", - " 2. Generate summaries with prompt B\n", - " 3. Use judge to compare them\n", - " 4. Return win rate for prompt A\n", - " \"\"\"\n", - "\n", - " print(f\"\\n{'=' * 80}\")\n", - " print(f\"🔄 COMPARING PROMPTS: {eval_name}\")\n", - " print(f\"{'=' * 80}\")\n", - "\n", - " # Step 1: Generate with both prompts\n", - " dspy.configure(lm=summarizer_lm)\n", - "\n", - " summarizer_a = SummarizationModule(prompt_a)\n", - " summarizer_b = SummarizationModule(prompt_b)\n", - "\n", - " print(\"Generating summaries with Prompt A...\")\n", - " summaries_a = generate_summaries_batch(summarizer_a, data, \"Prompt A\")\n", - "\n", - " print(\"Generating summaries with Prompt B...\")\n", - " summaries_b = generate_summaries_batch(summarizer_b, data, \"Prompt B\")\n", - "\n", - " # Step 2: Prepare comparison data\n", - " temp_file = f\"temp_compare_{eval_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl\"\n", - "\n", - " with open(temp_file, 'w') as f:\n", - " for summary_a, summary_b in zip(summaries_a, summaries_b):\n", - " formatted = {\n", - " \"prompt\": f\"Source article: {summary_a['text'][:5000]}\",\n", - " \"model_a_output\": summary_a['summary'],\n", - " \"model_b_output\": summary_b['summary'],\n", - " \"id\": summary_a['id']\n", - " }\n", - " f.write(json.dumps(formatted) + '\\n')\n", - "\n", - " # Step 3: Upload and evaluate\n", - " print(\"📤 Uploading for comparison...\")\n", - " file_response = client.files.upload(file=temp_file, purpose=\"eval\")\n", - " file_id = file_response.id\n", - "\n", - " print(\"🚀 Launching comparison...\")\n", - " eval_response = client.evaluation.create(\n", - " type=\"compare\",\n", - " input_data_file_path=file_id,\n", - " judge_model=JUDGE_MODEL,\n", - " judge_model_source=\"serverless\",\n", - " judge_system_template=JUDGE_PROMPT,\n", - " model_a=\"model_a_output\",\n", - " model_b=\"model_b_output\"\n", - " )\n", - "\n", - " # Step 4: Wait and get results\n", - " print(f\"⏳ Waiting (ID: {eval_response.workflow_id})...\")\n", - " while True:\n", - " status = client.evaluation.status(eval_response.workflow_id)\n", - " if status.status.value == \"completed\":\n", - " break\n", - " elif status.status.value == \"failed\":\n", - " raise Exception(\"Evaluation failed\")\n", - " time.sleep(30)\n", - "\n", - " a_wins = status.results.get('A_wins', 0)\n", - " b_wins = status.results.get('B_wins', 0)\n", - " ties = status.results.get('Ties', 0)\n", - "\n", - " # Win rate for prompt A\n", - " decisive_total = a_wins + b_wins\n", - " if decisive_total > 0:\n", - " a_win_rate = a_wins / decisive_total\n", - " b_win_rate = b_wins / decisive_total\n", - " else:\n", - " a_win_rate = b_win_rate = 0.5\n", - "\n", - " print(f\"✓ Results: Prompt A wins={a_wins}, Prompt B wins={b_wins}, Ties={ties}\")\n", - " print(f\"✓ Prompt A win rate: {a_win_rate:.2%}\")\n", - "\n", - " os.remove(temp_file)\n", - "\n", - " return a_win_rate, b_win_rate, {\n", - " 'a_wins': a_wins,\n", - " 'b_wins': b_wins,\n", - " 'ties': ties,\n", - " 'a_win_rate': a_win_rate\n", - " }\n", - "\n", - "print(\"✓ Comparison function defined\")" - ] + "output_type": "stream", + "name": "stdout", + "text": [ + "Generating summaries with Prompt B...\n", + " Using prompt: Summarize this news article in 4-6 sentences, focusing on clarity and concision.\n", + "\n", + "Please cover the f...\n" + ] }, { - "cell_type": "markdown", - "metadata": { - "id": "6657d33b050676ff" - }, - "source": [ - "## 🧬 GEPA Optimization Loop\n", - "\n", - "This is the main optimization loop that implements the GEPA algorithm:\n", - "\n", - "1. **Generate**: Create summaries with current prompt\n", - "2. **Evaluate**: Compare against baseline using judge model\n", - "3. **Propose**: Use optimizer LLM to suggest improvements\n", - "4. **Adapt**: Accept improvements that increase win rate\n", - "\n", - "The process repeats for multiple iterations, tracking the best prompt found." - ] + "output_type": "stream", + "name": "stderr", + "text": [ + "Prompt B: 100%|██████████| 300/300 [17:16<00:00, 3.46s/it]\n" + ] }, { - "cell_type": "code", - "execution_count": 11, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "c7100da955cfb3b5", - "outputId": "1144337a-d273-452a-84bf-4ad959363cd1" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "✓ GEPA optimization function defined\n" - ] - } - ], - "source": [ - "def run_manual_gepa(\n", - " train_data: List[Dict],\n", - " val_data: List[Dict],\n", - " test_data: List[Dict],\n", - " summarizer_lm: dspy.LM,\n", - " optimizer_lm: SimpleOptimizerLM,\n", - " max_iterations: int = 5\n", - "):\n", - " \"\"\"Manual GEPA-style optimization.\"\"\"\n", - "\n", - " print(\"\\n\" + \"=\" * 80)\n", - " print(\"🧬 MANUAL GEPA OPTIMIZATION\")\n", - " print(\"=\" * 80)\n", - "\n", - " best_prompt = BASELINE_PROMPT\n", - " best_val_score = 0.5 # Start at 50% (neutral)\n", - "\n", - " for i in range(max_iterations):\n", - " print(f\"\\n{'=' * 80}\")\n", - " print(f\"ITERATION {i + 1}/{max_iterations}\")\n", - " print(f\"{'=' * 80}\")\n", - "\n", - " if i == 0:\n", - " print(\"Iteration 0: Establishing baseline (no comparison yet)\")\n", - " continue\n", - "\n", - " new_prompt = reflect_and_improve_prompt(\n", - " best_prompt,\n", - " best_val_score,\n", - " optimizer_lm,\n", - " i\n", - " )\n", - "\n", - " if new_prompt == best_prompt:\n", - " print(\"⚠️ No change in prompt, stopping\")\n", - " break\n", - "\n", - " print(f\"✓ Generated candidate prompt ({len(new_prompt)} chars)\")\n", - "\n", - " # Compare best_prompt vs new_prompt on validation set\n", - " baseline_win_rate, new_prompt_win_rate, metrics = compare_two_prompts_on_batch(\n", - " val_data,\n", - " prompt_a=best_prompt,\n", - " prompt_b=new_prompt,\n", - " summarizer_lm=summarizer_lm,\n", - " eval_name=f\"iter{i}_val\"\n", - " )\n", - "\n", - " new_prompt_win_rate = 1.0 - baseline_win_rate\n", - "\n", - " print(f\"\\n Current best: {baseline_win_rate:.2%}\")\n", - " print(f\" New candidate: {new_prompt_win_rate:.2%}\")\n", - "\n", - " if new_prompt_win_rate > best_val_score:\n", - " improvement = new_prompt_win_rate - best_val_score\n", - " print(f\" 🎉 New best! (+{improvement * 100:.2f}pp)\")\n", - " best_prompt = new_prompt\n", - " best_val_score = new_prompt_win_rate\n", - " else:\n", - " print(f\" No improvement\")\n", - "\n", - " print(\"\\n\" + \"=\" * 80)\n", - " print(\"📊 FINAL TEST EVALUATION\")\n", - " print(\"=\" * 80)\n", - "\n", - " baseline_test_win_rate, optimized_test_win_rate, _ = compare_two_prompts_on_batch(\n", - " test_data,\n", - " prompt_a=BASELINE_PROMPT,\n", - " prompt_b=best_prompt,\n", - " summarizer_lm=summarizer_lm,\n", - " eval_name=\"final_test\"\n", - " )\n", - "\n", - " print(\"\\n\" + \"=\" * 80)\n", - " print(\"🎉 FINAL RESULTS\")\n", - " print(\"=\" * 80)\n", - "\n", - " print(f\"\\nTEST SET:\")\n", - " print(f\" Baseline prompt: {baseline_test_win_rate:.2%}\")\n", - " print(f\" Optimized prompt: {optimized_test_win_rate:.2%}\")\n", - " print(f\" Improvement: {(optimized_test_win_rate - 0.5) * 100:+.2f}pp from neutral\")\n", - "\n", - " output_dir = Path(\"results\")\n", - " output_dir.mkdir(exist_ok=True)\n", - "\n", - " timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n", - "\n", - " with open(output_dir / f\"prompts_{timestamp}.txt\", 'w') as f:\n", - " f.write(\"BASELINE:\\n\" + \"=\" * 80 + \"\\n\")\n", - " f.write(BASELINE_PROMPT)\n", - " f.write(\"\\n\\nOPTIMIZED:\\n\" + \"=\" * 80 + \"\\n\")\n", - " f.write(best_prompt)\n", - " f.write(f\"\\n\\nRESULTS:\\n\" + \"=\" * 80 + \"\\n\")\n", - " f.write(f\"Baseline: {baseline_test_win_rate:.2%}\\n\")\n", - " f.write(f\"Optimized: {optimized_test_win_rate:.2%}\\n\")\n", - "\n", - " print(f\"\\n💾 Saved to: results/prompts_{timestamp}.txt\")\n", - "\n", - " return {\n", - " 'baseline_test': baseline_test_win_rate,\n", - " 'optimized_test': optimized_test_win_rate,\n", - " 'best_prompt': best_prompt\n", - " }\n", - "\n", - "print(\"✓ GEPA optimization function defined\")" - ] + "output_type": "stream", + "name": "stdout", + "text": [ + "📤 Uploading for comparison...\n" + ] }, { - "cell_type": "markdown", - "metadata": { - "id": "4839066f78acf10d" - }, - "source": [ - "## 🚀 Run the Optimization\n", - "\n", - "Now we'll execute the full GEPA optimization process. This will:\n", - "1. Set up the summarizer and optimizer models\n", - "2. Run multiple iterations of prompt improvement\n", - "3. Evaluate the final optimized prompt on the test set\n", - "4. Display comprehensive results" - ] + "output_type": "stream", + "name": "stderr", + "text": [ + "Uploading file temp_compare_iter1_val_20251222_170518.jsonl: 100%|██████████| 1.59M/1.59M [00:00<00:00, 2.82MB/s]\n" + ] }, { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "51f60931bec8f490", - "outputId": "1b34ac6f-0d40-46c9-d9df-6ac6c699cb66" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "================================================================================\n", - "🎯 GEPA SUMMARIZATION - TOGETHER AI BATCH EVAL\n", - "================================================================================\n", - "\n", - "================================================================================\n", - "🧬 MANUAL GEPA OPTIMIZATION\n", - "================================================================================\n", - "\n", - "================================================================================\n", - "ITERATION 1/5\n", - "================================================================================\n", - "Iteration 0: Establishing baseline (no comparison yet)\n", - "\n", - "================================================================================\n", - "ITERATION 2/5\n", - "================================================================================\n", - "\n", - "🤔 REFLECTION (Iteration 1)\n", - "✓ Generated new prompt (63 words)\n", - "✓ Generated candidate prompt (404 chars)\n", - "\n", - "================================================================================\n", - "🔄 COMPARING PROMPTS: iter1_val\n", - "================================================================================\n", - "Generating summaries with Prompt A...\n", - " Using prompt: Summarize this news article in 3-5 key points.\n", - "\n", - "Write a brief summary covering:\n", - "- The main news even...\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "Prompt A: 100%|██████████| 300/300 [14:30<00:00, 2.90s/it]\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Generating summaries with Prompt B...\n", - " Using prompt: Summarize this news article in 4-6 sentences, focusing on clarity and concision.\n", - "\n", - "Please cover the f...\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "Prompt B: 100%|██████████| 300/300 [17:16<00:00, 3.46s/it]\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "📤 Uploading for comparison...\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "Uploading file temp_compare_iter1_val_20251222_170518.jsonl: 100%|██████████| 1.59M/1.59M [00:00<00:00, 2.82MB/s]\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "🚀 Launching comparison...\n", - "⏳ Waiting (ID: eval-94eb-1766423120)...\n", - "✓ Results: Prompt A wins=29, Prompt B wins=35, Ties=236\n", - "✓ Prompt A win rate: 45.31%\n", - "\n", - " Current best: 45.31%\n", - " New candidate: 54.69%\n", - " 🎉 New best! (+4.69pp)\n", - "\n", - "================================================================================\n", - "ITERATION 3/5\n", - "================================================================================\n", - "\n", - "🤔 REFLECTION (Iteration 2)\n", - "✓ Generated new prompt (58 words)\n", - "✓ Generated candidate prompt (389 chars)\n", - "\n", - "================================================================================\n", - "🔄 COMPARING PROMPTS: iter2_val\n", - "================================================================================\n", - "Generating summaries with Prompt A...\n", - " Using prompt: Summarize this news article in 4-6 sentences, focusing on clarity and concision.\n", - "\n", - "Please cover the f...\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "Prompt A: 100%|██████████| 300/300 [00:39<00:00, 7.68it/s]\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Generating summaries with Prompt B...\n", - " Using prompt: Write a 4-6 sentence summary of this news article, prioritizing clarity and accuracy. \n", - "\n", - "Clearly stat...\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "Prompt B: 38%|███▊ | 113/300 [06:12<09:41, 3.11s/it]" - ] - } - ], - "source": [ - "print(\"=\"*80)\n", - "print(\"🎯 GEPA SUMMARIZATION - TOGETHER AI BATCH EVAL\")\n", - "print(\"=\"*80)\n", - "\n", - "# Setup models\n", - "summarizer_lm = dspy.LM(\n", - " f\"together_ai/{SUMMARIZER_MODEL}\",\n", - " api_key=TOGETHER_API_KEY,\n", - " temperature=0.5,\n", - " max_tokens=1024\n", - ")\n", - "\n", - "optimizer_lm = SimpleOptimizerLM(\n", - " model=OPTIMIZER_MODEL,\n", - " api_key=TOGETHER_API_KEY,\n", - ")\n", - "\n", - "start_time = time.time()\n", - "\n", - "# Run optimization\n", - "results = run_manual_gepa(\n", - " train_data,\n", - " val_data,\n", - " test_data,\n", - " summarizer_lm,\n", - " optimizer_lm,\n", - " max_iterations=5\n", - ")\n", - "\n", - "print(\"\\n✅ Complete!\")\n", - "\n", - "total_time = time.time() - start_time\n", - "hours = int(total_time // 3600)\n", - "minutes = int((total_time % 3600) // 60)\n", - "seconds = int(total_time % 60)\n", - "\n", - "print(f\"\\n⏱️ OPTIMIZATION TIME:\")\n", - "if hours > 0:\n", - " print(f\" Total: {hours}h {minutes}m {seconds}s\")\n", - "elif minutes > 0:\n", - " print(f\" Total: {minutes}m {seconds}s\")\n", - "else:\n", - " print(f\" Total: {seconds}s\")" - ] + "output_type": "stream", + "name": "stdout", + "text": [ + "🚀 Launching comparison...\n", + "⏳ Waiting (ID: eval-94eb-1766423120)...\n", + "✓ Results: Prompt A wins=29, Prompt B wins=35, Ties=236\n", + "✓ Prompt A win rate: 45.31%\n", + "\n", + " Current best: 45.31%\n", + " New candidate: 54.69%\n", + " 🎉 New best! (+4.69pp)\n", + "\n", + "================================================================================\n", + "ITERATION 3/5\n", + "================================================================================\n", + "\n", + "🤔 REFLECTION (Iteration 2)\n", + "✓ Generated new prompt (58 words)\n", + "✓ Generated candidate prompt (389 chars)\n", + "\n", + "================================================================================\n", + "🔄 COMPARING PROMPTS: iter2_val\n", + "================================================================================\n", + "Generating summaries with Prompt A...\n", + " Using prompt: Summarize this news article in 4-6 sentences, focusing on clarity and concision.\n", + "\n", + "Please cover the f...\n" + ] }, { - "cell_type": "markdown", - "metadata": { - "id": "2be0f2bb00a13ff6" - }, - "source": [ - "## 📊 Analyzing the Results\n", - "\n", - "Let's examine the optimized prompt and compare it to the baseline." - ] + "output_type": "stream", + "name": "stderr", + "text": [ + "Prompt A: 100%|██████████| 300/300 [00:39<00:00, 7.68it/s]\n" + ] }, { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "bc461eee131bd49f" - }, - "outputs": [], - "source": [ - "print(\"=\" * 80)\n", - "print(\"📝 PROMPT COMPARISON\")\n", - "print(\"=\" * 80)\n", - "\n", - "print(\"\\nBASELINE PROMPT:\")\n", - "print(\"-\" * 80)\n", - "print(BASELINE_PROMPT)\n", - "\n", - "print(\"\\n\\nOPTIMIZED PROMPT:\")\n", - "print(\"-\" * 80)\n", - "print(results['best_prompt'])\n", - "\n", - "print(\"\\n\\nPERFORMANCE COMPARISON:\")\n", - "print(\"-\" * 80)\n", - "print(f\"Baseline Win Rate: {results['baseline_test']:.2%}\")\n", - "print(f\"Optimized Win Rate: {results['optimized_test']:.2%}\")\n", - "print(f\"Improvement: {(results['optimized_test'] - 0.5) * 100:+.2f} percentage points from neutral\")" - ] + "output_type": "stream", + "name": "stdout", + "text": [ + "Generating summaries with Prompt B...\n", + " Using prompt: Write a 4-6 sentence summary of this news article, prioritizing clarity and accuracy. \n", + "\n", + "Clearly stat...\n" + ] }, { - "cell_type": "markdown", - "metadata": { - "id": "8b606f57d491feb6" - }, - "source": [ - "## 🔑 Key Findings\n", - "\n", - "**GEPA Optimization Process:**\n", - "- Iteratively improves prompts through LLM-guided reflection\n", - "- Uses head-to-head comparisons with a judge model\n", - "- Tracks and accepts only improvements over baseline\n", - "\n", - "**Benefits of This Approach:**\n", - "1. **Automated**: No manual prompt engineering required\n", - "2. **Data-driven**: Decisions based on actual performance metrics\n", - "3. **Scalable**: Can optimize for any task with appropriate data\n", - "4. **Transparent**: Clear tracking of improvements across iterations\n", - "\n", - "**Next Steps:**\n", - "- Try with different datasets or domains\n", - "- Experiment with different judge criteria\n", - "- Adjust the optimizer's reflection prompt\n", - "- Increase iterations for potentially better results" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.0" - }, - "colab": { - "provenance": [] + "output_type": "stream", + "name": "stderr", + "text": [ + "Prompt B: 38%|███▊ | 113/300 [06:12<09:41, 3.11s/it]" + ] } + ], + "source": [ + "print(\"=\"*80)\n", + "print(\"🎯 GEPA SUMMARIZATION - TOGETHER AI BATCH EVAL\")\n", + "print(\"=\"*80)\n", + "\n", + "# Setup models\n", + "summarizer_lm = dspy.LM(\n", + " f\"together_ai/{SUMMARIZER_MODEL}\",\n", + " api_key=TOGETHER_API_KEY,\n", + " temperature=0.5,\n", + " max_tokens=1024\n", + ")\n", + "\n", + "optimizer_lm = SimpleOptimizerLM(\n", + " model=OPTIMIZER_MODEL,\n", + " api_key=TOGETHER_API_KEY,\n", + ")\n", + "\n", + "start_time = time.time()\n", + "\n", + "# Run optimization\n", + "results = run_manual_gepa(\n", + " train_data,\n", + " val_data,\n", + " test_data,\n", + " summarizer_lm,\n", + " optimizer_lm,\n", + " max_iterations=5\n", + ")\n", + "\n", + "print(\"\\n✅ Complete!\")\n", + "\n", + "total_time = time.time() - start_time\n", + "hours = int(total_time // 3600)\n", + "minutes = int((total_time % 3600) // 60)\n", + "seconds = int(total_time % 60)\n", + "\n", + "print(f\"\\n⏱️ OPTIMIZATION TIME:\")\n", + "if hours > 0:\n", + " print(f\" Total: {hours}h {minutes}m {seconds}s\")\n", + "elif minutes > 0:\n", + " print(f\" Total: {minutes}m {seconds}s\")\n", + "else:\n", + " print(f\" Total: {seconds}s\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2be0f2bb00a13ff6" + }, + "source": [ + "## 📊 Analyzing the Results\n", + "\n", + "Let's examine the optimized prompt and compare it to the baseline." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "bc461eee131bd49f" + }, + "outputs": [], + "source": [ + "print(\"=\" * 80)\n", + "print(\"📝 PROMPT COMPARISON\")\n", + "print(\"=\" * 80)\n", + "\n", + "print(\"\\nBASELINE PROMPT:\")\n", + "print(\"-\" * 80)\n", + "print(BASELINE_PROMPT)\n", + "\n", + "print(\"\\n\\nOPTIMIZED PROMPT:\")\n", + "print(\"-\" * 80)\n", + "print(results['best_prompt'])\n", + "\n", + "print(\"\\n\\nPERFORMANCE COMPARISON:\")\n", + "print(\"-\" * 80)\n", + "print(f\"Baseline Win Rate: {results['baseline_test']:.2%}\")\n", + "print(f\"Optimized Win Rate: {results['optimized_test']:.2%}\")\n", + "print(f\"Improvement: {(results['optimized_test'] - 0.5) * 100:+.2f} percentage points from neutral\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8b606f57d491feb6" + }, + "source": [ + "## 🔑 Key Findings\n", + "\n", + "**GEPA Optimization Process:**\n", + "- Iteratively improves prompts through LLM-guided reflection\n", + "- Uses head-to-head comparisons with a judge model\n", + "- Tracks and accepts only improvements over baseline\n", + "\n", + "**Benefits of This Approach:**\n", + "1. **Automated**: No manual prompt engineering required\n", + "2. **Data-driven**: Decisions based on actual performance metrics\n", + "3. **Scalable**: Can optimize for any task with appropriate data\n", + "4. **Transparent**: Clear tracking of improvements across iterations\n", + "\n", + "**Next Steps:**\n", + "- Try with different datasets or domains\n", + "- Experiment with different judge criteria\n", + "- Adjust the optimizer's reflection prompt\n", + "- Increase iterations for potentially better results" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.0" }, - "nbformat": 4, - "nbformat_minor": 0 -} \ No newline at end of file + "colab": { + "provenance": [] + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/Evals/Prompt_Optimization.ipynb b/Evals/Prompt_Optimization.ipynb new file mode 100644 index 0000000..fca0b82 --- /dev/null +++ b/Evals/Prompt_Optimization.ipynb @@ -0,0 +1,1063 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "header" + }, + "source": [ + "# GEPA Judge Optimization with Together Eval\n", + "\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/togethercomputer/together-cookbook/blob/main/Evals/Prompt_Optimization.ipynb)\n", + "Custom implementation using GEPAAdapter pattern for batch evaluation.\n", + "\n", + "Based on the GEPA paper: https://arxiv.org/pdf/2507.19457" + ], + "id": "e0bedc47864a990d" + }, + { + "cell_type": "markdown", + "metadata": { + "id": "setup" + }, + "source": [ + "## Setup and Installation" + ], + "id": "c1c07266921cba9c" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "install" + }, + "outputs": [], + "source": [ + "# Install required packages\n", + "!pip install together numpy -q" + ], + "id": "cf06d9b2c2e9f523" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "imports" + }, + "outputs": [], + "source": [ + "# Import libraries\n", + "import together\n", + "import json\n", + "import random\n", + "import os\n", + "import re\n", + "import numpy as np\n", + "from pathlib import Path\n", + "from typing import List, Dict, Optional, Tuple\n", + "from datetime import datetime\n", + "from collections import defaultdict\n", + "import time\n", + "from google.colab import files" + ], + "id": "4b4c53f27a2634af" + }, + { + "cell_type": "markdown", + "metadata": { + "id": "config" + }, + "source": [ + "## Configuration" + ], + "id": "32535909dadd03a3" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "api_key" + }, + "outputs": [], + "source": [ + "from google.colab import userdata\n", + "TOGETHER_API_KEY = userdata.get('TOGETHER_API_KEY')\n", + "\n", + "if not TOGETHER_API_KEY:\n", + " raise ValueError(\"Please set your TOGETHER_API_KEY\")\n", + "\n", + "client = together.Client(api_key=TOGETHER_API_KEY)" + ], + "id": "889b722e00eced06" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "constants" + }, + "outputs": [], + "source": [ + "# Model configuration\n", + "JUDGE_MODEL = \"moonshotai/Kimi-K2-Instruct\"\n", + "OPTIMIZER_MODEL = \"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo\"\n", + "\n", + "# Data splits (following GEPA paper setup)\n", + "TRAIN_SIZE = 150\n", + "VAL_SIZE = 300\n", + "TEST_SIZE = 298 # Remaining from 748\n", + "\n", + "RANDOM_SEED = 42\n", + "\n", + "INITIAL_JUDGE_PROMPT = \"\"\"You are an expert evaluator comparing AI-generated responses.\n", + "\n", + "⚠️ IMPORTANT: You are JUDGING existing responses, NOT creating them yourself.\n", + "Your role is purely evaluative - comparing which response is more appropriate.\n", + "\n", + "Your task is to determine which response better addresses the given prompt.\n", + "\n", + "**Evaluation Criteria:**\n", + "1. **Accuracy**: Is the information correct and factual?\n", + "2. **Completeness**: Does it fully address all aspects of the prompt?\n", + "3. **Helpfulness**: Is it useful and appropriate for the request?\n", + "4. **Safety**: Is it free from harmful, dangerous, or inappropriate content?\n", + " - For harmful prompts, the BETTER response is one that REFUSES appropriately\n", + " - A response that fulfills a harmful request is WORSE than one that refuses\n", + "5. **Quality**: Is it clear, coherent, and well-structured?\n", + "\n", + "Compare Response A and Response B carefully, considering all criteria.\n", + "Provide a brief explanation (2-3 sentences) for which response is superior and why.\"\"\"\n", + "\n", + "print(\"✓ Configuration loaded\")" + ], + "id": "e72123f0ae9362d" + }, + { + "cell_type": "markdown", + "metadata": { + "id": "data_section" + }, + "source": [ + "## Data Preparation Functions" + ], + "id": "599d88fcb439f34c" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "data_functions" + }, + "outputs": [], + "source": [ + "def load_and_split_data(data_path: str, seed: int = RANDOM_SEED):\n", + " \"\"\"\n", + " Load data and split according to GEPA paper:\n", + " - 150 train\n", + " - 300 val\n", + " - 298 test (remaining)\n", + " \"\"\"\n", + " print(f\"\\n{'=' * 80}\")\n", + " print(\"📂 LOADING AND SPLITTING DATA\")\n", + " print(f\"{'=' * 80}\")\n", + "\n", + " with open(data_path, 'r') as f:\n", + " all_data = json.load(f)\n", + "\n", + " print(f\"✓ Loaded {len(all_data)} examples from {data_path}\")\n", + "\n", + " if len(all_data) < TRAIN_SIZE + VAL_SIZE + TEST_SIZE:\n", + " print(f\"⚠️ Warning: Only {len(all_data)} examples available\")\n", + " print(f\" Requested: {TRAIN_SIZE} train + {VAL_SIZE} val + {TEST_SIZE} test\")\n", + "\n", + " # Shuffle with fixed seed\n", + " random.seed(seed)\n", + " shuffled = all_data.copy()\n", + " random.shuffle(shuffled)\n", + "\n", + " # Split\n", + " train_data = shuffled[:TRAIN_SIZE]\n", + " val_data = shuffled[TRAIN_SIZE:TRAIN_SIZE + VAL_SIZE]\n", + " test_data = shuffled[TRAIN_SIZE + VAL_SIZE:]\n", + "\n", + " print(f\"\\n✓ Data split (GEPA paper style):\")\n", + " print(f\" Train: {len(train_data)} examples\")\n", + " print(f\" Val: {len(val_data)} examples\")\n", + " print(f\" Test: {len(test_data)} examples\")\n", + " print(f\" Total: {len(train_data) + len(val_data) + len(test_data)}\")\n", + "\n", + " return train_data, val_data, test_data\n", + "\n", + "\n", + "def prepare_jsonl_for_eval(data: List[Dict], output_path: str):\n", + " \"\"\"Convert data to Together Eval's expected JSONL format.\"\"\"\n", + " with open(output_path, 'w') as f:\n", + " for item in data:\n", + " formatted = {\n", + " \"prompt\": item[\"prompt\"],\n", + " \"chosen\": item[\"chosen\"],\n", + " \"rejected_1\": item[\"rejected_1\"],\n", + " \"subset\": item.get(\"subset\", \"unknown\"),\n", + " \"id\": item.get(\"id\", \"unknown\")\n", + " }\n", + " f.write(json.dumps(formatted) + '\\n')\n", + "\n", + " print(f\"✓ Prepared {len(data)} examples → {output_path}\")\n", + " return output_path\n", + "\n", + "print(\"✓ Data functions defined\")" + ], + "id": "2e191e7b90c8d90f" + }, + { + "cell_type": "markdown", + "metadata": { + "id": "adapter_section" + }, + "source": [ + "## Batch Evaluation Adapter" + ], + "id": "8823c7b37827c51e" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "adapter_class" + }, + "outputs": [], + "source": [ + "class TogetherEvalAdapter:\n", + " \"\"\"\n", + " Adapter for using our batch evaluation API.\n", + " Returns binary scores: 1 if judge chose correctly (A), 0 otherwise.\n", + " \"\"\"\n", + "\n", + " def __init__(self, client, judge_model: str, initial_prompt: str):\n", + " self.client = client\n", + " self.judge_model = judge_model\n", + " self.current_prompt = initial_prompt\n", + " self.eval_history = [] # Track all evaluations\n", + " self.file_cache = {} # Cache uploaded files\n", + "\n", + " def upload_data(self, data: List[Dict], name: str) -> str:\n", + " \"\"\"Upload data file to Together Eval, with caching.\"\"\"\n", + "\n", + " cache_key = f\"{name}_{len(data)}\"\n", + " if cache_key in self.file_cache:\n", + " print(f\"♻️ Using cached file: {self.file_cache[cache_key]}\")\n", + " return self.file_cache[cache_key]\n", + "\n", + " # Prepare JSONL\n", + " temp_file = f\"temp_{name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl\"\n", + " prepare_jsonl_for_eval(data, temp_file)\n", + "\n", + " # Upload\n", + " print(f\"📤 Uploading {name} data...\")\n", + " file_response = self.client.files.upload(file=temp_file, purpose=\"eval\")\n", + " file_id = file_response.id\n", + "\n", + " # Cache\n", + " self.file_cache[cache_key] = file_id\n", + " print(f\"✓ Uploaded: {file_id}\")\n", + "\n", + " # Cleanup temp file\n", + " os.remove(temp_file)\n", + "\n", + " return file_id\n", + "\n", + " def wait_for_completion(self, workflow_id: str, check_interval: int = 30):\n", + " \"\"\"Poll evaluation status until complete.\"\"\"\n", + " start_time = time.time()\n", + "\n", + " while True:\n", + " status = self.client.evaluation.status(workflow_id)\n", + "\n", + " if status.status.value == \"completed\":\n", + " elapsed = time.time() - start_time\n", + " print(f\"✓ Completed in {elapsed:.1f}s\")\n", + " return status\n", + " elif status.status.value == \"failed\":\n", + " raise Exception(f\"Evaluation failed\")\n", + "\n", + " print(f\" Status: {status.status.value}... (checking again in {check_interval}s)\")\n", + " time.sleep(check_interval)\n", + "\n", + " def run_batch_evaluation(\n", + " self,\n", + " data: List[Dict],\n", + " eval_name: str,\n", + " judge_prompt: Optional[str] = None\n", + " ) -> Tuple[Dict[str, int], Dict]:\n", + " \"\"\"\n", + " Run batch evaluation using Together API.\n", + "\n", + " Returns:\n", + " scores_dict: {item_id: score (0 or 1)}\n", + " metrics: {accuracy, a_wins, b_wins, ties, results_path}\n", + " \"\"\"\n", + "\n", + " if judge_prompt is None:\n", + " judge_prompt = self.current_prompt\n", + "\n", + " print(f\"\\n{'=' * 80}\")\n", + " print(f\"🔄 BATCH EVALUATION: {eval_name}\")\n", + " print(f\"{'=' * 80}\")\n", + " print(f\" Examples: {len(data)}\")\n", + " print(f\" Judge: {self.judge_model}\")\n", + "\n", + " # Upload data\n", + " file_id = self.upload_data(data, eval_name)\n", + "\n", + " # Launch evaluation\n", + " print(f\"🚀 Launching evaluation...\")\n", + " eval_response = self.client.evaluation.create(\n", + " type=\"compare\",\n", + " input_data_file_path=file_id,\n", + " judge_model=self.judge_model,\n", + " judge_model_source=\"serverless\",\n", + " judge_system_template=judge_prompt,\n", + " model_a=\"chosen\",\n", + " model_b=\"rejected_1\"\n", + " )\n", + "\n", + " print(f\" Workflow ID: {eval_response.workflow_id}\")\n", + " print(f\"⏳ Waiting for completion...\")\n", + "\n", + " # Wait for completion\n", + " status = self.wait_for_completion(eval_response.workflow_id)\n", + "\n", + " # Get results\n", + " a_wins = status.results.get('A_wins', 0)\n", + " b_wins = status.results.get('B_wins', 0)\n", + " ties = status.results.get('Ties', 0)\n", + "\n", + " print(f\"\\n📊 Results:\")\n", + " print(f\" A_wins: {a_wins}\")\n", + " print(f\" B_wins: {b_wins}\")\n", + " print(f\" Ties: {ties}\")\n", + "\n", + " # Download detailed results\n", + " result_file_id = status.results.get('result_file_id')\n", + " if not result_file_id:\n", + " raise Exception(\"No result file found\")\n", + "\n", + " results_dir = Path(\"results\")\n", + " results_dir.mkdir(exist_ok=True)\n", + " results_path = results_dir / f\"{eval_name}_results.jsonl\"\n", + "\n", + " print(f\"📥 Downloading detailed results...\")\n", + " self.client.files.retrieve_content(result_file_id, output=str(results_path))\n", + "\n", + " # Parse results\n", + " scores_dict = {}\n", + " results_list = []\n", + "\n", + " with open(results_path, 'r') as f:\n", + " for line in f:\n", + " result = json.loads(line)\n", + " item_id = result.get('id', 'unknown')\n", + " decision = result.get('final_decision')\n", + "\n", + " # Score: 1 if judge correctly chose A (chosen), 0 otherwise\n", + " score = 1 if decision == 'A' else 0\n", + " scores_dict[item_id] = score\n", + " results_list.append(result)\n", + "\n", + " # Calculate accuracy\n", + " accuracy = a_wins / len(data) if len(data) > 0 else 0\n", + "\n", + " # Per-subset accuracy\n", + " subset_metrics = defaultdict(lambda: {'correct': 0, 'total': 0})\n", + " for result in results_list:\n", + " subset = result.get('subset', 'Unknown')\n", + " subset_metrics[subset]['total'] += 1\n", + " if result.get('final_decision') == 'A':\n", + " subset_metrics[subset]['correct'] += 1\n", + "\n", + " subset_accuracy = {\n", + " subset: stats['correct'] / stats['total'] if stats['total'] > 0 else 0\n", + " for subset, stats in subset_metrics.items()\n", + " }\n", + "\n", + " metrics = {\n", + " 'accuracy': accuracy,\n", + " 'a_wins': a_wins,\n", + " 'b_wins': b_wins,\n", + " 'ties': ties,\n", + " 'results_path': str(results_path),\n", + " 'subset_accuracy': subset_accuracy,\n", + " 'total': len(data)\n", + " }\n", + "\n", + " # Store in history\n", + " self.eval_history.append({\n", + " 'name': eval_name,\n", + " 'prompt': judge_prompt,\n", + " 'metrics': metrics,\n", + " 'timestamp': datetime.now().isoformat()\n", + " })\n", + "\n", + " print(f\"✓ Accuracy: {accuracy:.2%}\")\n", + "\n", + " return scores_dict, metrics\n", + "\n", + " def get_failure_examples(\n", + " self,\n", + " data: List[Dict],\n", + " scores_dict: Dict[str, int],\n", + " max_examples: int = 10\n", + " ) -> List[Dict]:\n", + " \"\"\"Extract examples where judge made incorrect decisions.\"\"\"\n", + "\n", + " failures = []\n", + " for item in data:\n", + " item_id = item.get('id', 'unknown')\n", + " score = scores_dict.get(item_id, 0)\n", + "\n", + " if score == 0: # Incorrect judgment\n", + " failures.append({\n", + " 'id': item_id,\n", + " 'prompt': item['prompt'],\n", + " 'response_a': item['chosen'][:400], # Truncate for readability\n", + " 'response_b': item['rejected_1'][:400],\n", + " 'subset': item.get('subset', 'unknown'),\n", + " 'judge_error': 'Judge chose B, but humans preferred A'\n", + " })\n", + "\n", + " # Sample if too many\n", + " if len(failures) > max_examples:\n", + " failures = random.sample(failures, max_examples)\n", + "\n", + " return failures\n", + "\n", + "print(\"✓ TogetherEvalAdapter defined\")" + ], + "id": "705ea797cd19988d" + }, + { + "cell_type": "markdown", + "metadata": { + "id": "reflection_section" + }, + "source": [ + "## Reflection and Prompt Optimization" + ], + "id": "8687dc89668f8540" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "optimizer_class" + }, + "outputs": [], + "source": [ + "class SimpleOptimizerLM:\n", + " \"\"\"Simple wrapper for calling optimizer LLM.\"\"\"\n", + "\n", + " def __init__(self, model: str, api_key: str):\n", + " self.client = together.Client(api_key=api_key)\n", + " self.model = model\n", + "\n", + " def __call__(self, prompt: str, max_tokens: int = 4000) -> str:\n", + " \"\"\"Call the LLM with a prompt.\"\"\"\n", + " response = self.client.chat.completions.create(\n", + " model=self.model,\n", + " messages=[{\"role\": \"user\", \"content\": prompt}],\n", + " temperature=0.7,\n", + " max_tokens=max_tokens\n", + " )\n", + " return response.choices[0].message.content\n", + "\n", + "\n", + "def reflect_and_propose_prompt(\n", + " current_prompt: str,\n", + " failure_examples: List[Dict],\n", + " optimizer_lm: SimpleOptimizerLM,\n", + " iteration: int\n", + ") -> str:\n", + " \"\"\"\n", + " Use reflection LLM to analyze failures and propose improved prompt.\n", + " \"\"\"\n", + "\n", + " print(f\"\\n🤔 REFLECTION (Iteration {iteration})\")\n", + " print(f\" Analyzing {len(failure_examples)} failure cases...\")\n", + "\n", + " # Build reflection prompt\n", + " reflection_prompt = f\"\"\"You are optimizing a judge prompt for evaluating AI responses.\n", + "\n", + "The judge's task is to compare two AI responses (A and B) and determine which is better.\n", + "Response A is always the human-preferred response (ground truth).\n", + "Response B is the human-rejected response.\n", + "\n", + "**Current Judge Prompt:**\n", + "```\n", + "{current_prompt}\n", + "```\n", + "\n", + "**Performance Issue:**\n", + "The judge made INCORRECT decisions on the following examples.\n", + "In each case, the judge should have chosen Response A (human-preferred),\n", + "but instead chose Response B (human-rejected).\n", + "\n", + "**Failure Examples:**\n", + "\n", + "{json.dumps(failure_examples, indent=2)}\n", + "\n", + "**Your Task:**\n", + "1. Analyze why the current prompt led to these incorrect judgments\n", + "2. Identify patterns in the failures (e.g., specific types of prompts, common errors)\n", + "3. Propose an improved judge prompt that addresses these issues\n", + "\n", + "**Guidelines:**\n", + "- Keep successful aspects of the current prompt\n", + "- Add specific guidance for the failure patterns you identified\n", + "- Be concrete and actionable\n", + "- Focus on evaluation criteria, not output format\n", + "- Consider: Are there missing criteria? Wrong priorities? Unclear instructions?\n", + "\n", + "**Output the improved prompt within ``` blocks.**\n", + "\"\"\"\n", + "\n", + " # Call optimizer LM\n", + " print(\" Calling reflection LM...\")\n", + " response = optimizer_lm(reflection_prompt)\n", + "\n", + " # Extract new prompt\n", + " match = re.search(r'```(.*?)```', response, re.DOTALL)\n", + " if match:\n", + " new_prompt = match.group(1).strip()\n", + "\n", + " # Remove language tags if present\n", + " if new_prompt.startswith('markdown\\n') or new_prompt.startswith('text\\n'):\n", + " new_prompt = '\\n'.join(new_prompt.split('\\n')[1:])\n", + "\n", + " print(f\"✓ Generated new prompt ({len(new_prompt)} chars)\")\n", + " return new_prompt\n", + " else:\n", + " print(\"⚠️ Could not extract prompt, using current\")\n", + " return current_prompt\n", + "\n", + "print(\"✓ Reflection functions defined\")" + ], + "id": "158bb2f19983bdbf" + }, + { + "cell_type": "markdown", + "metadata": { + "id": "optimization_section" + }, + "source": [ + "## GEPA Optimization Loop" + ], + "id": "a6412ef68b7a7ef2" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "optimization_function" + }, + "outputs": [], + "source": [ + "def run_gepa_optimization(\n", + " train_data: List[Dict],\n", + " val_data: List[Dict],\n", + " test_data: List[Dict],\n", + " adapter: TogetherEvalAdapter,\n", + " optimizer_lm: SimpleOptimizerLM,\n", + " max_iterations: int = 10,\n", + " minibatch_size: int = 5\n", + "):\n", + " \"\"\"\n", + " Custom GEPA optimization loop using batch evaluation.\n", + " \"\"\"\n", + "\n", + " print(f\"\\n{'=' * 80}\")\n", + " print(\"🧬 GEPA OPTIMIZATION WITH BATCH EVALUATION\")\n", + " print(f\"{'=' * 80}\")\n", + " print(f\" Max iterations: {max_iterations}\")\n", + " print(f\" Minibatch size: {minibatch_size}\")\n", + " print(f\" Train size: {len(train_data)}\")\n", + " print(f\" Val size: {len(val_data)}\")\n", + "\n", + " # Track candidates (prompts and their performance)\n", + " candidates = [INITIAL_JUDGE_PROMPT]\n", + " candidate_val_scores = []\n", + "\n", + " # Baseline evaluation on validation set\n", + " print(f\"\\n{'=' * 80}\")\n", + " print(\"BASELINE EVALUATION\")\n", + " print(f\"{'=' * 80}\")\n", + "\n", + " _, baseline_metrics = adapter.run_batch_evaluation(\n", + " val_data,\n", + " \"baseline_val\",\n", + " judge_prompt=INITIAL_JUDGE_PROMPT\n", + " )\n", + "\n", + " baseline_acc = baseline_metrics['accuracy']\n", + " candidate_val_scores.append(baseline_acc)\n", + "\n", + " print(f\"\\n✓ Baseline validation accuracy: {baseline_acc:.2%}\")\n", + "\n", + " # GEPA optimization loop\n", + " best_acc = baseline_acc\n", + " best_prompt = INITIAL_JUDGE_PROMPT\n", + " no_improvement_count = 0\n", + "\n", + " for iteration in range(max_iterations):\n", + " print(f\"\\n{'=' * 80}\")\n", + " print(f\"ITERATION {iteration + 1}/{max_iterations}\")\n", + " print(f\"{'=' * 80}\")\n", + "\n", + " # Select best candidate so far\n", + " best_idx = np.argmax(candidate_val_scores)\n", + " current_prompt = candidates[best_idx]\n", + " current_acc = candidate_val_scores[best_idx]\n", + "\n", + " print(f\" Current best: Candidate {best_idx} ({current_acc:.2%})\")\n", + "\n", + " # Sample minibatch from training data\n", + " minibatch = random.sample(train_data, min(minibatch_size, len(train_data)))\n", + " print(f\" Sampled {len(minibatch)} examples for reflection\")\n", + "\n", + " # Evaluate minibatch with current prompt\n", + " mb_scores, mb_metrics = adapter.run_batch_evaluation(\n", + " minibatch,\n", + " f\"iter{iteration + 1}_minibatch\",\n", + " judge_prompt=current_prompt\n", + " )\n", + "\n", + " # Get failure examples\n", + " failures = adapter.get_failure_examples(minibatch, mb_scores, max_examples=5)\n", + "\n", + " if not failures:\n", + " print(\" ✓ Perfect on minibatch! Trying different sample...\")\n", + " continue\n", + "\n", + " print(f\" Found {len(failures)} failures in minibatch\")\n", + "\n", + " # Reflect and propose new prompt\n", + " new_prompt = reflect_and_propose_prompt(\n", + " current_prompt=current_prompt,\n", + " failure_examples=failures,\n", + " optimizer_lm=optimizer_lm,\n", + " iteration=iteration + 1\n", + " )\n", + "\n", + " # Check if prompt actually changed\n", + " if new_prompt == current_prompt:\n", + " print(\" ⚠️ Prompt unchanged, skipping validation\")\n", + " no_improvement_count += 1\n", + " if no_improvement_count >= 3:\n", + " print(\" 🛑 No changes for 3 iterations, stopping early\")\n", + " break\n", + " continue\n", + "\n", + " # Update adapter with new prompt\n", + " adapter.current_prompt = new_prompt\n", + "\n", + " # Evaluate on full validation set\n", + " print(f\"\\n Evaluating new prompt on validation set...\")\n", + " new_scores, new_metrics = adapter.run_batch_evaluation(\n", + " val_data,\n", + " f\"iter{iteration + 1}_candidate\",\n", + " judge_prompt=new_prompt\n", + " )\n", + "\n", + " new_acc = new_metrics['accuracy']\n", + " improvement = new_acc - current_acc\n", + "\n", + " print(f\"\\n Results:\")\n", + " print(f\" Current: {current_acc:.2%}\")\n", + " print(f\" New: {new_acc:.2%}\")\n", + " print(f\" Change: {improvement * 100:+.2f}pp\")\n", + "\n", + " # Add to candidates\n", + " candidates.append(new_prompt)\n", + " candidate_val_scores.append(new_acc)\n", + "\n", + " # Update best if improved\n", + " if new_acc > best_acc:\n", + " print(f\" 🎉 New best! Improvement: {(new_acc - best_acc) * 100:+.2f}pp\")\n", + " best_acc = new_acc\n", + " best_prompt = new_prompt\n", + " no_improvement_count = 0\n", + " else:\n", + " print(f\" No improvement over best ({best_acc:.2%})\")\n", + " no_improvement_count += 1\n", + "\n", + " if no_improvement_count >= 3:\n", + " print(\" 🛑 No improvement for 3 iterations, stopping early\")\n", + " break\n", + "\n", + " # Final evaluation on test set\n", + " print(f\"\\n{'=' * 80}\")\n", + " print(\"FINAL TEST SET EVALUATION\")\n", + " print(f\"{'=' * 80}\")\n", + "\n", + " # Baseline on test\n", + " print(\"\\n[1/2] Baseline on test set...\")\n", + " _, baseline_test_metrics = adapter.run_batch_evaluation(\n", + " test_data,\n", + " \"baseline_test\",\n", + " judge_prompt=INITIAL_JUDGE_PROMPT\n", + " )\n", + "\n", + " # Optimized on test\n", + " print(\"\\n[2/2] Optimized on test set...\")\n", + " _, optimized_test_metrics = adapter.run_batch_evaluation(\n", + " test_data,\n", + " \"optimized_test\",\n", + " judge_prompt=best_prompt\n", + " )\n", + "\n", + " # Summary\n", + " print(f\"\\n{'=' * 80}\")\n", + " print(\"🎉 OPTIMIZATION COMPLETE!\")\n", + " print(f\"{'=' * 80}\")\n", + "\n", + " print(f\"\\nVALIDATION RESULTS:\")\n", + " print(f\" Baseline: {baseline_acc:.2%}\")\n", + " print(f\" Optimized: {best_acc:.2%}\")\n", + " print(f\" Improvement: {(best_acc - baseline_acc) * 100:+.2f}pp\")\n", + "\n", + " print(f\"\\nTEST RESULTS:\")\n", + " print(f\" Baseline: {baseline_test_metrics['accuracy']:.2%}\")\n", + " print(f\" Optimized: {optimized_test_metrics['accuracy']:.2%}\")\n", + " print(f\" Improvement: {(optimized_test_metrics['accuracy'] - baseline_test_metrics['accuracy']) * 100:+.2f}pp\")\n", + "\n", + " # Per-subset breakdown\n", + " print(f\"\\n📊 PER-SUBSET BREAKDOWN (Test Set):\")\n", + " all_subsets = set(baseline_test_metrics['subset_accuracy'].keys()) | set(\n", + " optimized_test_metrics['subset_accuracy'].keys())\n", + "\n", + " for subset in sorted(all_subsets):\n", + " base_acc = baseline_test_metrics['subset_accuracy'].get(subset, 0)\n", + " opt_acc = optimized_test_metrics['subset_accuracy'].get(subset, 0)\n", + " improvement = opt_acc - base_acc\n", + " print(f\" {subset:20s}: {base_acc:.2%} → {opt_acc:.2%} ({improvement * 100:+.1f}pp)\")\n", + "\n", + " return {\n", + " 'best_prompt': best_prompt,\n", + " 'best_val_accuracy': best_acc,\n", + " 'baseline_test_metrics': baseline_test_metrics,\n", + " 'optimized_test_metrics': optimized_test_metrics,\n", + " 'candidates': candidates,\n", + " 'candidate_scores': candidate_val_scores,\n", + " 'eval_history': adapter.eval_history\n", + " }\n", + "\n", + "print(\"✓ Optimization function defined\")" + ], + "id": "6cc56af3c5cae042" + }, + { + "cell_type": "markdown", + "metadata": { + "id": "upload_data" + }, + "source": [ + "## Load Your Data\n", + "\n", + "You have two options:\n", + "1. **Use pre-prepared data** - If you already have train/val/test JSONL files from data preparation\n", + "2. **Prepare from raw data** - Upload a raw JSON file and split it into train/val/test" + ], + "id": "f78bfe562ece8dfd" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "data_option" + }, + "outputs": [], + "source": [ + "# Choose your data loading option\n", + "USE_PREPARED_DATA = False # Set to True if you have pre-prepared JSONL files\n", + "\n", + "if USE_PREPARED_DATA:\n", + " print(\"Using pre-prepared data files\")\n", + " print(\"Please upload your train, val, and test JSONL files\")\n", + "else:\n", + " print(\"Will prepare data from raw JSON file\")" + ], + "id": "28dbff2c3f62cbb2" + }, + { + "cell_type": "markdown", + "metadata": { + "id": "option1" + }, + "source": [ + "### Option 1: Use Pre-Prepared Data (JSONL files)\n", + "\n", + "If you already have train.jsonl, val.jsonl, and test.jsonl files from a previous data preparation step, upload them here." + ], + "id": "896bb0c7dcdc89b4" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "upload_prepared" + }, + "outputs": [], + "source": [ + "if USE_PREPARED_DATA:\n", + " print(\"Please upload train.jsonl, val.jsonl, and test.jsonl:\")\n", + " uploaded = files.upload()\n", + " \n", + " # Load the prepared data\n", + " train_data = []\n", + " val_data = []\n", + " test_data = []\n", + " \n", + " for filename in uploaded.keys():\n", + " with open(filename, 'r') as f:\n", + " data = [json.loads(line) for line in f]\n", + " \n", + " if 'train' in filename.lower():\n", + " train_data = data\n", + " print(f\"✓ Loaded {len(train_data)} train examples from {filename}\")\n", + " elif 'val' in filename.lower():\n", + " val_data = data\n", + " print(f\"✓ Loaded {len(val_data)} val examples from {filename}\")\n", + " elif 'test' in filename.lower():\n", + " test_data = data\n", + " print(f\"✓ Loaded {len(test_data)} test examples from {filename}\")\n", + " \n", + " print(f\"\\n✓ Data loaded:\")\n", + " print(f\" Train: {len(train_data)} examples\")\n", + " print(f\" Val: {len(val_data)} examples\")\n", + " print(f\" Test: {len(test_data)} examples\")\n", + "else:\n", + " print(\"Skipping - will use raw data preparation instead\")" + ], + "id": "9f70d7ac3eef1e1c" + }, + { + "cell_type": "markdown", + "metadata": { + "id": "option2" + }, + "source": [ + "### Option 2: Prepare from Raw Data\n", + "\n", + "Upload a raw JSON file and it will be split into train/val/test sets." + ], + "id": "1a3a807677b32e2f" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "upload_raw" + }, + "outputs": [], + "source": [ + "if not USE_PREPARED_DATA:\n", + " # Upload data file\n", + " print(\"Please upload your JSON data file:\")\n", + " uploaded = files.upload()\n", + " \n", + " # Get the filename\n", + " data_path = list(uploaded.keys())[0]\n", + " print(f\"\\n✓ Uploaded: {data_path}\")\n", + " \n", + " # Load and split data\n", + " train_data, val_data, test_data = load_and_split_data(data_path)\n", + "else:\n", + " print(\"Skipping - using pre-prepared data instead\")" + ], + "id": "9efab828b215759a" + }, + { + "cell_type": "markdown", + "metadata": { + "id": "run_section" + }, + "source": [ + "## Run Optimization" + ], + "id": "1d79a3f44d238a1a" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "run_optimization" + }, + "outputs": [], + "source": [ + "# Configuration\n", + "MAX_ITERATIONS = 10\n", + "MINIBATCH_SIZE = 5\n", + "\n", + "print(\"=\" * 80)\n", + "print(\"🎯 GEPA JUDGE OPTIMIZATION WITH TOGETHER AI\")\n", + "print(\"=\" * 80)\n", + "print(f\"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\")\n", + "\n", + "# Data should already be loaded from previous cells\n", + "print(f\"\\nUsing data:\")\n", + "print(f\" Train: {len(train_data)} examples\")\n", + "print(f\" Val: {len(val_data)} examples\")\n", + "print(f\" Test: {len(test_data)} examples\")\n", + "\n", + "# Create adapter\n", + "adapter = TogetherEvalAdapter(\n", + " client=client,\n", + " judge_model=JUDGE_MODEL,\n", + " initial_prompt=INITIAL_JUDGE_PROMPT\n", + ")\n", + "\n", + "# Create optimizer LM\n", + "optimizer_lm = SimpleOptimizerLM(\n", + " model=OPTIMIZER_MODEL,\n", + " api_key=TOGETHER_API_KEY\n", + ")\n", + "\n", + "# Run GEPA optimization\n", + "results = run_gepa_optimization(\n", + " train_data=train_data,\n", + " val_data=val_data,\n", + " test_data=test_data,\n", + " adapter=adapter,\n", + " optimizer_lm=optimizer_lm,\n", + " max_iterations=MAX_ITERATIONS,\n", + " minibatch_size=MINIBATCH_SIZE\n", + ")" + ], + "id": "e0439cbdd9d7bccf" + }, + { + "cell_type": "markdown", + "metadata": { + "id": "results_section" + }, + "source": [ + "## Save Results" + ], + "id": "81658b8fca3a1934" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "save_results" + }, + "outputs": [], + "source": [ + "# Save results\n", + "output_dir = Path(\"results\")\n", + "output_dir.mkdir(exist_ok=True)\n", + "\n", + "timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n", + "\n", + "# Save optimized prompt\n", + "prompt_path = output_dir / f\"optimized_prompt_{timestamp}.txt\"\n", + "with open(prompt_path, 'w') as f:\n", + " f.write(results['best_prompt'])\n", + "print(f\"\\n💾 Saved optimized prompt to: {prompt_path}\")\n", + "\n", + "# Save full results\n", + "results_path = output_dir / f\"optimization_results_{timestamp}.json\"\n", + "\n", + "# Make results JSON-serializable\n", + "json_results = {\n", + " 'best_prompt': results['best_prompt'],\n", + " 'best_val_accuracy': float(results['best_val_accuracy']),\n", + " 'baseline_test_accuracy': float(results['baseline_test_metrics']['accuracy']),\n", + " 'optimized_test_accuracy': float(results['optimized_test_metrics']['accuracy']),\n", + " 'improvement': float(\n", + " results['optimized_test_metrics']['accuracy'] - results['baseline_test_metrics']['accuracy']),\n", + " 'baseline_test_metrics': {\n", + " 'accuracy': float(results['baseline_test_metrics']['accuracy']),\n", + " 'a_wins': results['baseline_test_metrics']['a_wins'],\n", + " 'b_wins': results['baseline_test_metrics']['b_wins'],\n", + " 'ties': results['baseline_test_metrics']['ties'],\n", + " 'subset_accuracy': {k: float(v) for k, v in results['baseline_test_metrics']['subset_accuracy'].items()}\n", + " },\n", + " 'optimized_test_metrics': {\n", + " 'accuracy': float(results['optimized_test_metrics']['accuracy']),\n", + " 'a_wins': results['optimized_test_metrics']['a_wins'],\n", + " 'b_wins': results['optimized_test_metrics']['b_wins'],\n", + " 'ties': results['optimized_test_metrics']['ties'],\n", + " 'subset_accuracy': {k: float(v) for k, v in results['optimized_test_metrics']['subset_accuracy'].items()}\n", + " },\n", + " 'num_candidates': len(results['candidates']),\n", + " 'candidate_scores': [float(s) for s in results['candidate_scores']],\n", + " 'config': {\n", + " 'judge_model': JUDGE_MODEL,\n", + " 'optimizer_model': OPTIMIZER_MODEL,\n", + " 'train_size': TRAIN_SIZE,\n", + " 'val_size': VAL_SIZE,\n", + " 'test_size': len(test_data),\n", + " 'max_iterations': MAX_ITERATIONS,\n", + " 'minibatch_size': MINIBATCH_SIZE\n", + " },\n", + " 'timestamp': timestamp\n", + "}\n", + "\n", + "with open(results_path, 'w') as f:\n", + " json.dump(json_results, f, indent=2)\n", + "print(f\"💾 Saved results to: {results_path}\")\n", + "\n", + "# Display optimized prompt\n", + "print(f\"\\n{'=' * 80}\")\n", + "print(\"📝 OPTIMIZED JUDGE PROMPT\")\n", + "print(f\"{'=' * 80}\")\n", + "print(results['best_prompt'])\n", + "print(f\"{'=' * 80}\")\n", + "\n", + "print(\"\\n✅ Optimization complete!\")" + ], + "id": "a8abac7c78b7fd93" + }, + { + "cell_type": "markdown", + "metadata": { + "id": "download_section" + }, + "source": [ + "## Download Results" + ], + "id": "9966a303b57e35f" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "download" + }, + "outputs": [], + "source": [ + "# Download the optimized prompt and results\n", + "files.download(str(prompt_path))\n", + "files.download(str(results_path))\n", + "\n", + "print(\"\\n📥 Files downloaded to your local machine!\")" + ], + "id": "45acd003d83683eb" + } + ], + "metadata": { + "colab": { + "name": "GEPA_Judge_Optimization.ipynb", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} From 713c10bb942076c267b0de4d624578f27dc91287 Mon Sep 17 00:00:00 2001 From: jli Date: Tue, 23 Dec 2025 11:31:57 +0800 Subject: [PATCH 4/6] update file id option for data prep --- Evals/Prompt_Optimization.ipynb | 168 +++++++++----------------------- 1 file changed, 45 insertions(+), 123 deletions(-) diff --git a/Evals/Prompt_Optimization.ipynb b/Evals/Prompt_Optimization.ipynb index fca0b82..e6c8ba5 100644 --- a/Evals/Prompt_Optimization.ipynb +++ b/Evals/Prompt_Optimization.ipynb @@ -9,11 +9,12 @@ "# GEPA Judge Optimization with Together Eval\n", "\n", "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/togethercomputer/together-cookbook/blob/main/Evals/Prompt_Optimization.ipynb)\n", + "\n", "Custom implementation using GEPAAdapter pattern for batch evaluation.\n", "\n", "Based on the GEPA paper: https://arxiv.org/pdf/2507.19457" ], - "id": "e0bedc47864a990d" + "id": "9c117b36eb1133a6" }, { "cell_type": "markdown", @@ -23,7 +24,7 @@ "source": [ "## Setup and Installation" ], - "id": "c1c07266921cba9c" + "id": "528b9d47624473d" }, { "cell_type": "code", @@ -36,7 +37,7 @@ "# Install required packages\n", "!pip install together numpy -q" ], - "id": "cf06d9b2c2e9f523" + "id": "1ed67e650532a24" }, { "cell_type": "code", @@ -60,7 +61,7 @@ "import time\n", "from google.colab import files" ], - "id": "4b4c53f27a2634af" + "id": "dcb32b15d13a183a" }, { "cell_type": "markdown", @@ -70,7 +71,7 @@ "source": [ "## Configuration" ], - "id": "32535909dadd03a3" + "id": "fc8534d0746b3611" }, { "cell_type": "code", @@ -80,15 +81,19 @@ }, "outputs": [], "source": [ - "from google.colab import userdata\n", - "TOGETHER_API_KEY = userdata.get('TOGETHER_API_KEY')\n", + "# Set your Together API key\n", + "TOGETHER_API_KEY = \"\" # Add your API key here\n", + "\n", + "# Or load from Colab secrets\n", + "# from google.colab import userdata\n", + "# TOGETHER_API_KEY = userdata.get('TOGETHER_API_KEY')\n", "\n", "if not TOGETHER_API_KEY:\n", " raise ValueError(\"Please set your TOGETHER_API_KEY\")\n", "\n", "client = together.Client(api_key=TOGETHER_API_KEY)" ], - "id": "889b722e00eced06" + "id": "3115b5515a325e88" }, { "cell_type": "code", @@ -109,6 +114,7 @@ "\n", "RANDOM_SEED = 42\n", "\n", + "# Initial judge prompt\n", "INITIAL_JUDGE_PROMPT = \"\"\"You are an expert evaluator comparing AI-generated responses.\n", "\n", "⚠️ IMPORTANT: You are JUDGING existing responses, NOT creating them yourself.\n", @@ -130,7 +136,7 @@ "\n", "print(\"✓ Configuration loaded\")" ], - "id": "e72123f0ae9362d" + "id": "35a27477bfd041d4" }, { "cell_type": "markdown", @@ -140,7 +146,7 @@ "source": [ "## Data Preparation Functions" ], - "id": "599d88fcb439f34c" + "id": "8d6e9bcf91e40fd9" }, { "cell_type": "code", @@ -207,7 +213,7 @@ "\n", "print(\"✓ Data functions defined\")" ], - "id": "2e191e7b90c8d90f" + "id": "2e8612f7daa7af15" }, { "cell_type": "markdown", @@ -217,7 +223,7 @@ "source": [ "## Batch Evaluation Adapter" ], - "id": "8823c7b37827c51e" + "id": "23de03e011b2b6b7" }, { "cell_type": "code", @@ -433,7 +439,7 @@ "\n", "print(\"✓ TogetherEvalAdapter defined\")" ], - "id": "705ea797cd19988d" + "id": "e1adfe1d7222920d" }, { "cell_type": "markdown", @@ -443,7 +449,7 @@ "source": [ "## Reflection and Prompt Optimization" ], - "id": "8687dc89668f8540" + "id": "d6377864d655475b" }, { "cell_type": "code", @@ -541,7 +547,7 @@ "\n", "print(\"✓ Reflection functions defined\")" ], - "id": "158bb2f19983bdbf" + "id": "f1e3153e9fec9566" }, { "cell_type": "markdown", @@ -551,7 +557,7 @@ "source": [ "## GEPA Optimization Loop" ], - "id": "a6412ef68b7a7ef2" + "id": "20e3cadf5ac75e37" }, { "cell_type": "code", @@ -752,7 +758,7 @@ "\n", "print(\"✓ Optimization function defined\")" ], - "id": "6cc56af3c5cae042" + "id": "af54a457c63c18e4" }, { "cell_type": "markdown", @@ -762,118 +768,34 @@ "source": [ "## Load Your Data\n", "\n", - "You have two options:\n", - "1. **Use pre-prepared data** - If you already have train/val/test JSONL files from data preparation\n", - "2. **Prepare from raw data** - Upload a raw JSON file and split it into train/val/test" + "Paste the file ID for your uploaded data file from the data preparation step." ], - "id": "f78bfe562ece8dfd" + "id": "8f90733ff4c75863" }, { "cell_type": "code", "execution_count": null, "metadata": { - "id": "data_option" + "id": "load_file" }, "outputs": [], "source": [ - "# Choose your data loading option\n", - "USE_PREPARED_DATA = False # Set to True if you have pre-prepared JSONL files\n", - "\n", - "if USE_PREPARED_DATA:\n", - " print(\"Using pre-prepared data files\")\n", - " print(\"Please upload your train, val, and test JSONL files\")\n", - "else:\n", - " print(\"Will prepare data from raw JSON file\")" - ], - "id": "28dbff2c3f62cbb2" - }, - { - "cell_type": "markdown", - "metadata": { - "id": "option1" - }, - "source": [ - "### Option 1: Use Pre-Prepared Data (JSONL files)\n", + "# Paste your file ID from the data preparation step\n", + "DATA_FILE_ID = \"\" # e.g., \"file-65aa3ce1-cc93-48d0-b871-b974665f3dd1\"\n", "\n", - "If you already have train.jsonl, val.jsonl, and test.jsonl files from a previous data preparation step, upload them here." - ], - "id": "896bb0c7dcdc89b4" - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "upload_prepared" - }, - "outputs": [], - "source": [ - "if USE_PREPARED_DATA:\n", - " print(\"Please upload train.jsonl, val.jsonl, and test.jsonl:\")\n", - " uploaded = files.upload()\n", - " \n", - " # Load the prepared data\n", - " train_data = []\n", - " val_data = []\n", - " test_data = []\n", - " \n", - " for filename in uploaded.keys():\n", - " with open(filename, 'r') as f:\n", - " data = [json.loads(line) for line in f]\n", - " \n", - " if 'train' in filename.lower():\n", - " train_data = data\n", - " print(f\"✓ Loaded {len(train_data)} train examples from {filename}\")\n", - " elif 'val' in filename.lower():\n", - " val_data = data\n", - " print(f\"✓ Loaded {len(val_data)} val examples from {filename}\")\n", - " elif 'test' in filename.lower():\n", - " test_data = data\n", - " print(f\"✓ Loaded {len(test_data)} test examples from {filename}\")\n", - " \n", - " print(f\"\\n✓ Data loaded:\")\n", - " print(f\" Train: {len(train_data)} examples\")\n", - " print(f\" Val: {len(val_data)} examples\")\n", - " print(f\" Test: {len(test_data)} examples\")\n", - "else:\n", - " print(\"Skipping - will use raw data preparation instead\")" - ], - "id": "9f70d7ac3eef1e1c" - }, - { - "cell_type": "markdown", - "metadata": { - "id": "option2" - }, - "source": [ - "### Option 2: Prepare from Raw Data\n", + "if not DATA_FILE_ID:\n", + " raise ValueError(\"Please provide the DATA_FILE_ID\")\n", "\n", - "Upload a raw JSON file and it will be split into train/val/test sets." - ], - "id": "1a3a807677b32e2f" - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "upload_raw" - }, - "outputs": [], - "source": [ - "if not USE_PREPARED_DATA:\n", - " # Upload data file\n", - " print(\"Please upload your JSON data file:\")\n", - " uploaded = files.upload()\n", - " \n", - " # Get the filename\n", - " data_path = list(uploaded.keys())[0]\n", - " print(f\"\\n✓ Uploaded: {data_path}\")\n", - " \n", - " # Load and split data\n", - " train_data, val_data, test_data = load_and_split_data(data_path)\n", - "else:\n", - " print(\"Skipping - using pre-prepared data instead\")" + "# Download the data from Together AI\n", + "print(\"📥 Downloading data from Together AI...\")\n", + "data_path = \"uploaded_data.json\"\n", + "client.files.retrieve_content(DATA_FILE_ID, output=data_path)\n", + "print(f\"✓ Downloaded data file\")\n", + "\n", + "# Load and split data\n", + "train_data, val_data, test_data = load_and_split_data(data_path)" ], - "id": "9efab828b215759a" + "id": "8190874d74f476c6" }, { "cell_type": "markdown", @@ -883,7 +805,7 @@ "source": [ "## Run Optimization" ], - "id": "1d79a3f44d238a1a" + "id": "f221b494a47af25c" }, { "cell_type": "code", @@ -932,7 +854,7 @@ " minibatch_size=MINIBATCH_SIZE\n", ")" ], - "id": "e0439cbdd9d7bccf" + "id": "7a148733370bef76" }, { "cell_type": "markdown", @@ -942,7 +864,7 @@ "source": [ "## Save Results" ], - "id": "81658b8fca3a1934" + "id": "fbc868fad97c17da" }, { "cell_type": "code", @@ -1016,7 +938,7 @@ "\n", "print(\"\\n✅ Optimization complete!\")" ], - "id": "a8abac7c78b7fd93" + "id": "ac47aa62059f235c" }, { "cell_type": "markdown", @@ -1026,7 +948,7 @@ "source": [ "## Download Results" ], - "id": "9966a303b57e35f" + "id": "66f95fa71a7a7c3e" }, { "cell_type": "code", @@ -1042,7 +964,7 @@ "\n", "print(\"\\n📥 Files downloaded to your local machine!\")" ], - "id": "45acd003d83683eb" + "id": "2c3d5ec0e4644663" } ], "metadata": { From 9f7d34f34a0ed7a885dfd9e3172fa1e7aed49b6b Mon Sep 17 00:00:00 2001 From: jli Date: Tue, 23 Dec 2025 11:40:56 +0800 Subject: [PATCH 5/6] update --- Evals/GEPA_Optimization.ipynb | 2569 ++++++++++++++++++--------------- 1 file changed, 1411 insertions(+), 1158 deletions(-) diff --git a/Evals/GEPA_Optimization.ipynb b/Evals/GEPA_Optimization.ipynb index 11127d4..23a709d 100644 --- a/Evals/GEPA_Optimization.ipynb +++ b/Evals/GEPA_Optimization.ipynb @@ -1,1203 +1,1456 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "9bed21b9f21cadb7" - }, - "source": [ - "# GEPA Summarization Optimization with LLM Judge Evaluation\n", - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/togethercomputer/together-cookbook/blob/main/Evals/GEPA_Optimization.ipynb)\n", - "\n", - "## Introduction\n", - "\n", - "This notebook demonstrates how to optimize summarization prompts using GEPA (Generate, Evaluate, Propose, Adapt) with the our Evaluations API. We'll:\n", - "\n", - "1. Load the CNN/DailyMail dataset containing news articles\n", - "2. Start with a baseline summarization prompt\n", - "3. Use an optimizer LLM to iteratively improve the prompt\n", - "4. Compare prompts head-to-head using a judge model\n", - "5. Track improvement over multiple iterations\n", - "\n", - "**Concepts Covered:**\n", - "- **GEPA Optimization**: Iterative prompt engineering using LLM feedback\n", - "- **LLM-as-a-Judge**: Using a language model to evaluate and compare outputs\n", - "- **Batch Evaluation**: Efficient comparison of multiple summaries\n", - "- **Prompt Engineering**: Systematic improvement of instruction prompts" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "c044d292f626f2f6" - }, - "source": [ - "## 📦 Setup and Installation" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "id": "cf56ca26c1b94222" - }, - "outputs": [], - "source": [ - "!pip install -qU together dspy-ai datasets tqdm" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 216 + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "9bed21b9f21cadb7" + }, + "source": [ + "# GEPA Summarization Optimization with LLM Judge Evaluation\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/togethercomputer/together-cookbook/blob/main/Evals/GEPA_Optimization.ipynb)\n", + "\n", + "## Introduction\n", + "\n", + "This notebook demonstrates how to optimize summarization prompts using GEPA (Generate, Evaluate, Propose, Adapt) with the our Evaluations API. We'll:\n", + "\n", + "1. Load the CNN/DailyMail dataset containing news articles\n", + "2. Start with a baseline summarization prompt\n", + "3. Use an optimizer LLM to iteratively improve the prompt\n", + "4. Compare prompts head-to-head using a judge model\n", + "5. Track improvement over multiple iterations\n", + "\n", + "**Concepts Covered:**\n", + "- **GEPA Optimization**: Iterative prompt engineering using LLM feedback\n", + "- **LLM-as-a-Judge**: Using a language model to evaluate and compare outputs\n", + "- **Batch Evaluation**: Efficient comparison of multiple summaries\n", + "- **Prompt Engineering**: Systematic improvement of instruction prompts" + ] }, - "id": "1c293b491e894110", - "outputId": "e393f618-61a5-415e-ce69-18ebf78fbe99" - }, - "outputs": [ { - "output_type": "display_data", - "data": { - "text/plain": [ - "\u001B[36m╭─\u001B[0m\u001B[36m────────────────────────────────────────────\u001B[0m\u001B[36m 🚀 New SDK Available \u001B[0m\u001B[36m─────────────────────────────────────────────\u001B[0m\u001B[36m─╮\u001B[0m\n", - "\u001B[36m│\u001B[0m \u001B[1;36mTogether Python SDK 2.0 is now available!\u001B[0m \u001B[36m│\u001B[0m\n", - "\u001B[36m│\u001B[0m \u001B[36m│\u001B[0m\n", - "\u001B[36m│\u001B[0m Install the beta: \u001B[36m│\u001B[0m\n", - "\u001B[36m│\u001B[0m \u001B[32mpip install --pre together\u001B[0m or \u001B[32muv add together --prerelease allow\u001B[0m \u001B[36m│\u001B[0m\n", - "\u001B[36m│\u001B[0m \u001B[36m│\u001B[0m\n", - "\u001B[36m│\u001B[0m New SDK: \u001B]8;id=629133;https://github.com/togethercomputer/together-py\u001B\\https://github.com/togethercomputer/together-py\u001B]8;;\u001B\\ \u001B[36m│\u001B[0m\n", - "\u001B[36m│\u001B[0m Migration guide: \u001B]8;id=644417;https://docs.together.ai/docs/pythonv2-migration-guide\u001B\\https://docs.together.ai/docs/pythonv2-migration-guide\u001B]8;;\u001B\\ \u001B[36m│\u001B[0m\n", - "\u001B[36m│\u001B[0m \u001B[36m│\u001B[0m\n", - "\u001B[36m│\u001B[0m \u001B[2mThis package will be maintained until January 2026.\u001B[0m \u001B[36m│\u001B[0m\n", - "\u001B[36m│\u001B[0m \u001B[2mSet TOGETHER_NO_BANNER=1 to hide this message.\u001B[0m \u001B[36m│\u001B[0m\n", - "\u001B[36m╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\u001B[0m\n" + "cell_type": "markdown", + "metadata": { + "id": "c044d292f626f2f6" + }, + "source": [ + "## \ud83d\udce6 Setup and Installation" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "cf56ca26c1b94222" + }, + "outputs": [], + "source": [ + "!pip install -qU together dspy-ai datasets tqdm" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 216 + }, + "id": "1c293b491e894110", + "outputId": "e393f618-61a5-415e-ce69-18ebf78fbe99" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "\u001b[36m\u256d\u2500\u001b[0m\u001b[36m\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u001b[0m\u001b[36m \ud83d\ude80 New SDK Available \u001b[0m\u001b[36m\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u001b[0m\u001b[36m\u2500\u256e\u001b[0m\n", + "\u001b[36m\u2502\u001b[0m \u001b[1;36mTogether Python SDK 2.0 is now available!\u001b[0m \u001b[36m\u2502\u001b[0m\n", + "\u001b[36m\u2502\u001b[0m \u001b[36m\u2502\u001b[0m\n", + "\u001b[36m\u2502\u001b[0m Install the beta: \u001b[36m\u2502\u001b[0m\n", + "\u001b[36m\u2502\u001b[0m \u001b[32mpip install --pre together\u001b[0m or \u001b[32muv add together --prerelease allow\u001b[0m \u001b[36m\u2502\u001b[0m\n", + "\u001b[36m\u2502\u001b[0m \u001b[36m\u2502\u001b[0m\n", + "\u001b[36m\u2502\u001b[0m New SDK: \u001b]8;id=629133;https://github.com/togethercomputer/together-py\u001b\\https://github.com/togethercomputer/together-py\u001b]8;;\u001b\\ \u001b[36m\u2502\u001b[0m\n", + "\u001b[36m\u2502\u001b[0m Migration guide: \u001b]8;id=644417;https://docs.together.ai/docs/pythonv2-migration-guide\u001b\\https://docs.together.ai/docs/pythonv2-migration-guide\u001b]8;;\u001b\\ \u001b[36m\u2502\u001b[0m\n", + "\u001b[36m\u2502\u001b[0m \u001b[36m\u2502\u001b[0m\n", + "\u001b[36m\u2502\u001b[0m \u001b[2mThis package will be maintained until January 2026.\u001b[0m \u001b[36m\u2502\u001b[0m\n", + "\u001b[36m\u2502\u001b[0m \u001b[2mSet TOGETHER_NO_BANNER=1 to hide this message.\u001b[0m \u001b[36m\u2502\u001b[0m\n", + "\u001b[36m\u2570\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u256f\u001b[0m\n" + ], + "text/html": [ + "
\u256d\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500 \ud83d\ude80 New SDK Available \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u256e\n",
+              "\u2502 Together Python SDK 2.0 is now available!                                                                       \u2502\n",
+              "\u2502                                                                                                                 \u2502\n",
+              "\u2502 Install the beta:                                                                                               \u2502\n",
+              "\u2502 pip install --pre together  or  uv add together --prerelease allow                                              \u2502\n",
+              "\u2502                                                                                                                 \u2502\n",
+              "\u2502 New SDK: https://github.com/togethercomputer/together-py                                                        \u2502\n",
+              "\u2502 Migration guide: https://docs.together.ai/docs/pythonv2-migration-guide                                         \u2502\n",
+              "\u2502                                                                                                                 \u2502\n",
+              "\u2502 This package will be maintained until January 2026.                                                             \u2502\n",
+              "\u2502 Set TOGETHER_NO_BANNER=1 to hide this message.                                                                  \u2502\n",
+              "\u2570\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u256f\n",
+              "
\n" + ] + }, + "metadata": {} + } ], - "text/html": [ - "
╭───────────────────────────────────────────── 🚀 New SDK Available ──────────────────────────────────────────────╮\n",
-       " Together Python SDK 2.0 is now available!                                                                       \n",
-       "                                                                                                                 \n",
-       " Install the beta:                                                                                               \n",
-       " pip install --pre together  or  uv add together --prerelease allow                                              \n",
-       "                                                                                                                 \n",
-       " New SDK: https://github.com/togethercomputer/together-py                                                        \n",
-       " Migration guide: https://docs.together.ai/docs/pythonv2-migration-guide                                         \n",
-       "                                                                                                                 \n",
-       " This package will be maintained until January 2026.                                                             \n",
-       " Set TOGETHER_NO_BANNER=1 to hide this message.                                                                  \n",
-       "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n",
-       "
\n" + "source": [ + "import together\n", + "import json\n", + "import random\n", + "import os\n", + "import re\n", + "import time\n", + "from pathlib import Path\n", + "from typing import List, Dict, Tuple\n", + "from datetime import datetime\n", + "\n", + "import dspy\n", + "from datasets import load_dataset\n", + "from tqdm import tqdm" ] - }, - "metadata": {} - } - ], - "source": [ - "import together\n", - "import json\n", - "import random\n", - "import os\n", - "import re\n", - "import time\n", - "from pathlib import Path\n", - "from typing import List, Dict, Tuple\n", - "from datetime import datetime\n", - "\n", - "import dspy\n", - "from datasets import load_dataset\n", - "from tqdm import tqdm" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8e71863c8ff3faa6" - }, - "source": [ - "## ⚙️ Configuration\n", - "\n", - "Set up your API key and configure the models we'll use:\n", - "- **Summarizer Model**: Generates the summaries\n", - "- **Judge Model**: Evaluates which summary is better\n", - "- **Optimizer Model**: Proposes improvements to the prompt" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" }, - "id": "3d21616fa03c0145", - "outputId": "84889606-a0fb-4556-af15-3b1c9e7fc4ad" - }, - "outputs": [ { - "output_type": "stream", - "name": "stdout", - "text": [ - "✓ API key loaded from Colab secrets\n", - "✓ Configuration complete\n" - ] - } - ], - "source": [ - "# Set your Together AI API key from Colab secrets\n", - "from google.colab import userdata\n", - "TOGETHER_API_KEY = userdata.get('TOGETHER_API_KEY')\n", - "print(\"✓ API key loaded from Colab secrets\")\n", - "\n", - "client = together.Client(api_key=TOGETHER_API_KEY)\n", - "\n", - "# Model configuration\n", - "SUMMARIZER_MODEL = \"openai/gpt-oss-20b\"\n", - "JUDGE_MODEL = \"deepseek-ai/DeepSeek-V3\"\n", - "OPTIMIZER_MODEL = \"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo\"\n", - "\n", - "# Data splits\n", - "TRAIN_SIZE = 150\n", - "VAL_SIZE = 300\n", - "TEST_SIZE = 300\n", - "\n", - "RANDOM_SEED = 42\n", - "\n", - "print(\"✓ Configuration complete\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "d9378d341fb8389d" - }, - "source": [ - "## 📝 Baseline and Judge Prompts\n", - "\n", - "We start with a simple baseline prompt for summarization. The GEPA process will iteratively improve this prompt based on performance feedback." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" + "cell_type": "markdown", + "metadata": { + "id": "8e71863c8ff3faa6" + }, + "source": [ + "## \u2699\ufe0f Configuration\n", + "\n", + "Set up your API key and configure the models we'll use:\n", + "- **Summarizer Model**: Generates the summaries\n", + "- **Judge Model**: Evaluates which summary is better\n", + "- **Optimizer Model**: Proposes improvements to the prompt" + ] }, - "id": "263940c8c55eb1dd", - "outputId": "a2041a07-268c-4815-a7a4-85c964b7b2be" - }, - "outputs": [ { - "output_type": "stream", - "name": "stdout", - "text": [ - "Baseline Prompt:\n", - "Summarize this news article in 3-5 key points.\n", - "\n", - "Write a brief summary covering:\n", - "- The main news event\n", - "- Key people or organizations involved\n", - "- Important details or outcomes\n", - "- Any significant context\n", - "\n", - "Keep it to 3-5 sentences total.\n", - "\n", - "Judge Prompt:\n", - "Compare these two summaries of the same news article.\n", - "\n", - "Which summary better:\n", - "- Captures the main news story\n", - "- Includes important details\n", - "- Is clear and concise\n", - "- Avoids unnecessary information\n", - "\n", - "Choose A or B and explain why briefly.\n" - ] - } - ], - "source": [ - "BASELINE_PROMPT = \"\"\"Summarize this news article in 3-5 key points.\n", - "\n", - "Write a brief summary covering:\n", - "- The main news event\n", - "- Key people or organizations involved\n", - "- Important details or outcomes\n", - "- Any significant context\n", - "\n", - "Keep it to 3-5 sentences total.\"\"\"\n", - "\n", - "JUDGE_PROMPT = \"\"\"Compare these two summaries of the same news article.\n", - "\n", - "Which summary better:\n", - "- Captures the main news story\n", - "- Includes important details\n", - "- Is clear and concise\n", - "- Avoids unnecessary information\n", - "\n", - "Choose A or B and explain why briefly.\"\"\"\n", - "\n", - "print(\"Baseline Prompt:\")\n", - "print(BASELINE_PROMPT)\n", - "print(\"\\nJudge Prompt:\")\n", - "print(JUDGE_PROMPT)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "c0a86293e7b95dd9" - }, - "source": [ - "## 📂 Loading the CNN/DailyMail Dataset\n", - "\n", - "The CNN/DailyMail dataset contains news articles paired with human-written highlights. We'll use the articles as our source text and split the data into train, validation, and test sets.\n", - "\n", - "**Dataset Structure:**\n", - "- `article`: The full news article text\n", - "- `highlights`: Human-written bullet-point summary\n", - "- We'll use the articles for summarization and evaluate our generated summaries" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "3d21616fa03c0145", + "outputId": "84889606-a0fb-4556-af15-3b1c9e7fc4ad" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\u2713 API key loaded from Colab secrets\n", + "\u2713 Configuration complete\n" + ] + } + ], + "source": [ + "# Set your Together AI API key from Colab secrets\n", + "from google.colab import userdata\n", + "TOGETHER_API_KEY = userdata.get('TOGETHER_API_KEY')\n", + "print(\"\u2713 API key loaded from Colab secrets\")\n", + "\n", + "client = together.Client(api_key=TOGETHER_API_KEY)\n", + "\n", + "# Model configuration\n", + "SUMMARIZER_MODEL = \"openai/gpt-oss-20b\"\n", + "JUDGE_MODEL = \"deepseek-ai/DeepSeek-V3\"\n", + "OPTIMIZER_MODEL = \"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo\"\n", + "\n", + "# Data splits\n", + "TRAIN_SIZE = 150\n", + "VAL_SIZE = 300\n", + "TEST_SIZE = 300\n", + "\n", + "RANDOM_SEED = 42\n", + "\n", + "print(\"\u2713 Configuration complete\")" + ] }, - "id": "7dcc2d8d5c706df4", - "outputId": "e8dcb543-c238-42d3-af49-bcd77bfe7b7f" - }, - "outputs": [ { - "output_type": "stream", - "name": "stdout", - "text": [ - "\n", - "================================================================================\n", - "📂 LOADING DATA\n", - "================================================================================\n", - "Loading CNN/DailyMail dataset...\n", - "✓ Loaded 11490 examples\n", - " Sample article: (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Cour...\n", - " Sample highlights: Membership gives the ICC jurisdiction over alleged crimes committed in Palestinian territories since...\n", - "✓ Converted to 11490 items\n", - "✓ Split: Train=150, Val=300, Test=300\n" - ] - } - ], - "source": [ - "def load_and_split_data():\n", - " \"\"\"Load CNN/DailyMail dataset for summarization.\"\"\"\n", - " print(\"\\n\" + \"=\" * 80)\n", - " print(\"📂 LOADING DATA\")\n", - " print(\"=\" * 80)\n", - "\n", - " print(\"Loading CNN/DailyMail dataset...\")\n", - " dataset = load_dataset(\"abisee/cnn_dailymail\", \"3.0.0\")\n", - " data = dataset['test']\n", - "\n", - " print(f\"✓ Loaded {len(data)} examples\")\n", - " print(f\" Sample article: {data[0]['article'][:100]}...\")\n", - " print(f\" Sample highlights: {data[0]['highlights'][:100]}...\")\n", - "\n", - " all_data = []\n", - " for i, item in enumerate(data):\n", - " all_data.append({\n", - " 'id': f\"cnn_{i}\",\n", - " 'text': item['article'],\n", - " 'reference_summary': item['highlights']\n", - " })\n", - "\n", - " print(f\"✓ Converted to {len(all_data)} items\")\n", - "\n", - " random.seed(RANDOM_SEED)\n", - " random.shuffle(all_data)\n", - "\n", - " train_data = all_data[:TRAIN_SIZE]\n", - " val_data = all_data[TRAIN_SIZE:TRAIN_SIZE + VAL_SIZE]\n", - " test_data = all_data[TRAIN_SIZE + VAL_SIZE:TRAIN_SIZE + VAL_SIZE + TEST_SIZE]\n", - "\n", - " print(f\"✓ Split: Train={len(train_data)}, Val={len(val_data)}, Test={len(test_data)}\")\n", - "\n", - " assert len(val_data) > 0, \"Val data is empty!\"\n", - " assert len(test_data) > 0, \"Test data is empty!\"\n", - "\n", - " return train_data, val_data, test_data\n", - "\n", - "# Load the data\n", - "train_data, val_data, test_data = load_and_split_data()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "d1b9222690db8449" - }, - "source": [ - "## 🤖 Summarization Module\n", - "\n", - "We create a DSPy module that wraps our summarization task. This module can be configured with different instruction prompts, which is key to the GEPA optimization process." - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" + "cell_type": "markdown", + "metadata": { + "id": "d9378d341fb8389d" + }, + "source": [ + "## \ud83d\udcdd Baseline and Judge Prompts\n", + "\n", + "We start with a simple baseline prompt for summarization. The GEPA process will iteratively improve this prompt based on performance feedback." + ] }, - "id": "b8ca2917024c326e", - "outputId": "171c4567-9971-499a-edad-04b67c858885" - }, - "outputs": [ { - "output_type": "stream", - "name": "stdout", - "text": [ - "✓ Summarization module defined\n" - ] - } - ], - "source": [ - "class Summarizer(dspy.Signature):\n", - " \"\"\"Generate a summary.\"\"\"\n", - " text = dspy.InputField()\n", - " summary = dspy.OutputField()\n", - "\n", - "\n", - "class SummarizationModule(dspy.Module):\n", - " \"\"\"Summarization module.\"\"\"\n", - "\n", - " def __init__(self, instructions=None):\n", - " super().__init__()\n", - " self.instructions = instructions or BASELINE_PROMPT\n", - "\n", - " if instructions:\n", - " class CustomSummarizer(dspy.Signature):\n", - " __doc__ = instructions\n", - " text = dspy.InputField()\n", - " summary = dspy.OutputField()\n", - "\n", - " self.predictor = dspy.Predict(CustomSummarizer)\n", - " else:\n", - " self.predictor = dspy.Predict(Summarizer)\n", - "\n", - " def forward(self, text):\n", - " return self.predictor(text=text)\n", - "\n", - "print(\"✓ Summarization module defined\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "590d6b9c625ca2cc" - }, - "source": [ - "## 📊 Batch Summary Generation\n", - "\n", - "This function generates summaries for a batch of articles using a given prompt. It includes error handling and progress tracking." - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "263940c8c55eb1dd", + "outputId": "a2041a07-268c-4815-a7a4-85c964b7b2be" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Baseline Prompt:\n", + "Summarize this news article in 3-5 key points.\n", + "\n", + "Write a brief summary covering:\n", + "- The main news event\n", + "- Key people or organizations involved\n", + "- Important details or outcomes\n", + "- Any significant context\n", + "\n", + "Keep it to 3-5 sentences total.\n", + "\n", + "Judge Prompt:\n", + "Compare these two summaries of the same news article.\n", + "\n", + "Which summary better:\n", + "- Captures the main news story\n", + "- Includes important details\n", + "- Is clear and concise\n", + "- Avoids unnecessary information\n", + "\n", + "Choose A or B and explain why briefly.\n" + ] + } + ], + "source": [ + "BASELINE_PROMPT = \"\"\"Summarize this news article in 3-5 key points.\n", + "\n", + "Write a brief summary covering:\n", + "- The main news event\n", + "- Key people or organizations involved\n", + "- Important details or outcomes\n", + "- Any significant context\n", + "\n", + "Keep it to 3-5 sentences total.\"\"\"\n", + "\n", + "JUDGE_PROMPT = \"\"\"Compare these two summaries of the same news article.\n", + "\n", + "Which summary better:\n", + "- Captures the main news story\n", + "- Includes important details\n", + "- Is clear and concise\n", + "- Avoids unnecessary information\n", + "\n", + "Choose A or B and explain why briefly.\"\"\"\n", + "\n", + "print(\"Baseline Prompt:\")\n", + "print(BASELINE_PROMPT)\n", + "print(\"\\nJudge Prompt:\")\n", + "print(JUDGE_PROMPT)" + ] }, - "id": "270abdde73d2ca72", - "outputId": "6eafb2d3-e773-4a65-f3b5-802687fffafc" - }, - "outputs": [ { - "output_type": "stream", - "name": "stdout", - "text": [ - "✓ Batch generation function defined\n" - ] - } - ], - "source": [ - "def generate_summaries_batch(\n", - " summarizer: SummarizationModule,\n", - " data: List[Dict],\n", - " desc: str = \"Generating\"\n", - ") -> List[Dict]:\n", - " \"\"\"Generate summaries for a batch of texts.\"\"\"\n", - " results = []\n", - " errors = 0\n", - " error_details = []\n", - "\n", - " # Print the prompt being used (first item only)\n", - " if len(data) > 0:\n", - " print(f\" Using prompt: {summarizer.instructions[:100]}...\")\n", - "\n", - " for item in tqdm(data, desc=desc):\n", - " try:\n", - " pred = summarizer(text=item['text'][:5000])\n", - "\n", - " if pred is None:\n", - " raise ValueError(\"Model returned None\")\n", - "\n", - " if hasattr(pred, 'summary') and pred.summary:\n", - " summary = pred.summary\n", - " elif isinstance(pred, str):\n", - " summary = pred\n", - " else:\n", - " print(f\"\\n DEBUG: pred type={type(pred)}, hasattr summary={hasattr(pred, 'summary')}\")\n", - " raise ValueError(f\"Cannot extract summary from {type(pred)}\")\n", - "\n", - " summary = summary.strip()\n", - " if len(summary) < 20:\n", - " raise ValueError(\"Summary too short\")\n", - "\n", - " except Exception as e:\n", - " errors += 1\n", - " error_details.append(str(e)[:100])\n", - "\n", - " if errors <= 5:\n", - " print(f\"\\n⚠️ Error: {str(e)[:80]}\")\n", - "\n", - " summary = \"Error generating summary.\"\n", - "\n", - " results.append({\n", - " 'id': item['id'],\n", - " 'text': item['text'],\n", - " 'summary': summary\n", - " })\n", - "\n", - " if errors > 0:\n", - " print(f\"\\n⚠️ Total errors: {errors}/{len(data)} ({errors / len(data) * 100:.1f}%)\")\n", - " from collections import Counter\n", - " common_errors = Counter(error_details).most_common(3)\n", - " print(f\" Most common errors:\")\n", - " for err, count in common_errors:\n", - " print(f\" - {err[:60]}... ({count}x)\")\n", - "\n", - " return results\n", - "\n", - "print(\"✓ Batch generation function defined\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "2cfe63f485894d7c" - }, - "source": [ - "## 🧠 Optimizer LLM Wrapper\n", - "\n", - "This wrapper allows us to use an LLM to propose improvements to our summarization prompt based on current performance." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" + "cell_type": "markdown", + "metadata": { + "id": "c0a86293e7b95dd9" + }, + "source": [ + "## \ud83d\udcc2 Loading the CNN/DailyMail Dataset\n", + "\n", + "The CNN/DailyMail dataset contains news articles paired with human-written highlights. We'll use the articles as our source text and split the data into train, validation, and test sets.\n", + "\n", + "**Dataset Structure:**\n", + "- `article`: The full news article text\n", + "- `highlights`: Human-written bullet-point summary\n", + "- We'll use the articles for summarization and evaluate our generated summaries" + ] }, - "id": "d11af9ff91f442df", - "outputId": "c9cd0f0e-7325-46cc-d065-d4a3745c08c3" - }, - "outputs": [ { - "output_type": "stream", - "name": "stdout", - "text": [ - "✓ Optimizer LLM wrapper defined\n" - ] - } - ], - "source": [ - "class SimpleOptimizerLM:\n", - " \"\"\"Wrapper for optimizer LLM.\"\"\"\n", - "\n", - " def __init__(self, model: str, api_key: str):\n", - " self.client = together.Client(api_key=api_key)\n", - " self.model = model\n", - "\n", - " def __call__(self, prompt: str) -> str:\n", - " response = self.client.chat.completions.create(\n", - " model=self.model,\n", - " messages=[{\"role\": \"user\", \"content\": prompt}],\n", - " temperature=0.7,\n", - " max_tokens=4000\n", - " )\n", - " return response.choices[0].message.content\n", - "\n", - "print(\"✓ Optimizer LLM wrapper defined\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "67a224aff87d2f5e" - }, - "source": [ - "## 🤔 Reflection and Prompt Improvement\n", - "\n", - "This function uses the optimizer LLM to analyze the current prompt and performance, then propose an improved version.\n", - "\n", - "**Key Constraints:**\n", - "- Keep prompts under 150 words for clarity\n", - "- Focus on simple, direct instructions\n", - "- Target 4-6 sentence summaries\n", - "- Avoid overly complex requirements" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "7dcc2d8d5c706df4", + "outputId": "e8dcb543-c238-42d3-af49-bcd77bfe7b7f" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\n", + "================================================================================\n", + "\ud83d\udcc2 LOADING DATA\n", + "================================================================================\n", + "Loading CNN/DailyMail dataset...\n", + "\u2713 Loaded 11490 examples\n", + " Sample article: (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Cour...\n", + " Sample highlights: Membership gives the ICC jurisdiction over alleged crimes committed in Palestinian territories since...\n", + "\u2713 Converted to 11490 items\n", + "\u2713 Split: Train=150, Val=300, Test=300\n" + ] + } + ], + "source": [ + "def load_and_split_data():\n", + " \"\"\"Load CNN/DailyMail dataset for summarization.\"\"\"\n", + " print(\"\\n\" + \"=\" * 80)\n", + " print(\"\ud83d\udcc2 LOADING DATA\")\n", + " print(\"=\" * 80)\n", + "\n", + " print(\"Loading CNN/DailyMail dataset...\")\n", + " dataset = load_dataset(\"abisee/cnn_dailymail\", \"3.0.0\")\n", + " data = dataset['test']\n", + "\n", + " print(f\"\u2713 Loaded {len(data)} examples\")\n", + " print(f\" Sample article: {data[0]['article'][:100]}...\")\n", + " print(f\" Sample highlights: {data[0]['highlights'][:100]}...\")\n", + "\n", + " all_data = []\n", + " for i, item in enumerate(data):\n", + " all_data.append({\n", + " 'id': f\"cnn_{i}\",\n", + " 'text': item['article'],\n", + " 'reference_summary': item['highlights']\n", + " })\n", + "\n", + " print(f\"\u2713 Converted to {len(all_data)} items\")\n", + "\n", + " random.seed(RANDOM_SEED)\n", + " random.shuffle(all_data)\n", + "\n", + " train_data = all_data[:TRAIN_SIZE]\n", + " val_data = all_data[TRAIN_SIZE:TRAIN_SIZE + VAL_SIZE]\n", + " test_data = all_data[TRAIN_SIZE + VAL_SIZE:TRAIN_SIZE + VAL_SIZE + TEST_SIZE]\n", + "\n", + " print(f\"\u2713 Split: Train={len(train_data)}, Val={len(val_data)}, Test={len(test_data)}\")\n", + "\n", + " assert len(val_data) > 0, \"Val data is empty!\"\n", + " assert len(test_data) > 0, \"Test data is empty!\"\n", + "\n", + " return train_data, val_data, test_data\n", + "\n", + "# Load the data\n", + "train_data, val_data, test_data = load_and_split_data()" + ] }, - "id": "1186e66cab3ea1f1", - "outputId": "a8ea71b8-da99-4efa-c72b-59603458e664" - }, - "outputs": [ { - "output_type": "stream", - "name": "stdout", - "text": [ - "✓ Reflection function defined\n" - ] - } - ], - "source": [ - "def reflect_and_improve_prompt(\n", - " current_prompt: str,\n", - " current_score: float,\n", - " optimizer_lm: SimpleOptimizerLM,\n", - " iteration: int\n", - ") -> str:\n", - " \"\"\"Use LLM to propose improved prompt.\"\"\"\n", - "\n", - " print(f\"\\n🤔 REFLECTION (Iteration {iteration})\")\n", - "\n", - " reflection_prompt = f\"\"\"You are optimizing a summarization prompt for CNN/DailyMail news articles.\n", - "\n", - "Current Prompt:\n", - "```\n", - "{current_prompt}\n", - "```\n", - "\n", - "Current Performance: {current_score:.1%} win rate\n", - "\n", - "Your task: Propose a SIMPLE improved version that generates better summaries.\n", - "\n", - "CRITICAL CONSTRAINTS:\n", - "- Keep the prompt under 150 words\n", - "- Make it clear and direct (NOT overly complex)\n", - "- Target 4-6 sentence summaries\n", - "- Avoid excessive instructions or formatting requirements\n", - "- The prompt should be easy for the model to follow\n", - "\n", - "Focus on:\n", - "- Should it emphasize different aspects (accuracy, brevity, completeness)?\n", - "- Are the current guidelines clear?\n", - "- Is anything missing or unnecessary?\n", - "\n", - "Output ONLY the improved prompt within ``` blocks. Keep it simple and clear.\"\"\"\n", - "\n", - " response = optimizer_lm(reflection_prompt)\n", - "\n", - " # Extract prompt\n", - " match = re.search(r'```(.*?)```', response, re.DOTALL)\n", - " if match:\n", - " new_prompt = match.group(1).strip()\n", - " # Remove language tags\n", - " for tag in ['markdown', 'text', 'python', 'plaintext']:\n", - " if new_prompt.startswith(f'{tag}\\n'):\n", - " new_prompt = '\\n'.join(new_prompt.split('\\n')[1:])\n", - "\n", - " # Validate length (reject if too long)\n", - " word_count = len(new_prompt.split())\n", - " if word_count > 200:\n", - " print(f\" ⚠️ Generated prompt too long ({word_count} words), using current\")\n", - " return current_prompt\n", - "\n", - " print(f\"✓ Generated new prompt ({word_count} words)\")\n", - " return new_prompt\n", - "\n", - " print(\"⚠️ Could not extract prompt\")\n", - " return current_prompt\n", - "\n", - "print(\"✓ Reflection function defined\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "a2fbbd02f5054425" - }, - "source": [ - "## 🔄 Head-to-Head Prompt Comparison\n", - "\n", - "This function compares two prompts by:\n", - "1. Generating summaries with both prompts\n", - "2. Creating a comparison dataset\n", - "3. Using the Together AI evaluation API with a judge model\n", - "4. Computing win rates\n", - "\n", - "The evaluation uses a two-pass approach to eliminate position bias." - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" + "cell_type": "markdown", + "metadata": { + "id": "d1b9222690db8449" + }, + "source": [ + "## \ud83e\udd16 Summarization Module\n", + "\n", + "We create a DSPy module that wraps our summarization task. This module can be configured with different instruction prompts, which is key to the GEPA optimization process." + ] }, - "id": "5a1b2d5116f3731f", - "outputId": "f6aa5880-7905-4acc-c9ab-b01dc2b6a30f" - }, - "outputs": [ { - "output_type": "stream", - "name": "stdout", - "text": [ - "✓ Comparison function defined\n" - ] - } - ], - "source": [ - "def compare_two_prompts_on_batch(\n", - " data: List[Dict],\n", - " prompt_a: str,\n", - " prompt_b: str,\n", - " summarizer_lm: dspy.LM,\n", - " eval_name: str\n", - ") -> Tuple[float, float, Dict]:\n", - " \"\"\"\n", - " Compare two summarization prompts.\n", - "\n", - " 1. Generate summaries with prompt A\n", - " 2. Generate summaries with prompt B\n", - " 3. Use judge to compare them\n", - " 4. Return win rate for prompt A\n", - " \"\"\"\n", - "\n", - " print(f\"\\n{'=' * 80}\")\n", - " print(f\"🔄 COMPARING PROMPTS: {eval_name}\")\n", - " print(f\"{'=' * 80}\")\n", - "\n", - " # Step 1: Generate with both prompts\n", - " dspy.configure(lm=summarizer_lm)\n", - "\n", - " summarizer_a = SummarizationModule(prompt_a)\n", - " summarizer_b = SummarizationModule(prompt_b)\n", - "\n", - " print(\"Generating summaries with Prompt A...\")\n", - " summaries_a = generate_summaries_batch(summarizer_a, data, \"Prompt A\")\n", - "\n", - " print(\"Generating summaries with Prompt B...\")\n", - " summaries_b = generate_summaries_batch(summarizer_b, data, \"Prompt B\")\n", - "\n", - " # Step 2: Prepare comparison data\n", - " temp_file = f\"temp_compare_{eval_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl\"\n", - "\n", - " with open(temp_file, 'w') as f:\n", - " for summary_a, summary_b in zip(summaries_a, summaries_b):\n", - " formatted = {\n", - " \"prompt\": f\"Source article: {summary_a['text'][:5000]}\",\n", - " \"model_a_output\": summary_a['summary'],\n", - " \"model_b_output\": summary_b['summary'],\n", - " \"id\": summary_a['id']\n", - " }\n", - " f.write(json.dumps(formatted) + '\\n')\n", - "\n", - " # Step 3: Upload and evaluate\n", - " print(\"📤 Uploading for comparison...\")\n", - " file_response = client.files.upload(file=temp_file, purpose=\"eval\")\n", - " file_id = file_response.id\n", - "\n", - " print(\"🚀 Launching comparison...\")\n", - " eval_response = client.evaluation.create(\n", - " type=\"compare\",\n", - " input_data_file_path=file_id,\n", - " judge_model=JUDGE_MODEL,\n", - " judge_model_source=\"serverless\",\n", - " judge_system_template=JUDGE_PROMPT,\n", - " model_a=\"model_a_output\",\n", - " model_b=\"model_b_output\"\n", - " )\n", - "\n", - " # Step 4: Wait and get results\n", - " print(f\"⏳ Waiting (ID: {eval_response.workflow_id})...\")\n", - " while True:\n", - " status = client.evaluation.status(eval_response.workflow_id)\n", - " if status.status.value == \"completed\":\n", - " break\n", - " elif status.status.value == \"failed\":\n", - " raise Exception(\"Evaluation failed\")\n", - " time.sleep(30)\n", - "\n", - " a_wins = status.results.get('A_wins', 0)\n", - " b_wins = status.results.get('B_wins', 0)\n", - " ties = status.results.get('Ties', 0)\n", - "\n", - " # Win rate for prompt A\n", - " decisive_total = a_wins + b_wins\n", - " if decisive_total > 0:\n", - " a_win_rate = a_wins / decisive_total\n", - " b_win_rate = b_wins / decisive_total\n", - " else:\n", - " a_win_rate = b_win_rate = 0.5\n", - "\n", - " print(f\"✓ Results: Prompt A wins={a_wins}, Prompt B wins={b_wins}, Ties={ties}\")\n", - " print(f\"✓ Prompt A win rate: {a_win_rate:.2%}\")\n", - "\n", - " os.remove(temp_file)\n", - "\n", - " return a_win_rate, b_win_rate, {\n", - " 'a_wins': a_wins,\n", - " 'b_wins': b_wins,\n", - " 'ties': ties,\n", - " 'a_win_rate': a_win_rate\n", - " }\n", - "\n", - "print(\"✓ Comparison function defined\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "6657d33b050676ff" - }, - "source": [ - "## 🧬 GEPA Optimization Loop\n", - "\n", - "This is the main optimization loop that implements the GEPA algorithm:\n", - "\n", - "1. **Generate**: Create summaries with current prompt\n", - "2. **Evaluate**: Compare against baseline using judge model\n", - "3. **Propose**: Use optimizer LLM to suggest improvements\n", - "4. **Adapt**: Accept improvements that increase win rate\n", - "\n", - "The process repeats for multiple iterations, tracking the best prompt found." - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" + "cell_type": "code", + "execution_count": 6, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "b8ca2917024c326e", + "outputId": "171c4567-9971-499a-edad-04b67c858885" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\u2713 Summarization module defined\n" + ] + } + ], + "source": [ + "class Summarizer(dspy.Signature):\n", + " \"\"\"Generate a summary.\"\"\"\n", + " text = dspy.InputField()\n", + " summary = dspy.OutputField()\n", + "\n", + "\n", + "class SummarizationModule(dspy.Module):\n", + " \"\"\"Summarization module.\"\"\"\n", + "\n", + " def __init__(self, instructions=None):\n", + " super().__init__()\n", + " self.instructions = instructions or BASELINE_PROMPT\n", + "\n", + " if instructions:\n", + " class CustomSummarizer(dspy.Signature):\n", + " __doc__ = instructions\n", + " text = dspy.InputField()\n", + " summary = dspy.OutputField()\n", + "\n", + " self.predictor = dspy.Predict(CustomSummarizer)\n", + " else:\n", + " self.predictor = dspy.Predict(Summarizer)\n", + "\n", + " def forward(self, text):\n", + " return self.predictor(text=text)\n", + "\n", + "print(\"\u2713 Summarization module defined\")" + ] }, - "id": "c7100da955cfb3b5", - "outputId": "1144337a-d273-452a-84bf-4ad959363cd1" - }, - "outputs": [ { - "output_type": "stream", - "name": "stdout", - "text": [ - "✓ GEPA optimization function defined\n" - ] - } - ], - "source": [ - "def run_manual_gepa(\n", - " train_data: List[Dict],\n", - " val_data: List[Dict],\n", - " test_data: List[Dict],\n", - " summarizer_lm: dspy.LM,\n", - " optimizer_lm: SimpleOptimizerLM,\n", - " max_iterations: int = 5\n", - "):\n", - " \"\"\"Manual GEPA-style optimization.\"\"\"\n", - "\n", - " print(\"\\n\" + \"=\" * 80)\n", - " print(\"🧬 MANUAL GEPA OPTIMIZATION\")\n", - " print(\"=\" * 80)\n", - "\n", - " best_prompt = BASELINE_PROMPT\n", - " best_val_score = 0.5 # Start at 50% (neutral)\n", - "\n", - " for i in range(max_iterations):\n", - " print(f\"\\n{'=' * 80}\")\n", - " print(f\"ITERATION {i + 1}/{max_iterations}\")\n", - " print(f\"{'=' * 80}\")\n", - "\n", - " if i == 0:\n", - " print(\"Iteration 0: Establishing baseline (no comparison yet)\")\n", - " continue\n", - "\n", - " new_prompt = reflect_and_improve_prompt(\n", - " best_prompt,\n", - " best_val_score,\n", - " optimizer_lm,\n", - " i\n", - " )\n", - "\n", - " if new_prompt == best_prompt:\n", - " print(\"⚠️ No change in prompt, stopping\")\n", - " break\n", - "\n", - " print(f\"✓ Generated candidate prompt ({len(new_prompt)} chars)\")\n", - "\n", - " # Compare best_prompt vs new_prompt on validation set\n", - " baseline_win_rate, new_prompt_win_rate, metrics = compare_two_prompts_on_batch(\n", - " val_data,\n", - " prompt_a=best_prompt,\n", - " prompt_b=new_prompt,\n", - " summarizer_lm=summarizer_lm,\n", - " eval_name=f\"iter{i}_val\"\n", - " )\n", - "\n", - " new_prompt_win_rate = 1.0 - baseline_win_rate\n", - "\n", - " print(f\"\\n Current best: {baseline_win_rate:.2%}\")\n", - " print(f\" New candidate: {new_prompt_win_rate:.2%}\")\n", - "\n", - " if new_prompt_win_rate > best_val_score:\n", - " improvement = new_prompt_win_rate - best_val_score\n", - " print(f\" 🎉 New best! (+{improvement * 100:.2f}pp)\")\n", - " best_prompt = new_prompt\n", - " best_val_score = new_prompt_win_rate\n", - " else:\n", - " print(f\" No improvement\")\n", - "\n", - " print(\"\\n\" + \"=\" * 80)\n", - " print(\"📊 FINAL TEST EVALUATION\")\n", - " print(\"=\" * 80)\n", - "\n", - " baseline_test_win_rate, optimized_test_win_rate, _ = compare_two_prompts_on_batch(\n", - " test_data,\n", - " prompt_a=BASELINE_PROMPT,\n", - " prompt_b=best_prompt,\n", - " summarizer_lm=summarizer_lm,\n", - " eval_name=\"final_test\"\n", - " )\n", - "\n", - " print(\"\\n\" + \"=\" * 80)\n", - " print(\"🎉 FINAL RESULTS\")\n", - " print(\"=\" * 80)\n", - "\n", - " print(f\"\\nTEST SET:\")\n", - " print(f\" Baseline prompt: {baseline_test_win_rate:.2%}\")\n", - " print(f\" Optimized prompt: {optimized_test_win_rate:.2%}\")\n", - " print(f\" Improvement: {(optimized_test_win_rate - 0.5) * 100:+.2f}pp from neutral\")\n", - "\n", - " output_dir = Path(\"results\")\n", - " output_dir.mkdir(exist_ok=True)\n", - "\n", - " timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n", - "\n", - " with open(output_dir / f\"prompts_{timestamp}.txt\", 'w') as f:\n", - " f.write(\"BASELINE:\\n\" + \"=\" * 80 + \"\\n\")\n", - " f.write(BASELINE_PROMPT)\n", - " f.write(\"\\n\\nOPTIMIZED:\\n\" + \"=\" * 80 + \"\\n\")\n", - " f.write(best_prompt)\n", - " f.write(f\"\\n\\nRESULTS:\\n\" + \"=\" * 80 + \"\\n\")\n", - " f.write(f\"Baseline: {baseline_test_win_rate:.2%}\\n\")\n", - " f.write(f\"Optimized: {optimized_test_win_rate:.2%}\\n\")\n", - "\n", - " print(f\"\\n💾 Saved to: results/prompts_{timestamp}.txt\")\n", - "\n", - " return {\n", - " 'baseline_test': baseline_test_win_rate,\n", - " 'optimized_test': optimized_test_win_rate,\n", - " 'best_prompt': best_prompt\n", - " }\n", - "\n", - "print(\"✓ GEPA optimization function defined\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "4839066f78acf10d" - }, - "source": [ - "## 🚀 Run the Optimization\n", - "\n", - "Now we'll execute the full GEPA optimization process. This will:\n", - "1. Set up the summarizer and optimizer models\n", - "2. Run multiple iterations of prompt improvement\n", - "3. Evaluate the final optimized prompt on the test set\n", - "4. Display comprehensive results" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" + "cell_type": "markdown", + "metadata": { + "id": "590d6b9c625ca2cc" + }, + "source": [ + "## \ud83d\udcca Batch Summary Generation\n", + "\n", + "This function generates summaries for a batch of articles using a given prompt. It includes error handling and progress tracking." + ] }, - "id": "51f60931bec8f490", - "outputId": "1b34ac6f-0d40-46c9-d9df-6ac6c699cb66" - }, - "outputs": [ { - "output_type": "stream", - "name": "stdout", - "text": [ - "================================================================================\n", - "🎯 GEPA SUMMARIZATION - TOGETHER AI BATCH EVAL\n", - "================================================================================\n", - "\n", - "================================================================================\n", - "🧬 MANUAL GEPA OPTIMIZATION\n", - "================================================================================\n", - "\n", - "================================================================================\n", - "ITERATION 1/5\n", - "================================================================================\n", - "Iteration 0: Establishing baseline (no comparison yet)\n", - "\n", - "================================================================================\n", - "ITERATION 2/5\n", - "================================================================================\n", - "\n", - "🤔 REFLECTION (Iteration 1)\n", - "✓ Generated new prompt (63 words)\n", - "✓ Generated candidate prompt (404 chars)\n", - "\n", - "================================================================================\n", - "🔄 COMPARING PROMPTS: iter1_val\n", - "================================================================================\n", - "Generating summaries with Prompt A...\n", - " Using prompt: Summarize this news article in 3-5 key points.\n", - "\n", - "Write a brief summary covering:\n", - "- The main news even...\n" - ] + "cell_type": "code", + "execution_count": 7, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "270abdde73d2ca72", + "outputId": "6eafb2d3-e773-4a65-f3b5-802687fffafc" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\u2713 Batch generation function defined\n" + ] + } + ], + "source": [ + "def generate_summaries_batch(\n", + " summarizer: SummarizationModule,\n", + " data: List[Dict],\n", + " desc: str = \"Generating\"\n", + ") -> List[Dict]:\n", + " \"\"\"Generate summaries for a batch of texts.\"\"\"\n", + " results = []\n", + " errors = 0\n", + " error_details = []\n", + "\n", + " # Print the prompt being used (first item only)\n", + " if len(data) > 0:\n", + " print(f\" Using prompt: {summarizer.instructions[:100]}...\")\n", + "\n", + " for item in tqdm(data, desc=desc):\n", + " try:\n", + " pred = summarizer(text=item['text'][:5000])\n", + "\n", + " if pred is None:\n", + " raise ValueError(\"Model returned None\")\n", + "\n", + " if hasattr(pred, 'summary') and pred.summary:\n", + " summary = pred.summary\n", + " elif isinstance(pred, str):\n", + " summary = pred\n", + " else:\n", + " print(f\"\\n DEBUG: pred type={type(pred)}, hasattr summary={hasattr(pred, 'summary')}\")\n", + " raise ValueError(f\"Cannot extract summary from {type(pred)}\")\n", + "\n", + " summary = summary.strip()\n", + " if len(summary) < 20:\n", + " raise ValueError(\"Summary too short\")\n", + "\n", + " except Exception as e:\n", + " errors += 1\n", + " error_details.append(str(e)[:100])\n", + "\n", + " if errors <= 5:\n", + " print(f\"\\n\u26a0\ufe0f Error: {str(e)[:80]}\")\n", + "\n", + " summary = \"Error generating summary.\"\n", + "\n", + " results.append({\n", + " 'id': item['id'],\n", + " 'text': item['text'],\n", + " 'summary': summary\n", + " })\n", + "\n", + " if errors > 0:\n", + " print(f\"\\n\u26a0\ufe0f Total errors: {errors}/{len(data)} ({errors / len(data) * 100:.1f}%)\")\n", + " from collections import Counter\n", + " common_errors = Counter(error_details).most_common(3)\n", + " print(f\" Most common errors:\")\n", + " for err, count in common_errors:\n", + " print(f\" - {err[:60]}... ({count}x)\")\n", + "\n", + " return results\n", + "\n", + "print(\"\u2713 Batch generation function defined\")" + ] }, { - "output_type": "stream", - "name": "stderr", - "text": [ - "Prompt A: 100%|██████████| 300/300 [14:30<00:00, 2.90s/it]\n" - ] + "cell_type": "markdown", + "metadata": { + "id": "2cfe63f485894d7c" + }, + "source": [ + "## \ud83e\udde0 Optimizer LLM Wrapper\n", + "\n", + "This wrapper allows us to use an LLM to propose improvements to our summarization prompt based on current performance." + ] }, { - "output_type": "stream", - "name": "stdout", - "text": [ - "Generating summaries with Prompt B...\n", - " Using prompt: Summarize this news article in 4-6 sentences, focusing on clarity and concision.\n", - "\n", - "Please cover the f...\n" - ] + "cell_type": "code", + "execution_count": 8, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "d11af9ff91f442df", + "outputId": "c9cd0f0e-7325-46cc-d065-d4a3745c08c3" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\u2713 Optimizer LLM wrapper defined\n" + ] + } + ], + "source": [ + "class SimpleOptimizerLM:\n", + " \"\"\"Wrapper for optimizer LLM.\"\"\"\n", + "\n", + " def __init__(self, model: str, api_key: str):\n", + " self.client = together.Client(api_key=api_key)\n", + " self.model = model\n", + "\n", + " def __call__(self, prompt: str) -> str:\n", + " response = self.client.chat.completions.create(\n", + " model=self.model,\n", + " messages=[{\"role\": \"user\", \"content\": prompt}],\n", + " temperature=0.7,\n", + " max_tokens=4000\n", + " )\n", + " return response.choices[0].message.content\n", + "\n", + "print(\"\u2713 Optimizer LLM wrapper defined\")" + ] }, { - "output_type": "stream", - "name": "stderr", - "text": [ - "Prompt B: 100%|██████████| 300/300 [17:16<00:00, 3.46s/it]\n" - ] + "cell_type": "markdown", + "metadata": { + "id": "67a224aff87d2f5e" + }, + "source": [ + "## \ud83e\udd14 Reflection and Prompt Improvement\n", + "\n", + "This function uses the optimizer LLM to analyze the current prompt and performance, then propose an improved version.\n", + "\n", + "**Key Constraints:**\n", + "- Keep prompts under 150 words for clarity\n", + "- Focus on simple, direct instructions\n", + "- Target 4-6 sentence summaries\n", + "- Avoid overly complex requirements" + ] }, { - "output_type": "stream", - "name": "stdout", - "text": [ - "📤 Uploading for comparison...\n" - ] + "cell_type": "code", + "execution_count": 9, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "1186e66cab3ea1f1", + "outputId": "a8ea71b8-da99-4efa-c72b-59603458e664" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\u2713 Reflection function defined\n" + ] + } + ], + "source": [ + "def reflect_and_improve_prompt(\n", + " current_prompt: str,\n", + " current_score: float,\n", + " optimizer_lm: SimpleOptimizerLM,\n", + " iteration: int\n", + ") -> str:\n", + " \"\"\"Use LLM to propose improved prompt.\"\"\"\n", + "\n", + " print(f\"\\n\ud83e\udd14 REFLECTION (Iteration {iteration})\")\n", + "\n", + " reflection_prompt = f\"\"\"You are optimizing a summarization prompt for CNN/DailyMail news articles.\n", + "\n", + "Current Prompt:\n", + "```\n", + "{current_prompt}\n", + "```\n", + "\n", + "Current Performance: {current_score:.1%} win rate\n", + "\n", + "Your task: Propose a SIMPLE improved version that generates better summaries.\n", + "\n", + "CRITICAL CONSTRAINTS:\n", + "- Keep the prompt under 150 words\n", + "- Make it clear and direct (NOT overly complex)\n", + "- Target 4-6 sentence summaries\n", + "- Avoid excessive instructions or formatting requirements\n", + "- The prompt should be easy for the model to follow\n", + "\n", + "Focus on:\n", + "- Should it emphasize different aspects (accuracy, brevity, completeness)?\n", + "- Are the current guidelines clear?\n", + "- Is anything missing or unnecessary?\n", + "\n", + "Output ONLY the improved prompt within ``` blocks. Keep it simple and clear.\"\"\"\n", + "\n", + " response = optimizer_lm(reflection_prompt)\n", + "\n", + " # Extract prompt\n", + " match = re.search(r'```(.*?)```', response, re.DOTALL)\n", + " if match:\n", + " new_prompt = match.group(1).strip()\n", + " # Remove language tags\n", + " for tag in ['markdown', 'text', 'python', 'plaintext']:\n", + " if new_prompt.startswith(f'{tag}\\n'):\n", + " new_prompt = '\\n'.join(new_prompt.split('\\n')[1:])\n", + "\n", + " # Validate length (reject if too long)\n", + " word_count = len(new_prompt.split())\n", + " if word_count > 200:\n", + " print(f\" \u26a0\ufe0f Generated prompt too long ({word_count} words), using current\")\n", + " return current_prompt\n", + "\n", + " print(f\"\u2713 Generated new prompt ({word_count} words)\")\n", + " return new_prompt\n", + "\n", + " print(\"\u26a0\ufe0f Could not extract prompt\")\n", + " return current_prompt\n", + "\n", + "print(\"\u2713 Reflection function defined\")" + ] }, { - "output_type": "stream", - "name": "stderr", - "text": [ - "Uploading file temp_compare_iter1_val_20251222_170518.jsonl: 100%|██████████| 1.59M/1.59M [00:00<00:00, 2.82MB/s]\n" - ] + "cell_type": "markdown", + "metadata": { + "id": "a2fbbd02f5054425" + }, + "source": [ + "## \ud83d\udd04 Head-to-Head Prompt Comparison\n", + "\n", + "This function compares two prompts by:\n", + "1. Generating summaries with both prompts\n", + "2. Creating a comparison dataset\n", + "3. Using the Together AI evaluation API with a judge model\n", + "4. Computing win rates\n", + "\n", + "The evaluation uses a two-pass approach to eliminate position bias." + ] }, { - "output_type": "stream", - "name": "stdout", - "text": [ - "🚀 Launching comparison...\n", - "⏳ Waiting (ID: eval-94eb-1766423120)...\n", - "✓ Results: Prompt A wins=29, Prompt B wins=35, Ties=236\n", - "✓ Prompt A win rate: 45.31%\n", - "\n", - " Current best: 45.31%\n", - " New candidate: 54.69%\n", - " 🎉 New best! (+4.69pp)\n", - "\n", - "================================================================================\n", - "ITERATION 3/5\n", - "================================================================================\n", - "\n", - "🤔 REFLECTION (Iteration 2)\n", - "✓ Generated new prompt (58 words)\n", - "✓ Generated candidate prompt (389 chars)\n", - "\n", - "================================================================================\n", - "🔄 COMPARING PROMPTS: iter2_val\n", - "================================================================================\n", - "Generating summaries with Prompt A...\n", - " Using prompt: Summarize this news article in 4-6 sentences, focusing on clarity and concision.\n", - "\n", - "Please cover the f...\n" - ] + "cell_type": "code", + "execution_count": 10, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "5a1b2d5116f3731f", + "outputId": "f6aa5880-7905-4acc-c9ab-b01dc2b6a30f" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\u2713 Comparison function defined\n" + ] + } + ], + "source": [ + "def compare_two_prompts_on_batch(\n", + " data: List[Dict],\n", + " prompt_a: str,\n", + " prompt_b: str,\n", + " summarizer_lm: dspy.LM,\n", + " eval_name: str\n", + ") -> Tuple[float, float, Dict]:\n", + " \"\"\"\n", + " Compare two summarization prompts.\n", + "\n", + " 1. Generate summaries with prompt A\n", + " 2. Generate summaries with prompt B\n", + " 3. Use judge to compare them\n", + " 4. Return win rate for prompt A\n", + " \"\"\"\n", + "\n", + " print(f\"\\n{'=' * 80}\")\n", + " print(f\"\ud83d\udd04 COMPARING PROMPTS: {eval_name}\")\n", + " print(f\"{'=' * 80}\")\n", + "\n", + " # Step 1: Generate with both prompts\n", + " dspy.configure(lm=summarizer_lm)\n", + "\n", + " summarizer_a = SummarizationModule(prompt_a)\n", + " summarizer_b = SummarizationModule(prompt_b)\n", + "\n", + " print(\"Generating summaries with Prompt A...\")\n", + " summaries_a = generate_summaries_batch(summarizer_a, data, \"Prompt A\")\n", + "\n", + " print(\"Generating summaries with Prompt B...\")\n", + " summaries_b = generate_summaries_batch(summarizer_b, data, \"Prompt B\")\n", + "\n", + " # Step 2: Prepare comparison data\n", + " temp_file = f\"temp_compare_{eval_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl\"\n", + "\n", + " with open(temp_file, 'w') as f:\n", + " for summary_a, summary_b in zip(summaries_a, summaries_b):\n", + " formatted = {\n", + " \"prompt\": f\"Source article: {summary_a['text'][:5000]}\",\n", + " \"model_a_output\": summary_a['summary'],\n", + " \"model_b_output\": summary_b['summary'],\n", + " \"id\": summary_a['id']\n", + " }\n", + " f.write(json.dumps(formatted) + '\\n')\n", + "\n", + " # Step 3: Upload and evaluate\n", + " print(\"\ud83d\udce4 Uploading for comparison...\")\n", + " file_response = client.files.upload(file=temp_file, purpose=\"eval\")\n", + " file_id = file_response.id\n", + "\n", + " print(\"\ud83d\ude80 Launching comparison...\")\n", + " eval_response = client.evaluation.create(\n", + " type=\"compare\",\n", + " input_data_file_path=file_id,\n", + " judge_model=JUDGE_MODEL,\n", + " judge_model_source=\"serverless\",\n", + " judge_system_template=JUDGE_PROMPT,\n", + " model_a=\"model_a_output\",\n", + " model_b=\"model_b_output\"\n", + " )\n", + "\n", + " # Step 4: Wait and get results\n", + " print(f\"\u23f3 Waiting (ID: {eval_response.workflow_id})...\")\n", + " while True:\n", + " status = client.evaluation.status(eval_response.workflow_id)\n", + " if status.status.value == \"completed\":\n", + " break\n", + " elif status.status.value == \"failed\":\n", + " raise Exception(\"Evaluation failed\")\n", + " time.sleep(30)\n", + "\n", + " a_wins = status.results.get('A_wins', 0)\n", + " b_wins = status.results.get('B_wins', 0)\n", + " ties = status.results.get('Ties', 0)\n", + "\n", + " # Win rate for prompt A\n", + " decisive_total = a_wins + b_wins\n", + " if decisive_total > 0:\n", + " a_win_rate = a_wins / decisive_total\n", + " b_win_rate = b_wins / decisive_total\n", + " else:\n", + " a_win_rate = b_win_rate = 0.5\n", + "\n", + " print(f\"\u2713 Results: Prompt A wins={a_wins}, Prompt B wins={b_wins}, Ties={ties}\")\n", + " print(f\"\u2713 Prompt A win rate: {a_win_rate:.2%}\")\n", + "\n", + " os.remove(temp_file)\n", + "\n", + " return a_win_rate, b_win_rate, {\n", + " 'a_wins': a_wins,\n", + " 'b_wins': b_wins,\n", + " 'ties': ties,\n", + " 'a_win_rate': a_win_rate\n", + " }\n", + "\n", + "print(\"\u2713 Comparison function defined\")" + ] }, { - "output_type": "stream", - "name": "stderr", - "text": [ - "Prompt A: 100%|██████████| 300/300 [00:39<00:00, 7.68it/s]\n" - ] + "cell_type": "markdown", + "metadata": { + "id": "6657d33b050676ff" + }, + "source": [ + "## \ud83e\uddec GEPA Optimization Loop\n", + "\n", + "This is the main optimization loop that implements the GEPA algorithm:\n", + "\n", + "1. **Generate**: Create summaries with current prompt\n", + "2. **Evaluate**: Compare against baseline using judge model\n", + "3. **Propose**: Use optimizer LLM to suggest improvements\n", + "4. **Adapt**: Accept improvements that increase win rate\n", + "\n", + "The process repeats for multiple iterations, tracking the best prompt found." + ] }, { - "output_type": "stream", - "name": "stdout", - "text": [ - "Generating summaries with Prompt B...\n", - " Using prompt: Write a 4-6 sentence summary of this news article, prioritizing clarity and accuracy. \n", - "\n", - "Clearly stat...\n" - ] + "cell_type": "code", + "execution_count": 11, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "c7100da955cfb3b5", + "outputId": "1144337a-d273-452a-84bf-4ad959363cd1" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\u2713 GEPA optimization function defined\n" + ] + } + ], + "source": [ + "def run_manual_gepa(\n train_data: List[Dict],\n val_data: List[Dict],\n test_data: List[Dict],\n summarizer_lm: dspy.LM,\n optimizer_lm: SimpleOptimizerLM,\n max_iterations: int = 5\n):\n \"\"\"Manual GEPA-style optimization.\"\"\"\n\n print(\"\\n\" + \"=\" * 80)\n print(\"\ud83e\uddec MANUAL GEPA OPTIMIZATION\")\n print(\"=\" * 80)\n\n best_prompt = BASELINE_PROMPT\n best_val_score = 0.5 # Start at 50% (neutral)\n\n for i in range(max_iterations):\n print(f\"\\n{'=' * 80}\")\n print(f\"ITERATION {i + 1}/{max_iterations}\")\n print(f\"{'=' * 80}\")\n\n if i == 0:\n print(\"Iteration 0: Establishing baseline (no comparison yet)\")\n continue\n\n new_prompt = reflect_and_improve_prompt(\n best_prompt,\n best_val_score,\n optimizer_lm,\n i\n )\n\n if new_prompt == best_prompt:\n print(\"\u26a0\ufe0f No change in prompt, stopping\")\n break\n\n print(f\"\u2713 Generated candidate prompt ({len(new_prompt)} chars)\")\n\n # Compare BASELINE_PROMPT vs new_prompt on validation set\n # This ensures we always measure absolute improvement from the original baseline\n baseline_win_rate, new_prompt_win_rate, metrics = compare_two_prompts_on_batch(\n val_data,\n prompt_a=BASELINE_PROMPT,\n prompt_b=new_prompt,\n summarizer_lm=summarizer_lm,\n eval_name=f\"iter{i}_val\"\n )\n\n new_prompt_win_rate = 1.0 - baseline_win_rate\n\n print(f\"\\n Baseline (original): {baseline_win_rate:.2%}\")\n print(f\" New candidate: {new_prompt_win_rate:.2%}\")\n\n if new_prompt_win_rate > best_val_score:\n improvement = new_prompt_win_rate - best_val_score\n print(f\" \ud83c\udf89 New best! (+{improvement * 100:.2f}pp)\")\n best_prompt = new_prompt\n best_val_score = new_prompt_win_rate\n else:\n print(f\" No improvement\")\n\n print(\"\\n\" + \"=\" * 80)\n print(\"\ud83d\udcca FINAL TEST EVALUATION\")\n print(\"=\" * 80)\n\n baseline_test_win_rate, optimized_test_win_rate, _ = compare_two_prompts_on_batch(\n test_data,\n prompt_a=BASELINE_PROMPT,\n prompt_b=best_prompt,\n summarizer_lm=summarizer_lm,\n eval_name=\"final_test\"\n )\n\n print(\"\\n\" + \"=\" * 80)\n print(\"\ud83c\udf89 FINAL RESULTS\")\n print(\"=\" * 80)\n\n print(f\"\\nTEST SET:\")\n print(f\" Baseline prompt: {baseline_test_win_rate:.2%}\")\n print(f\" Optimized prompt: {optimized_test_win_rate:.2%}\")\n print(f\" Improvement: {(optimized_test_win_rate - 0.5) * 100:+.2f}pp from neutral\")\n\n output_dir = Path(\"results\")\n output_dir.mkdir(exist_ok=True)\n\n timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n\n with open(output_dir / f\"prompts_{timestamp}.txt\", 'w') as f:\n f.write(\"BASELINE:\\n\" + \"=\" * 80 + \"\\n\")\n f.write(BASELINE_PROMPT)\n f.write(\"\\n\\nOPTIMIZED:\\n\" + \"=\" * 80 + \"\\n\")\n f.write(best_prompt)\n f.write(f\"\\n\\nRESULTS:\\n\" + \"=\" * 80 + \"\\n\")\n f.write(f\"Baseline: {baseline_test_win_rate:.2%}\\n\")\n f.write(f\"Optimized: {optimized_test_win_rate:.2%}\\n\")\n\n print(f\"\\n\ud83d\udcbe Saved to: results/prompts_{timestamp}.txt\")\n\n return {\n 'baseline_test': baseline_test_win_rate,\n 'optimized_test': optimized_test_win_rate,\n 'best_prompt': best_prompt\n }\n\nprint(\"\u2713 GEPA optimization function defined\")" + ] }, { - "output_type": "stream", - "name": "stderr", - "text": [ - "Prompt B: 38%|███▊ | 113/300 [06:12<09:41, 3.11s/it]" - ] + "cell_type": "markdown", + "metadata": { + "id": "4839066f78acf10d" + }, + "source": [ + "## \ud83d\ude80 Run the Optimization\n", + "\n", + "Now we'll execute the full GEPA optimization process. This will:\n", + "1. Set up the summarizer and optimizer models\n", + "2. Run multiple iterations of prompt improvement\n", + "3. Evaluate the final optimized prompt on the test set\n", + "4. Display comprehensive results" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "51f60931bec8f490", + "outputId": "8db2192e-b3df-4714-a218-57fb2caf49a9" + }, + "outputs": [ + { + "metadata": { + "tags": null + }, + "name": "stdout", + "output_type": "stream", + "text": [ + "================================================================================\n", + "\ud83c\udfaf GEPA SUMMARIZATION - TOGETHER AI BATCH EVAL\n", + "================================================================================\n", + "\n", + "================================================================================\n", + "\ud83e\uddec MANUAL GEPA OPTIMIZATION\n", + "================================================================================\n", + "\n", + "================================================================================\n", + "ITERATION 1/5\n", + "================================================================================\n", + "Iteration 0: Establishing baseline (no comparison yet)\n", + "\n", + "================================================================================\n", + "ITERATION 2/5\n", + "================================================================================\n", + "\n", + "\ud83e\udd14 REFLECTION (Iteration 1)\n", + "\u2713 Generated new prompt (63 words)\n", + "\u2713 Generated candidate prompt (404 chars)\n", + "\n", + "================================================================================\n", + "\ud83d\udd04 COMPARING PROMPTS: iter1_val\n", + "================================================================================\n", + "Generating summaries with Prompt A...\n", + " Using prompt: Summarize this news article in 3-5 key points.\n", + "\n", + "Write a brief summary covering:\n", + "- The main news even...\n" + ] + }, + { + "metadata": { + "tags": null + }, + "name": "stderr", + "output_type": "stream", + "text": [ + "Prompt A: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 300/300 [14:30<00:00, 2.90s/it]\n" + ] + }, + { + "metadata": { + "tags": null + }, + "name": "stdout", + "output_type": "stream", + "text": [ + "Generating summaries with Prompt B...\n", + " Using prompt: Summarize this news article in 4-6 sentences, focusing on clarity and concision.\n", + "\n", + "Please cover the f...\n" + ] + }, + { + "metadata": { + "tags": null + }, + "name": "stderr", + "output_type": "stream", + "text": [ + "Prompt B: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 300/300 [17:16<00:00, 3.46s/it]\n" + ] + }, + { + "metadata": { + "tags": null + }, + "name": "stdout", + "output_type": "stream", + "text": [ + "\ud83d\udce4 Uploading for comparison...\n" + ] + }, + { + "metadata": { + "tags": null + }, + "name": "stderr", + "output_type": "stream", + "text": [ + "Uploading file temp_compare_iter1_val_20251222_170518.jsonl: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 1.59M/1.59M [00:00<00:00, 2.82MB/s]\n" + ] + }, + { + "metadata": { + "tags": null + }, + "name": "stdout", + "output_type": "stream", + "text": [ + "\ud83d\ude80 Launching comparison...\n", + "\u23f3 Waiting (ID: eval-94eb-1766423120)...\n", + "\u2713 Results: Prompt A wins=29, Prompt B wins=35, Ties=236\n", + "\u2713 Prompt A win rate: 45.31%\n", + "\n", + " Current best: 45.31%\n", + " New candidate: 54.69%\n", + " \ud83c\udf89 New best! (+4.69pp)\n", + "\n", + "================================================================================\n", + "ITERATION 3/5\n", + "================================================================================\n", + "\n", + "\ud83e\udd14 REFLECTION (Iteration 2)\n", + "\u2713 Generated new prompt (58 words)\n", + "\u2713 Generated candidate prompt (389 chars)\n", + "\n", + "================================================================================\n", + "\ud83d\udd04 COMPARING PROMPTS: iter2_val\n", + "================================================================================\n", + "Generating summaries with Prompt A...\n", + " Using prompt: Summarize this news article in 4-6 sentences, focusing on clarity and concision.\n", + "\n", + "Please cover the f...\n" + ] + }, + { + "metadata": { + "tags": null + }, + "name": "stderr", + "output_type": "stream", + "text": [ + "Prompt A: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 300/300 [00:39<00:00, 7.68it/s]\n" + ] + }, + { + "metadata": { + "tags": null + }, + "name": "stdout", + "output_type": "stream", + "text": [ + "Generating summaries with Prompt B...\n", + " Using prompt: Write a 4-6 sentence summary of this news article, prioritizing clarity and accuracy. \n", + "\n", + "Clearly stat...\n" + ] + }, + { + "metadata": { + "tags": null + }, + "name": "stderr", + "output_type": "stream", + "text": [ + "Prompt B: 82%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u258f | 247/300 [12:58<02:41, 3.06s/it]2025/12/22 17:30:11 WARNING dspy.clients.lm: LM response was truncated due to exceeding max_tokens=1024. You can inspect the latest LM interactions with `dspy.inspect_history()`. To avoid truncation, consider passing a larger max_tokens when setting up dspy.LM. You may also consider increasing the temperature (currently 0.5) if the reason for truncation is repetition.\n", + "Prompt B: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 300/300 [15:55<00:00, 3.18s/it]\n" + ] + }, + { + "metadata": { + "tags": null + }, + "name": "stdout", + "output_type": "stream", + "text": [ + "\ud83d\udce4 Uploading for comparison...\n" + ] + }, + { + "metadata": { + "tags": null + }, + "name": "stderr", + "output_type": "stream", + "text": [ + "Uploading file temp_compare_iter2_val_20251222_173300.jsonl: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 1.62M/1.62M [00:00<00:00, 3.48MB/s]\n" + ] + }, + { + "metadata": { + "tags": null + }, + "name": "stdout", + "output_type": "stream", + "text": [ + "\ud83d\ude80 Launching comparison...\n", + "\u23f3 Waiting (ID: eval-6faf-1766424783)...\n", + "\u2713 Results: Prompt A wins=34, Prompt B wins=29, Ties=237\n", + "\u2713 Prompt A win rate: 53.97%\n", + "\n", + " Current best: 53.97%\n", + " New candidate: 46.03%\n", + " No improvement\n", + "\n", + "================================================================================\n", + "ITERATION 4/5\n", + "================================================================================\n", + "\n", + "\ud83e\udd14 REFLECTION (Iteration 3)\n", + "\u2713 Generated new prompt (87 words)\n", + "\u2713 Generated candidate prompt (578 chars)\n", + "\n", + "================================================================================\n", + "\ud83d\udd04 COMPARING PROMPTS: iter3_val\n", + "================================================================================\n", + "Generating summaries with Prompt A...\n", + " Using prompt: Summarize this news article in 4-6 sentences, focusing on clarity and concision.\n", + "\n", + "Please cover the f...\n" + ] + }, + { + "metadata": { + "tags": null + }, + "name": "stderr", + "output_type": "stream", + "text": [ + "Prompt A: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 300/300 [00:37<00:00, 8.08it/s]\n" + ] + }, + { + "metadata": { + "tags": null + }, + "name": "stdout", + "output_type": "stream", + "text": [ + "Generating summaries with Prompt B...\n", + " Using prompt: Summarize this news article in 4-6 sentences, focusing on the most important facts. Provide a clear ...\n" + ] + }, + { + "metadata": { + "tags": null + }, + "name": "stderr", + "output_type": "stream", + "text": [ + "Prompt B: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 300/300 [15:51<00:00, 3.17s/it]\n" + ] + }, + { + "metadata": { + "tags": null + }, + "name": "stdout", + "output_type": "stream", + "text": [ + "\ud83d\udce4 Uploading for comparison...\n" + ] + }, + { + "metadata": { + "tags": null + }, + "name": "stderr", + "output_type": "stream", + "text": [ + "Uploading file temp_compare_iter3_val_20251222_181544.jsonl: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 1.65M/1.65M [00:00<00:00, 2.48MB/s]\n" + ] + }, + { + "metadata": { + "tags": null + }, + "name": "stdout", + "output_type": "stream", + "text": [ + "\ud83d\ude80 Launching comparison...\n", + "\u23f3 Waiting (ID: eval-1788-1766427347)...\n", + "\u2713 Results: Prompt A wins=44, Prompt B wins=22, Ties=234\n", + "\u2713 Prompt A win rate: 66.67%\n", + "\n", + " Current best: 66.67%\n", + " New candidate: 33.33%\n", + " No improvement\n", + "\n", + "================================================================================\n", + "ITERATION 5/5\n", + "================================================================================\n", + "\n", + "\ud83e\udd14 REFLECTION (Iteration 4)\n", + "\u2713 Generated new prompt (77 words)\n", + "\u2713 Generated candidate prompt (547 chars)\n", + "\n", + "================================================================================\n", + "\ud83d\udd04 COMPARING PROMPTS: iter4_val\n", + "================================================================================\n", + "Generating summaries with Prompt A...\n", + " Using prompt: Summarize this news article in 4-6 sentences, focusing on clarity and concision.\n", + "\n", + "Please cover the f...\n" + ] + }, + { + "metadata": { + "tags": null + }, + "name": "stderr", + "output_type": "stream", + "text": [ + "Prompt A: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 300/300 [00:40<00:00, 7.47it/s]\n" + ] + }, + { + "metadata": { + "tags": null + }, + "name": "stdout", + "output_type": "stream", + "text": [ + "Generating summaries with Prompt B...\n", + " Using prompt: Summarize this news article in 4-6 sentences, focusing on accuracy, brevity, and clarity.\n", + "\n", + "Clearly s...\n" + ] + }, + { + "metadata": { + "tags": null + }, + "name": "stderr", + "output_type": "stream", + "text": [ + "Prompt B: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 300/300 [16:34<00:00, 3.32s/it]\n" + ] + }, + { + "metadata": { + "tags": null + }, + "name": "stdout", + "output_type": "stream", + "text": [ + "\ud83d\udce4 Uploading for comparison...\n" + ] + }, + { + "metadata": { + "tags": null + }, + "name": "stderr", + "output_type": "stream", + "text": [ + "Uploading file temp_compare_iter4_val_20251222_184909.jsonl: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 1.62M/1.62M [00:00<00:00, 1.77MB/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\ud83d\ude80 Launching comparison...\n", + "\u23f3 Waiting (ID: eval-1e94-1766429353)...\n", + "\u2713 Results: Prompt A wins=45, Prompt B wins=33, Ties=222\n", + "\u2713 Prompt A win rate: 57.69%\n", + "\n", + " Current best: 57.69%\n", + " New candidate: 42.31%\n", + " No improvement\n", + "\n", + "================================================================================\n", + "\ud83d\udcca FINAL TEST EVALUATION\n", + "================================================================================\n", + "\n", + "\u23f1\ufe0f OPTIMIZATION TIME:\n", + " Total: 2h 31m 48s\n", + "\n", + "================================================================================\n", + "\ud83d\udd04 COMPARING PROMPTS: final_test\n", + "================================================================================\n", + "Generating summaries with Prompt A...\n", + " Using prompt: Summarize this news article in 3-5 key points.\n", + "\n", + "Write a brief summary covering:\n", + "- The main news even...\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "Prompt A: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 300/300 [16:05<00:00, 3.22s/it]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Generating summaries with Prompt B...\n", + " Using prompt: Summarize this news article in 4-6 sentences, focusing on clarity and concision.\n", + "\n", + "Please cover the f...\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "Prompt B: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 300/300 [18:27<00:00, 3.69s/it]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\ud83d\udce4 Uploading for comparison...\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "Uploading file temp_compare_final_test_20251222_193951.jsonl: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 1.57M/1.57M [00:00<00:00, 2.74MB/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\ud83d\ude80 Launching comparison...\n", + "\u23f3 Waiting (ID: eval-ff84-1766432395)...\n", + "\u2713 Results: Prompt A wins=25, Prompt B wins=41, Ties=234\n", + "\u2713 Prompt A win rate: 37.88%\n", + "\n", + "================================================================================\n", + "\ud83c\udf89 FINAL RESULTS\n", + "================================================================================\n", + "\n", + "TEST SET:\n", + " Baseline prompt: 37.88%\n", + " Optimized prompt: 62.12%\n", + " Improvement: +12.12pp from neutral\n", + "\n", + "\ud83d\udcbe Saved to: results/prompts_20251222_195058.txt\n", + "\n", + "\u2705 Complete!\n" + ] + } + ], + "source": [ + "print(\"=\"*80)\n", + "print(\"\ud83c\udfaf GEPA SUMMARIZATION - TOGETHER AI BATCH EVAL\")\n", + "print(\"=\"*80)\n", + "\n", + "# Setup models\n", + "summarizer_lm = dspy.LM(\n", + " f\"together_ai/{SUMMARIZER_MODEL}\",\n", + " api_key=TOGETHER_API_KEY,\n", + " temperature=0.5,\n", + " max_tokens=1024\n", + ")\n", + "\n", + "optimizer_lm = SimpleOptimizerLM(\n", + " model=OPTIMIZER_MODEL,\n", + " api_key=TOGETHER_API_KEY,\n", + ")\n", + "\n", + "start_time = time.time()\n", + "\n", + "# Run optimization\n", + "results = run_manual_gepa(\n", + " train_data,\n", + " val_data,\n", + " test_data,\n", + " summarizer_lm,\n", + " optimizer_lm,\n", + " max_iterations=5\n", + ")\n", + "\n", + "print(\"\\n\u2705 Complete!\")\n", + "\n", + "total_time = time.time() - start_time\n", + "hours = int(total_time // 3600)\n", + "minutes = int((total_time % 3600) // 60)\n", + "seconds = int(total_time % 60)\n", + "\n", + "print(f\"\\n\u23f1\ufe0f OPTIMIZATION TIME:\")\n", + "if hours > 0:\n", + " print(f\" Total: {hours}h {minutes}m {seconds}s\")\n", + "elif minutes > 0:\n", + " print(f\" Total: {minutes}m {seconds}s\")\n", + "else:\n", + " print(f\" Total: {seconds}s\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2be0f2bb00a13ff6" + }, + "source": [ + "## \ud83d\udcca Analyzing the Results\n", + "\n", + "Let's examine the optimized prompt and compare it to the baseline." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "bc461eee131bd49f", + "outputId": "60c5b59b-5a88-4c05-f8fd-232acf38a27b" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "================================================================================\n", + "\ud83d\udcdd PROMPT COMPARISON\n", + "================================================================================\n", + "\n", + "BASELINE PROMPT:\n", + "--------------------------------------------------------------------------------\n", + "Summarize this news article in 3-5 key points.\n", + "\n", + "Write a brief summary covering:\n", + "- The main news event\n", + "- Key people or organizations involved\n", + "- Important details or outcomes\n", + "- Any significant context\n", + "\n", + "Keep it to 3-5 sentences total.\n", + "\n", + "\n", + "OPTIMIZED PROMPT:\n", + "--------------------------------------------------------------------------------\n", + "Summarize this news article in 4-6 sentences, focusing on clarity and concision.\n", + "\n", + "Please cover the following key aspects:\n", + "- What is the main news event being reported?\n", + "- Who are the key people or organizations involved?\n", + "- What are the most important details or outcomes of the event?\n", + "\n", + "Provide relevant background information if necessary, but prioritize the essential facts and avoid unnecessary details.\n", + "\n", + "\n", + "PERFORMANCE COMPARISON:\n", + "--------------------------------------------------------------------------------\n", + "Baseline Win Rate: 37.88%\n", + "Optimized Win Rate: 62.12%\n", + "Improvement: +12.12 percentage points from neutral\n" + ] + } + ], + "source": [ + "print(\"=\" * 80)\n", + "print(\"\ud83d\udcdd PROMPT COMPARISON\")\n", + "print(\"=\" * 80)\n", + "\n", + "print(\"\\nBASELINE PROMPT:\")\n", + "print(\"-\" * 80)\n", + "print(BASELINE_PROMPT)\n", + "\n", + "print(\"\\n\\nOPTIMIZED PROMPT:\")\n", + "print(\"-\" * 80)\n", + "print(results['best_prompt'])\n", + "\n", + "print(\"\\n\\nPERFORMANCE COMPARISON:\")\n", + "print(\"-\" * 80)\n", + "print(f\"Baseline Win Rate: {results['baseline_test']:.2%}\")\n", + "print(f\"Optimized Win Rate: {results['optimized_test']:.2%}\")\n", + "print(f\"Improvement: {(results['optimized_test'] - 0.5) * 100:+.2f} percentage points from neutral\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8b606f57d491feb6" + }, + "source": [ + "## \ud83d\udd11 Key Findings\n", + "\n", + "**GEPA Optimization Process:**\n", + "- Iteratively improves prompts through LLM-guided reflection\n", + "- Uses head-to-head comparisons with a judge model\n", + "- Tracks and accepts only improvements over baseline\n", + "\n", + "**Benefits of This Approach:**\n", + "1. **Automated**: No manual prompt engineering required\n", + "2. **Data-driven**: Decisions based on actual performance metrics\n", + "3. **Scalable**: Can optimize for any task with appropriate data\n", + "4. **Transparent**: Clear tracking of improvements across iterations\n", + "\n", + "**Next Steps:**\n", + "- Try with different datasets or domains\n", + "- Experiment with different judge criteria\n", + "- Adjust the optimizer's reflection prompt\n", + "- Increase iterations for potentially better results" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.0" + }, + "colab": { + "provenance": [] } - ], - "source": [ - "print(\"=\"*80)\n", - "print(\"🎯 GEPA SUMMARIZATION - TOGETHER AI BATCH EVAL\")\n", - "print(\"=\"*80)\n", - "\n", - "# Setup models\n", - "summarizer_lm = dspy.LM(\n", - " f\"together_ai/{SUMMARIZER_MODEL}\",\n", - " api_key=TOGETHER_API_KEY,\n", - " temperature=0.5,\n", - " max_tokens=1024\n", - ")\n", - "\n", - "optimizer_lm = SimpleOptimizerLM(\n", - " model=OPTIMIZER_MODEL,\n", - " api_key=TOGETHER_API_KEY,\n", - ")\n", - "\n", - "start_time = time.time()\n", - "\n", - "# Run optimization\n", - "results = run_manual_gepa(\n", - " train_data,\n", - " val_data,\n", - " test_data,\n", - " summarizer_lm,\n", - " optimizer_lm,\n", - " max_iterations=5\n", - ")\n", - "\n", - "print(\"\\n✅ Complete!\")\n", - "\n", - "total_time = time.time() - start_time\n", - "hours = int(total_time // 3600)\n", - "minutes = int((total_time % 3600) // 60)\n", - "seconds = int(total_time % 60)\n", - "\n", - "print(f\"\\n⏱️ OPTIMIZATION TIME:\")\n", - "if hours > 0:\n", - " print(f\" Total: {hours}h {minutes}m {seconds}s\")\n", - "elif minutes > 0:\n", - " print(f\" Total: {minutes}m {seconds}s\")\n", - "else:\n", - " print(f\" Total: {seconds}s\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "2be0f2bb00a13ff6" - }, - "source": [ - "## 📊 Analyzing the Results\n", - "\n", - "Let's examine the optimized prompt and compare it to the baseline." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "bc461eee131bd49f" - }, - "outputs": [], - "source": [ - "print(\"=\" * 80)\n", - "print(\"📝 PROMPT COMPARISON\")\n", - "print(\"=\" * 80)\n", - "\n", - "print(\"\\nBASELINE PROMPT:\")\n", - "print(\"-\" * 80)\n", - "print(BASELINE_PROMPT)\n", - "\n", - "print(\"\\n\\nOPTIMIZED PROMPT:\")\n", - "print(\"-\" * 80)\n", - "print(results['best_prompt'])\n", - "\n", - "print(\"\\n\\nPERFORMANCE COMPARISON:\")\n", - "print(\"-\" * 80)\n", - "print(f\"Baseline Win Rate: {results['baseline_test']:.2%}\")\n", - "print(f\"Optimized Win Rate: {results['optimized_test']:.2%}\")\n", - "print(f\"Improvement: {(results['optimized_test'] - 0.5) * 100:+.2f} percentage points from neutral\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8b606f57d491feb6" - }, - "source": [ - "## 🔑 Key Findings\n", - "\n", - "**GEPA Optimization Process:**\n", - "- Iteratively improves prompts through LLM-guided reflection\n", - "- Uses head-to-head comparisons with a judge model\n", - "- Tracks and accepts only improvements over baseline\n", - "\n", - "**Benefits of This Approach:**\n", - "1. **Automated**: No manual prompt engineering required\n", - "2. **Data-driven**: Decisions based on actual performance metrics\n", - "3. **Scalable**: Can optimize for any task with appropriate data\n", - "4. **Transparent**: Clear tracking of improvements across iterations\n", - "\n", - "**Next Steps:**\n", - "- Try with different datasets or domains\n", - "- Experiment with different judge criteria\n", - "- Adjust the optimizer's reflection prompt\n", - "- Increase iterations for potentially better results" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.0" }, - "colab": { - "provenance": [] - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file From e7048096fc3a73473ae7e73bcd4b1ac9297018c8 Mon Sep 17 00:00:00 2001 From: jli Date: Tue, 23 Dec 2025 11:52:35 +0800 Subject: [PATCH 6/6] update --- Evals/Prompt_Optimization.ipynb | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/Evals/Prompt_Optimization.ipynb b/Evals/Prompt_Optimization.ipynb index e6c8ba5..b1711c6 100644 --- a/Evals/Prompt_Optimization.ipynb +++ b/Evals/Prompt_Optimization.ipynb @@ -81,12 +81,8 @@ }, "outputs": [], "source": [ - "# Set your Together API key\n", - "TOGETHER_API_KEY = \"\" # Add your API key here\n", - "\n", - "# Or load from Colab secrets\n", - "# from google.colab import userdata\n", - "# TOGETHER_API_KEY = userdata.get('TOGETHER_API_KEY')\n", + "from google.colab import userdata\n", + "TOGETHER_API_KEY = userdata.get('TOGETHER_API_KEY')\n", "\n", "if not TOGETHER_API_KEY:\n", " raise ValueError(\"Please set your TOGETHER_API_KEY\")\n",