mirror of
https://github.com/rclone/rclone.git
synced 2026-03-01 02:41:11 +00:00
Compare commits
9 Commits
test
...
fix-azureb
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b0981ce041 | ||
|
|
90bf8e517c | ||
|
|
197cc57fd5 | ||
|
|
4fdc43d5c9 | ||
|
|
d17425eb1f | ||
|
|
83d0c186a7 | ||
|
|
2887806f33 | ||
|
|
9ed4295e34 | ||
|
|
2fa1a52f22 |
@@ -52,6 +52,7 @@ import (
|
|||||||
"github.com/rclone/rclone/lib/multipart"
|
"github.com/rclone/rclone/lib/multipart"
|
||||||
"github.com/rclone/rclone/lib/pacer"
|
"github.com/rclone/rclone/lib/pacer"
|
||||||
"github.com/rclone/rclone/lib/pool"
|
"github.com/rclone/rclone/lib/pool"
|
||||||
|
"github.com/rclone/rclone/lib/transferaccounter"
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -343,6 +344,16 @@ In tests, copy speed increases almost linearly with copy
|
|||||||
concurrency.`,
|
concurrency.`,
|
||||||
Default: 512,
|
Default: 512,
|
||||||
Advanced: true,
|
Advanced: true,
|
||||||
|
}, {
|
||||||
|
Name: "copy_total_concurrency",
|
||||||
|
Help: `Global concurrency limit for multipart copy chunks.
|
||||||
|
|
||||||
|
This limits the total number of multipart copy chunks running at once
|
||||||
|
across all files.
|
||||||
|
|
||||||
|
Set to 0 to disable this limiter.`,
|
||||||
|
Default: 0,
|
||||||
|
Advanced: true,
|
||||||
}, {
|
}, {
|
||||||
Name: "use_copy_blob",
|
Name: "use_copy_blob",
|
||||||
Help: `Whether to use the Copy Blob API when copying to the same storage account.
|
Help: `Whether to use the Copy Blob API when copying to the same storage account.
|
||||||
@@ -526,6 +537,7 @@ type Options struct {
|
|||||||
ChunkSize fs.SizeSuffix `config:"chunk_size"`
|
ChunkSize fs.SizeSuffix `config:"chunk_size"`
|
||||||
CopyCutoff fs.SizeSuffix `config:"copy_cutoff"`
|
CopyCutoff fs.SizeSuffix `config:"copy_cutoff"`
|
||||||
CopyConcurrency int `config:"copy_concurrency"`
|
CopyConcurrency int `config:"copy_concurrency"`
|
||||||
|
CopyTotalConcurrency int `config:"copy_total_concurrency"`
|
||||||
UseCopyBlob bool `config:"use_copy_blob"`
|
UseCopyBlob bool `config:"use_copy_blob"`
|
||||||
UploadConcurrency int `config:"upload_concurrency"`
|
UploadConcurrency int `config:"upload_concurrency"`
|
||||||
ListChunkSize uint `config:"list_chunk"`
|
ListChunkSize uint `config:"list_chunk"`
|
||||||
@@ -560,6 +572,7 @@ type Fs struct {
|
|||||||
cache *bucket.Cache // cache for container creation status
|
cache *bucket.Cache // cache for container creation status
|
||||||
pacer *fs.Pacer // To pace and retry the API calls
|
pacer *fs.Pacer // To pace and retry the API calls
|
||||||
uploadToken *pacer.TokenDispenser // control concurrency
|
uploadToken *pacer.TokenDispenser // control concurrency
|
||||||
|
copyToken *pacer.TokenDispenser // global multipart copy concurrency limiter
|
||||||
publicAccess container.PublicAccessType // Container Public Access Level
|
publicAccess container.PublicAccessType // Container Public Access Level
|
||||||
|
|
||||||
// user delegation cache
|
// user delegation cache
|
||||||
@@ -802,6 +815,7 @@ func NewFs(ctx context.Context, name, root string, m configmap.Mapper) (fs.Fs, e
|
|||||||
ci: ci,
|
ci: ci,
|
||||||
pacer: fs.NewPacer(ctx, pacer.NewS3(pacer.MinSleep(minSleep), pacer.MaxSleep(maxSleep), pacer.DecayConstant(decayConstant))),
|
pacer: fs.NewPacer(ctx, pacer.NewS3(pacer.MinSleep(minSleep), pacer.MaxSleep(maxSleep), pacer.DecayConstant(decayConstant))),
|
||||||
uploadToken: pacer.NewTokenDispenser(ci.Transfers),
|
uploadToken: pacer.NewTokenDispenser(ci.Transfers),
|
||||||
|
copyToken: pacer.NewTokenDispenser(opt.CopyTotalConcurrency),
|
||||||
cache: bucket.NewCache(),
|
cache: bucket.NewCache(),
|
||||||
cntSVCcache: make(map[string]*container.Client, 1),
|
cntSVCcache: make(map[string]*container.Client, 1),
|
||||||
}
|
}
|
||||||
@@ -1865,18 +1879,26 @@ func (f *Fs) copyMultipart(ctx context.Context, remote, dstContainer, dstPath st
|
|||||||
blockIDs = make([]string, numParts) // list of blocks for finalize
|
blockIDs = make([]string, numParts) // list of blocks for finalize
|
||||||
g, gCtx = errgroup.WithContext(ctx)
|
g, gCtx = errgroup.WithContext(ctx)
|
||||||
checker = newCheckForInvalidBlockOrBlob("copy", o)
|
checker = newCheckForInvalidBlockOrBlob("copy", o)
|
||||||
|
account = transferaccounter.Get(ctx)
|
||||||
)
|
)
|
||||||
g.SetLimit(f.opt.CopyConcurrency)
|
g.SetLimit(f.opt.CopyConcurrency)
|
||||||
|
|
||||||
fs.Debugf(o, "Starting multipart copy with %d parts of size %v", numParts, fs.SizeSuffix(partSize))
|
fs.Debugf(o, "Starting multipart copy with %d parts of size %v", numParts, fs.SizeSuffix(partSize))
|
||||||
|
account.Start()
|
||||||
for partNum := uint64(0); partNum < uint64(numParts); partNum++ {
|
for partNum := uint64(0); partNum < uint64(numParts); partNum++ {
|
||||||
// Fail fast, in case an errgroup managed function returns an error
|
// Fail fast, in case an errgroup managed function returns an error
|
||||||
// gCtx is cancelled. There is no point in uploading all the other parts.
|
// gCtx is cancelled. There is no point in uploading all the other parts.
|
||||||
if gCtx.Err() != nil {
|
if gCtx.Err() != nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
if f.opt.CopyTotalConcurrency > 0 {
|
||||||
|
f.copyToken.Get()
|
||||||
|
}
|
||||||
partNum := partNum // for closure
|
partNum := partNum // for closure
|
||||||
g.Go(func() error {
|
g.Go(func() error {
|
||||||
|
if f.opt.CopyTotalConcurrency > 0 {
|
||||||
|
defer f.copyToken.Put()
|
||||||
|
}
|
||||||
blockID := bic.newBlockID(partNum)
|
blockID := bic.newBlockID(partNum)
|
||||||
options := blockblob.StageBlockFromURLOptions{
|
options := blockblob.StageBlockFromURLOptions{
|
||||||
Range: blob.HTTPRange{
|
Range: blob.HTTPRange{
|
||||||
@@ -1910,6 +1932,7 @@ func (f *Fs) copyMultipart(ctx context.Context, remote, dstContainer, dstPath st
|
|||||||
return fmt.Errorf("multipart copy: failed to copy chunk %d with %v bytes: %w", partNum+1, -1, err)
|
return fmt.Errorf("multipart copy: failed to copy chunk %d with %v bytes: %w", partNum+1, -1, err)
|
||||||
}
|
}
|
||||||
blockIDs[partNum] = blockID
|
blockIDs[partNum] = blockID
|
||||||
|
account.Add(options.Range.Count)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -2765,8 +2788,6 @@ func (o *Object) clearUncommittedBlocks(ctx context.Context) (err error) {
|
|||||||
blockList blockblob.GetBlockListResponse
|
blockList blockblob.GetBlockListResponse
|
||||||
properties *blob.GetPropertiesResponse
|
properties *blob.GetPropertiesResponse
|
||||||
options *blockblob.CommitBlockListOptions
|
options *blockblob.CommitBlockListOptions
|
||||||
// Use temporary pacer as this can be called recursively which can cause a deadlock with --max-connections
|
|
||||||
pacer = fs.NewPacer(ctx, pacer.NewS3(pacer.MinSleep(minSleep), pacer.MaxSleep(maxSleep), pacer.DecayConstant(decayConstant)))
|
|
||||||
)
|
)
|
||||||
|
|
||||||
properties, err = o.readMetaDataAlways(ctx)
|
properties, err = o.readMetaDataAlways(ctx)
|
||||||
@@ -2778,7 +2799,7 @@ func (o *Object) clearUncommittedBlocks(ctx context.Context) (err error) {
|
|||||||
|
|
||||||
if objectExists {
|
if objectExists {
|
||||||
// Get the committed block list
|
// Get the committed block list
|
||||||
err = pacer.Call(func() (bool, error) {
|
err = o.fs.pacer.Call(func() (bool, error) {
|
||||||
blockList, err = blockBlobSVC.GetBlockList(ctx, blockblob.BlockListTypeAll, nil)
|
blockList, err = blockBlobSVC.GetBlockList(ctx, blockblob.BlockListTypeAll, nil)
|
||||||
return o.fs.shouldRetry(ctx, err)
|
return o.fs.shouldRetry(ctx, err)
|
||||||
})
|
})
|
||||||
@@ -2820,7 +2841,7 @@ func (o *Object) clearUncommittedBlocks(ctx context.Context) (err error) {
|
|||||||
|
|
||||||
// Commit only the committed blocks
|
// Commit only the committed blocks
|
||||||
fs.Debugf(o, "Committing %d blocks to remove uncommitted blocks", len(blockIDs))
|
fs.Debugf(o, "Committing %d blocks to remove uncommitted blocks", len(blockIDs))
|
||||||
err = pacer.Call(func() (bool, error) {
|
err = o.fs.pacer.Call(func() (bool, error) {
|
||||||
_, err := blockBlobSVC.CommitBlockList(ctx, blockIDs, options)
|
_, err := blockBlobSVC.CommitBlockList(ctx, blockIDs, options)
|
||||||
return o.fs.shouldRetry(ctx, err)
|
return o.fs.shouldRetry(ctx, err)
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
|
|
||||||
"github.com/rclone/rclone/fs/rc"
|
"github.com/rclone/rclone/fs/rc"
|
||||||
|
"github.com/rclone/rclone/lib/transferaccounter"
|
||||||
|
|
||||||
"github.com/rclone/rclone/fs"
|
"github.com/rclone/rclone/fs"
|
||||||
"github.com/rclone/rclone/fs/asyncreader"
|
"github.com/rclone/rclone/fs/asyncreader"
|
||||||
@@ -312,6 +313,15 @@ func (acc *Account) serverSideEnd(n int64) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewServerSideCopyAccounter returns a TransferAccounter for a server
|
||||||
|
// side copy and a new ctx with it embedded
|
||||||
|
func (acc *Account) NewServerSideCopyAccounter(ctx context.Context) (context.Context, *transferaccounter.TransferAccounter) {
|
||||||
|
return transferaccounter.New(ctx, func(n int64) {
|
||||||
|
acc.stats.AddServerSideCopyBytes(n)
|
||||||
|
acc.accountReadNoNetwork(n)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// ServerSideCopyEnd accounts for a read of n bytes in a server-side copy
|
// ServerSideCopyEnd accounts for a read of n bytes in a server-side copy
|
||||||
func (acc *Account) ServerSideCopyEnd(n int64) {
|
func (acc *Account) ServerSideCopyEnd(n int64) {
|
||||||
acc.stats.AddServerSideCopy(n)
|
acc.stats.AddServerSideCopy(n)
|
||||||
@@ -358,6 +368,17 @@ func (acc *Account) accountRead(n int) {
|
|||||||
acc.limitPerFileBandwidth(n)
|
acc.limitPerFileBandwidth(n)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Account the read if not using network (eg for server side copies)
|
||||||
|
func (acc *Account) accountReadNoNetwork(n int64) {
|
||||||
|
// Update Stats
|
||||||
|
acc.values.mu.Lock()
|
||||||
|
acc.values.lpBytes += int(n)
|
||||||
|
acc.values.bytes += n
|
||||||
|
acc.values.mu.Unlock()
|
||||||
|
|
||||||
|
acc.stats.BytesNoNetwork(n)
|
||||||
|
}
|
||||||
|
|
||||||
// read bytes from the io.Reader passed in and account them
|
// read bytes from the io.Reader passed in and account them
|
||||||
func (acc *Account) read(in io.Reader, p []byte) (n int, err error) {
|
func (acc *Account) read(in io.Reader, p []byte) (n int, err error) {
|
||||||
bytesUntilLimit, err := acc.checkReadBefore()
|
bytesUntilLimit, err := acc.checkReadBefore()
|
||||||
|
|||||||
@@ -938,6 +938,13 @@ func (s *StatsInfo) AddServerSideMove(n int64) {
|
|||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AddServerSideCopyBytes adds bytes for a server side copy
|
||||||
|
func (s *StatsInfo) AddServerSideCopyBytes(n int64) {
|
||||||
|
s.mu.Lock()
|
||||||
|
s.serverSideCopyBytes += n
|
||||||
|
s.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
// AddServerSideCopy counts a server side copy
|
// AddServerSideCopy counts a server side copy
|
||||||
func (s *StatsInfo) AddServerSideCopy(n int64) {
|
func (s *StatsInfo) AddServerSideCopy(n int64) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
|
|||||||
@@ -385,12 +385,14 @@ func (sg *statsGroups) sum(ctx context.Context) *StatsInfo {
|
|||||||
sum.checkQueueSize += stats.checkQueueSize
|
sum.checkQueueSize += stats.checkQueueSize
|
||||||
sum.transfers += stats.transfers
|
sum.transfers += stats.transfers
|
||||||
sum.transferring.merge(stats.transferring)
|
sum.transferring.merge(stats.transferring)
|
||||||
|
sum.transferQueue += stats.transferQueue
|
||||||
sum.transferQueueSize += stats.transferQueueSize
|
sum.transferQueueSize += stats.transferQueueSize
|
||||||
sum.listed += stats.listed
|
sum.listed += stats.listed
|
||||||
sum.renames += stats.renames
|
sum.renames += stats.renames
|
||||||
sum.renameQueue += stats.renameQueue
|
sum.renameQueue += stats.renameQueue
|
||||||
sum.renameQueueSize += stats.renameQueueSize
|
sum.renameQueueSize += stats.renameQueueSize
|
||||||
sum.deletes += stats.deletes
|
sum.deletes += stats.deletes
|
||||||
|
sum.deletesSize += stats.deletesSize
|
||||||
sum.deletedDirs += stats.deletedDirs
|
sum.deletedDirs += stats.deletedDirs
|
||||||
sum.inProgress.merge(stats.inProgress)
|
sum.inProgress.merge(stats.inProgress)
|
||||||
sum.startedTransfers = append(sum.startedTransfers, stats.startedTransfers...)
|
sum.startedTransfers = append(sum.startedTransfers, stats.startedTransfers...)
|
||||||
@@ -399,6 +401,10 @@ func (sg *statsGroups) sum(ctx context.Context) *StatsInfo {
|
|||||||
stats.average.mu.Lock()
|
stats.average.mu.Lock()
|
||||||
sum.average.speed += stats.average.speed
|
sum.average.speed += stats.average.speed
|
||||||
stats.average.mu.Unlock()
|
stats.average.mu.Unlock()
|
||||||
|
sum.serverSideCopies += stats.serverSideCopies
|
||||||
|
sum.serverSideCopyBytes += stats.serverSideCopyBytes
|
||||||
|
sum.serverSideMoves += stats.serverSideMoves
|
||||||
|
sum.serverSideMoveBytes += stats.serverSideMoveBytes
|
||||||
}
|
}
|
||||||
stats.mu.RUnlock()
|
stats.mu.RUnlock()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -148,9 +148,17 @@ func (c *copy) serverSideCopy(ctx context.Context) (actionTaken string, newDst f
|
|||||||
}
|
}
|
||||||
in := c.tr.Account(ctx, nil) // account the transfer
|
in := c.tr.Account(ctx, nil) // account the transfer
|
||||||
in.ServerSideTransferStart()
|
in.ServerSideTransferStart()
|
||||||
newDst, err = doCopy(ctx, c.src, c.remoteForCopy)
|
newCtx, ta := in.NewServerSideCopyAccounter(ctx)
|
||||||
|
newDst, err = doCopy(newCtx, c.src, c.remoteForCopy)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
in.ServerSideCopyEnd(newDst.Size()) // account the bytes for the server-side transfer
|
var n int64
|
||||||
|
if !ta.Started() {
|
||||||
|
n = newDst.Size()
|
||||||
|
}
|
||||||
|
in.ServerSideCopyEnd(n) // account the bytes for the server-side transfer
|
||||||
|
} else {
|
||||||
|
// Rewind any stats counted on error
|
||||||
|
ta.Reset()
|
||||||
}
|
}
|
||||||
_ = in.Close()
|
_ = in.Close()
|
||||||
if errors.Is(err, fs.ErrorCantCopy) {
|
if errors.Is(err, fs.ErrorCantCopy) {
|
||||||
|
|||||||
@@ -64,25 +64,15 @@ type multiThreadCopyState struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Copy a single chunk into place
|
// Copy a single chunk into place
|
||||||
func (mc *multiThreadCopyState) copyChunk(ctx context.Context, chunk int, writer fs.ChunkWriter) (err error) {
|
func (mc *multiThreadCopyState) copyChunk(ctx context.Context, chunk int, writer fs.ChunkWriter, start, end, size int64, rw *pool.RW) (err error) {
|
||||||
defer func() {
|
defer func() {
|
||||||
|
if !mc.noBuffering {
|
||||||
|
fs.CheckClose(rw, &err)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fs.Debugf(mc.src, "multi-thread copy: chunk %d/%d failed: %v", chunk+1, mc.numChunks, err)
|
fs.Debugf(mc.src, "multi-thread copy: chunk %d/%d failed: %v", chunk+1, mc.numChunks, err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
start := int64(chunk) * mc.partSize
|
|
||||||
if start >= mc.size {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
end := min(start+mc.partSize, mc.size)
|
|
||||||
size := end - start
|
|
||||||
|
|
||||||
// Reserve the memory first so we don't open the source and wait for memory buffers for ages
|
|
||||||
var rw *pool.RW
|
|
||||||
if !mc.noBuffering {
|
|
||||||
rw = multipart.NewRW().Reserve(size)
|
|
||||||
defer fs.CheckClose(rw, &err)
|
|
||||||
}
|
|
||||||
|
|
||||||
fs.Debugf(mc.src, "multi-thread copy: chunk %d/%d (%d-%d) size %v starting", chunk+1, mc.numChunks, start, end, fs.SizeSuffix(size))
|
fs.Debugf(mc.src, "multi-thread copy: chunk %d/%d (%d-%d) size %v starting", chunk+1, mc.numChunks, start, end, fs.SizeSuffix(size))
|
||||||
|
|
||||||
@@ -226,9 +216,24 @@ func multiThreadCopy(ctx context.Context, f fs.Fs, remote string, src fs.Object,
|
|||||||
if gCtx.Err() != nil {
|
if gCtx.Err() != nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
chunk := chunk
|
|
||||||
|
// Work out how big and where the chunk is
|
||||||
|
start := int64(chunk) * mc.partSize
|
||||||
|
if start >= mc.size {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
end := min(start+mc.partSize, mc.size)
|
||||||
|
size := end - start
|
||||||
|
|
||||||
|
// Reserve the memory first so we don't open the source and wait for memory buffers for ages
|
||||||
|
// This also avoids creating an excess of goroutines all waiting on memory.
|
||||||
|
var rw *pool.RW
|
||||||
|
if !mc.noBuffering {
|
||||||
|
rw = multipart.NewRW().Reserve(size)
|
||||||
|
}
|
||||||
|
|
||||||
g.Go(func() error {
|
g.Go(func() error {
|
||||||
return mc.copyChunk(gCtx, chunk, chunkWriter)
|
return mc.copyChunk(gCtx, chunk, chunkWriter, start, end, size, rw)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ package pacer
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -153,31 +155,43 @@ func (p *Pacer) ModifyCalculator(f func(Calculator)) {
|
|||||||
// This must be called as a pair with endCall.
|
// This must be called as a pair with endCall.
|
||||||
//
|
//
|
||||||
// This waits for the pacer token
|
// This waits for the pacer token
|
||||||
func (p *Pacer) beginCall() {
|
func (p *Pacer) beginCall(limitConnections bool) {
|
||||||
// pacer starts with a token in and whenever we take one out
|
// pacer starts with a token in and whenever we take one out
|
||||||
// XXX ms later we put another in. We could do this with a
|
// XXX ms later we put another in. We could do this with a
|
||||||
// Ticker more accurately, but then we'd have to work out how
|
// Ticker more accurately, but then we'd have to work out how
|
||||||
// not to run it when it wasn't needed
|
// not to run it when it wasn't needed
|
||||||
<-p.pacer
|
|
||||||
if p.maxConnections > 0 {
|
|
||||||
<-p.connTokens
|
|
||||||
}
|
|
||||||
|
|
||||||
p.mu.Lock()
|
p.mu.Lock()
|
||||||
// Restart the timer
|
sleepTime := p.state.SleepTime
|
||||||
go func(t time.Duration) {
|
|
||||||
time.Sleep(t)
|
|
||||||
p.pacer <- struct{}{}
|
|
||||||
}(p.state.SleepTime)
|
|
||||||
p.mu.Unlock()
|
p.mu.Unlock()
|
||||||
|
|
||||||
|
if sleepTime > 0 {
|
||||||
|
<-p.pacer
|
||||||
|
|
||||||
|
// Re-read the sleep time as it may be stale
|
||||||
|
// after waiting for the pacer token
|
||||||
|
p.mu.Lock()
|
||||||
|
sleepTime = p.state.SleepTime
|
||||||
|
p.mu.Unlock()
|
||||||
|
|
||||||
|
// Restart the timer
|
||||||
|
go func(t time.Duration) {
|
||||||
|
time.Sleep(t)
|
||||||
|
p.pacer <- struct{}{}
|
||||||
|
}(sleepTime)
|
||||||
|
}
|
||||||
|
|
||||||
|
if limitConnections {
|
||||||
|
<-p.connTokens
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// endCall implements the pacing algorithm
|
// endCall implements the pacing algorithm
|
||||||
//
|
//
|
||||||
// This should calculate a new sleepTime. It takes a boolean as to
|
// This should calculate a new sleepTime. It takes a boolean as to
|
||||||
// whether the operation should be retried or not.
|
// whether the operation should be retried or not.
|
||||||
func (p *Pacer) endCall(retry bool, err error) {
|
func (p *Pacer) endCall(retry bool, err error, limitConnections bool) {
|
||||||
if p.maxConnections > 0 {
|
if limitConnections {
|
||||||
p.connTokens <- struct{}{}
|
p.connTokens <- struct{}{}
|
||||||
}
|
}
|
||||||
p.mu.Lock()
|
p.mu.Lock()
|
||||||
@@ -191,13 +205,44 @@ func (p *Pacer) endCall(retry bool, err error) {
|
|||||||
p.mu.Unlock()
|
p.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Detect the pacer being called reentrantly.
|
||||||
|
//
|
||||||
|
// This looks for Pacer.call in the call stack and returns true if it
|
||||||
|
// is found.
|
||||||
|
//
|
||||||
|
// Ideally we would do this by passing a context about but there are
|
||||||
|
// an awful lot of Pacer calls!
|
||||||
|
//
|
||||||
|
// This is only needed when p.maxConnections > 0 which isn't a common
|
||||||
|
// configuration so adding a bit of extra slowdown here is not a
|
||||||
|
// problem.
|
||||||
|
func pacerReentered() bool {
|
||||||
|
var pcs [48]uintptr
|
||||||
|
n := runtime.Callers(3, pcs[:]) // skip runtime.Callers, pacerReentered and call
|
||||||
|
frames := runtime.CallersFrames(pcs[:n])
|
||||||
|
for {
|
||||||
|
f, more := frames.Next()
|
||||||
|
if strings.HasSuffix(f.Function, "(*Pacer).call") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if !more {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// call implements Call but with settable retries
|
// call implements Call but with settable retries
|
||||||
func (p *Pacer) call(fn Paced, retries int) (err error) {
|
func (p *Pacer) call(fn Paced, retries int) (err error) {
|
||||||
var retry bool
|
var retry bool
|
||||||
|
limitConnections := false
|
||||||
|
if p.maxConnections > 0 && !pacerReentered() {
|
||||||
|
limitConnections = true
|
||||||
|
}
|
||||||
for i := 1; i <= retries; i++ {
|
for i := 1; i <= retries; i++ {
|
||||||
p.beginCall()
|
p.beginCall(limitConnections)
|
||||||
retry, err = p.invoker(i, retries, fn)
|
retry, err = p.invoker(i, retries, fn)
|
||||||
p.endCall(retry, err)
|
p.endCall(retry, err, limitConnections)
|
||||||
if !retry {
|
if !retry {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -108,7 +108,7 @@ func waitForPace(p *Pacer, duration time.Duration) (when time.Time) {
|
|||||||
func TestBeginCall(t *testing.T) {
|
func TestBeginCall(t *testing.T) {
|
||||||
p := New(MaxConnectionsOption(10), CalculatorOption(NewDefault(MinSleep(1*time.Millisecond))))
|
p := New(MaxConnectionsOption(10), CalculatorOption(NewDefault(MinSleep(1*time.Millisecond))))
|
||||||
emptyTokens(p)
|
emptyTokens(p)
|
||||||
go p.beginCall()
|
go p.beginCall(true)
|
||||||
if !waitForPace(p, 10*time.Millisecond).IsZero() {
|
if !waitForPace(p, 10*time.Millisecond).IsZero() {
|
||||||
t.Errorf("beginSleep fired too early #1")
|
t.Errorf("beginSleep fired too early #1")
|
||||||
}
|
}
|
||||||
@@ -131,7 +131,7 @@ func TestBeginCall(t *testing.T) {
|
|||||||
func TestBeginCallZeroConnections(t *testing.T) {
|
func TestBeginCallZeroConnections(t *testing.T) {
|
||||||
p := New(MaxConnectionsOption(0), CalculatorOption(NewDefault(MinSleep(1*time.Millisecond))))
|
p := New(MaxConnectionsOption(0), CalculatorOption(NewDefault(MinSleep(1*time.Millisecond))))
|
||||||
emptyTokens(p)
|
emptyTokens(p)
|
||||||
go p.beginCall()
|
go p.beginCall(false)
|
||||||
if !waitForPace(p, 10*time.Millisecond).IsZero() {
|
if !waitForPace(p, 10*time.Millisecond).IsZero() {
|
||||||
t.Errorf("beginSleep fired too early #1")
|
t.Errorf("beginSleep fired too early #1")
|
||||||
}
|
}
|
||||||
@@ -257,7 +257,7 @@ func TestEndCall(t *testing.T) {
|
|||||||
p := New(MaxConnectionsOption(5))
|
p := New(MaxConnectionsOption(5))
|
||||||
emptyTokens(p)
|
emptyTokens(p)
|
||||||
p.state.ConsecutiveRetries = 1
|
p.state.ConsecutiveRetries = 1
|
||||||
p.endCall(true, nil)
|
p.endCall(true, nil, true)
|
||||||
assert.Equal(t, 1, len(p.connTokens))
|
assert.Equal(t, 1, len(p.connTokens))
|
||||||
assert.Equal(t, 2, p.state.ConsecutiveRetries)
|
assert.Equal(t, 2, p.state.ConsecutiveRetries)
|
||||||
}
|
}
|
||||||
@@ -266,7 +266,7 @@ func TestEndCallZeroConnections(t *testing.T) {
|
|||||||
p := New(MaxConnectionsOption(0))
|
p := New(MaxConnectionsOption(0))
|
||||||
emptyTokens(p)
|
emptyTokens(p)
|
||||||
p.state.ConsecutiveRetries = 1
|
p.state.ConsecutiveRetries = 1
|
||||||
p.endCall(false, nil)
|
p.endCall(false, nil, false)
|
||||||
assert.Equal(t, 0, len(p.connTokens))
|
assert.Equal(t, 0, len(p.connTokens))
|
||||||
assert.Equal(t, 0, p.state.ConsecutiveRetries)
|
assert.Equal(t, 0, p.state.ConsecutiveRetries)
|
||||||
}
|
}
|
||||||
@@ -353,6 +353,78 @@ func TestCallParallel(t *testing.T) {
|
|||||||
wait.Broadcast()
|
wait.Broadcast()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func BenchmarkPacerReentered(b *testing.B) {
|
||||||
|
for b.Loop() {
|
||||||
|
_ = pacerReentered()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkPacerReentered100(b *testing.B) {
|
||||||
|
var fn func(level int)
|
||||||
|
fn = func(level int) {
|
||||||
|
if level > 0 {
|
||||||
|
fn(level - 1)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for b.Loop() {
|
||||||
|
_ = pacerReentered()
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
fn(100)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCallMaxConnectionsRecursiveDeadlock(t *testing.T) {
|
||||||
|
p := New(CalculatorOption(NewDefault(MinSleep(1*time.Millisecond), MaxSleep(2*time.Millisecond))))
|
||||||
|
p.SetMaxConnections(1)
|
||||||
|
dp := &dummyPaced{retry: false}
|
||||||
|
err := p.Call(func() (bool, error) {
|
||||||
|
// check we have taken the connection token
|
||||||
|
// no tokens left means deadlock on the recursive call
|
||||||
|
assert.Equal(t, 0, len(p.connTokens))
|
||||||
|
return false, p.Call(dp.fn)
|
||||||
|
})
|
||||||
|
assert.Equal(t, 1, dp.called)
|
||||||
|
assert.Equal(t, errFoo, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCallMaxConnectionsRecursiveDeadlock2(t *testing.T) {
|
||||||
|
p := New(CalculatorOption(NewDefault(MinSleep(1*time.Millisecond), MaxSleep(2*time.Millisecond))))
|
||||||
|
p.SetMaxConnections(1)
|
||||||
|
dp := &dummyPaced{retry: false}
|
||||||
|
wg := new(sync.WaitGroup)
|
||||||
|
|
||||||
|
// Normal
|
||||||
|
for range 100 {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
err := p.Call(func() (bool, error) {
|
||||||
|
// check we have taken the connection token
|
||||||
|
assert.Equal(t, 0, len(p.connTokens))
|
||||||
|
return false, nil
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Now attempt a recursive call
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
err := p.Call(func() (bool, error) {
|
||||||
|
// check we have taken the connection token
|
||||||
|
assert.Equal(t, 0, len(p.connTokens))
|
||||||
|
// Do recursive call
|
||||||
|
return false, p.Call(dp.fn)
|
||||||
|
})
|
||||||
|
assert.Equal(t, errFoo, err)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tidy up
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
func TestRetryAfterError_NonNilErr(t *testing.T) {
|
func TestRetryAfterError_NonNilErr(t *testing.T) {
|
||||||
orig := errors.New("test failure")
|
orig := errors.New("test failure")
|
||||||
dur := 2 * time.Second
|
dur := 2 * time.Second
|
||||||
|
|||||||
75
lib/transferaccounter/transferaccounter.go
Normal file
75
lib/transferaccounter/transferaccounter.go
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
// Package transferaccounter provides utilities for accounting server side transfers.
|
||||||
|
package transferaccounter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Context key type for accounter
|
||||||
|
type accounterContextKeyType struct{}
|
||||||
|
|
||||||
|
// Context key for accounter
|
||||||
|
var accounterContextKey = accounterContextKeyType{}
|
||||||
|
|
||||||
|
// TransferAccounter is used to account server side and other transfers.
|
||||||
|
type TransferAccounter struct {
|
||||||
|
add func(n int64)
|
||||||
|
total atomic.Int64
|
||||||
|
started bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a TransferAccounter using the add function passed in.
|
||||||
|
//
|
||||||
|
// Note that the add function should be goroutine safe.
|
||||||
|
//
|
||||||
|
// It adds the new TransferAccounter to the context.
|
||||||
|
func New(ctx context.Context, add func(n int64)) (context.Context, *TransferAccounter) {
|
||||||
|
ta := &TransferAccounter{
|
||||||
|
add: add,
|
||||||
|
}
|
||||||
|
newCtx := context.WithValue(ctx, accounterContextKey, ta)
|
||||||
|
return newCtx, ta
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start the transfer. Call this before calling Add().
|
||||||
|
func (ta *TransferAccounter) Start() {
|
||||||
|
ta.started = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Started returns if the transfer has had Start() called or not.
|
||||||
|
func (ta *TransferAccounter) Started() bool {
|
||||||
|
return ta.started
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add n bytes to the transfer
|
||||||
|
func (ta *TransferAccounter) Add(n int64) {
|
||||||
|
ta.add(n)
|
||||||
|
ta.total.Add(n)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset reverses out all accounted stats if Started() has been called
|
||||||
|
func (ta *TransferAccounter) Reset() {
|
||||||
|
if ta.started {
|
||||||
|
ta.Add(-ta.total.Load())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// A transfer accounter which does nothing
|
||||||
|
var nullAccounter = &TransferAccounter{
|
||||||
|
add: func(n int64) {},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns a *TransferAccounter from the ctx.
|
||||||
|
//
|
||||||
|
// If none is found it will return a dummy one to keep the code simple.
|
||||||
|
func Get(ctx context.Context) *TransferAccounter {
|
||||||
|
if ctx == nil {
|
||||||
|
return nullAccounter
|
||||||
|
}
|
||||||
|
c := ctx.Value(accounterContextKey)
|
||||||
|
if c == nil {
|
||||||
|
return nullAccounter
|
||||||
|
}
|
||||||
|
return c.(*TransferAccounter)
|
||||||
|
}
|
||||||
91
lib/transferaccounter/transferaccounter_test.go
Normal file
91
lib/transferaccounter/transferaccounter_test.go
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
package transferaccounter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNew(t *testing.T) {
|
||||||
|
// Dummy add function
|
||||||
|
var totalBytes int64
|
||||||
|
addFn := func(n int64) {
|
||||||
|
totalBytes += n
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the accounter
|
||||||
|
ctx := context.Background()
|
||||||
|
_, ta := New(ctx, addFn)
|
||||||
|
|
||||||
|
// Verify object creation
|
||||||
|
require.NotNil(t, ta)
|
||||||
|
assert.False(t, ta.Started(), "New accounter should not be started by default")
|
||||||
|
|
||||||
|
// Test Start()
|
||||||
|
ta.Start()
|
||||||
|
assert.True(t, ta.Started(), "Accounter should be started after calling Start()")
|
||||||
|
|
||||||
|
// Test Add() logic
|
||||||
|
ta.Add(100)
|
||||||
|
ta.Add(50)
|
||||||
|
assert.Equal(t, int64(150), totalBytes, "The add function should have been called with cumulative values")
|
||||||
|
assert.Equal(t, int64(150), ta.total.Load(), "Internal counter did not count")
|
||||||
|
|
||||||
|
// Test Reset() logic
|
||||||
|
ta.Reset()
|
||||||
|
assert.Equal(t, int64(0), totalBytes, "The Reset function failed")
|
||||||
|
assert.Equal(t, int64(0), ta.total.Load(), "Internal counter did not reset")
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGet(t *testing.T) {
|
||||||
|
t.Run("Retrieve existing accounter", func(t *testing.T) {
|
||||||
|
// Create a specific accounter to identify later
|
||||||
|
expectedTotal := int64(0)
|
||||||
|
ctx, originalTa := New(context.Background(), func(n int64) { expectedTotal += n })
|
||||||
|
|
||||||
|
// Retrieve it
|
||||||
|
retrievedTa := Get(ctx)
|
||||||
|
|
||||||
|
// Assert it is the exact same pointer
|
||||||
|
assert.Equal(t, originalTa, retrievedTa)
|
||||||
|
|
||||||
|
// Verify functionality passes through
|
||||||
|
retrievedTa.Add(10)
|
||||||
|
assert.Equal(t, int64(10), expectedTotal)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Context does not contain accounter", func(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
ta := Get(ctx)
|
||||||
|
|
||||||
|
assert.NotNil(t, ta, "Get should never return nil")
|
||||||
|
assert.Equal(t, nullAccounter, ta, "Should return the global nullAccounter")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Context is nil", func(t *testing.T) {
|
||||||
|
ta := Get(nil) //nolint:staticcheck // we want to test this
|
||||||
|
|
||||||
|
assert.NotNil(t, ta, "Get should never return nil")
|
||||||
|
assert.Equal(t, nullAccounter, ta, "Should return the global nullAccounter")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNullAccounterBehavior(t *testing.T) {
|
||||||
|
// Ensure the null accounter (returned when context is missing/nil)
|
||||||
|
// can be called without panicking.
|
||||||
|
ta := Get(nil) //nolint:staticcheck // we want to test this
|
||||||
|
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
ta.Start()
|
||||||
|
})
|
||||||
|
|
||||||
|
// Even after start, it acts as a valid object
|
||||||
|
assert.True(t, ta.Started())
|
||||||
|
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
ta.Add(1000)
|
||||||
|
})
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user