package batcher import ( "context" "errors" "fmt" "sync" "sync/atomic" "testing" "time" "github.com/rclone/rclone/fs" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) type ( Result string Item string ) func TestBatcherNew(t *testing.T) { ctx := context.Background() ci := fs.GetConfig(ctx) opt := Options{ Mode: "async", Size: 100, Timeout: 1 * time.Second, MaxBatchSize: 1000, DefaultTimeoutSync: 500 * time.Millisecond, DefaultTimeoutAsync: 10 * time.Second, DefaultBatchSizeAsync: 100, } commitBatch := func(ctx context.Context, items []Item, results []Result, errors []error) (err error) { return nil } b, err := New[Item, Result](ctx, nil, commitBatch, opt) require.NoError(t, err) require.True(t, b.Batching()) b.Shutdown() opt.Mode = "sync" b, err = New[Item, Result](ctx, nil, commitBatch, opt) require.NoError(t, err) require.True(t, b.Batching()) b.Shutdown() opt.Mode = "off" b, err = New[Item, Result](ctx, nil, commitBatch, opt) require.NoError(t, err) require.False(t, b.Batching()) b.Shutdown() opt.Mode = "bad" _, err = New[Item, Result](ctx, nil, commitBatch, opt) require.ErrorContains(t, err, "batch mode") opt.Mode = "async" opt.Size = opt.MaxBatchSize + 1 _, err = New[Item, Result](ctx, nil, commitBatch, opt) require.ErrorContains(t, err, "batch size") opt.Mode = "sync" opt.Size = 0 opt.Timeout = 0 b, err = New[Item, Result](ctx, nil, commitBatch, opt) require.NoError(t, err) assert.Equal(t, ci.Transfers, b.opt.Size) assert.Equal(t, opt.DefaultTimeoutSync, b.opt.Timeout) b.Shutdown() opt.Mode = "async" opt.Size = 0 opt.Timeout = 0 b, err = New[Item, Result](ctx, nil, commitBatch, opt) require.NoError(t, err) assert.Equal(t, opt.DefaultBatchSizeAsync, b.opt.Size) assert.Equal(t, opt.DefaultTimeoutAsync, b.opt.Timeout) b.Shutdown() // Check we get an error on commit _, err = b.Commit(ctx, "last", Item("last")) require.ErrorContains(t, err, "shutting down") } func TestBatcherCommit(t *testing.T) { ctx := context.Background() opt := Options{ Mode: "sync", Size: 3, Timeout: 1 * time.Second, MaxBatchSize: 1000, DefaultTimeoutSync: 500 * time.Millisecond, DefaultTimeoutAsync: 10 * time.Second, DefaultBatchSizeAsync: 100, } var wg sync.WaitGroup errFail := errors.New("fail") var commits int var totalSize int commitBatch := func(ctx context.Context, items []Item, results []Result, errors []error) (err error) { commits += 1 totalSize += len(items) for i := range items { if items[i] == "5" { errors[i] = errFail } else { results[i] = Result(items[i]) + " result" } } return nil } b, err := New[Item, Result](ctx, nil, commitBatch, opt) require.NoError(t, err) defer b.Shutdown() for i := 0; i < 10; i++ { wg.Add(1) s := fmt.Sprintf("%d", i) go func() { defer wg.Done() result, err := b.Commit(ctx, s, Item(s)) if s == "5" { assert.True(t, errors.Is(err, errFail)) } else { require.NoError(t, err) assert.Equal(t, Result(s+" result"), result) } }() } wg.Wait() assert.Equal(t, 4, commits) assert.Equal(t, 10, totalSize) } func TestBatcherCommitFail(t *testing.T) { ctx := context.Background() opt := Options{ Mode: "sync", Size: 3, Timeout: 1 * time.Second, MaxBatchSize: 1000, DefaultTimeoutSync: 500 * time.Millisecond, DefaultTimeoutAsync: 10 * time.Second, DefaultBatchSizeAsync: 100, } var wg sync.WaitGroup errFail := errors.New("fail") var commits int var totalSize int commitBatch := func(ctx context.Context, items []Item, results []Result, errors []error) (err error) { commits += 1 totalSize += len(items) return errFail } b, err := New[Item, Result](ctx, nil, commitBatch, opt) require.NoError(t, err) defer b.Shutdown() for i := 0; i < 10; i++ { wg.Add(1) s := fmt.Sprintf("%d", i) go func() { defer wg.Done() _, err := b.Commit(ctx, s, Item(s)) assert.True(t, errors.Is(err, errFail)) }() } wg.Wait() assert.Equal(t, 4, commits) assert.Equal(t, 10, totalSize) } func TestBatcherCommitShutdown(t *testing.T) { ctx := context.Background() opt := Options{ Mode: "sync", Size: 3, Timeout: 1 * time.Second, MaxBatchSize: 1000, DefaultTimeoutSync: 500 * time.Millisecond, DefaultTimeoutAsync: 10 * time.Second, DefaultBatchSizeAsync: 100, } var wg sync.WaitGroup var commits int var totalSize int commitBatch := func(ctx context.Context, items []Item, results []Result, errors []error) (err error) { commits += 1 totalSize += len(items) for i := range items { results[i] = Result(items[i]) } return nil } b, err := New[Item, Result](ctx, nil, commitBatch, opt) require.NoError(t, err) for i := 0; i < 10; i++ { wg.Add(1) s := fmt.Sprintf("%d", i) go func() { defer wg.Done() result, err := b.Commit(ctx, s, Item(s)) assert.NoError(t, err) assert.Equal(t, Result(s), result) }() } time.Sleep(100 * time.Millisecond) b.Shutdown() // shutdown with batches outstanding wg.Wait() assert.Equal(t, 4, commits) assert.Equal(t, 10, totalSize) } func TestBatcherCommitAsync(t *testing.T) { ctx := context.Background() opt := Options{ Mode: "async", Size: 3, Timeout: 1 * time.Second, MaxBatchSize: 1000, DefaultTimeoutSync: 500 * time.Millisecond, DefaultTimeoutAsync: 10 * time.Second, DefaultBatchSizeAsync: 100, } var wg sync.WaitGroup errFail := errors.New("fail") var commits atomic.Int32 var totalSize atomic.Int32 commitBatch := func(ctx context.Context, items []Item, results []Result, errors []error) (err error) { wg.Add(1) defer wg.Done() // t.Logf("commit %d", len(items)) commits.Add(1) totalSize.Add(int32(len(items))) for i := range items { if items[i] == "5" { errors[i] = errFail } else { results[i] = Result(items[i]) + " result" } } return nil } b, err := New[Item, Result](ctx, nil, commitBatch, opt) require.NoError(t, err) defer b.Shutdown() for i := 0; i < 10; i++ { wg.Add(1) s := fmt.Sprintf("%d", i) go func() { defer wg.Done() result, err := b.Commit(ctx, s, Item(s)) // Async just returns straight away require.NoError(t, err) assert.Equal(t, Result(""), result) }() } time.Sleep(2 * time.Second) // wait for batch timeout - needed with async wg.Wait() assert.Equal(t, int32(4), commits.Load()) assert.Equal(t, int32(10), totalSize.Load()) }