diff --git a/irc/channel.go b/irc/channel.go index f1057e1f..960b1a04 100644 --- a/irc/channel.go +++ b/irc/channel.go @@ -660,7 +660,7 @@ func (channel *Channel) AddHistoryItem(item history.Item, account string) (err e } // Join joins the given client to this channel (if they can be joined). -func (channel *Channel) Join(client *Client, key string, isSajoin bool, rb *ResponseBuffer) { +func (channel *Channel) Join(client *Client, key string, isSajoin bool, rb *ResponseBuffer) error { details := client.Details() channel.stateMutex.RLock() @@ -676,39 +676,43 @@ func (channel *Channel) Join(client *Client, key string, isSajoin bool, rb *Resp if alreadyJoined { // no message needs to be sent - return + return nil } - // the founder can always join (even if they disabled auto +q on join); - // anyone who automatically receives halfop or higher can always join - hasPrivs := isSajoin || (founder != "" && founder == details.account) || (persistentMode != 0 && persistentMode != modes.Voice) + // 0. SAJOIN always succeeds + // 1. the founder can always join (even if they disabled auto +q on join) + // 2. anyone who automatically receives halfop or higher can always join + // 3. people invited with INVITE can join + hasPrivs := isSajoin || (founder != "" && founder == details.account) || + (persistentMode != 0 && persistentMode != modes.Voice) || + client.CheckInvited(chcfname) + if !hasPrivs { + if limit != 0 && chcount >= limit { + return errLimitExceeded + } - if !hasPrivs && limit != 0 && chcount >= limit { - rb.Add(nil, client.server.name, ERR_CHANNELISFULL, details.nick, chname, fmt.Sprintf(client.t("Cannot join channel (+%s)"), "l")) - return + if chkey != "" && !utils.SecretTokensMatch(chkey, key) { + return errWrongChannelKey + } + + if channel.flags.HasMode(modes.InviteOnly) && + !channel.lists[modes.InviteMask].Match(details.nickMaskCasefolded) { + return errInviteOnly + } + + if channel.lists[modes.BanMask].Match(details.nickMaskCasefolded) && + !channel.lists[modes.ExceptMask].Match(details.nickMaskCasefolded) && + !channel.lists[modes.InviteMask].Match(details.nickMaskCasefolded) { + return errBanned + } + + if channel.flags.HasMode(modes.RegisteredOnly) && details.account == "" { + return errRegisteredOnly + } } - if !hasPrivs && chkey != "" && !utils.SecretTokensMatch(chkey, key) { - rb.Add(nil, client.server.name, ERR_BADCHANNELKEY, details.nick, chname, fmt.Sprintf(client.t("Cannot join channel (+%s)"), "k")) - return - } - - isInvited := client.CheckInvited(chcfname) || channel.lists[modes.InviteMask].Match(details.nickMaskCasefolded) - if !hasPrivs && channel.flags.HasMode(modes.InviteOnly) && !isInvited { - rb.Add(nil, client.server.name, ERR_INVITEONLYCHAN, details.nick, chname, fmt.Sprintf(client.t("Cannot join channel (+%s)"), "i")) - return - } - - if !hasPrivs && channel.lists[modes.BanMask].Match(details.nickMaskCasefolded) && - !isInvited && - !channel.lists[modes.ExceptMask].Match(details.nickMaskCasefolded) { - rb.Add(nil, client.server.name, ERR_BANNEDFROMCHAN, details.nick, chname, fmt.Sprintf(client.t("Cannot join channel (+%s)"), "b")) - return - } - - if !hasPrivs && channel.flags.HasMode(modes.RegisteredOnly) && details.account == "" && !isInvited { - rb.Add(nil, client.server.name, ERR_NEEDREGGEDNICK, details.nick, chname, client.t("You must be registered to join that channel")) - return + if joinErr := client.addChannel(channel, rb == nil); joinErr != nil { + return joinErr } client.server.logger.Debug("join", fmt.Sprintf("%s joined channel %s", details.nick, chname)) @@ -753,10 +757,8 @@ func (channel *Channel) Join(client *Client, key string, isSajoin bool, rb *Resp channel.AddHistoryItem(histItem, details.account) } - client.addChannel(channel, rb == nil) - if rb == nil { - return + return nil } var modestr string @@ -799,6 +801,7 @@ func (channel *Channel) Join(client *Client, key string, isSajoin bool, rb *Resp rb.Flush(true) channel.autoReplayHistory(client, rb, message.Msgid) + return nil } func (channel *Channel) autoReplayHistory(client *Client, rb *ResponseBuffer, skipMsgid string) { @@ -1437,9 +1440,7 @@ func (channel *Channel) Invite(invitee *Client, inviter *Client, rb *ResponseBuf return } - if channel.flags.HasMode(modes.InviteOnly) { - invitee.Invite(channel.NameCasefolded()) - } + invitee.Invite(channel.NameCasefolded()) for _, member := range channel.Members() { if member == inviter || member == invitee || !channel.ClientIsAtLeast(member, modes.Halfop) { diff --git a/irc/channelmanager.go b/irc/channelmanager.go index 1b32ffaa..fce4b0b6 100644 --- a/irc/channelmanager.go +++ b/irc/channelmanager.go @@ -130,11 +130,11 @@ func (cm *ChannelManager) Join(client *Client, name string, key string, isSajoin } channel.EnsureLoaded() - channel.Join(client, key, isSajoin, rb) + err = channel.Join(client, key, isSajoin, rb) cm.maybeCleanup(channel, true) - return nil + return err } func (cm *ChannelManager) maybeCleanup(channel *Channel, afterJoin bool) { diff --git a/irc/client.go b/irc/client.go index 8dc4e9dd..9f772ca3 100644 --- a/irc/client.go +++ b/irc/client.go @@ -62,7 +62,7 @@ type Client struct { exitedSnomaskSent bool modes modes.ModeSet hostname string - invitedTo map[string]bool + invitedTo StringSet isSTSOnly bool languages []string lastActive time.Time // last time they sent a command that wasn't PONG or similar @@ -1595,15 +1595,24 @@ func (session *Session) Notice(text string) { // `simulated` is for the fake join of an always-on client // (we just read the channel name from the database, there's no need to write it back) -func (client *Client) addChannel(channel *Channel, simulated bool) { +func (client *Client) addChannel(channel *Channel, simulated bool) (err error) { + config := client.server.Config() + client.stateMutex.Lock() - client.channels[channel] = true alwaysOn := client.alwaysOn + if client.destroyed { + err = errClientDestroyed + } else if client.oper == nil && len(client.channels) >= config.Channels.MaxChannelsPerClient { + err = errTooManyChannels + } else { + client.channels[channel] = empty{} // success + } client.stateMutex.Unlock() - if alwaysOn && !simulated { + if err == nil && alwaysOn && !simulated { client.markDirty(IncludeChannels) } + return } func (client *Client) removeChannel(channel *Channel) { @@ -1623,10 +1632,10 @@ func (client *Client) Invite(casefoldedChannel string) { defer client.stateMutex.Unlock() if client.invitedTo == nil { - client.invitedTo = make(map[string]bool) + client.invitedTo = make(StringSet) } - client.invitedTo[casefoldedChannel] = true + client.invitedTo.Add(casefoldedChannel) } // Checks that the client was invited to join a given channel @@ -1634,7 +1643,7 @@ func (client *Client) CheckInvited(casefoldedChannel string) (invited bool) { client.stateMutex.Lock() defer client.stateMutex.Unlock() - invited = client.invitedTo[casefoldedChannel] + invited = client.invitedTo.Has(casefoldedChannel) // joining an invited channel "uses up" your invite, so you can't rejoin on kick delete(client.invitedTo, casefoldedChannel) return diff --git a/irc/errors.go b/irc/errors.go index 09afdf8a..70c75af1 100644 --- a/irc/errors.go +++ b/irc/errors.go @@ -67,6 +67,11 @@ var ( errInvalidMultilineBatch = errors.New("Invalid multiline batch") errTimedOut = errors.New("Operation timed out") errInvalidUtf8 = errors.New("Message rejected for invalid utf8") + errClientDestroyed = errors.New("Client was already destroyed") + errTooManyChannels = errors.New("You have joined too many channels") + errWrongChannelKey = errors.New("Cannot join password-protected channel without the password") + errInviteOnly = errors.New("Cannot join invite-only channel without an invite") + errRegisteredOnly = errors.New("Cannot join registered-only channel without an account") ) // Socket Errors diff --git a/irc/handlers.go b/irc/handlers.go index bd254b70..eaa83e27 100644 --- a/irc/handlers.go +++ b/irc/handlers.go @@ -1142,16 +1142,10 @@ func joinHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Resp keys = strings.Split(msg.Params[1], ",") } - config := server.Config() - oper := client.Oper() for i, name := range channels { if name == "" { continue // #679 } - if config.Channels.MaxChannelsPerClient <= client.NumChannels() && oper == nil { - rb.Add(nil, server.name, ERR_TOOMANYCHANNELS, client.Nick(), name, client.t("You have joined too many channels")) - return false - } var key string if len(keys) > i { key = keys[i] @@ -1165,18 +1159,35 @@ func joinHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Resp } func sendJoinError(client *Client, name string, rb *ResponseBuffer, err error) { - var errMsg string + var code, errMsg, forbiddingMode string switch err { case errInsufficientPrivs: - errMsg = `Only server operators can create new channels` + code, errMsg = ERR_NOSUCHCHANNEL, `Only server operators can create new channels` case errConfusableIdentifier: - errMsg = `That channel name is too close to the name of another channel` + code, errMsg = ERR_NOSUCHCHANNEL, `That channel name is too close to the name of another channel` case errChannelPurged: - errMsg = err.Error() + code, errMsg = ERR_NOSUCHCHANNEL, err.Error() + case errTooManyChannels: + code, errMsg = ERR_TOOMANYCHANNELS, `You have joined too many channels` + case errLimitExceeded: + code, forbiddingMode = ERR_CHANNELISFULL, "l" + case errWrongChannelKey: + code, forbiddingMode = ERR_BADCHANNELKEY, "k" + case errInviteOnly: + code, forbiddingMode = ERR_INVITEONLYCHAN, "i" + case errBanned: + code, forbiddingMode = ERR_BANNEDFROMCHAN, "b" + case errRegisteredOnly: + code, errMsg = ERR_NEEDREGGEDNICK, `You must be registered to join that channel` default: - errMsg = `No such channel` + code, errMsg = ERR_NOSUCHCHANNEL, `No such channel` } - rb.Add(nil, client.server.name, ERR_NOSUCHCHANNEL, client.Nick(), utils.SafeErrorParam(name), client.t(errMsg)) + if forbiddingMode != "" { + errMsg = fmt.Sprintf(client.t("Cannot join channel (+%s)"), forbiddingMode) + } else { + errMsg = client.t(errMsg) + } + rb.Add(nil, client.server.name, code, client.Nick(), utils.SafeErrorParam(name), errMsg) } // SAJOIN [nick] #channel{,#channel} diff --git a/irc/types.go b/irc/types.go index 6e1a1110..7307a725 100644 --- a/irc/types.go +++ b/irc/types.go @@ -69,4 +69,4 @@ func (members MemberSet) AnyHasMode(mode modes.Mode) bool { } // ChannelSet is a set of channels. -type ChannelSet map[*Channel]bool +type ChannelSet map[*Channel]empty