Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,22 @@ test = [
"pytest-cov>=4.0.0",
]

# Provider-specific dependencies
providers-anthropic = [
"anthropic>=0.22.0",
]

providers-gemini = [
"google-generativeai>=0.8.0",
]

providers-all = [
"vtk-prompt[providers-anthropic,providers-gemini]",
]

# All optional dependencies
all = [
"vtk-prompt[dev,test]",
"vtk-prompt[dev,test,providers-all]",
]

[project.urls]
Expand Down Expand Up @@ -132,7 +145,9 @@ module = [
"sentence_transformers.*",
"tree_sitter_languages.*",
"llama_index.*",
"query_db.*"
"query_db.*",
"google.*",
"anthropic.*",
]
ignore_missing_imports = true

Expand Down
8 changes: 5 additions & 3 deletions src/vtk_prompt/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,16 +122,18 @@ def main(
top_k=top_k,
rag=rag,
retry_attempts=retry_attempts,
provider=provider,
)

if isinstance(result, tuple) and len(result) == 3:
explanation, generated_code, usage = result
_explanation, generated_code, usage = result
if verbose and usage:
logger.info(
"Used tokens: input=%d output=%d",
usage.prompt_tokens,
usage.completion_tokens,
usage["prompt_tokens"],
usage["completion_tokens"],
)
logger.info("Explanation:\n%s", explanation)
client.run_code(generated_code)
else:
# Handle string result
Expand Down
273 changes: 212 additions & 61 deletions src/vtk_prompt/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,14 @@
get_python_role,
get_rag_context,
)
from .provider_utils import supports_temperature

logger = get_logger(__name__)


@dataclass
class VTKPromptClient:
"""OpenAI client for VTK code generation."""
"""Multi-provider LLM client for VTK code generation."""

_instance: Optional["VTKPromptClient"] = None
_initialized: bool = False
Expand All @@ -44,6 +45,7 @@ class VTKPromptClient:
verbose: bool = False
conversation_file: Optional[str] = None
conversation: Optional[list[dict[str, str]]] = None
provider: str = "openai"

def __new__(cls, **kwargs: Any) -> "VTKPromptClient":
"""Create singleton instance of VTKPromptClient."""
Expand Down Expand Up @@ -132,6 +134,151 @@ def run_code(self, code_string: str) -> None:
logger.debug("Failed code:\n%s", code_string)
return

def _get_provider_client(self, api_key: str, base_url: Optional[str] = None) -> Any:
"""Create provider-specific client."""
if self.provider == "anthropic":
try:
import anthropic

return anthropic.Anthropic(api_key=api_key)
except ImportError:
raise ValueError(
"Anthropic provider requires 'anthropic' package. "
"Install with: pip install 'vtk-prompt[providers-anthropic]'"
)
elif self.provider == "gemini":
try:
import google.generativeai as genai

genai.configure(api_key=api_key)
return genai
except ImportError:
raise ValueError(
"Gemini provider requires 'google-generativeai' package. "
"Install with: pip install 'vtk-prompt[providers-gemini]'"
)
else:
# OpenAI-compatible providers (openai, nim)
return openai.OpenAI(api_key=api_key, base_url=base_url)

def _convert_to_provider_format(self, messages: list[dict[str, str]]) -> Any:
"""Convert OpenAI format messages to provider-specific format."""
if self.provider == "anthropic":
system_message = None
conversation_messages = []

for msg in messages:
if msg["role"] == "system":
system_message = msg["content"]
else:
conversation_messages.append({"role": msg["role"], "content": msg["content"]})

return {"system": system_message, "messages": conversation_messages}
elif self.provider == "gemini":
# Gemini uses a different format - convert to their chat format
gemini_messages = []
for msg in messages:
if msg["role"] == "system":
# Gemini doesn't have system role, prepend to first user message
continue
elif msg["role"] == "user":
gemini_messages.append({"role": "user", "parts": [{"text": msg["content"]}]})
elif msg["role"] == "assistant":
gemini_messages.append({"role": "model", "parts": [{"text": msg["content"]}]})
return gemini_messages
else:
# OpenAI-compatible providers use the same format
return messages

