pool: Add ability to wait for a write to RW

This commit is contained in:
Nick Craig-Wood 2024-03-13 16:32:45 +00:00
parent cb2d2d72a0
commit d08b49d723
2 changed files with 43 additions and 10 deletions

View File

@ -1,9 +1,11 @@
package pool package pool
import ( import (
"context"
"errors" "errors"
"io" "io"
"sync" "sync"
"time"
) )
// RWAccount is a function which will be called after every read // RWAccount is a function which will be called after every read
@ -24,10 +26,11 @@ type RW struct {
// Shared variables between Read and Write // Shared variables between Read and Write
// Write updates these but Read reads from them // Write updates these but Read reads from them
// They must all stay in sync together // They must all stay in sync together
mu sync.Mutex // protect the shared variables mu sync.Mutex // protect the shared variables
pages [][]byte // backing store pages [][]byte // backing store
size int // size written size int // size written
lastOffset int // size in last page lastOffset int // size in last page
written chan struct{} // signalled when a write happens
// Read side Variables // Read side Variables
out int // offset we are reading from out int // offset we are reading from
@ -48,10 +51,12 @@ var (
// //
// When writing it only appends data. Seek only applies to reading. // When writing it only appends data. Seek only applies to reading.
func NewRW(pool *Pool) *RW { func NewRW(pool *Pool) *RW {
return &RW{ rw := &RW{
pool: pool, pool: pool,
pages: make([][]byte, 0, 16), pages: make([][]byte, 0, 16),
written: make(chan struct{}, 1),
} }
return rw
} }
// SetAccounting should be provided with a function which will be // SetAccounting should be provided with a function which will be
@ -217,6 +222,7 @@ func (rw *RW) Write(p []byte) (n int, err error) {
rw.size += nn rw.size += nn
rw.lastOffset += nn rw.lastOffset += nn
rw.mu.Unlock() rw.mu.Unlock()
rw.signalWrite() // signal more data available
} }
return n, nil return n, nil
} }
@ -240,6 +246,7 @@ func (rw *RW) ReadFrom(r io.Reader) (n int64, err error) {
rw.size += nn rw.size += nn
rw.lastOffset += nn rw.lastOffset += nn
rw.mu.Unlock() rw.mu.Unlock()
rw.signalWrite() // signal more data available
} }
if err == io.EOF { if err == io.EOF {
err = nil err = nil
@ -247,6 +254,29 @@ func (rw *RW) ReadFrom(r io.Reader) (n int64, err error) {
return n, err return n, err
} }
// signal that a write has happened
func (rw *RW) signalWrite() {
select {
case rw.written <- struct{}{}:
default:
}
}
// WaitWrite sleeps until a data is written to the RW or Close is
// called or the context is cancelled occurs or for a maximum of 1
// Second then returns.
//
// This can be used when calling Read while the buffer is filling up.
func (rw *RW) WaitWrite(ctx context.Context) {
timer := time.NewTimer(time.Second)
select {
case <-timer.C:
case <-ctx.Done():
case <-rw.written:
}
timer.Stop()
}
// Seek sets the offset for the next Read (not Write - this is always // Seek sets the offset for the next Read (not Write - this is always
// appended) to offset, interpreted according to whence: SeekStart // appended) to offset, interpreted according to whence: SeekStart
// means relative to the start of the file, SeekCurrent means relative // means relative to the start of the file, SeekCurrent means relative
@ -286,6 +316,7 @@ func (rw *RW) Seek(offset int64, whence int) (int64, error) {
func (rw *RW) Close() error { func (rw *RW) Close() error {
rw.mu.Lock() rw.mu.Lock()
defer rw.mu.Unlock() defer rw.mu.Unlock()
rw.signalWrite() // signal more data available
for _, page := range rw.pages { for _, page := range rw.pages {
rw.pool.Put(page) rw.pool.Put(page)
} }

View File

@ -2,6 +2,7 @@ package pool
import ( import (
"bytes" "bytes"
"context"
"errors" "errors"
"io" "io"
"sync" "sync"
@ -526,7 +527,7 @@ func TestRWConcurrency(t *testing.T) {
} }
// Read the data back from inP and check it is OK // Read the data back from inP and check it is OK
check := func(in io.Reader, size int64) { check := func(in io.Reader, size int64, rw *RW) {
ck := readers.NewPatternReader(size) ck := readers.NewPatternReader(size)
ckBuf := make([]byte, bufSize) ckBuf := make([]byte, bufSize)
rwBuf := make([]byte, bufSize) rwBuf := make([]byte, bufSize)
@ -549,6 +550,7 @@ func TestRWConcurrency(t *testing.T) {
if nin >= len(rwBuf) || nn >= size || inErr != io.EOF { if nin >= len(rwBuf) || nn >= size || inErr != io.EOF {
break break
} }
rw.WaitWrite(context.Background())
} }
require.Equal(t, ckBuf[:nck], rwBuf[:nin]) require.Equal(t, ckBuf[:nck], rwBuf[:nin])
if ckErr == io.EOF && inErr == io.EOF { if ckErr == io.EOF && inErr == io.EOF {
@ -560,7 +562,7 @@ func TestRWConcurrency(t *testing.T) {
// Read the data back and check it is OK // Read the data back and check it is OK
read := func(rw *RW, size int64) { read := func(rw *RW, size int64) {
check(rw, size) check(rw, size, rw)
} }
// Read the data back and check it is OK in using WriteTo // Read the data back and check it is OK in using WriteTo
@ -570,7 +572,7 @@ func TestRWConcurrency(t *testing.T) {
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
check(in, size) check(in, size, rw)
}() }()
var n int64 var n int64
for n < size { for n < size {