diff --git a/irc/accounts.go b/irc/accounts.go index 4a6298eb..a322989a 100644 --- a/irc/accounts.go +++ b/irc/accounts.go @@ -59,18 +59,15 @@ type AccountManager struct { accountToMethod map[string]NickReservationMethod } -func NewAccountManager(server *Server) *AccountManager { - am := AccountManager{ - accountToClients: make(map[string][]*Client), - nickToAccount: make(map[string]string), - skeletonToAccount: make(map[string]string), - accountToMethod: make(map[string]NickReservationMethod), - server: server, - } +func (am *AccountManager) Initialize(server *Server) { + am.accountToClients = make(map[string][]*Client) + am.nickToAccount = make(map[string]string) + am.skeletonToAccount = make(map[string]string) + am.accountToMethod = make(map[string]NickReservationMethod) + am.server = server am.buildNickToAccountIndex() am.initVHostRequestQueue() - return &am } func (am *AccountManager) buildNickToAccountIndex() { @@ -855,6 +852,7 @@ func (am *AccountManager) Unregister(account string) error { verificationCodeKey := fmt.Sprintf(keyAccountVerificationCode, casefoldedAccount) verifiedKey := fmt.Sprintf(keyAccountVerified, casefoldedAccount) nicksKey := fmt.Sprintf(keyAccountAdditionalNicks, casefoldedAccount) + enforcementKey := fmt.Sprintf(keyAccountEnforcement, casefoldedAccount) vhostKey := fmt.Sprintf(keyAccountVHost, casefoldedAccount) vhostQueueKey := fmt.Sprintf(keyVHostQueueAcctToId, casefoldedAccount) channelsKey := fmt.Sprintf(keyAccountChannels, casefoldedAccount) @@ -865,14 +863,7 @@ func (am *AccountManager) Unregister(account string) error { // on our way out, unregister all the account's channels and delete them from the db defer func() { for _, channelName := range registeredChannels { - info := am.server.channelRegistry.LoadChannel(channelName) - if info != nil && info.Founder == casefoldedAccount { - am.server.channelRegistry.Delete(channelName, *info) - } - channel := am.server.channels.Get(channelName) - if channel != nil { - channel.SetUnregistered(casefoldedAccount) - } + am.server.channels.SetUnregistered(channelName, casefoldedAccount) } }() @@ -892,6 +883,7 @@ func (am *AccountManager) Unregister(account string) error { tx.Delete(registeredTimeKey) tx.Delete(callbackKey) tx.Delete(verificationCodeKey) + tx.Delete(enforcementKey) rawNicks, _ = tx.Get(nicksKey) tx.Delete(nicksKey) credText, err = tx.Get(credentialsKey) diff --git a/irc/caps/set.go b/irc/caps/set.go index b348cbf8..c1e01398 100644 --- a/irc/caps/set.go +++ b/irc/caps/set.go @@ -16,7 +16,6 @@ type Set [bitsetLen]uint64 // NewSet returns a new Set, with the given capabilities enabled. func NewSet(capabs ...Capability) *Set { var newSet Set - utils.BitsetInitialize(newSet[:]) newSet.Enable(capabs...) return &newSet } diff --git a/irc/channel.go b/irc/channel.go index 89f586ce..2d0ba9c3 100644 --- a/irc/channel.go +++ b/irc/channel.go @@ -22,7 +22,7 @@ import ( // Channel represents a channel that clients can join. type Channel struct { - flags *modes.ModeSet + flags modes.ModeSet lists map[modes.Mode]*UserMaskSet key string members MemberSet @@ -33,19 +33,22 @@ type Channel struct { createdTime time.Time registeredFounder string registeredTime time.Time - stateMutex sync.RWMutex // tier 1 - joinPartMutex sync.Mutex // tier 3 topic string topicSetBy string topicSetTime time.Time userLimit int accountToUMode map[string]modes.Mode history history.Buffer + stateMutex sync.RWMutex // tier 1 + writerSemaphore Semaphore // tier 1.5 + joinPartMutex sync.Mutex // tier 3 + ensureLoaded utils.Once // manages loading stored registration info from the database + dirtyBits uint } // NewChannel creates a new channel from a `Server` and a `name` // string, which must be unique on the server. -func NewChannel(s *Server, name string, regInfo *RegisteredChannel) *Channel { +func NewChannel(s *Server, name string, registered bool) *Channel { casefoldedName, err := CasefoldChannel(name) if err != nil { s.logger.Error("internal", "Bad channel name", name, err.Error()) @@ -54,7 +57,6 @@ func NewChannel(s *Server, name string, regInfo *RegisteredChannel) *Channel { channel := &Channel{ createdTime: time.Now(), // may be overwritten by applyRegInfo - flags: modes.NewModeSet(), lists: map[modes.Mode]*UserMaskSet{ modes.BanMask: NewUserMaskSet(), modes.ExceptMask: NewUserMaskSet(), @@ -69,21 +71,43 @@ func NewChannel(s *Server, name string, regInfo *RegisteredChannel) *Channel { config := s.Config() - if regInfo != nil { - channel.applyRegInfo(regInfo) - } else { + channel.writerSemaphore.Initialize(1) + channel.history.Initialize(config.History.ChannelLength) + + if !registered { for _, mode := range config.Channels.defaultModes { channel.flags.SetMode(mode, true) } - } - - channel.history.Initialize(config.History.ChannelLength) + // no loading to do, so "mark" the load operation as "done": + channel.ensureLoaded.Do(func() {}) + } // else: modes will be loaded before first join return channel } +// EnsureLoaded blocks until the channel's registration info has been loaded +// from the database. +func (channel *Channel) EnsureLoaded() { + channel.ensureLoaded.Do(func() { + nmc := channel.NameCasefolded() + info, err := channel.server.channelRegistry.LoadChannel(nmc) + if err == nil { + channel.applyRegInfo(info) + } else { + channel.server.logger.Error("internal", "couldn't load channel", nmc, err.Error()) + } + }) +} + +func (channel *Channel) IsLoaded() bool { + return channel.ensureLoaded.Done() +} + // read in channel state that was persisted in the DB -func (channel *Channel) applyRegInfo(chanReg *RegisteredChannel) { +func (channel *Channel) applyRegInfo(chanReg RegisteredChannel) { + channel.stateMutex.Lock() + defer channel.stateMutex.Unlock() + channel.registeredFounder = chanReg.Founder channel.registeredTime = chanReg.RegisteredAt channel.topic = chanReg.Topic @@ -116,6 +140,7 @@ func (channel *Channel) ExportRegistration(includeFlags uint) (info RegisteredCh defer channel.stateMutex.RUnlock() info.Name = channel.name + info.NameCasefolded = channel.nameCasefolded info.Founder = channel.registeredFounder info.RegisteredAt = channel.registeredTime @@ -149,6 +174,115 @@ func (channel *Channel) ExportRegistration(includeFlags uint) (info RegisteredCh return } +// begin: asynchronous database writeback implementation, modeled on irc/socket.go + +// MarkDirty marks part (or all) of a channel's data as needing to be written back +// to the database, then starts a writer goroutine if necessary. +// This is the equivalent of Socket.Write(). +func (channel *Channel) MarkDirty(dirtyBits uint) { + channel.stateMutex.Lock() + isRegistered := channel.registeredFounder != "" + channel.dirtyBits = channel.dirtyBits | dirtyBits + channel.stateMutex.Unlock() + if !isRegistered { + return + } + + channel.wakeWriter() +} + +// IsClean returns whether a channel can be safely removed from the server. +// To avoid the obvious TOCTOU race condition, it must be called while holding +// ChannelManager's lock (that way, no one can join and make the channel dirty again +// between this method exiting and the actual deletion). +func (channel *Channel) IsClean() bool { + if !channel.writerSemaphore.TryAcquire() { + // a database write (which may fail) is in progress, the channel cannot be cleaned up + return false + } + defer channel.writerSemaphore.Release() + + channel.stateMutex.RLock() + defer channel.stateMutex.RUnlock() + // the channel must be empty, and either be unregistered or fully written to the DB + return len(channel.members) == 0 && (channel.registeredFounder == "" || channel.dirtyBits == 0) +} + +func (channel *Channel) wakeWriter() { + if channel.writerSemaphore.TryAcquire() { + go channel.writeLoop() + } +} + +// equivalent of Socket.send() +func (channel *Channel) writeLoop() { + for { + // TODO(#357) check the error value of this and implement timed backoff + channel.performWrite(0) + channel.writerSemaphore.Release() + + channel.stateMutex.RLock() + isDirty := channel.dirtyBits != 0 + isEmpty := len(channel.members) == 0 + channel.stateMutex.RUnlock() + + if !isDirty { + if isEmpty { + channel.server.channels.Cleanup(channel) + } + return // nothing to do + } // else: isDirty, so we need to write again + + if !channel.writerSemaphore.TryAcquire() { + return + } + } +} + +// Store writes part (or all) of the channel's data back to the database, +// blocking until the write is complete. This is the equivalent of +// Socket.BlockingWrite. +func (channel *Channel) Store(dirtyBits uint) (err error) { + defer func() { + channel.stateMutex.Lock() + isDirty := channel.dirtyBits != 0 + isEmpty := len(channel.members) == 0 + channel.stateMutex.Unlock() + + if isDirty { + channel.wakeWriter() + } else if isEmpty { + channel.server.channels.Cleanup(channel) + } + }() + + channel.writerSemaphore.Acquire() + defer channel.writerSemaphore.Release() + return channel.performWrite(dirtyBits) +} + +// do an individual write; equivalent of Socket.send() +func (channel *Channel) performWrite(additionalDirtyBits uint) (err error) { + channel.stateMutex.Lock() + dirtyBits := channel.dirtyBits | additionalDirtyBits + channel.dirtyBits = 0 + isRegistered := channel.registeredFounder != "" + channel.stateMutex.Unlock() + + if !isRegistered || dirtyBits == 0 { + return + } + + info := channel.ExportRegistration(dirtyBits) + err = channel.server.channelRegistry.StoreChannel(info, dirtyBits) + if err != nil { + channel.stateMutex.Lock() + channel.dirtyBits = channel.dirtyBits | dirtyBits + channel.stateMutex.Unlock() + } + return +} + // SetRegistered registers the channel, returning an error if it was already registered. func (channel *Channel) SetRegistered(founder string) error { channel.stateMutex.Lock() @@ -698,7 +832,7 @@ func (channel *Channel) SetTopic(client *Client, topic string, rb *ResponseBuffe } } - go channel.server.channelRegistry.StoreChannel(channel, IncludeTopic) + channel.MarkDirty(IncludeTopic) } // CanSpeak returns true if the client can speak on this channel. diff --git a/irc/channelmanager.go b/irc/channelmanager.go index 5729dcfe..f520619e 100644 --- a/irc/channelmanager.go +++ b/irc/channelmanager.go @@ -19,25 +19,38 @@ type channelManagerEntry struct { // providing synchronization for creation of new channels on first join, // cleanup of empty channels on last part, and renames. type ChannelManager struct { - sync.RWMutex // tier 2 - chans map[string]*channelManagerEntry + sync.RWMutex // tier 2 + chans map[string]*channelManagerEntry + registeredChannels map[string]bool + server *Server } // NewChannelManager returns a new ChannelManager. -func NewChannelManager() *ChannelManager { - return &ChannelManager{ - chans: make(map[string]*channelManagerEntry), +func (cm *ChannelManager) Initialize(server *Server) { + cm.chans = make(map[string]*channelManagerEntry) + cm.server = server + + if server.Config().Channels.Registration.Enabled { + cm.loadRegisteredChannels() } } +func (cm *ChannelManager) loadRegisteredChannels() { + registeredChannels := cm.server.channelRegistry.AllChannels() + cm.Lock() + defer cm.Unlock() + cm.registeredChannels = registeredChannels +} + // Get returns an existing channel with name equivalent to `name`, or nil -func (cm *ChannelManager) Get(name string) *Channel { +func (cm *ChannelManager) Get(name string) (channel *Channel) { name, err := CasefoldChannel(name) if err == nil { cm.RLock() defer cm.RUnlock() entry := cm.chans[name] - if entry != nil { + // if the channel is still loading, pretend we don't have it + if entry != nil && entry.channel.IsLoaded() { return entry.channel } } @@ -55,28 +68,21 @@ func (cm *ChannelManager) Join(client *Client, name string, key string, isSajoin cm.Lock() entry := cm.chans[casefoldedName] if entry == nil { - // XXX give up the lock to check for a registration, then check again - // to see if we need to create the channel. we could solve this by doing LoadChannel - // outside the lock initially on every join, so this is best thought of as an - // optimization to avoid that. - cm.Unlock() - info := client.server.channelRegistry.LoadChannel(casefoldedName) - cm.Lock() - entry = cm.chans[casefoldedName] - if entry == nil { - entry = &channelManagerEntry{ - channel: NewChannel(server, name, info), - pendingJoins: 0, - } - cm.chans[casefoldedName] = entry + registered := cm.registeredChannels[casefoldedName] + entry = &channelManagerEntry{ + channel: NewChannel(server, name, registered), + pendingJoins: 0, } + cm.chans[casefoldedName] = entry } entry.pendingJoins += 1 + channel := entry.channel cm.Unlock() - entry.channel.Join(client, key, isSajoin, rb) + channel.EnsureLoaded() + channel.Join(client, key, isSajoin, rb) - cm.maybeCleanup(entry.channel, true) + cm.maybeCleanup(channel, true) return nil } @@ -85,7 +91,8 @@ func (cm *ChannelManager) maybeCleanup(channel *Channel, afterJoin bool) { cm.Lock() defer cm.Unlock() - entry := cm.chans[channel.NameCasefolded()] + nameCasefolded := channel.NameCasefolded() + entry := cm.chans[nameCasefolded] if entry == nil || entry.channel != channel { return } @@ -93,23 +100,15 @@ func (cm *ChannelManager) maybeCleanup(channel *Channel, afterJoin bool) { if afterJoin { entry.pendingJoins -= 1 } - // TODO(slingamn) right now, registered channels cannot be cleaned up. - // this is because once ChannelManager becomes the source of truth about a channel, - // we can't move the source of truth back to the database unless we do an ACID - // store while holding the ChannelManager's Lock(). This is pending more decisions - // about where the database transaction lock fits into the overall lock model. - if !entry.channel.IsRegistered() && entry.channel.IsEmpty() && entry.pendingJoins == 0 { - // reread the name, handling the case where the channel was renamed - casefoldedName := entry.channel.NameCasefolded() - delete(cm.chans, casefoldedName) - // invalidate the entry (otherwise, a subsequent cleanup attempt could delete - // a valid, distinct entry under casefoldedName): - entry.channel = nil + if entry.pendingJoins == 0 && entry.channel.IsClean() { + delete(cm.chans, nameCasefolded) } } // Part parts `client` from the channel named `name`, deleting it if it's empty. func (cm *ChannelManager) Part(client *Client, name string, message string, rb *ResponseBuffer) error { + var channel *Channel + casefoldedName, err := CasefoldChannel(name) if err != nil { return errNoSuchChannel @@ -117,12 +116,15 @@ func (cm *ChannelManager) Part(client *Client, name string, message string, rb * cm.RLock() entry := cm.chans[casefoldedName] + if entry != nil { + channel = entry.channel + } cm.RUnlock() - if entry == nil { + if channel == nil { return errNoSuchChannel } - entry.channel.Part(client, message, rb) + channel.Part(client, message, rb) return nil } @@ -130,8 +132,68 @@ func (cm *ChannelManager) Cleanup(channel *Channel) { cm.maybeCleanup(channel, false) } +func (cm *ChannelManager) SetRegistered(channelName string, account string) (err error) { + var channel *Channel + cfname, err := CasefoldChannel(channelName) + if err != nil { + return err + } + + var entry *channelManagerEntry + + defer func() { + if err == nil && channel != nil { + // registration was successful: make the database reflect it + err = channel.Store(IncludeAllChannelAttrs) + } + }() + + cm.Lock() + defer cm.Unlock() + entry = cm.chans[cfname] + if entry == nil { + return errNoSuchChannel + } + channel = entry.channel + err = channel.SetRegistered(account) + if err != nil { + return err + } + cm.registeredChannels[cfname] = true + return nil +} + +func (cm *ChannelManager) SetUnregistered(channelName string, account string) (err error) { + cfname, err := CasefoldChannel(channelName) + if err != nil { + return err + } + + var info RegisteredChannel + + defer func() { + if err == nil { + err = cm.server.channelRegistry.Delete(info) + } + }() + + cm.Lock() + defer cm.Unlock() + entry := cm.chans[cfname] + if entry == nil { + return errNoSuchChannel + } + info = entry.channel.ExportRegistration(0) + if info.Founder != account { + return errChannelNotOwnedByAccount + } + entry.channel.SetUnregistered(account) + delete(cm.registeredChannels, cfname) + return nil +} + // Rename renames a channel (but does not notify the members) -func (cm *ChannelManager) Rename(name string, newname string) error { +func (cm *ChannelManager) Rename(name string, newname string) (err error) { cfname, err := CasefoldChannel(name) if err != nil { return errNoSuchChannel @@ -142,22 +204,37 @@ func (cm *ChannelManager) Rename(name string, newname string) error { return errInvalidChannelName } + var channel *Channel + var info RegisteredChannel + defer func() { + if channel != nil && info.Founder != "" { + channel.Store(IncludeAllChannelAttrs) + // we just flushed the channel under its new name, therefore this delete + // cannot be overwritten by a write to the old name: + cm.server.channelRegistry.Delete(info) + } + }() + cm.Lock() defer cm.Unlock() - if cm.chans[cfnewname] != nil { + if cm.chans[cfnewname] != nil || cm.registeredChannels[cfnewname] { return errChannelNameInUse } entry := cm.chans[cfname] if entry == nil { return errNoSuchChannel } + channel = entry.channel + info = channel.ExportRegistration(IncludeInitial) delete(cm.chans, cfname) cm.chans[cfnewname] = entry - entry.channel.setName(newname) - entry.channel.setNameCasefolded(cfnewname) + if cm.registeredChannels[cfname] { + delete(cm.registeredChannels, cfname) + cm.registeredChannels[cfnewname] = true + } + entry.channel.Rename(newname, cfnewname) return nil - } // Len returns the number of channels @@ -171,8 +248,11 @@ func (cm *ChannelManager) Len() int { func (cm *ChannelManager) Channels() (result []*Channel) { cm.RLock() defer cm.RUnlock() + result = make([]*Channel, 0, len(cm.chans)) for _, entry := range cm.chans { - result = append(result, entry.channel) + if entry.channel.IsLoaded() { + result = append(result, entry.channel) + } } return } diff --git a/irc/channelreg.go b/irc/channelreg.go index 7b05bccc..19cae5b8 100644 --- a/irc/channelreg.go +++ b/irc/channelreg.go @@ -7,7 +7,6 @@ import ( "fmt" "strconv" "strings" - "sync" "time" "encoding/json" @@ -71,6 +70,8 @@ const ( type RegisteredChannel struct { // Name of the channel. Name string + // Casefolded name of the channel. + NameCasefolded string // RegisteredAt represents the time that the channel was registered. RegisteredAt time.Time // Founder indicates the founder of the channel. @@ -97,58 +98,65 @@ type RegisteredChannel struct { // ChannelRegistry manages registered channels. type ChannelRegistry struct { - // This serializes operations of the form (read channel state, synchronously persist it); - // this is enough to guarantee eventual consistency of the database with the - // ChannelManager and Channel objects, which are the source of truth. - // - // We could use the buntdb RW transaction lock for this purpose but we share - // that with all the other modules, so let's not. - sync.Mutex // tier 2 - server *Server + server *Server } // NewChannelRegistry returns a new ChannelRegistry. -func NewChannelRegistry(server *Server) *ChannelRegistry { - return &ChannelRegistry{ - server: server, - } +func (reg *ChannelRegistry) Initialize(server *Server) { + reg.server = server +} + +func (reg *ChannelRegistry) AllChannels() (result map[string]bool) { + result = make(map[string]bool) + + prefix := fmt.Sprintf(keyChannelExists, "") + reg.server.store.View(func(tx *buntdb.Tx) error { + return tx.AscendGreaterOrEqual("", prefix, func(key, value string) bool { + if !strings.HasPrefix(key, prefix) { + return false + } + channel := strings.TrimPrefix(key, prefix) + result[channel] = true + return true + }) + }) + + return } // StoreChannel obtains a consistent view of a channel, then persists it to the store. -func (reg *ChannelRegistry) StoreChannel(channel *Channel, includeFlags uint) { +func (reg *ChannelRegistry) StoreChannel(info RegisteredChannel, includeFlags uint) (err error) { if !reg.server.ChannelRegistrationEnabled() { return } - reg.Lock() - defer reg.Unlock() - - key := channel.NameCasefolded() - info := channel.ExportRegistration(includeFlags) if info.Founder == "" { // sanity check, don't try to store an unregistered channel return } reg.server.store.Update(func(tx *buntdb.Tx) error { - reg.saveChannel(tx, key, info, includeFlags) + reg.saveChannel(tx, info, includeFlags) return nil }) + + return nil } // LoadChannel loads a channel from the store. -func (reg *ChannelRegistry) LoadChannel(nameCasefolded string) (info *RegisteredChannel) { +func (reg *ChannelRegistry) LoadChannel(nameCasefolded string) (info RegisteredChannel, err error) { if !reg.server.ChannelRegistrationEnabled() { - return nil + err = errFeatureDisabled + return } channelKey := nameCasefolded // nice to have: do all JSON (de)serialization outside of the buntdb transaction - reg.server.store.View(func(tx *buntdb.Tx) error { - _, err := tx.Get(fmt.Sprintf(keyChannelExists, channelKey)) - if err == buntdb.ErrNotFound { + err = reg.server.store.View(func(tx *buntdb.Tx) error { + _, dberr := tx.Get(fmt.Sprintf(keyChannelExists, channelKey)) + if dberr == buntdb.ErrNotFound { // chan does not already exist, return - return nil + return errNoSuchChannel } // channel exists, load it @@ -181,7 +189,7 @@ func (reg *ChannelRegistry) LoadChannel(nameCasefolded string) (info *Registered accountToUMode := make(map[string]modes.Mode) _ = json.Unmarshal([]byte(accountToUModeString), &accountToUMode) - info = &RegisteredChannel{ + info = RegisteredChannel{ Name: name, RegisteredAt: time.Unix(regTimeInt, 0), Founder: founder, @@ -198,46 +206,21 @@ func (reg *ChannelRegistry) LoadChannel(nameCasefolded string) (info *Registered return nil }) - return info + return } -func (reg *ChannelRegistry) Delete(casefoldedName string, info RegisteredChannel) { +// Delete deletes a channel corresponding to `info`. If no such channel +// is present in the database, no error is returned. +func (reg *ChannelRegistry) Delete(info RegisteredChannel) (err error) { if !reg.server.ChannelRegistrationEnabled() { return } - reg.Lock() - defer reg.Unlock() - reg.server.store.Update(func(tx *buntdb.Tx) error { - reg.deleteChannel(tx, casefoldedName, info) - return nil - }) -} - -// Rename handles the persistence part of a channel rename: the channel is -// persisted under its new name, and the old name is cleaned up if necessary. -func (reg *ChannelRegistry) Rename(channel *Channel, casefoldedOldName string) { - if !reg.server.ChannelRegistrationEnabled() { - return - } - - reg.Lock() - defer reg.Unlock() - - includeFlags := IncludeAllChannelAttrs - oldKey := casefoldedOldName - key := channel.NameCasefolded() - info := channel.ExportRegistration(includeFlags) - if info.Founder == "" { - return - } - - reg.server.store.Update(func(tx *buntdb.Tx) error { - reg.deleteChannel(tx, oldKey, info) - reg.saveChannel(tx, key, info, includeFlags) + reg.deleteChannel(tx, info.NameCasefolded, info) return nil }) + return nil } // delete a channel, unless it was overwritten by another registration of the same channel @@ -274,7 +257,8 @@ func (reg *ChannelRegistry) deleteChannel(tx *buntdb.Tx, key string, info Regist } // saveChannel saves a channel to the store. -func (reg *ChannelRegistry) saveChannel(tx *buntdb.Tx, channelKey string, channelInfo RegisteredChannel, includeFlags uint) { +func (reg *ChannelRegistry) saveChannel(tx *buntdb.Tx, channelInfo RegisteredChannel, includeFlags uint) { + channelKey := channelInfo.NameCasefolded // maintain the mapping of account -> registered channels chanExistsKey := fmt.Sprintf(keyChannelExists, channelKey) _, existsErr := tx.Get(chanExistsKey) diff --git a/irc/chanserv.go b/irc/chanserv.go index 3a12cc07..b9b79d7a 100644 --- a/irc/chanserv.go +++ b/irc/chanserv.go @@ -232,15 +232,12 @@ func csRegisterHandler(server *Server, client *Client, command string, params [] } // this provides the synchronization that allows exactly one registration of the channel: - err = channelInfo.SetRegistered(account) + err = server.channels.SetRegistered(channelKey, account) if err != nil { csNotice(rb, err.Error()) return } - // registration was successful: make the database reflect it - go server.channelRegistry.StoreChannel(channelInfo, IncludeAllChannelAttrs) - csNotice(rb, fmt.Sprintf(client.t("Channel %s successfully registered"), channelName)) server.logger.Info("services", fmt.Sprintf("Client %s registered channel %s", client.nick, channelName)) @@ -297,8 +294,7 @@ func csUnregisterHandler(server *Server, client *Client, command string, params return } - channel.SetUnregistered(founder) - server.channelRegistry.Delete(channelKey, info) + server.channels.SetUnregistered(channelKey, founder) csNotice(rb, fmt.Sprintf(client.t("Channel %s is now unregistered"), channelKey)) } diff --git a/irc/client.go b/irc/client.go index 0e546477..834c877f 100644 --- a/irc/client.go +++ b/irc/client.go @@ -50,7 +50,7 @@ type Client struct { accountName string // display name of the account: uncasefolded, '*' if not logged in atime time.Time awayMessage string - capabilities *caps.Set + capabilities caps.Set capState caps.State capVersion caps.Version certfp string @@ -58,7 +58,7 @@ type Client struct { ctime time.Time exitedSnomaskSent bool fakelag Fakelag - flags *modes.ModeSet + flags modes.ModeSet hasQuit bool hops int hostname string @@ -125,15 +125,13 @@ func RunNewClient(server *Server, conn clientConn) { // give them 1k of grace over the limit: socket := NewSocket(conn.Conn, fullLineLenLimit+1024, config.Server.MaxSendQBytes) client := &Client{ - atime: now, - capabilities: caps.NewSet(), - capState: caps.NoneState, - capVersion: caps.Cap301, - channels: make(ChannelSet), - ctime: now, - flags: modes.NewModeSet(), - isTor: conn.IsTor, - languages: server.Languages().Default(), + atime: now, + capState: caps.NoneState, + capVersion: caps.Cap301, + channels: make(ChannelSet), + ctime: now, + isTor: conn.IsTor, + languages: server.Languages().Default(), loginThrottle: connection_limits.GenericThrottle{ Duration: config.Accounts.LoginThrottling.Duration, Limit: config.Accounts.LoginThrottling.MaxAttempts, @@ -546,7 +544,6 @@ func (client *Client) replayPrivmsgHistory(rb *ResponseBuffer, items []history.I // copy applicable state from oldClient to client as part of a resume func (client *Client) copyResumeData(oldClient *Client) { oldClient.stateMutex.RLock() - flags := oldClient.flags history := oldClient.history nick := oldClient.nick nickCasefolded := oldClient.nickCasefolded @@ -560,7 +557,7 @@ func (client *Client) copyResumeData(oldClient *Client) { // resume over plaintext) hasTLS := client.flags.HasMode(modes.TLS) temp := modes.NewModeSet() - temp.Copy(flags) + temp.Copy(&oldClient.flags) temp.SetMode(modes.TLS, hasTLS) client.flags.Copy(temp) diff --git a/irc/client_lookup_set.go b/irc/client_lookup_set.go index ec98fc5f..18e3f33f 100644 --- a/irc/client_lookup_set.go +++ b/irc/client_lookup_set.go @@ -37,12 +37,10 @@ type ClientManager struct { bySkeleton map[string]*Client } -// NewClientManager returns a new ClientManager. -func NewClientManager() *ClientManager { - return &ClientManager{ - byNick: make(map[string]*Client), - bySkeleton: make(map[string]*Client), - } +// Initialize initializes a ClientManager. +func (clients *ClientManager) Initialize() { + clients.byNick = make(map[string]*Client) + clients.bySkeleton = make(map[string]*Client) } // Count returns how many clients are in the manager. diff --git a/irc/errors.go b/irc/errors.go index c3169151..0a5c3c5f 100644 --- a/irc/errors.go +++ b/irc/errors.go @@ -27,6 +27,8 @@ var ( errAccountMustHoldNick = errors.New(`You must hold that nickname in order to register it`) errCallbackFailed = errors.New("Account verification could not be sent") errCertfpAlreadyExists = errors.New(`An account already exists for your certificate fingerprint`) + errChannelNotOwnedByAccount = errors.New("Channel not owned by the specified account") + errChannelDoesNotExist = errors.New("Channel does not exist") errChannelAlreadyRegistered = errors.New("Channel is already registered") errChannelNameInUse = errors.New(`Channel name in use`) errInvalidChannelName = errors.New(`Invalid channel name`) diff --git a/irc/getters.go b/irc/getters.go index d9421bf9..afa089f2 100644 --- a/irc/getters.go +++ b/irc/getters.go @@ -4,6 +4,8 @@ package irc import ( + "time" + "github.com/oragono/oragono/irc/isupport" "github.com/oragono/oragono/irc/languages" "github.com/oragono/oragono/irc/modes" @@ -267,22 +269,20 @@ func (channel *Channel) Name() string { return channel.name } -func (channel *Channel) setName(name string) { - channel.stateMutex.Lock() - defer channel.stateMutex.Unlock() - channel.name = name -} - func (channel *Channel) NameCasefolded() string { channel.stateMutex.RLock() defer channel.stateMutex.RUnlock() return channel.nameCasefolded } -func (channel *Channel) setNameCasefolded(nameCasefolded string) { +func (channel *Channel) Rename(name, nameCasefolded string) { channel.stateMutex.Lock() - defer channel.stateMutex.Unlock() + channel.name = name channel.nameCasefolded = nameCasefolded + if channel.registeredFounder != "" { + channel.registeredTime = time.Now() + } + channel.stateMutex.Unlock() } func (channel *Channel) Members() (result []*Client) { @@ -314,3 +314,10 @@ func (channel *Channel) Founder() string { defer channel.stateMutex.RUnlock() return channel.registeredFounder } + +func (channel *Channel) DirtyBits() (dirtyBits uint) { + channel.stateMutex.Lock() + dirtyBits = channel.dirtyBits + channel.stateMutex.Unlock() + return +} diff --git a/irc/handlers.go b/irc/handlers.go index dde68601..bddefc07 100644 --- a/irc/handlers.go +++ b/irc/handlers.go @@ -1607,8 +1607,8 @@ func cmodeHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Res } } - if channel.IsRegistered() && includeFlags != 0 { - go server.channelRegistry.StoreChannel(channel, includeFlags) + if includeFlags != 0 { + channel.MarkDirty(includeFlags) } // send out changes @@ -2167,7 +2167,6 @@ func renameHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Re rb.Add(nil, server.name, ERR_NOSUCHCHANNEL, client.Nick(), oldName, client.t("No such channel")) return false } - casefoldedOldName := channel.NameCasefolded() if !(channel.ClientIsAtLeast(client, modes.Operator) || client.HasRoleCapabs("chanreg")) { rb.Add(nil, server.name, ERR_CHANOPRIVSNEEDED, client.Nick(), oldName, client.t("You're not a channel operator")) return false @@ -2192,9 +2191,6 @@ func renameHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Re return false } - // rename succeeded, persist it - go server.channelRegistry.Rename(channel, casefoldedOldName) - // send RENAME messages clientPrefix := client.NickMaskString() for _, mcl := range channel.Members() { diff --git a/irc/modes.go b/irc/modes.go index f1f3c293..3fd90cb4 100644 --- a/irc/modes.go +++ b/irc/modes.go @@ -290,14 +290,14 @@ func (channel *Channel) ProcessAccountToUmodeChange(client *Client, change modes case modes.Add: if targetModeNow != targetModeAfter { channel.accountToUMode[change.Arg] = change.Mode - go client.server.channelRegistry.StoreChannel(channel, IncludeLists) + channel.MarkDirty(IncludeLists) return []modes.ModeChange{change}, nil } return nil, nil case modes.Remove: if targetModeNow == change.Mode { delete(channel.accountToUMode, change.Arg) - go client.server.channelRegistry.StoreChannel(channel, IncludeLists) + channel.MarkDirty(IncludeLists) return []modes.ModeChange{change}, nil } return nil, nil diff --git a/irc/modes/modes.go b/irc/modes/modes.go index 35d59ec0..fd5341a0 100644 --- a/irc/modes/modes.go +++ b/irc/modes/modes.go @@ -335,7 +335,6 @@ const ( // returns a pointer to a new ModeSet func NewModeSet() *ModeSet { var set ModeSet - utils.BitsetInitialize(set[:]) return &set } diff --git a/irc/semaphores.go b/irc/semaphores.go index a9ce309f..c1968b67 100644 --- a/irc/semaphores.go +++ b/irc/semaphores.go @@ -32,14 +32,13 @@ type ServerSemaphores struct { ClientDestroy Semaphore } -// NewServerSemaphores creates a new ServerSemaphores. -func NewServerSemaphores() (result *ServerSemaphores) { +// Initialize initializes a set of server semaphores. +func (serversem *ServerSemaphores) Initialize() { capacity := runtime.NumCPU() if capacity > MaxServerSemaphoreCapacity { capacity = MaxServerSemaphoreCapacity } - result = new(ServerSemaphores) - result.ClientDestroy.Initialize(capacity) + serversem.ClientDestroy.Initialize(capacity) return } diff --git a/irc/server.go b/irc/server.go index 2d1ab06a..d76ec136 100644 --- a/irc/server.go +++ b/irc/server.go @@ -61,10 +61,10 @@ type ListenerWrapper struct { // Server is the main Oragono server. type Server struct { - accounts *AccountManager - channels *ChannelManager - channelRegistry *ChannelRegistry - clients *ClientManager + accounts AccountManager + channels ChannelManager + channelRegistry ChannelRegistry + clients ClientManager config *Config configFilename string configurableStateMutex sync.RWMutex // tier 1; generic protection for server state modified by rehash() @@ -89,9 +89,9 @@ type Server struct { snomasks *SnoManager store *buntdb.DB torLimiter connection_limits.TorLimiter - whoWas *WhoWasList - stats *Stats - semaphores *ServerSemaphores + whoWas WhoWasList + stats Stats + semaphores ServerSemaphores } var ( @@ -113,8 +113,6 @@ type clientConn struct { func NewServer(config *Config, logger *logger.Manager) (*Server, error) { // initialize data structures server := &Server{ - channels: NewChannelManager(), - clients: NewClientManager(), connectionLimiter: connection_limits.NewLimiter(), connectionThrottler: connection_limits.NewThrottler(), listeners: make(map[string]*ListenerWrapper), @@ -123,12 +121,12 @@ func NewServer(config *Config, logger *logger.Manager) (*Server, error) { rehashSignal: make(chan os.Signal, 1), signals: make(chan os.Signal, len(ServerExitSignals)), snomasks: NewSnoManager(), - whoWas: NewWhoWasList(config.Limits.WhowasEntries), - stats: NewStats(), - semaphores: NewServerSemaphores(), } + server.clients.Initialize() + server.semaphores.Initialize() server.resumeManager.Initialize(server) + server.whoWas.Initialize(config.Limits.WhowasEntries) if err := server.applyConfig(config, true); err != nil { return nil, err @@ -697,6 +695,12 @@ func (server *Server) applyConfig(config *Config, initial bool) (err error) { server.accounts.initVHostRequestQueue() } + chanRegPreviouslyDisabled := oldConfig != nil && !oldConfig.Channels.Registration.Enabled + chanRegNowEnabled := config.Channels.Registration.Enabled + if chanRegPreviouslyDisabled && chanRegNowEnabled { + server.channels.loadRegisteredChannels() + } + // MaxLine if config.Limits.LineLen.Rest != 512 { SupportedCapabilities.Enable(caps.MaxLine) @@ -922,9 +926,9 @@ func (server *Server) loadDatastore(config *Config) error { server.loadDLines() server.loadKLines() - server.channelRegistry = NewChannelRegistry(server) - - server.accounts = NewAccountManager(server) + server.channelRegistry.Initialize(server) + server.channels.Initialize(server) + server.accounts.Initialize(server) return nil } diff --git a/irc/stats.go b/irc/stats.go index 65e4a67f..0921f41a 100644 --- a/irc/stats.go +++ b/irc/stats.go @@ -13,17 +13,6 @@ type Stats struct { Operators int } -// NewStats creates a new instance of Stats -func NewStats() *Stats { - serverStats := &Stats{ - Total: 0, - Invisible: 0, - Operators: 0, - } - - return serverStats -} - // ChangeTotal increments the total user count on server func (s *Stats) ChangeTotal(i int) { s.Lock() diff --git a/irc/utils/bitset.go b/irc/utils/bitset.go index 9e0014a8..adb86997 100644 --- a/irc/utils/bitset.go +++ b/irc/utils/bitset.go @@ -9,17 +9,6 @@ import "sync/atomic" // For examples of use, see caps.Set and modes.ModeSet; the array has to be converted to a // slice to use these functions. -// BitsetInitialize initializes a bitset. -func BitsetInitialize(set []uint64) { - // XXX re-zero the bitset using atomic stores. it's unclear whether this is required, - // however, golang issue #5045 suggests that you shouldn't mix atomic operations - // with non-atomic operations (such as the runtime's automatic zero-initialization) on - // the same word - for i := 0; i < len(set); i++ { - atomic.StoreUint64(&set[i], 0) - } -} - // BitsetGet returns whether a given bit of the bitset is set. func BitsetGet(set []uint64, position uint) bool { idx := position / 64 diff --git a/irc/utils/bitset_test.go b/irc/utils/bitset_test.go index a34a0097..612e48cb 100644 --- a/irc/utils/bitset_test.go +++ b/irc/utils/bitset_test.go @@ -10,7 +10,6 @@ type testBitset [2]uint64 func TestSets(t *testing.T) { var t1 testBitset t1s := t1[:] - BitsetInitialize(t1s) if BitsetGet(t1s, 0) || BitsetGet(t1s, 63) || BitsetGet(t1s, 64) || BitsetGet(t1s, 127) { t.Error("no bits should be set in a newly initialized bitset") @@ -47,7 +46,6 @@ func TestSets(t *testing.T) { var t2 testBitset t2s := t2[:] - BitsetInitialize(t2s) for i = 0; i < 128; i++ { if i%2 == 1 { diff --git a/irc/utils/sync.go b/irc/utils/sync.go new file mode 100644 index 00000000..563f6185 --- /dev/null +++ b/irc/utils/sync.go @@ -0,0 +1,35 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package utils + +import ( + "sync" + "sync/atomic" +) + +// Once is a fork of sync.Once to expose a Done() method. +type Once struct { + done uint32 + m sync.Mutex +} + +func (o *Once) Do(f func()) { + if atomic.LoadUint32(&o.done) == 0 { + o.doSlow(f) + } +} + +func (o *Once) doSlow(f func()) { + o.m.Lock() + defer o.m.Unlock() + if o.done == 0 { + defer atomic.StoreUint32(&o.done, 1) + f() + } +} + +func (o *Once) Done() bool { + return atomic.LoadUint32(&o.done) == 1 +} diff --git a/irc/whowas.go b/irc/whowas.go index e12b587c..bf8dca67 100644 --- a/irc/whowas.go +++ b/irc/whowas.go @@ -23,12 +23,10 @@ type WhoWasList struct { } // NewWhoWasList returns a new WhoWasList -func NewWhoWasList(size int) *WhoWasList { - return &WhoWasList{ - buffer: make([]WhoWas, size), - start: -1, - end: -1, - } +func (list *WhoWasList) Initialize(size int) { + list.buffer = make([]WhoWas, size) + list.start = -1 + list.end = -1 } // Append adds an entry to the WhoWasList. diff --git a/irc/whowas_test.go b/irc/whowas_test.go index 2ee76536..2309aec3 100644 --- a/irc/whowas_test.go +++ b/irc/whowas_test.go @@ -23,7 +23,8 @@ func makeTestWhowas(nick string) WhoWas { func TestWhoWas(t *testing.T) { var results []WhoWas - wwl := NewWhoWasList(3) + var wwl WhoWasList + wwl.Initialize(3) // test Find on empty list results = wwl.Find("nobody", 10) if len(results) != 0 { @@ -88,7 +89,8 @@ func TestWhoWas(t *testing.T) { func TestEmptyWhoWas(t *testing.T) { // stupid edge case; setting an empty whowas buffer should not panic - wwl := NewWhoWasList(0) + var wwl WhoWasList + wwl.Initialize(0) results := wwl.Find("slingamn", 10) if len(results) != 0 { t.Fatalf("incorrect whowas results: %v", results)