From 0a5ce3280ee18d1c10189f3d2f92f301a2513fef Mon Sep 17 00:00:00 2001 From: Nikolai Kaploniuk Date: Tue, 30 Dec 2025 00:01:59 -0800 Subject: [PATCH] feat: when using custom PR templates - prompt user to fill it in --- github/githubclient/client.go | 5 +- github/githubclient/gen/genclient/client.go | 18 ++-- .../githubclient/gen/genclient/operations.go | 20 ++-- github/githubclient/queries.graphql | 1 + github/template/interface.go | 2 +- github/template/template_basic/template.go | 2 +- .../template/template_basic/template_test.go | 8 +- github/template/template_custom/template.go | 91 ++++++++++++++++++- github/template/template_stack/template.go | 2 +- .../template/template_stack/template_test.go | 26 +++--- github/template/template_why_what/template.go | 2 +- .../template_why_what/template_test.go | 6 +- 12 files changed, 136 insertions(+), 47 deletions(-) diff --git a/github/githubclient/client.go b/github/githubclient/client.go index b59c811a..edb790da 100644 --- a/github/githubclient/client.go +++ b/github/githubclient/client.go @@ -385,7 +385,7 @@ func (c *client) CreatePullRequest(ctx context.Context, gitcmd git.GitInterface, templatizer := config_fetcher.PRTemplatizer(c.config, gitcmd) - body := templatizer.Body(info, commit) + body := templatizer.Body(info, commit, nil) resp, err := c.api.CreatePullRequest(ctx, genclient.CreatePullRequestInput{ RepositoryId: info.RepositoryID, BaseRefName: baseRefName, @@ -403,6 +403,7 @@ func (c *client) CreatePullRequest(ctx context.Context, gitcmd git.GitInterface, ToBranch: baseRefName, Commit: commit, Title: commit.Subject, + Body: resp.CreatePullRequest.PullRequest.Body, MergeStatus: github.PullRequestMergeStatus{ ChecksPass: github.CheckStatusUnknown, ReviewApproved: false, @@ -437,7 +438,7 @@ func (c *client) UpdatePullRequest(ctx context.Context, gitcmd git.GitInterface, templatizer := config_fetcher.PRTemplatizer(c.config, gitcmd) title := templatizer.Title(info, commit) - body := templatizer.Body(info, commit) + body := templatizer.Body(info, commit, pr) input := genclient.UpdatePullRequestInput{ PullRequestId: pr.ID, Title: &title, diff --git a/github/githubclient/gen/genclient/client.go b/github/githubclient/gen/genclient/client.go index c43e842f..7879d200 100644 --- a/github/githubclient/gen/genclient/client.go +++ b/github/githubclient/gen/genclient/client.go @@ -34,48 +34,48 @@ type Client interface { input CreatePullRequestInput, ) (*CreatePullRequestResponse, error) - // UpdatePullRequest from github/githubclient/queries.graphql:115 + // UpdatePullRequest from github/githubclient/queries.graphql:116 UpdatePullRequest(ctx context.Context, input UpdatePullRequestInput, ) (*UpdatePullRequestResponse, error) - // AddReviewers from github/githubclient/queries.graphql:127 + // AddReviewers from github/githubclient/queries.graphql:128 AddReviewers(ctx context.Context, input RequestReviewsInput, ) (*AddReviewersResponse, error) - // CommentPullRequest from github/githubclient/queries.graphql:139 + // CommentPullRequest from github/githubclient/queries.graphql:140 CommentPullRequest(ctx context.Context, input AddCommentInput, ) (*CommentPullRequestResponse, error) - // MergePullRequest from github/githubclient/queries.graphql:149 + // MergePullRequest from github/githubclient/queries.graphql:150 MergePullRequest(ctx context.Context, input MergePullRequestInput, ) (*MergePullRequestResponse, error) - // AutoMergePullRequest from github/githubclient/queries.graphql:161 + // AutoMergePullRequest from github/githubclient/queries.graphql:162 AutoMergePullRequest(ctx context.Context, input EnablePullRequestAutoMergeInput, ) (*AutoMergePullRequestResponse, error) - // ClosePullRequest from github/githubclient/queries.graphql:173 + // ClosePullRequest from github/githubclient/queries.graphql:174 ClosePullRequest(ctx context.Context, input ClosePullRequestInput, ) (*ClosePullRequestResponse, error) - // StarCheck from github/githubclient/queries.graphql:185 + // StarCheck from github/githubclient/queries.graphql:186 StarCheck(ctx context.Context, after *string, ) (*StarCheckResponse, error) - // StarGetRepo from github/githubclient/queries.graphql:201 + // StarGetRepo from github/githubclient/queries.graphql:202 StarGetRepo(ctx context.Context, owner string, name string, ) (*StarGetRepoResponse, error) - // StarAdd from github/githubclient/queries.graphql:210 + // StarAdd from github/githubclient/queries.graphql:211 StarAdd(ctx context.Context, input AddStarInput, ) (*StarAddResponse, error) diff --git a/github/githubclient/gen/genclient/operations.go b/github/githubclient/gen/genclient/operations.go index 1f5312e1..b79ffc71 100644 --- a/github/githubclient/gen/genclient/operations.go +++ b/github/githubclient/gen/genclient/operations.go @@ -278,6 +278,7 @@ type CreatePullRequestCreatePullRequest struct { type CreatePullRequestCreatePullRequestPullRequest struct { Id string Number int + Body string } // CreatePullRequestResponse response type for CreatePullRequest @@ -296,6 +297,7 @@ func (c *gqlclient) CreatePullRequest(ctx context.Context, pullRequest { id number + body } } } @@ -343,7 +345,7 @@ type UpdatePullRequestResponse struct { UpdatePullRequest *UpdatePullRequestUpdatePullRequest } -// UpdatePullRequest from github/githubclient/queries.graphql:115 +// UpdatePullRequest from github/githubclient/queries.graphql:116 func (c *gqlclient) UpdatePullRequest(ctx context.Context, input UpdatePullRequestInput, ) (*UpdatePullRequestResponse, error) { @@ -400,7 +402,7 @@ type AddReviewersResponse struct { RequestReviews *AddReviewersRequestReviews } -// AddReviewers from github/githubclient/queries.graphql:127 +// AddReviewers from github/githubclient/queries.graphql:128 func (c *gqlclient) AddReviewers(ctx context.Context, input RequestReviewsInput, ) (*AddReviewersResponse, error) { @@ -453,7 +455,7 @@ type CommentPullRequestResponse struct { AddComment *CommentPullRequestAddComment } -// CommentPullRequest from github/githubclient/queries.graphql:139 +// CommentPullRequest from github/githubclient/queries.graphql:140 func (c *gqlclient) CommentPullRequest(ctx context.Context, input AddCommentInput, ) (*CommentPullRequestResponse, error) { @@ -508,7 +510,7 @@ type MergePullRequestResponse struct { MergePullRequest *MergePullRequestMergePullRequest } -// MergePullRequest from github/githubclient/queries.graphql:149 +// MergePullRequest from github/githubclient/queries.graphql:150 func (c *gqlclient) MergePullRequest(ctx context.Context, input MergePullRequestInput, ) (*MergePullRequestResponse, error) { @@ -565,7 +567,7 @@ type AutoMergePullRequestResponse struct { EnablePullRequestAutoMerge *AutoMergePullRequestEnablePullRequestAutoMerge } -// AutoMergePullRequest from github/githubclient/queries.graphql:161 +// AutoMergePullRequest from github/githubclient/queries.graphql:162 func (c *gqlclient) AutoMergePullRequest(ctx context.Context, input EnablePullRequestAutoMergeInput, ) (*AutoMergePullRequestResponse, error) { @@ -622,7 +624,7 @@ type ClosePullRequestResponse struct { ClosePullRequest *ClosePullRequestClosePullRequest } -// ClosePullRequest from github/githubclient/queries.graphql:173 +// ClosePullRequest from github/githubclient/queries.graphql:174 func (c *gqlclient) ClosePullRequest(ctx context.Context, input ClosePullRequestInput, ) (*ClosePullRequestResponse, error) { @@ -689,7 +691,7 @@ type StarCheckResponse struct { Viewer StarCheckViewer } -// StarCheck from github/githubclient/queries.graphql:185 +// StarCheck from github/githubclient/queries.graphql:186 func (c *gqlclient) StarCheck(ctx context.Context, after *string, ) (*StarCheckResponse, error) { @@ -748,7 +750,7 @@ type StarGetRepoResponse struct { Repository *StarGetRepoRepository } -// StarGetRepo from github/githubclient/queries.graphql:201 +// StarGetRepo from github/githubclient/queries.graphql:202 func (c *gqlclient) StarGetRepo(ctx context.Context, owner string, name string, @@ -801,7 +803,7 @@ type StarAddResponse struct { AddStar *StarAddAddStar } -// StarAdd from github/githubclient/queries.graphql:210 +// StarAdd from github/githubclient/queries.graphql:211 func (c *gqlclient) StarAdd(ctx context.Context, input AddStarInput, ) (*StarAddResponse, error) { diff --git a/github/githubclient/queries.graphql b/github/githubclient/queries.graphql index bc529d79..2146786a 100644 --- a/github/githubclient/queries.graphql +++ b/github/githubclient/queries.graphql @@ -108,6 +108,7 @@ mutation CreatePullRequest( pullRequest { id number + body } } } diff --git a/github/template/interface.go b/github/template/interface.go index 578f0971..cc32525a 100644 --- a/github/template/interface.go +++ b/github/template/interface.go @@ -7,5 +7,5 @@ import ( type PRTemplatizer interface { Title(info *github.GitHubInfo, commit git.Commit) string - Body(info *github.GitHubInfo, commit git.Commit) string + Body(info *github.GitHubInfo, commit git.Commit, pr *github.PullRequest) string } diff --git a/github/template/template_basic/template.go b/github/template/template_basic/template.go index 3ccf5ab2..d253df07 100644 --- a/github/template/template_basic/template.go +++ b/github/template/template_basic/template.go @@ -16,7 +16,7 @@ func (t *BasicTemplatizer) Title(info *github.GitHubInfo, commit git.Commit) str return commit.Subject } -func (t *BasicTemplatizer) Body(info *github.GitHubInfo, commit git.Commit) string { +func (t *BasicTemplatizer) Body(info *github.GitHubInfo, commit git.Commit, pr *github.PullRequest) string { body := commit.Body body += "\n\n" body += template.ManualMergeNotice() diff --git a/github/template/template_basic/template_test.go b/github/template/template_basic/template_test.go index 05b66326..1e3fba1a 100644 --- a/github/template/template_basic/template_test.go +++ b/github/template/template_basic/template_test.go @@ -194,7 +194,7 @@ func TestBody(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := templatizer.Body(info, tt.commit) + got := templatizer.Body(info, tt.commit, nil) // Verify all expected strings are present for _, wantStr := range tt.wantContains { @@ -229,7 +229,7 @@ func TestBodyManualMergeNoticeFormat(t *testing.T) { Body: "Test body", } - result := templatizer.Body(info, commit) + result := templatizer.Body(info, commit, nil) // Verify the exact format of the manual merge notice expectedNotice := "⚠️ *Part of a stack created by [spr](https://github.com/ejoffe/spr). Do not merge manually using the UI - doing so may have unexpected results.*" @@ -280,7 +280,7 @@ func TestBodyPreservesOriginalContent(t *testing.T) { Body: tc.body, } - result := templatizer.Body(info, commit) + result := templatizer.Body(info, commit, nil) // The result should contain the original body (if not empty) if tc.body != "" { @@ -349,7 +349,7 @@ Performance improvements: for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := templatizer.Body(info, tt.commit) + result := templatizer.Body(info, tt.commit, nil) // Should contain the original body content assert.Contains(t, result, tt.commit.Body) diff --git a/github/template/template_custom/template.go b/github/template/template_custom/template.go index ee65d09f..c6453bea 100644 --- a/github/template/template_custom/template.go +++ b/github/template/template_custom/template.go @@ -1,9 +1,11 @@ package template_custom import ( + "bufio" "errors" "fmt" "os" + "os/exec" "path" "path/filepath" "strings" @@ -34,21 +36,104 @@ func (t *CustomTemplatizer) Title(info *github.GitHubInfo, commit git.Commit) st return commit.Subject } -func (t *CustomTemplatizer) Body(info *github.GitHubInfo, commit git.Commit) string { +func (t *CustomTemplatizer) Body(info *github.GitHubInfo, commit git.Commit, pr *github.PullRequest) string { body := t.formatBody(commit, info.PullRequests) pullRequestTemplate, err := t.readPRTemplate() if err != nil { log.Fatal().Err(err).Msg("failed to read PR template") } - body, err = t.insertBodyIntoPRTemplate(body, pullRequestTemplate, nil) + body, err = t.insertBodyIntoPRTemplate(body, pullRequestTemplate, pr) if err != nil { log.Fatal().Err(err).Msg("failed to insert body into PR template") } + + // Open editor for user to edit the PR content only when creating a new PR (pr == nil) + if pr != nil { + return body + } + + if !promptUserToEdit(commit) { + return body + } + + body, err = EditWithEditor(body) + if err != nil { + log.Fatal().Err(err).Msg("failed to edit PR content with editor") + } + return body } -func (t *CustomTemplatizer) formatBody(commit git.Commit, stack []*github.PullRequest) string { +// promptUserToEdit prompts the user if they want to edit the PR content in their editor +func promptUserToEdit(commit git.Commit) bool { + scanner := bufio.NewScanner(os.Stdin) + for { + fmt.Println() + fmt.Println("New PR for:") + fmt.Printf(" %s: %s\n", commit.CommitHash[:7], commit.Subject) + fmt.Println() + fmt.Print("Edit PR content? [Y/n]: ") + if !scanner.Scan() { + // On error or EOF, default to editing + return true + } + input := strings.ToLower(strings.TrimSpace(scanner.Text())) + switch input { + case "y", "yes": + return true + case "n", "no": + return false + case "": + // Empty input defaults to yes + return true + default: + // Invalid input, ask again + continue + } + } +} + +// EditWithEditor opens the default editor to allow the user to edit the provided content. +func EditWithEditor(initialContent string) (string, error) { + editor := os.Getenv("EDITOR") + if editor == "" { + editor = "vi" + } + + // Create temporary file to hold the content + tmpFile, err := os.CreateTemp("", "spr-pr-*.md") + if err != nil { + return "", fmt.Errorf("failed to create temporary file: %w", err) + } + defer os.Remove(tmpFile.Name()) + // Write initial content to temporary file + if _, err := tmpFile.WriteString(initialContent); err != nil { + tmpFile.Close() + return "", fmt.Errorf("failed to write to temporary file: %w", err) + } + tmpFile.Close() + + // Open editor + cmd := exec.Command(editor, tmpFile.Name()) + cmd.Stdin = os.Stdin + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + if err := cmd.Run(); err != nil { + return "", fmt.Errorf("editor command failed: %w", err) + } + + // Read edited content from temporary file + editedBytes, err := os.ReadFile(tmpFile.Name()) + if err != nil { + return "", fmt.Errorf("failed to read edited content: %w", err) + } + + return string(editedBytes), nil +} + +func (t *CustomTemplatizer) formatBody(commit git.Commit, stack []*github.PullRequest) string { if len(stack) <= 1 { return strings.TrimSpace(commit.Body) } diff --git a/github/template/template_stack/template.go b/github/template/template_stack/template.go index 74e1e700..960a1941 100644 --- a/github/template/template_stack/template.go +++ b/github/template/template_stack/template.go @@ -18,7 +18,7 @@ func (t *StackTemplatizer) Title(info *github.GitHubInfo, commit git.Commit) str return commit.Subject } -func (t *StackTemplatizer) Body(info *github.GitHubInfo, commit git.Commit) string { +func (t *StackTemplatizer) Body(info *github.GitHubInfo, commit git.Commit, pr *github.PullRequest) string { body := commit.Body // Always show stack section and notice diff --git a/github/template/template_stack/template_test.go b/github/template/template_stack/template_test.go index 81588ab2..bd72e372 100644 --- a/github/template/template_stack/template_test.go +++ b/github/template/template_stack/template_test.go @@ -79,7 +79,7 @@ func TestBody_EmptyStack(t *testing.T) { Body: "Commit body text", } - result := templatizer.Body(info, commit) + result := templatizer.Body(info, commit, nil) // Should contain the commit body assert.Contains(t, result, "Commit body text") @@ -133,7 +133,7 @@ func TestBody_WithStack_NoTitles(t *testing.T) { } // Test with commit2 (middle of stack) - result := templatizer.Body(info, commit2) + result := templatizer.Body(info, commit2, nil) // Should contain the commit body assert.Contains(t, result, "Second body") @@ -210,7 +210,7 @@ func TestBody_WithStack_WithTitles(t *testing.T) { } // Test with commit2 (middle of stack) - result := templatizer.Body(info, commit2) + result := templatizer.Body(info, commit2, nil) // Should contain the commit body assert.Contains(t, result, "Second body") @@ -254,7 +254,7 @@ func TestBody_StackOrder(t *testing.T) { }, } - result := templatizer.Body(info, commit2) + result := templatizer.Body(info, commit2, nil) // Stack should be in reverse order (3, 2, 1) // Find the stack section @@ -305,7 +305,7 @@ func TestBody_CurrentCommitAtStart(t *testing.T) { } // Test with commit3 (last in stack, first in reverse order) - result := templatizer.Body(info, commit3) + result := templatizer.Body(info, commit3, nil) // Should have arrow on #3 assert.Contains(t, result, "#3 ⬅") @@ -343,7 +343,7 @@ func TestBody_CurrentCommitAtEnd(t *testing.T) { } // Test with commit1 (first in stack, last in reverse order) - result := templatizer.Body(info, commit1) + result := templatizer.Body(info, commit1, nil) // Should have arrow on #1 assert.Contains(t, result, "#1 ⬅") @@ -368,7 +368,7 @@ func TestBody_EmptyBody(t *testing.T) { }, } - result := templatizer.Body(info, commit) + result := templatizer.Body(info, commit, nil) // Should still contain stack and notice assert.Contains(t, result, "#1") @@ -391,7 +391,7 @@ func TestBody_SinglePRInStack(t *testing.T) { }, } - result := templatizer.Body(info, commit) + result := templatizer.Body(info, commit, nil) // Should contain the body assert.Contains(t, result, "Single body") @@ -427,7 +427,7 @@ func TestBody_WithTitlesVsWithoutTitles(t *testing.T) { // Test without titles templatizerNoTitles := NewStackTemplatizer(false) - resultNoTitles := templatizerNoTitles.Body(info, commit2) + resultNoTitles := templatizerNoTitles.Body(info, commit2, nil) // Should NOT contain PR titles assert.NotContains(t, resultNoTitles, "First PR") @@ -438,7 +438,7 @@ func TestBody_WithTitlesVsWithoutTitles(t *testing.T) { // Test with titles templatizerWithTitles := NewStackTemplatizer(true) - resultWithTitles := templatizerWithTitles.Body(info, commit2) + resultWithTitles := templatizerWithTitles.Body(info, commit2, nil) // Should contain PR titles assert.Contains(t, resultWithTitles, "First PR #1") @@ -460,7 +460,7 @@ func TestBody_Structure(t *testing.T) { }, } - result := templatizer.Body(info, commit) + result := templatizer.Body(info, commit, nil) // Verify structure: body + \n\n + stack + \n\n + notice // The body should come first @@ -513,7 +513,7 @@ func TestBody_RealWorldExample(t *testing.T) { } // Test with middle commit - result := templatizer.Body(info, commit2) + result := templatizer.Body(info, commit2, nil) // Should contain commit body assert.Contains(t, result, "Created POST /login endpoint") @@ -608,7 +608,7 @@ It even includes some **markdown** formatting. templatizer := NewStackTemplatizer(false) for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - body := templatizer.Body(tc.info, tc.commit) + body := templatizer.Body(tc.info, tc.commit, nil) if body != tc.expected { t.Fatalf("expected: '%v', actual: '%v'", tc.expected, body) } diff --git a/github/template/template_why_what/template.go b/github/template/template_why_what/template.go index e69e8bb1..c81f49ad 100644 --- a/github/template/template_why_what/template.go +++ b/github/template/template_why_what/template.go @@ -20,7 +20,7 @@ func (t *WhyWhatTemplatizer) Title(info *github.GitHubInfo, commit git.Commit) s return commit.Subject } -func (t *WhyWhatTemplatizer) Body(info *github.GitHubInfo, commit git.Commit) string { +func (t *WhyWhatTemplatizer) Body(info *github.GitHubInfo, commit git.Commit, pr *github.PullRequest) string { // Split commit body by empty lines and filter out empty sections sections := splitByEmptyLines(commit.Body) diff --git a/github/template/template_why_what/template_test.go b/github/template/template_why_what/template_test.go index a9dd41cd..7a9179a4 100644 --- a/github/template/template_why_what/template_test.go +++ b/github/template/template_why_what/template_test.go @@ -212,7 +212,7 @@ func TestBody(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := templatizer.Body(info, tt.commit) + got := templatizer.Body(info, tt.commit, nil) // Check that all required strings are present for _, wantStr := range tt.contains { @@ -320,7 +320,7 @@ func TestBodyTemplateStructure(t *testing.T) { Body: "Why section\n\nWhat changed section\n\nTest plan section", } - result := templatizer.Body(info, commit) + result := templatizer.Body(info, commit, nil) // Verify sections appear in correct order whyIndex := strings.Index(result, "Why\n===") @@ -381,7 +381,7 @@ Manual review of the docs.`, for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := templatizer.Body(info, tt.commit) + result := templatizer.Body(info, tt.commit, nil) // Should always contain the required sections assert.Contains(t, result, "Why\n===")