diff --git a/config.go b/config.go index 3308595..8b941d2 100644 --- a/config.go +++ b/config.go @@ -43,6 +43,7 @@ type Config struct { IRCHostPass string `yaml:"irc_host_password"` IRCUseSSL bool `yaml:"irc_use_ssl"` IRCVerifySSL bool `yaml:"irc_verify_ssl"` + IRCPingSecs int `yaml:"irc_ping_secs"` IRCChannels []IRCChannel `yaml:"irc_channels"` MsgTemplate string `yaml:"msg_template"` MsgOnce bool `yaml:"msg_once_per_alert_group"` @@ -66,6 +67,7 @@ func LoadConfig(configFile string) (*Config, error) { IRCHostPass: "", IRCUseSSL: true, IRCVerifySSL: true, + IRCPingSecs: 60, IRCChannels: []IRCChannel{}, MsgOnce: false, UsePrivmsg: false, diff --git a/irc.go b/irc.go index 7ea284f..f3ebd80 100644 --- a/irc.go +++ b/irc.go @@ -29,7 +29,6 @@ import ( ) const ( - pingFrequencySecs = 60 connectionTimeoutSecs = 30 nickservWaitSecs = 10 ircConnectMaxBackoffSecs = 300 @@ -69,7 +68,7 @@ func makeGOIRCConfig(config *Config) *irc.Config { ServerName: config.IRCHost, InsecureSkipVerify: !config.IRCVerifySSL, } - ircConfig.PingFreq = pingFrequencySecs * time.Second + ircConfig.PingFreq = time.Duration(config.IRCPingSecs) * time.Second ircConfig.Timeout = connectionTimeoutSecs * time.Second ircConfig.NewNick = func(n string) string { return n + "^" } @@ -102,6 +101,9 @@ type IRCNotifier struct { sessionUp bool sessionUpSignal chan bool sessionDownSignal chan bool + sessionPongSignal chan bool + sessionPingOnce sync.Once + sessionLastPong time.Time sessionWg sync.WaitGroup channelReconciler *ChannelReconciler @@ -136,6 +138,7 @@ func NewIRCNotifier(config *Config, alertMsgs chan AlertMsg, delayerMaker Delaye AlertMsgs: alertMsgs, sessionUpSignal: make(chan bool), sessionDownSignal: make(chan bool), + sessionPongSignal: make(chan bool), channelReconciler: channelReconciler, UsePrivmsg: config.UsePrivmsg, NickservDelayWait: nickservWaitSecs * time.Second, @@ -166,6 +169,11 @@ func (n *IRCNotifier) registerHandlers() { n.HandleNotice(line.Nick, line.Text()) }) + n.Client.HandleFunc(irc.PONG, + func(_ *irc.Conn, line *irc.Line) { + n.sessionPongSignal <- true + }) + for _, event := range []string{"433"} { n.Client.HandleFunc(event, loggerHandler) } @@ -282,6 +290,7 @@ func (n *IRCNotifier) ShutdownPhase() { logging.Info("Wait for IRC disconnect to complete") select { case <-n.sessionDownSignal: + case <-n.sessionPongSignal: case <-n.timeTeller.After(n.Client.Config().Timeout): logging.Warn("Timeout while waiting for IRC disconnect to complete, stopping anyway") } @@ -294,6 +303,23 @@ func (n *IRCNotifier) ConnectedPhase(ctx context.Context) { select { case alertMsg := <-n.AlertMsgs: n.SendAlertMsg(ctx, &alertMsg) + case <-n.sessionPongSignal: + logging.Debug("Received a PONG message; prev PONG was at %v", n.sessionLastPong) + n.sessionLastPong = time.Now() + case <-time.After(2*n.IrcConfig.PingFreq - time.Since(n.sessionLastPong)): + // Calling n.Client.Close() will trigger n.sessionDownSignal. However, as + // this also dispatches a hook, which we will catch as sessionDownSignal, + // this needs to be done in a concurrent fashion if we don't want to + // deadlock ourself. + // + // Furthermore, as this time.After(...) interval is now zero, it will also + // trigger when visiting this select the next time. To mitigate multiple + // Close() calls, it is wrapped within an sync.Once which will be reset + // during SetupPhase's sessionUpSignal. + n.sessionPingOnce.Do(func() { + logging.Error("Haven't received a PONG after twice the PING period") + go n.Client.Close() + }) case <-n.sessionDownSignal: n.sessionUp = false n.sessionWg.Done() @@ -325,6 +351,8 @@ func (n *IRCNotifier) SetupPhase(ctx context.Context) { select { case <-n.sessionUpSignal: n.sessionUp = true + n.sessionPingOnce = sync.Once{} + n.sessionLastPong = time.Now() n.sessionWg.Add(1) n.MaybeGhostNick() n.MaybeWaitForNickserv() @@ -332,6 +360,8 @@ func (n *IRCNotifier) SetupPhase(ctx context.Context) { ircConnectedGauge.Set(1) case <-n.sessionDownSignal: logging.Warn("Receiving a session down before the session is up, this is odd") + case <-n.sessionPongSignal: + logging.Warn("Receiving a PONG before the session is up, this is odd") case <-ctx.Done(): logging.Info("IRC routine asked to terminate") } diff --git a/irc_test.go b/irc_test.go index 8e6b915..6ad9294 100644 --- a/irc_test.go +++ b/irc_test.go @@ -35,6 +35,7 @@ func makeTestIRCConfig(IRCPort int) *Config { IRCHost: "127.0.0.1", IRCPort: IRCPort, IRCUseSSL: false, + IRCPingSecs: 60, IRCChannels: []IRCChannel{ IRCChannel{Name: "#foo"}, }, @@ -466,6 +467,66 @@ func TestReconnectNickIdentChange(t *testing.T) { } } +func TestReconnectMissingPong(t *testing.T) { + server, port := makeTestServer(t) + config := makeTestIRCConfig(port) + config.IRCPingSecs = 2 + notifier, _, ctx, cancel, stopWg := makeTestNotifier(t, config) + + var testStep sync.WaitGroup + + joinHandler := func(conn *bufio.ReadWriter, line *irc.Line) error { + testStep.Done() + return hJOIN(conn, line) + } + server.SetHandler("JOIN", joinHandler) + + testStep.Add(1) + go notifier.Run(ctx, stopWg) + + // Wait until the pre-joined channel is seen. + testStep.Wait() + + // Wait for a client disconnect due to missing PONGs... + testStep.Add(1) + + // Wait again until the pre-joined channel is seen. + testStep.Wait() + + cancel() + stopWg.Wait() + + server.Stop() + + expectedCommands := []string{ + // Commands from first connection + "NICK foo", + "USER foo 12 * :", + "PRIVMSG ChanServ :UNBAN #foo", + "JOIN #foo", + // Ping commands contain timestamps; note the modified check below! + // Commands from reconnection + "NICK foo", + "USER foo 12 * :", + "PRIVMSG ChanServ :UNBAN #foo", + "JOIN #foo", + "QUIT :see ya", + } + + logPos := 0 + for _, cmd := range server.Log { + if !strings.HasPrefix(cmd, "PING :") { + server.Log[logPos] = cmd + logPos++ + } + } + server.Log = server.Log[:logPos] + + if !reflect.DeepEqual(expectedCommands, server.Log) { + t.Error("Reconnection did not happen correctly. Received commands:\n", strings.Join(server.Log, "\n")) + } +} + func TestConnectErrorRetry(t *testing.T) { server, port := makeTestServer(t) config := makeTestIRCConfig(port) @@ -674,3 +735,54 @@ func TestStopRunningWhenHalfConnected(t *testing.T) { t.Error("Alert not sent correctly. Received commands:\n", strings.Join(server.Log, "\n")) } } + +func TestPingPong(t *testing.T) { + server, port := makeTestServer(t) + config := makeTestIRCConfig(port) + config.IRCPingSecs = 1 + notifier, _, ctx, cancel, stopWg := makeTestNotifier(t, config) + + var testStep sync.WaitGroup + + pingHandler := func(conn *bufio.ReadWriter, line *irc.Line) error { + testStep.Done() + r := fmt.Sprintf(":example.com PONG example.com :%s", line.Args[0]) + _, err := conn.WriteString(r) + return err + } + server.SetHandler("PING", pingHandler) + + testStep.Add(3) + go notifier.Run(ctx, stopWg) + + // Wait until three PING-PONGs have been exchanged.. + testStep.Wait() + + cancel() + stopWg.Wait() + + server.Stop() + + expectedCommands := []string{ + // Commands from first connection + "NICK foo", + "USER foo 12 * :", + "PRIVMSG ChanServ :UNBAN #foo", + "JOIN #foo", + // Ping commands contain timestamps; note the modified check below! + "PING :__TS1__", + "PING :__TS2__", + "PING :__TS3__", + "QUIT :see ya", + } + + expectedCommandsCheck := len(expectedCommands) == len(server.Log) && + reflect.DeepEqual(expectedCommands[:4], server.Log[:4]) && + strings.HasPrefix(server.Log[4], "PING :") && + strings.HasPrefix(server.Log[5], "PING :") && + strings.HasPrefix(server.Log[6], "PING :") && + reflect.DeepEqual(expectedCommands[7:], server.Log[7:]) + if !expectedCommandsCheck { + t.Error("Reconnection did not happen correctly. Received commands:\n", strings.Join(server.Log, "\n")) + } +}