correctly support disabling caps with CAP REQ, fixes #337

This commit is contained in:
Shivaram Lingamneni 2019-02-02 20:00:23 -05:00
parent 6667585605
commit f48af3ee44
5 changed files with 55 additions and 9 deletions

View File

@ -59,6 +59,11 @@ func (s *Set) Union(other *Set) {
utils.BitsetUnion(s[:], other[:]) utils.BitsetUnion(s[:], other[:])
} }
// Subtract removes all the capabilities of another set from this set.
func (s *Set) Subtract(other *Set) {
utils.BitsetSubtract(s[:], other[:])
}
// Empty returns whether the set is empty. // Empty returns whether the set is empty.
func (s *Set) Empty() bool { func (s *Set) Empty() bool {
return utils.BitsetEmpty(s[:]) return utils.BitsetEmpty(s[:])

View File

@ -60,6 +60,17 @@ func TestSets(t *testing.T) {
} }
} }
func TestSubtract(t *testing.T) {
s1 := NewSet(AccountTag, EchoMessage, UserhostInNames, ServerTime)
toRemove := NewSet(UserhostInNames, EchoMessage)
s1.Subtract(toRemove)
if !reflect.DeepEqual(s1, NewSet(AccountTag, ServerTime)) {
t.Errorf("subtract doesn't work")
}
}
func BenchmarkSetReads(b *testing.B) { func BenchmarkSetReads(b *testing.B) {
set := NewSet(UserhostInNames, EchoMessage) set := NewSet(UserhostInNames, EchoMessage)
b.ResetTimer() b.ResetTimer()

View File

@ -450,19 +450,27 @@ func awayHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Resp
// CAP <subcmd> [<caps>] // CAP <subcmd> [<caps>]
func capHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *ResponseBuffer) bool { func capHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *ResponseBuffer) bool {
subCommand := strings.ToUpper(msg.Params[0]) subCommand := strings.ToUpper(msg.Params[0])
capabilities := caps.NewSet() toAdd := caps.NewSet()
toRemove := caps.NewSet()
var capString string var capString string
var badCaps []string badCaps := false
if len(msg.Params) > 1 { if len(msg.Params) > 1 {
capString = msg.Params[1] capString = msg.Params[1]
strs := strings.Fields(capString) strs := strings.Fields(capString)
for _, str := range strs { for _, str := range strs {
remove := false
if str[0] == '-' {
str = str[1:]
remove = true
}
capab, err := caps.NameToCapability(str) capab, err := caps.NameToCapability(str)
if err != nil || !SupportedCapabilities.Has(capab) { if err != nil || (!remove && !SupportedCapabilities.Has(capab)) {
badCaps = append(badCaps, str) badCaps = true
} else if !remove {
toAdd.Enable(capab)
} else { } else {
capabilities.Enable(capab) toRemove.Enable(capab)
} }
} }
} }
@ -490,16 +498,17 @@ func capHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Respo
} }
// make sure all capabilities actually exist // make sure all capabilities actually exist
if len(badCaps) > 0 { if badCaps {
rb.Add(nil, server.name, "CAP", client.nick, "NAK", capString) rb.Add(nil, server.name, "CAP", client.nick, "NAK", capString)
return false return false
} }
client.capabilities.Union(capabilities) client.capabilities.Union(toAdd)
client.capabilities.Subtract(toRemove)
rb.Add(nil, server.name, "CAP", client.nick, "ACK", capString) rb.Add(nil, server.name, "CAP", client.nick, "ACK", capString)
// if this is the first time the client is requesting a resume token, // if this is the first time the client is requesting a resume token,
// send it to them // send it to them
if capabilities.Has(caps.Resume) { if toAdd.Has(caps.Resume) {
token, err := client.generateResumeToken() token, err := client.generateResumeToken()
if err == nil { if err == nil {
rb.Add(nil, server.name, "RESUME", "TOKEN", token) rb.Add(nil, server.name, "RESUME", "TOKEN", token)

View File

@ -89,3 +89,19 @@ func BitsetCopy(set []uint64, other []uint64) {
atomic.StoreUint64(&set[i], data) atomic.StoreUint64(&set[i], data)
} }
} }
// BitsetSubtract modifies `set` to subtract the contents of `other`.
// Similar caveats about race conditions as with `BitsetUnion` apply.
func BitsetSubtract(set []uint64, other []uint64) {
for i := 0; i < len(set); i++ {
for {
ourAddr := &set[i]
ourBlock := atomic.LoadUint64(ourAddr)
otherBlock := atomic.LoadUint64(&other[i])
newBlock := ourBlock & (^otherBlock)
if atomic.CompareAndSwapUint64(ourAddr, ourBlock, newBlock) {
break
}
}
}
}

View File

@ -73,8 +73,13 @@ func TestSets(t *testing.T) {
BitsetCopy(t3s, t1s) BitsetCopy(t3s, t1s)
for i = 0; i < 128; i++ { for i = 0; i < 128; i++ {
expected := (i != 72) expected := (i != 72)
if BitsetGet(t1s, i) != expected { if BitsetGet(t3s, i) != expected {
t.Error("all bits should be set except 72") t.Error("all bits should be set except 72")
} }
} }
BitsetSubtract(t3s, t2s)
if !BitsetGet(t3s, 0) || BitsetGet(t3s, 72) || !BitsetGet(t3s, 74) || BitsetGet(t3s, 71) {
t.Error("subtract doesn't work")
}
} }