Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 1 addition & 204 deletions cmd/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@ import (
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"time"

"github.com/dkmnx/kairo/internal/audit"
"github.com/dkmnx/kairo/internal/config"
Expand Down Expand Up @@ -51,7 +49,7 @@ var configCmd = &cobra.Command{

cfg, err := configCache.Get(dir)
if err != nil && !os.IsNotExist(err) {
ui.PrintError(fmt.Sprintf("Error loading config: %v", err))
handleConfigError(cmd, err)
return
}
if err != nil {
Expand Down Expand Up @@ -209,204 +207,3 @@ var configCmd = &cobra.Command{
func init() {
rootCmd.AddCommand(configCmd)
}

// createConfigBackup creates a backup of the current configuration file.
// Returns the path to the backup file or an error if the backup fails.
// The backup file is named with a timestamp to allow for multiple backups.
func createConfigBackup(configDir string) (string, error) {
configPath := getConfigPath(configDir)

// Read the current config file
data, err := os.ReadFile(configPath)
if err != nil {
return "", fmt.Errorf("failed to read config for backup: %w", err)
}

// Create backup filename with timestamp
backupPath := getBackupPath(configDir)

// Write the backup
if err := os.WriteFile(backupPath, data, 0600); err != nil {
return "", fmt.Errorf("failed to write backup file: %w", err)
}

return backupPath, nil
}

// rollbackConfig restores the configuration from a backup file.
// If successful, the current config is replaced with the backup.
// The backup file is preserved after rollback for safety.
func rollbackConfig(configDir, backupPath string) error {
// Verify backup exists
if _, err := os.Stat(backupPath); os.IsNotExist(err) {
return fmt.Errorf("backup file not found: %s", backupPath)
}

// Read backup data
data, err := os.ReadFile(backupPath)
if err != nil {
return fmt.Errorf("failed to read backup file: %w", err)
}

// Write to config file
configPath := getConfigPath(configDir)
if err := os.WriteFile(configPath, data, 0600); err != nil {
return fmt.Errorf("failed to restore config from backup: %w", err)
}

return nil
}

// withConfigTransaction executes a function within a transaction-like context.
//
// This function creates a backup of the configuration before executing the
// provided function. If the function returns an error, the configuration
// is automatically rolled back to the backup. This provides atomic-like
// behavior for configuration updates.
//
// Parameters:
// - configDir: Directory containing the configuration file
// - fn: Function to execute within the transaction context
//
// Returns:
// - error: Returns error if transaction fails or rollback fails
//
// Error conditions:
// - Returns error when unable to create configuration backup
// - Returns error when fn returns an error (after attempting rollback)
// - Returns error if rollback fails after transaction failure (critical error)
//
// Thread Safety: Not thread-safe due to file I/O operations
// Security Notes: Backup files retain same permissions as original config (0600)
func withConfigTransaction(configDir string, fn func(txDir string) error) error {
// Create backup before transaction
backupPath, err := createConfigBackup(configDir)
if err != nil {
return fmt.Errorf("failed to create transaction backup: %w", err)
}

// Execute the transaction function
err = fn(configDir)

// If transaction failed, rollback
if err != nil {
if rbErr := rollbackConfig(configDir, backupPath); rbErr != nil {
// Rollback failed - this is a critical situation
return fmt.Errorf("transaction failed and rollback also failed: tx_err=%w, rollback_err=%w", err, rbErr)
}
return fmt.Errorf("transaction failed, changes rolled back: %w", err)
}

return nil
}

// getConfigPath returns the full path to the config file.
func getConfigPath(configDir string) string {
return filepath.Join(configDir, "config.yaml")
}

// getBackupPath returns a backup file path with timestamp.
func getBackupPath(configDir string) string {
timestamp := time.Now().Format("20060102-150405")
return filepath.Join(configDir, fmt.Sprintf("config.yaml.backup.%s", timestamp))
}

// validateCrossProviderConfig validates configuration across all providers to detect conflicts.
//
// This function checks for environment variable collisions where multiple providers
// attempt to set the same environment variable with different values. Collisions
// with identical values are allowed (idempotent).
//
// Parameters:
// - cfg: Configuration object containing all provider definitions
//
// Returns:
// - error: Returns error if conflicting environment variables are detected
//
// Error conditions:
// - Returns error when same environment variable is set by multiple providers
// with different values (e.g., "API_KEY" set to "key1" by provider A and "key2" by provider B)
//
// Thread Safety: Thread-safe (no shared state, read-only access to config)
func validateCrossProviderConfig(cfg *config.Config) error {
// Build a map of env var names to their values and which providers set them
type envVarSource struct {
provider string
value string
}
envVarMap := make(map[string][]envVarSource)

for providerName, provider := range cfg.Providers {
for _, envVar := range provider.EnvVars {
// Parse env var to get key and value
parts := strings.SplitN(envVar, "=", 2)
if len(parts) != 2 {
continue
}
key := strings.TrimSpace(parts[0])
value := strings.TrimSpace(parts[1])

envVarMap[key] = append(envVarMap[key], envVarSource{
provider: providerName,
value: value,
})
}
}

// Check for collisions - env vars set by multiple providers with different values
for key, sources := range envVarMap {
if len(sources) > 1 {
// Check if all sources have the same value
firstValue := sources[0].value
allSame := true
for _, s := range sources {
if s.value != firstValue {
allSame = false
break
}
}
if !allSame {
return fmt.Errorf("environment variable collision: '%s' is set to different values by providers: %v",
key, sources)
}
}
}

return nil
}

// validateProviderModel validates a model name against provider capabilities.
// For built-in providers with default models, this ensures the model is reasonable.
// Returns an error if the model name is invalid.
func validateProviderModel(providerName, modelName string) error {
if modelName == "" {
return nil // Empty model is allowed (will use provider default)
}

// Check if this is a built-in provider
if def, ok := providers.GetBuiltInProvider(providerName); ok {
// If provider has a default model, do basic validation
if def.Model != "" {
// Check model name length (most LLM model names are reasonable length)
if len(modelName) > 100 {
return fmt.Errorf("model name '%s' for provider '%s' is too long (max 100 characters)", modelName, providerName)
}
// Check for valid characters (alphanumeric, hyphens, underscores, dots)
for _, r := range modelName {
if !isValidModelRune(r) {
return fmt.Errorf("model name '%s' for provider '%s' contains invalid characters", modelName, providerName)
}
}
}
}

return nil
}

// isValidModelRune returns true if the rune is valid in a model name.
func isValidModelRune(r rune) bool {
return (r >= 'a' && r <= 'z') ||
(r >= 'A' && r <= 'Z') ||
(r >= '0' && r <= '9') ||
r == '-' || r == '_' || r == '.'
}
11 changes: 6 additions & 5 deletions cmd/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

"github.com/dkmnx/kairo/internal/config"
"github.com/dkmnx/kairo/internal/providers"
"github.com/dkmnx/kairo/internal/validate"
)

func TestProviderDefaults(t *testing.T) {
Expand Down Expand Up @@ -470,7 +471,7 @@ func TestConfig_CrossProviderValidation(t *testing.T) {
}

// This should detect the collision
err := validateCrossProviderConfig(cfg)
err := validate.ValidateCrossProviderConfig(cfg)
if err == nil {
t.Error("Expected error for env var collision, got nil")
}
Expand All @@ -483,18 +484,18 @@ func TestConfig_CrossProviderValidation(t *testing.T) {
t.Run("ModelValidation", func(t *testing.T) {
// Test with a provider that has a default model (zai has "glm-4.7")
// This should validate the model name
err := validateProviderModel("zai", "invalid@model#name!")
err := validate.ValidateProviderModel("zai", "invalid@model#name!")
if err == nil {
t.Error("Expected error for invalid model with special characters, got nil")
}
// Test with a model that's too long
longModel := strings.Repeat("a", 101)
err = validateProviderModel("zai", longModel)
err = validate.ValidateProviderModel("zai", longModel)
if err == nil {
t.Error("Expected error for model name that's too long, got nil")
}
// Test with a valid model - should not error
err = validateProviderModel("zai", "valid-model-name.123")
err = validate.ValidateProviderModel("zai", "valid-model-name.123")
if err != nil {
t.Errorf("Expected valid model to pass validation, got error: %v", err)
}
Expand All @@ -520,7 +521,7 @@ func TestConfig_CrossProviderValidation(t *testing.T) {
DefaultProvider: "zai",
}

err := validateCrossProviderConfig(cfg)
err := validate.ValidateCrossProviderConfig(cfg)
if err != nil {
t.Errorf("Expected valid config to pass validation, got error: %v", err)
}
Expand Down
114 changes: 114 additions & 0 deletions cmd/config_tx.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
package cmd

import (
"fmt"
"os"
"path/filepath"
"time"
)

// createConfigBackup creates a backup of the current configuration file.
// Returns the path to the backup file or an error if the backup fails.
// The backup file is named with a timestamp to allow for multiple backups.
func createConfigBackup(configDir string) (string, error) {
configPath := getConfigPath(configDir)

// Read the current config file
data, err := os.ReadFile(configPath)
if err != nil {
return "", fmt.Errorf("failed to read config for backup: %w", err)
}

// Create backup filename with timestamp
backupPath := getBackupPath(configDir)

// Write the backup
if err := os.WriteFile(backupPath, data, 0600); err != nil {
return "", fmt.Errorf("failed to write backup file: %w", err)
}

return backupPath, nil
}

// rollbackConfig restores the configuration from a backup file.
// If successful, the current config is replaced with the backup.
// The backup file is preserved after rollback for safety.
func rollbackConfig(configDir, backupPath string) error {
// Verify backup exists
if _, err := os.Stat(backupPath); os.IsNotExist(err) {
return fmt.Errorf("backup file not found: %s", backupPath)
}

// Read backup data
data, err := os.ReadFile(backupPath)
if err != nil {
return fmt.Errorf("failed to read backup file: %w", err)
}

// Write to config file
configPath := getConfigPath(configDir)
if err := os.WriteFile(configPath, data, 0600); err != nil {
return fmt.Errorf("failed to restore config from backup: %w", err)
}

return nil
}

// withConfigTransaction executes a function within a transaction-like context.
//
// This function creates a backup of the configuration before executing the
// provided function. If the function returns an error, the configuration
// is automatically rolled back to the backup. This provides atomic-like
// behavior for configuration updates.
//
// Parameters:
// - configDir: Directory containing the configuration file
// - fn: Function to execute within the transaction context
//
// Returns:
// - error: Returns error if transaction fails or rollback fails
//
// Error conditions:
// - Returns error when unable to create configuration backup
// - Returns error when fn returns an error (after attempting rollback)
// - Returns error if rollback fails after transaction failure (critical error)
//
// Thread Safety: Not thread-safe due to file I/O operations
// Security Notes: Backup files retain same permissions as original config (0600)
func withConfigTransaction(configDir string, fn func(txDir string) error) error {
// Create backup before transaction
backupPath, err := createConfigBackup(configDir)
if err != nil {
return fmt.Errorf("failed to create transaction backup: %w", err)
}

// Execute the transaction function
err = fn(configDir)

// If transaction failed, rollback
if err != nil {
if rbErr := rollbackConfig(configDir, backupPath); rbErr != nil {
// Rollback failed - this is a critical situation
return fmt.Errorf("transaction failed and rollback also failed: tx_err=%w, rollback_err=%w", err, rbErr)
}
return fmt.Errorf("transaction failed, changes rolled back: %w", err)
}

// Transaction succeeded - clean up the backup file
// Best-effort cleanup, ignore errors
_ = os.Remove(backupPath)

return nil
}

// getConfigPath returns the full path to the config file.
func getConfigPath(configDir string) string {
return filepath.Join(configDir, "config.yaml")
}

// getBackupPath returns a backup file path with timestamp.
func getBackupPath(configDir string) string {
// Use nanosecond precision to avoid filename conflicts with rapid successive operations
timestamp := time.Now().Format("20060102-150405.000000000")
return filepath.Join(configDir, fmt.Sprintf("config.yaml.backup.%s", timestamp))
}
2 changes: 1 addition & 1 deletion cmd/default.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ var defaultCmd = &cobra.Command{
}
return
}
ui.PrintError(fmt.Sprintf("Error loading config: %v", err))
handleConfigError(cmd, err)
return
}

Expand Down
Loading
Loading