diff --git a/backend/azureblob/azureblob.go b/backend/azureblob/azureblob.go index 39d797122..a6929cc4f 100644 --- a/backend/azureblob/azureblob.go +++ b/backend/azureblob/azureblob.go @@ -4,8 +4,10 @@ package azureblob import ( + "bytes" "context" "crypto/md5" + "crypto/rand" "encoding/base64" "encoding/binary" "encoding/hex" @@ -1882,6 +1884,12 @@ func (o *Object) getBlobSVC() *blob.Client { return o.fs.getBlobSVC(container, directory) } +// getBlockBlobSVC creates a block blob client +func (o *Object) getBlockBlobSVC() *blockblob.Client { + container, directory := o.split() + return o.fs.getBlockBlobSVC(container, directory) +} + // clearMetaData clears enough metadata so readMetaData will re-read it func (o *Object) clearMetaData() { o.modTime = time.Time{} @@ -2052,6 +2060,58 @@ func (rs *readSeekCloser) Close() error { return nil } +// a creator for blockIDs with an incrementing part and a random part +// +// The random part is to make sure that blockIDs don't collide between +// uploads. We need block IDs not to be shared between upload attempts +// so we can remove the uncommitted blocks properly on errors. +type blockIDCreator struct { + random [8]byte // randomness to make sure blocks don't collide +} + +// create a new blockID creator with a random suffix +func newBlockIDCreator() (bic *blockIDCreator, err error) { + bic = &blockIDCreator{} + n, err := rand.Read(bic.random[:]) + if err != nil { + return nil, fmt.Errorf("crypto rand failed: %w", err) + } + if n != len(bic.random) { + return nil, errors.New("crypto rand failed: short read") + } + return bic, nil +} + +// create a new block ID for chunkNumber +func (bic *blockIDCreator) newBlockID(chunkNumber uint64) string { + var binaryBlockID [16]byte + // block counter as LSB first 8 bytes + binary.BigEndian.PutUint64(binaryBlockID[:8], chunkNumber) + // random bits at the end + copy(binaryBlockID[8:], bic.random[:]) + // return base64 encoded value + return base64.StdEncoding.EncodeToString(binaryBlockID[:]) +} + +// Check the chunkNumber is correct in the id +func (bic *blockIDCreator) checkID(chunkNumber uint64, id string) error { + binaryBlockID, err := base64.StdEncoding.DecodeString(id) + if err != nil { + return fmt.Errorf("internal error: bad block ID: %w", err) + } + if len(binaryBlockID) != 16 { + return errors.New("internal error: bad block ID length") + } + gotChunkNumber := binary.BigEndian.Uint64(binaryBlockID[:8]) + if chunkNumber != gotChunkNumber { + return fmt.Errorf("internal error: expecting decoded chunkNumber %d but got %d", chunkNumber, gotChunkNumber) + } + if !bytes.Equal(binaryBlockID[8:], bic.random[:]) { + return fmt.Errorf("internal error: random bytes are incorrect") + } + return nil +} + // record chunk number and id for Close type azBlock struct { chunkNumber uint64 @@ -2067,6 +2127,7 @@ type azChunkWriter struct { blocksMu sync.Mutex // protects the below blocks []azBlock // list of blocks for finalize o *Object + bic *blockIDCreator } // OpenChunkWriter returns the chunk size and a ChunkWriter @@ -2129,6 +2190,10 @@ func (f *Fs) OpenChunkWriter(ctx context.Context, remote string, src fs.ObjectIn Concurrency: o.fs.opt.UploadConcurrency, //LeavePartsOnError: o.fs.opt.LeavePartsOnError, } + chunkWriter.bic, err = newBlockIDCreator() + if err != nil { + return info, nil, err + } fs.Debugf(o, "open chunk writer: started multipart upload") return info, chunkWriter, nil } @@ -2152,10 +2217,8 @@ func (w *azChunkWriter) WriteChunk(ctx context.Context, chunkNumber int, reader } md5sum := m.Sum(nil) - // increment the blockID and save the blocks for finalize - var binaryBlockID [8]byte // block counter as LSB first 8 bytes - binary.LittleEndian.PutUint64(binaryBlockID[:], uint64(chunkNumber)) - blockID := base64.StdEncoding.EncodeToString(binaryBlockID[:]) + // Create a new blockID + blockID := w.bic.newBlockID(uint64(chunkNumber)) // Save the blockID for the commit w.blocksMu.Lock() @@ -2196,31 +2259,111 @@ func (w *azChunkWriter) WriteChunk(ctx context.Context, chunkNumber int, reader return currentChunkSize, err } -// Abort the multipart upload. +// Clear uncommitted blocks // -// FIXME it would be nice to delete uncommitted blocks. +// There isn't an API to clear uncommitted blocks. // -// See: https://github.com/rclone/rclone/issues/5583 +// However they are released when a Commit is called. Doing this will +// instantiate the object so we don't want to overwrite an existing +// object. // -// However there doesn't seem to be an easy way of doing this other than -// by deleting the target. +// We will use this algorithm: // -// This means that a failed upload deletes the target which isn't ideal. -// -// Uploading a zero length blob and deleting it will remove the -// uncommitted blocks I think. -// -// Could check to see if a file exists already and if it doesn't then -// create a 0 length file and delete it to flush the uncommitted -// blocks. -// -// This is what azcopy does -// https://github.com/MicrosoftDocs/azure-docs/issues/36347#issuecomment-541457962 -func (w *azChunkWriter) Abort(ctx context.Context) error { - fs.Debugf(w.o, "multipart upload aborted (did nothing - see issue #5583)") +// Attempt to read committed blocks from the object +// If the object exists +// - Commit the existing blocks again +// - This should get rid of the uncommitted blocks without changing the existing object +// If the object does not exist then +// - Commit an empty block list +// - This will get rid of the uncommitted blocks +// - This will also create a 0 length blob +// - So delete the 0 length blob +func (o *Object) clearUncomittedBlocks(ctx context.Context) (err error) { + fs.Debugf(o, "Clearing uncommitted blocks") + var ( + blockBlobSVC = o.getBlockBlobSVC() + objectExists = true + blockIDs []string + blockList blockblob.GetBlockListResponse + properties *blob.GetPropertiesResponse + options *blockblob.CommitBlockListOptions + ) + + properties, err = o.readMetaDataAlways(ctx) + if err == fs.ErrorObjectNotFound { + objectExists = false + } else if err != nil { + return fmt.Errorf("clear uncommitted blocks: failed to read metadata: %w", err) + } + + if objectExists { + // Get the committed block list + err = o.fs.pacer.Call(func() (bool, error) { + blockList, err = blockBlobSVC.GetBlockList(ctx, blockblob.BlockListTypeAll, nil) + return o.fs.shouldRetry(ctx, err) + }) + if err != nil { + return fmt.Errorf("clear uncommitted blocks: failed to read uncommitted block list: %w", err) + } + if len(blockList.UncommittedBlocks) == 0 { + fs.Debugf(o, "No uncommitted blocks - exiting") + return nil + } + fs.Debugf(o, "%d Uncommitted blocks found", len(blockList.UncommittedBlocks)) + objectExists = true + uncommittedBlocks := make(map[string]struct{}, len(blockList.UncommittedBlocks)) + for _, block := range blockList.UncommittedBlocks { + uncommittedBlocks[*block.Name] = struct{}{} + } + for _, block := range blockList.CommittedBlocks { + name := *block.Name + if _, found := uncommittedBlocks[name]; found { + return fmt.Errorf("clear uncommitted blocks: can't safely clear uncommitted blocks as committed and uncommitted IDs overlap. Delete the existing object to clear the uncommitted blocks") + } + blockIDs = append(blockIDs, name) + } + + // Reconstruct metadata from existing object as CommitBlockList overwrites it + options = &blockblob.CommitBlockListOptions{ + Metadata: properties.Metadata, + Tier: (*blob.AccessTier)(properties.AccessTier), + HTTPHeaders: &blob.HTTPHeaders{ + BlobCacheControl: properties.CacheControl, + BlobContentDisposition: properties.ContentDisposition, + BlobContentEncoding: properties.ContentEncoding, + BlobContentLanguage: properties.ContentLanguage, + BlobContentMD5: properties.ContentMD5, + BlobContentType: properties.ContentType, + }, + } + } + + // Commit only the committed blocks + fs.Debugf(o, "Committing %d blocks to remove uncommitted blocks", len(blockIDs)) + err = o.fs.pacer.Call(func() (bool, error) { + _, err := blockBlobSVC.CommitBlockList(ctx, blockIDs, options) + return o.fs.shouldRetry(ctx, err) + }) + if err != nil { + return fmt.Errorf("clear uncommitted blocks: failed to commit block list: %w", err) + } + + // If object didn't exist before, then delete it + if !objectExists { + fs.Debugf(o, "Removing empty object") + err = o.Remove(ctx) + if err != nil { + return fmt.Errorf("clear uncommitted blocks: failed to remove empty object: %w", err) + } + } return nil } +// Abort the multipart upload. +func (w *azChunkWriter) Abort(ctx context.Context) error { + return w.o.clearUncomittedBlocks(ctx) +} + // Close and finalise the multipart upload func (w *azChunkWriter) Close(ctx context.Context) (err error) { // sort the completed parts by part number @@ -2234,13 +2377,9 @@ func (w *azChunkWriter) Close(ctx context.Context) (err error) { if w.blocks[i].chunkNumber != uint64(i) { return fmt.Errorf("internal error: expecting chunkNumber %d but got %d", i, w.blocks[i].chunkNumber) } - chunkBytes, err := base64.StdEncoding.DecodeString(w.blocks[i].id) + err := w.bic.checkID(w.blocks[i].chunkNumber, w.blocks[i].id) if err != nil { - return fmt.Errorf("internal error: bad block ID: %w", err) - } - chunkNumber := binary.LittleEndian.Uint64(chunkBytes) - if w.blocks[i].chunkNumber != chunkNumber { - return fmt.Errorf("internal error: expecting decoded chunkNumber %d but got %d", w.blocks[i].chunkNumber, chunkNumber) + return err } blockIDs[i] = w.blocks[i].id } diff --git a/backend/azureblob/azureblob_internal_test.go b/backend/azureblob/azureblob_internal_test.go index f1606cdc3..6c96dbb97 100644 --- a/backend/azureblob/azureblob_internal_test.go +++ b/backend/azureblob/azureblob_internal_test.go @@ -3,9 +3,11 @@ package azureblob import ( + "encoding/base64" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func (f *Fs) InternalTest(t *testing.T) { @@ -16,3 +18,31 @@ func (f *Fs) InternalTest(t *testing.T) { enabled = f.Features().GetTier assert.True(t, enabled) } + +func TestBlockIDCreator(t *testing.T) { + // Check creation and random number + bic, err := newBlockIDCreator() + require.NoError(t, err) + bic2, err := newBlockIDCreator() + require.NoError(t, err) + assert.NotEqual(t, bic.random, bic2.random) + assert.NotEqual(t, bic.random, [8]byte{}) + + // Set random to known value for tests + bic.random = [8]byte{1, 2, 3, 4, 5, 6, 7, 8} + chunkNumber := uint64(0xFEDCBA9876543210) + + // Check creation of ID + want := base64.StdEncoding.EncodeToString([]byte{0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, 1, 2, 3, 4, 5, 6, 7, 8}) + assert.Equal(t, "/ty6mHZUMhABAgMEBQYHCA==", want) + got := bic.newBlockID(chunkNumber) + assert.Equal(t, want, got) + assert.Equal(t, "/ty6mHZUMhABAgMEBQYHCA==", got) + + // Test checkID is working + assert.NoError(t, bic.checkID(chunkNumber, got)) + assert.ErrorContains(t, bic.checkID(chunkNumber, "$"+got), "illegal base64") + assert.ErrorContains(t, bic.checkID(chunkNumber, "AAAA"+got), "bad block ID length") + assert.ErrorContains(t, bic.checkID(chunkNumber+1, got), "expecting decoded") + assert.ErrorContains(t, bic2.checkID(chunkNumber, got), "random bytes") +}