Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 19 additions & 14 deletions .github/ci-configs/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,25 @@ jobs:
continue-on-error: false
name: "cli test"
run: |
cargo run -- help
cargo run -- --help
cargo run -- version
cargo run -- --version
cargo run -- validate
cargo run -- v
cargo run -- validate -d
cargo run -- --debug validate
cargo run -- validate --trace dependabot
cargo run -- --trace dependabot validate
cargo run -- v --format json
cargo run -- --format json v
cargo run -- g
cargo run -- generate
cargo build
export PATH=$PWD/target/debug:$PATH

kg help
kg --help
kg version
kg --version
kg validate
kg v
kg validate -d
kg --debug validate
kg validate --trace dependabot
kg --trace dependabot validate
kg v --format json
kg --format json v
kg g
kg generate
kg init
kg validate --global

matrix:
os: []
Expand Down
39 changes: 39 additions & 0 deletions src/agent/custom_tool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,42 @@ pub struct CustomToolConfig {
pub fn tool_default_timeout() -> u64 {
120 * 1000
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn transport_type_default() {
assert_eq!(tool_default_timeout(), 120 * 1000);
assert_eq!(TransportType::default(), TransportType::Stdio);
}

#[test]
fn custom_tool_config_serde() {
let config = CustomToolConfig {
r#type: TransportType::Http,
url: "http://test".into(),
headers: HashMap::new(),
oauth: None,
command: "cmd".into(),
args: vec!["arg1".into()],
env: HashMap::new(),
timeout: 5000,
disabled: false,
};
let json = serde_json::to_string(&config).unwrap();
let deserialized: CustomToolConfig = serde_json::from_str(&json).unwrap();
assert_eq!(config, deserialized);
}

#[test]
fn oauth_config_serde() {
let oauth = OAuthConfig {
redirect_uri: Some("localhost:8080".into()),
};
let json = serde_json::to_string(&oauth).unwrap();
let deserialized: OAuthConfig = serde_json::from_str(&json).unwrap();
assert_eq!(oauth, deserialized);
}
}
32 changes: 32 additions & 0 deletions src/agent/hook.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,35 @@ impl Hook {
DEFAULT_CACHE_TTL_SECONDS
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn hook_trigger_display() {
assert_eq!(HookTrigger::AgentSpawn.to_string(), "agentSpawn");
assert_eq!(HookTrigger::Stop.to_string(), "stop");
}

#[test]
fn hook_defaults() {
assert_eq!(Hook::default_timeout_ms(), 30_000);
assert_eq!(Hook::default_max_output_size(), 10_240);
assert_eq!(Hook::default_cache_ttl_seconds(), 0);
}

#[test]
fn hook_serde() {
let hook = Hook {
command: "test".into(),
timeout_ms: 1000,
max_output_size: 500,
cache_ttl_seconds: 10,
matcher: Some("*.rs".into()),
};
let json = serde_json::to_string(&hook).unwrap();
let deserialized: Hook = serde_json::from_str(&json).unwrap();
assert_eq!(hook, deserialized);
}
}
30 changes: 30 additions & 0 deletions src/agent/mcp_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,33 @@ use {
pub struct McpServerConfig {
pub mcp_servers: HashMap<String, CustomToolConfig>,
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn mcp_server_config_default() {
let config = McpServerConfig::default();
assert!(config.mcp_servers.is_empty());
}

#[test]
fn mcp_server_config_serde() {
let mut config = McpServerConfig::default();
config.mcp_servers.insert("test".into(), CustomToolConfig {
r#type: Default::default(),
url: String::new(),
headers: HashMap::new(),
oauth: None,
command: "cmd".into(),
args: vec![],
env: HashMap::new(),
timeout: 120_000,
disabled: false,
});
let json = serde_json::to_string(&config).unwrap();
let deserialized: McpServerConfig = serde_json::from_str(&json).unwrap();
assert_eq!(config, deserialized);
}
}
44 changes: 35 additions & 9 deletions src/agent/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ pub struct Agent {
/// actual schema differs by tools and is documented in detail in our
/// documentation
#[serde(default)]
pub tools_settings: HashMap<ToolTarget, serde_json::Value>,
pub tools_settings: HashMap<String, serde_json::Value>,
/// The model ID to use for this agent. If not specified, uses the default
/// model.
#[serde(default, skip_serializing_if = "Option::is_none")]
Expand Down Expand Up @@ -93,33 +93,59 @@ impl Agent {
}
}

impl From<&KdlAgent> for Agent {
fn from(value: &KdlAgent) -> Self {
impl TryFrom<&KdlAgent> for Agent {
type Error = color_eyre::Report;

fn try_from(value: &KdlAgent) -> std::result::Result<Self, Self::Error> {
let native_tools = &value.native_tool;
let mut tools_settings = HashMap::new();

let tool: AwsTool = native_tools.into();
let tool_name = ToolTarget::Aws.to_string();
if tool != AwsTool::default() {
tools_settings.insert(ToolTarget::Aws, serde_json::to_value(&tool).unwrap());
tools_settings.insert(
tool_name.to_string(),
serde_json::to_value(&tool)
.map_err(|e| eyre!("Failed to serialize {tool_name} tool configuration {e}"))?,
);
}
let tool: ReadTool = native_tools.into();
let tool_name = ToolTarget::Read.to_string();
if tool != ReadTool::default() {
tools_settings.insert(ToolTarget::Read, serde_json::to_value(&tool).unwrap());
tools_settings.insert(
tool_name.to_string(),
serde_json::to_value(&tool)
.map_err(|e| eyre!("Failed to serialize {tool_name} tool configuration {e}"))?,
);
}
let tool: WriteTool = native_tools.into();
let tool_name = ToolTarget::Write.to_string();
if tool != WriteTool::default() {
tools_settings.insert(ToolTarget::Write, serde_json::to_value(&tool).unwrap());
tools_settings.insert(
tool_name.to_string(),
serde_json::to_value(&tool)
.map_err(|e| eyre!("Failed to serialize {tool_name} tool configuration {e}"))?,
);
}
let tool: ExecuteShellTool = native_tools.into();
let tool_name = ToolTarget::Shell.to_string();
if tool != ExecuteShellTool::default() {
tools_settings.insert(ToolTarget::Shell, serde_json::to_value(&tool).unwrap());
tools_settings.insert(
tool_name.to_string(),
serde_json::to_value(&tool)
.map_err(|e| eyre!("Failed to serialize {tool_name} tool configuration {e}"))?,
);
}
let default_agent = Self::default();
let tools = value.tools().clone();
let allowed_tools = value.allowed_tools().clone();
let resources: HashSet<String> = value.resources().map(|s| s.to_string()).collect();

Self {
// Extra tool settings override native tools
let extra_tool_settings = value.extra_tool_settings()?;
tools_settings.extend(extra_tool_settings);

Ok(Self {
name: value.name.clone(),
description: value.description.clone(),
prompt: value.prompt.clone(),
Expand All @@ -146,7 +172,7 @@ impl From<&KdlAgent> for Agent {
tools_settings,
model: value.model.clone(),
include_mcp_json: value.include_mcp_json(),
}
})
}
}

Expand Down
47 changes: 47 additions & 0 deletions src/agent/tools.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,50 @@ pub struct WriteTool {
#[serde(default, skip_serializing_if = "HashSet::is_empty")]
pub denied_paths: HashSet<String>,
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn tool_target_display() {
assert_eq!(ToolTarget::Aws.to_string(), "aws");
assert_eq!(ToolTarget::Shell.to_string(), "shell");
}

#[test]
fn tool_target_as_ref() {
assert_eq!(ToolTarget::Read.as_ref(), "read");
assert_eq!(ToolTarget::Write.as_ref(), "write");
}

#[test]
fn aws_tool_default() {
let tool = AwsTool::default();
assert!(tool.auto_allow_readonly);
assert!(tool.allowed_services.is_empty());
}

#[test]
fn execute_shell_tool_default() {
let tool = ExecuteShellTool::default();
assert!(!tool.deny_by_default);
assert!(!tool.auto_allow_readonly);
}

#[test]
fn read_tool_serde() {
let tool = ReadTool::default();
let json = serde_json::to_string(&tool).unwrap();
let deserialized: ReadTool = serde_json::from_str(&json).unwrap();
assert_eq!(tool, deserialized);
}

#[test]
fn write_tool_serde() {
let tool = WriteTool::default();
let json = serde_json::to_string(&tool).unwrap();
let deserialized: WriteTool = serde_json::from_str(&json).unwrap();
assert_eq!(tool, deserialized);
}
}
26 changes: 26 additions & 0 deletions src/agent/wrapper_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,29 @@ impl Borrow<str> for OriginalToolName {
self.0.as_str()
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn original_tool_name_deref() {
let name = OriginalToolName("test".into());
assert_eq!(&*name, "test");
}

#[test]
fn original_tool_name_borrow() {
let name = OriginalToolName("test".into());
let borrowed: &str = name.borrow();
assert_eq!(borrowed, "test");
}

#[test]
fn original_tool_name_serde() {
let name = OriginalToolName("test".into());
let json = serde_json::to_string(&name).unwrap();
let deserialized: OriginalToolName = serde_json::from_str(&json).unwrap();
assert_eq!(name, deserialized);
}
}
4 changes: 2 additions & 2 deletions src/commands.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ impl Cli {
/// Return home dir and ~/.kiro/generators/kg.kdl
pub fn config(&self) -> crate::Result<(PathBuf, PathBuf)> {
let home_dir = dirs::home_dir().ok_or(eyre!("cannot locate home directory"))?;
let cfg = home_dir.join(".kiro").join("generators").join("kg.kdl");
let cfg = home_dir.join(".kiro").join("generators");
Ok((home_dir, cfg))
}
}
Expand Down Expand Up @@ -248,7 +248,7 @@ mod tests {
let result = cli.config();
assert!(result.is_ok());
let (home, cfg) = result.unwrap();
assert!(cfg.ends_with(".kiro/generators/kg.kdl"));
assert!(cfg.ends_with(".kiro/generators"));
assert!(cfg.starts_with(&home));
}
}
4 changes: 4 additions & 0 deletions src/generator/merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ mod tests {
assert!(aws.allow.list.contains("s3"));
assert!(aws.deny.list.contains("iam"));

// check try_from
let results = generator.write_all(true).await?;
assert!(!results.is_empty());

Ok(())
}
}
2 changes: 1 addition & 1 deletion src/generator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ impl Generator {
pub(crate) async fn write(&self, agent: KdlAgent, dry_run: bool) -> Result<AgentResult> {
let destination = self.destination_dir(&agent.name);
let result = AgentResult {
kiro_agent: Agent::from(&agent),
kiro_agent: Agent::try_from(&agent)?,
writable: !agent.is_template(),
destination,
agent,
Expand Down
Loading