diff --git a/pyproject.toml b/pyproject.toml
index e52f371..53a05a8 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -66,9 +66,22 @@ test = [
"pytest-cov>=4.0.0",
]
+# Provider-specific dependencies
+providers-anthropic = [
+ "anthropic>=0.22.0",
+]
+
+providers-gemini = [
+ "google-generativeai>=0.8.0",
+]
+
+providers-all = [
+ "vtk-prompt[providers-anthropic,providers-gemini]",
+]
+
# All optional dependencies
all = [
- "vtk-prompt[dev,test]",
+ "vtk-prompt[dev,test,providers-all]",
]
[project.urls]
@@ -132,7 +145,9 @@ module = [
"sentence_transformers.*",
"tree_sitter_languages.*",
"llama_index.*",
- "query_db.*"
+ "query_db.*",
+ "google.*",
+ "anthropic.*",
]
ignore_missing_imports = true
diff --git a/src/vtk_prompt/cli.py b/src/vtk_prompt/cli.py
index 3481b99..0e93f07 100644
--- a/src/vtk_prompt/cli.py
+++ b/src/vtk_prompt/cli.py
@@ -122,16 +122,18 @@ def main(
top_k=top_k,
rag=rag,
retry_attempts=retry_attempts,
+ provider=provider,
)
if isinstance(result, tuple) and len(result) == 3:
- explanation, generated_code, usage = result
+ _explanation, generated_code, usage = result
if verbose and usage:
logger.info(
"Used tokens: input=%d output=%d",
- usage.prompt_tokens,
- usage.completion_tokens,
+ usage["prompt_tokens"],
+ usage["completion_tokens"],
)
+ logger.info("Explanation:\n%s", explanation)
client.run_code(generated_code)
else:
# Handle string result
diff --git a/src/vtk_prompt/client.py b/src/vtk_prompt/client.py
index 65d3267..a00a6b2 100644
--- a/src/vtk_prompt/client.py
+++ b/src/vtk_prompt/client.py
@@ -29,13 +29,14 @@
get_python_role,
get_rag_context,
)
+from .provider_utils import supports_temperature
logger = get_logger(__name__)
@dataclass
class VTKPromptClient:
- """OpenAI client for VTK code generation."""
+ """Multi-provider LLM client for VTK code generation."""
_instance: Optional["VTKPromptClient"] = None
_initialized: bool = False
@@ -44,6 +45,7 @@ class VTKPromptClient:
verbose: bool = False
conversation_file: Optional[str] = None
conversation: Optional[list[dict[str, str]]] = None
+ provider: str = "openai"
def __new__(cls, **kwargs: Any) -> "VTKPromptClient":
"""Create singleton instance of VTKPromptClient."""
@@ -132,6 +134,151 @@ def run_code(self, code_string: str) -> None:
logger.debug("Failed code:\n%s", code_string)
return
+ def _get_provider_client(self, api_key: str, base_url: Optional[str] = None) -> Any:
+ """Create provider-specific client."""
+ if self.provider == "anthropic":
+ try:
+ import anthropic
+
+ return anthropic.Anthropic(api_key=api_key)
+ except ImportError:
+ raise ValueError(
+ "Anthropic provider requires 'anthropic' package. "
+ "Install with: pip install 'vtk-prompt[providers-anthropic]'"
+ )
+ elif self.provider == "gemini":
+ try:
+ import google.generativeai as genai
+
+ genai.configure(api_key=api_key)
+ return genai
+ except ImportError:
+ raise ValueError(
+ "Gemini provider requires 'google-generativeai' package. "
+ "Install with: pip install 'vtk-prompt[providers-gemini]'"
+ )
+ else:
+ # OpenAI-compatible providers (openai, nim)
+ return openai.OpenAI(api_key=api_key, base_url=base_url)
+
+ def _convert_to_provider_format(self, messages: list[dict[str, str]]) -> Any:
+ """Convert OpenAI format messages to provider-specific format."""
+ if self.provider == "anthropic":
+ system_message = None
+ conversation_messages = []
+
+ for msg in messages:
+ if msg["role"] == "system":
+ system_message = msg["content"]
+ else:
+ conversation_messages.append({"role": msg["role"], "content": msg["content"]})
+
+ return {"system": system_message, "messages": conversation_messages}
+ elif self.provider == "gemini":
+ # Gemini uses a different format - convert to their chat format
+ gemini_messages = []
+ for msg in messages:
+ if msg["role"] == "system":
+ # Gemini doesn't have system role, prepend to first user message
+ continue
+ elif msg["role"] == "user":
+ gemini_messages.append({"role": "user", "parts": [{"text": msg["content"]}]})
+ elif msg["role"] == "assistant":
+ gemini_messages.append({"role": "model", "parts": [{"text": msg["content"]}]})
+ return gemini_messages
+ else:
+ # OpenAI-compatible providers use the same format
+ return messages
+
+ def _make_provider_request(
+ self,
+ client: Any,
+ messages: Any,
+ model: str,
+ max_tokens: int,
+ temperature: float,
+ ) -> Any:
+ """Make provider-specific API request."""
+ if self.provider == "anthropic":
+ return client.messages.create(
+ model=model,
+ system=messages["system"],
+ messages=messages["messages"],
+ max_tokens=max_tokens,
+ temperature=temperature,
+ )
+ elif self.provider == "gemini":
+ model_instance = client.GenerativeModel(model)
+ # Gemini API is different - convert our format
+ return model_instance.generate_content(
+ messages[-1]["parts"][0]["text"] if messages else "",
+ generation_config=client.types.GenerationConfig(
+ max_output_tokens=max_tokens,
+ temperature=temperature,
+ ),
+ )
+ else:
+ # OpenAI-compatible providers
+ request_params = {
+ "model": model,
+ "messages": messages,
+ "max_completion_tokens": max_tokens,
+ }
+ if supports_temperature(model):
+ request_params["temperature"] = temperature
+ return client.chat.completions.create(**request_params)
+
+ def _extract_response_content(self, response: Any) -> tuple[str, str]:
+ """Extract content and finish reason from provider-specific response."""
+ if self.provider == "anthropic":
+ content = response.content[0].text if response.content else "No content in response"
+ finish_reason = response.stop_reason or "unknown"
+ elif self.provider == "gemini":
+ content = response.text if hasattr(response, "text") else "No content in response"
+ finish_reason = "stop" # Gemini doesn't provide finish_reason in the same way
+ else:
+ # OpenAI-compatible providers
+ if hasattr(response, "choices") and len(response.choices) > 0:
+ content = response.choices[0].message.content or "No content in response"
+ finish_reason = response.choices[0].finish_reason or "unknown"
+ else:
+ content = "No content in response"
+ finish_reason = "unknown"
+
+ return content, finish_reason
+
+ def _get_usage_info(self, response: Any) -> dict[str, int]:
+ """Extract usage information from provider-specific response."""
+ if self.provider == "anthropic":
+ usage = response.usage
+ return {
+ "prompt_tokens": usage.input_tokens,
+ "completion_tokens": usage.output_tokens,
+ }
+ elif self.provider == "gemini":
+ if response.usage_metadata:
+ prompt_token_count = response.usage_metadata.prompt_token_count
+ completion_token_count = response.usage_metadata.candidates_token_count
+ return {
+ "prompt_tokens": prompt_token_count,
+ "completion_tokens": completion_token_count,
+ }
+ return {"prompt_tokens": 0, "completion_tokens": 0}
+ else:
+ # OpenAI-compatible providers
+ if hasattr(response, "usage") and response.usage:
+ return {
+ "prompt_tokens": response.usage.prompt_tokens,
+ "completion_tokens": response.usage.completion_tokens,
+ }
+ return {"prompt_tokens": 0, "completion_tokens": 0}
+
+ def _switch_provider(self, new_provider: str) -> None:
+ """Switch to a new provider. Conversation is already in OpenAI format (normalized)."""
+ if self.verbose:
+ logger.debug(f"Switching provider from {self.provider} to {new_provider}")
+ self.provider = new_provider
+
def query(
self,
message: str = "",
@@ -143,6 +290,7 @@ def query(
top_k: int = 5,
rag: bool = False,
retry_attempts: int = 1,
+ provider: Optional[str] = None,
) -> Union[tuple[str, str, Any], str]:
"""Generate VTK code with optional RAG enhancement and retry logic.
@@ -156,15 +304,19 @@ def query(
top_k: Number of RAG examples to retrieve
rag: Whether to use RAG enhancement
retry_attempts: Number of times to retry if AST validation fails
+ provider: LLM provider to use (overrides instance provider if provided)
"""
+ # Handle provider switching
+ if provider and provider != self.provider:
+ self._switch_provider(provider)
if not api_key:
api_key = os.environ.get("OPENAI_API_KEY")
if not api_key:
raise ValueError("No API key provided. Set OPENAI_API_KEY or pass api_key parameter.")
- # Create client with current parameters
- client = openai.OpenAI(api_key=api_key, base_url=base_url)
+ # Create provider-specific client
+ client = self._get_provider_client(api_key, base_url)
# Load existing conversation if present
if self.conversation_file and not self.conversation:
@@ -227,69 +379,68 @@ def query(
model=model,
messages=self.conversation, # type: ignore[arg-type]
max_completion_tokens=max_tokens,
- # max_tokens=max_tokens,
temperature=temperature,
)
- if hasattr(response, "choices") and len(response.choices) > 0:
- content = response.choices[0].message.content or "No content in response"
- finish_reason = response.choices[0].finish_reason
-
- if finish_reason == "length":
- raise ValueError(
- f"Output was truncated due to max_tokens limit ({max_tokens}).\n"
- "Please increase max_tokens."
- )
-
- generated_explanation = re.findall(
- "(.*?)", content, re.DOTALL
- )[0]
- generated_code = re.findall("(.*?)", content, re.DOTALL)[0]
- if "import vtk" not in generated_code:
- generated_code = "import vtk\n" + generated_code
+ # Extract response content using provider-specific method
+ content, finish_reason = self._extract_response_content(response)
+
+ if finish_reason == "length":
+ raise ValueError(
+ f"Output was truncated due to max_tokens limit ({max_tokens}).\n"
+ "Please increase max_tokens."
+ )
+
+ generated_explanation = re.findall(
+ "(.*?)", content, re.DOTALL
+ )[0]
+ generated_code = re.findall("(.*?)", content, re.DOTALL)[0]
+ if "import vtk" not in generated_code:
+ generated_code = "import vtk\n" + generated_code
+ else:
+ pos = generated_code.find("import vtk")
+ if pos != -1:
+ generated_code = generated_code[pos:]
else:
- pos = generated_code.find("import vtk")
- if pos != -1:
- generated_code = generated_code[pos:]
- else:
- generated_code = generated_code
-
- is_valid, error_msg = self.validate_code_syntax(generated_code)
- if is_valid:
- if message:
- self.conversation.append({"role": "assistant", "content": content})
- self.save_conversation()
- return generated_explanation, generated_code, response.usage
-
- elif attempt < retry_attempts - 1: # Don't log on last attempt
- if self.verbose:
- logger.warning("AST validation failed: %s. Retrying...", error_msg)
- # Add error feedback to context for retry
+ generated_code = generated_code
+
+ is_valid, error_msg = self.validate_code_syntax(generated_code)
+ if is_valid:
+ if message:
self.conversation.append({"role": "assistant", "content": content})
- self.conversation.append(
- {
- "role": "user",
- "content": (
- f"The generated code has a syntax error: {error_msg}. "
- "Please fix the syntax and generate valid Python code."
- ),
- }
- )
- else:
- # Last attempt failed
- if self.verbose:
- logger.error("Final attempt failed AST validation: %s", error_msg)
-
- if message:
- self.conversation.append({"role": "assistant", "content": content})
- self.save_conversation()
- return (
- generated_explanation,
- generated_code,
- response.usage or {},
- ) # Return anyway, let caller handle
+ self.save_conversation()
+ return (
+ generated_explanation,
+ generated_code,
+ self._get_usage_info(response),
+ )
+
+ elif attempt < retry_attempts - 1: # Don't log on last attempt
+ if self.verbose:
+ logger.warning("AST validation failed: %s. Retrying...", error_msg)
+ # Add error feedback to context for retry
+ self.conversation.append({"role": "assistant", "content": content})
+ self.conversation.append(
+ {
+ "role": "user",
+ "content": (
+ f"The generated code has a syntax error: {error_msg}. "
+ "Please fix the syntax and generate valid Python code."
+ ),
+ }
+ )
else:
- if attempt == retry_attempts - 1:
- return ("No response generated", "", response.usage or {})
+ # Last attempt failed
+ if self.verbose:
+ logger.error("Final attempt failed AST validation: %s", error_msg)
+
+ if message:
+ self.conversation.append({"role": "assistant", "content": content})
+ self.save_conversation()
+ return (
+ generated_explanation,
+ generated_code,
+ self._get_usage_info(response),
+ ) # Return anyway, let caller handle
return "No response generated"
diff --git a/src/vtk_prompt/vtk_prompt_ui.py b/src/vtk_prompt/vtk_prompt_ui.py
index 4df1964..8a37617 100644
--- a/src/vtk_prompt/vtk_prompt_ui.py
+++ b/src/vtk_prompt/vtk_prompt_ui.py
@@ -310,6 +310,7 @@ def _generate_and_execute_code(self) -> None:
top_k=int(self.state.top_k),
rag=self.state.use_rag,
retry_attempts=int(self.state.retry_attempts),
+ provider=self.state.provider,
)
# Keep UI in sync with conversation
self.state.conversation = self.prompt_client.conversation
@@ -318,8 +319,8 @@ def _generate_and_execute_code(self) -> None:
if isinstance(result, tuple) and len(result) == 3:
generated_explanation, generated_code, usage = result
if usage:
- self.state.input_tokens = usage.prompt_tokens
- self.state.output_tokens = usage.completion_tokens
+ self.state.input_tokens = usage["prompt_tokens"]
+ self.state.output_tokens = usage["completion_tokens"]
else:
# Handle string result
generated_explanation = str(result)