diff --git a/docs/cbenv.adoc b/docs/cbenv.adoc index 39d10c38..427b4ea2 100644 --- a/docs/cbenv.adoc +++ b/docs/cbenv.adoc @@ -260,6 +260,7 @@ provider = "OpenAI" embed_model = "text-embedding-3-small" chat_model = "gpt-3.5-turbo" api_key = "get-your-own" +api_base = "http://localhost:2134/v1" [[llm]] identifier = "Bedrock-titan" @@ -283,6 +284,8 @@ The `embed-model` field is the model that will be used to generate embeddings by While the `chat-model` is the model that will be used to answer questions with the <<_ask,ask>> command. These models can be any that the provider's API supports, and should be provided in the format given in the provider's API docs. +When using OpenAI user's can provide a custom api base, allowing users to use cbshell with local/custom models that support the OpenAI api format. + The api-keys can also be given separately in the <<_credentials_file_format,credentials file>>, for example: ``` diff --git a/src/cli/error.rs b/src/cli/error.rs index 33c0515d..25e5f485 100644 --- a/src/cli/error.rs +++ b/src/cli/error.rs @@ -180,6 +180,9 @@ pub enum CBShellError { project: String, span: Span, }, + CustomBaseNotSupported { + provider: String, + }, } impl From for ShellError { @@ -281,6 +284,9 @@ impl From for ShellError { CBShellError::ColumnarClustersNotFound {project, span} => { spanned_shell_error(format!("No columnar clusters found in project {}", project), "You can change the active project with the `cb-env project` command".to_string(), span) } + CBShellError::CustomBaseNotSupported { provider } => { + spanned_shell_error(format!("{} does not support custom api base", provider), "Either remove `api_base` entry from the config file or use the provider OpenAI".to_string(), None) + } } } } @@ -356,6 +362,10 @@ pub fn embed_model_missing() -> ShellError { CBShellError::EmbedModelMissing {}.into() } +pub fn api_base_unsupported(provider: String) -> ShellError { + CBShellError::CustomBaseNotSupported { provider }.into() +} + pub fn insufficient_columnar_permissions_error(span: Span) -> ShellError { CBShellError::InsufficientColumnarPermissions { span }.into() } diff --git a/src/client/bedrock_client.rs b/src/client/bedrock_client.rs index 6987f0a5..4d0548e7 100644 --- a/src/client/bedrock_client.rs +++ b/src/client/bedrock_client.rs @@ -1,4 +1,4 @@ -use crate::cli::generic_error; +use crate::cli::{api_base_unsupported, generic_error}; use aws_sdk_bedrockruntime::operation::invoke_model::InvokeModelError; use aws_sdk_bedrockruntime::primitives::Blob; use nu_protocol::ShellError; @@ -12,8 +12,12 @@ pub struct BedrockClient {} const MAX_RESPONSE_TOKENS: i32 = 8192; impl BedrockClient { - pub fn new() -> Self { - Self {} + pub fn new(api_base: Option) -> Result { + if api_base.is_some() { + return Err(api_base_unsupported("Bedrock".into())); + } + + Ok(Self {}) } pub fn batch_chunks(&self, chunks: Vec) -> Vec> { diff --git a/src/client/gemini_client.rs b/src/client/gemini_client.rs index 541cca45..5eb7615a 100644 --- a/src/client/gemini_client.rs +++ b/src/client/gemini_client.rs @@ -1,4 +1,4 @@ -use crate::cli::{generic_error, llm_api_key_missing}; +use crate::cli::{api_base_unsupported, generic_error, llm_api_key_missing}; use bytes::Bytes; use log::info; use nu_protocol::ShellError; @@ -25,7 +25,12 @@ impl GeminiClient { pub fn new( api_key: Option, max_tokens: impl Into>, + api_base: Option, ) -> Result { + if api_base.is_some() { + return Err(api_base_unsupported("Gemini".into())); + } + let max_tokens = max_tokens.into().unwrap_or(MAX_FREE_TIER_TOKENS); if let Some(api_key) = api_key { diff --git a/src/client/llm_client.rs b/src/client/llm_client.rs index 448ac02c..f610e5b6 100644 --- a/src/client/llm_client.rs +++ b/src/client/llm_client.rs @@ -52,17 +52,21 @@ impl LLMClients { max_tokens: impl Into>, ) -> Result { let guard = state.lock().unwrap(); - let (provider, api_key) = match guard.active_llm() { - Some(llm) => (llm.provider(), llm.api_key()), + let (provider, api_key, api_base) = match guard.active_llm() { + Some(llm) => (llm.provider(), llm.api_key(), llm.api_base()), None => { return Err(no_llm_configured()); } }; let client = match provider { - Provider::OpenAI => LLMClients::OpenAI(OpenAIClient::new(api_key, max_tokens)?), - Provider::Gemini => LLMClients::Gemini(GeminiClient::new(api_key, max_tokens)?), - Provider::Bedrock => LLMClients::Bedrock(BedrockClient::new()), + Provider::OpenAI => { + LLMClients::OpenAI(OpenAIClient::new(api_key, max_tokens, api_base)?) + } + Provider::Gemini => { + LLMClients::Gemini(GeminiClient::new(api_key, max_tokens, api_base)?) + } + Provider::Bedrock => LLMClients::Bedrock(BedrockClient::new(api_base)?), }; Ok(client) diff --git a/src/client/openai_client.rs b/src/client/openai_client.rs index e457b318..edad0ac9 100644 --- a/src/client/openai_client.rs +++ b/src/client/openai_client.rs @@ -1,4 +1,5 @@ use crate::cli::{generic_error, llm_api_key_missing}; +use async_openai::config::OPENAI_API_BASE; use async_openai::types::{ ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs, @@ -11,6 +12,7 @@ use tiktoken_rs::p50k_base; pub struct OpenAIClient { api_key: String, max_tokens: usize, + api_base: String, } const MAX_FREE_TIER_TOKENS: usize = 150000; @@ -19,13 +21,16 @@ impl OpenAIClient { pub fn new( api_key: Option, max_tokens: impl Into>, + api_base: Option, ) -> Result { let max_tokens = max_tokens.into().unwrap_or(MAX_FREE_TIER_TOKENS); + let api_base = api_base.unwrap_or(OPENAI_API_BASE.into()); if let Some(api_key) = api_key { Ok(Self { api_key, max_tokens, + api_base, }) } else { Err(llm_api_key_missing("OpenAI".to_string())) @@ -81,7 +86,9 @@ impl OpenAIClient { model: String, ) -> Result>, ShellError> { let client = Client::with_config( - async_openai::config::OpenAIConfig::default().with_api_key(self.api_key.clone()), + async_openai::config::OpenAIConfig::default() + .with_api_key(self.api_key.clone()) + .with_api_base(self.api_base.clone()), ); if log::log_enabled!(log::Level::Debug) { @@ -163,7 +170,9 @@ impl OpenAIClient { ); let client = Client::with_config( - async_openai::config::OpenAIConfig::default().with_api_key(self.api_key.clone()), + async_openai::config::OpenAIConfig::default() + .with_api_key(self.api_key.clone()) + .with_api_base(self.api_base.clone()), ); let request = CreateChatCompletionRequestArgs::default() diff --git a/src/config.rs b/src/config.rs index cbe496a6..e700da34 100644 --- a/src/config.rs +++ b/src/config.rs @@ -278,6 +278,7 @@ pub struct LLMConfig { provider: Provider, embed_model: Option, chat_model: Option, + api_base: Option, } impl LLMConfig { @@ -300,6 +301,10 @@ impl LLMConfig { pub fn chat_model(&self) -> Option { self.chat_model.clone() } + + pub fn api_base(&self) -> Option { + self.api_base.clone() + } } impl Debug for LLMConfig { @@ -310,6 +315,7 @@ impl Debug for LLMConfig { .field("provider", &self.provider) .field("embed_model", &self.embed_model) .field("chat_model", &self.chat_model) + .field("api_base", &self.api_base) .finish() } } diff --git a/src/main.rs b/src/main.rs index 1ad0682b..61cc3b06 100644 --- a/src/main.rs +++ b/src/main.rs @@ -703,6 +703,7 @@ fn make_state( config.provider(), config.embed_model(), config.chat_model(), + config.api_base(), ); llms.insert(config.identifier(), llm); diff --git a/src/state.rs b/src/state.rs index fbc40caf..161de952 100644 --- a/src/state.rs +++ b/src/state.rs @@ -37,6 +37,7 @@ pub struct Llm { provider: Provider, embed_model: Option, chat_model: Option, + api_base: Option, } #[derive(Clone, Debug, Deserialize, Serialize, PartialEq)] @@ -52,12 +53,14 @@ impl Llm { provider: Provider, embed_model: Option, chat_model: Option, + api_base: Option, ) -> Self { Self { api_key, provider, embed_model, chat_model, + api_base, } } @@ -76,6 +79,10 @@ impl Llm { pub fn chat_model(&self) -> Option { self.chat_model.clone() } + + pub fn api_base(&self) -> Option { + self.api_base.clone() + } } pub struct State {