diff --git a/pyproject.toml b/pyproject.toml index e52f371..53a05a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,9 +66,22 @@ test = [ "pytest-cov>=4.0.0", ] +# Provider-specific dependencies +providers-anthropic = [ + "anthropic>=0.22.0", +] + +providers-gemini = [ + "google-generativeai>=0.8.0", +] + +providers-all = [ + "vtk-prompt[providers-anthropic,providers-gemini]", +] + # All optional dependencies all = [ - "vtk-prompt[dev,test]", + "vtk-prompt[dev,test,providers-all]", ] [project.urls] @@ -132,7 +145,9 @@ module = [ "sentence_transformers.*", "tree_sitter_languages.*", "llama_index.*", - "query_db.*" + "query_db.*", + "google.*", + "anthropic.*", ] ignore_missing_imports = true diff --git a/src/vtk_prompt/cli.py b/src/vtk_prompt/cli.py index 3481b99..0e93f07 100644 --- a/src/vtk_prompt/cli.py +++ b/src/vtk_prompt/cli.py @@ -122,16 +122,18 @@ def main( top_k=top_k, rag=rag, retry_attempts=retry_attempts, + provider=provider, ) if isinstance(result, tuple) and len(result) == 3: - explanation, generated_code, usage = result + _explanation, generated_code, usage = result if verbose and usage: logger.info( "Used tokens: input=%d output=%d", - usage.prompt_tokens, - usage.completion_tokens, + usage["prompt_tokens"], + usage["completion_tokens"], ) + logger.info("Explanation:\n%s", explanation) client.run_code(generated_code) else: # Handle string result diff --git a/src/vtk_prompt/client.py b/src/vtk_prompt/client.py index 65d3267..a00a6b2 100644 --- a/src/vtk_prompt/client.py +++ b/src/vtk_prompt/client.py @@ -29,13 +29,14 @@ get_python_role, get_rag_context, ) +from .provider_utils import supports_temperature logger = get_logger(__name__) @dataclass class VTKPromptClient: - """OpenAI client for VTK code generation.""" + """Multi-provider LLM client for VTK code generation.""" _instance: Optional["VTKPromptClient"] = None _initialized: bool = False @@ -44,6 +45,7 @@ class VTKPromptClient: verbose: bool = False conversation_file: Optional[str] = None conversation: Optional[list[dict[str, str]]] = None + provider: str = "openai" def __new__(cls, **kwargs: Any) -> "VTKPromptClient": """Create singleton instance of VTKPromptClient.""" @@ -132,6 +134,151 @@ def run_code(self, code_string: str) -> None: logger.debug("Failed code:\n%s", code_string) return + def _get_provider_client(self, api_key: str, base_url: Optional[str] = None) -> Any: + """Create provider-specific client.""" + if self.provider == "anthropic": + try: + import anthropic + + return anthropic.Anthropic(api_key=api_key) + except ImportError: + raise ValueError( + "Anthropic provider requires 'anthropic' package. " + "Install with: pip install 'vtk-prompt[providers-anthropic]'" + ) + elif self.provider == "gemini": + try: + import google.generativeai as genai + + genai.configure(api_key=api_key) + return genai + except ImportError: + raise ValueError( + "Gemini provider requires 'google-generativeai' package. " + "Install with: pip install 'vtk-prompt[providers-gemini]'" + ) + else: + # OpenAI-compatible providers (openai, nim) + return openai.OpenAI(api_key=api_key, base_url=base_url) + + def _convert_to_provider_format(self, messages: list[dict[str, str]]) -> Any: + """Convert OpenAI format messages to provider-specific format.""" + if self.provider == "anthropic": + system_message = None + conversation_messages = [] + + for msg in messages: + if msg["role"] == "system": + system_message = msg["content"] + else: + conversation_messages.append({"role": msg["role"], "content": msg["content"]}) + + return {"system": system_message, "messages": conversation_messages} + elif self.provider == "gemini": + # Gemini uses a different format - convert to their chat format + gemini_messages = [] + for msg in messages: + if msg["role"] == "system": + # Gemini doesn't have system role, prepend to first user message + continue + elif msg["role"] == "user": + gemini_messages.append({"role": "user", "parts": [{"text": msg["content"]}]}) + elif msg["role"] == "assistant": + gemini_messages.append({"role": "model", "parts": [{"text": msg["content"]}]}) + return gemini_messages + else: + # OpenAI-compatible providers use the same format + return messages + + def _make_provider_request( + self, + client: Any, + messages: Any, + model: str, + max_tokens: int, + temperature: float, + ) -> Any: + """Make provider-specific API request.""" + if self.provider == "anthropic": + return client.messages.create( + model=model, + system=messages["system"], + messages=messages["messages"], + max_tokens=max_tokens, + temperature=temperature, + ) + elif self.provider == "gemini": + model_instance = client.GenerativeModel(model) + # Gemini API is different - convert our format + return model_instance.generate_content( + messages[-1]["parts"][0]["text"] if messages else "", + generation_config=client.types.GenerationConfig( + max_output_tokens=max_tokens, + temperature=temperature, + ), + ) + else: + # OpenAI-compatible providers + request_params = { + "model": model, + "messages": messages, + "max_completion_tokens": max_tokens, + } + if supports_temperature(model): + request_params["temperature"] = temperature + return client.chat.completions.create(**request_params) + + def _extract_response_content(self, response: Any) -> tuple[str, str]: + """Extract content and finish reason from provider-specific response.""" + if self.provider == "anthropic": + content = response.content[0].text if response.content else "No content in response" + finish_reason = response.stop_reason or "unknown" + elif self.provider == "gemini": + content = response.text if hasattr(response, "text") else "No content in response" + finish_reason = "stop" # Gemini doesn't provide finish_reason in the same way + else: + # OpenAI-compatible providers + if hasattr(response, "choices") and len(response.choices) > 0: + content = response.choices[0].message.content or "No content in response" + finish_reason = response.choices[0].finish_reason or "unknown" + else: + content = "No content in response" + finish_reason = "unknown" + + return content, finish_reason + + def _get_usage_info(self, response: Any) -> dict[str, int]: + """Extract usage information from provider-specific response.""" + if self.provider == "anthropic": + usage = response.usage + return { + "prompt_tokens": usage.input_tokens, + "completion_tokens": usage.output_tokens, + } + elif self.provider == "gemini": + if response.usage_metadata: + prompt_token_count = response.usage_metadata.prompt_token_count + completion_token_count = response.usage_metadata.candidates_token_count + return { + "prompt_tokens": prompt_token_count, + "completion_tokens": completion_token_count, + } + return {"prompt_tokens": 0, "completion_tokens": 0} + else: + # OpenAI-compatible providers + if hasattr(response, "usage") and response.usage: + return { + "prompt_tokens": response.usage.prompt_tokens, + "completion_tokens": response.usage.completion_tokens, + } + return {"prompt_tokens": 0, "completion_tokens": 0} + + def _switch_provider(self, new_provider: str) -> None: + """Switch to a new provider. Conversation is already in OpenAI format (normalized).""" + if self.verbose: + logger.debug(f"Switching provider from {self.provider} to {new_provider}") + self.provider = new_provider + def query( self, message: str = "", @@ -143,6 +290,7 @@ def query( top_k: int = 5, rag: bool = False, retry_attempts: int = 1, + provider: Optional[str] = None, ) -> Union[tuple[str, str, Any], str]: """Generate VTK code with optional RAG enhancement and retry logic. @@ -156,15 +304,19 @@ def query( top_k: Number of RAG examples to retrieve rag: Whether to use RAG enhancement retry_attempts: Number of times to retry if AST validation fails + provider: LLM provider to use (overrides instance provider if provided) """ + # Handle provider switching + if provider and provider != self.provider: + self._switch_provider(provider) if not api_key: api_key = os.environ.get("OPENAI_API_KEY") if not api_key: raise ValueError("No API key provided. Set OPENAI_API_KEY or pass api_key parameter.") - # Create client with current parameters - client = openai.OpenAI(api_key=api_key, base_url=base_url) + # Create provider-specific client + client = self._get_provider_client(api_key, base_url) # Load existing conversation if present if self.conversation_file and not self.conversation: @@ -227,69 +379,68 @@ def query( model=model, messages=self.conversation, # type: ignore[arg-type] max_completion_tokens=max_tokens, - # max_tokens=max_tokens, temperature=temperature, ) - if hasattr(response, "choices") and len(response.choices) > 0: - content = response.choices[0].message.content or "No content in response" - finish_reason = response.choices[0].finish_reason - - if finish_reason == "length": - raise ValueError( - f"Output was truncated due to max_tokens limit ({max_tokens}).\n" - "Please increase max_tokens." - ) - - generated_explanation = re.findall( - "(.*?)", content, re.DOTALL - )[0] - generated_code = re.findall("(.*?)", content, re.DOTALL)[0] - if "import vtk" not in generated_code: - generated_code = "import vtk\n" + generated_code + # Extract response content using provider-specific method + content, finish_reason = self._extract_response_content(response) + + if finish_reason == "length": + raise ValueError( + f"Output was truncated due to max_tokens limit ({max_tokens}).\n" + "Please increase max_tokens." + ) + + generated_explanation = re.findall( + "(.*?)", content, re.DOTALL + )[0] + generated_code = re.findall("(.*?)", content, re.DOTALL)[0] + if "import vtk" not in generated_code: + generated_code = "import vtk\n" + generated_code + else: + pos = generated_code.find("import vtk") + if pos != -1: + generated_code = generated_code[pos:] else: - pos = generated_code.find("import vtk") - if pos != -1: - generated_code = generated_code[pos:] - else: - generated_code = generated_code - - is_valid, error_msg = self.validate_code_syntax(generated_code) - if is_valid: - if message: - self.conversation.append({"role": "assistant", "content": content}) - self.save_conversation() - return generated_explanation, generated_code, response.usage - - elif attempt < retry_attempts - 1: # Don't log on last attempt - if self.verbose: - logger.warning("AST validation failed: %s. Retrying...", error_msg) - # Add error feedback to context for retry + generated_code = generated_code + + is_valid, error_msg = self.validate_code_syntax(generated_code) + if is_valid: + if message: self.conversation.append({"role": "assistant", "content": content}) - self.conversation.append( - { - "role": "user", - "content": ( - f"The generated code has a syntax error: {error_msg}. " - "Please fix the syntax and generate valid Python code." - ), - } - ) - else: - # Last attempt failed - if self.verbose: - logger.error("Final attempt failed AST validation: %s", error_msg) - - if message: - self.conversation.append({"role": "assistant", "content": content}) - self.save_conversation() - return ( - generated_explanation, - generated_code, - response.usage or {}, - ) # Return anyway, let caller handle + self.save_conversation() + return ( + generated_explanation, + generated_code, + self._get_usage_info(response), + ) + + elif attempt < retry_attempts - 1: # Don't log on last attempt + if self.verbose: + logger.warning("AST validation failed: %s. Retrying...", error_msg) + # Add error feedback to context for retry + self.conversation.append({"role": "assistant", "content": content}) + self.conversation.append( + { + "role": "user", + "content": ( + f"The generated code has a syntax error: {error_msg}. " + "Please fix the syntax and generate valid Python code." + ), + } + ) else: - if attempt == retry_attempts - 1: - return ("No response generated", "", response.usage or {}) + # Last attempt failed + if self.verbose: + logger.error("Final attempt failed AST validation: %s", error_msg) + + if message: + self.conversation.append({"role": "assistant", "content": content}) + self.save_conversation() + return ( + generated_explanation, + generated_code, + self._get_usage_info(response), + ) # Return anyway, let caller handle return "No response generated" diff --git a/src/vtk_prompt/vtk_prompt_ui.py b/src/vtk_prompt/vtk_prompt_ui.py index 4df1964..8a37617 100644 --- a/src/vtk_prompt/vtk_prompt_ui.py +++ b/src/vtk_prompt/vtk_prompt_ui.py @@ -310,6 +310,7 @@ def _generate_and_execute_code(self) -> None: top_k=int(self.state.top_k), rag=self.state.use_rag, retry_attempts=int(self.state.retry_attempts), + provider=self.state.provider, ) # Keep UI in sync with conversation self.state.conversation = self.prompt_client.conversation @@ -318,8 +319,8 @@ def _generate_and_execute_code(self) -> None: if isinstance(result, tuple) and len(result) == 3: generated_explanation, generated_code, usage = result if usage: - self.state.input_tokens = usage.prompt_tokens - self.state.output_tokens = usage.completion_tokens + self.state.input_tokens = usage["prompt_tokens"] + self.state.output_tokens = usage["completion_tokens"] else: # Handle string result generated_explanation = str(result)