3
0
mirror of https://github.com/ergochat/ergo.git synced 2024-11-22 03:49:27 +01:00

atomic bitset implementations of caps.Set and modes.ModeSet

This commit is contained in:
Shivaram Lingamneni 2018-06-25 18:08:15 -04:00
parent cdbb369a9c
commit 2a33c1483b
12 changed files with 577 additions and 174 deletions

View File

@ -1,5 +1,7 @@
.PHONY: all build
capdef_file = ./irc/caps/defs.go
all: build
build:
@ -8,11 +10,16 @@ build:
buildrelease:
goreleaser --skip-publish --rm-dist
capdefs:
python3 ./gencapdefs.py > ${capdef_file}
deps:
git submodule update --init
test:
python3 ./gencapdefs.py | diff - ${capdef_file}
cd irc && go test . && go vet .
cd irc/caps && go test . && go vet .
cd irc/isupport && go test . && go vet .
cd irc/modes && go test . && go vet .
cd irc/utils && go test . && go vet .

204
gencapdefs.py Normal file
View File

@ -0,0 +1,204 @@
#!/usr/bin/env python3
"""
Updates the capability definitions at irc/caps/defs.go
To add a capability, add it to the CAPDEFS list below,
then run `make capdefs` from the project root.
"""
import io
import subprocess
import sys
from collections import namedtuple
CapDef = namedtuple("CapDef", ['identifier', 'name', 'url', 'standard'])
CAPDEFS = [
CapDef(
identifier="LabelTagName",
name="draft/label",
url="https://ircv3.net/specs/extensions/labeled-response.html",
standard="draft IRCv3 tag name",
),
CapDef(
identifier="AccountNotify",
name="account-notify",
url="https://ircv3.net/specs/extensions/account-notify-3.1.html",
standard="IRCv3",
),
CapDef(
identifier="AccountTag",
name="account-tag",
url="https://ircv3.net/specs/extensions/account-tag-3.2.html",
standard="IRCv3",
),
CapDef(
identifier="AwayNotify",
name="away-notify",
url="https://ircv3.net/specs/extensions/away-notify-3.1.html",
standard="IRCv3",
),
CapDef(
identifier="Batch",
name="batch",
url="https://ircv3.net/specs/extensions/batch-3.2.html",
standard="IRCv3",
),
CapDef(
identifier="CapNotify",
name="cap-notify",
url="https://ircv3.net/specs/extensions/cap-notify-3.2.html",
standard="IRCv3",
),
CapDef(
identifier="ChgHost",
name="chghost",
url="https://ircv3.net/specs/extensions/chghost-3.2.html",
standard="IRCv3",
),
CapDef(
identifier="EchoMessage",
name="echo-message",
url="https://ircv3.net/specs/extensions/echo-message-3.2.html",
standard="IRCv3",
),
CapDef(
identifier="ExtendedJoin",
name="extended-join",
url="https://ircv3.net/specs/extensions/extended-join-3.1.html",
standard="IRCv3",
),
CapDef(
identifier="InviteNotify",
name="invite-notify",
url="https://ircv3.net/specs/extensions/invite-notify-3.2.html",
standard="IRCv3",
),
CapDef(
identifier="LabeledResponse",
name="draft/labeled-response",
url="https://ircv3.net/specs/extensions/labeled-response.html",
standard="draft IRCv3",
),
CapDef(
identifier="Languages",
name="draft/languages",
url="https://gist.github.com/DanielOaks/8126122f74b26012a3de37db80e4e0c6",
standard="proposed IRCv3",
),
CapDef(
identifier="MaxLine",
name="oragono.io/maxline",
url="https://oragono.io/maxline",
standard="Oragono-specific",
),
CapDef(
identifier="MessageTags",
name="draft/message-tags-0.2",
url="https://ircv3.net/specs/core/message-tags-3.3.html",
standard="draft IRCv3",
),
CapDef(
identifier="MultiPrefix",
name="multi-prefix",
url="https://ircv3.net/specs/extensions/multi-prefix-3.1.html",
standard="IRCv3",
),
CapDef(
identifier="Rename",
name="draft/rename",
url="https://github.com/SaberUK/ircv3-specifications/blob/rename/extensions/rename.md",
standard="proposed IRCv3",
),
CapDef(
identifier="Resume",
name="draft/resume",
url="https://github.com/DanielOaks/ircv3-specifications/blob/master+resume/extensions/resume.md",
standard="proposed IRCv3",
),
CapDef(
identifier="SASL",
name="sasl",
url="https://ircv3.net/specs/extensions/sasl-3.2.html",
standard="IRCv3",
),
CapDef(
identifier="ServerTime",
name="server-time",
url="https://ircv3.net/specs/extensions/server-time-3.2.html",
standard="IRCv3",
),
CapDef(
identifier="STS",
name="sts",
url="https://ircv3.net/specs/extensions/sts.html",
standard="IRCv3",
),
CapDef(
identifier="UserhostInNames",
name="userhost-in-names",
url="https://ircv3.net/specs/extensions/userhost-in-names-3.2.html",
standard="IRCv3",
),
]
def validate_defs():
numCaps = len(CAPDEFS)
numNames = len(set(capdef.name for capdef in CAPDEFS))
if numCaps != numNames:
raise Exception("defs must have unique names, but found duplicates")
numIdentifiers = len(set(capdef.identifier for capdef in CAPDEFS))
if numCaps != numIdentifiers:
raise Exception("defs must have unique identifiers, but found duplicates")
def main():
validate_defs()
output = io.StringIO()
print("""
package caps
/*
WARNING: this file is autogenerated by `make capdefs`
DO NOT EDIT MANUALLY.
*/
""", file=output)
numCapabs = len(CAPDEFS)
bitsetLen = numCapabs // 64
if numCapabs % 64 > 0:
bitsetLen += 1
print ("""
const (
// number of recognized capabilities:
numCapabs = %d
// length of the uint64 array that represents the bitset:
bitsetLen = %d
)
""" % (numCapabs, bitsetLen), file=output)
print("const (", file=output)
for capdef in CAPDEFS:
print("// %s is the %s capability named \"%s\":" % (capdef.identifier, capdef.standard, capdef.name), file=output)
print("// %s" % (capdef.url,), file=output)
print("%s Capability = iota" % (capdef.identifier,), file=output)
print(file=output)
print(")", file=output)
print("""var ( capabilityNames = [numCapabs]string{""", file=output)
for capdef in CAPDEFS:
print("\"%s\"," % (capdef.name,), file=output)
print("})", file=output)
gofmt = subprocess.Popen(['gofmt', '-s'], stdin=subprocess.PIPE)
gofmt.communicate(input=output.getvalue().encode('utf-8'))
if gofmt.poll() != 0:
print(output.getvalue())
raise Exception("gofmt failed")
return 0
if __name__ == '__main__':
sys.exit(main())

