diff --git a/lib/pool/reader_writer.go b/lib/pool/reader_writer.go index 18bd11e8e..c6cf2caea 100644 --- a/lib/pool/reader_writer.go +++ b/lib/pool/reader_writer.go @@ -3,6 +3,7 @@ package pool import ( "errors" "io" + "sync" ) // RWAccount is a function which will be called after every read @@ -12,15 +13,25 @@ import ( type RWAccount func(n int) error // RW contains the state for the read/writer +// +// It can be used as a FIFO to read data from a source and write it out again. type RW struct { - pool *Pool // pool to get pages from - pages [][]byte // backing store - size int // size written - out int // offset we are reading from - lastOffset int // size in last page - account RWAccount // account for a read - reads int // count how many times the data has been read - accountOn int // only account on or after this read + // Written once variables in initialization + pool *Pool // pool to get pages from + account RWAccount // account for a read + accountOn int // only account on or after this read + + // Shared variables between Read and Write + // Write updates these but Read reads from them + // They must all stay in sync together + mu sync.Mutex // protect the shared variables + pages [][]byte // backing store + size int // size written + lastOffset int // size in last page + + // Read side Variables + out int // offset we are reading from + reads int // count how many times the data has been read } var ( @@ -47,6 +58,8 @@ func NewRW(pool *Pool) *RW { // called after every read from the RW. // // It may return an error which will be passed back to the user. +// +// Not thread safe - call in initialization only. func (rw *RW) SetAccounting(account RWAccount) *RW { rw.account = account return rw @@ -73,6 +86,8 @@ type DelayAccountinger interface { // e.g. when calculating hashes. // // Set this to 0 to account everything. +// +// Not thread safe - call in initialization only. func (rw *RW) DelayAccounting(i int) { rw.accountOn = i rw.reads = 0 @@ -82,6 +97,8 @@ func (rw *RW) DelayAccounting(i int) { // // Ensure there are pages before calling this. func (rw *RW) readPage(i int) (page []byte) { + rw.mu.Lock() + defer rw.mu.Unlock() // Count a read of the data if we read the first page if i == 0 { rw.reads++ @@ -111,6 +128,13 @@ func (rw *RW) accountRead(n int) error { return nil } +// Returns true if we have read to EOF +func (rw *RW) eof() bool { + rw.mu.Lock() + defer rw.mu.Unlock() + return rw.out >= rw.size +} + // Read reads up to len(p) bytes into p. It returns the number of // bytes read (0 <= n <= len(p)) and any error encountered. If some // data is available but not len(p) bytes, Read returns what is @@ -121,7 +145,7 @@ func (rw *RW) Read(p []byte) (n int, err error) { page []byte ) for len(p) > 0 { - if rw.out >= rw.size { + if rw.eof() { return n, io.EOF } page = rw.readPage(rw.out) @@ -148,7 +172,7 @@ func (rw *RW) WriteTo(w io.Writer) (n int64, err error) { nn int page []byte ) - for rw.out < rw.size { + for !rw.eof() { page = rw.readPage(rw.out) nn, err = w.Write(page) n += int64(nn) @@ -166,6 +190,8 @@ func (rw *RW) WriteTo(w io.Writer) (n int64, err error) { // Get the page we are writing to func (rw *RW) writePage() (page []byte) { + rw.mu.Lock() + defer rw.mu.Unlock() if len(rw.pages) > 0 && rw.lastOffset < rw.pool.bufferSize { return rw.pages[len(rw.pages)-1][rw.lastOffset:] } @@ -187,8 +213,10 @@ func (rw *RW) Write(p []byte) (n int, err error) { nn = copy(page, p) p = p[nn:] n += nn + rw.mu.Lock() rw.size += nn rw.lastOffset += nn + rw.mu.Unlock() } return n, nil } @@ -208,8 +236,10 @@ func (rw *RW) ReadFrom(r io.Reader) (n int64, err error) { page = rw.writePage() nn, err = r.Read(page) n += int64(nn) + rw.mu.Lock() rw.size += nn rw.lastOffset += nn + rw.mu.Unlock() } if err == io.EOF { err = nil @@ -229,7 +259,9 @@ func (rw *RW) ReadFrom(r io.Reader) (n int64, err error) { // beyond the end of the written data is an error. func (rw *RW) Seek(offset int64, whence int) (int64, error) { var abs int64 + rw.mu.Lock() size := int64(rw.size) + rw.mu.Unlock() switch whence { case io.SeekStart: abs = offset @@ -252,6 +284,8 @@ func (rw *RW) Seek(offset int64, whence int) (int64, error) { // Close the buffer returning memory to the pool func (rw *RW) Close() error { + rw.mu.Lock() + defer rw.mu.Unlock() for _, page := range rw.pages { rw.pool.Put(page) } @@ -261,6 +295,8 @@ func (rw *RW) Close() error { // Size returns the number of bytes in the buffer func (rw *RW) Size() int64 { + rw.mu.Lock() + defer rw.mu.Unlock() return int64(rw.size) } diff --git a/lib/pool/reader_writer_test.go b/lib/pool/reader_writer_test.go index e9e02c22e..7f5315fc6 100644 --- a/lib/pool/reader_writer_test.go +++ b/lib/pool/reader_writer_test.go @@ -4,10 +4,12 @@ import ( "bytes" "errors" "io" + "sync" "testing" "time" "github.com/rclone/rclone/lib/random" + "github.com/rclone/rclone/lib/readers" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -489,3 +491,125 @@ func TestRWBoundaryConditions(t *testing.T) { }) } } + +// The RW should be thread safe for reading and writing concurrently +func TestRWConcurrency(t *testing.T) { + const bufSize = 1024 + + // Write data of size using Write + write := func(rw *RW, size int64) { + in := readers.NewPatternReader(size) + buf := make([]byte, bufSize) + nn := int64(0) + for { + nr, inErr := in.Read(buf) + if inErr != nil && inErr != io.EOF { + require.NoError(t, inErr) + } + nw, rwErr := rw.Write(buf[:nr]) + require.NoError(t, rwErr) + assert.Equal(t, nr, nw) + nn += int64(nw) + if inErr == io.EOF { + break + } + } + assert.Equal(t, size, nn) + } + + // Write the data using ReadFrom + readFrom := func(rw *RW, size int64) { + in := readers.NewPatternReader(size) + nn, err := rw.ReadFrom(in) + assert.NoError(t, err) + assert.Equal(t, size, nn) + } + + // Read the data back from inP and check it is OK + check := func(in io.Reader, size int64) { + ck := readers.NewPatternReader(size) + ckBuf := make([]byte, bufSize) + rwBuf := make([]byte, bufSize) + nn := int64(0) + for { + nck, ckErr := ck.Read(ckBuf) + if ckErr != io.EOF { + require.NoError(t, ckErr) + } + var nin int + var inErr error + for { + var nnin int + nnin, inErr = in.Read(rwBuf[nin:]) + if inErr != io.EOF { + require.NoError(t, inErr) + } + nin += nnin + nn += int64(nnin) + if nin >= len(rwBuf) || nn >= size || inErr != io.EOF { + break + } + } + require.Equal(t, ckBuf[:nck], rwBuf[:nin]) + if ckErr == io.EOF && inErr == io.EOF { + break + } + } + assert.Equal(t, size, nn) + } + + // Read the data back and check it is OK + read := func(rw *RW, size int64) { + check(rw, size) + } + + // Read the data back and check it is OK in using WriteTo + writeTo := func(rw *RW, size int64) { + in, out := io.Pipe() + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + check(in, size) + }() + var n int64 + for n < size { + nn, err := rw.WriteTo(out) + assert.NoError(t, err) + n += nn + } + assert.Equal(t, size, n) + require.NoError(t, out.Close()) + wg.Wait() + } + + type test struct { + name string + fn func(*RW, int64) + } + + const size = blockSize*255 + 255 + + // Read and Write the data with a range of block sizes and functions + for _, write := range []test{{"Write", write}, {"ReadFrom", readFrom}} { + t.Run(write.name, func(t *testing.T) { + for _, read := range []test{{"Read", read}, {"WriteTo", writeTo}} { + t.Run(read.name, func(t *testing.T) { + var wg sync.WaitGroup + wg.Add(2) + rw := NewRW(rwPool) + go func() { + defer wg.Done() + read.fn(rw, size) + }() + go func() { + defer wg.Done() + write.fn(rw, size) + }() + wg.Wait() + }) + } + }) + } + +}