From 61133a7e8c339cdd29d2af9e385aaf76f2920303 Mon Sep 17 00:00:00 2001 From: tshihad Date: Mon, 24 Nov 2025 16:26:55 +0100 Subject: [PATCH 1/2] Support table driven test --- gopls/internal/cache/testfuncs/tests.go | 312 +++++++++++++++++-- gopls/internal/cache/testfuncs/tests_test.go | 275 ++++++++++++++++ gopls/internal/golang/code_lens.go | 47 +++ 3 files changed, 616 insertions(+), 18 deletions(-) create mode 100644 gopls/internal/cache/testfuncs/tests_test.go diff --git a/gopls/internal/cache/testfuncs/tests.go b/gopls/internal/cache/testfuncs/tests.go index e0e3ce1beca..5cf579f2717 100644 --- a/gopls/internal/cache/testfuncs/tests.go +++ b/gopls/internal/cache/testfuncs/tests.go @@ -7,6 +7,7 @@ package testfuncs import ( "go/ast" "go/constant" + "go/token" "go/types" "strings" "unicode" @@ -133,6 +134,121 @@ func (b *indexBuilder) findSubtests(parent gobTest, typ *ast.FuncType, body *ast // parameter of the enclosing test function. var tests []gobTest for _, stmt := range body.List { + // Handle direct t.Run calls + if expr, ok := stmt.(*ast.ExprStmt); ok { + tests = append(tests, b.findDirectSubtests(parent, param, expr, file, files, info)...) + continue + } + + // Handle table-driven tests: for _, tt := range tests { t.Run(tt.name, ...) } + if rangeStmt, ok := stmt.(*ast.RangeStmt); ok { + tests = append(tests, b.findTableDrivenSubtests(parent, param, rangeStmt, file, files, info)...) + continue + } + } + return tests +} + +// findDirectSubtests finds subtests from direct t.Run("name", ...) calls. +func (b *indexBuilder) findDirectSubtests(parent gobTest, param types.Object, expr *ast.ExprStmt, file *parsego.File, files []*parsego.File, info *types.Info) []gobTest { + var tests []gobTest + + call, ok := expr.X.(*ast.CallExpr) + if !ok || len(call.Args) != 2 { + return nil + } + fun, ok := call.Fun.(*ast.SelectorExpr) + if !ok || fun.Sel.Name != "Run" { + return nil + } + recv, ok := fun.X.(*ast.Ident) + if !ok || info.ObjectOf(recv) != param { + return nil + } + + sig, ok := info.TypeOf(call.Args[1]).(*types.Signature) + if !ok { + return nil + } + if _, ok := testKind(sig); !ok { + return nil // subtest has wrong signature + } + + val := info.Types[call.Args[0]].Value // may be zero + if val == nil || val.Kind() != constant.String { + return nil + } + + var t gobTest + t.Name = b.uniqueName(parent.Name, rewrite(constant.StringVal(val))) + t.Location.URI = file.URI + t.Location.Range, _ = file.NodeRange(call) + tests = append(tests, t) + + fn, funcType, funcBody := findFunc(files, info, nil, call.Args[1]) + if funcType == nil { + return tests + } + + // Function literals don't have an associated object + if fn == nil { + tests = append(tests, b.findSubtests(t, funcType, funcBody, file, files, info)...) + return tests + } + + // Never recurse if the second argument is a top-level test function + if isTest, _ := isTestOrExample(fn); isTest { + return tests + } + + // Don't recurse into functions that have already been visited + if b.visited[fn] { + return tests + } + + b.visited[fn] = true + tests = append(tests, b.findSubtests(t, funcType, funcBody, file, files, info)...) + return tests +} + +// findTableDrivenSubtests finds subtests from table-driven tests. +// It handles patterns like: +// +// tests := []struct{ name string; ... }{{name: "test1"}, ...} +// for _, tt := range tests { +// t.Run(tt.name, func(t *testing.T) { ... }) +// } +func (b *indexBuilder) findTableDrivenSubtests(parent gobTest, param types.Object, rangeStmt *ast.RangeStmt, file *parsego.File, files []*parsego.File, info *types.Info) []gobTest { + var tests []gobTest + + // rangeStmt.Body should contain t.Run calls + if rangeStmt.Body == nil { + return nil + } + + // Get the loop variable (e.g., tt in "for _, tt := range tests") + var loopVar types.Object + if rangeStmt.Value != nil { + if ident, ok := rangeStmt.Value.(*ast.Ident); ok { + loopVar = info.ObjectOf(ident) + } + } + if loopVar == nil { + // Try rangeStmt.Key for "for tt := range tests" pattern + if rangeStmt.Key != nil { + if ident, ok := rangeStmt.Key.(*ast.Ident); ok { + loopVar = info.ObjectOf(ident) + } + } + } + if loopVar == nil { + return nil + } + + var testNameField *ast.Ident + // Find t.Run calls in the range body to confirm this is a table-driven test, if so then set the testNameVar + hasRun := false + for _, stmt := range rangeStmt.Body.List { expr, ok := stmt.(*ast.ExprStmt) if !ok { continue @@ -159,42 +275,202 @@ func (b *indexBuilder) findSubtests(parent gobTest, typ *ast.FuncType, body *ast continue // subtest has wrong signature } - val := info.Types[call.Args[0]].Value // may be zero - if val == nil || val.Kind() != constant.String { + // Check if first argument is a field access like tt.name, if so set + testNameField = b.isLoopVarFieldAccess(call.Args[0], loopVar, info) + if testNameField == nil { continue } + // TODO: handle expressions other than struct field selectors + + hasRun = true + break + } + + if !hasRun { + return nil + } + + // Find the table being ranged over and extract test cases with their locations + tableEntries := b.extractTableTestCases(rangeStmt.X, files, info, file, testNameField) + if len(tableEntries) == 0 { + return nil + } + + // Create a test entry for each table entry with its specific location + for _, entry := range tableEntries { var t gobTest - t.Name = b.uniqueName(parent.Name, rewrite(constant.StringVal(val))) + t.Name = b.uniqueName(parent.Name, rewrite(entry.name)) t.Location.URI = file.URI - t.Location.Range, _ = file.NodeRange(call) + t.Location.Range = entry.location tests = append(tests, t) + } - fn, typ, body := findFunc(files, info, body, call.Args[1]) - if typ == nil { - continue + return tests +} + +// isLoopVarFieldAccess checks if expr is a field access on the loop variable, if so returns the field identifier +// (e.g., tt.name where tt is the loop variable). +func (b *indexBuilder) isLoopVarFieldAccess(expr ast.Expr, loopVar types.Object, info *types.Info) *ast.Ident { + sel, ok := expr.(*ast.SelectorExpr) + if !ok { + return nil + } + ident, ok := sel.X.(*ast.Ident) + if !ok { + return nil + } + if info.ObjectOf(ident) != loopVar { + return nil + } + return sel.Sel +} + +// tableTestCase represents a single test case in a table-driven test +type tableTestCase struct { + name string + location protocol.Range +} + +// extractTableTestCases extracts test cases with their locations from a table-driven test slice. +// It handles patterns like: +// - tests := []struct{name string}{{"test1"}, {"test2"}} +// - []struct{name string}{{"test1"}, {"test2"}} +// - For identifier references, attempts to find the composite literal value +func (b *indexBuilder) extractTableTestCases(expr ast.Expr, files []*parsego.File, info *types.Info, file *parsego.File, testNameField *ast.Ident) []tableTestCase { + // Unwrap parentheses + for { + if paren, ok := expr.(*ast.ParenExpr); ok { + expr = paren.X + } else { + break } + } - // Function literals don't have an associated object - if fn == nil { - tests = append(tests, b.findSubtests(t, typ, body, file, files, info)...) - continue + // Handle both direct composite literals and identifiers + var comp *ast.CompositeLit + switch e := expr.(type) { + case *ast.CompositeLit: + comp = e + case *ast.Ident: + // Look for the assignment of this identifier + obj := info.ObjectOf(e) + if obj == nil { + return nil + } + // Find the composite literal from the identifier's definition + comp = b.findCompositeLiteralForIdent(e, files, info) + if comp == nil { + return nil } + default: + return nil + } + + // comp should be a slice composite literal + if comp.Type == nil { + return nil + } - // Never recurse if the second argument is a top-level test function - if isTest, _ := isTestOrExample(fn); isTest { + var cases []tableTestCase + for _, elt := range comp.Elts { + // Each element should be a struct literal + structLit, ok := elt.(*ast.CompositeLit) + if !ok { continue } - // Don't recurse into functions that have already been visited - if b.visited[fn] { + if len(structLit.Elts) == 0 { continue } - b.visited[fn] = true - tests = append(tests, b.findSubtests(t, typ, body, file, files, info)...) + // Try keyed fields first (e.g., {name: "test1", ...}) + for _, field := range structLit.Elts { + kv, ok := field.(*ast.KeyValueExpr) + if !ok { + //TODO: look for unkeyed fields + continue + } + key, ok := kv.Key.(*ast.Ident) + if !ok || key.Name != testNameField.Name { + continue + } + + // Get the location of this test case (the struct literal) + rng, err := file.NodeRange(structLit) + if err != nil { + continue + } + + // Extract the string value + if val := info.Types[kv.Value].Value; val != nil && val.Kind() == constant.String { + cases = append(cases, tableTestCase{ + name: constant.StringVal(val), + location: rng, + }) + break + } + } } - return tests + + return cases +} + +// findCompositeLiteralForIdent finds the composite literal that initializes the given identifier. +// It searches through the files for variable declarations and assignments. +func (b *indexBuilder) findCompositeLiteralForIdent(ident *ast.Ident, files []*parsego.File, info *types.Info) *ast.CompositeLit { + obj := info.ObjectOf(ident) + if obj == nil { + return nil + } + + // Search through all files to find the declaration + for _, file := range files { + // Walk through declarations to find variable declarations + for _, decl := range file.File.Decls { + // Check function declarations (where local variables are declared) + funcDecl, ok := decl.(*ast.FuncDecl) + if !ok || funcDecl.Body == nil { + continue + } + + // Walk through statements in the function body + for _, stmt := range funcDecl.Body.List { + // Check for short variable declaration: tests := ... + if assign, ok := stmt.(*ast.AssignStmt); ok && assign.Tok == token.DEFINE { + for i, lhs := range assign.Lhs { + if lhsIdent, ok := lhs.(*ast.Ident); ok && info.ObjectOf(lhsIdent) == obj { + // Found the declaration, check if RHS is a composite literal + if i < len(assign.Rhs) { + if comp, ok := assign.Rhs[i].(*ast.CompositeLit); ok { + return comp + } + } + } + } + } + + // Check for var declaration: var tests = ... + if declStmt, ok := stmt.(*ast.DeclStmt); ok { + if genDecl, ok := declStmt.Decl.(*ast.GenDecl); ok && genDecl.Tok == token.VAR { + for _, spec := range genDecl.Specs { + if valueSpec, ok := spec.(*ast.ValueSpec); ok { + for i, name := range valueSpec.Names { + if info.ObjectOf(name) == obj && i < len(valueSpec.Values) { + if comp, ok := valueSpec.Values[i].(*ast.CompositeLit); ok { + return comp + } + } + } + } + } + } + } + } + } + } + + return nil } // findFunc finds the type and body of the given expr, which may be a function diff --git a/gopls/internal/cache/testfuncs/tests_test.go b/gopls/internal/cache/testfuncs/tests_test.go new file mode 100644 index 00000000000..9897490f737 --- /dev/null +++ b/gopls/internal/cache/testfuncs/tests_test.go @@ -0,0 +1,275 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package testfuncs + +import ( + "go/ast" + "go/importer" + "go/parser" + "go/token" + "go/types" + "testing" + + "golang.org/x/tools/gopls/internal/cache/parsego" + "golang.org/x/tools/gopls/internal/protocol" +) + +func TestTableDrivenSubtests(t *testing.T) { + src := `package p + +import "testing" + +func TestExample(t *testing.T) { + tests := []struct { + name string + x int + want int + }{ + {name: "zero", x: 0, want: 0}, + {name: "one", x: 1, want: 1}, + {name: "two", x: 2, want: 2}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.x != tt.want { + t.Errorf("got %d, want %d", tt.x, tt.want) + } + }) + } +} +` + + fset := token.NewFileSet() + file, err := parser.ParseFile(fset, "example_test.go", src, 0) + if err != nil { + t.Fatal(err) + } + + info := &types.Info{ + Types: make(map[ast.Expr]types.TypeAndValue), + Defs: make(map[*ast.Ident]types.Object), + Uses: make(map[*ast.Ident]types.Object), + Selections: make(map[*ast.SelectorExpr]*types.Selection), + } + + conf := types.Config{Importer: importer.Default()} + _, err = conf.Check("p", fset, []*ast.File{file}, info) + if err != nil { + t.Fatalf("type checking failed: %v", err) + } + + // Create the mapper + tok := fset.File(file.Pos()) + content := []byte(src) + mapper := protocol.NewMapper(protocol.DocumentURI("file:///example_test.go"), content) + + pgf := &parsego.File{ + URI: protocol.DocumentURI("file:///example_test.go"), + File: file, + Tok: tok, + Src: content, + Mapper: mapper, + } + + index := NewIndex([]*parsego.File{pgf}, info) + results := index.All() + + // Debug: log what we found + t.Logf("Found %d results", len(results)) + + // We expect at least the main test function + if len(results) == 0 { + t.Fatal("expected at least one test result") + } + + // Check that we found the main test + foundMain := false + foundSubs := 0 + for _, r := range results { + t.Logf("Found test: %s", r.Name) + if r.Name == "TestExample" { + foundMain = true + } + if r.Name == "TestExample/zero" || r.Name == "TestExample/one" || r.Name == "TestExample/two" { + foundSubs++ + } + } + + if !foundMain { + t.Error("did not find main test function TestExample") + } + + // This is the new functionality - we should find the table-driven subtests + if foundSubs != 3 { + t.Errorf("expected to find 3 subtests, found %d", foundSubs) + } +} + +func TestNestedTableDrivenSubtests(t *testing.T) { + src := `package p + +import "testing" + +func TestNested(t *testing.T) { + tests := []struct { + name string + x int + }{ + {name: "outer1", x: 1}, + {name: "outer2", x: 2}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + subtests := []struct { + name string + y int + }{ + {name: "inner1", y: 10}, + {name: "inner2", y: 20}, + } + for _, st := range subtests { + t.Run(st.name, func(t *testing.T) { + // nested test + }) + } + }) + } +} +` + + fset := token.NewFileSet() + file, err := parser.ParseFile(fset, "nested_test.go", src, 0) + if err != nil { + t.Fatal(err) + } + + info := &types.Info{ + Types: make(map[ast.Expr]types.TypeAndValue), + Defs: make(map[*ast.Ident]types.Object), + Uses: make(map[*ast.Ident]types.Object), + Selections: make(map[*ast.SelectorExpr]*types.Selection), + } + + conf := types.Config{Importer: importer.Default()} + _, err = conf.Check("p", fset, []*ast.File{file}, info) + if err != nil { + t.Fatalf("type checking failed: %v", err) + } + + tok := fset.File(file.Pos()) + content := []byte(src) + mapper := protocol.NewMapper(protocol.DocumentURI("file:///nested_test.go"), content) + + pgf := &parsego.File{ + URI: protocol.DocumentURI("file:///nested_test.go"), + File: file, + Tok: tok, + Src: content, + Mapper: mapper, + } + + index := NewIndex([]*parsego.File{pgf}, info) + results := index.All() + + foundOuter1 := false + foundOuter2 := false + + for _, r := range results { + t.Logf("Found test: %s", r.Name) + switch r.Name { + case "TestNested/outer1": + foundOuter1 = true + case "TestNested/outer2": + foundOuter2 = true + } + } + + // We should find the outer table-driven subtests + // Note: Nested table-driven tests (table-driven tests inside function literals + // passed to t.Run) are not currently supported and would require more complex + // analysis. This is an acceptable limitation. + if !foundOuter1 { + t.Error("did not find subtest TestNested/outer1") + } + if !foundOuter2 { + t.Error("did not find subtest TestNested/outer2") + } +} + +func TestDirectSubtests(t *testing.T) { + src := `package p + +import "testing" + +func TestDirect(t *testing.T) { + t.Run("first", func(t *testing.T) { + // test code + }) + t.Run("second", func(t *testing.T) { + // test code + }) +} +` + + fset := token.NewFileSet() + file, err := parser.ParseFile(fset, "direct_test.go", src, 0) + if err != nil { + t.Fatal(err) + } + + info := &types.Info{ + Types: make(map[ast.Expr]types.TypeAndValue), + Defs: make(map[*ast.Ident]types.Object), + Uses: make(map[*ast.Ident]types.Object), + Selections: make(map[*ast.SelectorExpr]*types.Selection), + } + + conf := types.Config{Importer: importer.Default()} + _, err = conf.Check("p", fset, []*ast.File{file}, info) + if err != nil { + t.Fatalf("type checking failed: %v", err) + } + + // Create the mapper + tok := fset.File(file.Pos()) + content := []byte(src) + mapper := protocol.NewMapper(protocol.DocumentURI("file:///direct_test.go"), content) + + pgf := &parsego.File{ + URI: protocol.DocumentURI("file:///direct_test.go"), + File: file, + Tok: tok, + Src: content, + Mapper: mapper, + } + + index := NewIndex([]*parsego.File{pgf}, info) + results := index.All() + + foundMain := false + foundFirst := false + foundSecond := false + for _, r := range results { + t.Logf("Found test: %s", r.Name) + switch r.Name { + case "TestDirect": + foundMain = true + case "TestDirect/first": + foundFirst = true + case "TestDirect/second": + foundSecond = true + } + } + + if !foundMain { + t.Error("did not find main test function TestDirect") + } + if !foundFirst { + t.Error("did not find subtest TestDirect/first") + } + if !foundSecond { + t.Error("did not find subtest TestDirect/second") + } +} diff --git a/gopls/internal/golang/code_lens.go b/gopls/internal/golang/code_lens.go index b04724e0cbc..203f57b9b6a 100644 --- a/gopls/internal/golang/code_lens.go +++ b/gopls/internal/golang/code_lens.go @@ -55,6 +55,13 @@ func runTestCodeLens(ctx context.Context, snapshot *cache.Snapshot, fh file.Hand codeLens = append(codeLens, protocol.CodeLens{Range: rng, Command: cmd}) } + // Add code lenses for subtests (including table-driven subtests) + subtestLenses, err := subtestCodeLenses(ctx, snapshot, pkg, puri) + if err != nil { + return nil, err + } + codeLens = append(codeLens, subtestLenses...) + for _, fn := range benchFuncs { cmd := command.NewRunTestsCommand("run benchmark", command.RunTestsArgs{ URI: puri, @@ -154,6 +161,46 @@ func matchTestFunc(fn *ast.FuncDecl, info *types.Info, nameRe *regexp.Regexp, pa return namedObj.Id() == paramID } +// subtestCodeLenses returns code lenses for subtests, including table-driven subtests. +func subtestCodeLenses(ctx context.Context, snapshot *cache.Snapshot, pkg *cache.Package, uri protocol.DocumentURI) ([]protocol.CodeLens, error) { + // Get test index which includes subtests + indexes, err := snapshot.Tests(ctx, pkg.Metadata().ID) + if err != nil { + return nil, err + } + if len(indexes) == 0 { + return nil, nil + } + + var codeLens []protocol.CodeLens + for _, idx := range indexes { + if idx == nil { + continue + } + for _, result := range idx.All() { + // Only show code lenses for subtests in the current file + if result.Location.URI != uri { + continue + } + + // Skip top-level tests (they already have code lenses) + if !strings.Contains(result.Name, "/") { + continue + } + + // Create a code lens for this subtest + cmd := command.NewRunTestsCommand("run subtest", command.RunTestsArgs{ + URI: uri, + Tests: []string{result.Name}, + }) + rng := protocol.Range{Start: result.Location.Range.Start, End: result.Location.Range.Start} + codeLens = append(codeLens, protocol.CodeLens{Range: rng, Command: cmd}) + } + } + + return codeLens, nil +} + func goGenerateCodeLens(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle) ([]protocol.CodeLens, error) { pgf, err := snapshot.ParseGo(ctx, fh, parsego.Full) if err != nil { From ab30669dd270373c4454325e7ca79b963862e8e2 Mon Sep 17 00:00:00 2001 From: tshihad Date: Thu, 27 Nov 2025 18:04:20 +0100 Subject: [PATCH 2/2] support-table-test: Remove test --- gopls/internal/cache/testfuncs/tests_test.go | 275 ------------------- 1 file changed, 275 deletions(-) delete mode 100644 gopls/internal/cache/testfuncs/tests_test.go diff --git a/gopls/internal/cache/testfuncs/tests_test.go b/gopls/internal/cache/testfuncs/tests_test.go deleted file mode 100644 index 9897490f737..00000000000 --- a/gopls/internal/cache/testfuncs/tests_test.go +++ /dev/null @@ -1,275 +0,0 @@ -// Copyright 2024 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package testfuncs - -import ( - "go/ast" - "go/importer" - "go/parser" - "go/token" - "go/types" - "testing" - - "golang.org/x/tools/gopls/internal/cache/parsego" - "golang.org/x/tools/gopls/internal/protocol" -) - -func TestTableDrivenSubtests(t *testing.T) { - src := `package p - -import "testing" - -func TestExample(t *testing.T) { - tests := []struct { - name string - x int - want int - }{ - {name: "zero", x: 0, want: 0}, - {name: "one", x: 1, want: 1}, - {name: "two", x: 2, want: 2}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if tt.x != tt.want { - t.Errorf("got %d, want %d", tt.x, tt.want) - } - }) - } -} -` - - fset := token.NewFileSet() - file, err := parser.ParseFile(fset, "example_test.go", src, 0) - if err != nil { - t.Fatal(err) - } - - info := &types.Info{ - Types: make(map[ast.Expr]types.TypeAndValue), - Defs: make(map[*ast.Ident]types.Object), - Uses: make(map[*ast.Ident]types.Object), - Selections: make(map[*ast.SelectorExpr]*types.Selection), - } - - conf := types.Config{Importer: importer.Default()} - _, err = conf.Check("p", fset, []*ast.File{file}, info) - if err != nil { - t.Fatalf("type checking failed: %v", err) - } - - // Create the mapper - tok := fset.File(file.Pos()) - content := []byte(src) - mapper := protocol.NewMapper(protocol.DocumentURI("file:///example_test.go"), content) - - pgf := &parsego.File{ - URI: protocol.DocumentURI("file:///example_test.go"), - File: file, - Tok: tok, - Src: content, - Mapper: mapper, - } - - index := NewIndex([]*parsego.File{pgf}, info) - results := index.All() - - // Debug: log what we found - t.Logf("Found %d results", len(results)) - - // We expect at least the main test function - if len(results) == 0 { - t.Fatal("expected at least one test result") - } - - // Check that we found the main test - foundMain := false - foundSubs := 0 - for _, r := range results { - t.Logf("Found test: %s", r.Name) - if r.Name == "TestExample" { - foundMain = true - } - if r.Name == "TestExample/zero" || r.Name == "TestExample/one" || r.Name == "TestExample/two" { - foundSubs++ - } - } - - if !foundMain { - t.Error("did not find main test function TestExample") - } - - // This is the new functionality - we should find the table-driven subtests - if foundSubs != 3 { - t.Errorf("expected to find 3 subtests, found %d", foundSubs) - } -} - -func TestNestedTableDrivenSubtests(t *testing.T) { - src := `package p - -import "testing" - -func TestNested(t *testing.T) { - tests := []struct { - name string - x int - }{ - {name: "outer1", x: 1}, - {name: "outer2", x: 2}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - subtests := []struct { - name string - y int - }{ - {name: "inner1", y: 10}, - {name: "inner2", y: 20}, - } - for _, st := range subtests { - t.Run(st.name, func(t *testing.T) { - // nested test - }) - } - }) - } -} -` - - fset := token.NewFileSet() - file, err := parser.ParseFile(fset, "nested_test.go", src, 0) - if err != nil { - t.Fatal(err) - } - - info := &types.Info{ - Types: make(map[ast.Expr]types.TypeAndValue), - Defs: make(map[*ast.Ident]types.Object), - Uses: make(map[*ast.Ident]types.Object), - Selections: make(map[*ast.SelectorExpr]*types.Selection), - } - - conf := types.Config{Importer: importer.Default()} - _, err = conf.Check("p", fset, []*ast.File{file}, info) - if err != nil { - t.Fatalf("type checking failed: %v", err) - } - - tok := fset.File(file.Pos()) - content := []byte(src) - mapper := protocol.NewMapper(protocol.DocumentURI("file:///nested_test.go"), content) - - pgf := &parsego.File{ - URI: protocol.DocumentURI("file:///nested_test.go"), - File: file, - Tok: tok, - Src: content, - Mapper: mapper, - } - - index := NewIndex([]*parsego.File{pgf}, info) - results := index.All() - - foundOuter1 := false - foundOuter2 := false - - for _, r := range results { - t.Logf("Found test: %s", r.Name) - switch r.Name { - case "TestNested/outer1": - foundOuter1 = true - case "TestNested/outer2": - foundOuter2 = true - } - } - - // We should find the outer table-driven subtests - // Note: Nested table-driven tests (table-driven tests inside function literals - // passed to t.Run) are not currently supported and would require more complex - // analysis. This is an acceptable limitation. - if !foundOuter1 { - t.Error("did not find subtest TestNested/outer1") - } - if !foundOuter2 { - t.Error("did not find subtest TestNested/outer2") - } -} - -func TestDirectSubtests(t *testing.T) { - src := `package p - -import "testing" - -func TestDirect(t *testing.T) { - t.Run("first", func(t *testing.T) { - // test code - }) - t.Run("second", func(t *testing.T) { - // test code - }) -} -` - - fset := token.NewFileSet() - file, err := parser.ParseFile(fset, "direct_test.go", src, 0) - if err != nil { - t.Fatal(err) - } - - info := &types.Info{ - Types: make(map[ast.Expr]types.TypeAndValue), - Defs: make(map[*ast.Ident]types.Object), - Uses: make(map[*ast.Ident]types.Object), - Selections: make(map[*ast.SelectorExpr]*types.Selection), - } - - conf := types.Config{Importer: importer.Default()} - _, err = conf.Check("p", fset, []*ast.File{file}, info) - if err != nil { - t.Fatalf("type checking failed: %v", err) - } - - // Create the mapper - tok := fset.File(file.Pos()) - content := []byte(src) - mapper := protocol.NewMapper(protocol.DocumentURI("file:///direct_test.go"), content) - - pgf := &parsego.File{ - URI: protocol.DocumentURI("file:///direct_test.go"), - File: file, - Tok: tok, - Src: content, - Mapper: mapper, - } - - index := NewIndex([]*parsego.File{pgf}, info) - results := index.All() - - foundMain := false - foundFirst := false - foundSecond := false - for _, r := range results { - t.Logf("Found test: %s", r.Name) - switch r.Name { - case "TestDirect": - foundMain = true - case "TestDirect/first": - foundFirst = true - case "TestDirect/second": - foundSecond = true - } - } - - if !foundMain { - t.Error("did not find main test function TestDirect") - } - if !foundFirst { - t.Error("did not find subtest TestDirect/first") - } - if !foundSecond { - t.Error("did not find subtest TestDirect/second") - } -}