From 4d0f1f26b0d4e309709dd8330f055662a295edb4 Mon Sep 17 00:00:00 2001 From: Luca Bigliardi Date: Sat, 27 Mar 2021 17:29:54 +0100 Subject: [PATCH] Graceful disconnect upon context cancel Make sure the underlying library context cancellation happens only after the session has been shutdown. Signed-off-by: Luca Bigliardi --- context.go | 50 ++++++++++++++++++++++++++++++++++++++++++++++++++ irc.go | 21 ++++++++++++--------- irc_test.go | 1 - main.go | 19 ------------------- 4 files changed, 62 insertions(+), 29 deletions(-) create mode 100644 context.go diff --git a/context.go b/context.go new file mode 100644 index 0000000..a829201 --- /dev/null +++ b/context.go @@ -0,0 +1,50 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "log" + "os" + "os/signal" + "sync" +) + +func WithSignal(ctx context.Context, s ...os.Signal) (context.Context, context.CancelFunc) { + sigCtx, cancel := context.WithCancel(ctx) + c := make(chan os.Signal, 1) + signal.Notify(c, s...) + go func() { + select { + case <-c: + log.Printf("Received %s, exiting", s) + cancel() + case <-ctx.Done(): + cancel() + } + signal.Stop(c) + }() + return sigCtx, cancel +} + +func WithWaitGroup(ctx context.Context, wg *sync.WaitGroup) context.Context { + wgCtx, cancel := context.WithCancel(context.Background()) + go func() { + <-ctx.Done() + wg.Wait() + cancel() + }() + return wgCtx +} diff --git a/irc.go b/irc.go index ccdad83..9b78eae 100644 --- a/irc.go +++ b/irc.go @@ -92,6 +92,7 @@ type IRCNotifier struct { sessionUp bool sessionUpSignal chan bool sessionDownSignal chan bool + sessionWg sync.WaitGroup channelReconciler *ChannelReconciler @@ -212,19 +213,19 @@ func (n *IRCNotifier) SendAlertMsg(ctx context.Context, alertMsg *AlertMsg) { } func (n *IRCNotifier) ShutdownPhase() { - if n.Client.Connected() { + if n.sessionUp { log.Printf("IRC client connected, quitting") n.Client.Quit("see ya") - if n.sessionUp { - log.Printf("Session is up, wait for IRC disconnect to complete") - select { - case <-n.sessionDownSignal: - case <-time.After(n.Client.Config().Timeout): - log.Printf("Timeout while waiting for IRC disconnect to complete, stopping anyway") - } + log.Printf("Wait for IRC disconnect to complete") + select { + case <-n.sessionDownSignal: + case <-time.After(n.Client.Config().Timeout): + log.Printf("Timeout while waiting for IRC disconnect to complete, stopping anyway") } + n.sessionWg.Done() } + log.Printf("IRC shutdown complete") } func (n *IRCNotifier) ConnectedPhase(ctx context.Context) { @@ -233,6 +234,7 @@ func (n *IRCNotifier) ConnectedPhase(ctx context.Context) { n.SendAlertMsg(ctx, &alertMsg) case <-n.sessionDownSignal: n.sessionUp = false + n.sessionWg.Done() n.channelReconciler.Stop() n.Client.Quit("see ya") ircConnectedGauge.Set(0) @@ -247,7 +249,7 @@ func (n *IRCNotifier) SetupPhase(ctx context.Context) { if ok := n.BackoffCounter.DelayContext(ctx); !ok { return } - if err := n.Client.ConnectContext(ctx); err != nil { + if err := n.Client.ConnectContext(WithWaitGroup(ctx, &n.sessionWg)); err != nil { log.Printf("Could not connect to IRC: %s", err) return } @@ -256,6 +258,7 @@ func (n *IRCNotifier) SetupPhase(ctx context.Context) { select { case <-n.sessionUpSignal: n.sessionUp = true + n.sessionWg.Add(1) n.MaybeIdentifyNick() n.channelReconciler.Start(ctx) ircConnectedGauge.Set(1) diff --git a/irc_test.go b/irc_test.go index f61b1fc..a6ac856 100644 --- a/irc_test.go +++ b/irc_test.go @@ -582,7 +582,6 @@ func TestStopRunningWhenHalfConnected(t *testing.T) { expectedCommands := []string{ "NICK foo", "USER foo 12 * :", - "QUIT :see ya", } if !reflect.DeepEqual(expectedCommands, server.Log) { diff --git a/main.go b/main.go index 1a1350d..cfb303a 100644 --- a/main.go +++ b/main.go @@ -18,29 +18,10 @@ import ( "context" "flag" "log" - "os" - "os/signal" "sync" "syscall" ) -func WithSignal(ctx context.Context, s ...os.Signal) (context.Context, context.CancelFunc) { - ctx, cancel := context.WithCancel(ctx) - c := make(chan os.Signal, 1) - signal.Notify(c, s...) - go func() { - select { - case <-c: - log.Printf("Received %s, exiting", s) - cancel() - case <-ctx.Done(): - cancel() - } - signal.Stop(c) - }() - return ctx, cancel -} - func main() { configFile := flag.String("config", "", "Config file path.")