mirror of
https://github.com/ergochat/ergo.git
synced 2024-11-10 22:19:31 +01:00
correctly support disabling caps with CAP REQ, fixes #337
This commit is contained in:
parent
6667585605
commit
f48af3ee44
@ -59,6 +59,11 @@ func (s *Set) Union(other *Set) {
|
||||
utils.BitsetUnion(s[:], other[:])
|
||||
}
|
||||
|
||||
// Subtract removes all the capabilities of another set from this set.
|
||||
func (s *Set) Subtract(other *Set) {
|
||||
utils.BitsetSubtract(s[:], other[:])
|
||||
}
|
||||
|
||||
// Empty returns whether the set is empty.
|
||||
func (s *Set) Empty() bool {
|
||||
return utils.BitsetEmpty(s[:])
|
||||
|
@ -60,6 +60,17 @@ func TestSets(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubtract(t *testing.T) {
|
||||
s1 := NewSet(AccountTag, EchoMessage, UserhostInNames, ServerTime)
|
||||
|
||||
toRemove := NewSet(UserhostInNames, EchoMessage)
|
||||
s1.Subtract(toRemove)
|
||||
|
||||
if !reflect.DeepEqual(s1, NewSet(AccountTag, ServerTime)) {
|
||||
t.Errorf("subtract doesn't work")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSetReads(b *testing.B) {
|
||||
set := NewSet(UserhostInNames, EchoMessage)
|
||||
b.ResetTimer()
|
||||
|
@ -450,19 +450,27 @@ func awayHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Resp
|
||||
// CAP <subcmd> [<caps>]
|
||||
func capHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *ResponseBuffer) bool {
|
||||
subCommand := strings.ToUpper(msg.Params[0])
|
||||
capabilities := caps.NewSet()
|
||||
toAdd := caps.NewSet()
|
||||
toRemove := caps.NewSet()
|
||||
var capString string
|
||||
|
||||
var badCaps []string
|
||||
badCaps := false
|
||||
if len(msg.Params) > 1 {
|
||||
capString = msg.Params[1]
|
||||
strs := strings.Fields(capString)
|
||||
for _, str := range strs {
|
||||
remove := false
|
||||
if str[0] == '-' {
|
||||
str = str[1:]
|
||||
remove = true
|
||||
}
|
||||
capab, err := caps.NameToCapability(str)
|
||||
if err != nil || !SupportedCapabilities.Has(capab) {
|
||||
badCaps = append(badCaps, str)
|
||||
if err != nil || (!remove && !SupportedCapabilities.Has(capab)) {
|
||||
badCaps = true
|
||||
} else if !remove {
|
||||
toAdd.Enable(capab)
|
||||
} else {
|
||||
capabilities.Enable(capab)
|
||||
toRemove.Enable(capab)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -490,16 +498,17 @@ func capHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Respo
|
||||
}
|
||||
|
||||
// make sure all capabilities actually exist
|
||||
if len(badCaps) > 0 {
|
||||
if badCaps {
|
||||
rb.Add(nil, server.name, "CAP", client.nick, "NAK", capString)
|
||||
return false
|
||||
}
|
||||
client.capabilities.Union(capabilities)
|
||||
client.capabilities.Union(toAdd)
|
||||
client.capabilities.Subtract(toRemove)
|
||||
rb.Add(nil, server.name, "CAP", client.nick, "ACK", capString)
|
||||
|
||||
// if this is the first time the client is requesting a resume token,
|
||||
// send it to them
|
||||
if capabilities.Has(caps.Resume) {
|
||||
if toAdd.Has(caps.Resume) {
|
||||
token, err := client.generateResumeToken()
|
||||
if err == nil {
|
||||
rb.Add(nil, server.name, "RESUME", "TOKEN", token)
|
||||
|
@ -89,3 +89,19 @@ func BitsetCopy(set []uint64, other []uint64) {
|
||||
atomic.StoreUint64(&set[i], data)
|
||||
}
|
||||
}
|
||||
|
||||
// BitsetSubtract modifies `set` to subtract the contents of `other`.
|
||||
// Similar caveats about race conditions as with `BitsetUnion` apply.
|
||||
func BitsetSubtract(set []uint64, other []uint64) {
|
||||
for i := 0; i < len(set); i++ {
|
||||
for {
|
||||
ourAddr := &set[i]
|
||||
ourBlock := atomic.LoadUint64(ourAddr)
|
||||
otherBlock := atomic.LoadUint64(&other[i])
|
||||
newBlock := ourBlock & (^otherBlock)
|
||||
if atomic.CompareAndSwapUint64(ourAddr, ourBlock, newBlock) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -73,8 +73,13 @@ func TestSets(t *testing.T) {
|
||||
BitsetCopy(t3s, t1s)
|
||||
for i = 0; i < 128; i++ {
|
||||
expected := (i != 72)
|
||||
if BitsetGet(t1s, i) != expected {
|
||||
if BitsetGet(t3s, i) != expected {
|
||||
t.Error("all bits should be set except 72")
|
||||
}
|
||||
}
|
||||
|
||||
BitsetSubtract(t3s, t2s)
|
||||
if !BitsetGet(t3s, 0) || BitsetGet(t3s, 72) || !BitsetGet(t3s, 74) || BitsetGet(t3s, 71) {
|
||||
t.Error("subtract doesn't work")
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user