diff --git a/quicktest.go b/quicktest.go index 65243cf..b787dba 100644 --- a/quicktest.go +++ b/quicktest.go @@ -217,6 +217,38 @@ var ( tbType = reflect.TypeOf(new(testing.TB)).Elem() ) +// getRunFuncSignature checks the signature of the Run method of the type (ex: *testing.T) +// and returns the signature of its function argument (func(t *T) for *testing.T). +func getRunFuncSignature(t reflect.Type) (reflect.Type, error) { + badType := func(detail string) (reflect.Type, error) { + return nil, fmt.Errorf("cannot execute Run with underlying concrete type %s (%s)", t, detail) + } + m, ok := t.MethodByName("Run") + if !ok { + // c.TB doesn't implement a Run method. + return badType("no Run method") + } + mt := m.Type + // fmt.Println(mt) + if mt.NumIn() != 3 || + mt.In(1) != stringType || + mt.NumOut() != 1 || + mt.Out(0) != boolType { + // The Run method doesn't have the right argument counts and types. + return badType("wrong argument count for Run method") + } + farg := mt.In(2) + if farg.Kind() != reflect.Func || + farg.NumIn() != 1 || + farg.NumOut() != 0 || + !farg.In(0).AssignableTo(tbType) { + // The first argument to the Run function arg isn't right. + return badType("bad first argument type for Run method") + } + + return farg, nil +} + // Run runs f as a subtest of t called name. It's a wrapper around // the Run method of c.TB that provides the quicktest checker to f. When // the function completes, c.Done will be called to run any @@ -243,39 +275,56 @@ var ( // A panic is raised when Run is called and the embedded concrete type does not // implement a Run method with a correct signature. func (c *C) Run(name string, f func(c *C)) bool { - badType := func(m string) { - panic(fmt.Sprintf("cannot execute Run with underlying concrete type %T (%s)", c.TB, m)) - } - m := reflect.ValueOf(c.TB).MethodByName("Run") - if !m.IsValid() { - // c.TB doesn't implement a Run method. - badType("no Run method") - } - mt := m.Type() - if mt.NumIn() != 2 || - mt.In(0) != stringType || - mt.NumOut() != 1 || - mt.Out(0) != boolType { - // The Run method doesn't have the right argument counts and types. - badType("wrong argument count for Run method") + cFormat := c.getFormat() + // A wrapper for f that prepares its *C for the subtest + callF := func(tb testing.TB) { + cSub := New(tb) + defer cSub.Done() + cSub.SetFormat(cFormat) + f(cSub) } - farg := mt.In(1) - if farg.Kind() != reflect.Func || - farg.NumIn() != 1 || - farg.NumOut() != 0 || - !farg.In(0).AssignableTo(tbType) { - // The first argument to the Run function arg isn't right. - badType("bad first argument type for Run method") + + // Handle the various signatures of the Run method of c.TB + switch tb := c.TB.(type) { + + // *testing.T + case interface { + Run(string, func(*testing.T)) bool + }: + return tb.Run(name, func(t *testing.T) { + t.Helper() + callF(t) + }) + + // *testing.B + case interface { + Run(string, func(*testing.B)) bool + }: + return tb.Run(name, func(b *testing.B) { + callF(b) + }) + + // *quicktest.C + case interface{ Run(string, func(*C)) bool }: + return tb.Run(name, func(c *C) { + callF(c) + }) + + // any testing.TB, by using reflect + default: + farg, err := getRunFuncSignature(reflect.TypeOf(c.TB)) + if err != nil { + panic(err.Error()) + } + + fv := reflect.MakeFunc(farg, func(args []reflect.Value) []reflect.Value { + callF(args[0].Interface().(testing.TB)) + return nil + }) + + m := reflect.ValueOf(c.TB).MethodByName("Run") + return m.Call([]reflect.Value{reflect.ValueOf(name), fv})[0].Interface().(bool) } - cFormat := c.getFormat() - fv := reflect.MakeFunc(farg, func(args []reflect.Value) []reflect.Value { - c2 := New(args[0].Interface().(testing.TB)) - defer c2.Done() - c2.SetFormat(cFormat) - f(c2) - return nil - }) - return m.Call([]reflect.Value{reflect.ValueOf(name), fv})[0].Interface().(bool) } // Parallel signals that this test is to be run in parallel with (and only with) other parallel tests.