diff --git a/backend/smb/filepool.go b/backend/smb/filepool.go new file mode 100644 index 000000000..0061ec1ff --- /dev/null +++ b/backend/smb/filepool.go @@ -0,0 +1,99 @@ +package smb + +import ( + "context" + "fmt" + "os" + "sync" + + "github.com/cloudsoda/go-smb2" + "golang.org/x/sync/errgroup" +) + +// FsInterface defines the methods that filePool needs from Fs +type FsInterface interface { + getConnection(ctx context.Context, share string) (*conn, error) + putConnection(pc **conn, err error) + removeSession() +} + +type file struct { + *smb2.File + c *conn +} + +type filePool struct { + ctx context.Context + fs FsInterface + share string + path string + + mu sync.Mutex + pool []*file +} + +func newFilePool(ctx context.Context, fs FsInterface, share, path string) *filePool { + return &filePool{ + ctx: ctx, + fs: fs, + share: share, + path: path, + } +} + +func (p *filePool) get() (*file, error) { + p.mu.Lock() + if len(p.pool) > 0 { + f := p.pool[len(p.pool)-1] + p.pool = p.pool[:len(p.pool)-1] + p.mu.Unlock() + return f, nil + } + p.mu.Unlock() + + c, err := p.fs.getConnection(p.ctx, p.share) + if err != nil { + return nil, err + } + + fl, err := c.smbShare.OpenFile(p.path, os.O_WRONLY, 0o644) + if err != nil { + p.fs.putConnection(&c, err) + return nil, fmt.Errorf("failed to open: %w", err) + } + + return &file{File: fl, c: c}, nil +} + +func (p *filePool) put(f *file, err error) { + if f == nil { + return + } + + if err != nil { + _ = f.Close() + p.fs.putConnection(&f.c, err) + return + } + + p.mu.Lock() + p.pool = append(p.pool, f) + p.mu.Unlock() +} + +func (p *filePool) drain() error { + p.mu.Lock() + files := p.pool + p.pool = nil + p.mu.Unlock() + + g, _ := errgroup.WithContext(p.ctx) + for _, f := range files { + g.Go(func() error { + err := f.Close() + p.fs.putConnection(&f.c, err) + return err + }) + } + return g.Wait() +} diff --git a/backend/smb/filepool_test.go b/backend/smb/filepool_test.go new file mode 100644 index 000000000..7cf90988b --- /dev/null +++ b/backend/smb/filepool_test.go @@ -0,0 +1,228 @@ +package smb + +import ( + "context" + "errors" + "sync" + "testing" + + "github.com/cloudsoda/go-smb2" + "github.com/stretchr/testify/assert" +) + +// Mock Fs that implements FsInterface +type mockFs struct { + mu sync.Mutex + putConnectionCalled bool + putConnectionErr error + getConnectionCalled bool + getConnectionErr error + getConnectionResult *conn + removeSessionCalled bool +} + +func (m *mockFs) putConnection(pc **conn, err error) { + m.mu.Lock() + defer m.mu.Unlock() + m.putConnectionCalled = true + m.putConnectionErr = err +} + +func (m *mockFs) getConnection(ctx context.Context, share string) (*conn, error) { + m.mu.Lock() + defer m.mu.Unlock() + m.getConnectionCalled = true + if m.getConnectionErr != nil { + return nil, m.getConnectionErr + } + if m.getConnectionResult != nil { + return m.getConnectionResult, nil + } + return &conn{}, nil +} + +func (m *mockFs) removeSession() { + m.mu.Lock() + defer m.mu.Unlock() + m.removeSessionCalled = true +} + +func (m *mockFs) isPutConnectionCalled() bool { + m.mu.Lock() + defer m.mu.Unlock() + return m.putConnectionCalled +} + +func (m *mockFs) getPutConnectionErr() error { + m.mu.Lock() + defer m.mu.Unlock() + return m.putConnectionErr +} + +func (m *mockFs) isGetConnectionCalled() bool { + m.mu.Lock() + defer m.mu.Unlock() + return m.getConnectionCalled +} + +func newMockFs() *mockFs { + return &mockFs{} +} + +// Helper function to create a mock file +func newMockFile() *file { + return &file{ + File: &smb2.File{}, + c: &conn{}, + } +} + +// Test filePool creation +func TestNewFilePool(t *testing.T) { + ctx := context.Background() + fs := newMockFs() + share := "testshare" + path := "/test/path" + + pool := newFilePool(ctx, fs, share, path) + + assert.NotNil(t, pool) + assert.Equal(t, ctx, pool.ctx) + assert.Equal(t, fs, pool.fs) + assert.Equal(t, share, pool.share) + assert.Equal(t, path, pool.path) + assert.Empty(t, pool.pool) +} + +// Test getting file from pool when pool has files +func TestFilePool_Get_FromPool(t *testing.T) { + ctx := context.Background() + fs := newMockFs() + pool := newFilePool(ctx, fs, "testshare", "/test/path") + + // Add a mock file to the pool + mockFile := newMockFile() + pool.pool = append(pool.pool, mockFile) + + // Get file from pool + f, err := pool.get() + + assert.NoError(t, err) + assert.NotNil(t, f) + assert.Equal(t, mockFile, f) + assert.Empty(t, pool.pool) +} + +// Test getting file when pool is empty +func TestFilePool_Get_EmptyPool(t *testing.T) { + ctx := context.Background() + fs := newMockFs() + + // Set up the mock to return an error from getConnection + // This tests that the pool calls getConnection when empty + fs.getConnectionErr = errors.New("connection failed") + + pool := newFilePool(ctx, fs, "testshare", "test/path") + + // This should call getConnection and return the error + f, err := pool.get() + assert.Error(t, err) + assert.Nil(t, f) + assert.True(t, fs.isGetConnectionCalled()) + assert.Equal(t, "connection failed", err.Error()) +} + +// Test putting file successfully +func TestFilePool_Put_Success(t *testing.T) { + ctx := context.Background() + fs := newMockFs() + pool := newFilePool(ctx, fs, "testshare", "/test/path") + + mockFile := newMockFile() + + pool.put(mockFile, nil) + + assert.Len(t, pool.pool, 1) + assert.Equal(t, mockFile, pool.pool[0]) +} + +// Test putting file with error +func TestFilePool_Put_WithError(t *testing.T) { + ctx := context.Background() + fs := newMockFs() + pool := newFilePool(ctx, fs, "testshare", "/test/path") + + mockFile := newMockFile() + + pool.put(mockFile, errors.New("write error")) + + // Should call putConnection with error + assert.True(t, fs.isPutConnectionCalled()) + assert.Equal(t, errors.New("write error"), fs.getPutConnectionErr()) + assert.Empty(t, pool.pool) +} + +// Test putting nil file +func TestFilePool_Put_NilFile(t *testing.T) { + ctx := context.Background() + fs := newMockFs() + pool := newFilePool(ctx, fs, "testshare", "/test/path") + + // Should not panic + pool.put(nil, nil) + pool.put(nil, errors.New("some error")) + + assert.Empty(t, pool.pool) +} + +// Test draining pool with files +func TestFilePool_Drain_WithFiles(t *testing.T) { + ctx := context.Background() + fs := newMockFs() + pool := newFilePool(ctx, fs, "testshare", "/test/path") + + // Add mock files to pool + mockFile1 := newMockFile() + mockFile2 := newMockFile() + pool.pool = append(pool.pool, mockFile1, mockFile2) + + // Before draining + assert.Len(t, pool.pool, 2) + + _ = pool.drain() + assert.Empty(t, pool.pool) +} + +// Test concurrent access to pool +func TestFilePool_ConcurrentAccess(t *testing.T) { + ctx := context.Background() + fs := newMockFs() + pool := newFilePool(ctx, fs, "testshare", "/test/path") + + const numGoroutines = 10 + for i := 0; i < numGoroutines; i++ { + mockFile := newMockFile() + pool.pool = append(pool.pool, mockFile) + } + + // Test concurrent get operations + done := make(chan bool, numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func() { + defer func() { done <- true }() + + f, err := pool.get() + if err == nil { + pool.put(f, nil) + } + }() + } + + for i := 0; i < numGoroutines; i++ { + <-done + } + + // Pool should be in a consistent after the concurrence access + assert.Len(t, pool.pool, numGoroutines) +} diff --git a/backend/smb/smb.go b/backend/smb/smb.go index d102233f5..993b2ec7d 100644 --- a/backend/smb/smb.go +++ b/backend/smb/smb.go @@ -3,6 +3,7 @@ package smb import ( "context" + "errors" "fmt" "io" "os" @@ -503,13 +504,73 @@ func (f *Fs) About(ctx context.Context) (_ *fs.Usage, err error) { return usage, nil } +type smbWriterAt struct { + pool *filePool + closed bool + closeMu sync.Mutex + wg sync.WaitGroup +} + +func (w *smbWriterAt) WriteAt(p []byte, off int64) (int, error) { + w.closeMu.Lock() + if w.closed { + w.closeMu.Unlock() + return 0, errors.New("writer already closed") + } + w.wg.Add(1) + w.closeMu.Unlock() + defer w.wg.Done() + + f, err := w.pool.get() + if err != nil { + return 0, fmt.Errorf("failed to get file from pool: %w", err) + } + + n, writeErr := f.WriteAt(p, off) + w.pool.put(f, writeErr) + + if writeErr != nil { + return n, fmt.Errorf("failed to write at offset %d: %w", off, writeErr) + } + + return n, writeErr +} + +func (w *smbWriterAt) Close() error { + w.closeMu.Lock() + defer w.closeMu.Unlock() + + if w.closed { + return nil + } + w.closed = true + + // Wait for all pending writes to finish + w.wg.Wait() + + var errs []error + + // Drain the pool + if err := w.pool.drain(); err != nil { + errs = append(errs, fmt.Errorf("failed to drain file pool: %w", err)) + } + + // Remove session + w.pool.fs.removeSession() + + if len(errs) > 0 { + return errors.Join(errs...) + } + + return nil +} + // OpenWriterAt opens with a handle for random access writes // // Pass in the remote desired and the size if known. // // It truncates any existing object func (f *Fs) OpenWriterAt(ctx context.Context, remote string, size int64) (fs.WriterAtCloser, error) { - var err error o := &Object{ fs: f, remote: remote, @@ -519,27 +580,42 @@ func (f *Fs) OpenWriterAt(ctx context.Context, remote string, size int64) (fs.Wr return nil, fs.ErrorIsDir } - err = o.fs.ensureDirectory(ctx, share, filename) + err := o.fs.ensureDirectory(ctx, share, filename) if err != nil { return nil, fmt.Errorf("failed to make parent directories: %w", err) } - filename = o.fs.toSambaPath(filename) - - o.fs.addSession() // Show session in use - defer o.fs.removeSession() + smbPath := o.fs.toSambaPath(filename) + // One-time truncate cn, err := o.fs.getConnection(ctx, share) if err != nil { return nil, err } - - fl, err := cn.smbShare.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o644) + file, err := cn.smbShare.OpenFile(smbPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o644) if err != nil { - return nil, fmt.Errorf("failed to open: %w", err) + o.fs.putConnection(&cn, err) + return nil, err } + if size > 0 { + if truncateErr := file.Truncate(size); truncateErr != nil { + _ = file.Close() + o.fs.putConnection(&cn, truncateErr) + return nil, fmt.Errorf("failed to truncate file: %w", truncateErr) + } + } + if closeErr := file.Close(); closeErr != nil { + o.fs.putConnection(&cn, closeErr) + return nil, fmt.Errorf("failed to close file after truncate: %w", closeErr) + } + o.fs.putConnection(&cn, nil) - return fl, nil + // Add a new session + o.fs.addSession() + + return &smbWriterAt{ + pool: newFilePool(ctx, o.fs, share, smbPath), + }, nil } // Shutdown the backend, closing any background tasks and any