mirror of
https://github.com/ergochat/ergo.git
synced 2024-11-24 21:09:30 +01:00
refactor channel registration
This commit is contained in:
parent
29db70fa7b
commit
63029e2ff5
@ -59,18 +59,15 @@ type AccountManager struct {
|
||||
accountToMethod map[string]NickReservationMethod
|
||||
}
|
||||
|
||||
func NewAccountManager(server *Server) *AccountManager {
|
||||
am := AccountManager{
|
||||
accountToClients: make(map[string][]*Client),
|
||||
nickToAccount: make(map[string]string),
|
||||
skeletonToAccount: make(map[string]string),
|
||||
accountToMethod: make(map[string]NickReservationMethod),
|
||||
server: server,
|
||||
}
|
||||
func (am *AccountManager) Initialize(server *Server) {
|
||||
am.accountToClients = make(map[string][]*Client)
|
||||
am.nickToAccount = make(map[string]string)
|
||||
am.skeletonToAccount = make(map[string]string)
|
||||
am.accountToMethod = make(map[string]NickReservationMethod)
|
||||
am.server = server
|
||||
|
||||
am.buildNickToAccountIndex()
|
||||
am.initVHostRequestQueue()
|
||||
return &am
|
||||
}
|
||||
|
||||
func (am *AccountManager) buildNickToAccountIndex() {
|
||||
@ -855,6 +852,7 @@ func (am *AccountManager) Unregister(account string) error {
|
||||
verificationCodeKey := fmt.Sprintf(keyAccountVerificationCode, casefoldedAccount)
|
||||
verifiedKey := fmt.Sprintf(keyAccountVerified, casefoldedAccount)
|
||||
nicksKey := fmt.Sprintf(keyAccountAdditionalNicks, casefoldedAccount)
|
||||
enforcementKey := fmt.Sprintf(keyAccountEnforcement, casefoldedAccount)
|
||||
vhostKey := fmt.Sprintf(keyAccountVHost, casefoldedAccount)
|
||||
vhostQueueKey := fmt.Sprintf(keyVHostQueueAcctToId, casefoldedAccount)
|
||||
channelsKey := fmt.Sprintf(keyAccountChannels, casefoldedAccount)
|
||||
@ -865,14 +863,7 @@ func (am *AccountManager) Unregister(account string) error {
|
||||
// on our way out, unregister all the account's channels and delete them from the db
|
||||
defer func() {
|
||||
for _, channelName := range registeredChannels {
|
||||
info := am.server.channelRegistry.LoadChannel(channelName)
|
||||
if info != nil && info.Founder == casefoldedAccount {
|
||||
am.server.channelRegistry.Delete(channelName, *info)
|
||||
}
|
||||
channel := am.server.channels.Get(channelName)
|
||||
if channel != nil {
|
||||
channel.SetUnregistered(casefoldedAccount)
|
||||
}
|
||||
am.server.channels.SetUnregistered(channelName, casefoldedAccount)
|
||||
}
|
||||
}()
|
||||
|
||||
@ -892,6 +883,7 @@ func (am *AccountManager) Unregister(account string) error {
|
||||
tx.Delete(registeredTimeKey)
|
||||
tx.Delete(callbackKey)
|
||||
tx.Delete(verificationCodeKey)
|
||||
tx.Delete(enforcementKey)
|
||||
rawNicks, _ = tx.Get(nicksKey)
|
||||
tx.Delete(nicksKey)
|
||||
credText, err = tx.Get(credentialsKey)
|
||||
|
@ -16,7 +16,6 @@ type Set [bitsetLen]uint64
|
||||
// NewSet returns a new Set, with the given capabilities enabled.
|
||||
func NewSet(capabs ...Capability) *Set {
|
||||
var newSet Set
|
||||
utils.BitsetInitialize(newSet[:])
|
||||
newSet.Enable(capabs...)
|
||||
return &newSet
|
||||
}
|
||||
|
160
irc/channel.go
160
irc/channel.go
@ -22,7 +22,7 @@ import (
|
||||
|
||||
// Channel represents a channel that clients can join.
|
||||
type Channel struct {
|
||||
flags *modes.ModeSet
|
||||
flags modes.ModeSet
|
||||
lists map[modes.Mode]*UserMaskSet
|
||||
key string
|
||||
members MemberSet
|
||||
@ -33,19 +33,22 @@ type Channel struct {
|
||||
createdTime time.Time
|
||||
registeredFounder string
|
||||
registeredTime time.Time
|
||||
stateMutex sync.RWMutex // tier 1
|
||||
joinPartMutex sync.Mutex // tier 3
|
||||
topic string
|
||||
topicSetBy string
|
||||
topicSetTime time.Time
|
||||
userLimit int
|
||||
accountToUMode map[string]modes.Mode
|
||||
history history.Buffer
|
||||
stateMutex sync.RWMutex // tier 1
|
||||
writerSemaphore Semaphore // tier 1.5
|
||||
joinPartMutex sync.Mutex // tier 3
|
||||
ensureLoaded utils.Once // manages loading stored registration info from the database
|
||||
dirtyBits uint
|
||||
}
|
||||
|
||||
// NewChannel creates a new channel from a `Server` and a `name`
|
||||
// string, which must be unique on the server.
|
||||
func NewChannel(s *Server, name string, regInfo *RegisteredChannel) *Channel {
|
||||
func NewChannel(s *Server, name string, registered bool) *Channel {
|
||||
casefoldedName, err := CasefoldChannel(name)
|
||||
if err != nil {
|
||||
s.logger.Error("internal", "Bad channel name", name, err.Error())
|
||||
@ -54,7 +57,6 @@ func NewChannel(s *Server, name string, regInfo *RegisteredChannel) *Channel {
|
||||
|
||||
channel := &Channel{
|
||||
createdTime: time.Now(), // may be overwritten by applyRegInfo
|
||||
flags: modes.NewModeSet(),
|
||||
lists: map[modes.Mode]*UserMaskSet{
|
||||
modes.BanMask: NewUserMaskSet(),
|
||||
modes.ExceptMask: NewUserMaskSet(),
|
||||
@ -69,21 +71,43 @@ func NewChannel(s *Server, name string, regInfo *RegisteredChannel) *Channel {
|
||||
|
||||
config := s.Config()
|
||||
|
||||
if regInfo != nil {
|
||||
channel.applyRegInfo(regInfo)
|
||||
} else {
|
||||
channel.writerSemaphore.Initialize(1)
|
||||
channel.history.Initialize(config.History.ChannelLength)
|
||||
|
||||
if !registered {
|
||||
for _, mode := range config.Channels.defaultModes {
|
||||
channel.flags.SetMode(mode, true)
|
||||
}
|
||||
}
|
||||
|
||||
channel.history.Initialize(config.History.ChannelLength)
|
||||
// no loading to do, so "mark" the load operation as "done":
|
||||
channel.ensureLoaded.Do(func() {})
|
||||
} // else: modes will be loaded before first join
|
||||
|
||||
return channel
|
||||
}
|
||||
|
||||
// EnsureLoaded blocks until the channel's registration info has been loaded
|
||||
// from the database.
|
||||
func (channel *Channel) EnsureLoaded() {
|
||||
channel.ensureLoaded.Do(func() {
|
||||
nmc := channel.NameCasefolded()
|
||||
info, err := channel.server.channelRegistry.LoadChannel(nmc)
|
||||
if err == nil {
|
||||
channel.applyRegInfo(info)
|
||||
} else {
|
||||
channel.server.logger.Error("internal", "couldn't load channel", nmc, err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (channel *Channel) IsLoaded() bool {
|
||||
return channel.ensureLoaded.Done()
|
||||
}
|
||||
|
||||
// read in channel state that was persisted in the DB
|
||||
func (channel *Channel) applyRegInfo(chanReg *RegisteredChannel) {
|
||||
func (channel *Channel) applyRegInfo(chanReg RegisteredChannel) {
|
||||
channel.stateMutex.Lock()
|
||||
defer channel.stateMutex.Unlock()
|
||||
|
||||
channel.registeredFounder = chanReg.Founder
|
||||
channel.registeredTime = chanReg.RegisteredAt
|
||||
channel.topic = chanReg.Topic
|
||||
@ -116,6 +140,7 @@ func (channel *Channel) ExportRegistration(includeFlags uint) (info RegisteredCh
|
||||
defer channel.stateMutex.RUnlock()
|
||||
|
||||
info.Name = channel.name
|
||||
info.NameCasefolded = channel.nameCasefolded
|
||||
info.Founder = channel.registeredFounder
|
||||
info.RegisteredAt = channel.registeredTime
|
||||
|
||||
@ -149,6 +174,115 @@ func (channel *Channel) ExportRegistration(includeFlags uint) (info RegisteredCh
|
||||
return
|
||||
}
|
||||
|
||||
// begin: asynchronous database writeback implementation, modeled on irc/socket.go
|
||||
|
||||
// MarkDirty marks part (or all) of a channel's data as needing to be written back
|
||||
// to the database, then starts a writer goroutine if necessary.
|
||||
// This is the equivalent of Socket.Write().
|
||||
func (channel *Channel) MarkDirty(dirtyBits uint) {
|
||||
channel.stateMutex.Lock()
|
||||
isRegistered := channel.registeredFounder != ""
|
||||
channel.dirtyBits = channel.dirtyBits | dirtyBits
|
||||
channel.stateMutex.Unlock()
|
||||
if !isRegistered {
|
||||
return
|
||||
}
|
||||
|
||||
channel.wakeWriter()
|
||||
}
|
||||
|
||||
// IsClean returns whether a channel can be safely removed from the server.
|
||||
// To avoid the obvious TOCTOU race condition, it must be called while holding
|
||||
// ChannelManager's lock (that way, no one can join and make the channel dirty again
|
||||
// between this method exiting and the actual deletion).
|
||||
func (channel *Channel) IsClean() bool {
|
||||
if !channel.writerSemaphore.TryAcquire() {
|
||||
// a database write (which may fail) is in progress, the channel cannot be cleaned up
|
||||
return false
|
||||
}
|
||||
defer channel.writerSemaphore.Release()
|
||||
|
||||
channel.stateMutex.RLock()
|
||||
defer channel.stateMutex.RUnlock()
|
||||
// the channel must be empty, and either be unregistered or fully written to the DB
|
||||
return len(channel.members) == 0 && (channel.registeredFounder == "" || channel.dirtyBits == 0)
|
||||
}
|
||||
|
||||
func (channel *Channel) wakeWriter() {
|
||||
if channel.writerSemaphore.TryAcquire() {
|
||||
go channel.writeLoop()
|
||||
}
|
||||
}
|
||||
|
||||
// equivalent of Socket.send()
|
||||
func (channel *Channel) writeLoop() {
|
||||
for {
|
||||
// TODO(#357) check the error value of this and implement timed backoff
|
||||
channel.performWrite(0)
|
||||
channel.writerSemaphore.Release()
|
||||
|
||||
channel.stateMutex.RLock()
|
||||
isDirty := channel.dirtyBits != 0
|
||||
isEmpty := len(channel.members) == 0
|
||||
channel.stateMutex.RUnlock()
|
||||
|
||||
if !isDirty {
|
||||
if isEmpty {
|
||||
channel.server.channels.Cleanup(channel)
|
||||
}
|
||||
return // nothing to do
|
||||
} // else: isDirty, so we need to write again
|
||||
|
||||
if !channel.writerSemaphore.TryAcquire() {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Store writes part (or all) of the channel's data back to the database,
|
||||
// blocking until the write is complete. This is the equivalent of
|
||||
// Socket.BlockingWrite.
|
||||
func (channel *Channel) Store(dirtyBits uint) (err error) {
|
||||
defer func() {
|
||||
channel.stateMutex.Lock()
|
||||
isDirty := channel.dirtyBits != 0
|
||||
isEmpty := len(channel.members) == 0
|
||||
channel.stateMutex.Unlock()
|
||||
|
||||
if isDirty {
|
||||
channel.wakeWriter()
|
||||
} else if isEmpty {
|
||||
channel.server.channels.Cleanup(channel)
|
||||
}
|
||||
}()
|
||||
|
||||
channel.writerSemaphore.Acquire()
|
||||
defer channel.writerSemaphore.Release()
|
||||
return channel.performWrite(dirtyBits)
|
||||
}
|
||||
|
||||
// do an individual write; equivalent of Socket.send()
|
||||
func (channel *Channel) performWrite(additionalDirtyBits uint) (err error) {
|
||||
channel.stateMutex.Lock()
|
||||
dirtyBits := channel.dirtyBits | additionalDirtyBits
|
||||
channel.dirtyBits = 0
|
||||
isRegistered := channel.registeredFounder != ""
|
||||
channel.stateMutex.Unlock()
|
||||
|
||||
if !isRegistered || dirtyBits == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
info := channel.ExportRegistration(dirtyBits)
|
||||
err = channel.server.channelRegistry.StoreChannel(info, dirtyBits)
|
||||
if err != nil {
|
||||
channel.stateMutex.Lock()
|
||||
channel.dirtyBits = channel.dirtyBits | dirtyBits
|
||||
channel.stateMutex.Unlock()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// SetRegistered registers the channel, returning an error if it was already registered.
|
||||
func (channel *Channel) SetRegistered(founder string) error {
|
||||
channel.stateMutex.Lock()
|
||||
@ -698,7 +832,7 @@ func (channel *Channel) SetTopic(client *Client, topic string, rb *ResponseBuffe
|
||||
}
|
||||
}
|
||||
|
||||
go channel.server.channelRegistry.StoreChannel(channel, IncludeTopic)
|
||||
channel.MarkDirty(IncludeTopic)
|
||||
}
|
||||
|
||||
// CanSpeak returns true if the client can speak on this channel.
|
||||
|
@ -19,25 +19,38 @@ type channelManagerEntry struct {
|
||||
// providing synchronization for creation of new channels on first join,
|
||||
// cleanup of empty channels on last part, and renames.
|
||||
type ChannelManager struct {
|
||||
sync.RWMutex // tier 2
|
||||
chans map[string]*channelManagerEntry
|
||||
sync.RWMutex // tier 2
|
||||
chans map[string]*channelManagerEntry
|
||||
registeredChannels map[string]bool
|
||||
server *Server
|
||||
}
|
||||
|
||||
// NewChannelManager returns a new ChannelManager.
|
||||
func NewChannelManager() *ChannelManager {
|
||||
return &ChannelManager{
|
||||
chans: make(map[string]*channelManagerEntry),
|
||||
func (cm *ChannelManager) Initialize(server *Server) {
|
||||
cm.chans = make(map[string]*channelManagerEntry)
|
||||
cm.server = server
|
||||
|
||||
if server.Config().Channels.Registration.Enabled {
|
||||
cm.loadRegisteredChannels()
|
||||
}
|
||||
}
|
||||
|
||||
func (cm *ChannelManager) loadRegisteredChannels() {
|
||||
registeredChannels := cm.server.channelRegistry.AllChannels()
|
||||
cm.Lock()
|
||||
defer cm.Unlock()
|
||||
cm.registeredChannels = registeredChannels
|
||||
}
|
||||
|
||||
// Get returns an existing channel with name equivalent to `name`, or nil
|
||||
func (cm *ChannelManager) Get(name string) *Channel {
|
||||
func (cm *ChannelManager) Get(name string) (channel *Channel) {
|
||||
name, err := CasefoldChannel(name)
|
||||
if err == nil {
|
||||
cm.RLock()
|
||||
defer cm.RUnlock()
|
||||
entry := cm.chans[name]
|
||||
if entry != nil {
|
||||
// if the channel is still loading, pretend we don't have it
|
||||
if entry != nil && entry.channel.IsLoaded() {
|
||||
return entry.channel
|
||||
}
|
||||
}
|
||||
@ -55,28 +68,21 @@ func (cm *ChannelManager) Join(client *Client, name string, key string, isSajoin
|
||||
cm.Lock()
|
||||
entry := cm.chans[casefoldedName]
|
||||
if entry == nil {
|
||||
// XXX give up the lock to check for a registration, then check again
|
||||
// to see if we need to create the channel. we could solve this by doing LoadChannel
|
||||
// outside the lock initially on every join, so this is best thought of as an
|
||||
// optimization to avoid that.
|
||||
cm.Unlock()
|
||||
info := client.server.channelRegistry.LoadChannel(casefoldedName)
|
||||
cm.Lock()
|
||||
entry = cm.chans[casefoldedName]
|
||||
if entry == nil {
|
||||
entry = &channelManagerEntry{
|
||||
channel: NewChannel(server, name, info),
|
||||
pendingJoins: 0,
|
||||
}
|
||||
cm.chans[casefoldedName] = entry
|
||||
registered := cm.registeredChannels[casefoldedName]
|
||||
entry = &channelManagerEntry{
|
||||
channel: NewChannel(server, name, registered),
|
||||
pendingJoins: 0,
|
||||
}
|
||||
cm.chans[casefoldedName] = entry
|
||||
}
|
||||
entry.pendingJoins += 1
|
||||
channel := entry.channel
|
||||
cm.Unlock()
|
||||
|
||||
entry.channel.Join(client, key, isSajoin, rb)
|
||||
channel.EnsureLoaded()
|
||||
channel.Join(client, key, isSajoin, rb)
|
||||
|
||||
cm.maybeCleanup(entry.channel, true)
|
||||
cm.maybeCleanup(channel, true)
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -85,7 +91,8 @@ func (cm *ChannelManager) maybeCleanup(channel *Channel, afterJoin bool) {
|
||||
cm.Lock()
|
||||
defer cm.Unlock()
|
||||
|
||||
entry := cm.chans[channel.NameCasefolded()]
|
||||
nameCasefolded := channel.NameCasefolded()
|
||||
entry := cm.chans[nameCasefolded]
|
||||
if entry == nil || entry.channel != channel {
|
||||
return
|
||||
}
|
||||
@ -93,23 +100,15 @@ func (cm *ChannelManager) maybeCleanup(channel *Channel, afterJoin bool) {
|
||||
if afterJoin {
|
||||
entry.pendingJoins -= 1
|
||||
}
|
||||
// TODO(slingamn) right now, registered channels cannot be cleaned up.
|
||||
// this is because once ChannelManager becomes the source of truth about a channel,
|
||||
// we can't move the source of truth back to the database unless we do an ACID
|
||||
// store while holding the ChannelManager's Lock(). This is pending more decisions
|
||||
// about where the database transaction lock fits into the overall lock model.
|
||||
if !entry.channel.IsRegistered() && entry.channel.IsEmpty() && entry.pendingJoins == 0 {
|
||||
// reread the name, handling the case where the channel was renamed
|
||||
casefoldedName := entry.channel.NameCasefolded()
|
||||
delete(cm.chans, casefoldedName)
|
||||
// invalidate the entry (otherwise, a subsequent cleanup attempt could delete
|
||||
// a valid, distinct entry under casefoldedName):
|
||||
entry.channel = nil
|
||||
if entry.pendingJoins == 0 && entry.channel.IsClean() {
|
||||
delete(cm.chans, nameCasefolded)
|
||||
}
|
||||
}
|
||||
|
||||
// Part parts `client` from the channel named `name`, deleting it if it's empty.
|
||||
func (cm *ChannelManager) Part(client *Client, name string, message string, rb *ResponseBuffer) error {
|
||||
var channel *Channel
|
||||
|
||||
casefoldedName, err := CasefoldChannel(name)
|
||||
if err != nil {
|
||||
return errNoSuchChannel
|
||||
@ -117,12 +116,15 @@ func (cm *ChannelManager) Part(client *Client, name string, message string, rb *
|
||||
|
||||
cm.RLock()
|
||||
entry := cm.chans[casefoldedName]
|
||||
if entry != nil {
|
||||
channel = entry.channel
|
||||
}
|
||||
cm.RUnlock()
|
||||
|
||||
if entry == nil {
|
||||
if channel == nil {
|
||||
return errNoSuchChannel
|
||||
}
|
||||
entry.channel.Part(client, message, rb)
|
||||
channel.Part(client, message, rb)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -130,8 +132,68 @@ func (cm *ChannelManager) Cleanup(channel *Channel) {
|
||||
cm.maybeCleanup(channel, false)
|
||||
}
|
||||
|
||||
func (cm *ChannelManager) SetRegistered(channelName string, account string) (err error) {
|
||||
var channel *Channel
|
||||
cfname, err := CasefoldChannel(channelName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var entry *channelManagerEntry
|
||||
|
||||
defer func() {
|
||||
if err == nil && channel != nil {
|
||||
// registration was successful: make the database reflect it
|
||||
err = channel.Store(IncludeAllChannelAttrs)
|
||||
}
|
||||
}()
|
||||
|
||||
cm.Lock()
|
||||
defer cm.Unlock()
|
||||
entry = cm.chans[cfname]
|
||||
if entry == nil {
|
||||
return errNoSuchChannel
|
||||
}
|
||||
channel = entry.channel
|
||||
err = channel.SetRegistered(account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cm.registeredChannels[cfname] = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cm *ChannelManager) SetUnregistered(channelName string, account string) (err error) {
|
||||
cfname, err := CasefoldChannel(channelName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var info RegisteredChannel
|
||||
|
||||
defer func() {
|
||||
if err == nil {
|
||||
err = cm.server.channelRegistry.Delete(info)
|
||||
}
|
||||
}()
|
||||
|
||||
cm.Lock()
|
||||
defer cm.Unlock()
|
||||
entry := cm.chans[cfname]
|
||||
if entry == nil {
|
||||
return errNoSuchChannel
|
||||
}
|
||||
info = entry.channel.ExportRegistration(0)
|
||||
if info.Founder != account {
|
||||
return errChannelNotOwnedByAccount
|
||||
}
|
||||
entry.channel.SetUnregistered(account)
|
||||
delete(cm.registeredChannels, cfname)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Rename renames a channel (but does not notify the members)
|
||||
func (cm *ChannelManager) Rename(name string, newname string) error {
|
||||
func (cm *ChannelManager) Rename(name string, newname string) (err error) {
|
||||
cfname, err := CasefoldChannel(name)
|
||||
if err != nil {
|
||||
return errNoSuchChannel
|
||||
@ -142,6 +204,17 @@ func (cm *ChannelManager) Rename(name string, newname string) error {
|
||||
return errInvalidChannelName
|
||||
}
|
||||
|
||||
var channel *Channel
|
||||
var info RegisteredChannel
|
||||
defer func() {
|
||||
if channel != nil && info.Founder != "" {
|
||||
channel.Store(IncludeAllChannelAttrs)
|
||||
// we just flushed the channel under its new name, therefore this delete
|
||||
// cannot be overwritten by a write to the old name:
|
||||
cm.server.channelRegistry.Delete(info)
|
||||
}
|
||||
}()
|
||||
|
||||
cm.Lock()
|
||||
defer cm.Unlock()
|
||||
|
||||
@ -152,12 +225,12 @@ func (cm *ChannelManager) Rename(name string, newname string) error {
|
||||
if entry == nil {
|
||||
return errNoSuchChannel
|
||||
}
|
||||
channel = entry.channel
|
||||
info = channel.ExportRegistration(IncludeInitial)
|
||||
delete(cm.chans, cfname)
|
||||
cm.chans[cfnewname] = entry
|
||||
entry.channel.setName(newname)
|
||||
entry.channel.setNameCasefolded(cfnewname)
|
||||
entry.channel.Rename(newname, cfnewname)
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
// Len returns the number of channels
|
||||
@ -171,8 +244,11 @@ func (cm *ChannelManager) Len() int {
|
||||
func (cm *ChannelManager) Channels() (result []*Channel) {
|
||||
cm.RLock()
|
||||
defer cm.RUnlock()
|
||||
result = make([]*Channel, 0, len(cm.chans))
|
||||
for _, entry := range cm.chans {
|
||||
result = append(result, entry.channel)
|
||||
if entry.channel.IsLoaded() {
|
||||
result = append(result, entry.channel)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -7,7 +7,6 @@ import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"encoding/json"
|
||||
@ -71,6 +70,8 @@ const (
|
||||
type RegisteredChannel struct {
|
||||
// Name of the channel.
|
||||
Name string
|
||||
// Casefolded name of the channel.
|
||||
NameCasefolded string
|
||||
// RegisteredAt represents the time that the channel was registered.
|
||||
RegisteredAt time.Time
|
||||
// Founder indicates the founder of the channel.
|
||||
@ -97,58 +98,65 @@ type RegisteredChannel struct {
|
||||
|
||||
// ChannelRegistry manages registered channels.
|
||||
type ChannelRegistry struct {
|
||||
// This serializes operations of the form (read channel state, synchronously persist it);
|
||||
// this is enough to guarantee eventual consistency of the database with the
|
||||
// ChannelManager and Channel objects, which are the source of truth.
|
||||
//
|
||||
// We could use the buntdb RW transaction lock for this purpose but we share
|
||||
// that with all the other modules, so let's not.
|
||||
sync.Mutex // tier 2
|
||||
server *Server
|
||||
server *Server
|
||||
}
|
||||
|
||||
// NewChannelRegistry returns a new ChannelRegistry.
|
||||
func NewChannelRegistry(server *Server) *ChannelRegistry {
|
||||
return &ChannelRegistry{
|
||||
server: server,
|
||||
}
|
||||
func (reg *ChannelRegistry) Initialize(server *Server) {
|
||||
reg.server = server
|
||||
}
|
||||
|
||||
func (reg *ChannelRegistry) AllChannels() (result map[string]bool) {
|
||||
result = make(map[string]bool)
|
||||
|
||||
prefix := fmt.Sprintf(keyChannelExists, "")
|
||||
reg.server.store.View(func(tx *buntdb.Tx) error {
|
||||
return tx.AscendGreaterOrEqual("", prefix, func(key, value string) bool {
|
||||
if !strings.HasPrefix(key, prefix) {
|
||||
return false
|
||||
}
|
||||
channel := strings.TrimPrefix(key, prefix)
|
||||
result[channel] = true
|
||||
return true
|
||||
})
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// StoreChannel obtains a consistent view of a channel, then persists it to the store.
|
||||
func (reg *ChannelRegistry) StoreChannel(channel *Channel, includeFlags uint) {
|
||||
func (reg *ChannelRegistry) StoreChannel(info RegisteredChannel, includeFlags uint) (err error) {
|
||||
if !reg.server.ChannelRegistrationEnabled() {
|
||||
return
|
||||
}
|
||||
|
||||
reg.Lock()
|
||||
defer reg.Unlock()
|
||||
|
||||
key := channel.NameCasefolded()
|
||||
info := channel.ExportRegistration(includeFlags)
|
||||
if info.Founder == "" {
|
||||
// sanity check, don't try to store an unregistered channel
|
||||
return
|
||||
}
|
||||
|
||||
reg.server.store.Update(func(tx *buntdb.Tx) error {
|
||||
reg.saveChannel(tx, key, info, includeFlags)
|
||||
reg.saveChannel(tx, info, includeFlags)
|
||||
return nil
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadChannel loads a channel from the store.
|
||||
func (reg *ChannelRegistry) LoadChannel(nameCasefolded string) (info *RegisteredChannel) {
|
||||
func (reg *ChannelRegistry) LoadChannel(nameCasefolded string) (info RegisteredChannel, err error) {
|
||||
if !reg.server.ChannelRegistrationEnabled() {
|
||||
return nil
|
||||
err = errFeatureDisabled
|
||||
return
|
||||
}
|
||||
|
||||
channelKey := nameCasefolded
|
||||
// nice to have: do all JSON (de)serialization outside of the buntdb transaction
|
||||
reg.server.store.View(func(tx *buntdb.Tx) error {
|
||||
_, err := tx.Get(fmt.Sprintf(keyChannelExists, channelKey))
|
||||
if err == buntdb.ErrNotFound {
|
||||
err = reg.server.store.View(func(tx *buntdb.Tx) error {
|
||||
_, dberr := tx.Get(fmt.Sprintf(keyChannelExists, channelKey))
|
||||
if dberr == buntdb.ErrNotFound {
|
||||
// chan does not already exist, return
|
||||
return nil
|
||||
return errNoSuchChannel
|
||||
}
|
||||
|
||||
// channel exists, load it
|
||||
@ -181,7 +189,7 @@ func (reg *ChannelRegistry) LoadChannel(nameCasefolded string) (info *Registered
|
||||
accountToUMode := make(map[string]modes.Mode)
|
||||
_ = json.Unmarshal([]byte(accountToUModeString), &accountToUMode)
|
||||
|
||||
info = &RegisteredChannel{
|
||||
info = RegisteredChannel{
|
||||
Name: name,
|
||||
RegisteredAt: time.Unix(regTimeInt, 0),
|
||||
Founder: founder,
|
||||
@ -198,46 +206,21 @@ func (reg *ChannelRegistry) LoadChannel(nameCasefolded string) (info *Registered
|
||||
return nil
|
||||
})
|
||||
|
||||
return info
|
||||
return
|
||||
}
|
||||
|
||||
func (reg *ChannelRegistry) Delete(casefoldedName string, info RegisteredChannel) {
|
||||
// Delete deletes a channel corresponding to `info`. If no such channel
|
||||
// is present in the database, no error is returned.
|
||||
func (reg *ChannelRegistry) Delete(info RegisteredChannel) (err error) {
|
||||
if !reg.server.ChannelRegistrationEnabled() {
|
||||
return
|
||||
}
|
||||
|
||||
reg.Lock()
|
||||
defer reg.Unlock()
|
||||
|
||||
reg.server.store.Update(func(tx *buntdb.Tx) error {
|
||||
reg.deleteChannel(tx, casefoldedName, info)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// Rename handles the persistence part of a channel rename: the channel is
|
||||
// persisted under its new name, and the old name is cleaned up if necessary.
|
||||
func (reg *ChannelRegistry) Rename(channel *Channel, casefoldedOldName string) {
|
||||
if !reg.server.ChannelRegistrationEnabled() {
|
||||
return
|
||||
}
|
||||
|
||||
reg.Lock()
|
||||
defer reg.Unlock()
|
||||
|
||||
includeFlags := IncludeAllChannelAttrs
|
||||
oldKey := casefoldedOldName
|
||||
key := channel.NameCasefolded()
|
||||
info := channel.ExportRegistration(includeFlags)
|
||||
if info.Founder == "" {
|
||||
return
|
||||
}
|
||||
|
||||
reg.server.store.Update(func(tx *buntdb.Tx) error {
|
||||
reg.deleteChannel(tx, oldKey, info)
|
||||
reg.saveChannel(tx, key, info, includeFlags)
|
||||
reg.deleteChannel(tx, info.NameCasefolded, info)
|
||||
return nil
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// delete a channel, unless it was overwritten by another registration of the same channel
|
||||
@ -274,7 +257,8 @@ func (reg *ChannelRegistry) deleteChannel(tx *buntdb.Tx, key string, info Regist
|
||||
}
|
||||
|
||||
// saveChannel saves a channel to the store.
|
||||
func (reg *ChannelRegistry) saveChannel(tx *buntdb.Tx, channelKey string, channelInfo RegisteredChannel, includeFlags uint) {
|
||||
func (reg *ChannelRegistry) saveChannel(tx *buntdb.Tx, channelInfo RegisteredChannel, includeFlags uint) {
|
||||
channelKey := channelInfo.NameCasefolded
|
||||
// maintain the mapping of account -> registered channels
|
||||
chanExistsKey := fmt.Sprintf(keyChannelExists, channelKey)
|
||||
_, existsErr := tx.Get(chanExistsKey)
|
||||
|
@ -232,15 +232,12 @@ func csRegisterHandler(server *Server, client *Client, command string, params []
|
||||
}
|
||||
|
||||
// this provides the synchronization that allows exactly one registration of the channel:
|
||||
err = channelInfo.SetRegistered(account)
|
||||
err = server.channels.SetRegistered(channelKey, account)
|
||||
if err != nil {
|
||||
csNotice(rb, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// registration was successful: make the database reflect it
|
||||
go server.channelRegistry.StoreChannel(channelInfo, IncludeAllChannelAttrs)
|
||||
|
||||
csNotice(rb, fmt.Sprintf(client.t("Channel %s successfully registered"), channelName))
|
||||
|
||||
server.logger.Info("services", fmt.Sprintf("Client %s registered channel %s", client.nick, channelName))
|
||||
@ -297,8 +294,7 @@ func csUnregisterHandler(server *Server, client *Client, command string, params
|
||||
return
|
||||
}
|
||||
|
||||
channel.SetUnregistered(founder)
|
||||
server.channelRegistry.Delete(channelKey, info)
|
||||
server.channels.SetUnregistered(channelKey, founder)
|
||||
csNotice(rb, fmt.Sprintf(client.t("Channel %s is now unregistered"), channelKey))
|
||||
}
|
||||
|
||||
|
@ -50,7 +50,7 @@ type Client struct {
|
||||
accountName string // display name of the account: uncasefolded, '*' if not logged in
|
||||
atime time.Time
|
||||
awayMessage string
|
||||
capabilities *caps.Set
|
||||
capabilities caps.Set
|
||||
capState caps.State
|
||||
capVersion caps.Version
|
||||
certfp string
|
||||
@ -58,7 +58,7 @@ type Client struct {
|
||||
ctime time.Time
|
||||
exitedSnomaskSent bool
|
||||
fakelag Fakelag
|
||||
flags *modes.ModeSet
|
||||
flags modes.ModeSet
|
||||
hasQuit bool
|
||||
hops int
|
||||
hostname string
|
||||
@ -125,15 +125,13 @@ func RunNewClient(server *Server, conn clientConn) {
|
||||
// give them 1k of grace over the limit:
|
||||
socket := NewSocket(conn.Conn, fullLineLenLimit+1024, config.Server.MaxSendQBytes)
|
||||
client := &Client{
|
||||
atime: now,
|
||||
capabilities: caps.NewSet(),
|
||||
capState: caps.NoneState,
|
||||
capVersion: caps.Cap301,
|
||||
channels: make(ChannelSet),
|
||||
ctime: now,
|
||||
flags: modes.NewModeSet(),
|
||||
isTor: conn.IsTor,
|
||||
languages: server.Languages().Default(),
|
||||
atime: now,
|
||||
capState: caps.NoneState,
|
||||
capVersion: caps.Cap301,
|
||||
channels: make(ChannelSet),
|
||||
ctime: now,
|
||||
isTor: conn.IsTor,
|
||||
languages: server.Languages().Default(),
|
||||
loginThrottle: connection_limits.GenericThrottle{
|
||||
Duration: config.Accounts.LoginThrottling.Duration,
|
||||
Limit: config.Accounts.LoginThrottling.MaxAttempts,
|
||||
@ -546,7 +544,6 @@ func (client *Client) replayPrivmsgHistory(rb *ResponseBuffer, items []history.I
|
||||
// copy applicable state from oldClient to client as part of a resume
|
||||
func (client *Client) copyResumeData(oldClient *Client) {
|
||||
oldClient.stateMutex.RLock()
|
||||
flags := oldClient.flags
|
||||
history := oldClient.history
|
||||
nick := oldClient.nick
|
||||
nickCasefolded := oldClient.nickCasefolded
|
||||
@ -560,7 +557,7 @@ func (client *Client) copyResumeData(oldClient *Client) {
|
||||
// resume over plaintext)
|
||||
hasTLS := client.flags.HasMode(modes.TLS)
|
||||
temp := modes.NewModeSet()
|
||||
temp.Copy(flags)
|
||||
temp.Copy(&oldClient.flags)
|
||||
temp.SetMode(modes.TLS, hasTLS)
|
||||
client.flags.Copy(temp)
|
||||
|
||||
|
@ -37,12 +37,10 @@ type ClientManager struct {
|
||||
bySkeleton map[string]*Client
|
||||
}
|
||||
|
||||
// NewClientManager returns a new ClientManager.
|
||||
func NewClientManager() *ClientManager {
|
||||
return &ClientManager{
|
||||
byNick: make(map[string]*Client),
|
||||
bySkeleton: make(map[string]*Client),
|
||||
}
|
||||
// Initialize initializes a ClientManager.
|
||||
func (clients *ClientManager) Initialize() {
|
||||
clients.byNick = make(map[string]*Client)
|
||||
clients.bySkeleton = make(map[string]*Client)
|
||||
}
|
||||
|
||||
// Count returns how many clients are in the manager.
|
||||
|
@ -27,6 +27,8 @@ var (
|
||||
errAccountMustHoldNick = errors.New(`You must hold that nickname in order to register it`)
|
||||
errCallbackFailed = errors.New("Account verification could not be sent")
|
||||
errCertfpAlreadyExists = errors.New(`An account already exists for your certificate fingerprint`)
|
||||
errChannelNotOwnedByAccount = errors.New("Channel not owned by the specified account")
|
||||
errChannelDoesNotExist = errors.New("Channel does not exist")
|
||||
errChannelAlreadyRegistered = errors.New("Channel is already registered")
|
||||
errChannelNameInUse = errors.New(`Channel name in use`)
|
||||
errInvalidChannelName = errors.New(`Invalid channel name`)
|
||||
|
@ -4,6 +4,8 @@
|
||||
package irc
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/oragono/oragono/irc/isupport"
|
||||
"github.com/oragono/oragono/irc/languages"
|
||||
"github.com/oragono/oragono/irc/modes"
|
||||
@ -267,22 +269,20 @@ func (channel *Channel) Name() string {
|
||||
return channel.name
|
||||
}
|
||||
|
||||
func (channel *Channel) setName(name string) {
|
||||
channel.stateMutex.Lock()
|
||||
defer channel.stateMutex.Unlock()
|
||||
channel.name = name
|
||||
}
|
||||
|
||||
func (channel *Channel) NameCasefolded() string {
|
||||
channel.stateMutex.RLock()
|
||||
defer channel.stateMutex.RUnlock()
|
||||
return channel.nameCasefolded
|
||||
}
|
||||
|
||||
func (channel *Channel) setNameCasefolded(nameCasefolded string) {
|
||||
func (channel *Channel) Rename(name, nameCasefolded string) {
|
||||
channel.stateMutex.Lock()
|
||||
defer channel.stateMutex.Unlock()
|
||||
channel.name = name
|
||||
channel.nameCasefolded = nameCasefolded
|
||||
if channel.registeredFounder != "" {
|
||||
channel.registeredTime = time.Now()
|
||||
}
|
||||
channel.stateMutex.Unlock()
|
||||
}
|
||||
|
||||
func (channel *Channel) Members() (result []*Client) {
|
||||
@ -314,3 +314,10 @@ func (channel *Channel) Founder() string {
|
||||
defer channel.stateMutex.RUnlock()
|
||||
return channel.registeredFounder
|
||||
}
|
||||
|
||||
func (channel *Channel) DirtyBits() (dirtyBits uint) {
|
||||
channel.stateMutex.Lock()
|
||||
dirtyBits = channel.dirtyBits
|
||||
channel.stateMutex.Unlock()
|
||||
return
|
||||
}
|
||||
|
@ -1607,8 +1607,8 @@ func cmodeHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Res
|
||||
}
|
||||
}
|
||||
|
||||
if channel.IsRegistered() && includeFlags != 0 {
|
||||
go server.channelRegistry.StoreChannel(channel, includeFlags)
|
||||
if includeFlags != 0 {
|
||||
channel.MarkDirty(includeFlags)
|
||||
}
|
||||
|
||||
// send out changes
|
||||
@ -2215,7 +2215,6 @@ func renameHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Re
|
||||
rb.Add(nil, server.name, ERR_NOSUCHCHANNEL, client.Nick(), oldName, client.t("No such channel"))
|
||||
return false
|
||||
}
|
||||
casefoldedOldName := channel.NameCasefolded()
|
||||
if !(channel.ClientIsAtLeast(client, modes.Operator) || client.HasRoleCapabs("chanreg")) {
|
||||
rb.Add(nil, server.name, ERR_CHANOPRIVSNEEDED, client.Nick(), oldName, client.t("You're not a channel operator"))
|
||||
return false
|
||||
@ -2240,9 +2239,6 @@ func renameHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Re
|
||||
return false
|
||||
}
|
||||
|
||||
// rename succeeded, persist it
|
||||
go server.channelRegistry.Rename(channel, casefoldedOldName)
|
||||
|
||||
// send RENAME messages
|
||||
clientPrefix := client.NickMaskString()
|
||||
for _, mcl := range channel.Members() {
|
||||
|
@ -290,14 +290,14 @@ func (channel *Channel) ProcessAccountToUmodeChange(client *Client, change modes
|
||||
case modes.Add:
|
||||
if targetModeNow != targetModeAfter {
|
||||
channel.accountToUMode[change.Arg] = change.Mode
|
||||
go client.server.channelRegistry.StoreChannel(channel, IncludeLists)
|
||||
channel.MarkDirty(IncludeLists)
|
||||
return []modes.ModeChange{change}, nil
|
||||
}
|
||||
return nil, nil
|
||||
case modes.Remove:
|
||||
if targetModeNow == change.Mode {
|
||||
delete(channel.accountToUMode, change.Arg)
|
||||
go client.server.channelRegistry.StoreChannel(channel, IncludeLists)
|
||||
channel.MarkDirty(IncludeLists)
|
||||
return []modes.ModeChange{change}, nil
|
||||
}
|
||||
return nil, nil
|
||||
|
@ -335,7 +335,6 @@ const (
|
||||
// returns a pointer to a new ModeSet
|
||||
func NewModeSet() *ModeSet {
|
||||
var set ModeSet
|
||||
utils.BitsetInitialize(set[:])
|
||||
return &set
|
||||
}
|
||||
|
||||
|
@ -32,14 +32,13 @@ type ServerSemaphores struct {
|
||||
ClientDestroy Semaphore
|
||||
}
|
||||
|
||||
// NewServerSemaphores creates a new ServerSemaphores.
|
||||
func NewServerSemaphores() (result *ServerSemaphores) {
|
||||
// Initialize initializes a set of server semaphores.
|
||||
func (serversem *ServerSemaphores) Initialize() {
|
||||
capacity := runtime.NumCPU()
|
||||
if capacity > MaxServerSemaphoreCapacity {
|
||||
capacity = MaxServerSemaphoreCapacity
|
||||
}
|
||||
result = new(ServerSemaphores)
|
||||
result.ClientDestroy.Initialize(capacity)
|
||||
serversem.ClientDestroy.Initialize(capacity)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -61,10 +61,10 @@ type ListenerWrapper struct {
|
||||
|
||||
// Server is the main Oragono server.
|
||||
type Server struct {
|
||||
accounts *AccountManager
|
||||
channels *ChannelManager
|
||||
channelRegistry *ChannelRegistry
|
||||
clients *ClientManager
|
||||
accounts AccountManager
|
||||
channels ChannelManager
|
||||
channelRegistry ChannelRegistry
|
||||
clients ClientManager
|
||||
config *Config
|
||||
configFilename string
|
||||
configurableStateMutex sync.RWMutex // tier 1; generic protection for server state modified by rehash()
|
||||
@ -89,9 +89,9 @@ type Server struct {
|
||||
snomasks *SnoManager
|
||||
store *buntdb.DB
|
||||
torLimiter connection_limits.TorLimiter
|
||||
whoWas *WhoWasList
|
||||
stats *Stats
|
||||
semaphores *ServerSemaphores
|
||||
whoWas WhoWasList
|
||||
stats Stats
|
||||
semaphores ServerSemaphores
|
||||
}
|
||||
|
||||
var (
|
||||
@ -113,8 +113,6 @@ type clientConn struct {
|
||||
func NewServer(config *Config, logger *logger.Manager) (*Server, error) {
|
||||
// initialize data structures
|
||||
server := &Server{
|
||||
channels: NewChannelManager(),
|
||||
clients: NewClientManager(),
|
||||
connectionLimiter: connection_limits.NewLimiter(),
|
||||
connectionThrottler: connection_limits.NewThrottler(),
|
||||
listeners: make(map[string]*ListenerWrapper),
|
||||
@ -123,12 +121,12 @@ func NewServer(config *Config, logger *logger.Manager) (*Server, error) {
|
||||
rehashSignal: make(chan os.Signal, 1),
|
||||
signals: make(chan os.Signal, len(ServerExitSignals)),
|
||||
snomasks: NewSnoManager(),
|
||||
whoWas: NewWhoWasList(config.Limits.WhowasEntries),
|
||||
stats: NewStats(),
|
||||
semaphores: NewServerSemaphores(),
|
||||
}
|
||||
|
||||
server.clients.Initialize()
|
||||
server.semaphores.Initialize()
|
||||
server.resumeManager.Initialize(server)
|
||||
server.whoWas.Initialize(config.Limits.WhowasEntries)
|
||||
|
||||
if err := server.applyConfig(config, true); err != nil {
|
||||
return nil, err
|
||||
@ -697,6 +695,12 @@ func (server *Server) applyConfig(config *Config, initial bool) (err error) {
|
||||
server.accounts.initVHostRequestQueue()
|
||||
}
|
||||
|
||||
chanRegPreviouslyDisabled := oldConfig != nil && !oldConfig.Channels.Registration.Enabled
|
||||
chanRegNowEnabled := config.Channels.Registration.Enabled
|
||||
if chanRegPreviouslyDisabled && chanRegNowEnabled {
|
||||
server.channels.loadRegisteredChannels()
|
||||
}
|
||||
|
||||
// MaxLine
|
||||
if config.Limits.LineLen.Rest != 512 {
|
||||
SupportedCapabilities.Enable(caps.MaxLine)
|
||||
@ -922,9 +926,9 @@ func (server *Server) loadDatastore(config *Config) error {
|
||||
server.loadDLines()
|
||||
server.loadKLines()
|
||||
|
||||
server.channelRegistry = NewChannelRegistry(server)
|
||||
|
||||
server.accounts = NewAccountManager(server)
|
||||
server.channelRegistry.Initialize(server)
|
||||
server.channels.Initialize(server)
|
||||
server.accounts.Initialize(server)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
11
irc/stats.go
11
irc/stats.go
@ -13,17 +13,6 @@ type Stats struct {
|
||||
Operators int
|
||||
}
|
||||
|
||||
// NewStats creates a new instance of Stats
|
||||
func NewStats() *Stats {
|
||||
serverStats := &Stats{
|
||||
Total: 0,
|
||||
Invisible: 0,
|
||||
Operators: 0,
|
||||
}
|
||||
|
||||
return serverStats
|
||||
}
|
||||
|
||||
// ChangeTotal increments the total user count on server
|
||||
func (s *Stats) ChangeTotal(i int) {
|
||||
s.Lock()
|
||||
|
@ -9,17 +9,6 @@ import "sync/atomic"
|
||||
// For examples of use, see caps.Set and modes.ModeSet; the array has to be converted to a
|
||||
// slice to use these functions.
|
||||
|
||||
// BitsetInitialize initializes a bitset.
|
||||
func BitsetInitialize(set []uint64) {
|
||||
// XXX re-zero the bitset using atomic stores. it's unclear whether this is required,
|
||||
// however, golang issue #5045 suggests that you shouldn't mix atomic operations
|
||||
// with non-atomic operations (such as the runtime's automatic zero-initialization) on
|
||||
// the same word
|
||||
for i := 0; i < len(set); i++ {
|
||||
atomic.StoreUint64(&set[i], 0)
|
||||
}
|
||||
}
|
||||
|
||||
// BitsetGet returns whether a given bit of the bitset is set.
|
||||
func BitsetGet(set []uint64, position uint) bool {
|
||||
idx := position / 64
|
||||
|
@ -10,7 +10,6 @@ type testBitset [2]uint64
|
||||
func TestSets(t *testing.T) {
|
||||
var t1 testBitset
|
||||
t1s := t1[:]
|
||||
BitsetInitialize(t1s)
|
||||
|
||||
if BitsetGet(t1s, 0) || BitsetGet(t1s, 63) || BitsetGet(t1s, 64) || BitsetGet(t1s, 127) {
|
||||
t.Error("no bits should be set in a newly initialized bitset")
|
||||
@ -47,7 +46,6 @@ func TestSets(t *testing.T) {
|
||||
|
||||
var t2 testBitset
|
||||
t2s := t2[:]
|
||||
BitsetInitialize(t2s)
|
||||
|
||||
for i = 0; i < 128; i++ {
|
||||
if i%2 == 1 {
|
||||
|
35
irc/utils/sync.go
Normal file
35
irc/utils/sync.go
Normal file
@ -0,0 +1,35 @@
|
||||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package utils
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// Once is a fork of sync.Once to expose a Done() method.
|
||||
type Once struct {
|
||||
done uint32
|
||||
m sync.Mutex
|
||||
}
|
||||
|
||||
func (o *Once) Do(f func()) {
|
||||
if atomic.LoadUint32(&o.done) == 0 {
|
||||
o.doSlow(f)
|
||||
}
|
||||
}
|
||||
|
||||
func (o *Once) doSlow(f func()) {
|
||||
o.m.Lock()
|
||||
defer o.m.Unlock()
|
||||
if o.done == 0 {
|
||||
defer atomic.StoreUint32(&o.done, 1)
|
||||
f()
|
||||
}
|
||||
}
|
||||
|
||||
func (o *Once) Done() bool {
|
||||
return atomic.LoadUint32(&o.done) == 1
|
||||
}
|
@ -23,12 +23,10 @@ type WhoWasList struct {
|
||||
}
|
||||
|
||||
// NewWhoWasList returns a new WhoWasList
|
||||
func NewWhoWasList(size int) *WhoWasList {
|
||||
return &WhoWasList{
|
||||
buffer: make([]WhoWas, size),
|
||||
start: -1,
|
||||
end: -1,
|
||||
}
|
||||
func (list *WhoWasList) Initialize(size int) {
|
||||
list.buffer = make([]WhoWas, size)
|
||||
list.start = -1
|
||||
list.end = -1
|
||||
}
|
||||
|
||||
// Append adds an entry to the WhoWasList.
|
||||
|
@ -23,7 +23,8 @@ func makeTestWhowas(nick string) WhoWas {
|
||||
|
||||
func TestWhoWas(t *testing.T) {
|
||||
var results []WhoWas
|
||||
wwl := NewWhoWasList(3)
|
||||
var wwl WhoWasList
|
||||
wwl.Initialize(3)
|
||||
// test Find on empty list
|
||||
results = wwl.Find("nobody", 10)
|
||||
if len(results) != 0 {
|
||||
@ -88,7 +89,8 @@ func TestWhoWas(t *testing.T) {
|
||||
|
||||
func TestEmptyWhoWas(t *testing.T) {
|
||||
// stupid edge case; setting an empty whowas buffer should not panic
|
||||
wwl := NewWhoWasList(0)
|
||||
var wwl WhoWasList
|
||||
wwl.Initialize(0)
|
||||
results := wwl.Find("slingamn", 10)
|
||||
if len(results) != 0 {
|
||||
t.Fatalf("incorrect whowas results: %v", results)
|
||||
|
Loading…
Reference in New Issue
Block a user