diff --git a/lib/pacer/pacer.go b/lib/pacer/pacer.go index 1c7cc51e5..005317821 100644 --- a/lib/pacer/pacer.go +++ b/lib/pacer/pacer.go @@ -4,6 +4,8 @@ package pacer import ( "errors" "fmt" + "runtime" + "strings" "sync" "time" @@ -153,13 +155,13 @@ func (p *Pacer) ModifyCalculator(f func(Calculator)) { // This must be called as a pair with endCall. // // This waits for the pacer token -func (p *Pacer) beginCall() { +func (p *Pacer) beginCall(limitConnections bool) { // pacer starts with a token in and whenever we take one out // XXX ms later we put another in. We could do this with a // Ticker more accurately, but then we'd have to work out how // not to run it when it wasn't needed <-p.pacer - if p.maxConnections > 0 { + if limitConnections { <-p.connTokens } @@ -176,8 +178,8 @@ func (p *Pacer) beginCall() { // // This should calculate a new sleepTime. It takes a boolean as to // whether the operation should be retried or not. -func (p *Pacer) endCall(retry bool, err error) { - if p.maxConnections > 0 { +func (p *Pacer) endCall(retry bool, err error, limitConnections bool) { + if limitConnections { p.connTokens <- struct{}{} } p.mu.Lock() @@ -191,13 +193,44 @@ func (p *Pacer) endCall(retry bool, err error) { p.mu.Unlock() } +// Detect the pacer being called reentrantly. +// +// This looks for Pacer.call in the call stack and returns true if it +// is found. +// +// Ideally we would do this by passing a context about but there are +// an awful lot of Pacer calls! +// +// This is only needed when p.maxConnections > 0 which isn't a common +// configuration so adding a bit of extra slowdown here is not a +// problem. +func pacerReentered() bool { + var pcs [48]uintptr + n := runtime.Callers(3, pcs[:]) // skip runtime.Callers, pacerReentered and call + frames := runtime.CallersFrames(pcs[:n]) + for { + f, more := frames.Next() + if strings.HasSuffix(f.Function, "(*Pacer).call") { + return true + } + if !more { + break + } + } + return false +} + // call implements Call but with settable retries func (p *Pacer) call(fn Paced, retries int) (err error) { var retry bool + limitConnections := false + if p.maxConnections > 0 && !pacerReentered() { + limitConnections = true + } for i := 1; i <= retries; i++ { - p.beginCall() + p.beginCall(limitConnections) retry, err = p.invoker(i, retries, fn) - p.endCall(retry, err) + p.endCall(retry, err, limitConnections) if !retry { break } diff --git a/lib/pacer/pacer_test.go b/lib/pacer/pacer_test.go index 3ac9c3741..76ef4b071 100644 --- a/lib/pacer/pacer_test.go +++ b/lib/pacer/pacer_test.go @@ -108,7 +108,7 @@ func waitForPace(p *Pacer, duration time.Duration) (when time.Time) { func TestBeginCall(t *testing.T) { p := New(MaxConnectionsOption(10), CalculatorOption(NewDefault(MinSleep(1*time.Millisecond)))) emptyTokens(p) - go p.beginCall() + go p.beginCall(true) if !waitForPace(p, 10*time.Millisecond).IsZero() { t.Errorf("beginSleep fired too early #1") } @@ -131,7 +131,7 @@ func TestBeginCall(t *testing.T) { func TestBeginCallZeroConnections(t *testing.T) { p := New(MaxConnectionsOption(0), CalculatorOption(NewDefault(MinSleep(1*time.Millisecond)))) emptyTokens(p) - go p.beginCall() + go p.beginCall(false) if !waitForPace(p, 10*time.Millisecond).IsZero() { t.Errorf("beginSleep fired too early #1") } @@ -257,7 +257,7 @@ func TestEndCall(t *testing.T) { p := New(MaxConnectionsOption(5)) emptyTokens(p) p.state.ConsecutiveRetries = 1 - p.endCall(true, nil) + p.endCall(true, nil, true) assert.Equal(t, 1, len(p.connTokens)) assert.Equal(t, 2, p.state.ConsecutiveRetries) } @@ -266,7 +266,7 @@ func TestEndCallZeroConnections(t *testing.T) { p := New(MaxConnectionsOption(0)) emptyTokens(p) p.state.ConsecutiveRetries = 1 - p.endCall(false, nil) + p.endCall(false, nil, false) assert.Equal(t, 0, len(p.connTokens)) assert.Equal(t, 0, p.state.ConsecutiveRetries) } @@ -353,6 +353,41 @@ func TestCallParallel(t *testing.T) { wait.Broadcast() } +func BenchmarkPacerReentered(b *testing.B) { + for b.Loop() { + _ = pacerReentered() + } +} + +func BenchmarkPacerReentered100(b *testing.B) { + var fn func(level int) + fn = func(level int) { + if level > 0 { + fn(level - 1) + return + } + for b.Loop() { + _ = pacerReentered() + } + + } + fn(100) +} + +func TestCallMaxConnectionsRecursiveDeadlock(t *testing.T) { + p := New(CalculatorOption(NewDefault(MinSleep(1*time.Millisecond), MaxSleep(2*time.Millisecond)))) + p.SetMaxConnections(1) + dp := &dummyPaced{retry: false} + err := p.Call(func() (bool, error) { + // check we have taken the connection token + // no tokens left means deadlock on the recursive call + assert.Equal(t, 0, len(p.connTokens)) + return false, p.Call(dp.fn) + }) + assert.Equal(t, 1, dp.called) + assert.Equal(t, errFoo, err) +} + func TestRetryAfterError_NonNilErr(t *testing.T) { orig := errors.New("test failure") dur := 2 * time.Second