diff --git a/http.go b/http.go index 9a8bedb..b42eb38 100644 --- a/http.go +++ b/http.go @@ -51,12 +51,11 @@ var ( type HTTPListener func(string, http.Handler) error type HTTPServer struct { - StoppedRunning chan bool - Addr string - Port int - formatter *Formatter - AlertMsgs chan AlertMsg - httpListener HTTPListener + Addr string + Port int + formatter *Formatter + AlertMsgs chan AlertMsg + httpListener HTTPListener } func NewHTTPServer(config *Config, alertMsgs chan AlertMsg) ( @@ -71,12 +70,11 @@ func NewHTTPServerForTesting(config *Config, alertMsgs chan AlertMsg, return nil, err } server := &HTTPServer{ - StoppedRunning: make(chan bool), - Addr: config.HTTPHost, - Port: config.HTTPPort, - formatter: formatter, - AlertMsgs: alertMsgs, - httpListener: httpListener, + Addr: config.HTTPHost, + Port: config.HTTPPort, + formatter: formatter, + AlertMsgs: alertMsgs, + httpListener: httpListener, } return server, nil @@ -135,5 +133,4 @@ func (server *HTTPServer) Run() { if err := server.httpListener(listenAddr, router); err != nil { log.Printf("Could not start http server: %s", err) } - server.StoppedRunning <- true } diff --git a/http_test.go b/http_test.go index 4ece468..3e59839 100644 --- a/http_test.go +++ b/http_test.go @@ -77,7 +77,6 @@ func RunHTTPTest(t *testing.T, listener.router.ServeHTTP(responseRecorder, request) listener.StopServing <- true - <-httpServer.StoppedRunning return responseRecorder.Result() } diff --git a/irc.go b/irc.go index 1816817..f30f2e6 100644 --- a/irc.go +++ b/irc.go @@ -15,10 +15,12 @@ package main import ( + "context" "crypto/tls" "log" "strconv" "strings" + "sync" "time" irc "github.com/fluffle/goirc/client" @@ -63,12 +65,13 @@ type ChannelState struct { type IRCNotifier struct { // Nick stores the nickname specified in the config, because irc.Client // might change its copy. - Nick string - NickPassword string - Client *irc.Conn - StopRunning chan bool - StoppedRunning chan bool - AlertMsgs chan AlertMsg + Nick string + NickPassword string + Client *irc.Conn + AlertMsgs chan AlertMsg + + ctx context.Context + stopWg *sync.WaitGroup // irc.Conn has a Connected() method that can tell us wether the TCP // connection is up, and thus if we should trigger connect/disconnect. @@ -88,7 +91,7 @@ type IRCNotifier struct { BackoffCounter Delayer } -func NewIRCNotifier(config *Config, alertMsgs chan AlertMsg) (*IRCNotifier, error) { +func NewIRCNotifier(ctx context.Context, stopWg *sync.WaitGroup, config *Config, alertMsgs chan AlertMsg) (*IRCNotifier, error) { ircConfig := irc.NewConfig(config.IRCNick) ircConfig.Me.Ident = config.IRCNick @@ -113,9 +116,9 @@ func NewIRCNotifier(config *Config, alertMsgs chan AlertMsg) (*IRCNotifier, erro Nick: config.IRCNick, NickPassword: config.IRCNickPass, Client: irc.Client(ircConfig), - StopRunning: make(chan bool), - StoppedRunning: make(chan bool), AlertMsgs: alertMsgs, + ctx: ctx, + stopWg: stopWg, sessionUpSignal: make(chan bool), sessionDownSignal: make(chan bool), PreJoinChannels: config.IRCChannels, @@ -234,19 +237,14 @@ func (notifier *IRCNotifier) MaybeSendAlertMsg(alertMsg *AlertMsg) { } func (notifier *IRCNotifier) Run() { - keepGoing := true - for keepGoing { + defer notifier.stopWg.Done() + + for notifier.ctx.Err() != context.Canceled { if !notifier.Client.Connected() { log.Printf("Connecting to IRC %s", notifier.Client.Config().Server) notifier.BackoffCounter.Delay() if err := notifier.Client.Connect(); err != nil { log.Printf("Could not connect to IRC: %s", err) - select { - case <-notifier.StopRunning: - log.Printf("IRC routine not connected but asked to terminate") - keepGoing = false - default: - } continue } log.Printf("Connected to IRC server, waiting to establish session") @@ -265,9 +263,8 @@ func (notifier *IRCNotifier) Run() { notifier.CleanupChannels() notifier.Client.Quit("see ya") ircConnectedGauge.Set(0) - case <-notifier.StopRunning: + case <-notifier.ctx.Done(): log.Printf("IRC routine asked to terminate") - keepGoing = false } } if notifier.Client.Connected() { @@ -283,5 +280,4 @@ func (notifier *IRCNotifier) Run() { } } } - notifier.StoppedRunning <- true } diff --git a/irc_test.go b/irc_test.go index 30fb085..15d608d 100644 --- a/irc_test.go +++ b/irc_test.go @@ -16,6 +16,7 @@ package main import ( "bufio" + "context" "fmt" "io" "log" @@ -217,22 +218,25 @@ func makeTestIRCConfig(IRCPort int) *Config { } } -func makeTestNotifier(t *testing.T, config *Config) (*IRCNotifier, chan AlertMsg) { +func makeTestNotifier(t *testing.T, config *Config) (*IRCNotifier, chan AlertMsg, context.CancelFunc, *sync.WaitGroup) { alertMsgs := make(chan AlertMsg) - notifier, err := NewIRCNotifier(config, alertMsgs) + ctx, cancel := context.WithCancel(context.Background()) + stopWg := sync.WaitGroup{} + stopWg.Add(1) + notifier, err := NewIRCNotifier(ctx, &stopWg, config, alertMsgs) if err != nil { t.Fatal(fmt.Sprintf("Could not create IRC notifier: %s", err)) } notifier.Client.Config().Flood = true notifier.BackoffCounter = &FakeDelayer{} - return notifier, alertMsgs + return notifier, alertMsgs, cancel, &stopWg } func TestPreJoinChannels(t *testing.T) { server, port := makeTestServer(t) config := makeTestIRCConfig(port) - notifier, _ := makeTestNotifier(t, config) + notifier, _, cancel, _ := makeTestNotifier(t, config) var testStep sync.WaitGroup @@ -250,7 +254,7 @@ func TestPreJoinChannels(t *testing.T) { testStep.Wait() - notifier.StopRunning <- true + cancel() server.Stop() expectedCommands := []string{ @@ -271,7 +275,7 @@ func TestServerPassword(t *testing.T) { server, port := makeTestServer(t) config := makeTestIRCConfig(port) config.IRCHostPass = "hostsecret" - notifier, _ := makeTestNotifier(t, config) + notifier, _, cancel, _ := makeTestNotifier(t, config) var testStep sync.WaitGroup @@ -289,7 +293,7 @@ func TestServerPassword(t *testing.T) { testStep.Wait() - notifier.StopRunning <- true + cancel() server.Stop() expectedCommands := []string{ @@ -310,7 +314,7 @@ func TestServerPassword(t *testing.T) { func TestSendAlertOnPreJoinedChannel(t *testing.T) { server, port := makeTestServer(t) config := makeTestIRCConfig(port) - notifier, alertMsgs := makeTestNotifier(t, config) + notifier, alertMsgs, cancel, _ := makeTestNotifier(t, config) var testStep sync.WaitGroup @@ -345,7 +349,7 @@ func TestSendAlertOnPreJoinedChannel(t *testing.T) { testStep.Wait() - notifier.StopRunning <- true + cancel() server.Stop() expectedCommands := []string{ @@ -367,7 +371,7 @@ func TestUsePrivmsgToSendAlertOnPreJoinedChannel(t *testing.T) { server, port := makeTestServer(t) config := makeTestIRCConfig(port) config.UsePrivmsg = true - notifier, alertMsgs := makeTestNotifier(t, config) + notifier, alertMsgs, cancel, _ := makeTestNotifier(t, config) var testStep sync.WaitGroup @@ -402,7 +406,7 @@ func TestUsePrivmsgToSendAlertOnPreJoinedChannel(t *testing.T) { testStep.Wait() - notifier.StopRunning <- true + cancel() server.Stop() expectedCommands := []string{ @@ -423,7 +427,7 @@ func TestUsePrivmsgToSendAlertOnPreJoinedChannel(t *testing.T) { func TestSendAlertAndJoinChannel(t *testing.T) { server, port := makeTestServer(t) config := makeTestIRCConfig(port) - notifier, alertMsgs := makeTestNotifier(t, config) + notifier, alertMsgs, cancel, _ := makeTestNotifier(t, config) var testStep sync.WaitGroup @@ -459,7 +463,7 @@ func TestSendAlertAndJoinChannel(t *testing.T) { testStep.Wait() - notifier.StopRunning <- true + cancel() server.Stop() expectedCommands := []string{ @@ -482,7 +486,7 @@ func TestSendAlertAndJoinChannel(t *testing.T) { func TestSendAlertDisconnected(t *testing.T) { server, port := makeTestServer(t) config := makeTestIRCConfig(port) - notifier, alertMsgs := makeTestNotifier(t, config) + notifier, alertMsgs, cancel, _ := makeTestNotifier(t, config) var testStep, holdUserStep sync.WaitGroup @@ -535,7 +539,7 @@ func TestSendAlertDisconnected(t *testing.T) { testStep.Wait() - notifier.StopRunning <- true + cancel() server.Stop() expectedCommands := []string{ @@ -557,7 +561,7 @@ func TestSendAlertDisconnected(t *testing.T) { func TestReconnect(t *testing.T) { server, port := makeTestServer(t) config := makeTestIRCConfig(port) - notifier, _ := makeTestNotifier(t, config) + notifier, _, cancel, _ := makeTestNotifier(t, config) var testStep sync.WaitGroup @@ -583,7 +587,7 @@ func TestReconnect(t *testing.T) { // Wait again until the last pre-joined channel is seen. testStep.Wait() - notifier.StopRunning <- true + cancel() server.Stop() expectedCommands := []string{ @@ -613,7 +617,7 @@ func TestConnectErrorRetry(t *testing.T) { // Attempt SSL handshake. The server does not support it, resulting in // a connection error. config.IRCUseSSL = true - notifier, _ := makeTestNotifier(t, config) + notifier, _, cancel, _ := makeTestNotifier(t, config) var testStep, joinStep sync.WaitGroup @@ -643,7 +647,7 @@ func TestConnectErrorRetry(t *testing.T) { joinStep.Wait() - notifier.StopRunning <- true + cancel() server.Stop() expectedCommands := []string{ @@ -664,7 +668,7 @@ func TestIdentify(t *testing.T) { server, port := makeTestServer(t) config := makeTestIRCConfig(port) config.IRCNickPass = "nickpassword" - notifier, _ := makeTestNotifier(t, config) + notifier, _, cancel, _ := makeTestNotifier(t, config) notifier.NickservDelayWait = 0 * time.Second var testStep sync.WaitGroup @@ -685,7 +689,7 @@ func TestIdentify(t *testing.T) { testStep.Wait() - notifier.StopRunning <- true + cancel() server.Stop() expectedCommands := []string{ @@ -707,7 +711,7 @@ func TestGhostAndIdentify(t *testing.T) { server, port := makeTestServer(t) config := makeTestIRCConfig(port) config.IRCNickPass = "nickpassword" - notifier, _ := makeTestNotifier(t, config) + notifier, _, cancel, _ := makeTestNotifier(t, config) notifier.NickservDelayWait = 0 * time.Second var testStep, usedNick, unregisteredNickHandler sync.WaitGroup @@ -746,7 +750,7 @@ func TestGhostAndIdentify(t *testing.T) { testStep.Wait() - notifier.StopRunning <- true + cancel() server.Stop() expectedCommands := []string{ @@ -770,7 +774,7 @@ func TestGhostAndIdentify(t *testing.T) { func TestStopRunningWhenHalfConnected(t *testing.T) { server, port := makeTestServer(t) config := makeTestIRCConfig(port) - notifier, _ := makeTestNotifier(t, config) + notifier, _, cancel, stopWg := makeTestNotifier(t, config) var testStep, holdQuitWait sync.WaitGroup @@ -797,9 +801,9 @@ func TestStopRunningWhenHalfConnected(t *testing.T) { testStep.Wait() - notifier.StopRunning <- true + cancel() - <-notifier.StoppedRunning + stopWg.Wait() holdQuitWait.Wait() diff --git a/main.go b/main.go index d198042..97a4ea2 100644 --- a/main.go +++ b/main.go @@ -15,21 +15,40 @@ package main import ( + "context" "flag" "log" "os" "os/signal" + "sync" "syscall" ) +func WithSignal(ctx context.Context, s ...os.Signal) (context.Context, context.CancelFunc) { + ctx, cancel := context.WithCancel(ctx) + c := make(chan os.Signal, 1) + signal.Notify(c, s...) + go func() { + select { + case <-c: + log.Printf("Received %s, exiting", s) + cancel() + case <-ctx.Done(): + cancel() + } + signal.Stop(c) + }() + return ctx, cancel +} + func main() { configFile := flag.String("config", "", "Config file path.") flag.Parse() - signals := make(chan os.Signal, 1) - signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM) + ctx, _ := WithSignal(context.Background(), syscall.SIGINT, syscall.SIGTERM) + stopWg := sync.WaitGroup{} config, err := LoadConfig(*configFile) if err != nil { @@ -39,7 +58,8 @@ func main() { alertMsgs := make(chan AlertMsg, config.AlertBufferSize) - ircNotifier, err := NewIRCNotifier(config, alertMsgs) + stopWg.Add(1) + ircNotifier, err := NewIRCNotifier(ctx, &stopWg, config, alertMsgs) if err != nil { log.Printf("Could not create IRC notifier: %s", err) return @@ -53,15 +73,5 @@ func main() { } go httpServer.Run() - select { - case <-httpServer.StoppedRunning: - log.Printf("Http server terminated, exiting") - case <-ircNotifier.StoppedRunning: - log.Printf("IRC notifier stopped running, exiting") - case s := <-signals: - log.Printf("Received %s, exiting", s) - ircNotifier.StopRunning <- true - log.Printf("Waiting for IRC to quit") - <-ircNotifier.StoppedRunning - } + stopWg.Wait() }