diff --git a/backend/s3/s3.go b/backend/s3/s3.go index bd2de938e..28299066b 100644 --- a/backend/s3/s3.go +++ b/backend/s3/s3.go @@ -24,6 +24,7 @@ import ( "net/url" "path" "regexp" + "strconv" "strings" "time" @@ -1625,7 +1626,7 @@ func pathEscape(s string) string { // // It adds the boiler plate to the req passed in and calls the s3 // method -func (f *Fs) copy(ctx context.Context, req *s3.CopyObjectInput, dstBucket, dstPath, srcBucket, srcPath string) error { +func (f *Fs) copy(ctx context.Context, req *s3.CopyObjectInput, dstBucket, dstPath, srcBucket, srcPath string, srcSize int64) error { req.Bucket = &dstBucket req.ACL = &f.opt.ACL req.Key = &dstPath @@ -1640,12 +1641,113 @@ func (f *Fs) copy(ctx context.Context, req *s3.CopyObjectInput, dstBucket, dstPa if req.StorageClass == nil && f.opt.StorageClass != "" { req.StorageClass = &f.opt.StorageClass } + + if srcSize >= int64(f.opt.UploadCutoff) { + return f.copyMultipart(ctx, req, dstBucket, dstPath, srcBucket, srcPath, srcSize) + } return f.pacer.Call(func() (bool, error) { _, err := f.c.CopyObjectWithContext(ctx, req) return f.shouldRetry(err) }) } +func calculateRange(partSize, partIndex, numParts, totalSize int64) string { + start := partIndex * partSize + var ends string + if partIndex == numParts-1 { + if totalSize >= 0 { + ends = strconv.FormatInt(totalSize, 10) + } + } else { + ends = strconv.FormatInt(start+partSize-1, 10) + } + return fmt.Sprintf("bytes=%v-%v", start, ends) +} + +func (f *Fs) copyMultipart(ctx context.Context, req *s3.CopyObjectInput, dstBucket, dstPath, srcBucket, srcPath string, srcSize int64) (err error) { + var cout *s3.CreateMultipartUploadOutput + if err := f.pacer.Call(func() (bool, error) { + var err error + cout, err = f.c.CreateMultipartUploadWithContext(ctx, &s3.CreateMultipartUploadInput{ + Bucket: &dstBucket, + Key: &dstPath, + }) + return f.shouldRetry(err) + }); err != nil { + return err + } + uid := cout.UploadId + + defer func() { + if err != nil { + // We can try to abort the upload, but ignore the error. + _ = f.pacer.Call(func() (bool, error) { + _, err := f.c.AbortMultipartUploadWithContext(ctx, &s3.AbortMultipartUploadInput{ + Bucket: &dstBucket, + Key: &dstPath, + UploadId: uid, + RequestPayer: req.RequestPayer, + }) + return f.shouldRetry(err) + }) + } + }() + + partSize := int64(f.opt.ChunkSize) + numParts := (srcSize-1)/partSize + 1 + + var parts []*s3.CompletedPart + for partNum := int64(1); partNum <= numParts; partNum++ { + if err := f.pacer.Call(func() (bool, error) { + partNum := partNum + uploadPartReq := &s3.UploadPartCopyInput{ + Bucket: &dstBucket, + Key: &dstPath, + PartNumber: &partNum, + UploadId: uid, + CopySourceRange: aws.String(calculateRange(partSize, partNum-1, numParts, srcSize)), + // Args copy from req + CopySource: req.CopySource, + CopySourceIfMatch: req.CopySourceIfMatch, + CopySourceIfModifiedSince: req.CopySourceIfModifiedSince, + CopySourceIfNoneMatch: req.CopySourceIfNoneMatch, + CopySourceIfUnmodifiedSince: req.CopySourceIfUnmodifiedSince, + CopySourceSSECustomerAlgorithm: req.CopySourceSSECustomerAlgorithm, + CopySourceSSECustomerKey: req.CopySourceSSECustomerKey, + CopySourceSSECustomerKeyMD5: req.CopySourceSSECustomerKeyMD5, + RequestPayer: req.RequestPayer, + SSECustomerAlgorithm: req.SSECustomerAlgorithm, + SSECustomerKey: req.SSECustomerKey, + SSECustomerKeyMD5: req.SSECustomerKeyMD5, + } + uout, err := f.c.UploadPartCopyWithContext(ctx, uploadPartReq) + if err != nil { + return f.shouldRetry(err) + } + parts = append(parts, &s3.CompletedPart{ + PartNumber: &partNum, + ETag: uout.CopyPartResult.ETag, + }) + return false, nil + }); err != nil { + return err + } + } + + return f.pacer.Call(func() (bool, error) { + _, err := f.c.CompleteMultipartUploadWithContext(ctx, &s3.CompleteMultipartUploadInput{ + Bucket: &dstBucket, + Key: &dstPath, + MultipartUpload: &s3.CompletedMultipartUpload{ + Parts: parts, + }, + RequestPayer: req.RequestPayer, + UploadId: uid, + }) + return f.shouldRetry(err) + }) +} + // Copy src to this remote using server side copy operations. // // This is stored with the remote path given @@ -1670,7 +1772,7 @@ func (f *Fs) Copy(ctx context.Context, src fs.Object, remote string) (fs.Object, req := s3.CopyObjectInput{ MetadataDirective: aws.String(s3.MetadataDirectiveCopy), } - err = f.copy(ctx, &req, dstBucket, dstPath, srcBucket, srcPath) + err = f.copy(ctx, &req, dstBucket, dstPath, srcBucket, srcPath, srcObj.Size()) if err != nil { return nil, err } @@ -1833,7 +1935,7 @@ func (o *Object) SetModTime(ctx context.Context, modTime time.Time) error { Metadata: o.meta, MetadataDirective: aws.String(s3.MetadataDirectiveReplace), // replace metadata with that passed in } - return o.fs.copy(ctx, &req, bucket, bucketPath, bucket, bucketPath) + return o.fs.copy(ctx, &req, bucket, bucketPath, bucket, bucketPath, o.bytes) } // Storable raturns a boolean indicating if this object is storable @@ -2071,7 +2173,7 @@ func (o *Object) SetTier(tier string) (err error) { MetadataDirective: aws.String(s3.MetadataDirectiveCopy), StorageClass: aws.String(tier), } - err = o.fs.copy(ctx, &req, bucket, bucketPath, bucket, bucketPath) + err = o.fs.copy(ctx, &req, bucket, bucketPath, bucket, bucketPath, o.bytes) if err != nil { return err }