Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 148 additions & 103 deletions pkg/github/repository_resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,129 +64,174 @@ func GetRepositoryResourcePrContent(getClient GetClientFn, getRawClient raw.GetR
RepositoryResourceContentsHandler(getClient, getRawClient)
}

// resourceRequestParams holds the extracted parameters from a resource request
type resourceRequestParams struct {
owner string
repo string
path string
}

// extractRequiredStringArg extracts a required string argument from the request
func extractRequiredStringArg(args map[string]interface{}, key string) (string, error) {
val, ok := args[key].([]string)
if !ok || len(val) == 0 {
return "", fmt.Errorf("%s is required", key)
}
return val[0], nil
}

// extractOptionalStringArg extracts an optional string argument from the request
func extractOptionalStringArg(args map[string]interface{}, key string) (string, bool) {
val, ok := args[key].([]string)
if !ok || len(val) == 0 {
return "", false
}
return val[0], true
}

// extractResourceParams extracts owner, repo, and path from the request arguments
func extractResourceParams(args map[string]interface{}) (*resourceRequestParams, error) {
owner, err := extractRequiredStringArg(args, "owner")
if err != nil {
return nil, err
}
repo, err := extractRequiredStringArg(args, "repo")
if err != nil {
return nil, err
}
path := ""
if p, ok := args["path"].([]string); ok {
path = strings.Join(p, "/")
}
return &resourceRequestParams{owner: owner, repo: repo, path: path}, nil
}

// setRefOptions sets the ref options based on the request arguments
func setRefOptions(args map[string]interface{}, opts *github.RepositoryContentGetOptions, rawOpts *raw.ContentOpts) {
if sha, ok := extractOptionalStringArg(args, "sha"); ok {
opts.Ref = sha
rawOpts.SHA = sha
}
if branch, ok := extractOptionalStringArg(args, "branch"); ok {
opts.Ref = "refs/heads/" + branch
rawOpts.Ref = "refs/heads/" + branch
}
if tag, ok := extractOptionalStringArg(args, "tag"); ok {
opts.Ref = "refs/tags/" + tag
rawOpts.Ref = "refs/tags/" + tag
}
}

// handlePRRef fetches the PR and sets the SHA for the ref options
func handlePRRef(ctx context.Context, args map[string]interface{}, owner, repo string, getClient GetClientFn, opts *github.RepositoryContentGetOptions, rawOpts *raw.ContentOpts) error {
prNumberStr, ok := extractOptionalStringArg(args, "prNumber")
if !ok {
return nil
}
githubClient, err := getClient(ctx)
if err != nil {
return fmt.Errorf("failed to get GitHub client: %w", err)
}
prNum, err := strconv.Atoi(prNumberStr)
if err != nil {
return fmt.Errorf("invalid pull request number: %w", err)
}
pr, _, err := githubClient.PullRequests.Get(ctx, owner, repo, prNum)
if err != nil {
return fmt.Errorf("failed to get pull request: %w", err)
}
sha := pr.GetHead().GetSHA()
rawOpts.SHA = sha
opts.Ref = sha
return nil
}

// determineMimeType determines the MIME type for a file based on extension and response headers
func determineMimeType(path string, contentType string) string {
ext := filepath.Ext(path)
if ext == ".md" {
return "text/markdown"
}
if contentType != "" {
return contentType
}
return mime.TypeByExtension(ext)
}

// buildResourceContents builds the appropriate resource contents based on MIME type
func buildResourceContents(uri string, mimeType string, content []byte) []mcp.ResourceContents {
if strings.HasPrefix(mimeType, "text") || strings.HasPrefix(mimeType, "application") {
return []mcp.ResourceContents{
mcp.TextResourceContents{
URI: uri,
MIMEType: mimeType,
Text: string(content),
},
}
}
return []mcp.ResourceContents{
mcp.BlobResourceContents{
URI: uri,
MIMEType: mimeType,
Blob: base64.StdEncoding.EncodeToString(content),
},
}
}

// handleRawContentResponse processes the raw content response and returns resource contents
func handleRawContentResponse(resp *http.Response, uri string, path string) ([]mcp.ResourceContents, error) {
if resp.StatusCode == http.StatusOK {
mimeType := determineMimeType(path, resp.Header.Get("Content-Type"))
content, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read file content: %w", err)
}
return buildResourceContents(uri, mimeType, content), nil
}
if resp.StatusCode != http.StatusNotFound {
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
return nil, fmt.Errorf("failed to fetch raw content: %s", string(body))
}
return nil, errors.New("404 Not Found")
}

