3
0
mirror of https://github.com/ergochat/ergo.git synced 2024-11-30 07:59:24 +01:00

Merge pull request #277 from slingamn/bitset.3

implement #263
This commit is contained in:
Daniel Oaks 2018-07-02 16:00:04 +10:00 committed by GitHub
commit 477a9023ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 606 additions and 171 deletions

View File

@ -1,5 +1,7 @@
.PHONY: all build .PHONY: all build
capdef_file = ./irc/caps/defs.go
all: build all: build
build: build:
@ -8,11 +10,16 @@ build:
buildrelease: buildrelease:
goreleaser --skip-publish --rm-dist goreleaser --skip-publish --rm-dist
capdefs:
python3 ./gencapdefs.py > ${capdef_file}
deps: deps:
git submodule update --init git submodule update --init
test: test:
python3 ./gencapdefs.py | diff - ${capdef_file}
cd irc && go test . && go vet . cd irc && go test . && go vet .
cd irc/caps && go test . && go vet .
cd irc/isupport && go test . && go vet . cd irc/isupport && go test . && go vet .
cd irc/modes && go test . && go vet . cd irc/modes && go test . && go vet .
cd irc/utils && go test . && go vet . cd irc/utils && go test . && go vet .

200
gencapdefs.py Normal file
View File

@ -0,0 +1,200 @@
#!/usr/bin/env python3
"""
Updates the capability definitions at irc/caps/defs.go
To add a capability, add it to the CAPDEFS list below,
then run `make capdefs` from the project root.
"""
import io
import subprocess
import sys
from collections import namedtuple
CapDef = namedtuple("CapDef", ['identifier', 'name', 'url', 'standard'])
CAPDEFS = [
CapDef(
identifier="AccountNotify",
name="account-notify",
url="https://ircv3.net/specs/extensions/account-notify-3.1.html",
standard="IRCv3",
),
CapDef(
identifier="AccountTag",
name="account-tag",
url="https://ircv3.net/specs/extensions/account-tag-3.2.html",
standard="IRCv3",
),
CapDef(
identifier="AwayNotify",
name="away-notify",
url="https://ircv3.net/specs/extensions/away-notify-3.1.html",
standard="IRCv3",
),
CapDef(
identifier="Batch",
name="batch",
url="https://ircv3.net/specs/extensions/batch-3.2.html",
standard="IRCv3",
),
CapDef(
identifier="CapNotify",
name="cap-notify",
url="https://ircv3.net/specs/extensions/cap-notify-3.2.html",
standard="IRCv3",
),
CapDef(
identifier="ChgHost",
name="chghost",
url="https://ircv3.net/specs/extensions/chghost-3.2.html",
standard="IRCv3",
),
CapDef(
identifier="EchoMessage",
name="echo-message",
url="https://ircv3.net/specs/extensions/echo-message-3.2.html",
standard="IRCv3",
),
CapDef(
identifier="ExtendedJoin",
name="extended-join",
url="https://ircv3.net/specs/extensions/extended-join-3.1.html",
standard="IRCv3",
),
CapDef(
identifier="InviteNotify",
name="invite-notify",
url="https://ircv3.net/specs/extensions/invite-notify-3.2.html",
standard="IRCv3",
),
CapDef(
identifier="LabeledResponse",
name="draft/labeled-response",
url="https://ircv3.net/specs/extensions/labeled-response.html",
standard="draft IRCv3",
),
CapDef(
identifier="Languages",
name="draft/languages",
url="https://gist.github.com/DanielOaks/8126122f74b26012a3de37db80e4e0c6",
standard="proposed IRCv3",
),
CapDef(
identifier="MaxLine",
name="oragono.io/maxline",
url="https://oragono.io/maxline",
standard="Oragono-specific",
),
CapDef(
identifier="MessageTags",
name="draft/message-tags-0.2",
url="https://ircv3.net/specs/core/message-tags-3.3.html",
standard="draft IRCv3",
),
CapDef(
identifier="MultiPrefix",
name="multi-prefix",
url="https://ircv3.net/specs/extensions/multi-prefix-3.1.html",
standard="IRCv3",
),
CapDef(
identifier="Rename",
name="draft/rename",
url="https://github.com/SaberUK/ircv3-specifications/blob/rename/extensions/rename.md",
standard="proposed IRCv3",
),
CapDef(
identifier="Resume",
name="draft/resume",
url="https://github.com/DanielOaks/ircv3-specifications/blob/master+resume/extensions/resume.md",
standard="proposed IRCv3",
),
CapDef(
identifier="SASL",
name="sasl",
url="https://ircv3.net/specs/extensions/sasl-3.2.html",
standard="IRCv3",
),
CapDef(
identifier="ServerTime",
name="server-time",
url="https://ircv3.net/specs/extensions/server-time-3.2.html",
standard="IRCv3",
),
CapDef(
identifier="STS",
name="sts",
url="https://ircv3.net/specs/extensions/sts.html",
standard="IRCv3",
),
CapDef(
identifier="UserhostInNames",
name="userhost-in-names",
url="https://ircv3.net/specs/extensions/userhost-in-names-3.2.html",
standard="IRCv3",
),
]
def validate_defs():
numCaps = len(CAPDEFS)
numNames = len(set(capdef.name for capdef in CAPDEFS))
if numCaps != numNames:
raise Exception("defs must have unique names, but found duplicates")
numIdentifiers = len(set(capdef.identifier for capdef in CAPDEFS))
if numCaps != numIdentifiers:
raise Exception("defs must have unique identifiers, but found duplicates")
def main():
validate_defs()
output = io.StringIO()
print("""
package caps
/*
WARNING: this file is autogenerated by `make capdefs`
DO NOT EDIT MANUALLY.
*/
""", file=output)
numCapabs = len(CAPDEFS)
bitsetLen = numCapabs // 64
if numCapabs % 64 > 0:
bitsetLen += 1
print ("""
const (
// number of recognized capabilities:
numCapabs = %d
// length of the uint64 array that represents the bitset:
bitsetLen = %d
)
""" % (numCapabs, bitsetLen), file=output)
print("const (", file=output)
for capdef in CAPDEFS:
print("// %s is the %s capability named \"%s\":" % (capdef.identifier, capdef.standard, capdef.name), file=output)
print("// %s" % (capdef.url,), file=output)
print("%s Capability = iota" % (capdef.identifier,), file=output)
print(file=output)
print(")", file=output)
print("// `capabilityNames[capab]` is the string name of the capability `capab`", file=output)
print("""var ( capabilityNames = [numCapabs]string{""", file=output)
for capdef in CAPDEFS:
print("\"%s\"," % (capdef.name,), file=output)
print("})", file=output)
# run the generated code through `gofmt -s`, which will print it to stdout
gofmt = subprocess.Popen(['gofmt', '-s'], stdin=subprocess.PIPE)
gofmt.communicate(input=output.getvalue().encode('utf-8'))
if gofmt.poll() != 0:
print(output.getvalue())
raise Exception("gofmt failed")
return 0
if __name__ == '__main__':
sys.exit(main())