View File

@ -3,58 +3,30 @@
package caps
import "errors"
// Capability represents an optional feature that a client may request from the server.
type Capability string
type Capability uint
const (
// LabelTagName is the tag name used for the labeled-response spec.
LabelTagName = "draft/label"
// actual capability definitions appear in defs.go
// AccountNotify is this IRCv3 capability: http://ircv3.net/specs/extensions/account-notify-3.1.html
AccountNotify Capability = "account-notify"
// AccountTag is this IRCv3 capability: http://ircv3.net/specs/extensions/account-tag-3.2.html
AccountTag Capability = "account-tag"
// AwayNotify is this IRCv3 capability: http://ircv3.net/specs/extensions/away-notify-3.1.html
AwayNotify Capability = "away-notify"
// Batch is this IRCv3 capability: http://ircv3.net/specs/extensions/batch-3.2.html
Batch Capability = "batch"
// CapNotify is this IRCv3 capability: http://ircv3.net/specs/extensions/cap-notify-3.2.html
CapNotify Capability = "cap-notify"
// ChgHost is this IRCv3 capability: http://ircv3.net/specs/extensions/chghost-3.2.html
ChgHost Capability = "chghost"
// EchoMessage is this IRCv3 capability: http://ircv3.net/specs/extensions/echo-message-3.2.html
EchoMessage Capability = "echo-message"
// ExtendedJoin is this IRCv3 capability: http://ircv3.net/specs/extensions/extended-join-3.1.html
ExtendedJoin Capability = "extended-join"
// InviteNotify is this IRCv3 capability: http://ircv3.net/specs/extensions/invite-notify-3.2.html
InviteNotify Capability = "invite-notify"
// LabeledResponse is this draft IRCv3 capability: http://ircv3.net/specs/extensions/labeled-response.html
LabeledResponse Capability = "draft/labeled-response"
// Languages is this proposed IRCv3 capability: https://gist.github.com/DanielOaks/8126122f74b26012a3de37db80e4e0c6
Languages Capability = "draft/languages"
// MaxLine is this capability: https://oragono.io/maxline
MaxLine Capability = "oragono.io/maxline"
// MessageTags is this draft IRCv3 capability: http://ircv3.net/specs/core/message-tags-3.3.html
MessageTags Capability = "draft/message-tags-0.2"
// MultiPrefix is this IRCv3 capability: http://ircv3.net/specs/extensions/multi-prefix-3.1.html
MultiPrefix Capability = "multi-prefix"
// Rename is this proposed capability: https://github.com/SaberUK/ircv3-specifications/blob/rename/extensions/rename.md
Rename Capability = "draft/rename"
// Resume is this proposed capability: https://github.com/DanielOaks/ircv3-specifications/blob/master+resume/extensions/resume.md
Resume Capability = "draft/resume"
// SASL is this IRCv3 capability: http://ircv3.net/specs/extensions/sasl-3.2.html
SASL Capability = "sasl"
// ServerTime is this IRCv3 capability: http://ircv3.net/specs/extensions/server-time-3.2.html
ServerTime Capability = "server-time"
// STS is this IRCv3 capability: http://ircv3.net/specs/extensions/sts.html
STS Capability = "sts"
// UserhostInNames is this IRCv3 capability: http://ircv3.net/specs/extensions/userhost-in-names-3.2.html
UserhostInNames Capability = "userhost-in-names"
var (
nameToCapability map[string]Capability
NoSuchCap = errors.New("Unsupported capability name")
)
// Name returns the name of the given capability.
func (capability Capability) Name() string {
return string(capability)
return capabilityNames[capability]
}
func NameToCapability(name string) (result Capability, err error) {
result, found := nameToCapability[name]
if !found {
err = NoSuchCap
}
return
}
// Version is used to select which max version of CAP the client supports.
@ -78,3 +50,10 @@ const (
// NegotiatedState means CAP negotiation has been successfully ended and reg should complete.
NegotiatedState State = iota
)
func init() {
nameToCapability = make(map[string]Capability)
for capab, name := range capabilityNames {
nameToCapability[name] = Capability(capab)
}
}

125
irc/caps/defs.go Normal file
View File

@ -0,0 +1,125 @@
package caps
/*
WARNING: this file is autogenerated by `make capdefs`
DO NOT EDIT MANUALLY.
*/
const (
// number of recognized capabilities:
numCapabs = 21
// length of the uint64 array that represents the bitset:
bitsetLen = 1
)
const (
// LabelTagName is the draft IRCv3 tag name capability named "draft/label":
// https://ircv3.net/specs/extensions/labeled-response.html
LabelTagName Capability = iota
// AccountNotify is the IRCv3 capability named "account-notify":
// https://ircv3.net/specs/extensions/account-notify-3.1.html
AccountNotify Capability = iota
// AccountTag is the IRCv3 capability named "account-tag":
// https://ircv3.net/specs/extensions/account-tag-3.2.html
AccountTag Capability = iota
// AwayNotify is the IRCv3 capability named "away-notify":
// https://ircv3.net/specs/extensions/away-notify-3.1.html
AwayNotify Capability = iota
// Batch is the IRCv3 capability named "batch":
// https://ircv3.net/specs/extensions/batch-3.2.html
Batch Capability = iota
// CapNotify is the IRCv3 capability named "cap-notify":
// https://ircv3.net/specs/extensions/cap-notify-3.2.html
CapNotify Capability = iota
// ChgHost is the IRCv3 capability named "chghost":
// https://ircv3.net/specs/extensions/chghost-3.2.html
ChgHost Capability = iota
// EchoMessage is the IRCv3 capability named "echo-message":
// https://ircv3.net/specs/extensions/echo-message-3.2.html
EchoMessage Capability = iota
// ExtendedJoin is the IRCv3 capability named "extended-join":
// https://ircv3.net/specs/extensions/extended-join-3.1.html
ExtendedJoin Capability = iota
// InviteNotify is the IRCv3 capability named "invite-notify":
// https://ircv3.net/specs/extensions/invite-notify-3.2.html
InviteNotify Capability = iota
// LabeledResponse is the draft IRCv3 capability named "draft/labeled-response":
// https://ircv3.net/specs/extensions/labeled-response.html
LabeledResponse Capability = iota
// Languages is the proposed IRCv3 capability named "draft/languages":
// https://gist.github.com/DanielOaks/8126122f74b26012a3de37db80e4e0c6
Languages Capability = iota
// MaxLine is the Oragono-specific capability named "oragono.io/maxline":
// https://oragono.io/maxline
MaxLine Capability = iota
// MessageTags is the draft IRCv3 capability named "draft/message-tags-0.2":
// https://ircv3.net/specs/core/message-tags-3.3.html
MessageTags Capability = iota
// MultiPrefix is the IRCv3 capability named "multi-prefix":
// https://ircv3.net/specs/extensions/multi-prefix-3.1.html
MultiPrefix Capability = iota
// Rename is the proposed IRCv3 capability named "draft/rename":
// https://github.com/SaberUK/ircv3-specifications/blob/rename/extensions/rename.md
Rename Capability = iota
// Resume is the proposed IRCv3 capability named "draft/resume":
// https://github.com/DanielOaks/ircv3-specifications/blob/master+resume/extensions/resume.md
Resume Capability = iota
// SASL is the IRCv3 capability named "sasl":
// https://ircv3.net/specs/extensions/sasl-3.2.html
SASL Capability = iota
// ServerTime is the IRCv3 capability named "server-time":
// https://ircv3.net/specs/extensions/server-time-3.2.html
ServerTime Capability = iota
// STS is the IRCv3 capability named "sts":
// https://ircv3.net/specs/extensions/sts.html
STS Capability = iota
// UserhostInNames is the IRCv3 capability named "userhost-in-names":
// https://ircv3.net/specs/extensions/userhost-in-names-3.2.html
UserhostInNames Capability = iota
)
var (
capabilityNames = [numCapabs]string{
"draft/label",
"account-notify",
"account-tag",
"away-notify",
"batch",
"cap-notify",
"chghost",
"echo-message",
"extended-join",
"invite-notify",
"draft/labeled-response",
"draft/languages",
"oragono.io/maxline",
"draft/message-tags-0.2",
"multi-prefix",
"draft/rename",
"draft/resume",
"sasl",
"server-time",
"sts",
"userhost-in-names",
}
)

View File

@ -6,43 +6,34 @@ package caps
import (
"sort"
"strings"
"sync"
"github.com/oragono/oragono/irc/utils"
)
// Set holds a set of enabled capabilities.
type Set struct {
sync.RWMutex
// capabilities holds the capabilities this manager has.
capabilities map[Capability]bool
}
type Set [bitsetLen]uint64
// NewSet returns a new Set, with the given capabilities enabled.
func NewSet(capabs ...Capability) *Set {
newSet := Set{
capabilities: make(map[Capability]bool),
}
var newSet Set
utils.BitsetInitialize(newSet[:])
newSet.Enable(capabs...)
return &newSet
}
// Enable enables the given capabilities.
func (s *Set) Enable(capabs ...Capability) {
s.Lock()
defer s.Unlock()
asSlice := s[:]
for _, capab := range capabs {
s.capabilities[capab] = true
utils.BitsetSet(asSlice, uint(capab), true)
}
}
// Disable disables the given capabilities.
func (s *Set) Disable(capabs ...Capability) {
s.Lock()
defer s.Unlock()
asSlice := s[:]
for _, capab := range capabs {
delete(s.capabilities, capab)
utils.BitsetSet(asSlice, uint(capab), false)
}
}
@ -58,51 +49,35 @@ func (s *Set) Remove(capabs ...Capability) {
s.Disable(capabs...)
}
// Has returns true if this set has the given capabilities.
func (s *Set) Has(caps ...Capability) bool {
s.RLock()
defer s.RUnlock()
for _, cap := range caps {
if !s.capabilities[cap] {
return false
}
}
return true
// Has returns true if this set has the given capability.
func (s *Set) Has(capab Capability) bool {
return utils.BitsetGet(s[:], uint(capab))
}
// List return a list of our enabled capabilities.
func (s *Set) List() []Capability {
s.RLock()
defer s.RUnlock()
var allCaps []Capability
for capab := range s.capabilities {
allCaps = append(allCaps, capab)
}
return allCaps
// Union adds all the capabilities of another set to this set.
func (s *Set) Union(other *Set) {
utils.BitsetUnion(s[:], other[:])
}
// Count returns how many enabled caps this set has.
func (s *Set) Count() int {
s.RLock()
defer s.RUnlock()
return len(s.capabilities)
// Empty returns whether the set is empty.
func (s *Set) Empty() bool {
return utils.BitsetEmpty(s[:])
}
// String returns all of our enabled capabilities as a string.
func (s *Set) String(version Version, values *Values) string {
s.RLock()
defer s.RUnlock()
var strs sort.StringSlice
for capability := range s.capabilities {
capString := capability.Name()
var capab Capability
asSlice := s[:]
for capab = 0; capab < numCapabs; capab++ {
// skip any capabilities that are not enabled
if !utils.BitsetGet(asSlice, uint(capab)) {
continue
}
capString := capab.Name()
if version == Cap302 {
val, exists := values.Get(capability)
val, exists := values.Get(capab)
if exists {
capString += "=" + val
}

View File

@ -11,12 +11,12 @@ func TestSets(t *testing.T) {
s1.Enable(AccountTag, EchoMessage, UserhostInNames)
if !s1.Has(AccountTag, EchoMessage, UserhostInNames) {
if !(s1.Has(AccountTag) && s1.Has(EchoMessage) && s1.Has(UserhostInNames)) {
t.Error("Did not have the tags we expected")
}
if s1.Has(AccountTag, EchoMessage, STS, UserhostInNames) {
t.Error("Has() returned true when we don't have all the given capabilities")
if s1.Has(STS) {
t.Error("Has() returned true when we don't have the given capability")
}
s1.Disable(AccountTag)
@ -25,14 +25,9 @@ func TestSets(t *testing.T) {
t.Error("Disable() did not correctly disable the given capability")
}
enabledCaps := make(map[Capability]bool)
for _, capab := range s1.List() {
enabledCaps[capab] = true
}
expectedCaps := map[Capability]bool{
EchoMessage: true,
UserhostInNames: true,
}
enabledCaps := NewSet()
enabledCaps.Union(s1)
expectedCaps := NewSet(EchoMessage, UserhostInNames)
if !reflect.DeepEqual(enabledCaps, expectedCaps) {
t.Errorf("Enabled and expected capability lists do not match: %v, %v", enabledCaps, expectedCaps)
}
@ -40,16 +35,12 @@ func TestSets(t *testing.T) {
// make sure re-enabling doesn't add to the count or something weird like that
s1.Enable(EchoMessage)
if s1.Count() != 2 {
t.Error("Count() did not match expected capability count")
}
// make sure add and remove work fine
s1.Add(InviteNotify)
s1.Remove(EchoMessage)
if s1.Count() != 2 {
t.Error("Count() did not match expected capability count")
if !s1.Has(InviteNotify) || s1.Has(EchoMessage) {
t.Error("Add/Remove don't work")
}
// test String()

View File

@ -442,12 +442,16 @@ func capHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Respo
capabilities := caps.NewSet()
var capString string
var badCaps []string
if len(msg.Params) > 1 {
capString = msg.Params[1]
strs := strings.Split(capString, " ")
strs := strings.Fields(capString)
for _, str := range strs {
if len(str) > 0 {
capabilities.Enable(caps.Capability(str))
capab, err := caps.NameToCapability(str)
if err != nil || !SupportedCapabilities.Has(capab) {
badCaps = append(badCaps, str)
} else {
capabilities.Enable(capab)
}
}
}
@ -475,13 +479,11 @@ func capHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Respo
}
// make sure all capabilities actually exist
for _, capability := range capabilities.List() {
if !SupportedCapabilities.Has(capability) {
rb.Add(nil, server.name, "CAP", client.nick, "NAK", capString)
return false
}
if len(badCaps) > 0 {
rb.Add(nil, server.name, "CAP", client.nick, "NAK", capString)
return false
}
client.capabilities.Enable(capabilities.List()...)
client.capabilities.Union(capabilities)
rb.Add(nil, server.name, "CAP", client.nick, "ACK", capString)
case "END":

View File

@ -7,7 +7,9 @@ package modes
import (
"strings"
"sync"
"sync/atomic"
"github.com/oragono/oragono/irc/utils"
)
var (
@ -322,42 +324,29 @@ func ParseChannelModeChanges(params ...string) (ModeChanges, map[rune]bool) {
}
// ModeSet holds a set of modes.
type ModeSet struct {
sync.RWMutex // tier 0
modes map[Mode]bool
}
type ModeSet [1]uint64
// valid modes go from 65 ('A') to 122 ('z'), making at most 58 possible values;
// subtract 65 from the mode value and use that bit of the uint64 to represent it
const (
minMode = 65 // 'A'
)
// returns a pointer to a new ModeSet
func NewModeSet() *ModeSet {
return &ModeSet{
modes: make(map[Mode]bool),
}
var set ModeSet
utils.BitsetInitialize(set[:])
return &set
}
// test whether `mode` is set
func (set *ModeSet) HasMode(mode Mode) bool {
if set == nil {
return false
}
set.RLock()
defer set.RUnlock()
return set.modes[mode]
return utils.BitsetGet(set[:], uint(mode)-minMode)
}
// set `mode` to be on or off, return whether the value actually changed
func (set *ModeSet) SetMode(mode Mode, on bool) (applied bool) {
set.Lock()
defer set.Unlock()
previouslyOn := set.modes[mode]
needsApply := (on != previouslyOn)
if on && needsApply {
set.modes[mode] = true
} else if !on && needsApply {
delete(set.modes, mode)
}
return needsApply
return utils.BitsetSet(set[:], uint(mode)-minMode, on)
}
// return the modes in the set as a slice
@ -366,11 +355,12 @@ func (set *ModeSet) AllModes() (result []Mode) {
return
}
set.RLock()
defer set.RUnlock()
for mode := range set.modes {
result = append(result, mode)
block := atomic.LoadUint64(&set[0])
var i uint
for i = 0; i < 64; i++ {
if block&(1<<i) != 0 {
result = append(result, Mode(minMode+i))
}
}
return
}
@ -381,11 +371,8 @@ func (set *ModeSet) String() (result string) {
return
}
set.RLock()
defer set.RUnlock()
var buf strings.Builder
for mode := range set.modes {
for _, mode := range set.AllModes() {
buf.WriteRune(rune(mode))
}
return buf.String()
@ -397,12 +384,9 @@ func (set *ModeSet) Prefixes(isMultiPrefix bool) (prefixes string) {
return
}
set.RLock()
defer set.RUnlock()
// add prefixes in order from highest to lowest privs
for _, mode := range ChannelUserModes {
if set.modes[mode] {
if set.HasMode(mode) {
prefixes += ChannelModePrefixes[mode]
}
}

View File

@ -23,7 +23,7 @@ type ResponseBuffer struct {
// GetLabel returns the label from the given message.
func GetLabel(msg ircmsg.IrcMessage) string {
return msg.Tags[caps.LabelTagName].Value
return msg.Tags[caps.LabelTagName.Name()].Value
}
// NewResponseBuffer returns a new ResponseBuffer.
@ -90,13 +90,13 @@ func (rb *ResponseBuffer) Send() error {
// if label but no batch, add label to first message
if useLabel && batch == nil {
message := rb.messages[0]
message.Tags[caps.LabelTagName] = ircmsg.MakeTagValue(rb.Label)
message.Tags[caps.LabelTagName.Name()] = ircmsg.MakeTagValue(rb.Label)
rb.messages[0] = message
}
// start batch if required
if batch != nil {
batch.Start(rb.target, ircmsg.MakeTags(caps.LabelTagName, rb.Label))
batch.Start(rb.target, ircmsg.MakeTags(caps.LabelTagName.Name(), rb.Label))
}
// send each message out

View File

@ -898,13 +898,11 @@ func (server *Server) applyConfig(config *Config, initial bool) error {
// updated caps get DEL'd and then NEW'd
// so, we can just add updated ones to both removed and added lists here and they'll be correctly handled
server.logger.Debug("rehash", "Updated Caps", updatedCaps.String(caps.Cap301, CapValues), strconv.Itoa(updatedCaps.Count()))
for _, capab := range updatedCaps.List() {
addedCaps.Enable(capab)
removedCaps.Enable(capab)
}
server.logger.Debug("rehash", "Updated Caps", updatedCaps.String(caps.Cap301, CapValues))
addedCaps.Union(updatedCaps)
removedCaps.Union(updatedCaps)
if 0 < addedCaps.Count() || 0 < removedCaps.Count() {
if !addedCaps.Empty() || !removedCaps.Empty() {
capBurstClients = server.clients.AllWithCaps(caps.CapNotify)
added[caps.Cap301] = addedCaps.String(caps.Cap301, CapValues)
@ -918,7 +916,7 @@ func (server *Server) applyConfig(config *Config, initial bool) error {
// remove STS policy
//TODO(dan): this is an ugly hack. we can write this better.
stsPolicy := "sts=duration=0"
if 0 < addedCaps.Count() {
if !addedCaps.Empty() {
added[caps.Cap302] = added[caps.Cap302] + " " + stsPolicy
} else {
addedCaps.Enable(caps.STS)
@ -926,10 +924,10 @@ func (server *Server) applyConfig(config *Config, initial bool) error {
}
}
// DEL caps and then send NEW ones so that updated caps get removed/added correctly
if 0 < removedCaps.Count() {
if !removedCaps.Empty() {
sClient.Send(nil, server.name, "CAP", sClient.nick, "DEL", removed)
}
if 0 < addedCaps.Count() {
if !addedCaps.Empty() {
sClient.Send(nil, server.name, "CAP", sClient.nick, "NEW", added[sClient.capVersion])
}
}

86
irc/utils/bitset.go Normal file
View File

@ -0,0 +1,86 @@
// Copyright (c) 2018 Shivaram Lingamneni <slingamn@cs.stanford.edu>
// released under the MIT license
package utils
import "sync/atomic"
// Library functions for lock-free bitsets, typically (constant-sized) arrays of uint64.
// For examples of use, see caps.Set and modes.ModeSet; the array has to be converted to a
// slice to use these functions.
// BitsetInitialize initializes a bitset.
func BitsetInitialize(set []uint64) {
// XXX re-zero the bitset using atomic stores. it's unclear whether this is required,
// however, golang issue #5045 suggests that you shouldn't mix atomic operations
// with non-atomic operations (such as the runtime's automatic zero-initialization) on
// the same word
for i := 0; i < len(set); i++ {
atomic.StoreUint64(&set[i], 0)
}
}
// BitsetGet returns whether a given bit of the bitset is set.
func BitsetGet(set []uint64, position uint) bool {
idx := position / 64
bit := position % 64
block := atomic.LoadUint64(&set[idx])
return (block & (1 << bit)) != 0
}
// BitsetSet sets a given bit of the bitset to 0 or 1, returning whether it changed.
func BitsetSet(set []uint64, position uint, on bool) (changed bool) {
idx := position / 64
bit := position % 64
addr := &set[idx]
var mask uint64
mask = 1 << bit
for {
current := atomic.LoadUint64(addr)
previouslyOn := (current & mask) != 0
if on == previouslyOn {
return false
}
var desired uint64
if on {
desired = current | mask
} else {
desired = current & (^mask)
}
if atomic.CompareAndSwapUint64(addr, current, desired) {
return true
}
}
}
// BitsetEmpty returns whether the bitset is empty.
// Right now, this is technically free of race conditions because we don't
// have a method that can simultaneously modify two bits separated by a word boundary
// such that one of those modifications is an unset. If we did, there would be a race
// that could produce false positives. It's probably better to assume that they are
// already possible under concurrent modification (which is not how we're using this).
func BitsetEmpty(set []uint64) (empty bool) {
for i := 0; i < len(set); i++ {
if atomic.LoadUint64(&set[i]) != 0 {
return false
}
}
return true
}
// BitsetUnion modifies `set` to be the union of `set` and `other`.
// This has race conditions in that we don't necessarily get a single
// consistent view of `other` across word boundaries.
func BitsetUnion(set []uint64, other []uint64) {
for i := 0; i < len(set); i++ {
for {
ourAddr := &set[i]
ourBlock := atomic.LoadUint64(ourAddr)
otherBlock := atomic.LoadUint64(&other[i])
newBlock := ourBlock | otherBlock
if atomic.CompareAndSwapUint64(ourAddr, ourBlock, newBlock) {
break
}
}
}
}

52
irc/utils/bitset_test.go Normal file
View File

@ -0,0 +1,52 @@
// Copyright (c) 2018 Shivaram Lingamneni <slingamn@cs.stanford.edu>
// released under the MIT license
package utils
import "testing"
type testBitset [2]uint64
func TestSets(t *testing.T) {
var t1 testBitset
t1s := t1[:]
BitsetInitialize(t1s)
if BitsetGet(t1s, 0) || BitsetGet(t1s, 63) || BitsetGet(t1s, 64) || BitsetGet(t1s, 127) {
t.Error("no bits should be set in a newly initialized bitset")
}
var i uint
for i = 0; i < 128; i++ {
if i%2 == 0 {
BitsetSet(t1s, i, true)
}
}
if !(BitsetGet(t1s, 0) && !BitsetGet(t1s, 1) && BitsetGet(t1s, 64) && BitsetGet(t1s, 72) && !BitsetGet(t1s, 127)) {
t.Error("exactly the even-numbered bits should be set")
}
BitsetSet(t1s, 72, false)
if BitsetGet(t1s, 72) {
t.Error("remove doesn't work")
}
var t2 testBitset
t2s := t2[:]
BitsetInitialize(t2s)
for i = 0; i < 128; i++ {
if i%2 == 1 {
BitsetSet(t2s, i, true)
}
}
BitsetUnion(t1s, t2s)
for i = 0; i < 128; i++ {
expected := (i != 72)
if BitsetGet(t1s, i) != expected {
t.Error("all bits should be set except 72")
}
}
}