diff --git a/.cargo-husky/hooks/pre-commit b/.cargo-husky/hooks/pre-commit index ab4914f..3788163 100755 --- a/.cargo-husky/hooks/pre-commit +++ b/.cargo-husky/hooks/pre-commit @@ -1,3 +1,2 @@ #!/bin/sh -set -e -make ci +exec make ci diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000..7b3f3af --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,69 @@ +name: Release + +on: + push: + tags: + - "v[0-9]+.[0-9]+.[0-9]+*" + +permissions: + contents: write + +env: + CARGO_TERM_COLOR: always + +jobs: + release: + name: Release + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Extract version from tag + id: version + run: echo "version=${GITHUB_REF_NAME#v}" >> "$GITHUB_OUTPUT" + + - name: Extract changelog + id: changelog + run: | + notes=$(awk '/^## \[${{ steps.version.outputs.version }}\]/{found=1; next} /^## \[/{if(found) exit} found{print}' CHANGELOG.md) + { + echo "notes<> "$GITHUB_OUTPUT" + + - name: Create GitHub Release + run: | + gh release create "$GITHUB_REF_NAME" \ + --title "$GITHUB_REF_NAME" \ + --notes "${{ steps.changelog.outputs.notes }}" + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - uses: dtolnay/rust-toolchain@stable + - uses: Swatinem/rust-cache@v2 + + - name: Publish mixtape-anthropic-sdk + run: cargo publish -p mixtape-anthropic-sdk + env: + CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }} + + - name: Publish mixtape-core + run: cargo publish -p mixtape-core + env: + CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }} + + - name: Publish mixtape-tools + run: cargo publish -p mixtape-tools + env: + CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }} + + - name: Publish mixtape-cli + run: cargo publish -p mixtape-cli + env: + CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }} + + - name: Publish mixtape-server + run: cargo publish -p mixtape-server + env: + CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }} diff --git a/.gitignore b/.gitignore index 7d53108..c9e90ef 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,10 @@ /target Cargo.lock +# Rust artifacts +**/*.rs.bk +*.pdb + # IDE .idea/ .vscode/ @@ -15,3 +19,6 @@ Cargo.lock *.profraw *.profdata lcov.info + +# cargo-mutants artifacts +**/mutants.out*/ diff --git a/CHANGELOG.md b/CHANGELOG.md index 022afa4..f7eccf9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- **mixtape-server** *(experimental)*: HTTP server with AG-UI protocol support. API surface may change in future releases. - Claude Opus 4.6 model (flagship, 200K context, 128K output) - Claude Opus 4.1 model (200K context, 32K output) - Nova 2 Sonic model (1M context, 65K output) @@ -38,6 +39,27 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Claude Opus 4.5 output token limit: 32,000 → 64,000 - Stale `--providers` CLI flag in model_verification example docs (correct flag is `--vendors`) +## [0.2.1] - 2026-01-05 + +### Fixed + +- Use i64 for rusqlite in session store for cross-platform compatibility +- Use i64 for rusqlite COUNT queries for cross-platform compatibility + +## [0.2.0] - 2026-01-05 + +### Added + +- Claude Sonnet 4.5 1M model support +- Tool grouping exports for filesystem and process modules +- `builder.add_trusted_tool()` convenience method +- Animated spinner for thinking indicator in CLI +- Improved tool execution event model for approval + +### Changed + +- Updated non-interactive examples to use `add_trusted_tool()` + ## [0.1.1] - 2026-01-04 ### Added @@ -59,5 +81,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **mixtape-cli**: Session storage and REPL utilities for interactive agents [Unreleased]: https://github.com/adlio/mixtape/compare/v0.3.0...HEAD -[0.3.0]: https://github.com/adlio/mixtape/compare/v0.1.1...v0.3.0 +[0.3.0]: https://github.com/adlio/mixtape/compare/v0.2.1...v0.3.0 +[0.2.1]: https://github.com/adlio/mixtape/compare/v0.2.0...v0.2.1 +[0.2.0]: https://github.com/adlio/mixtape/compare/v0.1.1...v0.2.0 [0.1.1]: https://github.com/adlio/mixtape/releases/tag/v0.1.1 diff --git a/Cargo.toml b/Cargo.toml index 1245810..ec31954 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,11 +1,12 @@ [workspace] -members = ["mixtape-core", "mixtape-anthropic-sdk", "mixtape-tools", "mixtape-cli"] +members = ["mixtape-core", "mixtape-anthropic-sdk", "mixtape-tools", "mixtape-cli", "mixtape-server"] resolver = "2" [workspace.package] version = "0.3.0" edition = "2021" license = "MIT" +homepage = "https://github.com/adlio/mixtape" repository = "https://github.com/adlio/mixtape" [workspace.dependencies] @@ -14,14 +15,21 @@ mixtape-core = { version = "0.3.0", path = "./mixtape-core" } mixtape-anthropic-sdk = { version = "0.3.0", path = "./mixtape-anthropic-sdk" } mixtape-tools = { path = "./mixtape-tools" } mixtape-cli = { path = "./mixtape-cli" } +mixtape-server = { path = "./mixtape-server" } # Async runtime tokio = { version = "1.41", features = ["full"] } tokio-test = "0.4" +tokio-stream = "0.1" async-trait = "0.1" async-stream = "0.3" futures = "0.3" +# HTTP Server +axum = { version = "0.7", features = ["macros"] } +tower = { version = "0.5", features = ["util"] } +tower-http = { version = "0.6", features = ["cors", "trace"] } + # Serialization serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" diff --git a/Makefile b/Makefile index cee7816..25ad815 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,7 @@ .DEFAULT_GOAL := help -.PHONY: help test coverage coverage-html build build-release clean fmt fmt-check lint check doc doc-check all ci ensure-tools +.PHONY: help test coverage coverage-html coverage-ci build build-release clean fmt fmt-check clippy clippy-fix lint check doc doc-check all ci ensure-tools # Tool installation helpers CARGO_NEXTEST := $(shell command -v cargo-nextest 2>/dev/null) @@ -31,6 +31,9 @@ coverage: ensure-tools ## Show coverage summary in console coverage-html: ensure-tools ## Generate HTML coverage report and open cargo llvm-cov nextest --workspace --all-features --html --open +coverage-ci: ensure-tools ## Generate LCOV coverage for CI upload + cargo llvm-cov nextest --workspace --all-features --lcov --output-path lcov.info + build: ## Build debug cargo build --workspace --all-targets --all-features @@ -49,9 +52,14 @@ fmt: ## Format code fmt-check: ## Check formatting cargo fmt --all -- --check -lint: ## Run clippy +clippy: ## Run clippy with warnings as errors cargo clippy --workspace --all-targets --all-features -- -D warnings +clippy-fix: ## Run clippy and auto-fix + cargo clippy --workspace --all-targets --all-features --fix --allow-dirty -- -D warnings + +lint: clippy ## Alias for clippy + clean: ## Clean build artifacts cargo clean @@ -61,6 +69,6 @@ doc: ## Generate docs doc-check: ## Check docs build without warnings RUSTDOCFLAGS="-D warnings" cargo doc --workspace --no-deps -all: ensure-tools fmt lint build test ## Format, lint, build, and test +all: ensure-tools fmt clippy build test ## Format, lint, build, and test -ci: ensure-tools fmt-check lint build doc-check test ## Check formatting, lint, build, docs, test (for CI/hooks) +ci: ensure-tools fmt-check clippy build doc-check test ## Check formatting, lint, build, docs, test (for CI/hooks) diff --git a/README.md b/README.md index c9b17ed..79cb0f3 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,10 @@ # Mixtape +[![Crates.io](https://img.shields.io/crates/v/mixtape-core.svg)](https://crates.io/crates/mixtape-core) +[![Documentation](https://docs.rs/mixtape-core/badge.svg)](https://docs.rs/mixtape-core) [![CI](https://github.com/adlio/mixtape/actions/workflows/ci.yml/badge.svg)](https://github.com/adlio/mixtape/actions/workflows/ci.yml) -[![Coverage](https://codecov.io/gh/adlio/mixtape/branch/main/graph/badge.svg)](https://codecov.io/gh/adlio/mixtape) -[![License](https://img.shields.io/badge/license-MIT-blue.svg)](LICENSE) +[![codecov](https://codecov.io/gh/adlio/mixtape/graph/badge.svg)](https://codecov.io/gh/adlio/mixtape) +[![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)](https://opensource.org/licenses/MIT) An agentic AI framework for Rust. @@ -51,13 +53,14 @@ Add `mcp` for MCP server integration, `session` for conversation persistence. ## Workspace Crates -This repository contains four crates: +This repository contains five crates: | Crate | Purpose | |---------------------------|--------------------------------------------------------| | **mixtape-core** | Core agent framework | | **mixtape-tools** | Pre-built filesystem, process, web, and database tools | | **mixtape-cli** | Session storage and interactive REPL features | +| **mixtape-server** | HTTP server with AG-UI protocol support *(experimental)* | | **mixtape-anthropic-sdk** | Low-level Anthropic API client (used internally) | Most projects need only `mixtape-core`. Add `mixtape-tools` for ready-to-use tools. diff --git a/codecov.yml b/codecov.yml new file mode 100644 index 0000000..2a32abb --- /dev/null +++ b/codecov.yml @@ -0,0 +1,11 @@ +coverage: + status: + project: + default: + target: auto + threshold: 1% + patch: + default: + target: 80% +ignore: + - "mixtape-server/**" diff --git a/mixtape-anthropic-sdk/Cargo.toml b/mixtape-anthropic-sdk/Cargo.toml index 8522026..429bbd3 100644 --- a/mixtape-anthropic-sdk/Cargo.toml +++ b/mixtape-anthropic-sdk/Cargo.toml @@ -4,11 +4,13 @@ version.workspace = true edition.workspace = true license.workspace = true repository.workspace = true +homepage.workspace = true description = "Minimal Anthropic API client for the mixtape agent framework" documentation = "https://docs.rs/mixtape-anthropic-sdk" readme = "README.md" keywords = ["anthropic", "claude", "llm", "ai", "api"] categories = ["api-bindings", "asynchronous"] +exclude = [".cargo-husky/", ".claude/", ".github/", ".idea/"] [features] default = [] diff --git a/mixtape-cli/Cargo.toml b/mixtape-cli/Cargo.toml index 41652ed..d76ab58 100644 --- a/mixtape-cli/Cargo.toml +++ b/mixtape-cli/Cargo.toml @@ -4,11 +4,13 @@ version.workspace = true edition.workspace = true license.workspace = true repository.workspace = true +homepage.workspace = true description = "Session storage and REPL utilities for the mixtape agent framework" documentation = "https://docs.rs/mixtape-cli" readme = "README.md" keywords = ["ai", "agents", "cli", "repl", "session"] categories = ["command-line-utilities", "development-tools"] +exclude = [".cargo-husky/", ".claude/", ".github/", ".idea/"] [dependencies] mixtape-core = { workspace = true, features = ["session"] } diff --git a/mixtape-core/Cargo.toml b/mixtape-core/Cargo.toml index 8763798..e5cfa91 100644 --- a/mixtape-core/Cargo.toml +++ b/mixtape-core/Cargo.toml @@ -4,11 +4,13 @@ version.workspace = true edition.workspace = true license.workspace = true repository.workspace = true +homepage.workspace = true description = "An agentic AI framework for Rust" documentation = "https://docs.rs/mixtape-core" readme = "../README.md" keywords = ["ai", "agents", "llm", "anthropic", "tools"] categories = ["development-tools", "api-bindings", "asynchronous"] +exclude = [".cargo-husky/", ".claude/", ".github/", ".idea/"] [features] default = [] @@ -16,6 +18,7 @@ session = [] bedrock = ["dep:aws-config", "dep:aws-sdk-bedrockruntime", "dep:aws-smithy-types"] anthropic = ["dep:mixtape-anthropic-sdk", "dep:base64"] mcp = ["dep:rmcp", "dep:reqwest", "dep:shellexpand"] +test-utils = [] [dependencies] async-stream.workspace = true diff --git a/mixtape-core/src/agent/builder.rs b/mixtape-core/src/agent/builder.rs index 4c6f435..5b71ac9 100644 --- a/mixtape-core/src/agent/builder.rs +++ b/mixtape-core/src/agent/builder.rs @@ -9,6 +9,7 @@ use std::collections::HashMap; use std::future::Future; use std::pin::Pin; +use std::sync::atomic::AtomicU64; use std::sync::Arc; use std::time::Duration; use tokio::sync::RwLock; @@ -559,7 +560,8 @@ impl AgentBuilder { system_prompt: self.system_prompt, max_concurrent_tools: self.max_concurrent_tools, tools: self.tools, - hooks: Arc::new(parking_lot::RwLock::new(Vec::new())), + hooks: Arc::new(parking_lot::RwLock::new(HashMap::new())), + next_hook_id: AtomicU64::new(0), authorizer: Arc::new(RwLock::new(authorizer)), authorization_timeout: self.authorization_timeout, pending_authorizations: Arc::new(RwLock::new(HashMap::new())), diff --git a/mixtape-core/src/agent/mod.rs b/mixtape-core/src/agent/mod.rs index 0b9e107..29945c5 100644 --- a/mixtape-core/src/agent/mod.rs +++ b/mixtape-core/src/agent/mod.rs @@ -29,12 +29,13 @@ pub use types::{ pub use types::SessionInfo; use std::collections::HashMap; +use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; use std::time::Duration; use tokio::sync::{mpsc, RwLock}; use crate::conversation::BoxedConversationManager; -use crate::events::{AgentEvent, AgentHook}; +use crate::events::{AgentEvent, AgentHook, HookId}; use crate::permission::{AuthorizationResponse, ToolCallAuthorizer}; use crate::provider::ModelProvider; use crate::tool::DynTool; @@ -68,7 +69,8 @@ pub struct Agent { pub(super) system_prompt: Option, pub(super) max_concurrent_tools: usize, pub(super) tools: Vec>, - pub(super) hooks: Arc>>>, + pub(super) hooks: Arc>>>, + pub(super) next_hook_id: AtomicU64, /// Tool call authorizer (always present, uses MemoryGrantStore by default) pub(super) authorizer: Arc>, /// Timeout for authorization requests @@ -95,7 +97,9 @@ pub struct Agent { } impl Agent { - /// Add an event hook to observe agent execution + /// Add an event hook to observe agent execution. + /// + /// Returns a [`HookId`] that can be used to remove the hook later via [`remove_hook`](Self::remove_hook). /// /// Hooks receive notifications about agent lifecycle, model calls, /// and tool executions in real-time. @@ -116,16 +120,28 @@ impl Agent { /// .bedrock(ClaudeSonnet4_5) /// .build() /// .await?; - /// agent.add_hook(Logger); + /// let hook_id = agent.add_hook(Logger); + /// + /// // Later, remove the hook + /// agent.remove_hook(hook_id); /// ``` - pub fn add_hook(&self, hook: impl AgentHook + 'static) { - self.hooks.write().push(Arc::new(hook)); + pub fn add_hook(&self, hook: impl AgentHook + 'static) -> HookId { + let id = HookId(self.next_hook_id.fetch_add(1, Ordering::SeqCst)); + self.hooks.write().insert(id, Arc::new(hook)); + id + } + + /// Remove a previously registered hook. + /// + /// Returns `true` if the hook was found and removed, `false` otherwise. + pub fn remove_hook(&self, id: HookId) -> bool { + self.hooks.write().remove(&id).is_some() } /// Emit an event to all registered hooks pub(crate) fn emit_event(&self, event: AgentEvent) { let hooks = self.hooks.read(); - for hook in hooks.iter() { + for hook in hooks.values() { hook.on_event(&event); } } diff --git a/mixtape-core/src/events.rs b/mixtape-core/src/events.rs index e6c3f07..01c3c49 100644 --- a/mixtape-core/src/events.rs +++ b/mixtape-core/src/events.rs @@ -218,6 +218,12 @@ where } } +/// Unique identifier for a registered hook. +/// +/// Used to remove hooks via [`crate::Agent::remove_hook`]. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct HookId(pub(crate) u64); + #[cfg(test)] mod tests { use super::*; diff --git a/mixtape-core/src/lib.rs b/mixtape-core/src/lib.rs index 6dcdf72..f04b472 100644 --- a/mixtape-core/src/lib.rs +++ b/mixtape-core/src/lib.rs @@ -143,6 +143,9 @@ pub mod mcp; #[cfg(feature = "session")] pub mod session; +#[cfg(feature = "test-utils")] +pub mod test_utils; + pub use agent::{ Agent, AgentBuilder, AgentError, AgentResponse, ContextConfig, ContextError, ContextLoadResult, ContextSource, PermissionError, TokenUsageStats, ToolCallInfo, ToolInfo, @@ -154,7 +157,7 @@ pub use conversation::{ TokenEstimator, }; pub use error::{Error, Result}; -pub use events::{AgentEvent, AgentHook, TokenUsage}; +pub use events::{AgentEvent, AgentHook, HookId, TokenUsage}; pub use model::{ AnthropicModel, BedrockModel, InferenceProfile, Model, ModelRequest, ModelResponse, diff --git a/mixtape-core/src/test_utils.rs b/mixtape-core/src/test_utils.rs new file mode 100644 index 0000000..2da2b72 --- /dev/null +++ b/mixtape-core/src/test_utils.rs @@ -0,0 +1,364 @@ +//! Test utilities for mixtape-core. +//! +//! This module provides mock implementations for testing agents without +//! requiring real LLM provider credentials. +//! +//! Enable with the `test-utils` feature: +//! +//! ```toml +//! [dev-dependencies] +//! mixtape-core = { version = "...", features = ["test-utils"] } +//! ``` +//! +//! # Example +//! +//! ```rust +//! use mixtape_core::{Agent, test_utils::MockProvider}; +//! +//! # async fn example() -> mixtape_core::Result<()> { +//! let provider = MockProvider::new() +//! .with_text("Hello from mock!"); +//! +//! let agent = Agent::builder() +//! .provider(provider) +//! .build() +//! .await?; +//! +//! let response = agent.run("Hi").await?; +//! assert_eq!(response.text(), "Hello from mock!"); +//! # Ok(()) +//! # } +//! ``` + +use std::sync::{Arc, Mutex}; + +use crate::events::AgentEvent; +use crate::model::ModelResponse; +use crate::provider::{ModelProvider, ProviderError}; +use crate::types::{ContentBlock, Message, Role, StopReason, ToolDefinition, ToolUseBlock}; + +/// A mock model provider for testing. +/// +/// Returns pre-programmed responses in order. Useful for testing agent behavior +/// without making real API calls. +/// +/// # Example +/// +/// ```rust +/// use mixtape_core::test_utils::MockProvider; +/// use serde_json::json; +/// +/// // Simple text response +/// let provider = MockProvider::new() +/// .with_text("Hello!"); +/// +/// // Tool use followed by final response +/// let provider = MockProvider::new() +/// .with_tool_use("calculator", json!({"expr": "2+2"})) +/// .with_text("The answer is 4"); +/// ``` +#[derive(Clone)] +pub struct MockProvider { + responses: Arc>>, + call_count: Arc>, +} + +impl MockProvider { + /// Create a new mock provider with no responses. + pub fn new() -> Self { + Self { + responses: Arc::new(Mutex::new(Vec::new())), + call_count: Arc::new(Mutex::new(0)), + } + } + + /// Add a text response to the queue. + /// + /// The response will have `StopReason::EndTurn`. + pub fn with_text(self, text: impl Into) -> Self { + let message = Message::assistant(text); + + let response = ModelResponse { + message, + stop_reason: StopReason::EndTurn, + usage: None, + }; + + self.responses.lock().unwrap().push(response); + self + } + + /// Add a tool use response to the queue. + /// + /// The response will have `StopReason::ToolUse`. + pub fn with_tool_use( + self, + tool_name: impl Into, + tool_input: serde_json::Value, + ) -> Self { + let tool_use = ToolUseBlock { + id: format!("tool_{}", uuid::Uuid::new_v4()), + name: tool_name.into(), + input: tool_input, + }; + + let message = Message { + role: Role::Assistant, + content: vec![ContentBlock::ToolUse(tool_use)], + }; + + let response = ModelResponse { + message, + stop_reason: StopReason::ToolUse, + usage: None, + }; + + self.responses.lock().unwrap().push(response); + self + } + + /// Get the number of times `generate` was called. + pub fn call_count(&self) -> usize { + *self.call_count.lock().unwrap() + } +} + +impl Default for MockProvider { + fn default() -> Self { + Self::new() + } +} + +#[async_trait::async_trait] +impl ModelProvider for MockProvider { + fn name(&self) -> &str { + "MockProvider" + } + + fn max_context_tokens(&self) -> usize { + 200_000 + } + + fn max_output_tokens(&self) -> usize { + 8_192 + } + + async fn generate( + &self, + _messages: Vec, + _tools: Vec, + _system_prompt: Option, + ) -> Result { + let mut count = self.call_count.lock().unwrap(); + *count += 1; + + let mut responses = self.responses.lock().unwrap(); + if responses.is_empty() { + return Err(ProviderError::Other( + "MockProvider: No more responses configured".to_string(), + )); + } + + Ok(responses.remove(0)) + } +} + +/// Collects agent events for verification in tests. +/// +/// Stores full [`AgentEvent`] objects and provides convenience methods +/// for inspecting event types. +/// +/// # Example +/// +/// ```rust +/// use mixtape_core::{Agent, test_utils::{MockProvider, EventCollector}}; +/// +/// # async fn example() -> mixtape_core::Result<()> { +/// let provider = MockProvider::new().with_text("Hello!"); +/// let collector = EventCollector::new(); +/// +/// let agent = Agent::builder() +/// .provider(provider) +/// .build() +/// .await?; +/// +/// agent.add_hook(collector.clone()); +/// agent.run("Hi").await?; +/// +/// assert!(collector.has_event("run_started")); +/// assert!(collector.has_event("run_completed")); +/// # Ok(()) +/// # } +/// ``` +#[derive(Clone)] +pub struct EventCollector { + events: Arc>>, +} + +impl EventCollector { + /// Create a new event collector. + pub fn new() -> Self { + Self { + events: Arc::new(Mutex::new(Vec::new())), + } + } + + /// Get all collected events. + pub fn events(&self) -> Vec { + self.events.lock().unwrap().clone() + } + + /// Get all collected event type names. + pub fn event_types(&self) -> Vec { + self.events + .lock() + .unwrap() + .iter() + .map(|e| Self::event_type_name(e).to_string()) + .collect() + } + + /// Clear all collected events. + pub fn clear(&self) { + self.events.lock().unwrap().clear(); + } + + /// Check if a specific event type was collected. + pub fn has_event(&self, event_type: &str) -> bool { + self.events + .lock() + .unwrap() + .iter() + .any(|e| Self::event_type_name(e) == event_type) + } + + /// Count occurrences of a specific event type. + pub fn count_event(&self, event_type: &str) -> usize { + self.events + .lock() + .unwrap() + .iter() + .filter(|e| Self::event_type_name(e) == event_type) + .count() + } + + /// Get the number of collected events. + pub fn len(&self) -> usize { + self.events.lock().unwrap().len() + } + + /// Check if no events have been collected. + pub fn is_empty(&self) -> bool { + self.events.lock().unwrap().is_empty() + } + + fn event_type_name(event: &AgentEvent) -> &'static str { + match event { + AgentEvent::RunStarted { .. } => "run_started", + AgentEvent::RunCompleted { .. } => "run_completed", + AgentEvent::RunFailed { .. } => "run_failed", + AgentEvent::ModelCallStarted { .. } => "model_call_started", + AgentEvent::ModelCallStreaming { .. } => "model_streaming", + AgentEvent::ModelCallCompleted { .. } => "model_call_completed", + AgentEvent::ToolRequested { .. } => "tool_requested", + AgentEvent::ToolExecuting { .. } => "tool_executing", + AgentEvent::ToolCompleted { .. } => "tool_completed", + AgentEvent::ToolFailed { .. } => "tool_failed", + AgentEvent::PermissionRequired { .. } => "permission_required", + AgentEvent::PermissionGranted { .. } => "permission_granted", + AgentEvent::PermissionDenied { .. } => "permission_denied", + #[cfg(feature = "session")] + AgentEvent::SessionResumed { .. } => "session_resumed", + #[cfg(feature = "session")] + AgentEvent::SessionSaved { .. } => "session_saved", + } + } +} + +impl Default for EventCollector { + fn default() -> Self { + Self::new() + } +} + +impl crate::events::AgentHook for EventCollector { + fn on_event(&self, event: &AgentEvent) { + self.events.lock().unwrap().push(event.clone()); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_mock_provider_text_response() { + let provider = MockProvider::new().with_text("Hello!"); + assert_eq!(provider.call_count(), 0); + } + + #[test] + fn test_mock_provider_chained_responses() { + let provider = MockProvider::new() + .with_tool_use("calculator", serde_json::json!({"expr": "2+2"})) + .with_text("The answer is 4"); + + // Verify both responses were queued + assert_eq!(provider.call_count(), 0); + } + + #[tokio::test] + async fn test_mock_provider_generate() { + let provider = MockProvider::new() + .with_text("Response 1") + .with_text("Response 2"); + + let response1 = provider.generate(vec![], vec![], None).await.unwrap(); + assert_eq!(provider.call_count(), 1); + assert!(response1.message.text().contains("Response 1")); + + let response2 = provider.generate(vec![], vec![], None).await.unwrap(); + assert_eq!(provider.call_count(), 2); + assert!(response2.message.text().contains("Response 2")); + + // Should error when exhausted + let result = provider.generate(vec![], vec![], None).await; + assert!(result.is_err()); + } + + #[test] + fn test_event_collector() { + let collector = EventCollector::new(); + assert!(collector.is_empty()); + + // Simulate adding events directly for testing + collector + .events + .lock() + .unwrap() + .push(AgentEvent::RunStarted { + input: "test".to_string(), + timestamp: std::time::Instant::now(), + }); + collector + .events + .lock() + .unwrap() + .push(AgentEvent::RunCompleted { + output: "done".to_string(), + duration: std::time::Duration::from_secs(1), + }); + + assert_eq!(collector.len(), 2); + assert!(collector.has_event("run_started")); + assert!(collector.has_event("run_completed")); + assert!(!collector.has_event("run_failed")); + assert_eq!(collector.count_event("run_started"), 1); + + let types = collector.event_types(); + assert_eq!(types, vec!["run_started", "run_completed"]); + + collector.clear(); + assert!(collector.is_empty()); + } +} diff --git a/mixtape-server/Cargo.toml b/mixtape-server/Cargo.toml new file mode 100644 index 0000000..96d066b --- /dev/null +++ b/mixtape-server/Cargo.toml @@ -0,0 +1,42 @@ +[package] +name = "mixtape-server" +version.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true +homepage.workspace = true +description = "HTTP server and AG-UI protocol support for the mixtape agent framework" +documentation = "https://docs.rs/mixtape-server" +readme = "README.md" +keywords = ["ai", "agents", "http", "sse", "ag-ui"] +categories = ["web-programming::http-server", "development-tools"] +exclude = [".cargo-husky/", ".claude/", ".github/", ".idea/"] + +[features] +default = [] +agui = [] + +[dependencies] +mixtape-core.workspace = true +axum.workspace = true +tokio.workspace = true +tokio-stream.workspace = true +futures.workspace = true +serde.workspace = true +serde_json.workspace = true +thiserror.workspace = true +uuid.workspace = true +chrono.workspace = true +tower.workspace = true +tower-http.workspace = true +parking_lot.workspace = true + +[dev-dependencies] +tokio-test.workspace = true +cargo-husky.workspace = true +mixtape-core = { workspace = true, features = ["bedrock", "test-utils"] } +axum-test = "16" + +[[example]] +name = "basic_server" +required-features = ["agui"] diff --git a/mixtape-server/examples/basic_server.rs b/mixtape-server/examples/basic_server.rs new file mode 100644 index 0000000..2615938 --- /dev/null +++ b/mixtape-server/examples/basic_server.rs @@ -0,0 +1,45 @@ +//! Basic mixtape server example with AG-UI support. +//! +//! This example creates an HTTP server that exposes an agent via AG-UI protocol. +//! +//! Run with: +//! ```sh +//! cargo run -p mixtape-server --example basic_server --features agui +//! ``` +//! +//! Test with curl: +//! ```sh +//! curl -X POST http://localhost:3000/api/copilotkit \ +//! -H "Content-Type: application/json" \ +//! -d '{"message": "Hello!"}' \ +//! -N +//! ``` + +use mixtape_core::{Agent, ClaudeHaiku4_5}; +use mixtape_server::MixtapeRouter; + +#[tokio::main] +async fn main() -> Result<(), Box> { + // Create the agent + let agent = Agent::builder() + .bedrock(ClaudeHaiku4_5) + .with_system_prompt("You are a helpful assistant.") + .interactive() // Enable permission prompts (for demonstration) + .build() + .await?; + + // Build the router with AG-UI endpoint + let app = MixtapeRouter::new(agent) + .with_agui("/api/copilotkit") // SSE endpoint + .build()?; + + // Start the server + let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await?; + println!("Server running at http://localhost:3000"); + println!("AG-UI endpoint: POST http://localhost:3000/api/copilotkit"); + println!("Interrupt endpoint: POST http://localhost:3000/api/copilotkit/interrupt"); + + axum::serve(listener, app).await?; + + Ok(()) +} diff --git a/mixtape-server/src/agui/convert.rs b/mixtape-server/src/agui/convert.rs new file mode 100644 index 0000000..b34bda7 --- /dev/null +++ b/mixtape-server/src/agui/convert.rs @@ -0,0 +1,209 @@ +//! Conversion from mixtape AgentEvent to AG-UI events. + +use mixtape_core::events::AgentEvent; + +use super::events::{AguiEvent, InterruptData, InterruptType, MessageRole}; + +/// Context for converting AgentEvent to AG-UI events. +/// +/// Maintains state across events to properly track message boundaries +/// and generate consistent IDs. +pub struct ConversionContext { + /// Thread ID for conversation continuity. + pub thread_id: String, + /// Run ID for this execution. + pub run_id: String, + /// Current message ID being built (for streaming). + current_message_id: Option, +} + +impl ConversionContext { + /// Create a new conversion context. + pub fn new(thread_id: String, run_id: String) -> Self { + Self { + thread_id, + run_id, + current_message_id: None, + } + } + + /// Get the current message ID, if any. + pub fn current_message_id(&self) -> Option<&str> { + self.current_message_id.as_deref() + } + + /// Set the current message ID. + pub fn set_current_message_id(&mut self, id: String) { + self.current_message_id = Some(id); + } + + /// Clear and return the current message ID. + pub fn take_current_message_id(&mut self) -> Option { + self.current_message_id.take() + } +} + +/// Convert an AgentEvent to AG-UI events. +/// +/// Some AgentEvents map to multiple AG-UI events, so this returns a Vec. +/// The context is mutated to track state across events. +pub fn convert_event(event: &AgentEvent, ctx: &mut ConversionContext) -> Vec { + match event { + // ===== Lifecycle Events ===== + AgentEvent::RunStarted { .. } => { + vec![AguiEvent::RunStarted { + thread_id: ctx.thread_id.clone(), + run_id: ctx.run_id.clone(), + }] + } + + AgentEvent::RunCompleted { .. } => { + let mut events = Vec::new(); + + // End any current message + if let Some(msg_id) = ctx.take_current_message_id() { + events.push(AguiEvent::TextMessageEnd { message_id: msg_id }); + } + + events.push(AguiEvent::RunFinished { + thread_id: ctx.thread_id.clone(), + run_id: ctx.run_id.clone(), + }); + + events + } + + AgentEvent::RunFailed { error, .. } => { + vec![AguiEvent::RunError { + message: error.clone(), + code: None, + }] + } + + // ===== Model Streaming (Text Messages) ===== + AgentEvent::ModelCallStarted { .. } => { + // Start a new assistant message + let message_id = uuid::Uuid::new_v4().to_string(); + ctx.set_current_message_id(message_id.clone()); + + vec![AguiEvent::TextMessageStart { + message_id, + role: MessageRole::Assistant, + }] + } + + AgentEvent::ModelCallStreaming { delta, .. } => { + if let Some(message_id) = ctx.current_message_id() { + vec![AguiEvent::TextMessageContent { + message_id: message_id.to_string(), + delta: delta.clone(), + }] + } else { + vec![] + } + } + + AgentEvent::ModelCallCompleted { .. } => { + // Don't end the message here - wait for RunCompleted or next ModelCallStarted + // This handles the case where the model continues after tool use + vec![] + } + + // ===== Tool Events ===== + AgentEvent::ToolRequested { + tool_use_id, + name, + input, + } => { + // End current message before tool call + let mut events = Vec::new(); + if let Some(msg_id) = ctx.take_current_message_id() { + events.push(AguiEvent::TextMessageEnd { message_id: msg_id }); + } + + // Tool call events + events.push(AguiEvent::ToolCallStart { + tool_call_id: tool_use_id.clone(), + tool_call_name: name.clone(), + parent_message_id: None, + }); + events.push(AguiEvent::ToolCallArgs { + tool_call_id: tool_use_id.clone(), + delta: serde_json::to_string(input).unwrap_or_default(), + }); + events.push(AguiEvent::ToolCallEnd { + tool_call_id: tool_use_id.clone(), + }); + + events + } + + AgentEvent::ToolExecuting { .. } => { + // No AG-UI equivalent - tool execution status is implicit + vec![] + } + + AgentEvent::ToolCompleted { + tool_use_id, + output, + .. + } => { + vec![AguiEvent::ToolCallResult { + message_id: uuid::Uuid::new_v4().to_string(), + tool_call_id: tool_use_id.clone(), + content: output.as_text(), + role: Some(MessageRole::Tool), + }] + } + + AgentEvent::ToolFailed { + tool_use_id, error, .. + } => { + vec![AguiEvent::ToolCallResult { + message_id: uuid::Uuid::new_v4().to_string(), + tool_call_id: tool_use_id.clone(), + content: format!("Error: {}", error), + role: Some(MessageRole::Tool), + }] + } + + // ===== Permission Events ===== + AgentEvent::PermissionRequired { + proposal_id, + tool_name, + params, + params_hash, + } => { + vec![AguiEvent::Interrupt { + interrupt_id: proposal_id.clone(), + interrupt_type: InterruptType::ToolApproval, + data: InterruptData { + tool_use_id: proposal_id.clone(), + tool_name: tool_name.clone(), + params: params.clone(), + params_hash: params_hash.clone(), + }, + }] + } + + AgentEvent::PermissionGranted { .. } => { + // Silent - the tool will execute and emit ToolCompleted + vec![] + } + + AgentEvent::PermissionDenied { .. } => { + // Silent - covered by subsequent ToolFailed event + vec![] + } + + // ===== Session Events ===== + // These are feature-gated in mixtape-core, but we handle them here + // regardless since the enum variant exists when session is enabled + #[allow(unreachable_patterns)] + _ => vec![], + } +} + +#[cfg(test)] +#[path = "convert_tests.rs"] +mod tests; diff --git a/mixtape-server/src/agui/convert_tests.rs b/mixtape-server/src/agui/convert_tests.rs new file mode 100644 index 0000000..a1877ec --- /dev/null +++ b/mixtape-server/src/agui/convert_tests.rs @@ -0,0 +1,461 @@ +//! Comprehensive tests for AgentEvent to AG-UI event conversion. +//! +//! These tests verify message boundary management and edge cases. + +use super::*; +use mixtape_core::events::AgentEvent; +use mixtape_core::tool::ToolResult; +use std::time::{Duration, Instant}; + +#[test] +fn test_multiple_sequential_model_calls() { + let mut ctx = ConversionContext::new("thread-1".to_string(), "run-1".to_string()); + + // First model call + let start1 = AgentEvent::ModelCallStarted { + message_count: 1, + tool_count: 0, + timestamp: Instant::now(), + }; + let events = convert_event(&start1, &mut ctx); + assert_eq!(events.len(), 1); + let _first_msg_id = if let AguiEvent::TextMessageStart { message_id, .. } = &events[0] { + message_id.clone() + } else { + panic!("Expected TextMessageStart"); + }; + + // Second model call should end the first message + let start2 = AgentEvent::ModelCallStarted { + message_count: 2, + tool_count: 0, + timestamp: Instant::now(), + }; + let events = convert_event(&start2, &mut ctx); + + // Should produce: TextMessageEnd for first message, TextMessageStart for second + assert_eq!(events.len(), 1, "Second ModelCallStarted should only start new message, not end previous (that's handled by tool calls or RunCompleted)"); + + // Verify new message was started + assert!(matches!(&events[0], AguiEvent::TextMessageStart { .. })); +} + +#[test] +fn test_tool_requested_ends_current_message() { + let mut ctx = ConversionContext::new("thread-1".to_string(), "run-1".to_string()); + + // Start a message + let start = AgentEvent::ModelCallStarted { + message_count: 1, + tool_count: 0, + timestamp: Instant::now(), + }; + convert_event(&start, &mut ctx); + assert!(ctx.current_message_id().is_some()); + + // Tool request should end the message + let tool_req = AgentEvent::ToolRequested { + tool_use_id: "tc-1".to_string(), + name: "echo".to_string(), + input: serde_json::json!({"text": "hello"}), + }; + let events = convert_event(&tool_req, &mut ctx); + + // Should end message, then emit 3 tool events + assert!(events.len() >= 4); + assert!(matches!(&events[0], AguiEvent::TextMessageEnd { .. })); + assert!(matches!(&events[1], AguiEvent::ToolCallStart { .. })); + assert!(ctx.current_message_id().is_none()); +} + +#[test] +fn test_tool_requested_without_current_message() { + let mut ctx = ConversionContext::new("thread-1".to_string(), "run-1".to_string()); + + // Tool request with no active message (edge case) + let tool_req = AgentEvent::ToolRequested { + tool_use_id: "tc-1".to_string(), + name: "echo".to_string(), + input: serde_json::json!({"text": "hello"}), + }; + let events = convert_event(&tool_req, &mut ctx); + + // Should only emit tool events, no TextMessageEnd + assert_eq!(events.len(), 3); // Start, Args, End + assert!(matches!(&events[0], AguiEvent::ToolCallStart { .. })); + assert!(matches!(&events[1], AguiEvent::ToolCallArgs { .. })); + assert!(matches!(&events[2], AguiEvent::ToolCallEnd { .. })); +} + +#[test] +fn test_run_completed_ends_current_message() { + let mut ctx = ConversionContext::new("thread-1".to_string(), "run-1".to_string()); + + // Start a message + let start = AgentEvent::ModelCallStarted { + message_count: 1, + tool_count: 0, + timestamp: Instant::now(), + }; + convert_event(&start, &mut ctx); + assert!(ctx.current_message_id().is_some()); + + // RunCompleted should end the message + let completed = AgentEvent::RunCompleted { + output: "Done".to_string(), + duration: Duration::from_secs(1), + }; + let events = convert_event(&completed, &mut ctx); + + assert_eq!(events.len(), 2); + assert!(matches!(&events[0], AguiEvent::TextMessageEnd { .. })); + assert!(matches!(&events[1], AguiEvent::RunFinished { .. })); + assert!(ctx.current_message_id().is_none()); +} + +#[test] +fn test_run_completed_without_current_message() { + let mut ctx = ConversionContext::new("thread-1".to_string(), "run-1".to_string()); + + // RunCompleted with no active message + let completed = AgentEvent::RunCompleted { + output: "Done".to_string(), + duration: Duration::from_secs(1), + }; + let events = convert_event(&completed, &mut ctx); + + // Should only emit RunFinished + assert_eq!(events.len(), 1); + assert!(matches!(&events[0], AguiEvent::RunFinished { .. })); +} + +#[test] +fn test_streaming_without_current_message() { + let mut ctx = ConversionContext::new("thread-1".to_string(), "run-1".to_string()); + + // Streaming event without active message (shouldn't happen, but test graceful handling) + let streaming = AgentEvent::ModelCallStreaming { + delta: "Hello".to_string(), + accumulated_length: 5, + }; + let events = convert_event(&streaming, &mut ctx); + + // Should return empty vec, not crash + assert_eq!(events.len(), 0); +} + +#[test] +fn test_empty_streaming_delta() { + let mut ctx = ConversionContext::new("thread-1".to_string(), "run-1".to_string()); + + // Start message + let start = AgentEvent::ModelCallStarted { + message_count: 1, + tool_count: 0, + timestamp: Instant::now(), + }; + convert_event(&start, &mut ctx); + + // Empty delta should still produce event + let streaming = AgentEvent::ModelCallStreaming { + delta: "".to_string(), + accumulated_length: 0, + }; + let events = convert_event(&streaming, &mut ctx); + + assert_eq!(events.len(), 1); + if let AguiEvent::TextMessageContent { delta, .. } = &events[0] { + assert_eq!(delta, ""); + } else { + panic!("Expected TextMessageContent"); + } +} + +#[test] +fn test_tool_completed_with_different_result_types() { + let test_cases = [ + (ToolResult::Text("Success".to_string()), "Success"), + ( + ToolResult::Json(serde_json::json!({"status": "ok", "count": 42})), + r#"{"count":42,"status":"ok"}"#, // Note: JSON objects are sorted by key + ), + ( + ToolResult::Image { + format: mixtape_core::tool::ImageFormat::Png, + data: vec![0x89, 0x50, 0x4E, 0x47], // PNG header + }, + "[Image: Png, 4 bytes]", + ), + ( + ToolResult::Document { + format: mixtape_core::tool::DocumentFormat::Pdf, + data: vec![0x25, 0x50, 0x44, 0x46], // PDF header + name: Some("report.pdf".to_string()), + }, + "[Document: Pdf, report.pdf, 4 bytes]", + ), + ( + ToolResult::Document { + format: mixtape_core::tool::DocumentFormat::Txt, + data: vec![0x48, 0x69], // "Hi" + name: None, + }, + "[Document: Txt, unnamed, 2 bytes]", + ), + ]; + + for (result, expected_content) in test_cases { + let mut ctx = ConversionContext::new("thread-1".to_string(), "run-1".to_string()); + + let event = AgentEvent::ToolCompleted { + tool_use_id: "tc-1".to_string(), + name: "test_tool".to_string(), + output: result, + duration: Duration::from_millis(100), + }; + + let events = convert_event(&event, &mut ctx); + assert_eq!(events.len(), 1); + + if let AguiEvent::ToolCallResult { content, .. } = &events[0] { + assert_eq!(content, expected_content); + } else { + panic!("Expected ToolCallResult"); + } + } +} + +#[test] +fn test_tool_failed_error_formatting() { + let mut ctx = ConversionContext::new("thread-1".to_string(), "run-1".to_string()); + + let event = AgentEvent::ToolFailed { + tool_use_id: "tc-1".to_string(), + name: "dangerous_tool".to_string(), + error: "Permission denied".to_string(), + duration: Duration::from_millis(10), + }; + + let events = convert_event(&event, &mut ctx); + assert_eq!(events.len(), 1); + + if let AguiEvent::ToolCallResult { content, .. } = &events[0] { + assert_eq!(content, "Error: Permission denied"); + } else { + panic!("Expected ToolCallResult"); + } +} + +#[test] +fn test_tool_call_args_with_complex_json() { + let mut ctx = ConversionContext::new("thread-1".to_string(), "run-1".to_string()); + + let complex_input = serde_json::json!({ + "nested": { + "array": [1, 2, 3], + "object": {"key": "value"} + }, + "string": "test", + "number": 42, + "boolean": true, + "null": null + }); + + let event = AgentEvent::ToolRequested { + tool_use_id: "tc-1".to_string(), + name: "complex_tool".to_string(), + input: complex_input.clone(), + }; + + let events = convert_event(&event, &mut ctx); + + // Find the ToolCallArgs event + let args_event = events + .iter() + .find(|e| matches!(e, AguiEvent::ToolCallArgs { .. })); + assert!(args_event.is_some()); + + if let AguiEvent::ToolCallArgs { delta, .. } = args_event.unwrap() { + // Should be valid JSON + let parsed: serde_json::Value = serde_json::from_str(delta).unwrap(); + assert_eq!(parsed, complex_input); + } +} + +#[test] +fn test_tool_call_args_serialization_fallback() { + // This tests the unwrap_or_default() on line 132 of convert.rs + // We can't easily trigger a serialization error with Value, but we verify + // the happy path works correctly + let mut ctx = ConversionContext::new("thread-1".to_string(), "run-1".to_string()); + + let event = AgentEvent::ToolRequested { + tool_use_id: "tc-1".to_string(), + name: "tool".to_string(), + input: serde_json::json!({"key": "value"}), + }; + + let events = convert_event(&event, &mut ctx); + + // Should produce valid JSON delta + let args_event = events + .iter() + .find(|e| matches!(e, AguiEvent::ToolCallArgs { .. })); + if let Some(AguiEvent::ToolCallArgs { delta, .. }) = args_event { + assert!(serde_json::from_str::(delta).is_ok()); + } +} + +#[test] +fn test_permission_required_with_special_characters_in_params() { + let mut ctx = ConversionContext::new("thread-1".to_string(), "run-1".to_string()); + + let event = AgentEvent::PermissionRequired { + proposal_id: "prop-1".to_string(), + tool_name: "shell".to_string(), + params: serde_json::json!({ + "cmd": "echo \"Hello\\nWorld\"", + "special": "chars: \t\r\n" + }), + params_hash: "hash123".to_string(), + }; + + let events = convert_event(&event, &mut ctx); + assert_eq!(events.len(), 1); + + if let AguiEvent::Interrupt { data, .. } = &events[0] { + // Verify special characters are preserved + assert_eq!(data.params["cmd"], "echo \"Hello\\nWorld\""); + assert_eq!(data.params["special"], "chars: \t\r\n"); + } else { + panic!("Expected Interrupt event"); + } +} + +#[test] +fn test_multiple_tools_in_sequence() { + let mut ctx = ConversionContext::new("thread-1".to_string(), "run-1".to_string()); + + // Start message + convert_event( + &AgentEvent::ModelCallStarted { + message_count: 1, + tool_count: 2, + timestamp: Instant::now(), + }, + &mut ctx, + ); + + // First tool + let tool1 = AgentEvent::ToolRequested { + tool_use_id: "tc-1".to_string(), + name: "tool1".to_string(), + input: serde_json::json!({}), + }; + convert_event(&tool1, &mut ctx); + assert!(ctx.current_message_id().is_none()); // Message ended by tool + + // First tool completes + let complete1 = AgentEvent::ToolCompleted { + tool_use_id: "tc-1".to_string(), + name: "tool1".to_string(), + output: ToolResult::Text("Result 1".to_string()), + duration: Duration::from_millis(100), + }; + convert_event(&complete1, &mut ctx); + + // Second tool + let tool2 = AgentEvent::ToolRequested { + tool_use_id: "tc-2".to_string(), + name: "tool2".to_string(), + input: serde_json::json!({}), + }; + let events = convert_event(&tool2, &mut ctx); + + // Should not try to end a message (none is active) + let has_message_end = events + .iter() + .any(|e| matches!(e, AguiEvent::TextMessageEnd { .. })); + assert!( + !has_message_end, + "Should not emit TextMessageEnd when no message is active" + ); +} + +#[test] +fn test_silent_events_produce_no_output() { + let mut ctx = ConversionContext::new("thread-1".to_string(), "run-1".to_string()); + + let silent_events = vec![ + AgentEvent::ToolExecuting { + tool_use_id: "tc-1".to_string(), + name: "tool".to_string(), + }, + AgentEvent::ModelCallCompleted { + response_content: "Done".to_string(), + tokens: None, + duration: Duration::from_secs(1), + stop_reason: None, + }, + AgentEvent::PermissionGranted { + tool_use_id: "tc-1".to_string(), + tool_name: "tool".to_string(), + scope: None, + }, + AgentEvent::PermissionDenied { + tool_use_id: "tc-1".to_string(), + tool_name: "tool".to_string(), + reason: "denied".to_string(), + }, + ]; + + for event in silent_events { + let events = convert_event(&event, &mut ctx); + assert_eq!( + events.len(), + 0, + "Event {:?} should produce no AG-UI events", + event + ); + } +} + +#[test] +fn test_context_thread_and_run_ids() { + let mut ctx = ConversionContext::new("custom-thread".to_string(), "custom-run".to_string()); + + let event = AgentEvent::RunStarted { + input: "test".to_string(), + timestamp: Instant::now(), + }; + + let events = convert_event(&event, &mut ctx); + + if let AguiEvent::RunStarted { thread_id, run_id } = &events[0] { + assert_eq!(thread_id, "custom-thread"); + assert_eq!(run_id, "custom-run"); + } else { + panic!("Expected RunStarted"); + } +} + +#[test] +fn test_conversion_context_message_id_operations() { + let mut ctx = ConversionContext::new("t1".to_string(), "r1".to_string()); + + // Initially no message ID + assert!(ctx.current_message_id().is_none()); + + // Set message ID + ctx.set_current_message_id("msg-1".to_string()); + assert_eq!(ctx.current_message_id(), Some("msg-1")); + + // Take message ID (removes it) + let id = ctx.take_current_message_id(); + assert_eq!(id, Some("msg-1".to_string())); + assert!(ctx.current_message_id().is_none()); + + // Take from empty returns None + let id = ctx.take_current_message_id(); + assert!(id.is_none()); +} diff --git a/mixtape-server/src/agui/events.rs b/mixtape-server/src/agui/events.rs new file mode 100644 index 0000000..5c03aa3 --- /dev/null +++ b/mixtape-server/src/agui/events.rs @@ -0,0 +1,214 @@ +//! AG-UI protocol event types. +//! +//! These types represent the ~17 standard AG-UI event types used for +//! agent-to-frontend communication. + +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +/// AG-UI protocol events. +/// +/// Events are serialized with a `type` field in SCREAMING_SNAKE_CASE +/// as per the AG-UI specification. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "SCREAMING_SNAKE_CASE")] +pub enum AguiEvent { + // ===== Lifecycle Events ===== + /// Agent run started. + RunStarted { + /// Thread ID for conversation continuity. + thread_id: String, + /// Unique run ID for this execution. + run_id: String, + }, + + /// Agent run finished successfully. + RunFinished { + /// Thread ID for conversation continuity. + thread_id: String, + /// Unique run ID for this execution. + run_id: String, + }, + + /// Agent run failed with an error. + RunError { + /// Error message describing the failure. + message: String, + /// Optional error code. + #[serde(skip_serializing_if = "Option::is_none")] + code: Option, + }, + + // ===== Text Message Events ===== + /// Start of a new text message. + TextMessageStart { + /// Unique message ID. + message_id: String, + /// Role of the message author. + role: MessageRole, + }, + + /// Incremental content for a text message. + TextMessageContent { + /// Message ID this content belongs to. + message_id: String, + /// Text delta to append. + delta: String, + }, + + /// End of a text message. + TextMessageEnd { + /// Message ID that is complete. + message_id: String, + }, + + // ===== Tool Call Events ===== + /// Start of a tool call. + ToolCallStart { + /// Unique tool call ID. + tool_call_id: String, + /// Name of the tool being called. + tool_call_name: String, + /// Optional parent message ID. + #[serde(skip_serializing_if = "Option::is_none")] + parent_message_id: Option, + }, + + /// Incremental arguments for a tool call. + ToolCallArgs { + /// Tool call ID this belongs to. + tool_call_id: String, + /// JSON argument delta. + delta: String, + }, + + /// End of tool call arguments. + ToolCallEnd { + /// Tool call ID that is complete. + tool_call_id: String, + }, + + /// Result from a tool call. + ToolCallResult { + /// Unique message ID for this result. + message_id: String, + /// Tool call ID this result is for. + tool_call_id: String, + /// Result content (text or JSON string). + content: String, + /// Role (typically Tool). + #[serde(skip_serializing_if = "Option::is_none")] + role: Option, + }, + + // ===== State Management Events ===== + /// Complete state snapshot. + StateSnapshot { + /// The complete state object. + snapshot: Value, + }, + + /// Incremental state update (JSON Patch). + StateDelta { + /// JSON Patch operations (RFC 6902). + delta: Vec, + }, + + // ===== Interrupt Events (Human-in-the-Loop) ===== + /// Interrupt requiring user action. + /// + /// Used for permission requests and other human-in-the-loop interactions. + Interrupt { + /// Unique interrupt ID. + interrupt_id: String, + /// Type of interrupt. + interrupt_type: InterruptType, + /// Data associated with the interrupt. + data: InterruptData, + }, +} + +/// Message author role. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum MessageRole { + /// User message. + User, + /// Assistant message. + Assistant, + /// System message. + System, + /// Tool result message. + Tool, +} + +/// Type of interrupt. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum InterruptType { + /// Tool requires user approval before execution. + ToolApproval, +} + +/// Data associated with an interrupt. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct InterruptData { + /// Tool use ID / proposal ID. + pub tool_use_id: String, + /// Name of the tool requiring approval. + pub tool_name: String, + /// Tool input parameters. + pub params: Value, + /// Hash of parameters for exact-match grants. + pub params_hash: String, +} + +/// Response to an interrupt from the client. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "action", rename_all = "snake_case")] +pub enum InterruptResponse { + /// Approve this single call without saving a grant. + ApproveOnce, + /// Trust this tool entirely and save a grant. + TrustTool { + /// Scope for the grant. + scope: GrantScope, + }, + /// Trust this exact call (matching parameters) and save a grant. + TrustExact { + /// Scope for the grant. + scope: GrantScope, + }, + /// Deny the request. + Deny { + /// Optional reason for denial. + #[serde(skip_serializing_if = "Option::is_none")] + reason: Option, + }, +} + +/// Scope for permission grants. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum GrantScope { + /// Grant lives in memory only, cleared when process exits. + Session, + /// Grant persists to storage. + Persistent, +} + +/// JSON Patch operation (RFC 6902). +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JsonPatchOp { + /// Operation type (add, remove, replace, move, copy, test). + pub op: String, + /// JSON Pointer path. + pub path: String, + /// Value for add/replace operations. + #[serde(skip_serializing_if = "Option::is_none")] + pub value: Option, +} + +#[cfg(test)] +#[path = "events_tests.rs"] +mod tests; diff --git a/mixtape-server/src/agui/events_tests.rs b/mixtape-server/src/agui/events_tests.rs new file mode 100644 index 0000000..b6f8261 --- /dev/null +++ b/mixtape-server/src/agui/events_tests.rs @@ -0,0 +1,448 @@ +//! Comprehensive tests for AG-UI event serialization/deserialization. +//! +//! These tests verify the external API contract between the server and frontend. + +use super::*; +use serde_json::json; + +#[test] +fn test_all_lifecycle_events_serialization() { + let cases = [ + ( + AguiEvent::RunStarted { + thread_id: "t1".to_string(), + run_id: "r1".to_string(), + }, + "RUN_STARTED", + ), + ( + AguiEvent::RunFinished { + thread_id: "t1".to_string(), + run_id: "r1".to_string(), + }, + "RUN_FINISHED", + ), + ( + AguiEvent::RunError { + message: "failure".to_string(), + code: None, + }, + "RUN_ERROR", + ), + ( + AguiEvent::RunError { + message: "failure".to_string(), + code: Some("E001".to_string()), + }, + "RUN_ERROR", + ), + ]; + + for (event, expected_type) in cases { + let json = serde_json::to_string(&event).unwrap(); + assert!( + json.contains(&format!("\"type\":\"{}\"", expected_type)), + "Event {:?} should serialize with type {}", + event, + expected_type + ); + } +} + +#[test] +fn test_all_message_events_roundtrip() { + let events = vec![ + AguiEvent::TextMessageStart { + message_id: "msg-1".to_string(), + role: MessageRole::Assistant, + }, + AguiEvent::TextMessageContent { + message_id: "msg-1".to_string(), + delta: "Hello world".to_string(), + }, + AguiEvent::TextMessageEnd { + message_id: "msg-1".to_string(), + }, + ]; + + for event in events { + let json = serde_json::to_string(&event).unwrap(); + let deserialized: AguiEvent = serde_json::from_str(&json).unwrap(); + + // Verify type preservation through roundtrip + match (&event, &deserialized) { + (AguiEvent::TextMessageStart { .. }, AguiEvent::TextMessageStart { .. }) => {} + (AguiEvent::TextMessageContent { .. }, AguiEvent::TextMessageContent { .. }) => {} + (AguiEvent::TextMessageEnd { .. }, AguiEvent::TextMessageEnd { .. }) => {} + _ => panic!("Event type changed during roundtrip"), + } + } +} + +#[test] +fn test_tool_call_events_complete_sequence() { + let cases = [ + ( + AguiEvent::ToolCallStart { + tool_call_id: "tc-1".to_string(), + tool_call_name: "echo".to_string(), + parent_message_id: None, + }, + "TOOL_CALL_START", + ), + ( + AguiEvent::ToolCallStart { + tool_call_id: "tc-1".to_string(), + tool_call_name: "echo".to_string(), + parent_message_id: Some("msg-1".to_string()), + }, + "TOOL_CALL_START", + ), + ( + AguiEvent::ToolCallArgs { + tool_call_id: "tc-1".to_string(), + delta: r#"{"arg":"value"}"#.to_string(), + }, + "TOOL_CALL_ARGS", + ), + ( + AguiEvent::ToolCallEnd { + tool_call_id: "tc-1".to_string(), + }, + "TOOL_CALL_END", + ), + ( + AguiEvent::ToolCallResult { + message_id: "result-1".to_string(), + tool_call_id: "tc-1".to_string(), + content: "Success".to_string(), + role: Some(MessageRole::Tool), + }, + "TOOL_CALL_RESULT", + ), + ( + AguiEvent::ToolCallResult { + message_id: "result-1".to_string(), + tool_call_id: "tc-1".to_string(), + content: "Success".to_string(), + role: None, + }, + "TOOL_CALL_RESULT", + ), + ]; + + for (event, expected_type) in cases { + let json = serde_json::to_string(&event).unwrap(); + assert!( + json.contains(&format!("\"type\":\"{}\"", expected_type)), + "Event {:?} should serialize with type {}", + event, + expected_type + ); + } +} + +#[test] +fn test_state_events_with_complex_data() { + // Snapshot with nested JSON + let snapshot = json!({ + "users": [ + {"id": 1, "name": "Alice"}, + {"id": 2, "name": "Bob"} + ], + "count": 2 + }); + + let event = AguiEvent::StateSnapshot { + snapshot: snapshot.clone(), + }; + + let json = serde_json::to_string(&event).unwrap(); + let deserialized: AguiEvent = serde_json::from_str(&json).unwrap(); + + if let AguiEvent::StateSnapshot { + snapshot: deser_snapshot, + } = deserialized + { + assert_eq!(snapshot, deser_snapshot); + } else { + panic!("Wrong event type after deserialization"); + } +} + +#[test] +fn test_state_delta_with_json_patch_ops() { + let delta_ops = vec![ + JsonPatchOp { + op: "add".to_string(), + path: "/users/2".to_string(), + value: Some(json!({"id": 3, "name": "Charlie"})), + }, + JsonPatchOp { + op: "remove".to_string(), + path: "/users/0".to_string(), + value: None, + }, + JsonPatchOp { + op: "replace".to_string(), + path: "/count".to_string(), + value: Some(json!(2)), + }, + ]; + + let event = AguiEvent::StateDelta { delta: delta_ops }; + + let json = serde_json::to_string(&event).unwrap(); + assert!(json.contains("\"type\":\"STATE_DELTA\"")); + + // Verify roundtrip + let deserialized: AguiEvent = serde_json::from_str(&json).unwrap(); + if let AguiEvent::StateDelta { delta } = deserialized { + assert_eq!(delta.len(), 3); + assert_eq!(delta[0].op, "add"); + assert_eq!(delta[1].op, "remove"); + assert_eq!(delta[2].op, "replace"); + assert!(delta[1].value.is_none()); // remove has no value + } else { + panic!("Wrong event type after deserialization"); + } +} + +#[test] +fn test_interrupt_event_serialization() { + let event = AguiEvent::Interrupt { + interrupt_id: "int-1".to_string(), + interrupt_type: InterruptType::ToolApproval, + data: InterruptData { + tool_use_id: "tu-1".to_string(), + tool_name: "dangerous_cmd".to_string(), + params: json!({"cmd": "rm -rf /"}), + params_hash: "abc123".to_string(), + }, + }; + + let json = serde_json::to_string(&event).unwrap(); + assert!(json.contains("\"type\":\"INTERRUPT\"")); + assert!(json.contains("\"interrupt_type\":\"tool_approval\"")); + assert!(json.contains("dangerous_cmd")); +} + +#[test] +fn test_message_role_all_variants() { + let cases = [ + (MessageRole::User, "user"), + (MessageRole::Assistant, "assistant"), + (MessageRole::System, "system"), + (MessageRole::Tool, "tool"), + ]; + + for (role, expected_str) in cases { + let json = serde_json::to_string(&role).unwrap(); + assert_eq!(json, format!("\"{}\"", expected_str)); + + // Roundtrip + let deserialized: MessageRole = serde_json::from_str(&json).unwrap(); + assert_eq!(role, deserialized); + } +} + +#[test] +fn test_interrupt_response_all_variants() { + // Test approve_once + let json = r#"{"action":"approve_once"}"#; + let response: InterruptResponse = serde_json::from_str(json).unwrap(); + assert!(matches!(response, InterruptResponse::ApproveOnce)); + + // Test trust_tool with session scope + let json = r#"{"action":"trust_tool","scope":"session"}"#; + let response: InterruptResponse = serde_json::from_str(json).unwrap(); + assert!(matches!( + response, + InterruptResponse::TrustTool { + scope: GrantScope::Session + } + )); + + // Test trust_tool with persistent scope + let json = r#"{"action":"trust_tool","scope":"persistent"}"#; + let response: InterruptResponse = serde_json::from_str(json).unwrap(); + assert!(matches!( + response, + InterruptResponse::TrustTool { + scope: GrantScope::Persistent + } + )); + + // Test trust_exact with session scope + let json = r#"{"action":"trust_exact","scope":"session"}"#; + let response: InterruptResponse = serde_json::from_str(json).unwrap(); + assert!(matches!( + response, + InterruptResponse::TrustExact { + scope: GrantScope::Session + } + )); + + // Test trust_exact with persistent scope + let json = r#"{"action":"trust_exact","scope":"persistent"}"#; + let response: InterruptResponse = serde_json::from_str(json).unwrap(); + assert!(matches!( + response, + InterruptResponse::TrustExact { + scope: GrantScope::Persistent + } + )); + + // Test deny without reason + let json = r#"{"action":"deny"}"#; + let response: InterruptResponse = serde_json::from_str(json).unwrap(); + assert!(matches!(response, InterruptResponse::Deny { reason: None })); + + // Test deny with reason + let json = r#"{"action":"deny","reason":"Too dangerous"}"#; + let response: InterruptResponse = serde_json::from_str(json).unwrap(); + assert!(matches!( + response, + InterruptResponse::Deny { reason: Some(_) } + )); +} + +#[test] +fn test_grant_scope_all_variants() { + let cases = [ + (GrantScope::Session, "session"), + (GrantScope::Persistent, "persistent"), + ]; + + for (scope, expected_str) in cases { + let json = serde_json::to_string(&scope).unwrap(); + assert_eq!(json, format!("\"{}\"", expected_str)); + + // Roundtrip + let deserialized: GrantScope = serde_json::from_str(&json).unwrap(); + assert_eq!(scope, deserialized); + } +} + +#[test] +fn test_event_with_empty_strings() { + // Empty strings should be valid + let event = AguiEvent::TextMessageContent { + message_id: "".to_string(), + delta: "".to_string(), + }; + + let json = serde_json::to_string(&event).unwrap(); + let deserialized: AguiEvent = serde_json::from_str(&json).unwrap(); + + if let AguiEvent::TextMessageContent { message_id, delta } = deserialized { + assert_eq!(message_id, ""); + assert_eq!(delta, ""); + } else { + panic!("Wrong event type"); + } +} + +#[test] +fn test_event_with_special_characters() { + // Test special characters that need escaping in JSON + let special_chars = "Hello \"world\"\n\t\r\\slash/forward"; + + let event = AguiEvent::TextMessageContent { + message_id: "msg-1".to_string(), + delta: special_chars.to_string(), + }; + + let json = serde_json::to_string(&event).unwrap(); + let deserialized: AguiEvent = serde_json::from_str(&json).unwrap(); + + if let AguiEvent::TextMessageContent { delta, .. } = deserialized { + assert_eq!(delta, special_chars); + } else { + panic!("Wrong event type"); + } +} + +#[test] +fn test_event_with_unicode() { + // Test Unicode characters + let unicode_text = "Hello 世界 🌍 Привет مرحبا"; + + let event = AguiEvent::TextMessageContent { + message_id: "msg-1".to_string(), + delta: unicode_text.to_string(), + }; + + let json = serde_json::to_string(&event).unwrap(); + let deserialized: AguiEvent = serde_json::from_str(&json).unwrap(); + + if let AguiEvent::TextMessageContent { delta, .. } = deserialized { + assert_eq!(delta, unicode_text); + } else { + panic!("Wrong event type"); + } +} + +#[test] +fn test_event_with_very_long_strings() { + // Test handling of large content + let large_delta = "x".repeat(10_000); + + let event = AguiEvent::TextMessageContent { + message_id: "msg-1".to_string(), + delta: large_delta.clone(), + }; + + let json = serde_json::to_string(&event).unwrap(); + let deserialized: AguiEvent = serde_json::from_str(&json).unwrap(); + + if let AguiEvent::TextMessageContent { delta, .. } = deserialized { + assert_eq!(delta.len(), 10_000); + assert_eq!(delta, large_delta); + } else { + panic!("Wrong event type"); + } +} + +#[test] +fn test_skip_serializing_if_behavior() { + // Test that None fields are omitted from JSON + let event = AguiEvent::RunError { + message: "error".to_string(), + code: None, + }; + + let json = serde_json::to_string(&event).unwrap(); + assert!( + !json.contains("\"code\""), + "None code should be omitted from JSON" + ); + + // Test that Some fields are included + let event = AguiEvent::RunError { + message: "error".to_string(), + code: Some("E001".to_string()), + }; + + let json = serde_json::to_string(&event).unwrap(); + assert!( + json.contains("\"code\":\"E001\""), + "Some code should be included in JSON" + ); +} + +#[test] +fn test_malformed_interrupt_response_fails_gracefully() { + let bad_json_cases = [ + r#"{"action":"unknown_action"}"#, // Invalid action + r#"{"action":"trust_tool"}"#, // Missing required scope + r#"{"action":"trust_exact"}"#, // Missing required scope + r#"{}"#, // Missing action + r#"{"scope":"session"}"#, // Action without scope + ]; + + for bad_json in bad_json_cases { + let result: Result = serde_json::from_str(bad_json); + assert!(result.is_err(), "Should fail to deserialize: {}", bad_json); + } +} diff --git a/mixtape-server/src/agui/handler.rs b/mixtape-server/src/agui/handler.rs new file mode 100644 index 0000000..90eaaa1 --- /dev/null +++ b/mixtape-server/src/agui/handler.rs @@ -0,0 +1,197 @@ +//! HTTP handlers for AG-UI protocol endpoints. + +use std::convert::Infallible; +use std::sync::Arc; + +use axum::{ + extract::State, + response::sse::{Event, KeepAlive, Sse}, + Json, +}; +use futures::stream::Stream; +use mixtape_core::events::AgentEvent; +use mixtape_core::permission::{AuthorizationResponse, Grant, Scope}; +use serde::Deserialize; +use tokio::sync::mpsc; +use tokio_stream::wrappers::ReceiverStream; +use tokio_stream::StreamExt; + +use super::convert::{convert_event, ConversionContext}; +use super::events::{AguiEvent, GrantScope, InterruptResponse}; +use crate::error::ServerError; +use crate::state::AppState; + +/// Request body for running an agent. +#[derive(Debug, Deserialize)] +pub struct AgentRequest { + /// User message to send to the agent. + pub message: String, + /// Thread ID for conversation continuity. + #[serde(default)] + pub thread_id: Option, + /// Run ID for this specific run. + #[serde(default)] + pub run_id: Option, + /// Optional run options (included for AG-UI protocol compatibility). + #[serde(default)] + #[allow(dead_code)] + pub options: RunOptions, +} + +/// Options for agent run. +#[derive(Debug, Deserialize)] +pub struct RunOptions { + /// Whether to stream responses (always true for AG-UI, included for compatibility). + #[serde(default = "default_true")] + #[allow(dead_code)] + pub stream: bool, +} + +impl Default for RunOptions { + fn default() -> Self { + Self { stream: true } + } +} + +fn default_true() -> bool { + true +} + +/// Request body for responding to an interrupt (permission request). +#[derive(Debug, Deserialize)] +pub struct InterruptRequest { + /// The interrupt ID to respond to. + pub interrupt_id: String, + /// Tool name (from interrupt data). + pub tool_name: String, + /// Params hash (from interrupt data, for exact grants). + #[serde(default)] + pub params_hash: Option, + /// The response action. + pub response: InterruptResponse, +} + +/// Handle AG-UI protocol requests. +/// +/// Accepts POST with AgentRequest body, returns SSE stream of AG-UI events. +pub async fn agui_handler( + State(state): State, + Json(request): Json, +) -> Sse>> { + let agent = state.agent.clone(); + let thread_id = request + .thread_id + .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()); + let run_id = request + .run_id + .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()); + let message = request.message; + + // Create channel for AG-UI events + let (tx, rx) = mpsc::channel::(100); + + // Spawn agent run task + let tx_for_task = tx.clone(); + let thread_id_clone = thread_id.clone(); + let run_id_clone = run_id.clone(); + + tokio::spawn(async move { + // Create conversion context with shared state + let ctx = Arc::new(parking_lot::Mutex::new(ConversionContext::new( + thread_id_clone, + run_id_clone, + ))); + + // Add hook to forward events (capture hook ID for cleanup) + let ctx_for_hook = ctx.clone(); + let tx_for_hook = tx_for_task.clone(); + let hook_id = agent.add_hook(move |event: &AgentEvent| { + let mut ctx_guard = ctx_for_hook.lock(); + let agui_events = convert_event(event, &mut ctx_guard); + for agui_event in agui_events { + // Non-blocking send - drop events if channel is full + let _ = tx_for_hook.try_send(agui_event); + } + }); + + // Run the agent + match agent.run(&message).await { + Ok(_response) => { + // RunCompleted event is already emitted via hook + } + Err(e) => { + let _ = tx_for_task.try_send(AguiEvent::RunError { + message: e.to_string(), + code: None, + }); + } + } + + // Clean up: remove the hook after the run completes + agent.remove_hook(hook_id); + }); + + // Convert channel to SSE stream + let stream = ReceiverStream::new(rx).map(|event| { + let json = serde_json::to_string(&event).unwrap_or_else(|e| { + serde_json::json!({ + "type": "RUN_ERROR", + "message": format!("Failed to serialize event: {}", e) + }) + .to_string() + }); + Ok::<_, Infallible>(Event::default().data(json)) + }); + + Sse::new(stream).keep_alive(KeepAlive::default()) +} + +/// Handle interrupt responses (permission decisions). +/// +/// This endpoint receives permission decisions from the frontend and +/// forwards them to the agent. +pub async fn interrupt_handler( + State(state): State, + Json(request): Json, +) -> Result, ServerError> { + // Convert InterruptResponse to AuthorizationResponse + let auth_response = match request.response { + InterruptResponse::ApproveOnce => AuthorizationResponse::Once, + InterruptResponse::TrustTool { scope } => { + let core_scope = convert_scope(scope); + AuthorizationResponse::Trust { + grant: Grant::tool(&request.tool_name).with_scope(core_scope), + } + } + InterruptResponse::TrustExact { scope } => { + let core_scope = convert_scope(scope); + let hash = request.params_hash.ok_or_else(|| { + ServerError::InvalidRequest("params_hash required for TrustExact".to_string()) + })?; + AuthorizationResponse::Trust { + grant: Grant::exact(&request.tool_name, &hash).with_scope(core_scope), + } + } + InterruptResponse::Deny { reason } => AuthorizationResponse::Deny { reason }, + }; + + state + .agent + .respond_to_authorization(&request.interrupt_id, auth_response) + .await + .map_err(|e| ServerError::Permission(e.to_string()))?; + + Ok(Json(serde_json::json!({ "status": "ok" }))) +} + +/// Convert AG-UI GrantScope to mixtape-core Scope. +fn convert_scope(scope: GrantScope) -> Scope { + match scope { + GrantScope::Session => Scope::Session, + GrantScope::Persistent => Scope::Persistent, + } +} + +#[cfg(test)] +#[path = "handler_tests.rs"] +mod tests; diff --git a/mixtape-server/src/agui/handler_tests.rs b/mixtape-server/src/agui/handler_tests.rs new file mode 100644 index 0000000..ee4f72a --- /dev/null +++ b/mixtape-server/src/agui/handler_tests.rs @@ -0,0 +1,335 @@ +//! Tests for AG-UI HTTP handlers focusing on error paths and edge cases. + +use super::*; + +#[test] +fn test_agent_request_all_fields() { + let json = r#"{ + "message": "Hello", + "thread_id": "thread-123", + "run_id": "run-456", + "options": {"stream": false} + }"#; + + let request: AgentRequest = serde_json::from_str(json).unwrap(); + assert_eq!(request.message, "Hello"); + assert_eq!(request.thread_id, Some("thread-123".to_string())); + assert_eq!(request.run_id, Some("run-456".to_string())); + assert!(!request.options.stream); +} + +#[test] +fn test_agent_request_minimal() { + let json = r#"{"message": "Hello"}"#; + let request: AgentRequest = serde_json::from_str(json).unwrap(); + + assert_eq!(request.message, "Hello"); + assert!(request.thread_id.is_none()); + assert!(request.run_id.is_none()); + assert!(request.options.stream); // default is true +} + +#[test] +fn test_agent_request_empty_message() { + // Empty message is valid (let agent decide if it's an error) + let json = r#"{"message": ""}"#; + let request: AgentRequest = serde_json::from_str(json).unwrap(); + assert_eq!(request.message, ""); +} + +#[test] +fn test_agent_request_missing_message_field() { + let json = r#"{"thread_id": "thread-123"}"#; + let result: Result = serde_json::from_str(json); + assert!( + result.is_err(), + "Should fail without required message field" + ); +} + +#[test] +fn test_agent_request_with_special_characters() { + let json = r#"{"message": "Hello\n\"World\"\t\r"}"#; + let request: AgentRequest = serde_json::from_str(json).unwrap(); + assert_eq!(request.message, "Hello\n\"World\"\t\r"); +} + +#[test] +fn test_agent_request_with_unicode() { + let json = r#"{"message": "Hello 世界 🌍"}"#; + let request: AgentRequest = serde_json::from_str(json).unwrap(); + assert_eq!(request.message, "Hello 世界 🌍"); +} + +#[test] +fn test_run_options_default() { + let options = RunOptions::default(); + assert!(options.stream); +} + +#[test] +fn test_interrupt_request_approve_once() { + let json = r#"{ + "interrupt_id": "int-1", + "tool_name": "echo", + "response": {"action": "approve_once"} + }"#; + + let request: InterruptRequest = serde_json::from_str(json).unwrap(); + assert_eq!(request.interrupt_id, "int-1"); + assert_eq!(request.tool_name, "echo"); + assert!(request.params_hash.is_none()); + assert!(matches!(request.response, InterruptResponse::ApproveOnce)); +} + +#[test] +fn test_interrupt_request_trust_tool_session() { + let json = r#"{ + "interrupt_id": "int-1", + "tool_name": "safe_tool", + "response": {"action": "trust_tool", "scope": "session"} + }"#; + + let request: InterruptRequest = serde_json::from_str(json).unwrap(); + assert!(matches!( + request.response, + InterruptResponse::TrustTool { + scope: GrantScope::Session + } + )); +} + +#[test] +fn test_interrupt_request_trust_tool_persistent() { + let json = r#"{ + "interrupt_id": "int-1", + "tool_name": "safe_tool", + "response": {"action": "trust_tool", "scope": "persistent"} + }"#; + + let request: InterruptRequest = serde_json::from_str(json).unwrap(); + assert!(matches!( + request.response, + InterruptResponse::TrustTool { + scope: GrantScope::Persistent + } + )); +} + +#[test] +fn test_interrupt_request_trust_exact_with_hash() { + let json = r#"{ + "interrupt_id": "int-1", + "tool_name": "cmd", + "params_hash": "abc123", + "response": {"action": "trust_exact", "scope": "session"} + }"#; + + let request: InterruptRequest = serde_json::from_str(json).unwrap(); + assert_eq!(request.params_hash, Some("abc123".to_string())); + assert!(matches!( + request.response, + InterruptResponse::TrustExact { + scope: GrantScope::Session + } + )); +} + +#[test] +fn test_interrupt_request_trust_exact_without_hash() { + // This should deserialize fine - the handler will catch the missing hash + let json = r#"{ + "interrupt_id": "int-1", + "tool_name": "cmd", + "response": {"action": "trust_exact", "scope": "session"} + }"#; + + let request: InterruptRequest = serde_json::from_str(json).unwrap(); + assert!(request.params_hash.is_none()); + assert!(matches!( + request.response, + InterruptResponse::TrustExact { .. } + )); +} + +#[test] +fn test_interrupt_request_deny_without_reason() { + let json = r#"{ + "interrupt_id": "int-1", + "tool_name": "dangerous", + "response": {"action": "deny"} + }"#; + + let request: InterruptRequest = serde_json::from_str(json).unwrap(); + assert!(matches!( + request.response, + InterruptResponse::Deny { reason: None } + )); +} + +#[test] +fn test_interrupt_request_deny_with_reason() { + let json = r#"{ + "interrupt_id": "int-1", + "tool_name": "dangerous", + "response": {"action": "deny", "reason": "Too risky"} + }"#; + + let request: InterruptRequest = serde_json::from_str(json).unwrap(); + if let InterruptResponse::Deny { reason } = request.response { + assert_eq!(reason, Some("Too risky".to_string())); + } else { + panic!("Expected Deny response"); + } +} + +#[test] +fn test_interrupt_request_malformed() { + let bad_cases = vec![ + r#"{}"#, // Missing all fields + r#"{"interrupt_id": "int-1"}"#, // Missing tool_name and response + r#"{"interrupt_id": "int-1", "tool_name": "echo"}"#, // Missing response + r#"{"response": {"action": "approve_once"}}"#, // Missing interrupt_id and tool_name + ]; + + for bad_json in bad_cases { + let result: Result = serde_json::from_str(bad_json); + assert!( + result.is_err(), + "Should reject malformed request: {}", + bad_json + ); + } +} + +#[test] +fn test_scope_conversion_all_variants() { + use super::convert_scope; + + let cases = [ + ( + GrantScope::Session, + mixtape_core::permission::Scope::Session, + ), + ( + GrantScope::Persistent, + mixtape_core::permission::Scope::Persistent, + ), + ]; + + for (agui_scope, core_scope) in cases { + let converted = convert_scope(agui_scope); + assert!( + matches!( + (&converted, &core_scope), + ( + mixtape_core::permission::Scope::Session, + mixtape_core::permission::Scope::Session + ) | ( + mixtape_core::permission::Scope::Persistent, + mixtape_core::permission::Scope::Persistent + ) + ), + "Scope conversion failed for {:?}", + agui_scope + ); + } +} + +#[test] +fn test_interrupt_request_empty_strings() { + // Empty strings should be valid (though semantically wrong) + let json = r#"{ + "interrupt_id": "", + "tool_name": "", + "response": {"action": "approve_once"} + }"#; + + let request: InterruptRequest = serde_json::from_str(json).unwrap(); + assert_eq!(request.interrupt_id, ""); + assert_eq!(request.tool_name, ""); +} + +#[test] +fn test_interrupt_request_with_special_characters() { + let json = r#"{ + "interrupt_id": "int-\"123\"", + "tool_name": "tool\nname", + "params_hash": "hash\t123", + "response": {"action": "approve_once"} + }"#; + + let request: InterruptRequest = serde_json::from_str(json).unwrap(); + assert_eq!(request.interrupt_id, "int-\"123\""); + assert_eq!(request.tool_name, "tool\nname"); + assert_eq!(request.params_hash, Some("hash\t123".to_string())); +} + +#[test] +fn test_agent_request_large_message() { + // Test with very large message + let large_message = "x".repeat(100_000); + let json = format!(r#"{{"message": "{}"}}"#, large_message); + + let request: AgentRequest = serde_json::from_str(&json).unwrap(); + assert_eq!(request.message.len(), 100_000); + assert_eq!(request.message, large_message); +} + +#[test] +fn test_interrupt_request_deny_with_long_reason() { + let long_reason = "r".repeat(10_000); + let json = format!( + r#"{{ + "interrupt_id": "int-1", + "tool_name": "tool", + "response": {{"action": "deny", "reason": "{}"}} + }}"#, + long_reason + ); + + let request: InterruptRequest = serde_json::from_str(&json).unwrap(); + if let InterruptResponse::Deny { reason } = request.response { + assert_eq!(reason.unwrap().len(), 10_000); + } else { + panic!("Expected Deny response"); + } +} + +#[test] +fn test_run_options_explicit_stream_false() { + let json = r#"{"stream": false}"#; + let options: RunOptions = serde_json::from_str(json).unwrap(); + assert!(!options.stream); +} + +#[test] +fn test_run_options_explicit_stream_true() { + let json = r#"{"stream": true}"#; + let options: RunOptions = serde_json::from_str(json).unwrap(); + assert!(options.stream); +} + +#[test] +fn test_run_options_empty_object() { + let json = r#"{}"#; + let options: RunOptions = serde_json::from_str(json).unwrap(); + assert!(options.stream); // default is true +} + +#[test] +fn test_agent_request_null_optional_fields() { + // Test with null for Option fields (these work with null) + let json = r#"{ + "message": "Hello", + "thread_id": null, + "run_id": null + }"#; + + let request: AgentRequest = serde_json::from_str(json).unwrap(); + assert_eq!(request.message, "Hello"); + assert!(request.thread_id.is_none()); + assert!(request.run_id.is_none()); + // options should use default when omitted + assert!(request.options.stream); +} diff --git a/mixtape-server/src/agui/mod.rs b/mixtape-server/src/agui/mod.rs new file mode 100644 index 0000000..66b6cb8 --- /dev/null +++ b/mixtape-server/src/agui/mod.rs @@ -0,0 +1,30 @@ +//! AG-UI protocol support for CopilotKit integration. +//! +//! This module provides SSE streaming endpoints that implement the AG-UI protocol, +//! enabling integration with CopilotKit and other AG-UI compatible frontends. +//! +//! # Overview +//! +//! AG-UI (Agent-User Interaction) is an open, event-based protocol that standardizes +//! how AI agents connect to user-facing applications. It uses Server-Sent Events (SSE) +//! to stream agent events to the frontend in real-time. +//! +//! # Event Mapping +//! +//! Mixtape's `AgentEvent`s are mapped to AG-UI events: +//! +//! | AgentEvent | AG-UI Event(s) | +//! |------------|----------------| +//! | `RunStarted` | `RUN_STARTED` | +//! | `RunCompleted` | `TEXT_MESSAGE_END`, `RUN_FINISHED` | +//! | `RunFailed` | `RUN_ERROR` | +//! | `ModelCallStarted` | `TEXT_MESSAGE_START` | +//! | `ModelCallStreaming` | `TEXT_MESSAGE_CONTENT` | +//! | `ToolRequested` | `TOOL_CALL_START`, `TOOL_CALL_ARGS`, `TOOL_CALL_END` | +//! | `ToolCompleted` | `TOOL_CALL_RESULT` | +//! | `ToolFailed` | `TOOL_CALL_RESULT` (with error) | +//! | `PermissionRequired` | `INTERRUPT` | + +pub mod convert; +pub mod events; +pub mod handler; diff --git a/mixtape-server/src/error.rs b/mixtape-server/src/error.rs new file mode 100644 index 0000000..5076ab1 --- /dev/null +++ b/mixtape-server/src/error.rs @@ -0,0 +1,60 @@ +//! Error types for the mixtape server. + +use axum::{ + http::StatusCode, + response::{IntoResponse, Response}, + Json, +}; + +/// Errors that can occur when building a router. +#[derive(Debug, thiserror::Error)] +pub enum BuildError { + /// No endpoints were configured. + #[error("No endpoints configured. Call .with_agui() before .build()")] + NoEndpoints, +} + +/// Errors that can occur in the mixtape server. +#[derive(Debug, thiserror::Error)] +pub enum ServerError { + /// Error from the agent during execution. + #[error("Agent error: {0}")] + Agent(#[from] mixtape_core::AgentError), + + /// Permission-related error. + #[error("Permission error: {0}")] + Permission(String), + + /// Invalid request from client. + #[error("Invalid request: {0}")] + InvalidRequest(String), + + /// Internal server error. + #[error("Internal error: {0}")] + Internal(String), +} + +impl IntoResponse for ServerError { + fn into_response(self) -> Response { + let (status, message) = match &self { + ServerError::Agent(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()), + ServerError::Permission(e) => (StatusCode::FORBIDDEN, e.clone()), + ServerError::InvalidRequest(e) => (StatusCode::BAD_REQUEST, e.clone()), + ServerError::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.clone()), + }; + + let body = Json(serde_json::json!({ + "error": message, + "code": status.as_u16(), + })); + + (status, body).into_response() + } +} + +/// Result type alias for server operations. +pub type ServerResult = Result; + +#[cfg(test)] +#[path = "error_tests.rs"] +mod tests; diff --git a/mixtape-server/src/error_tests.rs b/mixtape-server/src/error_tests.rs new file mode 100644 index 0000000..57f400f --- /dev/null +++ b/mixtape-server/src/error_tests.rs @@ -0,0 +1,208 @@ +//! Tests for error handling and IntoResponse implementation. + +use crate::error::*; +use axum::{http::StatusCode, response::IntoResponse}; + +#[test] +fn test_server_error_agent_variant() { + // Create a mock agent error + let agent_error = mixtape_core::AgentError::NoResponse; + let server_error = ServerError::Agent(agent_error); + + let response = server_error.into_response(); + let (parts, _body) = response.into_parts(); + + assert_eq!(parts.status, StatusCode::INTERNAL_SERVER_ERROR); + + // Verify body format (need to consume body to check) + // For now, just verify status code +} + +#[test] +fn test_server_error_permission_variant() { + let error = ServerError::Permission("Access denied".to_string()); + + let response = error.into_response(); + let (parts, _body) = response.into_parts(); + + assert_eq!(parts.status, StatusCode::FORBIDDEN); +} + +#[test] +fn test_server_error_invalid_request_variant() { + let error = ServerError::InvalidRequest("Bad input".to_string()); + + let response = error.into_response(); + let (parts, _body) = response.into_parts(); + + assert_eq!(parts.status, StatusCode::BAD_REQUEST); +} + +#[test] +fn test_server_error_internal_variant() { + let error = ServerError::Internal("Something went wrong".to_string()); + + let response = error.into_response(); + let (parts, _body) = response.into_parts(); + + assert_eq!(parts.status, StatusCode::INTERNAL_SERVER_ERROR); +} + +#[test] +fn test_server_error_display() { + let cases = [ + ( + ServerError::Permission("denied".to_string()), + "Permission error: denied", + ), + ( + ServerError::InvalidRequest("bad".to_string()), + "Invalid request: bad", + ), + ( + ServerError::Internal("oops".to_string()), + "Internal error: oops", + ), + ]; + + for (error, expected) in cases { + assert_eq!(error.to_string(), expected); + } +} + +#[test] +fn test_server_error_from_agent_error() { + let agent_error = mixtape_core::AgentError::EmptyResponse; + let server_error: ServerError = agent_error.into(); + + assert!(matches!(server_error, ServerError::Agent(_))); +} + +#[test] +fn test_server_error_permission_with_empty_message() { + let error = ServerError::Permission("".to_string()); + let response = error.into_response(); + let (parts, _body) = response.into_parts(); + + assert_eq!(parts.status, StatusCode::FORBIDDEN); +} + +#[test] +fn test_server_error_invalid_request_with_special_characters() { + let error = ServerError::InvalidRequest("Bad input: \"value\"\n\t".to_string()); + let response = error.into_response(); + let (parts, _body) = response.into_parts(); + + assert_eq!(parts.status, StatusCode::BAD_REQUEST); + // The message should be properly escaped in the JSON response +} + +#[test] +fn test_server_error_with_very_long_message() { + let long_message = "x".repeat(10_000); + let error = ServerError::Internal(long_message.clone()); + + let response = error.into_response(); + let (parts, _body) = response.into_parts(); + + assert_eq!(parts.status, StatusCode::INTERNAL_SERVER_ERROR); + // Should handle large error messages without panicking +} + +#[test] +fn test_server_error_with_unicode() { + let error = ServerError::InvalidRequest("错误的输入 🚫".to_string()); + let response = error.into_response(); + let (parts, _body) = response.into_parts(); + + assert_eq!(parts.status, StatusCode::BAD_REQUEST); +} + +#[test] +fn test_status_code_correctness() { + // Verify status codes match HTTP semantics + let test_cases = [ + ( + ServerError::Permission("".to_string()), + StatusCode::FORBIDDEN, + 403, + ), + ( + ServerError::InvalidRequest("".to_string()), + StatusCode::BAD_REQUEST, + 400, + ), + ( + ServerError::Internal("".to_string()), + StatusCode::INTERNAL_SERVER_ERROR, + 500, + ), + ]; + + for (error, expected_status, expected_code) in test_cases { + let response = error.into_response(); + let (parts, _body) = response.into_parts(); + + assert_eq!(parts.status, expected_status); + assert_eq!(parts.status.as_u16(), expected_code); + } +} + +#[test] +fn test_agent_error_conversion_preserves_message() { + let agent_error = mixtape_core::AgentError::ToolDenied("Permission denied".to_string()); + let server_error: ServerError = agent_error.into(); + + // The display should contain information about the error + let display = server_error.to_string(); + assert!( + display.contains("Tool execution denied") || display.contains("Permission denied"), + "Error message should be preserved" + ); +} + +#[test] +fn test_error_types_are_send_sync() { + // Verify error types can be sent across threads + fn is_send() {} + fn is_sync() {} + + is_send::(); + is_sync::(); +} + +#[test] +fn test_multiple_permission_errors() { + // Verify consistent behavior across multiple instances + let errors = vec![ + ServerError::Permission("Error 1".to_string()), + ServerError::Permission("Error 2".to_string()), + ServerError::Permission("Error 3".to_string()), + ]; + + for error in errors { + let response = error.into_response(); + let (parts, _body) = response.into_parts(); + assert_eq!(parts.status, StatusCode::FORBIDDEN); + } +} + +#[test] +fn test_error_debug_output() { + let error = ServerError::InvalidRequest("test".to_string()); + let debug_str = format!("{:?}", error); + + // Should contain variant name and message + assert!(debug_str.contains("InvalidRequest")); + assert!(debug_str.contains("test")); +} + +#[test] +fn test_error_nested_quotes() { + let error = ServerError::InvalidRequest(r#"Field "name" has invalid value "test""#.to_string()); + let response = error.into_response(); + let (parts, _body) = response.into_parts(); + + assert_eq!(parts.status, StatusCode::BAD_REQUEST); + // JSON encoding should properly escape nested quotes +} diff --git a/mixtape-server/src/lib.rs b/mixtape-server/src/lib.rs new file mode 100644 index 0000000..4f11a96 --- /dev/null +++ b/mixtape-server/src/lib.rs @@ -0,0 +1,47 @@ +//! HTTP server and AG-UI protocol support for mixtape agents. +//! +//! This crate provides HTTP endpoints for running mixtape agents via web services, +//! with optional support for the AG-UI protocol used by CopilotKit. +//! +//! # Features +//! +//! - `agui` - Enable AG-UI protocol support for CopilotKit integration +//! +//! # Example +//! +//! ```rust,no_run +//! use mixtape_server::MixtapeRouter; +//! use mixtape_core::Agent; +//! +//! # async fn example() -> Result<(), Box> { +//! // Create your agent (requires provider feature in mixtape-core) +//! # let agent: Agent = todo!(); +//! +//! // Build the router with AG-UI support +//! let app = MixtapeRouter::new(agent) +//! .with_agui("/api/copilotkit") +//! .build()?; +//! +//! // Serve with axum +//! let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await?; +//! axum::serve(listener, app).await?; +//! # Ok(()) +//! # } +//! ``` + +pub mod error; +pub mod router; +pub(crate) mod state; + +#[cfg(feature = "agui")] +pub(crate) mod agui; + +// Re-exports +pub use error::{BuildError, ServerError, ServerResult}; +pub use router::MixtapeRouter; + +// AG-UI protocol types (for consumers who need to reference the event types) +#[cfg(feature = "agui")] +pub use agui::events::{ + AguiEvent, GrantScope, InterruptData, InterruptResponse, InterruptType, MessageRole, +}; diff --git a/mixtape-server/src/router.rs b/mixtape-server/src/router.rs new file mode 100644 index 0000000..7018cf1 --- /dev/null +++ b/mixtape-server/src/router.rs @@ -0,0 +1,188 @@ +//! Router builder for mixtape HTTP endpoints. + +use std::sync::Arc; + +use axum::Router; +use mixtape_core::Agent; + +use crate::error::BuildError; +use crate::state::AppState; + +/// Builder for configuring mixtape HTTP endpoints. +/// +/// # Example +/// +/// ```rust,no_run +/// use mixtape_server::MixtapeRouter; +/// use mixtape_core::Agent; +/// +/// # async fn example() -> Result<(), Box> { +/// # let agent: Agent = todo!(); +/// // Simple setup with AG-UI endpoint +/// let app = MixtapeRouter::new(agent) +/// .with_agui("/api/copilotkit") +/// .build()?; +/// # Ok(()) +/// # } +/// ``` +pub struct MixtapeRouter { + agent: Arc, + #[cfg(feature = "agui")] + agui_path: Option, + #[cfg(feature = "agui")] + interrupt_path: Option, +} + +impl MixtapeRouter { + /// Create a new router builder with the given agent. + /// + /// The agent will be wrapped in an `Arc` for sharing across handlers. + pub fn new(agent: Agent) -> Self { + Self { + agent: Arc::new(agent), + #[cfg(feature = "agui")] + agui_path: None, + #[cfg(feature = "agui")] + interrupt_path: None, + } + } + + /// Create a new router builder from an existing `Arc`. + /// + /// Use this when you need to share the agent with other parts of your application. + pub fn from_arc(agent: Arc) -> Self { + Self { + agent, + #[cfg(feature = "agui")] + agui_path: None, + #[cfg(feature = "agui")] + interrupt_path: None, + } + } + + /// Enable AG-UI protocol endpoint at the specified path. + /// + /// This also enables an interrupt endpoint at `{path}/interrupt` for handling + /// permission responses. Use [`interrupt_path`](Self::interrupt_path) to customize + /// the interrupt endpoint path. + /// + /// # Example + /// + /// ```rust,no_run + /// # use mixtape_server::MixtapeRouter; + /// # use mixtape_core::Agent; + /// # async fn example() -> Result<(), Box> { + /// # let agent: Agent = todo!(); + /// let app = MixtapeRouter::new(agent) + /// .with_agui("/api/copilotkit") // SSE endpoint at /api/copilotkit + /// .build()?; // Interrupt at /api/copilotkit/interrupt + /// # Ok(()) + /// # } + /// ``` + #[cfg(feature = "agui")] + pub fn with_agui(mut self, path: impl Into) -> Self { + let path = path.into(); + self.interrupt_path = Some(format!("{}/interrupt", path)); + self.agui_path = Some(path); + self + } + + /// Set a custom path for the interrupt endpoint. + /// + /// By default, the interrupt endpoint is at `{agui_path}/interrupt`. + /// Use this method to override that default. + /// + /// # Example + /// + /// ```rust,no_run + /// # use mixtape_server::MixtapeRouter; + /// # use mixtape_core::Agent; + /// # async fn example() -> Result<(), Box> { + /// # let agent: Agent = todo!(); + /// let app = MixtapeRouter::new(agent) + /// .with_agui("/api/copilotkit") + /// .interrupt_path("/api/approve") // Custom interrupt path + /// .build()?; + /// # Ok(()) + /// # } + /// ``` + #[cfg(feature = "agui")] + pub fn interrupt_path(mut self, path: impl Into) -> Self { + self.interrupt_path = Some(path.into()); + self + } + + /// Build the router with all configured endpoints. + /// + /// Returns an axum `Router` that can be served directly or merged + /// with other routes. + /// + /// # Errors + /// + /// Returns [`BuildError::NoEndpoints`] if no endpoints were configured. + /// Call `.with_agui()` before `.build()`. + pub fn build(self) -> Result { + // Validate that at least one endpoint is configured + #[cfg(feature = "agui")] + let has_endpoints = self.agui_path.is_some(); + #[cfg(not(feature = "agui"))] + let has_endpoints = false; + + if !has_endpoints { + return Err(BuildError::NoEndpoints); + } + + let state = AppState::from_arc(self.agent); + let mut router = Router::new(); + + // Add AG-UI endpoints if enabled and configured + #[cfg(feature = "agui")] + if let Some(agui_path) = self.agui_path { + use crate::agui::handler::{agui_handler, interrupt_handler}; + use axum::routing::post; + + router = router.route(&agui_path, post(agui_handler)); + + if let Some(interrupt_path) = self.interrupt_path { + router = router.route(&interrupt_path, post(interrupt_handler)); + } + } + + Ok(router.with_state(state)) + } + + /// Build the router and nest it under a prefix path. + /// + /// This is useful when integrating with an existing application router. + /// + /// # Errors + /// + /// Returns [`BuildError::NoEndpoints`] if no endpoints were configured. + /// + /// # Example + /// + /// ```rust,no_run + /// # use mixtape_server::MixtapeRouter; + /// # use mixtape_core::Agent; + /// # use axum::Router; + /// # async fn example() -> Result<(), Box> { + /// # let agent: Agent = todo!(); + /// // Nest mixtape routes under /agent + /// let mixtape = MixtapeRouter::new(agent) + /// .with_agui("/stream") // Will be at /agent/stream + /// .build_nested("/agent")?; + /// + /// // Merge with existing routes + /// let app = Router::new() + /// .merge(mixtape); + /// # Ok(()) + /// # } + /// ``` + pub fn build_nested(self, prefix: impl Into) -> Result { + Ok(Router::new().nest(&prefix.into(), self.build()?)) + } +} + +#[cfg(test)] +#[path = "router_tests.rs"] +mod tests; diff --git a/mixtape-server/src/router_tests.rs b/mixtape-server/src/router_tests.rs new file mode 100644 index 0000000..332acc6 --- /dev/null +++ b/mixtape-server/src/router_tests.rs @@ -0,0 +1,120 @@ +//! Tests for the router builder. +//! +//! These tests verify the builder pattern and path configuration. +//! Note: Most tests require an actual agent instance which needs async initialization. +//! The tests below focus on testing the types and basic construction patterns. + +use crate::router::MixtapeRouter; +use std::sync::Arc; + +// Note: Full integration tests with real agents would require: +// 1. Async test infrastructure +// 2. Provider credentials/mocking +// 3. Complex setup +// +// These tests focus on the builder API surface and type safety. + +#[test] +fn test_router_builder_type_signature() { + // Verify that the builder accepts the correct types + // This is a compile-time test + fn _accepts_agent_owned(_: impl FnOnce(mixtape_core::Agent) -> MixtapeRouter) {} + fn _accepts_agent_arc(_: impl FnOnce(Arc) -> MixtapeRouter) {} + + _accepts_agent_owned(MixtapeRouter::new); + _accepts_agent_arc(MixtapeRouter::from_arc); +} + +#[cfg(feature = "agui")] +#[test] +fn test_router_builder_fluent_api() { + use crate::error::BuildError; + + // Test that the builder methods return Self for chaining + // This is a compile-time test of the builder pattern + + // Verify the builder pattern compiles with method chaining + fn _test_chaining(f: F) + where + F: FnOnce(MixtapeRouter) -> Result, + { + drop(f); + } + + #[cfg(feature = "agui")] + _test_chaining(|builder| { + builder + .with_agui("/api/stream") + .interrupt_path("/api/interrupt") + .build() + }); + + _test_chaining(|builder| builder.with_agui("/api").build()); +} + +#[cfg(feature = "agui")] +#[test] +fn test_router_into_variants() { + use crate::error::BuildError; + + // Test that both `build()` and `build_nested()` return Result + fn _returns_result(_: impl FnOnce(MixtapeRouter) -> Result) {} + + _returns_result(|b| b.with_agui("/api").build()); + _returns_result(|b| b.with_agui("/api").build_nested("/prefix")); +} + +#[cfg(feature = "agui")] +#[test] +fn test_router_path_types() { + // Test that path methods accept Into + fn _test_with_agui>(path: S) { + // This would be: MixtapeRouter::new(agent).with_agui(path) + // We're just testing the type signature + drop(path.into()); + } + + _test_with_agui("/api/stream"); + _test_with_agui(String::from("/api/stream")); + _test_with_agui("api/stream"); // No leading slash +} + +#[cfg(feature = "agui")] +#[test] +fn test_router_builder_consumes_self() { + // Test move semantics - compile-time verification + // If this compiles, the builder correctly consumes self + fn _consume_builder(f: F) + where + F: FnOnce(MixtapeRouter), + { + drop(f); + } + + _consume_builder(|router| { + let _app = router.with_agui("/api").build(); + // router is moved here and can't be used again + }); +} + +#[test] +fn test_app_state_construction() { + // Test that AppState can be constructed from Arc + // This is used internally by the router + use crate::state::AppState; + + fn _from_arc(agent: Arc) -> AppState { + AppState::from_arc(agent) + } + + let _ = _from_arc; +} + +// Note: The following tests would require actual Agent instances: +// - test_router_new_wraps_agent_in_arc +// - test_router_from_arc +// - test_router_build_empty +// - test_router_with_agui_* +// - test_router_build_nested_* +// +// These would be better suited for integration tests with proper async setup. diff --git a/mixtape-server/src/state.rs b/mixtape-server/src/state.rs new file mode 100644 index 0000000..fa7ad1f --- /dev/null +++ b/mixtape-server/src/state.rs @@ -0,0 +1,22 @@ +//! Application state for the mixtape server. + +use std::sync::Arc; + +use mixtape_core::Agent; + +/// Shared application state containing the agent. +/// +/// This state is cloned for each request handler and provides +/// access to the shared agent instance. +#[derive(Clone)] +pub struct AppState { + /// The shared agent instance. + pub agent: Arc, +} + +impl AppState { + /// Create new application state from an Arc. + pub fn from_arc(agent: Arc) -> Self { + Self { agent } + } +} diff --git a/mixtape-server/tests/integration_tests.rs b/mixtape-server/tests/integration_tests.rs new file mode 100644 index 0000000..8433302 --- /dev/null +++ b/mixtape-server/tests/integration_tests.rs @@ -0,0 +1,277 @@ +//! Integration tests for mixtape-server. +//! +//! These tests verify the full request→hook→agent→events→SSE flow. + +use axum::body::Body; +use axum::http::{Request, StatusCode}; +use mixtape_core::test_utils::MockProvider; +use mixtape_core::Agent; +use mixtape_server::MixtapeRouter; +use tower::ServiceExt; + +/// Helper to build an agent with a mock provider. +async fn build_mock_agent(provider: MockProvider) -> Agent { + Agent::builder() + .provider(provider) + .build() + .await + .expect("Failed to build agent") +} + +/// Helper to create SSE request body. +fn sse_request(message: &str) -> Request { + Request::builder() + .method("POST") + .uri("/api/copilotkit") + .header("Content-Type", "application/json") + .body(Body::from(format!(r#"{{"message": "{}"}}"#, message))) + .unwrap() +} + +/// Collect SSE events from response body. +async fn collect_sse_events(body: Body) -> Vec { + let bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap(); + let text = String::from_utf8_lossy(&bytes); + text.lines() + .filter(|line| line.starts_with("data: ")) + .map(|line| line.strip_prefix("data: ").unwrap().to_string()) + .collect() +} + +/// Extract event type names from SSE event JSON strings. +fn extract_event_types(events: &[String]) -> Vec { + events + .iter() + .filter_map(|e| { + serde_json::from_str::(e) + .ok() + .and_then(|v| v.get("type").and_then(|t| t.as_str().map(String::from))) + }) + .collect() +} + +// ============================================================================ +// Hook Lifecycle Tests +// ============================================================================ + +#[tokio::test] +async fn test_hooks_receive_events_during_request() { + let provider = MockProvider::new().with_text("Hello!"); + let agent = build_mock_agent(provider).await; + + let app = MixtapeRouter::new(agent) + .with_agui("/api/copilotkit") + .build() + .unwrap(); + + let response = app.oneshot(sse_request("Hi")).await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + + let events = collect_sse_events(response.into_body()).await; + let event_types = extract_event_types(&events); + + assert!(event_types.contains(&"RUN_STARTED".to_string())); + assert!(event_types.contains(&"RUN_FINISHED".to_string())); +} + +#[tokio::test] +async fn test_multiple_requests_produce_consistent_events() { + let mut event_counts = Vec::new(); + + for i in 0..3 { + let provider = MockProvider::new().with_text(format!("Response {}", i)); + let agent = build_mock_agent(provider).await; + let app = MixtapeRouter::new(agent) + .with_agui("/api/copilotkit") + .build() + .unwrap(); + + let response = app.oneshot(sse_request("Hi")).await.unwrap(); + let events = collect_sse_events(response.into_body()).await; + event_counts.push(events.len()); + } + + // All requests should produce the same number of events + assert!( + event_counts.iter().all(|&c| c == event_counts[0]), + "Event counts should be consistent: {:?}", + event_counts + ); +} + +// ============================================================================ +// SSE Stream Tests +// ============================================================================ + +#[tokio::test] +async fn test_sse_stream_format() { + let provider = MockProvider::new().with_text("Hello, world!"); + let agent = build_mock_agent(provider).await; + let app = MixtapeRouter::new(agent) + .with_agui("/api/copilotkit") + .build() + .unwrap(); + + let response = app.oneshot(sse_request("Hi")).await.unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + assert_eq!( + response.headers().get("content-type").unwrap(), + "text/event-stream" + ); + + let events = collect_sse_events(response.into_body()).await; + for event in &events { + assert!( + serde_json::from_str::(event).is_ok(), + "Event should be valid JSON: {}", + event + ); + } +} + +#[tokio::test] +async fn test_sse_event_sequence() { + let provider = MockProvider::new().with_text("Test response"); + let agent = build_mock_agent(provider).await; + let app = MixtapeRouter::new(agent) + .with_agui("/api/copilotkit") + .build() + .unwrap(); + + let response = app.oneshot(sse_request("Hello")).await.unwrap(); + let events = collect_sse_events(response.into_body()).await; + let event_types = extract_event_types(&events); + + assert_eq!(event_types.first(), Some(&"RUN_STARTED".to_string())); + assert_eq!(event_types.last(), Some(&"RUN_FINISHED".to_string())); + assert!(event_types.contains(&"TEXT_MESSAGE_START".to_string())); + assert!(event_types.contains(&"TEXT_MESSAGE_END".to_string())); +} + +#[tokio::test] +async fn test_sse_tool_call_events() { + let provider = MockProvider::new() + .with_tool_use("calculator", serde_json::json!({"expression": "2+2"})) + .with_text("The answer is 4"); + + let agent = build_mock_agent(provider).await; + let app = MixtapeRouter::new(agent) + .with_agui("/api/copilotkit") + .build() + .unwrap(); + + let response = app.oneshot(sse_request("What is 2+2?")).await.unwrap(); + let events = collect_sse_events(response.into_body()).await; + let event_types = extract_event_types(&events); + + assert!(event_types.contains(&"TOOL_CALL_START".to_string())); + assert!(event_types.contains(&"TOOL_CALL_ARGS".to_string())); + assert!(event_types.contains(&"TOOL_CALL_END".to_string())); +} + +#[tokio::test] +async fn test_sse_uses_provided_thread_and_run_ids() { + let provider = MockProvider::new().with_text("Hello!"); + let agent = build_mock_agent(provider).await; + let app = MixtapeRouter::new(agent) + .with_agui("/api/copilotkit") + .build() + .unwrap(); + + let request = Request::builder() + .method("POST") + .uri("/api/copilotkit") + .header("Content-Type", "application/json") + .body(Body::from( + r#"{"message": "Hi", "thread_id": "thread-123", "run_id": "run-456"}"#, + )) + .unwrap(); + + let response = app.oneshot(request).await.unwrap(); + let events = collect_sse_events(response.into_body()).await; + + let run_started = events + .iter() + .find(|e| e.contains("RUN_STARTED")) + .expect("Should have RUN_STARTED"); + + let parsed: serde_json::Value = serde_json::from_str(run_started).unwrap(); + assert_eq!(parsed["thread_id"], "thread-123"); + assert_eq!(parsed["run_id"], "run-456"); +} + +#[tokio::test] +async fn test_sse_generates_ids_when_not_provided() { + let provider = MockProvider::new().with_text("Hello!"); + let agent = build_mock_agent(provider).await; + let app = MixtapeRouter::new(agent) + .with_agui("/api/copilotkit") + .build() + .unwrap(); + + let response = app.oneshot(sse_request("Hi")).await.unwrap(); + let events = collect_sse_events(response.into_body()).await; + + let run_started = events + .iter() + .find(|e| e.contains("RUN_STARTED")) + .expect("Should have RUN_STARTED"); + + let parsed: serde_json::Value = serde_json::from_str(run_started).unwrap(); + let thread_id = parsed["thread_id"].as_str().unwrap(); + let run_id = parsed["run_id"].as_str().unwrap(); + + assert!( + uuid::Uuid::parse_str(thread_id).is_ok(), + "thread_id should be valid UUID" + ); + assert!( + uuid::Uuid::parse_str(run_id).is_ok(), + "run_id should be valid UUID" + ); +} + +// ============================================================================ +// Error Handling Tests +// ============================================================================ + +#[tokio::test] +async fn test_sse_error_event_on_provider_failure() { + let provider = MockProvider::new(); // No responses = will error + let agent = build_mock_agent(provider).await; + let app = MixtapeRouter::new(agent) + .with_agui("/api/copilotkit") + .build() + .unwrap(); + + let response = app.oneshot(sse_request("Hi")).await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); // SSE streams errors as events + + let events = collect_sse_events(response.into_body()).await; + assert!( + events.iter().any(|e| e.contains("RUN_ERROR")), + "Should have RUN_ERROR event: {:?}", + events + ); +} + +#[tokio::test] +async fn test_invalid_request_body_returns_error() { + let provider = MockProvider::new().with_text("Hello!"); + let agent = build_mock_agent(provider).await; + let app = MixtapeRouter::new(agent) + .with_agui("/api/copilotkit") + .build() + .unwrap(); + + let request = Request::builder() + .method("POST") + .uri("/api/copilotkit") + .header("Content-Type", "application/json") + .body(Body::from("not valid json")) + .unwrap(); + + let response = app.oneshot(request).await.unwrap(); + assert!(response.status().is_client_error()); +} diff --git a/mixtape-tools/Cargo.toml b/mixtape-tools/Cargo.toml index 1181787..299666f 100644 --- a/mixtape-tools/Cargo.toml +++ b/mixtape-tools/Cargo.toml @@ -4,11 +4,13 @@ version.workspace = true edition.workspace = true license.workspace = true repository.workspace = true +homepage.workspace = true description = "Ready-to-use tool implementations for the mixtape agent framework" documentation = "https://docs.rs/mixtape-tools" readme = "README.md" keywords = ["ai", "agents", "tools", "filesystem", "llm"] categories = ["development-tools", "filesystem", "asynchronous"] +exclude = [".cargo-husky/", ".claude/", ".github/", ".idea/"] [features] default = ["filesystem", "process", "edit", "search", "fetch", "aws", "sqlite"]