diff --git a/.gitignore b/.gitignore
index 21f50bfa..e606b68c 100644
--- a/.gitignore
+++ b/.gitignore
@@ -168,3 +168,4 @@ benchmark_stress_db/
examples/data/
third_party/agfs/bin/
openviking/_version.py
+specs/
diff --git a/README.md b/README.md
index 7e09c5f2..b8107d29 100644
--- a/README.md
+++ b/README.md
@@ -89,10 +89,115 @@ OpenViking requires the following model capabilities:
- **VLM Model**: For image and content understanding
- **Embedding Model**: For vectorization and semantic retrieval
-OpenViking supports various model services:
-- **OpenAI Models**: Supports GPT-4V and other VLM models, and OpenAI Embedding models.
-- **Volcengine (Doubao Models)**: Recommended for low cost and high performance, with free quotas for new users. For purchase and activation, please refer to: [Volcengine Purchase Guide](./docs/en/guides/02-volcengine-purchase-guide.md).
-- **Other Custom Model Services**: Supports model services compatible with the OpenAI API format.
+#### Supported VLM Providers
+
+OpenViking supports multiple VLM providers:
+
+| Provider | Model | Get API Key |
+|----------|-------|-------------|
+| `volcengine` | doubao | [Volcengine Console](https://console.volcengine.com/ark) |
+| `openai` | gpt | [OpenAI Platform](https://platform.openai.com) |
+| `anthropic` | claude | [Anthropic Console](https://console.anthropic.com) |
+| `deepseek` | deepseek | [DeepSeek Platform](https://platform.deepseek.com) |
+| `gemini` | gemini | [Google AI Studio](https://aistudio.google.com) |
+| `moonshot` | kimi | [Moonshot Platform](https://platform.moonshot.cn) |
+| `zhipu` | glm | [Zhipu Open Platform](https://open.bigmodel.cn) |
+| `dashscope` | qwen | [DashScope Console](https://dashscope.console.aliyun.com) |
+| `minimax` | minimax | [MiniMax Platform](https://platform.minimax.io) |
+| `openrouter` | (any model) | [OpenRouter](https://openrouter.ai) |
+| `vllm` | (local model) | β |
+
+> π‘ **Tip**: OpenViking uses a **Provider Registry** for unified model access. The system automatically detects the provider based on model name keywords, so you can switch between providers seamlessly.
+
+#### Provider-Specific Notes
+
+
+Volcengine (Doubao)
+
+Volcengine supports both model names and endpoint IDs. Using model names is recommended for simplicity:
+
+```json
+{
+ "vlm": {
+ "provider": "volcengine",
+ "model": "doubao-seed-1-6-240615",
+ "api_key": "your-api-key"
+ }
+}
+```
+
+You can also use endpoint IDs (found in [Volcengine ARK Console](https://console.volcengine.com/ark)):
+
+```json
+{
+ "vlm": {
+ "provider": "volcengine",
+ "model": "ep-20241220174930-xxxxx",
+ "api_key": "your-api-key"
+ }
+}
+```
+
+
+
+
+Zhipu AI (ζΊθ°±)
+
+If you're on Zhipu's coding plan, use the coding API endpoint:
+
+```json
+{
+ "vlm": {
+ "provider": "zhipu",
+ "model": "glm-4-plus",
+ "api_key": "your-api-key",
+ "api_base": "https://open.bigmodel.cn/api/coding/paas/v4"
+ }
+}
+```
+
+
+
+
+MiniMax (δΈε½ε€§ι)
+
+For MiniMax's mainland China platform (minimaxi.com), specify the API base:
+
+```json
+{
+ "vlm": {
+ "provider": "minimax",
+ "model": "abab6.5s-chat",
+ "api_key": "your-api-key",
+ "api_base": "https://api.minimaxi.com/v1"
+ }
+}
+```
+
+
+
+
+Local Models (vLLM)
+
+Run OpenViking with your own local models using vLLM:
+
+```bash
+# Start vLLM server
+vllm serve meta-llama/Llama-3.1-8B-Instruct --port 8000
+```
+
+```json
+{
+ "vlm": {
+ "provider": "vllm",
+ "model": "meta-llama/Llama-3.1-8B-Instruct",
+ "api_key": "dummy",
+ "api_base": "http://localhost:8000/v1"
+ }
+}
+```
+
+
### 3. Environment Configuration
@@ -106,7 +211,7 @@ Create a configuration file `~/.openviking/ov.conf`:
"dense": {
"api_base" : "", // API endpoint address
"api_key" : "", // Model service API Key
- "provider" : "", // Provider type (volcengine or openai)
+ "provider" : "", // Provider type: "volcengine" or "openai" (currently supported)
"dimension": 1024, // Vector dimension
"model" : "" // Embedding model name (e.g., doubao-embedding-vision-250615 or text-embedding-3-large)
}
@@ -114,12 +219,14 @@ Create a configuration file `~/.openviking/ov.conf`:
"vlm": {
"api_base" : "", // API endpoint address
"api_key" : "", // Model service API Key
- "provider" : "", // Provider type (volcengine or openai)
+ "provider" : "", // Provider type (volcengine, openai, deepseek, anthropic, etc.)
"model" : "" // VLM model name (e.g., doubao-seed-1-8-251228 or gpt-4-vision-preview)
}
}
```
+> **Note**: For embedding models, currently only `volcengine` (Doubao) and `openai` providers are supported. For VLM models, we support multiple providers including volcengine, openai, deepseek, anthropic, gemini, moonshot, zhipu, dashscope, minimax, and more.
+
#### Configuration Examples
π Expand to see the configuration example for your model service:
diff --git a/openviking/models/vlm/__init__.py b/openviking/models/vlm/__init__.py
index 8e38acea..e9d01f4f 100644
--- a/openviking/models/vlm/__init__.py
+++ b/openviking/models/vlm/__init__.py
@@ -2,13 +2,29 @@
# SPDX-License-Identifier: Apache-2.0
"""VLM (Vision-Language Model) module"""
+from .backends.litellm_vlm import LiteLLMVLMProvider
from .backends.openai_vlm import OpenAIVLM
from .backends.volcengine_vlm import VolcEngineVLM
from .base import VLMBase, VLMFactory
+from .registry import (
+ PROVIDERS,
+ ProviderSpec,
+ find_by_model,
+ find_by_name,
+ find_gateway,
+ get_all_provider_names,
+)
__all__ = [
"VLMBase",
"VLMFactory",
"OpenAIVLM",
"VolcEngineVLM",
+ "LiteLLMVLMProvider",
+ "ProviderSpec",
+ "PROVIDERS",
+ "find_by_model",
+ "find_by_name",
+ "find_gateway",
+ "get_all_provider_names",
]
diff --git a/openviking/models/vlm/backends/litellm_vlm.py b/openviking/models/vlm/backends/litellm_vlm.py
new file mode 100644
index 00000000..f1efa562
--- /dev/null
+++ b/openviking/models/vlm/backends/litellm_vlm.py
@@ -0,0 +1,223 @@
+# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd.
+# SPDX-License-Identifier: Apache-2.0
+"""LiteLLM VLM Provider implementation with multi-provider support."""
+
+import os
+
+os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
+
+import asyncio
+import base64
+from pathlib import Path
+from typing import Any, Dict, List, Union
+
+import litellm
+from litellm import acompletion, completion
+
+from ..base import VLMBase
+from ..registry import find_by_model, find_gateway
+
+
+class LiteLLMVLMProvider(VLMBase):
+ """
+ Multi-provider VLM implementation based on LiteLLM.
+
+ Supports OpenRouter, Anthropic, OpenAI, Gemini, DeepSeek, VolcEngine and many other providers
+ through a unified interface. Provider-specific logic is driven by the registry
+ (see providers/registry.py) β no if-elif chains needed here.
+ """
+
+ def __init__(self, config: Dict[str, Any]):
+ super().__init__(config)
+
+ self._provider_name = config.get("provider")
+ self._extra_headers = config.get("extra_headers") or {}
+ self._thinking = config.get("thinking", False)
+
+ self._gateway = find_gateway(self._provider_name, self.api_key, self.api_base)
+
+ if self.api_key:
+ self._setup_env(self.api_key, self.api_base, self.model)
+
+ if self.api_base:
+ litellm.api_base = self.api_base
+
+ litellm.suppress_debug_info = True
+ litellm.drop_params = True
+
+ def _setup_env(self, api_key: str, api_base: str | None, model: str | None) -> None:
+ """Set environment variables based on detected provider."""
+ spec = self._gateway or find_by_model(model or "")
+ if not spec:
+ return
+
+ if self._gateway:
+ os.environ[spec.env_key] = api_key
+ else:
+ os.environ.setdefault(spec.env_key, api_key)
+
+ effective_base = api_base or spec.default_api_base
+ for env_name, env_val in spec.env_extras:
+ resolved = env_val.replace("{api_key}", api_key)
+ resolved = resolved.replace("{api_base}", effective_base or "")
+ os.environ.setdefault(env_name, resolved)
+
+ def _resolve_model(self, model: str) -> str:
+ """Resolve model name by applying provider/gateway prefixes."""
+ if self._gateway:
+ if self._gateway.strip_model_prefix:
+ model = model.split("/")[-1]
+ prefix = self._gateway.litellm_prefix
+ if prefix and not model.startswith(f"{prefix}/"):
+ model = f"{prefix}/{model}"
+ return model
+
+ spec = find_by_model(model)
+ if spec and spec.litellm_prefix:
+ if not any(model.startswith(s) for s in spec.skip_prefixes):
+ model = f"{spec.litellm_prefix}/{model}"
+ return model
+
+ def _apply_model_overrides(self, model: str, kwargs: dict[str, Any]) -> None:
+ """Apply model-specific parameter overrides from the registry."""
+ model_lower = model.lower()
+ spec = find_by_model(model)
+ if spec:
+ for pattern, overrides in spec.model_overrides:
+ if pattern in model_lower:
+ kwargs.update(overrides)
+ return
+
+ if self._provider_name == "volcengine":
+ kwargs["thinking"] = {"type": "enabled" if self._thinking else "disabled"}
+
+ def _prepare_image(self, image: Union[str, Path, bytes]) -> Dict[str, Any]:
+ """Prepare image data for vision completion."""
+ if isinstance(image, bytes):
+ b64 = base64.b64encode(image).decode("utf-8")
+ return {
+ "type": "image_url",
+ "image_url": {"url": f"data:image/png;base64,{b64}"},
+ }
+ elif isinstance(image, Path) or (
+ isinstance(image, str) and not image.startswith(("http://", "https://"))
+ ):
+ path = Path(image)
+ suffix = path.suffix.lower()
+ mime_type = {
+ ".png": "image/png",
+ ".jpg": "image/jpeg",
+ ".jpeg": "image/jpeg",
+ ".gif": "image/gif",
+ ".webp": "image/webp",
+ }.get(suffix, "image/png")
+ with open(path, "rb") as f:
+ b64 = base64.b64encode(f.read()).decode("utf-8")
+ return {
+ "type": "image_url",
+ "image_url": {"url": f"data:{mime_type};base64,{b64}"},
+ }
+ else:
+ return {"type": "image_url", "image_url": {"url": image}}
+
+ def _build_kwargs(self, model: str, messages: list) -> dict[str, Any]:
+ """Build kwargs for LiteLLM call."""
+ kwargs: dict[str, Any] = {
+ "model": model,
+ "messages": messages,
+ "temperature": self.temperature,
+ }
+
+ self._apply_model_overrides(model, kwargs)
+
+ if self.api_key:
+ kwargs["api_key"] = self.api_key
+ if self.api_base:
+ kwargs["api_base"] = self.api_base
+ if self._extra_headers:
+ kwargs["extra_headers"] = self._extra_headers
+
+ return kwargs
+
+ def get_completion(self, prompt: str) -> str:
+ """Get text completion synchronously."""
+ model = self._resolve_model(self.model or "gpt-4o-mini")
+ messages = [{"role": "user", "content": prompt}]
+ kwargs = self._build_kwargs(model, messages)
+
+ response = completion(**kwargs)
+ self._update_token_usage_from_response(response)
+ return response.choices[0].message.content or ""
+
+ async def get_completion_async(self, prompt: str, max_retries: int = 0) -> str:
+ """Get text completion asynchronously."""
+ model = self._resolve_model(self.model or "gpt-4o-mini")
+ messages = [{"role": "user", "content": prompt}]
+ kwargs = self._build_kwargs(model, messages)
+
+ last_error = None
+ for attempt in range(max_retries + 1):
+ try:
+ response = await acompletion(**kwargs)
+ self._update_token_usage_from_response(response)
+ return response.choices[0].message.content or ""
+ except Exception as e:
+ last_error = e
+ if attempt < max_retries:
+ await asyncio.sleep(2 ** attempt)
+
+ if last_error:
+ raise last_error
+ raise RuntimeError("Unknown error in async completion")
+
+ def get_vision_completion(
+ self,
+ prompt: str,
+ images: List[Union[str, Path, bytes]],
+ ) -> str:
+ """Get vision completion synchronously."""
+ model = self._resolve_model(self.model or "gpt-4o-mini")
+
+ content = []
+ for img in images:
+ content.append(self._prepare_image(img))
+ content.append({"type": "text", "text": prompt})
+
+ messages = [{"role": "user", "content": content}]
+ kwargs = self._build_kwargs(model, messages)
+
+ response = completion(**kwargs)
+ self._update_token_usage_from_response(response)
+ return response.choices[0].message.content or ""
+
+ async def get_vision_completion_async(
+ self,
+ prompt: str,
+ images: List[Union[str, Path, bytes]],
+ ) -> str:
+ """Get vision completion asynchronously."""
+ model = self._resolve_model(self.model or "gpt-4o-mini")
+
+ content = []
+ for img in images:
+ content.append(self._prepare_image(img))
+ content.append({"type": "text", "text": prompt})
+
+ messages = [{"role": "user", "content": content}]
+ kwargs = self._build_kwargs(model, messages)
+
+ response = await acompletion(**kwargs)
+ self._update_token_usage_from_response(response)
+ return response.choices[0].message.content or ""
+
+ def _update_token_usage_from_response(self, response) -> None:
+ """Update token usage from response."""
+ if hasattr(response, "usage") and response.usage:
+ prompt_tokens = response.usage.prompt_tokens
+ completion_tokens = response.usage.completion_tokens
+ self.update_token_usage(
+ model_name=self.model or "unknown",
+ provider=self.provider,
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ )
diff --git a/openviking/models/vlm/base.py b/openviking/models/vlm/base.py
index 2efeb192..ef55f712 100644
--- a/openviking/models/vlm/base.py
+++ b/openviking/models/vlm/base.py
@@ -123,18 +123,21 @@ def create(config: Dict[str, Any]) -> VLMBase:
"""
provider = config.get("provider") or config.get("backend") or "openai"
- if provider == "openai":
- from .backends.openai_vlm import OpenAIVLM
+ use_litellm = config.get("use_litellm", True)
- return OpenAIVLM(config)
- elif provider == "volcengine":
- from .backends.volcengine_vlm import VolcEngineVLM
+ if not use_litellm:
+ if provider == "openai":
+ from .backends.openai_vlm import OpenAIVLM
+ return OpenAIVLM(config)
+ elif provider == "volcengine":
+ from .backends.volcengine_vlm import VolcEngineVLM
+ return VolcEngineVLM(config)
- return VolcEngineVLM(config)
- else:
- raise ValueError(f"Unsupported VLM provider: {provider}")
+ from .backends.litellm_vlm import LiteLLMVLMProvider
+ return LiteLLMVLMProvider(config)
@staticmethod
def get_available_providers() -> List[str]:
"""Get list of available providers"""
- return ["openai", "volcengine"]
+ from .registry import get_all_provider_names
+ return get_all_provider_names()
diff --git a/openviking/models/vlm/registry.py b/openviking/models/vlm/registry.py
new file mode 100644
index 00000000..815b637a
--- /dev/null
+++ b/openviking/models/vlm/registry.py
@@ -0,0 +1,231 @@
+# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd.
+# SPDX-License-Identifier: Apache-2.0
+"""
+Provider Registry β single source of truth for LLM provider metadata.
+
+Adding a new provider:
+ 1. Add a ProviderSpec to PROVIDERS below.
+ 2. Use it in config with providers["newprovider"] = {"api_key": "xxx"}
+ Done. Env vars, prefixing, config matching, status display all derive from here.
+
+Order matters β it controls match priority and fallback. Gateways first.
+Every entry writes out all fields so you can copy-paste as a template.
+"""
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import Any, Optional
+
+
+@dataclass(frozen=True)
+class ProviderSpec:
+ """VLM Provider metadata definition.
+
+ Placeholders in env_extras values:
+ {api_key} - the user's API key
+ {api_base} - api_base from config, or this spec's default_api_base
+ """
+
+ name: str
+ keywords: tuple[str, ...]
+ env_key: str
+ display_name: str = ""
+
+ litellm_prefix: str = ""
+ skip_prefixes: tuple[str, ...] = ()
+
+ env_extras: tuple[tuple[str, str], ...] = ()
+
+ is_gateway: bool = False
+ is_local: bool = False
+ detect_by_key_prefix: str = ""
+ detect_by_base_keyword: str = ""
+ default_api_base: str = ""
+
+ strip_model_prefix: bool = False
+
+ model_overrides: tuple[tuple[str, dict[str, Any]], ...] = ()
+
+ @property
+ def label(self) -> str:
+ return self.display_name or self.name.title()
+
+
+PROVIDERS: tuple[ProviderSpec, ...] = (
+ ProviderSpec(
+ name="custom",
+ keywords=(),
+ env_key="OPENAI_API_KEY",
+ display_name="Custom",
+ litellm_prefix="openai",
+ skip_prefixes=("openai/",),
+ is_gateway=True,
+ strip_model_prefix=True,
+ ),
+
+ ProviderSpec(
+ name="openrouter",
+ keywords=("openrouter",),
+ env_key="OPENROUTER_API_KEY",
+ display_name="OpenRouter",
+ litellm_prefix="openrouter",
+ is_gateway=True,
+ detect_by_key_prefix="sk-or-",
+ detect_by_base_keyword="openrouter",
+ default_api_base="https://openrouter.ai/api/v1",
+ ),
+
+ ProviderSpec(
+ name="volcengine",
+ keywords=("doubao", "volcengine", "ep-"),
+ env_key="VOLCENGINE_API_KEY",
+ display_name="VolcEngine",
+ litellm_prefix="volcengine",
+ skip_prefixes=("volcengine/",),
+ default_api_base="https://ark.cn-beijing.volces.com/api/v3",
+ ),
+
+ ProviderSpec(
+ name="openai",
+ keywords=("openai", "gpt"),
+ env_key="OPENAI_API_KEY",
+ display_name="OpenAI",
+ litellm_prefix="",
+ ),
+
+ ProviderSpec(
+ name="anthropic",
+ keywords=("anthropic", "claude"),
+ env_key="ANTHROPIC_API_KEY",
+ display_name="Anthropic",
+ litellm_prefix="",
+ ),
+
+ ProviderSpec(
+ name="deepseek",
+ keywords=("deepseek",),
+ env_key="DEEPSEEK_API_KEY",
+ display_name="DeepSeek",
+ litellm_prefix="deepseek",
+ skip_prefixes=("deepseek/",),
+ ),
+
+ ProviderSpec(
+ name="gemini",
+ keywords=("gemini",),
+ env_key="GEMINI_API_KEY",
+ display_name="Gemini",
+ litellm_prefix="gemini",
+ skip_prefixes=("gemini/",),
+ ),
+
+ ProviderSpec(
+ name="moonshot",
+ keywords=("moonshot", "kimi"),
+ env_key="MOONSHOT_API_KEY",
+ display_name="Moonshot",
+ litellm_prefix="moonshot",
+ skip_prefixes=("moonshot/",),
+ env_extras=(
+ ("MOONSHOT_API_BASE", "{api_base}"),
+ ),
+ default_api_base="https://api.moonshot.ai/v1",
+ model_overrides=(
+ ("kimi-k2.5", {"temperature": 1.0}),
+ ),
+ ),
+
+ ProviderSpec(
+ name="zhipu",
+ keywords=("zhipu", "glm", "zai"),
+ env_key="ZAI_API_KEY",
+ display_name="Zhipu AI",
+ litellm_prefix="zai",
+ skip_prefixes=("zhipu/", "zai/"),
+ env_extras=(
+ ("ZHIPUAI_API_KEY", "{api_key}"),
+ ),
+ ),
+
+ ProviderSpec(
+ name="dashscope",
+ keywords=("qwen", "dashscope"),
+ env_key="DASHSCOPE_API_KEY",
+ display_name="DashScope",
+ litellm_prefix="dashscope",
+ skip_prefixes=("dashscope/",),
+ ),
+
+ ProviderSpec(
+ name="minimax",
+ keywords=("minimax",),
+ env_key="MINIMAX_API_KEY",
+ display_name="MiniMax",
+ litellm_prefix="minimax",
+ skip_prefixes=("minimax/",),
+ default_api_base="https://api.minimax.io/v1",
+ ),
+
+ ProviderSpec(
+ name="vllm",
+ keywords=("vllm",),
+ env_key="HOSTED_VLLM_API_KEY",
+ display_name="vLLM/Local",
+ litellm_prefix="hosted_vllm",
+ is_local=True,
+ ),
+)
+
+
+def find_by_model(model: str) -> ProviderSpec | None:
+ """Match a standard provider by model-name keyword (case-insensitive).
+ Skips gateways/local β those are matched by api_key/api_base instead."""
+ model_lower = model.lower()
+ for spec in PROVIDERS:
+ if spec.is_gateway or spec.is_local:
+ continue
+ if any(kw in model_lower for kw in spec.keywords):
+ return spec
+ return None
+
+
+def find_gateway(
+ provider_name: str | None = None,
+ api_key: str | None = None,
+ api_base: str | None = None,
+) -> ProviderSpec | None:
+ """Detect gateway/local provider.
+
+ Priority:
+ 1. provider_name β if it maps to a gateway/local spec, use it directly.
+ 2. api_key prefix β e.g. "sk-or-" β OpenRouter.
+ 3. api_base keyword β e.g. "aihubmix" in URL β AiHubMix.
+ """
+ if provider_name:
+ spec = find_by_name(provider_name)
+ if spec and (spec.is_gateway or spec.is_local):
+ return spec
+
+ for spec in PROVIDERS:
+ if spec.detect_by_key_prefix and api_key and api_key.startswith(spec.detect_by_key_prefix):
+ return spec
+
+ for spec in PROVIDERS:
+ if spec.detect_by_base_keyword and api_base and spec.detect_by_base_keyword in api_base:
+ return spec
+
+ return None
+
+
+def find_by_name(name: str) -> ProviderSpec | None:
+ """Find a provider spec by config field name, e.g. "dashscope"."""
+ for spec in PROVIDERS:
+ if spec.name == name:
+ return spec
+ return None
+
+
+def get_all_provider_names() -> list[str]:
+ """Get all provider names list."""
+ return [spec.name for spec in PROVIDERS]
diff --git a/openviking_cli/utils/config/vlm_config.py b/openviking_cli/utils/config/vlm_config.py
index 88f70024..ad1bea8f 100644
--- a/openviking_cli/utils/config/vlm_config.py
+++ b/openviking_cli/utils/config/vlm_config.py
@@ -1,25 +1,34 @@
# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd.
# SPDX-License-Identifier: Apache-2.0
-from typing import Any, Literal, Optional
+from typing import Any, Dict, Optional
from pydantic import BaseModel, Field, model_validator
class VLMConfig(BaseModel):
- """VLM configuration, supports multiple backends (openai, volcengine)."""
+ """VLM configuration, supports multiple provider backends."""
model: Optional[str] = Field(default=None, description="Model name")
api_key: Optional[str] = Field(default=None, description="API key")
api_base: Optional[str] = Field(default=None, description="API base URL")
temperature: float = Field(default=0.0, description="Generation temperature")
max_retries: int = Field(default=2, description="Maximum retry attempts")
- provider: Optional[Literal["openai", "volcengine"]] = Field(
- default="volcengine", description="Provider type"
+
+ provider: Optional[str] = Field(default=None, description="Provider type")
+ backend: Optional[str] = Field(default=None, description="Backend provider (Deprecated, use 'provider' instead)")
+
+ providers: Dict[str, Dict[str, Any]] = Field(
+ default_factory=dict,
+ description="Multi-provider configuration, e.g. {'deepseek': {'api_key': 'xxx', 'api_base': 'xxx'}}"
)
- backend: Literal["openai", "volcengine"] = Field(
- default="volcengine", description="Backend provider (Deprecated, use 'provider' instead)"
+
+ default_provider: Optional[str] = Field(
+ default=None,
+ description="Default provider name"
)
+ thinking: bool = Field(default=False, description="Enable thinking mode for VolcEngine models")
+
_vlm_instance: Optional[Any] = None
model_config = {"arbitrary_types_allowed": True, "extra": "forbid"}
@@ -38,27 +47,125 @@ def sync_provider_backend(cls, data: Any) -> Any:
@model_validator(mode="after")
def validate_config(self):
"""Validate configuration completeness and consistency"""
- if self.backend and not self.provider:
- self.provider = self.backend
+ self._migrate_legacy_config()
- # VLM is optional, but if configured, must have required fields
- if self.api_key or self.model or self.api_base:
- # If any VLM config is provided, require model and api_key
+ if self._has_any_config():
if not self.model:
raise ValueError("VLM configuration requires 'model' to be set")
- if not self.api_key:
+ if not self._get_effective_api_key():
raise ValueError("VLM configuration requires 'api_key' to be set")
return self
- def get_vlm_instance(self) -> Any:
- from openviking.models.vlm import VLMFactory
+ def _migrate_legacy_config(self):
+ """Migrate legacy config to providers structure."""
+ if self.api_key and self.provider:
+ if self.provider not in self.providers:
+ self.providers[self.provider] = {}
+ if "api_key" not in self.providers[self.provider]:
+ self.providers[self.provider]["api_key"] = self.api_key
+ if self.api_base and "api_base" not in self.providers[self.provider]:
+ self.providers[self.provider]["api_base"] = self.api_base
+
+ def _has_any_config(self) -> bool:
+ """Check if any config is provided."""
+ if self.api_key or self.model or self.api_base:
+ return True
+ if self.providers:
+ for p in self.providers.values():
+ if p.get("api_key"):
+ return True
+ return False
+
+ def _get_effective_api_key(self) -> str | None:
+ """Get effective API key."""
+ if self.api_key:
+ return self.api_key
+ config, _ = self._match_provider()
+ if config and config.get("api_key"):
+ return config["api_key"]
+ return None
+
+ def _match_provider(self, model: str | None = None) -> tuple[Dict[str, Any] | None, str | None]:
+ """Match provider config by model name.
+
+ Returns:
+ (provider_config_dict, provider_name)
+ """
+ from openviking.models.vlm.registry import PROVIDERS
+
+ model_lower = (model or self.model or "").lower()
+
+ if self.provider:
+ p = self.providers.get(self.provider)
+ if p and p.get("api_key"):
+ return p, self.provider
+
+ for spec in PROVIDERS:
+ p = self.providers.get(spec.name)
+ if p and any(kw in model_lower for kw in spec.keywords) and p.get("api_key"):
+ return p, spec.name
+
+ for spec in PROVIDERS:
+ if spec.is_gateway:
+ p = self.providers.get(spec.name)
+ if p and p.get("api_key"):
+ return p, spec.name
+
+ for spec in PROVIDERS:
+ if not spec.is_gateway:
+ p = self.providers.get(spec.name)
+ if p and p.get("api_key"):
+ return p, spec.name
+
+ return None, None
+
+ def get_provider_config(
+ self, model: str | None = None
+ ) -> tuple[Dict[str, Any] | None, str | None, "Any | None"]:
+ """Get provider config and spec.
+
+ Returns:
+ (provider_config_dict, provider_name, ProviderSpec)
+ """
+ from openviking.models.vlm.registry import find_by_name, find_gateway
+
+ config, name = self._match_provider(model)
+ if config and name:
+ spec = find_by_name(name)
+ gateway = find_gateway(name, config.get("api_key"), config.get("api_base"))
+ return config, name, gateway or spec
+ return None, None, None
- """Get VLM instance"""
+ def get_vlm_instance(self) -> Any:
+ """Get VLM instance."""
if self._vlm_instance is None:
- config_dict = self.model_dump()
+ config_dict = self._build_vlm_config_dict()
+ from openviking.models.vlm import VLMFactory
self._vlm_instance = VLMFactory.create(config_dict)
return self._vlm_instance
+ def _build_vlm_config_dict(self) -> Dict[str, Any]:
+ """Build VLM instance config dict."""
+ config, name, spec = self.get_provider_config()
+
+ result = {
+ "model": self.model,
+ "temperature": self.temperature,
+ "max_retries": self.max_retries,
+ "provider": name,
+ "thinking": self.thinking,
+ }
+
+ if config:
+ result["api_key"] = config.get("api_key")
+ result["api_base"] = config.get("api_base")
+ result["extra_headers"] = config.get("extra_headers")
+
+ if spec and not result.get("api_base") and spec.default_api_base:
+ result["api_base"] = spec.default_api_base
+
+ return result
+
def get_completion(self, prompt: str) -> str:
"""Get LLM completion."""
return self.get_vlm_instance().get_completion(prompt)
@@ -69,7 +176,7 @@ async def get_completion_async(self, prompt: str, max_retries: int = 0) -> str:
def is_available(self) -> bool:
"""Check if LLM is configured."""
- return self.api_key is not None or self.api_base is not None
+ return self._get_effective_api_key() is not None
def get_vision_completion(
self,
diff --git a/pyproject.toml b/pyproject.toml
index 542afd29..b6efd649 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -56,6 +56,7 @@ dependencies = [
"protobuf>=6.33.5",
"pdfminer-six>=20251230",
"typer>=0.12.0",
+ "litellm>=1.0.0",
]
[project.optional-dependencies]