Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/cbenv.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ provider = "OpenAI"
embed_model = "text-embedding-3-small"
chat_model = "gpt-3.5-turbo"
api_key = "get-your-own"
api_base = "http://localhost:2134/v1"

[[llm]]
identifier = "Bedrock-titan"
Expand All @@ -283,6 +284,8 @@ The `embed-model` field is the model that will be used to generate embeddings by
While the `chat-model` is the model that will be used to answer questions with the <<_ask,ask>> command.
These models can be any that the provider's API supports, and should be provided in the format given in the provider's API docs.

When using OpenAI user's can provide a custom api base, allowing users to use cbshell with local/custom models that support the OpenAI api format.

The api-keys can also be given separately in the <<_credentials_file_format,credentials file>>, for example:

```
Expand Down
10 changes: 10 additions & 0 deletions src/cli/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,9 @@ pub enum CBShellError {
project: String,
span: Span,
},
CustomBaseNotSupported {
provider: String,
},
}

impl From<CBShellError> for ShellError {
Expand Down Expand Up @@ -281,6 +284,9 @@ impl From<CBShellError> for ShellError {
CBShellError::ColumnarClustersNotFound {project, span} => {
spanned_shell_error(format!("No columnar clusters found in project {}", project), "You can change the active project with the `cb-env project` command".to_string(), span)
}
CBShellError::CustomBaseNotSupported { provider } => {
spanned_shell_error(format!("{} does not support custom api base", provider), "Either remove `api_base` entry from the config file or use the provider OpenAI".to_string(), None)
}
}
}
}
Expand Down Expand Up @@ -356,6 +362,10 @@ pub fn embed_model_missing() -> ShellError {
CBShellError::EmbedModelMissing {}.into()
}

pub fn api_base_unsupported(provider: String) -> ShellError {
CBShellError::CustomBaseNotSupported { provider }.into()
}

pub fn insufficient_columnar_permissions_error(span: Span) -> ShellError {
CBShellError::InsufficientColumnarPermissions { span }.into()
}
Expand Down
10 changes: 7 additions & 3 deletions src/client/bedrock_client.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::cli::generic_error;
use crate::cli::{api_base_unsupported, generic_error};
use aws_sdk_bedrockruntime::operation::invoke_model::InvokeModelError;
use aws_sdk_bedrockruntime::primitives::Blob;
use nu_protocol::ShellError;
Expand All @@ -12,8 +12,12 @@ pub struct BedrockClient {}
const MAX_RESPONSE_TOKENS: i32 = 8192;

impl BedrockClient {
pub fn new() -> Self {
Self {}
pub fn new(api_base: Option<String>) -> Result<Self, ShellError> {
if api_base.is_some() {
return Err(api_base_unsupported("Bedrock".into()));
}

Ok(Self {})
}

pub fn batch_chunks(&self, chunks: Vec<String>) -> Vec<Vec<String>> {
Expand Down
7 changes: 6 additions & 1 deletion src/client/gemini_client.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::cli::{generic_error, llm_api_key_missing};
use crate::cli::{api_base_unsupported, generic_error, llm_api_key_missing};
use bytes::Bytes;
use log::info;
use nu_protocol::ShellError;
Expand All @@ -25,7 +25,12 @@ impl GeminiClient {
pub fn new(
api_key: Option<String>,
max_tokens: impl Into<Option<usize>>,
api_base: Option<String>,
) -> Result<Self, ShellError> {
if api_base.is_some() {
return Err(api_base_unsupported("Gemini".into()));
}

let max_tokens = max_tokens.into().unwrap_or(MAX_FREE_TIER_TOKENS);

if let Some(api_key) = api_key {
Expand Down
14 changes: 9 additions & 5 deletions src/client/llm_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,21 @@ impl LLMClients {
max_tokens: impl Into<Option<usize>>,
) -> Result<LLMClients, ShellError> {
let guard = state.lock().unwrap();
let (provider, api_key) = match guard.active_llm() {
Some(llm) => (llm.provider(), llm.api_key()),
let (provider, api_key, api_base) = match guard.active_llm() {
Some(llm) => (llm.provider(), llm.api_key(), llm.api_base()),
None => {
return Err(no_llm_configured());
}
};

let client = match provider {
Provider::OpenAI => LLMClients::OpenAI(OpenAIClient::new(api_key, max_tokens)?),
Provider::Gemini => LLMClients::Gemini(GeminiClient::new(api_key, max_tokens)?),
Provider::Bedrock => LLMClients::Bedrock(BedrockClient::new()),
Provider::OpenAI => {
LLMClients::OpenAI(OpenAIClient::new(api_key, max_tokens, api_base)?)
}
Provider::Gemini => {
LLMClients::Gemini(GeminiClient::new(api_key, max_tokens, api_base)?)
}
Provider::Bedrock => LLMClients::Bedrock(BedrockClient::new(api_base)?),
};

Ok(client)
Expand Down
13 changes: 11 additions & 2 deletions src/client/openai_client.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::cli::{generic_error, llm_api_key_missing};
use async_openai::config::OPENAI_API_BASE;
use async_openai::types::{
ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs,
ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs,
Expand All @@ -11,6 +12,7 @@ use tiktoken_rs::p50k_base;
pub struct OpenAIClient {
api_key: String,
max_tokens: usize,
api_base: String,
}

const MAX_FREE_TIER_TOKENS: usize = 150000;
Expand All @@ -19,13 +21,16 @@ impl OpenAIClient {
pub fn new(
api_key: Option<String>,
max_tokens: impl Into<Option<usize>>,
api_base: Option<String>,
) -> Result<Self, ShellError> {
let max_tokens = max_tokens.into().unwrap_or(MAX_FREE_TIER_TOKENS);
let api_base = api_base.unwrap_or(OPENAI_API_BASE.into());

if let Some(api_key) = api_key {
Ok(Self {
api_key,
max_tokens,
api_base,
})
} else {
Err(llm_api_key_missing("OpenAI".to_string()))
Expand Down Expand Up @@ -81,7 +86,9 @@ impl OpenAIClient {
model: String,
) -> Result<Vec<Vec<f32>>, ShellError> {
let client = Client::with_config(
async_openai::config::OpenAIConfig::default().with_api_key(self.api_key.clone()),
async_openai::config::OpenAIConfig::default()
.with_api_key(self.api_key.clone())
.with_api_base(self.api_base.clone()),
);

if log::log_enabled!(log::Level::Debug) {
Expand Down Expand Up @@ -163,7 +170,9 @@ impl OpenAIClient {
);

let client = Client::with_config(
async_openai::config::OpenAIConfig::default().with_api_key(self.api_key.clone()),
async_openai::config::OpenAIConfig::default()
.with_api_key(self.api_key.clone())
.with_api_base(self.api_base.clone()),
);

let request = CreateChatCompletionRequestArgs::default()
Expand Down
6 changes: 6 additions & 0 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ pub struct LLMConfig {
provider: Provider,
embed_model: Option<String>,
chat_model: Option<String>,
api_base: Option<String>,
}

impl LLMConfig {
Expand All @@ -300,6 +301,10 @@ impl LLMConfig {
pub fn chat_model(&self) -> Option<String> {
self.chat_model.clone()
}

pub fn api_base(&self) -> Option<String> {
self.api_base.clone()
}
}

impl Debug for LLMConfig {
Expand All @@ -310,6 +315,7 @@ impl Debug for LLMConfig {
.field("provider", &self.provider)
.field("embed_model", &self.embed_model)
.field("chat_model", &self.chat_model)
.field("api_base", &self.api_base)
.finish()
}
}
Expand Down
1 change: 1 addition & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,7 @@ fn make_state(
config.provider(),
config.embed_model(),
config.chat_model(),
config.api_base(),
);
llms.insert(config.identifier(), llm);

Expand Down
7 changes: 7 additions & 0 deletions src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ pub struct Llm {
provider: Provider,
embed_model: Option<String>,
chat_model: Option<String>,
api_base: Option<String>,
}

#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
Expand All @@ -52,12 +53,14 @@ impl Llm {
provider: Provider,
embed_model: Option<String>,
chat_model: Option<String>,
api_base: Option<String>,
) -> Self {
Self {
api_key,
provider,
embed_model,
chat_model,
api_base,
}
}

Expand All @@ -76,6 +79,10 @@ impl Llm {
pub fn chat_model(&self) -> Option<String> {
self.chat_model.clone()
}

pub fn api_base(&self) -> Option<String> {
self.api_base.clone()
}
}

pub struct State {
Expand Down
Loading