Add Context support to Backoff

Signed-off-by: Luca Bigliardi <shammash@google.com>
This commit is contained in:
Luca Bigliardi 2021-02-24 17:14:40 +01:00
parent bde6681de9
commit 82af7c1f69
4 changed files with 82 additions and 18 deletions

View File

@ -15,6 +15,7 @@
package main package main
import ( import (
"context"
"log" "log"
"math" "math"
"math/rand" "math/rand"
@ -23,10 +24,9 @@ import (
type JitterFunc func(int) int type JitterFunc func(int) int
type TimeFunc func() time.Time
type Delayer interface { type Delayer interface {
Delay() Delay()
DelayContext(context.Context) bool
} }
type Backoff struct { type Backoff struct {
@ -36,7 +36,7 @@ type Backoff struct {
lastAttempt time.Time lastAttempt time.Time
durationUnit time.Duration durationUnit time.Duration
jitterer JitterFunc jitterer JitterFunc
timeGetter TimeFunc timeTeller TimeTeller
} }
func jitterFunc(input int) int { func jitterFunc(input int) int {
@ -46,27 +46,44 @@ func jitterFunc(input int) int {
return rand.Intn(input) return rand.Intn(input)
} }
// TimeTeller interface allows injection of fake time during testing
type TimeTeller interface {
Now() time.Time
After(time.Duration) <-chan time.Time
}
type RealTime struct{}
func (r *RealTime) Now() time.Time {
return time.Now()
}
func (r *RealTime) After(d time.Duration) <-chan time.Time {
return time.After(d)
}
func NewBackoff(maxBackoff float64, resetDelta float64, func NewBackoff(maxBackoff float64, resetDelta float64,
durationUnit time.Duration) *Backoff { durationUnit time.Duration) *Backoff {
timeTeller := &RealTime{}
return NewBackoffForTesting( return NewBackoffForTesting(
maxBackoff, resetDelta, durationUnit, jitterFunc, time.Now) maxBackoff, resetDelta, durationUnit, jitterFunc, timeTeller)
} }
func NewBackoffForTesting(maxBackoff float64, resetDelta float64, func NewBackoffForTesting(maxBackoff float64, resetDelta float64,
durationUnit time.Duration, jitterer JitterFunc, timeGetter TimeFunc) *Backoff { durationUnit time.Duration, jitterer JitterFunc, timeTeller TimeTeller) *Backoff {
return &Backoff{ return &Backoff{
step: 0, step: 0,
maxBackoff: maxBackoff, maxBackoff: maxBackoff,
resetDelta: resetDelta, resetDelta: resetDelta,
lastAttempt: timeGetter(), lastAttempt: timeTeller.Now(),
durationUnit: durationUnit, durationUnit: durationUnit,
jitterer: jitterer, jitterer: jitterer,
timeGetter: timeGetter, timeTeller: timeTeller,
} }
} }
func (b *Backoff) maybeReset() { func (b *Backoff) maybeReset() {
now := b.timeGetter() now := b.timeTeller.Now()
lastAttemptDelta := float64(now.Sub(b.lastAttempt) / b.durationUnit) lastAttemptDelta := float64(now.Sub(b.lastAttempt) / b.durationUnit)
b.lastAttempt = now b.lastAttempt = now
@ -96,7 +113,18 @@ func (b *Backoff) GetDelay() time.Duration {
} }
func (b *Backoff) Delay() { func (b *Backoff) Delay() {
delay := b.GetDelay() b.DelayContext(context.Background())
log.Printf("Backoff for %s", delay) }
time.Sleep(delay)
func (b *Backoff) DelayContext(ctx context.Context) bool {
delay := b.GetDelay()
log.Printf("Backoff for %s starts", delay)
select {
case <-b.timeTeller.After(delay):
log.Printf("Backoff for %s ends", delay)
case <-ctx.Done():
log.Printf("Backoff for %s canceled by context", delay)
return false
}
return true
} }

View File

@ -15,6 +15,7 @@
package main package main
import ( import (
"context"
"testing" "testing"
"time" "time"
) )
@ -23,29 +24,39 @@ type FakeTime struct {
timeseries []int timeseries []int
lastIndex int lastIndex int
durationUnit time.Duration durationUnit time.Duration
afterChan chan time.Time
} }
func (f *FakeTime) GetTime() time.Time { func (f *FakeTime) Now() time.Time {
timeDelta := time.Duration(f.timeseries[f.lastIndex]) * f.durationUnit timeDelta := time.Duration(f.timeseries[f.lastIndex]) * f.durationUnit
fakeTime := time.Unix(0, 0).Add(timeDelta) fakeTime := time.Unix(0, 0).Add(timeDelta)
f.lastIndex++ f.lastIndex++
return fakeTime return fakeTime
} }
func (f *FakeTime) After(d time.Duration) <-chan time.Time {
return f.afterChan
}
func FakeJitter(input int) int { func FakeJitter(input int) int {
return input return input
} }
func RunBackoffTest(t *testing.T, func MakeTestingBackoff(maxBackoff float64, resetDelta float64, elapsedTime []int) (*Backoff, *FakeTime) {
maxBackoff float64, resetDelta float64,
elapsedTime []int, expectedDelays []int) {
fakeTime := &FakeTime{ fakeTime := &FakeTime{
timeseries: elapsedTime, timeseries: elapsedTime,
lastIndex: 0, lastIndex: 0,
durationUnit: time.Millisecond, durationUnit: time.Millisecond,
afterChan: make(chan time.Time, 1),
} }
backoff := NewBackoffForTesting(maxBackoff, resetDelta, time.Millisecond, backoff := NewBackoffForTesting(maxBackoff, resetDelta, time.Millisecond,
FakeJitter, fakeTime.GetTime) FakeJitter, fakeTime)
return backoff, fakeTime
}
func RunBackoffTest(t *testing.T, maxBackoff float64, resetDelta float64, elapsedTime []int, expectedDelays []int) {
backoff, _ := MakeTestingBackoff(maxBackoff, resetDelta, elapsedTime)
for i, value := range expectedDelays { for i, value := range expectedDelays {
expected_delay := time.Duration(value) * time.Millisecond expected_delay := time.Duration(value) * time.Millisecond
@ -78,3 +89,19 @@ func TestBackoffReset(t *testing.T) {
[]int{0, 2, 4, 0, 2, 0, 2, 4}, []int{0, 2, 4, 0, 2, 0, 2, 4},
) )
} }
func TestBackoffDelayContext(t *testing.T) {
backoff, fakeTime := MakeTestingBackoff(8, 32, []int{0, 0, 0})
ctx, cancel := context.WithCancel(context.Background())
fakeTime.afterChan <- time.Now()
if ok := backoff.DelayContext(ctx); !ok {
t.Errorf("Expired time does not return true")
}
cancel()
if ok := backoff.DelayContext(ctx); ok {
t.Errorf("Canceled context does not return false")
}
}

8
irc.go
View File

@ -164,7 +164,9 @@ func (notifier *IRCNotifier) HandleKick(nick string, channel string) {
} }
log.Printf("Being kicked out of %s, re-joining", channel) log.Printf("Being kicked out of %s, re-joining", channel)
go func() { go func() {
state.BackoffCounter.Delay() if ok := state.BackoffCounter.DelayContext(notifier.ctx); !ok {
return
}
notifier.Client.Join(state.Channel.Name, state.Channel.Password) notifier.Client.Join(state.Channel.Name, state.Channel.Password)
}() }()
@ -242,7 +244,9 @@ func (notifier *IRCNotifier) Run() {
for notifier.ctx.Err() != context.Canceled { for notifier.ctx.Err() != context.Canceled {
if !notifier.Client.Connected() { if !notifier.Client.Connected() {
log.Printf("Connecting to IRC %s", notifier.Client.Config().Server) log.Printf("Connecting to IRC %s", notifier.Client.Config().Server)
notifier.BackoffCounter.Delay() if ok := notifier.BackoffCounter.DelayContext(notifier.ctx); !ok {
continue
}
if err := notifier.Client.Connect(); err != nil { if err := notifier.Client.Connect(); err != nil {
log.Printf("Could not connect to IRC: %s", err) log.Printf("Could not connect to IRC: %s", err)
continue continue

View File

@ -202,6 +202,11 @@ func (f *FakeDelayer) Delay() {
log.Printf("Faking Backoff") log.Printf("Faking Backoff")
} }
func (f *FakeDelayer) DelayContext(ctx context.Context) bool {
log.Printf("Faking Backoff")
return true
}
func makeTestIRCConfig(IRCPort int) *Config { func makeTestIRCConfig(IRCPort int) *Config {
return &Config{ return &Config{
IRCNick: "foo", IRCNick: "foo",