3
0
mirror of https://github.com/ergochat/ergo.git synced 2024-12-22 10:42:52 +01:00

Merge pull request #446 from slingamn/chanregrefactor.6

refactor channel registration
This commit is contained in:
Daniel Oaks 2019-04-04 21:59:25 +10:00 committed by GitHub
commit 8c7027c604
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 425 additions and 227 deletions

View File

@ -59,18 +59,15 @@ type AccountManager struct {
accountToMethod map[string]NickReservationMethod accountToMethod map[string]NickReservationMethod
} }
func NewAccountManager(server *Server) *AccountManager { func (am *AccountManager) Initialize(server *Server) {
am := AccountManager{ am.accountToClients = make(map[string][]*Client)
accountToClients: make(map[string][]*Client), am.nickToAccount = make(map[string]string)
nickToAccount: make(map[string]string), am.skeletonToAccount = make(map[string]string)
skeletonToAccount: make(map[string]string), am.accountToMethod = make(map[string]NickReservationMethod)
accountToMethod: make(map[string]NickReservationMethod), am.server = server
server: server,
}
am.buildNickToAccountIndex() am.buildNickToAccountIndex()
am.initVHostRequestQueue() am.initVHostRequestQueue()
return &am
} }
func (am *AccountManager) buildNickToAccountIndex() { func (am *AccountManager) buildNickToAccountIndex() {
@ -855,6 +852,7 @@ func (am *AccountManager) Unregister(account string) error {
verificationCodeKey := fmt.Sprintf(keyAccountVerificationCode, casefoldedAccount) verificationCodeKey := fmt.Sprintf(keyAccountVerificationCode, casefoldedAccount)
verifiedKey := fmt.Sprintf(keyAccountVerified, casefoldedAccount) verifiedKey := fmt.Sprintf(keyAccountVerified, casefoldedAccount)
nicksKey := fmt.Sprintf(keyAccountAdditionalNicks, casefoldedAccount) nicksKey := fmt.Sprintf(keyAccountAdditionalNicks, casefoldedAccount)
enforcementKey := fmt.Sprintf(keyAccountEnforcement, casefoldedAccount)
vhostKey := fmt.Sprintf(keyAccountVHost, casefoldedAccount) vhostKey := fmt.Sprintf(keyAccountVHost, casefoldedAccount)
vhostQueueKey := fmt.Sprintf(keyVHostQueueAcctToId, casefoldedAccount) vhostQueueKey := fmt.Sprintf(keyVHostQueueAcctToId, casefoldedAccount)
channelsKey := fmt.Sprintf(keyAccountChannels, casefoldedAccount) channelsKey := fmt.Sprintf(keyAccountChannels, casefoldedAccount)
@ -865,14 +863,7 @@ func (am *AccountManager) Unregister(account string) error {
// on our way out, unregister all the account's channels and delete them from the db // on our way out, unregister all the account's channels and delete them from the db
defer func() { defer func() {
for _, channelName := range registeredChannels { for _, channelName := range registeredChannels {
info := am.server.channelRegistry.LoadChannel(channelName) am.server.channels.SetUnregistered(channelName, casefoldedAccount)
if info != nil && info.Founder == casefoldedAccount {
am.server.channelRegistry.Delete(channelName, *info)
}
channel := am.server.channels.Get(channelName)
if channel != nil {
channel.SetUnregistered(casefoldedAccount)
}
} }
}() }()
@ -892,6 +883,7 @@ func (am *AccountManager) Unregister(account string) error {
tx.Delete(registeredTimeKey) tx.Delete(registeredTimeKey)
tx.Delete(callbackKey) tx.Delete(callbackKey)
tx.Delete(verificationCodeKey) tx.Delete(verificationCodeKey)
tx.Delete(enforcementKey)
rawNicks, _ = tx.Get(nicksKey) rawNicks, _ = tx.Get(nicksKey)
tx.Delete(nicksKey) tx.Delete(nicksKey)
credText, err = tx.Get(credentialsKey) credText, err = tx.Get(credentialsKey)

View File

@ -16,7 +16,6 @@ type Set [bitsetLen]uint64
// NewSet returns a new Set, with the given capabilities enabled. // NewSet returns a new Set, with the given capabilities enabled.
func NewSet(capabs ...Capability) *Set { func NewSet(capabs ...Capability) *Set {
var newSet Set var newSet Set
utils.BitsetInitialize(newSet[:])
newSet.Enable(capabs...) newSet.Enable(capabs...)
return &newSet return &newSet
} }

View File

@ -22,7 +22,7 @@ import (
// Channel represents a channel that clients can join. // Channel represents a channel that clients can join.
type Channel struct { type Channel struct {
flags *modes.ModeSet flags modes.ModeSet
lists map[modes.Mode]*UserMaskSet lists map[modes.Mode]*UserMaskSet
key string key string
members MemberSet members MemberSet
@ -33,19 +33,22 @@ type Channel struct {
createdTime time.Time createdTime time.Time
registeredFounder string registeredFounder string
registeredTime time.Time registeredTime time.Time
stateMutex sync.RWMutex // tier 1
joinPartMutex sync.Mutex // tier 3
topic string topic string
topicSetBy string topicSetBy string
topicSetTime time.Time topicSetTime time.Time
userLimit int userLimit int
accountToUMode map[string]modes.Mode accountToUMode map[string]modes.Mode
history history.Buffer history history.Buffer
stateMutex sync.RWMutex // tier 1
writerSemaphore Semaphore // tier 1.5
joinPartMutex sync.Mutex // tier 3
ensureLoaded utils.Once // manages loading stored registration info from the database
dirtyBits uint
} }
// NewChannel creates a new channel from a `Server` and a `name` // NewChannel creates a new channel from a `Server` and a `name`
// string, which must be unique on the server. // string, which must be unique on the server.
func NewChannel(s *Server, name string, regInfo *RegisteredChannel) *Channel { func NewChannel(s *Server, name string, registered bool) *Channel {
casefoldedName, err := CasefoldChannel(name) casefoldedName, err := CasefoldChannel(name)
if err != nil { if err != nil {
s.logger.Error("internal", "Bad channel name", name, err.Error()) s.logger.Error("internal", "Bad channel name", name, err.Error())
@ -54,7 +57,6 @@ func NewChannel(s *Server, name string, regInfo *RegisteredChannel) *Channel {
channel := &Channel{ channel := &Channel{
createdTime: time.Now(), // may be overwritten by applyRegInfo createdTime: time.Now(), // may be overwritten by applyRegInfo
flags: modes.NewModeSet(),
lists: map[modes.Mode]*UserMaskSet{ lists: map[modes.Mode]*UserMaskSet{
modes.BanMask: NewUserMaskSet(), modes.BanMask: NewUserMaskSet(),
modes.ExceptMask: NewUserMaskSet(), modes.ExceptMask: NewUserMaskSet(),
@ -69,21 +71,43 @@ func NewChannel(s *Server, name string, regInfo *RegisteredChannel) *Channel {
config := s.Config() config := s.Config()
if regInfo != nil { channel.writerSemaphore.Initialize(1)
channel.applyRegInfo(regInfo) channel.history.Initialize(config.History.ChannelLength)
} else {
if !registered {
for _, mode := range config.Channels.defaultModes { for _, mode := range config.Channels.defaultModes {
channel.flags.SetMode(mode, true) channel.flags.SetMode(mode, true)
} }
} // no loading to do, so "mark" the load operation as "done":
channel.ensureLoaded.Do(func() {})
channel.history.Initialize(config.History.ChannelLength) } // else: modes will be loaded before first join
return channel return channel
} }
// EnsureLoaded blocks until the channel's registration info has been loaded
// from the database.
func (channel *Channel) EnsureLoaded() {
channel.ensureLoaded.Do(func() {
nmc := channel.NameCasefolded()
info, err := channel.server.channelRegistry.LoadChannel(nmc)
if err == nil {
channel.applyRegInfo(info)
} else {
channel.server.logger.Error("internal", "couldn't load channel", nmc, err.Error())
}
})
}
func (channel *Channel) IsLoaded() bool {
return channel.ensureLoaded.Done()
}
// read in channel state that was persisted in the DB // read in channel state that was persisted in the DB
func (channel *Channel) applyRegInfo(chanReg *RegisteredChannel) { func (channel *Channel) applyRegInfo(chanReg RegisteredChannel) {
channel.stateMutex.Lock()
defer channel.stateMutex.Unlock()
channel.registeredFounder = chanReg.Founder channel.registeredFounder = chanReg.Founder
channel.registeredTime = chanReg.RegisteredAt channel.registeredTime = chanReg.RegisteredAt
channel.topic = chanReg.Topic channel.topic = chanReg.Topic
@ -116,6 +140,7 @@ func (channel *Channel) ExportRegistration(includeFlags uint) (info RegisteredCh
defer channel.stateMutex.RUnlock() defer channel.stateMutex.RUnlock()
info.Name = channel.name info.Name = channel.name
info.NameCasefolded = channel.nameCasefolded
info.Founder = channel.registeredFounder info.Founder = channel.registeredFounder
info.RegisteredAt = channel.registeredTime info.RegisteredAt = channel.registeredTime
@ -149,6 +174,115 @@ func (channel *Channel) ExportRegistration(includeFlags uint) (info RegisteredCh
return return
} }
// begin: asynchronous database writeback implementation, modeled on irc/socket.go
// MarkDirty marks part (or all) of a channel's data as needing to be written back
// to the database, then starts a writer goroutine if necessary.
// This is the equivalent of Socket.Write().
func (channel *Channel) MarkDirty(dirtyBits uint) {
channel.stateMutex.Lock()
isRegistered := channel.registeredFounder != ""
channel.dirtyBits = channel.dirtyBits | dirtyBits
channel.stateMutex.Unlock()
if !isRegistered {
return
}
channel.wakeWriter()
}
// IsClean returns whether a channel can be safely removed from the server.
// To avoid the obvious TOCTOU race condition, it must be called while holding
// ChannelManager's lock (that way, no one can join and make the channel dirty again
// between this method exiting and the actual deletion).
func (channel *Channel) IsClean() bool {
if !channel.writerSemaphore.TryAcquire() {
// a database write (which may fail) is in progress, the channel cannot be cleaned up
return false
}
defer channel.writerSemaphore.Release()
channel.stateMutex.RLock()
defer channel.stateMutex.RUnlock()
// the channel must be empty, and either be unregistered or fully written to the DB
return len(channel.members) == 0 && (channel.registeredFounder == "" || channel.dirtyBits == 0)
}
func (channel *Channel) wakeWriter() {
if channel.writerSemaphore.TryAcquire() {
go channel.writeLoop()
}
}
// equivalent of Socket.send()
func (channel *Channel) writeLoop() {
for {
// TODO(#357) check the error value of this and implement timed backoff
channel.performWrite(0)
channel.writerSemaphore.Release()
channel.stateMutex.RLock()
isDirty := channel.dirtyBits != 0
isEmpty := len(channel.members) == 0
channel.stateMutex.RUnlock()
if !isDirty {
if isEmpty {
channel.server.channels.Cleanup(channel)
}
return // nothing to do
} // else: isDirty, so we need to write again
if !channel.writerSemaphore.TryAcquire() {
return
}
}
}
// Store writes part (or all) of the channel's data back to the database,
// blocking until the write is complete. This is the equivalent of
// Socket.BlockingWrite.
func (channel *Channel) Store(dirtyBits uint) (err error) {
defer func() {
channel.stateMutex.Lock()
isDirty := channel.dirtyBits != 0
isEmpty := len(channel.members) == 0
channel.stateMutex.Unlock()
if isDirty {
channel.wakeWriter()
} else if isEmpty {
channel.server.channels.Cleanup(channel)
}
}()
channel.writerSemaphore.Acquire()
defer channel.writerSemaphore.Release()
return channel.performWrite(dirtyBits)
}
// do an individual write; equivalent of Socket.send()
func (channel *Channel) performWrite(additionalDirtyBits uint) (err error) {
channel.stateMutex.Lock()
dirtyBits := channel.dirtyBits | additionalDirtyBits
channel.dirtyBits = 0
isRegistered := channel.registeredFounder != ""
channel.stateMutex.Unlock()
if !isRegistered || dirtyBits == 0 {
return
}
info := channel.ExportRegistration(dirtyBits)
err = channel.server.channelRegistry.StoreChannel(info, dirtyBits)
if err != nil {
channel.stateMutex.Lock()
channel.dirtyBits = channel.dirtyBits | dirtyBits
channel.stateMutex.Unlock()
}
return
}
// SetRegistered registers the channel, returning an error if it was already registered. // SetRegistered registers the channel, returning an error if it was already registered.
func (channel *Channel) SetRegistered(founder string) error { func (channel *Channel) SetRegistered(founder string) error {
channel.stateMutex.Lock() channel.stateMutex.Lock()
@ -698,7 +832,7 @@ func (channel *Channel) SetTopic(client *Client, topic string, rb *ResponseBuffe
} }
} }
go channel.server.channelRegistry.StoreChannel(channel, IncludeTopic) channel.MarkDirty(IncludeTopic)
} }
// CanSpeak returns true if the client can speak on this channel. // CanSpeak returns true if the client can speak on this channel.

View File

@ -19,25 +19,38 @@ type channelManagerEntry struct {
// providing synchronization for creation of new channels on first join, // providing synchronization for creation of new channels on first join,
// cleanup of empty channels on last part, and renames. // cleanup of empty channels on last part, and renames.
type ChannelManager struct { type ChannelManager struct {
sync.RWMutex // tier 2 sync.RWMutex // tier 2
chans map[string]*channelManagerEntry chans map[string]*channelManagerEntry
registeredChannels map[string]bool
server *Server
} }
// NewChannelManager returns a new ChannelManager. // NewChannelManager returns a new ChannelManager.
func NewChannelManager() *ChannelManager { func (cm *ChannelManager) Initialize(server *Server) {
return &ChannelManager{ cm.chans = make(map[string]*channelManagerEntry)
chans: make(map[string]*channelManagerEntry), cm.server = server
if server.Config().Channels.Registration.Enabled {
cm.loadRegisteredChannels()
} }
} }
func (cm *ChannelManager) loadRegisteredChannels() {
registeredChannels := cm.server.channelRegistry.AllChannels()
cm.Lock()
defer cm.Unlock()
cm.registeredChannels = registeredChannels
}
// Get returns an existing channel with name equivalent to `name`, or nil // Get returns an existing channel with name equivalent to `name`, or nil
func (cm *ChannelManager) Get(name string) *Channel { func (cm *ChannelManager) Get(name string) (channel *Channel) {
name, err := CasefoldChannel(name) name, err := CasefoldChannel(name)
if err == nil { if err == nil {
cm.RLock() cm.RLock()
defer cm.RUnlock() defer cm.RUnlock()
entry := cm.chans[name] entry := cm.chans[name]
if entry != nil { // if the channel is still loading, pretend we don't have it
if entry != nil && entry.channel.IsLoaded() {
return entry.channel return entry.channel
} }
} }
@ -55,28 +68,21 @@ func (cm *ChannelManager) Join(client *Client, name string, key string, isSajoin
cm.Lock() cm.Lock()
entry := cm.chans[casefoldedName] entry := cm.chans[casefoldedName]
if entry == nil { if entry == nil {
// XXX give up the lock to check for a registration, then check again registered := cm.registeredChannels[casefoldedName]
// to see if we need to create the channel. we could solve this by doing LoadChannel entry = &channelManagerEntry{
// outside the lock initially on every join, so this is best thought of as an channel: NewChannel(server, name, registered),
// optimization to avoid that. pendingJoins: 0,
cm.Unlock()
info := client.server.channelRegistry.LoadChannel(casefoldedName)
cm.Lock()
entry = cm.chans[casefoldedName]
if entry == nil {
entry = &channelManagerEntry{
channel: NewChannel(server, name, info),
pendingJoins: 0,
}
cm.chans[casefoldedName] = entry
} }
cm.chans[casefoldedName] = entry
} }
entry.pendingJoins += 1 entry.pendingJoins += 1
channel := entry.channel
cm.Unlock() cm.Unlock()
entry.channel.Join(client, key, isSajoin, rb) channel.EnsureLoaded()
channel.Join(client, key, isSajoin, rb)
cm.maybeCleanup(entry.channel, true) cm.maybeCleanup(channel, true)
return nil return nil
} }
@ -85,7 +91,8 @@ func (cm *ChannelManager) maybeCleanup(channel *Channel, afterJoin bool) {
cm.Lock() cm.Lock()
defer cm.Unlock() defer cm.Unlock()
entry := cm.chans[channel.NameCasefolded()] nameCasefolded := channel.NameCasefolded()
entry := cm.chans[nameCasefolded]
if entry == nil || entry.channel != channel { if entry == nil || entry.channel != channel {
return return
} }
@ -93,23 +100,15 @@ func (cm *ChannelManager) maybeCleanup(channel *Channel, afterJoin bool) {
if afterJoin { if afterJoin {
entry.pendingJoins -= 1 entry.pendingJoins -= 1
} }
// TODO(slingamn) right now, registered channels cannot be cleaned up. if entry.pendingJoins == 0 && entry.channel.IsClean() {
// this is because once ChannelManager becomes the source of truth about a channel, delete(cm.chans, nameCasefolded)
// we can't move the source of truth back to the database unless we do an ACID
// store while holding the ChannelManager's Lock(). This is pending more decisions
// about where the database transaction lock fits into the overall lock model.
if !entry.channel.IsRegistered() && entry.channel.IsEmpty() && entry.pendingJoins == 0 {
// reread the name, handling the case where the channel was renamed
casefoldedName := entry.channel.NameCasefolded()
delete(cm.chans, casefoldedName)
// invalidate the entry (otherwise, a subsequent cleanup attempt could delete
// a valid, distinct entry under casefoldedName):
entry.channel = nil
} }
} }
// Part parts `client` from the channel named `name`, deleting it if it's empty. // Part parts `client` from the channel named `name`, deleting it if it's empty.
func (cm *ChannelManager) Part(client *Client, name string, message string, rb *ResponseBuffer) error { func (cm *ChannelManager) Part(client *Client, name string, message string, rb *ResponseBuffer) error {
var channel *Channel
casefoldedName, err := CasefoldChannel(name) casefoldedName, err := CasefoldChannel(name)
if err != nil { if err != nil {
return errNoSuchChannel return errNoSuchChannel
@ -117,12 +116,15 @@ func (cm *ChannelManager) Part(client *Client, name string, message string, rb *
cm.RLock() cm.RLock()
entry := cm.chans[casefoldedName] entry := cm.chans[casefoldedName]
if entry != nil {
channel = entry.channel
}
cm.RUnlock() cm.RUnlock()
if entry == nil { if channel == nil {
return errNoSuchChannel return errNoSuchChannel
} }
entry.channel.Part(client, message, rb) channel.Part(client, message, rb)
return nil return nil
} }
@ -130,8 +132,68 @@ func (cm *ChannelManager) Cleanup(channel *Channel) {
cm.maybeCleanup(channel, false) cm.maybeCleanup(channel, false)
} }
func (cm *ChannelManager) SetRegistered(channelName string, account string) (err error) {
var channel *Channel
cfname, err := CasefoldChannel(channelName)
if err != nil {
return err
}
var entry *channelManagerEntry
defer func() {
if err == nil && channel != nil {
// registration was successful: make the database reflect it
err = channel.Store(IncludeAllChannelAttrs)
}
}()
cm.Lock()
defer cm.Unlock()
entry = cm.chans[cfname]
if entry == nil {
return errNoSuchChannel
}
channel = entry.channel
err = channel.SetRegistered(account)
if err != nil {
return err
}
cm.registeredChannels[cfname] = true
return nil
}
func (cm *ChannelManager) SetUnregistered(channelName string, account string) (err error) {
cfname, err := CasefoldChannel(channelName)
if err != nil {
return err
}
var info RegisteredChannel
defer func() {
if err == nil {
err = cm.server.channelRegistry.Delete(info)
}
}()
cm.Lock()
defer cm.Unlock()
entry := cm.chans[cfname]
if entry == nil {
return errNoSuchChannel
}
info = entry.channel.ExportRegistration(0)
if info.Founder != account {
return errChannelNotOwnedByAccount
}
entry.channel.SetUnregistered(account)
delete(cm.registeredChannels, cfname)
return nil
}
// Rename renames a channel (but does not notify the members) // Rename renames a channel (but does not notify the members)
func (cm *ChannelManager) Rename(name string, newname string) error { func (cm *ChannelManager) Rename(name string, newname string) (err error) {
cfname, err := CasefoldChannel(name) cfname, err := CasefoldChannel(name)
if err != nil { if err != nil {
return errNoSuchChannel return errNoSuchChannel
@ -142,22 +204,37 @@ func (cm *ChannelManager) Rename(name string, newname string) error {
return errInvalidChannelName return errInvalidChannelName
} }
var channel *Channel
var info RegisteredChannel
defer func() {
if channel != nil && info.Founder != "" {
channel.Store(IncludeAllChannelAttrs)
// we just flushed the channel under its new name, therefore this delete
// cannot be overwritten by a write to the old name:
cm.server.channelRegistry.Delete(info)
}
}()
cm.Lock() cm.Lock()
defer cm.Unlock() defer cm.Unlock()
if cm.chans[cfnewname] != nil { if cm.chans[cfnewname] != nil || cm.registeredChannels[cfnewname] {
return errChannelNameInUse return errChannelNameInUse
} }
entry := cm.chans[cfname] entry := cm.chans[cfname]
if entry == nil { if entry == nil {
return errNoSuchChannel return errNoSuchChannel
} }
channel = entry.channel
info = channel.ExportRegistration(IncludeInitial)
delete(cm.chans, cfname) delete(cm.chans, cfname)
cm.chans[cfnewname] = entry cm.chans[cfnewname] = entry
entry.channel.setName(newname) if cm.registeredChannels[cfname] {
entry.channel.setNameCasefolded(cfnewname) delete(cm.registeredChannels, cfname)
cm.registeredChannels[cfnewname] = true
}
entry.channel.Rename(newname, cfnewname)
return nil return nil
} }
// Len returns the number of channels // Len returns the number of channels
@ -171,8 +248,11 @@ func (cm *ChannelManager) Len() int {
func (cm *ChannelManager) Channels() (result []*Channel) { func (cm *ChannelManager) Channels() (result []*Channel) {
cm.RLock() cm.RLock()
defer cm.RUnlock() defer cm.RUnlock()
result = make([]*Channel, 0, len(cm.chans))
for _, entry := range cm.chans { for _, entry := range cm.chans {
result = append(result, entry.channel) if entry.channel.IsLoaded() {
result = append(result, entry.channel)
}
} }
return return
} }

View File

@ -7,7 +7,6 @@ import (
"fmt" "fmt"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
"encoding/json" "encoding/json"
@ -71,6 +70,8 @@ const (
type RegisteredChannel struct { type RegisteredChannel struct {
// Name of the channel. // Name of the channel.
Name string Name string
// Casefolded name of the channel.
NameCasefolded string
// RegisteredAt represents the time that the channel was registered. // RegisteredAt represents the time that the channel was registered.
RegisteredAt time.Time RegisteredAt time.Time
// Founder indicates the founder of the channel. // Founder indicates the founder of the channel.
@ -97,58 +98,65 @@ type RegisteredChannel struct {
// ChannelRegistry manages registered channels. // ChannelRegistry manages registered channels.
type ChannelRegistry struct { type ChannelRegistry struct {
// This serializes operations of the form (read channel state, synchronously persist it); server *Server
// this is enough to guarantee eventual consistency of the database with the
// ChannelManager and Channel objects, which are the source of truth.
//
// We could use the buntdb RW transaction lock for this purpose but we share
// that with all the other modules, so let's not.
sync.Mutex // tier 2
server *Server
} }
// NewChannelRegistry returns a new ChannelRegistry. // NewChannelRegistry returns a new ChannelRegistry.
func NewChannelRegistry(server *Server) *ChannelRegistry { func (reg *ChannelRegistry) Initialize(server *Server) {
return &ChannelRegistry{ reg.server = server
server: server, }
}
func (reg *ChannelRegistry) AllChannels() (result map[string]bool) {
result = make(map[string]bool)
prefix := fmt.Sprintf(keyChannelExists, "")
reg.server.store.View(func(tx *buntdb.Tx) error {
return tx.AscendGreaterOrEqual("", prefix, func(key, value string) bool {
if !strings.HasPrefix(key, prefix) {
return false
}
channel := strings.TrimPrefix(key, prefix)
result[channel] = true
return true
})
})
return
} }
// StoreChannel obtains a consistent view of a channel, then persists it to the store. // StoreChannel obtains a consistent view of a channel, then persists it to the store.
func (reg *ChannelRegistry) StoreChannel(channel *Channel, includeFlags uint) { func (reg *ChannelRegistry) StoreChannel(info RegisteredChannel, includeFlags uint) (err error) {
if !reg.server.ChannelRegistrationEnabled() { if !reg.server.ChannelRegistrationEnabled() {
return return
} }
reg.Lock()
defer reg.Unlock()
key := channel.NameCasefolded()
info := channel.ExportRegistration(includeFlags)
if info.Founder == "" { if info.Founder == "" {
// sanity check, don't try to store an unregistered channel // sanity check, don't try to store an unregistered channel
return return
} }
reg.server.store.Update(func(tx *buntdb.Tx) error { reg.server.store.Update(func(tx *buntdb.Tx) error {
reg.saveChannel(tx, key, info, includeFlags) reg.saveChannel(tx, info, includeFlags)
return nil return nil
}) })
return nil
} }
// LoadChannel loads a channel from the store. // LoadChannel loads a channel from the store.
func (reg *ChannelRegistry) LoadChannel(nameCasefolded string) (info *RegisteredChannel) { func (reg *ChannelRegistry) LoadChannel(nameCasefolded string) (info RegisteredChannel, err error) {
if !reg.server.ChannelRegistrationEnabled() { if !reg.server.ChannelRegistrationEnabled() {
return nil err = errFeatureDisabled
return
} }
channelKey := nameCasefolded channelKey := nameCasefolded
// nice to have: do all JSON (de)serialization outside of the buntdb transaction // nice to have: do all JSON (de)serialization outside of the buntdb transaction
reg.server.store.View(func(tx *buntdb.Tx) error { err = reg.server.store.View(func(tx *buntdb.Tx) error {
_, err := tx.Get(fmt.Sprintf(keyChannelExists, channelKey)) _, dberr := tx.Get(fmt.Sprintf(keyChannelExists, channelKey))
if err == buntdb.ErrNotFound { if dberr == buntdb.ErrNotFound {
// chan does not already exist, return // chan does not already exist, return
return nil return errNoSuchChannel
} }
// channel exists, load it // channel exists, load it
@ -181,7 +189,7 @@ func (reg *ChannelRegistry) LoadChannel(nameCasefolded string) (info *Registered
accountToUMode := make(map[string]modes.Mode) accountToUMode := make(map[string]modes.Mode)
_ = json.Unmarshal([]byte(accountToUModeString), &accountToUMode) _ = json.Unmarshal([]byte(accountToUModeString), &accountToUMode)
info = &RegisteredChannel{ info = RegisteredChannel{
Name: name, Name: name,
RegisteredAt: time.Unix(regTimeInt, 0), RegisteredAt: time.Unix(regTimeInt, 0),
Founder: founder, Founder: founder,
@ -198,46 +206,21 @@ func (reg *ChannelRegistry) LoadChannel(nameCasefolded string) (info *Registered
return nil return nil
}) })
return info return
} }
func (reg *ChannelRegistry) Delete(casefoldedName string, info RegisteredChannel) { // Delete deletes a channel corresponding to `info`. If no such channel
// is present in the database, no error is returned.
func (reg *ChannelRegistry) Delete(info RegisteredChannel) (err error) {
if !reg.server.ChannelRegistrationEnabled() { if !reg.server.ChannelRegistrationEnabled() {
return return
} }
reg.Lock()
defer reg.Unlock()
reg.server.store.Update(func(tx *buntdb.Tx) error { reg.server.store.Update(func(tx *buntdb.Tx) error {
reg.deleteChannel(tx, casefoldedName, info) reg.deleteChannel(tx, info.NameCasefolded, info)
return nil
})
}
// Rename handles the persistence part of a channel rename: the channel is
// persisted under its new name, and the old name is cleaned up if necessary.
func (reg *ChannelRegistry) Rename(channel *Channel, casefoldedOldName string) {
if !reg.server.ChannelRegistrationEnabled() {
return
}
reg.Lock()
defer reg.Unlock()
includeFlags := IncludeAllChannelAttrs
oldKey := casefoldedOldName
key := channel.NameCasefolded()
info := channel.ExportRegistration(includeFlags)
if info.Founder == "" {
return
}
reg.server.store.Update(func(tx *buntdb.Tx) error {
reg.deleteChannel(tx, oldKey, info)
reg.saveChannel(tx, key, info, includeFlags)
return nil return nil
}) })
return nil
} }
// delete a channel, unless it was overwritten by another registration of the same channel // delete a channel, unless it was overwritten by another registration of the same channel
@ -274,7 +257,8 @@ func (reg *ChannelRegistry) deleteChannel(tx *buntdb.Tx, key string, info Regist
} }
// saveChannel saves a channel to the store. // saveChannel saves a channel to the store.
func (reg *ChannelRegistry) saveChannel(tx *buntdb.Tx, channelKey string, channelInfo RegisteredChannel, includeFlags uint) { func (reg *ChannelRegistry) saveChannel(tx *buntdb.Tx, channelInfo RegisteredChannel, includeFlags uint) {
channelKey := channelInfo.NameCasefolded
// maintain the mapping of account -> registered channels // maintain the mapping of account -> registered channels
chanExistsKey := fmt.Sprintf(keyChannelExists, channelKey) chanExistsKey := fmt.Sprintf(keyChannelExists, channelKey)
_, existsErr := tx.Get(chanExistsKey) _, existsErr := tx.Get(chanExistsKey)

View File

@ -232,15 +232,12 @@ func csRegisterHandler(server *Server, client *Client, command string, params []
} }
// this provides the synchronization that allows exactly one registration of the channel: // this provides the synchronization that allows exactly one registration of the channel:
err = channelInfo.SetRegistered(account) err = server.channels.SetRegistered(channelKey, account)
if err != nil { if err != nil {
csNotice(rb, err.Error()) csNotice(rb, err.Error())
return return
} }
// registration was successful: make the database reflect it
go server.channelRegistry.StoreChannel(channelInfo, IncludeAllChannelAttrs)
csNotice(rb, fmt.Sprintf(client.t("Channel %s successfully registered"), channelName)) csNotice(rb, fmt.Sprintf(client.t("Channel %s successfully registered"), channelName))
server.logger.Info("services", fmt.Sprintf("Client %s registered channel %s", client.nick, channelName)) server.logger.Info("services", fmt.Sprintf("Client %s registered channel %s", client.nick, channelName))
@ -297,8 +294,7 @@ func csUnregisterHandler(server *Server, client *Client, command string, params
return return
} }
channel.SetUnregistered(founder) server.channels.SetUnregistered(channelKey, founder)
server.channelRegistry.Delete(channelKey, info)
csNotice(rb, fmt.Sprintf(client.t("Channel %s is now unregistered"), channelKey)) csNotice(rb, fmt.Sprintf(client.t("Channel %s is now unregistered"), channelKey))
} }

View File

@ -50,7 +50,7 @@ type Client struct {
accountName string // display name of the account: uncasefolded, '*' if not logged in accountName string // display name of the account: uncasefolded, '*' if not logged in
atime time.Time atime time.Time
awayMessage string awayMessage string
capabilities *caps.Set capabilities caps.Set
capState caps.State capState caps.State
capVersion caps.Version capVersion caps.Version
certfp string certfp string
@ -58,7 +58,7 @@ type Client struct {
ctime time.Time ctime time.Time
exitedSnomaskSent bool exitedSnomaskSent bool
fakelag Fakelag fakelag Fakelag
flags *modes.ModeSet flags modes.ModeSet
hasQuit bool hasQuit bool
hops int hops int
hostname string hostname string
@ -125,15 +125,13 @@ func RunNewClient(server *Server, conn clientConn) {
// give them 1k of grace over the limit: // give them 1k of grace over the limit:
socket := NewSocket(conn.Conn, fullLineLenLimit+1024, config.Server.MaxSendQBytes) socket := NewSocket(conn.Conn, fullLineLenLimit+1024, config.Server.MaxSendQBytes)
client := &Client{ client := &Client{
atime: now, atime: now,
capabilities: caps.NewSet(), capState: caps.NoneState,
capState: caps.NoneState, capVersion: caps.Cap301,
capVersion: caps.Cap301, channels: make(ChannelSet),
channels: make(ChannelSet), ctime: now,
ctime: now, isTor: conn.IsTor,
flags: modes.NewModeSet(), languages: server.Languages().Default(),
isTor: conn.IsTor,
languages: server.Languages().Default(),
loginThrottle: connection_limits.GenericThrottle{ loginThrottle: connection_limits.GenericThrottle{
Duration: config.Accounts.LoginThrottling.Duration, Duration: config.Accounts.LoginThrottling.Duration,
Limit: config.Accounts.LoginThrottling.MaxAttempts, Limit: config.Accounts.LoginThrottling.MaxAttempts,
@ -546,7 +544,6 @@ func (client *Client) replayPrivmsgHistory(rb *ResponseBuffer, items []history.I
// copy applicable state from oldClient to client as part of a resume // copy applicable state from oldClient to client as part of a resume
func (client *Client) copyResumeData(oldClient *Client) { func (client *Client) copyResumeData(oldClient *Client) {
oldClient.stateMutex.RLock() oldClient.stateMutex.RLock()
flags := oldClient.flags
history := oldClient.history history := oldClient.history
nick := oldClient.nick nick := oldClient.nick
nickCasefolded := oldClient.nickCasefolded nickCasefolded := oldClient.nickCasefolded
@ -560,7 +557,7 @@ func (client *Client) copyResumeData(oldClient *Client) {
// resume over plaintext) // resume over plaintext)
hasTLS := client.flags.HasMode(modes.TLS) hasTLS := client.flags.HasMode(modes.TLS)
temp := modes.NewModeSet() temp := modes.NewModeSet()
temp.Copy(flags) temp.Copy(&oldClient.flags)
temp.SetMode(modes.TLS, hasTLS) temp.SetMode(modes.TLS, hasTLS)
client.flags.Copy(temp) client.flags.Copy(temp)

View File

@ -37,12 +37,10 @@ type ClientManager struct {
bySkeleton map[string]*Client bySkeleton map[string]*Client
} }
// NewClientManager returns a new ClientManager. // Initialize initializes a ClientManager.
func NewClientManager() *ClientManager { func (clients *ClientManager) Initialize() {
return &ClientManager{ clients.byNick = make(map[string]*Client)
byNick: make(map[string]*Client), clients.bySkeleton = make(map[string]*Client)
bySkeleton: make(map[string]*Client),
}
} }
// Count returns how many clients are in the manager. // Count returns how many clients are in the manager.

View File

@ -27,6 +27,8 @@ var (
errAccountMustHoldNick = errors.New(`You must hold that nickname in order to register it`) errAccountMustHoldNick = errors.New(`You must hold that nickname in order to register it`)
errCallbackFailed = errors.New("Account verification could not be sent") errCallbackFailed = errors.New("Account verification could not be sent")
errCertfpAlreadyExists = errors.New(`An account already exists for your certificate fingerprint`) errCertfpAlreadyExists = errors.New(`An account already exists for your certificate fingerprint`)
errChannelNotOwnedByAccount = errors.New("Channel not owned by the specified account")
errChannelDoesNotExist = errors.New("Channel does not exist")
errChannelAlreadyRegistered = errors.New("Channel is already registered") errChannelAlreadyRegistered = errors.New("Channel is already registered")
errChannelNameInUse = errors.New(`Channel name in use`) errChannelNameInUse = errors.New(`Channel name in use`)
errInvalidChannelName = errors.New(`Invalid channel name`) errInvalidChannelName = errors.New(`Invalid channel name`)

View File

@ -4,6 +4,8 @@
package irc package irc
import ( import (
"time"
"github.com/oragono/oragono/irc/isupport" "github.com/oragono/oragono/irc/isupport"
"github.com/oragono/oragono/irc/languages" "github.com/oragono/oragono/irc/languages"
"github.com/oragono/oragono/irc/modes" "github.com/oragono/oragono/irc/modes"
@ -267,22 +269,20 @@ func (channel *Channel) Name() string {
return channel.name return channel.name
} }
func (channel *Channel) setName(name string) {
channel.stateMutex.Lock()
defer channel.stateMutex.Unlock()
channel.name = name
}
func (channel *Channel) NameCasefolded() string { func (channel *Channel) NameCasefolded() string {
channel.stateMutex.RLock() channel.stateMutex.RLock()
defer channel.stateMutex.RUnlock() defer channel.stateMutex.RUnlock()
return channel.nameCasefolded return channel.nameCasefolded
} }
func (channel *Channel) setNameCasefolded(nameCasefolded string) { func (channel *Channel) Rename(name, nameCasefolded string) {
channel.stateMutex.Lock() channel.stateMutex.Lock()
defer channel.stateMutex.Unlock() channel.name = name
channel.nameCasefolded = nameCasefolded channel.nameCasefolded = nameCasefolded
if channel.registeredFounder != "" {
channel.registeredTime = time.Now()
}
channel.stateMutex.Unlock()
} }
func (channel *Channel) Members() (result []*Client) { func (channel *Channel) Members() (result []*Client) {
@ -314,3 +314,10 @@ func (channel *Channel) Founder() string {
defer channel.stateMutex.RUnlock() defer channel.stateMutex.RUnlock()
return channel.registeredFounder return channel.registeredFounder
} }
func (channel *Channel) DirtyBits() (dirtyBits uint) {
channel.stateMutex.Lock()
dirtyBits = channel.dirtyBits
channel.stateMutex.Unlock()
return
}

View File

@ -1607,8 +1607,8 @@ func cmodeHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Res
} }
} }
if channel.IsRegistered() && includeFlags != 0 { if includeFlags != 0 {
go server.channelRegistry.StoreChannel(channel, includeFlags) channel.MarkDirty(includeFlags)
} }
// send out changes // send out changes
@ -2167,7 +2167,6 @@ func renameHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Re
rb.Add(nil, server.name, ERR_NOSUCHCHANNEL, client.Nick(), oldName, client.t("No such channel")) rb.Add(nil, server.name, ERR_NOSUCHCHANNEL, client.Nick(), oldName, client.t("No such channel"))
return false return false
} }
casefoldedOldName := channel.NameCasefolded()
if !(channel.ClientIsAtLeast(client, modes.Operator) || client.HasRoleCapabs("chanreg")) { if !(channel.ClientIsAtLeast(client, modes.Operator) || client.HasRoleCapabs("chanreg")) {
rb.Add(nil, server.name, ERR_CHANOPRIVSNEEDED, client.Nick(), oldName, client.t("You're not a channel operator")) rb.Add(nil, server.name, ERR_CHANOPRIVSNEEDED, client.Nick(), oldName, client.t("You're not a channel operator"))
return false return false
@ -2192,9 +2191,6 @@ func renameHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Re
return false return false
} }
// rename succeeded, persist it
go server.channelRegistry.Rename(channel, casefoldedOldName)
// send RENAME messages // send RENAME messages
clientPrefix := client.NickMaskString() clientPrefix := client.NickMaskString()
for _, mcl := range channel.Members() { for _, mcl := range channel.Members() {

View File

@ -290,14 +290,14 @@ func (channel *Channel) ProcessAccountToUmodeChange(client *Client, change modes
case modes.Add: case modes.Add:
if targetModeNow != targetModeAfter { if targetModeNow != targetModeAfter {
channel.accountToUMode[change.Arg] = change.Mode channel.accountToUMode[change.Arg] = change.Mode
go client.server.channelRegistry.StoreChannel(channel, IncludeLists) channel.MarkDirty(IncludeLists)
return []modes.ModeChange{change}, nil return []modes.ModeChange{change}, nil
} }
return nil, nil return nil, nil
case modes.Remove: case modes.Remove:
if targetModeNow == change.Mode { if targetModeNow == change.Mode {
delete(channel.accountToUMode, change.Arg) delete(channel.accountToUMode, change.Arg)
go client.server.channelRegistry.StoreChannel(channel, IncludeLists) channel.MarkDirty(IncludeLists)
return []modes.ModeChange{change}, nil return []modes.ModeChange{change}, nil
} }
return nil, nil return nil, nil

View File

@ -335,7 +335,6 @@ const (
// returns a pointer to a new ModeSet // returns a pointer to a new ModeSet
func NewModeSet() *ModeSet { func NewModeSet() *ModeSet {
var set ModeSet var set ModeSet
utils.BitsetInitialize(set[:])
return &set return &set
} }

View File

@ -32,14 +32,13 @@ type ServerSemaphores struct {
ClientDestroy Semaphore ClientDestroy Semaphore
} }
// NewServerSemaphores creates a new ServerSemaphores. // Initialize initializes a set of server semaphores.
func NewServerSemaphores() (result *ServerSemaphores) { func (serversem *ServerSemaphores) Initialize() {
capacity := runtime.NumCPU() capacity := runtime.NumCPU()
if capacity > MaxServerSemaphoreCapacity { if capacity > MaxServerSemaphoreCapacity {
capacity = MaxServerSemaphoreCapacity capacity = MaxServerSemaphoreCapacity
} }
result = new(ServerSemaphores) serversem.ClientDestroy.Initialize(capacity)
result.ClientDestroy.Initialize(capacity)
return return
} }

View File

@ -61,10 +61,10 @@ type ListenerWrapper struct {
// Server is the main Oragono server. // Server is the main Oragono server.
type Server struct { type Server struct {
accounts *AccountManager accounts AccountManager
channels *ChannelManager channels ChannelManager
channelRegistry *ChannelRegistry channelRegistry ChannelRegistry
clients *ClientManager clients ClientManager
config *Config config *Config
configFilename string configFilename string
configurableStateMutex sync.RWMutex // tier 1; generic protection for server state modified by rehash() configurableStateMutex sync.RWMutex // tier 1; generic protection for server state modified by rehash()
@ -89,9 +89,9 @@ type Server struct {
snomasks *SnoManager snomasks *SnoManager
store *buntdb.DB store *buntdb.DB
torLimiter connection_limits.TorLimiter torLimiter connection_limits.TorLimiter
whoWas *WhoWasList whoWas WhoWasList
stats *Stats stats Stats
semaphores *ServerSemaphores semaphores ServerSemaphores
} }
var ( var (
@ -113,8 +113,6 @@ type clientConn struct {
func NewServer(config *Config, logger *logger.Manager) (*Server, error) { func NewServer(config *Config, logger *logger.Manager) (*Server, error) {
// initialize data structures // initialize data structures
server := &Server{ server := &Server{
channels: NewChannelManager(),
clients: NewClientManager(),
connectionLimiter: connection_limits.NewLimiter(), connectionLimiter: connection_limits.NewLimiter(),
connectionThrottler: connection_limits.NewThrottler(), connectionThrottler: connection_limits.NewThrottler(),
listeners: make(map[string]*ListenerWrapper), listeners: make(map[string]*ListenerWrapper),
@ -123,12 +121,12 @@ func NewServer(config *Config, logger *logger.Manager) (*Server, error) {
rehashSignal: make(chan os.Signal, 1), rehashSignal: make(chan os.Signal, 1),
signals: make(chan os.Signal, len(ServerExitSignals)), signals: make(chan os.Signal, len(ServerExitSignals)),
snomasks: NewSnoManager(), snomasks: NewSnoManager(),
whoWas: NewWhoWasList(config.Limits.WhowasEntries),
stats: NewStats(),
semaphores: NewServerSemaphores(),
} }
server.clients.Initialize()
server.semaphores.Initialize()
server.resumeManager.Initialize(server) server.resumeManager.Initialize(server)
server.whoWas.Initialize(config.Limits.WhowasEntries)
if err := server.applyConfig(config, true); err != nil { if err := server.applyConfig(config, true); err != nil {
return nil, err return nil, err
@ -697,6 +695,12 @@ func (server *Server) applyConfig(config *Config, initial bool) (err error) {
server.accounts.initVHostRequestQueue() server.accounts.initVHostRequestQueue()
} }
chanRegPreviouslyDisabled := oldConfig != nil && !oldConfig.Channels.Registration.Enabled
chanRegNowEnabled := config.Channels.Registration.Enabled
if chanRegPreviouslyDisabled && chanRegNowEnabled {
server.channels.loadRegisteredChannels()
}
// MaxLine // MaxLine
if config.Limits.LineLen.Rest != 512 { if config.Limits.LineLen.Rest != 512 {
SupportedCapabilities.Enable(caps.MaxLine) SupportedCapabilities.Enable(caps.MaxLine)
@ -922,9 +926,9 @@ func (server *Server) loadDatastore(config *Config) error {
server.loadDLines() server.loadDLines()
server.loadKLines() server.loadKLines()
server.channelRegistry = NewChannelRegistry(server) server.channelRegistry.Initialize(server)
server.channels.Initialize(server)
server.accounts = NewAccountManager(server) server.accounts.Initialize(server)
return nil return nil
} }

View File

@ -13,17 +13,6 @@ type Stats struct {
Operators int Operators int
} }
// NewStats creates a new instance of Stats
func NewStats() *Stats {
serverStats := &Stats{
Total: 0,
Invisible: 0,
Operators: 0,
}
return serverStats
}
// ChangeTotal increments the total user count on server // ChangeTotal increments the total user count on server
func (s *Stats) ChangeTotal(i int) { func (s *Stats) ChangeTotal(i int) {
s.Lock() s.Lock()

View File

@ -9,17 +9,6 @@ import "sync/atomic"
// For examples of use, see caps.Set and modes.ModeSet; the array has to be converted to a // For examples of use, see caps.Set and modes.ModeSet; the array has to be converted to a
// slice to use these functions. // slice to use these functions.
// BitsetInitialize initializes a bitset.
func BitsetInitialize(set []uint64) {
// XXX re-zero the bitset using atomic stores. it's unclear whether this is required,
// however, golang issue #5045 suggests that you shouldn't mix atomic operations
// with non-atomic operations (such as the runtime's automatic zero-initialization) on
// the same word
for i := 0; i < len(set); i++ {
atomic.StoreUint64(&set[i], 0)
}
}
// BitsetGet returns whether a given bit of the bitset is set. // BitsetGet returns whether a given bit of the bitset is set.
func BitsetGet(set []uint64, position uint) bool { func BitsetGet(set []uint64, position uint) bool {
idx := position / 64 idx := position / 64

View File

@ -10,7 +10,6 @@ type testBitset [2]uint64
func TestSets(t *testing.T) { func TestSets(t *testing.T) {
var t1 testBitset var t1 testBitset
t1s := t1[:] t1s := t1[:]
BitsetInitialize(t1s)
if BitsetGet(t1s, 0) || BitsetGet(t1s, 63) || BitsetGet(t1s, 64) || BitsetGet(t1s, 127) { if BitsetGet(t1s, 0) || BitsetGet(t1s, 63) || BitsetGet(t1s, 64) || BitsetGet(t1s, 127) {
t.Error("no bits should be set in a newly initialized bitset") t.Error("no bits should be set in a newly initialized bitset")
@ -47,7 +46,6 @@ func TestSets(t *testing.T) {
var t2 testBitset var t2 testBitset
t2s := t2[:] t2s := t2[:]
BitsetInitialize(t2s)
for i = 0; i < 128; i++ { for i = 0; i < 128; i++ {
if i%2 == 1 { if i%2 == 1 {

35
irc/utils/sync.go Normal file
View File

@ -0,0 +1,35 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package utils
import (
"sync"
"sync/atomic"
)
// Once is a fork of sync.Once to expose a Done() method.
type Once struct {
done uint32
m sync.Mutex
}
func (o *Once) Do(f func()) {
if atomic.LoadUint32(&o.done) == 0 {
o.doSlow(f)
}
}
func (o *Once) doSlow(f func()) {
o.m.Lock()
defer o.m.Unlock()
if o.done == 0 {
defer atomic.StoreUint32(&o.done, 1)
f()
}
}
func (o *Once) Done() bool {
return atomic.LoadUint32(&o.done) == 1
}

View File

@ -23,12 +23,10 @@ type WhoWasList struct {
} }
// NewWhoWasList returns a new WhoWasList // NewWhoWasList returns a new WhoWasList
func NewWhoWasList(size int) *WhoWasList { func (list *WhoWasList) Initialize(size int) {
return &WhoWasList{ list.buffer = make([]WhoWas, size)
buffer: make([]WhoWas, size), list.start = -1
start: -1, list.end = -1
end: -1,
}
} }
// Append adds an entry to the WhoWasList. // Append adds an entry to the WhoWasList.

View File

@ -23,7 +23,8 @@ func makeTestWhowas(nick string) WhoWas {
func TestWhoWas(t *testing.T) { func TestWhoWas(t *testing.T) {
var results []WhoWas var results []WhoWas
wwl := NewWhoWasList(3) var wwl WhoWasList
wwl.Initialize(3)
// test Find on empty list // test Find on empty list
results = wwl.Find("nobody", 10) results = wwl.Find("nobody", 10)
if len(results) != 0 { if len(results) != 0 {
@ -88,7 +89,8 @@ func TestWhoWas(t *testing.T) {
func TestEmptyWhoWas(t *testing.T) { func TestEmptyWhoWas(t *testing.T) {
// stupid edge case; setting an empty whowas buffer should not panic // stupid edge case; setting an empty whowas buffer should not panic
wwl := NewWhoWasList(0) var wwl WhoWasList
wwl.Initialize(0)
results := wwl.Find("slingamn", 10) results := wwl.Find("slingamn", 10)
if len(results) != 0 { if len(results) != 0 {
t.Fatalf("incorrect whowas results: %v", results) t.Fatalf("incorrect whowas results: %v", results)