From 2a33c1483b671914ecb74e9c2a7248816459386d Mon Sep 17 00:00:00 2001 From: Shivaram Lingamneni Date: Mon, 25 Jun 2018 18:08:15 -0400 Subject: [PATCH] atomic bitset implementations of caps.Set and modes.ModeSet --- Makefile | 7 ++ gencapdefs.py | 204 +++++++++++++++++++++++++++++++++++++++ irc/caps/constants.go | 69 +++++-------- irc/caps/defs.go | 125 ++++++++++++++++++++++++ irc/caps/set.go | 79 ++++++--------- irc/caps/set_test.go | 25 ++--- irc/handlers.go | 20 ++-- irc/modes/modes.go | 62 +++++------- irc/responsebuffer.go | 6 +- irc/server.go | 16 ++- irc/utils/bitset.go | 86 +++++++++++++++++ irc/utils/bitset_test.go | 52 ++++++++++ 12 files changed, 577 insertions(+), 174 deletions(-) create mode 100644 gencapdefs.py create mode 100644 irc/caps/defs.go create mode 100644 irc/utils/bitset.go create mode 100644 irc/utils/bitset_test.go diff --git a/Makefile b/Makefile index 572cd3a6..32ea0c27 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,7 @@ .PHONY: all build +capdef_file = ./irc/caps/defs.go + all: build build: @@ -8,11 +10,16 @@ build: buildrelease: goreleaser --skip-publish --rm-dist +capdefs: + python3 ./gencapdefs.py > ${capdef_file} + deps: git submodule update --init test: + python3 ./gencapdefs.py | diff - ${capdef_file} cd irc && go test . && go vet . + cd irc/caps && go test . && go vet . cd irc/isupport && go test . && go vet . cd irc/modes && go test . && go vet . cd irc/utils && go test . && go vet . diff --git a/gencapdefs.py b/gencapdefs.py new file mode 100644 index 00000000..f828a38d --- /dev/null +++ b/gencapdefs.py @@ -0,0 +1,204 @@ +#!/usr/bin/env python3 + +""" +Updates the capability definitions at irc/caps/defs.go + +To add a capability, add it to the CAPDEFS list below, +then run `make capdefs` from the project root. +""" + +import io +import subprocess +import sys +from collections import namedtuple + +CapDef = namedtuple("CapDef", ['identifier', 'name', 'url', 'standard']) + +CAPDEFS = [ + CapDef( + identifier="LabelTagName", + name="draft/label", + url="https://ircv3.net/specs/extensions/labeled-response.html", + standard="draft IRCv3 tag name", + ), + CapDef( + identifier="AccountNotify", + name="account-notify", + url="https://ircv3.net/specs/extensions/account-notify-3.1.html", + standard="IRCv3", + ), + CapDef( + identifier="AccountTag", + name="account-tag", + url="https://ircv3.net/specs/extensions/account-tag-3.2.html", + standard="IRCv3", + ), + CapDef( + identifier="AwayNotify", + name="away-notify", + url="https://ircv3.net/specs/extensions/away-notify-3.1.html", + standard="IRCv3", + ), + CapDef( + identifier="Batch", + name="batch", + url="https://ircv3.net/specs/extensions/batch-3.2.html", + standard="IRCv3", + ), + CapDef( + identifier="CapNotify", + name="cap-notify", + url="https://ircv3.net/specs/extensions/cap-notify-3.2.html", + standard="IRCv3", + ), + CapDef( + identifier="ChgHost", + name="chghost", + url="https://ircv3.net/specs/extensions/chghost-3.2.html", + standard="IRCv3", + ), + CapDef( + identifier="EchoMessage", + name="echo-message", + url="https://ircv3.net/specs/extensions/echo-message-3.2.html", + standard="IRCv3", + ), + CapDef( + identifier="ExtendedJoin", + name="extended-join", + url="https://ircv3.net/specs/extensions/extended-join-3.1.html", + standard="IRCv3", + ), + CapDef( + identifier="InviteNotify", + name="invite-notify", + url="https://ircv3.net/specs/extensions/invite-notify-3.2.html", + standard="IRCv3", + ), + CapDef( + identifier="LabeledResponse", + name="draft/labeled-response", + url="https://ircv3.net/specs/extensions/labeled-response.html", + standard="draft IRCv3", + ), + CapDef( + identifier="Languages", + name="draft/languages", + url="https://gist.github.com/DanielOaks/8126122f74b26012a3de37db80e4e0c6", + standard="proposed IRCv3", + ), + CapDef( + identifier="MaxLine", + name="oragono.io/maxline", + url="https://oragono.io/maxline", + standard="Oragono-specific", + ), + CapDef( + identifier="MessageTags", + name="draft/message-tags-0.2", + url="https://ircv3.net/specs/core/message-tags-3.3.html", + standard="draft IRCv3", + ), + CapDef( + identifier="MultiPrefix", + name="multi-prefix", + url="https://ircv3.net/specs/extensions/multi-prefix-3.1.html", + standard="IRCv3", + ), + CapDef( + identifier="Rename", + name="draft/rename", + url="https://github.com/SaberUK/ircv3-specifications/blob/rename/extensions/rename.md", + standard="proposed IRCv3", + ), + CapDef( + identifier="Resume", + name="draft/resume", + url="https://github.com/DanielOaks/ircv3-specifications/blob/master+resume/extensions/resume.md", + standard="proposed IRCv3", + ), + CapDef( + identifier="SASL", + name="sasl", + url="https://ircv3.net/specs/extensions/sasl-3.2.html", + standard="IRCv3", + ), + CapDef( + identifier="ServerTime", + name="server-time", + url="https://ircv3.net/specs/extensions/server-time-3.2.html", + standard="IRCv3", + ), + CapDef( + identifier="STS", + name="sts", + url="https://ircv3.net/specs/extensions/sts.html", + standard="IRCv3", + ), + CapDef( + identifier="UserhostInNames", + name="userhost-in-names", + url="https://ircv3.net/specs/extensions/userhost-in-names-3.2.html", + standard="IRCv3", + ), +] + +def validate_defs(): + numCaps = len(CAPDEFS) + numNames = len(set(capdef.name for capdef in CAPDEFS)) + if numCaps != numNames: + raise Exception("defs must have unique names, but found duplicates") + numIdentifiers = len(set(capdef.identifier for capdef in CAPDEFS)) + if numCaps != numIdentifiers: + raise Exception("defs must have unique identifiers, but found duplicates") + +def main(): + validate_defs() + output = io.StringIO() + print(""" +package caps + +/* + WARNING: this file is autogenerated by `make capdefs` + DO NOT EDIT MANUALLY. +*/ + + + """, file=output) + + + numCapabs = len(CAPDEFS) + bitsetLen = numCapabs // 64 + if numCapabs % 64 > 0: + bitsetLen += 1 + print (""" +const ( + // number of recognized capabilities: + numCapabs = %d + // length of the uint64 array that represents the bitset: + bitsetLen = %d +) + """ % (numCapabs, bitsetLen), file=output) + + print("const (", file=output) + for capdef in CAPDEFS: + print("// %s is the %s capability named \"%s\":" % (capdef.identifier, capdef.standard, capdef.name), file=output) + print("// %s" % (capdef.url,), file=output) + print("%s Capability = iota" % (capdef.identifier,), file=output) + print(file=output) + print(")", file=output) + + print("""var ( capabilityNames = [numCapabs]string{""", file=output) + for capdef in CAPDEFS: + print("\"%s\"," % (capdef.name,), file=output) + print("})", file=output) + + gofmt = subprocess.Popen(['gofmt', '-s'], stdin=subprocess.PIPE) + gofmt.communicate(input=output.getvalue().encode('utf-8')) + if gofmt.poll() != 0: + print(output.getvalue()) + raise Exception("gofmt failed") + return 0 + +if __name__ == '__main__': + sys.exit(main()) diff --git a/irc/caps/constants.go b/irc/caps/constants.go index 00e60a74..d618bd08 100644 --- a/irc/caps/constants.go +++ b/irc/caps/constants.go @@ -3,58 +3,30 @@ package caps +import "errors" + // Capability represents an optional feature that a client may request from the server. -type Capability string +type Capability uint -const ( - // LabelTagName is the tag name used for the labeled-response spec. - LabelTagName = "draft/label" +// actual capability definitions appear in defs.go - // AccountNotify is this IRCv3 capability: http://ircv3.net/specs/extensions/account-notify-3.1.html - AccountNotify Capability = "account-notify" - // AccountTag is this IRCv3 capability: http://ircv3.net/specs/extensions/account-tag-3.2.html - AccountTag Capability = "account-tag" - // AwayNotify is this IRCv3 capability: http://ircv3.net/specs/extensions/away-notify-3.1.html - AwayNotify Capability = "away-notify" - // Batch is this IRCv3 capability: http://ircv3.net/specs/extensions/batch-3.2.html - Batch Capability = "batch" - // CapNotify is this IRCv3 capability: http://ircv3.net/specs/extensions/cap-notify-3.2.html - CapNotify Capability = "cap-notify" - // ChgHost is this IRCv3 capability: http://ircv3.net/specs/extensions/chghost-3.2.html - ChgHost Capability = "chghost" - // EchoMessage is this IRCv3 capability: http://ircv3.net/specs/extensions/echo-message-3.2.html - EchoMessage Capability = "echo-message" - // ExtendedJoin is this IRCv3 capability: http://ircv3.net/specs/extensions/extended-join-3.1.html - ExtendedJoin Capability = "extended-join" - // InviteNotify is this IRCv3 capability: http://ircv3.net/specs/extensions/invite-notify-3.2.html - InviteNotify Capability = "invite-notify" - // LabeledResponse is this draft IRCv3 capability: http://ircv3.net/specs/extensions/labeled-response.html - LabeledResponse Capability = "draft/labeled-response" - // Languages is this proposed IRCv3 capability: https://gist.github.com/DanielOaks/8126122f74b26012a3de37db80e4e0c6 - Languages Capability = "draft/languages" - // MaxLine is this capability: https://oragono.io/maxline - MaxLine Capability = "oragono.io/maxline" - // MessageTags is this draft IRCv3 capability: http://ircv3.net/specs/core/message-tags-3.3.html - MessageTags Capability = "draft/message-tags-0.2" - // MultiPrefix is this IRCv3 capability: http://ircv3.net/specs/extensions/multi-prefix-3.1.html - MultiPrefix Capability = "multi-prefix" - // Rename is this proposed capability: https://github.com/SaberUK/ircv3-specifications/blob/rename/extensions/rename.md - Rename Capability = "draft/rename" - // Resume is this proposed capability: https://github.com/DanielOaks/ircv3-specifications/blob/master+resume/extensions/resume.md - Resume Capability = "draft/resume" - // SASL is this IRCv3 capability: http://ircv3.net/specs/extensions/sasl-3.2.html - SASL Capability = "sasl" - // ServerTime is this IRCv3 capability: http://ircv3.net/specs/extensions/server-time-3.2.html - ServerTime Capability = "server-time" - // STS is this IRCv3 capability: http://ircv3.net/specs/extensions/sts.html - STS Capability = "sts" - // UserhostInNames is this IRCv3 capability: http://ircv3.net/specs/extensions/userhost-in-names-3.2.html - UserhostInNames Capability = "userhost-in-names" +var ( + nameToCapability map[string]Capability + + NoSuchCap = errors.New("Unsupported capability name") ) // Name returns the name of the given capability. func (capability Capability) Name() string { - return string(capability) + return capabilityNames[capability] +} + +func NameToCapability(name string) (result Capability, err error) { + result, found := nameToCapability[name] + if !found { + err = NoSuchCap + } + return } // Version is used to select which max version of CAP the client supports. @@ -78,3 +50,10 @@ const ( // NegotiatedState means CAP negotiation has been successfully ended and reg should complete. NegotiatedState State = iota ) + +func init() { + nameToCapability = make(map[string]Capability) + for capab, name := range capabilityNames { + nameToCapability[name] = Capability(capab) + } +} diff --git a/irc/caps/defs.go b/irc/caps/defs.go new file mode 100644 index 00000000..7fb5d578 --- /dev/null +++ b/irc/caps/defs.go @@ -0,0 +1,125 @@ +package caps + +/* + WARNING: this file is autogenerated by `make capdefs` + DO NOT EDIT MANUALLY. +*/ + +const ( + // number of recognized capabilities: + numCapabs = 21 + // length of the uint64 array that represents the bitset: + bitsetLen = 1 +) + +const ( + // LabelTagName is the draft IRCv3 tag name capability named "draft/label": + // https://ircv3.net/specs/extensions/labeled-response.html + LabelTagName Capability = iota + + // AccountNotify is the IRCv3 capability named "account-notify": + // https://ircv3.net/specs/extensions/account-notify-3.1.html + AccountNotify Capability = iota + + // AccountTag is the IRCv3 capability named "account-tag": + // https://ircv3.net/specs/extensions/account-tag-3.2.html + AccountTag Capability = iota + + // AwayNotify is the IRCv3 capability named "away-notify": + // https://ircv3.net/specs/extensions/away-notify-3.1.html + AwayNotify Capability = iota + + // Batch is the IRCv3 capability named "batch": + // https://ircv3.net/specs/extensions/batch-3.2.html + Batch Capability = iota + + // CapNotify is the IRCv3 capability named "cap-notify": + // https://ircv3.net/specs/extensions/cap-notify-3.2.html + CapNotify Capability = iota + + // ChgHost is the IRCv3 capability named "chghost": + // https://ircv3.net/specs/extensions/chghost-3.2.html + ChgHost Capability = iota + + // EchoMessage is the IRCv3 capability named "echo-message": + // https://ircv3.net/specs/extensions/echo-message-3.2.html + EchoMessage Capability = iota + + // ExtendedJoin is the IRCv3 capability named "extended-join": + // https://ircv3.net/specs/extensions/extended-join-3.1.html + ExtendedJoin Capability = iota + + // InviteNotify is the IRCv3 capability named "invite-notify": + // https://ircv3.net/specs/extensions/invite-notify-3.2.html + InviteNotify Capability = iota + + // LabeledResponse is the draft IRCv3 capability named "draft/labeled-response": + // https://ircv3.net/specs/extensions/labeled-response.html + LabeledResponse Capability = iota + + // Languages is the proposed IRCv3 capability named "draft/languages": + // https://gist.github.com/DanielOaks/8126122f74b26012a3de37db80e4e0c6 + Languages Capability = iota + + // MaxLine is the Oragono-specific capability named "oragono.io/maxline": + // https://oragono.io/maxline + MaxLine Capability = iota + + // MessageTags is the draft IRCv3 capability named "draft/message-tags-0.2": + // https://ircv3.net/specs/core/message-tags-3.3.html + MessageTags Capability = iota + + // MultiPrefix is the IRCv3 capability named "multi-prefix": + // https://ircv3.net/specs/extensions/multi-prefix-3.1.html + MultiPrefix Capability = iota + + // Rename is the proposed IRCv3 capability named "draft/rename": + // https://github.com/SaberUK/ircv3-specifications/blob/rename/extensions/rename.md + Rename Capability = iota + + // Resume is the proposed IRCv3 capability named "draft/resume": + // https://github.com/DanielOaks/ircv3-specifications/blob/master+resume/extensions/resume.md + Resume Capability = iota + + // SASL is the IRCv3 capability named "sasl": + // https://ircv3.net/specs/extensions/sasl-3.2.html + SASL Capability = iota + + // ServerTime is the IRCv3 capability named "server-time": + // https://ircv3.net/specs/extensions/server-time-3.2.html + ServerTime Capability = iota + + // STS is the IRCv3 capability named "sts": + // https://ircv3.net/specs/extensions/sts.html + STS Capability = iota + + // UserhostInNames is the IRCv3 capability named "userhost-in-names": + // https://ircv3.net/specs/extensions/userhost-in-names-3.2.html + UserhostInNames Capability = iota +) + +var ( + capabilityNames = [numCapabs]string{ + "draft/label", + "account-notify", + "account-tag", + "away-notify", + "batch", + "cap-notify", + "chghost", + "echo-message", + "extended-join", + "invite-notify", + "draft/labeled-response", + "draft/languages", + "oragono.io/maxline", + "draft/message-tags-0.2", + "multi-prefix", + "draft/rename", + "draft/resume", + "sasl", + "server-time", + "sts", + "userhost-in-names", + } +) diff --git a/irc/caps/set.go b/irc/caps/set.go index f18852bd..867617b3 100644 --- a/irc/caps/set.go +++ b/irc/caps/set.go @@ -6,43 +6,34 @@ package caps import ( "sort" "strings" - "sync" + + "github.com/oragono/oragono/irc/utils" ) // Set holds a set of enabled capabilities. -type Set struct { - sync.RWMutex - // capabilities holds the capabilities this manager has. - capabilities map[Capability]bool -} +type Set [bitsetLen]uint64 // NewSet returns a new Set, with the given capabilities enabled. func NewSet(capabs ...Capability) *Set { - newSet := Set{ - capabilities: make(map[Capability]bool), - } + var newSet Set + utils.BitsetInitialize(newSet[:]) newSet.Enable(capabs...) - return &newSet } // Enable enables the given capabilities. func (s *Set) Enable(capabs ...Capability) { - s.Lock() - defer s.Unlock() - + asSlice := s[:] for _, capab := range capabs { - s.capabilities[capab] = true + utils.BitsetSet(asSlice, uint(capab), true) } } // Disable disables the given capabilities. func (s *Set) Disable(capabs ...Capability) { - s.Lock() - defer s.Unlock() - + asSlice := s[:] for _, capab := range capabs { - delete(s.capabilities, capab) + utils.BitsetSet(asSlice, uint(capab), false) } } @@ -58,51 +49,35 @@ func (s *Set) Remove(capabs ...Capability) { s.Disable(capabs...) } -// Has returns true if this set has the given capabilities. -func (s *Set) Has(caps ...Capability) bool { - s.RLock() - defer s.RUnlock() - - for _, cap := range caps { - if !s.capabilities[cap] { - return false - } - } - return true +// Has returns true if this set has the given capability. +func (s *Set) Has(capab Capability) bool { + return utils.BitsetGet(s[:], uint(capab)) } -// List return a list of our enabled capabilities. -func (s *Set) List() []Capability { - s.RLock() - defer s.RUnlock() - - var allCaps []Capability - for capab := range s.capabilities { - allCaps = append(allCaps, capab) - } - - return allCaps +// Union adds all the capabilities of another set to this set. +func (s *Set) Union(other *Set) { + utils.BitsetUnion(s[:], other[:]) } -// Count returns how many enabled caps this set has. -func (s *Set) Count() int { - s.RLock() - defer s.RUnlock() - - return len(s.capabilities) +// Empty returns whether the set is empty. +func (s *Set) Empty() bool { + return utils.BitsetEmpty(s[:]) } // String returns all of our enabled capabilities as a string. func (s *Set) String(version Version, values *Values) string { - s.RLock() - defer s.RUnlock() - var strs sort.StringSlice - for capability := range s.capabilities { - capString := capability.Name() + var capab Capability + asSlice := s[:] + for capab = 0; capab < numCapabs; capab++ { + // skip any capabilities that are not enabled + if !utils.BitsetGet(asSlice, uint(capab)) { + continue + } + capString := capab.Name() if version == Cap302 { - val, exists := values.Get(capability) + val, exists := values.Get(capab) if exists { capString += "=" + val } diff --git a/irc/caps/set_test.go b/irc/caps/set_test.go index 83d6c8ec..019ca76b 100644 --- a/irc/caps/set_test.go +++ b/irc/caps/set_test.go @@ -11,12 +11,12 @@ func TestSets(t *testing.T) { s1.Enable(AccountTag, EchoMessage, UserhostInNames) - if !s1.Has(AccountTag, EchoMessage, UserhostInNames) { + if !(s1.Has(AccountTag) && s1.Has(EchoMessage) && s1.Has(UserhostInNames)) { t.Error("Did not have the tags we expected") } - if s1.Has(AccountTag, EchoMessage, STS, UserhostInNames) { - t.Error("Has() returned true when we don't have all the given capabilities") + if s1.Has(STS) { + t.Error("Has() returned true when we don't have the given capability") } s1.Disable(AccountTag) @@ -25,14 +25,9 @@ func TestSets(t *testing.T) { t.Error("Disable() did not correctly disable the given capability") } - enabledCaps := make(map[Capability]bool) - for _, capab := range s1.List() { - enabledCaps[capab] = true - } - expectedCaps := map[Capability]bool{ - EchoMessage: true, - UserhostInNames: true, - } + enabledCaps := NewSet() + enabledCaps.Union(s1) + expectedCaps := NewSet(EchoMessage, UserhostInNames) if !reflect.DeepEqual(enabledCaps, expectedCaps) { t.Errorf("Enabled and expected capability lists do not match: %v, %v", enabledCaps, expectedCaps) } @@ -40,16 +35,12 @@ func TestSets(t *testing.T) { // make sure re-enabling doesn't add to the count or something weird like that s1.Enable(EchoMessage) - if s1.Count() != 2 { - t.Error("Count() did not match expected capability count") - } - // make sure add and remove work fine s1.Add(InviteNotify) s1.Remove(EchoMessage) - if s1.Count() != 2 { - t.Error("Count() did not match expected capability count") + if !s1.Has(InviteNotify) || s1.Has(EchoMessage) { + t.Error("Add/Remove don't work") } // test String() diff --git a/irc/handlers.go b/irc/handlers.go index df8def48..4d596f56 100644 --- a/irc/handlers.go +++ b/irc/handlers.go @@ -442,12 +442,16 @@ func capHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Respo capabilities := caps.NewSet() var capString string + var badCaps []string if len(msg.Params) > 1 { capString = msg.Params[1] - strs := strings.Split(capString, " ") + strs := strings.Fields(capString) for _, str := range strs { - if len(str) > 0 { - capabilities.Enable(caps.Capability(str)) + capab, err := caps.NameToCapability(str) + if err != nil || !SupportedCapabilities.Has(capab) { + badCaps = append(badCaps, str) + } else { + capabilities.Enable(capab) } } } @@ -475,13 +479,11 @@ func capHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Respo } // make sure all capabilities actually exist - for _, capability := range capabilities.List() { - if !SupportedCapabilities.Has(capability) { - rb.Add(nil, server.name, "CAP", client.nick, "NAK", capString) - return false - } + if len(badCaps) > 0 { + rb.Add(nil, server.name, "CAP", client.nick, "NAK", capString) + return false } - client.capabilities.Enable(capabilities.List()...) + client.capabilities.Union(capabilities) rb.Add(nil, server.name, "CAP", client.nick, "ACK", capString) case "END": diff --git a/irc/modes/modes.go b/irc/modes/modes.go index 481ee3d4..d9f362f1 100644 --- a/irc/modes/modes.go +++ b/irc/modes/modes.go @@ -7,7 +7,9 @@ package modes import ( "strings" - "sync" + "sync/atomic" + + "github.com/oragono/oragono/irc/utils" ) var ( @@ -322,42 +324,29 @@ func ParseChannelModeChanges(params ...string) (ModeChanges, map[rune]bool) { } // ModeSet holds a set of modes. -type ModeSet struct { - sync.RWMutex // tier 0 - modes map[Mode]bool -} +type ModeSet [1]uint64 + +// valid modes go from 65 ('A') to 122 ('z'), making at most 58 possible values; +// subtract 65 from the mode value and use that bit of the uint64 to represent it +const ( + minMode = 65 // 'A' +) // returns a pointer to a new ModeSet func NewModeSet() *ModeSet { - return &ModeSet{ - modes: make(map[Mode]bool), - } + var set ModeSet + utils.BitsetInitialize(set[:]) + return &set } // test whether `mode` is set func (set *ModeSet) HasMode(mode Mode) bool { - if set == nil { - return false - } - - set.RLock() - defer set.RUnlock() - return set.modes[mode] + return utils.BitsetGet(set[:], uint(mode)-minMode) } // set `mode` to be on or off, return whether the value actually changed func (set *ModeSet) SetMode(mode Mode, on bool) (applied bool) { - set.Lock() - defer set.Unlock() - - previouslyOn := set.modes[mode] - needsApply := (on != previouslyOn) - if on && needsApply { - set.modes[mode] = true - } else if !on && needsApply { - delete(set.modes, mode) - } - return needsApply + return utils.BitsetSet(set[:], uint(mode)-minMode, on) } // return the modes in the set as a slice @@ -366,11 +355,12 @@ func (set *ModeSet) AllModes() (result []Mode) { return } - set.RLock() - defer set.RUnlock() - - for mode := range set.modes { - result = append(result, mode) + block := atomic.LoadUint64(&set[0]) + var i uint + for i = 0; i < 64; i++ { + if block&(1< +// released under the MIT license + +package utils + +import "sync/atomic" + +// Library functions for lock-free bitsets, typically (constant-sized) arrays of uint64. +// For examples of use, see caps.Set and modes.ModeSet; the array has to be converted to a +// slice to use these functions. + +// BitsetInitialize initializes a bitset. +func BitsetInitialize(set []uint64) { + // XXX re-zero the bitset using atomic stores. it's unclear whether this is required, + // however, golang issue #5045 suggests that you shouldn't mix atomic operations + // with non-atomic operations (such as the runtime's automatic zero-initialization) on + // the same word + for i := 0; i < len(set); i++ { + atomic.StoreUint64(&set[i], 0) + } +} + +// BitsetGet returns whether a given bit of the bitset is set. +func BitsetGet(set []uint64, position uint) bool { + idx := position / 64 + bit := position % 64 + block := atomic.LoadUint64(&set[idx]) + return (block & (1 << bit)) != 0 +} + +// BitsetSet sets a given bit of the bitset to 0 or 1, returning whether it changed. +func BitsetSet(set []uint64, position uint, on bool) (changed bool) { + idx := position / 64 + bit := position % 64 + addr := &set[idx] + var mask uint64 + mask = 1 << bit + for { + current := atomic.LoadUint64(addr) + previouslyOn := (current & mask) != 0 + if on == previouslyOn { + return false + } + var desired uint64 + if on { + desired = current | mask + } else { + desired = current & (^mask) + } + if atomic.CompareAndSwapUint64(addr, current, desired) { + return true + } + } +} + +// BitsetEmpty returns whether the bitset is empty. +// Right now, this is technically free of race conditions because we don't +// have a method that can simultaneously modify two bits separated by a word boundary +// such that one of those modifications is an unset. If we did, there would be a race +// that could produce false positives. It's probably better to assume that they are +// already possible under concurrent modification (which is not how we're using this). +func BitsetEmpty(set []uint64) (empty bool) { + for i := 0; i < len(set); i++ { + if atomic.LoadUint64(&set[i]) != 0 { + return false + } + } + return true +} + +// BitsetUnion modifies `set` to be the union of `set` and `other`. +// This has race conditions in that we don't necessarily get a single +// consistent view of `other` across word boundaries. +func BitsetUnion(set []uint64, other []uint64) { + for i := 0; i < len(set); i++ { + for { + ourAddr := &set[i] + ourBlock := atomic.LoadUint64(ourAddr) + otherBlock := atomic.LoadUint64(&other[i]) + newBlock := ourBlock | otherBlock + if atomic.CompareAndSwapUint64(ourAddr, ourBlock, newBlock) { + break + } + } + } +} diff --git a/irc/utils/bitset_test.go b/irc/utils/bitset_test.go new file mode 100644 index 00000000..9db22668 --- /dev/null +++ b/irc/utils/bitset_test.go @@ -0,0 +1,52 @@ +// Copyright (c) 2018 Shivaram Lingamneni +// released under the MIT license + +package utils + +import "testing" + +type testBitset [2]uint64 + +func TestSets(t *testing.T) { + var t1 testBitset + t1s := t1[:] + BitsetInitialize(t1s) + + if BitsetGet(t1s, 0) || BitsetGet(t1s, 63) || BitsetGet(t1s, 64) || BitsetGet(t1s, 127) { + t.Error("no bits should be set in a newly initialized bitset") + } + + var i uint + for i = 0; i < 128; i++ { + if i%2 == 0 { + BitsetSet(t1s, i, true) + } + } + + if !(BitsetGet(t1s, 0) && !BitsetGet(t1s, 1) && BitsetGet(t1s, 64) && BitsetGet(t1s, 72) && !BitsetGet(t1s, 127)) { + t.Error("exactly the even-numbered bits should be set") + } + + BitsetSet(t1s, 72, false) + if BitsetGet(t1s, 72) { + t.Error("remove doesn't work") + } + + var t2 testBitset + t2s := t2[:] + BitsetInitialize(t2s) + + for i = 0; i < 128; i++ { + if i%2 == 1 { + BitsetSet(t2s, i, true) + } + } + + BitsetUnion(t1s, t2s) + for i = 0; i < 128; i++ { + expected := (i != 72) + if BitsetGet(t1s, i) != expected { + t.Error("all bits should be set except 72") + } + } +}