diff --git a/irc/channel.go b/irc/channel.go index 3ae6b9ea..0de004ee 100644 --- a/irc/channel.go +++ b/irc/channel.go @@ -53,13 +53,7 @@ type Channel struct { // NewChannel creates a new channel from a `Server` and a `name` // string, which must be unique on the server. -func NewChannel(s *Server, name string, registered bool) *Channel { - casefoldedName, err := CasefoldChannel(name) - if err != nil { - s.logger.Error("internal", "Bad channel name", name, err.Error()) - return nil - } - +func NewChannel(s *Server, name, casefoldedName string, registered bool) *Channel { config := s.Config() channel := &Channel{ diff --git a/irc/channelmanager.go b/irc/channelmanager.go index 8a12ed7c..507223f8 100644 --- a/irc/channelmanager.go +++ b/irc/channelmanager.go @@ -13,36 +13,54 @@ type channelManagerEntry struct { // think the channel is empty (without holding a lock across the entire Channel.Join() // call) pendingJoins int + skeleton string } // ChannelManager keeps track of all the channels on the server, // providing synchronization for creation of new channels on first join, // cleanup of empty channels on last part, and renames. type ChannelManager struct { - sync.RWMutex // tier 2 - chans map[string]*channelManagerEntry - registeredChannels map[string]bool - purgedChannels map[string]empty - server *Server + sync.RWMutex // tier 2 + // chans is the main data structure, mapping casefolded name -> *Channel + chans map[string]*channelManagerEntry + chansSkeletons StringSet // skeletons of *unregistered* chans + registeredChannels StringSet // casefolds of registered chans + registeredSkeletons StringSet // skeletons of registered chans + purgedChannels StringSet // casefolds of purged chans + server *Server } // NewChannelManager returns a new ChannelManager. func (cm *ChannelManager) Initialize(server *Server) { cm.chans = make(map[string]*channelManagerEntry) + cm.chansSkeletons = make(StringSet) cm.server = server if server.Config().Channels.Registration.Enabled { cm.loadRegisteredChannels() } + // purging should work even if registration is disabled + cm.purgedChannels = cm.server.channelRegistry.PurgedChannels() } func (cm *ChannelManager) loadRegisteredChannels() { - registeredChannels := cm.server.channelRegistry.AllChannels() - purgedChannels := cm.server.channelRegistry.PurgedChannels() + rawNames := cm.server.channelRegistry.AllChannels() + registeredChannels := make(StringSet, len(rawNames)) + registeredSkeletons := make(StringSet, len(rawNames)) + for _, name := range rawNames { + cfname, err := CasefoldChannel(name) + if err == nil { + registeredChannels.Add(cfname) + } + skeleton, err := Skeleton(name) + if err == nil { + registeredSkeletons.Add(skeleton) + } + } cm.Lock() defer cm.Unlock() cm.registeredChannels = registeredChannels - cm.purgedChannels = purgedChannels + cm.registeredSkeletons = registeredSkeletons } // Get returns an existing channel with name equivalent to `name`, or nil @@ -64,37 +82,49 @@ func (cm *ChannelManager) Get(name string) (channel *Channel) { func (cm *ChannelManager) Join(client *Client, name string, key string, isSajoin bool, rb *ResponseBuffer) error { server := client.server casefoldedName, err := CasefoldChannel(name) - if err != nil || len(casefoldedName) > server.Config().Limits.ChannelLen { + skeleton, skerr := Skeleton(name) + if err != nil || skerr != nil || len(casefoldedName) > server.Config().Limits.ChannelLen { return errNoSuchChannel } - channel := func() *Channel { + channel, err := func() (*Channel, error) { cm.Lock() defer cm.Unlock() - _, purged := cm.purgedChannels[casefoldedName] - if purged { - return nil + if cm.purgedChannels.Has(casefoldedName) { + return nil, errChannelPurged } entry := cm.chans[casefoldedName] if entry == nil { - registered := cm.registeredChannels[casefoldedName] + registered := cm.registeredChannels.Has(casefoldedName) // enforce OpOnlyCreation if !registered && server.Config().Channels.OpOnlyCreation && !client.HasRoleCapabs("chanreg") { - return nil + return nil, errInsufficientPrivs + } + // enforce confusables + if cm.chansSkeletons.Has(skeleton) || (!registered && cm.registeredSkeletons.Has(skeleton)) { + return nil, errConfusableIdentifier } entry = &channelManagerEntry{ - channel: NewChannel(server, name, registered), + channel: NewChannel(server, name, casefoldedName, registered), pendingJoins: 0, } + if !registered { + // for an unregistered channel, we already have the correct unfolded name + // and therefore the final skeleton. for a registered channel, we don't have + // the unfolded name yet (it needs to be loaded from the db), but we already + // have the final skeleton in `registeredSkeletons` so we don't need to track it + cm.chansSkeletons.Add(skeleton) + entry.skeleton = skeleton + } cm.chans[casefoldedName] = entry } entry.pendingJoins += 1 - return entry.channel + return entry.channel, nil }() - if channel == nil { - return errNoSuchChannel + if err != nil { + return err } channel.EnsureLoaded() @@ -109,8 +139,9 @@ func (cm *ChannelManager) maybeCleanup(channel *Channel, afterJoin bool) { cm.Lock() defer cm.Unlock() - nameCasefolded := channel.NameCasefolded() - entry := cm.chans[nameCasefolded] + cfname := channel.NameCasefolded() + + entry := cm.chans[cfname] if entry == nil || entry.channel != channel { return } @@ -119,7 +150,10 @@ func (cm *ChannelManager) maybeCleanup(channel *Channel, afterJoin bool) { entry.pendingJoins -= 1 } if entry.pendingJoins == 0 && entry.channel.IsClean() { - delete(cm.chans, nameCasefolded) + delete(cm.chans, cfname) + if entry.skeleton != "" { + delete(cm.chansSkeletons, entry.skeleton) + } } } @@ -177,7 +211,7 @@ func (cm *ChannelManager) SetRegistered(channelName string, account string) (err if err != nil { return err } - cm.registeredChannels[cfname] = true + cm.registeredChannels.Add(cfname) return nil } @@ -211,13 +245,17 @@ func (cm *ChannelManager) SetUnregistered(channelName string, account string) (e } // Rename renames a channel (but does not notify the members) -func (cm *ChannelManager) Rename(name string, newname string) (err error) { +func (cm *ChannelManager) Rename(name string, newName string) (err error) { cfname, err := CasefoldChannel(name) if err != nil { return errNoSuchChannel } - cfnewname, err := CasefoldChannel(newname) + newCfname, err := CasefoldChannel(newName) + if err != nil { + return errInvalidChannelName + } + newSkeleton, err := Skeleton(newName) if err != nil { return errInvalidChannelName } @@ -236,22 +274,35 @@ func (cm *ChannelManager) Rename(name string, newname string) (err error) { cm.Lock() defer cm.Unlock() - if cm.chans[cfnewname] != nil || cm.registeredChannels[cfnewname] { + if cm.chans[newCfname] != nil || cm.registeredChannels.Has(newCfname) { + return errChannelNameInUse + } + if cm.chansSkeletons.Has(newSkeleton) || cm.registeredSkeletons.Has(newSkeleton) { return errChannelNameInUse } entry := cm.chans[cfname] - if entry == nil { + if entry == nil || !entry.channel.IsLoaded() { return errNoSuchChannel } channel = entry.channel info = channel.ExportRegistration(IncludeInitial) + registered := info.Founder != "" delete(cm.chans, cfname) - cm.chans[cfnewname] = entry - if cm.registeredChannels[cfname] { + cm.chans[newCfname] = entry + if registered { delete(cm.registeredChannels, cfname) - cm.registeredChannels[cfnewname] = true + if oldSkeleton, err := Skeleton(info.Name); err == nil { + delete(cm.registeredSkeletons, oldSkeleton) + } + cm.registeredChannels.Add(newCfname) + cm.registeredSkeletons.Add(newSkeleton) + } else { + delete(cm.chansSkeletons, entry.skeleton) + cm.chansSkeletons.Add(newSkeleton) + entry.skeleton = newSkeleton + cm.chans[cfname] = entry } - entry.channel.Rename(newname, cfnewname) + entry.channel.Rename(newName, newCfname) return nil } @@ -283,7 +334,7 @@ func (cm *ChannelManager) Purge(chname string, record ChannelPurgeRecord) (err e } cm.Lock() - cm.purgedChannels[chname] = empty{} + cm.purgedChannels.Add(chname) cm.Unlock() cm.server.channelRegistry.PurgeChannel(chname, record) @@ -298,7 +349,7 @@ func (cm *ChannelManager) IsPurged(chname string) (result bool) { } cm.Lock() - _, result = cm.purgedChannels[chname] + result = cm.purgedChannels.Has(chname) cm.Unlock() return } @@ -311,7 +362,7 @@ func (cm *ChannelManager) Unpurge(chname string) (err error) { } cm.Lock() - _, found := cm.purgedChannels[chname] + found := cm.purgedChannels.Has(chname) delete(cm.purgedChannels, chname) cm.Unlock() diff --git a/irc/channelreg.go b/irc/channelreg.go index ad54affd..e0ea9b25 100644 --- a/irc/channelreg.go +++ b/irc/channelreg.go @@ -114,17 +114,15 @@ func (reg *ChannelRegistry) Initialize(server *Server) { reg.server = server } -func (reg *ChannelRegistry) AllChannels() (result map[string]bool) { - result = make(map[string]bool) - - prefix := fmt.Sprintf(keyChannelExists, "") +// AllChannels returns the uncasefolded names of all registered channels. +func (reg *ChannelRegistry) AllChannels() (result []string) { + prefix := fmt.Sprintf(keyChannelName, "") reg.server.store.View(func(tx *buntdb.Tx) error { return tx.AscendGreaterOrEqual("", prefix, func(key, value string) bool { if !strings.HasPrefix(key, prefix) { return false } - channel := strings.TrimPrefix(key, prefix) - result[channel] = true + result = append(result, value) return true }) }) @@ -132,7 +130,7 @@ func (reg *ChannelRegistry) AllChannels() (result map[string]bool) { return } -// PurgedChannels returns the set of all channel names that have been purged +// PurgedChannels returns the set of all casefolded channel names that have been purged func (reg *ChannelRegistry) PurgedChannels() (result map[string]empty) { result = make(map[string]empty) diff --git a/irc/errors.go b/irc/errors.go index f3dafd0f..00d64851 100644 --- a/irc/errors.go +++ b/irc/errors.go @@ -41,6 +41,8 @@ var ( errNicknameReserved = errors.New("nickname is reserved") errNoExistingBan = errors.New("Ban does not exist") errNoSuchChannel = errors.New(`No such channel`) + errChannelPurged = errors.New(`This channel was purged by the server operators and cannot be used`) + errConfusableIdentifier = errors.New("This identifier is confusable with one already in use") errInsufficientPrivs = errors.New("Insufficient privileges") errInvalidUsername = errors.New("Invalid username") errFeatureDisabled = errors.New(`That feature is disabled`) diff --git a/irc/handlers.go b/irc/handlers.go index 8bab0a04..661b23f3 100644 --- a/irc/handlers.go +++ b/irc/handlers.go @@ -1290,13 +1290,28 @@ func joinHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Resp key = keys[i] } err := server.channels.Join(client, name, key, false, rb) - if err == errNoSuchChannel { - rb.Add(nil, server.name, ERR_NOSUCHCHANNEL, client.Nick(), utils.SafeErrorParam(name), client.t("No such channel")) + if err != nil { + sendJoinError(client, name, rb, err) } } return false } +func sendJoinError(client *Client, name string, rb *ResponseBuffer, err error) { + var errMsg string + switch err { + case errInsufficientPrivs: + errMsg = `Only server operators can create new channels` + case errConfusableIdentifier: + errMsg = `That channel name is too close to the name of another channel` + case errChannelPurged: + errMsg = err.Error() + default: + errMsg = `No such channel` + } + rb.Add(nil, client.server.name, ERR_NOSUCHCHANNEL, client.Nick(), utils.SafeErrorParam(name), client.t(errMsg)) +} + // SAJOIN [nick] #channel{,#channel} func sajoinHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *ResponseBuffer) bool { var target *Client @@ -1306,7 +1321,7 @@ func sajoinHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Re channelString = msg.Params[0] } else { if len(msg.Params) == 1 { - rb.Add(nil, server.name, ERR_NEEDMOREPARAMS, client.Nick(), "KICK", client.t("Not enough parameters")) + rb.Add(nil, server.name, ERR_NEEDMOREPARAMS, client.Nick(), "SAJOIN", client.t("Not enough parameters")) return false } else { target = server.clients.Get(msg.Params[0]) @@ -1320,7 +1335,10 @@ func sajoinHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Re channels := strings.Split(channelString, ",") for _, chname := range channels { - server.channels.Join(target, chname, "", true, rb) + err := server.channels.Join(target, chname, "", true, rb) + if err != nil { + sendJoinError(client, chname, rb, err) + } } return false } diff --git a/irc/types.go b/irc/types.go index 756f1b61..6e1a1110 100644 --- a/irc/types.go +++ b/irc/types.go @@ -28,6 +28,17 @@ func (clients ClientSet) Has(client *Client) bool { return ok } +type StringSet map[string]empty + +func (s StringSet) Has(str string) bool { + _, ok := s[str] + return ok +} + +func (s StringSet) Add(str string) { + s[str] = empty{} +} + // MemberSet is a set of members with modes. type MemberSet map[*Client]*modes.ModeSet