diff --git a/internal/semantic/commands/collection_test.go b/internal/semantic/commands/collection_test.go index 69d7d41..9fa7d4e 100644 --- a/internal/semantic/commands/collection_test.go +++ b/internal/semantic/commands/collection_test.go @@ -13,12 +13,12 @@ func TestDeriveCollectionFromPath(t *testing.T) { }{ { name: "simple path with code suffix", - path: ".llm-index/code", + path: ".index/code", expected: "code", }, { name: "simple path with docs suffix", - path: ".llm-index/docs", + path: ".index/docs", expected: "docs", }, { @@ -28,17 +28,17 @@ func TestDeriveCollectionFromPath(t *testing.T) { }, { name: "path with trailing slash", - path: ".llm-index/code/", + path: ".index/code/", expected: "code", }, { - name: "default .llm-index should return empty", - path: ".llm-index", + name: "default .index should return empty", + path: ".index", expected: "", }, { - name: "llm-index without dot should return empty", - path: "llm-index", + name: "index without dot should return empty", + path: "index", expected: "", }, { @@ -48,27 +48,27 @@ func TestDeriveCollectionFromPath(t *testing.T) { }, { name: "path with hyphen converted to underscore", - path: ".llm-index/my-collection", + path: ".index/my-collection", expected: "my_collection", }, { name: "path with dot converted to underscore", - path: ".llm-index/v1.0", + path: ".index/v1.0", expected: "v1_0", }, { name: "path starting with number gets prefix", - path: ".llm-index/123test", + path: ".index/123test", expected: "idx_123test", }, { name: "alphanumeric path unchanged", - path: ".llm-index/MyProject2024", + path: ".index/MyProject2024", expected: "MyProject2024", }, { name: "path with underscores preserved", - path: ".llm-index/my_project_code", + path: ".index/my_project_code", expected: "my_project_code", }, { @@ -78,7 +78,7 @@ func TestDeriveCollectionFromPath(t *testing.T) { }, { name: "special characters removed", - path: ".llm-index/test@project#1", + path: ".index/test@project#1", expected: "testproject1", }, } @@ -108,7 +108,7 @@ func TestResolveCollectionName(t *testing.T) { t.Run("priority 1: explicit collection flag", func(t *testing.T) { collectionName = "explicit_collection" - indexDir = ".llm-index/code" + indexDir = ".index/code" os.Setenv("QDRANT_COLLECTION", "env_collection") result := resolveCollectionName() @@ -119,7 +119,7 @@ func TestResolveCollectionName(t *testing.T) { t.Run("priority 2: derive from index-dir", func(t *testing.T) { collectionName = "" - indexDir = ".llm-index/myproject" + indexDir = ".index/myproject" os.Setenv("QDRANT_COLLECTION", "env_collection") result := resolveCollectionName() @@ -130,7 +130,7 @@ func TestResolveCollectionName(t *testing.T) { t.Run("priority 3: environment variable", func(t *testing.T) { collectionName = "" - indexDir = ".llm-index" // default, won't derive + indexDir = ".index" // default, won't derive os.Setenv("QDRANT_COLLECTION", "env_collection") result := resolveCollectionName() @@ -141,7 +141,7 @@ func TestResolveCollectionName(t *testing.T) { t.Run("priority 4: default llm_semantic", func(t *testing.T) { collectionName = "" - indexDir = ".llm-index" + indexDir = ".index" os.Unsetenv("QDRANT_COLLECTION") result := resolveCollectionName() @@ -163,7 +163,7 @@ func TestResolveCollectionName(t *testing.T) { t.Run("default index-dir with no env uses default", func(t *testing.T) { collectionName = "" - indexDir = ".llm-index" + indexDir = ".index" os.Unsetenv("QDRANT_COLLECTION") result := resolveCollectionName() @@ -188,7 +188,7 @@ func TestResolveCollectionName_EdgeCases(t *testing.T) { t.Run("whitespace-only collection name treated as empty", func(t *testing.T) { collectionName = " " - indexDir = ".llm-index/code" + indexDir = ".index/code" os.Unsetenv("QDRANT_COLLECTION") result := resolveCollectionName() @@ -250,7 +250,7 @@ func TestIndexDirFlag_InRootCmd(t *testing.T) { t.Fatal("--index-dir flag not found in root command") } - if flag.DefValue != ".llm-index" { - t.Errorf("--index-dir default should be '.llm-index', got %q", flag.DefValue) + if flag.DefValue != ".index" { + t.Errorf("--index-dir default should be '.index', got %q", flag.DefValue) } } diff --git a/internal/semantic/commands/index.go b/internal/semantic/commands/index.go index 35e886b..276e8fe 100644 --- a/internal/semantic/commands/index.go +++ b/internal/semantic/commands/index.go @@ -311,18 +311,18 @@ func runCalibration(ctx context.Context, storage semantic.Storage, embedder sema func resolveIndexPath(rootPath string) string { // If custom index dir specified - if indexDir != "" && indexDir != ".llm-index" { + if indexDir != "" && indexDir != ".index" { return filepath.Join(indexDir, "semantic.db") } // Try git root first gitRoot, err := findGitRootFrom(rootPath) if err == nil { - return filepath.Join(gitRoot, ".llm-index", "semantic.db") + return filepath.Join(gitRoot, ".index", "semantic.db") } // Fall back to the indexed directory - return filepath.Join(rootPath, ".llm-index", "semantic.db") + return filepath.Join(rootPath, ".index", "semantic.db") } func findGitRootFrom(startPath string) (string, error) { diff --git a/internal/semantic/commands/memory.go b/internal/semantic/commands/memory.go index 402c326..1838f7f 100644 --- a/internal/semantic/commands/memory.go +++ b/internal/semantic/commands/memory.go @@ -100,7 +100,7 @@ func runMemoryStore(ctx context.Context, opts memoryStoreOpts) error { indexPath = findIndexPath() if indexPath == "" { // Create default index path - indexPath = ".llm-index/semantic.db" + indexPath = ".index/semantic.db" } } @@ -277,6 +277,25 @@ func runMemorySearch(ctx context.Context, opts memorySearchOpts) error { return fmt.Errorf("search failed: %w", err) } + // Track retrieval stats (automatic for memory profile) + if len(results) > 0 { + if tracker, ok := storage.(semantic.MemoryStatsTracker); ok { + // Build retrieval batch + retrievals := make([]semantic.MemoryRetrieval, len(results)) + for i, r := range results { + retrievals[i] = semantic.MemoryRetrieval{ + MemoryID: r.Entry.ID, + Score: r.Score, + } + } + // Track in background - don't fail search if tracking fails + if err := tracker.TrackMemoryRetrievalBatch(ctx, retrievals, opts.query); err != nil { + // Log error but don't fail the search + fmt.Fprintf(os.Stderr, "Warning: failed to track retrieval stats: %v\n", err) + } + } + } + // Output results if opts.jsonOutput || opts.minOutput { return outputMemoryJSON(results, opts.minOutput) @@ -586,7 +605,7 @@ func runMemoryImport(ctx context.Context, opts memoryImportOpts) error { indexPath = findIndexPath() if indexPath == "" { // Create default index path - indexPath = ".llm-index/semantic.db" + indexPath = ".index/semantic.db" } } diff --git a/internal/semantic/commands/root.go b/internal/semantic/commands/root.go index 3c9f009..90374f0 100644 --- a/internal/semantic/commands/root.go +++ b/internal/semantic/commands/root.go @@ -96,7 +96,7 @@ Supports any OpenAI-compatible embedding API (Ollama, vLLM, OpenAI, Azure, etc.) rootCmd.PersistentFlags().StringVar(&apiURL, "api-url", getDefaultAPIURL(), "Embedding API URL (OpenAI-compatible)") rootCmd.PersistentFlags().StringVar(&model, "model", getDefaultModel(), "Embedding model name (or set LLM_SEMANTIC_MODEL env var)") rootCmd.PersistentFlags().StringVar(&apiKey, "api-key", "", "API key (or set LLM_SEMANTIC_API_KEY env var)") - rootCmd.PersistentFlags().StringVar(&indexDir, "index-dir", ".llm-index", "Directory for semantic index") + rootCmd.PersistentFlags().StringVar(&indexDir, "index-dir", ".index", "Directory for semantic index") rootCmd.PersistentFlags().StringVar(&storageType, "storage", "sqlite", "Storage backend: sqlite (default) or qdrant") rootCmd.PersistentFlags().StringVar(&collectionName, "collection", "", "Qdrant collection name (default: QDRANT_COLLECTION env or 'llm_semantic')") rootCmd.PersistentFlags().StringVar(&embedderType, "embedder", "openai", "Embedding provider: openai (default), cohere, huggingface, openrouter") @@ -128,7 +128,7 @@ func getAPIKey() string { // resolveCollectionName returns the Qdrant collection name using this priority: // 1. --collection flag if specified // 2. Profile-specific config value (e.g., code_collection) -// 3. Derived from --index-dir (e.g., ".llm-index/code" → "code", ".llm-index/docs" → "docs") +// 3. Derived from --index-dir (e.g., ".index/code" → "code", ".index/docs" → "docs") // 4. QDRANT_COLLECTION environment variable // 5. Default: "llm_semantic" func resolveCollectionName() string { @@ -146,9 +146,9 @@ func resolveCollectionName() string { } // Priority 3: derive from index-dir if non-default - if indexDir != "" && indexDir != ".llm-index" { + if indexDir != "" && indexDir != ".index" { // Extract the last path component as collection name - // e.g., ".llm-index/code" → "code", "indexes/docs" → "docs" + // e.g., ".index/code" → "code", "indexes/docs" → "docs" derived := deriveCollectionFromPath(indexDir) if derived != "" { return derived @@ -204,8 +204,8 @@ func deriveCollectionFromPath(path string) string { name = path } - // Skip if it's just ".llm-index" or similar default - if name == ".llm-index" || name == "llm-index" || name == "" { + // Skip if it's just ".index" or similar default + if name == ".index" || name == "index" || name == "" { return "" } @@ -345,7 +345,7 @@ func ResetGlobalsForTesting() { apiURL = getDefaultAPIURL() model = getDefaultModel() apiKey = "" - indexDir = ".llm-index" + indexDir = ".index" storageType = "sqlite" collectionName = "" embedderType = "openai" diff --git a/internal/semantic/commands/search.go b/internal/semantic/commands/search.go index 4c9225a..1e4b98c 100644 --- a/internal/semantic/commands/search.go +++ b/internal/semantic/commands/search.go @@ -237,15 +237,15 @@ func findIndexPath() string { } } - // Try .llm-index in current directory - path := filepath.Join(".llm-index", "semantic.db") + // Try .index in current directory + path := filepath.Join(".index", "semantic.db") if _, err := os.Stat(path); err == nil { return path } // Try to find git root and check there if gitRoot, err := findGitRoot(); err == nil { - path := filepath.Join(gitRoot, ".llm-index", "semantic.db") + path := filepath.Join(gitRoot, ".index", "semantic.db") if _, err := os.Stat(path); err == nil { return path } diff --git a/internal/semantic/errors.go b/internal/semantic/errors.go index 802b899..1728233 100644 --- a/internal/semantic/errors.go +++ b/internal/semantic/errors.go @@ -95,7 +95,7 @@ func ErrStorageFailure(operation string, cause error) *SemanticError { Type: ErrTypeUnknown, Message: fmt.Sprintf("storage %s failed", operation), Cause: cause, - Hint: "The index may be corrupted. Try: rm -rf .llm-index && llm-semantic index .", + Hint: "The index may be corrupted. Try: rm -rf .index && llm-semantic index .", } } diff --git a/internal/semantic/lexical_index.go b/internal/semantic/lexical_index.go index f45a599..72973bf 100644 --- a/internal/semantic/lexical_index.go +++ b/internal/semantic/lexical_index.go @@ -155,7 +155,12 @@ func (idx *LexicalIndex) initSchema() error { } // Backfill existing chunks if FTS is empty but chunks exist - return idx.backfillFTS() + if err := idx.backfillFTS(); err != nil { + return err + } + + // Create memory stats tracking tables + return idx.initStatsSchema() } // backfillFTS populates the FTS5 index from existing chunks. @@ -434,20 +439,15 @@ func sanitizeCollectionName(name string) string { } // getFTSPath returns the path for a parallel FTS database. -// If dataDir is empty, uses ~/.llm-semantic/. -// Returns empty string if home directory cannot be determined and dataDir is empty. +// dataDir should be the project's .index directory (e.g., {gitRoot}/.index/). +// Returns empty string if dataDir is empty (caller should handle). func getFTSPath(collection string, dataDir string) string { if dataDir == "" { - home, err := os.UserHomeDir() - if err != nil { - // Return empty string to signal error - caller should handle - return "" - } - dataDir = filepath.Join(home, ".llm-semantic") + // Return empty string to signal error - caller should handle + return "" } - safeCollection := sanitizeCollectionName(collection) - return filepath.Join(dataDir, fmt.Sprintf("qdrant-fts-%s.db", safeCollection)) + return filepath.Join(dataDir, "qdrant_fts.db") } // nullableInt64Lexical returns nil if value is 0, otherwise returns the value. @@ -487,3 +487,322 @@ func (idx *LexicalIndex) migrateMtimeColumn() error { return nil } + +// ===== Memory Stats Tracking ===== + +// initStatsSchema creates the memory stats tracking tables. +func (idx *LexicalIndex) initStatsSchema() error { + // Create memory_stats table for tracking retrieval counts + statsSchema := ` + CREATE TABLE IF NOT EXISTS memory_stats ( + memory_id TEXT PRIMARY KEY, + retrieval_count INTEGER DEFAULT 0, + last_retrieved TEXT, + status TEXT DEFAULT 'active' + ); + ` + if _, err := idx.db.Exec(statsSchema); err != nil { + return fmt.Errorf("failed to create memory_stats table: %w", err) + } + + // Create retrieval_log table for detailed tracking + logSchema := ` + CREATE TABLE IF NOT EXISTS retrieval_log ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + memory_id TEXT NOT NULL, + query TEXT, + score REAL, + timestamp TEXT DEFAULT (datetime('now')) + ); + ` + if _, err := idx.db.Exec(logSchema); err != nil { + return fmt.Errorf("failed to create retrieval_log table: %w", err) + } + + // Create index for pruning old log entries + indexSchema := ` + CREATE INDEX IF NOT EXISTS idx_retrieval_log_timestamp ON retrieval_log(timestamp); + CREATE INDEX IF NOT EXISTS idx_retrieval_log_memory_id ON retrieval_log(memory_id); + ` + if _, err := idx.db.Exec(indexSchema); err != nil { + return fmt.Errorf("failed to create retrieval_log indexes: %w", err) + } + + return nil +} + +// TrackRetrieval records a memory retrieval event. +// This should be called after search results are returned for profile=memory. +func (idx *LexicalIndex) TrackRetrieval(memoryID string, query string, score float32) error { + idx.mu.Lock() + defer idx.mu.Unlock() + + if idx.closed { + return fmt.Errorf("index is closed") + } + + tx, err := idx.db.Begin() + if err != nil { + return fmt.Errorf("failed to begin transaction: %w", err) + } + defer tx.Rollback() + + // Update or insert stats + _, err = tx.Exec(` + INSERT INTO memory_stats (memory_id, retrieval_count, last_retrieved, status) + VALUES (?, 1, datetime('now'), 'active') + ON CONFLICT(memory_id) DO UPDATE SET + retrieval_count = retrieval_count + 1, + last_retrieved = datetime('now') + `, memoryID) + if err != nil { + return fmt.Errorf("failed to update memory_stats: %w", err) + } + + // Log the retrieval + _, err = tx.Exec(` + INSERT INTO retrieval_log (memory_id, query, score) + VALUES (?, ?, ?) + `, memoryID, query, score) + if err != nil { + return fmt.Errorf("failed to insert retrieval_log: %w", err) + } + + return tx.Commit() +} + +// TrackRetrievalBatch records multiple memory retrieval events in a single transaction. +// More efficient than calling TrackRetrieval multiple times. +func (idx *LexicalIndex) TrackRetrievalBatch(retrievals []struct { + MemoryID string + Score float32 +}, query string) error { + if len(retrievals) == 0 { + return nil + } + + idx.mu.Lock() + defer idx.mu.Unlock() + + if idx.closed { + return fmt.Errorf("index is closed") + } + + tx, err := idx.db.Begin() + if err != nil { + return fmt.Errorf("failed to begin transaction: %w", err) + } + defer tx.Rollback() + + // Prepare statements for batch operations + stmtStats, err := tx.Prepare(` + INSERT INTO memory_stats (memory_id, retrieval_count, last_retrieved, status) + VALUES (?, 1, datetime('now'), 'active') + ON CONFLICT(memory_id) DO UPDATE SET + retrieval_count = retrieval_count + 1, + last_retrieved = datetime('now') + `) + if err != nil { + return fmt.Errorf("failed to prepare stats statement: %w", err) + } + defer stmtStats.Close() + + stmtLog, err := tx.Prepare(` + INSERT INTO retrieval_log (memory_id, query, score) + VALUES (?, ?, ?) + `) + if err != nil { + return fmt.Errorf("failed to prepare log statement: %w", err) + } + defer stmtLog.Close() + + for _, r := range retrievals { + if _, err := stmtStats.Exec(r.MemoryID); err != nil { + return fmt.Errorf("failed to update stats for %s: %w", r.MemoryID, err) + } + if _, err := stmtLog.Exec(r.MemoryID, query, r.Score); err != nil { + return fmt.Errorf("failed to log retrieval for %s: %w", r.MemoryID, err) + } + } + + return tx.Commit() +} + +// GetMemoryStats returns stats for a specific memory entry. +func (idx *LexicalIndex) GetMemoryStats(memoryID string) (*RetrievalStats, error) { + idx.mu.RLock() + defer idx.mu.RUnlock() + + if idx.closed { + return nil, fmt.Errorf("index is closed") + } + + var stats RetrievalStats + var lastRetrieved sql.NullString + + err := idx.db.QueryRow(` + SELECT memory_id, retrieval_count, last_retrieved, status + FROM memory_stats + WHERE memory_id = ? + `, memoryID).Scan(&stats.MemoryID, &stats.RetrievalCount, &lastRetrieved, &stats.Status) + + if err == sql.ErrNoRows { + return nil, nil // No stats yet + } + if err != nil { + return nil, fmt.Errorf("failed to get memory stats: %w", err) + } + + if lastRetrieved.Valid { + stats.LastRetrieved = lastRetrieved.String + } + + // Calculate average score from log + var avgScore sql.NullFloat64 + err = idx.db.QueryRow(` + SELECT AVG(score) FROM retrieval_log WHERE memory_id = ? + `, memoryID).Scan(&avgScore) + if err == nil && avgScore.Valid { + stats.AvgScore = float32(avgScore.Float64) + } + + return &stats, nil +} + +// GetAllMemoryStats returns stats for all tracked memories. +func (idx *LexicalIndex) GetAllMemoryStats() ([]RetrievalStats, error) { + idx.mu.RLock() + defer idx.mu.RUnlock() + + if idx.closed { + return nil, fmt.Errorf("index is closed") + } + + rows, err := idx.db.Query(` + SELECT + ms.memory_id, + ms.retrieval_count, + ms.last_retrieved, + ms.status, + COALESCE(AVG(rl.score), 0) as avg_score + FROM memory_stats ms + LEFT JOIN retrieval_log rl ON ms.memory_id = rl.memory_id + GROUP BY ms.memory_id + ORDER BY ms.retrieval_count DESC + `) + if err != nil { + return nil, fmt.Errorf("failed to query memory stats: %w", err) + } + defer rows.Close() + + var results []RetrievalStats + for rows.Next() { + var stats RetrievalStats + var lastRetrieved sql.NullString + var avgScore sql.NullFloat64 + + if err := rows.Scan(&stats.MemoryID, &stats.RetrievalCount, &lastRetrieved, &stats.Status, &avgScore); err != nil { + return nil, fmt.Errorf("failed to scan memory stats: %w", err) + } + + if lastRetrieved.Valid { + stats.LastRetrieved = lastRetrieved.String + } + if avgScore.Valid { + stats.AvgScore = float32(avgScore.Float64) + } + + results = append(results, stats) + } + + return results, rows.Err() +} + +// GetRetrievalHistory returns recent retrieval log entries for a memory. +func (idx *LexicalIndex) GetRetrievalHistory(memoryID string, limit int) ([]RetrievalLogEntry, error) { + idx.mu.RLock() + defer idx.mu.RUnlock() + + if idx.closed { + return nil, fmt.Errorf("index is closed") + } + + if limit <= 0 { + limit = 100 + } + + rows, err := idx.db.Query(` + SELECT id, memory_id, query, score, timestamp + FROM retrieval_log + WHERE memory_id = ? + ORDER BY timestamp DESC + LIMIT ? + `, memoryID, limit) + if err != nil { + return nil, fmt.Errorf("failed to query retrieval history: %w", err) + } + defer rows.Close() + + var results []RetrievalLogEntry + for rows.Next() { + var entry RetrievalLogEntry + var query sql.NullString + var score sql.NullFloat64 + + if err := rows.Scan(&entry.ID, &entry.MemoryID, &query, &score, &entry.Timestamp); err != nil { + return nil, fmt.Errorf("failed to scan retrieval log entry: %w", err) + } + + if query.Valid { + entry.Query = query.String + } + if score.Valid { + entry.Score = float32(score.Float64) + } + + results = append(results, entry) + } + + return results, rows.Err() +} + +// PruneRetrievalLog removes retrieval log entries older than the specified duration. +func (idx *LexicalIndex) PruneRetrievalLog(olderThanDays int) (int64, error) { + idx.mu.Lock() + defer idx.mu.Unlock() + + if idx.closed { + return 0, fmt.Errorf("index is closed") + } + + result, err := idx.db.Exec(` + DELETE FROM retrieval_log + WHERE timestamp < datetime('now', '-' || ? || ' days') + `, olderThanDays) + if err != nil { + return 0, fmt.Errorf("failed to prune retrieval log: %w", err) + } + + return result.RowsAffected() +} + +// UpdateMemoryStatus updates the status of a memory entry (active, archived, etc.). +func (idx *LexicalIndex) UpdateMemoryStatus(memoryID string, status string) error { + idx.mu.Lock() + defer idx.mu.Unlock() + + if idx.closed { + return fmt.Errorf("index is closed") + } + + _, err := idx.db.Exec(` + INSERT INTO memory_stats (memory_id, retrieval_count, status) + VALUES (?, 0, ?) + ON CONFLICT(memory_id) DO UPDATE SET status = ? + `, memoryID, status, status) + if err != nil { + return fmt.Errorf("failed to update memory status: %w", err) + } + + return nil +} diff --git a/internal/semantic/lexical_index_test.go b/internal/semantic/lexical_index_test.go index 1756756..ec35995 100644 --- a/internal/semantic/lexical_index_test.go +++ b/internal/semantic/lexical_index_test.go @@ -417,26 +417,21 @@ func TestSanitizeCollectionName(t *testing.T) { // TestGetFTSPath verifies FTS database path generation. func TestGetFTSPath(t *testing.T) { - // Test with custom data dir + // Test with custom data dir (project-local .index/) t.Run("CustomDataDir", func(t *testing.T) { tmpDir := t.TempDir() path := getFTSPath("test_collection", tmpDir) - expected := filepath.Join(tmpDir, "qdrant-fts-test_collection.db") + expected := filepath.Join(tmpDir, "qdrant_fts.db") if path != expected { t.Errorf("getFTSPath = %q, want %q", path, expected) } }) - // Test with default data dir (home dir) - t.Run("DefaultDataDir", func(t *testing.T) { - home, err := os.UserHomeDir() - if err != nil { - t.Skip("cannot get home dir") - } + // Test with empty data dir (should return empty string) + t.Run("EmptyDataDir", func(t *testing.T) { path := getFTSPath("my_collection", "") - expected := filepath.Join(home, ".llm-semantic", "qdrant-fts-my_collection.db") - if path != expected { - t.Errorf("getFTSPath = %q, want %q", path, expected) + if path != "" { + t.Errorf("getFTSPath with empty dataDir = %q, want empty string", path) } }) } diff --git a/internal/semantic/storage.go b/internal/semantic/storage.go index 0241ead..add0c85 100644 --- a/internal/semantic/storage.go +++ b/internal/semantic/storage.go @@ -136,3 +136,28 @@ type Storage interface { // Overwrites any existing calibration data. SetCalibrationMetadata(ctx context.Context, meta *CalibrationMetadata) error } + +// MemoryStatsTracker is an optional interface for tracking memory retrieval statistics. +// Storage implementations that support memory stats should implement this interface. +type MemoryStatsTracker interface { + // TrackMemoryRetrieval records a single memory retrieval event. + TrackMemoryRetrieval(ctx context.Context, memoryID string, query string, score float32) error + + // TrackMemoryRetrievalBatch records multiple memory retrieval events in a single transaction. + TrackMemoryRetrievalBatch(ctx context.Context, retrievals []MemoryRetrieval, query string) error + + // GetMemoryStats returns stats for a specific memory entry. + GetMemoryStats(ctx context.Context, memoryID string) (*RetrievalStats, error) + + // GetAllMemoryStats returns stats for all tracked memories. + GetAllMemoryStats(ctx context.Context) ([]RetrievalStats, error) + + // GetMemoryRetrievalHistory returns recent retrieval log entries for a memory. + GetMemoryRetrievalHistory(ctx context.Context, memoryID string, limit int) ([]RetrievalLogEntry, error) + + // PruneMemoryRetrievalLog removes retrieval log entries older than the specified duration. + PruneMemoryRetrievalLog(ctx context.Context, olderThanDays int) (int64, error) + + // UpdateMemoryStatsStatus updates the status of a memory entry. + UpdateMemoryStatsStatus(ctx context.Context, memoryID string, status string) error +} diff --git a/internal/semantic/storage_qdrant.go b/internal/semantic/storage_qdrant.go index 7fb1daf..16a4510 100644 --- a/internal/semantic/storage_qdrant.go +++ b/internal/semantic/storage_qdrant.go @@ -34,7 +34,7 @@ type QdrantConfig struct { URL string // Full URL like https://abc123.qdrant.io:6334 CollectionName string EmbeddingDim int - FTSDataDir string // Directory for parallel FTS database (default: ~/.llm-semantic/) + FTSDataDir string // Directory for parallel FTS database (should be project's .index/) InMemoryFTS bool // Use in-memory FTS (for testing) } @@ -129,7 +129,7 @@ func NewQdrantStorage(config QdrantConfig) (*QdrantStorage, error) { ftsPath = getFTSPath(config.CollectionName, config.FTSDataDir) if ftsPath == "" { client.Close() - return nil, fmt.Errorf("failed to determine FTS database path: cannot get home directory") + return nil, fmt.Errorf("failed to determine FTS database path: FTSDataDir is required") } } diff --git a/internal/semantic/storage_sqlite.go b/internal/semantic/storage_sqlite.go index 84e2320..2663a8f 100644 --- a/internal/semantic/storage_sqlite.go +++ b/internal/semantic/storage_sqlite.go @@ -136,6 +136,11 @@ func (s *SQLiteStorage) initSchema() error { return err } + // Initialize memory stats tracking tables + if err := s.initMemoryStatsSchema(); err != nil { + return err + } + return nil } @@ -1075,7 +1080,7 @@ func sortResultsByScore(results []SearchResult) { // IndexPath returns the default index path for a repository func IndexPath(repoRoot string) string { - return filepath.Join(repoRoot, ".llm-index", "semantic.db") + return filepath.Join(repoRoot, ".index", "semantic.db") } // nullableInt64 returns nil if value is 0, otherwise returns the value. @@ -1143,3 +1148,318 @@ func (s *SQLiteStorage) SetCalibrationMetadata(ctx context.Context, meta *Calibr return nil } + +// ===== Memory Stats Tracking ===== + +// initMemoryStatsSchema creates the memory stats tracking tables. +func (s *SQLiteStorage) initMemoryStatsSchema() error { + // Create memory_stats table for tracking retrieval counts + statsSchema := ` + CREATE TABLE IF NOT EXISTS memory_stats ( + memory_id TEXT PRIMARY KEY, + retrieval_count INTEGER DEFAULT 0, + last_retrieved TEXT, + status TEXT DEFAULT 'active' + ); + ` + if _, err := s.db.Exec(statsSchema); err != nil { + return fmt.Errorf("failed to create memory_stats table: %w", err) + } + + // Create retrieval_log table for detailed tracking + logSchema := ` + CREATE TABLE IF NOT EXISTS retrieval_log ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + memory_id TEXT NOT NULL, + query TEXT, + score REAL, + timestamp TEXT DEFAULT (datetime('now')) + ); + ` + if _, err := s.db.Exec(logSchema); err != nil { + return fmt.Errorf("failed to create retrieval_log table: %w", err) + } + + // Create indexes for efficient queries + indexSchema := ` + CREATE INDEX IF NOT EXISTS idx_retrieval_log_timestamp ON retrieval_log(timestamp); + CREATE INDEX IF NOT EXISTS idx_retrieval_log_memory_id ON retrieval_log(memory_id); + ` + if _, err := s.db.Exec(indexSchema); err != nil { + return fmt.Errorf("failed to create retrieval_log indexes: %w", err) + } + + return nil +} + +// TrackMemoryRetrieval records a memory retrieval event. +// This should be called after search results are returned for memory search. +func (s *SQLiteStorage) TrackMemoryRetrieval(ctx context.Context, memoryID string, query string, score float32) error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.closed { + return ErrStorageClosed + } + + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("failed to begin transaction: %w", err) + } + defer tx.Rollback() + + // Update or insert stats + _, err = tx.ExecContext(ctx, ` + INSERT INTO memory_stats (memory_id, retrieval_count, last_retrieved, status) + VALUES (?, 1, datetime('now'), 'active') + ON CONFLICT(memory_id) DO UPDATE SET + retrieval_count = retrieval_count + 1, + last_retrieved = datetime('now') + `, memoryID) + if err != nil { + return fmt.Errorf("failed to update memory_stats: %w", err) + } + + // Log the retrieval + _, err = tx.ExecContext(ctx, ` + INSERT INTO retrieval_log (memory_id, query, score) + VALUES (?, ?, ?) + `, memoryID, query, score) + if err != nil { + return fmt.Errorf("failed to insert retrieval_log: %w", err) + } + + return tx.Commit() +} + +// TrackMemoryRetrievalBatch records multiple memory retrieval events in a single transaction. +func (s *SQLiteStorage) TrackMemoryRetrievalBatch(ctx context.Context, retrievals []MemoryRetrieval, query string) error { + if len(retrievals) == 0 { + return nil + } + + s.mu.Lock() + defer s.mu.Unlock() + + if s.closed { + return ErrStorageClosed + } + + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("failed to begin transaction: %w", err) + } + defer tx.Rollback() + + // Prepare statements for batch operations + stmtStats, err := tx.PrepareContext(ctx, ` + INSERT INTO memory_stats (memory_id, retrieval_count, last_retrieved, status) + VALUES (?, 1, datetime('now'), 'active') + ON CONFLICT(memory_id) DO UPDATE SET + retrieval_count = retrieval_count + 1, + last_retrieved = datetime('now') + `) + if err != nil { + return fmt.Errorf("failed to prepare stats statement: %w", err) + } + defer stmtStats.Close() + + stmtLog, err := tx.PrepareContext(ctx, ` + INSERT INTO retrieval_log (memory_id, query, score) + VALUES (?, ?, ?) + `) + if err != nil { + return fmt.Errorf("failed to prepare log statement: %w", err) + } + defer stmtLog.Close() + + for _, r := range retrievals { + if _, err := stmtStats.ExecContext(ctx, r.MemoryID); err != nil { + return fmt.Errorf("failed to update stats for %s: %w", r.MemoryID, err) + } + if _, err := stmtLog.ExecContext(ctx, r.MemoryID, query, r.Score); err != nil { + return fmt.Errorf("failed to log retrieval for %s: %w", r.MemoryID, err) + } + } + + return tx.Commit() +} + +// GetMemoryStats returns stats for a specific memory entry. +func (s *SQLiteStorage) GetMemoryStats(ctx context.Context, memoryID string) (*RetrievalStats, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + if s.closed { + return nil, ErrStorageClosed + } + + var stats RetrievalStats + var lastRetrieved sql.NullString + + err := s.db.QueryRowContext(ctx, ` + SELECT memory_id, retrieval_count, last_retrieved, status + FROM memory_stats + WHERE memory_id = ? + `, memoryID).Scan(&stats.MemoryID, &stats.RetrievalCount, &lastRetrieved, &stats.Status) + + if err == sql.ErrNoRows { + return nil, nil // No stats yet + } + if err != nil { + return nil, fmt.Errorf("failed to get memory stats: %w", err) + } + + if lastRetrieved.Valid { + stats.LastRetrieved = lastRetrieved.String + } + + // Calculate average score from log + var avgScore sql.NullFloat64 + err = s.db.QueryRowContext(ctx, ` + SELECT AVG(score) FROM retrieval_log WHERE memory_id = ? + `, memoryID).Scan(&avgScore) + if err == nil && avgScore.Valid { + stats.AvgScore = float32(avgScore.Float64) + } + + return &stats, nil +} + +// GetAllMemoryStats returns stats for all tracked memories. +func (s *SQLiteStorage) GetAllMemoryStats(ctx context.Context) ([]RetrievalStats, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + if s.closed { + return nil, ErrStorageClosed + } + + rows, err := s.db.QueryContext(ctx, ` + SELECT + ms.memory_id, + ms.retrieval_count, + ms.last_retrieved, + ms.status, + COALESCE(AVG(rl.score), 0) as avg_score + FROM memory_stats ms + LEFT JOIN retrieval_log rl ON ms.memory_id = rl.memory_id + GROUP BY ms.memory_id + ORDER BY ms.retrieval_count DESC + `) + if err != nil { + return nil, fmt.Errorf("failed to query memory stats: %w", err) + } + defer rows.Close() + + var results []RetrievalStats + for rows.Next() { + var stats RetrievalStats + var lastRetrieved sql.NullString + var avgScore sql.NullFloat64 + + if err := rows.Scan(&stats.MemoryID, &stats.RetrievalCount, &lastRetrieved, &stats.Status, &avgScore); err != nil { + return nil, fmt.Errorf("failed to scan memory stats: %w", err) + } + + if lastRetrieved.Valid { + stats.LastRetrieved = lastRetrieved.String + } + if avgScore.Valid { + stats.AvgScore = float32(avgScore.Float64) + } + + results = append(results, stats) + } + + return results, rows.Err() +} + +// GetMemoryRetrievalHistory returns recent retrieval log entries for a memory. +func (s *SQLiteStorage) GetMemoryRetrievalHistory(ctx context.Context, memoryID string, limit int) ([]RetrievalLogEntry, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + if s.closed { + return nil, ErrStorageClosed + } + + if limit <= 0 { + limit = 100 + } + + rows, err := s.db.QueryContext(ctx, ` + SELECT id, memory_id, query, score, timestamp + FROM retrieval_log + WHERE memory_id = ? + ORDER BY timestamp DESC + LIMIT ? + `, memoryID, limit) + if err != nil { + return nil, fmt.Errorf("failed to query retrieval history: %w", err) + } + defer rows.Close() + + var results []RetrievalLogEntry + for rows.Next() { + var entry RetrievalLogEntry + var query sql.NullString + var score sql.NullFloat64 + + if err := rows.Scan(&entry.ID, &entry.MemoryID, &query, &score, &entry.Timestamp); err != nil { + return nil, fmt.Errorf("failed to scan retrieval log entry: %w", err) + } + + if query.Valid { + entry.Query = query.String + } + if score.Valid { + entry.Score = float32(score.Float64) + } + + results = append(results, entry) + } + + return results, rows.Err() +} + +// PruneMemoryRetrievalLog removes retrieval log entries older than the specified duration. +func (s *SQLiteStorage) PruneMemoryRetrievalLog(ctx context.Context, olderThanDays int) (int64, error) { + s.mu.Lock() + defer s.mu.Unlock() + + if s.closed { + return 0, ErrStorageClosed + } + + result, err := s.db.ExecContext(ctx, ` + DELETE FROM retrieval_log + WHERE timestamp < datetime('now', '-' || ? || ' days') + `, olderThanDays) + if err != nil { + return 0, fmt.Errorf("failed to prune retrieval log: %w", err) + } + + return result.RowsAffected() +} + +// UpdateMemoryStatsStatus updates the status of a memory entry (active, archived, etc.). +func (s *SQLiteStorage) UpdateMemoryStatsStatus(ctx context.Context, memoryID string, status string) error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.closed { + return ErrStorageClosed + } + + _, err := s.db.ExecContext(ctx, ` + INSERT INTO memory_stats (memory_id, retrieval_count, status) + VALUES (?, 0, ?) + ON CONFLICT(memory_id) DO UPDATE SET status = ? + `, memoryID, status, status) + if err != nil { + return fmt.Errorf("failed to update memory status: %w", err) + } + + return nil +} diff --git a/internal/semantic/storage_sqlite_test.go b/internal/semantic/storage_sqlite_test.go index 1345c07..aa96641 100644 --- a/internal/semantic/storage_sqlite_test.go +++ b/internal/semantic/storage_sqlite_test.go @@ -446,3 +446,208 @@ func TestSQLiteStorage_ConcurrentReadWrite(t *testing.T) { func idForWorker(workerID, iteration int) string { return fmt.Sprintf("worker-%d-%d", workerID, iteration) } + +// ===== Memory Stats Tracking Tests ===== + +func TestSQLiteStorage_TrackMemoryRetrieval(t *testing.T) { + storage, err := NewSQLiteStorage(":memory:", 4) + if err != nil { + t.Fatalf("Failed to create SQLite storage: %v", err) + } + defer storage.Close() + + ctx := context.Background() + + // Track a retrieval + err = storage.TrackMemoryRetrieval(ctx, "mem-001", "test query", 0.95) + if err != nil { + t.Fatalf("TrackMemoryRetrieval failed: %v", err) + } + + // Verify stats were created + stats, err := storage.GetMemoryStats(ctx, "mem-001") + if err != nil { + t.Fatalf("GetMemoryStats failed: %v", err) + } + if stats == nil { + t.Fatal("Expected stats to be non-nil") + } + if stats.RetrievalCount != 1 { + t.Errorf("Expected RetrievalCount=1, got %d", stats.RetrievalCount) + } + if stats.Status != "active" { + t.Errorf("Expected Status='active', got %q", stats.Status) + } + + // Track another retrieval - count should increment + err = storage.TrackMemoryRetrieval(ctx, "mem-001", "another query", 0.85) + if err != nil { + t.Fatalf("Second TrackMemoryRetrieval failed: %v", err) + } + + stats, err = storage.GetMemoryStats(ctx, "mem-001") + if err != nil { + t.Fatalf("GetMemoryStats failed: %v", err) + } + if stats.RetrievalCount != 2 { + t.Errorf("Expected RetrievalCount=2, got %d", stats.RetrievalCount) + } +} + +func TestSQLiteStorage_TrackMemoryRetrievalBatch(t *testing.T) { + storage, err := NewSQLiteStorage(":memory:", 4) + if err != nil { + t.Fatalf("Failed to create SQLite storage: %v", err) + } + defer storage.Close() + + ctx := context.Background() + + // Track a batch of retrievals + retrievals := []MemoryRetrieval{ + {MemoryID: "mem-001", Score: 0.95}, + {MemoryID: "mem-002", Score: 0.85}, + {MemoryID: "mem-003", Score: 0.75}, + } + err = storage.TrackMemoryRetrievalBatch(ctx, retrievals, "batch query") + if err != nil { + t.Fatalf("TrackMemoryRetrievalBatch failed: %v", err) + } + + // Verify all stats were created + allStats, err := storage.GetAllMemoryStats(ctx) + if err != nil { + t.Fatalf("GetAllMemoryStats failed: %v", err) + } + if len(allStats) != 3 { + t.Errorf("Expected 3 stats entries, got %d", len(allStats)) + } + + // Verify individual stats + for _, id := range []string{"mem-001", "mem-002", "mem-003"} { + stats, err := storage.GetMemoryStats(ctx, id) + if err != nil { + t.Fatalf("GetMemoryStats for %s failed: %v", id, err) + } + if stats == nil { + t.Errorf("Expected stats for %s to be non-nil", id) + continue + } + if stats.RetrievalCount != 1 { + t.Errorf("Expected RetrievalCount=1 for %s, got %d", id, stats.RetrievalCount) + } + } +} + +func TestSQLiteStorage_GetMemoryRetrievalHistory(t *testing.T) { + storage, err := NewSQLiteStorage(":memory:", 4) + if err != nil { + t.Fatalf("Failed to create SQLite storage: %v", err) + } + defer storage.Close() + + ctx := context.Background() + + // Track multiple retrievals for the same memory + queries := []string{"query1", "query2", "query3"} + for i, q := range queries { + err = storage.TrackMemoryRetrieval(ctx, "mem-001", q, float32(0.9-float32(i)*0.1)) + if err != nil { + t.Fatalf("TrackMemoryRetrieval failed: %v", err) + } + } + + // Get history + history, err := storage.GetMemoryRetrievalHistory(ctx, "mem-001", 10) + if err != nil { + t.Fatalf("GetMemoryRetrievalHistory failed: %v", err) + } + if len(history) != 3 { + t.Errorf("Expected 3 history entries, got %d", len(history)) + } + + // Verify all queries are present in history (order may vary due to same-second timestamps) + queryMap := make(map[string]bool) + for _, h := range history { + queryMap[h.Query] = true + } + for _, q := range queries { + if !queryMap[q] { + t.Errorf("Expected query %q in history", q) + } + } +} + +func TestSQLiteStorage_UpdateMemoryStatsStatus(t *testing.T) { + storage, err := NewSQLiteStorage(":memory:", 4) + if err != nil { + t.Fatalf("Failed to create SQLite storage: %v", err) + } + defer storage.Close() + + ctx := context.Background() + + // Track a retrieval to create the stats entry + err = storage.TrackMemoryRetrieval(ctx, "mem-001", "test", 0.9) + if err != nil { + t.Fatalf("TrackMemoryRetrieval failed: %v", err) + } + + // Update status + err = storage.UpdateMemoryStatsStatus(ctx, "mem-001", "archived") + if err != nil { + t.Fatalf("UpdateMemoryStatsStatus failed: %v", err) + } + + // Verify status was updated + stats, err := storage.GetMemoryStats(ctx, "mem-001") + if err != nil { + t.Fatalf("GetMemoryStats failed: %v", err) + } + if stats.Status != "archived" { + t.Errorf("Expected Status='archived', got %q", stats.Status) + } + + // Retrieval count should be preserved + if stats.RetrievalCount != 1 { + t.Errorf("Expected RetrievalCount=1, got %d", stats.RetrievalCount) + } +} + +func TestSQLiteStorage_PruneMemoryRetrievalLog(t *testing.T) { + storage, err := NewSQLiteStorage(":memory:", 4) + if err != nil { + t.Fatalf("Failed to create SQLite storage: %v", err) + } + defer storage.Close() + + ctx := context.Background() + + // Track some retrievals + for i := 0; i < 5; i++ { + err = storage.TrackMemoryRetrieval(ctx, fmt.Sprintf("mem-%03d", i), "test", 0.9) + if err != nil { + t.Fatalf("TrackMemoryRetrieval failed: %v", err) + } + } + + // Pruning with 0 days should delete nothing (all are from "now") + deleted, err := storage.PruneMemoryRetrievalLog(ctx, 0) + if err != nil { + t.Fatalf("PruneMemoryRetrievalLog failed: %v", err) + } + if deleted != 0 { + t.Errorf("Expected 0 deleted, got %d", deleted) + } +} + +func TestSQLiteStorage_MemoryStatsInterface(t *testing.T) { + storage, err := NewSQLiteStorage(":memory:", 4) + if err != nil { + t.Fatalf("Failed to create SQLite storage: %v", err) + } + defer storage.Close() + + // Verify SQLiteStorage implements MemoryStatsTracker + var _ MemoryStatsTracker = storage +} diff --git a/internal/semantic/types.go b/internal/semantic/types.go index 36dbc9c..d03ebec 100644 --- a/internal/semantic/types.go +++ b/internal/semantic/types.go @@ -171,3 +171,29 @@ type IndexHealth struct { NewFiles int `json:"new_files"` ModifiedFiles int `json:"modified_files"` } + +// ===== Memory Retrieval Stats Types ===== + +// RetrievalStats holds retrieval statistics for a memory entry. +type RetrievalStats struct { + MemoryID string `json:"memory_id"` + RetrievalCount int `json:"retrieval_count"` + LastRetrieved string `json:"last_retrieved,omitempty"` + Status string `json:"status"` + AvgScore float32 `json:"avg_score,omitempty"` +} + +// MemoryRetrieval represents a single retrieval event for batch tracking. +type MemoryRetrieval struct { + MemoryID string + Score float32 +} + +// RetrievalLogEntry represents a single retrieval event log. +type RetrievalLogEntry struct { + ID int64 `json:"id"` + MemoryID string `json:"memory_id"` + Query string `json:"query"` + Score float32 `json:"score"` + Timestamp string `json:"timestamp"` +} diff --git a/internal/support/commands/yaml.go b/internal/support/commands/yaml.go index 3d538b4..a39efd9 100644 --- a/internal/support/commands/yaml.go +++ b/internal/support/commands/yaml.go @@ -279,6 +279,8 @@ func newYamlSetCmd() *cobra.Command { var create bool var jsonOutput bool var minOutput bool + var dryRun bool + var quiet bool cmd := &cobra.Command{ Use: "set KEY VALUE", @@ -288,10 +290,16 @@ func newYamlSetCmd() *cobra.Command { Creates intermediate keys if they don't exist. Preserves comments where possible. +Use '-' as VALUE to read from stdin (for piping values or multi-line input). +Use --dry-run to preview changes without writing to the file. +Use --quiet to suppress success messages (errors still output). + Examples: yaml set --file config.yaml helper.llm claude yaml set --file config.yaml helper.max_lines 2500 - yaml set --file config.yaml new.nested.key value --create`, + yaml set --file config.yaml new.nested.key value --create + echo "piped" | yaml set --file config.yaml key - + yaml set --file config.yaml key value --dry-run`, Args: cobra.ExactArgs(2), RunE: func(cmd *cobra.Command, args []string) error { if file == "" { @@ -301,15 +309,50 @@ Examples: key := args[0] value := args[1] + // Handle stdin input when value is "-" + if value == "-" { + stdinValue, err := readFromStdin(cmd) + if err != nil { + return fmt.Errorf("failed to read value from stdin: %w", err) + } + value = stdinValue + } + // Check if file exists if _, err := os.Stat(file); os.IsNotExist(err) { if !create { return fmt.Errorf("config file not found: %s (hint: use --create to create it, or run 'llm-support yaml init --file %s' first)", file, file) } - // Create empty file - if err := os.WriteFile(file, []byte(""), 0644); err != nil { - return fmt.Errorf("failed to create config file: %w", err) + // Create empty file (unless dry-run) + if !dryRun { + if err := os.WriteFile(file, []byte(""), 0644); err != nil { + return fmt.Errorf("failed to create config file: %w", err) + } + } + } + + // Set value (convert numeric strings to numbers) + var typedValue interface{} = value + if num, err := parseNumber(value); err == nil { + typedValue = num + } + + // Handle dry-run mode - preview without writing + if dryRun { + var oldValue interface{} + if _, err := os.Stat(file); err == nil { + // Acquire read lock to prevent race condition with concurrent writes + lock, err := yamlFileLock(file, false) // false = read lock + if err != nil { + return fmt.Errorf("failed to acquire read lock for dry-run: %w", err) + } + data, err := readYAMLAsMap(file) + lock.Unlock() + if err == nil { + oldValue, _ = getValueAtPath(data, key) + } } + return outputDryRunPreview(cmd, file, key, oldValue, typedValue, jsonOutput, minOutput) } // Acquire write lock @@ -319,17 +362,16 @@ Examples: } defer lock.Unlock() - // Set value (convert numeric strings to numbers) - var typedValue interface{} = value - if num, err := parseNumber(value); err == nil { - typedValue = num - } - // Use comment-preserving set (uses AST manipulation) if err := setValuePreservingComments(file, key, typedValue); err != nil { return fmt.Errorf("failed to set value: %w", err) } + // Skip success output if quiet + if quiet { + return nil + } + // Output if jsonOutput { if minOutput { @@ -357,6 +399,8 @@ Examples: cmd.Flags().BoolVar(&create, "create", false, "Create file if it doesn't exist") cmd.Flags().BoolVar(&jsonOutput, "json", false, "Output as JSON") cmd.Flags().BoolVar(&minOutput, "min", false, "Minimal output") + cmd.Flags().BoolVar(&dryRun, "dry-run", false, "Preview changes without writing to file") + cmd.Flags().BoolVar(&quiet, "quiet", false, "Suppress success messages (errors still output)") cmd.MarkFlagRequired("file") return cmd @@ -369,14 +413,19 @@ func newYamlMultigetCmd() *cobra.Command { var separator string var jsonOutput bool var minOutput bool + var requiredFile string cmd := &cobra.Command{ - Use: "multiget KEY1 [KEY2 ...]", + Use: "multiget [KEY1 KEY2 ...]", Short: "Retrieve multiple values", Long: `Retrieve multiple values from the YAML config file in a single operation. Values are returned in argument order. +Keys can be specified via: + - Positional arguments + - --required-file (one key per line, # comments supported) + Output Formats: default: key=value (one per line) --json: {"key1": "value1", "key2": "value2"} @@ -385,13 +434,36 @@ Output Formats: Examples: yaml multiget --file config.yaml helper.llm project.type yaml multiget --file config.yaml helper.llm missing.key --defaults '{"missing.key": "default"}' - yaml multiget --file config.yaml helper.llm project.type --min`, - Args: cobra.MinimumNArgs(1), + yaml multiget --file config.yaml helper.llm project.type --min + yaml multiget --file config.yaml --required-file keys.txt`, + Args: cobra.ArbitraryArgs, RunE: func(cmd *cobra.Command, args []string) error { if file == "" { return fmt.Errorf("--file flag is required") } + // Collect keys from all sources + var keys []string + + // From positional args + keys = append(keys, args...) + + // From --required-file + if requiredFile != "" { + fileKeys, err := parseRequiredKeysFile(requiredFile) + if err != nil { + return fmt.Errorf("failed to read required keys file: %w", err) + } + keys = append(keys, fileKeys...) + } + + // Deduplicate keys while preserving order + keys = yamlUniqueStrings(keys) + + if len(keys) == 0 { + return fmt.Errorf("no keys specified (use positional args or --required-file)") + } + // Parse defaults if provided var defaultsMap map[string]string if defaults != "" { @@ -419,7 +491,7 @@ Examples: // Get all values results := make(map[string]string) var orderedKeys []string - for _, key := range args { + for _, key := range keys { orderedKeys = append(orderedKeys, key) value, found := getValueAtPath(data, key) if found { @@ -469,6 +541,7 @@ Examples: cmd.Flags().StringVar(&separator, "separator", "\n", "Value separator for --min output") cmd.Flags().BoolVar(&jsonOutput, "json", false, "Output as JSON") cmd.Flags().BoolVar(&minOutput, "min", false, "Minimal output (values only)") + cmd.Flags().StringVar(&requiredFile, "required-file", "", "File containing required keys (one per line, # comments supported)") cmd.MarkFlagRequired("file") return cmd @@ -480,6 +553,8 @@ func newYamlMultisetCmd() *cobra.Command { var create bool var jsonOutput bool var minOutput bool + var dryRun bool + var quiet bool cmd := &cobra.Command{ Use: "multiset KEY1 VALUE1 [KEY2 VALUE2 ...]", @@ -488,10 +563,13 @@ func newYamlMultisetCmd() *cobra.Command { All keys are validated before any writes occur. Arguments must be in KEY VALUE pairs. +Use --dry-run to preview changes without writing to the file. +Use --quiet to suppress success messages (errors still output). Examples: yaml multiset --file config.yaml helper.llm claude helper.max_lines 2500 - yaml multiset --file config.yaml --create new.key value1 other.key value2`, + yaml multiset --file config.yaml --create new.key value1 other.key value2 + yaml multiset --file config.yaml key1 val1 key2 val2 --dry-run`, Args: cobra.MinimumNArgs(2), RunE: func(cmd *cobra.Command, args []string) error { if file == "" { @@ -517,12 +595,62 @@ Examples: if !create { return fmt.Errorf("config file not found: %s (hint: use --create)", file) } - // Create empty file - if err := os.WriteFile(file, []byte(""), 0644); err != nil { - return fmt.Errorf("failed to create config file: %w", err) + // Create empty file (unless dry-run) + if !dryRun { + if err := os.WriteFile(file, []byte(""), 0644); err != nil { + return fmt.Errorf("failed to create config file: %w", err) + } } } + // Handle dry-run mode - preview without writing + if dryRun { + var data map[string]interface{} + if _, statErr := os.Stat(file); statErr == nil { + // Acquire read lock to prevent race condition with concurrent writes + lock, lockErr := yamlFileLock(file, false) // false = read lock + if lockErr != nil { + return fmt.Errorf("failed to acquire read lock for dry-run: %w", lockErr) + } + var readErr error + data, readErr = readYAMLAsMap(file) + lock.Unlock() + if readErr != nil { + return fmt.Errorf("failed to read config file: %w", readErr) + } + } else { + data = make(map[string]interface{}) + } + + var changes []dryRunChange + for _, pair := range pairs { + var typedValue interface{} = pair.value + if num, parseErr := parseNumber(pair.value); parseErr == nil { + typedValue = num + } + + oldValue, _ := getValueAtPath(data, pair.key) + changes = append(changes, dryRunChange{ + Key: pair.key, + OldValue: oldValue, + NewValue: typedValue, + }) + } + return outputMultiDryRunPreview(cmd, file, changes, jsonOutput, minOutput) + } + + // Read file for actual update + var data map[string]interface{} + var err error + if _, statErr := os.Stat(file); statErr == nil { + data, err = readYAMLAsMap(file) + if err != nil { + return fmt.Errorf("failed to read config file: %w", err) + } + } else { + data = make(map[string]interface{}) + } + // Acquire write lock lock, err := yamlFileLock(file, true) if err != nil { @@ -530,17 +658,11 @@ Examples: } defer lock.Unlock() - // Read file - data, err := readYAMLAsMap(file) - if err != nil { - return fmt.Errorf("failed to read config file: %w", err) - } - // Set all values var keys []string for _, pair := range pairs { var typedValue interface{} = pair.value - if num, err := parseNumber(pair.value); err == nil { + if num, parseErr := parseNumber(pair.value); parseErr == nil { typedValue = num } @@ -555,6 +677,11 @@ Examples: return fmt.Errorf("failed to write config file: %w", err) } + // Skip success output if quiet + if quiet { + return nil + } + // Output // NOTE: Output format intentionally matches context_multiset for consistency if jsonOutput { @@ -583,6 +710,8 @@ Examples: cmd.Flags().BoolVar(&create, "create", false, "Create file if it doesn't exist") cmd.Flags().BoolVar(&jsonOutput, "json", false, "Output as JSON") cmd.Flags().BoolVar(&minOutput, "min", false, "Minimal output") + cmd.Flags().BoolVar(&dryRun, "dry-run", false, "Preview changes without writing to file") + cmd.Flags().BoolVar(&quiet, "quiet", false, "Suppress success messages (errors still output)") cmd.MarkFlagRequired("file") return cmd diff --git a/internal/support/commands/yaml_helpers.go b/internal/support/commands/yaml_helpers.go index 692675c..9e1ef28 100644 --- a/internal/support/commands/yaml_helpers.go +++ b/internal/support/commands/yaml_helpers.go @@ -1,17 +1,26 @@ package commands import ( + "bufio" + "encoding/json" "fmt" + "io" "os" "path/filepath" "strings" + "unicode/utf8" "github.com/goccy/go-yaml" "github.com/goccy/go-yaml/ast" "github.com/goccy/go-yaml/parser" "github.com/gofrs/flock" + "github.com/spf13/cobra" ) +// invalidArrayIndex is a sentinel value for parseArrayIndex when input is not a valid integer. +// We use -999 because -1, -2, etc. are valid negative indices for accessing array elements from the end. +const invalidArrayIndex = -999 + // Built-in templates for yaml init var planningTemplate = `# llm-support YAML configuration @@ -174,6 +183,7 @@ func convertDotPathToYAMLPath(dotPath string) string { } // getValueAtPath retrieves a value from YAML data using dot notation +// Supports negative array indices: -1 = last, -2 = second to last, etc. func getValueAtPath(data map[string]interface{}, dotPath string) (interface{}, bool) { parts := parseDotPath(dotPath) if len(parts) == 0 { @@ -196,8 +206,16 @@ func getValueAtPath(data map[string]interface{}, dotPath string) (interface{}, b } current = val case []interface{}: - // Handle array index + // Handle array index (including negative indices) idx := parseArrayIndex(part) + if idx == invalidArrayIndex { + // Not a valid index + return nil, false + } + // Handle negative indices: -1 = last, -2 = second to last, etc. + if idx < 0 { + idx = len(v) + idx + } if idx < 0 || idx >= len(v) { return nil, false } @@ -211,20 +229,89 @@ func getValueAtPath(data map[string]interface{}, dotPath string) (interface{}, b } // setValueAtPath sets a value in YAML data using dot notation, creating intermediate keys +// Supports array index traversal and setting values at array indices (including negative indices) func setValueAtPath(data map[string]interface{}, dotPath string, value interface{}) error { parts := parseDotPath(dotPath) if len(parts) == 0 { return fmt.Errorf("empty path") } - current := data + // Use interface{} to track current position (can be map or array) + var currentVal interface{} = data + var parentVal interface{} + var parentKey string + var parentIdx int = -1 + for i := 0; i < len(parts)-1; i++ { part := parts[i] - if existing, ok := current[part]; ok { - if m, ok := existing.(map[string]interface{}); ok { - current = m - } else if m, ok := existing.(map[interface{}]interface{}); ok { + switch curr := currentVal.(type) { + case map[string]interface{}: + if existing, ok := curr[part]; ok { + if m, ok := existing.(map[string]interface{}); ok { + parentVal = currentVal + parentKey = part + parentIdx = -1 + currentVal = m + } else if m, ok := existing.(map[interface{}]interface{}); ok { + // Convert to map[string]interface{} + converted := make(map[string]interface{}) + for k, v := range m { + if ks, ok := k.(string); ok { + converted[ks] = v + } + } + curr[part] = converted + parentVal = currentVal + parentKey = part + parentIdx = -1 + currentVal = converted + } else if arr, ok := existing.([]interface{}); ok { + // Next part should be an array index + parentVal = currentVal + parentKey = part + parentIdx = -1 + currentVal = arr + } else { + // Replace non-map value with a new map + newMap := make(map[string]interface{}) + curr[part] = newMap + parentVal = currentVal + parentKey = part + parentIdx = -1 + currentVal = newMap + } + } else { + // Create intermediate map + newMap := make(map[string]interface{}) + curr[part] = newMap + parentVal = currentVal + parentKey = part + parentIdx = -1 + currentVal = newMap + } + + case []interface{}: + // Handle array index traversal + idx := parseArrayIndex(part) + if idx == invalidArrayIndex { + return fmt.Errorf("invalid array index: %s", part) + } + // Handle negative indices + if idx < 0 { + idx = len(curr) + idx + } + if idx < 0 || idx >= len(curr) { + return fmt.Errorf("array index out of bounds: %d (length: %d)", idx, len(curr)) + } + + elem := curr[idx] + if m, ok := elem.(map[string]interface{}); ok { + parentVal = currentVal + parentIdx = idx + parentKey = "" + currentVal = m + } else if m, ok := elem.(map[interface{}]interface{}); ok { // Convert to map[string]interface{} converted := make(map[string]interface{}) for k, v := range m { @@ -232,24 +319,62 @@ func setValueAtPath(data map[string]interface{}, dotPath string, value interface converted[ks] = v } } - current[part] = converted - current = converted + curr[idx] = converted + parentVal = currentVal + parentIdx = idx + parentKey = "" + currentVal = converted + } else if arr, ok := elem.([]interface{}); ok { + parentVal = currentVal + parentIdx = idx + parentKey = "" + currentVal = arr } else { - // Replace non-map value with a new map + // Need to traverse further but hit a scalar - replace with map newMap := make(map[string]interface{}) - current[part] = newMap - current = newMap + curr[idx] = newMap + parentVal = currentVal + parentIdx = idx + parentKey = "" + currentVal = newMap } - } else { - // Create intermediate map - newMap := make(map[string]interface{}) - current[part] = newMap - current = newMap + + default: + return fmt.Errorf("cannot traverse path at %s", part) } } // Set the final value - current[parts[len(parts)-1]] = value + finalPart := parts[len(parts)-1] + + switch curr := currentVal.(type) { + case map[string]interface{}: + curr[finalPart] = value + case []interface{}: + idx := parseArrayIndex(finalPart) + if idx == invalidArrayIndex { + return fmt.Errorf("invalid array index: %s", finalPart) + } + // Handle negative indices + if idx < 0 { + idx = len(curr) + idx + } + if idx < 0 || idx >= len(curr) { + return fmt.Errorf("array index out of bounds: %d (length: %d)", idx, len(curr)) + } + curr[idx] = value + // Update parent reference since slice assignment doesn't modify the original + if parentVal != nil { + if parentIdx >= 0 { + parentVal.([]interface{})[parentIdx] = curr + } else if parentKey != "" { + parentVal.(map[string]interface{})[parentKey] = curr + } + } + default: + return fmt.Errorf("cannot set value at path") + } + return nil } @@ -293,7 +418,8 @@ func deleteValueAtPath(data map[string]interface{}, dotPath string) error { return nil } -// parseDotPath splits a dot-notation path into parts, handling escaped dots +// parseDotPath splits a dot-notation path into parts, handling escaped dots and bracket notation +// Supports: "a.b.c", "items[0].name", "items[-1]", "a\.b.c" (escaped dot), "a[0][1]" (nested arrays) func parseDotPath(path string) []string { if path == "" { return nil @@ -304,20 +430,44 @@ func parseDotPath(path string) []string { for i := 0; i < len(path); i++ { ch := path[i] - if ch == '\\' && i+1 < len(path) && path[i+1] == '.' { - // Escaped dot + + switch { + case ch == '\\' && i+1 < len(path) && path[i+1] == '.': + // Escaped dot - include literal dot current.WriteByte('.') i++ - } else if ch == '.' { + + case ch == '.': + // Dot separator - flush current segment if current.Len() > 0 { parts = append(parts, current.String()) current.Reset() } - } else { + + case ch == '[': + // Array bracket - flush current segment if any, then parse index + if current.Len() > 0 { + parts = append(parts, current.String()) + current.Reset() + } + // Find closing bracket + end := strings.Index(path[i:], "]") + if end == -1 { + // Malformed - treat as literal + current.WriteByte(ch) + continue + } + // Extract index (handles negative indices like -1) + indexStr := path[i+1 : i+end] + parts = append(parts, indexStr) + i += end // Skip past closing bracket + + default: current.WriteByte(ch) } } + // Flush final segment if current.Len() > 0 { parts = append(parts, current.String()) } @@ -325,7 +475,8 @@ func parseDotPath(path string) []string { return parts } -// parseArrayIndex parses an array index from a path part like "0" or "[0]" +// parseArrayIndex parses an array index from a path part like "0", "[0]", "-1", or "[-1]" +// Returns invalidArrayIndex as a sentinel value for invalid input (since -1, -2, etc. are valid negative indices) func parseArrayIndex(part string) int { // Remove brackets if present part = strings.TrimPrefix(part, "[") @@ -334,7 +485,7 @@ func parseArrayIndex(part string) int { var idx int _, err := fmt.Sscanf(part, "%d", &idx) if err != nil { - return -1 + return invalidArrayIndex } return idx } @@ -494,6 +645,23 @@ func setValuePreservingComments(filePath string, dotPath string, value interface return fmt.Errorf("invalid YAML: %w", err) } + // Check if path contains negative indices - if so, we need to resolve them first + // since the yaml library doesn't support negative indices + if strings.Contains(dotPath, "-") { + // Read current data to resolve negative indices + data, mapErr := readYAMLAsMap(filePath) + if mapErr != nil { + return mapErr + } + + // Resolve negative indices in the path + resolvedPath, resolveErr := resolveNegativeIndices(data, dotPath) + if resolveErr != nil { + return resolveErr + } + dotPath = resolvedPath + } + // Convert dot path to YAML path syntax yamlPath := convertDotPathToYAMLPath(dotPath) @@ -534,3 +702,280 @@ func setValuePreservingComments(filePath string, dotPath string, value interface } return writeYAMLFile(filePath, data) } + +// resolveNegativeIndices converts negative array indices to positive ones by examining the data +// For example: items[-1].name with a 3-element array becomes items[2].name +func resolveNegativeIndices(data map[string]interface{}, dotPath string) (string, error) { + parts := parseDotPath(dotPath) + if len(parts) == 0 { + return dotPath, nil + } + + current := interface{}(data) + var resolvedParts []string + + for i, part := range parts { + idx := parseArrayIndex(part) + if idx != invalidArrayIndex && idx < 0 { + // Negative index - need to resolve it + arr, ok := current.([]interface{}) + if !ok { + return "", fmt.Errorf("negative index %d used on non-array at %s", idx, strings.Join(parts[:i], ".")) + } + resolvedIdx := len(arr) + idx + if resolvedIdx < 0 || resolvedIdx >= len(arr) { + return "", fmt.Errorf("array index out of bounds: %d (length: %d)", idx, len(arr)) + } + resolvedParts = append(resolvedParts, fmt.Sprintf("[%d]", resolvedIdx)) + current = arr[resolvedIdx] + } else { + resolvedParts = append(resolvedParts, part) + // Navigate to next level + switch v := current.(type) { + case map[string]interface{}: + current = v[part] + case map[interface{}]interface{}: + current = v[part] + case []interface{}: + if idx != invalidArrayIndex && idx >= 0 && idx < len(v) { + current = v[idx] + } + } + } + } + + // Reconstruct path - parts that are numeric should use bracket notation + var result strings.Builder + for i, part := range resolvedParts { + if strings.HasPrefix(part, "[") { + result.WriteString(part) + } else { + if i > 0 { + result.WriteString(".") + } + result.WriteString(part) + } + } + return result.String(), nil +} + +// maxStdinSize is the maximum allowed size for stdin input (10MB) +// This prevents memory exhaustion from malicious or accidental large inputs +const maxStdinSize = 10 * 1024 * 1024 + +// readFromStdin reads all content from stdin, trimming trailing newline +// Uses cmd.InOrStdin() for testability - tests can use cmd.SetIn() +// Enforces a maximum size limit to prevent memory exhaustion +func readFromStdin(cmd *cobra.Command) (string, error) { + reader := cmd.InOrStdin() + + // Check if stdin has data (not a terminal) + // Only check if it's actually os.Stdin (not a test buffer) + if f, ok := reader.(*os.File); ok && f == os.Stdin { + stat, err := f.Stat() + if err != nil { + return "", fmt.Errorf("failed to stat stdin: %w", err) + } + if (stat.Mode() & os.ModeCharDevice) != 0 { + return "", fmt.Errorf("no data piped to stdin (use 'echo value | command' or 'command < file')") + } + } + + // Read with size limit to prevent memory exhaustion + limitedReader := io.LimitReader(reader, maxStdinSize+1) + content, err := io.ReadAll(limitedReader) + if err != nil { + return "", fmt.Errorf("error reading stdin: %w", err) + } + + // Check if we hit the limit + if len(content) > maxStdinSize { + return "", fmt.Errorf("stdin input exceeds maximum size of %d bytes", maxStdinSize) + } + + // Validate UTF-8 encoding to prevent YAML corruption from binary data + if !utf8.Valid(content) { + return "", fmt.Errorf("stdin input is not valid UTF-8 (binary data not supported for YAML values)") + } + + // Trim trailing newline (common from echo) + // Handle both Unix (\n) and Windows (\r\n) line endings + result := strings.TrimSuffix(string(content), "\n") + result = strings.TrimSuffix(result, "\r") + return result, nil +} + +// parseRequiredKeysFile reads keys from a file, one per line +// Supports # comments and skips empty lines +// Validates the file path to prevent path traversal attacks +func parseRequiredKeysFile(filePath string) ([]string, error) { + // Clean the path to resolve any . or .. components + cleanPath := filepath.Clean(filePath) + + // Resolve to absolute path + absPath, err := filepath.Abs(cleanPath) + if err != nil { + return nil, fmt.Errorf("failed to resolve file path: %w", err) + } + + // Verify the file exists and is a regular file (not a symlink to something else) + info, err := os.Lstat(absPath) + if err != nil { + return nil, fmt.Errorf("failed to stat required keys file: %w", err) + } + + // Don't follow symlinks to prevent traversal via symlink + if info.Mode()&os.ModeSymlink != 0 { + return nil, fmt.Errorf("required keys file cannot be a symlink for security reasons") + } + + if !info.Mode().IsRegular() { + return nil, fmt.Errorf("required keys file must be a regular file") + } + + file, err := os.Open(absPath) + if err != nil { + return nil, fmt.Errorf("failed to open required keys file: %w", err) + } + defer file.Close() + + var keys []string + scanner := bufio.NewScanner(file) + // Increase buffer size to handle keys longer than default 64KB limit + // YAML keys with long paths (e.g., deeply.nested.keys.with.many.segments) + // may exceed the default scanner buffer + const maxKeyLength = 256 * 1024 // 256KB should be more than enough for any key + buf := make([]byte, maxKeyLength) + scanner.Buffer(buf, maxKeyLength) + for scanner.Scan() { + line := scanner.Text() + + // Strip inline comments + if idx := strings.Index(line, "#"); idx != -1 { + line = line[:idx] + } + + // Trim whitespace + line = strings.TrimSpace(line) + + // Skip empty lines + if line == "" { + continue + } + + keys = append(keys, line) + } + + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("error reading required keys file: %w", err) + } + + return keys, nil +} + +// yamlUniqueStrings removes duplicates from a string slice while preserving order +// Used for deduplicating keys from multiple sources in yaml-multiget +func yamlUniqueStrings(input []string) []string { + seen := make(map[string]bool) + result := make([]string, 0, len(input)) + + for _, s := range input { + if !seen[s] { + seen[s] = true + result = append(result, s) + } + } + + return result +} + +// dryRunChange represents a single key change for dry-run preview +type dryRunChange struct { + Key string `json:"key"` + OldValue interface{} `json:"old_value"` + NewValue interface{} `json:"new_value"` +} + +// outputDryRunPreview outputs a single change preview +func outputDryRunPreview(cmd *cobra.Command, filePath, key string, oldValue, newValue interface{}, jsonOutput, minOutput bool) error { + if jsonOutput { + return outputDryRunJSON(cmd, filePath, []dryRunChange{{ + Key: key, + OldValue: oldValue, + NewValue: newValue, + }}, minOutput) + } + + if minOutput { + // Minimal: just key: old → new + fmt.Fprintf(cmd.OutOrStdout(), "%s: %v → %v\n", key, formatDryRunValue(oldValue), formatDryRunValue(newValue)) + return nil + } + + // Default text output + fmt.Fprintln(cmd.OutOrStdout(), "DRY RUN - No changes written") + fmt.Fprintf(cmd.OutOrStdout(), "File: %s\n", filePath) + fmt.Fprintf(cmd.OutOrStdout(), "Key: %s\n", key) + fmt.Fprintf(cmd.OutOrStdout(), "Old: %s\n", formatDryRunValue(oldValue)) + fmt.Fprintf(cmd.OutOrStdout(), "New: %s\n", formatDryRunValue(newValue)) + + return nil +} + +// outputMultiDryRunPreview outputs multiple change previews +func outputMultiDryRunPreview(cmd *cobra.Command, filePath string, changes []dryRunChange, jsonOutput, minOutput bool) error { + if jsonOutput { + return outputDryRunJSON(cmd, filePath, changes, minOutput) + } + + if minOutput { + for _, c := range changes { + fmt.Fprintf(cmd.OutOrStdout(), "%s: %v → %v\n", c.Key, formatDryRunValue(c.OldValue), formatDryRunValue(c.NewValue)) + } + return nil + } + + // Default text output + fmt.Fprintln(cmd.OutOrStdout(), "DRY RUN - No changes written") + fmt.Fprintf(cmd.OutOrStdout(), "File: %s\n", filePath) + fmt.Fprintf(cmd.OutOrStdout(), "Changes (%d):\n", len(changes)) + for _, c := range changes { + fmt.Fprintf(cmd.OutOrStdout(), " %s:\n", c.Key) + fmt.Fprintf(cmd.OutOrStdout(), " Old: %v\n", formatDryRunValue(c.OldValue)) + fmt.Fprintf(cmd.OutOrStdout(), " New: %v\n", formatDryRunValue(c.NewValue)) + } + + return nil +} + +// outputDryRunJSON outputs dry-run preview as JSON +func outputDryRunJSON(cmd *cobra.Command, filePath string, changes []dryRunChange, minOutput bool) error { + if minOutput { + // Minimal JSON: just array of changes + output := map[string]interface{}{ + "dry_run": true, + "changes": changes, + } + jsonBytes, _ := json.Marshal(output) + fmt.Fprintln(cmd.OutOrStdout(), string(jsonBytes)) + return nil + } + + // Full JSON with formatting + output := map[string]interface{}{ + "dry_run": true, + "file": filePath, + "changes": changes, + } + encoder := json.NewEncoder(cmd.OutOrStdout()) + encoder.SetIndent("", " ") + return encoder.Encode(output) +} + +// formatDryRunValue formats a value for display (handles nil, etc.) +func formatDryRunValue(v interface{}) string { + if v == nil { + return "" + } + return fmt.Sprintf("%v", v) +} diff --git a/internal/support/commands/yaml_test.go b/internal/support/commands/yaml_test.go index caa32e1..18ea2cc 100644 --- a/internal/support/commands/yaml_test.go +++ b/internal/support/commands/yaml_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "os" "path/filepath" + "reflect" "strings" "testing" ) @@ -1792,9 +1793,9 @@ func TestParseArrayIndex(t *testing.T) { {"10", 10}, {"[0]", 0}, {"[5]", 5}, - {"abc", -1}, - {"", -1}, - {"[abc]", -1}, + {"abc", invalidArrayIndex}, // Invalid index returns sentinel value + {"", invalidArrayIndex}, // Invalid index returns sentinel value + {"[abc]", invalidArrayIndex}, // Invalid index returns sentinel value } for _, tt := range tests { @@ -1831,10 +1832,13 @@ func TestGetValueAtPath_ArrayAccess(t *testing.T) { t.Error("expected not to find items.10") } - // Test negative array index - _, found = getValueAtPath(data, "items.-1") - if found { - t.Error("expected not to find items.-1") + // Test negative array index (now valid - returns last element) + val, found = getValueAtPath(data, "items.-1") + if !found { + t.Error("expected to find items.-1 (last element)") + } + if val != "third" { + t.Errorf("expected 'third' for items.-1, got: %v", val) } // Test empty path returns full data @@ -2120,3 +2124,1104 @@ data: t.Errorf("expected 'two', got: %s", output) } } + +// ============================================================================ +// Array Bracket Notation Tests (Sprint 9.0) +// ============================================================================ + +func TestParseDotPath_BracketNotation(t *testing.T) { + tests := []struct { + input string + expected []string + }{ + {"items[0]", []string{"items", "0"}}, + {"items[0].name", []string{"items", "0", "name"}}, + {"a.items[2].b", []string{"a", "items", "2", "b"}}, + {"items[-1]", []string{"items", "-1"}}, + {"items[-1].value", []string{"items", "-1", "value"}}, + {"a[0][1]", []string{"a", "0", "1"}}, // Nested arrays + {"simple.path", []string{"simple", "path"}}, + {`a\.b.c`, []string{"a.b", "c"}}, // Escaped dot regression + {`a\.b[0].c`, []string{"a.b", "0", "c"}}, // Escaped dot with bracket + {"", nil}, // Empty path + {"single", []string{"single"}}, + {"items[10].deep.nested[0]", []string{"items", "10", "deep", "nested", "0"}}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + result := parseDotPath(tt.input) + if !reflect.DeepEqual(result, tt.expected) { + t.Errorf("parseDotPath(%q) = %v, want %v", + tt.input, result, tt.expected) + } + }) + } +} + +func TestYamlGet_ArrayBracketNotation(t *testing.T) { + dir := createTempDir(t) + configPath := createTestYAML(t, dir, ` +items: + - name: first + value: 1 + - name: second + value: 2 + - name: third + value: 3 +`) + + tests := []struct { + key string + expected string + }{ + {"items[0].name", "first"}, + {"items[1].name", "second"}, + {"items[2].value", "3"}, + {"items[-1].name", "third"}, // Last element + {"items[-2].name", "second"}, // Second to last + {"items[-3].value", "1"}, // Third to last (first) + } + + for _, tt := range tests { + t.Run(tt.key, func(t *testing.T) { + cmd := newYamlCmd() + buf := new(bytes.Buffer) + cmd.SetOut(buf) + cmd.SetArgs([]string{"get", "--file", configPath, tt.key, "--min"}) + + err := cmd.Execute() + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + + output := strings.TrimSpace(buf.String()) + if output != tt.expected { + t.Errorf("get %s = %q, want %q", tt.key, output, tt.expected) + } + }) + } +} + +func TestYamlGet_ArrayBracketNotation_OutOfBounds(t *testing.T) { + dir := createTempDir(t) + configPath := createTestYAML(t, dir, ` +items: + - first + - second +`) + + tests := []struct { + key string + }{ + {"items[5]"}, // Beyond end + {"items[-5]"}, // Beyond start (negative) + {"items[100]"}, // Way beyond end + } + + for _, tt := range tests { + t.Run(tt.key, func(t *testing.T) { + cmd := newYamlCmd() + buf := new(bytes.Buffer) + cmd.SetOut(buf) + cmd.SetArgs([]string{"get", "--file", configPath, tt.key}) + + err := cmd.Execute() + if err == nil { + t.Errorf("expected error for out-of-bounds index %s", tt.key) + } + }) + } +} + +func TestYamlSet_ArrayBracketNotation(t *testing.T) { + dir := createTempDir(t) + configPath := createTestYAML(t, dir, ` +items: + - name: first + value: 1 + - name: second + value: 2 +`) + + // Set value at items[0].name + cmd := newYamlCmd() + buf := new(bytes.Buffer) + cmd.SetOut(buf) + cmd.SetArgs([]string{"set", "--file", configPath, "items[0].name", "updated"}) + + err := cmd.Execute() + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + + // Verify the value was set + cmd2 := newYamlCmd() + buf2 := new(bytes.Buffer) + cmd2.SetOut(buf2) + cmd2.SetArgs([]string{"get", "--file", configPath, "items[0].name", "--min"}) + + err = cmd2.Execute() + if err != nil { + t.Fatalf("expected no error reading back, got: %v", err) + } + + output := strings.TrimSpace(buf2.String()) + if output != "updated" { + t.Errorf("expected 'updated', got: %s", output) + } +} + +func TestYamlSet_ArrayBracketNotation_NegativeIndex(t *testing.T) { + dir := createTempDir(t) + configPath := createTestYAML(t, dir, ` +items: + - name: first + - name: second + - name: third +`) + + // Set value at items[-1].name (last element) + cmd := newYamlCmd() + buf := new(bytes.Buffer) + cmd.SetOut(buf) + cmd.SetArgs([]string{"set", "--file", configPath, "items[-1].name", "last_updated"}) + + err := cmd.Execute() + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + + // Verify the last element was updated + cmd2 := newYamlCmd() + buf2 := new(bytes.Buffer) + cmd2.SetOut(buf2) + cmd2.SetArgs([]string{"get", "--file", configPath, "items[2].name", "--min"}) + + err = cmd2.Execute() + if err != nil { + t.Fatalf("expected no error reading back, got: %v", err) + } + + output := strings.TrimSpace(buf2.String()) + if output != "last_updated" { + t.Errorf("expected 'last_updated', got: %s", output) + } +} + +func TestYamlMultiget_ArrayBracketNotation(t *testing.T) { + dir := createTempDir(t) + configPath := createTestYAML(t, dir, ` +items: + - name: first + - name: second +`) + + cmd := newYamlCmd() + buf := new(bytes.Buffer) + cmd.SetOut(buf) + cmd.SetArgs([]string{"multiget", "--file", configPath, "items[0].name", "items[1].name", "--min"}) + + err := cmd.Execute() + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + + lines := strings.Split(strings.TrimSpace(buf.String()), "\n") + if len(lines) != 2 { + t.Fatalf("expected 2 lines, got %d", len(lines)) + } + if lines[0] != "first" { + t.Errorf("expected first line 'first', got: %s", lines[0]) + } + if lines[1] != "second" { + t.Errorf("expected second line 'second', got: %s", lines[1]) + } +} + +func TestParseArrayIndex_NegativeIndices(t *testing.T) { + tests := []struct { + input string + expected int + }{ + {"0", 0}, + {"1", 1}, + {"-1", -1}, + {"-2", -2}, + {"-10", -10}, + {"[0]", 0}, + {"[-1]", -1}, + {"[-5]", -5}, + {"abc", invalidArrayIndex}, + {"", invalidArrayIndex}, + {"[abc]", invalidArrayIndex}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + result := parseArrayIndex(tt.input) + if result != tt.expected { + t.Errorf("parseArrayIndex(%q) = %d, want %d", tt.input, result, tt.expected) + } + }) + } +} + +func TestGetValueAtPath_NegativeArrayIndex(t *testing.T) { + data := map[string]interface{}{ + "items": []interface{}{"first", "second", "third"}, + } + + tests := []struct { + path string + expected string + found bool + }{ + {"items.-1", "third", true}, // Last + {"items.-2", "second", true}, // Second to last + {"items.-3", "first", true}, // Third to last (first) + {"items.-4", "", false}, // Beyond start + } + + for _, tt := range tests { + t.Run(tt.path, func(t *testing.T) { + val, found := getValueAtPath(data, tt.path) + if found != tt.found { + t.Errorf("getValueAtPath(%q) found = %v, want %v", tt.path, found, tt.found) + return + } + if found && val != tt.expected { + t.Errorf("getValueAtPath(%q) = %v, want %v", tt.path, val, tt.expected) + } + }) + } +} + +func TestSetValueAtPath_ArrayIndex(t *testing.T) { + // Test setting value at array index + data := map[string]interface{}{ + "items": []interface{}{ + map[string]interface{}{"name": "first"}, + map[string]interface{}{"name": "second"}, + }, + } + + err := setValueAtPath(data, "items.0.name", "updated") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + val, found := getValueAtPath(data, "items.0.name") + if !found { + t.Fatal("expected to find items.0.name") + } + if val != "updated" { + t.Errorf("expected 'updated', got: %v", val) + } +} + +func TestSetValueAtPath_NegativeArrayIndex(t *testing.T) { + // Test setting value at negative array index + data := map[string]interface{}{ + "items": []interface{}{ + map[string]interface{}{"name": "first"}, + map[string]interface{}{"name": "second"}, + map[string]interface{}{"name": "third"}, + }, + } + + err := setValueAtPath(data, "items.-1.name", "last_updated") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + val, found := getValueAtPath(data, "items.2.name") + if !found { + t.Fatal("expected to find items.2.name") + } + if val != "last_updated" { + t.Errorf("expected 'last_updated', got: %v", val) + } +} + +func TestSetValueAtPath_ArrayIndexOutOfBounds(t *testing.T) { + data := map[string]interface{}{ + "items": []interface{}{"first", "second"}, + } + + err := setValueAtPath(data, "items.5", "invalid") + if err == nil { + t.Fatal("expected error for out-of-bounds index") + } + if !strings.Contains(err.Error(), "out of bounds") { + t.Errorf("expected 'out of bounds' error, got: %v", err) + } +} + +func TestYamlGet_NestedArrayBracketNotation(t *testing.T) { + dir := createTempDir(t) + configPath := createTestYAML(t, dir, ` +data: + lists: + - items: + - a + - b + - items: + - c + - d +`) + + cmd := newYamlCmd() + buf := new(bytes.Buffer) + cmd.SetOut(buf) + cmd.SetArgs([]string{"get", "--file", configPath, "data.lists[1].items[0]", "--min"}) + + err := cmd.Execute() + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + + output := strings.TrimSpace(buf.String()) + if output != "c" { + t.Errorf("expected 'c', got: %s", output) + } +} + +// ============================================================================ +// Stdin Input Support Tests (Sprint 9.0 - Task 02) +// ============================================================================ + +func TestYamlSet_StdinValue(t *testing.T) { + dir := createTempDir(t) + configPath := createTestYAML(t, dir, `key: original`) + + // Simulate stdin using a buffer + var stdin bytes.Buffer + stdin.WriteString("from-stdin") + + cmd := newYamlCmd() + buf := new(bytes.Buffer) + cmd.SetOut(buf) + cmd.SetIn(&stdin) + cmd.SetArgs([]string{"set", "--file", configPath, "key", "-"}) + + err := cmd.Execute() + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + + // Verify value was set + content, _ := os.ReadFile(configPath) + if !strings.Contains(string(content), "from-stdin") { + t.Errorf("expected 'from-stdin' in file, got: %s", string(content)) + } +} + +func TestYamlSet_StdinMultiline(t *testing.T) { + dir := createTempDir(t) + configPath := createTestYAML(t, dir, `description: old`) + + var stdin bytes.Buffer + stdin.WriteString("line1\nline2\nline3") + + cmd := newYamlCmd() + buf := new(bytes.Buffer) + cmd.SetOut(buf) + cmd.SetIn(&stdin) + cmd.SetArgs([]string{"set", "--file", configPath, "description", "-"}) + + err := cmd.Execute() + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + + // Verify multi-line value (YAML uses | for multi-line strings) + content, _ := os.ReadFile(configPath) + contentStr := string(content) + if !strings.Contains(contentStr, "line1") || !strings.Contains(contentStr, "line2") { + t.Errorf("expected multi-line content, got: %s", contentStr) + } +} + +func TestYamlSet_StdinEmpty(t *testing.T) { + dir := createTempDir(t) + configPath := createTestYAML(t, dir, `key: original`) + + var stdin bytes.Buffer // Empty + + cmd := newYamlCmd() + buf := new(bytes.Buffer) + cmd.SetOut(buf) + cmd.SetIn(&stdin) + cmd.SetArgs([]string{"set", "--file", configPath, "key", "-"}) + + err := cmd.Execute() + if err != nil { + t.Fatalf("expected no error for empty stdin, got: %v", err) + } + + // Value should be empty string + content, _ := os.ReadFile(configPath) + // YAML represents empty string as "" or '' + if !strings.Contains(string(content), `key: ""`) && !strings.Contains(string(content), `key: ''`) && !strings.Contains(string(content), "key:") { + t.Errorf("expected empty key value, got: %s", string(content)) + } +} + +func TestYamlSet_StdinWithTrailingNewline(t *testing.T) { + dir := createTempDir(t) + configPath := createTestYAML(t, dir, `key: original`) + + var stdin bytes.Buffer + stdin.WriteString("value-with-newline\n") // Trailing newline like from echo + + cmd := newYamlCmd() + buf := new(bytes.Buffer) + cmd.SetOut(buf) + cmd.SetIn(&stdin) + cmd.SetArgs([]string{"set", "--file", configPath, "key", "-"}) + + err := cmd.Execute() + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + + // Read back to verify trailing newline was trimmed + cmd2 := newYamlCmd() + buf2 := new(bytes.Buffer) + cmd2.SetOut(buf2) + cmd2.SetArgs([]string{"get", "--file", configPath, "key", "--min"}) + + err = cmd2.Execute() + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + + output := strings.TrimSpace(buf2.String()) + if output != "value-with-newline" { + t.Errorf("expected 'value-with-newline', got: %q", output) + } +} + +func TestYamlSet_RegularValueStillWorks(t *testing.T) { + // Verify that regular values (not "-") still work + dir := createTempDir(t) + configPath := createTestYAML(t, dir, `key: original`) + + cmd := newYamlCmd() + buf := new(bytes.Buffer) + cmd.SetOut(buf) + cmd.SetArgs([]string{"set", "--file", configPath, "key", "regular-value"}) + + err := cmd.Execute() + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + + // Verify value was set + content, _ := os.ReadFile(configPath) + if !strings.Contains(string(content), "regular-value") { + t.Errorf("expected 'regular-value' in file, got: %s", string(content)) + } +} + +func TestYamlSet_LiteralDashValue(t *testing.T) { + // Test that a literal dash is interpreted as stdin (by design) + // This test documents the behavior + dir := createTempDir(t) + configPath := createTestYAML(t, dir, `key: original`) + + var stdin bytes.Buffer + stdin.WriteString("stdin-value") + + cmd := newYamlCmd() + buf := new(bytes.Buffer) + cmd.SetOut(buf) + cmd.SetIn(&stdin) + cmd.SetArgs([]string{"set", "--file", configPath, "key", "-"}) + + err := cmd.Execute() + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + + // Verify stdin was used, not literal "-" + content, _ := os.ReadFile(configPath) + if strings.Contains(string(content), "key: -") { + t.Error("expected stdin value, not literal '-'") + } + if !strings.Contains(string(content), "stdin-value") { + t.Errorf("expected 'stdin-value' in file, got: %s", string(content)) + } +} + +// ============================================================================ +// Dry-Run Flag Tests (Sprint 9.0 - Task 03) +// ============================================================================ + +func TestYamlSet_DryRun(t *testing.T) { + dir := createTempDir(t) + configPath := createTestYAML(t, dir, `key: original`) + + cmd := newYamlCmd() + buf := new(bytes.Buffer) + cmd.SetOut(buf) + cmd.SetArgs([]string{"set", "--file", configPath, "key", "new-value", "--dry-run"}) + + err := cmd.Execute() + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + + // Verify output contains preview + output := buf.String() + if !strings.Contains(output, "DRY RUN") { + t.Errorf("expected 'DRY RUN' in output, got: %s", output) + } + if !strings.Contains(output, "Old: original") { + t.Errorf("expected 'Old: original' in output, got: %s", output) + } + if !strings.Contains(output, "New: new-value") { + t.Errorf("expected 'New: new-value' in output, got: %s", output) + } + + // Verify file NOT modified + content, _ := os.ReadFile(configPath) + if !strings.Contains(string(content), "original") { + t.Error("expected file to still contain 'original'") + } + if strings.Contains(string(content), "new-value") { + t.Error("expected file to NOT contain 'new-value'") + } +} + +func TestYamlSet_DryRunJSON(t *testing.T) { + dir := createTempDir(t) + configPath := createTestYAML(t, dir, `key: original`) + + cmd := newYamlCmd() + buf := new(bytes.Buffer) + cmd.SetOut(buf) + cmd.SetArgs([]string{"set", "--file", configPath, "key", "new-value", "--dry-run", "--json"}) + + err := cmd.Execute() + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + + // Verify JSON output + var result map[string]interface{} + if err := json.Unmarshal(buf.Bytes(), &result); err != nil { + t.Fatalf("failed to parse JSON: %v", err) + } + if result["dry_run"] != true { + t.Error("expected dry_run: true") + } + + // File NOT modified + content, _ := os.ReadFile(configPath) + if strings.Contains(string(content), "new-value") { + t.Error("expected file to NOT contain 'new-value'") + } +} + +func TestYamlSet_DryRunMin(t *testing.T) { + dir := createTempDir(t) + configPath := createTestYAML(t, dir, `key: original`) + + cmd := newYamlCmd() + buf := new(bytes.Buffer) + cmd.SetOut(buf) + cmd.SetArgs([]string{"set", "--file", configPath, "key", "new-value", "--dry-run", "--min"}) + + err := cmd.Execute() + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + + // Verify minimal output: key: old → new + output := buf.String() + if !strings.Contains(output, "key:") && !strings.Contains(output, "→") { + t.Errorf("expected 'key: old → new' format, got: %s", output) + } +} + +func TestYamlSet_DryRunNewKey(t *testing.T) { + // Test dry-run for a key that doesn't exist yet + dir := createTempDir(t) + configPath := createTestYAML(t, dir, `existing: value`) + + cmd := newYamlCmd() + buf := new(bytes.Buffer) + cmd.SetOut(buf) + cmd.SetArgs([]string{"set", "--file", configPath, "new_key", "new-value", "--dry-run"}) + + err := cmd.Execute() + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + + // Verify output shows + output := buf.String() + if !strings.Contains(output, "") { + t.Errorf("expected '' for new key, got: %s", output) + } +} + +func TestYamlMultiset_DryRun(t *testing.T) { + dir := createTempDir(t) + configPath := createTestYAML(t, dir, ` +a: 1 +b: 2 +`) + + cmd := newYamlCmd() + buf := new(bytes.Buffer) + cmd.SetOut(buf) + cmd.SetArgs([]string{"multiset", "--file", configPath, "a", "10", "b", "20", "--dry-run"}) + + err := cmd.Execute() + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + + // Verify output contains preview + output := buf.String() + if !strings.Contains(output, "DRY RUN") { + t.Errorf("expected 'DRY RUN' in output, got: %s", output) + } + if !strings.Contains(output, "Changes (2)") { + t.Errorf("expected 'Changes (2)' in output, got: %s", output) + } + + // Verify file NOT modified + content, _ := os.ReadFile(configPath) + if strings.Contains(string(content), "10") || strings.Contains(string(content), "20") { + t.Error("expected file to NOT contain new values") + } +} + +func TestYamlMultiset_DryRunJSON(t *testing.T) { + dir := createTempDir(t) + configPath := createTestYAML(t, dir, ` +a: 1 +b: 2 +`) + + cmd := newYamlCmd() + buf := new(bytes.Buffer) + cmd.SetOut(buf) + cmd.SetArgs([]string{"multiset", "--file", configPath, "a", "10", "b", "20", "--dry-run", "--json"}) + + err := cmd.Execute() + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + + // Verify JSON output + var result map[string]interface{} + if err := json.Unmarshal(buf.Bytes(), &result); err != nil { + t.Fatalf("failed to parse JSON: %v", err) + } + if result["dry_run"] != true { + t.Error("expected dry_run: true") + } + changes, ok := result["changes"].([]interface{}) + if !ok || len(changes) != 2 { + t.Errorf("expected 2 changes, got: %v", result["changes"]) + } +} + +// ============================================================================ +// Task 04: Quiet Flag Tests +// ============================================================================ + +func TestYamlSet_Quiet(t *testing.T) { + dir := createTempDir(t) + configPath := createTestYAML(t, dir, `key: original`) + + cmd := newYamlCmd() + buf := new(bytes.Buffer) + cmd.SetOut(buf) + cmd.SetArgs([]string{"set", "--file", configPath, "key", "new", "--quiet"}) + + err := cmd.Execute() + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + + // Verify NO output on success + if buf.String() != "" { + t.Errorf("expected empty output with --quiet, got: %q", buf.String()) + } + + // But file WAS modified + content, _ := os.ReadFile(configPath) + if !strings.Contains(string(content), "new") { + t.Error("expected file to contain 'new'") + } +} + +func TestYamlSet_QuietStillReturnsErrors(t *testing.T) { + cmd := newYamlCmd() + buf := new(bytes.Buffer) + cmd.SetOut(buf) + cmd.SetArgs([]string{"set", "--file", "/nonexistent/path/config.yaml", "key", "value", "--quiet"}) + + err := cmd.Execute() + // Errors still returned + if err == nil { + t.Fatal("expected error for nonexistent file") + } +} + +func TestYamlMultiset_Quiet(t *testing.T) { + dir := createTempDir(t) + configPath := createTestYAML(t, dir, ` +a: 1 +b: 2 +`) + + cmd := newYamlCmd() + buf := new(bytes.Buffer) + cmd.SetOut(buf) + cmd.SetArgs([]string{"multiset", "--file", configPath, "a", "10", "b", "20", "--quiet"}) + + err := cmd.Execute() + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + + // No output + if buf.String() != "" { + t.Errorf("expected empty output with --quiet, got: %q", buf.String()) + } + + // File modified + content, _ := os.ReadFile(configPath) + if !strings.Contains(string(content), "a: 10") { + t.Error("expected file to contain 'a: 10'") + } +} + +func TestYamlSet_QuietWithJSON(t *testing.T) { + // --quiet should take precedence over --json for success output + dir := createTempDir(t) + configPath := createTestYAML(t, dir, `key: original`) + + cmd := newYamlCmd() + buf := new(bytes.Buffer) + cmd.SetOut(buf) + cmd.SetArgs([]string{"set", "--file", configPath, "key", "new", "--quiet", "--json"}) + + err := cmd.Execute() + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + + // Even with --json, quiet suppresses output + if buf.String() != "" { + t.Errorf("expected empty output with --quiet --json, got: %q", buf.String()) + } +} + +func TestYamlSet_QuietWithMin(t *testing.T) { + // --quiet should take precedence over --min for success output + dir := createTempDir(t) + configPath := createTestYAML(t, dir, `key: original`) + + cmd := newYamlCmd() + buf := new(bytes.Buffer) + cmd.SetOut(buf) + cmd.SetArgs([]string{"set", "--file", configPath, "key", "new", "--quiet", "--min"}) + + err := cmd.Execute() + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + + // Even with --min, quiet suppresses output + if buf.String() != "" { + t.Errorf("expected empty output with --quiet --min, got: %q", buf.String()) + } +} + +func TestYamlSet_QuietDoesNotSuppressDryRun(t *testing.T) { + // Dry-run output should NOT be suppressed by quiet + dir := createTempDir(t) + configPath := createTestYAML(t, dir, `key: original`) + + cmd := newYamlCmd() + buf := new(bytes.Buffer) + cmd.SetOut(buf) + cmd.SetArgs([]string{"set", "--file", configPath, "key", "new", "--quiet", "--dry-run"}) + + err := cmd.Execute() + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + + // Dry-run preview is still output even with --quiet + // (quiet only suppresses success messages, dry-run is informational) + output := buf.String() + if !strings.Contains(output, "DRY RUN") { + t.Errorf("expected dry-run output even with --quiet, got: %q", output) + } +} + +func TestYamlMultiset_QuietWithJSON(t *testing.T) { + dir := createTempDir(t) + configPath := createTestYAML(t, dir, ` +a: 1 +b: 2 +`) + + cmd := newYamlCmd() + buf := new(bytes.Buffer) + cmd.SetOut(buf) + cmd.SetArgs([]string{"multiset", "--file", configPath, "a", "10", "b", "20", "--quiet", "--json"}) + + err := cmd.Execute() + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + + // Even with --json, quiet suppresses output + if buf.String() != "" { + t.Errorf("expected empty output with --quiet --json, got: %q", buf.String()) + } +} + +// ============================================================================ +// Task 05: Required Keys from File Tests +// ============================================================================ + +func TestParseRequiredKeysFile(t *testing.T) { + // Create temp file with test content + content := `# Database config +db.host +db.port +db.name + +# API config +api.url +api.key # inline comment + +# Empty lines above ignored +` + tmpFile, err := os.CreateTemp("", "required-keys-*.txt") + if err != nil { + t.Fatalf("failed to create temp file: %v", err) + } + defer os.Remove(tmpFile.Name()) + tmpFile.WriteString(content) + tmpFile.Close() + + keys, err := parseRequiredKeysFile(tmpFile.Name()) + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + + expected := []string{"db.host", "db.port", "db.name", "api.url", "api.key"} + if !reflect.DeepEqual(keys, expected) { + t.Errorf("expected %v, got %v", expected, keys) + } +} + +func TestParseRequiredKeysFile_NotFound(t *testing.T) { + _, err := parseRequiredKeysFile("nonexistent-file-12345.txt") + if err == nil { + t.Fatal("expected error for nonexistent file") + } +} + +func TestYamlMultiget_RequiredFile(t *testing.T) { + // Create YAML file + dir := createTempDir(t) + configPath := createTestYAML(t, dir, ` +db: + host: localhost + port: 5432 +api: + url: https://api.example.com +`) + + // Create required keys file + keysFile, err := os.CreateTemp("", "keys-*.txt") + if err != nil { + t.Fatalf("failed to create keys file: %v", err) + } + defer os.Remove(keysFile.Name()) + keysFile.WriteString("db.host\ndb.port\napi.url\n") + keysFile.Close() + + cmd := newYamlCmd() + buf := new(bytes.Buffer) + cmd.SetOut(buf) + cmd.SetArgs([]string{"multiget", "--file", configPath, "--required-file", keysFile.Name()}) + + err = cmd.Execute() + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + + output := buf.String() + if !strings.Contains(output, "localhost") { + t.Errorf("expected output to contain 'localhost', got: %s", output) + } + if !strings.Contains(output, "5432") { + t.Errorf("expected output to contain '5432', got: %s", output) + } + if !strings.Contains(output, "https://api.example.com") { + t.Errorf("expected output to contain 'https://api.example.com', got: %s", output) + } +} + +func TestYamlMultiget_RequiredFileMissingKey(t *testing.T) { + dir := createTempDir(t) + configPath := createTestYAML(t, dir, ` +db: + host: localhost +`) + + // Create required keys file with a key that doesn't exist + keysFile, err := os.CreateTemp("", "keys-*.txt") + if err != nil { + t.Fatalf("failed to create keys file: %v", err) + } + defer os.Remove(keysFile.Name()) + keysFile.WriteString("db.host\ndb.port\n") // db.port doesn't exist + keysFile.Close() + + cmd := newYamlCmd() + cmd.SetArgs([]string{"multiget", "--file", configPath, "--required-file", keysFile.Name()}) + + err = cmd.Execute() + if err == nil { + t.Fatal("expected error for missing key") + } + if !strings.Contains(err.Error(), "db.port") { + t.Errorf("expected error to mention 'db.port', got: %v", err) + } +} + +func TestYamlMultiget_CombinedSources(t *testing.T) { + dir := createTempDir(t) + configPath := createTestYAML(t, dir, ` +a: 1 +b: 2 +c: 3 +`) + + // Create required keys file with "a" + keysFile, err := os.CreateTemp("", "keys-*.txt") + if err != nil { + t.Fatalf("failed to create keys file: %v", err) + } + defer os.Remove(keysFile.Name()) + keysFile.WriteString("a\n") + keysFile.Close() + + cmd := newYamlCmd() + buf := new(bytes.Buffer) + cmd.SetOut(buf) + // Combine: positional "b" and "c", --required-file has "a" + cmd.SetArgs([]string{"multiget", "--file", configPath, "b", "c", "--required-file", keysFile.Name()}) + + err = cmd.Execute() + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + + output := buf.String() + // All three keys should be present + if !strings.Contains(output, "b=2") { + t.Errorf("expected output to contain 'b=2', got: %s", output) + } + if !strings.Contains(output, "c=3") { + t.Errorf("expected output to contain 'c=3', got: %s", output) + } + if !strings.Contains(output, "a=1") { + t.Errorf("expected output to contain 'a=1', got: %s", output) + } +} + +func TestYamlUniqueStrings(t *testing.T) { + input := []string{"a", "b", "a", "c", "b", "d"} + expected := []string{"a", "b", "c", "d"} + result := yamlUniqueStrings(input) + if !reflect.DeepEqual(result, expected) { + t.Errorf("expected %v, got %v", expected, result) + } +} + +func TestYamlMultiget_DuplicateKeys(t *testing.T) { + // Verify duplicate keys are handled correctly + dir := createTempDir(t) + configPath := createTestYAML(t, dir, ` +a: 1 +b: 2 +`) + + // Create required keys file with duplicates + keysFile, err := os.CreateTemp("", "keys-*.txt") + if err != nil { + t.Fatalf("failed to create keys file: %v", err) + } + defer os.Remove(keysFile.Name()) + keysFile.WriteString("a\nb\na\n") // "a" appears twice + keysFile.Close() + + cmd := newYamlCmd() + buf := new(bytes.Buffer) + cmd.SetOut(buf) + cmd.SetArgs([]string{"multiget", "--file", configPath, "--required-file", keysFile.Name()}) + + err = cmd.Execute() + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + + output := buf.String() + // Each key should appear only once in output + count := strings.Count(output, "a=1") + if count != 1 { + t.Errorf("expected 'a=1' to appear once, appeared %d times in: %s", count, output) + } +} + +func TestYamlMultiget_RequiredFileWithBracketNotation(t *testing.T) { + dir := createTempDir(t) + configPath := createTestYAML(t, dir, ` +items: + - name: first + - name: second +`) + + // Create required keys file with bracket notation + keysFile, err := os.CreateTemp("", "keys-*.txt") + if err != nil { + t.Fatalf("failed to create keys file: %v", err) + } + defer os.Remove(keysFile.Name()) + keysFile.WriteString("items[0].name\nitems[1].name\n") + keysFile.Close() + + cmd := newYamlCmd() + buf := new(bytes.Buffer) + cmd.SetOut(buf) + cmd.SetArgs([]string{"multiget", "--file", configPath, "--required-file", keysFile.Name()}) + + err = cmd.Execute() + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + + output := buf.String() + if !strings.Contains(output, "first") { + t.Errorf("expected output to contain 'first', got: %s", output) + } + if !strings.Contains(output, "second") { + t.Errorf("expected output to contain 'second', got: %s", output) + } +} diff --git a/internal/support/mcpserver/handlers.go b/internal/support/mcpserver/handlers.go index 068eb3f..87a8cf7 100644 --- a/internal/support/mcpserver/handlers.go +++ b/internal/support/mcpserver/handlers.go @@ -959,6 +959,16 @@ func buildYamlSetArgs(args map[string]interface{}) []string { cmdArgs = append(cmdArgs, "--create") } + // Dry-run flag + if getBool(args, "dry_run") { + cmdArgs = append(cmdArgs, "--dry-run") + } + + // Quiet flag + if getBool(args, "quiet") { + cmdArgs = append(cmdArgs, "--quiet") + } + // Output format flags if getBoolDefault(args, "json", true) { cmdArgs = append(cmdArgs, "--json") @@ -996,6 +1006,11 @@ func buildYamlMultigetArgs(args map[string]interface{}) []string { } } + // Required file flag + if requiredFile, ok := args["required_file"].(string); ok && requiredFile != "" { + cmdArgs = append(cmdArgs, "--required-file", requiredFile) + } + // Output format flags if getBoolDefault(args, "json", true) { cmdArgs = append(cmdArgs, "--json") @@ -1032,6 +1047,16 @@ func buildYamlMultisetArgs(args map[string]interface{}) []string { cmdArgs = append(cmdArgs, "--create") } + // Dry-run flag + if getBool(args, "dry_run") { + cmdArgs = append(cmdArgs, "--dry-run") + } + + // Quiet flag + if getBool(args, "quiet") { + cmdArgs = append(cmdArgs, "--quiet") + } + // Output format flags if getBoolDefault(args, "json", true) { cmdArgs = append(cmdArgs, "--json") diff --git a/internal/support/mcpserver/handlers_test.go b/internal/support/mcpserver/handlers_test.go index d7a37a6..c2f4f30 100644 --- a/internal/support/mcpserver/handlers_test.go +++ b/internal/support/mcpserver/handlers_test.go @@ -1123,6 +1123,21 @@ func TestBuildYamlSetArgs(t *testing.T) { args: map[string]interface{}{"file": "/tmp/config.yaml", "key": "helper.llm", "value": "claude", "json": false, "min": false}, want: []string{"yaml", "set", "--file", "/tmp/config.yaml", "helper.llm", "claude"}, }, + { + name: "with dry_run flag", + args: map[string]interface{}{"file": "/tmp/config.yaml", "key": "helper.llm", "value": "claude", "dry_run": true}, + want: []string{"yaml", "set", "--file", "/tmp/config.yaml", "helper.llm", "claude", "--dry-run", "--json", "--min"}, + }, + { + name: "with quiet flag", + args: map[string]interface{}{"file": "/tmp/config.yaml", "key": "helper.llm", "value": "claude", "quiet": true}, + want: []string{"yaml", "set", "--file", "/tmp/config.yaml", "helper.llm", "claude", "--quiet", "--json", "--min"}, + }, + { + name: "with both dry_run and quiet", + args: map[string]interface{}{"file": "/tmp/config.yaml", "key": "helper.llm", "value": "claude", "dry_run": true, "quiet": true}, + want: []string{"yaml", "set", "--file", "/tmp/config.yaml", "helper.llm", "claude", "--dry-run", "--quiet", "--json", "--min"}, + }, } for _, tt := range tests { @@ -1179,6 +1194,15 @@ func TestBuildYamlMultigetArgs(t *testing.T) { }, want: []string{"yaml", "multiget", "--file", "/tmp/config.yaml", "helper.llm", "--json", "--min"}, }, + { + name: "with required_file parameter", + args: map[string]interface{}{ + "file": "/tmp/config.yaml", + "keys": []interface{}{"helper.llm"}, + "required_file": "/tmp/required-keys.txt", + }, + want: []string{"yaml", "multiget", "--file", "/tmp/config.yaml", "helper.llm", "--required-file", "/tmp/required-keys.txt", "--json", "--min"}, + }, } for _, tt := range tests { @@ -1227,6 +1251,49 @@ func TestBuildYamlMultisetArgs(t *testing.T) { wantSuffix: []string{"--json", "--min"}, skipDeepEqual: true, }, + { + name: "with dry_run flag", + args: map[string]interface{}{ + "file": "/tmp/config.yaml", + "pairs": map[string]interface{}{ + "key": "value", + }, + "dry_run": true, + }, + wantPrefix: []string{"yaml", "multiset", "--file", "/tmp/config.yaml"}, + wantContains: []string{"key", "value", "--dry-run"}, + wantSuffix: []string{"--json", "--min"}, + skipDeepEqual: true, + }, + { + name: "with quiet flag", + args: map[string]interface{}{ + "file": "/tmp/config.yaml", + "pairs": map[string]interface{}{ + "key": "value", + }, + "quiet": true, + }, + wantPrefix: []string{"yaml", "multiset", "--file", "/tmp/config.yaml"}, + wantContains: []string{"key", "value", "--quiet"}, + wantSuffix: []string{"--json", "--min"}, + skipDeepEqual: true, + }, + { + name: "with both dry_run and quiet", + args: map[string]interface{}{ + "file": "/tmp/config.yaml", + "pairs": map[string]interface{}{ + "key": "value", + }, + "dry_run": true, + "quiet": true, + }, + wantPrefix: []string{"yaml", "multiset", "--file", "/tmp/config.yaml"}, + wantContains: []string{"key", "value", "--dry-run", "--quiet"}, + wantSuffix: []string{"--json", "--min"}, + skipDeepEqual: true, + }, } for _, tt := range tests { diff --git a/internal/support/mcpserver/tools.go b/internal/support/mcpserver/tools.go index e00b494..78713b5 100644 --- a/internal/support/mcpserver/tools.go +++ b/internal/support/mcpserver/tools.go @@ -866,7 +866,7 @@ func GetToolDefinitions() []ToolDefinition { // 26. YAML set - store value at dot-notation key { Name: ToolPrefix + "yaml_set", - Description: "Store a value in YAML config file at the specified key. Creates intermediate keys if needed. Preserves comments.", + Description: "Store a value in YAML config file at the specified key. Creates intermediate keys if needed. Preserves comments. Supports array bracket notation (e.g., items[0].name, items[-1]).", InputSchema: json.RawMessage(`{ "type": "object", "properties": { @@ -876,7 +876,7 @@ func GetToolDefinitions() []ToolDefinition { }, "key": { "type": "string", - "description": "Dot-notation key (e.g., helper.llm)" + "description": "Dot-notation key (e.g., helper.llm, items[0].name)" }, "value": { "type": "string", @@ -886,6 +886,14 @@ func GetToolDefinitions() []ToolDefinition { "type": "boolean", "description": "Create file if it doesn't exist" }, + "dry_run": { + "type": "boolean", + "description": "Preview changes without writing to file" + }, + "quiet": { + "type": "boolean", + "description": "Suppress success messages (errors still returned)" + }, "json": { "type": "boolean", "description": "Output as JSON" @@ -902,7 +910,7 @@ func GetToolDefinitions() []ToolDefinition { // 27. YAML multiget - retrieve multiple values { Name: ToolPrefix + "yaml_multiget", - Description: "Retrieve multiple values from YAML config file in a single operation. More efficient than multiple get calls.", + Description: "Retrieve multiple values from YAML config file in a single operation. More efficient than multiple get calls. Supports array bracket notation (e.g., items[0].name).", InputSchema: json.RawMessage(`{ "type": "object", "properties": { @@ -919,6 +927,10 @@ func GetToolDefinitions() []ToolDefinition { "type": "object", "description": "Default values for keys (e.g., {\"helper.llm\": \"gemini\"})" }, + "required_file": { + "type": "string", + "description": "Path to file containing required keys (one per line, # comments supported)" + }, "json": { "type": "boolean", "description": "Output as JSON" @@ -928,14 +940,14 @@ func GetToolDefinitions() []ToolDefinition { "description": "Minimal output - values only, newline-separated" } }, - "required": ["file", "keys"] + "required": ["file"] }`), }, // 28. YAML multiset - set multiple key-value pairs { Name: ToolPrefix + "yaml_multiset", - Description: "Set multiple key-value pairs in YAML config file atomically. Validates all keys before writing.", + Description: "Set multiple key-value pairs in YAML config file atomically. Validates all keys before writing. Supports array bracket notation (e.g., items[0].name).", InputSchema: json.RawMessage(`{ "type": "object", "properties": { @@ -951,6 +963,14 @@ func GetToolDefinitions() []ToolDefinition { "type": "boolean", "description": "Create file if it doesn't exist" }, + "dry_run": { + "type": "boolean", + "description": "Preview changes without writing to file" + }, + "quiet": { + "type": "boolean", + "description": "Suppress success messages (errors still returned)" + }, "json": { "type": "boolean", "description": "Output as JSON" diff --git a/internal/support/mcpserver/tools_test.go b/internal/support/mcpserver/tools_test.go index 18b2b6c..01c7647 100644 --- a/internal/support/mcpserver/tools_test.go +++ b/internal/support/mcpserver/tools_test.go @@ -138,7 +138,7 @@ func TestToolSchemaRequiredFields(t *testing.T) { "llm_support_context_multiget": {"dir", "keys"}, "llm_support_yaml_get": {"file", "key"}, "llm_support_yaml_set": {"file", "key", "value"}, - "llm_support_yaml_multiget": {"file", "keys"}, + "llm_support_yaml_multiget": {"file"}, "llm_support_yaml_multiset": {"file", "pairs"}, }