Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,12 @@ pub enum GeminiError {
EventSource(#[from] reqwest_eventsource::Error),
#[error("API Error: {0}")]
Api(Value),
#[error("JSON Error: {0}")]
Json(#[from] serde_json::Error),
#[error("JSON Error: {error} (payload: {data})")]
Json {
data: String,
#[source]
error: serde_json::Error,
},
#[error("Function execution error: {0}")]
FunctionExecution(String),
}
Expand Down Expand Up @@ -210,7 +214,10 @@ impl GeminiClient {
Event::Open => (),
Event::Message(event) => yield
serde_json::from_str::<types::GenerateContentResponse>(&event.data)
.map_err(Into::into),
.map_err(|error| GeminiError::Json {
data: event.data,
error,
}),
},
Err(e) => match e {
reqwest_eventsource::Error::StreamEnded => stream.close(),
Expand Down Expand Up @@ -248,19 +255,22 @@ impl GeminiClient {
return Ok(response);
};

let Some(part) = candidate.content.parts.first() else {
let Some(part) = candidate.content.as_ref().and_then(|c| c.parts.first()) else {
return Ok(response);
};

if let ContentData::FunctionCall(function_call) = &part.data {
request.contents.push(candidate.content.clone());
if let Some(content) = candidate.content.clone() {
request.contents.push(content);
}

if let Some(handler) = function_handlers.get(&function_call.name) {
let mut args = function_call.arguments.clone();
match handler.execute(&mut args).await {
Ok(result) => {
request.contents.push(Content {
parts: vec![ContentData::FunctionResponse(FunctionResponse {
id: function_call.id.clone(),
name: function_call.name.clone(),
response: FunctionResponsePayload { content: result },
})
Expand Down
65 changes: 59 additions & 6 deletions src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ pub enum FunctionCallingMode {
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct Content {
#[serde(default)]
pub parts: Vec<ContentPart>,
// Optional. The producer of the content. Must be either 'user' or 'model'.
// Useful to set for multi-turn conversations, otherwise can be left blank or unset.
Expand Down Expand Up @@ -171,6 +172,7 @@ pub struct FunctionDeclaration {
pub name: String,
pub description: String,
pub parameters: Option<FunctionParameters>,
pub parameters_json_schema: Option<serde_json::Value>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

works fine from my tests but if parameters is set it will error as parameters_json_schema can not be set when parameters is set / present in the json. You'll get an api error from gemini. though parameters can be null but not set when using parameters_json_schema

API call failed: API Error: {"context":null,"message":{"error":{"code":400,"message":"* GenerateContentRequest.tools[0].function_declarations[0].parameters_json_schema: parameters_json_schema must not be set when parameters is set.\n","status":"INVALID_ARGUMENT"}},"status":400}

so just add this and all good
#[serde(skip_serializing_if = "Option::is_none")]

pub response: Option<FunctionParameters>,
}

Expand Down Expand Up @@ -243,8 +245,10 @@ pub struct GenerateContentResponse {
pub candidates: Vec<Candidate>,
pub prompt_feedback: Option<PromptFeedback>,
pub usage_metadata: UsageMetadata,
pub model_version: String,
pub response_id: String,
#[serde(default)]
pub model_version: Option<String>,
#[serde(default)]
pub response_id: Option<String>,
}

/// Specifies the reason why the prompt was blocked.
Expand Down Expand Up @@ -344,14 +348,40 @@ pub struct ThinkingConfig {
pub include_thoughts: bool,
/// The number of thoughts tokens that the model should generate.
pub thinking_budget: Option<u32>,
/// Controls the maximum depth of the model's internal reasoning process
/// before it produces a response. If not specified, the default is HIGH.
/// Recommended for Gemini 3 or later models. Use with earlier models
/// results in an error.
pub thinking_level: Option<ThinkingLevel>,
}

/// Allow user to specify how much to think using enum instead of integer
/// budget.
#[derive(Debug, Serialize, Deserialize, Clone, Copy, Default, PartialEq)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum ThinkingLevel {
/// Unspecified thinking level.
#[default]
ThinkingLevelUnspecified,
/// High thinking level.
High,
/// Low thinking level.
Low,
}

/// A response candidate generated from the model.
#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(rename_all = "camelCase")]
pub struct Candidate {
/// Generated content returned from the model.
pub content: Content,
///
/// This field is not always populated, e.g.:
///
/// ```json
/// {"candidates": [{"finishReason": "UNEXPECTED_TOOL_CALL","index": 0}]}
/// ```
#[serde(default)]
pub content: Option<Content>,
/// The reason why the model stopped generating tokens. If empty, the model
/// has not stopped generating tokens.
pub finish_reason: Option<FinishReason>,
Expand Down Expand Up @@ -696,6 +726,8 @@ pub enum FinishReason {
/// Token generation stopped because generated images contain safety
/// violations.
ImageSafety,
/// Model generated a tool call but no tools were enabled in the request.
UnexpectedToolCall,
}

#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
Expand All @@ -707,6 +739,8 @@ pub struct ContentPart {
pub data: ContentData,
#[serde(skip_serializing)]
pub metadata: Option<serde_json::Value>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub thought_signature: Option<String>,
}

impl ContentPart {
Expand All @@ -715,6 +749,7 @@ impl ContentPart {
data: ContentData::Text(text.to_string()),
thought,
metadata: None,
thought_signature: None,
}
}

Expand All @@ -726,6 +761,7 @@ impl ContentPart {
}),
thought,
metadata: None,
thought_signature: None,
}
}

Expand All @@ -737,17 +773,25 @@ impl ContentPart {
}),
thought: false,
metadata: None,
thought_signature: None,
}
}

pub fn new_function_call(name: &str, arguments: Value, thought: bool) -> Self {
pub fn new_function_call(
id: Option<&str>,
name: &str,
arguments: Value,
thought: bool,
) -> Self {
Self {
data: ContentData::FunctionCall(FunctionCall {
id: id.map(|s| s.to_string()),
name: name.to_string(),
arguments,
}),
thought,
metadata: None,
thought_signature: None,
}
}

Expand All @@ -758,6 +802,7 @@ impl ContentPart {
}),
thought: false,
metadata: None,
thought_signature: None,
}
}

Expand All @@ -766,17 +811,20 @@ impl ContentPart {
data: ContentData::CodeExecutionResult(content),
thought: false,
metadata: None,
thought_signature: None,
}
}

pub fn new_function_response(name: &str, content: Value) -> Self {
pub fn new_function_response(id: Option<&str>, name: &str, content: Value) -> Self {
Self {
data: ContentData::FunctionResponse(FunctionResponse {
id: id.map(|s| s.to_string()),
name: name.to_string(),
response: FunctionResponsePayload { content },
}),
thought: false,
metadata: None,
thought_signature: None,
}
}
}
Expand All @@ -791,6 +839,7 @@ impl From<ContentData> for ContentPart {
data,
thought: false,
metadata: None,
thought_signature: None,
}
}
}
Expand All @@ -810,14 +859,18 @@ pub enum ContentData {
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct FunctionCall {
#[serde(default)]
pub id: Option<String>,
pub name: String,
#[serde(rename = "args")]
#[serde(default, rename = "args")]
pub arguments: serde_json::Value,
}

#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct FunctionResponse {
#[serde(default)]
pub id: Option<String>,
pub name: String,
pub response: FunctionResponsePayload,
}
Expand Down