Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/backends/cohere.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ impl Cohere {
normalize_response,
embedding_encoding_format,
embedding_dimensions,
None, // extra_headers - not exposed via Cohere wrapper
)
}
}
Expand Down
1 change: 1 addition & 0 deletions src/backends/groq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ impl Groq {
normalize_response,
None, // embedding_encoding_format - not supported by Groq
None, // embedding_dimensions - not supported by Groq
None, // extra_headers - not exposed via Groq wrapper
)
}
}
Expand Down
1 change: 1 addition & 0 deletions src/backends/huggingface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ impl HuggingFace {
normalize_response,
None, // embedding_encoding_format
None, // embedding_dimensions
None, // extra_headers - not exposed via HuggingFace wrapper
)
}
}
Expand Down
1 change: 1 addition & 0 deletions src/backends/mistral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ impl Mistral {
normalize_response,
embedding_encoding_format,
embedding_dimensions,
None, // extra_headers - not exposed via Mistral wrapper
)
}
}
Expand Down
21 changes: 21 additions & 0 deletions src/backends/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ impl OpenAI {
web_search_user_location_approximate_country: Option<String>,
web_search_user_location_approximate_city: Option<String>,
web_search_user_location_approximate_region: Option<String>,
extra_headers: Option<std::collections::HashMap<String, String>>,
) -> Result<Self, LLMError> {
let api_key_str = api_key.into();
if api_key_str.is_empty() {
Expand All @@ -242,6 +243,7 @@ impl OpenAI {
normalize_response,
embedding_encoding_format,
embedding_dimensions,
extra_headers,
),
enable_web_search: enable_web_search.unwrap_or(false),
web_search_context_size,
Expand Down Expand Up @@ -340,6 +342,12 @@ impl ChatProvider for OpenAI {
.post(url)
.bearer_auth(&self.provider.api_key)
.json(&body);
// Add runtime extra headers
if let Some(headers) = &self.provider.extra_headers {
for (key, value) in headers {
request = request.header(key, value);
}
}
if log::log_enabled!(log::Level::Trace) {
if let Ok(json) = serde_json::to_string(&body) {
log::trace!("OpenAI request payload: {}", json);
Expand Down Expand Up @@ -475,6 +483,12 @@ impl ChatProvider for OpenAI {
.post(url)
.bearer_auth(&self.provider.api_key)
.json(&body);
// Add runtime extra headers
if let Some(headers) = &self.provider.extra_headers {
for (key, value) in headers {
request = request.header(key, value);
}
}
if let Some(timeout) = self.provider.timeout_seconds {
request = request.timeout(std::time::Duration::from_secs(timeout));
}
Expand Down Expand Up @@ -691,6 +705,13 @@ impl OpenAI {
.bearer_auth(&self.provider.api_key)
.json(&body);

// Add runtime extra headers
if let Some(headers) = &self.provider.extra_headers {
for (key, value) in headers {
request = request.header(key, value);
}
}

if log::log_enabled!(log::Level::Trace) {
if let Ok(json) = serde_json::to_string(&body) {
log::trace!("OpenAI hosted tools request payload: {}", json);
Expand Down
1 change: 1 addition & 0 deletions src/backends/openrouter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ impl OpenRouter {
normalize_response,
None, // embedding_encoding_format - not supported by OpenRouter
None, // embedding_dimensions - not supported by OpenRouter
None, // extra_headers - not exposed via OpenRouter wrapper
)
}
}
Expand Down
26 changes: 26 additions & 0 deletions src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ pub struct LLMBuilder {
resilient_max_delay_ms: Option<u64>,
/// Resilience: jitter toggle
resilient_jitter: Option<bool>,
/// Extra HTTP headers to include in all requests
extra_headers: Option<std::collections::HashMap<String, String>>,
}

impl LLMBuilder {
Expand Down Expand Up @@ -403,6 +405,28 @@ impl LLMBuilder {
self
}

/// Set extra HTTP headers to include in all requests.
///
/// Useful for custom authentication (e.g., Cloudflare Access tokens).
///
/// # Example
/// ```
/// use std::collections::HashMap;
/// use llm::builder::LLMBuilder;
///
/// let headers = HashMap::from([
/// ("CF-Access-Client-Id".to_string(), "my-client-id".to_string()),
/// ("CF-Access-Client-Secret".to_string(), "my-secret".to_string()),
/// ]);
///
/// let builder = LLMBuilder::new()
/// .extra_headers(headers);
/// ```
pub fn extra_headers(mut self, headers: std::collections::HashMap<String, String>) -> Self {
self.extra_headers = Some(headers);
self
}

/// Enable web search
pub fn openai_enable_web_search(mut self, enable: bool) -> Self {
self.openai_enable_web_search = Some(enable);
Expand Down Expand Up @@ -683,6 +707,7 @@ impl LLMBuilder {
self.openai_web_search_user_location_approximate_country,
self.openai_web_search_user_location_approximate_city,
self.openai_web_search_user_location_approximate_region,
self.extra_headers,
)?)
}
}
Expand Down Expand Up @@ -963,6 +988,7 @@ impl LLMBuilder {
self.normalize_response,
self.embedding_encoding_format,
self.embedding_dimensions,
self.extra_headers,
);
Box::new(cohere)
}
Expand Down
112 changes: 112 additions & 0 deletions src/providers/openai_compatible.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ pub struct OpenAICompatibleProvider<T: OpenAIProviderConfig> {
pub embedding_dimensions: Option<u32>,
pub normalize_response: bool,
pub client: Client,
/// Extra HTTP headers to include in all requests
pub extra_headers: Option<std::collections::HashMap<String, String>>,
_phantom: PhantomData<T>,
}

Expand Down Expand Up @@ -316,6 +318,7 @@ impl<T: OpenAIProviderConfig> OpenAICompatibleProvider<T> {
normalize_response: Option<bool>,
embedding_encoding_format: Option<String>,
embedding_dimensions: Option<u32>,
extra_headers: Option<std::collections::HashMap<String, String>>,
) -> Self {
let mut builder = Client::builder();
if let Some(sec) = timeout_seconds {
Expand Down Expand Up @@ -347,6 +350,7 @@ impl<T: OpenAIProviderConfig> OpenAICompatibleProvider<T> {
embedding_encoding_format,
embedding_dimensions,
client: builder.build().expect("Failed to build reqwest Client"),
extra_headers,
_phantom: PhantomData,
}
}
Expand Down Expand Up @@ -456,6 +460,12 @@ impl<T: OpenAIProviderConfig> ChatProvider for OpenAICompatibleProvider<T> {
request = request.header(key, value);
}
}
// Add runtime extra headers
if let Some(headers) = &self.extra_headers {
for (key, value) in headers {
request = request.header(key, value);
}
}
if log::log_enabled!(log::Level::Trace) {
if let Ok(json) = serde_json::to_string(&body) {
log::trace!("{} request payload: {}", T::PROVIDER_NAME, json);
Expand Down Expand Up @@ -571,6 +581,12 @@ impl<T: OpenAIProviderConfig> ChatProvider for OpenAICompatibleProvider<T> {
request = request.header(key, value);
}
}
// Add runtime extra headers
if let Some(headers) = &self.extra_headers {
for (key, value) in headers {
request = request.header(key, value);
}
}
if log::log_enabled!(log::Level::Trace) {
if let Ok(json) = serde_json::to_string(&body) {
log::trace!("{} request payload: {}", T::PROVIDER_NAME, json);
Expand Down Expand Up @@ -668,6 +684,13 @@ impl<T: OpenAIProviderConfig> ChatProvider for OpenAICompatibleProvider<T> {
}
}

// Add runtime extra headers
if let Some(headers) = &self.extra_headers {
for (key, value) in headers {
request = request.header(key, value);
}
}

if log::log_enabled!(log::Level::Trace) {
if let Ok(json) = serde_json::to_string(&body) {
log::trace!(
Expand Down Expand Up @@ -1417,4 +1440,93 @@ mod tests {
results[0]
);
}

#[test]
fn test_extra_headers_stored_in_provider() {
// Test that extra_headers are properly stored in the provider
use std::collections::HashMap;

struct TestConfig;
impl OpenAIProviderConfig for TestConfig {
const PROVIDER_NAME: &'static str = "Test";
const DEFAULT_BASE_URL: &'static str = "https://api.test.com/v1/";
const DEFAULT_MODEL: &'static str = "test-model";
}

let mut headers = HashMap::new();
headers.insert("CF-Access-Client-Id".to_string(), "test-id".to_string());
headers.insert(
"CF-Access-Client-Secret".to_string(),
"test-secret".to_string(),
);

let provider = OpenAICompatibleProvider::<TestConfig>::new(
"test-api-key",
None, // base_url
None, // model
None, // max_tokens
None, // temperature
None, // timeout_seconds
None, // system
None, // top_p
None, // top_k
None, // tools
None, // tool_choice
None, // reasoning_effort
None, // json_schema
None, // voice
None, // extra_body
None, // parallel_tool_calls
None, // normalize_response
None, // embedding_encoding_format
None, // embedding_dimensions
Some(headers.clone()),
);

assert!(provider.extra_headers.is_some());
let stored_headers = provider.extra_headers.unwrap();
assert_eq!(
stored_headers.get("CF-Access-Client-Id"),
Some(&"test-id".to_string())
);
assert_eq!(
stored_headers.get("CF-Access-Client-Secret"),
Some(&"test-secret".to_string())
);
}

#[test]
fn test_extra_headers_none_when_not_provided() {
struct TestConfig;
impl OpenAIProviderConfig for TestConfig {
const PROVIDER_NAME: &'static str = "Test";
const DEFAULT_BASE_URL: &'static str = "https://api.test.com/v1/";
const DEFAULT_MODEL: &'static str = "test-model";
}

let provider = OpenAICompatibleProvider::<TestConfig>::new(
"test-api-key",
None, // base_url
None, // model
None, // max_tokens
None, // temperature
None, // timeout_seconds
None, // system
None, // top_p
None, // top_k
None, // tools
None, // tool_choice
None, // reasoning_effort
None, // json_schema
None, // voice
None, // extra_body
None, // parallel_tool_calls
None, // normalize_response
None, // embedding_encoding_format
None, // embedding_dimensions
None, // extra_headers
);

assert!(provider.extra_headers.is_none());
}
}