View File

@ -3,58 +3,30 @@
package caps package caps
import "errors"
// Capability represents an optional feature that a client may request from the server. // Capability represents an optional feature that a client may request from the server.
type Capability string type Capability uint
const ( // actual capability definitions appear in defs.go
// LabelTagName is the tag name used for the labeled-response spec.
LabelTagName = "draft/label"
// AccountNotify is this IRCv3 capability: http://ircv3.net/specs/extensions/account-notify-3.1.html var (
AccountNotify Capability = "account-notify" nameToCapability map[string]Capability
// AccountTag is this IRCv3 capability: http://ircv3.net/specs/extensions/account-tag-3.2.html
AccountTag Capability = "account-tag" NoSuchCap = errors.New("Unsupported capability name")
// AwayNotify is this IRCv3 capability: http://ircv3.net/specs/extensions/away-notify-3.1.html
AwayNotify Capability = "away-notify"
// Batch is this IRCv3 capability: http://ircv3.net/specs/extensions/batch-3.2.html
Batch Capability = "batch"
// CapNotify is this IRCv3 capability: http://ircv3.net/specs/extensions/cap-notify-3.2.html
CapNotify Capability = "cap-notify"
// ChgHost is this IRCv3 capability: http://ircv3.net/specs/extensions/chghost-3.2.html
ChgHost Capability = "chghost"
// EchoMessage is this IRCv3 capability: http://ircv3.net/specs/extensions/echo-message-3.2.html
EchoMessage Capability = "echo-message"
// ExtendedJoin is this IRCv3 capability: http://ircv3.net/specs/extensions/extended-join-3.1.html
ExtendedJoin Capability = "extended-join"
// InviteNotify is this IRCv3 capability: http://ircv3.net/specs/extensions/invite-notify-3.2.html
InviteNotify Capability = "invite-notify"
// LabeledResponse is this draft IRCv3 capability: http://ircv3.net/specs/extensions/labeled-response.html
LabeledResponse Capability = "draft/labeled-response"
// Languages is this proposed IRCv3 capability: https://gist.github.com/DanielOaks/8126122f74b26012a3de37db80e4e0c6
Languages Capability = "draft/languages"
// MaxLine is this capability: https://oragono.io/maxline
MaxLine Capability = "oragono.io/maxline"
// MessageTags is this draft IRCv3 capability: http://ircv3.net/specs/core/message-tags-3.3.html
MessageTags Capability = "draft/message-tags-0.2"
// MultiPrefix is this IRCv3 capability: http://ircv3.net/specs/extensions/multi-prefix-3.1.html
MultiPrefix Capability = "multi-prefix"
// Rename is this proposed capability: https://github.com/SaberUK/ircv3-specifications/blob/rename/extensions/rename.md
Rename Capability = "draft/rename"
// Resume is this proposed capability: https://github.com/DanielOaks/ircv3-specifications/blob/master+resume/extensions/resume.md
Resume Capability = "draft/resume"
// SASL is this IRCv3 capability: http://ircv3.net/specs/extensions/sasl-3.2.html
SASL Capability = "sasl"
// ServerTime is this IRCv3 capability: http://ircv3.net/specs/extensions/server-time-3.2.html
ServerTime Capability = "server-time"
// STS is this IRCv3 capability: http://ircv3.net/specs/extensions/sts.html
STS Capability = "sts"
// UserhostInNames is this IRCv3 capability: http://ircv3.net/specs/extensions/userhost-in-names-3.2.html
UserhostInNames Capability = "userhost-in-names"
) )
// Name returns the name of the given capability. // Name returns the name of the given capability.
func (capability Capability) Name() string { func (capability Capability) Name() string {
return string(capability) return capabilityNames[capability]
}
func NameToCapability(name string) (result Capability, err error) {
result, found := nameToCapability[name]
if !found {
err = NoSuchCap
}
return
} }
// Version is used to select which max version of CAP the client supports. // Version is used to select which max version of CAP the client supports.
@ -78,3 +50,16 @@ const (
// NegotiatedState means CAP negotiation has been successfully ended and reg should complete. // NegotiatedState means CAP negotiation has been successfully ended and reg should complete.
NegotiatedState State = iota NegotiatedState State = iota
) )
const (
// LabelTagName is the tag name used for the labeled-response spec.
// https://ircv3.net/specs/extensions/labeled-response.html
LabelTagName = "draft/label"
)
func init() {
nameToCapability = make(map[string]Capability)
for capab, name := range capabilityNames {
nameToCapability[name] = Capability(capab)
}
}

