diff --git a/fs/fserrors/error.go b/fs/fserrors/error.go index 556ccb398..6b8c364f8 100644 --- a/fs/fserrors/error.go +++ b/fs/fserrors/error.go @@ -188,6 +188,12 @@ func Cause(cause error) (retriable bool, err error) { // Unwrap 1 level if possible err = errors.Cause(err) + if err == nil { + // errors.Cause can return nil which isn't + // desirable so pick the previous error in + // this case. + err = prev + } if err == prev { // Unpack any struct or *struct with a field // of name Err which satisfies the error @@ -196,11 +202,11 @@ func Cause(cause error) (retriable bool, err error) { // others in the stdlib errType := reflect.TypeOf(err) errValue := reflect.ValueOf(err) - if errType.Kind() == reflect.Ptr { + if errValue.IsValid() && errType.Kind() == reflect.Ptr { errType = errType.Elem() errValue = errValue.Elem() } - if errType.Kind() == reflect.Struct { + if errValue.IsValid() && errType.Kind() == reflect.Struct { if errField := errValue.FieldByName("Err"); errField.IsValid() { errFieldValue := errField.Interface() if newErr, ok := errFieldValue.(error); ok { diff --git a/fs/fserrors/error_test.go b/fs/fserrors/error_test.go index b0aa7c21d..67f217ed3 100644 --- a/fs/fserrors/error_test.go +++ b/fs/fserrors/error_test.go @@ -39,7 +39,15 @@ type myError2 struct { Err error } -func (e *myError2) Error() string { return e.Err.Error() } +func (e *myError2) Error() string { + if e == nil { + return "myError2(nil)" + } + if e.Err == nil { + return "myError2{Err: nil}" + } + return e.Err.Error() +} type myError3 struct { Err int @@ -53,11 +61,23 @@ type myError4 struct { func (e *myError4) Error() string { return e.e.Error() } +type errorCause struct { + e error +} + +func (e *errorCause) Error() string { return fmt.Sprintf("%#v", e) } + +func (e *errorCause) Cause() error { return e.e } + func TestCause(t *testing.T) { e3 := &myError3{3} e4 := &myError4{io.EOF} - + eNil1 := &myError2{nil} + eNil2 := &myError2{Err: (*myError2)(nil)} errPotato := errors.New("potato") + nilCause1 := &errorCause{nil} + nilCause2 := &errorCause{(*myError2)(nil)} + for i, test := range []struct { err error wantRetriable bool @@ -70,10 +90,15 @@ func TestCause(t *testing.T) { {errUseOfClosedNetworkConnection, false, errUseOfClosedNetworkConnection}, {makeNetErr(syscall.EAGAIN), true, syscall.EAGAIN}, {makeNetErr(syscall.Errno(123123123)), false, syscall.Errno(123123123)}, + {eNil1, false, eNil1}, + {eNil2, false, eNil2.Err}, {myError1{io.EOF}, false, io.EOF}, {&myError2{io.EOF}, false, io.EOF}, {e3, false, e3}, {e4, false, e4}, + {&errorCause{errPotato}, false, errPotato}, + {nilCause1, false, nilCause1}, + {nilCause2, false, nilCause2.e}, } { gotRetriable, gotErr := Cause(test.err) what := fmt.Sprintf("test #%d: %v", i, test.err)