diff --git a/backend/app.py b/backend/app.py index d6c6d373..45fc3929 100644 --- a/backend/app.py +++ b/backend/app.py @@ -1,41 +1,74 @@ import builtins -import io +import logging import os import sys +import time from typing import Any, Dict, Optional, Union -from flask import Flask, Response, jsonify, request, send_file, session -from flask_cors import CORS -from flask_socketio import SocketIO +import eventlet + +eventlet.monkey_patch() + +from flask import Flask, Response, jsonify, request, send_file, session # noqa: E402 +from flask_cors import CORS # noqa: E402 +from flask_socketio import SocketIO # noqa: E402 project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) if project_root not in sys.path: sys.path.insert(0, project_root) + original_print = builtins.print -original_stdout = sys.stdout -# Buffer to store messages until socketio is ready -log_buffer = [] +_LOG_BUFFER: list[dict[str, Any]] = [] +_MAX_LOG_BUFFER_LENGTH = 500 -class WebSocketCapture(io.StringIO): - def write(self, text: str) -> int: - # Also write to original stdout - original_stdout.write(text) - # Store for WebSocket emission - if text.strip(): # Only non-empty messages - log_buffer.append(text.strip()) - return len(text) +def _push_log(message: str, level: str = "info") -> None: + message = message.strip() + if not message: + return + payload = { + "message": message, + "level": level, + "timestamp": time.time(), + } + _LOG_BUFFER.append(payload) + if len(_LOG_BUFFER) > _MAX_LOG_BUFFER_LENGTH: + del _LOG_BUFFER[0] + try: + if "socketio" in globals(): + socketio.emit("log", payload) + except Exception: + pass + + +class SocketIOLogHandler(logging.Handler): + def emit(self, record: logging.LogRecord) -> None: + try: + message = self.format(record) + except Exception: + message = record.getMessage() + _push_log(message, record.levelname.lower()) + + +def flush_log_buffer() -> None: + global _LOG_BUFFER + if not _LOG_BUFFER: + return + try: + for payload in _LOG_BUFFER: + socketio.emit("log", payload) + except Exception: + return + finally: + _LOG_BUFFER = [] def websocket_print(*args: Any, **kwargs: Any) -> None: - # Call original print - original_print(*args, **kwargs) - # Also emit via WebSocket in real-time message = " ".join(str(arg) for arg in args) - if message.strip(): - emit_log_realtime(message.strip()) + _push_log(message) + original_print(*args, **kwargs) # Override print globally before importing tiny_scientist modules @@ -53,41 +86,6 @@ def websocket_print(*args: Any, **kwargs: Any) -> None: pass -# Create a function to emit buffered logs when socketio is ready -def emit_buffered_logs() -> None: - global log_buffer - try: - for message in log_buffer: - socketio.emit( - "log", - { - "message": message, - "level": "info", - "timestamp": __import__("time").time(), - }, - ) - log_buffer = [] # Clear buffer after emitting - except Exception: - pass - - -# Create a function to emit logs in real-time -def emit_log_realtime(message: str, level: str = "info") -> None: - try: - # Check if socketio is available - if "socketio" in globals(): - socketio.emit( - "log", - { - "message": message, - "level": level, - "timestamp": __import__("time").time(), - }, - ) - except Exception: - pass - - from tiny_scientist.budget_checker import BudgetChecker # noqa: E402 from tiny_scientist.coder import Coder # noqa: E402 from tiny_scientist.reviewer import Reviewer # noqa: E402 @@ -129,7 +127,15 @@ def patch_module_print() -> None: "http://localhost:3000", ], ) -socketio = SocketIO(app, cors_allowed_origins="*") +socketio = SocketIO(app, cors_allowed_origins="*", async_mode="eventlet") +root_logger = logging.getLogger() +if not any(isinstance(handler, SocketIOLogHandler) for handler in root_logger.handlers): + socketio_handler = SocketIOLogHandler() + socketio_handler.setLevel(logging.INFO) + socketio_handler.setFormatter(logging.Formatter("%(message)s")) + root_logger.addHandler(socketio_handler) +if root_logger.level > logging.INFO: + root_logger.setLevel(logging.INFO) # Print override is now active print("๐Ÿš€ Backend server starting with WebSocket logging enabled!") @@ -245,7 +251,7 @@ def configure() -> Union[Response, tuple[Response, int]]: @app.route("/api/generate-initial", methods=["POST"]) def generate_initial() -> Union[Response, tuple[Response, int]]: """Generate initial ideas from an intent (handleAnalysisIntentSubmit)""" - emit_buffered_logs() # Emit any buffered logs from module initialization + flush_log_buffer() # Emit any buffered logs from module initialization data = request.json if data is None: return jsonify({"error": "No JSON data provided"}), 400 @@ -523,7 +529,7 @@ def format_idea_content(idea: Union[Dict[str, Any], str]) -> str: @app.route("/api/code", methods=["POST"]) def generate_code() -> Union[Response, tuple[Response, int]]: """Generate code synchronously and return when complete""" - emit_buffered_logs() # Emit any buffered logs + flush_log_buffer() # Emit any buffered logs global coder if coder is None: @@ -624,9 +630,8 @@ def generate_paper() -> Union[Response, tuple[Response, int]]: experiment_dir = data.get("experiment_dir", None) s2_api_key = data.get("s2_api_key", None) - - if not s2_api_key: - return jsonify({"error": "Semantic Scholar API key is required"}), 400 + if isinstance(s2_api_key, str): + s2_api_key = s2_api_key.strip() or None if not idea_data: print("ERROR: No idea provided in request") @@ -653,6 +658,10 @@ def generate_paper() -> Union[Response, tuple[Response, int]]: ), ) print(f"Writer initialized for this request with model: {writer.model}") + if not s2_api_key: + print( + "Proceeding without Semantic Scholar API key; using fallback sources." + ) # Extract the original idea data if isinstance(idea_data, dict) and "originalData" in idea_data: diff --git a/backend/tests/test_coder_demo.py b/backend/tests/test_coder_demo.py new file mode 100644 index 00000000..c3d5be5d --- /dev/null +++ b/backend/tests/test_coder_demo.py @@ -0,0 +1,76 @@ +import json +import os +from pathlib import Path + +import pytest + +from backend.app import app + + +@pytest.fixture(scope="module") +def client(): + app.config["TESTING"] = True + with app.test_client() as client: + yield client + + +def _load_demo_idea() -> dict: + idea_path = Path(__file__).resolve().parents[2] / "demo_test" / "idea.json" + if not idea_path.exists(): + raise FileNotFoundError(f"Idea file not found: {idea_path}") + return json.loads(idea_path.read_text()) + + +def _configure_backend(client) -> None: + model = "gpt-4o" + api_key = os.environ.get("OPENAI_API_KEY") + if not api_key: + pytest.skip("OPENAI_API_KEY not set; skipping coder integration test") + + response = client.post( + "/api/configure", + json={ + "model": model, + "api_key": api_key, + "budget": 10.0, + "budget_preference": "balanced", + }, + ) + assert response.status_code == 200, response.get_data(as_text=True) + + +def test_coder_with_demo_idea(client): + """ + Integration test: run backend /api/code with the demo idea. + Ensures coder executes and produces experiment outputs. + """ + _configure_backend(client) + + idea_payload = _load_demo_idea() + response = client.post( + "/api/code", + json={"idea": {"originalData": idea_payload}}, + ) + + assert response.status_code == 200, response.get_data(as_text=True) + + data = response.get_json() + assert data is not None, "No JSON body returned" + assert data.get("success") is True, data + + experiment_dir = data.get("experiment_dir") + assert experiment_dir, data + + generated_base = Path(__file__).resolve().parents[2] / "generated" + abs_experiment_dir = generated_base / experiment_dir + assert abs_experiment_dir.exists(), f"Experiment dir missing: {abs_experiment_dir}" + + expected_files = { + "experiment.py", + "notes.txt", + "experiment_results.txt", + } + missing = [ + name for name in expected_files if not (abs_experiment_dir / name).exists() + ] + assert not missing, f"Missing files in {abs_experiment_dir}: {missing}" diff --git a/frontend/package.json b/frontend/package.json index bd5d8fa2..076fb00d 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -2,7 +2,6 @@ "name": "hypo-eval", "version": "0.1.0", "private": true, - "proxy": "http://localhost:5000", "dependencies": { "@monaco-editor/react": "^4.7.0", "@testing-library/jest-dom": "^5.17.0", diff --git a/frontend/src/components/LogDisplay.jsx b/frontend/src/components/LogDisplay.jsx index 5d22c2a4..1a2f8be3 100644 --- a/frontend/src/components/LogDisplay.jsx +++ b/frontend/src/components/LogDisplay.jsx @@ -1,15 +1,45 @@ -import React, { useState, useEffect, useRef } from 'react'; +import React, { useState, useEffect, useMemo, useRef } from 'react'; import { io } from 'socket.io-client'; +const resolveSocketBaseURL = () => { + if (typeof window === 'undefined') { + return 'http://localhost:5000'; + } + + const explicitBase = process.env.REACT_APP_SOCKET_BASE_URL; + if (explicitBase) { + return explicitBase; + } + + const { protocol, hostname, port } = window.location; + + // When running the dev server (port 3000) we need to talk to the backend port. + if (port === '3000') { + const backendPort = process.env.REACT_APP_SOCKET_PORT || '5000'; + return `${protocol}//${hostname}:${backendPort}`; + } + + return `${protocol}//${hostname}${port ? `:${port}` : ''}`; +}; + +const SOCKET_PATH = process.env.REACT_APP_SOCKET_PATH || '/socket.io'; + const LogDisplay = ({ isVisible, onToggle }) => { const [logs, setLogs] = useState([]); const [isConnected, setIsConnected] = useState(false); const socketRef = useRef(null); const logsEndRef = useRef(null); + const socketUrl = useMemo(resolveSocketBaseURL, []); useEffect(() => { - // Always maintain socket connection when component mounts - socketRef.current = io('http://localhost:5000'); + socketRef.current = io(socketUrl, { + path: SOCKET_PATH, + transports: ['websocket', 'polling'], + reconnection: true, + reconnectionDelay: 2000, + reconnectionDelayMax: 6000, + timeout: 10000, + }); socketRef.current.on('connect', () => { setIsConnected(true); @@ -26,8 +56,9 @@ const LogDisplay = ({ isVisible, onToggle }) => { return () => { socketRef.current?.disconnect(); + socketRef.current = null; }; - }, []); + }, [socketUrl]); useEffect(() => { // Auto-scroll to bottom when new logs arrive diff --git a/frontend/src/components/TreePlotVisualization.jsx b/frontend/src/components/TreePlotVisualization.jsx index dafc82dc..4bf56ede 100644 --- a/frontend/src/components/TreePlotVisualization.jsx +++ b/frontend/src/components/TreePlotVisualization.jsx @@ -8,6 +8,110 @@ import IdeaCard from './IdeaCard'; import IdeaFactorsAndScoresCard from './IdeaFactorsAndScoresCard'; import LogDisplay from './LogDisplay'; +const normalizeExperimentDir = (dir) => { + if (!dir) return 'experiments'; + let sanitized = dir.trim(); + if (!sanitized) return 'experiments'; + sanitized = sanitized.replace(/\\/g, '/'); + sanitized = sanitized.replace(/^\.?\/+/, ''); + if (sanitized.startsWith('generated/')) { + sanitized = sanitized.slice('generated/'.length); + } + return sanitized || 'experiments'; +}; + +const buildExperimentFileUrl = (dir, fileName) => { + const safeDir = normalizeExperimentDir(dir); + const encodedDir = safeDir + .split('/') + .filter(Boolean) + .map(encodeURIComponent) + .join('/'); + return `/api/files/${encodedDir}/${encodeURIComponent(fileName)}`; +}; + +const DEFAULT_RUN_DISCOVERY_LIMIT = 10; + +const fetchTextFile = async (dir, filePath) => { + const url = buildExperimentFileUrl(dir, filePath); + try { + console.log(`Attempting to fetch ${filePath} from ${url}`); + const response = await fetch(url); + if (!response.ok) { + console.log(`Fetch skipped for ${filePath}: ${response.status} ${response.statusText}`); + return null; + } + const data = await response.json(); + if (typeof data?.content === 'string') { + return data.content; + } + } catch (err) { + console.error(`Error fetching ${filePath}:`, err); + } + return null; +}; + +const getRunLabel = (runName, index = 0) => { + if (!runName) { + return `Run ${index + 1}`; + } + if (runName.startsWith('run_')) { + const parts = runName.split('_'); + if (parts.length > 1 && parts[1]) { + return `Run ${parts[1]}`; + } + } + return runName; +}; + +const getFileIcon = (fileName = '') => { + if (fileName.endsWith('.py')) return '๐Ÿ'; + if (fileName.endsWith('.json')) return '๐Ÿ“Š'; + if (fileName.endsWith('.txt')) return '๐Ÿ“„'; + return '๐Ÿ“'; +}; + +const normalizeCodeResult = (data) => { + if (!data || typeof data !== 'object') { + return data; + } + const safeDir = data.experiment_dir ? normalizeExperimentDir(data.experiment_dir) : null; + return { ...data, experiment_dir: safeDir }; +}; + +const formatMetricValue = (value) => { + if (typeof value === 'number' && Number.isFinite(value)) { + const fixed = value.toFixed(4); + return parseFloat(fixed).toString(); + } + return String(value); +}; + +const sanitizeS2ApiKey = (key) => (typeof key === 'string' ? key.trim() : ''); + +const getStoredS2ApiKey = () => { + if (typeof window === 'undefined') return ''; + try { + return window.localStorage.getItem('s2_api_key') || ''; + } catch (err) { + console.error('Failed to read Semantic Scholar key from storage:', err); + return ''; + } +}; + +const persistS2ApiKey = (key) => { + if (typeof window === 'undefined') return; + try { + if (key) { + window.localStorage.setItem('s2_api_key', key); + } else { + window.localStorage.removeItem('s2_api_key'); + } + } catch (err) { + console.error('Failed to persist Semantic Scholar key:', err); + } +}; + // Helper components defined outside the main component to preserve state const ContextAndGenerateCard = ({ @@ -398,7 +502,7 @@ const TreePlotVisualization = () => { const [isAddingCustom, setIsAddingCustom] = useState(false); const [customIdea, setCustomIdea] = useState({ title: '', content: '' }); // *** ๆ–ฐๅขž๏ผš็”จไบŽไธป็•Œ้ขๆจกๅž‹้€‰ๆ‹ฉๅ’Œapi-key่พ“ๅ…ฅ - const [selectedModel, setSelectedModel] = useState('deepseek-chat'); + const [selectedModel, setSelectedModel] = useState('gpt-4o'); const [apiKey, setApiKey] = useState(''); const [isConfigured, setIsConfigured] = useState(false); const [configError, setConfigError] = useState(''); @@ -438,7 +542,7 @@ const TreePlotVisualization = () => { const [codeFileName, setCodeFileName] = useState('experiment.py'); const [activeCodeTab, setActiveCodeTab] = useState('experiment.py'); const [experimentFiles, setExperimentFiles] = useState({}); - const [consoleOutput, setConsoleOutput] = useState(''); + const [experimentFileList, setExperimentFileList] = useState([]); const [experimentRuns, setExperimentRuns] = useState([]); const [isRunningExperiment, setIsRunningExperiment] = useState(false); const [pdfComments, setPdfComments] = useState([]); @@ -449,6 +553,36 @@ const TreePlotVisualization = () => { const [isReviewing, setIsReviewing] = useState(false); const [rightPanelTab, setRightPanelTab] = useState('comments'); // 'comments' or 'review' + const experimentFileCount = experimentFileList.length; + const hasExperimentFiles = experimentFileCount > 0; + + + const baseFileItems = experimentFileList.filter((item) => item.group === 'base'); + const runFileGroups = (() => { + const map = new Map(); + experimentFileList.forEach((item) => { + if (item.group !== 'run') { + return; + } + const key = item.runName || item.path; + if (!map.has(key)) { + map.set(key, { + runLabel: item.runLabel || getRunLabel(item.runName || '', map.size), + items: [], + }); + } + map.get(key).items.push(item); + }); + return Array.from(map.entries()); + })(); + + useEffect(() => { + const savedKey = sanitizeS2ApiKey(getStoredS2ApiKey()); + if (savedKey) { + setS2ApiKey(savedKey); + } + }, []); + // Track view changes const previousViewRef = useRef(currentView); // Initialize with current view @@ -515,9 +649,9 @@ const TreePlotVisualization = () => { }; // ============== ้…็ฝฎๆจกๅž‹ๅ’ŒAPI Key ============== const modelOptions = [ + { value: 'gpt-4o', label: 'GPT-4o' }, { value: 'deepseek-chat', label: 'DeepSeek Chat' }, { value: 'deepseek-reasoner', label: 'DeepSeek Reasoner' }, - { value: 'gpt-4o', label: 'GPT-4o' }, { value: 'o1-2024-12-17', label: 'GPT-o1' }, { value: 'claude-3-5-sonnet-20241022', label: 'Claude 3.5 Sonnet' }, ]; @@ -2223,46 +2357,177 @@ const TreePlotVisualization = () => { // Load generated files from public directory const loadGeneratedFiles = async (experimentDir) => { + const safeDir = normalizeExperimentDir(experimentDir); try { - const fileUrls = { - 'experiment.py': `/api/files/${experimentDir}/experiment.py`, - 'notes.txt': `/api/files/${experimentDir}/notes.txt`, - 'experiment_results.txt': `/api/files/${experimentDir}/experiment_results.txt`, - }; - const loadedFiles = {}; - for (const [fileName, url] of Object.entries(fileUrls)) { + const fileItems = []; + + const baseFiles = ['experiment.py', 'notes.txt', 'experiment_results.txt']; + for (const fileName of baseFiles) { + const content = await fetchTextFile(safeDir, fileName); + if (content !== null) { + loadedFiles[fileName] = content; + fileItems.push({ + path: fileName, + tabLabel: fileName, + sidebarLabel: fileName, + icon: getFileIcon(fileName), + group: 'base', + }); + } + } + + const resultContent = loadedFiles['experiment_results.txt'] || null; + let runNames = []; + if (resultContent) { try { - console.log(`Attempting to fetch: ${fileName} from ${url}`); - const response = await fetch(url); - console.log(`Response for ${fileName}:`, response.status, response.statusText); - - if (response.ok) { - const data = await response.json(); - if (data.content) { - loadedFiles[fileName] = data.content; - console.log(`Loaded ${fileName} successfully.`); + const parsed = JSON.parse(resultContent); + if (parsed && typeof parsed === 'object') { + runNames = Object.keys(parsed); + } + } catch (parseErr) { + console.error('Failed to parse experiment_results.txt:', parseErr); + } + } + if (runNames.length === 0) { + runNames = Array.from({ length: DEFAULT_RUN_DISCOVERY_LIMIT }, (_, i) => `run_${i + 1}`); + } + + const runMeta = new Map(); + await Promise.all( + runNames.map(async (runName, index) => { + if (!runName) return; + + const runLabel = getRunLabel(runName, index); + let meta = runMeta.get(runName); + if (!meta) { + meta = { runName, runLabel }; + runMeta.set(runName, meta); + } + + let fetchedContent = false; + + const finalInfoPath = `${runName}/final_info.json`; + const finalInfo = await fetchTextFile(safeDir, finalInfoPath); + if (finalInfo !== null) { + loadedFiles[finalInfoPath] = finalInfo; + fetchedContent = true; + try { + meta.finalInfo = JSON.parse(finalInfo); + } catch (err) { + console.error(`Failed to parse ${finalInfoPath}:`, err); + meta.finalInfo = null; } + meta.finalInfoPath = finalInfoPath; + fileItems.push({ + path: finalInfoPath, + tabLabel: `${runLabel} โ€ข final_info.json`, + sidebarLabel: 'final_info.json', + icon: getFileIcon(finalInfoPath), + group: 'run', + runName, + runLabel, + }); + } + + const runScriptPath = `${runName}.py`; + const runScript = await fetchTextFile(safeDir, runScriptPath); + if (runScript !== null) { + loadedFiles[runScriptPath] = runScript; + fetchedContent = true; + meta.scriptPath = runScriptPath; + const scriptFileName = runScriptPath.split('/').pop() || runScriptPath; + fileItems.push({ + path: runScriptPath, + tabLabel: `${runLabel} โ€ข ${scriptFileName}`, + sidebarLabel: scriptFileName, + icon: getFileIcon(runScriptPath), + group: 'run', + runName, + runLabel, + }); + } + + if (!fetchedContent) { + runMeta.delete(runName); + } + }) + ); + + setExperimentFiles(loadedFiles); + setExperimentFileList(fileItems); + + const runMetaMap = Object.fromEntries(runMeta.entries()); + + if (resultContent) { + try { + const parsed = JSON.parse(resultContent); + if (parsed && typeof parsed === 'object') { + const runs = Object.entries(parsed).map(([name, metrics], index) => { + const safeMetrics = typeof metrics === 'object' && metrics !== null ? metrics : {}; + const meta = runMetaMap[name] || { runName: name, runLabel: getRunLabel(name, index) }; + return { + name, + runLabel: meta.runLabel || getRunLabel(name, index), + metrics: safeMetrics, + success: Object.keys(safeMetrics).length > 0, + }; + }); + setExperimentRuns(runs); } else { - console.log(`Could not load ${fileName}, but this might be expected (e.g., no notes).`); + setExperimentRuns([]); } - } catch (err) { - console.error(`Error fetching ${fileName}:`, err); + } catch (parseErr) { + console.error('Failed to parse experiment_results.txt:', parseErr); + setExperimentRuns([]); } + } else if (Object.keys(runMetaMap).length > 0) { + const runs = Object.entries(runMetaMap).map(([name, meta], index) => { + return { + name, + runLabel: meta.runLabel || getRunLabel(name, index), + metrics: meta.finalInfo && typeof meta.finalInfo === 'object' ? meta.finalInfo : {}, + success: true, + }; + }); + setExperimentRuns(runs); + } else { + setExperimentRuns([]); } - // Update state all at once - setExperimentFiles(prev => ({ ...prev, ...loadedFiles })); - - // Set the code content for the editor if experiment.py was loaded if (loadedFiles['experiment.py']) { - setCodeContent(loadedFiles['experiment.py']); setActiveCodeTab('experiment.py'); + setCodeFileName('experiment.py'); + setCodeContent(loadedFiles['experiment.py']); + } else if (fileItems.length > 0) { + const firstItem = fileItems[0]; + setActiveCodeTab(firstItem.path); + setCodeFileName(firstItem.tabLabel); + setCodeContent(loadedFiles[firstItem.path] || ''); + } else { + setActiveCodeTab(''); + setCodeFileName(''); + setCodeContent( + `# Generated experiment code +# Files are being generated in: ${safeDir} + +# Please check the directory for the actual code files.` + ); } + return safeDir; } catch (err) { - console.log("Error loading generated files:", err); - setCodeContent(`# Generated experiment code\n# Files are being generated in: ${experimentDir}\n\n# Please check the directory for the actual code files.`); + console.log(`Error loading generated files from ${safeDir}:`, err); + setExperimentFiles({}); + setExperimentFileList([]); + setCodeContent( + `# Generated experiment code +# Files are being generated in: ${safeDir} + +# Please check the directory for the actual code files.` + ); + setExperimentRuns([]); + return safeDir; } }; @@ -2281,11 +2546,11 @@ const TreePlotVisualization = () => { console.log("Manual file loading completed"); // Set a fake successful result to satisfy the UI - setCodeResult({ + setCodeResult(normalizeCodeResult({ status: true, success: true, experiment_dir: "experiments" - }); + })); } catch (err) { console.log("Manual file loading failed:", err); @@ -2304,6 +2569,7 @@ const TreePlotVisualization = () => { setOperationStatus('Retrying code generation...'); setCodeResult(null); setExperimentFiles({}); + setExperimentFileList([]); setShowLogs(true); // Show logs when starting code generation try { @@ -2325,25 +2591,28 @@ const TreePlotVisualization = () => { const codeData = await codeResponse.json(); console.log("Retry code generation completed:", codeData); - setCodeResult(codeData); - if (!codeData.success) { throw new Error(codeData.error || 'Code generation failed'); } + const normalizedResult = normalizeCodeResult(codeData); + setCodeResult(normalizedResult); + // Load generated files setOperationStatus('Loading generated files...'); - await loadGeneratedFiles(codeData.experiment_dir); + if (normalizedResult.experiment_dir) { + await loadGeneratedFiles(normalizedResult.experiment_dir); + } setOperationStatus('Code generation retry completed successfully!'); } catch (error) { console.error("Retry code generation failed:", error); setOperationStatus('Retry code generation failed: ' + error.message); - setCodeResult({ + setCodeResult(normalizeCodeResult({ success: false, error: error.message, error_details: error.message - }); + })); } finally { setIsGeneratingCode(false); } @@ -2369,15 +2638,13 @@ const TreePlotVisualization = () => { setShowLogs(true); // Show logs when starting paper generation try { - // Get S2 API key from localStorage or prompt user - let s2ApiKey = localStorage.getItem('s2_api_key'); - if (!s2ApiKey) { - s2ApiKey = prompt('Please enter your Semantic Scholar API Key:'); - if (!s2ApiKey) { - throw new Error('Semantic Scholar API key is required for paper generation'); - } - localStorage.setItem('s2_api_key', s2ApiKey); - } + const effectiveS2Key = + sanitizeS2ApiKey(s2ApiKey) || sanitizeS2ApiKey(getStoredS2ApiKey()); + persistS2ApiKey(effectiveS2Key); + setS2ApiKey(effectiveS2Key); + console.log( + `Semantic Scholar API key ${effectiveS2Key ? 'detected' : 'not provided'} for paper generation` + ); const paperResponse = await fetch('/api/write', { method: 'POST', @@ -2388,7 +2655,7 @@ const TreePlotVisualization = () => { body: JSON.stringify({ idea: node.originalData, experiment_dir: experimentDir, - s2_api_key: s2ApiKey + s2_api_key: effectiveS2Key || null }) }); @@ -2448,11 +2715,9 @@ const TreePlotVisualization = () => { // Check if the idea is experimental to determine S2 API key requirement const isExperimental = selectedNode.originalData.is_experimental === true; - // Validate Semantic Scholar API key only for non-experimental ideas - if (!isExperimental && !s2ApiKey.trim()) { - setProceedError('Semantic Scholar API key is required for non-experimental ideas'); - return; - } + const sanitizedS2Key = sanitizeS2ApiKey(s2ApiKey); + persistS2ApiKey(sanitizedS2Key); + setS2ApiKey(sanitizedS2Key); setShowProceedConfirm(false); setProceedError(null); @@ -2566,7 +2831,6 @@ const TreePlotVisualization = () => { codeData = await codeResponse.json(); console.log("Code generation completed:", codeData); - setCodeResult(codeData); if (!codeData.success) { throw new Error(codeData.error || 'Code generation failed'); @@ -2608,12 +2872,17 @@ const TreePlotVisualization = () => { } // Store results - setCodeResult(codeData); - const finalExperimentDir = codeData.experiment_dir; + const normalizedResult = normalizeCodeResult(codeData); + setCodeResult(normalizedResult); + const finalExperimentDir = normalizedResult?.experiment_dir; // Load generated files setOperationStatus('Loading generated files...'); - await loadGeneratedFiles(finalExperimentDir); + if (finalExperimentDir) { + await loadGeneratedFiles(finalExperimentDir); + } else { + console.warn("Code generation succeeded but experiment directory is missing."); + } // Mark that code has been generated (this will show Code View tab) setHasGeneratedCode(true); @@ -2643,7 +2912,7 @@ const TreePlotVisualization = () => { body: JSON.stringify({ idea: selectedNode.originalData, experiment_dir: null, // No experiment directory for non-experimental papers - s2_api_key: s2ApiKey.trim(), + s2_api_key: sanitizedS2Key || null, }), }); @@ -2776,7 +3045,14 @@ const TreePlotVisualization = () => { // Enhanced code editing functions const switchCodeTab = (tabName) => { - // Save current content before switching + if (!tabName) { + return; + } + + if (experimentFiles[tabName] === undefined) { + return; + } + if (activeCodeTab && experimentFiles[activeCodeTab] !== undefined) { setExperimentFiles(prev => ({ ...prev, @@ -2784,9 +3060,9 @@ const TreePlotVisualization = () => { })); } - // Switch to new tab + const fileItem = experimentFileList.find((item) => item.path === tabName); setActiveCodeTab(tabName); - setCodeFileName(tabName); + setCodeFileName(fileItem?.tabLabel || tabName); setCodeContent(experimentFiles[tabName] || ''); }; @@ -2894,14 +3170,14 @@ const TreePlotVisualization = () => { {/* Download All Button */} - - -
- {consoleOutput || 'No output yet...'} -
- {proceedError && ( @@ -3198,7 +3502,7 @@ const TreePlotVisualization = () => { )} - {!codeResult && !isGeneratingCode && !proceedError && Object.keys(experimentFiles).length === 0 && ( + {!codeResult && !isGeneratingCode && !proceedError && !hasExperimentFiles && (
{ color: '#374151', fontWeight: 500 }}> - Semantic Scholar API Key {!selectedNode?.originalData?.is_experimental ? '*' : '(Optional for now)'} + Semantic Scholar API Key (Optional) setS2ApiKey(e.target.value)} - placeholder={selectedNode?.originalData?.is_experimental ? - "Enter your Semantic Scholar API key (can be provided later for paper generation)" : - "Enter your Semantic Scholar API key"} + placeholder="Optional: improves citation quality" style={{ width: '100%', padding: '8px 12px', @@ -4128,10 +4430,7 @@ const TreePlotVisualization = () => { }} />
- {selectedNode?.originalData?.is_experimental ? - 'For experimental ideas: Required only when generating paper. You can provide it later.' : - 'Required for paper generation.' - } Get your API key from{' '} + Providing a key lets us fetch references from Semantic Scholar. Leave blank to skip for now. Get your API key from{' '} { diff --git a/frontend/src/setupProxy.js b/frontend/src/setupProxy.js new file mode 100644 index 00000000..082265b2 --- /dev/null +++ b/frontend/src/setupProxy.js @@ -0,0 +1,12 @@ +const { createProxyMiddleware } = require('http-proxy-middleware'); + +module.exports = function setupProxy(app) { + app.use( + ['/api', '/socket.io'], + createProxyMiddleware({ + target: 'http://localhost:5000', + changeOrigin: true, + ws: true, + }) + ); +}; diff --git a/poetry.lock b/poetry.lock index faf3eb5e..84ef3938 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.2.0 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. [[package]] name = "aider-chat" @@ -1056,6 +1056,27 @@ files = [ {file = "distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed"}, ] +[[package]] +name = "dnspython" +version = "2.8.0" +description = "DNS toolkit" +optional = false +python-versions = ">=3.10" +groups = ["main"] +files = [ + {file = "dnspython-2.8.0-py3-none-any.whl", hash = "sha256:01d9bbc4a2d76bf0db7c1f729812ded6d912bd318d3b1cf81d30c0f845dbf3af"}, + {file = "dnspython-2.8.0.tar.gz", hash = "sha256:181d3c6996452cb1189c4046c61599b84a5a86e099562ffde77d26984ff26d0f"}, +] + +[package.extras] +dev = ["black (>=25.1.0)", "coverage (>=7.0)", "flake8 (>=7)", "hypercorn (>=0.17.0)", "mypy (>=1.17)", "pylint (>=3)", "pytest (>=8.4)", "pytest-cov (>=6.2.0)", "quart-trio (>=0.12.0)", "sphinx (>=8.2.0)", "sphinx-rtd-theme (>=3.0.0)", "twine (>=6.1.0)", "wheel (>=0.45.0)"] +dnssec = ["cryptography (>=45)"] +doh = ["h2 (>=4.2.0)", "httpcore (>=1.0.0)", "httpx (>=0.28.0)"] +doq = ["aioquic (>=1.2.0)"] +idna = ["idna (>=3.10)"] +trio = ["trio (>=0.30)"] +wmi = ["wmi (>=1.5.1) ; platform_system == \"Windows\""] + [[package]] name = "docker" version = "7.1.0" @@ -1079,6 +1100,25 @@ docs = ["myst-parser (==0.18.0)", "sphinx (==5.1.1)"] ssh = ["paramiko (>=2.4.3)"] websockets = ["websocket-client (>=1.3.0)"] +[[package]] +name = "eventlet" +version = "0.40.3" +description = "Highly concurrent networking library" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "eventlet-0.40.3-py3-none-any.whl", hash = "sha256:e681cae6ee956cfb066a966b5c0541e734cc14879bda6058024104790595ac9d"}, + {file = "eventlet-0.40.3.tar.gz", hash = "sha256:290852db0065d78cec17a821b78c8a51cafb820a792796a354592ae4d5fceeb0"}, +] + +[package.dependencies] +dnspython = ">=1.15.0" +greenlet = ">=1.0" + +[package.extras] +dev = ["black", "build", "commitizen", "isort", "pip-tools", "pre-commit", "twine"] + [[package]] name = "exceptiongroup" version = "1.2.2" @@ -1585,6 +1625,74 @@ protobuf = ">=3.20.2,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4 [package.extras] grpc = ["grpcio (>=1.44.0,<2.0.0)"] +[[package]] +name = "greenlet" +version = "3.2.4" +description = "Lightweight in-process concurrent programming" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "greenlet-3.2.4-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:8c68325b0d0acf8d91dde4e6f930967dd52a5302cd4062932a6b2e7c2969f47c"}, + {file = "greenlet-3.2.4-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:94385f101946790ae13da500603491f04a76b6e4c059dab271b3ce2e283b2590"}, + {file = "greenlet-3.2.4-cp310-cp310-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:f10fd42b5ee276335863712fa3da6608e93f70629c631bf77145021600abc23c"}, + {file = "greenlet-3.2.4-cp310-cp310-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:c8c9e331e58180d0d83c5b7999255721b725913ff6bc6cf39fa2a45841a4fd4b"}, + {file = "greenlet-3.2.4-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:58b97143c9cc7b86fc458f215bd0932f1757ce649e05b640fea2e79b54cedb31"}, + {file = "greenlet-3.2.4-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c2ca18a03a8cfb5b25bc1cbe20f3d9a4c80d8c3b13ba3df49ac3961af0b1018d"}, + {file = "greenlet-3.2.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:9fe0a28a7b952a21e2c062cd5756d34354117796c6d9215a87f55e38d15402c5"}, + {file = "greenlet-3.2.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8854167e06950ca75b898b104b63cc646573aa5fef1353d4508ecdd1ee76254f"}, + {file = "greenlet-3.2.4-cp310-cp310-win_amd64.whl", hash = "sha256:73f49b5368b5359d04e18d15828eecc1806033db5233397748f4ca813ff1056c"}, + {file = "greenlet-3.2.4-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:96378df1de302bc38e99c3a9aa311967b7dc80ced1dcc6f171e99842987882a2"}, + {file = "greenlet-3.2.4-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:1ee8fae0519a337f2329cb78bd7a8e128ec0f881073d43f023c7b8d4831d5246"}, + {file = "greenlet-3.2.4-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:94abf90142c2a18151632371140b3dba4dee031633fe614cb592dbb6c9e17bc3"}, + {file = "greenlet-3.2.4-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:4d1378601b85e2e5171b99be8d2dc85f594c79967599328f95c1dc1a40f1c633"}, + {file = "greenlet-3.2.4-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:0db5594dce18db94f7d1650d7489909b57afde4c580806b8d9203b6e79cdc079"}, + {file = "greenlet-3.2.4-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2523e5246274f54fdadbce8494458a2ebdcdbc7b802318466ac5606d3cded1f8"}, + {file = "greenlet-3.2.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:1987de92fec508535687fb807a5cea1560f6196285a4cde35c100b8cd632cc52"}, + {file = "greenlet-3.2.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:55e9c5affaa6775e2c6b67659f3a71684de4c549b3dd9afca3bc773533d284fa"}, + {file = "greenlet-3.2.4-cp311-cp311-win_amd64.whl", hash = "sha256:9c40adce87eaa9ddb593ccb0fa6a07caf34015a29bf8d344811665b573138db9"}, + {file = "greenlet-3.2.4-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:3b67ca49f54cede0186854a008109d6ee71f66bd57bb36abd6d0a0267b540cdd"}, + {file = "greenlet-3.2.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ddf9164e7a5b08e9d22511526865780a576f19ddd00d62f8a665949327fde8bb"}, + {file = "greenlet-3.2.4-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:f28588772bb5fb869a8eb331374ec06f24a83a9c25bfa1f38b6993afe9c1e968"}, + {file = "greenlet-3.2.4-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:5c9320971821a7cb77cfab8d956fa8e39cd07ca44b6070db358ceb7f8797c8c9"}, + {file = "greenlet-3.2.4-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:c60a6d84229b271d44b70fb6e5fa23781abb5d742af7b808ae3f6efd7c9c60f6"}, + {file = "greenlet-3.2.4-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3b3812d8d0c9579967815af437d96623f45c0f2ae5f04e366de62a12d83a8fb0"}, + {file = "greenlet-3.2.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:abbf57b5a870d30c4675928c37278493044d7c14378350b3aa5d484fa65575f0"}, + {file = "greenlet-3.2.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:20fb936b4652b6e307b8f347665e2c615540d4b42b3b4c8a321d8286da7e520f"}, + {file = "greenlet-3.2.4-cp312-cp312-win_amd64.whl", hash = "sha256:a7d4e128405eea3814a12cc2605e0e6aedb4035bf32697f72deca74de4105e02"}, + {file = "greenlet-3.2.4-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:1a921e542453fe531144e91e1feedf12e07351b1cf6c9e8a3325ea600a715a31"}, + {file = "greenlet-3.2.4-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cd3c8e693bff0fff6ba55f140bf390fa92c994083f838fece0f63be121334945"}, + {file = "greenlet-3.2.4-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:710638eb93b1fa52823aa91bf75326f9ecdfd5e0466f00789246a5280f4ba0fc"}, + {file = "greenlet-3.2.4-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:c5111ccdc9c88f423426df3fd1811bfc40ed66264d35aa373420a34377efc98a"}, + {file = "greenlet-3.2.4-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:d76383238584e9711e20ebe14db6c88ddcedc1829a9ad31a584389463b5aa504"}, + {file = "greenlet-3.2.4-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:23768528f2911bcd7e475210822ffb5254ed10d71f4028387e5a99b4c6699671"}, + {file = "greenlet-3.2.4-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:00fadb3fedccc447f517ee0d3fd8fe49eae949e1cd0f6a611818f4f6fb7dc83b"}, + {file = "greenlet-3.2.4-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:d25c5091190f2dc0eaa3f950252122edbbadbb682aa7b1ef2f8af0f8c0afefae"}, + {file = "greenlet-3.2.4-cp313-cp313-win_amd64.whl", hash = "sha256:554b03b6e73aaabec3745364d6239e9e012d64c68ccd0b8430c64ccc14939a8b"}, + {file = "greenlet-3.2.4-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:49a30d5fda2507ae77be16479bdb62a660fa51b1eb4928b524975b3bde77b3c0"}, + {file = "greenlet-3.2.4-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:299fd615cd8fc86267b47597123e3f43ad79c9d8a22bebdce535e53550763e2f"}, + {file = "greenlet-3.2.4-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:c17b6b34111ea72fc5a4e4beec9711d2226285f0386ea83477cbb97c30a3f3a5"}, + {file = "greenlet-3.2.4-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:b4a1870c51720687af7fa3e7cda6d08d801dae660f75a76f3845b642b4da6ee1"}, + {file = "greenlet-3.2.4-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:061dc4cf2c34852b052a8620d40f36324554bc192be474b9e9770e8c042fd735"}, + {file = "greenlet-3.2.4-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:44358b9bf66c8576a9f57a590d5f5d6e72fa4228b763d0e43fee6d3b06d3a337"}, + {file = "greenlet-3.2.4-cp314-cp314-win_amd64.whl", hash = "sha256:e37ab26028f12dbb0ff65f29a8d3d44a765c61e729647bf2ddfbbed621726f01"}, + {file = "greenlet-3.2.4-cp39-cp39-macosx_11_0_universal2.whl", hash = "sha256:b6a7c19cf0d2742d0809a4c05975db036fdff50cd294a93632d6a310bf9ac02c"}, + {file = "greenlet-3.2.4-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:27890167f55d2387576d1f41d9487ef171849ea0359ce1510ca6e06c8bece11d"}, + {file = "greenlet-3.2.4-cp39-cp39-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:18d9260df2b5fbf41ae5139e1be4e796d99655f023a636cd0e11e6406cca7d58"}, + {file = "greenlet-3.2.4-cp39-cp39-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:671df96c1f23c4a0d4077a325483c1503c96a1b7d9db26592ae770daa41233d4"}, + {file = "greenlet-3.2.4-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:16458c245a38991aa19676900d48bd1a6f2ce3e16595051a4db9d012154e8433"}, + {file = "greenlet-3.2.4-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c9913f1a30e4526f432991f89ae263459b1c64d1608c0d22a5c79c287b3c70df"}, + {file = "greenlet-3.2.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b90654e092f928f110e0007f572007c9727b5265f7632c2fa7415b4689351594"}, + {file = "greenlet-3.2.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:81701fd84f26330f0d5f4944d4e92e61afe6319dcd9775e39396e39d7c3e5f98"}, + {file = "greenlet-3.2.4-cp39-cp39-win32.whl", hash = "sha256:65458b409c1ed459ea899e939f0e1cdb14f58dbc803f2f93c5eab5694d32671b"}, + {file = "greenlet-3.2.4-cp39-cp39-win_amd64.whl", hash = "sha256:d2e685ade4dafd447ede19c31277a224a239a0a1a4eca4e6390efedf20260cfb"}, + {file = "greenlet-3.2.4.tar.gz", hash = "sha256:0dca0d95ff849f9a364385f36ab49f50065d76964944638be9691e1832e9f86d"}, +] + +[package.extras] +docs = ["Sphinx", "furo"] +test = ["objgraph", "psutil", "setuptools"] + [[package]] name = "grep-ast" version = "0.8.1" @@ -5930,4 +6038,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.1" python-versions = ">=3.10, <3.12" -content-hash = "8afa87528e8080a6910bc3078d7fdb5829f5e204f9d8eecabf836ef6b9028d05" +content-hash = "f0335eecbf8ab3826abd39b6c2363e66a74f085533dd15c4ecc4e7b95867dd50" diff --git a/pyproject.toml b/pyproject.toml index 0b0e98d4..8865bd49 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ flask-socketio = "^5.5.1" docker = "^7.1.0" fastmcp = "*" mcp = "*" +eventlet = "^0.40.3" [tool.poetry.group.dev.dependencies] pre-commit = "*" diff --git a/tiny_scientist/mcp/docker_runner_server.py b/tiny_scientist/mcp/docker_runner_server.py index 8bd7664b..e9ccc261 100644 --- a/tiny_scientist/mcp/docker_runner_server.py +++ b/tiny_scientist/mcp/docker_runner_server.py @@ -21,6 +21,27 @@ CONFIG_PATH = PACKAGE_ROOT / "config.toml" config = toml.load(CONFIG_PATH) if CONFIG_PATH.exists() else {"core": {}} +DEFAULT_BASE_PACKAGES = [ + "numpy", + "pandas", + "scikit-learn", + "matplotlib", + "seaborn", + "tqdm", + "requests", + "pillow", +] + + +def _resolve_configured_packages() -> List[str]: + docker_cfg = config.get("docker") if isinstance(config, dict) else None + configured = None + if isinstance(docker_cfg, dict): + configured = docker_cfg.get("base_packages") + if isinstance(configured, list) and all(isinstance(pkg, str) for pkg in configured): + return configured + return DEFAULT_BASE_PACKAGES.copy() + class DockerExperimentRunner: def __init__( @@ -32,6 +53,8 @@ def __init__( self.docker_base = docker_base self.docker_client = None self.use_docker = False + self.base_packages = _resolve_configured_packages() + self._base_package_set: Set[str] = set(self.base_packages) # Initialize Docker client try: @@ -49,24 +72,7 @@ def detect_required_packages( ) -> List[str]: """Detect required packages from import statements in a Python file.""" if base_packages is None: - base_packages = set( - [ - "numpy", - "pandas", - "scikit-learn", - "matplotlib", - "seaborn", - "torch", - "tensorflow", - "transformers", - "datasets", - "evaluate", - "wandb", - "tqdm", - "requests", - "pillow", - ] - ) + base_packages = set(DEFAULT_BASE_PACKAGES) # Common package name mappings (import_name -> pip_package_name) package_mapping = { @@ -298,10 +304,11 @@ def get_or_build_base_image(self) -> Optional[str]: print(f"[Docker] Using existing image: {self.docker_image}") except ImageNotFound: print(f"[Docker] Building image: {self.docker_image}") - dockerfile = f""" -FROM {self.docker_base} -RUN pip install --no-cache-dir numpy pandas scikit-learn matplotlib seaborn torch tensorflow transformers datasets evaluate wandb tqdm requests pillow -""" + dockerfile_lines = [f"FROM {self.docker_base}"] + if self.base_packages: + joined = " ".join(sorted(self.base_packages)) + dockerfile_lines.append(f"RUN pip install --no-cache-dir {joined}") + dockerfile = "\n".join(dockerfile_lines) + "\n" with tempfile.TemporaryDirectory() as tmpdir: with open(os.path.join(tmpdir, "Dockerfile"), "w") as f: f.write(dockerfile) @@ -316,7 +323,9 @@ def get_or_build_experiment_image(self, experiment_py_path: str) -> Optional[str if not self.use_docker: return None base_image = self.get_or_build_base_image() - extra_pkgs = self.detect_required_packages(experiment_py_path) + extra_pkgs = self.detect_required_packages( + experiment_py_path, base_packages=self._base_package_set + ) if extra_pkgs: image_name = f"tiny-scientist-exp-{hash(tuple(extra_pkgs))}" if self.docker_client is not None: