diff --git a/irc/channel.go b/irc/channel.go index 1fab08d8..52f50fab 100644 --- a/irc/channel.go +++ b/irc/channel.go @@ -7,7 +7,6 @@ package irc import ( "fmt" - "log" "strconv" "time" @@ -41,7 +40,7 @@ type Channel struct { func NewChannel(s *Server, name string, addDefaultModes bool) *Channel { casefoldedName, err := CasefoldChannel(name) if err != nil { - log.Println(fmt.Sprintf("ERROR: Channel name is bad: [%s]", name), err.Error()) + s.logger.Error("internal", fmt.Sprintf("Bad channel name %s: %v", name, err)) return nil } @@ -59,13 +58,11 @@ func NewChannel(s *Server, name string, addDefaultModes bool) *Channel { } if addDefaultModes { - for _, mode := range s.GetDefaultChannelModes() { + for _, mode := range s.DefaultChannelModes() { channel.flags[mode] = true } } - s.channels.Add(channel) - return channel } @@ -281,6 +278,12 @@ func (channel *Channel) CheckKey(key string) bool { return (channel.key == "") || (channel.key == key) } +func (channel *Channel) IsEmpty() bool { + channel.stateMutex.RLock() + defer channel.stateMutex.RUnlock() + return len(channel.members) == 0 +} + // Join joins the given client to this channel (if they can be joined). //TODO(dan): /SAJOIN and maybe a ForceJoin function? func (channel *Channel) Join(client *Client, key string) { @@ -684,16 +687,10 @@ func (channel *Channel) applyModeMask(client *Client, mode Mode, op ModeOp, mask func (channel *Channel) Quit(client *Client) { channel.stateMutex.Lock() channel.members.Remove(client) - empty := len(channel.members) == 0 channel.stateMutex.Unlock() channel.regenerateMembersCache() client.removeChannel(channel) - - //TODO(slingamn) fold this operation into a channelmanager type - if empty { - channel.server.channels.Remove(channel) - } } func (channel *Channel) Kick(client *Client, target *Client, comment string) { diff --git a/irc/channelmanager.go b/irc/channelmanager.go new file mode 100644 index 00000000..2e5ab072 --- /dev/null +++ b/irc/channelmanager.go @@ -0,0 +1,162 @@ +// Copyright (c) 2017 Shivaram Lingamneni +// released under the MIT license + +package irc + +import ( + "errors" + "sync" +) + +var ( + InvalidChannelName = errors.New("Invalid channel name") + NoSuchChannel = errors.New("No such channel") + ChannelNameInUse = errors.New("Channel name in use") +) + +type channelManagerEntry struct { + channel *Channel + // this is a refcount for joins, so we can avoid a race where we incorrectly + // think the channel is empty (without holding a lock across the entire Channel.Join() + // call) + pendingJoins int +} + +// ChannelManager keeps track of all the channels on the server, +// providing synchronization for creation of new channels on first join, +// cleanup of empty channels on last part, and renames. +type ChannelManager struct { + sync.RWMutex // tier 2 + chans map[string]*channelManagerEntry +} + +// NewChannelManager returns a new ChannelManager. +func NewChannelManager() *ChannelManager { + return &ChannelManager{ + chans: make(map[string]*channelManagerEntry), + } +} + +// Get returns an existing channel with name equivalent to `name`, or nil +func (cm *ChannelManager) Get(name string) *Channel { + name, err := CasefoldChannel(name) + if err == nil { + cm.RLock() + defer cm.RUnlock() + return cm.chans[name].channel + } + return nil +} + +// Join causes `client` to join the channel named `name`, creating it if necessary. +func (cm *ChannelManager) Join(client *Client, name string, key string) error { + server := client.server + casefoldedName, err := CasefoldChannel(name) + if err != nil || len(casefoldedName) > server.getLimits().ChannelLen { + return NoSuchChannel + } + + cm.Lock() + entry := cm.chans[casefoldedName] + if entry == nil { + entry = &channelManagerEntry{ + channel: NewChannel(server, name, true), + pendingJoins: 0, + } + cm.chans[casefoldedName] = entry + } + entry.pendingJoins += 1 + cm.Unlock() + + entry.channel.Join(client, key) + + cm.maybeCleanup(entry, true) + + return nil +} + +func (cm *ChannelManager) maybeCleanup(entry *channelManagerEntry, afterJoin bool) { + cm.Lock() + defer cm.Unlock() + + if entry.channel == nil { + return + } + if afterJoin { + entry.pendingJoins -= 1 + } + if entry.channel.IsEmpty() && entry.pendingJoins == 0 { + // reread the name, handling the case where the channel was renamed + casefoldedName := entry.channel.NameCasefolded() + delete(cm.chans, casefoldedName) + // invalidate the entry (otherwise, a subsequent cleanup attempt could delete + // a valid, distinct entry under casefoldedName): + entry.channel = nil + } +} + +// Part parts `client` from the channel named `name`, deleting it if it's empty. +func (cm *ChannelManager) Part(client *Client, name string, message string) error { + casefoldedName, err := CasefoldChannel(name) + if err != nil { + return NoSuchChannel + } + + cm.RLock() + entry := cm.chans[casefoldedName] + cm.RUnlock() + + if entry == nil { + return NoSuchChannel + } + entry.channel.Part(client, message) + cm.maybeCleanup(entry, false) + return nil +} + +// Rename renames a channel (but does not notify the members) +func (cm *ChannelManager) Rename(name string, newname string) error { + cfname, err := CasefoldChannel(name) + if err != nil { + return NoSuchChannel + } + + cfnewname, err := CasefoldChannel(newname) + if err != nil { + return InvalidChannelName + } + + cm.Lock() + defer cm.Unlock() + + if cm.chans[cfnewname] != nil { + return ChannelNameInUse + } + entry := cm.chans[cfname] + if entry == nil { + return NoSuchChannel + } + delete(cm.chans, cfname) + cm.chans[cfnewname] = entry + entry.channel.setName(newname) + entry.channel.setNameCasefolded(cfnewname) + return nil + +} + +// Len returns the number of channels +func (cm *ChannelManager) Len() int { + cm.RLock() + defer cm.RUnlock() + return len(cm.chans) +} + +// Channels returns a slice containing all current channels +func (cm *ChannelManager) Channels() (result []*Channel) { + cm.RLock() + defer cm.RUnlock() + for _, entry := range cm.chans { + result = append(result, entry.channel) + } + return +} diff --git a/irc/client.go b/irc/client.go index 5dcef5e6..0e23b1fb 100644 --- a/irc/client.go +++ b/irc/client.go @@ -548,14 +548,12 @@ func (client *Client) destroy() { client.server.monitorManager.RemoveAll(client) // clean up channels - client.server.channelJoinPartMutex.Lock() - for channel := range client.channels { + for _, channel := range client.Channels() { channel.Quit(client) for _, member := range channel.Members() { friends.Add(member) } } - client.server.channelJoinPartMutex.Unlock() // clean up server client.server.clients.Remove(client) diff --git a/irc/getters.go b/irc/getters.go index 92c11528..ecc1ac7f 100644 --- a/irc/getters.go +++ b/irc/getters.go @@ -41,6 +41,12 @@ func (server *Server) WebIRCConfig() []webircConfig { return server.webirc } +func (server *Server) DefaultChannelModes() Modes { + server.configurableStateMutex.RLock() + defer server.configurableStateMutex.RUnlock() + return server.defaultChannelModes +} + func (client *Client) getNick() string { client.stateMutex.RLock() defer client.stateMutex.RUnlock() @@ -114,6 +120,24 @@ func (channel *Channel) Name() string { return channel.name } +func (channel *Channel) setName(name string) { + channel.stateMutex.Lock() + defer channel.stateMutex.Unlock() + channel.name = name +} + +func (channel *Channel) NameCasefolded() string { + channel.stateMutex.RLock() + defer channel.stateMutex.RUnlock() + return channel.nameCasefolded +} + +func (channel *Channel) setNameCasefolded(nameCasefolded string) { + channel.stateMutex.Lock() + defer channel.stateMutex.Unlock() + channel.nameCasefolded = nameCasefolded +} + func (channel *Channel) Members() (result []*Client) { channel.stateMutex.RLock() defer channel.stateMutex.RUnlock() diff --git a/irc/monitor.go b/irc/monitor.go index 3674a845..c3c2134a 100644 --- a/irc/monitor.go +++ b/irc/monitor.go @@ -52,12 +52,9 @@ func (manager *MonitorManager) AlertAbout(client *Client, online bool) { command = RPL_MONONLINE } - // asynchronously send all the notifications - go func() { - for _, mClient := range watchers { - mClient.Send(nil, client.server.name, command, mClient.getNick(), nick) - } - }() + for _, mClient := range watchers { + mClient.Send(nil, client.server.name, command, mClient.getNick(), nick) + } } // Add registers `client` to receive notifications about `nick`. diff --git a/irc/server.go b/irc/server.go index d5efe435..4eaece87 100644 --- a/irc/server.go +++ b/irc/server.go @@ -9,6 +9,7 @@ import ( "bufio" "crypto/tls" "encoding/base64" + "errors" "fmt" "log" "math/rand" @@ -39,6 +40,8 @@ var ( // common error responses couldNotParseIPMsg, _ = (&[]ircmsg.IrcMessage{ircmsg.MakeMessage(nil, "", "ERROR", "Unable to parse your IP address")}[0]).Line() + + RenamePrivsNeeded = errors.New("Only chanops can rename channels") ) const ( @@ -80,8 +83,7 @@ type Server struct { accountRegistration *AccountRegistration accounts map[string]*ClientAccount channelRegistrationEnabled bool - channels ChannelNameMap - channelJoinPartMutex sync.Mutex // used when joining/parting channels to prevent stomping over each others' access and all + channels *ChannelManager checkIdent bool clients *ClientLookupSet commands chan Command @@ -147,7 +149,7 @@ func NewServer(config *Config, logger *logger.Manager) (*Server, error) { // initialize data structures server := &Server{ accounts: make(map[string]*ClientAccount), - channels: *NewChannelNameMap(), + channels: NewChannelManager(), clients: NewClientLookupSet(), commands: make(chan Command), connectionLimiter: connection_limits.NewLimiter(), @@ -553,53 +555,62 @@ func pongHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool { } // RENAME [] -//TODO(dan): Clean up this function so it doesn't look like an eldrich horror... prolly by putting it into a server.renameChannel function. -func renameHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool { - // get lots of locks... make sure nobody touches anything while we're doing this +func renameHandler(server *Server, client *Client, msg ircmsg.IrcMessage) (result bool) { + result = false + + // TODO(slingamn, #152) clean up locking here server.registeredChannelsMutex.Lock() defer server.registeredChannelsMutex.Unlock() - server.channels.ChansLock.Lock() - defer server.channels.ChansLock.Unlock() + + errorResponse := func(err error, name string) { + // TODO: send correct error codes, e.g., ERR_CANNOTRENAME, ERR_CHANNAMEINUSE + var code string + switch err { + case NoSuchChannel: + code = ERR_NOSUCHCHANNEL + case RenamePrivsNeeded: + code = ERR_CHANOPRIVSNEEDED + case InvalidChannelName: + code = ERR_UNKNOWNERROR + case ChannelNameInUse: + code = ERR_UNKNOWNERROR + default: + code = ERR_UNKNOWNERROR + } + client.Send(nil, server.name, code, client.getNick(), "RENAME", name, err.Error()) + } oldName := strings.TrimSpace(msg.Params[0]) newName := strings.TrimSpace(msg.Params[1]) + if oldName == "" || newName == "" { + errorResponse(InvalidChannelName, "") + return + } + casefoldedOldName, err := CasefoldChannel(oldName) + if err != nil { + errorResponse(InvalidChannelName, oldName) + return + } + casefoldedNewName, err := CasefoldChannel(newName) + if err != nil { + errorResponse(InvalidChannelName, newName) + return + } + reason := "No reason" if 2 < len(msg.Params) { reason = msg.Params[2] } - // check for all the reasons why the rename couldn't happen - casefoldedOldName, err := CasefoldChannel(oldName) - if err != nil { - //TODO(dan): Change this to ERR_CANNOTRENAME - client.Send(nil, server.name, ERR_UNKNOWNERROR, client.nick, "RENAME", oldName, "Old channel name is invalid") - return false - } - - channel := server.channels.Chans[casefoldedOldName] + channel := server.channels.Get(oldName) if channel == nil { - client.Send(nil, server.name, ERR_NOSUCHCHANNEL, client.nick, oldName, "No such channel") - return false + errorResponse(NoSuchChannel, oldName) + return } - //TODO(dan): allow IRCops to do this? if !channel.ClientIsAtLeast(client, Operator) { - client.Send(nil, server.name, ERR_CHANOPRIVSNEEDED, client.nick, oldName, "Only chanops can rename channels") - return false - } - - casefoldedNewName, err := CasefoldChannel(newName) - if err != nil { - //TODO(dan): Change this to ERR_CANNOTRENAME - client.Send(nil, server.name, ERR_UNKNOWNERROR, client.nick, "RENAME", newName, "New channel name is invalid") - return false - } - - newChannel := server.channels.Chans[casefoldedNewName] - if newChannel != nil { - //TODO(dan): Change this to ERR_CHANNAMEINUSE - client.Send(nil, server.name, ERR_UNKNOWNERROR, client.nick, "RENAME", newName, "New channel name is in use") - return false + errorResponse(RenamePrivsNeeded, oldName) + return } var canEdit bool @@ -622,11 +633,11 @@ func renameHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool { } // perform the channel rename - server.channels.Chans[casefoldedOldName] = nil - server.channels.Chans[casefoldedNewName] = channel - - channel.name = strings.TrimSpace(msg.Params[1]) - channel.nameCasefolded = casefoldedNewName + err = server.channels.Rename(oldName, newName) + if err != nil { + errorResponse(err, newName) + return + } // rename stored channel info if any exists server.store.Update(func(tx *buntdb.Tx) error { @@ -679,34 +690,15 @@ func joinHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool { keys = strings.Split(msg.Params[1], ",") } - // get lock - server.channelJoinPartMutex.Lock() - defer server.channelJoinPartMutex.Unlock() - for i, name := range channels { - casefoldedName, err := CasefoldChannel(name) - if err != nil { - if len(name) > 0 { - client.Send(nil, server.name, ERR_NOSUCHCHANNEL, client.nick, name, "No such channel") - } - continue - } - - channel := server.channels.Get(casefoldedName) - if channel == nil { - if len(casefoldedName) > server.getLimits().ChannelLen { - client.Send(nil, server.name, ERR_NOSUCHCHANNEL, client.nick, name, "No such channel") - continue - } - channel = NewChannel(server, name, true) - } - var key string if len(keys) > i { key = keys[i] } - - channel.Join(client, key) + err := server.channels.Join(client, name, key) + if err == NoSuchChannel { + client.Send(nil, server.name, ERR_NOSUCHCHANNEL, client.getNick(), name, "No such channel") + } } return false } @@ -719,22 +711,11 @@ func partHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool { reason = msg.Params[1] } - // get lock - server.channelJoinPartMutex.Lock() - defer server.channelJoinPartMutex.Unlock() - for _, chname := range channels { - casefoldedChannelName, err := CasefoldChannel(chname) - channel := server.channels.Get(casefoldedChannelName) - - if err != nil || channel == nil { - if len(chname) > 0 { - client.Send(nil, server.name, ERR_NOSUCHCHANNEL, client.nick, chname, "No such channel") - } - continue + err := server.channels.Part(client, chname, reason) + if err == NoSuchChannel { + client.Send(nil, server.name, ERR_NOSUCHCHANNEL, client.nick, chname, "No such channel") } - - channel.Part(client, reason) } return false } @@ -1096,11 +1077,9 @@ func whoHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool { //} if mask == "" { - server.channels.ChansLock.RLock() - for _, channel := range server.channels.Chans { + for _, channel := range server.channels.Channels() { whoChannel(client, channel, friends) } - server.channels.ChansLock.RUnlock() } else if mask[0] == '#' { // TODO implement wildcard matching //TODO(dan): ^ only for opers @@ -1859,8 +1838,7 @@ func listHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool { } if len(channels) == 0 { - server.channels.ChansLock.RLock() - for _, channel := range server.channels.Chans { + for _, channel := range server.channels.Channels() { if !client.flags[Operator] && channel.flags[Secret] { continue } @@ -1868,7 +1846,6 @@ func listHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool { client.RplList(channel) } } - server.channels.ChansLock.RUnlock() } else { // limit regular users to only listing one channel if !client.flags[Operator] { @@ -1922,11 +1899,9 @@ func namesHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool { //} if len(channels) == 0 { - server.channels.ChansLock.RLock() - for _, channel := range server.channels.Chans { + for _, channel := range server.channels.Channels() { channel.Names(client) } - server.channels.ChansLock.RUnlock() return false } diff --git a/irc/types.go b/irc/types.go index c7474177..3595a021 100644 --- a/irc/types.go +++ b/irc/types.go @@ -6,64 +6,9 @@ package irc import ( - "fmt" "strings" - "sync" ) -// ChannelNameMap is a map that converts channel names to actual channel objects. -type ChannelNameMap struct { - ChansLock sync.RWMutex - Chans map[string]*Channel -} - -// NewChannelNameMap returns a new ChannelNameMap. -func NewChannelNameMap() *ChannelNameMap { - var channels ChannelNameMap - channels.Chans = make(map[string]*Channel) - return &channels -} - -// Get returns the given channel if it exists. -func (channels *ChannelNameMap) Get(name string) *Channel { - name, err := CasefoldChannel(name) - if err == nil { - channels.ChansLock.RLock() - defer channels.ChansLock.RUnlock() - return channels.Chans[name] - } - return nil -} - -// Add adds the given channel to our map. -func (channels *ChannelNameMap) Add(channel *Channel) error { - channels.ChansLock.Lock() - defer channels.ChansLock.Unlock() - if channels.Chans[channel.nameCasefolded] != nil { - return fmt.Errorf("%s: already set", channel.name) - } - channels.Chans[channel.nameCasefolded] = channel - return nil -} - -// Remove removes the given channel from our map. -func (channels *ChannelNameMap) Remove(channel *Channel) error { - channels.ChansLock.Lock() - defer channels.ChansLock.Unlock() - if channel != channels.Chans[channel.nameCasefolded] { - return fmt.Errorf("%s: mismatch", channel.name) - } - delete(channels.Chans, channel.nameCasefolded) - return nil -} - -// Len returns how many channels we have. -func (channels *ChannelNameMap) Len() int { - channels.ChansLock.RLock() - defer channels.ChansLock.RUnlock() - return len(channels.Chans) -} - // ModeSet holds a set of modes. type ModeSet map[Mode]bool