diff --git a/irc/channel.go b/irc/channel.go index 1c7c351f..6a8db288 100644 --- a/irc/channel.go +++ b/irc/channel.go @@ -255,7 +255,7 @@ func (channel *Channel) writeLoop() { // Store writes part (or all) of the channel's data back to the database, // blocking until the write is complete. This is the equivalent of // Socket.BlockingWrite. -func (channel *Channel) Store(dirtyBits uint) (err error) { +func (channel *Channel) Store(additionalDirtyBits uint) (err error) { defer func() { channel.stateMutex.Lock() isDirty := channel.dirtyBits != 0 @@ -271,7 +271,7 @@ func (channel *Channel) Store(dirtyBits uint) (err error) { channel.writebackLock.Lock() defer channel.writebackLock.Unlock() - return channel.performWrite(dirtyBits) + return channel.performWrite(additionalDirtyBits) } // do an individual write; equivalent of Socket.send() @@ -802,7 +802,8 @@ func (channel *Channel) Join(client *Client, key string, isSajoin bool, rb *Resp } } - if joinErr := client.addChannel(channel, rb == nil); joinErr != nil { + alwaysOn, joinErr := client.addChannel(channel) + if joinErr != nil { return joinErr, "" } @@ -834,6 +835,13 @@ func (channel *Channel) Join(client *Client, key string, isSajoin bool, rb *Resp return }() + if alwaysOn { + // skip this for simulated join of always-on clients on server startup: + if rb != nil { + client.markDirty(IncludeChannels) + } + } + var message utils.SplitMessage respectAuditorium := givenMode == modes.Mode(0) && channel.flags.HasMode(modes.Auditorium) message = utils.MakeMessage("") diff --git a/irc/client.go b/irc/client.go index 956fc68c..d9aac5d9 100644 --- a/irc/client.go +++ b/irc/client.go @@ -1639,11 +1639,12 @@ func (session *Session) Notice(text string) { // `simulated` is for the fake join of an always-on client // (we just read the channel name from the database, there's no need to write it back) -func (client *Client) addChannel(channel *Channel, simulated bool) (err error) { +func (client *Client) addChannel(channel *Channel) (alwaysOn bool, err error) { config := client.server.Config() client.stateMutex.Lock() - alwaysOn := client.alwaysOn + defer client.stateMutex.Unlock() + alwaysOn = client.alwaysOn if client.destroyed { err = errClientDestroyed } else if client.oper == nil && len(client.channels) >= config.Channels.MaxChannelsPerClient { @@ -1651,11 +1652,8 @@ func (client *Client) addChannel(channel *Channel, simulated bool) (err error) { } else { client.channels.Add(channel) // success } - client.stateMutex.Unlock() - - if err == nil && alwaysOn && !simulated { - client.markDirty(IncludeChannels) - } + // XXX don't markDirty here; we need to wait for the change to go through + // on the channel side, so we can correctly record whatever mode was granted return } @@ -1967,8 +1965,11 @@ func (client *Client) performWrite(additionalDirtyBits uint) { } // Blocking store; see Channel.Store and Socket.BlockingWrite -func (client *Client) Store(dirtyBits uint) (err error) { +func (client *Client) Store(dirtyBits uint, shutdown bool) (err error) { defer func() { + if shutdown { + return // no need to restart the loop if we're shutting down + } client.stateMutex.Lock() isDirty := client.dirtyBits != 0 client.stateMutex.Unlock() diff --git a/irc/server.go b/irc/server.go index 6840189c..aeeb639c 100644 --- a/irc/server.go +++ b/irc/server.go @@ -163,7 +163,7 @@ func (server *Server) Shutdown() { } // flush data associated with always-on clients: - server.performAlwaysOnMaintenance(false, true) + server.performAlwaysOnMaintenance(true) if err := server.store.Close(); err != nil { server.logger.Error("shutdown", "Could not close datastore", err.Error()) @@ -285,20 +285,28 @@ func (server *Server) periodicAlwaysOnMaintenance() { defer server.HandlePanic(nil) server.logger.Info("accounts", "Performing periodic always-on client checks") - server.performAlwaysOnMaintenance(true, true) + server.performAlwaysOnMaintenance(false) } -func (server *Server) performAlwaysOnMaintenance(checkExpiration, flushTimestamps bool) { +func (server *Server) performAlwaysOnMaintenance(shutdown bool) { config := server.Config() for _, client := range server.clients.AllClients() { - if checkExpiration && client.IsExpiredAlwaysOn(config) { + if !shutdown && client.IsExpiredAlwaysOn(config) { // TODO save the channels list, use it for autojoin if/when they return? server.logger.Info("accounts", "Expiring always-on client", client.AccountName()) client.destroy(nil) continue } - if flushTimestamps && client.shouldFlushTimestamps() { + // synchronously flush channel memberships, etc., avoiding a race between + // immediate but asynchronous writeback of those fields and server shutdown + if shutdown && client.AlwaysOn() { + client.Store(0, shutdown) + } + + // flush the timestamps (which are not written back immediately, for debouncing + // reasons), either as periodic maintenance or on shutdown + if client.shouldFlushTimestamps() { account := client.Account() server.accounts.saveLastSeen(account, client.copyLastSeen()) server.accounts.saveReadMarkers(account, client.copyReadMarkers()) @@ -343,7 +351,7 @@ func (server *Server) performPushMaintenance() { } } // persist all push subscriptions on the assumption that the timestamps have changed - client.Store(IncludePushSubscriptions) + client.Store(IncludePushSubscriptions, false) } }