diff --git a/.github/ci-configs/rust.yml b/.github/ci-configs/rust.yml index a392a44..c26e10d 100644 --- a/.github/ci-configs/rust.yml +++ b/.github/ci-configs/rust.yml @@ -24,20 +24,25 @@ jobs: continue-on-error: false name: "cli test" run: | - cargo run -- help - cargo run -- --help - cargo run -- version - cargo run -- --version - cargo run -- validate - cargo run -- v - cargo run -- validate -d - cargo run -- --debug validate - cargo run -- validate --trace dependabot - cargo run -- --trace dependabot validate - cargo run -- v --format json - cargo run -- --format json v - cargo run -- g - cargo run -- generate + cargo build + export PATH=$PWD/target/debug:$PATH + + kg help + kg --help + kg version + kg --version + kg validate + kg v + kg validate -d + kg --debug validate + kg validate --trace dependabot + kg --trace dependabot validate + kg v --format json + kg --format json v + kg g + kg generate + kg init + kg validate --global matrix: os: [] diff --git a/src/agent/custom_tool.rs b/src/agent/custom_tool.rs index ea280a1..6796dd7 100644 --- a/src/agent/custom_tool.rs +++ b/src/agent/custom_tool.rs @@ -57,3 +57,42 @@ pub struct CustomToolConfig { pub fn tool_default_timeout() -> u64 { 120 * 1000 } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn transport_type_default() { + assert_eq!(tool_default_timeout(), 120 * 1000); + assert_eq!(TransportType::default(), TransportType::Stdio); + } + + #[test] + fn custom_tool_config_serde() { + let config = CustomToolConfig { + r#type: TransportType::Http, + url: "http://test".into(), + headers: HashMap::new(), + oauth: None, + command: "cmd".into(), + args: vec!["arg1".into()], + env: HashMap::new(), + timeout: 5000, + disabled: false, + }; + let json = serde_json::to_string(&config).unwrap(); + let deserialized: CustomToolConfig = serde_json::from_str(&json).unwrap(); + assert_eq!(config, deserialized); + } + + #[test] + fn oauth_config_serde() { + let oauth = OAuthConfig { + redirect_uri: Some("localhost:8080".into()), + }; + let json = serde_json::to_string(&oauth).unwrap(); + let deserialized: OAuthConfig = serde_json::from_str(&json).unwrap(); + assert_eq!(oauth, deserialized); + } +} diff --git a/src/agent/hook.rs b/src/agent/hook.rs index ce8b1bb..886495b 100644 --- a/src/agent/hook.rs +++ b/src/agent/hook.rs @@ -72,3 +72,35 @@ impl Hook { DEFAULT_CACHE_TTL_SECONDS } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn hook_trigger_display() { + assert_eq!(HookTrigger::AgentSpawn.to_string(), "agentSpawn"); + assert_eq!(HookTrigger::Stop.to_string(), "stop"); + } + + #[test] + fn hook_defaults() { + assert_eq!(Hook::default_timeout_ms(), 30_000); + assert_eq!(Hook::default_max_output_size(), 10_240); + assert_eq!(Hook::default_cache_ttl_seconds(), 0); + } + + #[test] + fn hook_serde() { + let hook = Hook { + command: "test".into(), + timeout_ms: 1000, + max_output_size: 500, + cache_ttl_seconds: 10, + matcher: Some("*.rs".into()), + }; + let json = serde_json::to_string(&hook).unwrap(); + let deserialized: Hook = serde_json::from_str(&json).unwrap(); + assert_eq!(hook, deserialized); + } +} diff --git a/src/agent/mcp_config.rs b/src/agent/mcp_config.rs index 2c9dd87..743e948 100644 --- a/src/agent/mcp_config.rs +++ b/src/agent/mcp_config.rs @@ -9,3 +9,33 @@ use { pub struct McpServerConfig { pub mcp_servers: HashMap, } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn mcp_server_config_default() { + let config = McpServerConfig::default(); + assert!(config.mcp_servers.is_empty()); + } + + #[test] + fn mcp_server_config_serde() { + let mut config = McpServerConfig::default(); + config.mcp_servers.insert("test".into(), CustomToolConfig { + r#type: Default::default(), + url: String::new(), + headers: HashMap::new(), + oauth: None, + command: "cmd".into(), + args: vec![], + env: HashMap::new(), + timeout: 120_000, + disabled: false, + }); + let json = serde_json::to_string(&config).unwrap(); + let deserialized: McpServerConfig = serde_json::from_str(&json).unwrap(); + assert_eq!(config, deserialized); + } +} diff --git a/src/agent/mod.rs b/src/agent/mod.rs index 593ab42..62be2ae 100644 --- a/src/agent/mod.rs +++ b/src/agent/mod.rs @@ -61,7 +61,7 @@ pub struct Agent { /// actual schema differs by tools and is documented in detail in our /// documentation #[serde(default)] - pub tools_settings: HashMap, + pub tools_settings: HashMap, /// The model ID to use for this agent. If not specified, uses the default /// model. #[serde(default, skip_serializing_if = "Option::is_none")] @@ -93,33 +93,59 @@ impl Agent { } } -impl From<&KdlAgent> for Agent { - fn from(value: &KdlAgent) -> Self { +impl TryFrom<&KdlAgent> for Agent { + type Error = color_eyre::Report; + + fn try_from(value: &KdlAgent) -> std::result::Result { let native_tools = &value.native_tool; let mut tools_settings = HashMap::new(); let tool: AwsTool = native_tools.into(); + let tool_name = ToolTarget::Aws.to_string(); if tool != AwsTool::default() { - tools_settings.insert(ToolTarget::Aws, serde_json::to_value(&tool).unwrap()); + tools_settings.insert( + tool_name.to_string(), + serde_json::to_value(&tool) + .map_err(|e| eyre!("Failed to serialize {tool_name} tool configuration {e}"))?, + ); } let tool: ReadTool = native_tools.into(); + let tool_name = ToolTarget::Read.to_string(); if tool != ReadTool::default() { - tools_settings.insert(ToolTarget::Read, serde_json::to_value(&tool).unwrap()); + tools_settings.insert( + tool_name.to_string(), + serde_json::to_value(&tool) + .map_err(|e| eyre!("Failed to serialize {tool_name} tool configuration {e}"))?, + ); } let tool: WriteTool = native_tools.into(); + let tool_name = ToolTarget::Write.to_string(); if tool != WriteTool::default() { - tools_settings.insert(ToolTarget::Write, serde_json::to_value(&tool).unwrap()); + tools_settings.insert( + tool_name.to_string(), + serde_json::to_value(&tool) + .map_err(|e| eyre!("Failed to serialize {tool_name} tool configuration {e}"))?, + ); } let tool: ExecuteShellTool = native_tools.into(); + let tool_name = ToolTarget::Shell.to_string(); if tool != ExecuteShellTool::default() { - tools_settings.insert(ToolTarget::Shell, serde_json::to_value(&tool).unwrap()); + tools_settings.insert( + tool_name.to_string(), + serde_json::to_value(&tool) + .map_err(|e| eyre!("Failed to serialize {tool_name} tool configuration {e}"))?, + ); } let default_agent = Self::default(); let tools = value.tools().clone(); let allowed_tools = value.allowed_tools().clone(); let resources: HashSet = value.resources().map(|s| s.to_string()).collect(); - Self { + // Extra tool settings override native tools + let extra_tool_settings = value.extra_tool_settings()?; + tools_settings.extend(extra_tool_settings); + + Ok(Self { name: value.name.clone(), description: value.description.clone(), prompt: value.prompt.clone(), @@ -146,7 +172,7 @@ impl From<&KdlAgent> for Agent { tools_settings, model: value.model.clone(), include_mcp_json: value.include_mcp_json(), - } + }) } } diff --git a/src/agent/tools.rs b/src/agent/tools.rs index 3fe4859..4d41cd2 100644 --- a/src/agent/tools.rs +++ b/src/agent/tools.rs @@ -106,3 +106,50 @@ pub struct WriteTool { #[serde(default, skip_serializing_if = "HashSet::is_empty")] pub denied_paths: HashSet, } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn tool_target_display() { + assert_eq!(ToolTarget::Aws.to_string(), "aws"); + assert_eq!(ToolTarget::Shell.to_string(), "shell"); + } + + #[test] + fn tool_target_as_ref() { + assert_eq!(ToolTarget::Read.as_ref(), "read"); + assert_eq!(ToolTarget::Write.as_ref(), "write"); + } + + #[test] + fn aws_tool_default() { + let tool = AwsTool::default(); + assert!(tool.auto_allow_readonly); + assert!(tool.allowed_services.is_empty()); + } + + #[test] + fn execute_shell_tool_default() { + let tool = ExecuteShellTool::default(); + assert!(!tool.deny_by_default); + assert!(!tool.auto_allow_readonly); + } + + #[test] + fn read_tool_serde() { + let tool = ReadTool::default(); + let json = serde_json::to_string(&tool).unwrap(); + let deserialized: ReadTool = serde_json::from_str(&json).unwrap(); + assert_eq!(tool, deserialized); + } + + #[test] + fn write_tool_serde() { + let tool = WriteTool::default(); + let json = serde_json::to_string(&tool).unwrap(); + let deserialized: WriteTool = serde_json::from_str(&json).unwrap(); + assert_eq!(tool, deserialized); + } +} diff --git a/src/agent/wrapper_types.rs b/src/agent/wrapper_types.rs index 26e9938..7b6d7e1 100644 --- a/src/agent/wrapper_types.rs +++ b/src/agent/wrapper_types.rs @@ -21,3 +21,29 @@ impl Borrow for OriginalToolName { self.0.as_str() } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn original_tool_name_deref() { + let name = OriginalToolName("test".into()); + assert_eq!(&*name, "test"); + } + + #[test] + fn original_tool_name_borrow() { + let name = OriginalToolName("test".into()); + let borrowed: &str = name.borrow(); + assert_eq!(borrowed, "test"); + } + + #[test] + fn original_tool_name_serde() { + let name = OriginalToolName("test".into()); + let json = serde_json::to_string(&name).unwrap(); + let deserialized: OriginalToolName = serde_json::from_str(&json).unwrap(); + assert_eq!(name, deserialized); + } +} diff --git a/src/commands.rs b/src/commands.rs index 53a31db..0788c12 100644 --- a/src/commands.rs +++ b/src/commands.rs @@ -122,7 +122,7 @@ impl Cli { /// Return home dir and ~/.kiro/generators/kg.kdl pub fn config(&self) -> crate::Result<(PathBuf, PathBuf)> { let home_dir = dirs::home_dir().ok_or(eyre!("cannot locate home directory"))?; - let cfg = home_dir.join(".kiro").join("generators").join("kg.kdl"); + let cfg = home_dir.join(".kiro").join("generators"); Ok((home_dir, cfg)) } } @@ -248,7 +248,7 @@ mod tests { let result = cli.config(); assert!(result.is_ok()); let (home, cfg) = result.unwrap(); - assert!(cfg.ends_with(".kiro/generators/kg.kdl")); + assert!(cfg.ends_with(".kiro/generators")); assert!(cfg.starts_with(&home)); } } diff --git a/src/generator/merge.rs b/src/generator/merge.rs index 570d5b2..8e8329e 100644 --- a/src/generator/merge.rs +++ b/src/generator/merge.rs @@ -139,6 +139,10 @@ mod tests { assert!(aws.allow.list.contains("s3")); assert!(aws.deny.list.contains("iam")); + // check try_from + let results = generator.write_all(true).await?; + assert!(!results.is_empty()); + Ok(()) } } diff --git a/src/generator/mod.rs b/src/generator/mod.rs index e6d228f..964bac8 100644 --- a/src/generator/mod.rs +++ b/src/generator/mod.rs @@ -142,7 +142,7 @@ impl Generator { pub(crate) async fn write(&self, agent: KdlAgent, dry_run: bool) -> Result { let destination = self.destination_dir(&agent.name); let result = AgentResult { - kiro_agent: Agent::from(&agent), + kiro_agent: Agent::try_from(&agent)?, writable: !agent.is_template(), destination, agent, diff --git a/src/kdl/agent.rs b/src/kdl/agent.rs index 84b778f..b2a1331 100644 --- a/src/kdl/agent.rs +++ b/src/kdl/agent.rs @@ -8,6 +8,7 @@ use { }, kdl::native::{AwsTool, ExecuteShellTool, NativeTools, ReadTool, WriteTool}, }, + color_eyre::eyre::WrapErr, knuffel::Decode, std::{ collections::{HashMap, HashSet}, @@ -61,6 +62,37 @@ pub(super) struct ToolAliasKdl { to: String, } +/// Raw JSON tool settings for forward compatibility. +/// +/// Allows users to configure tool settings not yet supported by kg's schema. +/// The JSON must be a valid object (not array or primitive). +/// +/// See https://kiro.dev/docs/cli/custom-agents/configuration-reference/#toolssettings-field +#[derive(Decode, Clone, Debug)] +pub struct ToolSetting { + #[knuffel(argument)] + name: String, + #[knuffel(child, unwrap(argument))] + json: String, +} + +impl ToolSetting { + fn to_value(&self) -> crate::Result<(String, serde_json::Value)> { + let v: serde_json::Value = serde_json::from_str(&self.json) + .wrap_err_with(|| format!("Failed to parse JSON for tool-setting '{}'", self.name))?; + + if !v.is_object() { + return Err(color_eyre::eyre::eyre!( + "tool-setting '{}' must be a JSON object, got: {}", + self.name, + v + )); + } + + Ok((self.name.clone(), v)) + } +} + #[derive(Decode, Clone, Default)] pub struct KdlAgent { /// Name of the agent @@ -105,6 +137,9 @@ pub struct KdlAgent { /// Tools builtin to kiro #[knuffel(child, default)] pub native_tool: NativeTools, + + #[knuffel(children(name = "tool-setting"), default)] + pub(super) tool_settings: Vec, } impl Debug for KdlAgent { @@ -187,4 +222,23 @@ impl KdlAgent { .map(|m| (m.name.clone(), m.into())) .collect() } + + /// Parse raw JSON tool settings into a map. + /// + /// This allows users to configure tools not yet supported by kg's schema. + pub fn extra_tool_settings(&self) -> crate::Result> { + let mut result = HashMap::new(); + for setting in &self.tool_settings { + let (name, value) = setting.to_value()?; + if result.contains_key(&name) { + return Err(color_eyre::eyre::eyre!( + "[{self}] - Duplicate tool-setting '{}' found. Each tool-setting name must be \ + unique.", + name + )); + } + result.insert(name, value); + } + Ok(result) + } } diff --git a/src/kdl/agent_file.rs b/src/kdl/agent_file.rs index 4637f98..0f6edc5 100644 --- a/src/kdl/agent_file.rs +++ b/src/kdl/agent_file.rs @@ -30,6 +30,8 @@ pub struct KdlAgentFileSource { pub(super) tool_aliases: HashSet, #[knuffel(child, default)] pub(super) native_tool: NativeTools, + #[knuffel(children(name = "tool-setting"), default)] + pub(super) tool_settings: Vec, } impl KdlAgent { @@ -70,6 +72,7 @@ impl KdlAgent { mcp: file_source.mcp, tool_aliases: file_source.tool_aliases, native_tool: file_source.native_tool, + tool_settings: file_source.tool_settings, } } } diff --git a/src/kdl/merge.rs b/src/kdl/merge.rs index 29a60e6..7586fc3 100644 --- a/src/kdl/merge.rs +++ b/src/kdl/merge.rs @@ -18,6 +18,7 @@ impl KdlAgent { self.tool_aliases.extend(other.tool_aliases); self.mcp.extend(other.mcp); self.inherits.parents.extend(other.inherits.parents); + self.tool_settings.extend(other.tool_settings); // Hooks are deep merged self.hook = match (self.hook, other.hook) { diff --git a/src/kdl/mod.rs b/src/kdl/mod.rs index 7a2371b..b3eb23e 100644 --- a/src/kdl/mod.rs +++ b/src/kdl/mod.rs @@ -4,7 +4,7 @@ mod hook; mod mcp; mod merge; mod native; -use std::collections::HashSet; +use std::{collections::HashSet, fmt::Debug}; pub use agent::KdlAgent; @@ -14,6 +14,12 @@ pub struct GeneratorConfig { pub agents: Vec, } +impl Debug for GeneratorConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!(f, "agents={}", self.agents.len()) + } +} + impl GeneratorConfig { pub fn names(&self) -> HashSet { self.agents.iter().map(|a| a.name.clone()).collect() @@ -95,6 +101,10 @@ mod tests { override "git pull .*" } } + + tool-setting "@git/status" { + json "{ \"git_user\": \"$GIT_USER\" }" + } } "#; @@ -152,6 +162,14 @@ mod tests { assert!(aws_docs.oauth.is_some()); assert_eq!(agent.tool_aliases().len(), 1); + + let extra = agent.extra_tool_settings()?; + assert_eq!(extra.len(), 1); + assert!(extra.contains_key("@git/status")); + let git_status = extra.get("@git/status").unwrap(); + assert!(git_status.is_object()); + assert_eq!(git_status["git_user"], "$GIT_USER"); + Ok(()) } @@ -169,11 +187,13 @@ mod tests { return Err(eyre!("failed to parse {kdl_agents}")); } }; + assert!(!format!("{config:?}").is_empty()); assert_eq!(config.agents.len(), 1); let agent = config.agents[0].clone(); assert_eq!(agent.name, "test"); assert!(agent.model.is_none()); assert!(agent.is_template()); + Ok(()) } @@ -250,4 +270,25 @@ mod tests { assert_eq!(agent.description.unwrap_or_default(), "agent from file"); Ok(()) } + + #[test_log::test] + fn test_tool_setting_invalid_json() -> crate::Result<()> { + let kdl = r#" + agent "test" { + tool-setting "bad" { + json "{ invalid json }" + } + } + "#; + let config: GeneratorConfig = parse("test.kdl", kdl)?; + let result = config.agents[0].extra_tool_settings(); + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("Failed to parse JSON") + ); + Ok(()) + } }