diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 582da22..8774307 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -68,34 +68,34 @@ jobs: - name: Run type checking with tox run: tox -e type - # test: - # name: Unit Tests - # runs-on: ubuntu-latest - # needs: [format, lint, type] - # strategy: - # matrix: - # python-version: ["3.10", "3.11", "3.12"] - - # steps: - # - uses: actions/checkout@v4 - - # - name: Set up Python ${{ matrix.python-version }} - # uses: actions/setup-python@v5 - # with: - # python-version: ${{ matrix.python-version }} - - # - name: Install tox - # run: | - # python -m pip install --upgrade pip - # pip install tox - - # - name: Run tests with tox - # run: tox -e test + test: + name: Unit Tests + runs-on: ubuntu-latest + needs: [format, lint, type] + strategy: + matrix: + python-version: ["3.10", "3.11", "3.12"] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install tox + run: | + python -m pip install --upgrade pip + pip install tox + + - name: Run tests with tox + run: tox -e test smoke-test: name: Smoke Tests runs-on: ubuntu-latest - needs: [format, lint, type] + needs: [format, lint, type, coverage] strategy: matrix: python-version: ["3.10", "3.11", "3.12"] @@ -116,3 +116,32 @@ jobs: - name: Test vtk-prompt CLI run: | vtk-prompt --help + + coverage: + name: Test Coverage + runs-on: ubuntu-latest + needs: [format, lint, type] + steps: + - uses: actions/checkout@v4 + + - name: Set up Python 3.11 + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e . + pip install pytest pytest-cov + + - name: Run tests with coverage (non-blocking when no tests) + run: | + pytest -q --cov=vtk_prompt --cov-report=term --cov-report=xml || true + + - name: Upload coverage report + if: ${{ hashFiles('coverage.xml') != '' }} + uses: actions/upload-artifact@v4 + with: + name: coverage-xml + path: coverage.xml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9957736..a14719e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -61,6 +61,7 @@ repos: - monai - nibabel - vtk + - types-PyYAML # Spelling - repo: https://github.com/codespell-project/codespell diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md index 2767e42..87fb0ab 100644 --- a/DEVELOPMENT.md +++ b/DEVELOPMENT.md @@ -10,12 +10,26 @@ pip install -e ".[all]" ## Running tests +### Test Suite + +The project includes a comprehensive test suite focused on the prompt assembly +system: + ```bash -# Lint and format -black src/ -flake8 src/ +# Run all tests with tox +tox -e test + +# Run specific test file +python -m pytest tests/test_prompt_assembly.py -v -# Test installation +# Run specific test methods +python -m pytest tests/test_prompt_assembly.py::TestPromptAssembly::test_default_values -v +``` + +### Manual Testing + +```bash +# Test CLI installation and basic functionality vtk-prompt --help vtk-prompt-ui --help ``` @@ -51,6 +65,75 @@ export VTK_PROMPT_LOG_FILE="vtk-prompt.log" setup_logging(level="DEBUG", log_file="vtk-prompt.log") ``` +## Component System Architecture + +The VTK Prompt system uses a modular component-based architecture for prompt +assembly. + +### Overview + +The component system allows you to: + +- **Compose prompts** from reusable YAML files +- **Inject variables** dynamically (`{{VAR_NAME}}`) +- **Configure model parameters** per component +- **Conditionally include** components based on context + +### Component Structure + +Components are YAML files stored in `src/vtk_prompt/prompts/components/`: + +```yaml +# example_component.yml +role: system | user | assistant +content: | + Your prompt content here with {{VARIABLE}} substitution. + VTK Version: {{VTK_VERSION}} + +# Optional: Merge with existing message instead of creating new one +append: true | false # Add content after existing user message (e.g., additional instructions) +prepend: true | false # Add content before existing user message (e.g., context injection) + +# Optional: Model configuration +model: "openai/gpt-5" +modelParameters: + temperature: 0.5 + max_tokens: 4000 +``` + +### Updating Existing Components + +When modifying components: + +1. **Preserve backward compatibility** - existing variable names +2. **Test thoroughly** - run full test suite +3. **Document changes** - update component comments +4. **Version carefully** - consider impact on existing prompts + +### Component Loading System + +The system uses these key classes: + +- **`PromptComponentLoader`**: Loads and caches YAML files +- **`VTKPromptAssembler`**: Chains components together +- **`YAMLPromptLoader`**: Handles variable substitution +- **`assemble_vtk_prompt()`**: High-level convenience function + +### Variable Substitution + +Components support these built-in variables: + +- `{{VTK_VERSION}}` - Current VTK version (e.g., "9.5.0") +- `{{PYTHON_VERSION}}` - Python requirements (e.g., ">=3.10") + +Custom variables can be passed via: + +```python +assembler.substitute_variables(CUSTOM_VAR="value") +# or +assemble_vtk_prompt("request", CUSTOM_VAR="value") +``` + ## Developer Mode The web UI includes a developer mode that enables hot reload and debug logging diff --git a/README.md b/README.md index 1cda31b..dab79ae 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ [![Build and Publish](https://github.com/vicentebolea/vtk-prompt/actions/workflows/publish.yml/badge.svg)](https://github.com/vicentebolea/vtk-prompt/actions/workflows/publish.yml) [![PyPI version](https://badge.fury.io/py/vtk-prompt.svg)](https://badge.fury.io/py/vtk-prompt) [![Python 3.10+](https://img.shields.io/badge/python-3.10+-blue.svg)](https://www.python.org/downloads/) -[![Coverage](https://img.shields.io/badge/coverage-11.0%25-red.svg)](htmlcov/index.html) +[![codecov](https://codecov.io/github/Kitware/vtk-prompt/graph/badge.svg?token=gg8CHNeBKR)](https://codecov.io/github/Kitware/vtk-prompt) A command-line interface and web-based UI for generating VTK visualization code using Large Language Models (Anthropic Claude, OpenAI GPT, NVIDIA NIM, and local @@ -99,7 +99,7 @@ vtk-prompt "Create a red sphere" # Advanced options vtk-prompt "Create a textured cone with 32 resolution" \ --provider anthropic \ - --model claude-opus-4-1 \ + --model claude-opus-4-6 \ --max-tokens 4000 \ --rag \ --verbose @@ -141,6 +141,61 @@ code = client.generate_code("Create a red sphere") print(code) ``` +## Model Configuration + +**Model configuration with YAML prompt files:** + +```yaml +# Model and parameter configuration +model: anthropic/claude-opus-4-6 +modelParameters: + temperature: 0.2 + max_tokens: 6000 +``` + +**Using custom prompt files:** + +```bash +# CLI: Use your custom prompt file +vtk-prompt "Create a sphere" --prompt-file custom_vtk_prompt.yml + +# CLI: Or with additional CLI overrides +vtk-prompt "Create a complex scene" --prompt-file custom_vtk_prompt.yml --retry-attempts 3 + +# UI: Use your custom prompt file +vtk-prompt-ui --server --prompt-file custom_vtk_prompt.yml +``` + +### Model Parameters Guide + +**Temperature Settings:** + +- `0.1-0.3`: More focused, deterministic code generation +- `0.4-0.7`: Balanced creativity and consistency (recommended) +- `0.8-1.0`: More creative but potentially less reliable + +**Token Limits:** Token usage can vary significantly between models and +providers. These are general guidelines: + +- `1000-2000`: Simple visualizations and basic VTK objects +- `3000-4000`: Complex scenes with multiple objects +- `5000+`: Detailed implementations with extensive documentation + +_Note: Different models have different token limits and costs. Check your +provider's documentation for specific model capabilities._ + +## Testing + +Run the test suite using the project's standard tools: + +```bash +# Run all tests with tox +tox -e test + +# Run pre-commit hooks (includes testing) +pre-commit run --all-files +``` + ## Configuration ### Environment Variables @@ -152,7 +207,7 @@ print(code) | Provider | Default Model | Base URL | | ------------- | ------------------------ | ----------------------------------- | -| **anthropic** | claude-opus-4-1 | https://api.anthropic.com/v1 | +| **anthropic** | claude-opus-4-6 | https://api.anthropic.com/v1 | | **openai** | gpt-5 | https://api.openai.com/v1 | | **nim** | meta/llama3-70b-instruct | https://integrate.api.nvidia.com/v1 | | **custom** | User-defined | User-defined (for local models) | diff --git a/pyproject.toml b/pyproject.toml index e52f371..1ded06e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,6 +6,7 @@ authors = [ {name = "Vicente Adolfo Bolea Sanchez", email = "vicente.bolea@kitware.com"}, ] dependencies = [ + "PyYAML>=6.0", "chromadb>=0.6.3", "click>=8.0.0", "importlib_resources>=5.0.0", @@ -45,7 +46,6 @@ dev = [ "pre-commit", "ruff", "pytest >=6", - "pytest-cov >=3", "tox", "black>=23.0.0", "isort>=5.12.0", @@ -58,12 +58,12 @@ dev = [ "types-requests", "types-click", "types-setuptools", + "types-PyYAML", ] # Testing dependencies test = [ "pytest>=7.0.0", - "pytest-cov>=4.0.0", ] # All optional dependencies diff --git a/src/vtk_prompt/cli.py b/src/vtk_prompt/cli.py index 3481b99..226df92 100644 --- a/src/vtk_prompt/cli.py +++ b/src/vtk_prompt/cli.py @@ -57,6 +57,10 @@ "--conversation", help="Path to conversation file for maintaining chat history", ) +@click.option( + "--prompt-file", + help="Path to custom YAML prompt file (overrides built-in prompts and defaults)", +) def main( input_string: str, provider: str, @@ -72,6 +76,7 @@ def main( top_k: int, retry_attempts: int, conversation: Optional[str], + prompt_file: Optional[str], ) -> None: """ Generate and execute VTK code using LLMs. @@ -90,12 +95,53 @@ def main( # Set default models based on provider if model == "gpt-5": default_models = { - "anthropic": "claude-opus-4-1", + "anthropic": "claude-opus-4-6", "gemini": "gemini-2.5-pro", "nim": "meta/llama3-70b-instruct", } model = default_models.get(provider, model) + # Load custom prompt file if provided + custom_prompt_data = None + if prompt_file: + try: + from pathlib import Path + + custom_file_path = Path(prompt_file) + if not custom_file_path.exists(): + logger.error("Custom prompt file not found: %s", prompt_file) + sys.exit(1) + + # Load the custom YAML prompt file manually + import yaml + + with open(custom_file_path, "r") as f: + custom_prompt_data = yaml.safe_load(f) + + logger.info("Loaded custom prompt file: %s", custom_file_path.name) + + # Override defaults with custom prompt parameters, but preserve CLI overrides + # Only override if CLI argument is still the default value + if custom_prompt_data and isinstance(custom_prompt_data, dict): + # Override model if CLI didn't specify a custom one + if model == "gpt-5" and custom_prompt_data.get("model"): + model = custom_prompt_data.get("model") + logger.info("Using model from prompt file: %s", model) + + # Override model parameters if CLI used defaults + model_params = custom_prompt_data.get("modelParameters", {}) + if temperature == 0.7 and "temperature" in model_params: + temperature = model_params["temperature"] + logger.info("Using temperature from prompt file: %s", temperature) + + if max_tokens == 1000 and "max_tokens" in model_params: + max_tokens = model_params["max_tokens"] + logger.info("Using max_tokens from prompt file: %s", max_tokens) + + except Exception as e: + logger.error("Failed to load custom prompt file %s: %s", prompt_file, e) + sys.exit(1) + # Handle temperature override for unsupported models if not supports_temperature(model): logger.warning( @@ -122,10 +168,24 @@ def main( top_k=top_k, rag=rag, retry_attempts=retry_attempts, + provider=provider, + custom_prompt=custom_prompt_data, ) - if isinstance(result, tuple) and len(result) == 3: - explanation, generated_code, usage = result + # Handle result with optional validation warnings + if isinstance(result, tuple): + if len(result) == 4: + # Result includes validation warnings + explanation, generated_code, usage, validation_warnings = result + # Display validation warnings + for warning in validation_warnings: + logger.warning("Custom prompt validation: %s", warning) + elif len(result) == 3: + explanation, generated_code, usage = result + else: + logger.info("Result: %s", result) + return + if verbose and usage: logger.info( "Used tokens: input=%d output=%d", diff --git a/src/vtk_prompt/client.py b/src/vtk_prompt/client.py index 65d3267..be5aa08 100644 --- a/src/vtk_prompt/client.py +++ b/src/vtk_prompt/client.py @@ -24,11 +24,7 @@ import openai from . import get_logger -from .prompts import ( - get_no_rag_context, - get_python_role, - get_rag_context, -) +from .prompts import assemble_vtk_prompt logger = get_logger(__name__) @@ -132,6 +128,169 @@ def run_code(self, code_string: str) -> None: logger.debug("Failed code:\n%s", code_string) return + def _validate_and_extract_model_params( + self, + custom_prompt_data: dict, + default_model: str, + default_temperature: float, + default_max_tokens: int, + ) -> tuple[str, float, int, list[str]]: + """Validate and extract model parameters from custom prompt data. + + Args: + custom_prompt_data: The loaded YAML prompt data + default_model: Default model to use if validation fails + default_temperature: Default temperature to use if validation fails + default_max_tokens: Default max_tokens to use if validation fails + + Returns: + Tuple of (model, temperature, max_tokens, warnings) + """ + from .provider_utils import ( + get_available_models, + get_default_model, + get_supported_providers, + supports_temperature, + ) + + warnings = [] + model = default_model + temperature = default_temperature + max_tokens = default_max_tokens + + # Extract model if present + if "model" in custom_prompt_data: + custom_model = custom_prompt_data["model"] + + # Validate model format: should be "provider/model" + if "/" not in custom_model: + warnings.append( + f"Invalid model format '{custom_model}'. Expected 'provider/model'. " + f"Using default: {default_model}" + ) + else: + provider, model_name = custom_model.split("/", 1) + + # Validate provider is supported + if provider not in get_supported_providers(): + warnings.append( + f"Unsupported provider '{provider}'. " + f"Supported providers: {', '.join(get_supported_providers())}. " + f"Using default: {default_model}" + ) + else: + # Validate model exists for provider + available_models = get_available_models().get(provider, []) + if model_name not in available_models: + warnings.append( + f"Model '{model_name}' not in curated list for provider '{provider}'. " + f"Available: {', '.join(available_models)}. " + f"Using default: {get_default_model(provider)}" + ) + model = f"{provider}/{get_default_model(provider)}" + else: + model = custom_model + + # Extract modelParameters if present + if "modelParameters" in custom_prompt_data: + params = custom_prompt_data["modelParameters"] + + # Validate temperature + if "temperature" in params: + try: + custom_temp = float(params["temperature"]) + if 0.0 <= custom_temp <= 2.0: + temperature = custom_temp + + # Check if model supports temperature + if not supports_temperature(model): + warnings.append( + f"Model '{model}' does not support temperature control. " + f"Temperature will be set to 1.0." + ) + temperature = 1.0 + else: + warnings.append( + f"Temperature {custom_temp} out of range [0.0, 2.0]. " + f"Using default: {default_temperature}" + ) + except (ValueError, TypeError): + warnings.append( + f"Invalid temperature value '{params['temperature']}'. " + f"Using default: {default_temperature}" + ) + + # Validate max_tokens + if "max_tokens" in params: + try: + custom_max_tokens = int(params["max_tokens"]) + if 1 <= custom_max_tokens <= 100000: + max_tokens = custom_max_tokens + else: + warnings.append( + f"max_tokens {custom_max_tokens} out of range [1, 100000]. " + f"Using default: {default_max_tokens}" + ) + except (ValueError, TypeError): + warnings.append( + f"Invalid max_tokens value '{params['max_tokens']}'. " + f"Using default: {default_max_tokens}" + ) + + return model, temperature, max_tokens, warnings + + def _extract_context_snippets(self, rag_snippets: dict) -> str: + """Extract and format RAG context snippets. + + Args: + rag_snippets: RAG snippets dictionary + + Returns: + Formatted context snippets string + """ + return "\n\n".join(rag_snippets["code_snippets"]) + + def _format_custom_prompt( + self, custom_prompt_data: dict, message: str, rag_snippets: Optional[dict] = None + ) -> list[dict[str, str]]: + """Format custom prompt data into messages for LLM client. + + Args: + custom_prompt_data: The loaded YAML prompt data + message: The user request + rag_snippets: Optional RAG snippets for context enhancement + + Returns: + Formatted messages ready for LLM client + + Note: + This method does NOT extract model/modelParameters from custom_prompt_data. + That is handled by _validate_and_extract_model_params() in the query() method. + """ + from .prompts import PYTHON_VERSION, VTK_VERSION, YAMLPromptLoader + + # Prepare variables for substitution + variables = { + "VTK_VERSION": VTK_VERSION, + "PYTHON_VERSION": PYTHON_VERSION, + "request": message, + } + + # Add RAG context if available + if rag_snippets: + variables["context_snippets"] = self._extract_context_snippets(rag_snippets) + + # Process messages from custom prompt + messages = custom_prompt_data.get("messages", []) + formatted_messages = [] + + yaml_loader = YAMLPromptLoader() + for msg in messages: + content = yaml_loader.substitute_yaml_variables(msg.get("content", ""), variables) + formatted_messages.append({"role": msg.get("role", "user"), "content": content}) + + return formatted_messages + def query( self, message: str = "", @@ -143,7 +302,10 @@ def query( top_k: int = 5, rag: bool = False, retry_attempts: int = 1, - ) -> Union[tuple[str, str, Any], str]: + provider: Optional[str] = None, + custom_prompt: Optional[dict] = None, + ui_mode: bool = False, + ) -> Union[tuple[str, str, Any], tuple[str, str, Any, list[str]], str]: """Generate VTK code with optional RAG enhancement and retry logic. Args: @@ -156,6 +318,9 @@ def query( top_k: Number of RAG examples to retrieve rag: Whether to use RAG enhancement retry_attempts: Number of times to retry if AST validation fails + provider: LLM provider to use (overrides instance provider if provided) + custom_prompt: Custom YAML prompt data (overrides built-in prompts) + ui_mode: Whether the request is coming from UI (affects prompt selection) """ if not api_key: api_key = os.environ.get("OPENAI_API_KEY") @@ -181,7 +346,6 @@ def query( if not check_rag_components_available(): raise ValueError("RAG components not available") - rag_snippets = get_rag_snippets( message, collection_name=self.collection_name, @@ -189,30 +353,77 @@ def query( top_k=top_k, ) - if not rag_snippets: - raise ValueError("Failed to load RAG snippets") + # Store validation warnings to return to caller + validation_warnings = [] + + # Use custom prompt if provided, otherwise use component-based assembly + if custom_prompt: + # Validate and extract model parameters from custom prompt + validated_model, validated_temp, validated_max_tokens, warnings = ( + self._validate_and_extract_model_params( + custom_prompt, model, temperature, max_tokens + ) + ) + validation_warnings.extend(warnings) - context_snippets = "\n\n".join(rag_snippets["code_snippets"]) - context = get_rag_context(message, context_snippets) + # Apply validated parameters + model = validated_model + temperature = validated_temp + max_tokens = validated_max_tokens + # Log warnings + for warning in warnings: + logger.warning(warning) + + # Process custom prompt data + yaml_messages = self._format_custom_prompt( + custom_prompt, message, rag_snippets if rag else None + ) if self.verbose: - logger.debug("RAG context: %s", context) - references = rag_snippets.get("references") - if references: - logger.info("Using examples from: %s", ", ".join(references)) + logger.debug("Using custom YAML prompt from file") + logger.debug( + "Applied model: %s, temperature: %s, max_tokens: %s", + model, + temperature, + max_tokens, + ) else: - context = get_no_rag_context(message) + # Use component-based assembly system (now the default and only option) + from .prompts import PYTHON_VERSION, VTK_VERSION + + context_snippets = None + if rag and rag_snippets: + context_snippets = self._extract_context_snippets(rag_snippets) + + prompt_data = assemble_vtk_prompt( + request=message, + ui_mode=ui_mode, + rag_enabled=rag, + context_snippets=context_snippets, + VTK_VERSION=VTK_VERSION, + PYTHON_VERSION=PYTHON_VERSION, + ) + yaml_messages = prompt_data["messages"] + if self.verbose: - logger.debug("No-RAG context: %s", context) + mode_str = "UI" if ui_mode else "CLI" + rag_str = " + RAG" if rag else "" + logger.debug(f"Using component assembly ({mode_str}{rag_str})") + + if rag and rag_snippets: + references = rag_snippets.get("references") + if references: + logger.info("Using examples from: %s", ", ".join(references)) - # Initialize conversation with system message if empty + # Initialize conversation with YAML messages if empty if not self.conversation: self.conversation = [] - self.conversation.append({"role": "system", "content": get_python_role()}) - - # Add current user message - if message: - self.conversation.append({"role": "user", "content": context}) + # Add all messages from YAML prompt (system + user) + self.conversation.extend(yaml_messages) + else: + # If conversation exists, just add the user message (last message from YAML) + if message and yaml_messages: + self.conversation.append(yaml_messages[-1]) # Retry loop for AST validation for attempt in range(retry_attempts): @@ -259,6 +470,14 @@ def query( if message: self.conversation.append({"role": "assistant", "content": content}) self.save_conversation() + # Return warnings if custom prompt was used + if validation_warnings: + return ( + generated_explanation, + generated_code, + response.usage, + validation_warnings, + ) return generated_explanation, generated_code, response.usage elif attempt < retry_attempts - 1: # Don't log on last attempt @@ -283,6 +502,14 @@ def query( if message: self.conversation.append({"role": "assistant", "content": content}) self.save_conversation() + # Return warnings if custom prompt was used + if validation_warnings: + return ( + generated_explanation, + generated_code, + response.usage or {}, + validation_warnings, + ) return ( generated_explanation, generated_code, diff --git a/src/vtk_prompt/generate_files.py b/src/vtk_prompt/generate_files.py index bf2b889..80bc212 100644 --- a/src/vtk_prompt/generate_files.py +++ b/src/vtk_prompt/generate_files.py @@ -19,18 +19,14 @@ import os import sys from pathlib import Path -from typing import Optional +from typing import Optional, cast import click import openai +from openai.types.chat import ChatCompletionMessageParam from . import get_logger - -# Import our template system -from .prompts import ( - get_vtk_xml_context, - get_xml_role, -) +from .prompts import YAMLPromptLoader logger = get_logger(__name__) @@ -63,16 +59,14 @@ def generate_xml( else: _ = "" - context = get_vtk_xml_context(message) + yaml_loader = YAMLPromptLoader() + yaml_messages = yaml_loader.get_yaml_prompt("vtk_xml_generation", description=message) + messages_param = cast(list[ChatCompletionMessageParam], yaml_messages) response = self.client.chat.completions.create( model=model, - messages=[ - {"role": "system", "content": get_xml_role()}, - {"role": "user", "content": context}, - ], max_completion_tokens=max_tokens, - # max_tokens=max_tokens, + messages=messages_param, temperature=temperature, ) @@ -157,7 +151,7 @@ def main( # Set default models based on provider if model == "gpt-5": default_models = { - "anthropic": "claude-opus-4-1", + "anthropic": "claude-opus-4-6", "gemini": "gemini-2.5-pro", "nim": "meta/llama3-70b-instruct", } diff --git a/src/vtk_prompt/prompts/__init__.py b/src/vtk_prompt/prompts/__init__.py index ceec2bc..72fdfb5 100644 --- a/src/vtk_prompt/prompts/__init__.py +++ b/src/vtk_prompt/prompts/__init__.py @@ -1,92 +1,47 @@ """ -VTK Prompt Template System. +VTK Prompt System. -This module provides a template system for generating prompts used in VTK code generation. -It includes functions for loading and formatting various template types: +This module provides a component-based prompt system for VTK code generation. -- Base context templates with VTK and Python version information -- RAG (Retrieval-Augmented Generation) context templates -- No-RAG context templates for direct queries -- Role-based templates for Python and XML generation -- UI-specific templates for post-processing +The system supports two approaches: +1. Component assembly (primary): Modular, file-based components that can be + composed programmatically to create prompts for different scenarios +2. YAML prompt loading (for custom prompts): Direct loading of user-defined + YAML prompt files with variable substitution -Templates are stored as text files in the prompts directory and can be dynamically -formatted with runtime values like VTK version, Python version, user requests, -and context snippets. +Component types include: +- Base system messages and VTK coding instructions +- UI-specific renderer instructions for web interface +- RAG context injection for retrieval-augmented generation +- Output formatting and model parameter defaults + +Components are stored as YAML files and assembled at runtime with support for +conditional inclusion, variable substitution, and message composition. """ from pathlib import Path -import vtk - -PYTHON_VERSION = ">=3.10" -VTK_VERSION = vtk.__version__ +from .constants import PYTHON_VERSION, VTK_VERSION +from .prompt_component_assembler import ( + PromptComponentLoader, + VTKPromptAssembler, + assemble_vtk_prompt, +) +from .yaml_prompt_loader import YAMLPromptLoader # Path to the prompts directory PROMPTS_DIR = Path(__file__).parent - -def load_template(template_name: str) -> str: - """Load a template file from the prompts directory. - - Args: - template_name: Name of the template file (without .txt extension) - - Returns: - The template content as a string - """ - template_path = PROMPTS_DIR / f"{template_name}.txt" - if not template_path.exists(): - raise FileNotFoundError(f"Template {template_name} not found at {template_path}") - - return template_path.read_text() - - -def get_base_context() -> str: - """Get the base context template with version variables filled in.""" - template = load_template("base_context") - return template.format(VTK_VERSION=VTK_VERSION, PYTHON_VERSION=PYTHON_VERSION) - - -def get_no_rag_context(request: str) -> str: - """Get the no-RAG context template with request filled in.""" - base_context = get_base_context() - template = load_template("no_rag_context") - return template.format(BASE_CONTEXT=base_context, request=request) - - -def get_rag_context(request: str, context_snippets: str) -> str: - """Get the RAG context template with request and snippets filled in.""" - base_context = get_base_context() - template = load_template("rag_context") - return template.format( - BASE_CONTEXT=base_context, request=request, context_snippets=context_snippets - ) - - -def get_python_role() -> str: - """Get the Python role template with version filled in.""" - template = load_template("python_role") - return template.format(PYTHON_VERSION=PYTHON_VERSION) - - -def get_vtk_xml_context(description: str) -> str: - """Get the VTK XML context template with description filled in.""" - template = load_template("vtk_xml_context") - return template.format(VTK_VERSION=VTK_VERSION, description=description) - - -def get_xml_role() -> str: - """Get the XML role template.""" - return load_template("xml_role") - - -def get_ui_post_prompt() -> str: - """Get the UI post prompt template.""" - return load_template("ui_post_prompt") - - -def get_rag_chat_context(context: str, query: str) -> str: - """Get the RAG chat context template with context and query filled in.""" - template = load_template("rag_chat_context") - return template.format(CONTEXT=context, QUERY=query) +# Export classes and functions for public API +__all__ = [ + # Component assembly (default system) + "assemble_vtk_prompt", + "PromptComponentLoader", + "VTKPromptAssembler", + # YAML prompt loader (for custom prompts) + "YAMLPromptLoader", + # Constants + "PYTHON_VERSION", + "VTK_VERSION", + "PROMPTS_DIR", +] diff --git a/src/vtk_prompt/prompts/base_context.txt b/src/vtk_prompt/prompts/base_context.txt deleted file mode 100644 index 31f1f04..0000000 --- a/src/vtk_prompt/prompts/base_context.txt +++ /dev/null @@ -1,38 +0,0 @@ -Write Python source code with an explanation that uses VTK. - - -- DO NOT READ OUTSIDE DATA -- DO NOT DEFINE FUNCTIONS -- DO NOT USE MARKDOWN -- ALWAYS PROVIDE SOURCE CODE -- ONLY import VTK and numpy if needed -- Only use {VTK_VERSION} Python basic components. -- Only use {PYTHON_VERSION} or above. - - - -- First, provide a **short but complete explanation** written in **full sentences**. -- The explanation must describe **what the code does and why** each step is needed. -- The explanation must always come **before** the code. -- The explanation MUST begin with a "" tag and end with a "" tag. -- The code MUST begin with a "" tag and end with a "" tag. -- Do not summarize, introduce, or conclude outside the explanation or code itself. -- Output the Python code **exactly as written**, with no additional text before or after the code. -- **No** markdown markers like ```python or ``` anywhere. -- Do not add phrases like “Here is the source code” or similar. -- The explanation must stay **above the code**. -- You may use inline comments in the code if helpful for clarity. - - - -input: Only create a vtkSphere -output: - -This code starts by generating the sphere geometry using vtkSphereSource. This source allows precise control over the sphere’s resolution and size. - - - -# Create a sphere source -sphere = vtk.vtkSphereSource() - - diff --git a/src/vtk_prompt/prompts/components/base_system.yml b/src/vtk_prompt/prompts/components/base_system.yml new file mode 100644 index 0000000..46b44d1 --- /dev/null +++ b/src/vtk_prompt/prompts/components/base_system.yml @@ -0,0 +1,4 @@ +role: system +content: | + You are a python {{PYTHON_VERSION}} source code producing entity, your + output will be fed to a python interpreter diff --git a/src/vtk_prompt/prompts/components/model_defaults.yml b/src/vtk_prompt/prompts/components/model_defaults.yml new file mode 100644 index 0000000..b384ec3 --- /dev/null +++ b/src/vtk_prompt/prompts/components/model_defaults.yml @@ -0,0 +1,4 @@ +model: openai/gpt-5 +modelParameters: + temperature: 0.5 + max_tokens: 10000 diff --git a/src/vtk_prompt/prompts/components/output_format.yml b/src/vtk_prompt/prompts/components/output_format.yml new file mode 100644 index 0000000..11fe5e7 --- /dev/null +++ b/src/vtk_prompt/prompts/components/output_format.yml @@ -0,0 +1,16 @@ +role: assistant +append: true +content: | + + - First, provide a **short but complete explanation** written in **full sentences**. + - The explanation must describe **what the code does and why** each step is needed. + - The explanation must always come **before** the code. + - The explanation MUST begin with a "" tag and end with a "" tag. + - The code MUST begin with a "" tag and end with a "" tag. + - Do not summarize, introduce, or conclude outside the explanation or code itself. + - Output the Python code **exactly as written**, with no additional text before or after the code. + - **No** markdown markers like ```python or ``` anywhere. + - Do not add phrases like "Here is the source code" or similar. + - The explanation must stay **above the code**. + - You may use inline comments in the code if helpful for clarity. + diff --git a/src/vtk_prompt/prompts/components/rag_context.yml b/src/vtk_prompt/prompts/components/rag_context.yml new file mode 100644 index 0000000..652ea70 --- /dev/null +++ b/src/vtk_prompt/prompts/components/rag_context.yml @@ -0,0 +1,7 @@ +role: assistant +prepend: true +content: | + + Here are relevant VTK examples and code snippets: + {{context_snippets}} + diff --git a/src/vtk_prompt/prompts/components/ui_renderer.yml b/src/vtk_prompt/prompts/components/ui_renderer.yml new file mode 100644 index 0000000..a8ae3f8 --- /dev/null +++ b/src/vtk_prompt/prompts/components/ui_renderer.yml @@ -0,0 +1,11 @@ +role: assistant +append: true +content: | + + - Do not create a new vtkRenderer + - Use the injected vtkrenderer object named renderer + - Do not manage rendering things + - You must connect the actors to the renderer injected object + - You must render what I ask even if I do not ask to render it + - Only avoid rendering if I explictitly ask you not to render it + diff --git a/src/vtk_prompt/prompts/components/vtk_instructions.yml b/src/vtk_prompt/prompts/components/vtk_instructions.yml new file mode 100644 index 0000000..ce23115 --- /dev/null +++ b/src/vtk_prompt/prompts/components/vtk_instructions.yml @@ -0,0 +1,13 @@ +role: assistant +content: | + Write Python source code with an explanation that uses VTK. + + + - DO NOT READ OUTSIDE DATA + - DO NOT DEFINE FUNCTIONS + - DO NOT USE MARKDOWN + - ALWAYS PROVIDE SOURCE CODE + - ONLY import VTK and numpy if needed + - Only use {{VTK_VERSION}} Python basic components. + - Only use {{PYTHON_VERSION}} or above. + diff --git a/src/vtk_prompt/prompts/constants.py b/src/vtk_prompt/prompts/constants.py new file mode 100644 index 0000000..3637c98 --- /dev/null +++ b/src/vtk_prompt/prompts/constants.py @@ -0,0 +1,11 @@ +""" +Constants for the VTK Prompt system. + +This module defines version constants used throughout the prompt system. +""" + +import vtk + +# Version constants used in prompt variable substitution +PYTHON_VERSION = ">=3.10" +VTK_VERSION = vtk.__version__ diff --git a/src/vtk_prompt/prompts/no_rag_context.txt b/src/vtk_prompt/prompts/no_rag_context.txt deleted file mode 100644 index d2b6d89..0000000 --- a/src/vtk_prompt/prompts/no_rag_context.txt +++ /dev/null @@ -1,4 +0,0 @@ -{BASE_CONTEXT} - -Request: -{request} diff --git a/src/vtk_prompt/prompts/prompt_component_assembler.py b/src/vtk_prompt/prompts/prompt_component_assembler.py new file mode 100644 index 0000000..4609774 --- /dev/null +++ b/src/vtk_prompt/prompts/prompt_component_assembler.py @@ -0,0 +1,235 @@ +""" +Component-based prompt assembly system for VTK Prompt. + +This module provides a flexible system for assembling prompts from reusable +file-based components. Components are stored as YAML files and can be composed +programmatically to create different prompt variations. +""" + +from functools import lru_cache +from pathlib import Path +from typing import Any, Optional, TypedDict + +import yaml + +from .yaml_prompt_loader import YAMLPromptLoader + +# Global instance for variable substitution in component content +_yaml_variable_substituter = YAMLPromptLoader() + +# Keys that define model parameters in component YAML files +_MODEL_PARAM_KEYS = frozenset(["model", "modelParameters"]) + + +@lru_cache(maxsize=32) +def _load_component_cached(components_dir: str, component_name: str) -> dict[str, Any]: + """Module-level cached loader to avoid caching bound methods. + + Using a module-level cache prevents retaining references to the class instance + that would otherwise cause B019 (potential memory leak). + """ + component_file = Path(components_dir) / f"{component_name}.yml" + if not component_file.exists(): + raise FileNotFoundError(f"Component not found: {component_file}") + + with open(component_file) as f: + return yaml.safe_load(f) + + +class PromptData(TypedDict, total=False): + """Type definition for assembled prompt data.""" + + messages: list[dict[str, str]] + model: str + modelParameters: dict[str, Any] + + +class PromptComponentLoader: + """Load and cache prompt components from files.""" + + def __init__(self, components_dir: Optional[Path] = None): + """Initialize component loader. + + Args: + components_dir: Directory containing component YAML files + """ + self.components_dir = components_dir or Path(__file__).parent / "components" + + # Ensure components directory exists + if not self.components_dir.exists(): + raise FileNotFoundError(f"Components directory not found: {self.components_dir}") + + def load_component(self, component_name: str) -> dict[str, Any]: + """Load a component file with caching. + + Args: + component_name: Name of component file (without .yml extension) + + Returns: + Component data from YAML file + + Raises: + FileNotFoundError: If component file doesn't exist + """ + return _load_component_cached(str(self.components_dir), component_name) + + def clear_cache(self) -> None: + """Clear component cache (useful for development).""" + _load_component_cached.cache_clear() + + def list_components(self) -> list[str]: + """List available component names.""" + return [f.stem for f in self.components_dir.glob("*.yml")] + + +class VTKPromptAssembler: + """Assemble VTK prompts from file-based components.""" + + def __init__(self, loader: Optional[PromptComponentLoader] = None): + """Initialize prompt assembler. + + Args: + loader: Component loader instance (creates default if None) + """ + self.loader = loader or PromptComponentLoader() + self.messages: list[dict[str, str]] = [] + self.model_params: dict[str, Any] = {} + + def add_component(self, component_name: str) -> "VTKPromptAssembler": + """Add a component from file. + + Args: + component_name: Name of component to add + + Returns: + Self for method chaining + """ + component = self.loader.load_component(component_name) + + if "role" in component: + message = {"role": component["role"], "content": component["content"]} + + # Handle special message composition: some components need to merge + # with existing user messages rather than create new ones + if component.get("append") and self.messages and self.messages[-1]["role"] == "user": + # Append content to last user message (e.g., additional instructions) + self.messages[-1]["content"] += "\n\n" + component["content"] + elif component.get("prepend") and self.messages and self.messages[-1]["role"] == "user": + # Prepend content to last user message (e.g., context before instructions) + self.messages[-1]["content"] = ( + component["content"] + "\n\n" + self.messages[-1]["content"] + ) + else: + # Default: add as new message + self.messages.append(message) + + # Extract model parameters if present + if any(key in component for key in _MODEL_PARAM_KEYS): + self.model_params.update({k: v for k, v in component.items() if k in _MODEL_PARAM_KEYS}) + + return self + + def add_if(self, condition: bool, component_name: str) -> "VTKPromptAssembler": + """Conditionally add a component. + + Args: + condition: Whether to add the component + component_name: Name of component to add if condition is True + + Returns: + Self for method chaining + """ + if condition: + self.add_component(component_name) + return self + + def add_request(self, request: str) -> "VTKPromptAssembler": + """Add the user request as a message. + + Args: + request: User's request text + + Returns: + Self for method chaining + """ + self.messages.append({"role": "user", "content": f"Request: {request}"}) + return self + + def substitute_variables(self, **variables: Any) -> "VTKPromptAssembler": + """Substitute variables in all message content. + + Args: + **variables: Variables to substitute in {{variable}} format + + Returns: + Self for method chaining + """ + for message in self.messages: + message["content"] = _yaml_variable_substituter.substitute_yaml_variables( + message["content"], variables + ) + + return self + + def build_prompt_data(self) -> PromptData: + """Build the final prompt data. + + Returns: + Dictionary with 'messages' and model parameters + """ + result: PromptData = {"messages": self.messages.copy()} + result.update(self.model_params) # type: ignore[typeddict-item] + return result + + +def assemble_vtk_prompt( + request: str, + ui_mode: bool = False, + rag_enabled: bool = False, + context_snippets: Optional[str] = None, + **variables: Any, +) -> PromptData: + """Assemble VTK prompt from file-based components. + + Args: + request: User's request text + ui_mode: Whether to include UI-specific instructions + rag_enabled: Whether to include RAG context + context_snippets: RAG context snippets (required if rag_enabled=True) + **variables: Additional variables for substitution + + Returns: + Complete prompt data ready for LLM client + + Raises: + ValueError: If rag_enabled=True but context_snippets is empty + """ + if rag_enabled and not context_snippets: + raise ValueError("context_snippets required when rag_enabled=True") + + assembler = VTKPromptAssembler() + + # Always add base components in order + assembler.add_component("model_defaults") + assembler.add_component("base_system") + assembler.add_component("vtk_instructions") + + # Conditional components (order matters for message composition) + assembler.add_if(rag_enabled, "rag_context") + assembler.add_if(ui_mode, "ui_renderer") + + # Always add output format and request last + assembler.add_component("output_format") + assembler.add_request(request) + + # Variable substitution with defaults + default_variables = { + "VTK_VERSION": variables.get("VTK_VERSION", "9.5.0"), + "PYTHON_VERSION": variables.get("PYTHON_VERSION", ">=3.10"), + "context_snippets": context_snippets or "", + } + default_variables.update(variables) + + assembler.substitute_variables(**default_variables) + + return assembler.build_prompt_data() diff --git a/src/vtk_prompt/prompts/python_role.txt b/src/vtk_prompt/prompts/python_role.txt deleted file mode 100644 index e0126b2..0000000 --- a/src/vtk_prompt/prompts/python_role.txt +++ /dev/null @@ -1 +0,0 @@ -You are a python {PYTHON_VERSION} source code producing entity, your output will be fed to a python interpreter diff --git a/src/vtk_prompt/prompts/rag_chat.prompt.yml b/src/vtk_prompt/prompts/rag_chat.prompt.yml new file mode 100644 index 0000000..de06b11 --- /dev/null +++ b/src/vtk_prompt/prompts/rag_chat.prompt.yml @@ -0,0 +1,54 @@ +name: VTK RAG Chat Assistant +description: | + AI assistant for VTK documentation queries using retrieval-augmented + generation +model: openai/gpt-5 +modelParameters: + temperature: 0.5 + max_tokens: 1500 +messages: + - role: system + content: | + You are an AI assistant specializing in VTK (Visualization Toolkit) + documentation. Your primary task is to provide accurate, concise, and helpful + responses to user queries about VTK, including relevant code snippets + - role: assistant + content: | + Here is the context information you should use to answer queries: + + {{CONTEXT}} + + + Here's the user's query: + + + {{QUERY}} + + + When responding to a user query, follow these guidelines: + + 1. Relevance Check: + + - If the query is not relevant to VTK, respond with "This question is not relevant to VTK." + + 2. Answer Formulation: + + - If you don't know the answer, clearly state that. + - If uncertain, ask the user for clarification. + - Respond in the same language as the user's query. + - Be concise while providing complete information. + - If the answer isn't in the context but you have the knowledge, explain this to the user and provide the answer based on your understanding. +testData: + - prompt: How do I create a sphere in VTK? + expected: Should provide clear instructions with code examples + - prompt: What is the difference between vtkPolyData and vtkUnstructuredGrid? + expected: Should explain data structure differences with use cases + - prompt: How to cook pasta? + expected: Should respond that this is not relevant to VTK +evaluators: + - type: relevance_check + description: Verifies that non-VTK questions are properly identified + - type: accuracy_assessment + description: Checks if VTK-related answers are technically accurate + - type: completeness_evaluation + description: Ensures answers provide sufficient detail without being verbose diff --git a/src/vtk_prompt/prompts/rag_chat_context.txt b/src/vtk_prompt/prompts/rag_chat_context.txt deleted file mode 100644 index 20d262f..0000000 --- a/src/vtk_prompt/prompts/rag_chat_context.txt +++ /dev/null @@ -1,28 +0,0 @@ -You are an AI assistant specializing in VTK (Visualization Toolkit) -documentation. Your primary task is to provide accurate, concise, and helpful -responses to user queries about VTK, including relevant code snippets - -Here is the context information you should use to answer queries: - -{CONTEXT} - - -Here's the user's query: - - -{QUERY} - - -When responding to a user query, follow these guidelines: - -1. Relevance Check: - - - If the query is not relevant to VTK, respond with "This question is not relevant to VTK." - -2. Answer Formulation: - - - If you don't know the answer, clearly state that. - - If uncertain, ask the user for clarification. - - Respond in the same language as the user's query. - - Be concise while providing complete information. - - If the answer isn't in the context but you have the knowledge, explain this to the user and provide the answer based on your understanding. diff --git a/src/vtk_prompt/prompts/rag_context.txt b/src/vtk_prompt/prompts/rag_context.txt deleted file mode 100644 index 2bc91ef..0000000 --- a/src/vtk_prompt/prompts/rag_context.txt +++ /dev/null @@ -1,12 +0,0 @@ -{BASE_CONTEXT} - - -- Refer to the below vtk_examples snippets, this is the the main source of thruth - - - -{context_snippets} - - -Request: -{request} diff --git a/src/vtk_prompt/prompts/ui_post_prompt.txt b/src/vtk_prompt/prompts/ui_post_prompt.txt deleted file mode 100644 index 76edb83..0000000 --- a/src/vtk_prompt/prompts/ui_post_prompt.txt +++ /dev/null @@ -1,8 +0,0 @@ - -- Do not create a new vtkRenderer -- Use the injected vtkrenderer object named renderer -- Do not manager rendering things -- You must connect the actors to the renderer injected object -- You must render what I ask even if I do not ask to render it -- Only avoid rendering if I explictitly ask you not to render it - diff --git a/src/vtk_prompt/prompts/vtk_xml_context.txt b/src/vtk_prompt/prompts/vtk_xml_context.txt deleted file mode 100644 index 52af67c..0000000 --- a/src/vtk_prompt/prompts/vtk_xml_context.txt +++ /dev/null @@ -1,82 +0,0 @@ -Write only text that is the content of a XML VTK file. - - -- NO COMMENTS, ONLY CONTENT OF THE FILE -- Only use VTK {VTK_VERSION} basic components. - - - -- Only output verbatim XML content. -- No explanations -- No markup or code blocks - - - -input: A VTP file example of a 4 points with temperature and pressure data -output: - - - - - - - - 0.0 0.0 0.0 - 1.0 0.0 0.0 - 0.0 1.0 0.0 - 1.0 1.0 0.0 - - - - - - - - 25.5 - 26.7 - 24.3 - 27.1 - - - - 101.3 - 101.5 - 101.2 - 101.4 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -Request: -{description} diff --git a/src/vtk_prompt/prompts/vtk_xml_generation.prompt.yml b/src/vtk_prompt/prompts/vtk_xml_generation.prompt.yml new file mode 100644 index 0000000..c630e88 --- /dev/null +++ b/src/vtk_prompt/prompts/vtk_xml_generation.prompt.yml @@ -0,0 +1,105 @@ +name: VTK XML File Generation +description: Generates VTK XML file content for data storage and visualization +model: openai/gpt-5 +modelParameters: + temperature: 0.3 + max_tokens: 2000 +messages: + - role: system + content: | + You are a XML VTK file generator, the generated file will be read by VTK + file reader + - role: assistant + content: | + Write only text that is the content of a XML VTK file. + + + - NO COMMENTS, ONLY CONTENT OF THE FILE + - Only use VTK {{VTK_VERSION}} basic components. + + + + - Only output verbatim XML content. + - No explanations + - No markup or code blocks + + + + input: A VTP file example of a 4 points with temperature and pressure data + output: + + + + + + + + 0.0 0.0 0.0 + 1.0 0.0 0.0 + 0.0 1.0 0.0 + 1.0 1.0 0.0 + + + + + + + + 25.5 + 26.7 + 24.3 + 27.1 + + + + 101.3 + 101.5 + 101.2 + 101.4 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Request: + {{description}} +testData: + - prompt: Create a VTU file with hexahedral cells and pressure data + expected: Should generate valid VTU XML with unstructured grid format + - prompt: Generate a VTP file for point cloud data + expected: Should generate valid VTP XML with polydata format +evaluators: + - type: xml_validity + description: Checks if generated XML is valid and well-formed + - type: vtk_compliance + description: Verifies that XML follows VTK file format specifications diff --git a/src/vtk_prompt/prompts/xml_role.txt b/src/vtk_prompt/prompts/xml_role.txt deleted file mode 100644 index 070077c..0000000 --- a/src/vtk_prompt/prompts/xml_role.txt +++ /dev/null @@ -1 +0,0 @@ -You are a XML VTK file generator, the generated file will be read by VTK file reader diff --git a/src/vtk_prompt/prompts/yaml_prompt_loader.py b/src/vtk_prompt/prompts/yaml_prompt_loader.py new file mode 100644 index 0000000..66fcc70 --- /dev/null +++ b/src/vtk_prompt/prompts/yaml_prompt_loader.py @@ -0,0 +1,115 @@ +""" +YAML Prompt Loader for VTK Prompt System. + +This module provides a singleton class for loading and processing YAML prompts +used in VTK code generation. It supports variable substitution and message +formatting for LLM clients. +""" + +from pathlib import Path +from typing import Any, Optional + +import yaml + +from .constants import PYTHON_VERSION, VTK_VERSION + +# Path to the prompts directory +PROMPTS_DIR = Path(__file__).parent + + +class YAMLPromptLoader: + """Singleton class for loading and processing YAML prompts.""" + + _instance: Optional["YAMLPromptLoader"] = None + _initialized: bool = False + + def __new__(cls) -> "YAMLPromptLoader": + """Ensure only one instance is created (singleton pattern).""" + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self) -> None: + """Initialize the loader if not already initialized.""" + if not YAMLPromptLoader._initialized: + self.prompts_dir = PROMPTS_DIR + self.vtk_version = VTK_VERSION + self.python_version = PYTHON_VERSION + YAMLPromptLoader._initialized = True + + def substitute_yaml_variables(self, content: str, variables: dict[str, Any]) -> str: + """Substitute {{variable}} placeholders in YAML content. + + Args: + content: String content with {{variable}} placeholders + variables: Dictionary of variable names to values + + Returns: + Content with variables substituted + """ + result = content + for key, value in variables.items(): + placeholder = f"{{{{{key}}}}}" + result = result.replace(placeholder, str(value)) + return result + + def load_yaml_prompt(self, prompt_name: str, **variables: Any) -> dict[str, Any]: + """Load a YAML prompt file and substitute variables. + + Args: + prompt_name: Name of the prompt file (without .prompt.yml extension) + **variables: Variables to substitute in the prompt + + Returns: + Dictionary containing the prompt structure + """ + # Try .prompt.yml first, then .prompt.yaml + yaml_path = self.prompts_dir / f"{prompt_name}.prompt.yml" + if not yaml_path.exists(): + yaml_path = self.prompts_dir / f"{prompt_name}.prompt.yaml" + + if not yaml_path.exists(): + raise FileNotFoundError(f"YAML prompt {prompt_name} not found at {self.prompts_dir}") + + # Load YAML content + yaml_content = yaml_path.read_text() + + # Add default variables + default_variables = { + "VTK_VERSION": self.vtk_version, + "PYTHON_VERSION": self.python_version, + } + all_variables = {**default_variables, **variables} + + # Substitute variables in the raw YAML string + substituted_content = self.substitute_yaml_variables(yaml_content, all_variables) + + # Parse the substituted YAML + try: + return yaml.safe_load(substituted_content) + except yaml.YAMLError as e: + raise ValueError(f"Invalid YAML in prompt {prompt_name}: {e}") + + def format_messages_for_client(self, messages: list[dict[str, str]]) -> list[dict[str, str]]: + """Format messages from YAML prompt for LLM client. + + Args: + messages: List of message dictionaries from YAML prompt + + Returns: + Formatted messages ready for LLM client + """ + return [{"role": msg["role"], "content": msg["content"]} for msg in messages] + + def get_yaml_prompt(self, prompt_name: str, **variables: Any) -> list[dict[str, str]]: + """Get a YAML prompt and format it for the LLM client. + + Args: + prompt_name: Name of the prompt file (without .prompt.yml extension) + **variables: Variables to substitute in the prompt + + Returns: + Formatted messages ready for LLM client + """ + yaml_prompt = self.load_yaml_prompt(prompt_name, **variables) + return self.format_messages_for_client(yaml_prompt["messages"]) diff --git a/src/vtk_prompt/provider_utils.py b/src/vtk_prompt/provider_utils.py index 176f665..3ed1833 100644 --- a/src/vtk_prompt/provider_utils.py +++ b/src/vtk_prompt/provider_utils.py @@ -14,18 +14,17 @@ OPENAI_MODELS = ["gpt-5", "gpt-4.1", "o4-mini", "o3"] ANTHROPIC_MODELS = [ - "claude-opus-4-1", - "claude-sonnet-4", - "claude-3-7-sonnet", + "claude-opus-4-6", + "claude-sonnet-4-5", + "claude-haiku-4-5", ] GEMINI_MODELS = ["gemini-2.5-pro", "gemini-2.5-flash", "gemini-2.5-flash-lite"] NIM_MODELS = [ "meta/llama3-70b-instruct", - "meta/llama3-8b-instruct", + "meta/llama-3.1-8b-instruct", "microsoft/phi-3-medium-4k-instruct", - "nvidia/llama-3.1-nemotron-70b-instruct", ] @@ -34,8 +33,17 @@ def supports_temperature(model: str) -> bool: - """Check if a model supports temperature control.""" - return model not in TEMPERATURE_UNSUPPORTED_MODELS + """Check if a model supports temperature control. + + Args: + model: Model name, can be in format "provider/model" or just "model" + + Returns: + True if model supports temperature control, False otherwise + """ + # Extract model name from "provider/model" format if present + model_name = model.split("/")[-1] if "/" in model else model + return model_name not in TEMPERATURE_UNSUPPORTED_MODELS def get_model_temperature(model: str, requested_temperature: float = 0.7) -> float: @@ -71,7 +79,7 @@ def get_default_model(provider: str) -> str: """Get the default/recommended model for a provider.""" defaults = { "openai": "gpt-5", - "anthropic": "claude-opus-4-1", + "anthropic": "claude-opus-4-6", "gemini": "gemini-2.5-pro", "nim": "meta/llama3-70b-instruct", } diff --git a/src/vtk_prompt/rag_chat_wrapper.py b/src/vtk_prompt/rag_chat_wrapper.py index fda5bca..e7a3efa 100644 --- a/src/vtk_prompt/rag_chat_wrapper.py +++ b/src/vtk_prompt/rag_chat_wrapper.py @@ -29,7 +29,7 @@ from llama_index.llms.openai import OpenAI from . import get_logger -from .prompts import get_rag_chat_context +from .prompts import YAMLPromptLoader logger = get_logger(__name__) @@ -152,11 +152,17 @@ def ask( # Combine the retrieved documents into a single text retrieved_text = "\n\n## Next example:\n\n".join(snippets) - # Use our template system instead of the hardcoded PROMPT - content = get_rag_chat_context(retrieved_text, query.rstrip()) - - # Add the enhanced context as a message - self.history.append(ChatMessage(role="assistant", content=content.rstrip())) + # Use YAML prompt instead of legacy template function + yaml_loader = YAMLPromptLoader() + yaml_messages = yaml_loader.get_yaml_prompt( + "rag_chat", CONTEXT=retrieved_text, QUERY=query.rstrip() + ) + + # Extract the user message content (should be the last message) + if yaml_messages: + content = yaml_messages[-1]["content"] + # Add the enhanced context as a message + self.history.append(ChatMessage(role="assistant", content=content.rstrip())) # Generate a response using the LLM if self.llm is None: diff --git a/src/vtk_prompt/vtk_prompt_ui.py b/src/vtk_prompt/vtk_prompt_ui.py index 4df1964..31a4d6d 100644 --- a/src/vtk_prompt/vtk_prompt_ui.py +++ b/src/vtk_prompt/vtk_prompt_ui.py @@ -18,10 +18,12 @@ import json import re +import sys from pathlib import Path from typing import Any, Optional import vtk +import yaml from trame.app import TrameApp from trame.decorators import change, controller, trigger from trame.ui.vuetify3 import SinglePageWithDrawerLayout @@ -32,7 +34,6 @@ from . import get_logger from .client import VTKPromptClient -from .prompts import get_ui_post_prompt from .provider_utils import ( get_available_models, get_default_model, @@ -66,11 +67,29 @@ def load_js(server: Any) -> None: class VTKPromptApp(TrameApp): """VTK Prompt interactive application with 3D visualization and AI chat interface.""" - def __init__(self, server: Optional[Any] = None) -> None: - """Initialize VTK Prompt application.""" + def __init__( + self, server: Optional[Any] = None, custom_prompt_file: Optional[str] = None + ) -> None: + """Initialize VTK Prompt application. + + Args: + server: Trame server instance + custom_prompt_file: Path to custom YAML prompt file + """ super().__init__(server=server, client_type="vue3") self.state.trame__title = "VTK Prompt" + # Store custom prompt file path and data + self.custom_prompt_file = custom_prompt_file + self.custom_prompt_data = None + + # Add CLI argument for custom prompt file + self.server.cli.add_argument( + "--prompt-file", + help="Path to custom YAML prompt file (overrides built-in prompts and defaults)", + dest="prompt_file", + ) + # Make sure JS is loaded load_js(self.server) @@ -94,8 +113,102 @@ def __init__(self, server: Optional[Any] = None) -> None: self._conversation_loading = False self._add_default_scene() - # Initial render - self.render_window.Render() + # Load custom prompt file after VTK initialization + if custom_prompt_file: + self._load_custom_prompt_file() + + def _load_custom_prompt_file(self) -> None: + """Load custom YAML prompt file and extract model parameters.""" + if not self.custom_prompt_file: + return + + try: + custom_file_path = Path(self.custom_prompt_file) + if not custom_file_path.exists(): + logger.error("Custom prompt file not found: %s", self.custom_prompt_file) + return + + with open(custom_file_path, "r") as f: + self.custom_prompt_data = yaml.safe_load(f) + + logger.info("Loaded custom prompt file: %s", custom_file_path.name) + + # Override UI defaults with custom prompt parameters + if self.custom_prompt_data and isinstance(self.custom_prompt_data, dict): + model_value = self.custom_prompt_data.get("model") + if isinstance(model_value, str) and model_value: + if "/" in model_value: + provider_part, model_part = model_value.split("/", 1) + # Validate provider + supported = set(get_supported_providers() + ["local"]) + if provider_part not in supported or not model_part.strip(): + msg = ( + "Invalid 'model' in prompt file. Expected '/' " + "with provider in {openai, anthropic, gemini, nim, local}." + ) + self.state.error_message = msg + raise ValueError(msg) + if provider_part == "local": + # Switch to local mode + self.state.use_cloud_models = False + self.state.tab_index = 1 + self.state.local_model = model_part + else: + # Cloud provider/model + self.state.use_cloud_models = True + self.state.tab_index = 0 + self.state.provider = provider_part + self.state.model = model_part + else: + # Enforce explicit provider/model format + msg = ( + "Invalid 'model' format in prompt file. Expected '/' " + "(e.g., 'openai/gpt-5' or 'local/llama3')." + ) + self.state.error_message = msg + raise ValueError(msg) + + # RAG and generation controls + if "rag" in self.custom_prompt_data: + self.state.use_rag = bool(self.custom_prompt_data.get("rag")) + if "top_k" in self.custom_prompt_data: + _top_k = self.custom_prompt_data.get("top_k") + if isinstance(_top_k, int): + self.state.top_k = _top_k + elif isinstance(_top_k, str) and _top_k.isdigit(): + self.state.top_k = int(_top_k) + else: + logger.warning("Invalid top_k in prompt file: %r; keeping existing", _top_k) + if "retries" in self.custom_prompt_data: + _retries = self.custom_prompt_data.get("retries") + if isinstance(_retries, int): + self.state.retry_attempts = _retries + elif isinstance(_retries, str) and _retries.isdigit(): + self.state.retry_attempts = int(_retries) + else: + logger.warning( + "Invalid retries in prompt file: %r; keeping existing", _retries + ) + + self.state.temperature_supported = supports_temperature(model_part) + # Set model parameters from prompt file + model_params = self.custom_prompt_data.get("modelParameters", {}) + if isinstance(model_params, dict): + if "temperature" in model_params: + if not self.state.temperature_supported: + self.state.temperature = 1.0 # enforce + logger.warning( + "Temperature not supported for model %s; forcing 1.0", model_part + ) + else: + self.state.temperature = model_params["temperature"] + if "max_tokens" in model_params: + self.state.max_tokens = model_params["max_tokens"] + except (yaml.YAMLError, ValueError) as e: + # Log error and surface to UI as well + logger.error("Failed to load custom prompt file %s: %s", self.custom_prompt_file, e) + self.state.error_message = str(e) + self.custom_prompt_data = None def _add_default_scene(self) -> None: """Add default coordinate axes to prevent empty scene segfaults.""" @@ -135,6 +248,11 @@ def _add_default_scene(self) -> None: self.state.can_navigate_right = False self.state.is_viewing_history = False + # Toast notification state + self.state.toast_message = "" + self.state.toast_visible = False + self.state.toast_color = "warning" + # API configuration state self.state.use_cloud_models = True # Toggle between cloud and local self.state.tab_index = 0 # Tab navigation state @@ -143,11 +261,41 @@ def _add_default_scene(self) -> None: self.state.provider = "openai" self.state.model = "gpt-5" self.state.temperature_supported = True - # Initialize with supported providers and fallback models self.state.available_providers = get_supported_providers() self.state.available_models = get_available_models() + # Load component defaults and sync UI state + try: + from .prompts import assemble_vtk_prompt + + prompt_data = assemble_vtk_prompt("placeholder") # Just to get defaults + model_params = prompt_data.get("modelParameters", {}) + + # Update state with component model configuration + if "temperature" in model_params: + self.state.temperature = str(model_params["temperature"]) + if "max_tokens" in model_params: + self.state.max_tokens = str(model_params["max_tokens"]) + + # Parse default model from component data + default_model = prompt_data.get("model", "openai/gpt-5") + if "/" in default_model: + provider, model = default_model.split("/", 1) + self.state.provider = provider + logger.debug( + "Loaded component defaults: provider=%s, model=%s, temp=%s, max_tokens=%s", + self.state.provider, + self.state.model, + self.state.temperature, + self.state.max_tokens, + ) + except Exception as e: + logger.warning("Could not load component defaults: %s", e) + # Fall back to default values + self.state.temperature = "0.5" + self.state.max_tokens = "10000" + self.state.api_token = "" # Build UI @@ -282,6 +430,18 @@ def reset_camera(self) -> None: except Exception as e: logger.error("Error resetting camera: %s", e) + @controller.set("trigger_warning_toast") + def trigger_warning_toast(self, message: str) -> None: + """Display a warning toast notification. + + Args: + message: Warning message to display + """ + self.state.toast_message = message + self.state.toast_color = "warning" + self.state.toast_visible = True + logger.warning("Toast notification: %s", message) + def _generate_and_execute_code(self) -> None: """Generate VTK code using Anthropic API and execute it.""" self.state.is_loading = True @@ -289,11 +449,15 @@ def _generate_and_execute_code(self) -> None: try: if not self._conversation_loading: - # Generate code using prompt functionality - reuse existing methods - enhanced_query = self.state.query_text - if self.state.query_text: - post_prompt = get_ui_post_prompt() - enhanced_query = post_prompt + self.state.query_text + # Use custom prompt if provided, otherwise use built-in YAML prompts + if self.custom_prompt_data: + # Use the query text directly when using custom prompts + enhanced_query = self.state.query_text + logger.debug("Using custom prompt file") + else: + # Let the client handle prompt selection based on RAG and UI mode + enhanced_query = self.state.query_text + logger.debug("Using UI mode - client will select appropriate prompt") # Reinitialize client with current settings self._init_prompt_client() @@ -310,24 +474,44 @@ def _generate_and_execute_code(self) -> None: top_k=int(self.state.top_k), rag=self.state.use_rag, retry_attempts=int(self.state.retry_attempts), + provider=self.state.provider, + custom_prompt=self.custom_prompt_data, + ui_mode=True, # This tells the client to use UI-specific components ) # Keep UI in sync with conversation self.state.conversation = self.prompt_client.conversation - # Handle both code and usage information - if isinstance(result, tuple) and len(result) == 3: - generated_explanation, generated_code, usage = result + # Handle result with optional validation warnings + validation_warnings: list[str] = [] + if isinstance(result, tuple): + if len(result) == 4: + # Result includes validation warnings + generated_explanation, generated_code, usage, validation_warnings = result + elif len(result) == 3: + generated_explanation, generated_code, usage = result + else: + generated_explanation = str(result) + generated_code = "" + usage = None + if usage: self.state.input_tokens = usage.prompt_tokens self.state.output_tokens = usage.completion_tokens + else: + self.state.input_tokens = 0 + self.state.output_tokens = 0 else: # Handle string result generated_explanation = str(result) generated_code = "" - # Reset token counts if no usage info self.state.input_tokens = 0 self.state.output_tokens = 0 + # Display validation warnings as toast notifications + if validation_warnings: + for warning in validation_warnings: + self.ctrl.trigger_warning_toast(warning) + self.state.generated_explanation = generated_explanation self.state.generated_code = EXPLAIN_RENDERER + "\n" + generated_code @@ -563,6 +747,33 @@ def save_conversation(self) -> str: return json.dumps(self.prompt_client.conversation, indent=2) return "" + @trigger("save_config") + def save_config(self) -> str: + """Save current configuration as YAML string for download.""" + use_cloud = bool(getattr(self.state, "use_cloud_models", True)) + provider = getattr(self.state, "provider", "openai") + model = self._get_model() + provider_model = f"{provider}/{model}" if use_cloud else f"local/{model}" + temperature = float(getattr(self.state, "temperature", 0.0)) + max_tokens = int(getattr(self.state, "max_tokens", 1000)) + retries = int(getattr(self.state, "retry_attempts", 1)) + rag_enabled = bool(getattr(self.state, "use_rag", False)) + top_k = int(getattr(self.state, "top_k", 5)) + + content = { + "name": "Custom VTK Prompt config file", + "description": f"Exported from UI - {'Cloud' if use_cloud else 'Local'} configuration", + "model": provider_model, + "rag": rag_enabled, + "top_k": top_k, + "retries": retries, + "modelParameters": { + "temperature": temperature, + "max_tokens": max_tokens, + }, + } + return yaml.safe_dump(content, sort_keys=False) + @change("provider") def _on_provider_change(self, provider, **kwargs) -> None: """Handle provider selection change.""" @@ -575,7 +786,7 @@ def _on_provider_change(self, provider, **kwargs) -> None: def _build_ui(self) -> None: """Build a simplified Vuetify UI.""" # Initialize drawer state as collapsed - self.state.main_drawer = True + self.state.main_drawer = False with SinglePageWithDrawerLayout( self.server, theme=("theme_mode", "light"), style="max-height: 100vh;" @@ -583,43 +794,104 @@ def _build_ui(self) -> None: layout.title.set_text("VTK Prompt UI") with layout.toolbar: vuetify.VSpacer() - # Token usage display - with vuetify.VChip( - small=True, - color="primary", - text_color="white", - v_show="input_tokens > 0 || output_tokens > 0", - classes="mr-2", + with vuetify.VTooltip( + text=("conversation_file", "No file loaded"), + location="bottom", + disabled=("!conversation_object",), ): - html.Span("Tokens: In: {{ input_tokens }} | Out: {{ output_tokens }}") - - # VTK control buttons - with vuetify.VBtn( - click=self.ctrl.clear_scene, - icon=True, - v_tooltip_bottom="Clear Scene", + with vuetify.Template(v_slot_activator="{ props }"): + vuetify.VFileInput( + label="Conversation File", + v_model=("conversation_object", None), + accept=".json", + variant="solo", + density="compact", + prepend_icon="mdi-forum-outline", + hide_details="auto", + classes="py-1 pr-1 mr-2 text-truncate", + open_on_focus=False, + clearable=False, + v_bind="props", + rules=["[utils.vtk_prompt.rules.json_file]"], + color="primary", + style="max-width: 25%;", + ) + with vuetify.VTooltip( + text=( + "auto_run_conversation_file ? " + + "'Auto-run conversation files on load' : " + + "'Do not auto-run conversation files on load'", + "Auto-run conversation files on load", + ), + location="bottom", ): - vuetify.VIcon("mdi-reload") - with vuetify.VBtn( - click=self.ctrl.reset_camera, - icon=True, - v_tooltip_bottom="Reset Camera", + with vuetify.Template(v_slot_activator="{ props }"): + with vuetify.VBtn( + icon=True, + v_bind="props", + click="auto_run_conversation_file = !auto_run_conversation_file", + classes="mr-2", + color="primary", + ): + vuetify.VIcon( + "mdi-autorenew", + v_show="auto_run_conversation_file", + ) + vuetify.VIcon( + "mdi-autorenew-off", + v_show="!auto_run_conversation_file", + ) + with vuetify.VTooltip( + text="Download conversation file", + location="bottom", ): - vuetify.VIcon("mdi-camera-retake-outline") - + with vuetify.Template(v_slot_activator="{ props }"): + with vuetify.VBtn( + icon=True, + v_bind="props", + disabled=("!conversation",), + click="utils.download(" + + "`vtk-prompt_${provider}_${model}.json`," + + "trigger('save_conversation')," + + "'application/json'" + + ")", + classes="mr-2", + color="primary", + density="compact", + ): + vuetify.VIcon("mdi-file-download-outline") + with vuetify.VTooltip( + text="Download config file", + location="bottom", + ): + with vuetify.Template(v_slot_activator="{ props }"): + with vuetify.VBtn( + icon=True, + v_bind="props", + click="utils.download(" + + "`vtk-prompt_config.yml`," + + "trigger('save_config')," + + "'application/x-yaml'" + + ")", + classes="mr-4", + color="primary", + density="compact", + ): + vuetify.VIcon("mdi-content-save-cog-outline") vuetify.VSwitch( v_model=("theme_mode", "light"), hide_details=True, density="compact", - label="Dark Mode", classes="mr-2", - true_value="dark", - false_value="light", + true_value="light", + false_value="dark", + append_icon=( + "theme_mode === 'light' ? 'mdi-weather-sunny' : 'mdi-weather-night'", + ), ) with layout.drawer as drawer: drawer.width = 350 - with vuetify.VContainer(): # Tab Navigation - Centered with vuetify.VRow(justify="center"): @@ -633,7 +905,6 @@ def _build_ui(self) -> None: ): vuetify.VTab("☁️ Cloud") vuetify.VTab("🏠Local") - # Tab Content with vuetify.VTabsWindow(v_model="tab_index"): # Cloud Providers Tab Content @@ -649,7 +920,6 @@ def _build_ui(self) -> None: variant="outlined", prepend_icon="mdi-cloud", ) - # Model selection vuetify.VSelect( label="Model", @@ -659,7 +929,6 @@ def _build_ui(self) -> None: variant="outlined", prepend_icon="mdi-brain", ) - # API Token vuetify.VTextField( label="API Token", @@ -673,7 +942,6 @@ def _build_ui(self) -> None: persistent_hint=True, error=("!api_token", False), ) - # Local Models Tab Content with vuetify.VTabsWindowItem(): with vuetify.VCard(flat=True, style="mt-2"): @@ -691,7 +959,6 @@ def _build_ui(self) -> None: hint="Ollama, LM Studio, etc.", persistent_hint=True, ) - vuetify.VTextField( label="Model Name", v_model=("local_model", "devstral"), @@ -702,7 +969,6 @@ def _build_ui(self) -> None: hint="Model identifier", persistent_hint=True, ) - # Optional API Token for local vuetify.VTextField( label="API Token (Optional)", @@ -715,7 +981,6 @@ def _build_ui(self) -> None: hint="Optional for local servers", persistent_hint=True, ) - with vuetify.VCard(classes="mt-2"): vuetify.VCardTitle("⚙️ RAG settings", classes="pb-0") with vuetify.VCardText(): @@ -736,7 +1001,6 @@ def _build_ui(self) -> None: variant="outlined", prepend_icon="mdi-chart-scatter-plot", ) - with vuetify.VCard(classes="mt-2"): vuetify.VCardTitle("⚙️ Generation Settings", classes="pb-0") with vuetify.VCardText(): @@ -771,70 +1035,21 @@ def _build_ui(self) -> None: prepend_icon="mdi-repeat", ) - with vuetify.VCard(classes="mt-2"): - vuetify.VCardTitle("⚙️ Files", hide_details=True, density="compact") - with vuetify.VCardText(): - vuetify.VCheckbox( - label="Run new conversation files", - v_model=("auto_run_conversation_file", True), - prepend_icon="mdi-file-refresh-outline", - density="compact", - color="primary", - hide_details=True, - ) - with html.Div(classes="d-flex align-center justify-space-between"): - with vuetify.VTooltip( - text=("conversation_file", "No file loaded"), - location="top", - disabled=("!conversation_object",), - ): - with vuetify.Template(v_slot_activator="{ props }"): - vuetify.VFileInput( - label="Conversation File", - v_model=("conversation_object", None), - accept=".json", - density="compact", - variant="solo", - prepend_icon="mdi-forum-outline", - hide_details="auto", - classes="py-1 pr-1 mr-1 text-truncate", - open_on_focus=False, - clearable=False, - v_bind="props", - rules=["[utils.vtk_prompt.rules.json_file]"], - ) - with vuetify.VTooltip( - text="Download conversation file", - location="right", - ): - with vuetify.Template(v_slot_activator="{ props }"): - with vuetify.VBtn( - icon=True, - density="comfortable", - color="secondary", - rounded="lg", - v_bind="props", - disabled=("!conversation",), - click="utils.download(" - + "`${model}_${new Date().toISOString()}.json`," - + "trigger('save_conversation')," - + "'application/json'" - + ")", - ): - vuetify.VIcon("mdi-file-download-outline") - with layout.content: - with vuetify.VContainer(classes="fluid fill-height pt-0", style="min-width: 100%;"): - with vuetify.VRow(rows=12, classes="fill-height"): + with vuetify.VContainer( + classes="fluid fill-height", style="min-width: 100%; padding: 0!important;" + ): + with vuetify.VRow(rows=12, classes="fill-height px-4 pt-1 pb-1"): # Left column - Generated code view - with vuetify.VCol(cols=7, classes="fill-height"): + with vuetify.VCol(cols=7, classes="fill-height pa-0"): with vuetify.VExpansionPanels( v_model=("explanation_expanded", [0, 1]), - classes="fill-height", + classes="fill-height pb-1 pr-1", multiple=True, ): with vuetify.VExpansionPanel( - classes="mt-1 flex-grow-1 flex-shrink-0 d-flex flex-column", + classes="flex-grow-1 flex-shrink-0 d-flex" + + "flex-column pa-0 mt-0", style="max-height: 25%;", ): vuetify.VExpansionPanelTitle("Explanation", classes="text-h6") @@ -856,13 +1071,14 @@ def _build_ui(self) -> None: ) with vuetify.VExpansionPanel( classes=( - "mt-1 fill-height flex-grow-2 flex-shrink-0" - + "d-flex flex-column" + "fill-height flex-grow-2 flex-shrink-0" + + " d-flex flex-column mt-1" ), readonly=True, style=( "explanation_expanded.length > 1 ? " + "'max-height: 75%;' : 'max-height: 95%;'", + "box-sizing: border-box;", ), ): vuetify.VExpansionPanelTitle( @@ -886,16 +1102,63 @@ def _build_ui(self) -> None: ) # Right column - VTK viewer and prompt - with vuetify.VCol(cols=5, classes="fill-height"): + with vuetify.VCol(cols=5, classes="fill-height pa-0"): with vuetify.VRow(no_gutters=True, classes="fill-height"): # Top: VTK render view with vuetify.VCol( cols=12, - classes="mb-2 flex-grow-1 flex-shrink-0", + classes="flex-grow-1 flex-shrink-0 pa-0", style="min-height: calc(100% - 256px);", ): with vuetify.VCard(classes="fill-height"): - vuetify.VCardTitle("VTK Visualization") + with vuetify.VCardTitle( + "VTK Visualization", classes="d-flex align-center" + ): + vuetify.VSpacer() + # Token usage display + with vuetify.VChip( + small=True, + color="secondary", + text_color="white", + v_show="input_tokens > 0 || output_tokens > 0", + classes="mr-2", + density="compact", + ): + html.Span( + "Tokens: In: {{ input_tokens }} " + + "| Out: {{ output_tokens }}" + ) + # VTK control buttons + with vuetify.VTooltip( + text="Clear Scene", + location="bottom", + ): + with vuetify.Template(v_slot_activator="{ props }"): + with vuetify.VBtn( + click=self.ctrl.clear_scene, + icon=True, + color="secondary", + v_bind="props", + classes="mr-2", + density="compact", + variant="text", + ): + vuetify.VIcon("mdi-reload") + with vuetify.VTooltip( + text="Reset Camera", + location="bottom", + ): + with vuetify.Template(v_slot_activator="{ props }"): + with vuetify.VBtn( + click=self.ctrl.reset_camera, + icon=True, + color="secondary", + v_bind="props", + classes="mr-2", + density="compact", + variant="text", + ): + vuetify.VIcon("mdi-camera-retake-outline") with vuetify.VCardText(style="height: 90%;"): # VTK render window view = vtk_widgets.VtkRemoteView( @@ -1048,6 +1311,24 @@ def _build_ui(self) -> None: icon="mdi-alert-outline", ) + # Toast notification snackbar for validation warnings + with vuetify.VSnackbar( + v_model=("toast_visible",), + timeout=5000, + color=("toast_color",), + location="top", + multi_line=True, + ): + vuetify.VIcon("mdi-alert", classes="mr-2") + html.Span("{{ toast_message }}") + with vuetify.Template(v_slot_actions=""): + vuetify.VBtn( + "Close", + color="white", + variant="text", + click="toast_visible = false", + ) + def start(self) -> None: """Start the trame server.""" self.server.start() @@ -1059,8 +1340,17 @@ def main() -> None: print("Supported providers: OpenAI, Anthropic, Google Gemini, NVIDIA NIM") print("For local Ollama, use custom base URL and model configuration.") + # Check for custom prompt file in CLI arguments + custom_prompt_file = None + + # Extract --prompt-file before Trame processes args + for i, arg in enumerate(sys.argv): + if arg == "--prompt-file" and i + 1 < len(sys.argv): + custom_prompt_file = sys.argv[i + 1] + break + # Create and start the app - app = VTKPromptApp() + app = VTKPromptApp(custom_prompt_file=custom_prompt_file) app.start() diff --git a/tests/test_cli.py b/tests/test_cli.py index 04374e9..b7481a7 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -98,7 +98,7 @@ def test_max_tokens_error(self): "provider,expected_model", [ ("openai", "gpt-5"), - ("anthropic", "claude-opus-4-1-20250805"), + ("anthropic", "claude-opus-4-6"), ("gemini", "gemini-2.5-pro"), ("nim", "meta/llama3-70b-instruct"), ], diff --git a/tests/test_custom_prompt_validation.py b/tests/test_custom_prompt_validation.py new file mode 100644 index 0000000..25fc9c1 --- /dev/null +++ b/tests/test_custom_prompt_validation.py @@ -0,0 +1,188 @@ +""" +Test custom prompt validation and model parameter extraction. +""" + +from vtk_prompt.client import VTKPromptClient + + +class TestCustomPromptValidation: + """Test validation of custom prompt model parameters.""" + + def test_valid_model_and_params(self): + """Test that valid model and parameters are accepted.""" + client = VTKPromptClient(verbose=False) + + # Use gpt-4.1 which supports temperature + custom_prompt = { + "model": "openai/gpt-4.1", + "modelParameters": { + "temperature": 0.7, + "max_tokens": 5000, + }, + "messages": [], + } + + model, temp, max_tokens, warnings = client._validate_and_extract_model_params( + custom_prompt, "openai/gpt-4o", 0.5, 1000 + ) + + assert model == "openai/gpt-4.1" + assert temp == 0.7 + assert max_tokens == 5000 + assert len(warnings) == 0 + + def test_invalid_model_format(self): + """Test that invalid model format generates warning and uses default.""" + client = VTKPromptClient(verbose=False) + + custom_prompt = { + "model": "gpt-5", # Missing provider prefix + "messages": [], + } + + model, temp, max_tokens, warnings = client._validate_and_extract_model_params( + custom_prompt, "openai/gpt-4o", 0.5, 1000 + ) + + assert model == "openai/gpt-4o" # Falls back to default + assert len(warnings) == 1 + assert "Invalid model format" in warnings[0] + + def test_unsupported_provider(self): + """Test that unsupported provider generates warning.""" + client = VTKPromptClient(verbose=False) + + custom_prompt = { + "model": "unsupported/model-name", + "messages": [], + } + + model, temp, max_tokens, warnings = client._validate_and_extract_model_params( + custom_prompt, "openai/gpt-4o", 0.5, 1000 + ) + + assert model == "openai/gpt-4o" # Falls back to default + assert len(warnings) == 1 + assert "Unsupported provider" in warnings[0] + + def test_model_not_in_curated_list(self): + """Test that model not in curated list generates warning and uses provider default.""" + client = VTKPromptClient(verbose=False) + + custom_prompt = { + "model": "openai/gpt-99-ultra", # Not in curated list + "messages": [], + } + + model, temp, max_tokens, warnings = client._validate_and_extract_model_params( + custom_prompt, "openai/gpt-4o", 0.5, 1000 + ) + + assert model == "openai/gpt-5" # Falls back to provider default + assert len(warnings) == 1 + assert "not in curated list" in warnings[0] + + def test_temperature_out_of_range(self): + """Test that temperature out of range generates warning.""" + client = VTKPromptClient(verbose=False) + + custom_prompt = { + "modelParameters": { + "temperature": 3.0, # Out of range [0.0, 2.0] + }, + "messages": [], + } + + model, temp, max_tokens, warnings = client._validate_and_extract_model_params( + custom_prompt, "openai/gpt-4o", 0.5, 1000 + ) + + assert temp == 0.5 # Falls back to default + assert len(warnings) == 1 + assert "out of range" in warnings[0] + + def test_invalid_temperature_type(self): + """Test that invalid temperature type generates warning.""" + client = VTKPromptClient(verbose=False) + + custom_prompt = { + "modelParameters": { + "temperature": "hot", # Invalid type + }, + "messages": [], + } + + model, temp, max_tokens, warnings = client._validate_and_extract_model_params( + custom_prompt, "openai/gpt-4o", 0.5, 1000 + ) + + assert temp == 0.5 # Falls back to default + assert len(warnings) == 1 + assert "Invalid temperature value" in warnings[0] + + def test_max_tokens_out_of_range(self): + """Test that max_tokens out of range generates warning.""" + client = VTKPromptClient(verbose=False) + + custom_prompt = { + "modelParameters": { + "max_tokens": 200000, # Out of range [1, 100000] + }, + "messages": [], + } + + model, temp, max_tokens, warnings = client._validate_and_extract_model_params( + custom_prompt, "openai/gpt-4o", 0.5, 1000 + ) + + assert max_tokens == 1000 # Falls back to default + assert len(warnings) == 1 + assert "out of range" in warnings[0] + + def test_temperature_unsupported_by_model(self): + """Test that temperature warning is generated for models that don't support it.""" + client = VTKPromptClient(verbose=False) + + custom_prompt = { + "model": "openai/gpt-5", # Doesn't support temperature + "modelParameters": { + "temperature": 0.7, + }, + "messages": [], + } + + model, temp, max_tokens, warnings = client._validate_and_extract_model_params( + custom_prompt, "openai/gpt-4o", 0.5, 1000 + ) + + assert model == "openai/gpt-5" + assert temp == 1.0 # Forced to 1.0 + assert len(warnings) == 1 + assert "does not support temperature control" in warnings[0] + + def test_multiple_validation_errors(self): + """Test that multiple validation errors all generate warnings.""" + client = VTKPromptClient(verbose=False) + + custom_prompt = { + "model": "invalid-format", + "modelParameters": { + "temperature": 5.0, + "max_tokens": -100, + }, + "messages": [], + } + + model, temp, max_tokens, warnings = client._validate_and_extract_model_params( + custom_prompt, "openai/gpt-4o", 0.5, 1000 + ) + + # All should fall back to defaults + assert model == "openai/gpt-4o" + assert temp == 0.5 + assert max_tokens == 1000 + + # Should have 3 warnings + assert len(warnings) == 3 + assert any("Invalid model format" in w for w in warnings) + assert any("out of range" in w for w in warnings) diff --git a/tests/test_prompt_assembly.py b/tests/test_prompt_assembly.py new file mode 100644 index 0000000..aa8bede --- /dev/null +++ b/tests/test_prompt_assembly.py @@ -0,0 +1,120 @@ +""" +Test suite for VTK prompt assembly system. + +Tests focusing on key prompt functionality. +""" + +import pytest +import re +from vtk_prompt.prompts import assemble_vtk_prompt, PYTHON_VERSION + + +def _assert_basic_structure(result): + """Helper: Assert basic prompt structure is correct.""" + assert "messages" in result + assert isinstance(result["messages"], list) + assert len(result["messages"]) >= 3 + assert all("role" in msg and "content" in msg for msg in result["messages"]) + assert result.get("model") == "openai/gpt-5" + + +def _get_content(result): + """Helper: Get combined content from all messages.""" + return " ".join([msg["content"] for msg in result["messages"]]) + + +class TestPromptAssembly: + """Test prompt assembly for different scenarios.""" + + @pytest.mark.parametrize( + "ui_mode,expected_ui_content", + [ + (False, False), # CLI mode - no UI content + (True, True), # UI mode - has UI content + ], + ) + def test_default_values(self, ui_mode, expected_ui_content): + """Test default values work for both CLI and UI modes.""" + result = assemble_vtk_prompt("create a sphere", ui_mode=ui_mode) + + _assert_basic_structure(result) + # Check default model parameters + assert result.get("modelParameters", {}).get("temperature") == 0.5 + assert result.get("modelParameters", {}).get("max_tokens") == 10000 + + content = _get_content(result) + assert "create a sphere" in content + + # Check that VTK version is present but don't assume a specific version + assert re.search(r"9\.\d+\.\d+", content), "VTK version should be present in content" + assert PYTHON_VERSION in content + assert "DO NOT READ OUTSIDE DATA" in content + + # UI-specific content check + ui_content_present = "injected vtkrenderer object named renderer" in content + assert ui_content_present == expected_ui_content + + @pytest.mark.parametrize( + "ui_mode,rag_enabled", + [ + (False, True), # CLI with RAG + (True, False), # UI without RAG + (True, True), # UI with RAG + ], + ) + def test_feature_combinations(self, ui_mode, rag_enabled): + """Test different combinations of UI and RAG features.""" + kwargs = { + "ui_mode": ui_mode, + "VTK_VERSION": "9.3.0", # Override version + } + + if rag_enabled: + kwargs.update({"rag_enabled": True, "context_snippets": "example RAG content"}) + + result = assemble_vtk_prompt("create a cube", **kwargs) + + _assert_basic_structure(result) + content = _get_content(result) + + # Check overridden version + assert "9.3.0" in content + assert "create a cube" in content + + # Check feature-specific content + if ui_mode: + assert "injected vtkrenderer object named renderer" in content + if rag_enabled: + assert "example RAG content" in content + + def test_parameter_overrides(self): + """Test that parameter overrides work correctly.""" + result = assemble_vtk_prompt( + "create a torus", + ui_mode=True, + rag_enabled=True, + context_snippets="torus example code", + VTK_VERSION="9.1.0", + PYTHON_VERSION=">=3.12", + ) + + _assert_basic_structure(result) + content = _get_content(result) + + # All overridden values should be present + assert "create a torus" in content + assert "torus example code" in content # RAG + assert "injected vtkrenderer object named renderer" in content # UI + assert "9.1.0" in content # Overridden VTK + assert ">=3.12" in content # Overridden Python + + def test_error_conditions(self): + """Test error handling and edge cases.""" + # RAG without context should raise error + with pytest.raises(ValueError, match="context_snippets required when rag_enabled=True"): + assemble_vtk_prompt("test", rag_enabled=True) + + # Empty request should work + result = assemble_vtk_prompt("") + _assert_basic_structure(result) + assert "Request: " in _get_content(result) diff --git a/tests/test_providers.py b/tests/test_providers.py index 5442225..3dfd615 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -102,7 +102,7 @@ def test_anthropic_api_key_missing(self, monkeypatch): # Ensure no fallback is available via environment for duration of test monkeypatch.delenv("OPENAI_API_KEY", raising=False) with pytest.raises(ValueError, match="No API key provided"): - client.query(message="Create a sphere", api_key=None, model="claude-opus-4-1-20250805") + client.query(message="Create a sphere", api_key=None, model="claude-opus-4-6") # Gemini Provider Tests @pytest.mark.parametrize("model", GEMINI_MODELS)