mirror of
https://github.com/google/alertmanager-irc-relay.git
synced 2024-11-16 16:19:21 +01:00
new channel management logic
this should handle bans and kicks a bit better Signed-off-by: Luca Bigliardi <shammash@google.com>
This commit is contained in:
parent
c22e7a0c84
commit
0b2fbef1f2
36
irc.go
36
irc.go
@ -177,14 +177,36 @@ func (n *IRCNotifier) MaybeIdentifyNick() {
|
||||
time.Sleep(n.NickservDelayWait)
|
||||
}
|
||||
|
||||
func (n *IRCNotifier) MaybeSendAlertMsg(alertMsg *AlertMsg) {
|
||||
func (n *IRCNotifier) ChannelJoined(channel string) bool {
|
||||
|
||||
isJoined, waitJoined := n.channelReconciler.JoinChannel(channel)
|
||||
if isJoined {
|
||||
return true
|
||||
}
|
||||
|
||||
select {
|
||||
case <-waitJoined:
|
||||
return true
|
||||
case <-time.After(ircJoinWaitSecs * time.Second):
|
||||
log.Printf("Channel %s not joined after %d seconds, giving bad news to caller", channel, ircJoinWaitSecs)
|
||||
return false
|
||||
case <-n.stopCtx.Done():
|
||||
log.Printf("Context canceled while waiting for join on channel %s", channel)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (n *IRCNotifier) SendAlertMsg(alertMsg *AlertMsg) {
|
||||
if !n.sessionUp {
|
||||
log.Printf("Cannot send alert to %s : IRC not connected",
|
||||
alertMsg.Channel)
|
||||
log.Printf("Cannot send alert to %s : IRC not connected", alertMsg.Channel)
|
||||
ircSendMsgErrors.WithLabelValues(alertMsg.Channel, "not_connected").Inc()
|
||||
return
|
||||
}
|
||||
n.channelReconciler.JoinChannel(&IRCChannel{Name: alertMsg.Channel})
|
||||
if !n.ChannelJoined(alertMsg.Channel) {
|
||||
log.Printf("Cannot send alert to %s : cannot join channel", alertMsg.Channel)
|
||||
ircSendMsgErrors.WithLabelValues(alertMsg.Channel, "not_joined").Inc()
|
||||
return
|
||||
}
|
||||
|
||||
if n.UsePrivmsg {
|
||||
n.Client.Privmsg(alertMsg.Channel, alertMsg.Alert)
|
||||
@ -213,10 +235,10 @@ func (n *IRCNotifier) ShutdownPhase() {
|
||||
func (n *IRCNotifier) ConnectedPhase() {
|
||||
select {
|
||||
case alertMsg := <-n.AlertMsgs:
|
||||
n.MaybeSendAlertMsg(&alertMsg)
|
||||
n.SendAlertMsg(&alertMsg)
|
||||
case <-n.sessionDownSignal:
|
||||
n.sessionUp = false
|
||||
n.channelReconciler.CleanupChannels()
|
||||
n.channelReconciler.Stop()
|
||||
n.Client.Quit("see ya")
|
||||
ircConnectedGauge.Set(0)
|
||||
case <-n.stopCtx.Done():
|
||||
@ -240,7 +262,7 @@ func (n *IRCNotifier) SetupPhase() {
|
||||
case <-n.sessionUpSignal:
|
||||
n.sessionUp = true
|
||||
n.MaybeIdentifyNick()
|
||||
n.channelReconciler.JoinChannels()
|
||||
n.channelReconciler.Start(n.stopCtx)
|
||||
ircConnectedGauge.Set(1)
|
||||
case <-n.sessionDownSignal:
|
||||
log.Printf("Receiving a session down before the session is up, this is odd")
|
||||
|
22
irc_test.go
22
irc_test.go
@ -108,7 +108,7 @@ func TestSendAlertOnPreJoinedChannel(t *testing.T) {
|
||||
if line.Args[0] == testChannel {
|
||||
testStep.Done()
|
||||
}
|
||||
return nil
|
||||
return hJOIN(conn, line)
|
||||
}
|
||||
server.SetHandler("JOIN", joinedHandler)
|
||||
|
||||
@ -117,7 +117,7 @@ func TestSendAlertOnPreJoinedChannel(t *testing.T) {
|
||||
|
||||
testStep.Wait()
|
||||
|
||||
server.SetHandler("JOIN", nil)
|
||||
server.SetHandler("JOIN", hJOIN)
|
||||
|
||||
noticeHandler := func(conn *bufio.ReadWriter, line *irc.Line) error {
|
||||
testStep.Done()
|
||||
@ -163,7 +163,7 @@ func TestUsePrivmsgToSendAlertOnPreJoinedChannel(t *testing.T) {
|
||||
if line.Args[0] == testChannel {
|
||||
testStep.Done()
|
||||
}
|
||||
return nil
|
||||
return hJOIN(conn, line)
|
||||
}
|
||||
server.SetHandler("JOIN", joinedHandler)
|
||||
|
||||
@ -172,7 +172,7 @@ func TestUsePrivmsgToSendAlertOnPreJoinedChannel(t *testing.T) {
|
||||
|
||||
testStep.Wait()
|
||||
|
||||
server.SetHandler("JOIN", nil)
|
||||
server.SetHandler("JOIN", hJOIN)
|
||||
|
||||
privmsgHandler := func(conn *bufio.ReadWriter, line *irc.Line) error {
|
||||
testStep.Done()
|
||||
@ -215,7 +215,7 @@ func TestSendAlertAndJoinChannel(t *testing.T) {
|
||||
// ordering.
|
||||
joinHandler := func(conn *bufio.ReadWriter, line *irc.Line) error {
|
||||
testStep.Done()
|
||||
return nil
|
||||
return hJOIN(conn, line)
|
||||
}
|
||||
server.SetHandler("JOIN", joinHandler)
|
||||
|
||||
@ -224,7 +224,7 @@ func TestSendAlertAndJoinChannel(t *testing.T) {
|
||||
|
||||
testStep.Wait()
|
||||
|
||||
server.SetHandler("JOIN", nil)
|
||||
server.SetHandler("JOIN", hJOIN)
|
||||
|
||||
noticeHandler := func(conn *bufio.ReadWriter, line *irc.Line) error {
|
||||
testStep.Done()
|
||||
@ -295,7 +295,7 @@ func TestSendAlertDisconnected(t *testing.T) {
|
||||
testStep.Add(1)
|
||||
joinHandler := func(conn *bufio.ReadWriter, line *irc.Line) error {
|
||||
testStep.Done()
|
||||
return nil
|
||||
return hJOIN(conn, line)
|
||||
}
|
||||
server.SetHandler("JOIN", joinHandler)
|
||||
|
||||
@ -339,7 +339,7 @@ func TestReconnect(t *testing.T) {
|
||||
|
||||
joinHandler := func(conn *bufio.ReadWriter, line *irc.Line) error {
|
||||
testStep.Done()
|
||||
return nil
|
||||
return hJOIN(conn, line)
|
||||
}
|
||||
server.SetHandler("JOIN", joinHandler)
|
||||
|
||||
@ -409,7 +409,7 @@ func TestConnectErrorRetry(t *testing.T) {
|
||||
joinStep.Add(1)
|
||||
joinHandler := func(conn *bufio.ReadWriter, line *irc.Line) error {
|
||||
joinStep.Done()
|
||||
return nil
|
||||
return hJOIN(conn, line)
|
||||
}
|
||||
server.SetHandler("JOIN", joinHandler)
|
||||
server.SetCloseEarly(nil)
|
||||
@ -446,7 +446,7 @@ func TestIdentify(t *testing.T) {
|
||||
// after identification).
|
||||
joinHandler := func(conn *bufio.ReadWriter, line *irc.Line) error {
|
||||
testStep.Done()
|
||||
return nil
|
||||
return hJOIN(conn, line)
|
||||
}
|
||||
server.SetHandler("JOIN", joinHandler)
|
||||
|
||||
@ -498,7 +498,7 @@ func TestGhostAndIdentify(t *testing.T) {
|
||||
// after identification).
|
||||
joinHandler := func(conn *bufio.ReadWriter, line *irc.Line) error {
|
||||
testStep.Done()
|
||||
return nil
|
||||
return hJOIN(conn, line)
|
||||
}
|
||||
server.SetHandler("JOIN", joinHandler)
|
||||
|
||||
|
@ -29,6 +29,12 @@ import (
|
||||
|
||||
type LineHandlerFunc func(*bufio.ReadWriter, *irc.Line) error
|
||||
|
||||
func hJOIN(conn *bufio.ReadWriter, line *irc.Line) error {
|
||||
r := fmt.Sprintf(":foo!foo@example.com JOIN :%s\n", line.Args[0])
|
||||
_, err := conn.WriteString(r)
|
||||
return err
|
||||
}
|
||||
|
||||
func hUSER(conn *bufio.ReadWriter, line *irc.Line) error {
|
||||
r := fmt.Sprintf(":example.com 001 %s :Welcome\n", line.Args[0])
|
||||
_, err := conn.WriteString(r)
|
||||
@ -61,6 +67,7 @@ func (s *testServer) setDefaultHandlers() {
|
||||
if s.lineHandlers == nil {
|
||||
s.lineHandlers = make(map[string]LineHandlerFunc)
|
||||
}
|
||||
s.lineHandlers["JOIN"] = hJOIN
|
||||
s.lineHandlers["USER"] = hUSER
|
||||
s.lineHandlers["QUIT"] = hQUIT
|
||||
}
|
||||
|
224
reconciler.go
224
reconciler.go
@ -23,9 +23,123 @@ import (
|
||||
irc "github.com/fluffle/goirc/client"
|
||||
)
|
||||
|
||||
const (
|
||||
ircJoinWaitSecs = 10
|
||||
ircJoinMaxBackoffSecs = 300
|
||||
ircJoinBackoffResetSecs = 1800
|
||||
)
|
||||
|
||||
type channelState struct {
|
||||
Channel IRCChannel
|
||||
BackoffCounter Delayer
|
||||
channel IRCChannel
|
||||
client *irc.Conn
|
||||
delayer Delayer
|
||||
|
||||
joinDone chan struct{} // joined when channel is closed
|
||||
joined bool
|
||||
|
||||
joinUnsetSignal chan bool
|
||||
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func newChannelState(channel *IRCChannel, client *irc.Conn, delayerMaker DelayerMaker) *channelState {
|
||||
delayer := delayerMaker.NewDelayer(ircJoinMaxBackoffSecs, ircJoinBackoffResetSecs, time.Second)
|
||||
|
||||
return &channelState{
|
||||
channel: *channel,
|
||||
client: client,
|
||||
delayer: delayer,
|
||||
joinDone: make(chan struct{}),
|
||||
joined: false,
|
||||
joinUnsetSignal: make(chan bool),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *channelState) JoinDone() <-chan struct{} {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
return c.joinDone
|
||||
}
|
||||
|
||||
func (c *channelState) SetJoined() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.joined == true {
|
||||
log.Printf("Not setting JOIN state on channel %s: already set", c.channel.Name)
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("Setting JOIN state on channel %s", c.channel.Name)
|
||||
c.joined = true
|
||||
close(c.joinDone)
|
||||
}
|
||||
|
||||
func (c *channelState) UnsetJoined() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.joined == false {
|
||||
log.Printf("Not removing JOIN state on channel %s: already not set", c.channel.Name)
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("Removing JOIN state on channel %s", c.channel.Name)
|
||||
c.joined = false
|
||||
c.joinDone = make(chan struct{})
|
||||
|
||||
// eventually poke monitor routine
|
||||
select {
|
||||
case c.joinUnsetSignal <- true:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (c *channelState) join(ctx context.Context) {
|
||||
log.Printf("Channel %s monitor: waiting to join", c.channel.Name)
|
||||
if ok := c.delayer.DelayContext(ctx); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
c.client.Join(c.channel.Name, c.channel.Password)
|
||||
log.Printf("Channel %s monitor: join request sent", c.channel.Name)
|
||||
|
||||
select {
|
||||
case <-c.JoinDone():
|
||||
log.Printf("Channel %s monitor: join succeeded", c.channel.Name)
|
||||
case <-time.After(ircJoinWaitSecs * time.Second):
|
||||
log.Printf("Channel %s monitor: could not join after %d seconds, will retry", c.channel.Name, ircJoinWaitSecs)
|
||||
case <-ctx.Done():
|
||||
log.Printf("Channel %s monitor: context canceled while waiting for join", c.channel.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *channelState) monitorJoinUnset(ctx context.Context) {
|
||||
select {
|
||||
case <-c.joinUnsetSignal:
|
||||
log.Printf("Channel %s monitor: channel no longer joined", c.channel.Name)
|
||||
case <-ctx.Done():
|
||||
log.Printf("Channel %s monitor: context canceled while monitoring", c.channel.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *channelState) Monitor(ctx context.Context, wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
|
||||
joined := func() bool {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.joined
|
||||
}
|
||||
|
||||
for ctx.Err() != context.Canceled {
|
||||
if !joined() {
|
||||
c.join(ctx)
|
||||
} else {
|
||||
c.monitorJoinUnset(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type ChannelReconciler struct {
|
||||
@ -57,53 +171,107 @@ func NewChannelReconciler(config *Config, client *irc.Conn, delayerMaker Delayer
|
||||
}
|
||||
|
||||
func (r *ChannelReconciler) registerHandlers() {
|
||||
r.client.HandleFunc(irc.JOIN,
|
||||
func(_ *irc.Conn, line *irc.Line) {
|
||||
r.HandleJoin(line.Nick, line.Args[0])
|
||||
})
|
||||
|
||||
r.client.HandleFunc(irc.KICK,
|
||||
func(_ *irc.Conn, line *irc.Line) {
|
||||
r.HandleKick(line.Args[1], line.Args[0])
|
||||
})
|
||||
}
|
||||
|
||||
func (r *ChannelReconciler) HandleJoin(nick string, channel string) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if nick != r.client.Me().Nick {
|
||||
// received join info for somebody else
|
||||
return
|
||||
}
|
||||
log.Printf("Received JOIN confirmation for channel %s", channel)
|
||||
|
||||
c, ok := r.channels[channel]
|
||||
if !ok {
|
||||
log.Printf("Not processing JOIN for channel %s: unknown channel", channel)
|
||||
return
|
||||
}
|
||||
c.SetJoined()
|
||||
}
|
||||
|
||||
func (r *ChannelReconciler) HandleKick(nick string, channel string) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if nick != r.client.Me().Nick {
|
||||
// received kick info for somebody else
|
||||
return
|
||||
}
|
||||
state, ok := r.channels[channel]
|
||||
log.Printf("Received KICK for channel %s", channel)
|
||||
|
||||
c, ok := r.channels[channel]
|
||||
if !ok {
|
||||
log.Printf("Being kicked out of non-joined channel (%s), ignoring", channel)
|
||||
log.Printf("Not processing KICK for channel %s: unknown channel", channel)
|
||||
return
|
||||
}
|
||||
log.Printf("Being kicked out of %s, re-joining", channel)
|
||||
go func() {
|
||||
if ok := state.BackoffCounter.DelayContext(r.stopCtx); !ok {
|
||||
return
|
||||
}
|
||||
r.client.Join(state.Channel.Name, state.Channel.Password)
|
||||
}()
|
||||
c.UnsetJoined()
|
||||
}
|
||||
|
||||
func (r *ChannelReconciler) CleanupChannels() {
|
||||
log.Printf("Deregistering all channels.")
|
||||
func (r *ChannelReconciler) unsafeAddChannel(channel *IRCChannel) *channelState {
|
||||
c := newChannelState(channel, r.client, r.delayerMaker)
|
||||
|
||||
r.stopWg.Add(1)
|
||||
go c.Monitor(r.stopCtx, &r.stopWg)
|
||||
|
||||
r.channels[channel.Name] = c
|
||||
return c
|
||||
}
|
||||
|
||||
func (r *ChannelReconciler) JoinChannel(channel string) (bool, <-chan struct{}) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
c, ok := r.channels[channel]
|
||||
if !ok {
|
||||
log.Printf("Request to JOIN new channel %s", channel)
|
||||
c = r.unsafeAddChannel(&IRCChannel{Name: channel})
|
||||
}
|
||||
|
||||
select {
|
||||
case <-c.JoinDone():
|
||||
return true, nil
|
||||
default:
|
||||
return false, c.JoinDone()
|
||||
}
|
||||
}
|
||||
|
||||
func (r *ChannelReconciler) unsafeStop() {
|
||||
if r.stopCtxCancel == nil {
|
||||
// calling stop before first start, ignoring
|
||||
return
|
||||
}
|
||||
r.stopCtxCancel()
|
||||
r.stopWg.Wait()
|
||||
r.channels = make(map[string]*channelState)
|
||||
}
|
||||
|
||||
func (r *ChannelReconciler) JoinChannel(channel *IRCChannel) {
|
||||
if _, joined := r.channels[channel.Name]; joined {
|
||||
return
|
||||
}
|
||||
log.Printf("Joining %s", channel.Name)
|
||||
r.client.Join(channel.Name, channel.Password)
|
||||
state := &channelState{
|
||||
Channel: *channel,
|
||||
BackoffCounter: r.delayerMaker.NewDelayer(
|
||||
ircConnectMaxBackoffSecs, ircConnectBackoffResetSecs,
|
||||
time.Second),
|
||||
}
|
||||
r.channels[channel.Name] = state
|
||||
func (r *ChannelReconciler) Stop() {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
r.unsafeStop()
|
||||
}
|
||||
|
||||
func (r *ChannelReconciler) JoinChannels() {
|
||||
func (r *ChannelReconciler) Start(ctx context.Context) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
r.unsafeStop()
|
||||
|
||||
r.stopCtx, r.stopCtxCancel = context.WithCancel(ctx)
|
||||
|
||||
for _, channel := range r.preJoinChannels {
|
||||
r.JoinChannel(&channel)
|
||||
r.unsafeAddChannel(&channel)
|
||||
}
|
||||
}
|
||||
|
@ -16,7 +16,9 @@ package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"reflect"
|
||||
"sort"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
@ -63,12 +65,14 @@ func TestPreJoinChannels(t *testing.T) {
|
||||
|
||||
var testStep sync.WaitGroup
|
||||
|
||||
joinedChannels := []string{}
|
||||
|
||||
joinHandler := func(conn *bufio.ReadWriter, line *irc.Line) error {
|
||||
// #baz is configured as the last channel to pre-join
|
||||
if line.Args[0] == "#baz" {
|
||||
joinedChannels = append(joinedChannels, line.Args[0])
|
||||
if len(joinedChannels) == 3 {
|
||||
testStep.Done()
|
||||
}
|
||||
return nil
|
||||
return hJOIN(conn, line)
|
||||
}
|
||||
server.SetHandler("JOIN", joinHandler)
|
||||
|
||||
@ -77,25 +81,20 @@ func TestPreJoinChannels(t *testing.T) {
|
||||
reconciler.client.Connect()
|
||||
|
||||
<-sessionUp
|
||||
reconciler.JoinChannels()
|
||||
reconciler.Start(context.Background())
|
||||
|
||||
testStep.Wait()
|
||||
|
||||
reconciler.client.Quit("see ya")
|
||||
<-sessionDown
|
||||
reconciler.Stop()
|
||||
|
||||
server.Stop()
|
||||
|
||||
expectedCommands := []string{
|
||||
"NICK foo",
|
||||
"USER foo 12 * :",
|
||||
"JOIN #foo",
|
||||
"JOIN #bar",
|
||||
"JOIN #baz",
|
||||
"QUIT :see ya",
|
||||
}
|
||||
expectedJoinedChannels := []string{"#bar", "#baz", "#foo"}
|
||||
sort.Strings(joinedChannels)
|
||||
|
||||
if !reflect.DeepEqual(expectedCommands, server.Log) {
|
||||
if !reflect.DeepEqual(expectedJoinedChannels, joinedChannels) {
|
||||
t.Error("Did not pre-join channels")
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user