diff --git a/errwrap.go b/errwrap.go index 9a080e7..7fd2a47 100644 --- a/errwrap.go +++ b/errwrap.go @@ -172,10 +172,22 @@ func (w *wrappedError) Error() string { return w.Outer.Error() } +func (w *wrappedError) Is(err error) bool { + return errors.Is(w.Outer, err) +} + +func (w *wrappedError) As(target interface{}) bool { + return errors.As(w.Outer, target) +} + func (w *wrappedError) WrappedErrors() []error { return []error{w.Outer, w.Inner} } func (w *wrappedError) Unwrap() error { - return w.Inner + if i := errors.Unwrap(w.Inner); i != nil { + return Wrap(w.Inner, i) + } else { + return w.Inner + } } diff --git a/errwrap_test.go b/errwrap_test.go index 2ef422d..dbe2b5b 100644 --- a/errwrap_test.go +++ b/errwrap_test.go @@ -6,14 +6,17 @@ package errwrap import ( "errors" "fmt" + "strconv" "testing" ) func TestWrappedError_impl(t *testing.T) { + t.Parallel() var _ error = new(wrappedError) } func TestGetAll(t *testing.T) { + t.Parallel() cases := []struct { Err error Msg string @@ -58,15 +61,18 @@ func TestGetAll(t *testing.T) { } for i, tc := range cases { - actual := GetAll(tc.Err, tc.Msg) - if len(actual) != tc.Len { - t.Fatalf("%d: bad: %#v", i, actual) - } - for _, v := range actual { - if v.Error() != tc.Msg { + t.Run(fmt.Sprintf("Test: %d", i), func(t *testing.T) { + t.Parallel() + actual := GetAll(tc.Err, tc.Msg) + if len(actual) != tc.Len { t.Fatalf("%d: bad: %#v", i, actual) } - } + for _, v := range actual { + if v.Error() != tc.Msg { + t.Fatalf("%d: bad: %#v", i, actual) + } + } + }) } } @@ -113,6 +119,8 @@ func TestGetAllType(t *testing.T) { } func TestWrappedError_IsCompatibleWithErrorsUnwrap(t *testing.T) { + t.Parallel() + inner := errors.New("inner error") err := Wrap(errors.New("outer"), inner) actual := errors.Unwrap(err) @@ -120,3 +128,35 @@ func TestWrappedError_IsCompatibleWithErrorsUnwrap(t *testing.T) { t.Fatal("wrappedError did not unwrap to inner") } } + +func TestWrappedError_IsCompatibleWithErrorsIs(t *testing.T) { + t.Parallel() + + inner := errors.New("inner") + outer := errors.New("outer") + wrapped := Wrap(outer, inner) + if !errors.Is(wrapped, outer) { + t.Fatal("wrappedError did not errors.Is() to outer") + } else if !errors.Is(wrapped, inner) { + t.Fatal("wrappedError did not errors.Is() to inner") + } else if errors.Is(wrapped, errors.New("unexpected")) { + t.Fatal("wrappedError should not have errors.Is() to unexpected") + } +} + +type customError int + +func (c customError) Error() string { return strconv.Itoa(int(c)) } + +func TestWrappedError_IsCompatibleWithErrorsAs(t *testing.T) { + t.Parallel() + + inner := customError(123) + outer := errors.New("1234") + wrapped := Wrap(outer, inner) + + var c customError + if !errors.As(wrapped, &c) { + t.Fatal("wrappedError should have errors.As() to customError") + } +}