diff --git a/claudetool/lsp/format.go b/claudetool/lsp/format.go new file mode 100644 index 0000000..9d08aaa --- /dev/null +++ b/claudetool/lsp/format.go @@ -0,0 +1,171 @@ +package lsp + +import ( + "fmt" + "os" + "path/filepath" + "strings" +) + +const maxReferences = 50 + +// formatDefinition formats definition locations for display. +func formatDefinition(locations []Location, wd string) string { + if len(locations) == 0 { + return "No definition found." + } + + var sb strings.Builder + for i, loc := range locations { + if i > 0 { + sb.WriteString("\n\n") + } + path := filePathFromURI(loc.URI) + relPath := relativePath(path, wd) + line := loc.Range.Start.Line + 1 // convert 0-based to 1-based for display + sb.WriteString(fmt.Sprintf("**%s:%d**\n", relPath, line)) + + // Read source context around the definition + context := readSourceContext(path, loc.Range.Start.Line, 5) + if context != "" { + sb.WriteString("```\n") + sb.WriteString(context) + sb.WriteString("\n```") + } + } + return sb.String() +} + +// formatReferences formats reference locations for display, grouped by file. +func formatReferences(locations []Location, wd string) string { + if len(locations) == 0 { + return "No references found." + } + + // Group by file + type fileRef struct { + path string + refs []Location + } + fileOrder := []string{} + byFile := make(map[string][]Location) + for _, loc := range locations { + path := filePathFromURI(loc.URI) + if _, exists := byFile[path]; !exists { + fileOrder = append(fileOrder, path) + } + byFile[path] = append(byFile[path], loc) + } + + var sb strings.Builder + total := len(locations) + truncated := total > maxReferences + if truncated { + sb.WriteString(fmt.Sprintf("Found %d references (showing first %d):\n\n", total, maxReferences)) + } else { + sb.WriteString(fmt.Sprintf("Found %d reference(s):\n\n", total)) + } + + shown := 0 + for _, path := range fileOrder { + if shown >= maxReferences { + break + } + refs := byFile[path] + relPath := relativePath(path, wd) + sb.WriteString(fmt.Sprintf("**%s**\n", relPath)) + for _, ref := range refs { + if shown >= maxReferences { + break + } + line := ref.Range.Start.Line + 1 + context := readSourceLine(path, ref.Range.Start.Line) + if context != "" { + sb.WriteString(fmt.Sprintf(" L%d: %s\n", line, strings.TrimSpace(context))) + } else { + sb.WriteString(fmt.Sprintf(" L%d\n", line)) + } + shown++ + } + sb.WriteString("\n") + } + return strings.TrimRight(sb.String(), "\n") +} + +// formatHover formats hover information for display. +func formatHover(hover *Hover) string { + if hover == nil || hover.Contents.Value == "" { + return "No hover information available." + } + return hover.Contents.Value +} + +// formatSymbols formats workspace symbol results for display. +func formatSymbols(symbols []SymbolInformation, wd string) string { + if len(symbols) == 0 { + return "No symbols found." + } + + var sb strings.Builder + sb.WriteString(fmt.Sprintf("Found %d symbol(s):\n\n", len(symbols))) + for _, sym := range symbols { + path := filePathFromURI(sym.Location.URI) + relPath := relativePath(path, wd) + line := sym.Location.Range.Start.Line + 1 + kind := SymbolKindName(sym.Kind) + if sym.ContainerName != "" { + sb.WriteString(fmt.Sprintf("- %s (%s) in %s — %s:%d\n", sym.Name, kind, sym.ContainerName, relPath, line)) + } else { + sb.WriteString(fmt.Sprintf("- %s (%s) — %s:%d\n", sym.Name, kind, relPath, line)) + } + } + return strings.TrimRight(sb.String(), "\n") +} + +// readSourceContext reads a few lines around the given 0-based line from a file. +func readSourceContext(path string, line int, contextLines int) string { + data, err := os.ReadFile(path) + if err != nil { + return "" + } + lines := strings.Split(string(data), "\n") + start := line - contextLines + if start < 0 { + start = 0 + } + end := line + contextLines + 1 + if end > len(lines) { + end = len(lines) + } + var sb strings.Builder + for i := start; i < end; i++ { + marker := " " + if i == line { + marker = "> " + } + sb.WriteString(fmt.Sprintf("%s%4d | %s\n", marker, i+1, lines[i])) + } + return strings.TrimRight(sb.String(), "\n") +} + +// readSourceLine reads a single 0-based line from a file. +func readSourceLine(path string, line int) string { + data, err := os.ReadFile(path) + if err != nil { + return "" + } + lines := strings.Split(string(data), "\n") + if line < 0 || line >= len(lines) { + return "" + } + return lines[line] +} + +// relativePath returns the path relative to wd, or the original path if it can't be made relative. +func relativePath(path, wd string) string { + rel, err := filepath.Rel(wd, path) + if err != nil { + return path + } + return rel +} diff --git a/claudetool/lsp/format_test.go b/claudetool/lsp/format_test.go new file mode 100644 index 0000000..23c63f2 --- /dev/null +++ b/claudetool/lsp/format_test.go @@ -0,0 +1,227 @@ +package lsp + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +func TestFormatDefinitionEmpty(t *testing.T) { + result := formatDefinition(nil, "/tmp") + if result != "No definition found." { + t.Errorf("got %q, want %q", result, "No definition found.") + } +} + +func TestFormatDefinitionSingle(t *testing.T) { + dir := t.TempDir() + testFile := filepath.Join(dir, "main.go") + os.WriteFile(testFile, []byte("package main\n\nfunc main() {\n\tprintln(\"hello\")\n}\n"), 0o644) + + locations := []Location{ + { + URI: fileURI(testFile), + Range: Range{ + Start: Position{Line: 2, Character: 5}, + End: Position{Line: 2, Character: 9}, + }, + }, + } + + result := formatDefinition(locations, dir) + if !strings.Contains(result, "main.go:3") { + t.Errorf("expected result to contain 'main.go:3' (1-based line), got:\n%s", result) + } + if !strings.Contains(result, "func main()") { + t.Errorf("expected result to contain source context, got:\n%s", result) + } +} + +func TestFormatReferencesEmpty(t *testing.T) { + result := formatReferences(nil, "/tmp") + if result != "No references found." { + t.Errorf("got %q, want %q", result, "No references found.") + } +} + +func TestFormatReferencesGroupedByFile(t *testing.T) { + dir := t.TempDir() + file1 := filepath.Join(dir, "a.go") + file2 := filepath.Join(dir, "b.go") + os.WriteFile(file1, []byte("package main\n\nvar x = 1\n"), 0o644) + os.WriteFile(file2, []byte("package main\n\nvar y = x\n"), 0o644) + + locations := []Location{ + {URI: fileURI(file1), Range: Range{Start: Position{Line: 2, Character: 4}}}, + {URI: fileURI(file2), Range: Range{Start: Position{Line: 2, Character: 8}}}, + } + + result := formatReferences(locations, dir) + if !strings.Contains(result, "a.go") { + t.Errorf("expected result to contain a.go, got:\n%s", result) + } + if !strings.Contains(result, "b.go") { + t.Errorf("expected result to contain b.go, got:\n%s", result) + } + if !strings.Contains(result, "2 reference") { + t.Errorf("expected result to mention 2 references, got:\n%s", result) + } +} + +func TestFormatReferenceTruncation(t *testing.T) { + dir := t.TempDir() + testFile := filepath.Join(dir, "test.go") + os.WriteFile(testFile, []byte("package main\n"), 0o644) + + // Create more than maxReferences locations + var locations []Location + for i := 0; i < maxReferences+10; i++ { + locations = append(locations, Location{ + URI: fileURI(testFile), + Range: Range{ + Start: Position{Line: 0, Character: 0}, + }, + }) + } + + result := formatReferences(locations, dir) + if !strings.Contains(result, "showing first 50") { + t.Errorf("expected truncation message, got:\n%s", result) + } +} + +func TestFormatHoverEmpty(t *testing.T) { + result := formatHover(nil) + if result != "No hover information available." { + t.Errorf("got %q", result) + } + + result = formatHover(&Hover{Contents: MarkupContent{}}) + if result != "No hover information available." { + t.Errorf("got %q", result) + } +} + +func TestFormatHoverWithContent(t *testing.T) { + hover := &Hover{ + Contents: MarkupContent{ + Kind: "markdown", + Value: "```go\nfunc Println(a ...any) (n int, err error)\n```\nPrintln formats using the default formats...", + }, + } + result := formatHover(hover) + if !strings.Contains(result, "func Println") { + t.Errorf("expected hover content, got:\n%s", result) + } +} + +func TestFormatSymbolsEmpty(t *testing.T) { + result := formatSymbols(nil, "/tmp") + if result != "No symbols found." { + t.Errorf("got %q", result) + } +} + +func TestFormatSymbols(t *testing.T) { + symbols := []SymbolInformation{ + { + Name: "ProcessOneTurn", + Kind: SymbolKindFunction, + Location: Location{URI: fileURI("/project/loop/loop.go"), Range: Range{Start: Position{Line: 99}}}, + }, + { + Name: "Run", + Kind: SymbolKindMethod, + ContainerName: "Server", + Location: Location{URI: fileURI("/project/server/server.go"), Range: Range{Start: Position{Line: 49}}}, + }, + } + + result := formatSymbols(symbols, "/project") + if !strings.Contains(result, "ProcessOneTurn") { + t.Errorf("expected ProcessOneTurn in output, got:\n%s", result) + } + if !strings.Contains(result, "Function") { + t.Errorf("expected Function kind, got:\n%s", result) + } + if !strings.Contains(result, "loop/loop.go:100") { + t.Errorf("expected 1-based line number, got:\n%s", result) + } + if !strings.Contains(result, "in Server") { + t.Errorf("expected container name, got:\n%s", result) + } +} + +func TestRelativePath(t *testing.T) { + tests := []struct { + path string + wd string + want string + }{ + {"/project/src/main.go", "/project", "src/main.go"}, + {"/other/file.go", "/project", "../other/file.go"}, + {"/project/file.go", "/project", "file.go"}, + } + for _, tt := range tests { + got := relativePath(tt.path, tt.wd) + if got != tt.want { + t.Errorf("relativePath(%q, %q) = %q, want %q", tt.path, tt.wd, got, tt.want) + } + } +} + +func TestFileURI(t *testing.T) { + got := fileURI("/home/user/file.go") + if !strings.HasPrefix(got, "file://") { + t.Errorf("fileURI should start with file://, got %q", got) + } + if !strings.Contains(got, "file.go") { + t.Errorf("fileURI should contain filename, got %q", got) + } +} + +func TestFilePathFromURI(t *testing.T) { + // Round-trip test + original := "/home/user/project/main.go" + uri := fileURI(original) + back := filePathFromURI(uri) + if back != original { + t.Errorf("round trip: %q -> %q -> %q", original, uri, back) + } +} + +func TestLanguageID(t *testing.T) { + tests := []struct { + path string + want string + }{ + {"main.go", "go"}, + {"app.ts", "typescript"}, + {"app.tsx", "typescriptreact"}, + {"app.js", "javascript"}, + {"app.jsx", "javascriptreact"}, + {"script.py", "python"}, + {"main.rs", "rust"}, + {"Main.java", "java"}, + {"readme.txt", "plaintext"}, + } + for _, tt := range tests { + got := languageID(tt.path) + if got != tt.want { + t.Errorf("languageID(%q) = %q, want %q", tt.path, got, tt.want) + } + } +} + +func TestSymbolKindName(t *testing.T) { + if got := SymbolKindName(SymbolKindFunction); got != "Function" { + t.Errorf("SymbolKindName(Function) = %q", got) + } + if got := SymbolKindName(SymbolKindStruct); got != "Struct" { + t.Errorf("SymbolKindName(Struct) = %q", got) + } + if got := SymbolKindName(SymbolKind(999)); got != "Unknown" { + t.Errorf("SymbolKindName(999) = %q", got) + } +} diff --git a/claudetool/lsp/manager.go b/claudetool/lsp/manager.go new file mode 100644 index 0000000..d1359e8 --- /dev/null +++ b/claudetool/lsp/manager.go @@ -0,0 +1,113 @@ +package lsp + +import ( + "context" + "fmt" + "log/slog" + "os/exec" + "path/filepath" + "strings" + "sync" +) + +// Manager manages the lifecycle of LSP servers, routing requests by file extension. +type Manager struct { + workingDirFn func() string + configs []ServerConfig + extToConfig map[string]*ServerConfig // extension -> config + + mu sync.Mutex + servers map[string]*Server // config.Name -> running server +} + +// NewManager creates a new LSP server manager. +func NewManager(workingDirFn func() string) *Manager { + configs := DefaultServers() + extToConfig := make(map[string]*ServerConfig) + for i := range configs { + for _, ext := range configs[i].Extensions { + extToConfig[ext] = &configs[i] + } + } + return &Manager{ + workingDirFn: workingDirFn, + configs: configs, + extToConfig: extToConfig, + servers: make(map[string]*Server), + } +} + +// GetServer returns a running LSP server for the given file path, starting one if needed. +func (m *Manager) GetServer(ctx context.Context, filePath string) (*Server, error) { + ext := filepath.Ext(filePath) + cfg, ok := m.extToConfig[ext] + if !ok { + return nil, fmt.Errorf("no LSP server configured for %s files", ext) + } + + // Check if binary exists + if _, err := exec.LookPath(cfg.Command); err != nil { + return nil, fmt.Errorf("LSP server %q not found. %s", cfg.Command, cfg.InstallHint) + } + + rootURI := m.rootURI() + + m.mu.Lock() + defer m.mu.Unlock() + + // Check for existing server + if srv, ok := m.servers[cfg.Name]; ok { + if srv.Alive() && srv.RootURI() == rootURI { + return srv, nil + } + // Dead or stale root — shut down and restart + slog.Info("lsp: restarting server", "server", cfg.Name, "reason", "dead or root changed") + srv.Shutdown() + delete(m.servers, cfg.Name) + } + + // Start new server + srv, err := NewServer(ctx, *cfg, rootURI) + if err != nil { + return nil, err + } + m.servers[cfg.Name] = srv + slog.Info("lsp: started server", "server", cfg.Name, "rootURI", rootURI) + return srv, nil +} + +// ConfigForExt returns the server config for a file extension, or nil if none. +func (m *Manager) ConfigForExt(ext string) *ServerConfig { + return m.extToConfig[ext] +} + +// Close shuts down all running LSP servers. +func (m *Manager) Close() { + m.mu.Lock() + defer m.mu.Unlock() + for name, srv := range m.servers { + slog.Info("lsp: shutting down server", "server", name) + srv.Shutdown() + } + m.servers = make(map[string]*Server) +} + +func (m *Manager) rootURI() string { + wd := m.workingDirFn() + root, err := findRepoRoot(wd) + if err != nil { + root = wd + } + return fileURI(root) +} + +// findRepoRoot finds the git repository root from the given directory. +func findRepoRoot(wd string) (string, error) { + cmd := exec.Command("git", "rev-parse", "--show-toplevel") + cmd.Dir = wd + out, err := cmd.Output() + if err != nil { + return "", fmt.Errorf("failed to find git repository root: %w", err) + } + return strings.TrimSpace(string(out)), nil +} diff --git a/claudetool/lsp/manager_test.go b/claudetool/lsp/manager_test.go new file mode 100644 index 0000000..3c4693c --- /dev/null +++ b/claudetool/lsp/manager_test.go @@ -0,0 +1,127 @@ +package lsp + +import ( + "context" + "os/exec" + "path/filepath" + "testing" +) + +func TestManagerConfigForExt(t *testing.T) { + m := NewManager(func() string { return "/tmp" }) + + tests := []struct { + ext string + wantName string + wantNil bool + }{ + {".go", "gopls", false}, + {".ts", "typescript-language-server", false}, + {".tsx", "typescript-language-server", false}, + {".js", "typescript-language-server", false}, + {".jsx", "typescript-language-server", false}, + {".py", "", true}, + {".rs", "", true}, + {".txt", "", true}, + } + for _, tt := range tests { + cfg := m.ConfigForExt(tt.ext) + if tt.wantNil { + if cfg != nil { + t.Errorf("ConfigForExt(%q) = %v, want nil", tt.ext, cfg) + } + continue + } + if cfg == nil { + t.Errorf("ConfigForExt(%q) = nil, want %q", tt.ext, tt.wantName) + continue + } + if cfg.Name != tt.wantName { + t.Errorf("ConfigForExt(%q).Name = %q, want %q", tt.ext, cfg.Name, tt.wantName) + } + } +} + +func TestManagerGetServerMissingBinary(t *testing.T) { + m := NewManager(func() string { return t.TempDir() }) + + // Override configs to use a nonexistent binary + m.configs = []ServerConfig{ + { + Name: "fake-lsp", + Command: "definitely-not-a-real-binary-12345", + Args: []string{}, + Extensions: []string{".fake"}, + InstallHint: "Install fake-lsp: go install fake-lsp@latest", + }, + } + m.extToConfig = map[string]*ServerConfig{ + ".fake": &m.configs[0], + } + + ctx := context.Background() + _, err := m.GetServer(ctx, "/tmp/test.fake") + if err == nil { + t.Fatal("expected error for missing binary") + } + // Error should contain the install hint + if got := err.Error(); !contains(got, "Install fake-lsp") { + t.Errorf("error %q should contain install hint", got) + } +} + +func TestManagerGetServerNoConfig(t *testing.T) { + m := NewManager(func() string { return t.TempDir() }) + + ctx := context.Background() + _, err := m.GetServer(ctx, "/tmp/test.py") + if err == nil { + t.Fatal("expected error for unconfigured extension") + } +} + +func TestManagerGetServerReusesServer(t *testing.T) { + if _, err := exec.LookPath("gopls"); err != nil { + t.Skip("gopls not installed") + } + + dir := t.TempDir() + m := NewManager(func() string { return dir }) + defer m.Close() + + ctx := context.Background() + goFile := filepath.Join(dir, "main.go") + + srv1, err := m.GetServer(ctx, goFile) + if err != nil { + t.Fatalf("first GetServer: %v", err) + } + + srv2, err := m.GetServer(ctx, goFile) + if err != nil { + t.Fatalf("second GetServer: %v", err) + } + + if srv1 != srv2 { + t.Error("expected same server instance for repeat calls") + } +} + +func TestManagerClose(t *testing.T) { + m := NewManager(func() string { return t.TempDir() }) + // Close with no servers should not panic + m.Close() +} + +func contains(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsSubstr(s, substr)) +} + +func containsSubstr(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} diff --git a/claudetool/lsp/protocol.go b/claudetool/lsp/protocol.go new file mode 100644 index 0000000..05132ab --- /dev/null +++ b/claudetool/lsp/protocol.go @@ -0,0 +1,286 @@ +package lsp + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "os/exec" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" +) + +// Client is a JSON-RPC 2.0 client that communicates with an LSP server over stdin/stdout. +type Client struct { + cmd *exec.Cmd + stdin io.WriteCloser + stdout io.ReadCloser + + mu sync.Mutex + nextID atomic.Int64 + pending map[int64]chan *rpcResponse + closed chan struct{} + closeErr error +} + +type rpcRequest struct { + JSONRPC string `json:"jsonrpc"` + ID int64 `json:"id,omitempty"` + Method string `json:"method"` + Params any `json:"params,omitempty"` +} + +type rpcNotification struct { + JSONRPC string `json:"jsonrpc"` + Method string `json:"method"` + Params any `json:"params,omitempty"` +} + +type rpcResponse struct { + JSONRPC string `json:"jsonrpc"` + ID *json.RawMessage `json:"id,omitempty"` + Result json.RawMessage `json:"result,omitempty"` + Error *rpcError `json:"error,omitempty"` +} + +type rpcError struct { + Code int `json:"code"` + Message string `json:"message"` +} + +func (e *rpcError) Error() string { + return fmt.Sprintf("LSP error %d: %s", e.Code, e.Message) +} + +// NewClient starts the given command and returns a JSON-RPC 2.0 client connected to it. +func NewClient(cmd *exec.Cmd) (*Client, error) { + stdin, err := cmd.StdinPipe() + if err != nil { + return nil, fmt.Errorf("stdin pipe: %w", err) + } + stdout, err := cmd.StdoutPipe() + if err != nil { + return nil, fmt.Errorf("stdout pipe: %w", err) + } + if err := cmd.Start(); err != nil { + return nil, fmt.Errorf("start LSP server: %w", err) + } + c := &Client{ + cmd: cmd, + stdin: stdin, + stdout: stdout, + pending: make(map[int64]chan *rpcResponse), + closed: make(chan struct{}), + } + go c.readLoop() + return c, nil +} + +// Call sends a JSON-RPC request and waits for the response. +func (c *Client) Call(ctx context.Context, method string, params any, result any) error { + id := c.nextID.Add(1) + ch := make(chan *rpcResponse, 1) + + c.mu.Lock() + c.pending[id] = ch + c.mu.Unlock() + + defer func() { + c.mu.Lock() + delete(c.pending, id) + c.mu.Unlock() + }() + + req := rpcRequest{ + JSONRPC: "2.0", + ID: id, + Method: method, + Params: params, + } + if err := c.send(req); err != nil { + return fmt.Errorf("send %s: %w", method, err) + } + + select { + case <-ctx.Done(): + return ctx.Err() + case <-c.closed: + return fmt.Errorf("LSP server closed: %w", c.closeErr) + case resp := <-ch: + if resp.Error != nil { + return resp.Error + } + if result != nil && len(resp.Result) > 0 { + if err := json.Unmarshal(resp.Result, result); err != nil { + return fmt.Errorf("unmarshal %s result: %w", method, err) + } + } + return nil + } +} + +// Notify sends a JSON-RPC notification (no response expected). +func (c *Client) Notify(method string, params any) error { + n := rpcNotification{ + JSONRPC: "2.0", + Method: method, + Params: params, + } + return c.send(n) +} + +// Close shuts down the LSP server gracefully. +func (c *Client) Close() error { + // Send shutdown request with timeout (server may already be dead or unresponsive) + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + _ = c.Call(ctx, "shutdown", nil, nil) + cancel() + _ = c.Notify("exit", nil) + _ = c.stdin.Close() + + // Wait briefly for process to exit, then kill it + done := make(chan error, 1) + go func() { done <- c.cmd.Wait() }() + + var err error + select { + case err = <-done: + case <-time.After(3 * time.Second): + if c.cmd.Process != nil { + _ = c.cmd.Process.Kill() + } + err = <-done + } + + // Signal readLoop to stop + select { + case <-c.closed: + default: + c.closeErr = err + close(c.closed) + } + return err +} + +// Alive returns true if the underlying process is still running. +func (c *Client) Alive() bool { + select { + case <-c.closed: + return false + default: + return c.cmd.ProcessState == nil // not yet exited + } +} + +func (c *Client) send(msg any) error { + data, err := json.Marshal(msg) + if err != nil { + return err + } + header := fmt.Sprintf("Content-Length: %d\r\n\r\n", len(data)) + c.mu.Lock() + defer c.mu.Unlock() + if _, err := io.WriteString(c.stdin, header); err != nil { + return err + } + _, err = c.stdin.Write(data) + return err +} + +func (c *Client) readLoop() { + reader := bufio.NewReader(c.stdout) + for { + // Read headers + contentLength := -1 + for { + line, err := reader.ReadString('\n') + if err != nil { + c.mu.Lock() + c.closeErr = err + c.mu.Unlock() + select { + case <-c.closed: + default: + close(c.closed) + } + return + } + line = strings.TrimSpace(line) + if line == "" { + break // End of headers + } + if strings.HasPrefix(line, "Content-Length: ") { + n, err := strconv.Atoi(strings.TrimPrefix(line, "Content-Length: ")) + if err == nil { + contentLength = n + } + } + } + if contentLength < 0 { + continue + } + + // Read body + body := make([]byte, contentLength) + if _, err := io.ReadFull(reader, body); err != nil { + c.mu.Lock() + c.closeErr = err + c.mu.Unlock() + select { + case <-c.closed: + default: + close(c.closed) + } + return + } + + var resp rpcResponse + if err := json.Unmarshal(body, &resp); err != nil { + slog.Debug("lsp: failed to unmarshal response", "err", err) + continue + } + + // If no ID, it's a server notification — ignore + if resp.ID == nil { + continue + } + + var id int64 + if err := json.Unmarshal(*resp.ID, &id); err != nil { + slog.Debug("lsp: failed to unmarshal response ID", "err", err) + continue + } + + c.mu.Lock() + ch, ok := c.pending[id] + c.mu.Unlock() + if ok { + ch <- &resp + } + } +} + +// EncodeHeader returns the Content-Length header for a JSON-RPC message. +// Exported for testing. +func EncodeHeader(bodyLen int) string { + return fmt.Sprintf("Content-Length: %d\r\n\r\n", bodyLen) +} + +// DecodeHeader parses a Content-Length header value from a header line. +// Exported for testing. +func DecodeHeader(line string) (int, bool) { + val, ok := strings.CutPrefix(strings.TrimSpace(line), "Content-Length: ") + if !ok { + return 0, false + } + n, err := strconv.Atoi(val) + if err != nil { + return 0, false + } + return n, true +} diff --git a/claudetool/lsp/protocol_test.go b/claudetool/lsp/protocol_test.go new file mode 100644 index 0000000..5b9de3e --- /dev/null +++ b/claudetool/lsp/protocol_test.go @@ -0,0 +1,199 @@ +package lsp + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "io" + "os/exec" + "testing" + "time" +) + +func TestEncodeHeader(t *testing.T) { + got := EncodeHeader(42) + want := "Content-Length: 42\r\n\r\n" + if got != want { + t.Errorf("EncodeHeader(42) = %q, want %q", got, want) + } +} + +func TestDecodeHeader(t *testing.T) { + tests := []struct { + input string + want int + ok bool + }{ + {"Content-Length: 42", 42, true}, + {"Content-Length: 0", 0, true}, + {"Content-Length: 12345", 12345, true}, + {"content-type: json", 0, false}, + {"", 0, false}, + {"Content-Length: abc", 0, false}, + {" Content-Length: 10 ", 10, true}, + } + for _, tt := range tests { + got, ok := DecodeHeader(tt.input) + if got != tt.want || ok != tt.ok { + t.Errorf("DecodeHeader(%q) = (%d, %v), want (%d, %v)", tt.input, got, ok, tt.want, tt.ok) + } + } +} + +func TestRPCRequestEncoding(t *testing.T) { + req := rpcRequest{ + JSONRPC: "2.0", + ID: 1, + Method: "initialize", + Params: map[string]any{"processId": 123}, + } + data, err := json.Marshal(req) + if err != nil { + t.Fatal(err) + } + var parsed map[string]any + if err := json.Unmarshal(data, &parsed); err != nil { + t.Fatal(err) + } + if parsed["jsonrpc"] != "2.0" { + t.Errorf("jsonrpc = %v, want 2.0", parsed["jsonrpc"]) + } + if parsed["method"] != "initialize" { + t.Errorf("method = %v, want initialize", parsed["method"]) + } + if parsed["id"].(float64) != 1 { + t.Errorf("id = %v, want 1", parsed["id"]) + } +} + +func TestRPCNotificationEncoding(t *testing.T) { + notif := rpcNotification{ + JSONRPC: "2.0", + Method: "initialized", + } + data, err := json.Marshal(notif) + if err != nil { + t.Fatal(err) + } + var parsed map[string]any + if err := json.Unmarshal(data, &parsed); err != nil { + t.Fatal(err) + } + if _, hasID := parsed["id"]; hasID { + t.Error("notification should not have id field") + } + if parsed["method"] != "initialized" { + t.Errorf("method = %v, want initialized", parsed["method"]) + } +} + +func TestClientContextCancellation(t *testing.T) { + // Use `sleep` so nothing is echoed back to stdout + cmd := exec.Command("sleep", "60") + client, err := NewClient(cmd) + if err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + err = client.Call(ctx, "test/method", nil, nil) + if err == nil { + t.Fatal("expected error from cancelled context") + } + + // Kill the process directly — don't use Close() which tries graceful shutdown + if client.cmd.Process != nil { + client.cmd.Process.Kill() + } +} + +func TestClientAlive(t *testing.T) { + cmd := exec.Command("cat") + client, err := NewClient(cmd) + if err != nil { + t.Fatal(err) + } + + if !client.Alive() { + t.Error("expected client to be alive after creation") + } + + client.stdin.Close() + time.Sleep(100 * time.Millisecond) + + if client.Alive() { + t.Error("expected client to be dead after stdin close") + } +} + +func TestClientCallWithMockServer(t *testing.T) { + // Create pipes for mock communication + clientReader, serverWriter := io.Pipe() + serverReader, clientWriter := io.Pipe() + + cmd := exec.Command("true") // dummy, not used + client := &Client{ + cmd: cmd, + stdin: clientWriter, + stdout: clientReader, + pending: make(map[int64]chan *rpcResponse), + closed: make(chan struct{}), + } + go client.readLoop() + + // Mock server goroutine: read one JSON-RPC message and respond + go func() { + reader := bufio.NewReader(serverReader) + // Read Content-Length header + var contentLength int + for { + line, err := reader.ReadString('\n') + if err != nil { + return + } + if n, ok := DecodeHeader(line); ok { + contentLength = n + } + if line == "\r\n" || line == "\n" { + break + } + } + // Read body + body := make([]byte, contentLength) + if _, err := io.ReadFull(reader, body); err != nil { + return + } + + // Parse request to get ID + var req struct { + ID int64 `json:"id"` + } + json.Unmarshal(body, &req) + + // Send response + resp := fmt.Sprintf(`{"jsonrpc":"2.0","id":%d,"result":{"capabilities":{}}}`, req.ID) + header := fmt.Sprintf("Content-Length: %d\r\n\r\n", len(resp)) + serverWriter.Write([]byte(header + resp)) + }() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + var result map[string]any + err := client.Call(ctx, "initialize", map[string]any{"processId": 1}, &result) + if err != nil { + t.Fatalf("Call failed: %v", err) + } + if result == nil { + t.Fatal("expected non-nil result") + } + if _, ok := result["capabilities"]; !ok { + t.Error("expected capabilities in result") + } + + clientWriter.Close() + serverWriter.Close() +} diff --git a/claudetool/lsp/register.go b/claudetool/lsp/register.go new file mode 100644 index 0000000..6346483 --- /dev/null +++ b/claudetool/lsp/register.go @@ -0,0 +1,14 @@ +package lsp + +import "shelley.exe.dev/llm" + +// RegisterLSPTools creates the LSP code intelligence tools and returns them with a cleanup function. +// The cleanup function shuts down all running LSP servers. +func RegisterLSPTools(workingDirFn func() string) ([]*llm.Tool, func()) { + manager := NewManager(workingDirFn) + tool := &CodeIntelTool{ + manager: manager, + workingDir: workingDirFn, + } + return []*llm.Tool{tool.Tool()}, manager.Close +} diff --git a/claudetool/lsp/server.go b/claudetool/lsp/server.go new file mode 100644 index 0000000..d89475b --- /dev/null +++ b/claudetool/lsp/server.go @@ -0,0 +1,261 @@ +package lsp + +import ( + "context" + "fmt" + "log/slog" + "net/url" + "os" + "os/exec" + "path/filepath" + "sync" +) + +// ServerConfig describes how to start an LSP server for a given language. +type ServerConfig struct { + Name string // e.g., "gopls", "typescript-language-server" + Command string // binary name + Args []string // command-line arguments + Extensions []string // file extensions this server handles (e.g., ".go", ".ts") + InstallHint string // message shown if the binary is not found +} + +// DefaultServers returns built-in server configurations. +func DefaultServers() []ServerConfig { + return []ServerConfig{ + { + Name: "gopls", + Command: "gopls", + Args: []string{"serve"}, + Extensions: []string{".go"}, + InstallHint: "Install gopls: go install golang.org/x/tools/gopls@latest", + }, + { + Name: "typescript-language-server", + Command: "typescript-language-server", + Args: []string{"--stdio"}, + Extensions: []string{".ts", ".tsx", ".js", ".jsx"}, + InstallHint: "Install typescript-language-server: npm install -g typescript-language-server typescript", + }, + } +} + +// Server wraps a running LSP server process. +type Server struct { + client *Client + config ServerConfig + rootURI string + + mu sync.Mutex + openFiles map[string]int // URI -> version +} + +// NewServer starts an LSP server and initializes it with the given root URI. +func NewServer(ctx context.Context, config ServerConfig, rootURI string) (*Server, error) { + cmd := exec.CommandContext(ctx, config.Command, config.Args...) + cmd.Stderr = os.Stderr // let LSP server errors show + client, err := NewClient(cmd) + if err != nil { + return nil, fmt.Errorf("start %s: %w", config.Name, err) + } + + s := &Server{ + client: client, + config: config, + rootURI: rootURI, + openFiles: make(map[string]int), + } + + if err := s.initialize(ctx); err != nil { + client.Close() + return nil, fmt.Errorf("initialize %s: %w", config.Name, err) + } + + return s, nil +} + +func (s *Server) initialize(ctx context.Context) error { + params := InitializeParams{ + ProcessID: os.Getpid(), + RootURI: s.rootURI, + Capabilities: ClientCapabilities{ + TextDocument: &TextDocumentClientCapabilities{ + Definition: &DefinitionClientCapabilities{}, + }, + }, + } + + var result InitializeResult + if err := s.client.Call(ctx, "initialize", params, &result); err != nil { + return err + } + + slog.Debug("lsp: initialized", "server", s.config.Name, "rootURI", s.rootURI) + return s.client.Notify("initialized", struct{}{}) +} + +// OpenFile opens or refreshes a file in the LSP server by reading it from disk. +func (s *Server) OpenFile(ctx context.Context, filePath string) error { + uri := fileURI(filePath) + content, err := os.ReadFile(filePath) + if err != nil { + return fmt.Errorf("read file %s: %w", filePath, err) + } + + lang := languageID(filePath) + + s.mu.Lock() + version, isOpen := s.openFiles[uri] + if isOpen { + // File already open — send didChange with incremented version + version++ + s.openFiles[uri] = version + s.mu.Unlock() + return s.client.Notify("textDocument/didChange", DidChangeTextDocumentParams{ + TextDocument: VersionedTextDocumentIdentifier{URI: uri, Version: version}, + ContentChanges: []TextDocumentContentChangeEvent{ + {Text: string(content)}, + }, + }) + } + s.openFiles[uri] = 1 + s.mu.Unlock() + + return s.client.Notify("textDocument/didOpen", DidOpenTextDocumentParams{ + TextDocument: TextDocumentItem{ + URI: uri, + LanguageID: lang, + Version: 1, + Text: string(content), + }, + }) +} + +// CloseFile closes a file in the LSP server. +func (s *Server) CloseFile(uri string) { + s.mu.Lock() + delete(s.openFiles, uri) + s.mu.Unlock() + _ = s.client.Notify("textDocument/didClose", DidCloseTextDocumentParams{ + TextDocument: TextDocumentIdentifier{URI: uri}, + }) +} + +// Definition returns the definition location(s) for the symbol at the given position. +func (s *Server) Definition(ctx context.Context, uri string, pos Position) ([]Location, error) { + params := TextDocumentPositionParams{ + TextDocument: TextDocumentIdentifier{URI: uri}, + Position: pos, + } + // LSP spec: result is Location | Location[] | null + // Try array first, then single location + var locations []Location + if err := s.client.Call(ctx, "textDocument/definition", params, &locations); err != nil { + return nil, err + } + return locations, nil +} + +// References returns all references to the symbol at the given position. +func (s *Server) References(ctx context.Context, uri string, pos Position) ([]Location, error) { + params := ReferenceParams{ + TextDocument: TextDocumentIdentifier{URI: uri}, + Position: pos, + Context: ReferenceContext{IncludeDeclaration: true}, + } + var locations []Location + if err := s.client.Call(ctx, "textDocument/references", params, &locations); err != nil { + return nil, err + } + return locations, nil +} + +// HoverResult returns hover information for the symbol at the given position. +func (s *Server) HoverResult(ctx context.Context, uri string, pos Position) (*Hover, error) { + params := TextDocumentPositionParams{ + TextDocument: TextDocumentIdentifier{URI: uri}, + Position: pos, + } + var hover Hover + if err := s.client.Call(ctx, "textDocument/hover", params, &hover); err != nil { + return nil, err + } + return &hover, nil +} + +// WorkspaceSymbols searches for symbols matching the given query across the workspace. +func (s *Server) WorkspaceSymbols(ctx context.Context, query string) ([]SymbolInformation, error) { + params := WorkspaceSymbolParams{Query: query} + var symbols []SymbolInformation + if err := s.client.Call(ctx, "workspace/symbol", params, &symbols); err != nil { + return nil, err + } + return symbols, nil +} + +// Alive returns true if the underlying LSP process is still running. +func (s *Server) Alive() bool { + return s.client.Alive() +} + +// RootURI returns the root URI this server was initialized with. +func (s *Server) RootURI() string { + return s.rootURI +} + +// Shutdown gracefully shuts down the server. +func (s *Server) Shutdown() { + s.mu.Lock() + for uri := range s.openFiles { + _ = s.client.Notify("textDocument/didClose", DidCloseTextDocumentParams{ + TextDocument: TextDocumentIdentifier{URI: uri}, + }) + } + s.openFiles = make(map[string]int) + s.mu.Unlock() + _ = s.client.Close() +} + +// fileURI converts an absolute file path to a file:// URI. +func fileURI(path string) string { + u := &url.URL{Scheme: "file", Path: path} + return u.String() +} + +// filePathFromURI converts a file:// URI back to a file path. +func filePathFromURI(uri string) string { + u, err := url.Parse(uri) + if err != nil || u.Scheme != "file" { + return uri + } + return u.Path +} + +// languageID returns the LSP language ID for a file path based on its extension. +func languageID(path string) string { + ext := filepath.Ext(path) + switch ext { + case ".go": + return "go" + case ".ts": + return "typescript" + case ".tsx": + return "typescriptreact" + case ".js": + return "javascript" + case ".jsx": + return "javascriptreact" + case ".py": + return "python" + case ".rs": + return "rust" + case ".java": + return "java" + case ".c", ".h": + return "c" + case ".cpp", ".hpp", ".cc": + return "cpp" + default: + return "plaintext" + } +} diff --git a/claudetool/lsp/tool.go b/claudetool/lsp/tool.go new file mode 100644 index 0000000..81141cd --- /dev/null +++ b/claudetool/lsp/tool.go @@ -0,0 +1,191 @@ +package lsp + +import ( + "context" + "encoding/json" + "path/filepath" + "strings" + + "shelley.exe.dev/llm" +) + +const ( + codeIntelName = "code_intelligence" + codeIntelDescription = `Provides compiler-accurate code intelligence powered by Language Server Protocol (LSP). + +Operations: +- definition: Go to the definition of a symbol at a given file position +- references: Find all references to a symbol at a given file position +- hover: Get type information and documentation for a symbol at a given file position +- symbols: Search for symbols (functions, types, variables) across the workspace + +Use this for precise, semantic code navigation. For text-based search, use keyword_search instead. +Requires an LSP server installed for the file's language (e.g., gopls for Go, typescript-language-server for TypeScript). + +Note: The first call for a language may be slow while the LSP server starts and indexes the workspace. +` + codeIntelInputSchema = `{ + "type": "object", + "required": ["operation"], + "properties": { + "operation": { + "type": "string", + "enum": ["definition", "references", "hover", "symbols"], + "description": "The code intelligence operation to perform" + }, + "file": { + "type": "string", + "description": "File path (absolute or relative to working directory). Required for definition, references, hover." + }, + "line": { + "type": "integer", + "description": "Line number (1-based). Required for definition, references, hover." + }, + "column": { + "type": "integer", + "description": "Column number (1-based). Required for definition, references, hover." + }, + "query": { + "type": "string", + "description": "Symbol name to search for. Required for symbols operation." + } + } +}` +) + +type codeIntelInput struct { + Operation string `json:"operation"` + File string `json:"file"` + Line int `json:"line"` + Column int `json:"column"` + Query string `json:"query"` +} + +// CodeIntelTool provides LSP-based code intelligence. +type CodeIntelTool struct { + manager *Manager + workingDir func() string +} + +// Tool returns the llm.Tool definition for code intelligence. +func (c *CodeIntelTool) Tool() *llm.Tool { + return &llm.Tool{ + Name: codeIntelName, + Description: codeIntelDescription, + InputSchema: llm.MustSchema(codeIntelInputSchema), + Run: c.Run, + } +} + +// Run executes the code intelligence tool. +func (c *CodeIntelTool) Run(ctx context.Context, m json.RawMessage) llm.ToolOut { + var input codeIntelInput + if err := json.Unmarshal(m, &input); err != nil { + return llm.ErrorfToolOut("failed to parse input: %w", err) + } + + switch input.Operation { + case "definition", "references", "hover": + return c.runPositionOp(ctx, input) + case "symbols": + return c.runSymbols(ctx, input) + default: + return llm.ErrorfToolOut("unknown operation %q: must be one of definition, references, hover, symbols", input.Operation) + } +} + +func (c *CodeIntelTool) runPositionOp(ctx context.Context, input codeIntelInput) llm.ToolOut { + if input.File == "" { + return llm.ErrorfToolOut("file is required for %s operation", input.Operation) + } + if input.Line < 1 { + return llm.ErrorfToolOut("line is required and must be >= 1 for %s operation", input.Operation) + } + if input.Column < 1 { + return llm.ErrorfToolOut("column is required and must be >= 1 for %s operation", input.Operation) + } + + // Resolve relative paths + filePath := input.File + if !filepath.IsAbs(filePath) { + filePath = filepath.Join(c.workingDir(), filePath) + } + filePath = filepath.Clean(filePath) + + // Get LSP server + srv, err := c.manager.GetServer(ctx, filePath) + if err != nil { + return llm.ErrorfToolOut("%s", err) + } + + // Open/refresh file in LSP + if err := srv.OpenFile(ctx, filePath); err != nil { + return llm.ErrorfToolOut("failed to open file in LSP: %s", err) + } + + uri := fileURI(filePath) + // Convert 1-based input to 0-based LSP positions + pos := Position{ + Line: input.Line - 1, + Character: input.Column - 1, + } + + wd := c.workingDir() + + switch input.Operation { + case "definition": + locations, err := srv.Definition(ctx, uri, pos) + if err != nil { + return llm.ErrorfToolOut("definition failed: %s", err) + } + return llm.ToolOut{LLMContent: llm.TextContent(formatDefinition(locations, wd))} + + case "references": + locations, err := srv.References(ctx, uri, pos) + if err != nil { + return llm.ErrorfToolOut("references failed: %s", err) + } + return llm.ToolOut{LLMContent: llm.TextContent(formatReferences(locations, wd))} + + case "hover": + hover, err := srv.HoverResult(ctx, uri, pos) + if err != nil { + return llm.ErrorfToolOut("hover failed: %s", err) + } + return llm.ToolOut{LLMContent: llm.TextContent(formatHover(hover))} + } + + return llm.ErrorfToolOut("unreachable") +} + +func (c *CodeIntelTool) runSymbols(ctx context.Context, input codeIntelInput) llm.ToolOut { + if input.Query == "" { + return llm.ErrorfToolOut("query is required for symbols operation") + } + + // For symbols, we need any server — use the working dir to pick one. + // Try to find an appropriate server by looking for common files. + // Fall back to gopls if available. + wd := c.workingDir() + + // Try to get a server. For workspace symbols, we'll try Go first then TS. + var srv *Server + var errs []string + for _, ext := range []string{".go", ".ts"} { + var err error + srv, err = c.manager.GetServer(ctx, filepath.Join(wd, "dummy"+ext)) + if err == nil { + break + } + errs = append(errs, err.Error()) + } + if srv == nil { + return llm.ErrorfToolOut("no LSP server available for symbols. Tried: %s", strings.Join(errs, "; ")) + } + + symbols, err := srv.WorkspaceSymbols(ctx, input.Query) + if err != nil { + return llm.ErrorfToolOut("workspace symbols failed: %s", err) + } + return llm.ToolOut{LLMContent: llm.TextContent(formatSymbols(symbols, wd))} +} diff --git a/claudetool/lsp/tool_test.go b/claudetool/lsp/tool_test.go new file mode 100644 index 0000000..63a75f4 --- /dev/null +++ b/claudetool/lsp/tool_test.go @@ -0,0 +1,151 @@ +package lsp + +import ( + "context" + "encoding/json" + "testing" +) + +func TestToolInputValidation(t *testing.T) { + tool := &CodeIntelTool{ + manager: NewManager(func() string { return t.TempDir() }), + workingDir: func() string { return t.TempDir() }, + } + + tests := []struct { + name string + input codeIntelInput + wantErr string + }{ + { + name: "unknown operation", + input: codeIntelInput{Operation: "unknown"}, + wantErr: "unknown operation", + }, + { + name: "definition missing file", + input: codeIntelInput{Operation: "definition", Line: 1, Column: 1}, + wantErr: "file is required", + }, + { + name: "definition missing line", + input: codeIntelInput{Operation: "definition", File: "test.go", Column: 1}, + wantErr: "line is required", + }, + { + name: "definition missing column", + input: codeIntelInput{Operation: "definition", File: "test.go", Line: 1}, + wantErr: "column is required", + }, + { + name: "references missing file", + input: codeIntelInput{Operation: "references", Line: 1, Column: 1}, + wantErr: "file is required", + }, + { + name: "hover missing file", + input: codeIntelInput{Operation: "hover", Line: 1, Column: 1}, + wantErr: "file is required", + }, + { + name: "symbols missing query", + input: codeIntelInput{Operation: "symbols"}, + wantErr: "query is required", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, err := json.Marshal(tt.input) + if err != nil { + t.Fatal(err) + } + result := tool.Run(context.Background(), data) + if result.Error == nil { + t.Fatal("expected error") + } + if got := result.Error.Error(); !containsSubstr(got, tt.wantErr) { + t.Errorf("error %q should contain %q", got, tt.wantErr) + } + }) + } +} + +func TestToolPositionConversion(t *testing.T) { + // Test that 1-based input converts to 0-based LSP positions + // by checking the input struct parsing + input := codeIntelInput{ + Operation: "definition", + File: "test.go", + Line: 10, + Column: 5, + } + + // 1-based line 10 should become 0-based line 9 + pos := Position{ + Line: input.Line - 1, + Character: input.Column - 1, + } + if pos.Line != 9 { + t.Errorf("position line = %d, want 9", pos.Line) + } + if pos.Character != 4 { + t.Errorf("position character = %d, want 4", pos.Character) + } +} + +func TestToolRelativePathResolution(t *testing.T) { + wd := "/home/user/project" + tool := &CodeIntelTool{ + manager: NewManager(func() string { return wd }), + workingDir: func() string { return wd }, + } + + // The tool should resolve relative paths against working dir. + // We can't fully test this without an LSP server, but we can verify + // the input parsing doesn't fail on relative paths. + input := codeIntelInput{ + Operation: "definition", + File: "src/main.go", + Line: 1, + Column: 1, + } + data, err := json.Marshal(input) + if err != nil { + t.Fatal(err) + } + result := tool.Run(context.Background(), data) + // Should get an error about LSP server, not about path resolution + if result.Error == nil { + t.Fatal("expected error (no LSP server)") + } + errMsg := result.Error.Error() + if containsSubstr(errMsg, "file is required") || containsSubstr(errMsg, "line is required") { + t.Errorf("unexpected validation error: %s", errMsg) + } +} + +func TestToolSchemaIsValid(t *testing.T) { + tool := &CodeIntelTool{ + manager: NewManager(func() string { return "/tmp" }), + workingDir: func() string { return "/tmp" }, + } + llmTool := tool.Tool() + if llmTool.Name != "code_intelligence" { + t.Errorf("tool name = %q, want %q", llmTool.Name, "code_intelligence") + } + if llmTool.Run == nil { + t.Error("tool Run function is nil") + } + if len(llmTool.InputSchema) == 0 { + t.Error("tool InputSchema is empty") + } + + // Verify schema is valid JSON + var schema map[string]any + if err := json.Unmarshal(llmTool.InputSchema, &schema); err != nil { + t.Fatalf("schema is not valid JSON: %v", err) + } + if schema["type"] != "object" { + t.Errorf("schema type = %v, want object", schema["type"]) + } +} diff --git a/claudetool/lsp/types.go b/claudetool/lsp/types.go new file mode 100644 index 0000000..190fca7 --- /dev/null +++ b/claudetool/lsp/types.go @@ -0,0 +1,212 @@ +package lsp + +// LSP protocol types — minimal subset needed for code intelligence operations. + +// Position in a text document (0-based line and character). +type Position struct { + Line int `json:"line"` + Character int `json:"character"` +} + +// Range in a text document. +type Range struct { + Start Position `json:"start"` + End Position `json:"end"` +} + +// Location represents a location inside a resource. +type Location struct { + URI string `json:"uri"` + Range Range `json:"range"` +} + +// TextDocumentIdentifier identifies a text document by its URI. +type TextDocumentIdentifier struct { + URI string `json:"uri"` +} + +// TextDocumentItem is an item to transfer a text document from client to server. +type TextDocumentItem struct { + URI string `json:"uri"` + LanguageID string `json:"languageId"` + Version int `json:"version"` + Text string `json:"text"` +} + +// TextDocumentPositionParams is a parameter literal used in requests that require a position in a text document. +type TextDocumentPositionParams struct { + TextDocument TextDocumentIdentifier `json:"textDocument"` + Position Position `json:"position"` +} + +// ReferenceContext controls whether declarations should be included in references results. +type ReferenceContext struct { + IncludeDeclaration bool `json:"includeDeclaration"` +} + +// ReferenceParams extends TextDocumentPositionParams with reference context. +type ReferenceParams struct { + TextDocument TextDocumentIdentifier `json:"textDocument"` + Position Position `json:"position"` + Context ReferenceContext `json:"context"` +} + +// WorkspaceSymbolParams is the params for a workspace/symbol request. +type WorkspaceSymbolParams struct { + Query string `json:"query"` +} + +// SymbolKind represents the kind of a symbol. +type SymbolKind int + +const ( + SymbolKindFile SymbolKind = 1 + SymbolKindModule SymbolKind = 2 + SymbolKindNamespace SymbolKind = 3 + SymbolKindPackage SymbolKind = 4 + SymbolKindClass SymbolKind = 5 + SymbolKindMethod SymbolKind = 6 + SymbolKindProperty SymbolKind = 7 + SymbolKindField SymbolKind = 8 + SymbolKindConstructor SymbolKind = 9 + SymbolKindEnum SymbolKind = 10 + SymbolKindInterface SymbolKind = 11 + SymbolKindFunction SymbolKind = 12 + SymbolKindVariable SymbolKind = 13 + SymbolKindConstant SymbolKind = 14 + SymbolKindString SymbolKind = 15 + SymbolKindNumber SymbolKind = 16 + SymbolKindBoolean SymbolKind = 17 + SymbolKindArray SymbolKind = 18 + SymbolKindObject SymbolKind = 19 + SymbolKindKey SymbolKind = 20 + SymbolKindNull SymbolKind = 21 + SymbolKindEnumMember SymbolKind = 22 + SymbolKindStruct SymbolKind = 23 + SymbolKindEvent SymbolKind = 24 + SymbolKindOperator SymbolKind = 25 + SymbolKindTypeParameter SymbolKind = 26 +) + +// SymbolKindName returns a human-readable name for a SymbolKind. +func SymbolKindName(k SymbolKind) string { + names := map[SymbolKind]string{ + SymbolKindFile: "File", + SymbolKindModule: "Module", + SymbolKindNamespace: "Namespace", + SymbolKindPackage: "Package", + SymbolKindClass: "Class", + SymbolKindMethod: "Method", + SymbolKindProperty: "Property", + SymbolKindField: "Field", + SymbolKindConstructor: "Constructor", + SymbolKindEnum: "Enum", + SymbolKindInterface: "Interface", + SymbolKindFunction: "Function", + SymbolKindVariable: "Variable", + SymbolKindConstant: "Constant", + SymbolKindString: "String", + SymbolKindNumber: "Number", + SymbolKindBoolean: "Boolean", + SymbolKindArray: "Array", + SymbolKindObject: "Object", + SymbolKindKey: "Key", + SymbolKindNull: "Null", + SymbolKindEnumMember: "EnumMember", + SymbolKindStruct: "Struct", + SymbolKindEvent: "Event", + SymbolKindOperator: "Operator", + SymbolKindTypeParameter: "TypeParameter", + } + if name, ok := names[k]; ok { + return name + } + return "Unknown" +} + +// SymbolInformation represents information about a programming construct. +type SymbolInformation struct { + Name string `json:"name"` + Kind SymbolKind `json:"kind"` + Location Location `json:"location"` + ContainerName string `json:"containerName,omitempty"` +} + +// Hover is the result of a hover request. +type Hover struct { + Contents MarkupContent `json:"contents"` + Range *Range `json:"range,omitempty"` +} + +// MarkupContent represents a string value with a specific content type. +type MarkupContent struct { + Kind string `json:"kind"` + Value string `json:"value"` +} + +// InitializeParams is sent as the first request from client to server. +type InitializeParams struct { + ProcessID int `json:"processId"` + RootURI string `json:"rootUri"` + Capabilities ClientCapabilities `json:"capabilities"` +} + +// ClientCapabilities define capabilities the editor / tool provides. +type ClientCapabilities struct { + TextDocument *TextDocumentClientCapabilities `json:"textDocument,omitempty"` +} + +// TextDocumentClientCapabilities define capabilities the editor / tool provides on text documents. +type TextDocumentClientCapabilities struct { + Definition *DefinitionClientCapabilities `json:"definition,omitempty"` +} + +// DefinitionClientCapabilities indicates whether definition supports dynamic registration. +type DefinitionClientCapabilities struct { + DynamicRegistration bool `json:"dynamicRegistration,omitempty"` +} + +// InitializeResult is the result returned from the initialize request. +type InitializeResult struct { + Capabilities ServerCapabilities `json:"capabilities"` +} + +// ServerCapabilities define capabilities the language server provides. +type ServerCapabilities struct { + TextDocumentSync any `json:"textDocumentSync,omitempty"` + DefinitionProvider bool `json:"definitionProvider,omitempty"` + ReferencesProvider bool `json:"referencesProvider,omitempty"` + HoverProvider bool `json:"hoverProvider,omitempty"` + WorkspaceSymbolProvider bool `json:"workspaceSymbolProvider,omitempty"` + DocumentSymbolProvider bool `json:"documentSymbolProvider,omitempty"` + CompletionProvider any `json:"completionProvider,omitempty"` + SignatureHelpProvider any `json:"signatureHelpProvider,omitempty"` + DocumentFormattingProvider bool `json:"documentFormattingProvider,omitempty"` +} + +// DidOpenTextDocumentParams is sent when a text document is opened. +type DidOpenTextDocumentParams struct { + TextDocument TextDocumentItem `json:"textDocument"` +} + +// DidCloseTextDocumentParams is sent when a text document is closed. +type DidCloseTextDocumentParams struct { + TextDocument TextDocumentIdentifier `json:"textDocument"` +} + +// TextDocumentContentChangeEvent describes an event describing a change to a text document. +type TextDocumentContentChangeEvent struct { + Text string `json:"text"` +} + +// VersionedTextDocumentIdentifier is a text document identifier with a version. +type VersionedTextDocumentIdentifier struct { + URI string `json:"uri"` + Version int `json:"version"` +} + +// DidChangeTextDocumentParams is sent when the content of a text document changes. +type DidChangeTextDocumentParams struct { + TextDocument VersionedTextDocumentIdentifier `json:"textDocument"` + ContentChanges []TextDocumentContentChangeEvent `json:"contentChanges"` +} diff --git a/claudetool/toolset.go b/claudetool/toolset.go index 31294f2..1add4bc 100644 --- a/claudetool/toolset.go +++ b/claudetool/toolset.go @@ -6,6 +6,7 @@ import ( "sync" "shelley.exe.dev/claudetool/browse" + "shelley.exe.dev/claudetool/lsp" "shelley.exe.dev/llm" ) @@ -44,6 +45,8 @@ type ToolSetConfig struct { EnableJITInstall bool // EnableBrowser enables browser tools. EnableBrowser bool + // EnableCodeIntelligence enables LSP-based code intelligence tools. + EnableCodeIntelligence bool // ModelID is the model being used for this conversation. // Used to determine tool configuration (e.g., simplified patch schema for weaker models). ModelID string @@ -154,7 +157,8 @@ func NewToolSet(ctx context.Context, cfg ToolSetConfig) *ToolSet { tools = append(tools, subagentTool.Tool()) } - var cleanup func() + var cleanups []func() + if cfg.EnableBrowser { // Get max image dimension from the LLM service maxImageDimension := 0 @@ -167,7 +171,22 @@ func NewToolSet(ctx context.Context, cfg ToolSetConfig) *ToolSet { if len(browserTools) > 0 { tools = append(tools, browserTools...) } - cleanup = browserCleanup + cleanups = append(cleanups, browserCleanup) + } + + if cfg.EnableCodeIntelligence { + lspTools, lspCleanup := lsp.RegisterLSPTools(wd.Get) + tools = append(tools, lspTools...) + cleanups = append(cleanups, lspCleanup) + } + + var cleanup func() + if len(cleanups) > 0 { + cleanup = func() { + for _, fn := range cleanups { + fn() + } + } } return &ToolSet{ diff --git a/cmd/shelley/main.go b/cmd/shelley/main.go index d26082b..a677f0e 100644 --- a/cmd/shelley/main.go +++ b/cmd/shelley/main.go @@ -239,10 +239,11 @@ func setupToolSetConfig(llmProvider claudetool.LLMServiceProvider) claudetool.To wd = "/" } return claudetool.ToolSetConfig{ - WorkingDir: wd, - LLMProvider: llmProvider, - EnableJITInstall: claudetool.EnableBashToolJITInstall, - EnableBrowser: true, + WorkingDir: wd, + LLMProvider: llmProvider, + EnableJITInstall: claudetool.EnableBashToolJITInstall, + EnableBrowser: true, + EnableCodeIntelligence: true, } }