def _make_provider_request(
self,
client: Any,
messages: Any,
model: str,
max_tokens: int,
temperature: float,
) -> Any:
"""Make provider-specific API request."""
if self.provider == "anthropic":
return client.messages.create(
model=model,
system=messages["system"],
messages=messages["messages"],
max_tokens=max_tokens,
temperature=temperature,
)
elif self.provider == "gemini":
model_instance = client.GenerativeModel(model)
# Gemini API is different - convert our format
return model_instance.generate_content(
messages[-1]["parts"][0]["text"] if messages else "",
generation_config=client.types.GenerationConfig(
max_output_tokens=max_tokens,
temperature=temperature,
),
)
else:
# OpenAI-compatible providers
request_params = {
"model": model,
"messages": messages,
"max_completion_tokens": max_tokens,
}
if supports_temperature(model):
request_params["temperature"] = temperature
return client.chat.completions.create(**request_params)

def _extract_response_content(self, response: Any) -> tuple[str, str]:
"""Extract content and finish reason from provider-specific response."""
if self.provider == "anthropic":
content = response.content[0].text if response.content else "No content in response"
finish_reason = response.stop_reason or "unknown"
elif self.provider == "gemini":
content = response.text if hasattr(response, "text") else "No content in response"
finish_reason = "stop" # Gemini doesn't provide finish_reason in the same way
else:
# OpenAI-compatible providers
if hasattr(response, "choices") and len(response.choices) > 0:
content = response.choices[0].message.content or "No content in response"
finish_reason = response.choices[0].finish_reason or "unknown"
else:
content = "No content in response"
finish_reason = "unknown"

return content, finish_reason

def _get_usage_info(self, response: Any) -> dict[str, int]:
"""Extract usage information from provider-specific response."""
if self.provider == "anthropic":
usage = response.usage
return {
"prompt_tokens": usage.input_tokens,
"completion_tokens": usage.output_tokens,
}
elif self.provider == "gemini":
if response.usage_metadata:
prompt_token_count = response.usage_metadata.prompt_token_count
completion_token_count = response.usage_metadata.candidates_token_count
return {
"prompt_tokens": prompt_token_count,
"completion_tokens": completion_token_count,
}
return {"prompt_tokens": 0, "completion_tokens": 0}
else:
# OpenAI-compatible providers
if hasattr(response, "usage") and response.usage:
return {
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
}
return {"prompt_tokens": 0, "completion_tokens": 0}

def _switch_provider(self, new_provider: str) -> None:
"""Switch to a new provider. Conversation is already in OpenAI format (normalized)."""
if self.verbose:
logger.debug(f"Switching provider from {self.provider} to {new_provider}")
self.provider = new_provider

