3
0
mirror of https://github.com/ergochat/ergo.git synced 2024-11-14 07:59:31 +01:00

caps: Move most capability-handling types into the caps package

This commit is contained in:
Daniel Oaks 2017-09-29 17:25:58 +10:00
parent 85bfe3818b
commit 275449e6cc
10 changed files with 259 additions and 138 deletions

View File

@ -13,28 +13,12 @@ import (
var ( var (
// SupportedCapabilities are the caps we advertise. // SupportedCapabilities are the caps we advertise.
SupportedCapabilities = CapabilitySet{ // MaxLine, SASL and STS are set during server startup.
caps.AccountTag: true, SupportedCapabilities = caps.NewSet(caps.AccountTag, caps.AccountNotify, caps.AwayNotify, caps.CapNotify, caps.ChgHost, caps.EchoMessage, caps.ExtendedJoin, caps.InviteNotify, caps.MessageTags, caps.MultiPrefix, caps.Rename, caps.ServerTime, caps.UserhostInNames)
caps.AccountNotify: true,
caps.AwayNotify: true,
caps.CapNotify: true,
caps.ChgHost: true,
caps.EchoMessage: true,
caps.ExtendedJoin: true,
caps.InviteNotify: true,
// MaxLine is set during server startup
caps.MessageTags: true,
caps.MultiPrefix: true,
caps.Rename: true,
// SASL is set during server startup
caps.ServerTime: true,
// STS is set during server startup
caps.UserhostInNames: true,
}
// CapValues are the actual values we advertise to v3.2 clients. // CapValues are the actual values we advertise to v3.2 clients.
CapValues = map[caps.Capability]string{ // actual values are set during server startup.
caps.SASL: "PLAIN,EXTERNAL", CapValues = caps.NewValues()
}
) )
// CapState shows whether we're negotiating caps, finished, etc for connection registration. // CapState shows whether we're negotiating caps, finished, etc for connection registration.
@ -49,40 +33,10 @@ const (
CapNegotiated CapState = iota CapNegotiated CapState = iota
) )
// CapVersion is used to select which max version of CAP the client supports.
type CapVersion uint
const (
// Cap301 refers to the base CAP spec.
Cap301 CapVersion = 301
// Cap302 refers to the IRCv3.2 CAP spec.
Cap302 CapVersion = 302
)
// CapabilitySet is used to track supported, enabled, and existing caps.
type CapabilitySet map[caps.Capability]bool
func (set CapabilitySet) String(version CapVersion) string {
strs := make([]string, len(set))
index := 0
for capability := range set {
capString := string(capability)
if version == Cap302 {
val, exists := CapValues[capability]
if exists {
capString += "=" + val
}
}
strs[index] = capString
index++
}
return strings.Join(strs, " ")
}
// CAP <subcmd> [<caps>] // CAP <subcmd> [<caps>]
func capHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool { func capHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool {
subCommand := strings.ToUpper(msg.Params[0]) subCommand := strings.ToUpper(msg.Params[0])
capabilities := make(CapabilitySet) capabilities := caps.NewSet()
var capString string var capString string
if len(msg.Params) > 1 { if len(msg.Params) > 1 {
@ -90,7 +44,7 @@ func capHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool {
strs := strings.Split(capString, " ") strs := strings.Split(capString, " ")
for _, str := range strs { for _, str := range strs {
if len(str) > 0 { if len(str) > 0 {
capabilities[caps.Capability(str)] = true capabilities.Enable(caps.Capability(str))
} }
} }
} }
@ -107,22 +61,20 @@ func capHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool {
// the server.name source... otherwise it doesn't respond to the CAP message with // the server.name source... otherwise it doesn't respond to the CAP message with
// anything and just hangs on connection. // anything and just hangs on connection.
//TODO(dan): limit number of caps and send it multiline in 3.2 style as appropriate. //TODO(dan): limit number of caps and send it multiline in 3.2 style as appropriate.
client.Send(nil, server.name, "CAP", client.nick, subCommand, SupportedCapabilities.String(client.capVersion)) client.Send(nil, server.name, "CAP", client.nick, subCommand, SupportedCapabilities.String(client.capVersion, CapValues))
case "LIST": case "LIST":
client.Send(nil, server.name, "CAP", client.nick, subCommand, client.capabilities.String(Cap301)) // values not sent on LIST so force 3.1 client.Send(nil, server.name, "CAP", client.nick, subCommand, client.capabilities.String(caps.Cap301, CapValues)) // values not sent on LIST so force 3.1
case "REQ": case "REQ":
// make sure all capabilities actually exist // make sure all capabilities actually exist
for capability := range capabilities { for _, capability := range capabilities.List() {
if !SupportedCapabilities[capability] { if !SupportedCapabilities.Has(capability) {
client.Send(nil, server.name, "CAP", client.nick, "NAK", capString) client.Send(nil, server.name, "CAP", client.nick, "NAK", capString)
return false return false
} }
} }
for capability := range capabilities { client.capabilities.Enable(capabilities.List()...)
client.capabilities[capability] = true
}
client.Send(nil, server.name, "CAP", client.nick, "ACK", capString) client.Send(nil, server.name, "CAP", client.nick, "ACK", capString)
case "END": case "END":

View File

@ -45,6 +45,7 @@ const (
UserhostInNames Capability = "userhost-in-names" UserhostInNames Capability = "userhost-in-names"
) )
func (capability Capability) String() string { // Name returns the name of the given capability.
func (capability Capability) Name() string {
return string(capability) return string(capability)
} }

115
irc/caps/set.go Normal file
View File

@ -0,0 +1,115 @@
// Package caps holds capabilities.
package caps
import (
"sort"
"strings"
"sync"
)
// Set holds a set of enabled capabilities.
type Set struct {
sync.RWMutex
// capabilities holds the capabilities this manager has.
capabilities map[Capability]bool
}
// NewSet returns a new Set, with the given capabilities enabled.
func NewSet(capabs ...Capability) *Set {
newSet := Set{
capabilities: make(map[Capability]bool),
}
newSet.Enable(capabs...)
return &newSet
}
// Enable enables the given capabilities.
func (s *Set) Enable(capabs ...Capability) {
s.Lock()
defer s.Unlock()
for _, capab := range capabs {
s.capabilities[capab] = true
}
}
// Disable disables the given capabilities.
func (s *Set) Disable(capabs ...Capability) {
s.Lock()
defer s.Unlock()
for _, capab := range capabs {
delete(s.capabilities, capab)
}
}
// Add adds the given capabilities to this set.
// this is just a wrapper to allow more clear use.
func (s *Set) Add(capabs ...Capability) {
s.Enable(capabs...)
}
// Remove removes the given capabilities from this set.
// this is just a wrapper to allow more clear use.
func (s *Set) Remove(capabs ...Capability) {
s.Disable(capabs...)
}
// Has returns true if this set has the given capabilities.
func (s *Set) Has(caps ...Capability) bool {
s.RLock()
defer s.RUnlock()
for _, cap := range caps {
if !s.capabilities[cap] {
return false
}
}
return true
}
// List return a list of our enabled capabilities.
func (s *Set) List() []Capability {
s.RLock()
defer s.RUnlock()
var allCaps []Capability
for capab := range s.capabilities {
allCaps = append(allCaps, capab)
}
return allCaps
}
// Count returns how many enabled caps this set has.
func (s *Set) Count() int {
s.RLock()
defer s.RUnlock()
return len(s.capabilities)
}
// String returns all of our enabled capabilities as a string.
func (s *Set) String(version Version, values *Values) string {
s.RLock()
defer s.RUnlock()
var strs sort.StringSlice
for capability := range s.capabilities {
capString := capability.Name()
if version == Cap302 {
val, exists := values.Get(capability)
if exists {
capString += "=" + val
}
}
strs = append(strs, capString)
}
// sort the cap string before we send it out
sort.Sort(strs)
return strings.Join(strs, " ")
}

42
irc/caps/values.go Normal file
View File

@ -0,0 +1,42 @@
package caps
import "sync"
// Values holds capability values.
type Values struct {
sync.RWMutex
// values holds our actual capability values.
values map[Capability]string
}
// NewValues returns a new Values.
func NewValues() *Values {
return &Values{
values: make(map[Capability]string),
}
}
// Set sets the value for the given capability.
func (v *Values) Set(capab Capability, value string) {
v.Lock()
defer v.Unlock()
v.values[capab] = value
}
// Unset removes the value for the given capability, if it exists.
func (v *Values) Unset(capab Capability) {
v.Lock()
defer v.Unlock()
delete(v.values, capab)
}
// Get returns the value of the given capability, and whether one exists.
func (v *Values) Get(capab Capability) (string, bool) {
v.RLock()
defer v.RUnlock()
value, exists := v.values[capab]
return value, exists
}

11
irc/caps/version.go Normal file
View File

@ -0,0 +1,11 @@
package caps
// Version is used to select which max version of CAP the client supports.
type Version uint
const (
// Cap301 refers to the base CAP spec.
Cap301 Version = 301
// Cap302 refers to the IRCv3.2 CAP spec.
Cap302 Version = 302
)

View File

@ -165,8 +165,8 @@ func (modes ModeSet) Prefixes(isMultiPrefix bool) string {
} }
func (channel *Channel) nicksNoMutex(target *Client) []string { func (channel *Channel) nicksNoMutex(target *Client) []string {
isMultiPrefix := (target != nil) && target.capabilities[caps.MultiPrefix] isMultiPrefix := (target != nil) && target.capabilities.Has(caps.MultiPrefix)
isUserhostInNames := (target != nil) && target.capabilities[caps.UserhostInNames] isUserhostInNames := (target != nil) && target.capabilities.Has(caps.UserhostInNames)
nicks := make([]string, len(channel.members)) nicks := make([]string, len(channel.members))
i := 0 i := 0
for client, modes := range channel.members { for client, modes := range channel.members {
@ -262,7 +262,7 @@ func (channel *Channel) Join(client *Client, key string) {
client.server.logger.Debug("join", fmt.Sprintf("%s joined channel %s", client.nick, channel.name)) client.server.logger.Debug("join", fmt.Sprintf("%s joined channel %s", client.nick, channel.name))
for member := range channel.members { for member := range channel.members {
if member.capabilities[caps.ExtendedJoin] { if member.capabilities.Has(caps.ExtendedJoin) {
member.Send(nil, client.nickMaskString, "JOIN", channel.name, client.account.Name, client.realname) member.Send(nil, client.nickMaskString, "JOIN", channel.name, client.account.Name, client.realname)
} else { } else {
member.Send(nil, client.nickMaskString, "JOIN", channel.name) member.Send(nil, client.nickMaskString, "JOIN", channel.name)
@ -314,7 +314,7 @@ func (channel *Channel) Join(client *Client, key string) {
return nil return nil
}) })
if client.capabilities[caps.ExtendedJoin] { if client.capabilities.Has(caps.ExtendedJoin) {
client.Send(nil, client.nickMaskString, "JOIN", channel.name, client.account.Name, client.realname) client.Send(nil, client.nickMaskString, "JOIN", channel.name, client.account.Name, client.realname)
} else { } else {
client.Send(nil, client.nickMaskString, "JOIN", channel.name) client.Send(nil, client.nickMaskString, "JOIN", channel.name)
@ -465,13 +465,13 @@ func (channel *Channel) sendMessage(msgid, cmd string, requiredCaps []caps.Capab
// STATUSMSG // STATUSMSG
continue continue
} }
if member == client && !client.capabilities[caps.EchoMessage] { if member == client && !client.capabilities.Has(caps.EchoMessage) {
continue continue
} }
canReceive := true canReceive := true
for _, capName := range requiredCaps { for _, capName := range requiredCaps {
if !member.capabilities[capName] { if !member.capabilities.Has(capName) {
canReceive = false canReceive = false
} }
} }
@ -480,7 +480,7 @@ func (channel *Channel) sendMessage(msgid, cmd string, requiredCaps []caps.Capab
} }
var messageTagsToUse *map[string]ircmsg.TagValue var messageTagsToUse *map[string]ircmsg.TagValue
if member.capabilities[caps.MessageTags] { if member.capabilities.Has(caps.MessageTags) {
messageTagsToUse = clientOnlyTags messageTagsToUse = clientOnlyTags
} }
@ -521,11 +521,11 @@ func (channel *Channel) sendSplitMessage(msgid, cmd string, minPrefix *Mode, cli
// STATUSMSG // STATUSMSG
continue continue
} }
if member == client && !client.capabilities[caps.EchoMessage] { if member == client && !client.capabilities.Has(caps.EchoMessage) {
continue continue
} }
var tagsToUse *map[string]ircmsg.TagValue var tagsToUse *map[string]ircmsg.TagValue
if member.capabilities[caps.MessageTags] { if member.capabilities.Has(caps.MessageTags) {
tagsToUse = clientOnlyTags tagsToUse = clientOnlyTags
} }
@ -729,7 +729,7 @@ func (channel *Channel) Invite(invitee *Client, inviter *Client) {
// send invite-notify // send invite-notify
for member := range channel.members { for member := range channel.members {
if member.capabilities[caps.InviteNotify] && member != inviter && member != invitee && channel.ClientIsAtLeast(member, Halfop) { if member.capabilities.Has(caps.InviteNotify) && member != inviter && member != invitee && channel.ClientIsAtLeast(member, Halfop) {
member.Send(nil, inviter.nickMaskString, "INVITE", invitee.nick, channel.name) member.Send(nil, inviter.nickMaskString, "INVITE", invitee.nick, channel.name)
} }
} }

View File

@ -45,9 +45,9 @@ type Client struct {
atime time.Time atime time.Time
authorized bool authorized bool
awayMessage string awayMessage string
capabilities CapabilitySet capabilities *caps.Set
capState CapState capState CapState
capVersion CapVersion capVersion caps.Version
certfp string certfp string
channels ChannelSet channels ChannelSet
class *OperClass class *OperClass
@ -95,9 +95,9 @@ func NewClient(server *Server, conn net.Conn, isTLS bool) *Client {
client := &Client{ client := &Client{
atime: now, atime: now,
authorized: server.password == nil, authorized: server.password == nil,
capabilities: make(CapabilitySet), capabilities: caps.NewSet(),
capState: CapNone, capState: CapNone,
capVersion: Cap301, capVersion: caps.Cap301,
channels: make(ChannelSet), channels: make(ChannelSet),
ctime: now, ctime: now,
flags: make(map[Mode]bool), flags: make(map[Mode]bool),
@ -178,10 +178,10 @@ func (client *Client) IPString() string {
func (client *Client) maxlens() (int, int) { func (client *Client) maxlens() (int, int) {
maxlenTags := 512 maxlenTags := 512
maxlenRest := 512 maxlenRest := 512
if client.capabilities[caps.MessageTags] { if client.capabilities.Has(caps.MessageTags) {
maxlenTags = 4096 maxlenTags = 4096
} }
if client.capabilities[caps.MaxLine] { if client.capabilities.Has(caps.MaxLine) {
if client.server.limits.LineLen.Tags > maxlenTags { if client.server.limits.LineLen.Tags > maxlenTags {
maxlenTags = client.server.limits.LineLen.Tags maxlenTags = client.server.limits.LineLen.Tags
} }
@ -357,13 +357,13 @@ func (client *Client) ModeString() (str string) {
} }
// Friends refers to clients that share a channel with this client. // Friends refers to clients that share a channel with this client.
func (client *Client) Friends(Capabilities ...caps.Capability) ClientSet { func (client *Client) Friends(capabs ...caps.Capability) ClientSet {
friends := make(ClientSet) friends := make(ClientSet)
// make sure that I have the right caps // make sure that I have the right caps
hasCaps := true hasCaps := true
for _, Cap := range Capabilities { for _, capab := range capabs {
if !client.capabilities[Cap] { if !client.capabilities.Has(capab) {
hasCaps = false hasCaps = false
break break
} }
@ -377,8 +377,8 @@ func (client *Client) Friends(Capabilities ...caps.Capability) ClientSet {
for member := range channel.members { for member := range channel.members {
// make sure they have all the required caps // make sure they have all the required caps
hasCaps = true hasCaps = true
for _, Cap := range Capabilities { for _, capab := range capabs {
if !member.capabilities[Cap] { if !member.capabilities.Has(capab) {
hasCaps = false hasCaps = false
break break
} }
@ -580,7 +580,7 @@ func (client *Client) destroy() {
// SendSplitMsgFromClient sends an IRC PRIVMSG/NOTICE coming from a specific client. // SendSplitMsgFromClient sends an IRC PRIVMSG/NOTICE coming from a specific client.
// Adds account-tag to the line as well. // Adds account-tag to the line as well.
func (client *Client) SendSplitMsgFromClient(msgid string, from *Client, tags *map[string]ircmsg.TagValue, command, target string, message SplitMessage) { func (client *Client) SendSplitMsgFromClient(msgid string, from *Client, tags *map[string]ircmsg.TagValue, command, target string, message SplitMessage) {
if client.capabilities[caps.MaxLine] { if client.capabilities.Has(caps.MaxLine) {
client.SendFromClient(msgid, from, tags, command, target, message.ForMaxLine) client.SendFromClient(msgid, from, tags, command, target, message.ForMaxLine)
} else { } else {
for _, str := range message.For512 { for _, str := range message.For512 {
@ -593,7 +593,7 @@ func (client *Client) SendSplitMsgFromClient(msgid string, from *Client, tags *m
// Adds account-tag to the line as well. // Adds account-tag to the line as well.
func (client *Client) SendFromClient(msgid string, from *Client, tags *map[string]ircmsg.TagValue, command string, params ...string) error { func (client *Client) SendFromClient(msgid string, from *Client, tags *map[string]ircmsg.TagValue, command string, params ...string) error {
// attach account-tag // attach account-tag
if client.capabilities[caps.AccountTag] && from.account != &NoAccount { if client.capabilities.Has(caps.AccountTag) && from.account != &NoAccount {
if tags == nil { if tags == nil {
tags = ircmsg.MakeTags("account", from.account.Name) tags = ircmsg.MakeTags("account", from.account.Name)
} else { } else {
@ -601,7 +601,7 @@ func (client *Client) SendFromClient(msgid string, from *Client, tags *map[strin
} }
} }
// attach message-id // attach message-id
if len(msgid) > 0 && client.capabilities[caps.MessageTags] { if len(msgid) > 0 && client.capabilities.Has(caps.MessageTags) {
if tags == nil { if tags == nil {
tags = ircmsg.MakeTags("draft/msgid", msgid) tags = ircmsg.MakeTags("draft/msgid", msgid)
} else { } else {
@ -628,7 +628,7 @@ var (
// Send sends an IRC line to the client. // Send sends an IRC line to the client.
func (client *Client) Send(tags *map[string]ircmsg.TagValue, prefix string, command string, params ...string) error { func (client *Client) Send(tags *map[string]ircmsg.TagValue, prefix string, command string, params ...string) error {
// attach server-time // attach server-time
if client.capabilities[caps.ServerTime] { if client.capabilities.Has(caps.ServerTime) {
t := time.Now().UTC().Format("2006-01-02T15:04:05.999Z") t := time.Now().UTC().Format("2006-01-02T15:04:05.999Z")
if tags == nil { if tags == nil {
tags = ircmsg.MakeTags("time", t) tags = ircmsg.MakeTags("time", t)
@ -678,7 +678,7 @@ func (client *Client) Send(tags *map[string]ircmsg.TagValue, prefix string, comm
// Notice sends the client a notice from the server. // Notice sends the client a notice from the server.
func (client *Client) Notice(text string) { func (client *Client) Notice(text string) {
limit := 400 limit := 400
if client.capabilities[caps.MaxLine] { if client.capabilities.Has(caps.MaxLine) {
limit = client.server.limits.LineLen.Rest - 110 limit = client.server.limits.LineLen.Rest - 110
} }
lines := wordWrap(text, limit) lines := wordWrap(text, limit)

View File

@ -156,7 +156,7 @@ func (clients *ClientLookupSet) Replace(oldNick, newNick string, client *Client)
} }
// AllWithCaps returns all clients with the given capabilities. // AllWithCaps returns all clients with the given capabilities.
func (clients *ClientLookupSet) AllWithCaps(caps ...caps.Capability) (set ClientSet) { func (clients *ClientLookupSet) AllWithCaps(capabs ...caps.Capability) (set ClientSet) {
set = make(ClientSet) set = make(ClientSet)
clients.ByNickMutex.RLock() clients.ByNickMutex.RLock()
@ -164,8 +164,8 @@ func (clients *ClientLookupSet) AllWithCaps(caps ...caps.Capability) (set Client
var client *Client var client *Client
for _, client = range clients.ByNick { for _, client = range clients.ByNick {
// make sure they have all the required caps // make sure they have all the required caps
for _, Cap := range caps { for _, capab := range capabs {
if !client.capabilities[Cap] { if !client.capabilities.Has(capab) {
continue continue
} }
} }

View File

@ -90,7 +90,7 @@ func sendRoleplayMessage(server *Server, client *Client, source string, targetSt
channel.membersMutex.RLock() channel.membersMutex.RLock()
for member := range channel.members { for member := range channel.members {
if member == client && !client.capabilities[caps.EchoMessage] { if member == client && !client.capabilities.Has(caps.EchoMessage) {
continue continue
} }
member.Send(nil, source, "PRIVMSG", channel.name, message) member.Send(nil, source, "PRIVMSG", channel.name, message)
@ -110,7 +110,7 @@ func sendRoleplayMessage(server *Server, client *Client, source string, targetSt
} }
user.Send(nil, source, "PRIVMSG", user.nick, message) user.Send(nil, source, "PRIVMSG", user.nick, message)
if client.capabilities[caps.EchoMessage] { if client.capabilities.Has(caps.EchoMessage) {
client.Send(nil, source, "PRIVMSG", user.nick, message) client.Send(nil, source, "PRIVMSG", user.nick, message)
} }
if user.flags[Away] { if user.flags[Away] {

View File

@ -642,11 +642,11 @@ func renameHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool {
// send RENAME messages // send RENAME messages
for mcl := range channel.members { for mcl := range channel.members {
if mcl.capabilities[caps.Rename] { if mcl.capabilities.Has(caps.Rename) {
mcl.Send(nil, client.nickMaskString, "RENAME", oldName, newName, reason) mcl.Send(nil, client.nickMaskString, "RENAME", oldName, newName, reason)
} else { } else {
mcl.Send(nil, mcl.nickMaskString, "PART", oldName, fmt.Sprintf("Channel renamed: %s", reason)) mcl.Send(nil, mcl.nickMaskString, "PART", oldName, fmt.Sprintf("Channel renamed: %s", reason))
if mcl.capabilities[caps.ExtendedJoin] { if mcl.capabilities.Has(caps.ExtendedJoin) {
accountName := "*" accountName := "*"
if mcl.account != nil { if mcl.account != nil {
accountName = mcl.account.Name accountName = mcl.account.Name
@ -825,7 +825,7 @@ func privmsgHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool
message := msg.Params[1] message := msg.Params[1]
// split privmsg // split privmsg
splitMsg := server.splitMessage(message, !client.capabilities[caps.MaxLine]) splitMsg := server.splitMessage(message, !client.capabilities.Has(caps.MaxLine))
for i, targetString := range targets { for i, targetString := range targets {
// max of four targets per privmsg // max of four targets per privmsg
@ -869,7 +869,7 @@ func privmsgHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool
} }
continue continue
} }
if !user.capabilities[caps.MessageTags] { if !user.capabilities.Has(caps.MessageTags) {
clientOnlyTags = nil clientOnlyTags = nil
} }
msgid := server.generateMessageID() msgid := server.generateMessageID()
@ -878,7 +878,7 @@ func privmsgHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool
if !user.flags[RegisteredOnly] || client.registered { if !user.flags[RegisteredOnly] || client.registered {
user.SendSplitMsgFromClient(msgid, client, clientOnlyTags, "PRIVMSG", user.nick, splitMsg) user.SendSplitMsgFromClient(msgid, client, clientOnlyTags, "PRIVMSG", user.nick, splitMsg)
} }
if client.capabilities[caps.EchoMessage] { if client.capabilities.Has(caps.EchoMessage) {
client.SendSplitMsgFromClient(msgid, client, clientOnlyTags, "PRIVMSG", user.nick, splitMsg) client.SendSplitMsgFromClient(msgid, client, clientOnlyTags, "PRIVMSG", user.nick, splitMsg)
} }
if user.flags[Away] { if user.flags[Away] {
@ -939,11 +939,11 @@ func tagmsgHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool {
msgid := server.generateMessageID() msgid := server.generateMessageID()
// end user can't receive tagmsgs // end user can't receive tagmsgs
if !user.capabilities[caps.MessageTags] { if !user.capabilities.Has(caps.MessageTags) {
continue continue
} }
user.SendFromClient(msgid, client, clientOnlyTags, "TAGMSG", user.nick) user.SendFromClient(msgid, client, clientOnlyTags, "TAGMSG", user.nick)
if client.capabilities[caps.EchoMessage] { if client.capabilities.Has(caps.EchoMessage) {
client.SendFromClient(msgid, client, clientOnlyTags, "TAGMSG", user.nick) client.SendFromClient(msgid, client, clientOnlyTags, "TAGMSG", user.nick)
} }
if user.flags[Away] { if user.flags[Away] {
@ -957,7 +957,7 @@ func tagmsgHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool {
// WhoisChannelsNames returns the common channel names between two users. // WhoisChannelsNames returns the common channel names between two users.
func (client *Client) WhoisChannelsNames(target *Client) []string { func (client *Client) WhoisChannelsNames(target *Client) []string {
isMultiPrefix := target.capabilities[caps.MultiPrefix] isMultiPrefix := target.capabilities.Has(caps.MultiPrefix)
var chstrs []string var chstrs []string
index := 0 index := 0
for channel := range client.channels { for channel := range client.channels {
@ -1062,7 +1062,7 @@ func (target *Client) RplWhoReplyNoMutex(channel *Channel, client *Client) {
} }
if channel != nil { if channel != nil {
flags += channel.members[client].Prefixes(target.capabilities[caps.MultiPrefix]) flags += channel.members[client].Prefixes(target.capabilities.Has(caps.MultiPrefix))
channelName = channel.name channelName = channel.name
} }
target.Send(nil, target.server.name, RPL_WHOREPLY, target.nick, channelName, client.username, client.hostname, client.server.name, client.nick, flags, strconv.Itoa(client.hops)+" "+client.realname) target.Send(nil, target.server.name, RPL_WHOREPLY, target.nick, channelName, client.username, client.hostname, client.server.name, client.nick, flags, strconv.Itoa(client.hops)+" "+client.realname)
@ -1288,66 +1288,66 @@ func (server *Server) applyConfig(config *Config, initial bool) error {
server.connectionLimitsMutex.Unlock() server.connectionLimitsMutex.Unlock()
// setup new and removed caps // setup new and removed caps
addedCaps := make(CapabilitySet) addedCaps := caps.NewSet()
removedCaps := make(CapabilitySet) removedCaps := caps.NewSet()
updatedCaps := make(CapabilitySet) updatedCaps := caps.NewSet()
// SASL // SASL
if config.Accounts.AuthenticationEnabled && !server.accountAuthenticationEnabled { if config.Accounts.AuthenticationEnabled && !server.accountAuthenticationEnabled {
// enabling SASL // enabling SASL
SupportedCapabilities[caps.SASL] = true SupportedCapabilities.Enable(caps.SASL)
addedCaps[caps.SASL] = true CapValues.Set(caps.SASL, "PLAIN,EXTERNAL")
addedCaps.Add(caps.SASL)
} }
if !config.Accounts.AuthenticationEnabled && server.accountAuthenticationEnabled { if !config.Accounts.AuthenticationEnabled && server.accountAuthenticationEnabled {
// disabling SASL // disabling SASL
SupportedCapabilities[caps.SASL] = false SupportedCapabilities.Disable(caps.SASL)
removedCaps[caps.SASL] = true removedCaps.Add(caps.SASL)
} }
server.accountAuthenticationEnabled = config.Accounts.AuthenticationEnabled server.accountAuthenticationEnabled = config.Accounts.AuthenticationEnabled
// STS // STS
stsValue := config.Server.STS.Value() stsValue := config.Server.STS.Value()
var stsDisabled bool var stsDisabled bool
server.logger.Debug("rehash", "STS Vals", CapValues[caps.STS], stsValue, fmt.Sprintf("server[%v] config[%v]", server.stsEnabled, config.Server.STS.Enabled)) stsCurrentCapValue, _ := CapValues.Get(caps.STS)
server.logger.Debug("rehash", "STS Vals", stsCurrentCapValue, stsValue, fmt.Sprintf("server[%v] config[%v]", server.stsEnabled, config.Server.STS.Enabled))
if config.Server.STS.Enabled && !server.stsEnabled { if config.Server.STS.Enabled && !server.stsEnabled {
// enabling STS // enabling STS
SupportedCapabilities[caps.STS] = true SupportedCapabilities.Enable(caps.STS)
addedCaps[caps.STS] = true addedCaps.Add(caps.STS)
CapValues[caps.STS] = stsValue CapValues.Set(caps.STS, stsValue)
} else if !config.Server.STS.Enabled && server.stsEnabled { } else if !config.Server.STS.Enabled && server.stsEnabled {
// disabling STS // disabling STS
SupportedCapabilities[caps.STS] = false SupportedCapabilities.Disable(caps.STS)
removedCaps[caps.STS] = true removedCaps.Add(caps.STS)
stsDisabled = true stsDisabled = true
} else if config.Server.STS.Enabled && server.stsEnabled && stsValue != CapValues[caps.STS] { } else if config.Server.STS.Enabled && server.stsEnabled && stsValue != stsCurrentCapValue {
// STS policy updated // STS policy updated
CapValues[caps.STS] = stsValue CapValues.Set(caps.STS, stsValue)
updatedCaps[caps.STS] = true updatedCaps.Add(caps.STS)
} }
server.stsEnabled = config.Server.STS.Enabled server.stsEnabled = config.Server.STS.Enabled
// burst new and removed caps // burst new and removed caps
var capBurstClients ClientSet var capBurstClients ClientSet
added := make(map[CapVersion]string) added := make(map[caps.Version]string)
var removed string var removed string
// updated caps get DEL'd and then NEW'd // updated caps get DEL'd and then NEW'd
// so, we can just add updated ones to both removed and added lists here and they'll be correctly handled // so, we can just add updated ones to both removed and added lists here and they'll be correctly handled
server.logger.Debug("rehash", "Updated Caps", updatedCaps.String(Cap301), strconv.Itoa(len(updatedCaps))) server.logger.Debug("rehash", "Updated Caps", updatedCaps.String(caps.Cap301, CapValues), strconv.Itoa(updatedCaps.Count()))
if len(updatedCaps) > 0 { for _, capab := range updatedCaps.List() {
for capab := range updatedCaps { addedCaps.Enable(capab)
addedCaps[capab] = true removedCaps.Enable(capab)
removedCaps[capab] = true
}
} }
if len(addedCaps) > 0 || len(removedCaps) > 0 { if 0 < addedCaps.Count() || 0 < removedCaps.Count() {
capBurstClients = server.clients.AllWithCaps(caps.CapNotify) capBurstClients = server.clients.AllWithCaps(caps.CapNotify)
added[Cap301] = addedCaps.String(Cap301) added[caps.Cap301] = addedCaps.String(caps.Cap301, CapValues)
added[Cap302] = addedCaps.String(Cap302) added[caps.Cap302] = addedCaps.String(caps.Cap302, CapValues)
// removed never has values // removed never has values, so we leave it as Cap301
removed = removedCaps.String(Cap301) removed = removedCaps.String(caps.Cap301, CapValues)
} }
for sClient := range capBurstClients { for sClient := range capBurstClients {
@ -1355,18 +1355,18 @@ func (server *Server) applyConfig(config *Config, initial bool) error {
// remove STS policy // remove STS policy
//TODO(dan): this is an ugly hack. we can write this better. //TODO(dan): this is an ugly hack. we can write this better.
stsPolicy := "sts=duration=0" stsPolicy := "sts=duration=0"
if len(addedCaps) > 0 { if 0 < addedCaps.Count() {
added[Cap302] = added[Cap302] + " " + stsPolicy added[caps.Cap302] = added[caps.Cap302] + " " + stsPolicy
} else { } else {
addedCaps[caps.STS] = true addedCaps.Enable(caps.STS)
added[Cap302] = stsPolicy added[caps.Cap302] = stsPolicy
} }
} }
// DEL caps and then send NEW ones so that updated caps get removed/added correctly // DEL caps and then send NEW ones so that updated caps get removed/added correctly
if len(removedCaps) > 0 { if 0 < removedCaps.Count() {
sClient.Send(nil, server.name, "CAP", sClient.nick, "DEL", removed) sClient.Send(nil, server.name, "CAP", sClient.nick, "DEL", removed)
} }
if len(addedCaps) > 0 { if 0 < addedCaps.Count() {
sClient.Send(nil, server.name, "CAP", sClient.nick, "NEW", added[sClient.capVersion]) sClient.Send(nil, server.name, "CAP", sClient.nick, "NEW", added[sClient.capVersion])
} }
} }
@ -1707,7 +1707,7 @@ func noticeHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool {
message := msg.Params[1] message := msg.Params[1]
// split privmsg // split privmsg
splitMsg := server.splitMessage(message, !client.capabilities[caps.MaxLine]) splitMsg := server.splitMessage(message, !client.capabilities.Has(caps.MaxLine))
for i, targetString := range targets { for i, targetString := range targets {
// max of four targets per privmsg // max of four targets per privmsg
@ -1748,7 +1748,7 @@ func noticeHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool {
// errors silently ignored with NOTICE as per RFC // errors silently ignored with NOTICE as per RFC
continue continue
} }
if !user.capabilities[caps.MessageTags] { if !user.capabilities.Has(caps.MessageTags) {
clientOnlyTags = nil clientOnlyTags = nil
} }
msgid := server.generateMessageID() msgid := server.generateMessageID()
@ -1757,7 +1757,7 @@ func noticeHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool {
if !user.flags[RegisteredOnly] || client.registered { if !user.flags[RegisteredOnly] || client.registered {
user.SendSplitMsgFromClient(msgid, client, clientOnlyTags, "NOTICE", user.nick, splitMsg) user.SendSplitMsgFromClient(msgid, client, clientOnlyTags, "NOTICE", user.nick, splitMsg)
} }
if client.capabilities[caps.EchoMessage] { if client.capabilities.Has(caps.EchoMessage) {
client.SendSplitMsgFromClient(msgid, client, clientOnlyTags, "NOTICE", user.nick, splitMsg) client.SendSplitMsgFromClient(msgid, client, clientOnlyTags, "NOTICE", user.nick, splitMsg)
} }
} }