mirror of
				https://github.com/google/alertmanager-irc-relay.git
				synced 2025-11-04 07:57:24 +01:00 
			
		
		
		
	new channel management logic
this should handle bans and kicks a bit better Signed-off-by: Luca Bigliardi <shammash@google.com>
This commit is contained in:
		
							parent
							
								
									c22e7a0c84
								
							
						
					
					
						commit
						0b2fbef1f2
					
				
							
								
								
									
										36
									
								
								irc.go
									
									
									
									
									
								
							
							
						
						
									
										36
									
								
								irc.go
									
									
									
									
									
								
							@ -177,14 +177,36 @@ func (n *IRCNotifier) MaybeIdentifyNick() {
 | 
			
		||||
	time.Sleep(n.NickservDelayWait)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (n *IRCNotifier) MaybeSendAlertMsg(alertMsg *AlertMsg) {
 | 
			
		||||
func (n *IRCNotifier) ChannelJoined(channel string) bool {
 | 
			
		||||
 | 
			
		||||
	isJoined, waitJoined := n.channelReconciler.JoinChannel(channel)
 | 
			
		||||
	if isJoined {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	select {
 | 
			
		||||
	case <-waitJoined:
 | 
			
		||||
		return true
 | 
			
		||||
	case <-time.After(ircJoinWaitSecs * time.Second):
 | 
			
		||||
		log.Printf("Channel %s not joined after %d seconds, giving bad news to caller", channel, ircJoinWaitSecs)
 | 
			
		||||
		return false
 | 
			
		||||
	case <-n.stopCtx.Done():
 | 
			
		||||
		log.Printf("Context canceled while waiting for join on channel %s", channel)
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (n *IRCNotifier) SendAlertMsg(alertMsg *AlertMsg) {
 | 
			
		||||
	if !n.sessionUp {
 | 
			
		||||
		log.Printf("Cannot send alert to %s : IRC not connected",
 | 
			
		||||
			alertMsg.Channel)
 | 
			
		||||
		log.Printf("Cannot send alert to %s : IRC not connected", alertMsg.Channel)
 | 
			
		||||
		ircSendMsgErrors.WithLabelValues(alertMsg.Channel, "not_connected").Inc()
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	n.channelReconciler.JoinChannel(&IRCChannel{Name: alertMsg.Channel})
 | 
			
		||||
	if !n.ChannelJoined(alertMsg.Channel) {
 | 
			
		||||
		log.Printf("Cannot send alert to %s : cannot join channel", alertMsg.Channel)
 | 
			
		||||
		ircSendMsgErrors.WithLabelValues(alertMsg.Channel, "not_joined").Inc()
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if n.UsePrivmsg {
 | 
			
		||||
		n.Client.Privmsg(alertMsg.Channel, alertMsg.Alert)
 | 
			
		||||
@ -213,10 +235,10 @@ func (n *IRCNotifier) ShutdownPhase() {
 | 
			
		||||
func (n *IRCNotifier) ConnectedPhase() {
 | 
			
		||||
	select {
 | 
			
		||||
	case alertMsg := <-n.AlertMsgs:
 | 
			
		||||
		n.MaybeSendAlertMsg(&alertMsg)
 | 
			
		||||
		n.SendAlertMsg(&alertMsg)
 | 
			
		||||
	case <-n.sessionDownSignal:
 | 
			
		||||
		n.sessionUp = false
 | 
			
		||||
		n.channelReconciler.CleanupChannels()
 | 
			
		||||
		n.channelReconciler.Stop()
 | 
			
		||||
		n.Client.Quit("see ya")
 | 
			
		||||
		ircConnectedGauge.Set(0)
 | 
			
		||||
	case <-n.stopCtx.Done():
 | 
			
		||||
@ -240,7 +262,7 @@ func (n *IRCNotifier) SetupPhase() {
 | 
			
		||||
	case <-n.sessionUpSignal:
 | 
			
		||||
		n.sessionUp = true
 | 
			
		||||
		n.MaybeIdentifyNick()
 | 
			
		||||
		n.channelReconciler.JoinChannels()
 | 
			
		||||
		n.channelReconciler.Start(n.stopCtx)
 | 
			
		||||
		ircConnectedGauge.Set(1)
 | 
			
		||||
	case <-n.sessionDownSignal:
 | 
			
		||||
		log.Printf("Receiving a session down before the session is up, this is odd")
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										22
									
								
								irc_test.go
									
									
									
									
									
								
							
							
						
						
									
										22
									
								
								irc_test.go
									
									
									
									
									
								
							@ -108,7 +108,7 @@ func TestSendAlertOnPreJoinedChannel(t *testing.T) {
 | 
			
		||||
		if line.Args[0] == testChannel {
 | 
			
		||||
			testStep.Done()
 | 
			
		||||
		}
 | 
			
		||||
		return nil
 | 
			
		||||
		return hJOIN(conn, line)
 | 
			
		||||
	}
 | 
			
		||||
	server.SetHandler("JOIN", joinedHandler)
 | 
			
		||||
 | 
			
		||||
@ -117,7 +117,7 @@ func TestSendAlertOnPreJoinedChannel(t *testing.T) {
 | 
			
		||||
 | 
			
		||||
	testStep.Wait()
 | 
			
		||||
 | 
			
		||||
	server.SetHandler("JOIN", nil)
 | 
			
		||||
	server.SetHandler("JOIN", hJOIN)
 | 
			
		||||
 | 
			
		||||
	noticeHandler := func(conn *bufio.ReadWriter, line *irc.Line) error {
 | 
			
		||||
		testStep.Done()
 | 
			
		||||
@ -163,7 +163,7 @@ func TestUsePrivmsgToSendAlertOnPreJoinedChannel(t *testing.T) {
 | 
			
		||||
		if line.Args[0] == testChannel {
 | 
			
		||||
			testStep.Done()
 | 
			
		||||
		}
 | 
			
		||||
		return nil
 | 
			
		||||
		return hJOIN(conn, line)
 | 
			
		||||
	}
 | 
			
		||||
	server.SetHandler("JOIN", joinedHandler)
 | 
			
		||||
 | 
			
		||||
@ -172,7 +172,7 @@ func TestUsePrivmsgToSendAlertOnPreJoinedChannel(t *testing.T) {
 | 
			
		||||
 | 
			
		||||
	testStep.Wait()
 | 
			
		||||
 | 
			
		||||
	server.SetHandler("JOIN", nil)
 | 
			
		||||
	server.SetHandler("JOIN", hJOIN)
 | 
			
		||||
 | 
			
		||||
	privmsgHandler := func(conn *bufio.ReadWriter, line *irc.Line) error {
 | 
			
		||||
		testStep.Done()
 | 
			
		||||
@ -215,7 +215,7 @@ func TestSendAlertAndJoinChannel(t *testing.T) {
 | 
			
		||||
	// ordering.
 | 
			
		||||
	joinHandler := func(conn *bufio.ReadWriter, line *irc.Line) error {
 | 
			
		||||
		testStep.Done()
 | 
			
		||||
		return nil
 | 
			
		||||
		return hJOIN(conn, line)
 | 
			
		||||
	}
 | 
			
		||||
	server.SetHandler("JOIN", joinHandler)
 | 
			
		||||
 | 
			
		||||
@ -224,7 +224,7 @@ func TestSendAlertAndJoinChannel(t *testing.T) {
 | 
			
		||||
 | 
			
		||||
	testStep.Wait()
 | 
			
		||||
 | 
			
		||||
	server.SetHandler("JOIN", nil)
 | 
			
		||||
	server.SetHandler("JOIN", hJOIN)
 | 
			
		||||
 | 
			
		||||
	noticeHandler := func(conn *bufio.ReadWriter, line *irc.Line) error {
 | 
			
		||||
		testStep.Done()
 | 
			
		||||
@ -295,7 +295,7 @@ func TestSendAlertDisconnected(t *testing.T) {
 | 
			
		||||
	testStep.Add(1)
 | 
			
		||||
	joinHandler := func(conn *bufio.ReadWriter, line *irc.Line) error {
 | 
			
		||||
		testStep.Done()
 | 
			
		||||
		return nil
 | 
			
		||||
		return hJOIN(conn, line)
 | 
			
		||||
	}
 | 
			
		||||
	server.SetHandler("JOIN", joinHandler)
 | 
			
		||||
 | 
			
		||||
@ -339,7 +339,7 @@ func TestReconnect(t *testing.T) {
 | 
			
		||||
 | 
			
		||||
	joinHandler := func(conn *bufio.ReadWriter, line *irc.Line) error {
 | 
			
		||||
		testStep.Done()
 | 
			
		||||
		return nil
 | 
			
		||||
		return hJOIN(conn, line)
 | 
			
		||||
	}
 | 
			
		||||
	server.SetHandler("JOIN", joinHandler)
 | 
			
		||||
 | 
			
		||||
@ -409,7 +409,7 @@ func TestConnectErrorRetry(t *testing.T) {
 | 
			
		||||
	joinStep.Add(1)
 | 
			
		||||
	joinHandler := func(conn *bufio.ReadWriter, line *irc.Line) error {
 | 
			
		||||
		joinStep.Done()
 | 
			
		||||
		return nil
 | 
			
		||||
		return hJOIN(conn, line)
 | 
			
		||||
	}
 | 
			
		||||
	server.SetHandler("JOIN", joinHandler)
 | 
			
		||||
	server.SetCloseEarly(nil)
 | 
			
		||||
@ -446,7 +446,7 @@ func TestIdentify(t *testing.T) {
 | 
			
		||||
	// after identification).
 | 
			
		||||
	joinHandler := func(conn *bufio.ReadWriter, line *irc.Line) error {
 | 
			
		||||
		testStep.Done()
 | 
			
		||||
		return nil
 | 
			
		||||
		return hJOIN(conn, line)
 | 
			
		||||
	}
 | 
			
		||||
	server.SetHandler("JOIN", joinHandler)
 | 
			
		||||
 | 
			
		||||
@ -498,7 +498,7 @@ func TestGhostAndIdentify(t *testing.T) {
 | 
			
		||||
	// after identification).
 | 
			
		||||
	joinHandler := func(conn *bufio.ReadWriter, line *irc.Line) error {
 | 
			
		||||
		testStep.Done()
 | 
			
		||||
		return nil
 | 
			
		||||
		return hJOIN(conn, line)
 | 
			
		||||
	}
 | 
			
		||||
	server.SetHandler("JOIN", joinHandler)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -29,6 +29,12 @@ import (
 | 
			
		||||
 | 
			
		||||
type LineHandlerFunc func(*bufio.ReadWriter, *irc.Line) error
 | 
			
		||||
 | 
			
		||||
func hJOIN(conn *bufio.ReadWriter, line *irc.Line) error {
 | 
			
		||||
	r := fmt.Sprintf(":foo!foo@example.com JOIN :%s\n", line.Args[0])
 | 
			
		||||
	_, err := conn.WriteString(r)
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func hUSER(conn *bufio.ReadWriter, line *irc.Line) error {
 | 
			
		||||
	r := fmt.Sprintf(":example.com 001 %s :Welcome\n", line.Args[0])
 | 
			
		||||
	_, err := conn.WriteString(r)
 | 
			
		||||
@ -61,6 +67,7 @@ func (s *testServer) setDefaultHandlers() {
 | 
			
		||||
	if s.lineHandlers == nil {
 | 
			
		||||
		s.lineHandlers = make(map[string]LineHandlerFunc)
 | 
			
		||||
	}
 | 
			
		||||
	s.lineHandlers["JOIN"] = hJOIN
 | 
			
		||||
	s.lineHandlers["USER"] = hUSER
 | 
			
		||||
	s.lineHandlers["QUIT"] = hQUIT
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										224
									
								
								reconciler.go
									
									
									
									
									
								
							
							
						
						
									
										224
									
								
								reconciler.go
									
									
									
									
									
								
							@ -23,9 +23,123 @@ import (
 | 
			
		||||
	irc "github.com/fluffle/goirc/client"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	ircJoinWaitSecs         = 10
 | 
			
		||||
	ircJoinMaxBackoffSecs   = 300
 | 
			
		||||
	ircJoinBackoffResetSecs = 1800
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type channelState struct {
 | 
			
		||||
	Channel        IRCChannel
 | 
			
		||||
	BackoffCounter Delayer
 | 
			
		||||
	channel IRCChannel
 | 
			
		||||
	client  *irc.Conn
 | 
			
		||||
	delayer Delayer
 | 
			
		||||
 | 
			
		||||
	joinDone chan struct{} // joined when channel is closed
 | 
			
		||||
	joined   bool
 | 
			
		||||
 | 
			
		||||
	joinUnsetSignal chan bool
 | 
			
		||||
 | 
			
		||||
	mu sync.Mutex
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newChannelState(channel *IRCChannel, client *irc.Conn, delayerMaker DelayerMaker) *channelState {
 | 
			
		||||
	delayer := delayerMaker.NewDelayer(ircJoinMaxBackoffSecs, ircJoinBackoffResetSecs, time.Second)
 | 
			
		||||
 | 
			
		||||
	return &channelState{
 | 
			
		||||
		channel:         *channel,
 | 
			
		||||
		client:          client,
 | 
			
		||||
		delayer:         delayer,
 | 
			
		||||
		joinDone:        make(chan struct{}),
 | 
			
		||||
		joined:          false,
 | 
			
		||||
		joinUnsetSignal: make(chan bool),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *channelState) JoinDone() <-chan struct{} {
 | 
			
		||||
	c.mu.Lock()
 | 
			
		||||
	defer c.mu.Unlock()
 | 
			
		||||
 | 
			
		||||
	return c.joinDone
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *channelState) SetJoined() {
 | 
			
		||||
	c.mu.Lock()
 | 
			
		||||
	defer c.mu.Unlock()
 | 
			
		||||
 | 
			
		||||
	if c.joined == true {
 | 
			
		||||
		log.Printf("Not setting JOIN state on channel %s: already set", c.channel.Name)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	log.Printf("Setting JOIN state on channel %s", c.channel.Name)
 | 
			
		||||
	c.joined = true
 | 
			
		||||
	close(c.joinDone)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *channelState) UnsetJoined() {
 | 
			
		||||
	c.mu.Lock()
 | 
			
		||||
	defer c.mu.Unlock()
 | 
			
		||||
 | 
			
		||||
	if c.joined == false {
 | 
			
		||||
		log.Printf("Not removing JOIN state on channel %s: already not set", c.channel.Name)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	log.Printf("Removing JOIN state on channel %s", c.channel.Name)
 | 
			
		||||
	c.joined = false
 | 
			
		||||
	c.joinDone = make(chan struct{})
 | 
			
		||||
 | 
			
		||||
	// eventually poke monitor routine
 | 
			
		||||
	select {
 | 
			
		||||
	case c.joinUnsetSignal <- true:
 | 
			
		||||
	default:
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *channelState) join(ctx context.Context) {
 | 
			
		||||
	log.Printf("Channel %s monitor: waiting to join", c.channel.Name)
 | 
			
		||||
	if ok := c.delayer.DelayContext(ctx); !ok {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	c.client.Join(c.channel.Name, c.channel.Password)
 | 
			
		||||
	log.Printf("Channel %s monitor: join request sent", c.channel.Name)
 | 
			
		||||
 | 
			
		||||
	select {
 | 
			
		||||
	case <-c.JoinDone():
 | 
			
		||||
		log.Printf("Channel %s monitor: join succeeded", c.channel.Name)
 | 
			
		||||
	case <-time.After(ircJoinWaitSecs * time.Second):
 | 
			
		||||
		log.Printf("Channel %s monitor: could not join after %d seconds, will retry", c.channel.Name, ircJoinWaitSecs)
 | 
			
		||||
	case <-ctx.Done():
 | 
			
		||||
		log.Printf("Channel %s monitor: context canceled while waiting for join", c.channel.Name)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *channelState) monitorJoinUnset(ctx context.Context) {
 | 
			
		||||
	select {
 | 
			
		||||
	case <-c.joinUnsetSignal:
 | 
			
		||||
		log.Printf("Channel %s monitor: channel no longer joined", c.channel.Name)
 | 
			
		||||
	case <-ctx.Done():
 | 
			
		||||
		log.Printf("Channel %s monitor: context canceled while monitoring", c.channel.Name)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *channelState) Monitor(ctx context.Context, wg *sync.WaitGroup) {
 | 
			
		||||
	defer wg.Done()
 | 
			
		||||
 | 
			
		||||
	joined := func() bool {
 | 
			
		||||
		c.mu.Lock()
 | 
			
		||||
		defer c.mu.Unlock()
 | 
			
		||||
		return c.joined
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for ctx.Err() != context.Canceled {
 | 
			
		||||
		if !joined() {
 | 
			
		||||
			c.join(ctx)
 | 
			
		||||
		} else {
 | 
			
		||||
			c.monitorJoinUnset(ctx)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ChannelReconciler struct {
 | 
			
		||||
@ -57,53 +171,107 @@ func NewChannelReconciler(config *Config, client *irc.Conn, delayerMaker Delayer
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *ChannelReconciler) registerHandlers() {
 | 
			
		||||
	r.client.HandleFunc(irc.JOIN,
 | 
			
		||||
		func(_ *irc.Conn, line *irc.Line) {
 | 
			
		||||
			r.HandleJoin(line.Nick, line.Args[0])
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
	r.client.HandleFunc(irc.KICK,
 | 
			
		||||
		func(_ *irc.Conn, line *irc.Line) {
 | 
			
		||||
			r.HandleKick(line.Args[1], line.Args[0])
 | 
			
		||||
		})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *ChannelReconciler) HandleJoin(nick string, channel string) {
 | 
			
		||||
	r.mu.Lock()
 | 
			
		||||
	defer r.mu.Unlock()
 | 
			
		||||
 | 
			
		||||
	if nick != r.client.Me().Nick {
 | 
			
		||||
		// received join info for somebody else
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	log.Printf("Received JOIN confirmation for channel %s", channel)
 | 
			
		||||
 | 
			
		||||
	c, ok := r.channels[channel]
 | 
			
		||||
	if !ok {
 | 
			
		||||
		log.Printf("Not processing JOIN for channel %s: unknown channel", channel)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	c.SetJoined()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *ChannelReconciler) HandleKick(nick string, channel string) {
 | 
			
		||||
	r.mu.Lock()
 | 
			
		||||
	defer r.mu.Unlock()
 | 
			
		||||
 | 
			
		||||
	if nick != r.client.Me().Nick {
 | 
			
		||||
		// received kick info for somebody else
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	state, ok := r.channels[channel]
 | 
			
		||||
	log.Printf("Received KICK for channel %s", channel)
 | 
			
		||||
 | 
			
		||||
	c, ok := r.channels[channel]
 | 
			
		||||
	if !ok {
 | 
			
		||||
		log.Printf("Being kicked out of non-joined channel (%s), ignoring", channel)
 | 
			
		||||
		log.Printf("Not processing KICK for channel %s: unknown channel", channel)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	log.Printf("Being kicked out of %s, re-joining", channel)
 | 
			
		||||
	go func() {
 | 
			
		||||
		if ok := state.BackoffCounter.DelayContext(r.stopCtx); !ok {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		r.client.Join(state.Channel.Name, state.Channel.Password)
 | 
			
		||||
	}()
 | 
			
		||||
	c.UnsetJoined()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *ChannelReconciler) CleanupChannels() {
 | 
			
		||||
	log.Printf("Deregistering all channels.")
 | 
			
		||||
func (r *ChannelReconciler) unsafeAddChannel(channel *IRCChannel) *channelState {
 | 
			
		||||
	c := newChannelState(channel, r.client, r.delayerMaker)
 | 
			
		||||
 | 
			
		||||
	r.stopWg.Add(1)
 | 
			
		||||
	go c.Monitor(r.stopCtx, &r.stopWg)
 | 
			
		||||
 | 
			
		||||
	r.channels[channel.Name] = c
 | 
			
		||||
	return c
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *ChannelReconciler) JoinChannel(channel string) (bool, <-chan struct{}) {
 | 
			
		||||
	r.mu.Lock()
 | 
			
		||||
	defer r.mu.Unlock()
 | 
			
		||||
 | 
			
		||||
	c, ok := r.channels[channel]
 | 
			
		||||
	if !ok {
 | 
			
		||||
		log.Printf("Request to JOIN new channel %s", channel)
 | 
			
		||||
		c = r.unsafeAddChannel(&IRCChannel{Name: channel})
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	select {
 | 
			
		||||
	case <-c.JoinDone():
 | 
			
		||||
		return true, nil
 | 
			
		||||
	default:
 | 
			
		||||
		return false, c.JoinDone()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *ChannelReconciler) unsafeStop() {
 | 
			
		||||
	if r.stopCtxCancel == nil {
 | 
			
		||||
		// calling stop before first start, ignoring
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	r.stopCtxCancel()
 | 
			
		||||
	r.stopWg.Wait()
 | 
			
		||||
	r.channels = make(map[string]*channelState)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *ChannelReconciler) JoinChannel(channel *IRCChannel) {
 | 
			
		||||
	if _, joined := r.channels[channel.Name]; joined {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	log.Printf("Joining %s", channel.Name)
 | 
			
		||||
	r.client.Join(channel.Name, channel.Password)
 | 
			
		||||
	state := &channelState{
 | 
			
		||||
		Channel: *channel,
 | 
			
		||||
		BackoffCounter: r.delayerMaker.NewDelayer(
 | 
			
		||||
			ircConnectMaxBackoffSecs, ircConnectBackoffResetSecs,
 | 
			
		||||
			time.Second),
 | 
			
		||||
	}
 | 
			
		||||
	r.channels[channel.Name] = state
 | 
			
		||||
func (r *ChannelReconciler) Stop() {
 | 
			
		||||
	r.mu.Lock()
 | 
			
		||||
	defer r.mu.Unlock()
 | 
			
		||||
 | 
			
		||||
	r.unsafeStop()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *ChannelReconciler) JoinChannels() {
 | 
			
		||||
func (r *ChannelReconciler) Start(ctx context.Context) {
 | 
			
		||||
	r.mu.Lock()
 | 
			
		||||
	defer r.mu.Unlock()
 | 
			
		||||
 | 
			
		||||
	r.unsafeStop()
 | 
			
		||||
 | 
			
		||||
	r.stopCtx, r.stopCtxCancel = context.WithCancel(ctx)
 | 
			
		||||
 | 
			
		||||
	for _, channel := range r.preJoinChannels {
 | 
			
		||||
		r.JoinChannel(&channel)
 | 
			
		||||
		r.unsafeAddChannel(&channel)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -16,7 +16,9 @@ package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"context"
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"sort"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"testing"
 | 
			
		||||
 | 
			
		||||
@ -63,12 +65,14 @@ func TestPreJoinChannels(t *testing.T) {
 | 
			
		||||
 | 
			
		||||
	var testStep sync.WaitGroup
 | 
			
		||||
 | 
			
		||||
	joinedChannels := []string{}
 | 
			
		||||
 | 
			
		||||
	joinHandler := func(conn *bufio.ReadWriter, line *irc.Line) error {
 | 
			
		||||
		// #baz is configured as the last channel to pre-join
 | 
			
		||||
		if line.Args[0] == "#baz" {
 | 
			
		||||
		joinedChannels = append(joinedChannels, line.Args[0])
 | 
			
		||||
		if len(joinedChannels) == 3 {
 | 
			
		||||
			testStep.Done()
 | 
			
		||||
		}
 | 
			
		||||
		return nil
 | 
			
		||||
		return hJOIN(conn, line)
 | 
			
		||||
	}
 | 
			
		||||
	server.SetHandler("JOIN", joinHandler)
 | 
			
		||||
 | 
			
		||||
@ -77,25 +81,20 @@ func TestPreJoinChannels(t *testing.T) {
 | 
			
		||||
	reconciler.client.Connect()
 | 
			
		||||
 | 
			
		||||
	<-sessionUp
 | 
			
		||||
	reconciler.JoinChannels()
 | 
			
		||||
	reconciler.Start(context.Background())
 | 
			
		||||
 | 
			
		||||
	testStep.Wait()
 | 
			
		||||
 | 
			
		||||
	reconciler.client.Quit("see ya")
 | 
			
		||||
	<-sessionDown
 | 
			
		||||
	reconciler.Stop()
 | 
			
		||||
 | 
			
		||||
	server.Stop()
 | 
			
		||||
 | 
			
		||||
	expectedCommands := []string{
 | 
			
		||||
		"NICK foo",
 | 
			
		||||
		"USER foo 12 * :",
 | 
			
		||||
		"JOIN #foo",
 | 
			
		||||
		"JOIN #bar",
 | 
			
		||||
		"JOIN #baz",
 | 
			
		||||
		"QUIT :see ya",
 | 
			
		||||
	}
 | 
			
		||||
	expectedJoinedChannels := []string{"#bar", "#baz", "#foo"}
 | 
			
		||||
	sort.Strings(joinedChannels)
 | 
			
		||||
 | 
			
		||||
	if !reflect.DeepEqual(expectedCommands, server.Log) {
 | 
			
		||||
	if !reflect.DeepEqual(expectedJoinedChannels, joinedChannels) {
 | 
			
		||||
		t.Error("Did not pre-join channels")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user