package operations import ( "context" "errors" "io" "testing" "github.com/rclone/rclone/fs" "github.com/rclone/rclone/fs/hash" "github.com/rclone/rclone/fstest/mockobject" "github.com/rclone/rclone/lib/pool" "github.com/rclone/rclone/lib/readers" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // check interfaces var ( _ io.ReadSeekCloser = (*ReOpen)(nil) _ pool.DelayAccountinger = (*ReOpen)(nil) ) var errorTestError = errors.New("test error") // this is a wrapper for a mockobject with a custom Open function // // breaks indicate the number of bytes to read before returning an // error type reOpenTestObject struct { fs.Object t *testing.T wantStart int64 breaks []int64 unknownSize bool } // Open opens the file for read. Call Close() on the returned io.ReadCloser // // This will break after reading the number of bytes in breaks func (o *reOpenTestObject) Open(ctx context.Context, options ...fs.OpenOption) (io.ReadCloser, error) { gotHash := false gotRange := false startPos := int64(0) for _, option := range options { switch x := option.(type) { case *fs.HashesOption: gotHash = true case *fs.RangeOption: gotRange = true startPos = x.Start if o.unknownSize { assert.Equal(o.t, int64(-1), x.End) } case *fs.SeekOption: startPos = x.Offset } } assert.Equal(o.t, o.wantStart, startPos) // Check if ranging, mustn't have hash if offset != 0 if gotHash && gotRange { assert.Equal(o.t, int64(0), startPos) } rc, err := o.Object.Open(ctx, options...) if err != nil { return nil, err } if len(o.breaks) > 0 { // Pop a breakpoint off N := o.breaks[0] o.breaks = o.breaks[1:] o.wantStart += N // If 0 then return an error immediately if N == 0 { return nil, errorTestError } // Read N bytes then an error r := io.MultiReader(&io.LimitedReader{R: rc, N: N}, readers.ErrorReader{Err: errorTestError}) // Wrap with Close in a new readCloser rc = readCloser{Reader: r, Closer: rc} } return rc, nil } func TestReOpen(t *testing.T) { for _, testName := range []string{"Normal", "WithRangeOption", "WithSeekOption", "UnknownSize"} { t.Run(testName, func(t *testing.T) { // Contents for the mock object var ( reOpenTestcontents = []byte("0123456789") expectedRead = reOpenTestcontents rangeOption *fs.RangeOption seekOption *fs.SeekOption unknownSize = false ) switch testName { case "Normal": case "WithRangeOption": rangeOption = &fs.RangeOption{Start: 1, End: 7} // range is inclusive expectedRead = reOpenTestcontents[1:8] case "WithSeekOption": seekOption = &fs.SeekOption{Offset: 2} expectedRead = reOpenTestcontents[2:] case "UnknownSize": rangeOption = &fs.RangeOption{Start: 1, End: -1} expectedRead = reOpenTestcontents[1:] unknownSize = true default: panic("bad test name") } // Start the test with the given breaks testReOpen := func(breaks []int64, maxRetries int) (*ReOpen, *reOpenTestObject, error) { srcOrig := mockobject.New("potato").WithContent(reOpenTestcontents, mockobject.SeekModeNone) srcOrig.SetUnknownSize(unknownSize) src := &reOpenTestObject{ Object: srcOrig, t: t, breaks: breaks, unknownSize: unknownSize, } opts := []fs.OpenOption{} if rangeOption == nil && seekOption == nil { opts = append(opts, &fs.HashesOption{Hashes: hash.NewHashSet(hash.MD5)}) } if rangeOption != nil { opts = append(opts, rangeOption) src.wantStart = rangeOption.Start } if seekOption != nil { opts = append(opts, seekOption) src.wantStart = seekOption.Offset } rc, err := NewReOpen(context.Background(), src, maxRetries, opts...) return rc, src, err } t.Run("Basics", func(t *testing.T) { // open h, _, err := testReOpen(nil, 10) assert.NoError(t, err) // Check contents read correctly got, err := io.ReadAll(h) assert.NoError(t, err) assert.Equal(t, expectedRead, got) // Check read after end var buf = make([]byte, 1) n, err := h.Read(buf) assert.Equal(t, 0, n) assert.Equal(t, io.EOF, err) // Rewind the stream _, err = h.Seek(0, io.SeekStart) require.NoError(t, err) // Check contents read correctly got, err = io.ReadAll(h) assert.NoError(t, err) assert.Equal(t, expectedRead, got) // Check close assert.NoError(t, h.Close()) // Check double close assert.Equal(t, errFileClosed, h.Close()) // Check read after close n, err = h.Read(buf) assert.Equal(t, 0, n) assert.Equal(t, errFileClosed, err) }) t.Run("ErrorAtStart", func(t *testing.T) { // open with immediate breaking h, _, err := testReOpen([]int64{0}, 10) assert.Equal(t, errorTestError, err) assert.Nil(t, h) }) t.Run("WithErrors", func(t *testing.T) { // open with a few break points but less than the max h, _, err := testReOpen([]int64{2, 1, 3}, 10) assert.NoError(t, err) // check contents got, err := io.ReadAll(h) assert.NoError(t, err) assert.Equal(t, expectedRead, got) // check close assert.NoError(t, h.Close()) }) t.Run("TooManyErrors", func(t *testing.T) { // open with a few break points but >= the max h, _, err := testReOpen([]int64{2, 1, 3}, 3) assert.NoError(t, err) // check contents got, err := io.ReadAll(h) assert.Equal(t, errorTestError, err) assert.Equal(t, expectedRead[:6], got) // check old error is returned var buf = make([]byte, 1) n, err := h.Read(buf) assert.Equal(t, 0, n) assert.Equal(t, errTooManyTries, err) // Check close assert.Equal(t, errFileClosed, h.Close()) }) t.Run("Seek", func(t *testing.T) { // open h, src, err := testReOpen([]int64{2, 1, 3}, 10) assert.NoError(t, err) // Seek to end pos, err := h.Seek(int64(len(expectedRead)), io.SeekStart) assert.NoError(t, err) assert.Equal(t, int64(len(expectedRead)), pos) // Seek to start pos, err = h.Seek(0, io.SeekStart) assert.NoError(t, err) assert.Equal(t, int64(0), pos) // Should not allow seek past end pos, err = h.Seek(int64(len(expectedRead))+1, io.SeekCurrent) if !unknownSize { assert.Equal(t, errSeekPastEnd, err) assert.Equal(t, len(expectedRead), int(pos)) } else { assert.Equal(t, nil, err) assert.Equal(t, len(expectedRead)+1, int(pos)) // Seek back to start to get tests in sync pos, err = h.Seek(0, io.SeekStart) assert.NoError(t, err) assert.Equal(t, int64(0), pos) } // Should not allow seek to negative position start pos, err = h.Seek(-1, io.SeekCurrent) assert.Equal(t, errNegativeSeek, err) assert.Equal(t, 0, int(pos)) // Should not allow seek with invalid whence pos, err = h.Seek(0, 3) assert.Equal(t, errInvalidWhence, err) assert.Equal(t, 0, int(pos)) // check read dst := make([]byte, 5) n, err := h.Read(dst) assert.Nil(t, err) assert.Equal(t, 5, n) assert.Equal(t, expectedRead[:5], dst) // Test io.SeekCurrent pos, err = h.Seek(-3, io.SeekCurrent) assert.Nil(t, err) assert.Equal(t, 2, int(pos)) // Reset the start after a seek, taking into account the offset setWantStart := func(x int64) { src.wantStart = x if rangeOption != nil { src.wantStart += rangeOption.Start } else if seekOption != nil { src.wantStart += seekOption.Offset } } // check read setWantStart(2) n, err = h.Read(dst) assert.Nil(t, err) assert.Equal(t, 5, n) assert.Equal(t, expectedRead[2:7], dst) pos, err = h.Seek(-2, io.SeekCurrent) assert.Nil(t, err) assert.Equal(t, 5, int(pos)) // Test io.SeekEnd pos, err = h.Seek(-3, io.SeekEnd) if !unknownSize { assert.Nil(t, err) assert.Equal(t, len(expectedRead)-3, int(pos)) } else { assert.Equal(t, errBadEndSeek, err) assert.Equal(t, 0, int(pos)) // sync pos, err = h.Seek(1, io.SeekCurrent) assert.Nil(t, err) assert.Equal(t, 6, int(pos)) } // check read dst = make([]byte, 3) setWantStart(int64(len(expectedRead) - 3)) n, err = h.Read(dst) assert.Nil(t, err) assert.Equal(t, 3, n) assert.Equal(t, expectedRead[len(expectedRead)-3:], dst) // check close assert.NoError(t, h.Close()) _, err = h.Seek(0, io.SeekCurrent) assert.Equal(t, errFileClosed, err) }) t.Run("AccountRead", func(t *testing.T) { h, _, err := testReOpen(nil, 10) assert.NoError(t, err) var total int h.SetAccounting(func(n int) error { total += n return nil }) dst := make([]byte, 3) n, err := h.Read(dst) assert.Equal(t, 3, n) assert.NoError(t, err) assert.Equal(t, 3, total) }) t.Run("AccountReadDelay", func(t *testing.T) { h, _, err := testReOpen(nil, 10) assert.NoError(t, err) var total int h.SetAccounting(func(n int) error { total += n return nil }) rewind := func() { _, err := h.Seek(0, io.SeekStart) require.NoError(t, err) } h.DelayAccounting(3) dst := make([]byte, 16) n, err := h.Read(dst) assert.Equal(t, len(expectedRead), n) assert.Equal(t, io.EOF, err) assert.Equal(t, 0, total) rewind() n, err = h.Read(dst) assert.Equal(t, len(expectedRead), n) assert.Equal(t, io.EOF, err) assert.Equal(t, 0, total) rewind() n, err = h.Read(dst) assert.Equal(t, len(expectedRead), n) assert.Equal(t, io.EOF, err) assert.Equal(t, len(expectedRead), total) rewind() n, err = h.Read(dst) assert.Equal(t, len(expectedRead), n) assert.Equal(t, io.EOF, err) assert.Equal(t, 2*len(expectedRead), total) rewind() }) t.Run("AccountReadError", func(t *testing.T) { // Test accounting errors h, _, err := testReOpen(nil, 10) assert.NoError(t, err) h.SetAccounting(func(n int) error { return errorTestError }) dst := make([]byte, 3) n, err := h.Read(dst) assert.Equal(t, 3, n) assert.Equal(t, errorTestError, err) }) }) } }