// RepositoryResourceContentsHandler returns a handler function for repository content requests.
func RepositoryResourceContentsHandler(getClient GetClientFn, getRawClient raw.GetRawClientFn) func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) {
return func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) {
// the matcher will give []string with one element
// https://github.com/mark3labs/mcp-go/pull/54
o, ok := request.Params.Arguments["owner"].([]string)
if !ok || len(o) == 0 {
return nil, errors.New("owner is required")
}
owner := o[0]

r, ok := request.Params.Arguments["repo"].([]string)
if !ok || len(r) == 0 {
return nil, errors.New("repo is required")
params, err := extractResourceParams(request.Params.Arguments)
if err != nil {
return nil, err
}
repo := r[0]

// path should be a joined list of the path parts
path := ""
p, ok := request.Params.Arguments["path"].([]string)
if ok {
path = strings.Join(p, "/")
if params.path == "" || strings.HasSuffix(params.path, "/") {
return nil, fmt.Errorf("directories are not supported: %s", params.path)
}

opts := &github.RepositoryContentGetOptions{}
rawOpts := &raw.ContentOpts{}

sha, ok := request.Params.Arguments["sha"].([]string)
if ok && len(sha) > 0 {
opts.Ref = sha[0]
rawOpts.SHA = sha[0]
}
setRefOptions(request.Params.Arguments, opts, rawOpts)

branch, ok := request.Params.Arguments["branch"].([]string)
if ok && len(branch) > 0 {
opts.Ref = "refs/heads/" + branch[0]
rawOpts.Ref = "refs/heads/" + branch[0]
if err := handlePRRef(ctx, request.Params.Arguments, params.owner, params.repo, getClient, opts, rawOpts); err != nil {
return nil, err
}

tag, ok := request.Params.Arguments["tag"].([]string)
if ok && len(tag) > 0 {
opts.Ref = "refs/tags/" + tag[0]
rawOpts.Ref = "refs/tags/" + tag[0]
}
prNumber, ok := request.Params.Arguments["prNumber"].([]string)
if ok && len(prNumber) > 0 {
// fetch the PR from the API to get the latest commit and use SHA
githubClient, err := getClient(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
}
prNum, err := strconv.Atoi(prNumber[0])
if err != nil {
return nil, fmt.Errorf("invalid pull request number: %w", err)
}
pr, _, err := githubClient.PullRequests.Get(ctx, owner, repo, prNum)
if err != nil {
return nil, fmt.Errorf("failed to get pull request: %w", err)
}
sha := pr.GetHead().GetSHA()
rawOpts.SHA = sha
opts.Ref = sha
}
// if it's a directory
if path == "" || strings.HasSuffix(path, "/") {
return nil, fmt.Errorf("directories are not supported: %s", path)
}
rawClient, err := getRawClient(ctx)

if err != nil {
return nil, fmt.Errorf("failed to get GitHub raw content client: %w", err)
}

resp, err := rawClient.GetRawContent(ctx, owner, repo, path, rawOpts)
resp, err := rawClient.GetRawContent(ctx, params.owner, params.repo, params.path, rawOpts)
if err != nil {
return nil, fmt.Errorf("failed to get raw content: %w", err)
}
defer func() {
_ = resp.Body.Close()
}()
// If the raw content is not found, we will fall back to the GitHub API (in case it is a directory)
switch {
case err != nil:
return nil, fmt.Errorf("failed to get raw content: %w", err)
case resp.StatusCode == http.StatusOK:
ext := filepath.Ext(path)
mimeType := resp.Header.Get("Content-Type")
if ext == ".md" {
mimeType = "text/markdown"
} else if mimeType == "" {
mimeType = mime.TypeByExtension(ext)
}

content, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read file content: %w", err)
}

switch {
case strings.HasPrefix(mimeType, "text"), strings.HasPrefix(mimeType, "application"):
return []mcp.ResourceContents{
mcp.TextResourceContents{
URI: request.Params.URI,
MIMEType: mimeType,
Text: string(content),
},
}, nil
default:
return []mcp.ResourceContents{
mcp.BlobResourceContents{
URI: request.Params.URI,
MIMEType: mimeType,
Blob: base64.StdEncoding.EncodeToString(content),
},
}, nil
}
case resp.StatusCode != http.StatusNotFound:
// If we got a response but it is not 200 OK, we return an error
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
return nil, fmt.Errorf("failed to fetch raw content: %s", string(body))
default:
// This should be unreachable because GetContents should return an error if neither file nor directory content is found.
return nil, errors.New("404 Not Found")
}

return handleRawContentResponse(resp, request.Params.URI, params.path)
}
}
Loading