diff --git a/fs/accounting/accounting.go b/fs/accounting/accounting.go index 009cb2c07..91aff8316 100644 --- a/fs/accounting/accounting.go +++ b/fs/accounting/accounting.go @@ -11,6 +11,7 @@ import ( "unicode/utf8" "github.com/rclone/rclone/fs/rc" + "github.com/rclone/rclone/lib/transferaccounter" "github.com/rclone/rclone/fs" "github.com/rclone/rclone/fs/asyncreader" @@ -312,6 +313,15 @@ func (acc *Account) serverSideEnd(n int64) { } } +// NewServerSideCopyAccounter returns a TransferAccounter for a server +// side copy and a new ctx with it embedded +func (acc *Account) NewServerSideCopyAccounter(ctx context.Context) (context.Context, *transferaccounter.TransferAccounter) { + return transferaccounter.New(ctx, func(n int64) { + acc.stats.AddServerSideCopyBytes(n) + acc.accountReadNoNetwork(n) + }) +} + // ServerSideCopyEnd accounts for a read of n bytes in a server-side copy func (acc *Account) ServerSideCopyEnd(n int64) { acc.stats.AddServerSideCopy(n) @@ -358,6 +368,17 @@ func (acc *Account) accountRead(n int) { acc.limitPerFileBandwidth(n) } +// Account the read if not using network (eg for server side copies) +func (acc *Account) accountReadNoNetwork(n int64) { + // Update Stats + acc.values.mu.Lock() + acc.values.lpBytes += int(n) + acc.values.bytes += n + acc.values.mu.Unlock() + + acc.stats.BytesNoNetwork(n) +} + // read bytes from the io.Reader passed in and account them func (acc *Account) read(in io.Reader, p []byte) (n int, err error) { bytesUntilLimit, err := acc.checkReadBefore() diff --git a/fs/accounting/stats.go b/fs/accounting/stats.go index 64f160953..0fe72e89a 100644 --- a/fs/accounting/stats.go +++ b/fs/accounting/stats.go @@ -938,6 +938,13 @@ func (s *StatsInfo) AddServerSideMove(n int64) { s.mu.Unlock() } +// AddServerSideCopyBytes adds bytes for a server side copy +func (s *StatsInfo) AddServerSideCopyBytes(n int64) { + s.mu.Lock() + s.serverSideCopyBytes += n + s.mu.Unlock() +} + // AddServerSideCopy counts a server side copy func (s *StatsInfo) AddServerSideCopy(n int64) { s.mu.Lock() diff --git a/fs/operations/copy.go b/fs/operations/copy.go index 2d42fac07..8ab14c53e 100644 --- a/fs/operations/copy.go +++ b/fs/operations/copy.go @@ -148,9 +148,14 @@ func (c *copy) serverSideCopy(ctx context.Context) (actionTaken string, newDst f } in := c.tr.Account(ctx, nil) // account the transfer in.ServerSideTransferStart() - newDst, err = doCopy(ctx, c.src, c.remoteForCopy) + newCtx, ta := in.NewServerSideCopyAccounter(ctx) + newDst, err = doCopy(newCtx, c.src, c.remoteForCopy) if err == nil { - in.ServerSideCopyEnd(newDst.Size()) // account the bytes for the server-side transfer + var n int64 + if !ta.Started() { + n = newDst.Size() + } + in.ServerSideCopyEnd(n) // account the bytes for the server-side transfer } _ = in.Close() if errors.Is(err, fs.ErrorCantCopy) { diff --git a/lib/transferaccounter/transferaccounter.go b/lib/transferaccounter/transferaccounter.go new file mode 100644 index 000000000..c33049d7f --- /dev/null +++ b/lib/transferaccounter/transferaccounter.go @@ -0,0 +1,63 @@ +// Package transferaccounter provides utilities for accounting server side transfers. +package transferaccounter + +import "context" + +// Context key type for accounter +type accounterContextKeyType struct{} + +// Context key for accounter +var accounterContextKey = accounterContextKeyType{} + +// TransferAccounter is used to account server side and other transfers. +type TransferAccounter struct { + add func(n int64) + started bool +} + +// New creates a TransferAccounter using the add function passed in. +// +// Note that the add function should be goroutine safe. +// +// It adds the new TransferAccounter to the context. +func New(ctx context.Context, add func(n int64)) (context.Context, *TransferAccounter) { + ta := &TransferAccounter{ + add: add, + } + newCtx := context.WithValue(ctx, accounterContextKey, ta) + return newCtx, ta +} + +// Start the transfer. Call this before calling Add(). +func (ta *TransferAccounter) Start() { + ta.started = true +} + +// Started returns if the transfer has had Start() called or not. +func (ta *TransferAccounter) Started() bool { + return ta.started +} + +// Add n bytes to the transfer +func (ta *TransferAccounter) Add(n int64) { + ta.add(n) +} + +// A transfer accounter which does nothing +var nullAccounter = &TransferAccounter{ + add: func(n int64) {}, +} + +// Get returns a *TransferAccounter from the ctx. +// +// If none is found it will return a dummy one to keep the code simple. +func Get(ctx context.Context) *TransferAccounter { + if ctx == nil { + return nullAccounter + } + c := ctx.Value(accounterContextKey) + if c == nil { + return nullAccounter + } + return c.(*TransferAccounter) +} diff --git a/lib/transferaccounter/transferaccounter_test.go b/lib/transferaccounter/transferaccounter_test.go new file mode 100644 index 000000000..fce30605e --- /dev/null +++ b/lib/transferaccounter/transferaccounter_test.go @@ -0,0 +1,83 @@ +package transferaccounter + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNew(t *testing.T) { + // Dummy add function + var totalBytes int64 + addFn := func(n int64) { + totalBytes += n + } + + // Create the accounter + ctx := context.Background() + _, ta := New(ctx, addFn) + + // Verify object creation + assert.NotNil(t, ta) + assert.False(t, ta.Started(), "New accounter should not be started by default") + + // Test Start() + ta.Start() + assert.True(t, ta.Started(), "Accounter should be started after calling Start()") + + // Test Add() logic + ta.Add(100) + ta.Add(50) + assert.Equal(t, int64(150), totalBytes, "The add function should have been called with cumulative values") +} + +func TestGet(t *testing.T) { + t.Run("Retrieve existing accounter", func(t *testing.T) { + // Create a specific accounter to identify later + expectedTotal := int64(0) + ctx, originalTa := New(context.Background(), func(n int64) { expectedTotal += n }) + + // Retrieve it + retrievedTa := Get(ctx) + + // Assert it is the exact same pointer + assert.Equal(t, originalTa, retrievedTa) + + // Verify functionality passes through + retrievedTa.Add(10) + assert.Equal(t, int64(10), expectedTotal) + }) + + t.Run("Context does not contain accounter", func(t *testing.T) { + ctx := context.Background() + ta := Get(ctx) + + assert.NotNil(t, ta, "Get should never return nil") + assert.Equal(t, nullAccounter, ta, "Should return the global nullAccounter") + }) + + t.Run("Context is nil", func(t *testing.T) { + ta := Get(nil) //nolint:staticcheck // we want to test this + + assert.NotNil(t, ta, "Get should never return nil") + assert.Equal(t, nullAccounter, ta, "Should return the global nullAccounter") + }) +} + +func TestNullAccounterBehavior(t *testing.T) { + // Ensure the null accounter (returned when context is missing/nil) + // can be called without panicking. + ta := Get(nil) //nolint:staticcheck // we want to test this + + assert.NotPanics(t, func() { + ta.Start() + }) + + // Even after start, it acts as a valid object + assert.True(t, ta.Started()) + + assert.NotPanics(t, func() { + ta.Add(1000) + }) +}