1
0
mirror of https://github.com/rclone/rclone.git synced 2025-12-19 09:43:14 +00:00

Compare commits

..

1 Commits

Author SHA1 Message Date
Nick Craig-Wood
7e9eae8c90 rc: add rmdirs and leaveRoot to operations/delete 2025-09-15 15:47:00 +01:00
20 changed files with 112 additions and 377 deletions

View File

@@ -2797,6 +2797,8 @@ func (o *Object) clearUncommittedBlocks(ctx context.Context) (err error) {
blockList blockblob.GetBlockListResponse blockList blockblob.GetBlockListResponse
properties *blob.GetPropertiesResponse properties *blob.GetPropertiesResponse
options *blockblob.CommitBlockListOptions options *blockblob.CommitBlockListOptions
// Use temporary pacer as this can be called recursively which can cause a deadlock with --max-connections
pacer = fs.NewPacer(ctx, pacer.NewS3(pacer.MinSleep(minSleep), pacer.MaxSleep(maxSleep), pacer.DecayConstant(decayConstant)))
) )
properties, err = o.readMetaDataAlways(ctx) properties, err = o.readMetaDataAlways(ctx)
@@ -2808,7 +2810,7 @@ func (o *Object) clearUncommittedBlocks(ctx context.Context) (err error) {
if objectExists { if objectExists {
// Get the committed block list // Get the committed block list
err = o.fs.pacer.Call(func() (bool, error) { err = pacer.Call(func() (bool, error) {
blockList, err = blockBlobSVC.GetBlockList(ctx, blockblob.BlockListTypeAll, nil) blockList, err = blockBlobSVC.GetBlockList(ctx, blockblob.BlockListTypeAll, nil)
return o.fs.shouldRetry(ctx, err) return o.fs.shouldRetry(ctx, err)
}) })
@@ -2850,7 +2852,7 @@ func (o *Object) clearUncommittedBlocks(ctx context.Context) (err error) {
// Commit only the committed blocks // Commit only the committed blocks
fs.Debugf(o, "Committing %d blocks to remove uncommitted blocks", len(blockIDs)) fs.Debugf(o, "Committing %d blocks to remove uncommitted blocks", len(blockIDs))
err = o.fs.pacer.Call(func() (bool, error) { err = pacer.Call(func() (bool, error) {
_, err := blockBlobSVC.CommitBlockList(ctx, blockIDs, options) _, err := blockBlobSVC.CommitBlockList(ctx, blockIDs, options)
return o.fs.shouldRetry(ctx, err) return o.fs.shouldRetry(ctx, err)
}) })

View File

@@ -2224,17 +2224,13 @@ func (f *Fs) OpenChunkWriter(ctx context.Context, remote string, src fs.ObjectIn
return info, nil, err return info, nil, err
} }
up, err := f.newLargeUpload(ctx, o, nil, src, f.opt.ChunkSize, false, nil, options...)
if err != nil {
return info, nil, err
}
info = fs.ChunkWriterInfo{ info = fs.ChunkWriterInfo{
ChunkSize: up.chunkSize, ChunkSize: int64(f.opt.ChunkSize),
Concurrency: o.fs.opt.UploadConcurrency, Concurrency: o.fs.opt.UploadConcurrency,
//LeavePartsOnError: o.fs.opt.LeavePartsOnError, //LeavePartsOnError: o.fs.opt.LeavePartsOnError,
} }
return info, up, nil up, err := f.newLargeUpload(ctx, o, nil, src, f.opt.ChunkSize, false, nil, options...)
return info, up, err
} }
// Remove an object // Remove an object

View File

@@ -5,7 +5,6 @@ package api
import ( import (
"fmt" "fmt"
"net/url"
"reflect" "reflect"
"strconv" "strconv"
"time" "time"
@@ -137,25 +136,8 @@ type Link struct {
} }
// Valid reports whether l is non-nil, has an URL, and is not expired. // Valid reports whether l is non-nil, has an URL, and is not expired.
// It primarily checks the URL's expire query parameter, falling back to the Expire field.
func (l *Link) Valid() bool { func (l *Link) Valid() bool {
if l == nil || l.URL == "" { return l != nil && l.URL != "" && time.Now().Add(10*time.Second).Before(time.Time(l.Expire))
return false
}
// Primary validation: check URL's expire query parameter
if u, err := url.Parse(l.URL); err == nil {
if expireStr := u.Query().Get("expire"); expireStr != "" {
// Try parsing as Unix timestamp (seconds)
if expireInt, err := strconv.ParseInt(expireStr, 10, 64); err == nil {
expireTime := time.Unix(expireInt, 0)
return time.Now().Add(10 * time.Second).Before(expireTime)
}
}
}
// Fallback validation: use the Expire field if URL parsing didn't work
return time.Now().Add(10 * time.Second).Before(time.Time(l.Expire))
} }
// URL is a basic form of URL // URL is a basic form of URL

View File

@@ -1,99 +0,0 @@
package api
import (
"fmt"
"testing"
"time"
)
// TestLinkValid tests the Link.Valid method for various scenarios
func TestLinkValid(t *testing.T) {
tests := []struct {
name string
link *Link
expected bool
desc string
}{
{
name: "nil link",
link: nil,
expected: false,
desc: "nil link should be invalid",
},
{
name: "empty URL",
link: &Link{URL: ""},
expected: false,
desc: "empty URL should be invalid",
},
{
name: "valid URL with future expire parameter",
link: &Link{
URL: fmt.Sprintf("https://example.com/file?expire=%d", time.Now().Add(time.Hour).Unix()),
},
expected: true,
desc: "URL with future expire parameter should be valid",
},
{
name: "expired URL with past expire parameter",
link: &Link{
URL: fmt.Sprintf("https://example.com/file?expire=%d", time.Now().Add(-time.Hour).Unix()),
},
expected: false,
desc: "URL with past expire parameter should be invalid",
},
{
name: "URL expire parameter takes precedence over Expire field",
link: &Link{
URL: fmt.Sprintf("https://example.com/file?expire=%d", time.Now().Add(time.Hour).Unix()),
Expire: Time(time.Now().Add(-time.Hour)), // Fallback is expired
},
expected: true,
desc: "URL expire parameter should take precedence over Expire field",
},
{
name: "URL expire parameter within 10 second buffer should be invalid",
link: &Link{
URL: fmt.Sprintf("https://example.com/file?expire=%d", time.Now().Add(5*time.Second).Unix()),
},
expected: false,
desc: "URL expire parameter within 10 second buffer should be invalid",
},
{
name: "fallback to Expire field when no URL expire parameter",
link: &Link{
URL: "https://example.com/file",
Expire: Time(time.Now().Add(time.Hour)),
},
expected: true,
desc: "should fallback to Expire field when URL has no expire parameter",
},
{
name: "fallback to Expire field when URL expire parameter is invalid",
link: &Link{
URL: "https://example.com/file?expire=invalid",
Expire: Time(time.Now().Add(time.Hour)),
},
expected: true,
desc: "should fallback to Expire field when URL expire parameter is unparseable",
},
{
name: "invalid when both URL expire and Expire field are expired",
link: &Link{
URL: fmt.Sprintf("https://example.com/file?expire=%d", time.Now().Add(-time.Hour).Unix()),
Expire: Time(time.Now().Add(-time.Hour)),
},
expected: false,
desc: "should be invalid when both URL expire and Expire field are expired",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.link.Valid()
if result != tt.expected {
t.Errorf("Link.Valid() = %v, expected %v. %s", result, tt.expected, tt.desc)
}
})
}
}

View File

@@ -192,9 +192,6 @@ func NewFs(ctx context.Context, name, root string, m configmap.Mapper) (fs.Fs, e
return nil, err return nil, err
} }
// if root is empty or ends with / (must be a directory)
isRootDir := isPathDir(root)
root = strings.Trim(root, "/") root = strings.Trim(root, "/")
f := &Fs{ f := &Fs{
@@ -221,11 +218,6 @@ func NewFs(ctx context.Context, name, root string, m configmap.Mapper) (fs.Fs, e
if share == "" || dir == "" { if share == "" || dir == "" {
return f, nil return f, nil
} }
// Skip stat check if root is already a directory
if isRootDir {
return f, nil
}
cn, err := f.getConnection(ctx, share) cn, err := f.getConnection(ctx, share)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -902,11 +894,6 @@ func ensureSuffix(s, suffix string) string {
return s + suffix return s + suffix
} }
// isPathDir determines if a path represents a directory based on trailing slash
func isPathDir(path string) bool {
return path == "" || strings.HasSuffix(path, "/")
}
func trimPathPrefix(s, prefix string) string { func trimPathPrefix(s, prefix string) string {
// we need to clean the paths to make tests pass! // we need to clean the paths to make tests pass!
s = betterPathClean(s) s = betterPathClean(s)

View File

@@ -1,41 +0,0 @@
// Unit tests for internal SMB functions
package smb
import "testing"
// TestIsPathDir tests the isPathDir function logic
func TestIsPathDir(t *testing.T) {
tests := []struct {
path string
expected bool
}{
// Empty path should be considered a directory
{"", true},
// Paths with trailing slash should be directories
{"/", true},
{"share/", true},
{"share/dir/", true},
{"share/dir/subdir/", true},
// Paths without trailing slash should not be directories
{"share", false},
{"share/dir", false},
{"share/dir/file", false},
{"share/dir/subdir/file", false},
// Edge cases
{"share//", true},
{"share///", true},
{"share/dir//", true},
}
for _, tt := range tests {
t.Run(tt.path, func(t *testing.T) {
result := isPathDir(tt.path)
if result != tt.expected {
t.Errorf("isPathDir(%q) = %v, want %v", tt.path, result, tt.expected)
}
})
}
}

View File

@@ -4,19 +4,15 @@ package bilib
import ( import (
"bytes" "bytes"
"log/slog" "log/slog"
"sync"
"github.com/rclone/rclone/fs/log" "github.com/rclone/rclone/fs/log"
) )
// CaptureOutput runs a function capturing its output at log level INFO. // CaptureOutput runs a function capturing its output at log level INFO.
func CaptureOutput(fun func()) []byte { func CaptureOutput(fun func()) []byte {
var mu sync.Mutex
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
oldLevel := log.Handler.SetLevel(slog.LevelInfo) oldLevel := log.Handler.SetLevel(slog.LevelInfo)
log.Handler.SetOutput(func(level slog.Level, text string) { log.Handler.SetOutput(func(level slog.Level, text string) {
mu.Lock()
defer mu.Unlock()
buf.WriteString(text) buf.WriteString(text)
}) })
defer func() { defer func() {
@@ -24,7 +20,5 @@ func CaptureOutput(fun func()) []byte {
log.Handler.SetLevel(oldLevel) log.Handler.SetLevel(oldLevel)
}() }()
fun() fun()
mu.Lock()
defer mu.Unlock()
return buf.Bytes() return buf.Bytes()
} }

View File

@@ -330,7 +330,7 @@ func testBisync(ctx context.Context, t *testing.T, path1, path2 string) {
baseDir, err := os.Getwd() baseDir, err := os.Getwd()
require.NoError(t, err, "get current directory") require.NoError(t, err, "get current directory")
randName := time.Now().Format("150405") + random.String(2) // some bucket backends don't like dots, keep this short to avoid linux errors randName := time.Now().Format("150405") + random.String(2) // some bucket backends don't like dots, keep this short to avoid linux errors
tempDir := filepath.Join(t.TempDir(), randName) tempDir := filepath.Join(os.TempDir(), randName)
workDir := filepath.Join(tempDir, "workdir") workDir := filepath.Join(tempDir, "workdir")
b := &bisyncTest{ b := &bisyncTest{

View File

@@ -707,7 +707,8 @@ func (b *bisyncRun) modifyListing(ctx context.Context, src fs.Fs, dst fs.Fs, res
prettyprint(dstList.list, "dstList", fs.LogLevelDebug) prettyprint(dstList.list, "dstList", fs.LogLevelDebug)
// clear stats so we only do this once // clear stats so we only do this once
accounting.Stats(ctx).RemoveDoneTransfers() accounting.MaxCompletedTransfers = 0
accounting.Stats(ctx).PruneTransfers()
} }
if b.DebugName != "" { if b.DebugName != "" {

View File

@@ -246,7 +246,9 @@ func (b *bisyncRun) fastCopy(ctx context.Context, fsrc, fdst fs.Fs, files bilib.
} }
b.SyncCI = fs.GetConfig(ctxCopy) // allows us to request graceful shutdown b.SyncCI = fs.GetConfig(ctxCopy) // allows us to request graceful shutdown
accounting.Stats(ctxCopy).SetMaxCompletedTransfers(-1) // we need a complete list in the event of graceful shutdown if accounting.MaxCompletedTransfers != -1 {
accounting.MaxCompletedTransfers = -1 // we need a complete list in the event of graceful shutdown
}
ctxCopy, b.CancelSync = context.WithCancel(ctxCopy) ctxCopy, b.CancelSync = context.WithCancel(ctxCopy)
b.testFn() b.testFn()
err := sync.Sync(ctxCopy, fdst, fsrc, b.opt.CreateEmptySrcDirs) err := sync.Sync(ctxCopy, fdst, fsrc, b.opt.CreateEmptySrcDirs)

View File

@@ -208,7 +208,6 @@ func newServer(ctx context.Context, f fs.Fs, opt *Options, vfsOpt *vfscommon.Opt
// Serve HTTP until the server is shutdown // Serve HTTP until the server is shutdown
func (s *HTTP) Serve() error { func (s *HTTP) Serve() error {
s.server.Serve() s.server.Serve()
fs.Logf(s.f, "HTTP Server started on %s", s.server.URLs())
s.server.Wait() s.server.Wait()
return nil return nil
} }

View File

@@ -1013,5 +1013,3 @@ put them back in again.` >}}
- Robin Rolf <imer@imer.cc> - Robin Rolf <imer@imer.cc>
- Jean-Christophe Cura <jcaspes@gmail.com> - Jean-Christophe Cura <jcaspes@gmail.com>
- russcoss <russcoss@outlook.com> - russcoss <russcoss@outlook.com>
- Matt LaPaglia <mlapaglia@gmail.com>
- Youfu Zhang <1315097+zhangyoufu@users.noreply.github.com>

View File

@@ -21,7 +21,7 @@ you started to share on Windows. On smbd, it's the section title in `smb.conf`
(usually in `/etc/samba/`) file. (usually in `/etc/samba/`) file.
You can find shares by querying the root if you're unsure (e.g. `rclone lsd remote:`). You can find shares by querying the root if you're unsure (e.g. `rclone lsd remote:`).
You can't access the shared printers from rclone, obviously. You can't access to the shared printers from rclone, obviously.
You can't use Anonymous access for logging in. You have to use the `guest` user You can't use Anonymous access for logging in. You have to use the `guest` user
with an empty password instead. The rclone client tries to avoid 8.3 names when with an empty password instead. The rclone client tries to avoid 8.3 names when

View File

@@ -22,10 +22,7 @@ const (
averageStopAfter = time.Minute averageStopAfter = time.Minute
) )
// MaxCompletedTransfers specifies the default maximum number of // MaxCompletedTransfers specifies maximum number of completed transfers in startedTransfers list
// completed transfers in startedTransfers list. This can be adjusted
// for a given StatsInfo by calling the SetMaxCompletedTransfers
// method.
var MaxCompletedTransfers = 100 var MaxCompletedTransfers = 100
// StatsInfo accounts all transfers // StatsInfo accounts all transfers
@@ -67,7 +64,6 @@ type StatsInfo struct {
serverSideCopyBytes int64 serverSideCopyBytes int64
serverSideMoves int64 serverSideMoves int64
serverSideMoveBytes int64 serverSideMoveBytes int64
maxCompletedTransfers int
} }
type averageValues struct { type averageValues struct {
@@ -92,19 +88,10 @@ func NewStats(ctx context.Context) *StatsInfo {
inProgress: newInProgress(ctx), inProgress: newInProgress(ctx),
startTime: time.Now(), startTime: time.Now(),
average: averageValues{}, average: averageValues{},
maxCompletedTransfers: MaxCompletedTransfers,
} }
return s return s
} }
// SetMaxCompletedTransfers sets the maximum number of completed transfers to keep.
func (s *StatsInfo) SetMaxCompletedTransfers(n int) *StatsInfo {
s.mu.Lock()
s.maxCompletedTransfers = n
s.mu.Unlock()
return s
}
// RemoteStats returns stats for rc // RemoteStats returns stats for rc
// //
// If short is true then the transfers and checkers won't be added. // If short is true then the transfers and checkers won't be added.
@@ -925,31 +912,22 @@ func (s *StatsInfo) RemoveTransfer(transfer *Transfer) {
} }
// PruneTransfers makes sure there aren't too many old transfers by removing // PruneTransfers makes sure there aren't too many old transfers by removing
// a single finished transfer. Returns true if it removed a transfer. // single finished transfer.
func (s *StatsInfo) PruneTransfers() bool { func (s *StatsInfo) PruneTransfers() {
s.mu.Lock() if MaxCompletedTransfers < 0 {
defer s.mu.Unlock() return
if s.maxCompletedTransfers < 0 {
return false
} }
removed := false s.mu.Lock()
// remove a transfer from the start if we are over quota // remove a transfer from the start if we are over quota
if len(s.startedTransfers) > s.maxCompletedTransfers+s.ci.Transfers { if len(s.startedTransfers) > MaxCompletedTransfers+s.ci.Transfers {
for i, tr := range s.startedTransfers { for i, tr := range s.startedTransfers {
if tr.IsDone() { if tr.IsDone() {
s._removeTransfer(tr, i) s._removeTransfer(tr, i)
removed = true
break break
} }
} }
} }
return removed s.mu.Unlock()
}
// RemoveDoneTransfers removes all Done transfers.
func (s *StatsInfo) RemoveDoneTransfers() {
for s.PruneTransfers() {
}
} }
// AddServerSideMove counts a server side move // AddServerSideMove counts a server side move

View File

@@ -465,27 +465,3 @@ func TestPruneTransfers(t *testing.T) {
}) })
} }
} }
func TestRemoveDoneTransfers(t *testing.T) {
ctx := context.Background()
s := NewStats(ctx)
const transfers = 10
for i := int64(1); i <= int64(transfers); i++ {
s.AddTransfer(&Transfer{
startedAt: time.Unix(i, 0),
completedAt: time.Unix(i+1, 0),
})
}
s.mu.Lock()
assert.Equal(t, time.Duration(transfers)*time.Second, s._totalDuration())
assert.Equal(t, transfers, len(s.startedTransfers))
s.mu.Unlock()
s.RemoveDoneTransfers()
s.mu.Lock()
assert.Equal(t, time.Duration(transfers)*time.Second, s._totalDuration())
assert.Equal(t, transfers, len(s.startedTransfers))
s.mu.Unlock()
}

View File

@@ -208,7 +208,7 @@ func init() {
{name: "rmdir", title: "Remove an empty directory or container"}, {name: "rmdir", title: "Remove an empty directory or container"},
{name: "purge", title: "Remove a directory or container and all of its contents"}, {name: "purge", title: "Remove a directory or container and all of its contents"},
{name: "rmdirs", title: "Remove all the empty directories in the path", help: "- leaveRoot - boolean, set to true not to delete the root\n"}, {name: "rmdirs", title: "Remove all the empty directories in the path", help: "- leaveRoot - boolean, set to true not to delete the root\n"},
{name: "delete", title: "Remove files in the path", noRemote: true}, {name: "delete", title: "Remove files in the path", help: "- rmdirs - boolean, set to true to remove empty directories\n- leaveRoot - boolean if rmdirs is set, set to true not to delete the root\n", noRemote: true},
{name: "deletefile", title: "Remove the single file pointed to"}, {name: "deletefile", title: "Remove the single file pointed to"},
{name: "copyurl", title: "Copy the URL to the object", help: "- url - string, URL to read from\n - autoFilename - boolean, set to true to retrieve destination file name from url\n"}, {name: "copyurl", title: "Copy the URL to the object", help: "- url - string, URL to read from\n - autoFilename - boolean, set to true to retrieve destination file name from url\n"},
{name: "uploadfile", title: "Upload file using multiform/form-data", help: "- each part in body represents a file to be uploaded\n", needsRequest: true, noCommand: true}, {name: "uploadfile", title: "Upload file using multiform/form-data", help: "- each part in body represents a file to be uploaded\n", needsRequest: true, noCommand: true},
@@ -267,7 +267,22 @@ func rcSingleCommand(ctx context.Context, in rc.Params, name string, noRemote bo
} }
return nil, Rmdirs(ctx, f, remote, leaveRoot) return nil, Rmdirs(ctx, f, remote, leaveRoot)
case "delete": case "delete":
return nil, Delete(ctx, f) rmdirs, err := in.GetBool("rmdirs")
if rc.NotErrParamNotFound(err) {
return nil, err
}
leaveRoot, err := in.GetBool("leaveRoot")
if rc.NotErrParamNotFound(err) {
return nil, err
}
err = Delete(ctx, f)
if err != nil {
return nil, err
}
if !rmdirs {
return nil, nil
}
return nil, Rmdirs(ctx, f, remote, leaveRoot)
case "deletefile": case "deletefile":
o, err := f.NewObject(ctx, remote) o, err := f.NewObject(ctx, remote)
if err != nil { if err != nil {

View File

@@ -159,21 +159,32 @@ func TestRcCopyurl(t *testing.T) {
// operations/delete: Remove files in the path // operations/delete: Remove files in the path
func TestRcDelete(t *testing.T) { func TestRcDelete(t *testing.T) {
ctx := context.Background()
r, call := rcNewRun(t, "operations/delete") r, call := rcNewRun(t, "operations/delete")
file1 := r.WriteObject(ctx, "subdir/file1", "subdir/file1 contents", t1)
file2 := r.WriteObject(ctx, "file2", "file2 contents", t1)
file1 := r.WriteObject(context.Background(), "small", "1234567890", t2) // 10 bytes fstest.CheckListingWithPrecision(t, r.Fremote, []fstest.Item{file1, file2}, []string{"subdir"}, fs.GetModifyWindow(ctx, r.Fremote))
file2 := r.WriteObject(context.Background(), "medium", "------------------------------------------------------------", t1) // 60 bytes
file3 := r.WriteObject(context.Background(), "large", "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA", t1) // 100 bytes
r.CheckRemoteItems(t, file1, file2, file3)
in := rc.Params{ in := rc.Params{
"fs": r.FremoteName, "fs": r.FremoteName,
} }
out, err := call.Fn(context.Background(), in) out, err := call.Fn(ctx, in)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, rc.Params(nil), out) assert.Equal(t, rc.Params(nil), out)
r.CheckRemoteItems(t) fstest.CheckListingWithPrecision(t, r.Fremote, []fstest.Item{}, []string{"subdir"}, fs.GetModifyWindow(ctx, r.Fremote))
// Now try with rmdirs=true and leaveRoot=true
in["rmdirs"] = true
in["leaveRoot"] = true
out, err = call.Fn(ctx, in)
require.NoError(t, err)
assert.Equal(t, rc.Params(nil), out)
fstest.CheckListingWithPrecision(t, r.Fremote, []fstest.Item{}, []string{}, fs.GetModifyWindow(ctx, r.Fremote))
// FIXME don't have an easy way of checking the root still exists or not
} }
// operations/deletefile: Remove the single file pointed to // operations/deletefile: Remove the single file pointed to

View File

@@ -523,6 +523,8 @@ func (s *Server) initTLS() error {
func (s *Server) Serve() { func (s *Server) Serve() {
s.wg.Add(len(s.instances)) s.wg.Add(len(s.instances))
for _, ii := range s.instances { for _, ii := range s.instances {
// TODO: decide how/when to log listening url
// log.Printf("listening on %s", ii.url)
go ii.serve(&s.wg) go ii.serve(&s.wg)
} }
// Install an atexit handler to shutdown gracefully // Install an atexit handler to shutdown gracefully

View File

@@ -4,8 +4,6 @@ package pacer
import ( import (
"errors" "errors"
"fmt" "fmt"
"runtime"
"strings"
"sync" "sync"
"time" "time"
@@ -155,13 +153,13 @@ func (p *Pacer) ModifyCalculator(f func(Calculator)) {
// This must be called as a pair with endCall. // This must be called as a pair with endCall.
// //
// This waits for the pacer token // This waits for the pacer token
func (p *Pacer) beginCall(limitConnections bool) { func (p *Pacer) beginCall() {
// pacer starts with a token in and whenever we take one out // pacer starts with a token in and whenever we take one out
// XXX ms later we put another in. We could do this with a // XXX ms later we put another in. We could do this with a
// Ticker more accurately, but then we'd have to work out how // Ticker more accurately, but then we'd have to work out how
// not to run it when it wasn't needed // not to run it when it wasn't needed
<-p.pacer <-p.pacer
if limitConnections { if p.maxConnections > 0 {
<-p.connTokens <-p.connTokens
} }
@@ -178,8 +176,8 @@ func (p *Pacer) beginCall(limitConnections bool) {
// //
// This should calculate a new sleepTime. It takes a boolean as to // This should calculate a new sleepTime. It takes a boolean as to
// whether the operation should be retried or not. // whether the operation should be retried or not.
func (p *Pacer) endCall(retry bool, err error, limitConnections bool) { func (p *Pacer) endCall(retry bool, err error) {
if limitConnections { if p.maxConnections > 0 {
p.connTokens <- struct{}{} p.connTokens <- struct{}{}
} }
p.mu.Lock() p.mu.Lock()
@@ -193,44 +191,13 @@ func (p *Pacer) endCall(retry bool, err error, limitConnections bool) {
p.mu.Unlock() p.mu.Unlock()
} }
// Detect the pacer being called reentrantly.
//
// This looks for Pacer.call in the call stack and returns true if it
// is found.
//
// Ideally we would do this by passing a context about but there are
// an awful lot of Pacer calls!
//
// This is only needed when p.maxConnections > 0 which isn't a common
// configuration so adding a bit of extra slowdown here is not a
// problem.
func pacerReentered() bool {
var pcs [48]uintptr
n := runtime.Callers(3, pcs[:]) // skip runtime.Callers, pacerReentered and call
frames := runtime.CallersFrames(pcs[:n])
for {
f, more := frames.Next()
if strings.HasSuffix(f.Function, "(*Pacer).call") {
return true
}
if !more {
break
}
}
return false
}
// call implements Call but with settable retries // call implements Call but with settable retries
func (p *Pacer) call(fn Paced, retries int) (err error) { func (p *Pacer) call(fn Paced, retries int) (err error) {
var retry bool var retry bool
limitConnections := false
if p.maxConnections > 0 && !pacerReentered() {
limitConnections = true
}
for i := 1; i <= retries; i++ { for i := 1; i <= retries; i++ {
p.beginCall(limitConnections) p.beginCall()
retry, err = p.invoker(i, retries, fn) retry, err = p.invoker(i, retries, fn)
p.endCall(retry, err, limitConnections) p.endCall(retry, err)
if !retry { if !retry {
break break
} }

View File

@@ -108,7 +108,7 @@ func waitForPace(p *Pacer, duration time.Duration) (when time.Time) {
func TestBeginCall(t *testing.T) { func TestBeginCall(t *testing.T) {
p := New(MaxConnectionsOption(10), CalculatorOption(NewDefault(MinSleep(1*time.Millisecond)))) p := New(MaxConnectionsOption(10), CalculatorOption(NewDefault(MinSleep(1*time.Millisecond))))
emptyTokens(p) emptyTokens(p)
go p.beginCall(true) go p.beginCall()
if !waitForPace(p, 10*time.Millisecond).IsZero() { if !waitForPace(p, 10*time.Millisecond).IsZero() {
t.Errorf("beginSleep fired too early #1") t.Errorf("beginSleep fired too early #1")
} }
@@ -131,7 +131,7 @@ func TestBeginCall(t *testing.T) {
func TestBeginCallZeroConnections(t *testing.T) { func TestBeginCallZeroConnections(t *testing.T) {
p := New(MaxConnectionsOption(0), CalculatorOption(NewDefault(MinSleep(1*time.Millisecond)))) p := New(MaxConnectionsOption(0), CalculatorOption(NewDefault(MinSleep(1*time.Millisecond))))
emptyTokens(p) emptyTokens(p)
go p.beginCall(false) go p.beginCall()
if !waitForPace(p, 10*time.Millisecond).IsZero() { if !waitForPace(p, 10*time.Millisecond).IsZero() {
t.Errorf("beginSleep fired too early #1") t.Errorf("beginSleep fired too early #1")
} }
@@ -257,7 +257,7 @@ func TestEndCall(t *testing.T) {
p := New(MaxConnectionsOption(5)) p := New(MaxConnectionsOption(5))
emptyTokens(p) emptyTokens(p)
p.state.ConsecutiveRetries = 1 p.state.ConsecutiveRetries = 1
p.endCall(true, nil, true) p.endCall(true, nil)
assert.Equal(t, 1, len(p.connTokens)) assert.Equal(t, 1, len(p.connTokens))
assert.Equal(t, 2, p.state.ConsecutiveRetries) assert.Equal(t, 2, p.state.ConsecutiveRetries)
} }
@@ -266,7 +266,7 @@ func TestEndCallZeroConnections(t *testing.T) {
p := New(MaxConnectionsOption(0)) p := New(MaxConnectionsOption(0))
emptyTokens(p) emptyTokens(p)
p.state.ConsecutiveRetries = 1 p.state.ConsecutiveRetries = 1
p.endCall(false, nil, false) p.endCall(false, nil)
assert.Equal(t, 0, len(p.connTokens)) assert.Equal(t, 0, len(p.connTokens))
assert.Equal(t, 0, p.state.ConsecutiveRetries) assert.Equal(t, 0, p.state.ConsecutiveRetries)
} }
@@ -353,41 +353,6 @@ func TestCallParallel(t *testing.T) {
wait.Broadcast() wait.Broadcast()
} }
func BenchmarkPacerReentered(b *testing.B) {
for b.Loop() {
_ = pacerReentered()
}
}
func BenchmarkPacerReentered100(b *testing.B) {
var fn func(level int)
fn = func(level int) {
if level > 0 {
fn(level - 1)
return
}
for b.Loop() {
_ = pacerReentered()
}
}
fn(100)
}
func TestCallMaxConnectionsRecursiveDeadlock(t *testing.T) {
p := New(CalculatorOption(NewDefault(MinSleep(1*time.Millisecond), MaxSleep(2*time.Millisecond))))
p.SetMaxConnections(1)
dp := &dummyPaced{retry: false}
err := p.Call(func() (bool, error) {
// check we have taken the connection token
// no tokens left means deadlock on the recursive call
assert.Equal(t, 0, len(p.connTokens))
return false, p.Call(dp.fn)
})
assert.Equal(t, 1, dp.called)
assert.Equal(t, errFoo, err)
}
func TestRetryAfterError_NonNilErr(t *testing.T) { func TestRetryAfterError_NonNilErr(t *testing.T) {
orig := errors.New("test failure") orig := errors.New("test failure")
dur := 2 * time.Second dur := 2 * time.Second