121
irc/caps/defs.go Normal file
View File

@ -0,0 +1,121 @@
package caps
/*
WARNING: this file is autogenerated by `make capdefs`
DO NOT EDIT MANUALLY.
*/
const (
// number of recognized capabilities:
numCapabs = 20
// length of the uint64 array that represents the bitset:
bitsetLen = 1
)
const (
// AccountNotify is the IRCv3 capability named "account-notify":
// https://ircv3.net/specs/extensions/account-notify-3.1.html
AccountNotify Capability = iota
// AccountTag is the IRCv3 capability named "account-tag":
// https://ircv3.net/specs/extensions/account-tag-3.2.html
AccountTag Capability = iota
// AwayNotify is the IRCv3 capability named "away-notify":
// https://ircv3.net/specs/extensions/away-notify-3.1.html
AwayNotify Capability = iota
// Batch is the IRCv3 capability named "batch":
// https://ircv3.net/specs/extensions/batch-3.2.html
Batch Capability = iota
// CapNotify is the IRCv3 capability named "cap-notify":
// https://ircv3.net/specs/extensions/cap-notify-3.2.html
CapNotify Capability = iota
// ChgHost is the IRCv3 capability named "chghost":
// https://ircv3.net/specs/extensions/chghost-3.2.html
ChgHost Capability = iota
// EchoMessage is the IRCv3 capability named "echo-message":
// https://ircv3.net/specs/extensions/echo-message-3.2.html
EchoMessage Capability = iota
// ExtendedJoin is the IRCv3 capability named "extended-join":
// https://ircv3.net/specs/extensions/extended-join-3.1.html
ExtendedJoin Capability = iota
// InviteNotify is the IRCv3 capability named "invite-notify":
// https://ircv3.net/specs/extensions/invite-notify-3.2.html
InviteNotify Capability = iota
// LabeledResponse is the draft IRCv3 capability named "draft/labeled-response":
// https://ircv3.net/specs/extensions/labeled-response.html
LabeledResponse Capability = iota
// Languages is the proposed IRCv3 capability named "draft/languages":
// https://gist.github.com/DanielOaks/8126122f74b26012a3de37db80e4e0c6
Languages Capability = iota
// MaxLine is the Oragono-specific capability named "oragono.io/maxline":
// https://oragono.io/maxline
MaxLine Capability = iota
// MessageTags is the draft IRCv3 capability named "draft/message-tags-0.2":
// https://ircv3.net/specs/core/message-tags-3.3.html
MessageTags Capability = iota
// MultiPrefix is the IRCv3 capability named "multi-prefix":
// https://ircv3.net/specs/extensions/multi-prefix-3.1.html
MultiPrefix Capability = iota
// Rename is the proposed IRCv3 capability named "draft/rename":
// https://github.com/SaberUK/ircv3-specifications/blob/rename/extensions/rename.md
Rename Capability = iota
// Resume is the proposed IRCv3 capability named "draft/resume":
// https://github.com/DanielOaks/ircv3-specifications/blob/master+resume/extensions/resume.md
Resume Capability = iota
// SASL is the IRCv3 capability named "sasl":
// https://ircv3.net/specs/extensions/sasl-3.2.html
SASL Capability = iota
// ServerTime is the IRCv3 capability named "server-time":
// https://ircv3.net/specs/extensions/server-time-3.2.html
ServerTime Capability = iota
// STS is the IRCv3 capability named "sts":
// https://ircv3.net/specs/extensions/sts.html
STS Capability = iota
// UserhostInNames is the IRCv3 capability named "userhost-in-names":
// https://ircv3.net/specs/extensions/userhost-in-names-3.2.html
UserhostInNames Capability = iota
)
// `capabilityNames[capab]` is the string name of the capability `capab`
var (
capabilityNames = [numCapabs]string{
"account-notify",
"account-tag",
"away-notify",
"batch",
"cap-notify",
"chghost",
"echo-message",
"extended-join",
"invite-notify",
"draft/labeled-response",
"draft/languages",
"oragono.io/maxline",
"draft/message-tags-0.2",
"multi-prefix",
"draft/rename",
"draft/resume",
"sasl",
"server-time",
"sts",
"userhost-in-names",
}
)

