Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion crates/llm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ async-trait = "0.1.88"
clap = { version = "4.5.43", features = ["derive"] }
futures-util = "0.3.31"
gbnf-rs = { version = "0.1.0", path = "../gbnf-rs" }
gemini-rust = "1.3.1"
gemini-rust = "1.4.0"
ollama-rs = { git = "https://github.com/dstoc/ollama-rs", branch = "RobJellinghaus/streaming-tools", version = "0.3.2", features = ["macros", "stream"] }
openai-harmony = { git = "https://github.com/openai/harmony", tag = "v0.0.4", version = "0.0.4" }
reqwest = { version = "0.12.15", default-features = false, features = ["json", "rustls-tls"] }
Expand Down
40 changes: 26 additions & 14 deletions crates/llm/src/gemini_rust.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use std::error::Error;

use async_trait::async_trait;
use futures_util::StreamExt;
use futures_util::{StreamExt, TryStreamExt};
use gemini_rust::{
Content, FunctionCallingMode, FunctionDeclaration, FunctionParameters, Gemini, Message, Part,
Role,
};
use reqwest::Client as HttpClient;
use reqwest::{Client as HttpClient, Url};
use serde_json::Value;
use uuid::Uuid;

Expand Down Expand Up @@ -52,8 +52,8 @@ impl LlmClient for GeminiRustClient {
let gemini = Gemini::with_model_and_base_url(
self.api_key.clone(),
format!("models/{}", request.model_name),
self.base_url.clone(),
);
Url::parse(self.base_url.as_str()).unwrap(),
)?;
let mut builder = gemini.generate_content();

let mut system_instruction: Option<String> = None;
Expand All @@ -70,6 +70,7 @@ impl LlmClient for GeminiRustClient {
parts_vec.push(Part::Text {
text,
thought: None,
thought_signature: None,
});
}
AssistantPart::ToolCall(tc) => {
Expand All @@ -79,6 +80,7 @@ impl LlmClient for GeminiRustClient {
};
parts_vec.push(Part::FunctionCall {
function_call: gemini_rust::FunctionCall::new(tc.name, args),
thought_signature: None,
});
}
AssistantPart::Thinking { .. } => {}
Expand Down Expand Up @@ -143,33 +145,43 @@ impl LlmClient for GeminiRustClient {
let mut input_tokens = 0u32;
let mut output_tokens = 0u32;
let stream = builder.execute_stream().await?;
let mapped = stream.flat_map(move |res| match res {
let mapped = stream.into_stream().flat_map(move |res| match res {
Ok(chunk) => {
let mut out: Vec<Result<ResponseChunk, Box<dyn Error + Send + Sync>>> = Vec::new();
if let Some(usage) = chunk.usage_metadata {
let input_delta = usage.prompt_token_count as u32 - input_tokens;
let input_delta =
usage.prompt_token_count.unwrap_or_default() as u32 - input_tokens;
input_tokens += input_delta;
let output_delta = usage.total_token_count as u32
- usage.prompt_token_count as u32
let output_delta = usage.total_token_count.unwrap_or_default() as u32
- usage.prompt_token_count.unwrap_or_default() as u32
- output_tokens;
output_tokens += output_delta;
out.push(Ok(ResponseChunk::Usage {
input_tokens: input_delta,
output_tokens: output_delta,
}));
if input_delta > 0 || output_delta > 0 {
out.push(Ok(ResponseChunk::Usage {
input_tokens: input_delta,
output_tokens: output_delta,
}));
}
}
if let Some(candidate) = chunk.candidates.first() {
if let Some(parts) = &candidate.content.parts {
for part in parts {
match part {
Part::Text { text, thought } => {
Part::Text {
text,
thought,
thought_signature: _,
} => {
if thought.unwrap_or(false) {
out.push(Ok(ResponseChunk::Thinking(text.clone())));
} else if !text.is_empty() {
out.push(Ok(ResponseChunk::Content(text.clone())));
}
}
Part::FunctionCall { function_call } => {
Part::FunctionCall {
function_call,
thought_signature: _,
} => {
out.push(Ok(ResponseChunk::ToolCall(ToolCall {
id: Uuid::new_v4().to_string(),
name: function_call.name.clone(),
Expand Down