From f48af3ee449271450a91d22c732075e0707a1d36 Mon Sep 17 00:00:00 2001 From: Shivaram Lingamneni Date: Sat, 2 Feb 2019 20:00:23 -0500 Subject: [PATCH] correctly support disabling caps with CAP REQ, fixes #337 --- irc/caps/set.go | 5 +++++ irc/caps/set_test.go | 11 +++++++++++ irc/handlers.go | 25 +++++++++++++++++-------- irc/utils/bitset.go | 16 ++++++++++++++++ irc/utils/bitset_test.go | 7 ++++++- 5 files changed, 55 insertions(+), 9 deletions(-) diff --git a/irc/caps/set.go b/irc/caps/set.go index 867617b3..b348cbf8 100644 --- a/irc/caps/set.go +++ b/irc/caps/set.go @@ -59,6 +59,11 @@ func (s *Set) Union(other *Set) { utils.BitsetUnion(s[:], other[:]) } +// Subtract removes all the capabilities of another set from this set. +func (s *Set) Subtract(other *Set) { + utils.BitsetSubtract(s[:], other[:]) +} + // Empty returns whether the set is empty. func (s *Set) Empty() bool { return utils.BitsetEmpty(s[:]) diff --git a/irc/caps/set_test.go b/irc/caps/set_test.go index 019ca76b..c32e25e6 100644 --- a/irc/caps/set_test.go +++ b/irc/caps/set_test.go @@ -60,6 +60,17 @@ func TestSets(t *testing.T) { } } +func TestSubtract(t *testing.T) { + s1 := NewSet(AccountTag, EchoMessage, UserhostInNames, ServerTime) + + toRemove := NewSet(UserhostInNames, EchoMessage) + s1.Subtract(toRemove) + + if !reflect.DeepEqual(s1, NewSet(AccountTag, ServerTime)) { + t.Errorf("subtract doesn't work") + } +} + func BenchmarkSetReads(b *testing.B) { set := NewSet(UserhostInNames, EchoMessage) b.ResetTimer() diff --git a/irc/handlers.go b/irc/handlers.go index b8075296..adfe3b0e 100644 --- a/irc/handlers.go +++ b/irc/handlers.go @@ -450,19 +450,27 @@ func awayHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Resp // CAP [] func capHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *ResponseBuffer) bool { subCommand := strings.ToUpper(msg.Params[0]) - capabilities := caps.NewSet() + toAdd := caps.NewSet() + toRemove := caps.NewSet() var capString string - var badCaps []string + badCaps := false if len(msg.Params) > 1 { capString = msg.Params[1] strs := strings.Fields(capString) for _, str := range strs { + remove := false + if str[0] == '-' { + str = str[1:] + remove = true + } capab, err := caps.NameToCapability(str) - if err != nil || !SupportedCapabilities.Has(capab) { - badCaps = append(badCaps, str) + if err != nil || (!remove && !SupportedCapabilities.Has(capab)) { + badCaps = true + } else if !remove { + toAdd.Enable(capab) } else { - capabilities.Enable(capab) + toRemove.Enable(capab) } } } @@ -490,16 +498,17 @@ func capHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Respo } // make sure all capabilities actually exist - if len(badCaps) > 0 { + if badCaps { rb.Add(nil, server.name, "CAP", client.nick, "NAK", capString) return false } - client.capabilities.Union(capabilities) + client.capabilities.Union(toAdd) + client.capabilities.Subtract(toRemove) rb.Add(nil, server.name, "CAP", client.nick, "ACK", capString) // if this is the first time the client is requesting a resume token, // send it to them - if capabilities.Has(caps.Resume) { + if toAdd.Has(caps.Resume) { token, err := client.generateResumeToken() if err == nil { rb.Add(nil, server.name, "RESUME", "TOKEN", token) diff --git a/irc/utils/bitset.go b/irc/utils/bitset.go index 514d2fac..9e0014a8 100644 --- a/irc/utils/bitset.go +++ b/irc/utils/bitset.go @@ -89,3 +89,19 @@ func BitsetCopy(set []uint64, other []uint64) { atomic.StoreUint64(&set[i], data) } } + +// BitsetSubtract modifies `set` to subtract the contents of `other`. +// Similar caveats about race conditions as with `BitsetUnion` apply. +func BitsetSubtract(set []uint64, other []uint64) { + for i := 0; i < len(set); i++ { + for { + ourAddr := &set[i] + ourBlock := atomic.LoadUint64(ourAddr) + otherBlock := atomic.LoadUint64(&other[i]) + newBlock := ourBlock & (^otherBlock) + if atomic.CompareAndSwapUint64(ourAddr, ourBlock, newBlock) { + break + } + } + } +} diff --git a/irc/utils/bitset_test.go b/irc/utils/bitset_test.go index 282f1c6f..a34a0097 100644 --- a/irc/utils/bitset_test.go +++ b/irc/utils/bitset_test.go @@ -73,8 +73,13 @@ func TestSets(t *testing.T) { BitsetCopy(t3s, t1s) for i = 0; i < 128; i++ { expected := (i != 72) - if BitsetGet(t1s, i) != expected { + if BitsetGet(t3s, i) != expected { t.Error("all bits should be set except 72") } } + + BitsetSubtract(t3s, t2s) + if !BitsetGet(t3s, 0) || BitsetGet(t3s, 72) || !BitsetGet(t3s, 74) || BitsetGet(t3s, 71) { + t.Error("subtract doesn't work") + } }