From d516515dfe9d0ef5445027317cef9b2bb132d4b7 Mon Sep 17 00:00:00 2001 From: Nick Craig-Wood Date: Thu, 12 Feb 2026 16:32:09 +0000 Subject: [PATCH] operations: add method to real time account server side copy Before this change server side copies would show at 0% until they were done then show at 100%. With support from the backend, server side copies can now be accounted in real time. This will only work for backends which have been modified and themselves get feedback about how copies are going. --- fs/accounting/accounting.go | 21 +++++ fs/accounting/stats.go | 7 ++ fs/operations/copy.go | 9 +- lib/transferaccounter/transferaccounter.go | 63 ++++++++++++++ .../transferaccounter_test.go | 83 +++++++++++++++++++ 5 files changed, 181 insertions(+), 2 deletions(-) create mode 100644 lib/transferaccounter/transferaccounter.go create mode 100644 lib/transferaccounter/transferaccounter_test.go diff --git a/fs/accounting/accounting.go b/fs/accounting/accounting.go index 009cb2c07..91aff8316 100644 --- a/fs/accounting/accounting.go +++ b/fs/accounting/accounting.go @@ -11,6 +11,7 @@ import ( "unicode/utf8" "github.com/rclone/rclone/fs/rc" + "github.com/rclone/rclone/lib/transferaccounter" "github.com/rclone/rclone/fs" "github.com/rclone/rclone/fs/asyncreader" @@ -312,6 +313,15 @@ func (acc *Account) serverSideEnd(n int64) { } } +// NewServerSideCopyAccounter returns a TransferAccounter for a server +// side copy and a new ctx with it embedded +func (acc *Account) NewServerSideCopyAccounter(ctx context.Context) (context.Context, *transferaccounter.TransferAccounter) { + return transferaccounter.New(ctx, func(n int64) { + acc.stats.AddServerSideCopyBytes(n) + acc.accountReadNoNetwork(n) + }) +} + // ServerSideCopyEnd accounts for a read of n bytes in a server-side copy func (acc *Account) ServerSideCopyEnd(n int64) { acc.stats.AddServerSideCopy(n) @@ -358,6 +368,17 @@ func (acc *Account) accountRead(n int) { acc.limitPerFileBandwidth(n) } +// Account the read if not using network (eg for server side copies) +func (acc *Account) accountReadNoNetwork(n int64) { + // Update Stats + acc.values.mu.Lock() + acc.values.lpBytes += int(n) + acc.values.bytes += n + acc.values.mu.Unlock() + + acc.stats.BytesNoNetwork(n) +} + // read bytes from the io.Reader passed in and account them func (acc *Account) read(in io.Reader, p []byte) (n int, err error) { bytesUntilLimit, err := acc.checkReadBefore() diff --git a/fs/accounting/stats.go b/fs/accounting/stats.go index 64f160953..0fe72e89a 100644 --- a/fs/accounting/stats.go +++ b/fs/accounting/stats.go @@ -938,6 +938,13 @@ func (s *StatsInfo) AddServerSideMove(n int64) { s.mu.Unlock() } +// AddServerSideCopyBytes adds bytes for a server side copy +func (s *StatsInfo) AddServerSideCopyBytes(n int64) { + s.mu.Lock() + s.serverSideCopyBytes += n + s.mu.Unlock() +} + // AddServerSideCopy counts a server side copy func (s *StatsInfo) AddServerSideCopy(n int64) { s.mu.Lock() diff --git a/fs/operations/copy.go b/fs/operations/copy.go index 2d42fac07..8ab14c53e 100644 --- a/fs/operations/copy.go +++ b/fs/operations/copy.go @@ -148,9 +148,14 @@ func (c *copy) serverSideCopy(ctx context.Context) (actionTaken string, newDst f } in := c.tr.Account(ctx, nil) // account the transfer in.ServerSideTransferStart() - newDst, err = doCopy(ctx, c.src, c.remoteForCopy) + newCtx, ta := in.NewServerSideCopyAccounter(ctx) + newDst, err = doCopy(newCtx, c.src, c.remoteForCopy) if err == nil { - in.ServerSideCopyEnd(newDst.Size()) // account the bytes for the server-side transfer + var n int64 + if !ta.Started() { + n = newDst.Size() + } + in.ServerSideCopyEnd(n) // account the bytes for the server-side transfer } _ = in.Close() if errors.Is(err, fs.ErrorCantCopy) { diff --git a/lib/transferaccounter/transferaccounter.go b/lib/transferaccounter/transferaccounter.go new file mode 100644 index 000000000..c33049d7f --- /dev/null +++ b/lib/transferaccounter/transferaccounter.go @@ -0,0 +1,63 @@ +// Package transferaccounter provides utilities for accounting server side transfers. +package transferaccounter + +import "context" + +// Context key type for accounter +type accounterContextKeyType struct{} + +// Context key for accounter +var accounterContextKey = accounterContextKeyType{} + +// TransferAccounter is used to account server side and other transfers. +type TransferAccounter struct { + add func(n int64) + started bool +} + +// New creates a TransferAccounter using the add function passed in. +// +// Note that the add function should be goroutine safe. +// +// It adds the new TransferAccounter to the context. +func New(ctx context.Context, add func(n int64)) (context.Context, *TransferAccounter) { + ta := &TransferAccounter{ + add: add, + } + newCtx := context.WithValue(ctx, accounterContextKey, ta) + return newCtx, ta +} + +// Start the transfer. Call this before calling Add(). +func (ta *TransferAccounter) Start() { + ta.started = true +} + +// Started returns if the transfer has had Start() called or not. +func (ta *TransferAccounter) Started() bool { + return ta.started +} + +// Add n bytes to the transfer +func (ta *TransferAccounter) Add(n int64) { + ta.add(n) +} + +// A transfer accounter which does nothing +var nullAccounter = &TransferAccounter{ + add: func(n int64) {}, +} + +// Get returns a *TransferAccounter from the ctx. +// +// If none is found it will return a dummy one to keep the code simple. +func Get(ctx context.Context) *TransferAccounter { + if ctx == nil { + return nullAccounter + } + c := ctx.Value(accounterContextKey) + if c == nil { + return nullAccounter + } + return c.(*TransferAccounter) +} diff --git a/lib/transferaccounter/transferaccounter_test.go b/lib/transferaccounter/transferaccounter_test.go new file mode 100644 index 000000000..fce30605e --- /dev/null +++ b/lib/transferaccounter/transferaccounter_test.go @@ -0,0 +1,83 @@ +package transferaccounter + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNew(t *testing.T) { + // Dummy add function + var totalBytes int64 + addFn := func(n int64) { + totalBytes += n + } + + // Create the accounter + ctx := context.Background() + _, ta := New(ctx, addFn) + + // Verify object creation + assert.NotNil(t, ta) + assert.False(t, ta.Started(), "New accounter should not be started by default") + + // Test Start() + ta.Start() + assert.True(t, ta.Started(), "Accounter should be started after calling Start()") + + // Test Add() logic + ta.Add(100) + ta.Add(50) + assert.Equal(t, int64(150), totalBytes, "The add function should have been called with cumulative values") +} + +func TestGet(t *testing.T) { + t.Run("Retrieve existing accounter", func(t *testing.T) { + // Create a specific accounter to identify later + expectedTotal := int64(0) + ctx, originalTa := New(context.Background(), func(n int64) { expectedTotal += n }) + + // Retrieve it + retrievedTa := Get(ctx) + + // Assert it is the exact same pointer + assert.Equal(t, originalTa, retrievedTa) + + // Verify functionality passes through + retrievedTa.Add(10) + assert.Equal(t, int64(10), expectedTotal) + }) + + t.Run("Context does not contain accounter", func(t *testing.T) { + ctx := context.Background() + ta := Get(ctx) + + assert.NotNil(t, ta, "Get should never return nil") + assert.Equal(t, nullAccounter, ta, "Should return the global nullAccounter") + }) + + t.Run("Context is nil", func(t *testing.T) { + ta := Get(nil) //nolint:staticcheck // we want to test this + + assert.NotNil(t, ta, "Get should never return nil") + assert.Equal(t, nullAccounter, ta, "Should return the global nullAccounter") + }) +} + +func TestNullAccounterBehavior(t *testing.T) { + // Ensure the null accounter (returned when context is missing/nil) + // can be called without panicking. + ta := Get(nil) //nolint:staticcheck // we want to test this + + assert.NotPanics(t, func() { + ta.Start() + }) + + // Even after start, it acts as a valid object + assert.True(t, ta.Started()) + + assert.NotPanics(t, func() { + ta.Add(1000) + }) +}