3
0
mirror of https://github.com/ergochat/ergo.git synced 2024-11-15 00:19:29 +01:00

refactor of channel persistence to use UUIDs

This commit is contained in:
Shivaram Lingamneni 2023-01-04 05:06:21 -05:00
parent bceae9b739
commit 7ce0636276
18 changed files with 804 additions and 653 deletions

View File

@ -39,7 +39,6 @@ const (
keyAccountSettings = "account.settings %s" keyAccountSettings = "account.settings %s"
keyAccountVHost = "account.vhost %s" keyAccountVHost = "account.vhost %s"
keyCertToAccount = "account.creds.certfp %s" keyCertToAccount = "account.creds.certfp %s"
keyAccountChannels = "account.channels %s" // channels registered to the account
keyAccountLastSeen = "account.lastseen %s" keyAccountLastSeen = "account.lastseen %s"
keyAccountReadMarkers = "account.readmarkers %s" keyAccountReadMarkers = "account.readmarkers %s"
keyAccountModes = "account.modes %s" // user modes for the always-on client as a string keyAccountModes = "account.modes %s" // user modes for the always-on client as a string
@ -1765,7 +1764,6 @@ func (am *AccountManager) Unregister(account string, erase bool) error {
nicksKey := fmt.Sprintf(keyAccountAdditionalNicks, casefoldedAccount) nicksKey := fmt.Sprintf(keyAccountAdditionalNicks, casefoldedAccount)
settingsKey := fmt.Sprintf(keyAccountSettings, casefoldedAccount) settingsKey := fmt.Sprintf(keyAccountSettings, casefoldedAccount)
vhostKey := fmt.Sprintf(keyAccountVHost, casefoldedAccount) vhostKey := fmt.Sprintf(keyAccountVHost, casefoldedAccount)
channelsKey := fmt.Sprintf(keyAccountChannels, casefoldedAccount)
joinedChannelsKey := fmt.Sprintf(keyAccountChannelToModes, casefoldedAccount) joinedChannelsKey := fmt.Sprintf(keyAccountChannelToModes, casefoldedAccount)
lastSeenKey := fmt.Sprintf(keyAccountLastSeen, casefoldedAccount) lastSeenKey := fmt.Sprintf(keyAccountLastSeen, casefoldedAccount)
readMarkersKey := fmt.Sprintf(keyAccountReadMarkers, casefoldedAccount) readMarkersKey := fmt.Sprintf(keyAccountReadMarkers, casefoldedAccount)
@ -1781,10 +1779,9 @@ func (am *AccountManager) Unregister(account string, erase bool) error {
am.killClients(clients) am.killClients(clients)
}() }()
var registeredChannels []string
// on our way out, unregister all the account's channels and delete them from the db // on our way out, unregister all the account's channels and delete them from the db
defer func() { defer func() {
for _, channelName := range registeredChannels { for _, channelName := range am.server.channels.ChannelsForAccount(casefoldedAccount) {
err := am.server.channels.SetUnregistered(channelName, casefoldedAccount) err := am.server.channels.SetUnregistered(channelName, casefoldedAccount)
if err != nil { if err != nil {
am.server.logger.Error("internal", "couldn't unregister channel", channelName, err.Error()) am.server.logger.Error("internal", "couldn't unregister channel", channelName, err.Error())
@ -1799,7 +1796,6 @@ func (am *AccountManager) Unregister(account string, erase bool) error {
defer am.serialCacheUpdateMutex.Unlock() defer am.serialCacheUpdateMutex.Unlock()
var accountName string var accountName string
var channelsStr string
keepProtections := false keepProtections := false
am.server.store.Update(func(tx *buntdb.Tx) error { am.server.store.Update(func(tx *buntdb.Tx) error {
// get the unfolded account name; for an active account, this is // get the unfolded account name; for an active account, this is
@ -1827,8 +1823,6 @@ func (am *AccountManager) Unregister(account string, erase bool) error {
credText, err = tx.Get(credentialsKey) credText, err = tx.Get(credentialsKey)
tx.Delete(credentialsKey) tx.Delete(credentialsKey)
tx.Delete(vhostKey) tx.Delete(vhostKey)
channelsStr, _ = tx.Get(channelsKey)
tx.Delete(channelsKey)
tx.Delete(joinedChannelsKey) tx.Delete(joinedChannelsKey)
tx.Delete(lastSeenKey) tx.Delete(lastSeenKey)
tx.Delete(readMarkersKey) tx.Delete(readMarkersKey)
@ -1858,7 +1852,6 @@ func (am *AccountManager) Unregister(account string, erase bool) error {
skeleton, _ := Skeleton(accountName) skeleton, _ := Skeleton(accountName)
additionalNicks := unmarshalReservedNicks(rawNicks) additionalNicks := unmarshalReservedNicks(rawNicks)
registeredChannels = unmarshalRegisteredChannels(channelsStr)
am.Lock() am.Lock()
defer am.Unlock() defer am.Unlock()
@ -1890,21 +1883,6 @@ func unmarshalRegisteredChannels(channelsStr string) (result []string) {
return return
} }
func (am *AccountManager) ChannelsForAccount(account string) (channels []string) {
cfaccount, err := CasefoldName(account)
if err != nil {
return
}
var channelStr string
key := fmt.Sprintf(keyAccountChannels, cfaccount)
am.server.store.View(func(tx *buntdb.Tx) error {
channelStr, _ = tx.Get(key)
return nil
})
return unmarshalRegisteredChannels(channelStr)
}
func (am *AccountManager) AuthenticateByCertificate(client *Client, certfp string, peerCerts []*x509.Certificate, authzid string) (err error) { func (am *AccountManager) AuthenticateByCertificate(client *Client, certfp string, peerCerts []*x509.Certificate, authzid string) (err error) {
if certfp == "" { if certfp == "" {
return errAccountInvalidCredentials return errAccountInvalidCredentials

106
irc/bunt/bunt_datastore.go Normal file
View File

@ -0,0 +1,106 @@
// Copyright (c) 2022 Shivaram Lingamneni
// released under the MIT license
package bunt
import (
"fmt"
"strings"
"time"
"github.com/tidwall/buntdb"
"github.com/ergochat/ergo/irc/datastore"
"github.com/ergochat/ergo/irc/logger"
"github.com/ergochat/ergo/irc/utils"
)
// BuntKey yields a string key corresponding to a (table, UUID) pair.
// Ideally this would not be public, but some of the migration code
// needs it.
func BuntKey(table datastore.Table, uuid utils.UUID) string {
return fmt.Sprintf("%x %s", table, uuid.String())
}
// buntdbDatastore implements datastore.Datastore using a buntdb.
type buntdbDatastore struct {
db *buntdb.DB
logger *logger.Manager
}
// NewBuntdbDatastore returns a datastore.Datastore backed by buntdb.
func NewBuntdbDatastore(db *buntdb.DB, logger *logger.Manager) datastore.Datastore {
return &buntdbDatastore{
db: db,
logger: logger,
}
}
func (b *buntdbDatastore) Backoff() time.Duration {
return 0
}
func (b *buntdbDatastore) GetAll(table datastore.Table) (result []datastore.KV, err error) {
tablePrefix := fmt.Sprintf("%x ", table)
err = b.db.View(func(tx *buntdb.Tx) error {
err := tx.AscendGreaterOrEqual("", tablePrefix, func(key, value string) bool {
if !strings.HasPrefix(key, tablePrefix) {
return false
}
uuid, err := utils.DecodeUUID(strings.TrimPrefix(key, tablePrefix))
if err == nil {
result = append(result, datastore.KV{UUID: uuid, Value: []byte(value)})
} else {
b.logger.Error("datastore", "invalid uuid", key)
}
return true
})
return err
})
return
}
func (b *buntdbDatastore) Get(table datastore.Table, uuid utils.UUID) (value []byte, err error) {
buntKey := BuntKey(table, uuid)
var result string
err = b.db.View(func(tx *buntdb.Tx) error {
result, err = tx.Get(buntKey)
return err
})
return []byte(result), err
}
func (b *buntdbDatastore) Set(table datastore.Table, uuid utils.UUID, value []byte, expiration time.Time) (err error) {
buntKey := BuntKey(table, uuid)
var setOptions *buntdb.SetOptions
if !expiration.IsZero() {
ttl := time.Until(expiration)
if ttl > 0 {
setOptions = &buntdb.SetOptions{Expires: true, TTL: ttl}
} else {
return nil // it already expired, i guess?
}
}
strVal := string(value)
err = b.db.Update(func(tx *buntdb.Tx) error {
_, _, err := tx.Set(buntKey, strVal, setOptions)
return err
})
return
}
func (b *buntdbDatastore) Delete(table datastore.Table, key utils.UUID) (err error) {
buntKey := BuntKey(table, key)
err = b.db.Update(func(tx *buntdb.Tx) error {
_, err := tx.Delete(buntKey)
return err
})
// deleting a nonexistent key is not considered an error
switch err {
case buntdb.ErrNotFound:
return nil
default:
return err
}
}

View File

@ -16,6 +16,7 @@ import (
"github.com/ergochat/irc-go/ircutils" "github.com/ergochat/irc-go/ircutils"
"github.com/ergochat/ergo/irc/caps" "github.com/ergochat/ergo/irc/caps"
"github.com/ergochat/ergo/irc/datastore"
"github.com/ergochat/ergo/irc/history" "github.com/ergochat/ergo/irc/history"
"github.com/ergochat/ergo/irc/modes" "github.com/ergochat/ergo/irc/modes"
"github.com/ergochat/ergo/irc/utils" "github.com/ergochat/ergo/irc/utils"
@ -50,14 +51,14 @@ type Channel struct {
stateMutex sync.RWMutex // tier 1 stateMutex sync.RWMutex // tier 1
writebackLock sync.Mutex // tier 1.5 writebackLock sync.Mutex // tier 1.5
joinPartMutex sync.Mutex // tier 3 joinPartMutex sync.Mutex // tier 3
ensureLoaded utils.Once // manages loading stored registration info from the database
dirtyBits uint dirtyBits uint
settings ChannelSettings settings ChannelSettings
uuid utils.UUID
} }
// NewChannel creates a new channel from a `Server` and a `name` // NewChannel creates a new channel from a `Server` and a `name`
// string, which must be unique on the server. // string, which must be unique on the server.
func NewChannel(s *Server, name, casefoldedName string, registered bool) *Channel { func NewChannel(s *Server, name, casefoldedName string, registered bool, regInfo RegisteredChannel) *Channel {
config := s.Config() config := s.Config()
channel := &Channel{ channel := &Channel{
@ -71,14 +72,15 @@ func NewChannel(s *Server, name, casefoldedName string, registered bool) *Channe
channel.initializeLists() channel.initializeLists()
channel.history.Initialize(0, 0) channel.history.Initialize(0, 0)
if !registered { if registered {
channel.applyRegInfo(regInfo)
} else {
channel.resizeHistory(config) channel.resizeHistory(config)
for _, mode := range config.Channels.defaultModes { for _, mode := range config.Channels.defaultModes {
channel.flags.SetMode(mode, true) channel.flags.SetMode(mode, true)
} }
// no loading to do, so "mark" the load operation as "done": channel.uuid = utils.GenerateUUIDv4()
channel.ensureLoaded.Do(func() {}) }
} // else: modes will be loaded before first join
return channel return channel
} }
@ -92,24 +94,6 @@ func (channel *Channel) initializeLists() {
channel.accountToUMode = make(map[string]modes.Mode) channel.accountToUMode = make(map[string]modes.Mode)
} }
// EnsureLoaded blocks until the channel's registration info has been loaded
// from the database.
func (channel *Channel) EnsureLoaded() {
channel.ensureLoaded.Do(func() {
nmc := channel.NameCasefolded()
info, err := channel.server.channelRegistry.LoadChannel(nmc)
if err == nil {
channel.applyRegInfo(info)
} else {
channel.server.logger.Error("internal", "couldn't load channel", nmc, err.Error())
}
})
}
func (channel *Channel) IsLoaded() bool {
return channel.ensureLoaded.Done()
}
func (channel *Channel) resizeHistory(config *Config) { func (channel *Channel) resizeHistory(config *Config) {
status, _, _ := channel.historyStatus(config) status, _, _ := channel.historyStatus(config)
if status == HistoryEphemeral { if status == HistoryEphemeral {
@ -126,6 +110,7 @@ func (channel *Channel) applyRegInfo(chanReg RegisteredChannel) {
channel.stateMutex.Lock() channel.stateMutex.Lock()
defer channel.stateMutex.Unlock() defer channel.stateMutex.Unlock()
channel.uuid = chanReg.UUID
channel.registeredFounder = chanReg.Founder channel.registeredFounder = chanReg.Founder
channel.registeredTime = chanReg.RegisteredAt channel.registeredTime = chanReg.RegisteredAt
channel.topic = chanReg.Topic channel.topic = chanReg.Topic
@ -150,38 +135,41 @@ func (channel *Channel) applyRegInfo(chanReg RegisteredChannel) {
} }
// obtain a consistent snapshot of the channel state that can be persisted to the DB // obtain a consistent snapshot of the channel state that can be persisted to the DB
func (channel *Channel) ExportRegistration(includeFlags uint) (info RegisteredChannel) { func (channel *Channel) ExportRegistration() (info RegisteredChannel) {
channel.stateMutex.RLock() channel.stateMutex.RLock()
defer channel.stateMutex.RUnlock() defer channel.stateMutex.RUnlock()
info.Name = channel.name info.Name = channel.name
info.NameCasefolded = channel.nameCasefolded info.UUID = channel.uuid
info.Founder = channel.registeredFounder info.Founder = channel.registeredFounder
info.RegisteredAt = channel.registeredTime info.RegisteredAt = channel.registeredTime
if includeFlags&IncludeTopic != 0 { info.Topic = channel.topic
info.Topic = channel.topic info.TopicSetBy = channel.topicSetBy
info.TopicSetBy = channel.topicSetBy info.TopicSetTime = channel.topicSetTime
info.TopicSetTime = channel.topicSetTime
}
if includeFlags&IncludeModes != 0 { info.Key = channel.key
info.Key = channel.key info.Forward = channel.forward
info.Forward = channel.forward info.Modes = channel.flags.AllModes()
info.Modes = channel.flags.AllModes() info.UserLimit = channel.userLimit
info.UserLimit = channel.userLimit
}
if includeFlags&IncludeLists != 0 { info.Bans = channel.lists[modes.BanMask].Masks()
info.Bans = channel.lists[modes.BanMask].Masks() info.Invites = channel.lists[modes.InviteMask].Masks()
info.Invites = channel.lists[modes.InviteMask].Masks() info.Excepts = channel.lists[modes.ExceptMask].Masks()
info.Excepts = channel.lists[modes.ExceptMask].Masks() info.AccountToUMode = utils.CopyMap(channel.accountToUMode)
info.AccountToUMode = utils.CopyMap(channel.accountToUMode)
}
if includeFlags&IncludeSettings != 0 { info.Settings = channel.settings
info.Settings = channel.settings
} return
}
func (channel *Channel) exportSummary() (info RegisteredChannel) {
channel.stateMutex.RLock()
defer channel.stateMutex.RUnlock()
info.Name = channel.name
info.Founder = channel.registeredFounder
info.RegisteredAt = channel.registeredTime
return return
} }
@ -288,9 +276,19 @@ func (channel *Channel) performWrite(additionalDirtyBits uint) (err error) {
return return
} }
info := channel.ExportRegistration(dirtyBits) var success bool
err = channel.server.channelRegistry.StoreChannel(info, dirtyBits) info := channel.ExportRegistration()
if err != nil { if b, err := info.Serialize(); err == nil {
if err := channel.server.dstore.Set(datastore.TableChannels, info.UUID, b, time.Time{}); err == nil {
success = true
} else {
channel.server.logger.Error("internal", "couldn't persist channel", info.Name, err.Error())
}
} else {
channel.server.logger.Error("internal", "couldn't serialize channel", info.Name, err.Error())
}
if !success {
channel.stateMutex.Lock() channel.stateMutex.Lock()
channel.dirtyBits = channel.dirtyBits | dirtyBits channel.dirtyBits = channel.dirtyBits | dirtyBits
channel.stateMutex.Unlock() channel.stateMutex.Unlock()
@ -314,6 +312,7 @@ func (channel *Channel) SetRegistered(founder string) error {
// SetUnregistered deletes the channel's registration information. // SetUnregistered deletes the channel's registration information.
func (channel *Channel) SetUnregistered(expectedFounder string) { func (channel *Channel) SetUnregistered(expectedFounder string) {
uuid := utils.GenerateUUIDv4()
channel.stateMutex.Lock() channel.stateMutex.Lock()
defer channel.stateMutex.Unlock() defer channel.stateMutex.Unlock()
@ -324,6 +323,9 @@ func (channel *Channel) SetUnregistered(expectedFounder string) {
var zeroTime time.Time var zeroTime time.Time
channel.registeredTime = zeroTime channel.registeredTime = zeroTime
channel.accountToUMode = make(map[string]modes.Mode) channel.accountToUMode = make(map[string]modes.Mode)
// reset the UUID so that any re-registration will persist under
// a separate key:
channel.uuid = uuid
} }
// implements `CHANSERV CLEAR #chan ACCESS` (resets bans, invites, excepts, and amodes) // implements `CHANSERV CLEAR #chan ACCESS` (resets bans, invites, excepts, and amodes)

View File

@ -6,7 +6,9 @@ package irc
import ( import (
"sort" "sort"
"sync" "sync"
"time"
"github.com/ergochat/ergo/irc/datastore"
"github.com/ergochat/ergo/irc/utils" "github.com/ergochat/ergo/irc/utils"
) )
@ -25,85 +27,75 @@ type channelManagerEntry struct {
type ChannelManager struct { type ChannelManager struct {
sync.RWMutex // tier 2 sync.RWMutex // tier 2
// chans is the main data structure, mapping casefolded name -> *Channel // chans is the main data structure, mapping casefolded name -> *Channel
chans map[string]*channelManagerEntry chans map[string]*channelManagerEntry
chansSkeletons utils.HashSet[string] // skeletons of *unregistered* chans chansSkeletons utils.HashSet[string]
registeredChannels utils.HashSet[string] // casefolds of registered chans purgedChannels map[string]ChannelPurgeRecord // casefolded name to purge record
registeredSkeletons utils.HashSet[string] // skeletons of registered chans server *Server
purgedChannels utils.HashSet[string] // casefolds of purged chans
server *Server
} }
// NewChannelManager returns a new ChannelManager. // NewChannelManager returns a new ChannelManager.
func (cm *ChannelManager) Initialize(server *Server) { func (cm *ChannelManager) Initialize(server *Server, config *Config) (err error) {
cm.chans = make(map[string]*channelManagerEntry) cm.chans = make(map[string]*channelManagerEntry)
cm.chansSkeletons = make(utils.HashSet[string]) cm.chansSkeletons = make(utils.HashSet[string])
cm.server = server cm.server = server
return cm.loadRegisteredChannels(config)
// purging should work even if registration is disabled
cm.purgedChannels = cm.server.channelRegistry.PurgedChannels()
cm.loadRegisteredChannels(server.Config())
} }
func (cm *ChannelManager) loadRegisteredChannels(config *Config) { func (cm *ChannelManager) loadRegisteredChannels(config *Config) (err error) {
if !config.Channels.Registration.Enabled { allChannels, err := FetchAndDeserializeAll[RegisteredChannel](datastore.TableChannels, cm.server.dstore, cm.server.logger)
if err != nil {
return
}
allPurgeRecords, err := FetchAndDeserializeAll[ChannelPurgeRecord](datastore.TableChannelPurges, cm.server.dstore, cm.server.logger)
if err != nil {
return return
} }
var newChannels []*Channel
var collisions []string
defer func() {
for _, ch := range newChannels {
ch.EnsureLoaded()
cm.server.logger.Debug("channels", "initialized registered channel", ch.Name())
}
for _, collision := range collisions {
cm.server.logger.Warning("channels", "registered channel collides with existing channel", collision)
}
}()
rawNames := cm.server.channelRegistry.AllChannels()
cm.Lock() cm.Lock()
defer cm.Unlock() defer cm.Unlock()
cm.registeredChannels = make(utils.HashSet[string], len(rawNames)) cm.purgedChannels = make(map[string]ChannelPurgeRecord, len(allPurgeRecords))
cm.registeredSkeletons = make(utils.HashSet[string], len(rawNames)) for _, purge := range allPurgeRecords {
for _, name := range rawNames { cm.purgedChannels[purge.NameCasefolded] = purge
cfname, err := CasefoldChannel(name) }
if err == nil {
cm.registeredChannels.Add(cfname) for _, regInfo := range allChannels {
cfname, err := CasefoldChannel(regInfo.Name)
if err != nil {
cm.server.logger.Error("channels", "couldn't casefold registered channel, skipping", regInfo.Name, err.Error())
continue
} else {
cm.server.logger.Debug("channels", "initializing registered channel", regInfo.Name)
} }
skeleton, err := Skeleton(name) skeleton, err := Skeleton(regInfo.Name)
if err == nil { if err == nil {
cm.registeredSkeletons.Add(skeleton) cm.chansSkeletons.Add(skeleton)
} }
if !cm.purgedChannels.Has(cfname) { if _, ok := cm.purgedChannels[cfname]; !ok {
if _, ok := cm.chans[cfname]; !ok { ch := NewChannel(cm.server, regInfo.Name, cfname, true, regInfo)
ch := NewChannel(cm.server, name, cfname, true) cm.chans[cfname] = &channelManagerEntry{
cm.chans[cfname] = &channelManagerEntry{ channel: ch,
channel: ch, pendingJoins: 0,
pendingJoins: 0, skeleton: skeleton,
}
newChannels = append(newChannels, ch)
} else {
collisions = append(collisions, name)
} }
} }
} }
return nil
} }
// Get returns an existing channel with name equivalent to `name`, or nil // Get returns an existing channel with name equivalent to `name`, or nil
func (cm *ChannelManager) Get(name string) (channel *Channel) { func (cm *ChannelManager) Get(name string) (channel *Channel) {
name, err := CasefoldChannel(name) name, err := CasefoldChannel(name)
if err == nil { if err != nil {
cm.RLock() return nil
defer cm.RUnlock() }
entry := cm.chans[name] cm.RLock()
// if the channel is still loading, pretend we don't have it defer cm.RUnlock()
if entry != nil && entry.channel.IsLoaded() { entry := cm.chans[name]
return entry.channel if entry != nil {
} return entry.channel
} }
return nil return nil
} }
@ -122,33 +114,26 @@ func (cm *ChannelManager) Join(client *Client, name string, key string, isSajoin
cm.Lock() cm.Lock()
defer cm.Unlock() defer cm.Unlock()
if cm.purgedChannels.Has(casefoldedName) { // check purges first; a registered purged channel will still be present in `chans`
if _, ok := cm.purgedChannels[casefoldedName]; ok {
return nil, errChannelPurged, false return nil, errChannelPurged, false
} }
entry := cm.chans[casefoldedName] entry := cm.chans[casefoldedName]
if entry == nil { if entry == nil {
registered := cm.registeredChannels.Has(casefoldedName) if server.Config().Channels.OpOnlyCreation &&
// enforce OpOnlyCreation
if !registered && server.Config().Channels.OpOnlyCreation &&
!(isSajoin || client.HasRoleCapabs("chanreg")) { !(isSajoin || client.HasRoleCapabs("chanreg")) {
return nil, errInsufficientPrivs, false return nil, errInsufficientPrivs, false
} }
// enforce confusables // enforce confusables
if !registered && (cm.chansSkeletons.Has(skeleton) || cm.registeredSkeletons.Has(skeleton)) { if cm.chansSkeletons.Has(skeleton) {
return nil, errConfusableIdentifier, false return nil, errConfusableIdentifier, false
} }
entry = &channelManagerEntry{ entry = &channelManagerEntry{
channel: NewChannel(server, name, casefoldedName, registered), channel: NewChannel(server, name, casefoldedName, false, RegisteredChannel{}),
pendingJoins: 0, pendingJoins: 0,
} }
if !registered { cm.chansSkeletons.Add(skeleton)
// for an unregistered channel, we already have the correct unfolded name entry.skeleton = skeleton
// and therefore the final skeleton. for a registered channel, we don't have
// the unfolded name yet (it needs to be loaded from the db), but we already
// have the final skeleton in `registeredSkeletons` so we don't need to track it
cm.chansSkeletons.Add(skeleton)
entry.skeleton = skeleton
}
cm.chans[casefoldedName] = entry cm.chans[casefoldedName] = entry
newChannel = true newChannel = true
} }
@ -160,7 +145,6 @@ func (cm *ChannelManager) Join(client *Client, name string, key string, isSajoin
return err, "" return err, ""
} }
channel.EnsureLoaded()
err, forward = channel.Join(client, key, isSajoin || newChannel, rb) err, forward = channel.Join(client, key, isSajoin || newChannel, rb)
cm.maybeCleanup(channel, true) cm.maybeCleanup(channel, true)
@ -252,13 +236,6 @@ func (cm *ChannelManager) SetRegistered(channelName string, account string) (err
if err != nil { if err != nil {
return err return err
} }
// transfer the skeleton from chansSkeletons to registeredSkeletons
skeleton := entry.skeleton
delete(cm.chansSkeletons, skeleton)
entry.skeleton = ""
cm.chans[cfname] = entry
cm.registeredChannels.Add(cfname)
cm.registeredSkeletons.Add(skeleton)
return nil return nil
} }
@ -268,17 +245,13 @@ func (cm *ChannelManager) SetUnregistered(channelName string, account string) (e
return err return err
} }
info, err := cm.server.channelRegistry.LoadChannel(cfname) var uuid utils.UUID
if err != nil {
return err
}
if info.Founder != account {
return errChannelNotOwnedByAccount
}
defer func() { defer func() {
if err == nil { if err == nil {
err = cm.server.channelRegistry.Delete(info) if delErr := cm.server.dstore.Delete(datastore.TableChannels, uuid); delErr != nil {
cm.server.logger.Error("datastore", "couldn't delete channel registration", cfname, delErr.Error())
}
} }
}() }()
@ -286,15 +259,11 @@ func (cm *ChannelManager) SetUnregistered(channelName string, account string) (e
defer cm.Unlock() defer cm.Unlock()
entry := cm.chans[cfname] entry := cm.chans[cfname]
if entry != nil { if entry != nil {
entry.channel.SetUnregistered(account) if entry.channel.Founder() != account {
delete(cm.registeredChannels, cfname) return errChannelNotOwnedByAccount
// transfer the skeleton from registeredSkeletons to chansSkeletons
if skel, err := Skeleton(entry.channel.Name()); err == nil {
delete(cm.registeredSkeletons, skel)
cm.chansSkeletons.Add(skel)
entry.skeleton = skel
cm.chans[cfname] = entry
} }
uuid = entry.channel.UUID()
entry.channel.SetUnregistered(account) // changes the UUID
// #1619: if the channel has 0 members and was only being retained // #1619: if the channel has 0 members and was only being retained
// because it was registered, clean it up: // because it was registered, clean it up:
cm.maybeCleanupInternal(cfname, entry, false) cm.maybeCleanupInternal(cfname, entry, false)
@ -322,12 +291,11 @@ func (cm *ChannelManager) Rename(name string, newName string) (err error) {
var info RegisteredChannel var info RegisteredChannel
defer func() { defer func() {
if channel != nil && info.Founder != "" { if channel != nil && info.Founder != "" {
channel.Store(IncludeAllAttrs) channel.MarkDirty(IncludeAllAttrs)
if oldCfname != newCfname { }
// we just flushed the channel under its new name, therefore this delete // always-on clients need to update their saved channel memberships
// cannot be overwritten by a write to the old name: for _, member := range channel.Members() {
cm.server.channelRegistry.Delete(info) member.markDirty(IncludeChannels)
}
} }
}() }()
@ -335,11 +303,11 @@ func (cm *ChannelManager) Rename(name string, newName string) (err error) {
defer cm.Unlock() defer cm.Unlock()
entry := cm.chans[oldCfname] entry := cm.chans[oldCfname]
if entry == nil || !entry.channel.IsLoaded() { if entry == nil {
return errNoSuchChannel return errNoSuchChannel
} }
channel = entry.channel channel = entry.channel
info = channel.ExportRegistration(IncludeInitial) info = channel.ExportRegistration()
registered := info.Founder != "" registered := info.Founder != ""
oldSkeleton, err := Skeleton(info.Name) oldSkeleton, err := Skeleton(info.Name)
@ -348,13 +316,13 @@ func (cm *ChannelManager) Rename(name string, newName string) (err error) {
} }
if newCfname != oldCfname { if newCfname != oldCfname {
if cm.chans[newCfname] != nil || cm.registeredChannels.Has(newCfname) { if cm.chans[newCfname] != nil {
return errChannelNameInUse return errChannelNameInUse
} }
} }
if oldSkeleton != newSkeleton { if oldSkeleton != newSkeleton {
if cm.chansSkeletons.Has(newSkeleton) || cm.registeredSkeletons.Has(newSkeleton) { if cm.chansSkeletons.Has(newSkeleton) {
return errConfusableIdentifier return errConfusableIdentifier
} }
} }
@ -364,15 +332,8 @@ func (cm *ChannelManager) Rename(name string, newName string) (err error) {
entry.skeleton = newSkeleton entry.skeleton = newSkeleton
} }
cm.chans[newCfname] = entry cm.chans[newCfname] = entry
if registered { delete(cm.chansSkeletons, oldSkeleton)
delete(cm.registeredChannels, oldCfname) cm.chansSkeletons.Add(newSkeleton)
cm.registeredChannels.Add(newCfname)
delete(cm.registeredSkeletons, oldSkeleton)
cm.registeredSkeletons.Add(newSkeleton)
} else {
delete(cm.chansSkeletons, oldSkeleton)
cm.chansSkeletons.Add(newSkeleton)
}
entry.channel.Rename(newName, newCfname) entry.channel.Rename(newName, newCfname)
return nil return nil
} }
@ -390,7 +351,18 @@ func (cm *ChannelManager) Channels() (result []*Channel) {
defer cm.RUnlock() defer cm.RUnlock()
result = make([]*Channel, 0, len(cm.chans)) result = make([]*Channel, 0, len(cm.chans))
for _, entry := range cm.chans { for _, entry := range cm.chans {
if entry.channel.IsLoaded() { result = append(result, entry.channel)
}
return
}
// ListableChannels returns a slice of all non-purged channels.
func (cm *ChannelManager) ListableChannels() (result []*Channel) {
cm.RLock()
defer cm.RUnlock()
result = make([]*Channel, 0, len(cm.chans))
for cfname, entry := range cm.chans {
if _, ok := cm.purgedChannels[cfname]; !ok {
result = append(result, entry.channel) result = append(result, entry.channel)
} }
} }
@ -403,29 +375,46 @@ func (cm *ChannelManager) Purge(chname string, record ChannelPurgeRecord) (err e
if err != nil { if err != nil {
return errInvalidChannelName return errInvalidChannelName
} }
skel, err := Skeleton(chname)
if err != nil {
return errInvalidChannelName
}
cm.Lock() record.NameCasefolded = chname
cm.purgedChannels.Add(chname) record.UUID = utils.GenerateUUIDv4()
entry := cm.chans[chname]
if entry != nil { channel, err := func() (channel *Channel, err error) {
delete(cm.chans, chname) cm.Lock()
if entry.channel.Founder() != "" { defer cm.Unlock()
delete(cm.registeredSkeletons, skel)
} else { if _, ok := cm.purgedChannels[chname]; ok {
delete(cm.chansSkeletons, skel) return nil, errChannelPurgedAlready
} }
}
cm.Unlock()
cm.server.channelRegistry.PurgeChannel(chname, record) entry := cm.chans[chname]
if entry != nil { // atomically prevent anyone from rejoining
entry.channel.Purge("") cm.purgedChannels[chname] = record
if entry != nil {
channel = entry.channel
}
return
}()
if err != nil {
return err
} }
return nil
if channel != nil {
// actually kick everyone off the channel
channel.Purge("")
}
var purgeBytes []byte
if purgeBytes, err = record.Serialize(); err != nil {
cm.server.logger.Error("internal", "couldn't serialize purge record", channel.Name(), err.Error())
}
// TODO we need a better story about error handling for later
if err = cm.server.dstore.Set(datastore.TableChannelPurges, record.UUID, purgeBytes, time.Time{}); err != nil {
cm.server.logger.Error("datastore", "couldn't store purge record", chname, err.Error())
}
return
} }
// IsPurged queries whether a channel is purged. // IsPurged queries whether a channel is purged.
@ -436,7 +425,7 @@ func (cm *ChannelManager) IsPurged(chname string) (result bool) {
} }
cm.RLock() cm.RLock()
result = cm.purgedChannels.Has(chname) _, result = cm.purgedChannels[chname]
cm.RUnlock() cm.RUnlock()
return return
} }
@ -449,14 +438,16 @@ func (cm *ChannelManager) Unpurge(chname string) (err error) {
} }
cm.Lock() cm.Lock()
found := cm.purgedChannels.Has(chname) record, found := cm.purgedChannels[chname]
delete(cm.purgedChannels, chname) delete(cm.purgedChannels, chname)
cm.Unlock() cm.Unlock()
cm.server.channelRegistry.UnpurgeChannel(chname)
if !found { if !found {
return errNoSuchChannel return errNoSuchChannel
} }
if err := cm.server.dstore.Delete(datastore.TableChannelPurges, record.UUID); err != nil {
cm.server.logger.Error("datastore", "couldn't delete purge record", chname, err.Error())
}
return nil return nil
} }
@ -475,8 +466,46 @@ func (cm *ChannelManager) UnfoldName(cfname string) (result string) {
cm.RLock() cm.RLock()
entry := cm.chans[cfname] entry := cm.chans[cfname]
cm.RUnlock() cm.RUnlock()
if entry != nil && entry.channel.IsLoaded() { if entry != nil {
return entry.channel.Name() return entry.channel.Name()
} }
return cfname return cfname
} }
func (cm *ChannelManager) LoadPurgeRecord(cfchname string) (record ChannelPurgeRecord, err error) {
cm.RLock()
defer cm.RUnlock()
if record, ok := cm.purgedChannels[cfchname]; ok {
return record, nil
} else {
return record, errNoSuchChannel
}
}
func (cm *ChannelManager) ChannelsForAccount(account string) (channels []string) {
cm.RLock()
defer cm.RUnlock()
for cfname, entry := range cm.chans {
if entry.channel.Founder() == account {
channels = append(channels, cfname)
}
}
return
}
// AllChannels returns the uncasefolded names of all registered channels.
func (cm *ChannelManager) AllRegisteredChannels() (result []string) {
cm.RLock()
defer cm.RUnlock()
for cfname, entry := range cm.chans {
if entry.channel.Founder() != "" {
result = append(result, cfname)
}
}
return
}

View File

@ -5,13 +5,8 @@ package irc
import ( import (
"encoding/json" "encoding/json"
"fmt"
"strconv"
"strings"
"time" "time"
"github.com/tidwall/buntdb"
"github.com/ergochat/ergo/irc/modes" "github.com/ergochat/ergo/irc/modes"
"github.com/ergochat/ergo/irc/utils" "github.com/ergochat/ergo/irc/utils"
) )
@ -19,48 +14,6 @@ import (
// this is exclusively the *persistence* layer for channel registration; // this is exclusively the *persistence* layer for channel registration;
// channel creation/tracking/destruction is in channelmanager.go // channel creation/tracking/destruction is in channelmanager.go
const (
keyChannelExists = "channel.exists %s"
keyChannelName = "channel.name %s" // stores the 'preferred name' of the channel, not casemapped
keyChannelRegTime = "channel.registered.time %s"
keyChannelFounder = "channel.founder %s"
keyChannelTopic = "channel.topic %s"
keyChannelTopicSetBy = "channel.topic.setby %s"
keyChannelTopicSetTime = "channel.topic.settime %s"
keyChannelBanlist = "channel.banlist %s"
keyChannelExceptlist = "channel.exceptlist %s"
keyChannelInvitelist = "channel.invitelist %s"
keyChannelPassword = "channel.key %s"
keyChannelModes = "channel.modes %s"
keyChannelAccountToUMode = "channel.accounttoumode %s"
keyChannelUserLimit = "channel.userlimit %s"
keyChannelSettings = "channel.settings %s"
keyChannelForward = "channel.forward %s"
keyChannelPurged = "channel.purged %s"
)
var (
channelKeyStrings = []string{
keyChannelExists,
keyChannelName,
keyChannelRegTime,
keyChannelFounder,
keyChannelTopic,
keyChannelTopicSetBy,
keyChannelTopicSetTime,
keyChannelBanlist,
keyChannelExceptlist,
keyChannelInvitelist,
keyChannelPassword,
keyChannelModes,
keyChannelAccountToUMode,
keyChannelUserLimit,
keyChannelSettings,
keyChannelForward,
}
)
// these are bit flags indicating what part of the channel status is "dirty" // these are bit flags indicating what part of the channel status is "dirty"
// and needs to be read from memory and written to the db // and needs to be read from memory and written to the db
const ( const (
@ -80,8 +33,8 @@ const (
type RegisteredChannel struct { type RegisteredChannel struct {
// Name of the channel. // Name of the channel.
Name string Name string
// Casefolded name of the channel. // UUID for the datastore.
NameCasefolded string UUID utils.UUID
// RegisteredAt represents the time that the channel was registered. // RegisteredAt represents the time that the channel was registered.
RegisteredAt time.Time RegisteredAt time.Time
// Founder indicates the founder of the channel. // Founder indicates the founder of the channel.
@ -112,322 +65,26 @@ type RegisteredChannel struct {
Settings ChannelSettings Settings ChannelSettings
} }
func (r *RegisteredChannel) Serialize() ([]byte, error) {
return json.Marshal(r)
}
func (r *RegisteredChannel) Deserialize(b []byte) (err error) {
return json.Unmarshal(b, r)
}
type ChannelPurgeRecord struct { type ChannelPurgeRecord struct {
Oper string NameCasefolded string `json:"Name"`
PurgedAt time.Time UUID utils.UUID
Reason string Oper string
PurgedAt time.Time
Reason string
} }
// ChannelRegistry manages registered channels. func (c *ChannelPurgeRecord) Serialize() ([]byte, error) {
type ChannelRegistry struct { return json.Marshal(c)
server *Server
} }
// NewChannelRegistry returns a new ChannelRegistry. func (c *ChannelPurgeRecord) Deserialize(b []byte) error {
func (reg *ChannelRegistry) Initialize(server *Server) { return json.Unmarshal(b, c)
reg.server = server
}
// AllChannels returns the uncasefolded names of all registered channels.
func (reg *ChannelRegistry) AllChannels() (result []string) {
prefix := fmt.Sprintf(keyChannelName, "")
reg.server.store.View(func(tx *buntdb.Tx) error {
return tx.AscendGreaterOrEqual("", prefix, func(key, value string) bool {
if !strings.HasPrefix(key, prefix) {
return false
}
result = append(result, value)
return true
})
})
return
}
// PurgedChannels returns the set of all casefolded channel names that have been purged
func (reg *ChannelRegistry) PurgedChannels() (result utils.HashSet[string]) {
result = make(utils.HashSet[string])
prefix := fmt.Sprintf(keyChannelPurged, "")
reg.server.store.View(func(tx *buntdb.Tx) error {
return tx.AscendGreaterOrEqual("", prefix, func(key, value string) bool {
if !strings.HasPrefix(key, prefix) {
return false
}
channel := strings.TrimPrefix(key, prefix)
result.Add(channel)
return true
})
})
return
}
// StoreChannel obtains a consistent view of a channel, then persists it to the store.
func (reg *ChannelRegistry) StoreChannel(info RegisteredChannel, includeFlags uint) (err error) {
if !reg.server.ChannelRegistrationEnabled() {
return
}
if info.Founder == "" {
// sanity check, don't try to store an unregistered channel
return
}
reg.server.store.Update(func(tx *buntdb.Tx) error {
reg.saveChannel(tx, info, includeFlags)
return nil
})
return nil
}
// LoadChannel loads a channel from the store.
func (reg *ChannelRegistry) LoadChannel(nameCasefolded string) (info RegisteredChannel, err error) {
if !reg.server.ChannelRegistrationEnabled() {
err = errFeatureDisabled
return
}
channelKey := nameCasefolded
// nice to have: do all JSON (de)serialization outside of the buntdb transaction
err = reg.server.store.View(func(tx *buntdb.Tx) error {
_, dberr := tx.Get(fmt.Sprintf(keyChannelExists, channelKey))
if dberr == buntdb.ErrNotFound {
// chan does not already exist, return
return errNoSuchChannel
}
// channel exists, load it
name, _ := tx.Get(fmt.Sprintf(keyChannelName, channelKey))
regTime, _ := tx.Get(fmt.Sprintf(keyChannelRegTime, channelKey))
regTimeInt, _ := strconv.ParseInt(regTime, 10, 64)
founder, _ := tx.Get(fmt.Sprintf(keyChannelFounder, channelKey))
topic, _ := tx.Get(fmt.Sprintf(keyChannelTopic, channelKey))
topicSetBy, _ := tx.Get(fmt.Sprintf(keyChannelTopicSetBy, channelKey))
var topicSetTime time.Time
topicSetTimeStr, _ := tx.Get(fmt.Sprintf(keyChannelTopicSetTime, channelKey))
if topicSetTimeInt, topicSetTimeErr := strconv.ParseInt(topicSetTimeStr, 10, 64); topicSetTimeErr == nil {
topicSetTime = time.Unix(0, topicSetTimeInt).UTC()
}
password, _ := tx.Get(fmt.Sprintf(keyChannelPassword, channelKey))
modeString, _ := tx.Get(fmt.Sprintf(keyChannelModes, channelKey))
userLimitString, _ := tx.Get(fmt.Sprintf(keyChannelUserLimit, channelKey))
forward, _ := tx.Get(fmt.Sprintf(keyChannelForward, channelKey))
banlistString, _ := tx.Get(fmt.Sprintf(keyChannelBanlist, channelKey))
exceptlistString, _ := tx.Get(fmt.Sprintf(keyChannelExceptlist, channelKey))
invitelistString, _ := tx.Get(fmt.Sprintf(keyChannelInvitelist, channelKey))
accountToUModeString, _ := tx.Get(fmt.Sprintf(keyChannelAccountToUMode, channelKey))
settingsString, _ := tx.Get(fmt.Sprintf(keyChannelSettings, channelKey))
modeSlice := make([]modes.Mode, len(modeString))
for i, mode := range modeString {
modeSlice[i] = modes.Mode(mode)
}
userLimit, _ := strconv.Atoi(userLimitString)
var banlist map[string]MaskInfo
_ = json.Unmarshal([]byte(banlistString), &banlist)
var exceptlist map[string]MaskInfo
_ = json.Unmarshal([]byte(exceptlistString), &exceptlist)
var invitelist map[string]MaskInfo
_ = json.Unmarshal([]byte(invitelistString), &invitelist)
accountToUMode := make(map[string]modes.Mode)
_ = json.Unmarshal([]byte(accountToUModeString), &accountToUMode)
var settings ChannelSettings
_ = json.Unmarshal([]byte(settingsString), &settings)
info = RegisteredChannel{
Name: name,
NameCasefolded: nameCasefolded,
RegisteredAt: time.Unix(0, regTimeInt).UTC(),
Founder: founder,
Topic: topic,
TopicSetBy: topicSetBy,
TopicSetTime: topicSetTime,
Key: password,
Modes: modeSlice,
Bans: banlist,
Excepts: exceptlist,
Invites: invitelist,
AccountToUMode: accountToUMode,
UserLimit: int(userLimit),
Settings: settings,
Forward: forward,
}
return nil
})
return
}
// Delete deletes a channel corresponding to `info`. If no such channel
// is present in the database, no error is returned.
func (reg *ChannelRegistry) Delete(info RegisteredChannel) (err error) {
if !reg.server.ChannelRegistrationEnabled() {
return
}
reg.server.store.Update(func(tx *buntdb.Tx) error {
reg.deleteChannel(tx, info.NameCasefolded, info)
return nil
})
return nil
}
// delete a channel, unless it was overwritten by another registration of the same channel
func (reg *ChannelRegistry) deleteChannel(tx *buntdb.Tx, key string, info RegisteredChannel) {
_, err := tx.Get(fmt.Sprintf(keyChannelExists, key))
if err == nil {
regTime, _ := tx.Get(fmt.Sprintf(keyChannelRegTime, key))
regTimeInt, _ := strconv.ParseInt(regTime, 10, 64)
registeredAt := time.Unix(0, regTimeInt).UTC()
founder, _ := tx.Get(fmt.Sprintf(keyChannelFounder, key))
// to see if we're deleting the right channel, confirm the founder and the registration time
if founder == info.Founder && registeredAt.Equal(info.RegisteredAt) {
for _, keyFmt := range channelKeyStrings {
tx.Delete(fmt.Sprintf(keyFmt, key))
}
// remove this channel from the client's list of registered channels
channelsKey := fmt.Sprintf(keyAccountChannels, info.Founder)
channelsStr, err := tx.Get(channelsKey)
if err == buntdb.ErrNotFound {
return
}
registeredChannels := unmarshalRegisteredChannels(channelsStr)
var nowRegisteredChannels []string
for _, channel := range registeredChannels {
if channel != key {
nowRegisteredChannels = append(nowRegisteredChannels, channel)
}
}
tx.Set(channelsKey, strings.Join(nowRegisteredChannels, ","), nil)
}
}
}
func (reg *ChannelRegistry) updateAccountToChannelMapping(tx *buntdb.Tx, channelInfo RegisteredChannel) {
channelKey := channelInfo.NameCasefolded
chanFounderKey := fmt.Sprintf(keyChannelFounder, channelKey)
founder, existsErr := tx.Get(chanFounderKey)
if existsErr == buntdb.ErrNotFound || founder != channelInfo.Founder {
// add to new founder's list
accountChannelsKey := fmt.Sprintf(keyAccountChannels, channelInfo.Founder)
alreadyChannels, _ := tx.Get(accountChannelsKey)
newChannels := channelKey // this is the casefolded channel name
if alreadyChannels != "" {
newChannels = fmt.Sprintf("%s,%s", alreadyChannels, newChannels)
}
tx.Set(accountChannelsKey, newChannels, nil)
}
if existsErr == nil && founder != channelInfo.Founder {
// remove from old founder's list
accountChannelsKey := fmt.Sprintf(keyAccountChannels, founder)
alreadyChannelsRaw, _ := tx.Get(accountChannelsKey)
var newChannels []string
if alreadyChannelsRaw != "" {
for _, chname := range strings.Split(alreadyChannelsRaw, ",") {
if chname != channelInfo.NameCasefolded {
newChannels = append(newChannels, chname)
}
}
}
tx.Set(accountChannelsKey, strings.Join(newChannels, ","), nil)
}
}
// saveChannel saves a channel to the store.
func (reg *ChannelRegistry) saveChannel(tx *buntdb.Tx, channelInfo RegisteredChannel, includeFlags uint) {
channelKey := channelInfo.NameCasefolded
// maintain the mapping of account -> registered channels
reg.updateAccountToChannelMapping(tx, channelInfo)
if includeFlags&IncludeInitial != 0 {
tx.Set(fmt.Sprintf(keyChannelExists, channelKey), "1", nil)
tx.Set(fmt.Sprintf(keyChannelName, channelKey), channelInfo.Name, nil)
tx.Set(fmt.Sprintf(keyChannelRegTime, channelKey), strconv.FormatInt(channelInfo.RegisteredAt.UnixNano(), 10), nil)
tx.Set(fmt.Sprintf(keyChannelFounder, channelKey), channelInfo.Founder, nil)
}
if includeFlags&IncludeTopic != 0 {
tx.Set(fmt.Sprintf(keyChannelTopic, channelKey), channelInfo.Topic, nil)
var topicSetTimeStr string
if !channelInfo.TopicSetTime.IsZero() {
topicSetTimeStr = strconv.FormatInt(channelInfo.TopicSetTime.UnixNano(), 10)
}
tx.Set(fmt.Sprintf(keyChannelTopicSetTime, channelKey), topicSetTimeStr, nil)
tx.Set(fmt.Sprintf(keyChannelTopicSetBy, channelKey), channelInfo.TopicSetBy, nil)
}
if includeFlags&IncludeModes != 0 {
tx.Set(fmt.Sprintf(keyChannelPassword, channelKey), channelInfo.Key, nil)
modeString := modes.Modes(channelInfo.Modes).String()
tx.Set(fmt.Sprintf(keyChannelModes, channelKey), modeString, nil)
tx.Set(fmt.Sprintf(keyChannelUserLimit, channelKey), strconv.Itoa(channelInfo.UserLimit), nil)
tx.Set(fmt.Sprintf(keyChannelForward, channelKey), channelInfo.Forward, nil)
}
if includeFlags&IncludeLists != 0 {
banlistString, _ := json.Marshal(channelInfo.Bans)
tx.Set(fmt.Sprintf(keyChannelBanlist, channelKey), string(banlistString), nil)
exceptlistString, _ := json.Marshal(channelInfo.Excepts)
tx.Set(fmt.Sprintf(keyChannelExceptlist, channelKey), string(exceptlistString), nil)
invitelistString, _ := json.Marshal(channelInfo.Invites)
tx.Set(fmt.Sprintf(keyChannelInvitelist, channelKey), string(invitelistString), nil)
accountToUModeString, _ := json.Marshal(channelInfo.AccountToUMode)
tx.Set(fmt.Sprintf(keyChannelAccountToUMode, channelKey), string(accountToUModeString), nil)
}
if includeFlags&IncludeSettings != 0 {
settingsString, _ := json.Marshal(channelInfo.Settings)
tx.Set(fmt.Sprintf(keyChannelSettings, channelKey), string(settingsString), nil)
}
}
// PurgeChannel records a channel purge.
func (reg *ChannelRegistry) PurgeChannel(chname string, record ChannelPurgeRecord) (err error) {
serialized, err := json.Marshal(record)
if err != nil {
return err
}
serializedStr := string(serialized)
key := fmt.Sprintf(keyChannelPurged, chname)
return reg.server.store.Update(func(tx *buntdb.Tx) error {
tx.Set(key, serializedStr, nil)
return nil
})
}
// LoadPurgeRecord retrieves information about whether and how a channel was purged.
func (reg *ChannelRegistry) LoadPurgeRecord(chname string) (record ChannelPurgeRecord, err error) {
var rawRecord string
key := fmt.Sprintf(keyChannelPurged, chname)
reg.server.store.View(func(tx *buntdb.Tx) error {
rawRecord, _ = tx.Get(key)
return nil
})
if rawRecord == "" {
err = errNoSuchChannel
return
}
err = json.Unmarshal([]byte(rawRecord), &record)
if err != nil {
reg.server.logger.Error("internal", "corrupt purge record", chname, err.Error())
err = errNoSuchChannel
return
}
return
}
// UnpurgeChannel deletes the record of a channel purge.
func (reg *ChannelRegistry) UnpurgeChannel(chname string) (err error) {
key := fmt.Sprintf(keyChannelPurged, chname)
return reg.server.store.Update(func(tx *buntdb.Tx) error {
tx.Delete(key)
return nil
})
} }

View File

@ -459,7 +459,7 @@ func csRegisterHandler(service *ircService, server *Server, client *Client, comm
// check whether a client has already registered too many channels // check whether a client has already registered too many channels
func checkChanLimit(service *ircService, client *Client, rb *ResponseBuffer) (ok bool) { func checkChanLimit(service *ircService, client *Client, rb *ResponseBuffer) (ok bool) {
account := client.Account() account := client.Account()
channelsAlreadyRegistered := client.server.accounts.ChannelsForAccount(account) channelsAlreadyRegistered := client.server.channels.ChannelsForAccount(account)
ok = len(channelsAlreadyRegistered) < client.server.Config().Channels.Registration.MaxChannelsPerAccount || client.HasRoleCapabs("chanreg") ok = len(channelsAlreadyRegistered) < client.server.Config().Channels.Registration.MaxChannelsPerAccount || client.HasRoleCapabs("chanreg")
if !ok { if !ok {
service.Notice(rb, client.t("You have already registered the maximum number of channels; try dropping some with /CS UNREGISTER")) service.Notice(rb, client.t("You have already registered the maximum number of channels; try dropping some with /CS UNREGISTER"))
@ -496,8 +496,8 @@ func csUnregisterHandler(service *ircService, server *Server, client *Client, co
return return
} }
info := channel.ExportRegistration(0) info := channel.exportSummary()
channelKey := info.NameCasefolded channelKey := channel.NameCasefolded()
if !csPrivsCheck(service, info, client, rb) { if !csPrivsCheck(service, info, client, rb) {
return return
} }
@ -519,7 +519,7 @@ func csClearHandler(service *ircService, server *Server, client *Client, command
service.Notice(rb, client.t("Channel does not exist")) service.Notice(rb, client.t("Channel does not exist"))
return return
} }
if !csPrivsCheck(service, channel.ExportRegistration(0), client, rb) { if !csPrivsCheck(service, channel.exportSummary(), client, rb) {
return return
} }
@ -550,7 +550,7 @@ func csTransferHandler(service *ircService, server *Server, client *Client, comm
service.Notice(rb, client.t("Channel does not exist")) service.Notice(rb, client.t("Channel does not exist"))
return return
} }
regInfo := channel.ExportRegistration(0) regInfo := channel.exportSummary()
chname = regInfo.Name chname = regInfo.Name
account := client.Account() account := client.Account()
isFounder := account != "" && account == regInfo.Founder isFounder := account != "" && account == regInfo.Founder
@ -729,11 +729,6 @@ func csPurgeListHandler(service *ircService, client *Client, rb *ResponseBuffer)
} }
func csListHandler(service *ircService, server *Server, client *Client, command string, params []string, rb *ResponseBuffer) { func csListHandler(service *ircService, server *Server, client *Client, command string, params []string, rb *ResponseBuffer) {
if !client.HasRoleCapabs("chanreg") {
service.Notice(rb, client.t("Insufficient privileges"))
return
}
var searchRegex *regexp.Regexp var searchRegex *regexp.Regexp
if len(params) > 0 { if len(params) > 0 {
var err error var err error
@ -746,7 +741,7 @@ func csListHandler(service *ircService, server *Server, client *Client, command
service.Notice(rb, ircfmt.Unescape(client.t("*** $bChanServ LIST$b ***"))) service.Notice(rb, ircfmt.Unescape(client.t("*** $bChanServ LIST$b ***")))
channels := server.channelRegistry.AllChannels() channels := server.channels.AllRegisteredChannels()
for _, channel := range channels { for _, channel := range channels {
if searchRegex == nil || searchRegex.MatchString(channel) { if searchRegex == nil || searchRegex.MatchString(channel) {
service.Notice(rb, fmt.Sprintf(" %s", channel)) service.Notice(rb, fmt.Sprintf(" %s", channel))
@ -771,7 +766,7 @@ func csInfoHandler(service *ircService, server *Server, client *Client, command
// purge status // purge status
if client.HasRoleCapabs("chanreg") { if client.HasRoleCapabs("chanreg") {
purgeRecord, err := server.channelRegistry.LoadPurgeRecord(chname) purgeRecord, err := server.channels.LoadPurgeRecord(chname)
if err == nil { if err == nil {
service.Notice(rb, fmt.Sprintf(client.t("Channel %s was purged by the server operators and cannot be used"), chname)) service.Notice(rb, fmt.Sprintf(client.t("Channel %s was purged by the server operators and cannot be used"), chname))
service.Notice(rb, fmt.Sprintf(client.t("Purged by operator: %s"), purgeRecord.Oper)) service.Notice(rb, fmt.Sprintf(client.t("Purged by operator: %s"), purgeRecord.Oper))
@ -789,13 +784,7 @@ func csInfoHandler(service *ircService, server *Server, client *Client, command
var chinfo RegisteredChannel var chinfo RegisteredChannel
channel := server.channels.Get(params[0]) channel := server.channels.Get(params[0])
if channel != nil { if channel != nil {
chinfo = channel.ExportRegistration(0) chinfo = channel.exportSummary()
} else {
chinfo, err = server.channelRegistry.LoadChannel(chname)
if err != nil && !(err == errNoSuchChannel || err == errFeatureDisabled) {
service.Notice(rb, client.t("An error occurred"))
return
}
} }
// channel exists but is unregistered, or doesn't exist: // channel exists but is unregistered, or doesn't exist:
@ -835,12 +824,12 @@ func csGetHandler(service *ircService, server *Server, client *Client, command s
service.Notice(rb, client.t("No such channel")) service.Notice(rb, client.t("No such channel"))
return return
} }
info := channel.ExportRegistration(IncludeSettings) info := channel.exportSummary()
if !csPrivsCheck(service, info, client, rb) { if !csPrivsCheck(service, info, client, rb) {
return return
} }
displayChannelSetting(service, setting, info.Settings, client, rb) displayChannelSetting(service, setting, channel.Settings(), client, rb)
} }
func csSetHandler(service *ircService, server *Server, client *Client, command string, params []string, rb *ResponseBuffer) { func csSetHandler(service *ircService, server *Server, client *Client, command string, params []string, rb *ResponseBuffer) {
@ -850,12 +839,12 @@ func csSetHandler(service *ircService, server *Server, client *Client, command s
service.Notice(rb, client.t("No such channel")) service.Notice(rb, client.t("No such channel"))
return return
} }
info := channel.ExportRegistration(IncludeSettings) info := channel.exportSummary()
settings := info.Settings
if !csPrivsCheck(service, info, client, rb) { if !csPrivsCheck(service, info, client, rb) {
return return
} }
settings := channel.Settings()
var err error var err error
switch strings.ToLower(setting) { switch strings.ToLower(setting) {
case "history": case "history":

View File

@ -14,6 +14,8 @@ import (
"strings" "strings"
"time" "time"
"github.com/ergochat/ergo/irc/bunt"
"github.com/ergochat/ergo/irc/datastore"
"github.com/ergochat/ergo/irc/modes" "github.com/ergochat/ergo/irc/modes"
"github.com/ergochat/ergo/irc/utils" "github.com/ergochat/ergo/irc/utils"
@ -21,12 +23,19 @@ import (
) )
const ( const (
// 'version' of the database schema // TODO migrate metadata keys as well
keySchemaVersion = "db.version"
// latest schema of the db
latestDbSchema = 22
keyCloakSecret = "crypto.cloak_secret" // 'version' of the database schema
// latest schema of the db
latestDbSchema = 23
)
var (
schemaVersionUUID = utils.UUID{0, 255, 85, 13, 212, 10, 191, 121, 245, 152, 142, 89, 97, 141, 219, 87} // AP9VDdQKv3n1mI5ZYY3bVw
cloakSecretUUID = utils.UUID{170, 214, 184, 208, 116, 181, 67, 75, 161, 23, 233, 16, 113, 251, 94, 229} // qta40HS1Q0uhF-kQcfte5Q
keySchemaVersion = bunt.BuntKey(datastore.TableMetadata, schemaVersionUUID)
keyCloakSecret = bunt.BuntKey(datastore.TableMetadata, cloakSecretUUID)
) )
type SchemaChanger func(*Config, *buntdb.Tx) error type SchemaChanger func(*Config, *buntdb.Tx) error
@ -99,10 +108,7 @@ func openDatabaseInternal(config *Config, allowAutoupgrade bool) (db *buntdb.DB,
// read the current version string // read the current version string
var version int var version int
err = db.View(func(tx *buntdb.Tx) (err error) { err = db.View(func(tx *buntdb.Tx) (err error) {
vStr, err := tx.Get(keySchemaVersion) version, err = retrieveSchemaVersion(tx)
if err == nil {
version, err = strconv.Atoi(vStr)
}
return err return err
}) })
if err != nil { if err != nil {
@ -130,6 +136,17 @@ func openDatabaseInternal(config *Config, allowAutoupgrade bool) (db *buntdb.DB,
} }
} }
func retrieveSchemaVersion(tx *buntdb.Tx) (version int, err error) {
if val, err := tx.Get(keySchemaVersion); err == nil {
return strconv.Atoi(val)
}
// legacy key:
if val, err := tx.Get("db.version"); err == nil {
return strconv.Atoi(val)
}
return 0, buntdb.ErrNotFound
}
func performAutoUpgrade(currentVersion int, config *Config) (err error) { func performAutoUpgrade(currentVersion int, config *Config) (err error) {
path := config.Datastore.Path path := config.Datastore.Path
log.Printf("attempting to auto-upgrade schema from version %d to %d\n", currentVersion, latestDbSchema) log.Printf("attempting to auto-upgrade schema from version %d to %d\n", currentVersion, latestDbSchema)
@ -167,8 +184,12 @@ func UpgradeDB(config *Config) (err error) {
var version int var version int
err = store.Update(func(tx *buntdb.Tx) error { err = store.Update(func(tx *buntdb.Tx) error {
for { for {
vStr, _ := tx.Get(keySchemaVersion) if version == 0 {
version, _ = strconv.Atoi(vStr) version, err = retrieveSchemaVersion(tx)
if err != nil {
return err
}
}
if version == latestDbSchema { if version == latestDbSchema {
// success! // success!
break break
@ -183,11 +204,12 @@ func UpgradeDB(config *Config) (err error) {
if err != nil { if err != nil {
return err return err
} }
_, _, err = tx.Set(keySchemaVersion, strconv.Itoa(change.TargetVersion), nil) version = change.TargetVersion
_, _, err = tx.Set(keySchemaVersion, strconv.Itoa(version), nil)
if err != nil { if err != nil {
return err return err
} }
log.Printf("successfully updated schema to version %d\n", change.TargetVersion) log.Printf("successfully updated schema to version %d\n", version)
} }
return nil return nil
}) })
@ -198,19 +220,17 @@ func UpgradeDB(config *Config) (err error) {
return err return err
} }
func LoadCloakSecret(db *buntdb.DB) (result string) { func LoadCloakSecret(dstore datastore.Datastore) (result string, err error) {
db.View(func(tx *buntdb.Tx) error { val, err := dstore.Get(datastore.TableMetadata, cloakSecretUUID)
result, _ = tx.Get(keyCloakSecret) if err != nil {
return nil return
}) }
return return string(val), nil
} }
func StoreCloakSecret(db *buntdb.DB, secret string) { func StoreCloakSecret(dstore datastore.Datastore, secret string) {
db.Update(func(tx *buntdb.Tx) error { // TODO error checking
tx.Set(keyCloakSecret, secret, nil) dstore.Set(datastore.TableMetadata, cloakSecretUUID, []byte(secret), time.Time{})
return nil
})
} }
func schemaChangeV1toV2(config *Config, tx *buntdb.Tx) error { func schemaChangeV1toV2(config *Config, tx *buntdb.Tx) error {
@ -1112,6 +1132,92 @@ func schemaChangeV21To22(config *Config, tx *buntdb.Tx) error {
return nil return nil
} }
// first phase of document-oriented database refactor: channels
func schemaChangeV22ToV23(config *Config, tx *buntdb.Tx) error {
keyChannelExists := "channel.exists "
var channelNames []string
tx.AscendGreaterOrEqual("", keyChannelExists, func(key, value string) bool {
if !strings.HasPrefix(key, keyChannelExists) {
return false
}
channelNames = append(channelNames, strings.TrimPrefix(key, keyChannelExists))
return true
})
for _, channelName := range channelNames {
channel, err := loadLegacyChannel(tx, channelName)
if err != nil {
log.Printf("error loading legacy channel %s: %v", channelName, err)
continue
}
channel.UUID = utils.GenerateUUIDv4()
newKey := bunt.BuntKey(datastore.TableChannels, channel.UUID)
j, err := json.Marshal(channel)
if err != nil {
log.Printf("error marshaling channel %s: %v", channelName, err)
continue
}
tx.Set(newKey, string(j), nil)
deleteLegacyChannel(tx, channelName)
}
// purges
keyChannelPurged := "channel.purged "
var purgeKeys []string
var channelPurges []ChannelPurgeRecord
tx.AscendGreaterOrEqual("", keyChannelPurged, func(key, value string) bool {
if !strings.HasPrefix(key, keyChannelPurged) {
return false
}
purgeKeys = append(purgeKeys, key)
cfname := strings.TrimPrefix(key, keyChannelPurged)
var record ChannelPurgeRecord
err := json.Unmarshal([]byte(value), &record)
if err != nil {
log.Printf("error unmarshaling channel purge for %s: %v", cfname, err)
return true
}
record.NameCasefolded = cfname
record.UUID = utils.GenerateUUIDv4()
channelPurges = append(channelPurges, record)
return true
})
for _, record := range channelPurges {
newKey := bunt.BuntKey(datastore.TableChannelPurges, record.UUID)
j, err := json.Marshal(record)
if err != nil {
log.Printf("error marshaling channel purge %s: %v", record.NameCasefolded, err)
continue
}
tx.Set(newKey, string(j), nil)
}
for _, purgeKey := range purgeKeys {
tx.Delete(purgeKey)
}
// clean up denormalized account-to-channels mapping
keyAccountChannels := "account.channels "
var accountToChannels []string
tx.AscendGreaterOrEqual("", keyAccountChannels, func(key, value string) bool {
if !strings.HasPrefix(key, keyAccountChannels) {
return false
}
accountToChannels = append(accountToChannels, key)
return true
})
for _, key := range accountToChannels {
tx.Delete(key)
}
// migrate cloak secret
val, _ := tx.Get("crypto.cloak_secret")
tx.Set(keyCloakSecret, val, nil)
// bump the legacy version key to mark the database as downgrade-incompatible
tx.Set("db.version", "23", nil)
return nil
}
func getSchemaChange(initialVersion int) (result SchemaChange, ok bool) { func getSchemaChange(initialVersion int) (result SchemaChange, ok bool) {
for _, change := range allChanges { for _, change := range allChanges {
if initialVersion == change.InitialVersion { if initialVersion == change.InitialVersion {
@ -1227,4 +1333,9 @@ var allChanges = []SchemaChange{
TargetVersion: 22, TargetVersion: 22,
Changer: schemaChangeV21To22, Changer: schemaChangeV21To22,
}, },
{
InitialVersion: 22,
TargetVersion: 23,
Changer: schemaChangeV22ToV23,
},
} }

View File

@ -0,0 +1,45 @@
// Copyright (c) 2022 Shivaram Lingamneni <slingamn@cs.stanford.edu>
// released under the MIT license
package datastore
import (
"time"
"github.com/ergochat/ergo/irc/utils"
)
type Table uint16
// XXX these are persisted and must remain stable;
// do not reorder, when deleting use _ to ensure that the deleted value is skipped
const (
TableMetadata Table = iota
TableChannels
TableChannelPurges
)
type KV struct {
UUID utils.UUID
Value []byte
}
// A Datastore provides the following abstraction:
// 1. Tables, each keyed on a UUID (the implementation is free to merge
// the table name and the UUID into a single key as long as the rest of
// the contract can be satisfied). Table names are [a-z0-9_]+
// 2. The ability to efficiently enumerate all uuid-value pairs in a table
// 3. Gets, sets, and deletes for individual (table, uuid) keys
type Datastore interface {
Backoff() time.Duration
GetAll(table Table) ([]KV, error)
// This is rarely used because it would typically lead to TOCTOU races
Get(table Table, key utils.UUID) (value []byte, err error)
Set(table Table, key utils.UUID, value []byte, expiration time.Time) error
// Note that deleting a nonexistent key is not considered an error
Delete(table Table, key utils.UUID) error
}

View File

@ -51,6 +51,7 @@ var (
errNoExistingBan = errors.New("Ban does not exist") errNoExistingBan = errors.New("Ban does not exist")
errNoSuchChannel = errors.New(`No such channel`) errNoSuchChannel = errors.New(`No such channel`)
errChannelPurged = errors.New(`This channel was purged by the server operators and cannot be used`) errChannelPurged = errors.New(`This channel was purged by the server operators and cannot be used`)
errChannelPurgedAlready = errors.New(`This channel was already purged and cannot be purged again`)
errConfusableIdentifier = errors.New("This identifier is confusable with one already in use") errConfusableIdentifier = errors.New("This identifier is confusable with one already in use")
errInsufficientPrivs = errors.New("Insufficient privileges") errInsufficientPrivs = errors.New("Insufficient privileges")
errInvalidUsername = errors.New("Invalid username") errInvalidUsername = errors.New("Invalid username")

View File

@ -638,3 +638,9 @@ func (channel *Channel) getAmode(cfaccount string) (result modes.Mode) {
defer channel.stateMutex.RUnlock() defer channel.stateMutex.RUnlock()
return channel.accountToUMode[cfaccount] return channel.accountToUMode[cfaccount]
} }
func (channel *Channel) UUID() utils.UUID {
channel.stateMutex.RLock()
defer channel.stateMutex.RUnlock()
return channel.uuid
}

View File

@ -1718,7 +1718,7 @@ func listHandler(server *Server, client *Client, msg ircmsg.Message, rb *Respons
clientIsOp := client.HasRoleCapabs("sajoin") clientIsOp := client.HasRoleCapabs("sajoin")
if len(channels) == 0 { if len(channels) == 0 {
for _, channel := range server.channels.Channels() { for _, channel := range server.channels.ListableChannels() {
if !clientIsOp && channel.flags.HasMode(modes.Secret) && !channel.hasClient(client) { if !clientIsOp && channel.flags.HasMode(modes.Secret) && !channel.hasClient(client) {
continue continue
} }

View File

@ -193,6 +193,6 @@ func hsSetCloakSecretHandler(service *ircService, server *Server, client *Client
service.Notice(rb, fmt.Sprintf(client.t("To confirm, run this command: %s"), fmt.Sprintf("/HS SETCLOAKSECRET %s %s", secret, expectedCode))) service.Notice(rb, fmt.Sprintf(client.t("To confirm, run this command: %s"), fmt.Sprintf("/HS SETCLOAKSECRET %s %s", secret, expectedCode)))
return return
} }
StoreCloakSecret(server.store, secret) StoreCloakSecret(server.dstore, secret)
service.Notice(rb, client.t("Rotated the cloak secret; you must rehash or restart the server for it to take effect")) service.Notice(rb, client.t("Rotated the cloak secret; you must rehash or restart the server for it to take effect"))
} }

View File

@ -9,9 +9,13 @@ import (
"log" "log"
"os" "os"
"strconv" "strconv"
"time"
"github.com/tidwall/buntdb" "github.com/tidwall/buntdb"
"github.com/ergochat/ergo/irc/bunt"
"github.com/ergochat/ergo/irc/datastore"
"github.com/ergochat/ergo/irc/modes"
"github.com/ergochat/ergo/irc/utils" "github.com/ergochat/ergo/irc/utils"
) )
@ -20,7 +24,7 @@ const (
// XXX instead of referencing, e.g., keyAccountExists, we should write in the string literal // XXX instead of referencing, e.g., keyAccountExists, we should write in the string literal
// (to ensure that no matter what code changes happen elsewhere, we're still producing a // (to ensure that no matter what code changes happen elsewhere, we're still producing a
// db of the hardcoded version) // db of the hardcoded version)
importDBSchemaVersion = 22 importDBSchemaVersion = 23
) )
type userImport struct { type userImport struct {
@ -54,8 +58,8 @@ type databaseImport struct {
Channels map[string]channelImport Channels map[string]channelImport
} }
func serializeAmodes(raw map[string]string, validCfUsernames utils.HashSet[string]) (result []byte, err error) { func convertAmodes(raw map[string]string, validCfUsernames utils.HashSet[string]) (result map[string]modes.Mode, err error) {
processed := make(map[string]int, len(raw)) result = make(map[string]modes.Mode)
for accountName, mode := range raw { for accountName, mode := range raw {
if len(mode) != 1 { if len(mode) != 1 {
return nil, fmt.Errorf("invalid mode %s for account %s", mode, accountName) return nil, fmt.Errorf("invalid mode %s for account %s", mode, accountName)
@ -64,10 +68,9 @@ func serializeAmodes(raw map[string]string, validCfUsernames utils.HashSet[strin
if err != nil || !validCfUsernames.Has(cfname) { if err != nil || !validCfUsernames.Has(cfname) {
log.Printf("skipping invalid amode recipient %s\n", accountName) log.Printf("skipping invalid amode recipient %s\n", accountName)
} else { } else {
processed[cfname] = int(mode[0]) result[cfname] = modes.Mode(mode[0])
} }
} }
result, err = json.Marshal(processed)
return return
} }
@ -147,8 +150,9 @@ func doImportDBGeneric(config *Config, dbImport databaseImport, credsType Creden
cfUsernames.Add(cfUsername) cfUsernames.Add(cfUsername)
} }
// TODO fix this:
for chname, chInfo := range dbImport.Channels { for chname, chInfo := range dbImport.Channels {
cfchname, err := CasefoldChannel(chname) _, err := CasefoldChannel(chname)
if err != nil { if err != nil {
log.Printf("invalid channel name %s: %v", chname, err) log.Printf("invalid channel name %s: %v", chname, err)
continue continue
@ -158,43 +162,42 @@ func doImportDBGeneric(config *Config, dbImport databaseImport, credsType Creden
log.Printf("invalid founder %s for channel %s: %v", chInfo.Founder, chname, err) log.Printf("invalid founder %s for channel %s: %v", chInfo.Founder, chname, err)
continue continue
} }
tx.Set(fmt.Sprintf(keyChannelExists, cfchname), "1", nil) var regInfo RegisteredChannel
tx.Set(fmt.Sprintf(keyChannelName, cfchname), chname, nil) regInfo.Name = chname
tx.Set(fmt.Sprintf(keyChannelRegTime, cfchname), strconv.FormatInt(chInfo.RegisteredAt, 10), nil) regInfo.UUID = utils.GenerateUUIDv4()
tx.Set(fmt.Sprintf(keyChannelFounder, cfchname), cffounder, nil) regInfo.Founder = cffounder
accountChannelsKey := fmt.Sprintf(keyAccountChannels, cffounder) regInfo.RegisteredAt = time.Unix(0, chInfo.RegisteredAt).UTC()
founderChannels, fcErr := tx.Get(accountChannelsKey)
if fcErr != nil || founderChannels == "" {
founderChannels = cfchname
} else {
founderChannels = fmt.Sprintf("%s,%s", founderChannels, cfchname)
}
tx.Set(accountChannelsKey, founderChannels, nil)
if chInfo.Topic != "" { if chInfo.Topic != "" {
tx.Set(fmt.Sprintf(keyChannelTopic, cfchname), chInfo.Topic, nil) regInfo.Topic = chInfo.Topic
tx.Set(fmt.Sprintf(keyChannelTopicSetTime, cfchname), strconv.FormatInt(chInfo.TopicSetAt, 10), nil) regInfo.TopicSetBy = chInfo.TopicSetBy
tx.Set(fmt.Sprintf(keyChannelTopicSetBy, cfchname), chInfo.TopicSetBy, nil) regInfo.TopicSetTime = time.Unix(0, chInfo.TopicSetAt).UTC()
} }
if len(chInfo.Amode) != 0 { if len(chInfo.Amode) != 0 {
m, err := serializeAmodes(chInfo.Amode, cfUsernames) m, err := convertAmodes(chInfo.Amode, cfUsernames)
if err == nil { if err == nil {
tx.Set(fmt.Sprintf(keyChannelAccountToUMode, cfchname), string(m), nil) regInfo.AccountToUMode = m
} else { } else {
log.Printf("couldn't serialize amodes for %s: %v", chname, err) log.Printf("couldn't process amodes for %s: %v", chname, err)
} }
} }
tx.Set(fmt.Sprintf(keyChannelModes, cfchname), chInfo.Modes, nil) for _, mode := range chInfo.Modes {
if chInfo.Key != "" { regInfo.Modes = append(regInfo.Modes, modes.Mode(mode))
tx.Set(fmt.Sprintf(keyChannelPassword, cfchname), chInfo.Key, nil)
} }
regInfo.Key = chInfo.Key
if chInfo.Limit > 0 { if chInfo.Limit > 0 {
tx.Set(fmt.Sprintf(keyChannelUserLimit, cfchname), strconv.Itoa(chInfo.Limit), nil) regInfo.UserLimit = chInfo.Limit
} }
if chInfo.Forward != "" { if chInfo.Forward != "" {
if _, err := CasefoldChannel(chInfo.Forward); err == nil { if _, err := CasefoldChannel(chInfo.Forward); err == nil {
tx.Set(fmt.Sprintf(keyChannelForward, cfchname), chInfo.Forward, nil) regInfo.Forward = chInfo.Forward
} }
} }
if j, err := json.Marshal(regInfo); err == nil {
tx.Set(bunt.BuntKey(datastore.TableChannels, regInfo.UUID), string(j), nil)
} else {
log.Printf("couldn't serialize channel %s: %v", chname, err)
}
} }
if warnSkeletons { if warnSkeletons {

View File

@ -4,7 +4,15 @@ package irc
import ( import (
"encoding/base64" "encoding/base64"
"encoding/json"
"errors" "errors"
"fmt"
"strconv"
"time"
"github.com/tidwall/buntdb"
"github.com/ergochat/ergo/irc/modes"
) )
var ( var (
@ -25,3 +33,116 @@ func decodeLegacyPasswordHash(hash string) ([]byte, error) {
return nil, errInvalidPasswordHash return nil, errInvalidPasswordHash
} }
} }
// legacy channel registration code
const (
keyChannelExists = "channel.exists %s"
keyChannelName = "channel.name %s" // stores the 'preferred name' of the channel, not casemapped
keyChannelRegTime = "channel.registered.time %s"
keyChannelFounder = "channel.founder %s"
keyChannelTopic = "channel.topic %s"
keyChannelTopicSetBy = "channel.topic.setby %s"
keyChannelTopicSetTime = "channel.topic.settime %s"
keyChannelBanlist = "channel.banlist %s"
keyChannelExceptlist = "channel.exceptlist %s"
keyChannelInvitelist = "channel.invitelist %s"
keyChannelPassword = "channel.key %s"
keyChannelModes = "channel.modes %s"
keyChannelAccountToUMode = "channel.accounttoumode %s"
keyChannelUserLimit = "channel.userlimit %s"
keyChannelSettings = "channel.settings %s"
keyChannelForward = "channel.forward %s"
keyChannelPurged = "channel.purged %s"
)
func deleteLegacyChannel(tx *buntdb.Tx, nameCasefolded string) {
tx.Delete(fmt.Sprintf(keyChannelExists, nameCasefolded))
tx.Delete(fmt.Sprintf(keyChannelName, nameCasefolded))
tx.Delete(fmt.Sprintf(keyChannelRegTime, nameCasefolded))
tx.Delete(fmt.Sprintf(keyChannelFounder, nameCasefolded))
tx.Delete(fmt.Sprintf(keyChannelTopic, nameCasefolded))
tx.Delete(fmt.Sprintf(keyChannelTopicSetBy, nameCasefolded))
tx.Delete(fmt.Sprintf(keyChannelTopicSetTime, nameCasefolded))
tx.Delete(fmt.Sprintf(keyChannelBanlist, nameCasefolded))
tx.Delete(fmt.Sprintf(keyChannelExceptlist, nameCasefolded))
tx.Delete(fmt.Sprintf(keyChannelInvitelist, nameCasefolded))
tx.Delete(fmt.Sprintf(keyChannelPassword, nameCasefolded))
tx.Delete(fmt.Sprintf(keyChannelModes, nameCasefolded))
tx.Delete(fmt.Sprintf(keyChannelAccountToUMode, nameCasefolded))
tx.Delete(fmt.Sprintf(keyChannelUserLimit, nameCasefolded))
tx.Delete(fmt.Sprintf(keyChannelSettings, nameCasefolded))
tx.Delete(fmt.Sprintf(keyChannelForward, nameCasefolded))
}
func loadLegacyChannel(tx *buntdb.Tx, nameCasefolded string) (info RegisteredChannel, err error) {
channelKey := nameCasefolded
// nice to have: do all JSON (de)serialization outside of the buntdb transaction
_, dberr := tx.Get(fmt.Sprintf(keyChannelExists, channelKey))
if dberr == buntdb.ErrNotFound {
// chan does not already exist, return
err = errNoSuchChannel
return
}
// channel exists, load it
name, _ := tx.Get(fmt.Sprintf(keyChannelName, channelKey))
regTime, _ := tx.Get(fmt.Sprintf(keyChannelRegTime, channelKey))
regTimeInt, _ := strconv.ParseInt(regTime, 10, 64)
founder, _ := tx.Get(fmt.Sprintf(keyChannelFounder, channelKey))
topic, _ := tx.Get(fmt.Sprintf(keyChannelTopic, channelKey))
topicSetBy, _ := tx.Get(fmt.Sprintf(keyChannelTopicSetBy, channelKey))
var topicSetTime time.Time
topicSetTimeStr, _ := tx.Get(fmt.Sprintf(keyChannelTopicSetTime, channelKey))
if topicSetTimeInt, topicSetTimeErr := strconv.ParseInt(topicSetTimeStr, 10, 64); topicSetTimeErr == nil {
topicSetTime = time.Unix(0, topicSetTimeInt).UTC()
}
password, _ := tx.Get(fmt.Sprintf(keyChannelPassword, channelKey))
modeString, _ := tx.Get(fmt.Sprintf(keyChannelModes, channelKey))
userLimitString, _ := tx.Get(fmt.Sprintf(keyChannelUserLimit, channelKey))
forward, _ := tx.Get(fmt.Sprintf(keyChannelForward, channelKey))
banlistString, _ := tx.Get(fmt.Sprintf(keyChannelBanlist, channelKey))
exceptlistString, _ := tx.Get(fmt.Sprintf(keyChannelExceptlist, channelKey))
invitelistString, _ := tx.Get(fmt.Sprintf(keyChannelInvitelist, channelKey))
accountToUModeString, _ := tx.Get(fmt.Sprintf(keyChannelAccountToUMode, channelKey))
settingsString, _ := tx.Get(fmt.Sprintf(keyChannelSettings, channelKey))
modeSlice := make([]modes.Mode, len(modeString))
for i, mode := range modeString {
modeSlice[i] = modes.Mode(mode)
}
userLimit, _ := strconv.Atoi(userLimitString)
var banlist map[string]MaskInfo
_ = json.Unmarshal([]byte(banlistString), &banlist)
var exceptlist map[string]MaskInfo
_ = json.Unmarshal([]byte(exceptlistString), &exceptlist)
var invitelist map[string]MaskInfo
_ = json.Unmarshal([]byte(invitelistString), &invitelist)
accountToUMode := make(map[string]modes.Mode)
_ = json.Unmarshal([]byte(accountToUModeString), &accountToUMode)
var settings ChannelSettings
_ = json.Unmarshal([]byte(settingsString), &settings)
info = RegisteredChannel{
Name: name,
RegisteredAt: time.Unix(0, regTimeInt).UTC(),
Founder: founder,
Topic: topic,
TopicSetBy: topicSetBy,
TopicSetTime: topicSetTime,
Key: password,
Modes: modeSlice,
Bans: banlist,
Excepts: exceptlist,
Invites: invitelist,
AccountToUMode: accountToUMode,
UserLimit: int(userLimit),
Settings: settings,
Forward: forward,
}
return info, nil
}

View File

@ -954,9 +954,9 @@ func nsInfoHandler(service *ircService, server *Server, client *Client, command
func listRegisteredChannels(service *ircService, accountName string, rb *ResponseBuffer) { func listRegisteredChannels(service *ircService, accountName string, rb *ResponseBuffer) {
client := rb.session.client client := rb.session.client
channels := client.server.accounts.ChannelsForAccount(accountName) channels := client.server.channels.ChannelsForAccount(accountName)
service.Notice(rb, fmt.Sprintf(client.t("Account %s has %d registered channel(s)."), accountName, len(channels))) service.Notice(rb, fmt.Sprintf(client.t("Account %s has %d registered channel(s)."), accountName, len(channels)))
for _, channel := range rb.session.client.server.accounts.ChannelsForAccount(accountName) { for _, channel := range channels {
service.Notice(rb, fmt.Sprintf(client.t("Registered channel: %s"), channel)) service.Notice(rb, fmt.Sprintf(client.t("Registered channel: %s"), channel))
} }
} }

37
irc/serde.go Normal file
View File

@ -0,0 +1,37 @@
// Copyright (c) 2022 Shivaram Lingamneni
// released under the MIT license
package irc
import (
"strconv"
"github.com/ergochat/ergo/irc/datastore"
"github.com/ergochat/ergo/irc/logger"
)
type Serializable interface {
Serialize() ([]byte, error)
Deserialize([]byte) error
}
func FetchAndDeserializeAll[T any, C interface {
*T
Serializable
}](table datastore.Table, dstore datastore.Datastore, log *logger.Manager) (result []T, err error) {
rawRecords, err := dstore.GetAll(table)
if err != nil {
return
}
result = make([]T, len(rawRecords))
pos := 0
for _, record := range rawRecords {
err := C(&result[pos]).Deserialize(record.Value)
if err != nil {
log.Error("internal", "deserialization error", strconv.Itoa(int(table)), record.UUID.String(), err.Error())
continue
}
pos++
}
return result[:pos], nil
}

View File

@ -22,9 +22,12 @@ import (
"github.com/ergochat/irc-go/ircfmt" "github.com/ergochat/irc-go/ircfmt"
"github.com/okzk/sdnotify" "github.com/okzk/sdnotify"
"github.com/tidwall/buntdb"
"github.com/ergochat/ergo/irc/bunt"
"github.com/ergochat/ergo/irc/caps" "github.com/ergochat/ergo/irc/caps"
"github.com/ergochat/ergo/irc/connection_limits" "github.com/ergochat/ergo/irc/connection_limits"
"github.com/ergochat/ergo/irc/datastore"
"github.com/ergochat/ergo/irc/flatip" "github.com/ergochat/ergo/irc/flatip"
"github.com/ergochat/ergo/irc/flock" "github.com/ergochat/ergo/irc/flock"
"github.com/ergochat/ergo/irc/history" "github.com/ergochat/ergo/irc/history"
@ -33,7 +36,6 @@ import (
"github.com/ergochat/ergo/irc/mysql" "github.com/ergochat/ergo/irc/mysql"
"github.com/ergochat/ergo/irc/sno" "github.com/ergochat/ergo/irc/sno"
"github.com/ergochat/ergo/irc/utils" "github.com/ergochat/ergo/irc/utils"
"github.com/tidwall/buntdb"
) )
const ( const (
@ -66,7 +68,6 @@ type Server struct {
accepts AcceptManager accepts AcceptManager
accounts AccountManager accounts AccountManager
channels ChannelManager channels ChannelManager
channelRegistry ChannelRegistry
clients ClientManager clients ClientManager
config atomic.Pointer[Config] config atomic.Pointer[Config]
configFilename string configFilename string
@ -87,6 +88,7 @@ type Server struct {
tracebackSignal chan os.Signal tracebackSignal chan os.Signal
snomasks SnoManager snomasks SnoManager
store *buntdb.DB store *buntdb.DB
dstore datastore.Datastore
historyDB mysql.MySQL historyDB mysql.MySQL
torLimiter connection_limits.TorLimiter torLimiter connection_limits.TorLimiter
whoWas WhoWasList whoWas WhoWasList
@ -98,6 +100,10 @@ type Server struct {
// NewServer returns a new Oragono server. // NewServer returns a new Oragono server.
func NewServer(config *Config, logger *logger.Manager) (*Server, error) { func NewServer(config *Config, logger *logger.Manager) (*Server, error) {
// sanity check that kernel randomness is available; on modern Linux,
// this will block until it is, on other platforms it may panic:
utils.GenerateUUIDv4()
// initialize data structures // initialize data structures
server := &Server{ server := &Server{
ctime: time.Now().UTC(), ctime: time.Now().UTC(),
@ -716,7 +722,11 @@ func (server *Server) applyConfig(config *Config) (err error) {
// now that the datastore is initialized, we can load the cloak secret from it // now that the datastore is initialized, we can load the cloak secret from it
// XXX this modifies config after the initial load, which is naughty, // XXX this modifies config after the initial load, which is naughty,
// but there's no data race because we haven't done SetConfig yet // but there's no data race because we haven't done SetConfig yet
config.Server.Cloaks.SetSecret(LoadCloakSecret(server.store)) cloakSecret, err := LoadCloakSecret(server.dstore)
if err != nil {
return fmt.Errorf("Could not load cloak secret: %w", err)
}
config.Server.Cloaks.SetSecret(cloakSecret)
// activate the new config // activate the new config
server.config.Store(config) server.config.Store(config)
@ -837,6 +847,7 @@ func (server *Server) loadDatastore(config *Config) error {
db, err := OpenDatabase(config) db, err := OpenDatabase(config)
if err == nil { if err == nil {
server.store = db server.store = db
server.dstore = bunt.NewBuntdbDatastore(db, server.logger)
return nil return nil
} else { } else {
return fmt.Errorf("Failed to open datastore: %s", err.Error()) return fmt.Errorf("Failed to open datastore: %s", err.Error())
@ -849,8 +860,7 @@ func (server *Server) loadFromDatastore(config *Config) (err error) {
server.loadDLines() server.loadDLines()
server.loadKLines() server.loadKLines()
server.channelRegistry.Initialize(server) server.channels.Initialize(server, config)
server.channels.Initialize(server)
server.accounts.Initialize(server) server.accounts.Initialize(server)
if config.Datastore.MySQL.Enabled { if config.Datastore.MySQL.Enabled {

56
irc/utils/uuid.go Normal file
View File

@ -0,0 +1,56 @@
// Copyright (c) 2022 Shivaram Lingamneni <slingamn@cs.stanford.edu>
// released under the MIT license
package utils
import (
"crypto/rand"
"encoding/base64"
"errors"
)
var (
ErrInvalidUUID = errors.New("Invalid uuid")
)
// Technically a UUIDv4 has version bits set, but this doesn't matter in practice
type UUID [16]byte
func (u UUID) MarshalJSON() (b []byte, err error) {
b = make([]byte, 24)
b[0] = '"'
base64.RawURLEncoding.Encode(b[1:], u[:])
b[23] = '"'
return
}
func (u *UUID) UnmarshalJSON(b []byte) (err error) {
if len(b) != 24 {
return ErrInvalidUUID
}
readLen, err := base64.RawURLEncoding.Decode(u[:], b[1:23])
if readLen != 16 {
return ErrInvalidUUID
}
return nil
}
func (u *UUID) String() string {
return base64.RawURLEncoding.EncodeToString(u[:])
}
func GenerateUUIDv4() (result UUID) {
_, err := rand.Read(result[:])
if err != nil {
panic(err)
}
return
}
func DecodeUUID(ustr string) (result UUID, err error) {
length, err := base64.RawURLEncoding.Decode(result[:], []byte(ustr))
if err == nil && length != 16 {
err = ErrInvalidUUID
}
return
}