Graceful disconnect upon context cancel

Make sure the underlying library context cancellation happens only
after the session has been shutdown.

Signed-off-by: Luca Bigliardi <shammash@google.com>
This commit is contained in:
Luca Bigliardi 2021-03-27 17:29:54 +01:00
parent 882cecd6a6
commit 4d0f1f26b0
4 changed files with 62 additions and 29 deletions

50
context.go Normal file
View File

@ -0,0 +1,50 @@
// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
"context"
"log"
"os"
"os/signal"
"sync"
)
func WithSignal(ctx context.Context, s ...os.Signal) (context.Context, context.CancelFunc) {
sigCtx, cancel := context.WithCancel(ctx)
c := make(chan os.Signal, 1)
signal.Notify(c, s...)
go func() {
select {
case <-c:
log.Printf("Received %s, exiting", s)
cancel()
case <-ctx.Done():
cancel()
}
signal.Stop(c)
}()
return sigCtx, cancel
}
func WithWaitGroup(ctx context.Context, wg *sync.WaitGroup) context.Context {
wgCtx, cancel := context.WithCancel(context.Background())
go func() {
<-ctx.Done()
wg.Wait()
cancel()
}()
return wgCtx
}

21
irc.go
View File

@ -92,6 +92,7 @@ type IRCNotifier struct {
sessionUp bool
sessionUpSignal chan bool
sessionDownSignal chan bool
sessionWg sync.WaitGroup
channelReconciler *ChannelReconciler
@ -212,19 +213,19 @@ func (n *IRCNotifier) SendAlertMsg(ctx context.Context, alertMsg *AlertMsg) {
}
func (n *IRCNotifier) ShutdownPhase() {
if n.Client.Connected() {
if n.sessionUp {
log.Printf("IRC client connected, quitting")
n.Client.Quit("see ya")
if n.sessionUp {
log.Printf("Session is up, wait for IRC disconnect to complete")
select {
case <-n.sessionDownSignal:
case <-time.After(n.Client.Config().Timeout):
log.Printf("Timeout while waiting for IRC disconnect to complete, stopping anyway")
}
log.Printf("Wait for IRC disconnect to complete")
select {
case <-n.sessionDownSignal:
case <-time.After(n.Client.Config().Timeout):
log.Printf("Timeout while waiting for IRC disconnect to complete, stopping anyway")
}
n.sessionWg.Done()
}
log.Printf("IRC shutdown complete")
}
func (n *IRCNotifier) ConnectedPhase(ctx context.Context) {
@ -233,6 +234,7 @@ func (n *IRCNotifier) ConnectedPhase(ctx context.Context) {
n.SendAlertMsg(ctx, &alertMsg)
case <-n.sessionDownSignal:
n.sessionUp = false
n.sessionWg.Done()
n.channelReconciler.Stop()
n.Client.Quit("see ya")
ircConnectedGauge.Set(0)
@ -247,7 +249,7 @@ func (n *IRCNotifier) SetupPhase(ctx context.Context) {
if ok := n.BackoffCounter.DelayContext(ctx); !ok {
return
}
if err := n.Client.ConnectContext(ctx); err != nil {
if err := n.Client.ConnectContext(WithWaitGroup(ctx, &n.sessionWg)); err != nil {
log.Printf("Could not connect to IRC: %s", err)
return
}
@ -256,6 +258,7 @@ func (n *IRCNotifier) SetupPhase(ctx context.Context) {
select {
case <-n.sessionUpSignal:
n.sessionUp = true
n.sessionWg.Add(1)
n.MaybeIdentifyNick()
n.channelReconciler.Start(ctx)
ircConnectedGauge.Set(1)

View File

@ -582,7 +582,6 @@ func TestStopRunningWhenHalfConnected(t *testing.T) {
expectedCommands := []string{
"NICK foo",
"USER foo 12 * :",
"QUIT :see ya",
}
if !reflect.DeepEqual(expectedCommands, server.Log) {

19
main.go
View File

@ -18,29 +18,10 @@ import (
"context"
"flag"
"log"
"os"
"os/signal"
"sync"
"syscall"
)
func WithSignal(ctx context.Context, s ...os.Signal) (context.Context, context.CancelFunc) {
ctx, cancel := context.WithCancel(ctx)
c := make(chan os.Signal, 1)
signal.Notify(c, s...)
go func() {
select {
case <-c:
log.Printf("Received %s, exiting", s)
cancel()
case <-ctx.Done():
cancel()
}
signal.Stop(c)
}()
return ctx, cancel
}
func main() {
configFile := flag.String("config", "", "Config file path.")