diff --git a/src/backends/cohere.rs b/src/backends/cohere.rs index a835556..519e28c 100644 --- a/src/backends/cohere.rs +++ b/src/backends/cohere.rs @@ -76,6 +76,7 @@ impl Cohere { normalize_response, embedding_encoding_format, embedding_dimensions, + None, // extra_headers - not exposed via Cohere wrapper ) } } diff --git a/src/backends/groq.rs b/src/backends/groq.rs index 6a7a97f..58d0334 100644 --- a/src/backends/groq.rs +++ b/src/backends/groq.rs @@ -87,6 +87,7 @@ impl Groq { normalize_response, None, // embedding_encoding_format - not supported by Groq None, // embedding_dimensions - not supported by Groq + None, // extra_headers - not exposed via Groq wrapper ) } } diff --git a/src/backends/huggingface.rs b/src/backends/huggingface.rs index 233ab85..9eb0d63 100644 --- a/src/backends/huggingface.rs +++ b/src/backends/huggingface.rs @@ -73,6 +73,7 @@ impl HuggingFace { normalize_response, None, // embedding_encoding_format None, // embedding_dimensions + None, // extra_headers - not exposed via HuggingFace wrapper ) } } diff --git a/src/backends/mistral.rs b/src/backends/mistral.rs index 761d915..06d24a2 100644 --- a/src/backends/mistral.rs +++ b/src/backends/mistral.rs @@ -76,6 +76,7 @@ impl Mistral { normalize_response, embedding_encoding_format, embedding_dimensions, + None, // extra_headers - not exposed via Mistral wrapper ) } } diff --git a/src/backends/openai.rs b/src/backends/openai.rs index 026a9af..d754f93 100644 --- a/src/backends/openai.rs +++ b/src/backends/openai.rs @@ -216,6 +216,7 @@ impl OpenAI { web_search_user_location_approximate_country: Option, web_search_user_location_approximate_city: Option, web_search_user_location_approximate_region: Option, + extra_headers: Option>, ) -> Result { let api_key_str = api_key.into(); if api_key_str.is_empty() { @@ -242,6 +243,7 @@ impl OpenAI { normalize_response, embedding_encoding_format, embedding_dimensions, + extra_headers, ), enable_web_search: enable_web_search.unwrap_or(false), web_search_context_size, @@ -340,6 +342,12 @@ impl ChatProvider for OpenAI { .post(url) .bearer_auth(&self.provider.api_key) .json(&body); + // Add runtime extra headers + if let Some(headers) = &self.provider.extra_headers { + for (key, value) in headers { + request = request.header(key, value); + } + } if log::log_enabled!(log::Level::Trace) { if let Ok(json) = serde_json::to_string(&body) { log::trace!("OpenAI request payload: {}", json); @@ -475,6 +483,12 @@ impl ChatProvider for OpenAI { .post(url) .bearer_auth(&self.provider.api_key) .json(&body); + // Add runtime extra headers + if let Some(headers) = &self.provider.extra_headers { + for (key, value) in headers { + request = request.header(key, value); + } + } if let Some(timeout) = self.provider.timeout_seconds { request = request.timeout(std::time::Duration::from_secs(timeout)); } @@ -691,6 +705,13 @@ impl OpenAI { .bearer_auth(&self.provider.api_key) .json(&body); + // Add runtime extra headers + if let Some(headers) = &self.provider.extra_headers { + for (key, value) in headers { + request = request.header(key, value); + } + } + if log::log_enabled!(log::Level::Trace) { if let Ok(json) = serde_json::to_string(&body) { log::trace!("OpenAI hosted tools request payload: {}", json); diff --git a/src/backends/openrouter.rs b/src/backends/openrouter.rs index c0ca170..86c0fd1 100644 --- a/src/backends/openrouter.rs +++ b/src/backends/openrouter.rs @@ -73,6 +73,7 @@ impl OpenRouter { normalize_response, None, // embedding_encoding_format - not supported by OpenRouter None, // embedding_dimensions - not supported by OpenRouter + None, // extra_headers - not exposed via OpenRouter wrapper ) } } diff --git a/src/builder.rs b/src/builder.rs index 92033b1..e019fd6 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -198,6 +198,8 @@ pub struct LLMBuilder { resilient_max_delay_ms: Option, /// Resilience: jitter toggle resilient_jitter: Option, + /// Extra HTTP headers to include in all requests + extra_headers: Option>, } impl LLMBuilder { @@ -403,6 +405,28 @@ impl LLMBuilder { self } + /// Set extra HTTP headers to include in all requests. + /// + /// Useful for custom authentication (e.g., Cloudflare Access tokens). + /// + /// # Example + /// ``` + /// use std::collections::HashMap; + /// use llm::builder::LLMBuilder; + /// + /// let headers = HashMap::from([ + /// ("CF-Access-Client-Id".to_string(), "my-client-id".to_string()), + /// ("CF-Access-Client-Secret".to_string(), "my-secret".to_string()), + /// ]); + /// + /// let builder = LLMBuilder::new() + /// .extra_headers(headers); + /// ``` + pub fn extra_headers(mut self, headers: std::collections::HashMap) -> Self { + self.extra_headers = Some(headers); + self + } + /// Enable web search pub fn openai_enable_web_search(mut self, enable: bool) -> Self { self.openai_enable_web_search = Some(enable); @@ -683,6 +707,7 @@ impl LLMBuilder { self.openai_web_search_user_location_approximate_country, self.openai_web_search_user_location_approximate_city, self.openai_web_search_user_location_approximate_region, + self.extra_headers, )?) } } @@ -963,6 +988,7 @@ impl LLMBuilder { self.normalize_response, self.embedding_encoding_format, self.embedding_dimensions, + self.extra_headers, ); Box::new(cohere) } diff --git a/src/providers/openai_compatible.rs b/src/providers/openai_compatible.rs index 8aac07e..6c08dbe 100644 --- a/src/providers/openai_compatible.rs +++ b/src/providers/openai_compatible.rs @@ -48,6 +48,8 @@ pub struct OpenAICompatibleProvider { pub embedding_dimensions: Option, pub normalize_response: bool, pub client: Client, + /// Extra HTTP headers to include in all requests + pub extra_headers: Option>, _phantom: PhantomData, } @@ -316,6 +318,7 @@ impl OpenAICompatibleProvider { normalize_response: Option, embedding_encoding_format: Option, embedding_dimensions: Option, + extra_headers: Option>, ) -> Self { let mut builder = Client::builder(); if let Some(sec) = timeout_seconds { @@ -347,6 +350,7 @@ impl OpenAICompatibleProvider { embedding_encoding_format, embedding_dimensions, client: builder.build().expect("Failed to build reqwest Client"), + extra_headers, _phantom: PhantomData, } } @@ -456,6 +460,12 @@ impl ChatProvider for OpenAICompatibleProvider { request = request.header(key, value); } } + // Add runtime extra headers + if let Some(headers) = &self.extra_headers { + for (key, value) in headers { + request = request.header(key, value); + } + } if log::log_enabled!(log::Level::Trace) { if let Ok(json) = serde_json::to_string(&body) { log::trace!("{} request payload: {}", T::PROVIDER_NAME, json); @@ -571,6 +581,12 @@ impl ChatProvider for OpenAICompatibleProvider { request = request.header(key, value); } } + // Add runtime extra headers + if let Some(headers) = &self.extra_headers { + for (key, value) in headers { + request = request.header(key, value); + } + } if log::log_enabled!(log::Level::Trace) { if let Ok(json) = serde_json::to_string(&body) { log::trace!("{} request payload: {}", T::PROVIDER_NAME, json); @@ -668,6 +684,13 @@ impl ChatProvider for OpenAICompatibleProvider { } } + // Add runtime extra headers + if let Some(headers) = &self.extra_headers { + for (key, value) in headers { + request = request.header(key, value); + } + } + if log::log_enabled!(log::Level::Trace) { if let Ok(json) = serde_json::to_string(&body) { log::trace!( @@ -1417,4 +1440,93 @@ mod tests { results[0] ); } + + #[test] + fn test_extra_headers_stored_in_provider() { + // Test that extra_headers are properly stored in the provider + use std::collections::HashMap; + + struct TestConfig; + impl OpenAIProviderConfig for TestConfig { + const PROVIDER_NAME: &'static str = "Test"; + const DEFAULT_BASE_URL: &'static str = "https://api.test.com/v1/"; + const DEFAULT_MODEL: &'static str = "test-model"; + } + + let mut headers = HashMap::new(); + headers.insert("CF-Access-Client-Id".to_string(), "test-id".to_string()); + headers.insert( + "CF-Access-Client-Secret".to_string(), + "test-secret".to_string(), + ); + + let provider = OpenAICompatibleProvider::::new( + "test-api-key", + None, // base_url + None, // model + None, // max_tokens + None, // temperature + None, // timeout_seconds + None, // system + None, // top_p + None, // top_k + None, // tools + None, // tool_choice + None, // reasoning_effort + None, // json_schema + None, // voice + None, // extra_body + None, // parallel_tool_calls + None, // normalize_response + None, // embedding_encoding_format + None, // embedding_dimensions + Some(headers.clone()), + ); + + assert!(provider.extra_headers.is_some()); + let stored_headers = provider.extra_headers.unwrap(); + assert_eq!( + stored_headers.get("CF-Access-Client-Id"), + Some(&"test-id".to_string()) + ); + assert_eq!( + stored_headers.get("CF-Access-Client-Secret"), + Some(&"test-secret".to_string()) + ); + } + + #[test] + fn test_extra_headers_none_when_not_provided() { + struct TestConfig; + impl OpenAIProviderConfig for TestConfig { + const PROVIDER_NAME: &'static str = "Test"; + const DEFAULT_BASE_URL: &'static str = "https://api.test.com/v1/"; + const DEFAULT_MODEL: &'static str = "test-model"; + } + + let provider = OpenAICompatibleProvider::::new( + "test-api-key", + None, // base_url + None, // model + None, // max_tokens + None, // temperature + None, // timeout_seconds + None, // system + None, // top_p + None, // top_k + None, // tools + None, // tool_choice + None, // reasoning_effort + None, // json_schema + None, // voice + None, // extra_body + None, // parallel_tool_calls + None, // normalize_response + None, // embedding_encoding_format + None, // embedding_dimensions + None, // extra_headers + ); + + assert!(provider.extra_headers.is_none()); + } }