def query(
self,
message: str = "",
Expand All @@ -143,6 +290,7 @@ def query(
top_k: int = 5,
rag: bool = False,
retry_attempts: int = 1,
provider: Optional[str] = None,
) -> Union[tuple[str, str, Any], str]:
"""Generate VTK code with optional RAG enhancement and retry logic.

Expand All @@ -156,15 +304,19 @@ def query(
top_k: Number of RAG examples to retrieve
rag: Whether to use RAG enhancement
retry_attempts: Number of times to retry if AST validation fails
provider: LLM provider to use (overrides instance provider if provided)
"""
# Handle provider switching
if provider and provider != self.provider:
self._switch_provider(provider)
if not api_key:
api_key = os.environ.get("OPENAI_API_KEY")

if not api_key:
raise ValueError("No API key provided. Set OPENAI_API_KEY or pass api_key parameter.")

# Create client with current parameters
client = openai.OpenAI(api_key=api_key, base_url=base_url)
# Create provider-specific client
client = self._get_provider_client(api_key, base_url)

# Load existing conversation if present
if self.conversation_file and not self.conversation:
Expand Down Expand Up @@ -227,69 +379,68 @@ def query(
model=model,
messages=self.conversation, # type: ignore[arg-type]
max_completion_tokens=max_tokens,
# max_tokens=max_tokens,
temperature=temperature,
)

if hasattr(response, "choices") and len(response.choices) > 0:
content = response.choices[0].message.content or "No content in response"
finish_reason = response.choices[0].finish_reason

if finish_reason == "length":
raise ValueError(
f"Output was truncated due to max_tokens limit ({max_tokens}).\n"
"Please increase max_tokens."
)

generated_explanation = re.findall(
"<explanation>(.*?)</explanation>", content, re.DOTALL
)[0]
generated_code = re.findall("<code>(.*?)</code>", content, re.DOTALL)[0]
if "import vtk" not in generated_code:
generated_code = "import vtk\n" + generated_code
# Extract response content using provider-specific method
content, finish_reason = self._extract_response_content(response)

if finish_reason == "length":
raise ValueError(
f"Output was truncated due to max_tokens limit ({max_tokens}).\n"
"Please increase max_tokens."
)

generated_explanation = re.findall(
"<explanation>(.*?)</explanation>", content, re.DOTALL
)[0]
generated_code = re.findall("<code>(.*?)</code>", content, re.DOTALL)[0]
if "import vtk" not in generated_code:
generated_code = "import vtk\n" + generated_code
else:
pos = generated_code.find("import vtk")
if pos != -1:
generated_code = generated_code[pos:]
else:
pos = generated_code.find("import vtk")
if pos != -1:
generated_code = generated_code[pos:]
else:
generated_code = generated_code

is_valid, error_msg = self.validate_code_syntax(generated_code)
if is_valid:
if message:
self.conversation.append({"role": "assistant", "content": content})
self.save_conversation()
return generated_explanation, generated_code, response.usage

elif attempt < retry_attempts - 1: # Don't log on last attempt
if self.verbose:
logger.warning("AST validation failed: %s. Retrying...", error_msg)
# Add error feedback to context for retry
generated_code = generated_code

is_valid, error_msg = self.validate_code_syntax(generated_code)
if is_valid:
if message:
self.conversation.append({"role": "assistant", "content": content})
self.conversation.append(
{
"role": "user",
"content": (
f"The generated code has a syntax error: {error_msg}. "
"Please fix the syntax and generate valid Python code."
),
}
)
else:
# Last attempt failed
if self.verbose:
logger.error("Final attempt failed AST validation: %s", error_msg)

if message:
self.conversation.append({"role": "assistant", "content": content})
self.save_conversation()
return (
generated_explanation,
generated_code,
response.usage or {},
) # Return anyway, let caller handle
self.save_conversation()
return (
generated_explanation,
generated_code,
self._get_usage_info(response),
)

elif attempt < retry_attempts - 1: # Don't log on last attempt
if self.verbose:
logger.warning("AST validation failed: %s. Retrying...", error_msg)
# Add error feedback to context for retry
self.conversation.append({"role": "assistant", "content": content})
self.conversation.append(
{
"role": "user",
"content": (
f"The generated code has a syntax error: {error_msg}. "
"Please fix the syntax and generate valid Python code."
),
}
)
else:
if attempt == retry_attempts - 1:
return ("No response generated", "", response.usage or {})
# Last attempt failed
if self.verbose:
logger.error("Final attempt failed AST validation: %s", error_msg)

if message:
self.conversation.append({"role": "assistant", "content": content})
self.save_conversation()
return (
generated_explanation,
generated_code,
self._get_usage_info(response),
) # Return anyway, let caller handle

return "No response generated"
Loading
Loading