mirror of
https://github.com/ergochat/ergo.git
synced 2025-01-23 18:54:08 +01:00
refactor of channel persistence to use UUIDs
This commit is contained in:
parent
bceae9b739
commit
7ce0636276
@ -39,7 +39,6 @@ const (
|
||||
keyAccountSettings = "account.settings %s"
|
||||
keyAccountVHost = "account.vhost %s"
|
||||
keyCertToAccount = "account.creds.certfp %s"
|
||||
keyAccountChannels = "account.channels %s" // channels registered to the account
|
||||
keyAccountLastSeen = "account.lastseen %s"
|
||||
keyAccountReadMarkers = "account.readmarkers %s"
|
||||
keyAccountModes = "account.modes %s" // user modes for the always-on client as a string
|
||||
@ -1765,7 +1764,6 @@ func (am *AccountManager) Unregister(account string, erase bool) error {
|
||||
nicksKey := fmt.Sprintf(keyAccountAdditionalNicks, casefoldedAccount)
|
||||
settingsKey := fmt.Sprintf(keyAccountSettings, casefoldedAccount)
|
||||
vhostKey := fmt.Sprintf(keyAccountVHost, casefoldedAccount)
|
||||
channelsKey := fmt.Sprintf(keyAccountChannels, casefoldedAccount)
|
||||
joinedChannelsKey := fmt.Sprintf(keyAccountChannelToModes, casefoldedAccount)
|
||||
lastSeenKey := fmt.Sprintf(keyAccountLastSeen, casefoldedAccount)
|
||||
readMarkersKey := fmt.Sprintf(keyAccountReadMarkers, casefoldedAccount)
|
||||
@ -1781,10 +1779,9 @@ func (am *AccountManager) Unregister(account string, erase bool) error {
|
||||
am.killClients(clients)
|
||||
}()
|
||||
|
||||
var registeredChannels []string
|
||||
// on our way out, unregister all the account's channels and delete them from the db
|
||||
defer func() {
|
||||
for _, channelName := range registeredChannels {
|
||||
for _, channelName := range am.server.channels.ChannelsForAccount(casefoldedAccount) {
|
||||
err := am.server.channels.SetUnregistered(channelName, casefoldedAccount)
|
||||
if err != nil {
|
||||
am.server.logger.Error("internal", "couldn't unregister channel", channelName, err.Error())
|
||||
@ -1799,7 +1796,6 @@ func (am *AccountManager) Unregister(account string, erase bool) error {
|
||||
defer am.serialCacheUpdateMutex.Unlock()
|
||||
|
||||
var accountName string
|
||||
var channelsStr string
|
||||
keepProtections := false
|
||||
am.server.store.Update(func(tx *buntdb.Tx) error {
|
||||
// get the unfolded account name; for an active account, this is
|
||||
@ -1827,8 +1823,6 @@ func (am *AccountManager) Unregister(account string, erase bool) error {
|
||||
credText, err = tx.Get(credentialsKey)
|
||||
tx.Delete(credentialsKey)
|
||||
tx.Delete(vhostKey)
|
||||
channelsStr, _ = tx.Get(channelsKey)
|
||||
tx.Delete(channelsKey)
|
||||
tx.Delete(joinedChannelsKey)
|
||||
tx.Delete(lastSeenKey)
|
||||
tx.Delete(readMarkersKey)
|
||||
@ -1858,7 +1852,6 @@ func (am *AccountManager) Unregister(account string, erase bool) error {
|
||||
|
||||
skeleton, _ := Skeleton(accountName)
|
||||
additionalNicks := unmarshalReservedNicks(rawNicks)
|
||||
registeredChannels = unmarshalRegisteredChannels(channelsStr)
|
||||
|
||||
am.Lock()
|
||||
defer am.Unlock()
|
||||
@ -1890,21 +1883,6 @@ func unmarshalRegisteredChannels(channelsStr string) (result []string) {
|
||||
return
|
||||
}
|
||||
|
||||
func (am *AccountManager) ChannelsForAccount(account string) (channels []string) {
|
||||
cfaccount, err := CasefoldName(account)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var channelStr string
|
||||
key := fmt.Sprintf(keyAccountChannels, cfaccount)
|
||||
am.server.store.View(func(tx *buntdb.Tx) error {
|
||||
channelStr, _ = tx.Get(key)
|
||||
return nil
|
||||
})
|
||||
return unmarshalRegisteredChannels(channelStr)
|
||||
}
|
||||
|
||||
func (am *AccountManager) AuthenticateByCertificate(client *Client, certfp string, peerCerts []*x509.Certificate, authzid string) (err error) {
|
||||
if certfp == "" {
|
||||
return errAccountInvalidCredentials
|
||||
|
106
irc/bunt/bunt_datastore.go
Normal file
106
irc/bunt/bunt_datastore.go
Normal file
@ -0,0 +1,106 @@
|
||||
// Copyright (c) 2022 Shivaram Lingamneni
|
||||
// released under the MIT license
|
||||
|
||||
package bunt
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/tidwall/buntdb"
|
||||
|
||||
"github.com/ergochat/ergo/irc/datastore"
|
||||
"github.com/ergochat/ergo/irc/logger"
|
||||
"github.com/ergochat/ergo/irc/utils"
|
||||
)
|
||||
|
||||
// BuntKey yields a string key corresponding to a (table, UUID) pair.
|
||||
// Ideally this would not be public, but some of the migration code
|
||||
// needs it.
|
||||
func BuntKey(table datastore.Table, uuid utils.UUID) string {
|
||||
return fmt.Sprintf("%x %s", table, uuid.String())
|
||||
}
|
||||
|
||||
// buntdbDatastore implements datastore.Datastore using a buntdb.
|
||||
type buntdbDatastore struct {
|
||||
db *buntdb.DB
|
||||
logger *logger.Manager
|
||||
}
|
||||
|
||||
// NewBuntdbDatastore returns a datastore.Datastore backed by buntdb.
|
||||
func NewBuntdbDatastore(db *buntdb.DB, logger *logger.Manager) datastore.Datastore {
|
||||
return &buntdbDatastore{
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *buntdbDatastore) Backoff() time.Duration {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (b *buntdbDatastore) GetAll(table datastore.Table) (result []datastore.KV, err error) {
|
||||
tablePrefix := fmt.Sprintf("%x ", table)
|
||||
err = b.db.View(func(tx *buntdb.Tx) error {
|
||||
err := tx.AscendGreaterOrEqual("", tablePrefix, func(key, value string) bool {
|
||||
if !strings.HasPrefix(key, tablePrefix) {
|
||||
return false
|
||||
}
|
||||
uuid, err := utils.DecodeUUID(strings.TrimPrefix(key, tablePrefix))
|
||||
if err == nil {
|
||||
result = append(result, datastore.KV{UUID: uuid, Value: []byte(value)})
|
||||
} else {
|
||||
b.logger.Error("datastore", "invalid uuid", key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (b *buntdbDatastore) Get(table datastore.Table, uuid utils.UUID) (value []byte, err error) {
|
||||
buntKey := BuntKey(table, uuid)
|
||||
var result string
|
||||
err = b.db.View(func(tx *buntdb.Tx) error {
|
||||
result, err = tx.Get(buntKey)
|
||||
return err
|
||||
})
|
||||
return []byte(result), err
|
||||
}
|
||||
|
||||
func (b *buntdbDatastore) Set(table datastore.Table, uuid utils.UUID, value []byte, expiration time.Time) (err error) {
|
||||
buntKey := BuntKey(table, uuid)
|
||||
var setOptions *buntdb.SetOptions
|
||||
if !expiration.IsZero() {
|
||||
ttl := time.Until(expiration)
|
||||
if ttl > 0 {
|
||||
setOptions = &buntdb.SetOptions{Expires: true, TTL: ttl}
|
||||
} else {
|
||||
return nil // it already expired, i guess?
|
||||
}
|
||||
}
|
||||
strVal := string(value)
|
||||
|
||||
err = b.db.Update(func(tx *buntdb.Tx) error {
|
||||
_, _, err := tx.Set(buntKey, strVal, setOptions)
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (b *buntdbDatastore) Delete(table datastore.Table, key utils.UUID) (err error) {
|
||||
buntKey := BuntKey(table, key)
|
||||
err = b.db.Update(func(tx *buntdb.Tx) error {
|
||||
_, err := tx.Delete(buntKey)
|
||||
return err
|
||||
})
|
||||
// deleting a nonexistent key is not considered an error
|
||||
switch err {
|
||||
case buntdb.ErrNotFound:
|
||||
return nil
|
||||
default:
|
||||
return err
|
||||
}
|
||||
}
|
100
irc/channel.go
100
irc/channel.go
@ -16,6 +16,7 @@ import (
|
||||
"github.com/ergochat/irc-go/ircutils"
|
||||
|
||||
"github.com/ergochat/ergo/irc/caps"
|
||||
"github.com/ergochat/ergo/irc/datastore"
|
||||
"github.com/ergochat/ergo/irc/history"
|
||||
"github.com/ergochat/ergo/irc/modes"
|
||||
"github.com/ergochat/ergo/irc/utils"
|
||||
@ -50,14 +51,14 @@ type Channel struct {
|
||||
stateMutex sync.RWMutex // tier 1
|
||||
writebackLock sync.Mutex // tier 1.5
|
||||
joinPartMutex sync.Mutex // tier 3
|
||||
ensureLoaded utils.Once // manages loading stored registration info from the database
|
||||
dirtyBits uint
|
||||
settings ChannelSettings
|
||||
uuid utils.UUID
|
||||
}
|
||||
|
||||
// NewChannel creates a new channel from a `Server` and a `name`
|
||||
// string, which must be unique on the server.
|
||||
func NewChannel(s *Server, name, casefoldedName string, registered bool) *Channel {
|
||||
func NewChannel(s *Server, name, casefoldedName string, registered bool, regInfo RegisteredChannel) *Channel {
|
||||
config := s.Config()
|
||||
|
||||
channel := &Channel{
|
||||
@ -71,14 +72,15 @@ func NewChannel(s *Server, name, casefoldedName string, registered bool) *Channe
|
||||
channel.initializeLists()
|
||||
channel.history.Initialize(0, 0)
|
||||
|
||||
if !registered {
|
||||
if registered {
|
||||
channel.applyRegInfo(regInfo)
|
||||
} else {
|
||||
channel.resizeHistory(config)
|
||||
for _, mode := range config.Channels.defaultModes {
|
||||
channel.flags.SetMode(mode, true)
|
||||
}
|
||||
// no loading to do, so "mark" the load operation as "done":
|
||||
channel.ensureLoaded.Do(func() {})
|
||||
} // else: modes will be loaded before first join
|
||||
channel.uuid = utils.GenerateUUIDv4()
|
||||
}
|
||||
|
||||
return channel
|
||||
}
|
||||
@ -92,24 +94,6 @@ func (channel *Channel) initializeLists() {
|
||||
channel.accountToUMode = make(map[string]modes.Mode)
|
||||
}
|
||||
|
||||
// EnsureLoaded blocks until the channel's registration info has been loaded
|
||||
// from the database.
|
||||
func (channel *Channel) EnsureLoaded() {
|
||||
channel.ensureLoaded.Do(func() {
|
||||
nmc := channel.NameCasefolded()
|
||||
info, err := channel.server.channelRegistry.LoadChannel(nmc)
|
||||
if err == nil {
|
||||
channel.applyRegInfo(info)
|
||||
} else {
|
||||
channel.server.logger.Error("internal", "couldn't load channel", nmc, err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (channel *Channel) IsLoaded() bool {
|
||||
return channel.ensureLoaded.Done()
|
||||
}
|
||||
|
||||
func (channel *Channel) resizeHistory(config *Config) {
|
||||
status, _, _ := channel.historyStatus(config)
|
||||
if status == HistoryEphemeral {
|
||||
@ -126,6 +110,7 @@ func (channel *Channel) applyRegInfo(chanReg RegisteredChannel) {
|
||||
channel.stateMutex.Lock()
|
||||
defer channel.stateMutex.Unlock()
|
||||
|
||||
channel.uuid = chanReg.UUID
|
||||
channel.registeredFounder = chanReg.Founder
|
||||
channel.registeredTime = chanReg.RegisteredAt
|
||||
channel.topic = chanReg.Topic
|
||||
@ -150,38 +135,41 @@ func (channel *Channel) applyRegInfo(chanReg RegisteredChannel) {
|
||||
}
|
||||
|
||||
// obtain a consistent snapshot of the channel state that can be persisted to the DB
|
||||
func (channel *Channel) ExportRegistration(includeFlags uint) (info RegisteredChannel) {
|
||||
func (channel *Channel) ExportRegistration() (info RegisteredChannel) {
|
||||
channel.stateMutex.RLock()
|
||||
defer channel.stateMutex.RUnlock()
|
||||
|
||||
info.Name = channel.name
|
||||
info.NameCasefolded = channel.nameCasefolded
|
||||
info.UUID = channel.uuid
|
||||
info.Founder = channel.registeredFounder
|
||||
info.RegisteredAt = channel.registeredTime
|
||||
|
||||
if includeFlags&IncludeTopic != 0 {
|
||||
info.Topic = channel.topic
|
||||
info.TopicSetBy = channel.topicSetBy
|
||||
info.TopicSetTime = channel.topicSetTime
|
||||
}
|
||||
info.Topic = channel.topic
|
||||
info.TopicSetBy = channel.topicSetBy
|
||||
info.TopicSetTime = channel.topicSetTime
|
||||
|
||||
if includeFlags&IncludeModes != 0 {
|
||||
info.Key = channel.key
|
||||
info.Forward = channel.forward
|
||||
info.Modes = channel.flags.AllModes()
|
||||
info.UserLimit = channel.userLimit
|
||||
}
|
||||
info.Key = channel.key
|
||||
info.Forward = channel.forward
|
||||
info.Modes = channel.flags.AllModes()
|
||||
info.UserLimit = channel.userLimit
|
||||
|
||||
if includeFlags&IncludeLists != 0 {
|
||||
info.Bans = channel.lists[modes.BanMask].Masks()
|
||||
info.Invites = channel.lists[modes.InviteMask].Masks()
|
||||
info.Excepts = channel.lists[modes.ExceptMask].Masks()
|
||||
info.AccountToUMode = utils.CopyMap(channel.accountToUMode)
|
||||
}
|
||||
info.Bans = channel.lists[modes.BanMask].Masks()
|
||||
info.Invites = channel.lists[modes.InviteMask].Masks()
|
||||
info.Excepts = channel.lists[modes.ExceptMask].Masks()
|
||||
info.AccountToUMode = utils.CopyMap(channel.accountToUMode)
|
||||
|
||||
if includeFlags&IncludeSettings != 0 {
|
||||
info.Settings = channel.settings
|
||||
}
|
||||
info.Settings = channel.settings
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (channel *Channel) exportSummary() (info RegisteredChannel) {
|
||||
channel.stateMutex.RLock()
|
||||
defer channel.stateMutex.RUnlock()
|
||||
|
||||
info.Name = channel.name
|
||||
info.Founder = channel.registeredFounder
|
||||
info.RegisteredAt = channel.registeredTime
|
||||
|
||||
return
|
||||
}
|
||||
@ -288,9 +276,19 @@ func (channel *Channel) performWrite(additionalDirtyBits uint) (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
info := channel.ExportRegistration(dirtyBits)
|
||||
err = channel.server.channelRegistry.StoreChannel(info, dirtyBits)
|
||||
if err != nil {
|
||||
var success bool
|
||||
info := channel.ExportRegistration()
|
||||
if b, err := info.Serialize(); err == nil {
|
||||
if err := channel.server.dstore.Set(datastore.TableChannels, info.UUID, b, time.Time{}); err == nil {
|
||||
success = true
|
||||
} else {
|
||||
channel.server.logger.Error("internal", "couldn't persist channel", info.Name, err.Error())
|
||||
}
|
||||
} else {
|
||||
channel.server.logger.Error("internal", "couldn't serialize channel", info.Name, err.Error())
|
||||
}
|
||||
|
||||
if !success {
|
||||
channel.stateMutex.Lock()
|
||||
channel.dirtyBits = channel.dirtyBits | dirtyBits
|
||||
channel.stateMutex.Unlock()
|
||||
@ -314,6 +312,7 @@ func (channel *Channel) SetRegistered(founder string) error {
|
||||
|
||||
// SetUnregistered deletes the channel's registration information.
|
||||
func (channel *Channel) SetUnregistered(expectedFounder string) {
|
||||
uuid := utils.GenerateUUIDv4()
|
||||
channel.stateMutex.Lock()
|
||||
defer channel.stateMutex.Unlock()
|
||||
|
||||
@ -324,6 +323,9 @@ func (channel *Channel) SetUnregistered(expectedFounder string) {
|
||||
var zeroTime time.Time
|
||||
channel.registeredTime = zeroTime
|
||||
channel.accountToUMode = make(map[string]modes.Mode)
|
||||
// reset the UUID so that any re-registration will persist under
|
||||
// a separate key:
|
||||
channel.uuid = uuid
|
||||
}
|
||||
|
||||
// implements `CHANSERV CLEAR #chan ACCESS` (resets bans, invites, excepts, and amodes)
|
||||
|
@ -6,7 +6,9 @@ package irc
|
||||
import (
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ergochat/ergo/irc/datastore"
|
||||
"github.com/ergochat/ergo/irc/utils"
|
||||
)
|
||||
|
||||
@ -25,85 +27,75 @@ type channelManagerEntry struct {
|
||||
type ChannelManager struct {
|
||||
sync.RWMutex // tier 2
|
||||
// chans is the main data structure, mapping casefolded name -> *Channel
|
||||
chans map[string]*channelManagerEntry
|
||||
chansSkeletons utils.HashSet[string] // skeletons of *unregistered* chans
|
||||
registeredChannels utils.HashSet[string] // casefolds of registered chans
|
||||
registeredSkeletons utils.HashSet[string] // skeletons of registered chans
|
||||
purgedChannels utils.HashSet[string] // casefolds of purged chans
|
||||
server *Server
|
||||
chans map[string]*channelManagerEntry
|
||||
chansSkeletons utils.HashSet[string]
|
||||
purgedChannels map[string]ChannelPurgeRecord // casefolded name to purge record
|
||||
server *Server
|
||||
}
|
||||
|
||||
// NewChannelManager returns a new ChannelManager.
|
||||
func (cm *ChannelManager) Initialize(server *Server) {
|
||||
func (cm *ChannelManager) Initialize(server *Server, config *Config) (err error) {
|
||||
cm.chans = make(map[string]*channelManagerEntry)
|
||||
cm.chansSkeletons = make(utils.HashSet[string])
|
||||
cm.server = server
|
||||
|
||||
// purging should work even if registration is disabled
|
||||
cm.purgedChannels = cm.server.channelRegistry.PurgedChannels()
|
||||
cm.loadRegisteredChannels(server.Config())
|
||||
return cm.loadRegisteredChannels(config)
|
||||
}
|
||||
|
||||
func (cm *ChannelManager) loadRegisteredChannels(config *Config) {
|
||||
if !config.Channels.Registration.Enabled {
|
||||
func (cm *ChannelManager) loadRegisteredChannels(config *Config) (err error) {
|
||||
allChannels, err := FetchAndDeserializeAll[RegisteredChannel](datastore.TableChannels, cm.server.dstore, cm.server.logger)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
allPurgeRecords, err := FetchAndDeserializeAll[ChannelPurgeRecord](datastore.TableChannelPurges, cm.server.dstore, cm.server.logger)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var newChannels []*Channel
|
||||
var collisions []string
|
||||
defer func() {
|
||||
for _, ch := range newChannels {
|
||||
ch.EnsureLoaded()
|
||||
cm.server.logger.Debug("channels", "initialized registered channel", ch.Name())
|
||||
}
|
||||
for _, collision := range collisions {
|
||||
cm.server.logger.Warning("channels", "registered channel collides with existing channel", collision)
|
||||
}
|
||||
}()
|
||||
|
||||
rawNames := cm.server.channelRegistry.AllChannels()
|
||||
|
||||
cm.Lock()
|
||||
defer cm.Unlock()
|
||||
|
||||
cm.registeredChannels = make(utils.HashSet[string], len(rawNames))
|
||||
cm.registeredSkeletons = make(utils.HashSet[string], len(rawNames))
|
||||
for _, name := range rawNames {
|
||||
cfname, err := CasefoldChannel(name)
|
||||
if err == nil {
|
||||
cm.registeredChannels.Add(cfname)
|
||||
cm.purgedChannels = make(map[string]ChannelPurgeRecord, len(allPurgeRecords))
|
||||
for _, purge := range allPurgeRecords {
|
||||
cm.purgedChannels[purge.NameCasefolded] = purge
|
||||
}
|
||||
|
||||
for _, regInfo := range allChannels {
|
||||
cfname, err := CasefoldChannel(regInfo.Name)
|
||||
if err != nil {
|
||||
cm.server.logger.Error("channels", "couldn't casefold registered channel, skipping", regInfo.Name, err.Error())
|
||||
continue
|
||||
} else {
|
||||
cm.server.logger.Debug("channels", "initializing registered channel", regInfo.Name)
|
||||
}
|
||||
skeleton, err := Skeleton(name)
|
||||
skeleton, err := Skeleton(regInfo.Name)
|
||||
if err == nil {
|
||||
cm.registeredSkeletons.Add(skeleton)
|
||||
cm.chansSkeletons.Add(skeleton)
|
||||
}
|
||||
|
||||
if !cm.purgedChannels.Has(cfname) {
|
||||
if _, ok := cm.chans[cfname]; !ok {
|
||||
ch := NewChannel(cm.server, name, cfname, true)
|
||||
cm.chans[cfname] = &channelManagerEntry{
|
||||
channel: ch,
|
||||
pendingJoins: 0,
|
||||
}
|
||||
newChannels = append(newChannels, ch)
|
||||
} else {
|
||||
collisions = append(collisions, name)
|
||||
if _, ok := cm.purgedChannels[cfname]; !ok {
|
||||
ch := NewChannel(cm.server, regInfo.Name, cfname, true, regInfo)
|
||||
cm.chans[cfname] = &channelManagerEntry{
|
||||
channel: ch,
|
||||
pendingJoins: 0,
|
||||
skeleton: skeleton,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get returns an existing channel with name equivalent to `name`, or nil
|
||||
func (cm *ChannelManager) Get(name string) (channel *Channel) {
|
||||
name, err := CasefoldChannel(name)
|
||||
if err == nil {
|
||||
cm.RLock()
|
||||
defer cm.RUnlock()
|
||||
entry := cm.chans[name]
|
||||
// if the channel is still loading, pretend we don't have it
|
||||
if entry != nil && entry.channel.IsLoaded() {
|
||||
return entry.channel
|
||||
}
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
cm.RLock()
|
||||
defer cm.RUnlock()
|
||||
entry := cm.chans[name]
|
||||
if entry != nil {
|
||||
return entry.channel
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -122,33 +114,26 @@ func (cm *ChannelManager) Join(client *Client, name string, key string, isSajoin
|
||||
cm.Lock()
|
||||
defer cm.Unlock()
|
||||
|
||||
if cm.purgedChannels.Has(casefoldedName) {
|
||||
// check purges first; a registered purged channel will still be present in `chans`
|
||||
if _, ok := cm.purgedChannels[casefoldedName]; ok {
|
||||
return nil, errChannelPurged, false
|
||||
}
|
||||
entry := cm.chans[casefoldedName]
|
||||
if entry == nil {
|
||||
registered := cm.registeredChannels.Has(casefoldedName)
|
||||
// enforce OpOnlyCreation
|
||||
if !registered && server.Config().Channels.OpOnlyCreation &&
|
||||
if server.Config().Channels.OpOnlyCreation &&
|
||||
!(isSajoin || client.HasRoleCapabs("chanreg")) {
|
||||
return nil, errInsufficientPrivs, false
|
||||
}
|
||||
// enforce confusables
|
||||
if !registered && (cm.chansSkeletons.Has(skeleton) || cm.registeredSkeletons.Has(skeleton)) {
|
||||
if cm.chansSkeletons.Has(skeleton) {
|
||||
return nil, errConfusableIdentifier, false
|
||||
}
|
||||
entry = &channelManagerEntry{
|
||||
channel: NewChannel(server, name, casefoldedName, registered),
|
||||
channel: NewChannel(server, name, casefoldedName, false, RegisteredChannel{}),
|
||||
pendingJoins: 0,
|
||||
}
|
||||
if !registered {
|
||||
// for an unregistered channel, we already have the correct unfolded name
|
||||
// and therefore the final skeleton. for a registered channel, we don't have
|
||||
// the unfolded name yet (it needs to be loaded from the db), but we already
|
||||
// have the final skeleton in `registeredSkeletons` so we don't need to track it
|
||||
cm.chansSkeletons.Add(skeleton)
|
||||
entry.skeleton = skeleton
|
||||
}
|
||||
cm.chansSkeletons.Add(skeleton)
|
||||
entry.skeleton = skeleton
|
||||
cm.chans[casefoldedName] = entry
|
||||
newChannel = true
|
||||
}
|
||||
@ -160,7 +145,6 @@ func (cm *ChannelManager) Join(client *Client, name string, key string, isSajoin
|
||||
return err, ""
|
||||
}
|
||||
|
||||
channel.EnsureLoaded()
|
||||
err, forward = channel.Join(client, key, isSajoin || newChannel, rb)
|
||||
|
||||
cm.maybeCleanup(channel, true)
|
||||
@ -252,13 +236,6 @@ func (cm *ChannelManager) SetRegistered(channelName string, account string) (err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// transfer the skeleton from chansSkeletons to registeredSkeletons
|
||||
skeleton := entry.skeleton
|
||||
delete(cm.chansSkeletons, skeleton)
|
||||
entry.skeleton = ""
|
||||
cm.chans[cfname] = entry
|
||||
cm.registeredChannels.Add(cfname)
|
||||
cm.registeredSkeletons.Add(skeleton)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -268,17 +245,13 @@ func (cm *ChannelManager) SetUnregistered(channelName string, account string) (e
|
||||
return err
|
||||
}
|
||||
|
||||
info, err := cm.server.channelRegistry.LoadChannel(cfname)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if info.Founder != account {
|
||||
return errChannelNotOwnedByAccount
|
||||
}
|
||||
var uuid utils.UUID
|
||||
|
||||
defer func() {
|
||||
if err == nil {
|
||||
err = cm.server.channelRegistry.Delete(info)
|
||||
if delErr := cm.server.dstore.Delete(datastore.TableChannels, uuid); delErr != nil {
|
||||
cm.server.logger.Error("datastore", "couldn't delete channel registration", cfname, delErr.Error())
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
@ -286,15 +259,11 @@ func (cm *ChannelManager) SetUnregistered(channelName string, account string) (e
|
||||
defer cm.Unlock()
|
||||
entry := cm.chans[cfname]
|
||||
if entry != nil {
|
||||
entry.channel.SetUnregistered(account)
|
||||
delete(cm.registeredChannels, cfname)
|
||||
// transfer the skeleton from registeredSkeletons to chansSkeletons
|
||||
if skel, err := Skeleton(entry.channel.Name()); err == nil {
|
||||
delete(cm.registeredSkeletons, skel)
|
||||
cm.chansSkeletons.Add(skel)
|
||||
entry.skeleton = skel
|
||||
cm.chans[cfname] = entry
|
||||
if entry.channel.Founder() != account {
|
||||
return errChannelNotOwnedByAccount
|
||||
}
|
||||
uuid = entry.channel.UUID()
|
||||
entry.channel.SetUnregistered(account) // changes the UUID
|
||||
// #1619: if the channel has 0 members and was only being retained
|
||||
// because it was registered, clean it up:
|
||||
cm.maybeCleanupInternal(cfname, entry, false)
|
||||
@ -322,12 +291,11 @@ func (cm *ChannelManager) Rename(name string, newName string) (err error) {
|
||||
var info RegisteredChannel
|
||||
defer func() {
|
||||
if channel != nil && info.Founder != "" {
|
||||
channel.Store(IncludeAllAttrs)
|
||||
if oldCfname != newCfname {
|
||||
// we just flushed the channel under its new name, therefore this delete
|
||||
// cannot be overwritten by a write to the old name:
|
||||
cm.server.channelRegistry.Delete(info)
|
||||
}
|
||||
channel.MarkDirty(IncludeAllAttrs)
|
||||
}
|
||||
// always-on clients need to update their saved channel memberships
|
||||
for _, member := range channel.Members() {
|
||||
member.markDirty(IncludeChannels)
|
||||
}
|
||||
}()
|
||||
|
||||
@ -335,11 +303,11 @@ func (cm *ChannelManager) Rename(name string, newName string) (err error) {
|
||||
defer cm.Unlock()
|
||||
|
||||
entry := cm.chans[oldCfname]
|
||||
if entry == nil || !entry.channel.IsLoaded() {
|
||||
if entry == nil {
|
||||
return errNoSuchChannel
|
||||
}
|
||||
channel = entry.channel
|
||||
info = channel.ExportRegistration(IncludeInitial)
|
||||
info = channel.ExportRegistration()
|
||||
registered := info.Founder != ""
|
||||
|
||||
oldSkeleton, err := Skeleton(info.Name)
|
||||
@ -348,13 +316,13 @@ func (cm *ChannelManager) Rename(name string, newName string) (err error) {
|
||||
}
|
||||
|
||||
if newCfname != oldCfname {
|
||||
if cm.chans[newCfname] != nil || cm.registeredChannels.Has(newCfname) {
|
||||
if cm.chans[newCfname] != nil {
|
||||
return errChannelNameInUse
|
||||
}
|
||||
}
|
||||
|
||||
if oldSkeleton != newSkeleton {
|
||||
if cm.chansSkeletons.Has(newSkeleton) || cm.registeredSkeletons.Has(newSkeleton) {
|
||||
if cm.chansSkeletons.Has(newSkeleton) {
|
||||
return errConfusableIdentifier
|
||||
}
|
||||
}
|
||||
@ -364,15 +332,8 @@ func (cm *ChannelManager) Rename(name string, newName string) (err error) {
|
||||
entry.skeleton = newSkeleton
|
||||
}
|
||||
cm.chans[newCfname] = entry
|
||||
if registered {
|
||||
delete(cm.registeredChannels, oldCfname)
|
||||
cm.registeredChannels.Add(newCfname)
|
||||
delete(cm.registeredSkeletons, oldSkeleton)
|
||||
cm.registeredSkeletons.Add(newSkeleton)
|
||||
} else {
|
||||
delete(cm.chansSkeletons, oldSkeleton)
|
||||
cm.chansSkeletons.Add(newSkeleton)
|
||||
}
|
||||
delete(cm.chansSkeletons, oldSkeleton)
|
||||
cm.chansSkeletons.Add(newSkeleton)
|
||||
entry.channel.Rename(newName, newCfname)
|
||||
return nil
|
||||
}
|
||||
@ -390,7 +351,18 @@ func (cm *ChannelManager) Channels() (result []*Channel) {
|
||||
defer cm.RUnlock()
|
||||
result = make([]*Channel, 0, len(cm.chans))
|
||||
for _, entry := range cm.chans {
|
||||
if entry.channel.IsLoaded() {
|
||||
result = append(result, entry.channel)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// ListableChannels returns a slice of all non-purged channels.
|
||||
func (cm *ChannelManager) ListableChannels() (result []*Channel) {
|
||||
cm.RLock()
|
||||
defer cm.RUnlock()
|
||||
result = make([]*Channel, 0, len(cm.chans))
|
||||
for cfname, entry := range cm.chans {
|
||||
if _, ok := cm.purgedChannels[cfname]; !ok {
|
||||
result = append(result, entry.channel)
|
||||
}
|
||||
}
|
||||
@ -403,29 +375,46 @@ func (cm *ChannelManager) Purge(chname string, record ChannelPurgeRecord) (err e
|
||||
if err != nil {
|
||||
return errInvalidChannelName
|
||||
}
|
||||
skel, err := Skeleton(chname)
|
||||
if err != nil {
|
||||
return errInvalidChannelName
|
||||
}
|
||||
|
||||
cm.Lock()
|
||||
cm.purgedChannels.Add(chname)
|
||||
entry := cm.chans[chname]
|
||||
if entry != nil {
|
||||
delete(cm.chans, chname)
|
||||
if entry.channel.Founder() != "" {
|
||||
delete(cm.registeredSkeletons, skel)
|
||||
} else {
|
||||
delete(cm.chansSkeletons, skel)
|
||||
record.NameCasefolded = chname
|
||||
record.UUID = utils.GenerateUUIDv4()
|
||||
|
||||
channel, err := func() (channel *Channel, err error) {
|
||||
cm.Lock()
|
||||
defer cm.Unlock()
|
||||
|
||||
if _, ok := cm.purgedChannels[chname]; ok {
|
||||
return nil, errChannelPurgedAlready
|
||||
}
|
||||
}
|
||||
cm.Unlock()
|
||||
|
||||
cm.server.channelRegistry.PurgeChannel(chname, record)
|
||||
if entry != nil {
|
||||
entry.channel.Purge("")
|
||||
entry := cm.chans[chname]
|
||||
// atomically prevent anyone from rejoining
|
||||
cm.purgedChannels[chname] = record
|
||||
if entry != nil {
|
||||
channel = entry.channel
|
||||
}
|
||||
return
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
|
||||
if channel != nil {
|
||||
// actually kick everyone off the channel
|
||||
channel.Purge("")
|
||||
}
|
||||
|
||||
var purgeBytes []byte
|
||||
if purgeBytes, err = record.Serialize(); err != nil {
|
||||
cm.server.logger.Error("internal", "couldn't serialize purge record", channel.Name(), err.Error())
|
||||
}
|
||||
// TODO we need a better story about error handling for later
|
||||
if err = cm.server.dstore.Set(datastore.TableChannelPurges, record.UUID, purgeBytes, time.Time{}); err != nil {
|
||||
cm.server.logger.Error("datastore", "couldn't store purge record", chname, err.Error())
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// IsPurged queries whether a channel is purged.
|
||||
@ -436,7 +425,7 @@ func (cm *ChannelManager) IsPurged(chname string) (result bool) {
|
||||
}
|
||||
|
||||
cm.RLock()
|
||||
result = cm.purgedChannels.Has(chname)
|
||||
_, result = cm.purgedChannels[chname]
|
||||
cm.RUnlock()
|
||||
return
|
||||
}
|
||||
@ -449,14 +438,16 @@ func (cm *ChannelManager) Unpurge(chname string) (err error) {
|
||||
}
|
||||
|
||||
cm.Lock()
|
||||
found := cm.purgedChannels.Has(chname)
|
||||
record, found := cm.purgedChannels[chname]
|
||||
delete(cm.purgedChannels, chname)
|
||||
cm.Unlock()
|
||||
|
||||
cm.server.channelRegistry.UnpurgeChannel(chname)
|
||||
if !found {
|
||||
return errNoSuchChannel
|
||||
}
|
||||
if err := cm.server.dstore.Delete(datastore.TableChannelPurges, record.UUID); err != nil {
|
||||
cm.server.logger.Error("datastore", "couldn't delete purge record", chname, err.Error())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -475,8 +466,46 @@ func (cm *ChannelManager) UnfoldName(cfname string) (result string) {
|
||||
cm.RLock()
|
||||
entry := cm.chans[cfname]
|
||||
cm.RUnlock()
|
||||
if entry != nil && entry.channel.IsLoaded() {
|
||||
if entry != nil {
|
||||
return entry.channel.Name()
|
||||
}
|
||||
return cfname
|
||||
}
|
||||
|
||||
func (cm *ChannelManager) LoadPurgeRecord(cfchname string) (record ChannelPurgeRecord, err error) {
|
||||
cm.RLock()
|
||||
defer cm.RUnlock()
|
||||
|
||||
if record, ok := cm.purgedChannels[cfchname]; ok {
|
||||
return record, nil
|
||||
} else {
|
||||
return record, errNoSuchChannel
|
||||
}
|
||||
}
|
||||
|
||||
func (cm *ChannelManager) ChannelsForAccount(account string) (channels []string) {
|
||||
cm.RLock()
|
||||
defer cm.RUnlock()
|
||||
|
||||
for cfname, entry := range cm.chans {
|
||||
if entry.channel.Founder() == account {
|
||||
channels = append(channels, cfname)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// AllChannels returns the uncasefolded names of all registered channels.
|
||||
func (cm *ChannelManager) AllRegisteredChannels() (result []string) {
|
||||
cm.RLock()
|
||||
defer cm.RUnlock()
|
||||
|
||||
for cfname, entry := range cm.chans {
|
||||
if entry.channel.Founder() != "" {
|
||||
result = append(result, cfname)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
@ -5,13 +5,8 @@ package irc
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/tidwall/buntdb"
|
||||
|
||||
"github.com/ergochat/ergo/irc/modes"
|
||||
"github.com/ergochat/ergo/irc/utils"
|
||||
)
|
||||
@ -19,48 +14,6 @@ import (
|
||||
// this is exclusively the *persistence* layer for channel registration;
|
||||
// channel creation/tracking/destruction is in channelmanager.go
|
||||
|
||||
const (
|
||||
keyChannelExists = "channel.exists %s"
|
||||
keyChannelName = "channel.name %s" // stores the 'preferred name' of the channel, not casemapped
|
||||
keyChannelRegTime = "channel.registered.time %s"
|
||||
keyChannelFounder = "channel.founder %s"
|
||||
keyChannelTopic = "channel.topic %s"
|
||||
keyChannelTopicSetBy = "channel.topic.setby %s"
|
||||
keyChannelTopicSetTime = "channel.topic.settime %s"
|
||||
keyChannelBanlist = "channel.banlist %s"
|
||||
keyChannelExceptlist = "channel.exceptlist %s"
|
||||
keyChannelInvitelist = "channel.invitelist %s"
|
||||
keyChannelPassword = "channel.key %s"
|
||||
keyChannelModes = "channel.modes %s"
|
||||
keyChannelAccountToUMode = "channel.accounttoumode %s"
|
||||
keyChannelUserLimit = "channel.userlimit %s"
|
||||
keyChannelSettings = "channel.settings %s"
|
||||
keyChannelForward = "channel.forward %s"
|
||||
|
||||
keyChannelPurged = "channel.purged %s"
|
||||
)
|
||||
|
||||
var (
|
||||
channelKeyStrings = []string{
|
||||
keyChannelExists,
|
||||
keyChannelName,
|
||||
keyChannelRegTime,
|
||||
keyChannelFounder,
|
||||
keyChannelTopic,
|
||||
keyChannelTopicSetBy,
|
||||
keyChannelTopicSetTime,
|
||||
keyChannelBanlist,
|
||||
keyChannelExceptlist,
|
||||
keyChannelInvitelist,
|
||||
keyChannelPassword,
|
||||
keyChannelModes,
|
||||
keyChannelAccountToUMode,
|
||||
keyChannelUserLimit,
|
||||
keyChannelSettings,
|
||||
keyChannelForward,
|
||||
}
|
||||
)
|
||||
|
||||
// these are bit flags indicating what part of the channel status is "dirty"
|
||||
// and needs to be read from memory and written to the db
|
||||
const (
|
||||
@ -80,8 +33,8 @@ const (
|
||||
type RegisteredChannel struct {
|
||||
// Name of the channel.
|
||||
Name string
|
||||
// Casefolded name of the channel.
|
||||
NameCasefolded string
|
||||
// UUID for the datastore.
|
||||
UUID utils.UUID
|
||||
// RegisteredAt represents the time that the channel was registered.
|
||||
RegisteredAt time.Time
|
||||
// Founder indicates the founder of the channel.
|
||||
@ -112,322 +65,26 @@ type RegisteredChannel struct {
|
||||
Settings ChannelSettings
|
||||
}
|
||||
|
||||
func (r *RegisteredChannel) Serialize() ([]byte, error) {
|
||||
return json.Marshal(r)
|
||||
}
|
||||
|
||||
func (r *RegisteredChannel) Deserialize(b []byte) (err error) {
|
||||
return json.Unmarshal(b, r)
|
||||
}
|
||||
|
||||
type ChannelPurgeRecord struct {
|
||||
Oper string
|
||||
PurgedAt time.Time
|
||||
Reason string
|
||||
NameCasefolded string `json:"Name"`
|
||||
UUID utils.UUID
|
||||
Oper string
|
||||
PurgedAt time.Time
|
||||
Reason string
|
||||
}
|
||||
|
||||
// ChannelRegistry manages registered channels.
|
||||
type ChannelRegistry struct {
|
||||
server *Server
|
||||
func (c *ChannelPurgeRecord) Serialize() ([]byte, error) {
|
||||
return json.Marshal(c)
|
||||
}
|
||||
|
||||
// NewChannelRegistry returns a new ChannelRegistry.
|
||||
func (reg *ChannelRegistry) Initialize(server *Server) {
|
||||
reg.server = server
|
||||
}
|
||||
|
||||
// AllChannels returns the uncasefolded names of all registered channels.
|
||||
func (reg *ChannelRegistry) AllChannels() (result []string) {
|
||||
prefix := fmt.Sprintf(keyChannelName, "")
|
||||
reg.server.store.View(func(tx *buntdb.Tx) error {
|
||||
return tx.AscendGreaterOrEqual("", prefix, func(key, value string) bool {
|
||||
if !strings.HasPrefix(key, prefix) {
|
||||
return false
|
||||
}
|
||||
result = append(result, value)
|
||||
return true
|
||||
})
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// PurgedChannels returns the set of all casefolded channel names that have been purged
|
||||
func (reg *ChannelRegistry) PurgedChannels() (result utils.HashSet[string]) {
|
||||
result = make(utils.HashSet[string])
|
||||
|
||||
prefix := fmt.Sprintf(keyChannelPurged, "")
|
||||
reg.server.store.View(func(tx *buntdb.Tx) error {
|
||||
return tx.AscendGreaterOrEqual("", prefix, func(key, value string) bool {
|
||||
if !strings.HasPrefix(key, prefix) {
|
||||
return false
|
||||
}
|
||||
channel := strings.TrimPrefix(key, prefix)
|
||||
result.Add(channel)
|
||||
return true
|
||||
})
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// StoreChannel obtains a consistent view of a channel, then persists it to the store.
|
||||
func (reg *ChannelRegistry) StoreChannel(info RegisteredChannel, includeFlags uint) (err error) {
|
||||
if !reg.server.ChannelRegistrationEnabled() {
|
||||
return
|
||||
}
|
||||
|
||||
if info.Founder == "" {
|
||||
// sanity check, don't try to store an unregistered channel
|
||||
return
|
||||
}
|
||||
|
||||
reg.server.store.Update(func(tx *buntdb.Tx) error {
|
||||
reg.saveChannel(tx, info, includeFlags)
|
||||
return nil
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadChannel loads a channel from the store.
|
||||
func (reg *ChannelRegistry) LoadChannel(nameCasefolded string) (info RegisteredChannel, err error) {
|
||||
if !reg.server.ChannelRegistrationEnabled() {
|
||||
err = errFeatureDisabled
|
||||
return
|
||||
}
|
||||
|
||||
channelKey := nameCasefolded
|
||||
// nice to have: do all JSON (de)serialization outside of the buntdb transaction
|
||||
err = reg.server.store.View(func(tx *buntdb.Tx) error {
|
||||
_, dberr := tx.Get(fmt.Sprintf(keyChannelExists, channelKey))
|
||||
if dberr == buntdb.ErrNotFound {
|
||||
// chan does not already exist, return
|
||||
return errNoSuchChannel
|
||||
}
|
||||
|
||||
// channel exists, load it
|
||||
name, _ := tx.Get(fmt.Sprintf(keyChannelName, channelKey))
|
||||
regTime, _ := tx.Get(fmt.Sprintf(keyChannelRegTime, channelKey))
|
||||
regTimeInt, _ := strconv.ParseInt(regTime, 10, 64)
|
||||
founder, _ := tx.Get(fmt.Sprintf(keyChannelFounder, channelKey))
|
||||
topic, _ := tx.Get(fmt.Sprintf(keyChannelTopic, channelKey))
|
||||
topicSetBy, _ := tx.Get(fmt.Sprintf(keyChannelTopicSetBy, channelKey))
|
||||
var topicSetTime time.Time
|
||||
topicSetTimeStr, _ := tx.Get(fmt.Sprintf(keyChannelTopicSetTime, channelKey))
|
||||
if topicSetTimeInt, topicSetTimeErr := strconv.ParseInt(topicSetTimeStr, 10, 64); topicSetTimeErr == nil {
|
||||
topicSetTime = time.Unix(0, topicSetTimeInt).UTC()
|
||||
}
|
||||
password, _ := tx.Get(fmt.Sprintf(keyChannelPassword, channelKey))
|
||||
modeString, _ := tx.Get(fmt.Sprintf(keyChannelModes, channelKey))
|
||||
userLimitString, _ := tx.Get(fmt.Sprintf(keyChannelUserLimit, channelKey))
|
||||
forward, _ := tx.Get(fmt.Sprintf(keyChannelForward, channelKey))
|
||||
banlistString, _ := tx.Get(fmt.Sprintf(keyChannelBanlist, channelKey))
|
||||
exceptlistString, _ := tx.Get(fmt.Sprintf(keyChannelExceptlist, channelKey))
|
||||
invitelistString, _ := tx.Get(fmt.Sprintf(keyChannelInvitelist, channelKey))
|
||||
accountToUModeString, _ := tx.Get(fmt.Sprintf(keyChannelAccountToUMode, channelKey))
|
||||
settingsString, _ := tx.Get(fmt.Sprintf(keyChannelSettings, channelKey))
|
||||
|
||||
modeSlice := make([]modes.Mode, len(modeString))
|
||||
for i, mode := range modeString {
|
||||
modeSlice[i] = modes.Mode(mode)
|
||||
}
|
||||
|
||||
userLimit, _ := strconv.Atoi(userLimitString)
|
||||
|
||||
var banlist map[string]MaskInfo
|
||||
_ = json.Unmarshal([]byte(banlistString), &banlist)
|
||||
var exceptlist map[string]MaskInfo
|
||||
_ = json.Unmarshal([]byte(exceptlistString), &exceptlist)
|
||||
var invitelist map[string]MaskInfo
|
||||
_ = json.Unmarshal([]byte(invitelistString), &invitelist)
|
||||
accountToUMode := make(map[string]modes.Mode)
|
||||
_ = json.Unmarshal([]byte(accountToUModeString), &accountToUMode)
|
||||
|
||||
var settings ChannelSettings
|
||||
_ = json.Unmarshal([]byte(settingsString), &settings)
|
||||
|
||||
info = RegisteredChannel{
|
||||
Name: name,
|
||||
NameCasefolded: nameCasefolded,
|
||||
RegisteredAt: time.Unix(0, regTimeInt).UTC(),
|
||||
Founder: founder,
|
||||
Topic: topic,
|
||||
TopicSetBy: topicSetBy,
|
||||
TopicSetTime: topicSetTime,
|
||||
Key: password,
|
||||
Modes: modeSlice,
|
||||
Bans: banlist,
|
||||
Excepts: exceptlist,
|
||||
Invites: invitelist,
|
||||
AccountToUMode: accountToUMode,
|
||||
UserLimit: int(userLimit),
|
||||
Settings: settings,
|
||||
Forward: forward,
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Delete deletes a channel corresponding to `info`. If no such channel
|
||||
// is present in the database, no error is returned.
|
||||
func (reg *ChannelRegistry) Delete(info RegisteredChannel) (err error) {
|
||||
if !reg.server.ChannelRegistrationEnabled() {
|
||||
return
|
||||
}
|
||||
|
||||
reg.server.store.Update(func(tx *buntdb.Tx) error {
|
||||
reg.deleteChannel(tx, info.NameCasefolded, info)
|
||||
return nil
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// delete a channel, unless it was overwritten by another registration of the same channel
|
||||
func (reg *ChannelRegistry) deleteChannel(tx *buntdb.Tx, key string, info RegisteredChannel) {
|
||||
_, err := tx.Get(fmt.Sprintf(keyChannelExists, key))
|
||||
if err == nil {
|
||||
regTime, _ := tx.Get(fmt.Sprintf(keyChannelRegTime, key))
|
||||
regTimeInt, _ := strconv.ParseInt(regTime, 10, 64)
|
||||
registeredAt := time.Unix(0, regTimeInt).UTC()
|
||||
founder, _ := tx.Get(fmt.Sprintf(keyChannelFounder, key))
|
||||
|
||||
// to see if we're deleting the right channel, confirm the founder and the registration time
|
||||
if founder == info.Founder && registeredAt.Equal(info.RegisteredAt) {
|
||||
for _, keyFmt := range channelKeyStrings {
|
||||
tx.Delete(fmt.Sprintf(keyFmt, key))
|
||||
}
|
||||
|
||||
// remove this channel from the client's list of registered channels
|
||||
channelsKey := fmt.Sprintf(keyAccountChannels, info.Founder)
|
||||
channelsStr, err := tx.Get(channelsKey)
|
||||
if err == buntdb.ErrNotFound {
|
||||
return
|
||||
}
|
||||
registeredChannels := unmarshalRegisteredChannels(channelsStr)
|
||||
var nowRegisteredChannels []string
|
||||
for _, channel := range registeredChannels {
|
||||
if channel != key {
|
||||
nowRegisteredChannels = append(nowRegisteredChannels, channel)
|
||||
}
|
||||
}
|
||||
tx.Set(channelsKey, strings.Join(nowRegisteredChannels, ","), nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (reg *ChannelRegistry) updateAccountToChannelMapping(tx *buntdb.Tx, channelInfo RegisteredChannel) {
|
||||
channelKey := channelInfo.NameCasefolded
|
||||
chanFounderKey := fmt.Sprintf(keyChannelFounder, channelKey)
|
||||
founder, existsErr := tx.Get(chanFounderKey)
|
||||
if existsErr == buntdb.ErrNotFound || founder != channelInfo.Founder {
|
||||
// add to new founder's list
|
||||
accountChannelsKey := fmt.Sprintf(keyAccountChannels, channelInfo.Founder)
|
||||
alreadyChannels, _ := tx.Get(accountChannelsKey)
|
||||
newChannels := channelKey // this is the casefolded channel name
|
||||
if alreadyChannels != "" {
|
||||
newChannels = fmt.Sprintf("%s,%s", alreadyChannels, newChannels)
|
||||
}
|
||||
tx.Set(accountChannelsKey, newChannels, nil)
|
||||
}
|
||||
if existsErr == nil && founder != channelInfo.Founder {
|
||||
// remove from old founder's list
|
||||
accountChannelsKey := fmt.Sprintf(keyAccountChannels, founder)
|
||||
alreadyChannelsRaw, _ := tx.Get(accountChannelsKey)
|
||||
var newChannels []string
|
||||
if alreadyChannelsRaw != "" {
|
||||
for _, chname := range strings.Split(alreadyChannelsRaw, ",") {
|
||||
if chname != channelInfo.NameCasefolded {
|
||||
newChannels = append(newChannels, chname)
|
||||
}
|
||||
}
|
||||
}
|
||||
tx.Set(accountChannelsKey, strings.Join(newChannels, ","), nil)
|
||||
}
|
||||
}
|
||||
|
||||
// saveChannel saves a channel to the store.
|
||||
func (reg *ChannelRegistry) saveChannel(tx *buntdb.Tx, channelInfo RegisteredChannel, includeFlags uint) {
|
||||
channelKey := channelInfo.NameCasefolded
|
||||
// maintain the mapping of account -> registered channels
|
||||
reg.updateAccountToChannelMapping(tx, channelInfo)
|
||||
|
||||
if includeFlags&IncludeInitial != 0 {
|
||||
tx.Set(fmt.Sprintf(keyChannelExists, channelKey), "1", nil)
|
||||
tx.Set(fmt.Sprintf(keyChannelName, channelKey), channelInfo.Name, nil)
|
||||
tx.Set(fmt.Sprintf(keyChannelRegTime, channelKey), strconv.FormatInt(channelInfo.RegisteredAt.UnixNano(), 10), nil)
|
||||
tx.Set(fmt.Sprintf(keyChannelFounder, channelKey), channelInfo.Founder, nil)
|
||||
}
|
||||
|
||||
if includeFlags&IncludeTopic != 0 {
|
||||
tx.Set(fmt.Sprintf(keyChannelTopic, channelKey), channelInfo.Topic, nil)
|
||||
var topicSetTimeStr string
|
||||
if !channelInfo.TopicSetTime.IsZero() {
|
||||
topicSetTimeStr = strconv.FormatInt(channelInfo.TopicSetTime.UnixNano(), 10)
|
||||
}
|
||||
tx.Set(fmt.Sprintf(keyChannelTopicSetTime, channelKey), topicSetTimeStr, nil)
|
||||
tx.Set(fmt.Sprintf(keyChannelTopicSetBy, channelKey), channelInfo.TopicSetBy, nil)
|
||||
}
|
||||
|
||||
if includeFlags&IncludeModes != 0 {
|
||||
tx.Set(fmt.Sprintf(keyChannelPassword, channelKey), channelInfo.Key, nil)
|
||||
modeString := modes.Modes(channelInfo.Modes).String()
|
||||
tx.Set(fmt.Sprintf(keyChannelModes, channelKey), modeString, nil)
|
||||
tx.Set(fmt.Sprintf(keyChannelUserLimit, channelKey), strconv.Itoa(channelInfo.UserLimit), nil)
|
||||
tx.Set(fmt.Sprintf(keyChannelForward, channelKey), channelInfo.Forward, nil)
|
||||
}
|
||||
|
||||
if includeFlags&IncludeLists != 0 {
|
||||
banlistString, _ := json.Marshal(channelInfo.Bans)
|
||||
tx.Set(fmt.Sprintf(keyChannelBanlist, channelKey), string(banlistString), nil)
|
||||
exceptlistString, _ := json.Marshal(channelInfo.Excepts)
|
||||
tx.Set(fmt.Sprintf(keyChannelExceptlist, channelKey), string(exceptlistString), nil)
|
||||
invitelistString, _ := json.Marshal(channelInfo.Invites)
|
||||
tx.Set(fmt.Sprintf(keyChannelInvitelist, channelKey), string(invitelistString), nil)
|
||||
accountToUModeString, _ := json.Marshal(channelInfo.AccountToUMode)
|
||||
tx.Set(fmt.Sprintf(keyChannelAccountToUMode, channelKey), string(accountToUModeString), nil)
|
||||
}
|
||||
|
||||
if includeFlags&IncludeSettings != 0 {
|
||||
settingsString, _ := json.Marshal(channelInfo.Settings)
|
||||
tx.Set(fmt.Sprintf(keyChannelSettings, channelKey), string(settingsString), nil)
|
||||
}
|
||||
}
|
||||
|
||||
// PurgeChannel records a channel purge.
|
||||
func (reg *ChannelRegistry) PurgeChannel(chname string, record ChannelPurgeRecord) (err error) {
|
||||
serialized, err := json.Marshal(record)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
serializedStr := string(serialized)
|
||||
key := fmt.Sprintf(keyChannelPurged, chname)
|
||||
|
||||
return reg.server.store.Update(func(tx *buntdb.Tx) error {
|
||||
tx.Set(key, serializedStr, nil)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// LoadPurgeRecord retrieves information about whether and how a channel was purged.
|
||||
func (reg *ChannelRegistry) LoadPurgeRecord(chname string) (record ChannelPurgeRecord, err error) {
|
||||
var rawRecord string
|
||||
key := fmt.Sprintf(keyChannelPurged, chname)
|
||||
reg.server.store.View(func(tx *buntdb.Tx) error {
|
||||
rawRecord, _ = tx.Get(key)
|
||||
return nil
|
||||
})
|
||||
if rawRecord == "" {
|
||||
err = errNoSuchChannel
|
||||
return
|
||||
}
|
||||
err = json.Unmarshal([]byte(rawRecord), &record)
|
||||
if err != nil {
|
||||
reg.server.logger.Error("internal", "corrupt purge record", chname, err.Error())
|
||||
err = errNoSuchChannel
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// UnpurgeChannel deletes the record of a channel purge.
|
||||
func (reg *ChannelRegistry) UnpurgeChannel(chname string) (err error) {
|
||||
key := fmt.Sprintf(keyChannelPurged, chname)
|
||||
return reg.server.store.Update(func(tx *buntdb.Tx) error {
|
||||
tx.Delete(key)
|
||||
return nil
|
||||
})
|
||||
func (c *ChannelPurgeRecord) Deserialize(b []byte) error {
|
||||
return json.Unmarshal(b, c)
|
||||
}
|
||||
|
@ -459,7 +459,7 @@ func csRegisterHandler(service *ircService, server *Server, client *Client, comm
|
||||
// check whether a client has already registered too many channels
|
||||
func checkChanLimit(service *ircService, client *Client, rb *ResponseBuffer) (ok bool) {
|
||||
account := client.Account()
|
||||
channelsAlreadyRegistered := client.server.accounts.ChannelsForAccount(account)
|
||||
channelsAlreadyRegistered := client.server.channels.ChannelsForAccount(account)
|
||||
ok = len(channelsAlreadyRegistered) < client.server.Config().Channels.Registration.MaxChannelsPerAccount || client.HasRoleCapabs("chanreg")
|
||||
if !ok {
|
||||
service.Notice(rb, client.t("You have already registered the maximum number of channels; try dropping some with /CS UNREGISTER"))
|
||||
@ -496,8 +496,8 @@ func csUnregisterHandler(service *ircService, server *Server, client *Client, co
|
||||
return
|
||||
}
|
||||
|
||||
info := channel.ExportRegistration(0)
|
||||
channelKey := info.NameCasefolded
|
||||
info := channel.exportSummary()
|
||||
channelKey := channel.NameCasefolded()
|
||||
if !csPrivsCheck(service, info, client, rb) {
|
||||
return
|
||||
}
|
||||
@ -519,7 +519,7 @@ func csClearHandler(service *ircService, server *Server, client *Client, command
|
||||
service.Notice(rb, client.t("Channel does not exist"))
|
||||
return
|
||||
}
|
||||
if !csPrivsCheck(service, channel.ExportRegistration(0), client, rb) {
|
||||
if !csPrivsCheck(service, channel.exportSummary(), client, rb) {
|
||||
return
|
||||
}
|
||||
|
||||
@ -550,7 +550,7 @@ func csTransferHandler(service *ircService, server *Server, client *Client, comm
|
||||
service.Notice(rb, client.t("Channel does not exist"))
|
||||
return
|
||||
}
|
||||
regInfo := channel.ExportRegistration(0)
|
||||
regInfo := channel.exportSummary()
|
||||
chname = regInfo.Name
|
||||
account := client.Account()
|
||||
isFounder := account != "" && account == regInfo.Founder
|
||||
@ -729,11 +729,6 @@ func csPurgeListHandler(service *ircService, client *Client, rb *ResponseBuffer)
|
||||
}
|
||||
|
||||
func csListHandler(service *ircService, server *Server, client *Client, command string, params []string, rb *ResponseBuffer) {
|
||||
if !client.HasRoleCapabs("chanreg") {
|
||||
service.Notice(rb, client.t("Insufficient privileges"))
|
||||
return
|
||||
}
|
||||
|
||||
var searchRegex *regexp.Regexp
|
||||
if len(params) > 0 {
|
||||
var err error
|
||||
@ -746,7 +741,7 @@ func csListHandler(service *ircService, server *Server, client *Client, command
|
||||
|
||||
service.Notice(rb, ircfmt.Unescape(client.t("*** $bChanServ LIST$b ***")))
|
||||
|
||||
channels := server.channelRegistry.AllChannels()
|
||||
channels := server.channels.AllRegisteredChannels()
|
||||
for _, channel := range channels {
|
||||
if searchRegex == nil || searchRegex.MatchString(channel) {
|
||||
service.Notice(rb, fmt.Sprintf(" %s", channel))
|
||||
@ -771,7 +766,7 @@ func csInfoHandler(service *ircService, server *Server, client *Client, command
|
||||
|
||||
// purge status
|
||||
if client.HasRoleCapabs("chanreg") {
|
||||
purgeRecord, err := server.channelRegistry.LoadPurgeRecord(chname)
|
||||
purgeRecord, err := server.channels.LoadPurgeRecord(chname)
|
||||
if err == nil {
|
||||
service.Notice(rb, fmt.Sprintf(client.t("Channel %s was purged by the server operators and cannot be used"), chname))
|
||||
service.Notice(rb, fmt.Sprintf(client.t("Purged by operator: %s"), purgeRecord.Oper))
|
||||
@ -789,13 +784,7 @@ func csInfoHandler(service *ircService, server *Server, client *Client, command
|
||||
var chinfo RegisteredChannel
|
||||
channel := server.channels.Get(params[0])
|
||||
if channel != nil {
|
||||
chinfo = channel.ExportRegistration(0)
|
||||
} else {
|
||||
chinfo, err = server.channelRegistry.LoadChannel(chname)
|
||||
if err != nil && !(err == errNoSuchChannel || err == errFeatureDisabled) {
|
||||
service.Notice(rb, client.t("An error occurred"))
|
||||
return
|
||||
}
|
||||
chinfo = channel.exportSummary()
|
||||
}
|
||||
|
||||
// channel exists but is unregistered, or doesn't exist:
|
||||
@ -835,12 +824,12 @@ func csGetHandler(service *ircService, server *Server, client *Client, command s
|
||||
service.Notice(rb, client.t("No such channel"))
|
||||
return
|
||||
}
|
||||
info := channel.ExportRegistration(IncludeSettings)
|
||||
info := channel.exportSummary()
|
||||
if !csPrivsCheck(service, info, client, rb) {
|
||||
return
|
||||
}
|
||||
|
||||
displayChannelSetting(service, setting, info.Settings, client, rb)
|
||||
displayChannelSetting(service, setting, channel.Settings(), client, rb)
|
||||
}
|
||||
|
||||
func csSetHandler(service *ircService, server *Server, client *Client, command string, params []string, rb *ResponseBuffer) {
|
||||
@ -850,12 +839,12 @@ func csSetHandler(service *ircService, server *Server, client *Client, command s
|
||||
service.Notice(rb, client.t("No such channel"))
|
||||
return
|
||||
}
|
||||
info := channel.ExportRegistration(IncludeSettings)
|
||||
settings := info.Settings
|
||||
info := channel.exportSummary()
|
||||
if !csPrivsCheck(service, info, client, rb) {
|
||||
return
|
||||
}
|
||||
|
||||
settings := channel.Settings()
|
||||
var err error
|
||||
switch strings.ToLower(setting) {
|
||||
case "history":
|
||||
|
159
irc/database.go
159
irc/database.go
@ -14,6 +14,8 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ergochat/ergo/irc/bunt"
|
||||
"github.com/ergochat/ergo/irc/datastore"
|
||||
"github.com/ergochat/ergo/irc/modes"
|
||||
"github.com/ergochat/ergo/irc/utils"
|
||||
|
||||
@ -21,12 +23,19 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
// 'version' of the database schema
|
||||
keySchemaVersion = "db.version"
|
||||
// latest schema of the db
|
||||
latestDbSchema = 22
|
||||
// TODO migrate metadata keys as well
|
||||
|
||||
keyCloakSecret = "crypto.cloak_secret"
|
||||
// 'version' of the database schema
|
||||
// latest schema of the db
|
||||
latestDbSchema = 23
|
||||
)
|
||||
|
||||
var (
|
||||
schemaVersionUUID = utils.UUID{0, 255, 85, 13, 212, 10, 191, 121, 245, 152, 142, 89, 97, 141, 219, 87} // AP9VDdQKv3n1mI5ZYY3bVw
|
||||
cloakSecretUUID = utils.UUID{170, 214, 184, 208, 116, 181, 67, 75, 161, 23, 233, 16, 113, 251, 94, 229} // qta40HS1Q0uhF-kQcfte5Q
|
||||
|
||||
keySchemaVersion = bunt.BuntKey(datastore.TableMetadata, schemaVersionUUID)
|
||||
keyCloakSecret = bunt.BuntKey(datastore.TableMetadata, cloakSecretUUID)
|
||||
)
|
||||
|
||||
type SchemaChanger func(*Config, *buntdb.Tx) error
|
||||
@ -99,10 +108,7 @@ func openDatabaseInternal(config *Config, allowAutoupgrade bool) (db *buntdb.DB,
|
||||
// read the current version string
|
||||
var version int
|
||||
err = db.View(func(tx *buntdb.Tx) (err error) {
|
||||
vStr, err := tx.Get(keySchemaVersion)
|
||||
if err == nil {
|
||||
version, err = strconv.Atoi(vStr)
|
||||
}
|
||||
version, err = retrieveSchemaVersion(tx)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
@ -130,6 +136,17 @@ func openDatabaseInternal(config *Config, allowAutoupgrade bool) (db *buntdb.DB,
|
||||
}
|
||||
}
|
||||
|
||||
func retrieveSchemaVersion(tx *buntdb.Tx) (version int, err error) {
|
||||
if val, err := tx.Get(keySchemaVersion); err == nil {
|
||||
return strconv.Atoi(val)
|
||||
}
|
||||
// legacy key:
|
||||
if val, err := tx.Get("db.version"); err == nil {
|
||||
return strconv.Atoi(val)
|
||||
}
|
||||
return 0, buntdb.ErrNotFound
|
||||
}
|
||||
|
||||
func performAutoUpgrade(currentVersion int, config *Config) (err error) {
|
||||
path := config.Datastore.Path
|
||||
log.Printf("attempting to auto-upgrade schema from version %d to %d\n", currentVersion, latestDbSchema)
|
||||
@ -167,8 +184,12 @@ func UpgradeDB(config *Config) (err error) {
|
||||
var version int
|
||||
err = store.Update(func(tx *buntdb.Tx) error {
|
||||
for {
|
||||
vStr, _ := tx.Get(keySchemaVersion)
|
||||
version, _ = strconv.Atoi(vStr)
|
||||
if version == 0 {
|
||||
version, err = retrieveSchemaVersion(tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if version == latestDbSchema {
|
||||
// success!
|
||||
break
|
||||
@ -183,11 +204,12 @@ func UpgradeDB(config *Config) (err error) {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, _, err = tx.Set(keySchemaVersion, strconv.Itoa(change.TargetVersion), nil)
|
||||
version = change.TargetVersion
|
||||
_, _, err = tx.Set(keySchemaVersion, strconv.Itoa(version), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Printf("successfully updated schema to version %d\n", change.TargetVersion)
|
||||
log.Printf("successfully updated schema to version %d\n", version)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
@ -198,19 +220,17 @@ func UpgradeDB(config *Config) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
func LoadCloakSecret(db *buntdb.DB) (result string) {
|
||||
db.View(func(tx *buntdb.Tx) error {
|
||||
result, _ = tx.Get(keyCloakSecret)
|
||||
return nil
|
||||
})
|
||||
return
|
||||
func LoadCloakSecret(dstore datastore.Datastore) (result string, err error) {
|
||||
val, err := dstore.Get(datastore.TableMetadata, cloakSecretUUID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return string(val), nil
|
||||
}
|
||||
|
||||
func StoreCloakSecret(db *buntdb.DB, secret string) {
|
||||
db.Update(func(tx *buntdb.Tx) error {
|
||||
tx.Set(keyCloakSecret, secret, nil)
|
||||
return nil
|
||||
})
|
||||
func StoreCloakSecret(dstore datastore.Datastore, secret string) {
|
||||
// TODO error checking
|
||||
dstore.Set(datastore.TableMetadata, cloakSecretUUID, []byte(secret), time.Time{})
|
||||
}
|
||||
|
||||
func schemaChangeV1toV2(config *Config, tx *buntdb.Tx) error {
|
||||
@ -1112,6 +1132,92 @@ func schemaChangeV21To22(config *Config, tx *buntdb.Tx) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// first phase of document-oriented database refactor: channels
|
||||
func schemaChangeV22ToV23(config *Config, tx *buntdb.Tx) error {
|
||||
keyChannelExists := "channel.exists "
|
||||
var channelNames []string
|
||||
tx.AscendGreaterOrEqual("", keyChannelExists, func(key, value string) bool {
|
||||
if !strings.HasPrefix(key, keyChannelExists) {
|
||||
return false
|
||||
}
|
||||
channelNames = append(channelNames, strings.TrimPrefix(key, keyChannelExists))
|
||||
return true
|
||||
})
|
||||
for _, channelName := range channelNames {
|
||||
channel, err := loadLegacyChannel(tx, channelName)
|
||||
if err != nil {
|
||||
log.Printf("error loading legacy channel %s: %v", channelName, err)
|
||||
continue
|
||||
}
|
||||
channel.UUID = utils.GenerateUUIDv4()
|
||||
newKey := bunt.BuntKey(datastore.TableChannels, channel.UUID)
|
||||
j, err := json.Marshal(channel)
|
||||
if err != nil {
|
||||
log.Printf("error marshaling channel %s: %v", channelName, err)
|
||||
continue
|
||||
}
|
||||
tx.Set(newKey, string(j), nil)
|
||||
deleteLegacyChannel(tx, channelName)
|
||||
}
|
||||
|
||||
// purges
|
||||
keyChannelPurged := "channel.purged "
|
||||
var purgeKeys []string
|
||||
var channelPurges []ChannelPurgeRecord
|
||||
tx.AscendGreaterOrEqual("", keyChannelPurged, func(key, value string) bool {
|
||||
if !strings.HasPrefix(key, keyChannelPurged) {
|
||||
return false
|
||||
}
|
||||
purgeKeys = append(purgeKeys, key)
|
||||
cfname := strings.TrimPrefix(key, keyChannelPurged)
|
||||
var record ChannelPurgeRecord
|
||||
err := json.Unmarshal([]byte(value), &record)
|
||||
if err != nil {
|
||||
log.Printf("error unmarshaling channel purge for %s: %v", cfname, err)
|
||||
return true
|
||||
}
|
||||
record.NameCasefolded = cfname
|
||||
record.UUID = utils.GenerateUUIDv4()
|
||||
channelPurges = append(channelPurges, record)
|
||||
return true
|
||||
})
|
||||
for _, record := range channelPurges {
|
||||
newKey := bunt.BuntKey(datastore.TableChannelPurges, record.UUID)
|
||||
j, err := json.Marshal(record)
|
||||
if err != nil {
|
||||
log.Printf("error marshaling channel purge %s: %v", record.NameCasefolded, err)
|
||||
continue
|
||||
}
|
||||
tx.Set(newKey, string(j), nil)
|
||||
}
|
||||
for _, purgeKey := range purgeKeys {
|
||||
tx.Delete(purgeKey)
|
||||
}
|
||||
|
||||
// clean up denormalized account-to-channels mapping
|
||||
keyAccountChannels := "account.channels "
|
||||
var accountToChannels []string
|
||||
tx.AscendGreaterOrEqual("", keyAccountChannels, func(key, value string) bool {
|
||||
if !strings.HasPrefix(key, keyAccountChannels) {
|
||||
return false
|
||||
}
|
||||
accountToChannels = append(accountToChannels, key)
|
||||
return true
|
||||
})
|
||||
for _, key := range accountToChannels {
|
||||
tx.Delete(key)
|
||||
}
|
||||
|
||||
// migrate cloak secret
|
||||
val, _ := tx.Get("crypto.cloak_secret")
|
||||
tx.Set(keyCloakSecret, val, nil)
|
||||
|
||||
// bump the legacy version key to mark the database as downgrade-incompatible
|
||||
tx.Set("db.version", "23", nil)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getSchemaChange(initialVersion int) (result SchemaChange, ok bool) {
|
||||
for _, change := range allChanges {
|
||||
if initialVersion == change.InitialVersion {
|
||||
@ -1227,4 +1333,9 @@ var allChanges = []SchemaChange{
|
||||
TargetVersion: 22,
|
||||
Changer: schemaChangeV21To22,
|
||||
},
|
||||
{
|
||||
InitialVersion: 22,
|
||||
TargetVersion: 23,
|
||||
Changer: schemaChangeV22ToV23,
|
||||
},
|
||||
}
|
||||
|
45
irc/datastore/datastore.go
Normal file
45
irc/datastore/datastore.go
Normal file
@ -0,0 +1,45 @@
|
||||
// Copyright (c) 2022 Shivaram Lingamneni <slingamn@cs.stanford.edu>
|
||||
// released under the MIT license
|
||||
|
||||
package datastore
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/ergochat/ergo/irc/utils"
|
||||
)
|
||||
|
||||
type Table uint16
|
||||
|
||||
// XXX these are persisted and must remain stable;
|
||||
// do not reorder, when deleting use _ to ensure that the deleted value is skipped
|
||||
const (
|
||||
TableMetadata Table = iota
|
||||
TableChannels
|
||||
TableChannelPurges
|
||||
)
|
||||
|
||||
type KV struct {
|
||||
UUID utils.UUID
|
||||
Value []byte
|
||||
}
|
||||
|
||||
// A Datastore provides the following abstraction:
|
||||
// 1. Tables, each keyed on a UUID (the implementation is free to merge
|
||||
// the table name and the UUID into a single key as long as the rest of
|
||||
// the contract can be satisfied). Table names are [a-z0-9_]+
|
||||
// 2. The ability to efficiently enumerate all uuid-value pairs in a table
|
||||
// 3. Gets, sets, and deletes for individual (table, uuid) keys
|
||||
type Datastore interface {
|
||||
Backoff() time.Duration
|
||||
|
||||
GetAll(table Table) ([]KV, error)
|
||||
|
||||
// This is rarely used because it would typically lead to TOCTOU races
|
||||
Get(table Table, key utils.UUID) (value []byte, err error)
|
||||
|
||||
Set(table Table, key utils.UUID, value []byte, expiration time.Time) error
|
||||
|
||||
// Note that deleting a nonexistent key is not considered an error
|
||||
Delete(table Table, key utils.UUID) error
|
||||
}
|
@ -51,6 +51,7 @@ var (
|
||||
errNoExistingBan = errors.New("Ban does not exist")
|
||||
errNoSuchChannel = errors.New(`No such channel`)
|
||||
errChannelPurged = errors.New(`This channel was purged by the server operators and cannot be used`)
|
||||
errChannelPurgedAlready = errors.New(`This channel was already purged and cannot be purged again`)
|
||||
errConfusableIdentifier = errors.New("This identifier is confusable with one already in use")
|
||||
errInsufficientPrivs = errors.New("Insufficient privileges")
|
||||
errInvalidUsername = errors.New("Invalid username")
|
||||
|
@ -638,3 +638,9 @@ func (channel *Channel) getAmode(cfaccount string) (result modes.Mode) {
|
||||
defer channel.stateMutex.RUnlock()
|
||||
return channel.accountToUMode[cfaccount]
|
||||
}
|
||||
|
||||
func (channel *Channel) UUID() utils.UUID {
|
||||
channel.stateMutex.RLock()
|
||||
defer channel.stateMutex.RUnlock()
|
||||
return channel.uuid
|
||||
}
|
||||
|
@ -1718,7 +1718,7 @@ func listHandler(server *Server, client *Client, msg ircmsg.Message, rb *Respons
|
||||
|
||||
clientIsOp := client.HasRoleCapabs("sajoin")
|
||||
if len(channels) == 0 {
|
||||
for _, channel := range server.channels.Channels() {
|
||||
for _, channel := range server.channels.ListableChannels() {
|
||||
if !clientIsOp && channel.flags.HasMode(modes.Secret) && !channel.hasClient(client) {
|
||||
continue
|
||||
}
|
||||
|
@ -193,6 +193,6 @@ func hsSetCloakSecretHandler(service *ircService, server *Server, client *Client
|
||||
service.Notice(rb, fmt.Sprintf(client.t("To confirm, run this command: %s"), fmt.Sprintf("/HS SETCLOAKSECRET %s %s", secret, expectedCode)))
|
||||
return
|
||||
}
|
||||
StoreCloakSecret(server.store, secret)
|
||||
StoreCloakSecret(server.dstore, secret)
|
||||
service.Notice(rb, client.t("Rotated the cloak secret; you must rehash or restart the server for it to take effect"))
|
||||
}
|
||||
|
@ -9,9 +9,13 @@ import (
|
||||
"log"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/tidwall/buntdb"
|
||||
|
||||
"github.com/ergochat/ergo/irc/bunt"
|
||||
"github.com/ergochat/ergo/irc/datastore"
|
||||
"github.com/ergochat/ergo/irc/modes"
|
||||
"github.com/ergochat/ergo/irc/utils"
|
||||
)
|
||||
|
||||
@ -20,7 +24,7 @@ const (
|
||||
// XXX instead of referencing, e.g., keyAccountExists, we should write in the string literal
|
||||
// (to ensure that no matter what code changes happen elsewhere, we're still producing a
|
||||
// db of the hardcoded version)
|
||||
importDBSchemaVersion = 22
|
||||
importDBSchemaVersion = 23
|
||||
)
|
||||
|
||||
type userImport struct {
|
||||
@ -54,8 +58,8 @@ type databaseImport struct {
|
||||
Channels map[string]channelImport
|
||||
}
|
||||
|
||||
func serializeAmodes(raw map[string]string, validCfUsernames utils.HashSet[string]) (result []byte, err error) {
|
||||
processed := make(map[string]int, len(raw))
|
||||
func convertAmodes(raw map[string]string, validCfUsernames utils.HashSet[string]) (result map[string]modes.Mode, err error) {
|
||||
result = make(map[string]modes.Mode)
|
||||
for accountName, mode := range raw {
|
||||
if len(mode) != 1 {
|
||||
return nil, fmt.Errorf("invalid mode %s for account %s", mode, accountName)
|
||||
@ -64,10 +68,9 @@ func serializeAmodes(raw map[string]string, validCfUsernames utils.HashSet[strin
|
||||
if err != nil || !validCfUsernames.Has(cfname) {
|
||||
log.Printf("skipping invalid amode recipient %s\n", accountName)
|
||||
} else {
|
||||
processed[cfname] = int(mode[0])
|
||||
result[cfname] = modes.Mode(mode[0])
|
||||
}
|
||||
}
|
||||
result, err = json.Marshal(processed)
|
||||
return
|
||||
}
|
||||
|
||||
@ -147,8 +150,9 @@ func doImportDBGeneric(config *Config, dbImport databaseImport, credsType Creden
|
||||
cfUsernames.Add(cfUsername)
|
||||
}
|
||||
|
||||
// TODO fix this:
|
||||
for chname, chInfo := range dbImport.Channels {
|
||||
cfchname, err := CasefoldChannel(chname)
|
||||
_, err := CasefoldChannel(chname)
|
||||
if err != nil {
|
||||
log.Printf("invalid channel name %s: %v", chname, err)
|
||||
continue
|
||||
@ -158,43 +162,42 @@ func doImportDBGeneric(config *Config, dbImport databaseImport, credsType Creden
|
||||
log.Printf("invalid founder %s for channel %s: %v", chInfo.Founder, chname, err)
|
||||
continue
|
||||
}
|
||||
tx.Set(fmt.Sprintf(keyChannelExists, cfchname), "1", nil)
|
||||
tx.Set(fmt.Sprintf(keyChannelName, cfchname), chname, nil)
|
||||
tx.Set(fmt.Sprintf(keyChannelRegTime, cfchname), strconv.FormatInt(chInfo.RegisteredAt, 10), nil)
|
||||
tx.Set(fmt.Sprintf(keyChannelFounder, cfchname), cffounder, nil)
|
||||
accountChannelsKey := fmt.Sprintf(keyAccountChannels, cffounder)
|
||||
founderChannels, fcErr := tx.Get(accountChannelsKey)
|
||||
if fcErr != nil || founderChannels == "" {
|
||||
founderChannels = cfchname
|
||||
} else {
|
||||
founderChannels = fmt.Sprintf("%s,%s", founderChannels, cfchname)
|
||||
}
|
||||
tx.Set(accountChannelsKey, founderChannels, nil)
|
||||
var regInfo RegisteredChannel
|
||||
regInfo.Name = chname
|
||||
regInfo.UUID = utils.GenerateUUIDv4()
|
||||
regInfo.Founder = cffounder
|
||||
regInfo.RegisteredAt = time.Unix(0, chInfo.RegisteredAt).UTC()
|
||||
if chInfo.Topic != "" {
|
||||
tx.Set(fmt.Sprintf(keyChannelTopic, cfchname), chInfo.Topic, nil)
|
||||
tx.Set(fmt.Sprintf(keyChannelTopicSetTime, cfchname), strconv.FormatInt(chInfo.TopicSetAt, 10), nil)
|
||||
tx.Set(fmt.Sprintf(keyChannelTopicSetBy, cfchname), chInfo.TopicSetBy, nil)
|
||||
regInfo.Topic = chInfo.Topic
|
||||
regInfo.TopicSetBy = chInfo.TopicSetBy
|
||||
regInfo.TopicSetTime = time.Unix(0, chInfo.TopicSetAt).UTC()
|
||||
}
|
||||
|
||||
if len(chInfo.Amode) != 0 {
|
||||
m, err := serializeAmodes(chInfo.Amode, cfUsernames)
|
||||
m, err := convertAmodes(chInfo.Amode, cfUsernames)
|
||||
if err == nil {
|
||||
tx.Set(fmt.Sprintf(keyChannelAccountToUMode, cfchname), string(m), nil)
|
||||
regInfo.AccountToUMode = m
|
||||
} else {
|
||||
log.Printf("couldn't serialize amodes for %s: %v", chname, err)
|
||||
log.Printf("couldn't process amodes for %s: %v", chname, err)
|
||||
}
|
||||
}
|
||||
tx.Set(fmt.Sprintf(keyChannelModes, cfchname), chInfo.Modes, nil)
|
||||
if chInfo.Key != "" {
|
||||
tx.Set(fmt.Sprintf(keyChannelPassword, cfchname), chInfo.Key, nil)
|
||||
for _, mode := range chInfo.Modes {
|
||||
regInfo.Modes = append(regInfo.Modes, modes.Mode(mode))
|
||||
}
|
||||
regInfo.Key = chInfo.Key
|
||||
if chInfo.Limit > 0 {
|
||||
tx.Set(fmt.Sprintf(keyChannelUserLimit, cfchname), strconv.Itoa(chInfo.Limit), nil)
|
||||
regInfo.UserLimit = chInfo.Limit
|
||||
}
|
||||
if chInfo.Forward != "" {
|
||||
if _, err := CasefoldChannel(chInfo.Forward); err == nil {
|
||||
tx.Set(fmt.Sprintf(keyChannelForward, cfchname), chInfo.Forward, nil)
|
||||
regInfo.Forward = chInfo.Forward
|
||||
}
|
||||
}
|
||||
if j, err := json.Marshal(regInfo); err == nil {
|
||||
tx.Set(bunt.BuntKey(datastore.TableChannels, regInfo.UUID), string(j), nil)
|
||||
} else {
|
||||
log.Printf("couldn't serialize channel %s: %v", chname, err)
|
||||
}
|
||||
}
|
||||
|
||||
if warnSkeletons {
|
||||
|
121
irc/legacy.go
121
irc/legacy.go
@ -4,7 +4,15 @@ package irc
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/tidwall/buntdb"
|
||||
|
||||
"github.com/ergochat/ergo/irc/modes"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -25,3 +33,116 @@ func decodeLegacyPasswordHash(hash string) ([]byte, error) {
|
||||
return nil, errInvalidPasswordHash
|
||||
}
|
||||
}
|
||||
|
||||
// legacy channel registration code
|
||||
|
||||
const (
|
||||
keyChannelExists = "channel.exists %s"
|
||||
keyChannelName = "channel.name %s" // stores the 'preferred name' of the channel, not casemapped
|
||||
keyChannelRegTime = "channel.registered.time %s"
|
||||
keyChannelFounder = "channel.founder %s"
|
||||
keyChannelTopic = "channel.topic %s"
|
||||
keyChannelTopicSetBy = "channel.topic.setby %s"
|
||||
keyChannelTopicSetTime = "channel.topic.settime %s"
|
||||
keyChannelBanlist = "channel.banlist %s"
|
||||
keyChannelExceptlist = "channel.exceptlist %s"
|
||||
keyChannelInvitelist = "channel.invitelist %s"
|
||||
keyChannelPassword = "channel.key %s"
|
||||
keyChannelModes = "channel.modes %s"
|
||||
keyChannelAccountToUMode = "channel.accounttoumode %s"
|
||||
keyChannelUserLimit = "channel.userlimit %s"
|
||||
keyChannelSettings = "channel.settings %s"
|
||||
keyChannelForward = "channel.forward %s"
|
||||
|
||||
keyChannelPurged = "channel.purged %s"
|
||||
)
|
||||
|
||||
func deleteLegacyChannel(tx *buntdb.Tx, nameCasefolded string) {
|
||||
tx.Delete(fmt.Sprintf(keyChannelExists, nameCasefolded))
|
||||
tx.Delete(fmt.Sprintf(keyChannelName, nameCasefolded))
|
||||
tx.Delete(fmt.Sprintf(keyChannelRegTime, nameCasefolded))
|
||||
tx.Delete(fmt.Sprintf(keyChannelFounder, nameCasefolded))
|
||||
tx.Delete(fmt.Sprintf(keyChannelTopic, nameCasefolded))
|
||||
tx.Delete(fmt.Sprintf(keyChannelTopicSetBy, nameCasefolded))
|
||||
tx.Delete(fmt.Sprintf(keyChannelTopicSetTime, nameCasefolded))
|
||||
tx.Delete(fmt.Sprintf(keyChannelBanlist, nameCasefolded))
|
||||
tx.Delete(fmt.Sprintf(keyChannelExceptlist, nameCasefolded))
|
||||
tx.Delete(fmt.Sprintf(keyChannelInvitelist, nameCasefolded))
|
||||
tx.Delete(fmt.Sprintf(keyChannelPassword, nameCasefolded))
|
||||
tx.Delete(fmt.Sprintf(keyChannelModes, nameCasefolded))
|
||||
tx.Delete(fmt.Sprintf(keyChannelAccountToUMode, nameCasefolded))
|
||||
tx.Delete(fmt.Sprintf(keyChannelUserLimit, nameCasefolded))
|
||||
tx.Delete(fmt.Sprintf(keyChannelSettings, nameCasefolded))
|
||||
tx.Delete(fmt.Sprintf(keyChannelForward, nameCasefolded))
|
||||
}
|
||||
|
||||
func loadLegacyChannel(tx *buntdb.Tx, nameCasefolded string) (info RegisteredChannel, err error) {
|
||||
channelKey := nameCasefolded
|
||||
// nice to have: do all JSON (de)serialization outside of the buntdb transaction
|
||||
_, dberr := tx.Get(fmt.Sprintf(keyChannelExists, channelKey))
|
||||
if dberr == buntdb.ErrNotFound {
|
||||
// chan does not already exist, return
|
||||
err = errNoSuchChannel
|
||||
return
|
||||
}
|
||||
|
||||
// channel exists, load it
|
||||
name, _ := tx.Get(fmt.Sprintf(keyChannelName, channelKey))
|
||||
regTime, _ := tx.Get(fmt.Sprintf(keyChannelRegTime, channelKey))
|
||||
regTimeInt, _ := strconv.ParseInt(regTime, 10, 64)
|
||||
founder, _ := tx.Get(fmt.Sprintf(keyChannelFounder, channelKey))
|
||||
topic, _ := tx.Get(fmt.Sprintf(keyChannelTopic, channelKey))
|
||||
topicSetBy, _ := tx.Get(fmt.Sprintf(keyChannelTopicSetBy, channelKey))
|
||||
var topicSetTime time.Time
|
||||
topicSetTimeStr, _ := tx.Get(fmt.Sprintf(keyChannelTopicSetTime, channelKey))
|
||||
if topicSetTimeInt, topicSetTimeErr := strconv.ParseInt(topicSetTimeStr, 10, 64); topicSetTimeErr == nil {
|
||||
topicSetTime = time.Unix(0, topicSetTimeInt).UTC()
|
||||
}
|
||||
password, _ := tx.Get(fmt.Sprintf(keyChannelPassword, channelKey))
|
||||
modeString, _ := tx.Get(fmt.Sprintf(keyChannelModes, channelKey))
|
||||
userLimitString, _ := tx.Get(fmt.Sprintf(keyChannelUserLimit, channelKey))
|
||||
forward, _ := tx.Get(fmt.Sprintf(keyChannelForward, channelKey))
|
||||
banlistString, _ := tx.Get(fmt.Sprintf(keyChannelBanlist, channelKey))
|
||||
exceptlistString, _ := tx.Get(fmt.Sprintf(keyChannelExceptlist, channelKey))
|
||||
invitelistString, _ := tx.Get(fmt.Sprintf(keyChannelInvitelist, channelKey))
|
||||
accountToUModeString, _ := tx.Get(fmt.Sprintf(keyChannelAccountToUMode, channelKey))
|
||||
settingsString, _ := tx.Get(fmt.Sprintf(keyChannelSettings, channelKey))
|
||||
|
||||
modeSlice := make([]modes.Mode, len(modeString))
|
||||
for i, mode := range modeString {
|
||||
modeSlice[i] = modes.Mode(mode)
|
||||
}
|
||||
|
||||
userLimit, _ := strconv.Atoi(userLimitString)
|
||||
|
||||
var banlist map[string]MaskInfo
|
||||
_ = json.Unmarshal([]byte(banlistString), &banlist)
|
||||
var exceptlist map[string]MaskInfo
|
||||
_ = json.Unmarshal([]byte(exceptlistString), &exceptlist)
|
||||
var invitelist map[string]MaskInfo
|
||||
_ = json.Unmarshal([]byte(invitelistString), &invitelist)
|
||||
accountToUMode := make(map[string]modes.Mode)
|
||||
_ = json.Unmarshal([]byte(accountToUModeString), &accountToUMode)
|
||||
|
||||
var settings ChannelSettings
|
||||
_ = json.Unmarshal([]byte(settingsString), &settings)
|
||||
|
||||
info = RegisteredChannel{
|
||||
Name: name,
|
||||
RegisteredAt: time.Unix(0, regTimeInt).UTC(),
|
||||
Founder: founder,
|
||||
Topic: topic,
|
||||
TopicSetBy: topicSetBy,
|
||||
TopicSetTime: topicSetTime,
|
||||
Key: password,
|
||||
Modes: modeSlice,
|
||||
Bans: banlist,
|
||||
Excepts: exceptlist,
|
||||
Invites: invitelist,
|
||||
AccountToUMode: accountToUMode,
|
||||
UserLimit: int(userLimit),
|
||||
Settings: settings,
|
||||
Forward: forward,
|
||||
}
|
||||
return info, nil
|
||||
}
|
||||
|
@ -954,9 +954,9 @@ func nsInfoHandler(service *ircService, server *Server, client *Client, command
|
||||
|
||||
func listRegisteredChannels(service *ircService, accountName string, rb *ResponseBuffer) {
|
||||
client := rb.session.client
|
||||
channels := client.server.accounts.ChannelsForAccount(accountName)
|
||||
channels := client.server.channels.ChannelsForAccount(accountName)
|
||||
service.Notice(rb, fmt.Sprintf(client.t("Account %s has %d registered channel(s)."), accountName, len(channels)))
|
||||
for _, channel := range rb.session.client.server.accounts.ChannelsForAccount(accountName) {
|
||||
for _, channel := range channels {
|
||||
service.Notice(rb, fmt.Sprintf(client.t("Registered channel: %s"), channel))
|
||||
}
|
||||
}
|
||||
|
37
irc/serde.go
Normal file
37
irc/serde.go
Normal file
@ -0,0 +1,37 @@
|
||||
// Copyright (c) 2022 Shivaram Lingamneni
|
||||
// released under the MIT license
|
||||
|
||||
package irc
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/ergochat/ergo/irc/datastore"
|
||||
"github.com/ergochat/ergo/irc/logger"
|
||||
)
|
||||
|
||||
type Serializable interface {
|
||||
Serialize() ([]byte, error)
|
||||
Deserialize([]byte) error
|
||||
}
|
||||
|
||||
func FetchAndDeserializeAll[T any, C interface {
|
||||
*T
|
||||
Serializable
|
||||
}](table datastore.Table, dstore datastore.Datastore, log *logger.Manager) (result []T, err error) {
|
||||
rawRecords, err := dstore.GetAll(table)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
result = make([]T, len(rawRecords))
|
||||
pos := 0
|
||||
for _, record := range rawRecords {
|
||||
err := C(&result[pos]).Deserialize(record.Value)
|
||||
if err != nil {
|
||||
log.Error("internal", "deserialization error", strconv.Itoa(int(table)), record.UUID.String(), err.Error())
|
||||
continue
|
||||
}
|
||||
pos++
|
||||
}
|
||||
return result[:pos], nil
|
||||
}
|
@ -22,9 +22,12 @@ import (
|
||||
|
||||
"github.com/ergochat/irc-go/ircfmt"
|
||||
"github.com/okzk/sdnotify"
|
||||
"github.com/tidwall/buntdb"
|
||||
|
||||
"github.com/ergochat/ergo/irc/bunt"
|
||||
"github.com/ergochat/ergo/irc/caps"
|
||||
"github.com/ergochat/ergo/irc/connection_limits"
|
||||
"github.com/ergochat/ergo/irc/datastore"
|
||||
"github.com/ergochat/ergo/irc/flatip"
|
||||
"github.com/ergochat/ergo/irc/flock"
|
||||
"github.com/ergochat/ergo/irc/history"
|
||||
@ -33,7 +36,6 @@ import (
|
||||
"github.com/ergochat/ergo/irc/mysql"
|
||||
"github.com/ergochat/ergo/irc/sno"
|
||||
"github.com/ergochat/ergo/irc/utils"
|
||||
"github.com/tidwall/buntdb"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -66,7 +68,6 @@ type Server struct {
|
||||
accepts AcceptManager
|
||||
accounts AccountManager
|
||||
channels ChannelManager
|
||||
channelRegistry ChannelRegistry
|
||||
clients ClientManager
|
||||
config atomic.Pointer[Config]
|
||||
configFilename string
|
||||
@ -87,6 +88,7 @@ type Server struct {
|
||||
tracebackSignal chan os.Signal
|
||||
snomasks SnoManager
|
||||
store *buntdb.DB
|
||||
dstore datastore.Datastore
|
||||
historyDB mysql.MySQL
|
||||
torLimiter connection_limits.TorLimiter
|
||||
whoWas WhoWasList
|
||||
@ -98,6 +100,10 @@ type Server struct {
|
||||
|
||||
// NewServer returns a new Oragono server.
|
||||
func NewServer(config *Config, logger *logger.Manager) (*Server, error) {
|
||||
// sanity check that kernel randomness is available; on modern Linux,
|
||||
// this will block until it is, on other platforms it may panic:
|
||||
utils.GenerateUUIDv4()
|
||||
|
||||
// initialize data structures
|
||||
server := &Server{
|
||||
ctime: time.Now().UTC(),
|
||||
@ -716,7 +722,11 @@ func (server *Server) applyConfig(config *Config) (err error) {
|
||||
// now that the datastore is initialized, we can load the cloak secret from it
|
||||
// XXX this modifies config after the initial load, which is naughty,
|
||||
// but there's no data race because we haven't done SetConfig yet
|
||||
config.Server.Cloaks.SetSecret(LoadCloakSecret(server.store))
|
||||
cloakSecret, err := LoadCloakSecret(server.dstore)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Could not load cloak secret: %w", err)
|
||||
}
|
||||
config.Server.Cloaks.SetSecret(cloakSecret)
|
||||
|
||||
// activate the new config
|
||||
server.config.Store(config)
|
||||
@ -837,6 +847,7 @@ func (server *Server) loadDatastore(config *Config) error {
|
||||
db, err := OpenDatabase(config)
|
||||
if err == nil {
|
||||
server.store = db
|
||||
server.dstore = bunt.NewBuntdbDatastore(db, server.logger)
|
||||
return nil
|
||||
} else {
|
||||
return fmt.Errorf("Failed to open datastore: %s", err.Error())
|
||||
@ -849,8 +860,7 @@ func (server *Server) loadFromDatastore(config *Config) (err error) {
|
||||
server.loadDLines()
|
||||
server.loadKLines()
|
||||
|
||||
server.channelRegistry.Initialize(server)
|
||||
server.channels.Initialize(server)
|
||||
server.channels.Initialize(server, config)
|
||||
server.accounts.Initialize(server)
|
||||
|
||||
if config.Datastore.MySQL.Enabled {
|
||||
|
56
irc/utils/uuid.go
Normal file
56
irc/utils/uuid.go
Normal file
@ -0,0 +1,56 @@
|
||||
// Copyright (c) 2022 Shivaram Lingamneni <slingamn@cs.stanford.edu>
|
||||
// released under the MIT license
|
||||
|
||||
package utils
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidUUID = errors.New("Invalid uuid")
|
||||
)
|
||||
|
||||
// Technically a UUIDv4 has version bits set, but this doesn't matter in practice
|
||||
type UUID [16]byte
|
||||
|
||||
func (u UUID) MarshalJSON() (b []byte, err error) {
|
||||
b = make([]byte, 24)
|
||||
b[0] = '"'
|
||||
base64.RawURLEncoding.Encode(b[1:], u[:])
|
||||
b[23] = '"'
|
||||
return
|
||||
}
|
||||
|
||||
func (u *UUID) UnmarshalJSON(b []byte) (err error) {
|
||||
if len(b) != 24 {
|
||||
return ErrInvalidUUID
|
||||
}
|
||||
readLen, err := base64.RawURLEncoding.Decode(u[:], b[1:23])
|
||||
if readLen != 16 {
|
||||
return ErrInvalidUUID
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *UUID) String() string {
|
||||
return base64.RawURLEncoding.EncodeToString(u[:])
|
||||
}
|
||||
|
||||
func GenerateUUIDv4() (result UUID) {
|
||||
_, err := rand.Read(result[:])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func DecodeUUID(ustr string) (result UUID, err error) {
|
||||
length, err := base64.RawURLEncoding.Decode(result[:], []byte(ustr))
|
||||
if err == nil && length != 16 {
|
||||
err = ErrInvalidUUID
|
||||
}
|
||||
return
|
||||
}
|
Loading…
Reference in New Issue
Block a user