mirror of
https://github.com/ergochat/ergo.git
synced 2024-11-25 13:29:27 +01:00
atomic bitset implementations of caps.Set and modes.ModeSet
This commit is contained in:
parent
cdbb369a9c
commit
2a33c1483b
7
Makefile
7
Makefile
@ -1,5 +1,7 @@
|
||||
.PHONY: all build
|
||||
|
||||
capdef_file = ./irc/caps/defs.go
|
||||
|
||||
all: build
|
||||
|
||||
build:
|
||||
@ -8,11 +10,16 @@ build:
|
||||
buildrelease:
|
||||
goreleaser --skip-publish --rm-dist
|
||||
|
||||
capdefs:
|
||||
python3 ./gencapdefs.py > ${capdef_file}
|
||||
|
||||
deps:
|
||||
git submodule update --init
|
||||
|
||||
test:
|
||||
python3 ./gencapdefs.py | diff - ${capdef_file}
|
||||
cd irc && go test . && go vet .
|
||||
cd irc/caps && go test . && go vet .
|
||||
cd irc/isupport && go test . && go vet .
|
||||
cd irc/modes && go test . && go vet .
|
||||
cd irc/utils && go test . && go vet .
|
||||
|
204
gencapdefs.py
Normal file
204
gencapdefs.py
Normal file
@ -0,0 +1,204 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
Updates the capability definitions at irc/caps/defs.go
|
||||
|
||||
To add a capability, add it to the CAPDEFS list below,
|
||||
then run `make capdefs` from the project root.
|
||||
"""
|
||||
|
||||
import io
|
||||
import subprocess
|
||||
import sys
|
||||
from collections import namedtuple
|
||||
|
||||
CapDef = namedtuple("CapDef", ['identifier', 'name', 'url', 'standard'])
|
||||
|
||||
CAPDEFS = [
|
||||
CapDef(
|
||||
identifier="LabelTagName",
|
||||
name="draft/label",
|
||||
url="https://ircv3.net/specs/extensions/labeled-response.html",
|
||||
standard="draft IRCv3 tag name",
|
||||
),
|
||||
CapDef(
|
||||
identifier="AccountNotify",
|
||||
name="account-notify",
|
||||
url="https://ircv3.net/specs/extensions/account-notify-3.1.html",
|
||||
standard="IRCv3",
|
||||
),
|
||||
CapDef(
|
||||
identifier="AccountTag",
|
||||
name="account-tag",
|
||||
url="https://ircv3.net/specs/extensions/account-tag-3.2.html",
|
||||
standard="IRCv3",
|
||||
),
|
||||
CapDef(
|
||||
identifier="AwayNotify",
|
||||
name="away-notify",
|
||||
url="https://ircv3.net/specs/extensions/away-notify-3.1.html",
|
||||
standard="IRCv3",
|
||||
),
|
||||
CapDef(
|
||||
identifier="Batch",
|
||||
name="batch",
|
||||
url="https://ircv3.net/specs/extensions/batch-3.2.html",
|
||||
standard="IRCv3",
|
||||
),
|
||||
CapDef(
|
||||
identifier="CapNotify",
|
||||
name="cap-notify",
|
||||
url="https://ircv3.net/specs/extensions/cap-notify-3.2.html",
|
||||
standard="IRCv3",
|
||||
),
|
||||
CapDef(
|
||||
identifier="ChgHost",
|
||||
name="chghost",
|
||||
url="https://ircv3.net/specs/extensions/chghost-3.2.html",
|
||||
standard="IRCv3",
|
||||
),
|
||||
CapDef(
|
||||
identifier="EchoMessage",
|
||||
name="echo-message",
|
||||
url="https://ircv3.net/specs/extensions/echo-message-3.2.html",
|
||||
standard="IRCv3",
|
||||
),
|
||||
CapDef(
|
||||
identifier="ExtendedJoin",
|
||||
name="extended-join",
|
||||
url="https://ircv3.net/specs/extensions/extended-join-3.1.html",
|
||||
standard="IRCv3",
|
||||
),
|
||||
CapDef(
|
||||
identifier="InviteNotify",
|
||||
name="invite-notify",
|
||||
url="https://ircv3.net/specs/extensions/invite-notify-3.2.html",
|
||||
standard="IRCv3",
|
||||
),
|
||||
CapDef(
|
||||
identifier="LabeledResponse",
|
||||
name="draft/labeled-response",
|
||||
url="https://ircv3.net/specs/extensions/labeled-response.html",
|
||||
standard="draft IRCv3",
|
||||
),
|
||||
CapDef(
|
||||
identifier="Languages",
|
||||
name="draft/languages",
|
||||
url="https://gist.github.com/DanielOaks/8126122f74b26012a3de37db80e4e0c6",
|
||||
standard="proposed IRCv3",
|
||||
),
|
||||
CapDef(
|
||||
identifier="MaxLine",
|
||||
name="oragono.io/maxline",
|
||||
url="https://oragono.io/maxline",
|
||||
standard="Oragono-specific",
|
||||
),
|
||||
CapDef(
|
||||
identifier="MessageTags",
|
||||
name="draft/message-tags-0.2",
|
||||
url="https://ircv3.net/specs/core/message-tags-3.3.html",
|
||||
standard="draft IRCv3",
|
||||
),
|
||||
CapDef(
|
||||
identifier="MultiPrefix",
|
||||
name="multi-prefix",
|
||||
url="https://ircv3.net/specs/extensions/multi-prefix-3.1.html",
|
||||
standard="IRCv3",
|
||||
),
|
||||
CapDef(
|
||||
identifier="Rename",
|
||||
name="draft/rename",
|
||||
url="https://github.com/SaberUK/ircv3-specifications/blob/rename/extensions/rename.md",
|
||||
standard="proposed IRCv3",
|
||||
),
|
||||
CapDef(
|
||||
identifier="Resume",
|
||||
name="draft/resume",
|
||||
url="https://github.com/DanielOaks/ircv3-specifications/blob/master+resume/extensions/resume.md",
|
||||
standard="proposed IRCv3",
|
||||
),
|
||||
CapDef(
|
||||
identifier="SASL",
|
||||
name="sasl",
|
||||
url="https://ircv3.net/specs/extensions/sasl-3.2.html",
|
||||
standard="IRCv3",
|
||||
),
|
||||
CapDef(
|
||||
identifier="ServerTime",
|
||||
name="server-time",
|
||||
url="https://ircv3.net/specs/extensions/server-time-3.2.html",
|
||||
standard="IRCv3",
|
||||
),
|
||||
CapDef(
|
||||
identifier="STS",
|
||||
name="sts",
|
||||
url="https://ircv3.net/specs/extensions/sts.html",
|
||||
standard="IRCv3",
|
||||
),
|
||||
CapDef(
|
||||
identifier="UserhostInNames",
|
||||
name="userhost-in-names",
|
||||
url="https://ircv3.net/specs/extensions/userhost-in-names-3.2.html",
|
||||
standard="IRCv3",
|
||||
),
|
||||
]
|
||||
|
||||
def validate_defs():
|
||||
numCaps = len(CAPDEFS)
|
||||
numNames = len(set(capdef.name for capdef in CAPDEFS))
|
||||
if numCaps != numNames:
|
||||
raise Exception("defs must have unique names, but found duplicates")
|
||||
numIdentifiers = len(set(capdef.identifier for capdef in CAPDEFS))
|
||||
if numCaps != numIdentifiers:
|
||||
raise Exception("defs must have unique identifiers, but found duplicates")
|
||||
|
||||
def main():
|
||||
validate_defs()
|
||||
output = io.StringIO()
|
||||
print("""
|
||||
package caps
|
||||
|
||||
/*
|
||||
WARNING: this file is autogenerated by `make capdefs`
|
||||
DO NOT EDIT MANUALLY.
|
||||
*/
|
||||
|
||||
|
||||
""", file=output)
|
||||
|
||||
|
||||
numCapabs = len(CAPDEFS)
|
||||
bitsetLen = numCapabs // 64
|
||||
if numCapabs % 64 > 0:
|
||||
bitsetLen += 1
|
||||
print ("""
|
||||
const (
|
||||
// number of recognized capabilities:
|
||||
numCapabs = %d
|
||||
// length of the uint64 array that represents the bitset:
|
||||
bitsetLen = %d
|
||||
)
|
||||
""" % (numCapabs, bitsetLen), file=output)
|
||||
|
||||
print("const (", file=output)
|
||||
for capdef in CAPDEFS:
|
||||
print("// %s is the %s capability named \"%s\":" % (capdef.identifier, capdef.standard, capdef.name), file=output)
|
||||
print("// %s" % (capdef.url,), file=output)
|
||||
print("%s Capability = iota" % (capdef.identifier,), file=output)
|
||||
print(file=output)
|
||||
print(")", file=output)
|
||||
|
||||
print("""var ( capabilityNames = [numCapabs]string{""", file=output)
|
||||
for capdef in CAPDEFS:
|
||||
print("\"%s\"," % (capdef.name,), file=output)
|
||||
print("})", file=output)
|
||||
|
||||
gofmt = subprocess.Popen(['gofmt', '-s'], stdin=subprocess.PIPE)
|
||||
gofmt.communicate(input=output.getvalue().encode('utf-8'))
|
||||
if gofmt.poll() != 0:
|
||||
print(output.getvalue())
|
||||
raise Exception("gofmt failed")
|
||||
return 0
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main())
|
@ -3,58 +3,30 @@
|
||||
|
||||
package caps
|
||||
|
||||
import "errors"
|
||||
|
||||
// Capability represents an optional feature that a client may request from the server.
|
||||
type Capability string
|
||||
type Capability uint
|
||||
|
||||
const (
|
||||
// LabelTagName is the tag name used for the labeled-response spec.
|
||||
LabelTagName = "draft/label"
|
||||
// actual capability definitions appear in defs.go
|
||||
|
||||
// AccountNotify is this IRCv3 capability: http://ircv3.net/specs/extensions/account-notify-3.1.html
|
||||
AccountNotify Capability = "account-notify"
|
||||
// AccountTag is this IRCv3 capability: http://ircv3.net/specs/extensions/account-tag-3.2.html
|
||||
AccountTag Capability = "account-tag"
|
||||
// AwayNotify is this IRCv3 capability: http://ircv3.net/specs/extensions/away-notify-3.1.html
|
||||
AwayNotify Capability = "away-notify"
|
||||
// Batch is this IRCv3 capability: http://ircv3.net/specs/extensions/batch-3.2.html
|
||||
Batch Capability = "batch"
|
||||
// CapNotify is this IRCv3 capability: http://ircv3.net/specs/extensions/cap-notify-3.2.html
|
||||
CapNotify Capability = "cap-notify"
|
||||
// ChgHost is this IRCv3 capability: http://ircv3.net/specs/extensions/chghost-3.2.html
|
||||
ChgHost Capability = "chghost"
|
||||
// EchoMessage is this IRCv3 capability: http://ircv3.net/specs/extensions/echo-message-3.2.html
|
||||
EchoMessage Capability = "echo-message"
|
||||
// ExtendedJoin is this IRCv3 capability: http://ircv3.net/specs/extensions/extended-join-3.1.html
|
||||
ExtendedJoin Capability = "extended-join"
|
||||
// InviteNotify is this IRCv3 capability: http://ircv3.net/specs/extensions/invite-notify-3.2.html
|
||||
InviteNotify Capability = "invite-notify"
|
||||
// LabeledResponse is this draft IRCv3 capability: http://ircv3.net/specs/extensions/labeled-response.html
|
||||
LabeledResponse Capability = "draft/labeled-response"
|
||||
// Languages is this proposed IRCv3 capability: https://gist.github.com/DanielOaks/8126122f74b26012a3de37db80e4e0c6
|
||||
Languages Capability = "draft/languages"
|
||||
// MaxLine is this capability: https://oragono.io/maxline
|
||||
MaxLine Capability = "oragono.io/maxline"
|
||||
// MessageTags is this draft IRCv3 capability: http://ircv3.net/specs/core/message-tags-3.3.html
|
||||
MessageTags Capability = "draft/message-tags-0.2"
|
||||
// MultiPrefix is this IRCv3 capability: http://ircv3.net/specs/extensions/multi-prefix-3.1.html
|
||||
MultiPrefix Capability = "multi-prefix"
|
||||
// Rename is this proposed capability: https://github.com/SaberUK/ircv3-specifications/blob/rename/extensions/rename.md
|
||||
Rename Capability = "draft/rename"
|
||||
// Resume is this proposed capability: https://github.com/DanielOaks/ircv3-specifications/blob/master+resume/extensions/resume.md
|
||||
Resume Capability = "draft/resume"
|
||||
// SASL is this IRCv3 capability: http://ircv3.net/specs/extensions/sasl-3.2.html
|
||||
SASL Capability = "sasl"
|
||||
// ServerTime is this IRCv3 capability: http://ircv3.net/specs/extensions/server-time-3.2.html
|
||||
ServerTime Capability = "server-time"
|
||||
// STS is this IRCv3 capability: http://ircv3.net/specs/extensions/sts.html
|
||||
STS Capability = "sts"
|
||||
// UserhostInNames is this IRCv3 capability: http://ircv3.net/specs/extensions/userhost-in-names-3.2.html
|
||||
UserhostInNames Capability = "userhost-in-names"
|
||||
var (
|
||||
nameToCapability map[string]Capability
|
||||
|
||||
NoSuchCap = errors.New("Unsupported capability name")
|
||||
)
|
||||
|
||||
// Name returns the name of the given capability.
|
||||
func (capability Capability) Name() string {
|
||||
return string(capability)
|
||||
return capabilityNames[capability]
|
||||
}
|
||||
|
||||
func NameToCapability(name string) (result Capability, err error) {
|
||||
result, found := nameToCapability[name]
|
||||
if !found {
|
||||
err = NoSuchCap
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Version is used to select which max version of CAP the client supports.
|
||||
@ -78,3 +50,10 @@ const (
|
||||
// NegotiatedState means CAP negotiation has been successfully ended and reg should complete.
|
||||
NegotiatedState State = iota
|
||||
)
|
||||
|
||||
func init() {
|
||||
nameToCapability = make(map[string]Capability)
|
||||
for capab, name := range capabilityNames {
|
||||
nameToCapability[name] = Capability(capab)
|
||||
}
|
||||
}
|
||||
|
125
irc/caps/defs.go
Normal file
125
irc/caps/defs.go
Normal file
@ -0,0 +1,125 @@
|
||||
package caps
|
||||
|
||||
/*
|
||||
WARNING: this file is autogenerated by `make capdefs`
|
||||
DO NOT EDIT MANUALLY.
|
||||
*/
|
||||
|
||||
const (
|
||||
// number of recognized capabilities:
|
||||
numCapabs = 21
|
||||
// length of the uint64 array that represents the bitset:
|
||||
bitsetLen = 1
|
||||
)
|
||||
|
||||
const (
|
||||
// LabelTagName is the draft IRCv3 tag name capability named "draft/label":
|
||||
// https://ircv3.net/specs/extensions/labeled-response.html
|
||||
LabelTagName Capability = iota
|
||||
|
||||
// AccountNotify is the IRCv3 capability named "account-notify":
|
||||
// https://ircv3.net/specs/extensions/account-notify-3.1.html
|
||||
AccountNotify Capability = iota
|
||||
|
||||
// AccountTag is the IRCv3 capability named "account-tag":
|
||||
// https://ircv3.net/specs/extensions/account-tag-3.2.html
|
||||
AccountTag Capability = iota
|
||||
|
||||
// AwayNotify is the IRCv3 capability named "away-notify":
|
||||
// https://ircv3.net/specs/extensions/away-notify-3.1.html
|
||||
AwayNotify Capability = iota
|
||||
|
||||
// Batch is the IRCv3 capability named "batch":
|
||||
// https://ircv3.net/specs/extensions/batch-3.2.html
|
||||
Batch Capability = iota
|
||||
|
||||
// CapNotify is the IRCv3 capability named "cap-notify":
|
||||
// https://ircv3.net/specs/extensions/cap-notify-3.2.html
|
||||
CapNotify Capability = iota
|
||||
|
||||
// ChgHost is the IRCv3 capability named "chghost":
|
||||
// https://ircv3.net/specs/extensions/chghost-3.2.html
|
||||
ChgHost Capability = iota
|
||||
|
||||
// EchoMessage is the IRCv3 capability named "echo-message":
|
||||
// https://ircv3.net/specs/extensions/echo-message-3.2.html
|
||||
EchoMessage Capability = iota
|
||||
|
||||
// ExtendedJoin is the IRCv3 capability named "extended-join":
|
||||
// https://ircv3.net/specs/extensions/extended-join-3.1.html
|
||||
ExtendedJoin Capability = iota
|
||||
|
||||
// InviteNotify is the IRCv3 capability named "invite-notify":
|
||||
// https://ircv3.net/specs/extensions/invite-notify-3.2.html
|
||||
InviteNotify Capability = iota
|
||||
|
||||
// LabeledResponse is the draft IRCv3 capability named "draft/labeled-response":
|
||||
// https://ircv3.net/specs/extensions/labeled-response.html
|
||||
LabeledResponse Capability = iota
|
||||
|
||||
// Languages is the proposed IRCv3 capability named "draft/languages":
|
||||
// https://gist.github.com/DanielOaks/8126122f74b26012a3de37db80e4e0c6
|
||||
Languages Capability = iota
|
||||
|
||||
// MaxLine is the Oragono-specific capability named "oragono.io/maxline":
|
||||
// https://oragono.io/maxline
|
||||
MaxLine Capability = iota
|
||||
|
||||
// MessageTags is the draft IRCv3 capability named "draft/message-tags-0.2":
|
||||
// https://ircv3.net/specs/core/message-tags-3.3.html
|
||||
MessageTags Capability = iota
|
||||
|
||||
// MultiPrefix is the IRCv3 capability named "multi-prefix":
|
||||
// https://ircv3.net/specs/extensions/multi-prefix-3.1.html
|
||||
MultiPrefix Capability = iota
|
||||
|
||||
// Rename is the proposed IRCv3 capability named "draft/rename":
|
||||
// https://github.com/SaberUK/ircv3-specifications/blob/rename/extensions/rename.md
|
||||
Rename Capability = iota
|
||||
|
||||
// Resume is the proposed IRCv3 capability named "draft/resume":
|
||||
// https://github.com/DanielOaks/ircv3-specifications/blob/master+resume/extensions/resume.md
|
||||
Resume Capability = iota
|
||||
|
||||
// SASL is the IRCv3 capability named "sasl":
|
||||
// https://ircv3.net/specs/extensions/sasl-3.2.html
|
||||
SASL Capability = iota
|
||||
|
||||
// ServerTime is the IRCv3 capability named "server-time":
|
||||
// https://ircv3.net/specs/extensions/server-time-3.2.html
|
||||
ServerTime Capability = iota
|
||||
|
||||
// STS is the IRCv3 capability named "sts":
|
||||
// https://ircv3.net/specs/extensions/sts.html
|
||||
STS Capability = iota
|
||||
|
||||
// UserhostInNames is the IRCv3 capability named "userhost-in-names":
|
||||
// https://ircv3.net/specs/extensions/userhost-in-names-3.2.html
|
||||
UserhostInNames Capability = iota
|
||||
)
|
||||
|
||||
var (
|
||||
capabilityNames = [numCapabs]string{
|
||||
"draft/label",
|
||||
"account-notify",
|
||||
"account-tag",
|
||||
"away-notify",
|
||||
"batch",
|
||||
"cap-notify",
|
||||
"chghost",
|
||||
"echo-message",
|
||||
"extended-join",
|
||||
"invite-notify",
|
||||
"draft/labeled-response",
|
||||
"draft/languages",
|
||||
"oragono.io/maxline",
|
||||
"draft/message-tags-0.2",
|
||||
"multi-prefix",
|
||||
"draft/rename",
|
||||
"draft/resume",
|
||||
"sasl",
|
||||
"server-time",
|
||||
"sts",
|
||||
"userhost-in-names",
|
||||
}
|
||||
)
|
@ -6,43 +6,34 @@ package caps
|
||||
import (
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/oragono/oragono/irc/utils"
|
||||
)
|
||||
|
||||
// Set holds a set of enabled capabilities.
|
||||
type Set struct {
|
||||
sync.RWMutex
|
||||
// capabilities holds the capabilities this manager has.
|
||||
capabilities map[Capability]bool
|
||||
}
|
||||
type Set [bitsetLen]uint64
|
||||
|
||||
// NewSet returns a new Set, with the given capabilities enabled.
|
||||
func NewSet(capabs ...Capability) *Set {
|
||||
newSet := Set{
|
||||
capabilities: make(map[Capability]bool),
|
||||
}
|
||||
var newSet Set
|
||||
utils.BitsetInitialize(newSet[:])
|
||||
newSet.Enable(capabs...)
|
||||
|
||||
return &newSet
|
||||
}
|
||||
|
||||
// Enable enables the given capabilities.
|
||||
func (s *Set) Enable(capabs ...Capability) {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
|
||||
asSlice := s[:]
|
||||
for _, capab := range capabs {
|
||||
s.capabilities[capab] = true
|
||||
utils.BitsetSet(asSlice, uint(capab), true)
|
||||
}
|
||||
}
|
||||
|
||||
// Disable disables the given capabilities.
|
||||
func (s *Set) Disable(capabs ...Capability) {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
|
||||
asSlice := s[:]
|
||||
for _, capab := range capabs {
|
||||
delete(s.capabilities, capab)
|
||||
utils.BitsetSet(asSlice, uint(capab), false)
|
||||
}
|
||||
}
|
||||
|
||||
@ -58,51 +49,35 @@ func (s *Set) Remove(capabs ...Capability) {
|
||||
s.Disable(capabs...)
|
||||
}
|
||||
|
||||
// Has returns true if this set has the given capabilities.
|
||||
func (s *Set) Has(caps ...Capability) bool {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
|
||||
for _, cap := range caps {
|
||||
if !s.capabilities[cap] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
// Has returns true if this set has the given capability.
|
||||
func (s *Set) Has(capab Capability) bool {
|
||||
return utils.BitsetGet(s[:], uint(capab))
|
||||
}
|
||||
|
||||
// List return a list of our enabled capabilities.
|
||||
func (s *Set) List() []Capability {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
|
||||
var allCaps []Capability
|
||||
for capab := range s.capabilities {
|
||||
allCaps = append(allCaps, capab)
|
||||
}
|
||||
|
||||
return allCaps
|
||||
// Union adds all the capabilities of another set to this set.
|
||||
func (s *Set) Union(other *Set) {
|
||||
utils.BitsetUnion(s[:], other[:])
|
||||
}
|
||||
|
||||
// Count returns how many enabled caps this set has.
|
||||
func (s *Set) Count() int {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
|
||||
return len(s.capabilities)
|
||||
// Empty returns whether the set is empty.
|
||||
func (s *Set) Empty() bool {
|
||||
return utils.BitsetEmpty(s[:])
|
||||
}
|
||||
|
||||
// String returns all of our enabled capabilities as a string.
|
||||
func (s *Set) String(version Version, values *Values) string {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
|
||||
var strs sort.StringSlice
|
||||
|
||||
for capability := range s.capabilities {
|
||||
capString := capability.Name()
|
||||
var capab Capability
|
||||
asSlice := s[:]
|
||||
for capab = 0; capab < numCapabs; capab++ {
|
||||
// skip any capabilities that are not enabled
|
||||
if !utils.BitsetGet(asSlice, uint(capab)) {
|
||||
continue
|
||||
}
|
||||
capString := capab.Name()
|
||||
if version == Cap302 {
|
||||
val, exists := values.Get(capability)
|
||||
val, exists := values.Get(capab)
|
||||
if exists {
|
||||
capString += "=" + val
|
||||
}
|
||||
|
@ -11,12 +11,12 @@ func TestSets(t *testing.T) {
|
||||
|
||||
s1.Enable(AccountTag, EchoMessage, UserhostInNames)
|
||||
|
||||
if !s1.Has(AccountTag, EchoMessage, UserhostInNames) {
|
||||
if !(s1.Has(AccountTag) && s1.Has(EchoMessage) && s1.Has(UserhostInNames)) {
|
||||
t.Error("Did not have the tags we expected")
|
||||
}
|
||||
|
||||
if s1.Has(AccountTag, EchoMessage, STS, UserhostInNames) {
|
||||
t.Error("Has() returned true when we don't have all the given capabilities")
|
||||
if s1.Has(STS) {
|
||||
t.Error("Has() returned true when we don't have the given capability")
|
||||
}
|
||||
|
||||
s1.Disable(AccountTag)
|
||||
@ -25,14 +25,9 @@ func TestSets(t *testing.T) {
|
||||
t.Error("Disable() did not correctly disable the given capability")
|
||||
}
|
||||
|
||||
enabledCaps := make(map[Capability]bool)
|
||||
for _, capab := range s1.List() {
|
||||
enabledCaps[capab] = true
|
||||
}
|
||||
expectedCaps := map[Capability]bool{
|
||||
EchoMessage: true,
|
||||
UserhostInNames: true,
|
||||
}
|
||||
enabledCaps := NewSet()
|
||||
enabledCaps.Union(s1)
|
||||
expectedCaps := NewSet(EchoMessage, UserhostInNames)
|
||||
if !reflect.DeepEqual(enabledCaps, expectedCaps) {
|
||||
t.Errorf("Enabled and expected capability lists do not match: %v, %v", enabledCaps, expectedCaps)
|
||||
}
|
||||
@ -40,16 +35,12 @@ func TestSets(t *testing.T) {
|
||||
// make sure re-enabling doesn't add to the count or something weird like that
|
||||
s1.Enable(EchoMessage)
|
||||
|
||||
if s1.Count() != 2 {
|
||||
t.Error("Count() did not match expected capability count")
|
||||
}
|
||||
|
||||
// make sure add and remove work fine
|
||||
s1.Add(InviteNotify)
|
||||
s1.Remove(EchoMessage)
|
||||
|
||||
if s1.Count() != 2 {
|
||||
t.Error("Count() did not match expected capability count")
|
||||
if !s1.Has(InviteNotify) || s1.Has(EchoMessage) {
|
||||
t.Error("Add/Remove don't work")
|
||||
}
|
||||
|
||||
// test String()
|
||||
|
@ -442,12 +442,16 @@ func capHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Respo
|
||||
capabilities := caps.NewSet()
|
||||
var capString string
|
||||
|
||||
var badCaps []string
|
||||
if len(msg.Params) > 1 {
|
||||
capString = msg.Params[1]
|
||||
strs := strings.Split(capString, " ")
|
||||
strs := strings.Fields(capString)
|
||||
for _, str := range strs {
|
||||
if len(str) > 0 {
|
||||
capabilities.Enable(caps.Capability(str))
|
||||
capab, err := caps.NameToCapability(str)
|
||||
if err != nil || !SupportedCapabilities.Has(capab) {
|
||||
badCaps = append(badCaps, str)
|
||||
} else {
|
||||
capabilities.Enable(capab)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -475,13 +479,11 @@ func capHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Respo
|
||||
}
|
||||
|
||||
// make sure all capabilities actually exist
|
||||
for _, capability := range capabilities.List() {
|
||||
if !SupportedCapabilities.Has(capability) {
|
||||
rb.Add(nil, server.name, "CAP", client.nick, "NAK", capString)
|
||||
return false
|
||||
}
|
||||
if len(badCaps) > 0 {
|
||||
rb.Add(nil, server.name, "CAP", client.nick, "NAK", capString)
|
||||
return false
|
||||
}
|
||||
client.capabilities.Enable(capabilities.List()...)
|
||||
client.capabilities.Union(capabilities)
|
||||
rb.Add(nil, server.name, "CAP", client.nick, "ACK", capString)
|
||||
|
||||
case "END":
|
||||
|
@ -7,7 +7,9 @@ package modes
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/oragono/oragono/irc/utils"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -322,42 +324,29 @@ func ParseChannelModeChanges(params ...string) (ModeChanges, map[rune]bool) {
|
||||
}
|
||||
|
||||
// ModeSet holds a set of modes.
|
||||
type ModeSet struct {
|
||||
sync.RWMutex // tier 0
|
||||
modes map[Mode]bool
|
||||
}
|
||||
type ModeSet [1]uint64
|
||||
|
||||
// valid modes go from 65 ('A') to 122 ('z'), making at most 58 possible values;
|
||||
// subtract 65 from the mode value and use that bit of the uint64 to represent it
|
||||
const (
|
||||
minMode = 65 // 'A'
|
||||
)
|
||||
|
||||
// returns a pointer to a new ModeSet
|
||||
func NewModeSet() *ModeSet {
|
||||
return &ModeSet{
|
||||
modes: make(map[Mode]bool),
|
||||
}
|
||||
var set ModeSet
|
||||
utils.BitsetInitialize(set[:])
|
||||
return &set
|
||||
}
|
||||
|
||||
// test whether `mode` is set
|
||||
func (set *ModeSet) HasMode(mode Mode) bool {
|
||||
if set == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
set.RLock()
|
||||
defer set.RUnlock()
|
||||
return set.modes[mode]
|
||||
return utils.BitsetGet(set[:], uint(mode)-minMode)
|
||||
}
|
||||
|
||||
// set `mode` to be on or off, return whether the value actually changed
|
||||
func (set *ModeSet) SetMode(mode Mode, on bool) (applied bool) {
|
||||
set.Lock()
|
||||
defer set.Unlock()
|
||||
|
||||
previouslyOn := set.modes[mode]
|
||||
needsApply := (on != previouslyOn)
|
||||
if on && needsApply {
|
||||
set.modes[mode] = true
|
||||
} else if !on && needsApply {
|
||||
delete(set.modes, mode)
|
||||
}
|
||||
return needsApply
|
||||
return utils.BitsetSet(set[:], uint(mode)-minMode, on)
|
||||
}
|
||||
|
||||
// return the modes in the set as a slice
|
||||
@ -366,11 +355,12 @@ func (set *ModeSet) AllModes() (result []Mode) {
|
||||
return
|
||||
}
|
||||
|
||||
set.RLock()
|
||||
defer set.RUnlock()
|
||||
|
||||
for mode := range set.modes {
|
||||
result = append(result, mode)
|
||||
block := atomic.LoadUint64(&set[0])
|
||||
var i uint
|
||||
for i = 0; i < 64; i++ {
|
||||
if block&(1<<i) != 0 {
|
||||
result = append(result, Mode(minMode+i))
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
@ -381,11 +371,8 @@ func (set *ModeSet) String() (result string) {
|
||||
return
|
||||
}
|
||||
|
||||
set.RLock()
|
||||
defer set.RUnlock()
|
||||
|
||||
var buf strings.Builder
|
||||
for mode := range set.modes {
|
||||
for _, mode := range set.AllModes() {
|
||||
buf.WriteRune(rune(mode))
|
||||
}
|
||||
return buf.String()
|
||||
@ -397,12 +384,9 @@ func (set *ModeSet) Prefixes(isMultiPrefix bool) (prefixes string) {
|
||||
return
|
||||
}
|
||||
|
||||
set.RLock()
|
||||
defer set.RUnlock()
|
||||
|
||||
// add prefixes in order from highest to lowest privs
|
||||
for _, mode := range ChannelUserModes {
|
||||
if set.modes[mode] {
|
||||
if set.HasMode(mode) {
|
||||
prefixes += ChannelModePrefixes[mode]
|
||||
}
|
||||
}
|
||||
|
@ -23,7 +23,7 @@ type ResponseBuffer struct {
|
||||
|
||||
// GetLabel returns the label from the given message.
|
||||
func GetLabel(msg ircmsg.IrcMessage) string {
|
||||
return msg.Tags[caps.LabelTagName].Value
|
||||
return msg.Tags[caps.LabelTagName.Name()].Value
|
||||
}
|
||||
|
||||
// NewResponseBuffer returns a new ResponseBuffer.
|
||||
@ -90,13 +90,13 @@ func (rb *ResponseBuffer) Send() error {
|
||||
// if label but no batch, add label to first message
|
||||
if useLabel && batch == nil {
|
||||
message := rb.messages[0]
|
||||
message.Tags[caps.LabelTagName] = ircmsg.MakeTagValue(rb.Label)
|
||||
message.Tags[caps.LabelTagName.Name()] = ircmsg.MakeTagValue(rb.Label)
|
||||
rb.messages[0] = message
|
||||
}
|
||||
|
||||
// start batch if required
|
||||
if batch != nil {
|
||||
batch.Start(rb.target, ircmsg.MakeTags(caps.LabelTagName, rb.Label))
|
||||
batch.Start(rb.target, ircmsg.MakeTags(caps.LabelTagName.Name(), rb.Label))
|
||||
}
|
||||
|
||||
// send each message out
|
||||
|
@ -898,13 +898,11 @@ func (server *Server) applyConfig(config *Config, initial bool) error {
|
||||
|
||||
// updated caps get DEL'd and then NEW'd
|
||||
// so, we can just add updated ones to both removed and added lists here and they'll be correctly handled
|
||||
server.logger.Debug("rehash", "Updated Caps", updatedCaps.String(caps.Cap301, CapValues), strconv.Itoa(updatedCaps.Count()))
|
||||
for _, capab := range updatedCaps.List() {
|
||||
addedCaps.Enable(capab)
|
||||
removedCaps.Enable(capab)
|
||||
}
|
||||
server.logger.Debug("rehash", "Updated Caps", updatedCaps.String(caps.Cap301, CapValues))
|
||||
addedCaps.Union(updatedCaps)
|
||||
removedCaps.Union(updatedCaps)
|
||||
|
||||
if 0 < addedCaps.Count() || 0 < removedCaps.Count() {
|
||||
if !addedCaps.Empty() || !removedCaps.Empty() {
|
||||
capBurstClients = server.clients.AllWithCaps(caps.CapNotify)
|
||||
|
||||
added[caps.Cap301] = addedCaps.String(caps.Cap301, CapValues)
|
||||
@ -918,7 +916,7 @@ func (server *Server) applyConfig(config *Config, initial bool) error {
|
||||
// remove STS policy
|
||||
//TODO(dan): this is an ugly hack. we can write this better.
|
||||
stsPolicy := "sts=duration=0"
|
||||
if 0 < addedCaps.Count() {
|
||||
if !addedCaps.Empty() {
|
||||
added[caps.Cap302] = added[caps.Cap302] + " " + stsPolicy
|
||||
} else {
|
||||
addedCaps.Enable(caps.STS)
|
||||
@ -926,10 +924,10 @@ func (server *Server) applyConfig(config *Config, initial bool) error {
|
||||
}
|
||||
}
|
||||
// DEL caps and then send NEW ones so that updated caps get removed/added correctly
|
||||
if 0 < removedCaps.Count() {
|
||||
if !removedCaps.Empty() {
|
||||
sClient.Send(nil, server.name, "CAP", sClient.nick, "DEL", removed)
|
||||
}
|
||||
if 0 < addedCaps.Count() {
|
||||
if !addedCaps.Empty() {
|
||||
sClient.Send(nil, server.name, "CAP", sClient.nick, "NEW", added[sClient.capVersion])
|
||||
}
|
||||
}
|
||||
|
86
irc/utils/bitset.go
Normal file
86
irc/utils/bitset.go
Normal file
@ -0,0 +1,86 @@
|
||||
// Copyright (c) 2018 Shivaram Lingamneni <slingamn@cs.stanford.edu>
|
||||
// released under the MIT license
|
||||
|
||||
package utils
|
||||
|
||||
import "sync/atomic"
|
||||
|
||||
// Library functions for lock-free bitsets, typically (constant-sized) arrays of uint64.
|
||||
// For examples of use, see caps.Set and modes.ModeSet; the array has to be converted to a
|
||||
// slice to use these functions.
|
||||
|
||||
// BitsetInitialize initializes a bitset.
|
||||
func BitsetInitialize(set []uint64) {
|
||||
// XXX re-zero the bitset using atomic stores. it's unclear whether this is required,
|
||||
// however, golang issue #5045 suggests that you shouldn't mix atomic operations
|
||||
// with non-atomic operations (such as the runtime's automatic zero-initialization) on
|
||||
// the same word
|
||||
for i := 0; i < len(set); i++ {
|
||||
atomic.StoreUint64(&set[i], 0)
|
||||
}
|
||||
}
|
||||
|
||||
// BitsetGet returns whether a given bit of the bitset is set.
|
||||
func BitsetGet(set []uint64, position uint) bool {
|
||||
idx := position / 64
|
||||
bit := position % 64
|
||||
block := atomic.LoadUint64(&set[idx])
|
||||
return (block & (1 << bit)) != 0
|
||||
}
|
||||
|
||||
// BitsetSet sets a given bit of the bitset to 0 or 1, returning whether it changed.
|
||||
func BitsetSet(set []uint64, position uint, on bool) (changed bool) {
|
||||
idx := position / 64
|
||||
bit := position % 64
|
||||
addr := &set[idx]
|
||||
var mask uint64
|
||||
mask = 1 << bit
|
||||
for {
|
||||
current := atomic.LoadUint64(addr)
|
||||
previouslyOn := (current & mask) != 0
|
||||
if on == previouslyOn {
|
||||
return false
|
||||
}
|
||||
var desired uint64
|
||||
if on {
|
||||
desired = current | mask
|
||||
} else {
|
||||
desired = current & (^mask)
|
||||
}
|
||||
if atomic.CompareAndSwapUint64(addr, current, desired) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BitsetEmpty returns whether the bitset is empty.
|
||||
// Right now, this is technically free of race conditions because we don't
|
||||
// have a method that can simultaneously modify two bits separated by a word boundary
|
||||
// such that one of those modifications is an unset. If we did, there would be a race
|
||||
// that could produce false positives. It's probably better to assume that they are
|
||||
// already possible under concurrent modification (which is not how we're using this).
|
||||
func BitsetEmpty(set []uint64) (empty bool) {
|
||||
for i := 0; i < len(set); i++ {
|
||||
if atomic.LoadUint64(&set[i]) != 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// BitsetUnion modifies `set` to be the union of `set` and `other`.
|
||||
// This has race conditions in that we don't necessarily get a single
|
||||
// consistent view of `other` across word boundaries.
|
||||
func BitsetUnion(set []uint64, other []uint64) {
|
||||
for i := 0; i < len(set); i++ {
|
||||
for {
|
||||
ourAddr := &set[i]
|
||||
ourBlock := atomic.LoadUint64(ourAddr)
|
||||
otherBlock := atomic.LoadUint64(&other[i])
|
||||
newBlock := ourBlock | otherBlock
|
||||
if atomic.CompareAndSwapUint64(ourAddr, ourBlock, newBlock) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
52
irc/utils/bitset_test.go
Normal file
52
irc/utils/bitset_test.go
Normal file
@ -0,0 +1,52 @@
|
||||
// Copyright (c) 2018 Shivaram Lingamneni <slingamn@cs.stanford.edu>
|
||||
// released under the MIT license
|
||||
|
||||
package utils
|
||||
|
||||
import "testing"
|
||||
|
||||
type testBitset [2]uint64
|
||||
|
||||
func TestSets(t *testing.T) {
|
||||
var t1 testBitset
|
||||
t1s := t1[:]
|
||||
BitsetInitialize(t1s)
|
||||
|
||||
if BitsetGet(t1s, 0) || BitsetGet(t1s, 63) || BitsetGet(t1s, 64) || BitsetGet(t1s, 127) {
|
||||
t.Error("no bits should be set in a newly initialized bitset")
|
||||
}
|
||||
|
||||
var i uint
|
||||
for i = 0; i < 128; i++ {
|
||||
if i%2 == 0 {
|
||||
BitsetSet(t1s, i, true)
|
||||
}
|
||||
}
|
||||
|
||||
if !(BitsetGet(t1s, 0) && !BitsetGet(t1s, 1) && BitsetGet(t1s, 64) && BitsetGet(t1s, 72) && !BitsetGet(t1s, 127)) {
|
||||
t.Error("exactly the even-numbered bits should be set")
|
||||
}
|
||||
|
||||
BitsetSet(t1s, 72, false)
|
||||
if BitsetGet(t1s, 72) {
|
||||
t.Error("remove doesn't work")
|
||||
}
|
||||
|
||||
var t2 testBitset
|
||||
t2s := t2[:]
|
||||
BitsetInitialize(t2s)
|
||||
|
||||
for i = 0; i < 128; i++ {
|
||||
if i%2 == 1 {
|
||||
BitsetSet(t2s, i, true)
|
||||
}
|
||||
}
|
||||
|
||||
BitsetUnion(t1s, t2s)
|
||||
for i = 0; i < 128; i++ {
|
||||
expected := (i != 72)
|
||||
if BitsetGet(t1s, i) != expected {
|
||||
t.Error("all bits should be set except 72")
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user