diff --git a/irc/capability.go b/irc/capability.go index fa4ba9b8..afb4760e 100644 --- a/irc/capability.go +++ b/irc/capability.go @@ -13,28 +13,12 @@ import ( var ( // SupportedCapabilities are the caps we advertise. - SupportedCapabilities = CapabilitySet{ - caps.AccountTag: true, - caps.AccountNotify: true, - caps.AwayNotify: true, - caps.CapNotify: true, - caps.ChgHost: true, - caps.EchoMessage: true, - caps.ExtendedJoin: true, - caps.InviteNotify: true, - // MaxLine is set during server startup - caps.MessageTags: true, - caps.MultiPrefix: true, - caps.Rename: true, - // SASL is set during server startup - caps.ServerTime: true, - // STS is set during server startup - caps.UserhostInNames: true, - } + // MaxLine, SASL and STS are set during server startup. + SupportedCapabilities = caps.NewSet(caps.AccountTag, caps.AccountNotify, caps.AwayNotify, caps.CapNotify, caps.ChgHost, caps.EchoMessage, caps.ExtendedJoin, caps.InviteNotify, caps.MessageTags, caps.MultiPrefix, caps.Rename, caps.ServerTime, caps.UserhostInNames) + // CapValues are the actual values we advertise to v3.2 clients. - CapValues = map[caps.Capability]string{ - caps.SASL: "PLAIN,EXTERNAL", - } + // actual values are set during server startup. + CapValues = caps.NewValues() ) // CapState shows whether we're negotiating caps, finished, etc for connection registration. @@ -49,40 +33,10 @@ const ( CapNegotiated CapState = iota ) -// CapVersion is used to select which max version of CAP the client supports. -type CapVersion uint - -const ( - // Cap301 refers to the base CAP spec. - Cap301 CapVersion = 301 - // Cap302 refers to the IRCv3.2 CAP spec. - Cap302 CapVersion = 302 -) - -// CapabilitySet is used to track supported, enabled, and existing caps. -type CapabilitySet map[caps.Capability]bool - -func (set CapabilitySet) String(version CapVersion) string { - strs := make([]string, len(set)) - index := 0 - for capability := range set { - capString := string(capability) - if version == Cap302 { - val, exists := CapValues[capability] - if exists { - capString += "=" + val - } - } - strs[index] = capString - index++ - } - return strings.Join(strs, " ") -} - // CAP [] func capHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool { subCommand := strings.ToUpper(msg.Params[0]) - capabilities := make(CapabilitySet) + capabilities := caps.NewSet() var capString string if len(msg.Params) > 1 { @@ -90,7 +44,7 @@ func capHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool { strs := strings.Split(capString, " ") for _, str := range strs { if len(str) > 0 { - capabilities[caps.Capability(str)] = true + capabilities.Enable(caps.Capability(str)) } } } @@ -107,22 +61,20 @@ func capHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool { // the server.name source... otherwise it doesn't respond to the CAP message with // anything and just hangs on connection. //TODO(dan): limit number of caps and send it multiline in 3.2 style as appropriate. - client.Send(nil, server.name, "CAP", client.nick, subCommand, SupportedCapabilities.String(client.capVersion)) + client.Send(nil, server.name, "CAP", client.nick, subCommand, SupportedCapabilities.String(client.capVersion, CapValues)) case "LIST": - client.Send(nil, server.name, "CAP", client.nick, subCommand, client.capabilities.String(Cap301)) // values not sent on LIST so force 3.1 + client.Send(nil, server.name, "CAP", client.nick, subCommand, client.capabilities.String(caps.Cap301, CapValues)) // values not sent on LIST so force 3.1 case "REQ": // make sure all capabilities actually exist - for capability := range capabilities { - if !SupportedCapabilities[capability] { + for _, capability := range capabilities.List() { + if !SupportedCapabilities.Has(capability) { client.Send(nil, server.name, "CAP", client.nick, "NAK", capString) return false } } - for capability := range capabilities { - client.capabilities[capability] = true - } + client.capabilities.Enable(capabilities.List()...) client.Send(nil, server.name, "CAP", client.nick, "ACK", capString) case "END": diff --git a/irc/caps/constants.go b/irc/caps/constants.go index 523b07af..c43b84be 100644 --- a/irc/caps/constants.go +++ b/irc/caps/constants.go @@ -45,6 +45,7 @@ const ( UserhostInNames Capability = "userhost-in-names" ) -func (capability Capability) String() string { +// Name returns the name of the given capability. +func (capability Capability) Name() string { return string(capability) } diff --git a/irc/caps/set.go b/irc/caps/set.go new file mode 100644 index 00000000..c5ad4207 --- /dev/null +++ b/irc/caps/set.go @@ -0,0 +1,115 @@ +// Package caps holds capabilities. +package caps + +import ( + "sort" + "strings" + "sync" +) + +// Set holds a set of enabled capabilities. +type Set struct { + sync.RWMutex + // capabilities holds the capabilities this manager has. + capabilities map[Capability]bool +} + +// NewSet returns a new Set, with the given capabilities enabled. +func NewSet(capabs ...Capability) *Set { + newSet := Set{ + capabilities: make(map[Capability]bool), + } + newSet.Enable(capabs...) + + return &newSet +} + +// Enable enables the given capabilities. +func (s *Set) Enable(capabs ...Capability) { + s.Lock() + defer s.Unlock() + + for _, capab := range capabs { + s.capabilities[capab] = true + } +} + +// Disable disables the given capabilities. +func (s *Set) Disable(capabs ...Capability) { + s.Lock() + defer s.Unlock() + + for _, capab := range capabs { + delete(s.capabilities, capab) + } +} + +// Add adds the given capabilities to this set. +// this is just a wrapper to allow more clear use. +func (s *Set) Add(capabs ...Capability) { + s.Enable(capabs...) +} + +// Remove removes the given capabilities from this set. +// this is just a wrapper to allow more clear use. +func (s *Set) Remove(capabs ...Capability) { + s.Disable(capabs...) +} + +// Has returns true if this set has the given capabilities. +func (s *Set) Has(caps ...Capability) bool { + s.RLock() + defer s.RUnlock() + + for _, cap := range caps { + if !s.capabilities[cap] { + return false + } + } + return true +} + +// List return a list of our enabled capabilities. +func (s *Set) List() []Capability { + s.RLock() + defer s.RUnlock() + + var allCaps []Capability + for capab := range s.capabilities { + allCaps = append(allCaps, capab) + } + + return allCaps +} + +// Count returns how many enabled caps this set has. +func (s *Set) Count() int { + s.RLock() + defer s.RUnlock() + + return len(s.capabilities) +} + +// String returns all of our enabled capabilities as a string. +func (s *Set) String(version Version, values *Values) string { + s.RLock() + defer s.RUnlock() + + var strs sort.StringSlice + + for capability := range s.capabilities { + capString := capability.Name() + if version == Cap302 { + val, exists := values.Get(capability) + if exists { + capString += "=" + val + } + } + strs = append(strs, capString) + } + + // sort the cap string before we send it out + sort.Sort(strs) + + return strings.Join(strs, " ") +} diff --git a/irc/caps/values.go b/irc/caps/values.go new file mode 100644 index 00000000..e8fb9979 --- /dev/null +++ b/irc/caps/values.go @@ -0,0 +1,42 @@ +package caps + +import "sync" + +// Values holds capability values. +type Values struct { + sync.RWMutex + // values holds our actual capability values. + values map[Capability]string +} + +// NewValues returns a new Values. +func NewValues() *Values { + return &Values{ + values: make(map[Capability]string), + } +} + +// Set sets the value for the given capability. +func (v *Values) Set(capab Capability, value string) { + v.Lock() + defer v.Unlock() + + v.values[capab] = value +} + +// Unset removes the value for the given capability, if it exists. +func (v *Values) Unset(capab Capability) { + v.Lock() + defer v.Unlock() + + delete(v.values, capab) +} + +// Get returns the value of the given capability, and whether one exists. +func (v *Values) Get(capab Capability) (string, bool) { + v.RLock() + defer v.RUnlock() + + value, exists := v.values[capab] + return value, exists +} diff --git a/irc/caps/version.go b/irc/caps/version.go new file mode 100644 index 00000000..b718f2fa --- /dev/null +++ b/irc/caps/version.go @@ -0,0 +1,11 @@ +package caps + +// Version is used to select which max version of CAP the client supports. +type Version uint + +const ( + // Cap301 refers to the base CAP spec. + Cap301 Version = 301 + // Cap302 refers to the IRCv3.2 CAP spec. + Cap302 Version = 302 +) diff --git a/irc/channel.go b/irc/channel.go index eba86d7d..94f6c280 100644 --- a/irc/channel.go +++ b/irc/channel.go @@ -165,8 +165,8 @@ func (modes ModeSet) Prefixes(isMultiPrefix bool) string { } func (channel *Channel) nicksNoMutex(target *Client) []string { - isMultiPrefix := (target != nil) && target.capabilities[caps.MultiPrefix] - isUserhostInNames := (target != nil) && target.capabilities[caps.UserhostInNames] + isMultiPrefix := (target != nil) && target.capabilities.Has(caps.MultiPrefix) + isUserhostInNames := (target != nil) && target.capabilities.Has(caps.UserhostInNames) nicks := make([]string, len(channel.members)) i := 0 for client, modes := range channel.members { @@ -262,7 +262,7 @@ func (channel *Channel) Join(client *Client, key string) { client.server.logger.Debug("join", fmt.Sprintf("%s joined channel %s", client.nick, channel.name)) for member := range channel.members { - if member.capabilities[caps.ExtendedJoin] { + if member.capabilities.Has(caps.ExtendedJoin) { member.Send(nil, client.nickMaskString, "JOIN", channel.name, client.account.Name, client.realname) } else { member.Send(nil, client.nickMaskString, "JOIN", channel.name) @@ -314,7 +314,7 @@ func (channel *Channel) Join(client *Client, key string) { return nil }) - if client.capabilities[caps.ExtendedJoin] { + if client.capabilities.Has(caps.ExtendedJoin) { client.Send(nil, client.nickMaskString, "JOIN", channel.name, client.account.Name, client.realname) } else { client.Send(nil, client.nickMaskString, "JOIN", channel.name) @@ -465,13 +465,13 @@ func (channel *Channel) sendMessage(msgid, cmd string, requiredCaps []caps.Capab // STATUSMSG continue } - if member == client && !client.capabilities[caps.EchoMessage] { + if member == client && !client.capabilities.Has(caps.EchoMessage) { continue } canReceive := true for _, capName := range requiredCaps { - if !member.capabilities[capName] { + if !member.capabilities.Has(capName) { canReceive = false } } @@ -480,7 +480,7 @@ func (channel *Channel) sendMessage(msgid, cmd string, requiredCaps []caps.Capab } var messageTagsToUse *map[string]ircmsg.TagValue - if member.capabilities[caps.MessageTags] { + if member.capabilities.Has(caps.MessageTags) { messageTagsToUse = clientOnlyTags } @@ -521,11 +521,11 @@ func (channel *Channel) sendSplitMessage(msgid, cmd string, minPrefix *Mode, cli // STATUSMSG continue } - if member == client && !client.capabilities[caps.EchoMessage] { + if member == client && !client.capabilities.Has(caps.EchoMessage) { continue } var tagsToUse *map[string]ircmsg.TagValue - if member.capabilities[caps.MessageTags] { + if member.capabilities.Has(caps.MessageTags) { tagsToUse = clientOnlyTags } @@ -729,7 +729,7 @@ func (channel *Channel) Invite(invitee *Client, inviter *Client) { // send invite-notify for member := range channel.members { - if member.capabilities[caps.InviteNotify] && member != inviter && member != invitee && channel.ClientIsAtLeast(member, Halfop) { + if member.capabilities.Has(caps.InviteNotify) && member != inviter && member != invitee && channel.ClientIsAtLeast(member, Halfop) { member.Send(nil, inviter.nickMaskString, "INVITE", invitee.nick, channel.name) } } diff --git a/irc/client.go b/irc/client.go index f19c0467..04172f97 100644 --- a/irc/client.go +++ b/irc/client.go @@ -45,9 +45,9 @@ type Client struct { atime time.Time authorized bool awayMessage string - capabilities CapabilitySet + capabilities *caps.Set capState CapState - capVersion CapVersion + capVersion caps.Version certfp string channels ChannelSet class *OperClass @@ -95,9 +95,9 @@ func NewClient(server *Server, conn net.Conn, isTLS bool) *Client { client := &Client{ atime: now, authorized: server.password == nil, - capabilities: make(CapabilitySet), + capabilities: caps.NewSet(), capState: CapNone, - capVersion: Cap301, + capVersion: caps.Cap301, channels: make(ChannelSet), ctime: now, flags: make(map[Mode]bool), @@ -178,10 +178,10 @@ func (client *Client) IPString() string { func (client *Client) maxlens() (int, int) { maxlenTags := 512 maxlenRest := 512 - if client.capabilities[caps.MessageTags] { + if client.capabilities.Has(caps.MessageTags) { maxlenTags = 4096 } - if client.capabilities[caps.MaxLine] { + if client.capabilities.Has(caps.MaxLine) { if client.server.limits.LineLen.Tags > maxlenTags { maxlenTags = client.server.limits.LineLen.Tags } @@ -357,13 +357,13 @@ func (client *Client) ModeString() (str string) { } // Friends refers to clients that share a channel with this client. -func (client *Client) Friends(Capabilities ...caps.Capability) ClientSet { +func (client *Client) Friends(capabs ...caps.Capability) ClientSet { friends := make(ClientSet) // make sure that I have the right caps hasCaps := true - for _, Cap := range Capabilities { - if !client.capabilities[Cap] { + for _, capab := range capabs { + if !client.capabilities.Has(capab) { hasCaps = false break } @@ -377,8 +377,8 @@ func (client *Client) Friends(Capabilities ...caps.Capability) ClientSet { for member := range channel.members { // make sure they have all the required caps hasCaps = true - for _, Cap := range Capabilities { - if !member.capabilities[Cap] { + for _, capab := range capabs { + if !member.capabilities.Has(capab) { hasCaps = false break } @@ -580,7 +580,7 @@ func (client *Client) destroy() { // SendSplitMsgFromClient sends an IRC PRIVMSG/NOTICE coming from a specific client. // Adds account-tag to the line as well. func (client *Client) SendSplitMsgFromClient(msgid string, from *Client, tags *map[string]ircmsg.TagValue, command, target string, message SplitMessage) { - if client.capabilities[caps.MaxLine] { + if client.capabilities.Has(caps.MaxLine) { client.SendFromClient(msgid, from, tags, command, target, message.ForMaxLine) } else { for _, str := range message.For512 { @@ -593,7 +593,7 @@ func (client *Client) SendSplitMsgFromClient(msgid string, from *Client, tags *m // Adds account-tag to the line as well. func (client *Client) SendFromClient(msgid string, from *Client, tags *map[string]ircmsg.TagValue, command string, params ...string) error { // attach account-tag - if client.capabilities[caps.AccountTag] && from.account != &NoAccount { + if client.capabilities.Has(caps.AccountTag) && from.account != &NoAccount { if tags == nil { tags = ircmsg.MakeTags("account", from.account.Name) } else { @@ -601,7 +601,7 @@ func (client *Client) SendFromClient(msgid string, from *Client, tags *map[strin } } // attach message-id - if len(msgid) > 0 && client.capabilities[caps.MessageTags] { + if len(msgid) > 0 && client.capabilities.Has(caps.MessageTags) { if tags == nil { tags = ircmsg.MakeTags("draft/msgid", msgid) } else { @@ -628,7 +628,7 @@ var ( // Send sends an IRC line to the client. func (client *Client) Send(tags *map[string]ircmsg.TagValue, prefix string, command string, params ...string) error { // attach server-time - if client.capabilities[caps.ServerTime] { + if client.capabilities.Has(caps.ServerTime) { t := time.Now().UTC().Format("2006-01-02T15:04:05.999Z") if tags == nil { tags = ircmsg.MakeTags("time", t) @@ -678,7 +678,7 @@ func (client *Client) Send(tags *map[string]ircmsg.TagValue, prefix string, comm // Notice sends the client a notice from the server. func (client *Client) Notice(text string) { limit := 400 - if client.capabilities[caps.MaxLine] { + if client.capabilities.Has(caps.MaxLine) { limit = client.server.limits.LineLen.Rest - 110 } lines := wordWrap(text, limit) diff --git a/irc/client_lookup_set.go b/irc/client_lookup_set.go index 74cf8114..11cc0332 100644 --- a/irc/client_lookup_set.go +++ b/irc/client_lookup_set.go @@ -156,7 +156,7 @@ func (clients *ClientLookupSet) Replace(oldNick, newNick string, client *Client) } // AllWithCaps returns all clients with the given capabilities. -func (clients *ClientLookupSet) AllWithCaps(caps ...caps.Capability) (set ClientSet) { +func (clients *ClientLookupSet) AllWithCaps(capabs ...caps.Capability) (set ClientSet) { set = make(ClientSet) clients.ByNickMutex.RLock() @@ -164,8 +164,8 @@ func (clients *ClientLookupSet) AllWithCaps(caps ...caps.Capability) (set Client var client *Client for _, client = range clients.ByNick { // make sure they have all the required caps - for _, Cap := range caps { - if !client.capabilities[Cap] { + for _, capab := range capabs { + if !client.capabilities.Has(capab) { continue } } diff --git a/irc/roleplay.go b/irc/roleplay.go index 9f1be9fc..d5d14d4d 100644 --- a/irc/roleplay.go +++ b/irc/roleplay.go @@ -90,7 +90,7 @@ func sendRoleplayMessage(server *Server, client *Client, source string, targetSt channel.membersMutex.RLock() for member := range channel.members { - if member == client && !client.capabilities[caps.EchoMessage] { + if member == client && !client.capabilities.Has(caps.EchoMessage) { continue } member.Send(nil, source, "PRIVMSG", channel.name, message) @@ -110,7 +110,7 @@ func sendRoleplayMessage(server *Server, client *Client, source string, targetSt } user.Send(nil, source, "PRIVMSG", user.nick, message) - if client.capabilities[caps.EchoMessage] { + if client.capabilities.Has(caps.EchoMessage) { client.Send(nil, source, "PRIVMSG", user.nick, message) } if user.flags[Away] { diff --git a/irc/server.go b/irc/server.go index 788a4054..7d9feabf 100644 --- a/irc/server.go +++ b/irc/server.go @@ -642,11 +642,11 @@ func renameHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool { // send RENAME messages for mcl := range channel.members { - if mcl.capabilities[caps.Rename] { + if mcl.capabilities.Has(caps.Rename) { mcl.Send(nil, client.nickMaskString, "RENAME", oldName, newName, reason) } else { mcl.Send(nil, mcl.nickMaskString, "PART", oldName, fmt.Sprintf("Channel renamed: %s", reason)) - if mcl.capabilities[caps.ExtendedJoin] { + if mcl.capabilities.Has(caps.ExtendedJoin) { accountName := "*" if mcl.account != nil { accountName = mcl.account.Name @@ -825,7 +825,7 @@ func privmsgHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool message := msg.Params[1] // split privmsg - splitMsg := server.splitMessage(message, !client.capabilities[caps.MaxLine]) + splitMsg := server.splitMessage(message, !client.capabilities.Has(caps.MaxLine)) for i, targetString := range targets { // max of four targets per privmsg @@ -869,7 +869,7 @@ func privmsgHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool } continue } - if !user.capabilities[caps.MessageTags] { + if !user.capabilities.Has(caps.MessageTags) { clientOnlyTags = nil } msgid := server.generateMessageID() @@ -878,7 +878,7 @@ func privmsgHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool if !user.flags[RegisteredOnly] || client.registered { user.SendSplitMsgFromClient(msgid, client, clientOnlyTags, "PRIVMSG", user.nick, splitMsg) } - if client.capabilities[caps.EchoMessage] { + if client.capabilities.Has(caps.EchoMessage) { client.SendSplitMsgFromClient(msgid, client, clientOnlyTags, "PRIVMSG", user.nick, splitMsg) } if user.flags[Away] { @@ -939,11 +939,11 @@ func tagmsgHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool { msgid := server.generateMessageID() // end user can't receive tagmsgs - if !user.capabilities[caps.MessageTags] { + if !user.capabilities.Has(caps.MessageTags) { continue } user.SendFromClient(msgid, client, clientOnlyTags, "TAGMSG", user.nick) - if client.capabilities[caps.EchoMessage] { + if client.capabilities.Has(caps.EchoMessage) { client.SendFromClient(msgid, client, clientOnlyTags, "TAGMSG", user.nick) } if user.flags[Away] { @@ -957,7 +957,7 @@ func tagmsgHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool { // WhoisChannelsNames returns the common channel names between two users. func (client *Client) WhoisChannelsNames(target *Client) []string { - isMultiPrefix := target.capabilities[caps.MultiPrefix] + isMultiPrefix := target.capabilities.Has(caps.MultiPrefix) var chstrs []string index := 0 for channel := range client.channels { @@ -1062,7 +1062,7 @@ func (target *Client) RplWhoReplyNoMutex(channel *Channel, client *Client) { } if channel != nil { - flags += channel.members[client].Prefixes(target.capabilities[caps.MultiPrefix]) + flags += channel.members[client].Prefixes(target.capabilities.Has(caps.MultiPrefix)) channelName = channel.name } target.Send(nil, target.server.name, RPL_WHOREPLY, target.nick, channelName, client.username, client.hostname, client.server.name, client.nick, flags, strconv.Itoa(client.hops)+" "+client.realname) @@ -1288,66 +1288,66 @@ func (server *Server) applyConfig(config *Config, initial bool) error { server.connectionLimitsMutex.Unlock() // setup new and removed caps - addedCaps := make(CapabilitySet) - removedCaps := make(CapabilitySet) - updatedCaps := make(CapabilitySet) + addedCaps := caps.NewSet() + removedCaps := caps.NewSet() + updatedCaps := caps.NewSet() // SASL if config.Accounts.AuthenticationEnabled && !server.accountAuthenticationEnabled { // enabling SASL - SupportedCapabilities[caps.SASL] = true - addedCaps[caps.SASL] = true + SupportedCapabilities.Enable(caps.SASL) + CapValues.Set(caps.SASL, "PLAIN,EXTERNAL") + addedCaps.Add(caps.SASL) } if !config.Accounts.AuthenticationEnabled && server.accountAuthenticationEnabled { // disabling SASL - SupportedCapabilities[caps.SASL] = false - removedCaps[caps.SASL] = true + SupportedCapabilities.Disable(caps.SASL) + removedCaps.Add(caps.SASL) } server.accountAuthenticationEnabled = config.Accounts.AuthenticationEnabled // STS stsValue := config.Server.STS.Value() var stsDisabled bool - server.logger.Debug("rehash", "STS Vals", CapValues[caps.STS], stsValue, fmt.Sprintf("server[%v] config[%v]", server.stsEnabled, config.Server.STS.Enabled)) + stsCurrentCapValue, _ := CapValues.Get(caps.STS) + server.logger.Debug("rehash", "STS Vals", stsCurrentCapValue, stsValue, fmt.Sprintf("server[%v] config[%v]", server.stsEnabled, config.Server.STS.Enabled)) if config.Server.STS.Enabled && !server.stsEnabled { // enabling STS - SupportedCapabilities[caps.STS] = true - addedCaps[caps.STS] = true - CapValues[caps.STS] = stsValue + SupportedCapabilities.Enable(caps.STS) + addedCaps.Add(caps.STS) + CapValues.Set(caps.STS, stsValue) } else if !config.Server.STS.Enabled && server.stsEnabled { // disabling STS - SupportedCapabilities[caps.STS] = false - removedCaps[caps.STS] = true + SupportedCapabilities.Disable(caps.STS) + removedCaps.Add(caps.STS) stsDisabled = true - } else if config.Server.STS.Enabled && server.stsEnabled && stsValue != CapValues[caps.STS] { + } else if config.Server.STS.Enabled && server.stsEnabled && stsValue != stsCurrentCapValue { // STS policy updated - CapValues[caps.STS] = stsValue - updatedCaps[caps.STS] = true + CapValues.Set(caps.STS, stsValue) + updatedCaps.Add(caps.STS) } server.stsEnabled = config.Server.STS.Enabled // burst new and removed caps var capBurstClients ClientSet - added := make(map[CapVersion]string) + added := make(map[caps.Version]string) var removed string // updated caps get DEL'd and then NEW'd // so, we can just add updated ones to both removed and added lists here and they'll be correctly handled - server.logger.Debug("rehash", "Updated Caps", updatedCaps.String(Cap301), strconv.Itoa(len(updatedCaps))) - if len(updatedCaps) > 0 { - for capab := range updatedCaps { - addedCaps[capab] = true - removedCaps[capab] = true - } + server.logger.Debug("rehash", "Updated Caps", updatedCaps.String(caps.Cap301, CapValues), strconv.Itoa(updatedCaps.Count())) + for _, capab := range updatedCaps.List() { + addedCaps.Enable(capab) + removedCaps.Enable(capab) } - if len(addedCaps) > 0 || len(removedCaps) > 0 { + if 0 < addedCaps.Count() || 0 < removedCaps.Count() { capBurstClients = server.clients.AllWithCaps(caps.CapNotify) - added[Cap301] = addedCaps.String(Cap301) - added[Cap302] = addedCaps.String(Cap302) - // removed never has values - removed = removedCaps.String(Cap301) + added[caps.Cap301] = addedCaps.String(caps.Cap301, CapValues) + added[caps.Cap302] = addedCaps.String(caps.Cap302, CapValues) + // removed never has values, so we leave it as Cap301 + removed = removedCaps.String(caps.Cap301, CapValues) } for sClient := range capBurstClients { @@ -1355,18 +1355,18 @@ func (server *Server) applyConfig(config *Config, initial bool) error { // remove STS policy //TODO(dan): this is an ugly hack. we can write this better. stsPolicy := "sts=duration=0" - if len(addedCaps) > 0 { - added[Cap302] = added[Cap302] + " " + stsPolicy + if 0 < addedCaps.Count() { + added[caps.Cap302] = added[caps.Cap302] + " " + stsPolicy } else { - addedCaps[caps.STS] = true - added[Cap302] = stsPolicy + addedCaps.Enable(caps.STS) + added[caps.Cap302] = stsPolicy } } // DEL caps and then send NEW ones so that updated caps get removed/added correctly - if len(removedCaps) > 0 { + if 0 < removedCaps.Count() { sClient.Send(nil, server.name, "CAP", sClient.nick, "DEL", removed) } - if len(addedCaps) > 0 { + if 0 < addedCaps.Count() { sClient.Send(nil, server.name, "CAP", sClient.nick, "NEW", added[sClient.capVersion]) } } @@ -1707,7 +1707,7 @@ func noticeHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool { message := msg.Params[1] // split privmsg - splitMsg := server.splitMessage(message, !client.capabilities[caps.MaxLine]) + splitMsg := server.splitMessage(message, !client.capabilities.Has(caps.MaxLine)) for i, targetString := range targets { // max of four targets per privmsg @@ -1748,7 +1748,7 @@ func noticeHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool { // errors silently ignored with NOTICE as per RFC continue } - if !user.capabilities[caps.MessageTags] { + if !user.capabilities.Has(caps.MessageTags) { clientOnlyTags = nil } msgid := server.generateMessageID() @@ -1757,7 +1757,7 @@ func noticeHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool { if !user.flags[RegisteredOnly] || client.registered { user.SendSplitMsgFromClient(msgid, client, clientOnlyTags, "NOTICE", user.nick, splitMsg) } - if client.capabilities[caps.EchoMessage] { + if client.capabilities.Has(caps.EchoMessage) { client.SendSplitMsgFromClient(msgid, client, clientOnlyTags, "NOTICE", user.nick, splitMsg) } }