From b0b056b2cbf3278ada702adce31f4917d16ee6bf Mon Sep 17 00:00:00 2001 From: Naz Quadri Date: Wed, 31 Dec 2025 12:10:28 -0500 Subject: [PATCH] feat: add opt-in metrics collection for chat requests --- Cargo.lock | 21 +++ Cargo.toml | 1 + src/builder.rs | 32 ++++ src/chat/mod.rs | 73 +++++++++ src/chat/tracked.rs | 381 ++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 52 ++++++ src/metrics.rs | 339 +++++++++++++++++++++++++++++++++++++++ 7 files changed, 899 insertions(+) create mode 100644 src/chat/tracked.rs create mode 100644 src/metrics.rs diff --git a/Cargo.lock b/Cargo.lock index cf58008..a1263b0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1196,6 +1196,7 @@ dependencies = [ "http-body-util", "hyper", "log", + "pin-project", "rand", "regex", "reqwest", @@ -1540,6 +1541,26 @@ version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" +[[package]] +name = "pin-project" +version = "1.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "677f1add503faace112b9f1373e43e9e054bfdd22ff1a63c1bc485eaec6a6a8a" +dependencies = [ + "pin-project-internal", +] + +[[package]] +name = "pin-project-internal" +version = "1.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e918e4ff8c4549eb882f14b3a4bc8c8bc93de829416eacf579f1207a8fbf861" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.101", +] + [[package]] name = "pin-project-lite" version = "0.2.15" diff --git a/Cargo.toml b/Cargo.toml index 3cc58d0..33c19d5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -79,6 +79,7 @@ log = "0.4" env_logger = { version = "0.11", optional = true } chrono = {version = "0.4", default-features = false, features = ["serde"]} rand = "0.8" +pin-project = "1" [[bin]] name = "llm" diff --git a/src/builder.rs b/src/builder.rs index 92033b1..7027fd0 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -198,6 +198,8 @@ pub struct LLMBuilder { resilient_max_delay_ms: Option, /// Resilience: jitter toggle resilient_jitter: Option, + /// Enable metrics collection (timing, usage) for non-streaming calls + enable_metrics: Option, } impl LLMBuilder { @@ -476,6 +478,31 @@ impl LLMBuilder { self } + /// Enables metrics collection for non-streaming chat calls. + /// + /// When enabled, `ChatResponse::metrics()` will return timing and usage + /// information. For streaming calls, use `Tracked::new()` wrapper instead. + /// + /// # Example + /// + /// ```rust,ignore + /// let llm = LLMBuilder::new() + /// .backend(LLMBackend::OpenAI) + /// .api_key("...") + /// .enable_metrics(true) + /// .build()?; + /// + /// let response = llm.chat("Hello").await?; + /// if let Some(metrics) = response.metrics() { + /// println!("Duration: {:?}", metrics.duration); + /// println!("Tokens/sec: {:?}", metrics.tokens_per_second()); + /// } + /// ``` + pub fn enable_metrics(mut self, enable: bool) -> Self { + self.enable_metrics = Some(enable); + self + } + #[deprecated(note = "Renamed to `xai_search_mode`.")] pub fn search_mode(self, mode: impl Into) -> Self { self.xai_search_mode(mode) @@ -1112,6 +1139,11 @@ impl LLMBuilder { final_provider = Box::new(crate::resilient_llm::ResilientLLM::new(final_provider, cfg)); } + // Wrap with metrics collection if enabled + if self.enable_metrics.unwrap_or(false) { + final_provider = Box::new(crate::metrics::MetricsProvider::new(final_provider)); + } + // Wrap with memory capabilities if memory is configured if let Some(memory) = self.memory { let memory_arc = Arc::new(RwLock::new(memory)); diff --git a/src/chat/mod.rs b/src/chat/mod.rs index 331ca95..08e1214 100644 --- a/src/chat/mod.rs +++ b/src/chat/mod.rs @@ -1,6 +1,7 @@ use std::collections::HashMap; use std::fmt; use std::pin::Pin; +use std::time::Duration; use async_trait::async_trait; use futures::stream::{Stream, StreamExt}; @@ -9,6 +10,9 @@ use serde_json::Value; use crate::{error::LLMError, ToolCall}; +mod tracked; +pub use tracked::{Trackable, Tracked}; + /// Usage metadata for a chat response. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct Usage { @@ -127,6 +131,54 @@ pub struct PromptTokensDetails { pub audio_tokens: Option, } +/// Comprehensive metrics for a chat request including timing and token usage. +/// +/// This struct is returned by `ChatResponse::metrics()` when metrics collection +/// is enabled, or by `Tracked::finalize()` for streaming responses. +/// +/// # Example +/// +/// ```rust,ignore +/// // Non-streaming with metrics enabled +/// let response = llm.chat("Hello").await?; +/// if let Some(metrics) = response.metrics() { +/// println!("Duration: {:?}", metrics.duration); +/// println!("Tokens: {:?}", metrics.usage); +/// println!("Tokens/sec: {:?}", metrics.tokens_per_second()); +/// } +/// +/// // Streaming with Tracked wrapper +/// let stream = llm.chat_stream_with_tools(messages, None).await?; +/// let mut tracked = Tracked::new(stream); +/// while let Some(chunk) = tracked.next().await { /* ... */ } +/// let metrics = tracked.finalize(); +/// println!("TTFT: {:?}", metrics.time_to_first_token); +/// ``` +#[derive(Debug, Clone, Default)] +pub struct ChatMetrics { + /// Token usage (prompt, completion, total) + pub usage: Option, + /// Total wall-clock duration of the request + pub duration: Duration, + /// Time to first token (streaming only, None for non-streaming) + pub time_to_first_token: Option, +} + +impl ChatMetrics { + /// Calculate tokens per second (completion tokens / duration). + /// + /// Returns `None` if usage data is unavailable or duration is zero. + pub fn tokens_per_second(&self) -> Option { + let usage = self.usage.as_ref()?; + let secs = self.duration.as_secs_f64(); + if secs > 0.0 { + Some(usage.completion_tokens as f64 / secs) + } else { + None + } + } +} + /// Role of a participant in a chat conversation. #[derive(Debug, Clone, PartialEq, Eq)] pub enum ChatRole { @@ -356,15 +408,36 @@ impl Serialize for ToolChoice { } } +/// Trait for chat response types returned by providers. +/// +/// Provides access to the response content, tool calls, usage statistics, +/// and optional metrics when enabled. pub trait ChatResponse: std::fmt::Debug + std::fmt::Display + Send + Sync { + /// Returns the text content of the response, if any. fn text(&self) -> Option; + + /// Returns tool calls requested by the model, if any. fn tool_calls(&self) -> Option>; + + /// Returns the model's thinking/reasoning output, if available. fn thinking(&self) -> Option { None } + + /// Returns token usage statistics, if available. fn usage(&self) -> Option { None } + + /// Returns comprehensive metrics including timing and usage. + /// + /// This method returns `Some` only when metrics collection is enabled + /// via `.enable_metrics(true)` on the builder. Otherwise returns `None`. + /// + /// For streaming responses, use `Tracked::finalize()` instead. + fn metrics(&self) -> Option { + None + } } /// Trait for providers that support chat-style interactions. diff --git a/src/chat/tracked.rs b/src/chat/tracked.rs new file mode 100644 index 0000000..0c1c95e --- /dev/null +++ b/src/chat/tracked.rs @@ -0,0 +1,381 @@ +//! Tracked stream wrapper for collecting metrics during streaming. +//! +//! This module provides [`Tracked`], a generic stream wrapper that collects +//! timing and usage metrics as chunks are consumed. +//! +//! # Example +//! +//! ```rust,ignore +//! use llm::chat::Tracked; +//! use futures::StreamExt; +//! +//! let stream = provider.chat_stream_with_tools(messages, None).await?; +//! let mut tracked = Tracked::new(stream); +//! +//! while let Some(chunk) = tracked.next().await { +//! match chunk? { +//! StreamChunk::Text(text) => print!("{}", text), +//! StreamChunk::Done { .. } => break, +//! _ => {} +//! } +//! } +//! +//! let metrics = tracked.finalize(); +//! println!("Time to first token: {:?}", metrics.time_to_first_token); +//! println!("Total duration: {:?}", metrics.duration); +//! println!("Tokens/sec: {:?}", metrics.tokens_per_second()); +//! ``` + +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::Instant; + +use futures::Stream; +use pin_project::pin_project; + +use crate::error::LLMError; +use crate::ToolCall; + +use super::{ChatMetrics, StreamChunk, StreamResponse, Usage}; + +/// Trait for stream items that can be tracked for metrics. +/// +/// Implement this trait for custom stream item types to enable +/// tracking with [`Tracked`]. +pub trait Trackable { + /// Extract text content from this item, if any. + fn extract_text(&self) -> Option<&str>; + + /// Extract a completed tool call from this item, if any. + fn extract_tool_call(&self) -> Option<&ToolCall>; + + /// Extract usage statistics from this item, if any. + fn extract_usage(&self) -> Option<&Usage>; + + /// Returns true if this item indicates the stream is done. + fn is_done(&self) -> bool; +} + +impl Trackable for StreamChunk { + fn extract_text(&self) -> Option<&str> { + match self { + StreamChunk::Text(t) => Some(t), + _ => None, + } + } + + fn extract_tool_call(&self) -> Option<&ToolCall> { + match self { + StreamChunk::ToolUseComplete { tool_call, .. } => Some(tool_call), + _ => None, + } + } + + fn extract_usage(&self) -> Option<&Usage> { + None // StreamChunk doesn't carry usage + } + + fn is_done(&self) -> bool { + matches!(self, StreamChunk::Done { .. }) + } +} + +impl Trackable for StreamResponse { + fn extract_text(&self) -> Option<&str> { + self.choices.first()?.delta.content.as_deref() + } + + fn extract_tool_call(&self) -> Option<&ToolCall> { + None // Tool calls come through delta.tool_calls but aren't complete + } + + fn extract_usage(&self) -> Option<&Usage> { + self.usage.as_ref() + } + + fn is_done(&self) -> bool { + false // StreamResponse doesn't have explicit done marker + } +} + +impl Trackable for String { + fn extract_text(&self) -> Option<&str> { + Some(self) + } + + fn extract_tool_call(&self) -> Option<&ToolCall> { + None + } + + fn extract_usage(&self) -> Option<&Usage> { + None + } + + fn is_done(&self) -> bool { + false + } +} + +/// A stream wrapper that tracks metrics as chunks are consumed. +/// +/// `Tracked` wraps any stream and collects timing information and +/// accumulated content as items are polled. After draining the stream, +/// call [`finalize()`](Tracked::finalize) to get the collected metrics. +/// +/// # Type Parameters +/// +/// * `S` - The inner stream type +/// +/// # Example +/// +/// ```rust,ignore +/// let stream = provider.chat_stream_with_tools(messages, None).await?; +/// let mut tracked = Tracked::new(stream); +/// +/// while let Some(chunk) = tracked.next().await { +/// // Process chunk... +/// } +/// +/// let metrics = tracked.finalize(); +/// println!("Duration: {:?}", metrics.duration); +/// ``` +#[pin_project] +pub struct Tracked { + #[pin] + inner: S, + start_time: Instant, + first_chunk_time: Option, + accumulated_text: String, + tool_calls: Vec, + usage: Option, + chunk_count: usize, +} + +impl Tracked { + /// Create a new tracked stream wrapper. + /// + /// The timer starts immediately when this is called. + pub fn new(inner: S) -> Self { + Self { + inner, + start_time: Instant::now(), + first_chunk_time: None, + accumulated_text: String::new(), + tool_calls: Vec::new(), + usage: None, + chunk_count: 0, + } + } + + /// Finalize and get metrics. + /// + /// Can be called at any time, but metrics are most meaningful + /// after the stream has been fully drained. + pub fn finalize(&self) -> ChatMetrics { + ChatMetrics { + usage: self.usage.clone(), + duration: self.start_time.elapsed(), + time_to_first_token: self + .first_chunk_time + .map(|t| t.duration_since(self.start_time)), + } + } + + /// Get the accumulated text so far. + pub fn text(&self) -> &str { + &self.accumulated_text + } + + /// Get the tool calls collected so far. + pub fn tool_calls(&self) -> &[ToolCall] { + &self.tool_calls + } + + /// Get the number of chunks received so far. + pub fn chunk_count(&self) -> usize { + self.chunk_count + } + + /// Get the usage statistics if available. + pub fn usage(&self) -> Option<&Usage> { + self.usage.as_ref() + } +} + +impl Stream for Tracked +where + S: Stream>, + T: Trackable, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + + match this.inner.poll_next(cx) { + Poll::Ready(Some(Ok(item))) => { + *this.chunk_count += 1; + + // Extract and accumulate + if let Some(text) = item.extract_text() { + // Record time to first token only on actual text content + if !text.is_empty() && this.first_chunk_time.is_none() { + *this.first_chunk_time = Some(Instant::now()); + } + this.accumulated_text.push_str(text); + } + if let Some(tool_call) = item.extract_tool_call() { + this.tool_calls.push(tool_call.clone()); + } + if let Some(usage) = item.extract_usage() { + *this.usage = Some(usage.clone()); + } + + Poll::Ready(Some(Ok(item))) + } + Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use futures::stream; + use futures::StreamExt; + + #[tokio::test] + async fn test_tracked_accumulates_text() { + let chunks = vec![ + Ok(StreamChunk::Text("Hello ".to_string())), + Ok(StreamChunk::Text("world".to_string())), + Ok(StreamChunk::Done { + stop_reason: "end_turn".to_string(), + }), + ]; + let stream = stream::iter(chunks); + let mut tracked = Tracked::new(stream); + + while let Some(_) = tracked.next().await {} + + assert_eq!(tracked.text(), "Hello world"); + assert_eq!(tracked.chunk_count(), 3); + } + + #[tokio::test] + async fn test_tracked_records_first_chunk_time() { + let chunks = vec![ + Ok(StreamChunk::Text("Hi".to_string())), + Ok(StreamChunk::Done { + stop_reason: "end_turn".to_string(), + }), + ]; + let stream = stream::iter(chunks); + let mut tracked = Tracked::new(stream); + + // Before any chunks + assert!(tracked.first_chunk_time.is_none()); + + // Consume first chunk + let _ = tracked.next().await; + assert!(tracked.first_chunk_time.is_some()); + + let metrics = tracked.finalize(); + assert!(metrics.time_to_first_token.is_some()); + } + + #[tokio::test] + async fn test_tracked_collects_tool_calls() { + let tool_call = ToolCall { + id: "call_123".to_string(), + call_type: "function".to_string(), + function: crate::FunctionCall { + name: "get_weather".to_string(), + arguments: r#"{"location": "Paris"}"#.to_string(), + }, + }; + + let chunks = vec![ + Ok(StreamChunk::ToolUseStart { + index: 0, + id: "call_123".to_string(), + name: "get_weather".to_string(), + }), + Ok(StreamChunk::ToolUseComplete { + index: 0, + tool_call: tool_call.clone(), + }), + Ok(StreamChunk::Done { + stop_reason: "tool_use".to_string(), + }), + ]; + let stream = stream::iter(chunks); + let mut tracked = Tracked::new(stream); + + while let Some(_) = tracked.next().await {} + + assert_eq!(tracked.tool_calls().len(), 1); + assert_eq!(tracked.tool_calls()[0].function.name, "get_weather"); + } + + #[tokio::test] + async fn test_tracked_finalize_returns_metrics() { + let chunks = vec![ + Ok(StreamChunk::Text("Test".to_string())), + Ok(StreamChunk::Done { + stop_reason: "end_turn".to_string(), + }), + ]; + let stream = stream::iter(chunks); + let mut tracked = Tracked::new(stream); + + while let Some(_) = tracked.next().await {} + + let metrics = tracked.finalize(); + assert!(metrics.duration.as_nanos() > 0); + assert!(metrics.time_to_first_token.is_some()); + // Usage is None since StreamChunk doesn't carry it + assert!(metrics.usage.is_none()); + } + + #[tokio::test] + async fn test_tracked_with_string_stream() { + let chunks: Vec> = vec![ + Ok("Hello ".to_string()), + Ok("world".to_string()), + ]; + let stream = stream::iter(chunks); + let mut tracked = Tracked::new(stream); + + while let Some(_) = tracked.next().await {} + + assert_eq!(tracked.text(), "Hello world"); + assert_eq!(tracked.chunk_count(), 2); + } + + #[test] + fn test_trackable_stream_chunk_text() { + let chunk = StreamChunk::Text("hello".to_string()); + assert_eq!(chunk.extract_text(), Some("hello")); + assert!(chunk.extract_tool_call().is_none()); + assert!(!chunk.is_done()); + } + + #[test] + fn test_trackable_stream_chunk_done() { + let chunk = StreamChunk::Done { + stop_reason: "end_turn".to_string(), + }; + assert!(chunk.is_done()); + assert!(chunk.extract_text().is_none()); + } + + #[test] + fn test_trackable_string() { + let s = "test".to_string(); + assert_eq!(s.extract_text(), Some("test")); + assert!(!s.is_done()); + } +} diff --git a/src/lib.rs b/src/lib.rs index 4edde3c..cc926e9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,6 +13,51 @@ //! - Embeddings generation //! - Multiple providers (OpenAI, Anthropic, Google, etc.) //! - Request validation and retry logic +//! - **Metrics collection** (timing, token usage) for performance monitoring +//! +//! ## Metrics Collection +//! +//! The library provides opt-in metrics collection for measuring request timing and token usage. +//! +//! ### Non-Streaming Requests +//! +//! Enable metrics via the builder: +//! +//! ```rust,ignore +//! use llm::builder::{LLMBuilder, LLMBackend}; +//! +//! let llm = LLMBuilder::new() +//! .backend(LLMBackend::OpenAI) +//! .api_key("...") +//! .enable_metrics(true) // Enable metrics +//! .build()?; +//! +//! let response = llm.chat("Hello").await?; +//! if let Some(metrics) = response.metrics() { +//! println!("Duration: {:?}", metrics.duration); +//! println!("Tokens/sec: {:?}", metrics.tokens_per_second()); +//! } +//! ``` +//! +//! ### Streaming Requests +//! +//! Use the [`Tracked`] wrapper around any stream: +//! +//! ```rust,ignore +//! use llm::Tracked; +//! use futures::StreamExt; +//! +//! let stream = llm.chat_stream_with_tools(messages, None).await?; +//! let mut tracked = Tracked::new(stream); +//! +//! while let Some(chunk) = tracked.next().await { +//! // Process chunk... +//! } +//! +//! let metrics = tracked.finalize(); +//! println!("Time to first token: {:?}", metrics.time_to_first_token); +//! println!("Total duration: {:?}", metrics.duration); +//! ``` //! //! ## Examples //! @@ -25,6 +70,10 @@ // Re-export for convenience pub use async_trait::async_trait; +// Re-export metrics types for easy access +pub use chat::{ChatMetrics, Trackable, Tracked}; +pub use metrics::MetricsProvider; + use chat::Tool; use serde::{Deserialize, Serialize}; @@ -55,6 +104,9 @@ pub mod error; /// Validation wrapper for LLM providers with retry capabilities pub mod validated_llm; +/// Metrics wrapper for LLM providers (timing, usage tracking) +pub mod metrics; + /// Resilience wrapper (retry/backoff) for LLM providers pub mod resilient_llm; diff --git a/src/metrics.rs b/src/metrics.rs new file mode 100644 index 0000000..614e1e9 --- /dev/null +++ b/src/metrics.rs @@ -0,0 +1,339 @@ +//! Metrics provider wrapper for timing non-streaming chat calls. +//! +//! This module provides [`MetricsProvider`], a wrapper that adds timing +//! information to chat responses when metrics collection is enabled. + +use std::fmt; +use std::pin::Pin; +use std::time::{Duration, Instant}; + +use async_trait::async_trait; +use futures::Stream; + +use crate::chat::{ + ChatMessage, ChatMetrics, ChatProvider, ChatResponse, StreamChunk, StreamResponse, Tool, + Usage, +}; +use crate::completion::{CompletionProvider, CompletionRequest, CompletionResponse}; +use crate::embedding::EmbeddingProvider; +use crate::error::LLMError; +use crate::models::{ModelListRequest, ModelListResponse as ModelListResponseTrait, ModelsProvider}; +use crate::stt::SpeechToTextProvider; +use crate::tts::TextToSpeechProvider; +use crate::{LLMProvider, ToolCall}; + +/// A provider wrapper that adds timing metrics to chat responses. +/// +/// This wrapper intercepts `chat_with_tools` calls to measure duration +/// and wraps the response to include metrics. All other methods are +/// delegated directly to the inner provider. +/// +/// Created automatically by the builder when `.enable_metrics(true)` is set. +pub struct MetricsProvider { + inner: Box, +} + +impl MetricsProvider { + /// Create a new metrics-enabled provider wrapper. + pub fn new(inner: Box) -> Self { + Self { inner } + } +} + +#[async_trait] +impl ChatProvider for MetricsProvider { + async fn chat_with_tools( + &self, + messages: &[ChatMessage], + tools: Option<&[Tool]>, + ) -> Result, LLMError> { + let start = Instant::now(); + let response = self.inner.chat_with_tools(messages, tools).await?; + let duration = start.elapsed(); + + Ok(Box::new(MetricsResponse { + inner: response, + duration, + })) + } + + async fn chat_with_web_search( + &self, + input: String, + ) -> Result, LLMError> { + let start = Instant::now(); + let response = self.inner.chat_with_web_search(input).await?; + let duration = start.elapsed(); + + Ok(Box::new(MetricsResponse { + inner: response, + duration, + })) + } + + // Streaming methods are delegated - user should use Tracked wrapper + async fn chat_stream( + &self, + messages: &[ChatMessage], + ) -> Result> + Send>>, LLMError> { + self.inner.chat_stream(messages).await + } + + async fn chat_stream_struct( + &self, + messages: &[ChatMessage], + ) -> Result> + Send>>, LLMError> + { + self.inner.chat_stream_struct(messages).await + } + + async fn chat_stream_with_tools( + &self, + messages: &[ChatMessage], + tools: Option<&[Tool]>, + ) -> Result> + Send>>, LLMError> { + self.inner.chat_stream_with_tools(messages, tools).await + } +} + +#[async_trait] +impl CompletionProvider for MetricsProvider { + async fn complete(&self, request: &CompletionRequest) -> Result { + self.inner.complete(request).await + } +} + +#[async_trait] +impl EmbeddingProvider for MetricsProvider { + async fn embed(&self, input: Vec) -> Result>, LLMError> { + self.inner.embed(input).await + } +} + +#[async_trait] +impl SpeechToTextProvider for MetricsProvider { + async fn transcribe(&self, audio_data: Vec) -> Result { + self.inner.transcribe(audio_data).await + } +} + +#[async_trait] +impl TextToSpeechProvider for MetricsProvider { + async fn speech(&self, input: &str) -> Result, LLMError> { + self.inner.speech(input).await + } +} + +#[async_trait] +impl ModelsProvider for MetricsProvider { + async fn list_models( + &self, + request: Option<&ModelListRequest>, + ) -> Result, LLMError> { + self.inner.list_models(request).await + } +} + +impl LLMProvider for MetricsProvider { + fn tools(&self) -> Option<&[Tool]> { + self.inner.tools() + } +} + +/// A chat response wrapper that includes timing metrics. +struct MetricsResponse { + inner: Box, + duration: Duration, +} + +impl ChatResponse for MetricsResponse { + fn text(&self) -> Option { + self.inner.text() + } + + fn tool_calls(&self) -> Option> { + self.inner.tool_calls() + } + + fn thinking(&self) -> Option { + self.inner.thinking() + } + + fn usage(&self) -> Option { + self.inner.usage() + } + + fn metrics(&self) -> Option { + Some(ChatMetrics { + usage: self.inner.usage(), + duration: self.duration, + time_to_first_token: None, // N/A for non-streaming + }) + } +} + +impl fmt::Debug for MetricsResponse { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("MetricsResponse") + .field("inner", &self.inner) + .field("duration", &self.duration) + .finish() + } +} + +impl fmt::Display for MetricsResponse { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.inner) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // Mock provider for testing + struct MockProvider; + + struct MockResponse { + text: String, + usage: Option, + } + + impl fmt::Debug for MockResponse { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("MockResponse") + .field("text", &self.text) + .finish() + } + } + + impl fmt::Display for MockResponse { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.text) + } + } + + impl ChatResponse for MockResponse { + fn text(&self) -> Option { + Some(self.text.clone()) + } + + fn tool_calls(&self) -> Option> { + None + } + + fn usage(&self) -> Option { + self.usage.clone() + } + } + + #[async_trait] + impl ChatProvider for MockProvider { + async fn chat_with_tools( + &self, + _messages: &[ChatMessage], + _tools: Option<&[Tool]>, + ) -> Result, LLMError> { + // Simulate some work + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + Ok(Box::new(MockResponse { + text: "Hello".to_string(), + usage: Some(Usage { + prompt_tokens: 10, + completion_tokens: 5, + total_tokens: 15, + completion_tokens_details: None, + prompt_tokens_details: None, + }), + })) + } + } + + #[async_trait] + impl CompletionProvider for MockProvider { + async fn complete( + &self, + _request: &CompletionRequest, + ) -> Result { + Err(LLMError::Generic("Not implemented".to_string())) + } + } + + #[async_trait] + impl EmbeddingProvider for MockProvider { + async fn embed(&self, _input: Vec) -> Result>, LLMError> { + Err(LLMError::Generic("Not implemented".to_string())) + } + } + + #[async_trait] + impl SpeechToTextProvider for MockProvider { + async fn transcribe(&self, _audio_data: Vec) -> Result { + Err(LLMError::Generic("Not implemented".to_string())) + } + } + + #[async_trait] + impl TextToSpeechProvider for MockProvider { + async fn speech(&self, _input: &str) -> Result, LLMError> { + Err(LLMError::Generic("Not implemented".to_string())) + } + } + + #[async_trait] + impl ModelsProvider for MockProvider { + async fn list_models( + &self, + _request: Option<&ModelListRequest>, + ) -> Result, LLMError> { + Err(LLMError::Generic("Not implemented".to_string())) + } + } + + impl LLMProvider for MockProvider {} + + #[tokio::test] + async fn test_metrics_provider_adds_timing() { + let provider = MetricsProvider::new(Box::new(MockProvider)); + let messages = vec![]; + + let response = provider.chat_with_tools(&messages, None).await.unwrap(); + + // Should have metrics + let metrics = response.metrics().unwrap(); + assert!(metrics.duration.as_millis() >= 10); + assert!(metrics.time_to_first_token.is_none()); // Non-streaming + + // Should preserve original response + assert_eq!(response.text(), Some("Hello".to_string())); + assert!(response.usage().is_some()); + } + + #[tokio::test] + async fn test_metrics_response_includes_usage() { + let provider = MetricsProvider::new(Box::new(MockProvider)); + let messages = vec![]; + + let response = provider.chat_with_tools(&messages, None).await.unwrap(); + let metrics = response.metrics().unwrap(); + + // Usage should be passed through + let usage = metrics.usage.unwrap(); + assert_eq!(usage.prompt_tokens, 10); + assert_eq!(usage.completion_tokens, 5); + assert_eq!(usage.total_tokens, 15); + } + + #[test] + fn test_metrics_response_display() { + let response = MetricsResponse { + inner: Box::new(MockResponse { + text: "Test".to_string(), + usage: None, + }), + duration: Duration::from_millis(100), + }; + + assert_eq!(format!("{}", response), "Test"); + } +}