View File

@ -6,43 +6,34 @@ package caps
import ( import (
"sort" "sort"
"strings" "strings"
"sync"
"github.com/oragono/oragono/irc/utils"
) )
// Set holds a set of enabled capabilities. // Set holds a set of enabled capabilities.
type Set struct { type Set [bitsetLen]uint64
sync.RWMutex
// capabilities holds the capabilities this manager has.
capabilities map[Capability]bool
}
// NewSet returns a new Set, with the given capabilities enabled. // NewSet returns a new Set, with the given capabilities enabled.
func NewSet(capabs ...Capability) *Set { func NewSet(capabs ...Capability) *Set {
newSet := Set{ var newSet Set
capabilities: make(map[Capability]bool), utils.BitsetInitialize(newSet[:])
}
newSet.Enable(capabs...) newSet.Enable(capabs...)
return &newSet return &newSet
} }
// Enable enables the given capabilities. // Enable enables the given capabilities.
func (s *Set) Enable(capabs ...Capability) { func (s *Set) Enable(capabs ...Capability) {
s.Lock() asSlice := s[:]
defer s.Unlock()
for _, capab := range capabs { for _, capab := range capabs {
s.capabilities[capab] = true utils.BitsetSet(asSlice, uint(capab), true)
} }
} }
// Disable disables the given capabilities. // Disable disables the given capabilities.
func (s *Set) Disable(capabs ...Capability) { func (s *Set) Disable(capabs ...Capability) {
s.Lock() asSlice := s[:]
defer s.Unlock()
for _, capab := range capabs { for _, capab := range capabs {
delete(s.capabilities, capab) utils.BitsetSet(asSlice, uint(capab), false)
} }
} }
@ -58,51 +49,35 @@ func (s *Set) Remove(capabs ...Capability) {
s.Disable(capabs...) s.Disable(capabs...)
} }
// Has returns true if this set has the given capabilities. // Has returns true if this set has the given capability.
func (s *Set) Has(caps ...Capability) bool { func (s *Set) Has(capab Capability) bool {
s.RLock() return utils.BitsetGet(s[:], uint(capab))
defer s.RUnlock()
for _, cap := range caps {
if !s.capabilities[cap] {
return false
}
}
return true
} }
// List return a list of our enabled capabilities. // Union adds all the capabilities of another set to this set.
func (s *Set) List() []Capability { func (s *Set) Union(other *Set) {
s.RLock() utils.BitsetUnion(s[:], other[:])
defer s.RUnlock()
var allCaps []Capability
for capab := range s.capabilities {
allCaps = append(allCaps, capab)
}
return allCaps
} }
// Count returns how many enabled caps this set has. // Empty returns whether the set is empty.
func (s *Set) Count() int { func (s *Set) Empty() bool {
s.RLock() return utils.BitsetEmpty(s[:])
defer s.RUnlock()
return len(s.capabilities)
} }
// String returns all of our enabled capabilities as a string. // String returns all of our enabled capabilities as a string.
func (s *Set) String(version Version, values *Values) string { func (s *Set) String(version Version, values *Values) string {
s.RLock()
defer s.RUnlock()
var strs sort.StringSlice var strs sort.StringSlice
for capability := range s.capabilities { var capab Capability
capString := capability.Name() asSlice := s[:]
for capab = 0; capab < numCapabs; capab++ {
// skip any capabilities that are not enabled
if !utils.BitsetGet(asSlice, uint(capab)) {
continue
}
capString := capab.Name()
if version == Cap302 { if version == Cap302 {
val, exists := values.Get(capability) val, exists := values.Get(capab)
if exists { if exists {
capString += "=" + val capString += "=" + val
} }

View File

@ -11,12 +11,12 @@ func TestSets(t *testing.T) {
s1.Enable(AccountTag, EchoMessage, UserhostInNames) s1.Enable(AccountTag, EchoMessage, UserhostInNames)
if !s1.Has(AccountTag, EchoMessage, UserhostInNames) { if !(s1.Has(AccountTag) && s1.Has(EchoMessage) && s1.Has(UserhostInNames)) {
t.Error("Did not have the tags we expected") t.Error("Did not have the tags we expected")
} }
if s1.Has(AccountTag, EchoMessage, STS, UserhostInNames) { if s1.Has(STS) {
t.Error("Has() returned true when we don't have all the given capabilities") t.Error("Has() returned true when we don't have the given capability")
} }
s1.Disable(AccountTag) s1.Disable(AccountTag)
@ -25,14 +25,9 @@ func TestSets(t *testing.T) {
t.Error("Disable() did not correctly disable the given capability") t.Error("Disable() did not correctly disable the given capability")
} }
enabledCaps := make(map[Capability]bool) enabledCaps := NewSet()
for _, capab := range s1.List() { enabledCaps.Union(s1)
enabledCaps[capab] = true expectedCaps := NewSet(EchoMessage, UserhostInNames)
}
expectedCaps := map[Capability]bool{
EchoMessage: true,
UserhostInNames: true,
}
if !reflect.DeepEqual(enabledCaps, expectedCaps) { if !reflect.DeepEqual(enabledCaps, expectedCaps) {
t.Errorf("Enabled and expected capability lists do not match: %v, %v", enabledCaps, expectedCaps) t.Errorf("Enabled and expected capability lists do not match: %v, %v", enabledCaps, expectedCaps)
} }
@ -40,16 +35,12 @@ func TestSets(t *testing.T) {
// make sure re-enabling doesn't add to the count or something weird like that // make sure re-enabling doesn't add to the count or something weird like that
s1.Enable(EchoMessage) s1.Enable(EchoMessage)
if s1.Count() != 2 {
t.Error("Count() did not match expected capability count")
}
// make sure add and remove work fine // make sure add and remove work fine
s1.Add(InviteNotify) s1.Add(InviteNotify)
s1.Remove(EchoMessage) s1.Remove(EchoMessage)
if s1.Count() != 2 { if !s1.Has(InviteNotify) || s1.Has(EchoMessage) {
t.Error("Count() did not match expected capability count") t.Error("Add/Remove don't work")
} }
// test String() // test String()
@ -68,3 +59,24 @@ func TestSets(t *testing.T) {
t.Errorf("Generated Cap302 values string [%s] did not match expected values string [%s]", actualCap302ValuesString, expectedCap302ValuesString) t.Errorf("Generated Cap302 values string [%s] did not match expected values string [%s]", actualCap302ValuesString, expectedCap302ValuesString)
} }
} }
func BenchmarkSetReads(b *testing.B) {
set := NewSet(UserhostInNames, EchoMessage)
b.ResetTimer()
for i := 0; i < b.N; i++ {
set.Has(UserhostInNames)
set.Has(LabeledResponse)
set.Has(EchoMessage)
set.Has(Rename)
}
}
func BenchmarkSetWrites(b *testing.B) {
for i := 0; i < b.N; i++ {
set := NewSet(UserhostInNames, EchoMessage)
set.Add(Rename)
set.Add(ExtendedJoin)
set.Remove(UserhostInNames)
set.Remove(LabeledResponse)
}
}

View File

@ -442,12 +442,16 @@ func capHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Respo
capabilities := caps.NewSet() capabilities := caps.NewSet()
var capString string var capString string
var badCaps []string
if len(msg.Params) > 1 { if len(msg.Params) > 1 {
capString = msg.Params[1] capString = msg.Params[1]
strs := strings.Split(capString, " ") strs := strings.Fields(capString)
for _, str := range strs { for _, str := range strs {
if len(str) > 0 { capab, err := caps.NameToCapability(str)
capabilities.Enable(caps.Capability(str)) if err != nil || !SupportedCapabilities.Has(capab) {
badCaps = append(badCaps, str)
} else {
capabilities.Enable(capab)
} }
} }
} }
@ -475,13 +479,11 @@ func capHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Respo
} }
// make sure all capabilities actually exist // make sure all capabilities actually exist
for _, capability := range capabilities.List() { if len(badCaps) > 0 {
if !SupportedCapabilities.Has(capability) {
rb.Add(nil, server.name, "CAP", client.nick, "NAK", capString) rb.Add(nil, server.name, "CAP", client.nick, "NAK", capString)
return false return false
} }
} client.capabilities.Union(capabilities)
client.capabilities.Enable(capabilities.List()...)
rb.Add(nil, server.name, "CAP", client.nick, "ACK", capString) rb.Add(nil, server.name, "CAP", client.nick, "ACK", capString)
case "END": case "END":

View File

@ -7,7 +7,9 @@ package modes
import ( import (
"strings" "strings"
"sync" "sync/atomic"
"github.com/oragono/oragono/irc/utils"
) )
var ( var (
@ -322,42 +324,29 @@ func ParseChannelModeChanges(params ...string) (ModeChanges, map[rune]bool) {
} }
// ModeSet holds a set of modes. // ModeSet holds a set of modes.
type ModeSet struct { type ModeSet [1]uint64
sync.RWMutex // tier 0
modes map[Mode]bool // valid modes go from 65 ('A') to 122 ('z'), making at most 58 possible values;
} // subtract 65 from the mode value and use that bit of the uint64 to represent it
const (
minMode = 65 // 'A'
)
// returns a pointer to a new ModeSet // returns a pointer to a new ModeSet
func NewModeSet() *ModeSet { func NewModeSet() *ModeSet {
return &ModeSet{ var set ModeSet
modes: make(map[Mode]bool), utils.BitsetInitialize(set[:])
} return &set
} }
// test whether `mode` is set // test whether `mode` is set
func (set *ModeSet) HasMode(mode Mode) bool { func (set *ModeSet) HasMode(mode Mode) bool {
if set == nil { return utils.BitsetGet(set[:], uint(mode)-minMode)
return false
}
set.RLock()
defer set.RUnlock()
return set.modes[mode]
} }
// set `mode` to be on or off, return whether the value actually changed // set `mode` to be on or off, return whether the value actually changed
func (set *ModeSet) SetMode(mode Mode, on bool) (applied bool) { func (set *ModeSet) SetMode(mode Mode, on bool) (applied bool) {
set.Lock() return utils.BitsetSet(set[:], uint(mode)-minMode, on)
defer set.Unlock()
previouslyOn := set.modes[mode]
needsApply := (on != previouslyOn)
if on && needsApply {
set.modes[mode] = true
} else if !on && needsApply {
delete(set.modes, mode)
}
return needsApply
} }
// return the modes in the set as a slice // return the modes in the set as a slice
@ -366,11 +355,12 @@ func (set *ModeSet) AllModes() (result []Mode) {
return return
} }
set.RLock() block := atomic.LoadUint64(&set[0])
defer set.RUnlock() var i uint
for i = 0; i < 64; i++ {
for mode := range set.modes { if block&(1<<i) != 0 {
result = append(result, mode) result = append(result, Mode(minMode+i))
}
} }
return return
} }
@ -381,11 +371,8 @@ func (set *ModeSet) String() (result string) {
return return
} }
set.RLock()
defer set.RUnlock()
var buf strings.Builder var buf strings.Builder
for mode := range set.modes { for _, mode := range set.AllModes() {
buf.WriteRune(rune(mode)) buf.WriteRune(rune(mode))
} }
return buf.String() return buf.String()
@ -397,12 +384,9 @@ func (set *ModeSet) Prefixes(isMultiPrefix bool) (prefixes string) {
return return
} }
set.RLock()
defer set.RUnlock()
// add prefixes in order from highest to lowest privs // add prefixes in order from highest to lowest privs
for _, mode := range ChannelUserModes { for _, mode := range ChannelUserModes {
if set.modes[mode] { if set.HasMode(mode) {
prefixes += ChannelModePrefixes[mode] prefixes += ChannelModePrefixes[mode]
} }
} }

