diff --git a/backoff.go b/backoff.go index 2d16ffe..2d40286 100644 --- a/backoff.go +++ b/backoff.go @@ -50,22 +50,6 @@ func jitterFunc(input int) int { return rand.Intn(input) } -// TimeTeller interface allows injection of fake time during testing -type TimeTeller interface { - Now() time.Time - After(time.Duration) <-chan time.Time -} - -type RealTime struct{} - -func (r *RealTime) Now() time.Time { - return time.Now() -} - -func (r *RealTime) After(d time.Duration) <-chan time.Time { - return time.After(d) -} - type BackoffMaker struct{} func (bm *BackoffMaker) NewDelayer(maxBackoff float64, resetDelta float64, durationUnit time.Duration) Delayer { diff --git a/backoff_test.go b/backoff_test.go index f262ea3..8e21894 100644 --- a/backoff_test.go +++ b/backoff_test.go @@ -20,24 +20,6 @@ import ( "time" ) -type FakeTime struct { - timeseries []int - lastIndex int - durationUnit time.Duration - afterChan chan time.Time -} - -func (f *FakeTime) Now() time.Time { - timeDelta := time.Duration(f.timeseries[f.lastIndex]) * f.durationUnit - fakeTime := time.Unix(0, 0).Add(timeDelta) - f.lastIndex++ - return fakeTime -} - -func (f *FakeTime) After(d time.Duration) <-chan time.Time { - return f.afterChan -} - func FakeJitter(input int) int { return input } diff --git a/fake_timeteller.go b/fake_timeteller.go new file mode 100644 index 0000000..f7d4b7a --- /dev/null +++ b/fake_timeteller.go @@ -0,0 +1,37 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "time" +) + +type FakeTime struct { + timeseries []int + lastIndex int + durationUnit time.Duration + afterChan chan time.Time +} + +func (f *FakeTime) Now() time.Time { + timeDelta := time.Duration(f.timeseries[f.lastIndex]) * f.durationUnit + fakeTime := time.Unix(0, 0).Add(timeDelta) + f.lastIndex++ + return fakeTime +} + +func (f *FakeTime) After(d time.Duration) <-chan time.Time { + return f.afterChan +} diff --git a/irc.go b/irc.go index 9b78eae..dc049bc 100644 --- a/irc.go +++ b/irc.go @@ -100,9 +100,10 @@ type IRCNotifier struct { NickservDelayWait time.Duration BackoffCounter Delayer + timeTeller TimeTeller } -func NewIRCNotifier(config *Config, alertMsgs chan AlertMsg, delayerMaker DelayerMaker) (*IRCNotifier, error) { +func NewIRCNotifier(config *Config, alertMsgs chan AlertMsg, delayerMaker DelayerMaker, timeTeller TimeTeller) (*IRCNotifier, error) { ircConfig := makeGOIRCConfig(config) @@ -112,7 +113,7 @@ func NewIRCNotifier(config *Config, alertMsgs chan AlertMsg, delayerMaker Delaye ircConnectMaxBackoffSecs, ircConnectBackoffResetSecs, time.Second) - channelReconciler := NewChannelReconciler(config, client, delayerMaker) + channelReconciler := NewChannelReconciler(config, client, delayerMaker, timeTeller) notifier := &IRCNotifier{ Nick: config.IRCNick, @@ -125,6 +126,7 @@ func NewIRCNotifier(config *Config, alertMsgs chan AlertMsg, delayerMaker Delaye UsePrivmsg: config.UsePrivmsg, NickservDelayWait: nickservWaitSecs * time.Second, BackoffCounter: backoffCounter, + timeTeller: timeTeller, } notifier.registerHandlers() @@ -183,7 +185,7 @@ func (n *IRCNotifier) ChannelJoined(ctx context.Context, channel string) bool { select { case <-waitJoined: return true - case <-time.After(ircJoinWaitSecs * time.Second): + case <-n.timeTeller.After(ircJoinWaitSecs * time.Second): log.Printf("Channel %s not joined after %d seconds, giving bad news to caller", channel, ircJoinWaitSecs) return false case <-ctx.Done(): @@ -220,7 +222,7 @@ func (n *IRCNotifier) ShutdownPhase() { log.Printf("Wait for IRC disconnect to complete") select { case <-n.sessionDownSignal: - case <-time.After(n.Client.Config().Timeout): + case <-n.timeTeller.After(n.Client.Config().Timeout): log.Printf("Timeout while waiting for IRC disconnect to complete, stopping anyway") } n.sessionWg.Done() diff --git a/irc_server_for_test.go b/irc_server_for_test.go index 80791af..2494ddc 100644 --- a/irc_server_for_test.go +++ b/irc_server_for_test.go @@ -16,6 +16,7 @@ package main import ( "bufio" + "errors" "fmt" "io" "log" @@ -130,6 +131,17 @@ func (s *testServer) handleConnection(conn net.Conn) { } } +func (s *testServer) SendMsg(msg string) error { + if s.Client == nil { + return errors.New("Cannot write without client connected") + } + bufConn := bufio.NewWriter(s.Client) + log.Printf("=Server= sending to client: %s", msg) + _, err := bufConn.WriteString(msg) + bufConn.Flush() + return err +} + func (s *testServer) SetCloseEarly(h closeEarlyHandler) { s.closeEarlyMu.Lock() defer s.closeEarlyMu.Unlock() diff --git a/irc_test.go b/irc_test.go index a6ac856..0a89728 100644 --- a/irc_test.go +++ b/irc_test.go @@ -44,11 +44,14 @@ func makeTestIRCConfig(IRCPort int) *Config { func makeTestNotifier(t *testing.T, config *Config) (*IRCNotifier, chan AlertMsg, context.Context, context.CancelFunc, *sync.WaitGroup) { fakeDelayerMaker := &FakeDelayerMaker{} + fakeTime := &FakeTime{ + afterChan: make(chan time.Time, 1), + } alertMsgs := make(chan AlertMsg) ctx, cancel := context.WithCancel(context.Background()) stopWg := sync.WaitGroup{} stopWg.Add(1) - notifier, err := NewIRCNotifier(config, alertMsgs, fakeDelayerMaker) + notifier, err := NewIRCNotifier(config, alertMsgs, fakeDelayerMaker, fakeTime) if err != nil { t.Fatal(fmt.Sprintf("Could not create IRC notifier: %s", err)) } diff --git a/main.go b/main.go index cfb303a..46cf21b 100644 --- a/main.go +++ b/main.go @@ -40,7 +40,7 @@ func main() { alertMsgs := make(chan AlertMsg, config.AlertBufferSize) stopWg.Add(1) - ircNotifier, err := NewIRCNotifier(config, alertMsgs, &BackoffMaker{}) + ircNotifier, err := NewIRCNotifier(config, alertMsgs, &BackoffMaker{}, &RealTime{}) if err != nil { log.Printf("Could not create IRC notifier: %s", err) return diff --git a/reconciler.go b/reconciler.go index d05eafc..e372da2 100644 --- a/reconciler.go +++ b/reconciler.go @@ -32,7 +32,9 @@ const ( type channelState struct { channel IRCChannel client *irc.Conn - delayer Delayer + + delayer Delayer + timeTeller TimeTeller joinDone chan struct{} // joined when channel is closed joined bool @@ -42,13 +44,14 @@ type channelState struct { mu sync.Mutex } -func newChannelState(channel *IRCChannel, client *irc.Conn, delayerMaker DelayerMaker) *channelState { +func newChannelState(channel *IRCChannel, client *irc.Conn, delayerMaker DelayerMaker, timeTeller TimeTeller) *channelState { delayer := delayerMaker.NewDelayer(ircJoinMaxBackoffSecs, ircJoinBackoffResetSecs, time.Second) return &channelState{ channel: *channel, client: client, delayer: delayer, + timeTeller: timeTeller, joinDone: make(chan struct{}), joined: false, joinUnsetSignal: make(chan bool), @@ -108,7 +111,7 @@ func (c *channelState) join(ctx context.Context) { select { case <-c.JoinDone(): log.Printf("Channel %s monitor: join succeeded", c.channel.Name) - case <-time.After(ircJoinWaitSecs * time.Second): + case <-c.timeTeller.After(ircJoinWaitSecs * time.Second): log.Printf("Channel %s monitor: could not join after %d seconds, will retry", c.channel.Name, ircJoinWaitSecs) case <-ctx.Done(): log.Printf("Channel %s monitor: context canceled while waiting for join", c.channel.Name) @@ -147,6 +150,7 @@ type ChannelReconciler struct { client *irc.Conn delayerMaker DelayerMaker + timeTeller TimeTeller channels map[string]*channelState @@ -157,11 +161,12 @@ type ChannelReconciler struct { mu sync.Mutex } -func NewChannelReconciler(config *Config, client *irc.Conn, delayerMaker DelayerMaker) *ChannelReconciler { +func NewChannelReconciler(config *Config, client *irc.Conn, delayerMaker DelayerMaker, timeTeller TimeTeller) *ChannelReconciler { reconciler := &ChannelReconciler{ preJoinChannels: config.IRCChannels, client: client, delayerMaker: delayerMaker, + timeTeller: timeTeller, channels: make(map[string]*channelState), } @@ -219,7 +224,7 @@ func (r *ChannelReconciler) HandleKick(nick string, channel string) { } func (r *ChannelReconciler) unsafeAddChannel(channel *IRCChannel) *channelState { - c := newChannelState(channel, r.client, r.delayerMaker) + c := newChannelState(channel, r.client, r.delayerMaker, r.timeTeller) r.stopWg.Add(1) go c.Monitor(r.stopCtx, &r.stopWg) diff --git a/reconciler_test.go b/reconciler_test.go index d9cb09e..d039bc3 100644 --- a/reconciler_test.go +++ b/reconciler_test.go @@ -21,21 +21,12 @@ import ( "sort" "sync" "testing" + "time" irc "github.com/fluffle/goirc/client" ) -func makeReconcilerTestIRCConfig(IRCPort int) *Config { - config := makeTestIRCConfig(IRCPort) - config.IRCChannels = []IRCChannel{ - IRCChannel{Name: "#foo"}, - IRCChannel{Name: "#bar"}, - IRCChannel{Name: "#baz"}, - } - return config -} - -func makeTestReconciler(config *Config) (*ChannelReconciler, chan bool, chan bool) { +func makeTestReconciler(config *Config) (*ChannelReconciler, chan bool, chan bool, *FakeTime) { sessionUp := make(chan bool) sessionDown := make(chan bool) @@ -53,15 +44,23 @@ func makeTestReconciler(config *Config) (*ChannelReconciler, chan bool, chan boo }) fakeDelayerMaker := &FakeDelayerMaker{} - reconciler := NewChannelReconciler(config, client, fakeDelayerMaker) + fakeTime := &FakeTime{ + afterChan: make(chan time.Time, 1), + } + reconciler := NewChannelReconciler(config, client, fakeDelayerMaker, fakeTime) - return reconciler, sessionUp, sessionDown + return reconciler, sessionUp, sessionDown, fakeTime } func TestPreJoinChannels(t *testing.T) { server, port := makeTestServer(t) - config := makeReconcilerTestIRCConfig(port) - reconciler, sessionUp, sessionDown := makeTestReconciler(config) + config := makeTestIRCConfig(port) + config.IRCChannels = []IRCChannel{ + IRCChannel{Name: "#foo"}, + IRCChannel{Name: "#bar"}, + IRCChannel{Name: "#baz"}, + } + reconciler, sessionUp, sessionDown, _ := makeTestReconciler(config) var testStep sync.WaitGroup @@ -98,3 +97,86 @@ func TestPreJoinChannels(t *testing.T) { t.Error("Did not pre-join channels") } } + +func TestKeepJoining(t *testing.T) { + server, port := makeTestServer(t) + config := makeTestIRCConfig(port) + reconciler, sessionUp, sessionDown, fakeTime := makeTestReconciler(config) + + var testStep sync.WaitGroup + + var joinedCounter int + + // Confirm join only after a few attempts + joinHandler := func(conn *bufio.ReadWriter, line *irc.Line) error { + joinedCounter++ + + if joinedCounter == 3 { + testStep.Done() + return hJOIN(conn, line) + } else { + fakeTime.afterChan <- time.Now() + } + return nil + } + server.SetHandler("JOIN", joinHandler) + + testStep.Add(1) + + reconciler.client.Connect() + + <-sessionUp + reconciler.Start(context.Background()) + + testStep.Wait() + + reconciler.client.Quit("see ya") + <-sessionDown + reconciler.Stop() + + server.Stop() + + expectedJoinedCounter := 3 + + if !reflect.DeepEqual(expectedJoinedCounter, joinedCounter) { + t.Error("Did not keep joining") + } +} + +func TestKickRejoin(t *testing.T) { + server, port := makeTestServer(t) + config := makeTestIRCConfig(port) + reconciler, sessionUp, sessionDown, _ := makeTestReconciler(config) + + var testStep sync.WaitGroup + + // Wait for channel to be joined + joinHandler := func(conn *bufio.ReadWriter, line *irc.Line) error { + hJOIN(conn, line) + testStep.Done() + return nil + } + server.SetHandler("JOIN", joinHandler) + + testStep.Add(1) + + reconciler.client.Connect() + + <-sessionUp + reconciler.Start(context.Background()) + + testStep.Wait() + + // Kick and wait for channel to be joined again + testStep.Add(1) + server.SendMsg(":test!~test@example.com KICK #foo foo :Bye!\n") + + testStep.Wait() + + reconciler.client.Quit("see ya") + <-sessionDown + reconciler.Stop() + + server.Stop() + +} diff --git a/time.go b/time.go new file mode 100644 index 0000000..a7420d2 --- /dev/null +++ b/time.go @@ -0,0 +1,35 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "time" +) + +// TimeTeller interface allows injection of fake time during testing +type TimeTeller interface { + Now() time.Time + After(time.Duration) <-chan time.Time +} + +type RealTime struct{} + +func (r *RealTime) Now() time.Time { + return time.Now() +} + +func (r *RealTime) After(d time.Duration) <-chan time.Time { + return time.After(d) +}