mirror of
https://github.com/ergochat/ergo.git
synced 2024-12-22 18:52:41 +01:00
correctly support disabling caps with CAP REQ, fixes #337
This commit is contained in:
parent
6667585605
commit
f48af3ee44
@ -59,6 +59,11 @@ func (s *Set) Union(other *Set) {
|
|||||||
utils.BitsetUnion(s[:], other[:])
|
utils.BitsetUnion(s[:], other[:])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Subtract removes all the capabilities of another set from this set.
|
||||||
|
func (s *Set) Subtract(other *Set) {
|
||||||
|
utils.BitsetSubtract(s[:], other[:])
|
||||||
|
}
|
||||||
|
|
||||||
// Empty returns whether the set is empty.
|
// Empty returns whether the set is empty.
|
||||||
func (s *Set) Empty() bool {
|
func (s *Set) Empty() bool {
|
||||||
return utils.BitsetEmpty(s[:])
|
return utils.BitsetEmpty(s[:])
|
||||||
|
@ -60,6 +60,17 @@ func TestSets(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSubtract(t *testing.T) {
|
||||||
|
s1 := NewSet(AccountTag, EchoMessage, UserhostInNames, ServerTime)
|
||||||
|
|
||||||
|
toRemove := NewSet(UserhostInNames, EchoMessage)
|
||||||
|
s1.Subtract(toRemove)
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(s1, NewSet(AccountTag, ServerTime)) {
|
||||||
|
t.Errorf("subtract doesn't work")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkSetReads(b *testing.B) {
|
func BenchmarkSetReads(b *testing.B) {
|
||||||
set := NewSet(UserhostInNames, EchoMessage)
|
set := NewSet(UserhostInNames, EchoMessage)
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
|
@ -450,19 +450,27 @@ func awayHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Resp
|
|||||||
// CAP <subcmd> [<caps>]
|
// CAP <subcmd> [<caps>]
|
||||||
func capHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *ResponseBuffer) bool {
|
func capHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *ResponseBuffer) bool {
|
||||||
subCommand := strings.ToUpper(msg.Params[0])
|
subCommand := strings.ToUpper(msg.Params[0])
|
||||||
capabilities := caps.NewSet()
|
toAdd := caps.NewSet()
|
||||||
|
toRemove := caps.NewSet()
|
||||||
var capString string
|
var capString string
|
||||||
|
|
||||||
var badCaps []string
|
badCaps := false
|
||||||
if len(msg.Params) > 1 {
|
if len(msg.Params) > 1 {
|
||||||
capString = msg.Params[1]
|
capString = msg.Params[1]
|
||||||
strs := strings.Fields(capString)
|
strs := strings.Fields(capString)
|
||||||
for _, str := range strs {
|
for _, str := range strs {
|
||||||
|
remove := false
|
||||||
|
if str[0] == '-' {
|
||||||
|
str = str[1:]
|
||||||
|
remove = true
|
||||||
|
}
|
||||||
capab, err := caps.NameToCapability(str)
|
capab, err := caps.NameToCapability(str)
|
||||||
if err != nil || !SupportedCapabilities.Has(capab) {
|
if err != nil || (!remove && !SupportedCapabilities.Has(capab)) {
|
||||||
badCaps = append(badCaps, str)
|
badCaps = true
|
||||||
|
} else if !remove {
|
||||||
|
toAdd.Enable(capab)
|
||||||
} else {
|
} else {
|
||||||
capabilities.Enable(capab)
|
toRemove.Enable(capab)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -490,16 +498,17 @@ func capHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Respo
|
|||||||
}
|
}
|
||||||
|
|
||||||
// make sure all capabilities actually exist
|
// make sure all capabilities actually exist
|
||||||
if len(badCaps) > 0 {
|
if badCaps {
|
||||||
rb.Add(nil, server.name, "CAP", client.nick, "NAK", capString)
|
rb.Add(nil, server.name, "CAP", client.nick, "NAK", capString)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
client.capabilities.Union(capabilities)
|
client.capabilities.Union(toAdd)
|
||||||
|
client.capabilities.Subtract(toRemove)
|
||||||
rb.Add(nil, server.name, "CAP", client.nick, "ACK", capString)
|
rb.Add(nil, server.name, "CAP", client.nick, "ACK", capString)
|
||||||
|
|
||||||
// if this is the first time the client is requesting a resume token,
|
// if this is the first time the client is requesting a resume token,
|
||||||
// send it to them
|
// send it to them
|
||||||
if capabilities.Has(caps.Resume) {
|
if toAdd.Has(caps.Resume) {
|
||||||
token, err := client.generateResumeToken()
|
token, err := client.generateResumeToken()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
rb.Add(nil, server.name, "RESUME", "TOKEN", token)
|
rb.Add(nil, server.name, "RESUME", "TOKEN", token)
|
||||||
|
@ -89,3 +89,19 @@ func BitsetCopy(set []uint64, other []uint64) {
|
|||||||
atomic.StoreUint64(&set[i], data)
|
atomic.StoreUint64(&set[i], data)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BitsetSubtract modifies `set` to subtract the contents of `other`.
|
||||||
|
// Similar caveats about race conditions as with `BitsetUnion` apply.
|
||||||
|
func BitsetSubtract(set []uint64, other []uint64) {
|
||||||
|
for i := 0; i < len(set); i++ {
|
||||||
|
for {
|
||||||
|
ourAddr := &set[i]
|
||||||
|
ourBlock := atomic.LoadUint64(ourAddr)
|
||||||
|
otherBlock := atomic.LoadUint64(&other[i])
|
||||||
|
newBlock := ourBlock & (^otherBlock)
|
||||||
|
if atomic.CompareAndSwapUint64(ourAddr, ourBlock, newBlock) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -73,8 +73,13 @@ func TestSets(t *testing.T) {
|
|||||||
BitsetCopy(t3s, t1s)
|
BitsetCopy(t3s, t1s)
|
||||||
for i = 0; i < 128; i++ {
|
for i = 0; i < 128; i++ {
|
||||||
expected := (i != 72)
|
expected := (i != 72)
|
||||||
if BitsetGet(t1s, i) != expected {
|
if BitsetGet(t3s, i) != expected {
|
||||||
t.Error("all bits should be set except 72")
|
t.Error("all bits should be set except 72")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
BitsetSubtract(t3s, t2s)
|
||||||
|
if !BitsetGet(t3s, 0) || BitsetGet(t3s, 72) || !BitsetGet(t3s, 74) || BitsetGet(t3s, 71) {
|
||||||
|
t.Error("subtract doesn't work")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user