View File

@ -898,13 +898,11 @@ func (server *Server) applyConfig(config *Config, initial bool) error {
// updated caps get DEL'd and then NEW'd // updated caps get DEL'd and then NEW'd
// so, we can just add updated ones to both removed and added lists here and they'll be correctly handled // so, we can just add updated ones to both removed and added lists here and they'll be correctly handled
server.logger.Debug("rehash", "Updated Caps", updatedCaps.String(caps.Cap301, CapValues), strconv.Itoa(updatedCaps.Count())) server.logger.Debug("rehash", "Updated Caps", updatedCaps.String(caps.Cap301, CapValues))
for _, capab := range updatedCaps.List() { addedCaps.Union(updatedCaps)
addedCaps.Enable(capab) removedCaps.Union(updatedCaps)
removedCaps.Enable(capab)
}
if 0 < addedCaps.Count() || 0 < removedCaps.Count() { if !addedCaps.Empty() || !removedCaps.Empty() {
capBurstClients = server.clients.AllWithCaps(caps.CapNotify) capBurstClients = server.clients.AllWithCaps(caps.CapNotify)
added[caps.Cap301] = addedCaps.String(caps.Cap301, CapValues) added[caps.Cap301] = addedCaps.String(caps.Cap301, CapValues)
@ -918,7 +916,7 @@ func (server *Server) applyConfig(config *Config, initial bool) error {
// remove STS policy // remove STS policy
//TODO(dan): this is an ugly hack. we can write this better. //TODO(dan): this is an ugly hack. we can write this better.
stsPolicy := "sts=duration=0" stsPolicy := "sts=duration=0"
if 0 < addedCaps.Count() { if !addedCaps.Empty() {
added[caps.Cap302] = added[caps.Cap302] + " " + stsPolicy added[caps.Cap302] = added[caps.Cap302] + " " + stsPolicy
} else { } else {
addedCaps.Enable(caps.STS) addedCaps.Enable(caps.STS)
@ -926,10 +924,10 @@ func (server *Server) applyConfig(config *Config, initial bool) error {
} }
} }
// DEL caps and then send NEW ones so that updated caps get removed/added correctly // DEL caps and then send NEW ones so that updated caps get removed/added correctly
if 0 < removedCaps.Count() { if !removedCaps.Empty() {
sClient.Send(nil, server.name, "CAP", sClient.nick, "DEL", removed) sClient.Send(nil, server.name, "CAP", sClient.nick, "DEL", removed)
} }
if 0 < addedCaps.Count() { if !addedCaps.Empty() {
sClient.Send(nil, server.name, "CAP", sClient.nick, "NEW", added[sClient.capVersion]) sClient.Send(nil, server.name, "CAP", sClient.nick, "NEW", added[sClient.capVersion])
} }
} }

86
irc/utils/bitset.go Normal file
View File

@ -0,0 +1,86 @@
// Copyright (c) 2018 Shivaram Lingamneni <slingamn@cs.stanford.edu>
// released under the MIT license
package utils
import "sync/atomic"
// Library functions for lock-free bitsets, typically (constant-sized) arrays of uint64.
// For examples of use, see caps.Set and modes.ModeSet; the array has to be converted to a
// slice to use these functions.
// BitsetInitialize initializes a bitset.
func BitsetInitialize(set []uint64) {
// XXX re-zero the bitset using atomic stores. it's unclear whether this is required,
// however, golang issue #5045 suggests that you shouldn't mix atomic operations
// with non-atomic operations (such as the runtime's automatic zero-initialization) on
// the same word
for i := 0; i < len(set); i++ {
atomic.StoreUint64(&set[i], 0)
}
}
// BitsetGet returns whether a given bit of the bitset is set.
func BitsetGet(set []uint64, position uint) bool {
idx := position / 64
bit := position % 64
block := atomic.LoadUint64(&set[idx])
return (block & (1 << bit)) != 0
}
// BitsetSet sets a given bit of the bitset to 0 or 1, returning whether it changed.
func BitsetSet(set []uint64, position uint, on bool) (changed bool) {
idx := position / 64
bit := position % 64
addr := &set[idx]
var mask uint64
mask = 1 << bit
for {
current := atomic.LoadUint64(addr)
previouslyOn := (current & mask) != 0
if on == previouslyOn {
return false
}
var desired uint64
if on {
desired = current | mask
} else {
desired = current & (^mask)
}
if atomic.CompareAndSwapUint64(addr, current, desired) {
return true
}
}
}
// BitsetEmpty returns whether the bitset is empty.
// Right now, this is technically free of race conditions because we don't
// have a method that can simultaneously modify two bits separated by a word boundary
// such that one of those modifications is an unset. If we did, there would be a race
// that could produce false positives. It's probably better to assume that they are
// already possible under concurrent modification (which is not how we're using this).
func BitsetEmpty(set []uint64) (empty bool) {
for i := 0; i < len(set); i++ {
if atomic.LoadUint64(&set[i]) != 0 {
return false
}
}
return true
}
// BitsetUnion modifies `set` to be the union of `set` and `other`.
// This has race conditions in that we don't necessarily get a single
// consistent view of `other` across word boundaries.
func BitsetUnion(set []uint64, other []uint64) {
for i := 0; i < len(set); i++ {
for {
ourAddr := &set[i]
ourBlock := atomic.LoadUint64(ourAddr)
otherBlock := atomic.LoadUint64(&other[i])
newBlock := ourBlock | otherBlock
if atomic.CompareAndSwapUint64(ourAddr, ourBlock, newBlock) {
break
}
}
}
}

65
irc/utils/bitset_test.go Normal file
View File

@ -0,0 +1,65 @@
// Copyright (c) 2018 Shivaram Lingamneni <slingamn@cs.stanford.edu>
// released under the MIT license
package utils
import "testing"
type testBitset [2]uint64
func TestSets(t *testing.T) {
var t1 testBitset
t1s := t1[:]
BitsetInitialize(t1s)
if BitsetGet(t1s, 0) || BitsetGet(t1s, 63) || BitsetGet(t1s, 64) || BitsetGet(t1s, 127) {
t.Error("no bits should be set in a newly initialized bitset")
}
var i uint
for i = 0; i < 128; i++ {
if i%2 == 0 {
if !BitsetSet(t1s, i, true) {
t.Error("setting an uninitialized bit should return true")
}
}
}
if BitsetSet(t1s, 24, true) {
t.Error("setting an already-set bit should return false")
}
if !(BitsetGet(t1s, 0) && !BitsetGet(t1s, 1) && BitsetGet(t1s, 64) && BitsetGet(t1s, 72) && !BitsetGet(t1s, 127)) {
t.Error("exactly the even-numbered bits should be set")
}
if !BitsetSet(t1s, 72, false) {
t.Error("removing a set bit should return true")
}
if BitsetGet(t1s, 72) {
t.Error("remove doesn't work")
}
if BitsetSet(t1s, 72, false) {
t.Error("removing an unset bit should return false")
}
var t2 testBitset
t2s := t2[:]
BitsetInitialize(t2s)
for i = 0; i < 128; i++ {
if i%2 == 1 {
BitsetSet(t2s, i, true)
}
}
BitsetUnion(t1s, t2s)
for i = 0; i < 128; i++ {
expected := (i != 72)
if BitsetGet(t1s, i) != expected {
t.Error("all bits should be set except 72")
}
}
}