From 90dec3fc8595d6197fc5a0e265616d9876b1604c Mon Sep 17 00:00:00 2001 From: Benedick Montales Date: Sun, 15 Feb 2026 01:54:00 +0800 Subject: [PATCH 1/2] fix: show platform-specific installation guide for outdated binary errors When users have a newer config file with fields that don't exist in their installed kairo binary, they now see a helpful guide to reinstall kairo using the appropriate script for their platform. Changes: - Added handleConfigError() to detect outdated binary errors - Show curl | sh script for Linux/macOS - Show irm | iex script for Windows - Link to manual installation docs - Updated all config loading sites to use the new error handler - Added tests for the new error handling - Added unknown field detection in config loader --- cmd/config.go | 2 +- cmd/default.go | 2 +- cmd/harness.go | 4 +- cmd/integration_test.go | 9 ---- cmd/list.go | 2 +- cmd/reset.go | 2 +- cmd/root.go | 54 ++++++++++++++++++- cmd/root_test.go | 111 ++++++++++++++++++++++++++++++++++++++ cmd/status.go | 2 +- cmd/switch.go | 2 +- cmd/test.go | 2 +- internal/config/loader.go | 26 +++++++++ 12 files changed, 199 insertions(+), 19 deletions(-) diff --git a/cmd/config.go b/cmd/config.go index 32afd31..43ffd3f 100644 --- a/cmd/config.go +++ b/cmd/config.go @@ -51,7 +51,7 @@ var configCmd = &cobra.Command{ cfg, err := configCache.Get(dir) if err != nil && !os.IsNotExist(err) { - ui.PrintError(fmt.Sprintf("Error loading config: %v", err)) + handleConfigError(cmd, err) return } if err != nil { diff --git a/cmd/default.go b/cmd/default.go index acb6fb1..c015535 100644 --- a/cmd/default.go +++ b/cmd/default.go @@ -32,7 +32,7 @@ var defaultCmd = &cobra.Command{ } return } - ui.PrintError(fmt.Sprintf("Error loading config: %v", err)) + handleConfigError(cmd, err) return } diff --git a/cmd/harness.go b/cmd/harness.go index 50a8fe6..3b79bd7 100644 --- a/cmd/harness.go +++ b/cmd/harness.go @@ -26,7 +26,7 @@ var harnessGetCmd = &cobra.Command{ cfg, err := configCache.Get(dir) if err != nil { - ui.PrintError(fmt.Sprintf("Error loading config: %v", err)) + handleConfigError(cmd, err) return } @@ -66,7 +66,7 @@ var harnessSetCmd = &cobra.Command{ cfg, err := configCache.Get(dir) if err != nil && !errors.Is(err, kairoerrors.ErrConfigNotFound) { - ui.PrintError(fmt.Sprintf("Error loading config: %v", err)) + handleConfigError(cmd, err) return } if err != nil { diff --git a/cmd/integration_test.go b/cmd/integration_test.go index 2369e97..b03481c 100644 --- a/cmd/integration_test.go +++ b/cmd/integration_test.go @@ -234,15 +234,6 @@ func contains(s, substr string) bool { return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsSubstring(s, substr)) } -func containsSubstring(s, substr string) bool { - for i := 0; i <= len(s)-len(substr); i++ { - if s[i:i+len(substr)] == substr { - return true - } - } - return false -} - // TestE2ESetupToSwitchWorkflow tests the complete end-to-end workflow // from initial setup through provider switching. func TestE2ESetupToSwitchWorkflow(t *testing.T) { diff --git a/cmd/list.go b/cmd/list.go index 036106f..af1ed9d 100644 --- a/cmd/list.go +++ b/cmd/list.go @@ -30,7 +30,7 @@ var listCmd = &cobra.Command{ ui.PrintInfo("Run 'kairo setup' to get started") return } - ui.PrintError(fmt.Sprintf("Error loading config: %v", err)) + handleConfigError(cmd, err) return } diff --git a/cmd/reset.go b/cmd/reset.go index aa9ef91..0a3131d 100644 --- a/cmd/reset.go +++ b/cmd/reset.go @@ -37,7 +37,7 @@ var resetCmd = &cobra.Command{ ui.PrintWarn("No providers configured") return } - ui.PrintError(fmt.Sprintf("Error loading config: %v", err)) + handleConfigError(cmd, err) return } diff --git a/cmd/root.go b/cmd/root.go index d7d08ff..546b934 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -24,6 +24,7 @@ package cmd import ( "fmt" "os" + "runtime" "time" "github.com/dkmnx/kairo/internal/config" @@ -71,7 +72,7 @@ Version: %s (commit: %s, date: %s)`, kairoversion.Version, kairoversion.Commit, cmd.Println("No providers configured. Run 'kairo setup' to get started.") return } - cmd.Printf("Error loading config: %v\n", err) + handleConfigError(cmd, err) return } @@ -190,3 +191,54 @@ func getConfigDir() string { } return env.GetConfigDir() } + +// handleConfigError provides user-friendly guidance for config errors. +func handleConfigError(cmd *cobra.Command, err error) { + errStr := err.Error() + + // Check for unknown field error (outdated binary) + // This can appear in two forms: + // 1. Raw YAML error: "field X not found in type config.Config" + // 2. Wrapped error: "configuration file contains field(s) not recognized" + if (containsSubstring(errStr, "field") && containsSubstring(errStr, "not found in type")) || + containsSubstring(errStr, "configuration file contains field(s) not recognized") || + containsSubstring(errStr, "your installed kairo binary is outdated") { + cmd.Println("Error: Your kairo binary is outdated and cannot read your configuration file.") + cmd.Println() + cmd.Println("The configuration file contains newer fields that this version doesn't recognize.") + cmd.Println() + cmd.Println("How to fix:") + cmd.Println(" Run the installation script for your platform:") + cmd.Println() + + // Display platform-specific installation script + switch runtime.GOOS { + case "windows": + cmd.Println(" irm https://raw.githubusercontent.com/dkmnx/kairo/main/scripts/install.ps1 | iex") + default: // linux, darwin (macOS) + cmd.Println(" curl -sSL https://raw.githubusercontent.com/dkmnx/kairo/main/scripts/install.sh | sh") + } + + cmd.Println() + cmd.Println(" For manual installation, see:") + cmd.Println(" https://github.com/dkmnx/kairo/blob/main/docs/guides/user-guide.md#manual-installation") + cmd.Println() + if verbose { + cmd.Printf("Technical details: %v\n", err) + } + return + } + + // Default error handling + cmd.Printf("Error loading config: %v\n", err) +} + +// containsSubstring checks if a string contains a substring. +func containsSubstring(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} diff --git a/cmd/root_test.go b/cmd/root_test.go index 478fb6d..e000242 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -2,6 +2,7 @@ package cmd import ( "bytes" + "fmt" "os" "os/exec" "path/filepath" @@ -436,3 +437,113 @@ func createConfigFile(t *testing.T, dir string, cfg *config.Config) string { } return configPath } + +func TestHandleConfigError(t *testing.T) { + t.Run("unknown field error shows helpful guide", func(t *testing.T) { + output := &bytes.Buffer{} + rootCmd.SetOut(output) + + // Simulate the error from outdated binary + err := fmt.Errorf("field default_harness not found in type config.Config (path=/home/user/.config/kairo/config.yaml)") + + originalVerbose := verbose + setVerbose(false) + defer func() { setVerbose(originalVerbose) }() + + handleConfigError(rootCmd, err) + + result := output.String() + + // Verify the helpful message is shown + expectedMessages := []string{ + "Your kairo binary is outdated", + "configuration file contains newer fields", + "installation script", + "github.com/dkmnx/kairo", + "install.sh", + } + + for _, msg := range expectedMessages { + if !containsString(result, msg) { + t.Errorf("Expected message %q not found in output:\n%s", msg, result) + } + } + }) + + t.Run("unknown field error with verbose shows technical details", func(t *testing.T) { + output := &bytes.Buffer{} + rootCmd.SetOut(output) + + // Simulate the error from outdated binary + err := fmt.Errorf("field default_harness not found in type config.Config") + + originalVerbose := verbose + setVerbose(true) + defer func() { setVerbose(originalVerbose) }() + + handleConfigError(rootCmd, err) + + result := output.String() + + // Verify technical details are shown in verbose mode + if !containsString(result, "Technical details:") { + t.Errorf("Expected 'Technical details:' in verbose output:\n%s", result) + } + if !containsString(result, "field default_harness") { + t.Errorf("Expected error details in verbose output:\n%s", result) + } + }) + + t.Run("other errors show default message", func(t *testing.T) { + output := &bytes.Buffer{} + rootCmd.SetOut(output) + + // Simulate a different error + err := fmt.Errorf("some other config error") + + originalVerbose := verbose + setVerbose(false) + defer func() { setVerbose(originalVerbose) }() + + handleConfigError(rootCmd, err) + + result := output.String() + + // Verify default error message + if !containsString(result, "Error loading config:") { + t.Errorf("Expected default error message, got:\n%s", result) + } + if !containsString(result, "some other config error") { + t.Errorf("Expected error text in output:\n%s", result) + } + }) +} + +func TestContainsSubstring(t *testing.T) { + tests := []struct { + name string + s string + substr string + expected bool + }{ + {"substring exists", "hello world", "world", true}, + {"substring at start", "hello world", "hello", true}, + {"substring at end", "hello world", "world", true}, + {"substring in middle", "hello world test", "world", true}, + {"exact match", "hello", "hello", true}, + {"empty substring", "hello", "", true}, + {"substring not found", "hello world", "goodbye", false}, + {"case sensitive", "Hello World", "hello", false}, + {"longer substring than string", "hi", "hello", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := containsSubstring(tt.s, tt.substr) + if result != tt.expected { + t.Errorf("containsSubstring(%q, %q) = %v, want %v", + tt.s, tt.substr, result, tt.expected) + } + }) + } +} diff --git a/cmd/status.go b/cmd/status.go index 4ddeb0e..747e543 100644 --- a/cmd/status.go +++ b/cmd/status.go @@ -31,7 +31,7 @@ var statusCmd = &cobra.Command{ ui.PrintInfo("Run 'kairo setup' to get started") return } - ui.PrintError(fmt.Sprintf("Error loading config: %v", err)) + handleConfigError(cmd, err) return } diff --git a/cmd/switch.go b/cmd/switch.go index 45e50ef..9b57cb7 100644 --- a/cmd/switch.go +++ b/cmd/switch.go @@ -118,7 +118,7 @@ var switchCmd = &cobra.Command{ cfg, err := configCache.Get(dir) if err != nil { - cmd.Printf("Error loading config: %v\n", err) + handleConfigError(cmd, err) return } diff --git a/cmd/test.go b/cmd/test.go index 8192e90..05c4764 100644 --- a/cmd/test.go +++ b/cmd/test.go @@ -32,7 +32,7 @@ var testCmd = &cobra.Command{ ui.PrintInfo("Run 'kairo config " + providerName + "' to configure") return } - ui.PrintError(fmt.Sprintf("Error loading config: %v", err)) + handleConfigError(cmd, err) return } diff --git a/internal/config/loader.go b/internal/config/loader.go index ead9377..37c902e 100644 --- a/internal/config/loader.go +++ b/internal/config/loader.go @@ -116,6 +116,14 @@ func LoadConfig(configDir string) (*Config, error) { decoder := yaml.NewDecoder(bytes.NewReader(data)) decoder.KnownFields(true) if err := decoder.Decode(&cfg); err != nil { + // Check for unknown field errors and provide helpful guidance + errStr := err.Error() + if containsUnknownField(errStr) { + return nil, kairoerrors.WrapError(kairoerrors.ConfigError, + "configuration file contains field(s) not recognized by this kairo version", err). + WithContext("path", configPath). + WithContext("hint", "your installed kairo binary is outdated - rebuild and reinstall from source") + } return nil, kairoerrors.WrapError(kairoerrors.ConfigError, "failed to parse configuration file (invalid YAML)", err). WithContext("path", configPath). @@ -152,3 +160,21 @@ func SaveConfig(configDir string, cfg *Config) error { return nil } + +// containsUnknownField checks if the error message indicates an unknown YAML field. +// This pattern appears when the config file contains fields that don't exist +// in the current Config struct, typically due to an outdated binary. +func containsUnknownField(errStr string) bool { + return containsSubstring(errStr, "field") && + containsSubstring(errStr, "not found in type") +} + +// containsSubstring is a simple substring checker to avoid importing strings package. +func containsSubstring(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} From 2acad501a1a66196a9b14fc92d1883b38b034240 Mon Sep 17 00:00:00 2001 From: Benedick Montales Date: Sun, 15 Feb 2026 12:35:22 +0800 Subject: [PATCH 2/2] refactor: extract validation and transaction logic into separate modules - Move cross-provider config validation to internal/validate/provider.go - Move transaction/backup logic to cmd/config_tx.go - Move env var merging to cmd/env.go - Add unit tests for ValidateCrossProviderConfig and ValidateProviderModel - Optimize mergeEnvVars with O(1) index map for duplicate removal - Add backup cleanup on successful transactions - Upgrade Go to 1.25.7 to fix TLS vulnerability (GO-2026-4337) - Fix redundant nil check in errors test - Improve getHarness parameter naming for clarity --- cmd/config.go | 203 ------------------ cmd/config_test.go | 11 +- cmd/config_tx.go | 114 ++++++++++ cmd/env.go | 45 ++++ cmd/harness.go | 28 +++ cmd/harness_test.go | 10 +- cmd/integration_test.go | 3 +- cmd/root.go | 17 +- cmd/root_test.go | 5 +- cmd/switch.go | 67 +----- go.mod | 2 +- internal/config/loader.go | 15 +- internal/errors/errors_test.go | 81 +++----- internal/validate/provider.go | 97 +++++++++ internal/validate/provider_test.go | 321 +++++++++++++++++++++++++++++ 15 files changed, 659 insertions(+), 360 deletions(-) create mode 100644 cmd/config_tx.go create mode 100644 cmd/env.go create mode 100644 internal/validate/provider.go create mode 100644 internal/validate/provider_test.go diff --git a/cmd/config.go b/cmd/config.go index 43ffd3f..b37d4a9 100644 --- a/cmd/config.go +++ b/cmd/config.go @@ -4,9 +4,7 @@ import ( "errors" "fmt" "os" - "path/filepath" "strings" - "time" "github.com/dkmnx/kairo/internal/audit" "github.com/dkmnx/kairo/internal/config" @@ -209,204 +207,3 @@ var configCmd = &cobra.Command{ func init() { rootCmd.AddCommand(configCmd) } - -// createConfigBackup creates a backup of the current configuration file. -// Returns the path to the backup file or an error if the backup fails. -// The backup file is named with a timestamp to allow for multiple backups. -func createConfigBackup(configDir string) (string, error) { - configPath := getConfigPath(configDir) - - // Read the current config file - data, err := os.ReadFile(configPath) - if err != nil { - return "", fmt.Errorf("failed to read config for backup: %w", err) - } - - // Create backup filename with timestamp - backupPath := getBackupPath(configDir) - - // Write the backup - if err := os.WriteFile(backupPath, data, 0600); err != nil { - return "", fmt.Errorf("failed to write backup file: %w", err) - } - - return backupPath, nil -} - -// rollbackConfig restores the configuration from a backup file. -// If successful, the current config is replaced with the backup. -// The backup file is preserved after rollback for safety. -func rollbackConfig(configDir, backupPath string) error { - // Verify backup exists - if _, err := os.Stat(backupPath); os.IsNotExist(err) { - return fmt.Errorf("backup file not found: %s", backupPath) - } - - // Read backup data - data, err := os.ReadFile(backupPath) - if err != nil { - return fmt.Errorf("failed to read backup file: %w", err) - } - - // Write to config file - configPath := getConfigPath(configDir) - if err := os.WriteFile(configPath, data, 0600); err != nil { - return fmt.Errorf("failed to restore config from backup: %w", err) - } - - return nil -} - -// withConfigTransaction executes a function within a transaction-like context. -// -// This function creates a backup of the configuration before executing the -// provided function. If the function returns an error, the configuration -// is automatically rolled back to the backup. This provides atomic-like -// behavior for configuration updates. -// -// Parameters: -// - configDir: Directory containing the configuration file -// - fn: Function to execute within the transaction context -// -// Returns: -// - error: Returns error if transaction fails or rollback fails -// -// Error conditions: -// - Returns error when unable to create configuration backup -// - Returns error when fn returns an error (after attempting rollback) -// - Returns error if rollback fails after transaction failure (critical error) -// -// Thread Safety: Not thread-safe due to file I/O operations -// Security Notes: Backup files retain same permissions as original config (0600) -func withConfigTransaction(configDir string, fn func(txDir string) error) error { - // Create backup before transaction - backupPath, err := createConfigBackup(configDir) - if err != nil { - return fmt.Errorf("failed to create transaction backup: %w", err) - } - - // Execute the transaction function - err = fn(configDir) - - // If transaction failed, rollback - if err != nil { - if rbErr := rollbackConfig(configDir, backupPath); rbErr != nil { - // Rollback failed - this is a critical situation - return fmt.Errorf("transaction failed and rollback also failed: tx_err=%w, rollback_err=%w", err, rbErr) - } - return fmt.Errorf("transaction failed, changes rolled back: %w", err) - } - - return nil -} - -// getConfigPath returns the full path to the config file. -func getConfigPath(configDir string) string { - return filepath.Join(configDir, "config.yaml") -} - -// getBackupPath returns a backup file path with timestamp. -func getBackupPath(configDir string) string { - timestamp := time.Now().Format("20060102-150405") - return filepath.Join(configDir, fmt.Sprintf("config.yaml.backup.%s", timestamp)) -} - -// validateCrossProviderConfig validates configuration across all providers to detect conflicts. -// -// This function checks for environment variable collisions where multiple providers -// attempt to set the same environment variable with different values. Collisions -// with identical values are allowed (idempotent). -// -// Parameters: -// - cfg: Configuration object containing all provider definitions -// -// Returns: -// - error: Returns error if conflicting environment variables are detected -// -// Error conditions: -// - Returns error when same environment variable is set by multiple providers -// with different values (e.g., "API_KEY" set to "key1" by provider A and "key2" by provider B) -// -// Thread Safety: Thread-safe (no shared state, read-only access to config) -func validateCrossProviderConfig(cfg *config.Config) error { - // Build a map of env var names to their values and which providers set them - type envVarSource struct { - provider string - value string - } - envVarMap := make(map[string][]envVarSource) - - for providerName, provider := range cfg.Providers { - for _, envVar := range provider.EnvVars { - // Parse env var to get key and value - parts := strings.SplitN(envVar, "=", 2) - if len(parts) != 2 { - continue - } - key := strings.TrimSpace(parts[0]) - value := strings.TrimSpace(parts[1]) - - envVarMap[key] = append(envVarMap[key], envVarSource{ - provider: providerName, - value: value, - }) - } - } - - // Check for collisions - env vars set by multiple providers with different values - for key, sources := range envVarMap { - if len(sources) > 1 { - // Check if all sources have the same value - firstValue := sources[0].value - allSame := true - for _, s := range sources { - if s.value != firstValue { - allSame = false - break - } - } - if !allSame { - return fmt.Errorf("environment variable collision: '%s' is set to different values by providers: %v", - key, sources) - } - } - } - - return nil -} - -// validateProviderModel validates a model name against provider capabilities. -// For built-in providers with default models, this ensures the model is reasonable. -// Returns an error if the model name is invalid. -func validateProviderModel(providerName, modelName string) error { - if modelName == "" { - return nil // Empty model is allowed (will use provider default) - } - - // Check if this is a built-in provider - if def, ok := providers.GetBuiltInProvider(providerName); ok { - // If provider has a default model, do basic validation - if def.Model != "" { - // Check model name length (most LLM model names are reasonable length) - if len(modelName) > 100 { - return fmt.Errorf("model name '%s' for provider '%s' is too long (max 100 characters)", modelName, providerName) - } - // Check for valid characters (alphanumeric, hyphens, underscores, dots) - for _, r := range modelName { - if !isValidModelRune(r) { - return fmt.Errorf("model name '%s' for provider '%s' contains invalid characters", modelName, providerName) - } - } - } - } - - return nil -} - -// isValidModelRune returns true if the rune is valid in a model name. -func isValidModelRune(r rune) bool { - return (r >= 'a' && r <= 'z') || - (r >= 'A' && r <= 'Z') || - (r >= '0' && r <= '9') || - r == '-' || r == '_' || r == '.' -} diff --git a/cmd/config_test.go b/cmd/config_test.go index 98dffaa..a8f6fdb 100644 --- a/cmd/config_test.go +++ b/cmd/config_test.go @@ -10,6 +10,7 @@ import ( "github.com/dkmnx/kairo/internal/config" "github.com/dkmnx/kairo/internal/providers" + "github.com/dkmnx/kairo/internal/validate" ) func TestProviderDefaults(t *testing.T) { @@ -470,7 +471,7 @@ func TestConfig_CrossProviderValidation(t *testing.T) { } // This should detect the collision - err := validateCrossProviderConfig(cfg) + err := validate.ValidateCrossProviderConfig(cfg) if err == nil { t.Error("Expected error for env var collision, got nil") } @@ -483,18 +484,18 @@ func TestConfig_CrossProviderValidation(t *testing.T) { t.Run("ModelValidation", func(t *testing.T) { // Test with a provider that has a default model (zai has "glm-4.7") // This should validate the model name - err := validateProviderModel("zai", "invalid@model#name!") + err := validate.ValidateProviderModel("zai", "invalid@model#name!") if err == nil { t.Error("Expected error for invalid model with special characters, got nil") } // Test with a model that's too long longModel := strings.Repeat("a", 101) - err = validateProviderModel("zai", longModel) + err = validate.ValidateProviderModel("zai", longModel) if err == nil { t.Error("Expected error for model name that's too long, got nil") } // Test with a valid model - should not error - err = validateProviderModel("zai", "valid-model-name.123") + err = validate.ValidateProviderModel("zai", "valid-model-name.123") if err != nil { t.Errorf("Expected valid model to pass validation, got error: %v", err) } @@ -520,7 +521,7 @@ func TestConfig_CrossProviderValidation(t *testing.T) { DefaultProvider: "zai", } - err := validateCrossProviderConfig(cfg) + err := validate.ValidateCrossProviderConfig(cfg) if err != nil { t.Errorf("Expected valid config to pass validation, got error: %v", err) } diff --git a/cmd/config_tx.go b/cmd/config_tx.go new file mode 100644 index 0000000..f5e08c8 --- /dev/null +++ b/cmd/config_tx.go @@ -0,0 +1,114 @@ +package cmd + +import ( + "fmt" + "os" + "path/filepath" + "time" +) + +// createConfigBackup creates a backup of the current configuration file. +// Returns the path to the backup file or an error if the backup fails. +// The backup file is named with a timestamp to allow for multiple backups. +func createConfigBackup(configDir string) (string, error) { + configPath := getConfigPath(configDir) + + // Read the current config file + data, err := os.ReadFile(configPath) + if err != nil { + return "", fmt.Errorf("failed to read config for backup: %w", err) + } + + // Create backup filename with timestamp + backupPath := getBackupPath(configDir) + + // Write the backup + if err := os.WriteFile(backupPath, data, 0600); err != nil { + return "", fmt.Errorf("failed to write backup file: %w", err) + } + + return backupPath, nil +} + +// rollbackConfig restores the configuration from a backup file. +// If successful, the current config is replaced with the backup. +// The backup file is preserved after rollback for safety. +func rollbackConfig(configDir, backupPath string) error { + // Verify backup exists + if _, err := os.Stat(backupPath); os.IsNotExist(err) { + return fmt.Errorf("backup file not found: %s", backupPath) + } + + // Read backup data + data, err := os.ReadFile(backupPath) + if err != nil { + return fmt.Errorf("failed to read backup file: %w", err) + } + + // Write to config file + configPath := getConfigPath(configDir) + if err := os.WriteFile(configPath, data, 0600); err != nil { + return fmt.Errorf("failed to restore config from backup: %w", err) + } + + return nil +} + +// withConfigTransaction executes a function within a transaction-like context. +// +// This function creates a backup of the configuration before executing the +// provided function. If the function returns an error, the configuration +// is automatically rolled back to the backup. This provides atomic-like +// behavior for configuration updates. +// +// Parameters: +// - configDir: Directory containing the configuration file +// - fn: Function to execute within the transaction context +// +// Returns: +// - error: Returns error if transaction fails or rollback fails +// +// Error conditions: +// - Returns error when unable to create configuration backup +// - Returns error when fn returns an error (after attempting rollback) +// - Returns error if rollback fails after transaction failure (critical error) +// +// Thread Safety: Not thread-safe due to file I/O operations +// Security Notes: Backup files retain same permissions as original config (0600) +func withConfigTransaction(configDir string, fn func(txDir string) error) error { + // Create backup before transaction + backupPath, err := createConfigBackup(configDir) + if err != nil { + return fmt.Errorf("failed to create transaction backup: %w", err) + } + + // Execute the transaction function + err = fn(configDir) + + // If transaction failed, rollback + if err != nil { + if rbErr := rollbackConfig(configDir, backupPath); rbErr != nil { + // Rollback failed - this is a critical situation + return fmt.Errorf("transaction failed and rollback also failed: tx_err=%w, rollback_err=%w", err, rbErr) + } + return fmt.Errorf("transaction failed, changes rolled back: %w", err) + } + + // Transaction succeeded - clean up the backup file + // Best-effort cleanup, ignore errors + _ = os.Remove(backupPath) + + return nil +} + +// getConfigPath returns the full path to the config file. +func getConfigPath(configDir string) string { + return filepath.Join(configDir, "config.yaml") +} + +// getBackupPath returns a backup file path with timestamp. +func getBackupPath(configDir string) string { + // Use nanosecond precision to avoid filename conflicts with rapid successive operations + timestamp := time.Now().Format("20060102-150405.000000000") + return filepath.Join(configDir, fmt.Sprintf("config.yaml.backup.%s", timestamp)) +} diff --git a/cmd/env.go b/cmd/env.go new file mode 100644 index 0000000..5ae7288 --- /dev/null +++ b/cmd/env.go @@ -0,0 +1,45 @@ +package cmd + +import ( + "strings" +) + +// mergeEnvVars merges environment variable slices, deduplicating by key. +// If duplicate keys are found, the last value wins (preserves order of precedence). +// Env vars should be in "KEY=VALUE" format. +func mergeEnvVars(envs ...[]string) []string { + // Use map to track key -> index in result for O(1) lookup and removal + seen := make(map[string]int) + var result []string + + for _, envSlice := range envs { + for _, env := range envSlice { + // Extract the key (everything before the first '=') + idx := strings.IndexByte(env, '=') + // Check for invalid format: no '=' or '=' is first or last char + if idx <= 0 || idx == len(env)-1 { + // Invalid format (no key or no value), skip + continue + } + key := env[:idx] + + // Remove any previous occurrence of this key + if prevIdx, exists := seen[key]; exists { + // Remove the previous entry by swapping with last and truncating + lastIdx := len(result) - 1 + result[prevIdx] = result[lastIdx] + // Update the index for the moved entry + if prevIdx != lastIdx { + seen[result[prevIdx][:strings.IndexByte(result[prevIdx], '=')]] = prevIdx + } + result = result[:lastIdx] + } + + // Add the new entry and record its index + result = append(result, env) + seen[key] = len(result) - 1 + } + } + + return result +} diff --git a/cmd/harness.go b/cmd/harness.go index 3b79bd7..1b6d39e 100644 --- a/cmd/harness.go +++ b/cmd/harness.go @@ -97,3 +97,31 @@ func init() { harnessCmd.AddCommand(harnessSetCmd) rootCmd.AddCommand(harnessCmd) } + +// getHarness returns the harness to use, checking flag then config then defaulting to claude. +func getHarness(flagHarness, configHarness string) string { + harness := flagHarness + if harness == "" { + harness = configHarness + } + if harness == "" { + return "claude" + } + if harness != "claude" && harness != "qwen" { + ui.PrintWarn(fmt.Sprintf("Unknown harness '%s', using 'claude'", harness)) + return "claude" + } + return harness +} + +// getHarnessBinary returns the CLI binary name for a given harness. +func getHarnessBinary(harness string) string { + switch harness { + case "qwen": + return "qwen" + case "claude": + return "claude" + default: + return "claude" + } +} diff --git a/cmd/harness_test.go b/cmd/harness_test.go index 304ad44..1902764 100644 --- a/cmd/harness_test.go +++ b/cmd/harness_test.go @@ -157,13 +157,7 @@ func TestGetHarness(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - cfg := &config.Config{ - Providers: make(map[string]config.Provider), - DefaultModels: make(map[string]string), - DefaultHarness: tt.configHarness, - } - - result := getHarness(cfg, tt.flagHarness) + result := getHarness(tt.flagHarness, tt.configHarness) if result != tt.expected { t.Errorf("getHarness() = %q, want %q", result, tt.expected) } @@ -215,7 +209,7 @@ func TestGetHarnessWithExistingConfig(t *testing.T) { t.Fatalf("LoadConfig() error = %v", err) } - result := getHarness(loadedCfg, "") + result := getHarness("", loadedCfg.DefaultHarness) if result != "qwen" { t.Errorf("getHarness() = %q, want %q", result, "qwen") } diff --git a/cmd/integration_test.go b/cmd/integration_test.go index b03481c..798d747 100644 --- a/cmd/integration_test.go +++ b/cmd/integration_test.go @@ -2,6 +2,7 @@ package cmd import ( "path/filepath" + "strings" "testing" "github.com/dkmnx/kairo/internal/config" @@ -231,7 +232,7 @@ func TestCustomProviderSecretsAfterRotation(t *testing.T) { } func contains(s, substr string) bool { - return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsSubstring(s, substr)) + return strings.Contains(s, substr) } // TestE2ESetupToSwitchWorkflow tests the complete end-to-end workflow diff --git a/cmd/root.go b/cmd/root.go index 546b934..854a05c 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -25,6 +25,7 @@ import ( "fmt" "os" "runtime" + "strings" "time" "github.com/dkmnx/kairo/internal/config" @@ -200,9 +201,9 @@ func handleConfigError(cmd *cobra.Command, err error) { // This can appear in two forms: // 1. Raw YAML error: "field X not found in type config.Config" // 2. Wrapped error: "configuration file contains field(s) not recognized" - if (containsSubstring(errStr, "field") && containsSubstring(errStr, "not found in type")) || - containsSubstring(errStr, "configuration file contains field(s) not recognized") || - containsSubstring(errStr, "your installed kairo binary is outdated") { + if (strings.Contains(errStr, "field") && strings.Contains(errStr, "not found in type")) || + strings.Contains(errStr, "configuration file contains field(s) not recognized") || + strings.Contains(errStr, "your installed kairo binary is outdated") { cmd.Println("Error: Your kairo binary is outdated and cannot read your configuration file.") cmd.Println() cmd.Println("The configuration file contains newer fields that this version doesn't recognize.") @@ -232,13 +233,3 @@ func handleConfigError(cmd *cobra.Command, err error) { // Default error handling cmd.Printf("Error loading config: %v\n", err) } - -// containsSubstring checks if a string contains a substring. -func containsSubstring(s, substr string) bool { - for i := 0; i <= len(s)-len(substr); i++ { - if s[i:i+len(substr)] == substr { - return true - } - } - return false -} diff --git a/cmd/root_test.go b/cmd/root_test.go index e000242..fa28871 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -6,6 +6,7 @@ import ( "os" "os/exec" "path/filepath" + "strings" "sync/atomic" "testing" @@ -539,9 +540,9 @@ func TestContainsSubstring(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := containsSubstring(tt.s, tt.substr) + result := strings.Contains(tt.s, tt.substr) if result != tt.expected { - t.Errorf("containsSubstring(%q, %q) = %v, want %v", + t.Errorf("strings.Contains(%q, %q) = %v, want %v", tt.s, tt.substr, result, tt.expected) } }) diff --git a/cmd/switch.go b/cmd/switch.go index 9b57cb7..c0a4a44 100644 --- a/cmd/switch.go +++ b/cmd/switch.go @@ -37,71 +37,6 @@ var ( harnessFlag string ) -// getHarness returns the harness to use, checking flag then config then defaulting to claude. -func getHarness(cfg *config.Config, flagHarness string) string { - harness := flagHarness - if harness == "" { - harness = cfg.DefaultHarness - } - if harness == "" { - return "claude" - } - if harness != "claude" && harness != "qwen" { - ui.PrintWarn(fmt.Sprintf("Unknown harness '%s', using 'claude'", harness)) - return "claude" - } - return harness -} - -// getHarnessBinary returns the CLI binary name for a given harness. -func getHarnessBinary(harness string) string { - switch harness { - case "qwen": - return "qwen" - case "claude": - return "claude" - default: - return "claude" - } -} - -// mergeEnvVars merges environment variable slices, deduplicating by key. -// If duplicate keys are found, the last value wins (preserves order of precedence). -// Env vars should be in "KEY=VALUE" format. -func mergeEnvVars(envs ...[]string) []string { - seen := make(map[string]bool) - var result []string - - for _, envSlice := range envs { - for _, env := range envSlice { - // Extract the key (everything before the first '=') - idx := strings.IndexByte(env, '=') - if idx <= 0 { - // Invalid format, skip - continue - } - key := env[:idx] - - // Remove any previous occurrence of this key - if seen[key] { - // Find and remove previous entry with this key - for i, e := range result { - if strings.HasPrefix(e, key+"=") { - result = append(result[:i], result[i+1:]...) - break - } - } - } - - // Add the new entry - result = append(result, env) - seen[key] = true - } - } - - return result -} - var switchCmd = &cobra.Command{ Use: "switch [args]", Short: "Switch to a provider and execute Claude", @@ -134,7 +69,7 @@ var switchCmd = &cobra.Command{ ui.PrintWarn(fmt.Sprintf("Audit logging failed: %v", err)) } - harnessToUse := getHarness(cfg, harnessFlag) + harnessToUse := getHarness(harnessFlag, cfg.DefaultHarness) harnessBinary := getHarnessBinary(harnessToUse) // Environment variable name constants for model configuration diff --git a/go.mod b/go.mod index c69b0ba..401bcda 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/dkmnx/kairo -go 1.25.6 +go 1.25.7 require ( filippo.io/age v1.2.1 diff --git a/internal/config/loader.go b/internal/config/loader.go index 37c902e..c6fd2c0 100644 --- a/internal/config/loader.go +++ b/internal/config/loader.go @@ -5,6 +5,7 @@ import ( "fmt" "os" "path/filepath" + "strings" kairoerrors "github.com/dkmnx/kairo/internal/errors" "gopkg.in/yaml.v3" @@ -165,16 +166,6 @@ func SaveConfig(configDir string, cfg *Config) error { // This pattern appears when the config file contains fields that don't exist // in the current Config struct, typically due to an outdated binary. func containsUnknownField(errStr string) bool { - return containsSubstring(errStr, "field") && - containsSubstring(errStr, "not found in type") -} - -// containsSubstring is a simple substring checker to avoid importing strings package. -func containsSubstring(s, substr string) bool { - for i := 0; i <= len(s)-len(substr); i++ { - if s[i:i+len(substr)] == substr { - return true - } - } - return false + return strings.Contains(errStr, "field") && + strings.Contains(errStr, "not found in type") } diff --git a/internal/errors/errors_test.go b/internal/errors/errors_test.go index 9f3b72b..bfad9e4 100644 --- a/internal/errors/errors_test.go +++ b/internal/errors/errors_test.go @@ -2,6 +2,7 @@ package errors import ( "errors" + "strings" "testing" ) @@ -12,11 +13,6 @@ func TestNewError(t *testing.T) { t.Fatal("NewError() returned nil") } - // Defensive check for static analysis - if err == nil { - return - } - if err.Type != ConfigError { t.Errorf("Type = %v, want %v", err.Type, ConfigError) } @@ -69,13 +65,13 @@ func TestWithContext(t *testing.T) { errMsg := err.Error() // Map order is not guaranteed, just check values are present - if !containsSubstring(errMsg, "provider=zai") { + if !strings.Contains(errMsg, "provider=zai") { t.Errorf("Error() should contain provider context, got: %v", errMsg) } - if !containsSubstring(errMsg, "action=switch") { + if !strings.Contains(errMsg, "action=switch") { t.Errorf("Error() should contain action context, got: %v", errMsg) } - if !containsSubstring(errMsg, "provider not configured") { + if !strings.Contains(errMsg, "provider not configured") { t.Errorf("Error() should contain message, got: %v", errMsg) } }) @@ -87,13 +83,13 @@ func TestWithContext(t *testing.T) { errMsg := err.Error() // Map order is not guaranteed, just check both values are present - if !containsSubstring(errMsg, "file=/path/to/config") { + if !strings.Contains(errMsg, "file=/path/to/config") { t.Errorf("Error() should contain file context, got: %v", errMsg) } - if !containsSubstring(errMsg, "line=42") { + if !strings.Contains(errMsg, "line=42") { t.Errorf("Error() should contain line context, got: %v", errMsg) } - if !containsSubstring(errMsg, "invalid config") { + if !strings.Contains(errMsg, "invalid config") { t.Errorf("Error() should contain message, got: %v", errMsg) } }) @@ -196,13 +192,13 @@ func TestErrorStringFormatting(t *testing.T) { WithContext("attempt", "3") errMsg := err.Error() // Map order is not guaranteed, just check values are present - if !containsSubstring(errMsg, "provider unavailable") { + if !strings.Contains(errMsg, "provider unavailable") { t.Errorf("Error() should contain message, got: %v", errMsg) } - if !containsSubstring(errMsg, "name=anthropic") { + if !strings.Contains(errMsg, "name=anthropic") { t.Errorf("Error() should contain name context, got: %v", errMsg) } - if !containsSubstring(errMsg, "attempt=3") { + if !strings.Contains(errMsg, "attempt=3") { t.Errorf("Error() should contain attempt context, got: %v", errMsg) } }) @@ -214,31 +210,18 @@ func TestErrorStringFormatting(t *testing.T) { WithContext("port", "443") errMsg := err.Error() - if !containsSubstring(errMsg, "failed to connect") { + if !strings.Contains(errMsg, "failed to connect") { t.Error("Error message should contain original message") } - if !containsSubstring(errMsg, "connection timeout") { + if !strings.Contains(errMsg, "connection timeout") { t.Error("Error message should contain cause") } - if !containsSubstring(errMsg, "host=api.example.com") { + if !strings.Contains(errMsg, "host=api.example.com") { t.Error("Error message should contain host context") } }) } -func containsSubstring(s, substr string) bool { - return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && findSubstring(s, substr)) -} - -func findSubstring(s, substr string) bool { - for i := 0; i <= len(s)-len(substr); i++ { - if s[i:i+len(substr)] == substr { - return true - } - } - return false -} - func TestFileError(t *testing.T) { t.Run("creates file error with path context", func(t *testing.T) { cause := errors.New("permission denied") @@ -261,13 +244,13 @@ func TestFileError(t *testing.T) { } errMsg := err.Error() - if !containsSubstring(errMsg, "failed to write config") { + if !strings.Contains(errMsg, "failed to write config") { t.Errorf("Error() should contain message, got: %v", errMsg) } - if !containsSubstring(errMsg, "permission denied") { + if !strings.Contains(errMsg, "permission denied") { t.Errorf("Error() should contain cause, got: %v", errMsg) } - if !containsSubstring(errMsg, "path=/path/to/config.yaml") { + if !strings.Contains(errMsg, "path=/path/to/config.yaml") { t.Errorf("Error() should contain path context, got: %v", errMsg) } }) @@ -284,10 +267,10 @@ func TestFileError(t *testing.T) { } errMsg := err.Error() - if !containsSubstring(errMsg, "file not found") { + if !strings.Contains(errMsg, "file not found") { t.Errorf("Error() should contain message, got: %v", errMsg) } - if !containsSubstring(errMsg, "path=/missing/file.txt") { + if !strings.Contains(errMsg, "path=/missing/file.txt") { t.Errorf("Error() should contain path context, got: %v", errMsg) } }) @@ -324,13 +307,13 @@ func TestConfigFileError(t *testing.T) { } errMsg := err.Error() - if !containsSubstring(errMsg, "failed to parse config") { + if !strings.Contains(errMsg, "failed to parse config") { t.Errorf("Error() should contain message, got: %v", errMsg) } - if !containsSubstring(errMsg, "path=/home/user/.config/kairo/config") { + if !strings.Contains(errMsg, "path=/home/user/.config/kairo/config") { t.Errorf("Error() should contain path context, got: %v", errMsg) } - if !containsSubstring(errMsg, "hint=Check YAML syntax and indentation") { + if !strings.Contains(errMsg, "hint=Check YAML syntax and indentation") { t.Errorf("Error() should contain hint context, got: %v", errMsg) } }) @@ -356,7 +339,7 @@ func TestConfigFileError(t *testing.T) { } errMsg := err.Error() - if !containsSubstring(errMsg, "config file not found") { + if !strings.Contains(errMsg, "config file not found") { t.Errorf("Error() should contain message, got: %v", errMsg) } }) @@ -388,13 +371,13 @@ func TestCryptoErrorWithHint(t *testing.T) { } errMsg := err.Error() - if !containsSubstring(errMsg, "failed to decrypt secrets") { + if !strings.Contains(errMsg, "failed to decrypt secrets") { t.Errorf("Error() should contain message, got: %v", errMsg) } - if !containsSubstring(errMsg, "X25519 key not found") { + if !strings.Contains(errMsg, "X25519 key not found") { t.Errorf("Error() should contain cause, got: %v", errMsg) } - if !containsSubstring(errMsg, "hint=Ensure 'age.key' exists in config directory") { + if !strings.Contains(errMsg, "hint=Ensure 'age.key' exists in config directory") { t.Errorf("Error() should contain hint context, got: %v", errMsg) } }) @@ -411,10 +394,10 @@ func TestCryptoErrorWithHint(t *testing.T) { } errMsg := err.Error() - if !containsSubstring(errMsg, "encryption failed") { + if !strings.Contains(errMsg, "encryption failed") { t.Errorf("Error() should contain message, got: %v", errMsg) } - if !containsSubstring(errMsg, "hint=Check file permissions on 'secrets.age'") { + if !strings.Contains(errMsg, "hint=Check file permissions on 'secrets.age'") { t.Errorf("Error() should contain hint context, got: %v", errMsg) } }) @@ -442,13 +425,13 @@ func TestProviderErr(t *testing.T) { } errMsg := err.Error() - if !containsSubstring(errMsg, "API request failed") { + if !strings.Contains(errMsg, "API request failed") { t.Errorf("Error() should contain message, got: %v", errMsg) } - if !containsSubstring(errMsg, "HTTP 401 Unauthorized") { + if !strings.Contains(errMsg, "HTTP 401 Unauthorized") { t.Errorf("Error() should contain cause, got: %v", errMsg) } - if !containsSubstring(errMsg, "provider=zai") { + if !strings.Contains(errMsg, "provider=zai") { t.Errorf("Error() should contain provider context, got: %v", errMsg) } }) @@ -481,10 +464,10 @@ func TestProviderErr(t *testing.T) { } errMsg := err.Error() - if !containsSubstring(errMsg, "provider not configured") { + if !strings.Contains(errMsg, "provider not configured") { t.Errorf("Error() should contain message, got: %v", errMsg) } - if !containsSubstring(errMsg, "provider=minimax") { + if !strings.Contains(errMsg, "provider=minimax") { t.Errorf("Error() should contain provider context, got: %v", errMsg) } }) diff --git a/internal/validate/provider.go b/internal/validate/provider.go new file mode 100644 index 0000000..3ca307b --- /dev/null +++ b/internal/validate/provider.go @@ -0,0 +1,97 @@ +package validate + +import ( + "fmt" + "strings" + + "github.com/dkmnx/kairo/internal/config" + "github.com/dkmnx/kairo/internal/providers" +) + +// validateCrossProviderConfig validates configuration across all providers to detect conflicts. +// +// This function checks for environment variable collisions where multiple providers +// attempt to set the same environment variable with different values. Collisions +// with identical values are allowed (idempotent). +func ValidateCrossProviderConfig(cfg *config.Config) error { + // Build a map of env var names to their values and which providers set them + type envVarSource struct { + provider string + value string + } + envVarMap := make(map[string][]envVarSource) + + for providerName, provider := range cfg.Providers { + for _, envVar := range provider.EnvVars { + // Parse env var to get key and value + parts := strings.SplitN(envVar, "=", 2) + if len(parts) != 2 { + continue + } + key := strings.TrimSpace(parts[0]) + value := strings.TrimSpace(parts[1]) + + envVarMap[key] = append(envVarMap[key], envVarSource{ + provider: providerName, + value: value, + }) + } + } + + // Check for collisions - env vars set by multiple providers with different values + for key, sources := range envVarMap { + if len(sources) > 1 { + // Check if all sources have the same value + firstValue := sources[0].value + allSame := true + for _, s := range sources { + if s.value != firstValue { + allSame = false + break + } + } + if !allSame { + return fmt.Errorf("environment variable collision: '%s' is set to different values by providers: %v", + key, sources) + } + } + } + + return nil +} + +// ValidateProviderModel validates a model name against provider capabilities. +// For built-in providers with default models, this ensures the model is reasonable. +// Returns an error if the model name is invalid. +func ValidateProviderModel(providerName, modelName string) error { + if modelName == "" { + return nil // Empty model is allowed (will use provider default) + } + + // Check if this is a built-in provider + if def, ok := providers.GetBuiltInProvider(providerName); ok { + // If provider has a default model, do basic validation + if def.Model != "" { + // Check model name length (most LLM model names are reasonable length) + if len(modelName) > 100 { + return fmt.Errorf("model name '%s' for provider '%s' is too long (max 100 characters)", modelName, providerName) + } + // Check for valid characters (alphanumeric, hyphens, underscores, dots) + for _, r := range modelName { + if !isValidModelRune(r) { + return fmt.Errorf("model name '%s' for provider '%s' contains invalid characters", modelName, providerName) + } + } + } + } + + return nil +} + +// isValidModelRune returns true if the rune is valid in a model name. +func isValidModelRune(r rune) bool { + return (r >= 'a' && r <= 'z') || + (r >= 'A' && r <= 'Z') || + (r >= '0' && r <= '9') || + r == '-' || r == '_' || r == '.' +} diff --git a/internal/validate/provider_test.go b/internal/validate/provider_test.go new file mode 100644 index 0000000..c108cf1 --- /dev/null +++ b/internal/validate/provider_test.go @@ -0,0 +1,321 @@ +package validate + +import ( + "strings" + "testing" + + "github.com/dkmnx/kairo/internal/config" +) + +func TestValidateCrossProviderConfig(t *testing.T) { + tests := []struct { + name string + cfg *config.Config + wantErr bool + errMsg string // substring to check in error message + }{ + { + name: "empty providers", + cfg: &config.Config{ + Providers: map[string]config.Provider{}, + }, + wantErr: false, + }, + { + name: "single provider", + cfg: &config.Config{ + Providers: map[string]config.Provider{ + "test": { + EnvVars: []string{"KEY=value"}, + }, + }, + }, + wantErr: false, + }, + { + name: "multiple providers same env var same value", + cfg: &config.Config{ + Providers: map[string]config.Provider{ + "provider1": { + EnvVars: []string{"SHARED_VAR=value1"}, + }, + "provider2": { + EnvVars: []string{"SHARED_VAR=value1"}, + }, + }, + }, + wantErr: false, + }, + { + name: "multiple providers same env var different values", + cfg: &config.Config{ + Providers: map[string]config.Provider{ + "provider1": { + EnvVars: []string{"SHARED_VAR=value1"}, + }, + "provider2": { + EnvVars: []string{"SHARED_VAR=value2"}, + }, + }, + }, + wantErr: true, + errMsg: "environment variable collision", + }, + { + name: "multiple providers different env vars", + cfg: &config.Config{ + Providers: map[string]config.Provider{ + "provider1": { + EnvVars: []string{"VAR1=value1", "VAR2=value2"}, + }, + "provider2": { + EnvVars: []string{"VAR3=value3"}, + }, + }, + }, + wantErr: false, + }, + { + name: "three providers collision", + cfg: &config.Config{ + Providers: map[string]config.Provider{ + "provider1": { + EnvVars: []string{"COMMON=test"}, + }, + "provider2": { + EnvVars: []string{"COMMON=test"}, + }, + "provider3": { + EnvVars: []string{"COMMON=different"}, + }, + }, + }, + wantErr: true, + errMsg: "environment variable collision", + }, + { + name: "env var with equals in value", + cfg: &config.Config{ + Providers: map[string]config.Provider{ + "provider1": { + EnvVars: []string{"JSON_DATA={\"key\":\"value\"}"}, + }, + }, + }, + wantErr: false, + }, + { + name: "malformed env var (no equals)", + cfg: &config.Config{ + Providers: map[string]config.Provider{ + "provider1": { + EnvVars: []string{"INVALID_VAR"}, + }, + }, + }, + wantErr: false, // malformed vars are skipped + }, + { + name: "malformed env var (empty key)", + cfg: &config.Config{ + Providers: map[string]config.Provider{ + "provider1": { + EnvVars: []string{"=value"}, + }, + }, + }, + wantErr: false, // malformed vars are skipped + }, + { + name: "whitespace in key and value - same after trim", + cfg: &config.Config{ + Providers: map[string]config.Provider{ + "provider1": { + EnvVars: []string{" KEY = value "}, + }, + "provider2": { + EnvVars: []string{"KEY=value"}, + }, + }, + }, + wantErr: false, // Keys and values match after trimming - no collision + }, + { + name: "whitespace in key - different values after trim", + cfg: &config.Config{ + Providers: map[string]config.Provider{ + "provider1": { + EnvVars: []string{" KEY = value1 "}, + }, + "provider2": { + EnvVars: []string{"KEY=value2"}, + }, + }, + }, + wantErr: true, // Keys match after trim, but values differ -> collision + errMsg: "environment variable collision", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateCrossProviderConfig(tt.cfg) + if (err != nil) != tt.wantErr { + t.Errorf("ValidateCrossProviderConfig() error = %v, wantErr %v", err, tt.wantErr) + } + if tt.wantErr && tt.errMsg != "" && err != nil { + if !strings.Contains(err.Error(), tt.errMsg) { + t.Errorf("ValidateCrossProviderConfig() error = %v, should contain %q", err, tt.errMsg) + } + } + }) + } +} + +func TestValidateProviderModel(t *testing.T) { + tests := []struct { + name string + provider string + model string + wantErr bool + errContains string + }{ + // Empty model tests + { + name: "empty model", + provider: "anthropic", + model: "", + wantErr: false, + }, + + // Valid model names + { + name: "valid model for built-in provider", + provider: "zai", + model: "valid-model-name", + wantErr: false, + }, + { + name: "valid model with dots", + provider: "zai", + model: "model.v1.0", + wantErr: false, + }, + { + name: "valid model with numbers", + provider: "zai", + model: "claude-3-5-sonnet-20241022", + wantErr: false, + }, + + // Invalid model names + { + name: "model too long", + provider: "zai", + model: strings.Repeat("a", 101), + wantErr: true, + errContains: "too long", + }, + { + name: "model with invalid character @", + provider: "zai", + model: "invalid@model", + wantErr: true, + errContains: "invalid characters", + }, + { + name: "model with invalid character #", + provider: "zai", + model: "model#name", + wantErr: true, + errContains: "invalid characters", + }, + { + name: "model with invalid character space", + provider: "zai", + model: "invalid model", + wantErr: true, + errContains: "invalid characters", + }, + { + name: "model with invalid character !", + provider: "zai", + model: "model!name", + wantErr: true, + errContains: "invalid characters", + }, + + // Non-built-in providers (should skip validation) + { + name: "non-built-in provider with any model", + provider: "custom", + model: "any-model-name-@#$%", + wantErr: false, + }, + { + name: "non-built-in provider with empty model", + provider: "unknown", + model: "", + wantErr: false, + }, + + // Built-in providers without default models + { + name: "anthropic (no default model) valid", + provider: "anthropic", + model: "any-model", + wantErr: false, + }, + { + name: "custom (no default model) valid", + provider: "custom", + model: "model@#$", + wantErr: false, // no validation for providers without default model + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateProviderModel(tt.provider, tt.model) + if (err != nil) != tt.wantErr { + t.Errorf("ValidateProviderModel() error = %v, wantErr %v", err, tt.wantErr) + } + if tt.wantErr && tt.errContains != "" && err != nil { + if !strings.Contains(err.Error(), tt.errContains) { + t.Errorf("ValidateProviderModel() error = %v, should contain %q", err, tt.errContains) + } + } + }) + } +} + +func TestIsValidModelRune(t *testing.T) { + tests := []struct { + rune rune + valid bool + }{ + {'a', true}, + {'z', true}, + {'A', true}, + {'Z', true}, + {'0', true}, + {'9', true}, + {'-', true}, + {'_', true}, + {'.', true}, + {'@', false}, + {'#', false}, + {'!', false}, + {' ', false}, + {'/', false}, + {':', false}, + } + + for _, tt := range tests { + t.Run(string(tt.rune), func(t *testing.T) { + if got := isValidModelRune(tt.rune); got != tt.valid { + t.Errorf("isValidModelRune(%q) = %v, want %v", tt.rune, got, tt.valid) + } + }) + } +}