new channel management logic

this should handle bans and kicks a bit better

Signed-off-by: Luca Bigliardi <shammash@google.com>
This commit is contained in:
Luca Bigliardi 2021-03-27 00:49:16 +01:00
parent c22e7a0c84
commit 0b2fbef1f2
5 changed files with 255 additions and 59 deletions

36
irc.go
View File

@ -177,14 +177,36 @@ func (n *IRCNotifier) MaybeIdentifyNick() {
time.Sleep(n.NickservDelayWait)
}
func (n *IRCNotifier) MaybeSendAlertMsg(alertMsg *AlertMsg) {
func (n *IRCNotifier) ChannelJoined(channel string) bool {
isJoined, waitJoined := n.channelReconciler.JoinChannel(channel)
if isJoined {
return true
}
select {
case <-waitJoined:
return true
case <-time.After(ircJoinWaitSecs * time.Second):
log.Printf("Channel %s not joined after %d seconds, giving bad news to caller", channel, ircJoinWaitSecs)
return false
case <-n.stopCtx.Done():
log.Printf("Context canceled while waiting for join on channel %s", channel)
return false
}
}
func (n *IRCNotifier) SendAlertMsg(alertMsg *AlertMsg) {
if !n.sessionUp {
log.Printf("Cannot send alert to %s : IRC not connected",
alertMsg.Channel)
log.Printf("Cannot send alert to %s : IRC not connected", alertMsg.Channel)
ircSendMsgErrors.WithLabelValues(alertMsg.Channel, "not_connected").Inc()
return
}
n.channelReconciler.JoinChannel(&IRCChannel{Name: alertMsg.Channel})
if !n.ChannelJoined(alertMsg.Channel) {
log.Printf("Cannot send alert to %s : cannot join channel", alertMsg.Channel)
ircSendMsgErrors.WithLabelValues(alertMsg.Channel, "not_joined").Inc()
return
}
if n.UsePrivmsg {
n.Client.Privmsg(alertMsg.Channel, alertMsg.Alert)
@ -213,10 +235,10 @@ func (n *IRCNotifier) ShutdownPhase() {
func (n *IRCNotifier) ConnectedPhase() {
select {
case alertMsg := <-n.AlertMsgs:
n.MaybeSendAlertMsg(&alertMsg)
n.SendAlertMsg(&alertMsg)
case <-n.sessionDownSignal:
n.sessionUp = false
n.channelReconciler.CleanupChannels()
n.channelReconciler.Stop()
n.Client.Quit("see ya")
ircConnectedGauge.Set(0)
case <-n.stopCtx.Done():
@ -240,7 +262,7 @@ func (n *IRCNotifier) SetupPhase() {
case <-n.sessionUpSignal:
n.sessionUp = true
n.MaybeIdentifyNick()
n.channelReconciler.JoinChannels()
n.channelReconciler.Start(n.stopCtx)
ircConnectedGauge.Set(1)
case <-n.sessionDownSignal:
log.Printf("Receiving a session down before the session is up, this is odd")

View File

@ -108,7 +108,7 @@ func TestSendAlertOnPreJoinedChannel(t *testing.T) {
if line.Args[0] == testChannel {
testStep.Done()
}
return nil
return hJOIN(conn, line)
}
server.SetHandler("JOIN", joinedHandler)
@ -117,7 +117,7 @@ func TestSendAlertOnPreJoinedChannel(t *testing.T) {
testStep.Wait()
server.SetHandler("JOIN", nil)
server.SetHandler("JOIN", hJOIN)
noticeHandler := func(conn *bufio.ReadWriter, line *irc.Line) error {
testStep.Done()
@ -163,7 +163,7 @@ func TestUsePrivmsgToSendAlertOnPreJoinedChannel(t *testing.T) {
if line.Args[0] == testChannel {
testStep.Done()
}
return nil
return hJOIN(conn, line)
}
server.SetHandler("JOIN", joinedHandler)
@ -172,7 +172,7 @@ func TestUsePrivmsgToSendAlertOnPreJoinedChannel(t *testing.T) {
testStep.Wait()
server.SetHandler("JOIN", nil)
server.SetHandler("JOIN", hJOIN)
privmsgHandler := func(conn *bufio.ReadWriter, line *irc.Line) error {
testStep.Done()
@ -215,7 +215,7 @@ func TestSendAlertAndJoinChannel(t *testing.T) {
// ordering.
joinHandler := func(conn *bufio.ReadWriter, line *irc.Line) error {
testStep.Done()
return nil
return hJOIN(conn, line)
}
server.SetHandler("JOIN", joinHandler)
@ -224,7 +224,7 @@ func TestSendAlertAndJoinChannel(t *testing.T) {
testStep.Wait()
server.SetHandler("JOIN", nil)
server.SetHandler("JOIN", hJOIN)
noticeHandler := func(conn *bufio.ReadWriter, line *irc.Line) error {
testStep.Done()
@ -295,7 +295,7 @@ func TestSendAlertDisconnected(t *testing.T) {
testStep.Add(1)
joinHandler := func(conn *bufio.ReadWriter, line *irc.Line) error {
testStep.Done()
return nil
return hJOIN(conn, line)
}
server.SetHandler("JOIN", joinHandler)
@ -339,7 +339,7 @@ func TestReconnect(t *testing.T) {
joinHandler := func(conn *bufio.ReadWriter, line *irc.Line) error {
testStep.Done()
return nil
return hJOIN(conn, line)
}
server.SetHandler("JOIN", joinHandler)
@ -409,7 +409,7 @@ func TestConnectErrorRetry(t *testing.T) {
joinStep.Add(1)
joinHandler := func(conn *bufio.ReadWriter, line *irc.Line) error {
joinStep.Done()
return nil
return hJOIN(conn, line)
}
server.SetHandler("JOIN", joinHandler)
server.SetCloseEarly(nil)
@ -446,7 +446,7 @@ func TestIdentify(t *testing.T) {
// after identification).
joinHandler := func(conn *bufio.ReadWriter, line *irc.Line) error {
testStep.Done()
return nil
return hJOIN(conn, line)
}
server.SetHandler("JOIN", joinHandler)
@ -498,7 +498,7 @@ func TestGhostAndIdentify(t *testing.T) {
// after identification).
joinHandler := func(conn *bufio.ReadWriter, line *irc.Line) error {
testStep.Done()
return nil
return hJOIN(conn, line)
}
server.SetHandler("JOIN", joinHandler)

View File

@ -29,6 +29,12 @@ import (
type LineHandlerFunc func(*bufio.ReadWriter, *irc.Line) error
func hJOIN(conn *bufio.ReadWriter, line *irc.Line) error {
r := fmt.Sprintf(":foo!foo@example.com JOIN :%s\n", line.Args[0])
_, err := conn.WriteString(r)
return err
}
func hUSER(conn *bufio.ReadWriter, line *irc.Line) error {
r := fmt.Sprintf(":example.com 001 %s :Welcome\n", line.Args[0])
_, err := conn.WriteString(r)
@ -61,6 +67,7 @@ func (s *testServer) setDefaultHandlers() {
if s.lineHandlers == nil {
s.lineHandlers = make(map[string]LineHandlerFunc)
}
s.lineHandlers["JOIN"] = hJOIN
s.lineHandlers["USER"] = hUSER
s.lineHandlers["QUIT"] = hQUIT
}

View File

@ -23,9 +23,123 @@ import (
irc "github.com/fluffle/goirc/client"
)
const (
ircJoinWaitSecs = 10
ircJoinMaxBackoffSecs = 300
ircJoinBackoffResetSecs = 1800
)
type channelState struct {
Channel IRCChannel
BackoffCounter Delayer
channel IRCChannel
client *irc.Conn
delayer Delayer
joinDone chan struct{} // joined when channel is closed
joined bool
joinUnsetSignal chan bool
mu sync.Mutex
}
func newChannelState(channel *IRCChannel, client *irc.Conn, delayerMaker DelayerMaker) *channelState {
delayer := delayerMaker.NewDelayer(ircJoinMaxBackoffSecs, ircJoinBackoffResetSecs, time.Second)
return &channelState{
channel: *channel,
client: client,
delayer: delayer,
joinDone: make(chan struct{}),
joined: false,
joinUnsetSignal: make(chan bool),
}
}
func (c *channelState) JoinDone() <-chan struct{} {
c.mu.Lock()
defer c.mu.Unlock()
return c.joinDone
}
func (c *channelState) SetJoined() {
c.mu.Lock()
defer c.mu.Unlock()
if c.joined == true {
log.Printf("Not setting JOIN state on channel %s: already set", c.channel.Name)
return
}
log.Printf("Setting JOIN state on channel %s", c.channel.Name)
c.joined = true
close(c.joinDone)
}
func (c *channelState) UnsetJoined() {
c.mu.Lock()
defer c.mu.Unlock()
if c.joined == false {
log.Printf("Not removing JOIN state on channel %s: already not set", c.channel.Name)
return
}
log.Printf("Removing JOIN state on channel %s", c.channel.Name)
c.joined = false
c.joinDone = make(chan struct{})
// eventually poke monitor routine
select {
case c.joinUnsetSignal <- true:
default:
}
}
func (c *channelState) join(ctx context.Context) {
log.Printf("Channel %s monitor: waiting to join", c.channel.Name)
if ok := c.delayer.DelayContext(ctx); !ok {
return
}
c.client.Join(c.channel.Name, c.channel.Password)
log.Printf("Channel %s monitor: join request sent", c.channel.Name)
select {
case <-c.JoinDone():
log.Printf("Channel %s monitor: join succeeded", c.channel.Name)
case <-time.After(ircJoinWaitSecs * time.Second):
log.Printf("Channel %s monitor: could not join after %d seconds, will retry", c.channel.Name, ircJoinWaitSecs)
case <-ctx.Done():
log.Printf("Channel %s monitor: context canceled while waiting for join", c.channel.Name)
}
}
func (c *channelState) monitorJoinUnset(ctx context.Context) {
select {
case <-c.joinUnsetSignal:
log.Printf("Channel %s monitor: channel no longer joined", c.channel.Name)
case <-ctx.Done():
log.Printf("Channel %s monitor: context canceled while monitoring", c.channel.Name)
}
}
func (c *channelState) Monitor(ctx context.Context, wg *sync.WaitGroup) {
defer wg.Done()
joined := func() bool {
c.mu.Lock()
defer c.mu.Unlock()
return c.joined
}
for ctx.Err() != context.Canceled {
if !joined() {
c.join(ctx)
} else {
c.monitorJoinUnset(ctx)
}
}
}
type ChannelReconciler struct {
@ -57,53 +171,107 @@ func NewChannelReconciler(config *Config, client *irc.Conn, delayerMaker Delayer
}
func (r *ChannelReconciler) registerHandlers() {
r.client.HandleFunc(irc.JOIN,
func(_ *irc.Conn, line *irc.Line) {
r.HandleJoin(line.Nick, line.Args[0])
})
r.client.HandleFunc(irc.KICK,
func(_ *irc.Conn, line *irc.Line) {
r.HandleKick(line.Args[1], line.Args[0])
})
}
func (r *ChannelReconciler) HandleJoin(nick string, channel string) {
r.mu.Lock()
defer r.mu.Unlock()
if nick != r.client.Me().Nick {
// received join info for somebody else
return
}
log.Printf("Received JOIN confirmation for channel %s", channel)
c, ok := r.channels[channel]
if !ok {
log.Printf("Not processing JOIN for channel %s: unknown channel", channel)
return
}
c.SetJoined()
}
func (r *ChannelReconciler) HandleKick(nick string, channel string) {
r.mu.Lock()
defer r.mu.Unlock()
if nick != r.client.Me().Nick {
// received kick info for somebody else
return
}
state, ok := r.channels[channel]
log.Printf("Received KICK for channel %s", channel)
c, ok := r.channels[channel]
if !ok {
log.Printf("Being kicked out of non-joined channel (%s), ignoring", channel)
log.Printf("Not processing KICK for channel %s: unknown channel", channel)
return
}
log.Printf("Being kicked out of %s, re-joining", channel)
go func() {
if ok := state.BackoffCounter.DelayContext(r.stopCtx); !ok {
return
}
r.client.Join(state.Channel.Name, state.Channel.Password)
}()
c.UnsetJoined()
}
func (r *ChannelReconciler) CleanupChannels() {
log.Printf("Deregistering all channels.")
func (r *ChannelReconciler) unsafeAddChannel(channel *IRCChannel) *channelState {
c := newChannelState(channel, r.client, r.delayerMaker)
r.stopWg.Add(1)
go c.Monitor(r.stopCtx, &r.stopWg)
r.channels[channel.Name] = c
return c
}
func (r *ChannelReconciler) JoinChannel(channel string) (bool, <-chan struct{}) {
r.mu.Lock()
defer r.mu.Unlock()
c, ok := r.channels[channel]
if !ok {
log.Printf("Request to JOIN new channel %s", channel)
c = r.unsafeAddChannel(&IRCChannel{Name: channel})
}
select {
case <-c.JoinDone():
return true, nil
default:
return false, c.JoinDone()
}
}
func (r *ChannelReconciler) unsafeStop() {
if r.stopCtxCancel == nil {
// calling stop before first start, ignoring
return
}
r.stopCtxCancel()
r.stopWg.Wait()
r.channels = make(map[string]*channelState)
}
func (r *ChannelReconciler) JoinChannel(channel *IRCChannel) {
if _, joined := r.channels[channel.Name]; joined {
return
}
log.Printf("Joining %s", channel.Name)
r.client.Join(channel.Name, channel.Password)
state := &channelState{
Channel: *channel,
BackoffCounter: r.delayerMaker.NewDelayer(
ircConnectMaxBackoffSecs, ircConnectBackoffResetSecs,
time.Second),
}
r.channels[channel.Name] = state
func (r *ChannelReconciler) Stop() {
r.mu.Lock()
defer r.mu.Unlock()
r.unsafeStop()
}
func (r *ChannelReconciler) JoinChannels() {
func (r *ChannelReconciler) Start(ctx context.Context) {
r.mu.Lock()
defer r.mu.Unlock()
r.unsafeStop()
r.stopCtx, r.stopCtxCancel = context.WithCancel(ctx)
for _, channel := range r.preJoinChannels {
r.JoinChannel(&channel)
r.unsafeAddChannel(&channel)
}
}

View File

@ -16,7 +16,9 @@ package main
import (
"bufio"
"context"
"reflect"
"sort"
"sync"
"testing"
@ -63,12 +65,14 @@ func TestPreJoinChannels(t *testing.T) {
var testStep sync.WaitGroup
joinedChannels := []string{}
joinHandler := func(conn *bufio.ReadWriter, line *irc.Line) error {
// #baz is configured as the last channel to pre-join
if line.Args[0] == "#baz" {
joinedChannels = append(joinedChannels, line.Args[0])
if len(joinedChannels) == 3 {
testStep.Done()
}
return nil
return hJOIN(conn, line)
}
server.SetHandler("JOIN", joinHandler)
@ -77,25 +81,20 @@ func TestPreJoinChannels(t *testing.T) {
reconciler.client.Connect()
<-sessionUp
reconciler.JoinChannels()
reconciler.Start(context.Background())
testStep.Wait()
reconciler.client.Quit("see ya")
<-sessionDown
reconciler.Stop()
server.Stop()
expectedCommands := []string{
"NICK foo",
"USER foo 12 * :",
"JOIN #foo",
"JOIN #bar",
"JOIN #baz",
"QUIT :see ya",
}
expectedJoinedChannels := []string{"#bar", "#baz", "#foo"}
sort.Strings(joinedChannels)
if !reflect.DeepEqual(expectedCommands, server.Log) {
if !reflect.DeepEqual(expectedJoinedChannels, joinedChannels) {
t.Error("Did not pre-join channels")
}
}