diff --git a/pkg/github/discussions.go b/pkg/github/discussions.go index a7ec8e20f..feff61068 100644 --- a/pkg/github/discussions.go +++ b/pkg/github/discussions.go @@ -13,6 +13,14 @@ import ( "github.com/shurcooL/githubv4" ) +func discussionOwnerOption() mcp.ToolOption { + return mcp.WithString("owner", mcp.Required(), mcp.Description(DescriptionRepositoryOwner)) +} + +func discussionRepoOption() mcp.ToolOption { + return mcp.WithString("repo", mcp.Required(), mcp.Description(DescriptionRepositoryName)) +} + func ListDiscussions(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("list_discussions", mcp.WithDescription(t("TOOL_LIST_DISCUSSIONS_DESCRIPTION", "List discussions for a repository")), @@ -20,14 +28,8 @@ func ListDiscussions(getGQLClient GetGQLClientFn, t translations.TranslationHelp Title: t("TOOL_LIST_DISCUSSIONS_USER_TITLE", "List discussions"), ReadOnlyHint: ToBoolPtr(true), }), - mcp.WithString("owner", - mcp.Required(), - mcp.Description("Repository owner"), - ), - mcp.WithString("repo", - mcp.Required(), - mcp.Description("Repository name"), - ), + discussionOwnerOption(), + discussionRepoOption(), mcp.WithString("category", mcp.Description("Optional filter by discussion category ID. If provided, only discussions with this category are listed."), ), @@ -162,14 +164,8 @@ func GetDiscussion(getGQLClient GetGQLClientFn, t translations.TranslationHelper Title: t("TOOL_GET_DISCUSSION_USER_TITLE", "Get discussion"), ReadOnlyHint: ToBoolPtr(true), }), - mcp.WithString("owner", - mcp.Required(), - mcp.Description("Repository owner"), - ), - mcp.WithString("repo", - mcp.Required(), - mcp.Description("Repository name"), - ), + discussionOwnerOption(), + discussionRepoOption(), mcp.WithNumber("discussionNumber", mcp.Required(), mcp.Description("Discussion Number"), @@ -241,8 +237,8 @@ func GetDiscussionComments(getGQLClient GetGQLClientFn, t translations.Translati Title: t("TOOL_GET_DISCUSSION_COMMENTS_USER_TITLE", "Get discussion comments"), ReadOnlyHint: ToBoolPtr(true), }), - mcp.WithString("owner", mcp.Required(), mcp.Description("Repository owner")), - mcp.WithString("repo", mcp.Required(), mcp.Description("Repository name")), + discussionOwnerOption(), + discussionRepoOption(), mcp.WithNumber("discussionNumber", mcp.Required(), mcp.Description("Discussion Number")), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { @@ -301,14 +297,8 @@ func ListDiscussionCategories(getGQLClient GetGQLClientFn, t translations.Transl Title: t("TOOL_LIST_DISCUSSION_CATEGORIES_USER_TITLE", "List discussion categories"), ReadOnlyHint: ToBoolPtr(true), }), - mcp.WithString("owner", - mcp.Required(), - mcp.Description("Repository owner"), - ), - mcp.WithString("repo", - mcp.Required(), - mcp.Description("Repository name"), - ), + discussionOwnerOption(), + discussionRepoOption(), mcp.WithNumber("first", mcp.Description("Number of categories to return per page (min 1, max 100)"), mcp.Min(1), diff --git a/pkg/github/repository_resource.go b/pkg/github/repository_resource.go index a454db630..b1ffb22c5 100644 --- a/pkg/github/repository_resource.go +++ b/pkg/github/repository_resource.go @@ -64,129 +64,153 @@ func GetRepositoryResourcePrContent(getClient GetClientFn, getRawClient raw.GetR RepositoryResourceContentsHandler(getClient, getRawClient) } -// RepositoryResourceContentsHandler returns a handler function for repository content requests. -func RepositoryResourceContentsHandler(getClient GetClientFn, getRawClient raw.GetRawClientFn) func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { - return func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { - // the matcher will give []string with one element - // https://github.com/mark3labs/mcp-go/pull/54 - o, ok := request.Params.Arguments["owner"].([]string) - if !ok || len(o) == 0 { - return nil, errors.New("owner is required") - } - owner := o[0] +// extractStringArg extracts a string argument from the request arguments +func extractStringArg(args map[string]any, key string) (string, bool) { + v, ok := args[key].([]string) + if !ok || len(v) == 0 { + return "", false + } + return v[0], true +} - r, ok := request.Params.Arguments["repo"].([]string) - if !ok || len(r) == 0 { - return nil, errors.New("repo is required") - } - repo := r[0] +// extractPathArg extracts and joins path parts from the request arguments +func extractPathArg(args map[string]any) string { + p, ok := args["path"].([]string) + if !ok { + return "" + } + return strings.Join(p, "/") +} - // path should be a joined list of the path parts - path := "" - p, ok := request.Params.Arguments["path"].([]string) - if ok { - path = strings.Join(p, "/") - } +// resolveRefOptions resolves the ref options based on the request arguments +func resolveRefOptions(ctx context.Context, args map[string]any, owner, repo string, getClient GetClientFn) (*github.RepositoryContentGetOptions, *raw.ContentOpts, error) { + opts := &github.RepositoryContentGetOptions{} + rawOpts := &raw.ContentOpts{} - opts := &github.RepositoryContentGetOptions{} - rawOpts := &raw.ContentOpts{} + if sha, ok := extractStringArg(args, "sha"); ok { + opts.Ref = sha + rawOpts.SHA = sha + return opts, rawOpts, nil + } + + if branch, ok := extractStringArg(args, "branch"); ok { + opts.Ref = "refs/heads/" + branch + rawOpts.Ref = "refs/heads/" + branch + return opts, rawOpts, nil + } - sha, ok := request.Params.Arguments["sha"].([]string) - if ok && len(sha) > 0 { - opts.Ref = sha[0] - rawOpts.SHA = sha[0] + if tag, ok := extractStringArg(args, "tag"); ok { + opts.Ref = "refs/tags/" + tag + rawOpts.Ref = "refs/tags/" + tag + return opts, rawOpts, nil + } + + if prNumberStr, ok := extractStringArg(args, "prNumber"); ok { + githubClient, err := getClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + prNum, err := strconv.Atoi(prNumberStr) + if err != nil { + return nil, nil, fmt.Errorf("invalid pull request number: %w", err) + } + pr, _, err := githubClient.PullRequests.Get(ctx, owner, repo, prNum) + if err != nil { + return nil, nil, fmt.Errorf("failed to get pull request: %w", err) } + sha := pr.GetHead().GetSHA() + rawOpts.SHA = sha + opts.Ref = sha + } + + return opts, rawOpts, nil +} + +// determineMimeType determines the MIME type for a file based on extension and response headers +func determineMimeType(path string, contentTypeHeader string) string { + ext := filepath.Ext(path) + if ext == ".md" { + return "text/markdown" + } + if contentTypeHeader != "" { + return contentTypeHeader + } + return mime.TypeByExtension(ext) +} - branch, ok := request.Params.Arguments["branch"].([]string) - if ok && len(branch) > 0 { - opts.Ref = "refs/heads/" + branch[0] - rawOpts.Ref = "refs/heads/" + branch[0] +// buildResourceContents builds the appropriate resource contents based on MIME type +func buildResourceContents(uri, mimeType string, content []byte) []mcp.ResourceContents { + if strings.HasPrefix(mimeType, "text") || strings.HasPrefix(mimeType, "application") { + return []mcp.ResourceContents{ + mcp.TextResourceContents{ + URI: uri, + MIMEType: mimeType, + Text: string(content), + }, } + } + return []mcp.ResourceContents{ + mcp.BlobResourceContents{ + URI: uri, + MIMEType: mimeType, + Blob: base64.StdEncoding.EncodeToString(content), + }, + } +} + +// RepositoryResourceContentsHandler returns a handler function for repository content requests. +func RepositoryResourceContentsHandler(getClient GetClientFn, getRawClient raw.GetRawClientFn) func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + return func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + args := request.Params.Arguments - tag, ok := request.Params.Arguments["tag"].([]string) - if ok && len(tag) > 0 { - opts.Ref = "refs/tags/" + tag[0] - rawOpts.Ref = "refs/tags/" + tag[0] + owner, ok := extractStringArg(args, "owner") + if !ok { + return nil, errors.New("owner is required") } - prNumber, ok := request.Params.Arguments["prNumber"].([]string) - if ok && len(prNumber) > 0 { - // fetch the PR from the API to get the latest commit and use SHA - githubClient, err := getClient(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - prNum, err := strconv.Atoi(prNumber[0]) - if err != nil { - return nil, fmt.Errorf("invalid pull request number: %w", err) - } - pr, _, err := githubClient.PullRequests.Get(ctx, owner, repo, prNum) - if err != nil { - return nil, fmt.Errorf("failed to get pull request: %w", err) - } - sha := pr.GetHead().GetSHA() - rawOpts.SHA = sha - opts.Ref = sha + + repo, ok := extractStringArg(args, "repo") + if !ok { + return nil, errors.New("repo is required") } - // if it's a directory + + path := extractPathArg(args) if path == "" || strings.HasSuffix(path, "/") { return nil, fmt.Errorf("directories are not supported: %s", path) } - rawClient, err := getRawClient(ctx) + _, rawOpts, err := resolveRefOptions(ctx, args, owner, repo, getClient) + if err != nil { + return nil, err + } + + rawClient, err := getRawClient(ctx) if err != nil { return nil, fmt.Errorf("failed to get GitHub raw content client: %w", err) } resp, err := rawClient.GetRawContent(ctx, owner, repo, path, rawOpts) - defer func() { - _ = resp.Body.Close() - }() - // If the raw content is not found, we will fall back to the GitHub API (in case it is a directory) - switch { - case err != nil: + if err != nil { return nil, fmt.Errorf("failed to get raw content: %w", err) - case resp.StatusCode == http.StatusOK: - ext := filepath.Ext(path) - mimeType := resp.Header.Get("Content-Type") - if ext == ".md" { - mimeType = "text/markdown" - } else if mimeType == "" { - mimeType = mime.TypeByExtension(ext) - } + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode == http.StatusOK { content, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read file content: %w", err) } + mimeType := determineMimeType(path, resp.Header.Get("Content-Type")) + return buildResourceContents(request.Params.URI, mimeType, content), nil + } - switch { - case strings.HasPrefix(mimeType, "text"), strings.HasPrefix(mimeType, "application"): - return []mcp.ResourceContents{ - mcp.TextResourceContents{ - URI: request.Params.URI, - MIMEType: mimeType, - Text: string(content), - }, - }, nil - default: - return []mcp.ResourceContents{ - mcp.BlobResourceContents{ - URI: request.Params.URI, - MIMEType: mimeType, - Blob: base64.StdEncoding.EncodeToString(content), - }, - }, nil - } - case resp.StatusCode != http.StatusNotFound: - // If we got a response but it is not 200 OK, we return an error + if resp.StatusCode != http.StatusNotFound { body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) } return nil, fmt.Errorf("failed to fetch raw content: %s", string(body)) - default: - // This should be unreachable because GetContents should return an error if neither file nor directory content is found. - return nil, errors.New("404 Not Found") } + + return nil, errors.New("404 Not Found") } }