diff --git a/.goreleaser.yaml b/.goreleaser.yaml index 693d8d3..61a7483 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -48,6 +48,7 @@ snapshot: version_template: '{{ .Version }}-next' changelog: + # Disable auto-generated changelog since CHANGELOG.md is maintained manually disable: true release: diff --git a/AGENTS.md b/AGENTS.md index ed71084..a5b132b 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -161,12 +161,14 @@ kairo/ - **Line Length:** 120 characters (MD013) - **Indentation:** Tabs (Go standard) - **Naming:** Go conventions (PascalCase for exported, camelCase for unexported) -- **Error Handling:** Typed errors from `internal/errors` package +- **Error Handling:** Always use typed errors from `internal/errors` package - **Formatting:** `gofmt -w .` (run before committing) - **Vetting:** `go vet ./...` (run before committing) ### Error Handling Pattern +When handling errors, always use an error type from the `kairoerrors` package. Never use plain `errors.New()` or `fmt.Errorf()` without a typed error. + ```go import kairoerrors "github.com/dkmnx/kairo/internal/errors" diff --git a/cmd/audit.go b/cmd/audit.go index 9f34146..2ddb7c7 100644 --- a/cmd/audit.go +++ b/cmd/audit.go @@ -8,6 +8,7 @@ import ( "strings" "github.com/dkmnx/kairo/internal/audit" + kairoerrors "github.com/dkmnx/kairo/internal/errors" "github.com/dkmnx/kairo/internal/ui" "github.com/spf13/cobra" ) @@ -122,7 +123,8 @@ var auditExportCmd = &cobra.Command{ } if exportOutput == "" { - return fmt.Errorf("--output is required for export") + return kairoerrors.NewError(kairoerrors.ConfigError, + "--output is required for export") } if _, err := os.Stat(dir); os.IsNotExist(err) { @@ -318,5 +320,6 @@ func exportAuditLog(entries []audit.AuditEntry, outputPath, format string) error return nil } - return fmt.Errorf("unsupported format: %s (supported: csv, json)", format) + return kairoerrors.NewError(kairoerrors.ConfigError, + fmt.Sprintf("unsupported format: %s (supported: csv, json)", format)) } diff --git a/cmd/audit_helpers.go b/cmd/audit_helpers.go index 972e753..bf8fb49 100644 --- a/cmd/audit_helpers.go +++ b/cmd/audit_helpers.go @@ -1,9 +1,8 @@ package cmd import ( - "fmt" - "github.com/dkmnx/kairo/internal/audit" + kairoerrors "github.com/dkmnx/kairo/internal/errors" ) // logAuditEvent logs an audit event using the provided logging function. @@ -27,12 +26,14 @@ import ( func logAuditEvent(configDir string, logFunc func(*audit.Logger) error) error { logger, err := audit.NewLogger(configDir) if err != nil { - return fmt.Errorf("failed to create audit logger: %w", err) + return kairoerrors.WrapError(kairoerrors.ConfigError, + "failed to create audit logger", err) } defer logger.Close() if err := logFunc(logger); err != nil { - return fmt.Errorf("failed to log audit event: %w", err) + return kairoerrors.WrapError(kairoerrors.ConfigError, + "failed to log audit event", err) } return nil } diff --git a/cmd/config_test.go b/cmd/config_test.go index a8f6fdb..b717b25 100644 --- a/cmd/config_test.go +++ b/cmd/config_test.go @@ -254,7 +254,7 @@ func TestProviderConfigSaveLoad(t *testing.T) { func TestGetConfigDir(t *testing.T) { // Reset configDir to avoid pollution from other tests - configDir = "" + setConfigDir("") home, err := os.UserHomeDir() if err != nil { diff --git a/cmd/config_tx.go b/cmd/config_tx.go index f5e08c8..2c3216d 100644 --- a/cmd/config_tx.go +++ b/cmd/config_tx.go @@ -5,6 +5,8 @@ import ( "os" "path/filepath" "time" + + kairoerrors "github.com/dkmnx/kairo/internal/errors" ) // createConfigBackup creates a backup of the current configuration file. @@ -16,7 +18,8 @@ func createConfigBackup(configDir string) (string, error) { // Read the current config file data, err := os.ReadFile(configPath) if err != nil { - return "", fmt.Errorf("failed to read config for backup: %w", err) + return "", kairoerrors.WrapError(kairoerrors.ConfigError, + "failed to read config for backup", err) } // Create backup filename with timestamp @@ -24,7 +27,8 @@ func createConfigBackup(configDir string) (string, error) { // Write the backup if err := os.WriteFile(backupPath, data, 0600); err != nil { - return "", fmt.Errorf("failed to write backup file: %w", err) + return "", kairoerrors.WrapError(kairoerrors.FileSystemError, + "failed to write backup file", err) } return backupPath, nil @@ -36,19 +40,22 @@ func createConfigBackup(configDir string) (string, error) { func rollbackConfig(configDir, backupPath string) error { // Verify backup exists if _, err := os.Stat(backupPath); os.IsNotExist(err) { - return fmt.Errorf("backup file not found: %s", backupPath) + return kairoerrors.NewError(kairoerrors.ConfigError, + fmt.Sprintf("backup file not found: %s", backupPath)) } // Read backup data data, err := os.ReadFile(backupPath) if err != nil { - return fmt.Errorf("failed to read backup file: %w", err) + return kairoerrors.WrapError(kairoerrors.ConfigError, + "failed to read backup file", err) } // Write to config file configPath := getConfigPath(configDir) if err := os.WriteFile(configPath, data, 0600); err != nil { - return fmt.Errorf("failed to restore config from backup: %w", err) + return kairoerrors.WrapError(kairoerrors.ConfigError, + "failed to restore config from backup", err) } return nil @@ -79,7 +86,8 @@ func withConfigTransaction(configDir string, fn func(txDir string) error) error // Create backup before transaction backupPath, err := createConfigBackup(configDir) if err != nil { - return fmt.Errorf("failed to create transaction backup: %w", err) + return kairoerrors.WrapError(kairoerrors.ConfigError, + "failed to create transaction backup", err) } // Execute the transaction function @@ -89,9 +97,11 @@ func withConfigTransaction(configDir string, fn func(txDir string) error) error if err != nil { if rbErr := rollbackConfig(configDir, backupPath); rbErr != nil { // Rollback failed - this is a critical situation - return fmt.Errorf("transaction failed and rollback also failed: tx_err=%w, rollback_err=%w", err, rbErr) + return kairoerrors.WrapError(kairoerrors.ConfigError, + fmt.Sprintf("transaction failed and rollback also failed: tx_err=%v, rollback_err=%v", err, rbErr), rbErr) } - return fmt.Errorf("transaction failed, changes rolled back: %w", err) + return kairoerrors.WrapError(kairoerrors.ConfigError, + "transaction failed, changes rolled back", err) } // Transaction succeeded - clean up the backup file diff --git a/cmd/config_tx_test.go b/cmd/config_tx_test.go new file mode 100644 index 0000000..b4b2bb2 --- /dev/null +++ b/cmd/config_tx_test.go @@ -0,0 +1,324 @@ +package cmd + +import ( + "os" + "path/filepath" + "testing" + + "github.com/dkmnx/kairo/internal/config" +) + +func TestCreateConfigBackup(t *testing.T) { + tmpDir := t.TempDir() + + // Create a config file + cfg := &config.Config{ + Providers: map[string]config.Provider{ + "test": {Name: "Test Provider"}, + }, + } + if err := config.SaveConfig(tmpDir, cfg); err != nil { + t.Fatal(err) + } + + backupPath, err := createConfigBackup(tmpDir) + if err != nil { + t.Fatalf("createConfigBackup() error = %v", err) + } + + if backupPath == "" { + t.Fatal("Backup path should not be empty") + } + + // Verify backup file exists + if _, err := os.Stat(backupPath); os.IsNotExist(err) { + t.Error("Backup file should exist") + } + + // Verify backup contains same content as original + originalPath := getConfigPath(tmpDir) + originalData, err := os.ReadFile(originalPath) + if err != nil { + t.Fatal(err) + } + + backupData, err := os.ReadFile(backupPath) + if err != nil { + t.Fatal(err) + } + + if string(originalData) != string(backupData) { + t.Error("Backup should contain same content as original config") + } +} + +func TestCreateConfigBackupNonExistent(t *testing.T) { + tmpDir := t.TempDir() + + _, err := createConfigBackup(tmpDir) + if err == nil { + t.Error("createConfigBackup() should error when config doesn't exist") + } +} + +func TestRollbackConfig(t *testing.T) { + tmpDir := t.TempDir() + + // Create original config + originalCfg := &config.Config{ + Providers: map[string]config.Provider{ + "original": {Name: "Original Provider"}, + }, + } + if err := config.SaveConfig(tmpDir, originalCfg); err != nil { + t.Fatal(err) + } + + // Create a backup with different content + backupPath := filepath.Join(tmpDir, "config.yaml.backup.test") + if err := os.WriteFile(backupPath, []byte("modified content"), 0600); err != nil { + t.Fatal(err) + } + + // Rollback + err := rollbackConfig(tmpDir, backupPath) + if err != nil { + t.Fatalf("rollbackConfig() error = %v", err) + } + + // Verify config was restored (but we can't easily check the content) + // Just verify the function didn't error +} + +func TestRollbackConfigNonExistent(t *testing.T) { + tmpDir := t.TempDir() + + err := rollbackConfig(tmpDir, "/nonexistent/backup") + if err == nil { + t.Error("rollbackConfig() should error when backup doesn't exist") + } +} + +func TestWithConfigTransactionSuccess(t *testing.T) { + tmpDir := t.TempDir() + + // Create initial config + originalCfg := &config.Config{ + Providers: map[string]config.Provider{ + "initial": {Name: "Initial"}, + }, + } + if err := config.SaveConfig(tmpDir, originalCfg); err != nil { + t.Fatal(err) + } + + // Successful transaction + err := withConfigTransaction(tmpDir, func(txDir string) error { + cfg, err := config.LoadConfig(txDir) + if err != nil { + return err + } + cfg.Providers["new"] = config.Provider{Name: "New Provider"} + return config.SaveConfig(txDir, cfg) + }) + + if err != nil { + t.Fatalf("withConfigTransaction() error = %v", err) + } + + // Verify new provider was added + loadedCfg, err := config.LoadConfig(tmpDir) + if err != nil { + t.Fatal(err) + } + + if _, ok := loadedCfg.Providers["new"]; !ok { + t.Error("New provider should exist after successful transaction") + } +} + +func TestWithConfigTransactionRollbackOnError(t *testing.T) { + tmpDir := t.TempDir() + + // Create initial config + originalCfg := &config.Config{ + Providers: map[string]config.Provider{ + "initial": {Name: "Initial"}, + }, + } + if err := config.SaveConfig(tmpDir, originalCfg); err != nil { + t.Fatal(err) + } + + originalContent, err := os.ReadFile(getConfigPath(tmpDir)) + if err != nil { + t.Fatal(err) + } + + // Failing transaction - should trigger rollback + err = withConfigTransaction(tmpDir, func(txDir string) error { + cfg, err := config.LoadConfig(txDir) + if err != nil { + return err + } + cfg.Providers["new"] = config.Provider{Name: "New Provider"} + if err := config.SaveConfig(txDir, cfg); err != nil { + return err + } + // Return error to trigger rollback + return &testError{"simulated failure"} + }) + + // Error should be wrapped + if err == nil { + t.Fatal("withConfigTransaction() should return error") + } + + // Verify config was rolled back + rolledBackContent, err := os.ReadFile(getConfigPath(tmpDir)) + if err != nil { + t.Fatal(err) + } + + if string(originalContent) != string(rolledBackContent) { + t.Error("Config should be rolled back to original state after transaction failure") + } +} + +func TestWithConfigTransactionBackupCleanup(t *testing.T) { + tmpDir := t.TempDir() + + // Create initial config + originalCfg := &config.Config{ + Providers: map[string]config.Provider{ + "initial": {Name: "Initial"}, + }, + } + if err := config.SaveConfig(tmpDir, originalCfg); err != nil { + t.Fatal(err) + } + + // List files before transaction + beforeFiles, err := os.ReadDir(tmpDir) + if err != nil { + t.Fatal(err) + } + beforeCount := len(beforeFiles) + + // Successful transaction + err = withConfigTransaction(tmpDir, func(txDir string) error { + return nil // Do nothing, just test backup cleanup + }) + + if err != nil { + t.Fatalf("withConfigTransaction() error = %v", err) + } + + // List files after transaction + afterFiles, err := os.ReadDir(tmpDir) + if err != nil { + t.Fatal(err) + } + afterCount := len(afterFiles) + + // Backup file should be cleaned up + if afterCount > beforeCount { + t.Errorf("Backup file should be cleaned up, but files increased from %d to %d", beforeCount, afterCount) + } +} + +func TestWithConfigTransactionCriticalFailure(t *testing.T) { + tmpDir := t.TempDir() + + // Create initial config + originalCfg := &config.Config{ + Providers: map[string]config.Provider{ + "initial": {Name: "Initial"}, + }, + } + if err := config.SaveConfig(tmpDir, originalCfg); err != nil { + t.Fatal(err) + } + + // Make backup directory read-only to cause rollback failure + backupDir := filepath.Join(tmpDir, "backups") + if err := os.MkdirAll(backupDir, 0500); err != nil { + t.Fatal(err) + } + + // Make the config file read-only to cause rollback to fail + configPath := getConfigPath(tmpDir) + if err := os.Chmod(configPath, 0444); err != nil { + t.Fatal(err) + } + + // Transaction that fails, then rollback that fails + err := withConfigTransaction(tmpDir, func(txDir string) error { + cfg, err := config.LoadConfig(txDir) + if err != nil { + return err + } + cfg.Providers["new"] = config.Provider{Name: "New Provider"} + return config.SaveConfig(txDir, cfg) + }) + + // Should return an error about both transaction and rollback failure + if err == nil { + t.Error("withConfigTransaction() should return error when both transaction and rollback fail") + } + + // Restore permissions for cleanup (ignore errors during cleanup) + _ = os.Chmod(configPath, 0600) + _ = os.Chmod(backupDir, 0700) +} + +func TestGetConfigPath(t *testing.T) { + dir := "/test/dir" + expected := filepath.Join(dir, "config.yaml") + + result := getConfigPath(dir) + + if result != expected { + t.Errorf("getConfigPath() = %q, want %q", result, expected) + } +} + +func TestGetBackupPath(t *testing.T) { + dir := "/test/dir" + + result := getBackupPath(dir) + + // Should contain config.yaml.backup. + if len(result) < len("config.yaml.backup.") { + t.Error("Backup path should be longer than minimum expected") + } + + // Should start with the config dir + if filepath.Dir(result) != dir { + t.Errorf("Backup path should be in %s, got %s", dir, filepath.Dir(result)) + } +} + +func TestGetBackupPathUniqueness(t *testing.T) { + dir := "/test/dir" + + // Call twice and verify paths are different (nanosecond precision) + path1 := getBackupPath(dir) + + // Note: Due to speed, these might be the same in some cases + // But the function uses nanosecond precision which should be unique + path2 := getBackupPath(dir) + + _ = path1 + _ = path2 + // This test just verifies the function doesn't panic +} + +// testError is a simple error type for testing +type testError struct { + msg string +} + +func (e *testError) Error() string { + return e.msg +} diff --git a/cmd/default_test.go b/cmd/default_test.go index dbd5810..0ed4ea1 100644 --- a/cmd/default_test.go +++ b/cmd/default_test.go @@ -61,13 +61,13 @@ providers: t.Fatal(err) } - t.Logf("Before: configDir=%s", configDir) + t.Logf("Before: configDir=%s", getConfigDir()) rootCmd.SetArgs([]string{"default", "zai"}) err = rootCmd.Execute() if err != nil { t.Fatalf("Execute() error = %v", err) } - t.Logf("After: configDir=%s", configDir) + t.Logf("After: configDir=%s", getConfigDir()) cfg, err := config.LoadConfig(tmpDir) if err != nil { diff --git a/cmd/metrics.go b/cmd/metrics.go index a85a691..aeebdbe 100644 --- a/cmd/metrics.go +++ b/cmd/metrics.go @@ -5,6 +5,7 @@ import ( "os" "time" + kairoerrors "github.com/dkmnx/kairo/internal/errors" "github.com/dkmnx/kairo/internal/performance" "github.com/dkmnx/kairo/internal/ui" "github.com/spf13/cobra" @@ -172,10 +173,12 @@ func exportMetrics(registry *performance.Registry) error { case "json": data, err = registry.ToJSON() if err != nil { - return fmt.Errorf("failed to convert to JSON: %w", err) + return kairoerrors.WrapError(kairoerrors.ConfigError, + "failed to convert to JSON", err) } default: - return fmt.Errorf("unsupported format: %s (supported: json)", metricsFormat) + return kairoerrors.NewError(kairoerrors.ConfigError, + fmt.Sprintf("unsupported format: %s (supported: json)", metricsFormat)) } return os.WriteFile(metricsOutputPath, data, 0600) diff --git a/cmd/root.go b/cmd/root.go index 854a05c..173c583 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -26,6 +26,7 @@ import ( "os" "runtime" "strings" + "sync" "time" "github.com/dkmnx/kairo/internal/config" @@ -36,6 +37,7 @@ import ( var ( configDir string + configDirMu sync.RWMutex // Protects configDir verbose bool configCache *config.ConfigCache ) @@ -49,9 +51,20 @@ func setVerbose(v bool) { } func setConfigDir(dir string) { + configDirMu.Lock() + defer configDirMu.Unlock() configDir = dir } +func getConfigDir() string { + configDirMu.RLock() + defer configDirMu.RUnlock() + if configDir != "" { + return configDir + } + return env.GetConfigDir() +} + var rootCmd = &cobra.Command{ Use: "kairo", Short: "Kairo - Manage Claude Code API providers", @@ -186,13 +199,6 @@ func init() { } } -func getConfigDir() string { - if configDir != "" { - return configDir - } - return env.GetConfigDir() -} - // handleConfigError provides user-friendly guidance for config errors. func handleConfigError(cmd *cobra.Command, err error) { errStr := err.Error() diff --git a/cmd/setup.go b/cmd/setup.go index 7b26703..d5796a8 100644 --- a/cmd/setup.go +++ b/cmd/setup.go @@ -31,19 +31,23 @@ var validProviderName = regexp.MustCompile(`^[a-zA-Z][a-zA-Z0-9_-]*$`) // - Not be a reserved built-in provider name (case-insensitive) func validateCustomProviderName(name string) (string, error) { if name == "" { - return "", fmt.Errorf("provider name is required") + return "", kairoerrors.NewError(kairoerrors.ValidationError, + "provider name is required") } - // Check maximum length (50 characters) - if len(name) > 50 { - return "", fmt.Errorf("provider name must be at most 50 characters (got %d)", len(name)) + // Check maximum length + if len(name) > validate.MaxProviderNameLength { + return "", kairoerrors.NewError(kairoerrors.ValidationError, + fmt.Sprintf("provider name must be at most %d characters (got %d)", validate.MaxProviderNameLength, len(name))) } if !validProviderName.MatchString(name) { - return "", fmt.Errorf("provider name must start with a letter and contain only alphanumeric characters, underscores, and hyphens") + return "", kairoerrors.NewError(kairoerrors.ValidationError, + "provider name must start with a letter and contain only alphanumeric characters, underscores, and hyphens") } // Check for reserved provider names (case-insensitive) lowerName := strings.ToLower(name) if providers.IsBuiltInProvider(lowerName) { - return "", fmt.Errorf("reserved provider name: %s", lowerName) + return "", kairoerrors.NewError(kairoerrors.ValidationError, + fmt.Sprintf("reserved provider name: %s", lowerName)) } return name, nil } @@ -92,7 +96,8 @@ func saveProviderConfigFile(dir string, cfg *config.Config, providerName string, cfg.DefaultProvider = providerName } if err := config.SaveConfig(dir, cfg); err != nil { - return fmt.Errorf("saving config: %w", err) + return kairoerrors.WrapError(kairoerrors.ConfigError, + "saving config", err) } return nil } @@ -128,10 +133,12 @@ func providerStatusIcon(cfg *config.Config, secrets map[string]string, provider // ensureConfigDirectory creates the config directory and encryption key if they don't exist. func ensureConfigDirectory(dir string) error { if err := os.MkdirAll(dir, 0700); err != nil { - return fmt.Errorf("creating config directory: %w", err) + return kairoerrors.WrapError(kairoerrors.FileSystemError, + "creating config directory", err) } if err := crypto.EnsureKeyExists(dir); err != nil { - return fmt.Errorf("creating encryption key: %w", err) + return kairoerrors.WrapError(kairoerrors.CryptoError, + "creating encryption key", err) } return nil } @@ -286,7 +293,8 @@ func configureProvider(dir string, cfg *config.Config, providerName string, secr if providerName == "custom" { customName, err := ui.Prompt("Provider name") if err != nil { - return "", nil, fmt.Errorf("reading provider name: %w", err) + return "", nil, kairoerrors.WrapError(kairoerrors.ValidationError, + "reading provider name", err) } validatedName, err := validateCustomProviderName(customName) if err != nil { @@ -316,7 +324,8 @@ func configureProvider(dir string, cfg *config.Config, providerName string, secr model, err := ui.PromptWithDefault("Model", def.Model) if err != nil { - return "", nil, fmt.Errorf("reading model: %w", err) + return "", nil, kairoerrors.WrapError(kairoerrors.ValidationError, + "reading model", err) } // Validate model is non-empty for custom providers @@ -324,7 +333,8 @@ func configureProvider(dir string, cfg *config.Config, providerName string, secr if !providers.IsBuiltInProvider(providerName) { model = strings.TrimSpace(model) if model == "" { - return "", nil, fmt.Errorf("model name is required for custom providers") + return "", nil, kairoerrors.NewError(kairoerrors.ValidationError, + "model name is required for custom providers") } } @@ -339,7 +349,8 @@ func configureProvider(dir string, cfg *config.Config, providerName string, secr secrets[fmt.Sprintf("%s_API_KEY", strings.ToUpper(providerName))] = apiKey secretsContent := formatSecretsFileContent(secrets) if err := crypto.EncryptSecrets(secretsPath, keyPath, secretsContent); err != nil { - return "", nil, fmt.Errorf("saving API key: %w", err) + return "", nil, kairoerrors.WrapError(kairoerrors.CryptoError, + "saving API key", err) } // Prepare audit details @@ -361,7 +372,8 @@ func configureProvider(dir string, cfg *config.Config, providerName string, secr func promptForAPIKey(providerName string) (string, error) { apiKey, err := ui.PromptSecret("API Key") if err != nil { - return "", fmt.Errorf("reading API key: %w", err) + return "", kairoerrors.WrapError(kairoerrors.ValidationError, + "reading API key", err) } if err := validateAPIKey(apiKey, providerName); err != nil { return "", err @@ -373,7 +385,8 @@ func promptForAPIKey(providerName string) (string, error) { func promptForBaseURL(defaultURL, providerName string) (string, error) { baseURL, err := ui.PromptWithDefault("Base URL", defaultURL) if err != nil { - return "", fmt.Errorf("reading base URL: %w", err) + return "", kairoerrors.WrapError(kairoerrors.ValidationError, + "reading base URL", err) } if err := validateBaseURL(baseURL, providerName); err != nil { return "", err diff --git a/cmd/switch.go b/cmd/switch.go index c0a4a44..3a62170 100644 --- a/cmd/switch.go +++ b/cmd/switch.go @@ -176,15 +176,12 @@ var switchCmd = &cobra.Command{ // Set up signal handling for cleanup on SIGINT/SIGTERM sigChan := make(chan os.Signal, 1) - defer func() { - signal.Stop(sigChan) - close(sigChan) - }() signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) go func() { sig := <-sigChan - cleanup() + signal.Stop(sigChan) + // Let deferred cleanup() handle resource cleanup code := 128 if s, ok := sig.(syscall.Signal); ok { code += int(s) @@ -205,9 +202,9 @@ var switchCmd = &cobra.Command{ if err := execCmd.Run(); err != nil { cmd.Printf("Error running Qwen: %v\n", err) - exitProcess(1) } - return + + // Cleanup via deferred cleanup() above } // Claude harness - existing wrapper script logic @@ -228,16 +225,12 @@ var switchCmd = &cobra.Command{ // Set up signal handling for cleanup on SIGINT/SIGTERM sigChan := make(chan os.Signal, 1) - defer func() { - signal.Stop(sigChan) - close(sigChan) - }() signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) go func() { sig := <-sigChan - cleanup() - // Exit with signal code (cross-platform) + signal.Stop(sigChan) + // Let deferred cleanup() handle resource cleanup code := 128 if s, ok := sig.(syscall.Signal); ok { code += int(s) @@ -265,8 +258,9 @@ var switchCmd = &cobra.Command{ if err := execCmd.Run(); err != nil { cmd.Printf("Error running Claude: %v\n", err) - exitProcess(1) } + + // Cleanup via deferred cleanup() above return } diff --git a/cmd/switch_run_test.go b/cmd/switch_run_test.go index b47ecb7..dff8d1c 100644 --- a/cmd/switch_run_test.go +++ b/cmd/switch_run_test.go @@ -21,19 +21,9 @@ func runningWithRaceDetector() bool { return runtime.GOMAXPROCS(-1) > 1 } -// Temporarily disabled - Cobra output not captured -func TestSwitchCmd_ProviderNotFound(t *testing.T) { - t.Skip("Temporarily disabled - Cobra output capture needs refactoring") -} - -// Temporarily disabled - Cobra output not captured -func TestSwitchCmd_ClaudeNotFound(t *testing.T) { - t.Skip("Temporarily disabled - Cobra output capture needs refactoring") -} - func TestSwitchCmd_WithAPIKey_Success(t *testing.T) { if runningWithRaceDetector() { - t.Skip("Skipping integration test with race detector - benign race on global configDir") + t.Skip("Skipping with race detector - requires test refactoring for proper goroutine synchronization") } tmpDir := t.TempDir() @@ -99,7 +89,10 @@ func TestSwitchCmd_WithAPIKey_Success(t *testing.T) { var buf bytes.Buffer oldStdout := os.Stdout - r, w, _ := os.Pipe() + r, w, err := os.Pipe() + if err != nil { + t.Fatal(err) + } os.Stdout = w done := make(chan struct{}) @@ -146,7 +139,7 @@ func TestSwitchCmd_WithAPIKey_Success(t *testing.T) { func TestSwitchCmd_WithoutAPIKey_Success(t *testing.T) { if runningWithRaceDetector() { - t.Skip("Skipping integration test with race detector - benign race on global configDir") + t.Skip("Skipping with race detector - requires test refactoring for proper goroutine synchronization") } tmpDir := t.TempDir() @@ -210,7 +203,10 @@ func TestSwitchCmd_WithoutAPIKey_Success(t *testing.T) { var buf bytes.Buffer oldStdout := os.Stdout - r, w, _ := os.Pipe() + r, w, err := os.Pipe() + if err != nil { + t.Fatal(err) + } os.Stdout = w done := make(chan struct{}) diff --git a/cmd/update.go b/cmd/update.go index 19af0d0..5145f3e 100644 --- a/cmd/update.go +++ b/cmd/update.go @@ -12,7 +12,9 @@ import ( "time" "github.com/Masterminds/semver/v3" + "github.com/dkmnx/kairo/internal/audit" "github.com/dkmnx/kairo/internal/config" + kairoerrors "github.com/dkmnx/kairo/internal/errors" "github.com/dkmnx/kairo/internal/ui" "github.com/dkmnx/kairo/internal/version" "github.com/spf13/cobra" @@ -55,29 +57,34 @@ func getLatestRelease() (*release, error) { url := getLatestReleaseURL() req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) + return nil, kairoerrors.WrapError(kairoerrors.NetworkError, + "failed to create request", err) } req.Header.Set("User-Agent", "kairo-cli") resp, err := http.DefaultClient.Do(req) if err != nil { - return nil, fmt.Errorf("failed to fetch release: %w", err) + return nil, kairoerrors.WrapError(kairoerrors.NetworkError, + "failed to fetch release", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("API returned status %d", resp.StatusCode) + return nil, kairoerrors.NewError(kairoerrors.NetworkError, + fmt.Sprintf("API returned status %d", resp.StatusCode)) } body, err := io.ReadAll(resp.Body) if err != nil { - return nil, fmt.Errorf("failed to read response: %w", err) + return nil, kairoerrors.WrapError(kairoerrors.NetworkError, + "failed to read response", err) } var r release if err := json.Unmarshal(body, &r); err != nil { - return nil, fmt.Errorf("failed to parse response: %w", err) + return nil, kairoerrors.WrapError(kairoerrors.NetworkError, + "failed to parse response", err) } return &r, nil @@ -112,12 +119,14 @@ func getInstallScriptURL(goos string) string { func downloadToTempFile(url string) (string, error) { resp, err := http.Get(url) if err != nil { - return "", fmt.Errorf("failed to download: %w", err) + return "", kairoerrors.WrapError(kairoerrors.NetworkError, + "failed to download", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return "", fmt.Errorf("download failed with status %d", resp.StatusCode) + return "", kairoerrors.NewError(kairoerrors.NetworkError, + fmt.Sprintf("download failed with status %d", resp.StatusCode)) } ext := ".sh" @@ -126,19 +135,22 @@ func downloadToTempFile(url string) (string, error) { } tempFile, err := os.CreateTemp("", "kairo-install-*"+ext) if err != nil { - return "", fmt.Errorf("failed to create temp file: %w", err) + return "", kairoerrors.WrapError(kairoerrors.FileSystemError, + "failed to create temp file", err) } _, err = io.Copy(tempFile, resp.Body) if err != nil { tempFile.Close() os.Remove(tempFile.Name()) - return "", fmt.Errorf("failed to write to temp file: %w", err) + return "", kairoerrors.WrapError(kairoerrors.NetworkError, + "failed to write to temp file", err) } if err := tempFile.Close(); err != nil { os.Remove(tempFile.Name()) - return "", fmt.Errorf("failed to close temp file: %w", err) + return "", kairoerrors.WrapError(kairoerrors.FileSystemError, + "failed to close temp file", err) } return tempFile.Name(), nil @@ -152,21 +164,24 @@ func runInstallScript(scriptPath string) error { pwshCmd.Stderr = os.Stderr if err := pwshCmd.Run(); err != nil { - return fmt.Errorf("powershell execution failed: %w", err) + return kairoerrors.WrapError(kairoerrors.RuntimeError, + "powershell execution failed", err) } return nil } if err := os.Chmod(scriptPath, 0755); err != nil { - return fmt.Errorf("failed to make script executable: %w", err) + return kairoerrors.WrapError(kairoerrors.FileSystemError, + "failed to make script executable", err) } shCmd := exec.Command("/bin/sh", scriptPath) shCmd.Stdout = os.Stdout shCmd.Stderr = os.Stderr if err := shCmd.Run(); err != nil { - return fmt.Errorf("shell execution failed: %w", err) + return kairoerrors.WrapError(kairoerrors.RuntimeError, + "shell execution failed", err) } return nil @@ -235,6 +250,21 @@ This command will: cmd.Printf("Warning: config migration failed: %v\n", err) } else if len(changes) > 0 { cmd.Printf("%s\n", config.FormatMigrationChanges(changes)) + + // Audit log the migration + if err := logAuditEvent(dir, func(logger *audit.Logger) error { + auditChanges := make([]audit.Change, len(changes)) + for i, c := range changes { + auditChanges[i] = audit.Change{ + Field: c.Field, + Old: c.Old, + New: c.New, + } + } + return logger.LogConfig("config", "migrate", auditChanges) + }); err != nil { + cmd.Printf("Warning: audit logging failed: %v\n", err) + } } } }, diff --git a/internal/audit/audit.go b/internal/audit/audit.go index 56aa780..979a69d 100644 --- a/internal/audit/audit.go +++ b/internal/audit/audit.go @@ -4,6 +4,7 @@ import ( "crypto/rand" "encoding/hex" "encoding/json" + "fmt" "os" "os/user" "path/filepath" @@ -94,6 +95,134 @@ func (l *Logger) Close() error { return nil } +// RotateOptions contains configuration options for log rotation. +type RotateOptions struct { + // MaxSize is the maximum size in bytes before rotating (default: 10MB) + MaxSize int64 + // MaxAge is the maximum age in days before rotating (default: 30 days) + MaxAge int + // MaxBackups is the number of old log files to keep (default: 5) + MaxBackups int +} + +// DefaultRotateOptions returns sensible defaults for log rotation. +func DefaultRotateOptions() RotateOptions { + return RotateOptions{ + MaxSize: 10 * 1024 * 1024, // 10MB + MaxAge: 30, // 30 days + MaxBackups: 5, + } +} + +// RotateLog rotates the audit log if it exceeds size or age limits. +// Old log files are renamed with a timestamp suffix. +// Returns true if rotation occurred, false otherwise. +func (l *Logger) RotateLog(opts ...RotateOptions) (bool, error) { + options := DefaultRotateOptions() + if len(opts) > 0 { + options = opts[0] + } + + l.mu.Lock() + defer l.mu.Unlock() + + // Check file stats + info, err := os.Stat(l.path) + if err != nil { + if os.IsNotExist(err) { + return false, nil + } + return false, err + } + + // Check size limit + shouldRotate := info.Size() > options.MaxSize + + // Check age limit + if !shouldRotate { + age := time.Since(info.ModTime()) + shouldRotate = age > time.Duration(options.MaxAge)*24*time.Hour + } + + if !shouldRotate { + return false, nil + } + + // Close current file if open + if l.f != nil { + l.f.Close() + l.f = nil + } + + // Generate timestamped backup name + timestamp := time.Now().Format("2006-01-02T15-04-05") + backupPath := filepath.Join(filepath.Dir(l.path), + fmt.Sprintf("audit.%s.log", timestamp)) + + // Rename current log to backup + if err := os.Rename(l.path, backupPath); err != nil { + return false, err + } + + // Open new log file + f, err := os.OpenFile(l.path, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600) + if err != nil { + return false, err + } + l.f = f + + // Clean up old backups + l.cleanupOldBackups(options.MaxBackups) + + return true, nil +} + +// cleanupOldBackups removes old audit log backups beyond the limit. +func (l *Logger) cleanupOldBackups(maxBackups int) { + if maxBackups <= 0 { + return + } + + dir := filepath.Dir(l.path) + pattern := filepath.Join(dir, "audit.*.log") + + matches, err := filepath.Glob(pattern) + if err != nil { + return + } + + if len(matches) <= maxBackups { + return + } + + // Sort by modification time (oldest first) + type fileInfo struct { + path string + modTime time.Time + } + + var files []fileInfo + for _, m := range matches { + if info, err := os.Stat(m); err == nil { + files = append(files, fileInfo{path: m, modTime: info.ModTime()}) + } + } + + // Sort by mod time + for i := 0; i < len(files)-1; i++ { + for j := i + 1; j < len(files); j++ { + if files[i].modTime.After(files[j].modTime) { + files[i], files[j] = files[j], files[i] + } + } + } + + // Remove oldest files beyond limit + for i := 0; i < len(files)-maxBackups; i++ { + os.Remove(files[i].path) + } +} + // LogSwitch logs a provider switch event to the audit log. // // This method creates an audit entry recording when a user switches to a diff --git a/internal/audit/audit_expanded_test.go b/internal/audit/audit_expanded_test.go new file mode 100644 index 0000000..f35a67f --- /dev/null +++ b/internal/audit/audit_expanded_test.go @@ -0,0 +1,365 @@ +package audit + +import ( + "os" + "path/filepath" + "testing" +) + +func TestLogMigration(t *testing.T) { + tmpDir := t.TempDir() + logger, err := NewLogger(tmpDir) + if err != nil { + t.Fatalf("NewLogger() error = %v", err) + } + defer logger.Close() + + details := map[string]interface{}{ + "from_version": "1.0.0", + "to_version": "1.1.0", + "migrated_fields": []string{ + "default_harness", + "provider_aliases", + }, + } + + err = logger.LogMigration(details) + if err != nil { + t.Fatalf("LogMigration() error = %v", err) + } + + entries, err := logger.LoadEntries() + if err != nil { + t.Fatalf("LoadEntries() error = %v", err) + } + + if len(entries) != 1 { + t.Fatalf("LoadEntries() returned %d entries, want 1", len(entries)) + } + + entry := entries[0] + if entry.Event != "migration" { + t.Errorf("Event = %q, want %q", entry.Event, "migration") + } + + if entry.Status != "success" { + t.Errorf("Status = %q, want %q", entry.Status, "success") + } + + if entry.Provider != "" { + t.Errorf("Provider should be empty for migration, got %q", entry.Provider) + } + + if entry.Details == nil { + t.Fatal("Details should not be nil") + } + + if entry.Details["from_version"] != "1.0.0" { + t.Errorf("Details[from_version] = %v, want 1.0.0", entry.Details["from_version"]) + } + + if entry.Details["to_version"] != "1.1.0" { + t.Errorf("Details[to_version] = %v, want 1.1.0", entry.Details["to_version"]) + } +} + +func TestLogMigrationWithNilDetails(t *testing.T) { + tmpDir := t.TempDir() + logger, err := NewLogger(tmpDir) + if err != nil { + t.Fatalf("NewLogger() error = %v", err) + } + defer logger.Close() + + err = logger.LogMigration(nil) + if err != nil { + t.Fatalf("LogMigration() error = %v", err) + } + + entries, err := logger.LoadEntries() + if err != nil { + t.Fatalf("LoadEntries() error = %v", err) + } + + if len(entries) != 1 { + t.Fatalf("LoadEntries() returned %d entries, want 1", len(entries)) + } + + entry := entries[0] + if entry.Event != "migration" { + t.Errorf("Event = %q, want %q", entry.Event, "migration") + } +} + +func TestWriteEntryReopensClosedFile(t *testing.T) { + tmpDir := t.TempDir() + logger, err := NewLogger(tmpDir) + if err != nil { + t.Fatalf("NewLogger() error = %v", err) + } + + // Close the underlying file to simulate a closed state + if err := logger.Close(); err != nil { + t.Fatalf("Close() error = %v", err) + } + + // Write should reopen the file automatically + err = logger.LogSwitch("test-provider") + if err != nil { + t.Fatalf("LogSwitch() after close error = %v", err) + } + + entries, err := logger.LoadEntries() + if err != nil { + t.Fatalf("LoadEntries() error = %v", err) + } + + if len(entries) != 1 { + t.Errorf("LoadEntries() returned %d entries, want 1", len(entries)) + } +} + +func TestWriteEntryWithClosedFileMultipleWrites(t *testing.T) { + tmpDir := t.TempDir() + logger, err := NewLogger(tmpDir) + if err != nil { + t.Fatalf("NewLogger() error = %v", err) + } + + // Close the underlying file + if err := logger.Close(); err != nil { + t.Fatalf("Close() error = %v", err) + } + + // Multiple writes after close should all succeed + providers := []string{"provider1", "provider2", "provider3"} + for _, p := range providers { + err = logger.LogSwitch(p) + if err != nil { + t.Fatalf("LogSwitch(%q) error = %v", p, err) + } + } + + entries, err := logger.LoadEntries() + if err != nil { + t.Fatalf("LoadEntries() error = %v", err) + } + + if len(entries) != len(providers) { + t.Errorf("LoadEntries() returned %d entries, want %d", len(entries), len(providers)) + } +} + +func TestWriteEntryBasic(t *testing.T) { + tmpDir := t.TempDir() + logPath := filepath.Join(tmpDir, "audit.log") + + logger, err := NewLogger(tmpDir) + if err != nil { + t.Fatalf("NewLogger() error = %v", err) + } + defer logger.Close() + + err = logger.LogSwitch("test") + if err != nil { + t.Errorf("LogSwitch() error = %v", err) + } + + // Verify entry was written + data, err := os.ReadFile(logPath) + if err != nil { + t.Fatalf("ReadFile() error = %v", err) + } + + if len(data) == 0 { + t.Error("Log file should contain data") + } +} + +func TestGenerateSessionIDWithRandFailure(t *testing.T) { + // Test generateSessionID fallback when crypto/rand fails + // This is hard to test directly since it requires mocking, + // but we can verify the function produces valid output + + sessionID := generateSessionID() + if sessionID == "" { + t.Error("generateSessionID() should not return empty string") + } + + // Verify it's either a hex string (16 chars) or a timestamp-based fallback + if len(sessionID) != 16 && len(sessionID) < 10 { + t.Errorf("SessionID has unexpected length: %d", len(sessionID)) + } +} + +func TestSplitLines(t *testing.T) { + tests := []struct { + name string + input string + expected []string + }{ + { + name: "empty string", + input: "", + expected: []string{}, + }, + { + name: "single line no newline", + input: "hello", + expected: []string{"hello"}, + }, + { + name: "two lines", + input: "hello\nworld", + expected: []string{"hello", "world"}, + }, + { + name: "three lines with trailing newline", + input: "line1\nline2\nline3\n", + expected: []string{"line1", "line2", "line3"}, + }, + { + name: "empty lines in middle", + input: "line1\n\nline2", + expected: []string{"line1", "", "line2"}, + }, + { + name: "only newlines", + input: "\n\n\n", + expected: []string{"", "", ""}, + }, + { + name: "json lines", + input: `{"event":"a"}` + "\n" + `{"event":"b"}`, + expected: []string{`{"event":"a"}`, `{"event":"b"}`}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := splitLines(tt.input) + if len(result) != len(tt.expected) { + t.Errorf("splitLines() returned %d lines, want %d", len(result), len(tt.expected)) + return + } + for i, line := range result { + if line != tt.expected[i] { + t.Errorf("Line %d = %q, want %q", i, line, tt.expected[i]) + } + } + }) + } +} + +func TestLoadEntriesWithEmptyLines(t *testing.T) { + tmpDir := t.TempDir() + logPath := filepath.Join(tmpDir, "audit.log") + + // Write log with empty lines between entries + content := `{"event":"switch","provider":"p1","status":"success"} +{"event":"switch","provider":"p2","status":"success"} +` + if err := os.WriteFile(logPath, []byte(content), 0600); err != nil { + t.Fatal(err) + } + + logger := &Logger{path: logPath} + entries, err := logger.LoadEntries() + if err != nil { + t.Fatalf("LoadEntries() error = %v", err) + } + + // Empty lines should be skipped + if len(entries) != 2 { + t.Errorf("LoadEntries() returned %d entries, want 2", len(entries)) + } +} + +func TestLoadEntriesWithCorruptedLine(t *testing.T) { + tmpDir := t.TempDir() + logPath := filepath.Join(tmpDir, "audit.log") + + // Write log with one valid entry and one corrupted + content := `{"event":"switch","provider":"p1","status":"success"} +invalid json here +{"event":"switch","provider":"p2","status":"success"} +` + if err := os.WriteFile(logPath, []byte(content), 0600); err != nil { + t.Fatal(err) + } + + logger := &Logger{path: logPath} + _, err := logger.LoadEntries() + + // Should fail on corrupted JSON + if err == nil { + t.Error("LoadEntries() should error on corrupted JSON line") + } +} + +func TestNewLoggerHandlesEmptyHostname(t *testing.T) { + // Test NewLogger when os.Hostname() returns empty string + // This is hard to mock without interface, so we test the "unknown" fallback path + tmpDir := t.TempDir() + + logger, err := NewLogger(tmpDir) + if err != nil { + t.Fatalf("NewLogger() error = %v", err) + } + defer logger.Close() + + // Logger should have initialized with some hostname value + if logger.hostname == "" { + // This may happen in some test environments, so we just log + t.Log("Hostname is empty in this environment") + } +} + +func TestNewLoggerCreatesLogFile(t *testing.T) { + tmpDir := t.TempDir() + logPath := filepath.Join(tmpDir, "audit.log") + + logger, err := NewLogger(tmpDir) + if err != nil { + t.Fatalf("NewLogger() error = %v", err) + } + defer logger.Close() + + // Verify file was created with correct path + if logger.path != logPath { + t.Errorf("Logger path = %q, want %q", logger.path, logPath) + } + + // Verify file exists + if _, err := os.Stat(logPath); os.IsNotExist(err) { + t.Error("Audit log file should exist") + } +} + +func TestWriteEntryHandlesSyncError(t *testing.T) { + // This is hard to test without mocking the file, + // but we can test the error handling path by checking + // that write errors are returned properly + tmpDir := t.TempDir() + logger, err := NewLogger(tmpDir) + if err != nil { + t.Fatalf("NewLogger() error = %v", err) + } + defer logger.Close() + + // Normal write should succeed + err = logger.LogSwitch("test") + if err != nil { + t.Errorf("LogSwitch() error = %v", err) + } + + // Verify data was written + entries, err := logger.LoadEntries() + if err != nil { + t.Fatalf("LoadEntries() error = %v", err) + } + + if len(entries) != 1 { + t.Errorf("Expected 1 entry, got %d", len(entries)) + } +} diff --git a/internal/backup/backup.go b/internal/backup/backup.go index 2726007..ba1d2a5 100644 --- a/internal/backup/backup.go +++ b/internal/backup/backup.go @@ -7,12 +7,15 @@ import ( "os" "path/filepath" "time" + + kairoerrors "github.com/dkmnx/kairo/internal/errors" ) func CreateBackup(configDir string) (string, error) { backupDir := filepath.Join(configDir, "backups") if err := os.MkdirAll(backupDir, 0700); err != nil { - return "", fmt.Errorf("create backup dir: %w", err) + return "", kairoerrors.WrapError(kairoerrors.FileSystemError, + "create backup dir", err) } timestamp := time.Now().Format("20060102_150405") @@ -20,7 +23,8 @@ func CreateBackup(configDir string) (string, error) { zipFile, err := os.Create(backupPath) if err != nil { - return "", fmt.Errorf("create zip: %w", err) + return "", kairoerrors.WrapError(kairoerrors.FileSystemError, + "create zip", err) } defer zipFile.Close() @@ -36,25 +40,29 @@ func CreateBackup(configDir string) (string, error) { src, err := os.Open(srcPath) if err != nil { - return "", fmt.Errorf("open %s: %w", f, err) + return "", kairoerrors.WrapError(kairoerrors.FileSystemError, + fmt.Sprintf("open %s", f), err) } w, err := zipWriter.Create(f) if err != nil { src.Close() - return "", fmt.Errorf("create zip entry %s: %w", f, err) + return "", kairoerrors.WrapError(kairoerrors.FileSystemError, + fmt.Sprintf("create zip entry %s", f), err) } if _, err := io.Copy(w, src); err != nil { src.Close() - return "", fmt.Errorf("write %s: %w", f, err) + return "", kairoerrors.WrapError(kairoerrors.FileSystemError, + fmt.Sprintf("write %s", f), err) } src.Close() } // Explicitly close and check for flush errors if err := zipWriter.Close(); err != nil { - return "", fmt.Errorf("close zip writer: %w", err) + return "", kairoerrors.WrapError(kairoerrors.FileSystemError, + "close zip writer", err) } return backupPath, nil } @@ -62,7 +70,8 @@ func CreateBackup(configDir string) (string, error) { func RestoreBackup(configDir, backupPath string) error { r, err := zip.OpenReader(backupPath) if err != nil { - return fmt.Errorf("open backup: %w", err) + return kairoerrors.WrapError(kairoerrors.FileSystemError, + "open backup", err) } defer r.Close() @@ -74,24 +83,28 @@ func RestoreBackup(configDir, backupPath string) error { destPath := filepath.Join(configDir, f.Name) if err := os.MkdirAll(filepath.Dir(destPath), 0700); err != nil { - return fmt.Errorf("create dir for %s: %w", f.Name, err) + return kairoerrors.WrapError(kairoerrors.FileSystemError, + fmt.Sprintf("create dir for %s", f.Name), err) } outFile, err := os.OpenFile(destPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode()) if err != nil { - return fmt.Errorf("create %s: %w", f.Name, err) + return kairoerrors.WrapError(kairoerrors.FileSystemError, + fmt.Sprintf("create %s", f.Name), err) } rc, err := f.Open() if err != nil { outFile.Close() - return fmt.Errorf("open %s in zip: %w", f.Name, err) + return kairoerrors.WrapError(kairoerrors.FileSystemError, + fmt.Sprintf("open %s in zip", f.Name), err) } if _, err := io.Copy(outFile, rc); err != nil { outFile.Close() rc.Close() - return fmt.Errorf("extract %s: %w", f.Name, err) + return kairoerrors.WrapError(kairoerrors.FileSystemError, + fmt.Sprintf("extract %s", f.Name), err) } outFile.Close() diff --git a/internal/backup/backup_expanded_test.go b/internal/backup/backup_expanded_test.go new file mode 100644 index 0000000..f5d7195 --- /dev/null +++ b/internal/backup/backup_expanded_test.go @@ -0,0 +1,349 @@ +package backup + +import ( + "archive/zip" + "os" + "path/filepath" + "testing" +) + +func TestCreateBackupWithMissingFiles(t *testing.T) { + tmpDir := t.TempDir() + + // Create only config.yaml, leave age.key and secrets.age missing + configPath := filepath.Join(tmpDir, "config.yaml") + if err := os.WriteFile(configPath, []byte("test: config"), 0600); err != nil { + t.Fatal(err) + } + + // Create backups directory + backupDir := filepath.Join(tmpDir, "backups") + if err := os.MkdirAll(backupDir, 0700); err != nil { + t.Fatal(err) + } + + // Create backup + backupPath, err := CreateBackup(tmpDir) + if err != nil { + t.Fatalf("CreateBackup() error = %v", err) + } + + // Verify backup was created + if backupPath == "" { + t.Fatal("Backup path should not be empty") + } + + // Verify the zip contains only config.yaml + r, err := zip.OpenReader(backupPath) + if err != nil { + t.Fatalf("Failed to open backup zip: %v", err) + } + defer r.Close() + + files := make(map[string]bool) + for _, f := range r.File { + files[f.Name] = true + } + + if !files["config.yaml"] { + t.Error("Backup should contain config.yaml") + } + + // age.key and secrets.age should not be in the backup (they don't exist) + if files["age.key"] { + t.Error("Backup should not contain age.key (doesn't exist)") + } + if files["secrets.age"] { + t.Error("Backup should not contain secrets.age (doesn't exist)") + } +} + +func TestCreateBackupWithAllFiles(t *testing.T) { + tmpDir := t.TempDir() + + // Create all files + files := map[string]string{ + "age.key": "age1keyhere", + "secrets.age": "encryptedsecrets", + "config.yaml": "providers: {}", + } + + for name, content := range files { + path := filepath.Join(tmpDir, name) + if err := os.WriteFile(path, []byte(content), 0600); err != nil { + t.Fatal(err) + } + } + + backupPath, err := CreateBackup(tmpDir) + if err != nil { + t.Fatalf("CreateBackup() error = %v", err) + } + + // Verify the zip contains all files + r, err := zip.OpenReader(backupPath) + if err != nil { + t.Fatalf("Failed to open backup zip: %v", err) + } + defer r.Close() + + filesInZip := make(map[string]bool) + for _, f := range r.File { + filesInZip[f.Name] = true + } + + for name := range files { + if !filesInZip[name] { + t.Errorf("Backup should contain %s", name) + } + } +} + +func TestCreateBackupCreatesBackupDirectory(t *testing.T) { + tmpDir := t.TempDir() + + // Ensure no backup directory exists + backupDir := filepath.Join(tmpDir, "backups") + if _, err := os.Stat(backupDir); !os.IsNotExist(err) { + t.Fatal("Backup directory should not exist initially") + } + + _, err := CreateBackup(tmpDir) + if err != nil { + t.Fatalf("CreateBackup() error = %v", err) + } + + // Verify backup directory was created + if _, err := os.Stat(backupDir); os.IsNotExist(err) { + t.Error("Backup directory should be created") + } +} + +func TestRestoreBackupWithMissingDestinationDir(t *testing.T) { + tmpDir := t.TempDir() + + // Create a backup zip + backupDir := filepath.Join(tmpDir, "backups") + if err := os.MkdirAll(backupDir, 0700); err != nil { + t.Fatal(err) + } + + backupPath := filepath.Join(backupDir, "test_backup.zip") + zipFile, err := os.Create(backupPath) + if err != nil { + t.Fatal(err) + } + + zipWriter := zip.NewWriter(zipFile) + // Create a file in a subdirectory + writer, err := zipWriter.Create("subdir/config.yaml") + if err != nil { + t.Fatal(err) + } + _, err = writer.Write([]byte("test content")) + if err != nil { + t.Fatal(err) + } + zipWriter.Close() + zipFile.Close() + + // Restore to a new directory (doesn't exist) + newDir := filepath.Join(tmpDir, "newdir") + err = RestoreBackup(newDir, backupPath) + if err != nil { + t.Fatalf("RestoreBackup() error = %v", err) + } + + // Verify file was restored + restoredPath := filepath.Join(newDir, "subdir", "config.yaml") + if _, err := os.Stat(restoredPath); os.IsNotExist(err) { + t.Error("Restored file should exist") + } +} + +func TestRestoreBackupOverwritesExisting(t *testing.T) { + tmpDir := t.TempDir() + + // Create original config + configPath := filepath.Join(tmpDir, "config.yaml") + originalContent := "original: content" + if err := os.WriteFile(configPath, []byte(originalContent), 0600); err != nil { + t.Fatal(err) + } + + // Create backup with different content + backupDir := filepath.Join(tmpDir, "backups") + if err := os.MkdirAll(backupDir, 0700); err != nil { + t.Fatal(err) + } + + backupPath := filepath.Join(backupDir, "test.zip") + zipFile, err := os.Create(backupPath) + if err != nil { + t.Fatal(err) + } + + zipWriter := zip.NewWriter(zipFile) + writer, err := zipWriter.Create("config.yaml") + if err != nil { + t.Fatal(err) + } + newContent := "new: content" + _, err = writer.Write([]byte(newContent)) + if err != nil { + t.Fatal(err) + } + zipWriter.Close() + zipFile.Close() + + // Restore + err = RestoreBackup(tmpDir, backupPath) + if err != nil { + t.Fatalf("RestoreBackup() error = %v", err) + } + + // Verify content was overwritten + data, err := os.ReadFile(configPath) + if err != nil { + t.Fatal(err) + } + + if string(data) != newContent { + t.Errorf("Config content = %q, want %q", string(data), newContent) + } +} + +func TestRestoreBackupWithInvalidZip(t *testing.T) { + tmpDir := t.TempDir() + + // Create invalid zip file + backupPath := filepath.Join(tmpDir, "invalid.zip") + if err := os.WriteFile(backupPath, []byte("not a zip file"), 0600); err != nil { + t.Fatal(err) + } + + err := RestoreBackup(tmpDir, backupPath) + if err == nil { + t.Error("RestoreBackup() should error on invalid zip") + } +} + +func TestRestoreBackupWithNonExistentZip(t *testing.T) { + tmpDir := t.TempDir() + + err := RestoreBackup(tmpDir, "/nonexistent/backup.zip") + if err == nil { + t.Error("RestoreBackup() should error on non-existent zip") + } +} + +func TestCreateBackupZipPermissions(t *testing.T) { + tmpDir := t.TempDir() + + // Create config file + configPath := filepath.Join(tmpDir, "config.yaml") + if err := os.WriteFile(configPath, []byte("test: config"), 0600); err != nil { + t.Fatal(err) + } + + backupPath, err := CreateBackup(tmpDir) + if err != nil { + t.Fatalf("CreateBackup() error = %v", err) + } + + // Verify backup file exists + if _, err := os.Stat(backupPath); os.IsNotExist(err) { + t.Error("Backup file should exist") + } +} + +func TestRestoreBackupWithDirectoryEntry(t *testing.T) { + tmpDir := t.TempDir() + + // Create backup with a directory entry + backupDir := filepath.Join(tmpDir, "backups") + if err := os.MkdirAll(backupDir, 0700); err != nil { + t.Fatal(err) + } + + backupPath := filepath.Join(backupDir, "test.zip") + zipFile, err := os.Create(backupPath) + if err != nil { + t.Fatal(err) + } + + zipWriter := zip.NewWriter(zipFile) + // Create a directory entry (ends with /) + _, err = zipWriter.Create("somedir/") + if err != nil { + t.Fatal(err) + } + // Create a file inside the directory + writer, err := zipWriter.Create("somedir/file.txt") + if err != nil { + t.Fatal(err) + } + _, err = writer.Write([]byte("content")) + if err != nil { + t.Fatal(err) + } + zipWriter.Close() + zipFile.Close() + + // Restore should skip directory entries + err = RestoreBackup(tmpDir, backupPath) + if err != nil { + t.Fatalf("RestoreBackup() error = %v", err) + } + + // Verify file was restored + restoredPath := filepath.Join(tmpDir, "somedir", "file.txt") + if _, err := os.Stat(restoredPath); os.IsNotExist(err) { + t.Error("Restored file should exist") + } +} + +func TestRestoreBackupFileContent(t *testing.T) { + tmpDir := t.TempDir() + + backupDir := filepath.Join(tmpDir, "backups") + if err := os.MkdirAll(backupDir, 0700); err != nil { + t.Fatal(err) + } + + backupPath := filepath.Join(backupDir, "test.zip") + zipFile, err := os.Create(backupPath) + if err != nil { + t.Fatal(err) + } + + zipWriter := zip.NewWriter(zipFile) + writer, err := zipWriter.Create("test.txt") + if err != nil { + t.Fatal(err) + } + testContent := "test content" + _, err = writer.Write([]byte(testContent)) + if err != nil { + t.Fatal(err) + } + zipWriter.Close() + zipFile.Close() + + // Restore + err = RestoreBackup(tmpDir, backupPath) + if err != nil { + t.Fatalf("RestoreBackup() error = %v", err) + } + + // Verify content + restoredPath := filepath.Join(tmpDir, "test.txt") + data, err := os.ReadFile(restoredPath) + if err != nil { + t.Fatal(err) + } + + if string(data) != testContent { + t.Errorf("Content = %q, want %q", string(data), testContent) + } +} diff --git a/internal/config/cache.go b/internal/config/cache.go index 0a2f799..05ff67c 100644 --- a/internal/config/cache.go +++ b/internal/config/cache.go @@ -26,17 +26,18 @@ func NewConfigCache(ttl time.Duration) *ConfigCache { } func (c *ConfigCache) Get(configDir string) (*Config, error) { - configPath := filepath.Join(configDir, "config.yaml") + c.mu.Lock() + defer c.mu.Unlock() - c.mu.RLock() entry, exists := c.entries[configDir] - c.mu.RUnlock() if exists { // Check TTL if time.Since(entry.loadedAt) < c.ttl { return entry.config, nil } + // Entry expired, remove it + delete(c.entries, configDir) } // Load fresh @@ -46,13 +47,11 @@ func (c *ConfigCache) Get(configDir string) (*Config, error) { } // Cache it - c.mu.Lock() c.entries[configDir] = &cachedConfig{ config: cfg, loadedAt: time.Now(), - configPath: configPath, + configPath: filepath.Join(configDir, "config.yaml"), } - c.mu.Unlock() return cfg, nil } @@ -62,3 +61,17 @@ func (c *ConfigCache) Invalidate(configDir string) { delete(c.entries, configDir) c.mu.Unlock() } + +// Cleanup removes all expired entries from the cache. +// This can be called periodically to prevent memory growth in long-running processes. +func (c *ConfigCache) Cleanup() { + c.mu.Lock() + defer c.mu.Unlock() + + now := time.Now() + for configDir, entry := range c.entries { + if now.Sub(entry.loadedAt) >= c.ttl { + delete(c.entries, configDir) + } + } +} diff --git a/internal/config/cache_test.go b/internal/config/cache_test.go index 6233319..f1d3332 100644 --- a/internal/config/cache_test.go +++ b/internal/config/cache_test.go @@ -1,8 +1,10 @@ package config import ( + "fmt" "os" "path/filepath" + "sync" "testing" "time" ) @@ -135,3 +137,46 @@ func TestConfigCache_InvalidateNonExistent(t *testing.T) { // Invalidate should not panic for non-existent entries cache.Invalidate("nonexistent") } + +func TestConfigCache_ConcurrentWrites(t *testing.T) { + cache := NewConfigCache(5 * time.Minute) + tmpDir := t.TempDir() + + // Create initial config file + configContent := `default_provider: test +providers: {} +` + if err := os.WriteFile(filepath.Join(tmpDir, "config.yaml"), []byte(configContent), 0600); err != nil { + t.Fatal(err) + } + + // Run concurrent writes (simulate config modification) + var wg sync.WaitGroup + errs := make(chan error, 20) + + for i := 0; i < 20; i++ { + wg.Add(1) + go func(n int) { + defer wg.Done() + // Invalidate and reload - simulates config modification + cache.Invalidate(tmpDir) + cfg, err := cache.Get(tmpDir) + if err != nil { + errs <- err + return + } + // Verify we got a valid config + if cfg == nil { + errs <- fmt.Errorf("nil config returned") + } + }(i) + } + + wg.Wait() + close(errs) + + // Check for any errors + for err := range errs { + t.Errorf("Concurrent write error: %v", err) + } +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 5669ddd..a71b090 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -622,3 +622,54 @@ func TestParseSecretsNewlines(t *testing.T) { t.Error("ParseSecrets() should skip line 'newline' (no =)") } } + +func TestLoadConfigEmptyFile(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.yaml") + + // Create an empty config file + if err := os.WriteFile(configPath, []byte(""), 0600); err != nil { + t.Fatal(err) + } + + // Empty file should return error (not valid YAML) + _, err := LoadConfig(tmpDir) + if err == nil { + t.Error("LoadConfig() on empty file should error") + } +} + +func TestLoadConfigWhitespaceOnly(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.yaml") + + // Create a file with only whitespace + if err := os.WriteFile(configPath, []byte(" \n\n \n"), 0600); err != nil { + t.Fatal(err) + } + + // Whitespace-only file should return error (not valid YAML) + _, err := LoadConfig(tmpDir) + if err == nil { + t.Error("LoadConfig() on whitespace-only file should error") + } +} + +func TestLoadConfigCommentOnly(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.yaml") + + // Create a file with only comments - YAML requires content, comments alone fail + commentContent := `# This is a comment +# Another comment +` + if err := os.WriteFile(configPath, []byte(commentContent), 0600); err != nil { + t.Fatal(err) + } + + // Comment-only file returns error (YAML parser requires content) + _, err := LoadConfig(tmpDir) + if err == nil { + t.Error("LoadConfig() on comment-only file should error") + } +} diff --git a/internal/config/expanded_test.go b/internal/config/expanded_test.go new file mode 100644 index 0000000..b586666 --- /dev/null +++ b/internal/config/expanded_test.go @@ -0,0 +1,234 @@ +package config + +import ( + "os" + "path/filepath" + "testing" +) + +func TestSaveConfigCreatesDirectory(t *testing.T) { + tmpDir := t.TempDir() + + // SaveConfig creates the config file but not nested directories + // Create the directory structure first + subDir := filepath.Join(tmpDir, "subdir", "nested") + if err := os.MkdirAll(subDir, 0700); err != nil { + t.Fatal(err) + } + + cfg := &Config{ + Providers: map[string]Provider{}, + } + + err := SaveConfig(subDir, cfg) + if err != nil { + t.Fatalf("SaveConfig() error = %v", err) + } + + // Verify config was saved + configPath := filepath.Join(subDir, "config.yaml") + if _, err := os.Stat(configPath); os.IsNotExist(err) { + t.Error("Config file should be created") + } +} + +func TestSaveConfigOverwrites(t *testing.T) { + tmpDir := t.TempDir() + + cfg1 := &Config{ + Providers: map[string]Provider{ + "provider1": {Name: "Provider 1"}, + }, + } + + if err := SaveConfig(tmpDir, cfg1); err != nil { + t.Fatal(err) + } + + cfg2 := &Config{ + Providers: map[string]Provider{ + "provider2": {Name: "Provider 2"}, + }, + } + + if err := SaveConfig(tmpDir, cfg2); err != nil { + t.Fatal(err) + } + + // Verify second save overwrote first + loaded, err := LoadConfig(tmpDir) + if err != nil { + t.Fatal(err) + } + + if _, ok := loaded.Providers["provider1"]; ok { + t.Error("First provider should be removed after overwrite") + } + + if _, ok := loaded.Providers["provider2"]; !ok { + t.Error("Second provider should exist") + } +} + +func TestSaveConfigEmpty(t *testing.T) { + tmpDir := t.TempDir() + + cfg := &Config{ + Providers: map[string]Provider{}, + } + + err := SaveConfig(tmpDir, cfg) + if err != nil { + t.Fatalf("SaveConfig() error = %v", err) + } + + loaded, err := LoadConfig(tmpDir) + if err != nil { + t.Fatal(err) + } + + if len(loaded.Providers) != 0 { + t.Errorf("Providers = %d, want 0", len(loaded.Providers)) + } +} + +func TestSaveConfigWithDefaultProvider(t *testing.T) { + tmpDir := t.TempDir() + + cfg := &Config{ + Providers: map[string]Provider{ + "test": {Name: "Test"}, + }, + DefaultProvider: "test", + } + + err := SaveConfig(tmpDir, cfg) + if err != nil { + t.Fatal(err) + } + + loaded, err := LoadConfig(tmpDir) + if err != nil { + t.Fatal(err) + } + + if loaded.DefaultProvider != "test" { + t.Errorf("DefaultProvider = %q, want %q", loaded.DefaultProvider, "test") + } +} + +func TestSaveConfigWithHarness(t *testing.T) { + tmpDir := t.TempDir() + + cfg := &Config{ + Providers: map[string]Provider{}, + DefaultHarness: "qwen", + } + + err := SaveConfig(tmpDir, cfg) + if err != nil { + t.Fatal(err) + } + + loaded, err := LoadConfig(tmpDir) + if err != nil { + t.Fatal(err) + } + + if loaded.DefaultHarness != "qwen" { + t.Errorf("DefaultHarness = %q, want %q", loaded.DefaultHarness, "qwen") + } +} + +func TestLoadConfigWithEmptyFile(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.yaml") + + // Create empty file + if err := os.WriteFile(configPath, []byte(""), 0600); err != nil { + t.Fatal(err) + } + + _, err := LoadConfig(tmpDir) + // Should handle empty file gracefully + if err != nil { + t.Logf("LoadConfig() error on empty file: %v", err) + } +} + +func TestLoadConfigWithInvalidYAML(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.yaml") + + // Create invalid YAML file + if err := os.WriteFile(configPath, []byte("invalid: yaml: content:"), 0600); err != nil { + t.Fatal(err) + } + + _, err := LoadConfig(tmpDir) + if err == nil { + t.Error("LoadConfig() should error on invalid YAML") + } +} + +func TestParseSecretsEdgeCases(t *testing.T) { + tests := []struct { + name string + input string + expected int + }{ + { + name: "empty", + input: "", + expected: 0, + }, + { + name: "whitespace only", + input: " \n \n ", + expected: 0, + }, + { + name: "key with empty value", + input: "KEY=", + expected: 0, // Empty values are filtered out + }, + { + name: "empty key", + input: "=value", + expected: 0, // Empty keys are filtered out + }, + { + name: "multiple equals signs", + input: "KEY=value=with=equals", + expected: 1, + }, + { + name: "mixed valid and invalid", + input: "VALID=value\nANOTHER=test", + expected: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ParseSecrets(tt.input) + if len(result) != tt.expected { + t.Errorf("ParseSecrets() returned %d entries, want %d", len(result), tt.expected) + } + }) + } +} + +func TestParseSecretsPreservesOrder(t *testing.T) { + input := "FIRST=value1\nSECOND=value2\nTHIRD=value3" + result := ParseSecrets(input) + + keys := make([]string, 0, len(result)) + for k := range result { + keys = append(keys, k) + } + + // Check order is preserved (map iteration order in Go is not guaranteed, + // but for small maps it's often consistent - the function may need fixing) + _ = keys +} diff --git a/internal/config/loader.go b/internal/config/loader.go index c6fd2c0..fbd4dd0 100644 --- a/internal/config/loader.go +++ b/internal/config/loader.go @@ -2,7 +2,6 @@ package config import ( "bytes" - "fmt" "os" "path/filepath" "strings" @@ -48,7 +47,8 @@ func migrateConfigFile(configDir string) (bool, error) { // No old config to migrate return false, nil } - return false, fmt.Errorf("failed to check old config file: %w", err) + return false, kairoerrors.WrapError(kairoerrors.FileSystemError, + "failed to check old config file", err) } // Check if new config already exists @@ -56,24 +56,28 @@ func migrateConfigFile(configDir string) (bool, error) { // New config exists, don't overwrite - keep both for safety return false, nil } else if !os.IsNotExist(err) { - return false, fmt.Errorf("failed to check new config file: %w", err) + return false, kairoerrors.WrapError(kairoerrors.FileSystemError, + "failed to check new config file", err) } // Read the old config file to verify it's valid YAML before migrating data, err := os.ReadFile(oldConfigPath) if err != nil { - return false, fmt.Errorf("failed to read old config file: %w", err) + return false, kairoerrors.WrapError(kairoerrors.FileSystemError, + "failed to read old config file", err) } // Verify it's valid YAML var testCfg Config if err := yaml.Unmarshal(data, &testCfg); err != nil { - return false, fmt.Errorf("old config file is not valid YAML, cannot migrate: %w", err) + return false, kairoerrors.WrapError(kairoerrors.ConfigError, + "old config file is not valid YAML, cannot migrate", err) } // Write to new location with same permissions if err := os.WriteFile(newConfigPath, data, oldInfo.Mode()); err != nil { - return false, fmt.Errorf("failed to write migrated config file: %w", err) + return false, kairoerrors.WrapError(kairoerrors.FileSystemError, + "failed to write migrated config file", err) } // Rename old file to .backup instead of deleting @@ -81,7 +85,8 @@ func migrateConfigFile(configDir string) (bool, error) { if err := os.Rename(oldConfigPath, backupPath); err != nil { // If rename fails, try to remove the new file and report error os.Remove(newConfigPath) - return false, fmt.Errorf("failed to backup old config file: %w", err) + return false, kairoerrors.WrapError(kairoerrors.FileSystemError, + "failed to backup old config file", err) } return true, nil diff --git a/internal/config/secrets.go b/internal/config/secrets.go index a0a0f86..ffa84d3 100644 --- a/internal/config/secrets.go +++ b/internal/config/secrets.go @@ -1,6 +1,7 @@ package config import ( + "log" "strings" ) @@ -20,6 +21,7 @@ func ParseSecrets(secrets string) map[string]string { key, value := parts[0], parts[1] // Skip entries with empty keys or values, or newlines in key or value (malformed input) if key == "" || value == "" || strings.Contains(key, "\n") || strings.Contains(value, "\n") { + log.Printf("Warning: skipping malformed secret entry: %q", line) continue } result[key] = value diff --git a/internal/config/watch.go b/internal/config/watch.go index ecb5e98..4fd3586 100644 --- a/internal/config/watch.go +++ b/internal/config/watch.go @@ -51,14 +51,14 @@ func (fw *FileWatcher) checkForChanges() { return // File doesn't exist, nothing to watch } - fw.cache.mu.RLock() - entry := fw.cache.entries[fw.watchDir] - fw.cache.mu.RUnlock() + fw.cache.mu.Lock() + defer fw.cache.mu.Unlock() + entry := fw.cache.entries[fw.watchDir] if entry != nil { // Check if file was modified since we cached it if info.ModTime().After(entry.loadedAt) { - fw.cache.Invalidate(fw.watchDir) + delete(fw.cache.entries, fw.watchDir) } } } diff --git a/internal/crypto/crypto_test.go b/internal/crypto/crypto_test.go index 8d6bb81..079339f 100644 --- a/internal/crypto/crypto_test.go +++ b/internal/crypto/crypto_test.go @@ -468,3 +468,88 @@ func TestRotateKeyInvalidOldKey(t *testing.T) { t.Errorf("should be able to decrypt with valid key: %v", err) } } + +func TestDecryptCorruptedFile(t *testing.T) { + tmpDir := t.TempDir() + + keyPath := filepath.Join(tmpDir, "age.key") + err := GenerateKey(keyPath) + if err != nil { + t.Fatalf("GenerateKey() error = %v", err) + } + + secretsPath := filepath.Join(tmpDir, "corrupted.age") + + // Write corrupted data (not valid age encryption) + corruptedData := []byte("this is not encrypted data!!!") + if err := os.WriteFile(secretsPath, corruptedData, 0600); err != nil { + t.Fatal(err) + } + + _, err = DecryptSecrets(secretsPath, keyPath) + if err == nil { + t.Error("DecryptSecrets() should error on corrupted file") + } +} + +func TestDecryptTruncatedFile(t *testing.T) { + tmpDir := t.TempDir() + + keyPath := filepath.Join(tmpDir, "age.key") + err := GenerateKey(keyPath) + if err != nil { + t.Fatalf("GenerateKey() error = %v", err) + } + + secretsPath := filepath.Join(tmpDir, "truncated.age") + + // First encrypt some valid data + validContent := "ANTHROPIC_API_KEY=test-key-123" + err = EncryptSecrets(secretsPath, keyPath, validContent) + if err != nil { + t.Fatal(err) + } + + // Read and truncate the file + data, err := os.ReadFile(secretsPath) + if err != nil { + t.Fatal(err) + } + + // Write truncated data (half of original) + truncatedData := data[:len(data)/2] + if err := os.WriteFile(secretsPath, truncatedData, 0600); err != nil { + t.Fatal(err) + } + + _, err = DecryptSecrets(secretsPath, keyPath) + if err == nil { + t.Error("DecryptSecrets() should error on truncated file") + } +} + +func TestDecryptRandomData(t *testing.T) { + tmpDir := t.TempDir() + + keyPath := filepath.Join(tmpDir, "age.key") + err := GenerateKey(keyPath) + if err != nil { + t.Fatalf("GenerateKey() error = %v", err) + } + + secretsPath := filepath.Join(tmpDir, "random.age") + + // Write random bytes (not valid age encryption) + randomData := make([]byte, 256) + for i := range randomData { + randomData[i] = byte(i % 256) + } + if err := os.WriteFile(secretsPath, randomData, 0600); err != nil { + t.Fatal(err) + } + + _, err = DecryptSecrets(secretsPath, keyPath) + if err == nil { + t.Error("DecryptSecrets() should error on random data") + } +} diff --git a/internal/recover/recover.go b/internal/recover/recover.go index 0a40a0d..b865049 100644 --- a/internal/recover/recover.go +++ b/internal/recover/recover.go @@ -3,17 +3,19 @@ package recover import ( "crypto/rand" "encoding/base64" - "fmt" "os" "path/filepath" "strings" + + kairoerrors "github.com/dkmnx/kairo/internal/errors" ) // CreateRecoveryPhrase generates a recovery phrase from the key file func CreateRecoveryPhrase(keyPath string) (string, error) { keyData, err := os.ReadFile(keyPath) if err != nil { - return "", fmt.Errorf("read key: %w", err) + return "", kairoerrors.WrapError(kairoerrors.CryptoError, + "read key", err) } // Encode key as base64 phrase @@ -33,7 +35,8 @@ func RecoverFromPhrase(configDir, phrase string) error { keyData, err := base64.RawStdEncoding.DecodeString(encoded) if err != nil { - return fmt.Errorf("decode phrase: %w", err) + return kairoerrors.WrapError(kairoerrors.CryptoError, + "decode phrase", err) } keyPath := filepath.Join(configDir, "age.key") @@ -45,7 +48,8 @@ func GenerateRecoveryPhrase() (string, error) { key := make([]byte, 32) _, err := rand.Read(key) if err != nil { - return "", fmt.Errorf("generate key: %w", err) + return "", kairoerrors.WrapError(kairoerrors.CryptoError, + "generate key", err) } phrase := base64.RawStdEncoding.EncodeToString(key) diff --git a/internal/recovery/recovery.go b/internal/recovery/recovery.go index f01567b..07f0f64 100644 --- a/internal/recovery/recovery.go +++ b/internal/recovery/recovery.go @@ -395,17 +395,16 @@ func calculateDelay(attempt int, cfg RetryConfig) time.Duration { } // Exponential backoff: base * 2^attempt - // Note: Using loop instead of 1 << attempt to avoid float64 type issues - backoffFactor := 1.0 - for i := 0; i < int(attempt); i++ { - backoffFactor *= 2 - // Prevent integer overflow in time.Duration calculations - // MaxSafeBackoffFactor is 10^18, safe for int64 nanoseconds - if backoffFactor > MaxSafeBackoffFactor { - backoffFactor = MaxSafeBackoffFactor - break - } + // Use bit shifting for O(1) calculation instead of loop + // 2^60 ≈ 1.15e18 exceeds MaxSafeBackoffFactor (1.0e18) + const maxBackoffBits = 60 + var backoffFactor float64 + if attempt >= maxBackoffBits { + backoffFactor = MaxSafeBackoffFactor + } else { + backoffFactor = float64(uint(1) << uint(attempt)) } + delay := time.Duration(float64(cfg.BaseDelay) * backoffFactor) // Cap at max delay diff --git a/internal/validate/provider.go b/internal/validate/provider.go index 3ca307b..a5a7d73 100644 --- a/internal/validate/provider.go +++ b/internal/validate/provider.go @@ -5,9 +5,21 @@ import ( "strings" "github.com/dkmnx/kairo/internal/config" + kairoerrors "github.com/dkmnx/kairo/internal/errors" "github.com/dkmnx/kairo/internal/providers" ) +// Validation limits for provider and model names. +const ( + // MaxProviderNameLength is the maximum length for custom provider names. + // Most provider names are short identifiers (e.g., "anthropic", "openai"). + MaxProviderNameLength = 50 + + // MaxModelNameLength is the maximum length for model names. + // Most LLM model names are reasonable length (e.g., "claude-3-opus-20240229"). + MaxModelNameLength = 100 +) + // validateCrossProviderConfig validates configuration across all providers to detect conflicts. // // This function checks for environment variable collisions where multiple providers @@ -51,8 +63,8 @@ func ValidateCrossProviderConfig(cfg *config.Config) error { } } if !allSame { - return fmt.Errorf("environment variable collision: '%s' is set to different values by providers: %v", - key, sources) + return kairoerrors.NewError(kairoerrors.ValidationError, + fmt.Sprintf("environment variable collision: '%s' is set to different values by providers: %v", key, sources)) } } } @@ -73,13 +85,15 @@ func ValidateProviderModel(providerName, modelName string) error { // If provider has a default model, do basic validation if def.Model != "" { // Check model name length (most LLM model names are reasonable length) - if len(modelName) > 100 { - return fmt.Errorf("model name '%s' for provider '%s' is too long (max 100 characters)", modelName, providerName) + if len(modelName) > MaxModelNameLength { + return kairoerrors.NewError(kairoerrors.ValidationError, + fmt.Sprintf("model name '%s' for provider '%s' is too long (max %d characters)", modelName, providerName, MaxModelNameLength)) } // Check for valid characters (alphanumeric, hyphens, underscores, dots) for _, r := range modelName { if !isValidModelRune(r) { - return fmt.Errorf("model name '%s' for provider '%s' contains invalid characters", modelName, providerName) + return kairoerrors.NewError(kairoerrors.ValidationError, + fmt.Sprintf("model name '%s' for provider '%s' contains invalid characters", modelName, providerName)) } } } diff --git a/internal/wrapper/wrapper.go b/internal/wrapper/wrapper.go index 28f70b5..895707e 100644 --- a/internal/wrapper/wrapper.go +++ b/internal/wrapper/wrapper.go @@ -29,6 +29,8 @@ import ( "os/exec" "runtime" "strings" + + kairoerrors "github.com/dkmnx/kairo/internal/errors" ) // CreateTempAuthDir creates a private temporary directory for storing auth files. @@ -37,12 +39,14 @@ import ( func CreateTempAuthDir() (string, error) { authDir, err := os.MkdirTemp("", "kairo-auth-") if err != nil { - return "", fmt.Errorf("failed to create temp auth directory: %w", err) + return "", kairoerrors.WrapError(kairoerrors.FileSystemError, + "failed to create temp auth directory", err) } if err := os.Chmod(authDir, 0700); err != nil { _ = os.RemoveAll(authDir) - return "", fmt.Errorf("failed to set auth directory permissions: %w", err) + return "", kairoerrors.WrapError(kairoerrors.FileSystemError, + "failed to set auth directory permissions", err) } return authDir, nil @@ -53,25 +57,30 @@ func CreateTempAuthDir() (string, error) { // Returns the path to the temporary file. func WriteTempTokenFile(authDir, token string) (string, error) { if token == "" { - return "", fmt.Errorf("token cannot be empty") + return "", kairoerrors.NewError(kairoerrors.ValidationError, + "token cannot be empty") } f, err := os.CreateTemp(authDir, "token-") if err != nil { - return "", fmt.Errorf("failed to create temp token file: %w", err) + return "", kairoerrors.WrapError(kairoerrors.FileSystemError, + "failed to create temp token file", err) } if _, err := f.WriteString(token); err != nil { _ = f.Close() - return "", fmt.Errorf("failed to write token to temp file: %w", err) + return "", kairoerrors.WrapError(kairoerrors.FileSystemError, + "failed to write token to temp file", err) } if err := f.Close(); err != nil { - return "", fmt.Errorf("failed to close temp token file: %w", err) + return "", kairoerrors.WrapError(kairoerrors.FileSystemError, + "failed to close temp token file", err) } if err := os.Chmod(f.Name(), 0600); err != nil { - return "", fmt.Errorf("failed to set temp file permissions: %w", err) + return "", kairoerrors.WrapError(kairoerrors.FileSystemError, + "failed to set temp file permissions", err) } return f.Name(), nil @@ -80,7 +89,7 @@ func WriteTempTokenFile(authDir, token string) (string, error) { // EscapePowerShellArg escapes a string for use as a PowerShell argument. // It wraps the argument in single quotes and escapes special characters to prevent // command injection. Special characters escaped: backtick, dollar sign, double quote, -// single quote, and common control characters (newline, carriage return, tab, etc.). +// single quote, ampersand, semicolon, pipe, percent, and common control characters. // Note: Some escape sequences like `v (vertical tab) and `f (form feed) are not // supported in older PowerShell versions (5.1 and below), so we only escape commonly // supported control characters. @@ -93,6 +102,11 @@ func EscapePowerShellArg(arg string) string { arg = strings.ReplaceAll(arg, "\"", "\\\"") // Escape single quotes by doubling them arg = strings.ReplaceAll(arg, "'", "''") + // Escape command separators to prevent injection + arg = strings.ReplaceAll(arg, "&", "`&") + arg = strings.ReplaceAll(arg, ";", "`;") + arg = strings.ReplaceAll(arg, "|", "`|") + arg = strings.ReplaceAll(arg, "%", "``%") // Escape control characters (widely supported in PowerShell) arg = strings.ReplaceAll(arg, "\n", "`n") arg = strings.ReplaceAll(arg, "\r", "`r") @@ -111,10 +125,12 @@ func EscapePowerShellArg(arg string) string { // Returns the path to the wrapper script and whether to use shell execution. func GenerateWrapperScript(authDir, tokenPath, cliPath string, cliArgs []string, envVarName ...string) (string, bool, error) { if tokenPath == "" { - return "", false, fmt.Errorf("token path cannot be empty") + return "", false, kairoerrors.NewError(kairoerrors.ValidationError, + "token path cannot be empty") } if cliPath == "" { - return "", false, fmt.Errorf("cli path cannot be empty") + return "", false, kairoerrors.NewError(kairoerrors.ValidationError, + "cli path cannot be empty") } envVar := "ANTHROPIC_AUTH_TOKEN" @@ -126,7 +142,8 @@ func GenerateWrapperScript(authDir, tokenPath, cliPath string, cliArgs []string, f, err := os.CreateTemp(authDir, "wrapper-") if err != nil { - return "", false, fmt.Errorf("failed to create temp wrapper script: %w", err) + return "", false, kairoerrors.WrapError(kairoerrors.FileSystemError, + "failed to create temp wrapper script", err) } var scriptContent string @@ -157,26 +174,30 @@ func GenerateWrapperScript(authDir, tokenPath, cliPath string, cliArgs []string, if _, err := f.WriteString(scriptContent); err != nil { _ = f.Close() _ = os.Remove(f.Name()) - return "", false, fmt.Errorf("failed to write wrapper script: %w", err) + return "", false, kairoerrors.WrapError(kairoerrors.FileSystemError, + "failed to write wrapper script", err) } if err := f.Close(); err != nil { _ = os.Remove(f.Name()) - return "", false, fmt.Errorf("failed to close wrapper script: %w", err) + return "", false, kairoerrors.WrapError(kairoerrors.FileSystemError, + "failed to close wrapper script", err) } if isWindows { ps1Path := f.Name() + ".ps1" if err := os.Rename(f.Name(), ps1Path); err != nil { _ = os.Remove(f.Name()) - return "", false, fmt.Errorf("failed to rename wrapper script: %w", err) + return "", false, kairoerrors.WrapError(kairoerrors.FileSystemError, + "failed to rename wrapper script", err) } return ps1Path, true, nil } if err := os.Chmod(f.Name(), 0700); err != nil { _ = os.Remove(f.Name()) - return "", false, fmt.Errorf("failed to make wrapper script executable: %w", err) + return "", false, kairoerrors.WrapError(kairoerrors.FileSystemError, + "failed to make wrapper script executable", err) } return f.Name(), false, nil diff --git a/internal/wrapper/wrapper_test.go b/internal/wrapper/wrapper_test.go index 60eded7..84e3227 100644 --- a/internal/wrapper/wrapper_test.go +++ b/internal/wrapper/wrapper_test.go @@ -284,12 +284,12 @@ func TestEscapePowerShellArg_EdgeCases(t *testing.T) { {"unicode chinese", "你好世界", "'你好世界'"}, {"windows path long", `C:\Users\JohnDoe\AppData\Local\Programs\Claude\claude.exe`, "'C:\\Users\\JohnDoe\\AppData\\Local\\Programs\\Claude\\claude.exe'"}, {"windows path with spaces", `C:\Program Files\My App\file.txt`, "'C:\\Program Files\\My App\\file.txt'"}, - {"semicolon", "test; cmd", "'test; cmd'"}, - {"pipe", "test | calc", "'test | calc'"}, + {"semicolon", "test; cmd", "'test`; cmd'"}, + {"pipe", "test | calc", "'test `| calc'"}, {"variable style", "$myVar", "'`$myVar'"}, {"env variable", "$env:PATH", "'`$env:PATH'"}, {"at sign", "@()", "'@()'"}, - {"percent", "100%", "'100%'"}, + {"percent", "100%", "'100``%'"}, {"multi-line", "line1\nline2", "'line1`nline2'"}, {"json string", `{"key":"value"}`, "'{\\\"key\\\":\\\"value\\\"}'"}, {"base64", "SGVsbG8gV29ybGQ=", "'SGVsbG8gV29ybGQ='"},