3
0
mirror of https://github.com/ergochat/ergo.git synced 2024-11-13 23:49:30 +01:00

Merge pull request #224 from slingamn/channelkeys.1

Updates to channel persistence
This commit is contained in:
Daniel Oaks 2018-04-16 13:35:48 +10:00 committed by GitHub
commit c75d2c91c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 291 additions and 110 deletions

View File

@ -6,6 +6,7 @@
package irc package irc
import ( import (
"crypto/subtle"
"fmt" "fmt"
"strconv" "strconv"
"time" "time"
@ -36,11 +37,12 @@ type Channel struct {
topicSetBy string topicSetBy string
topicSetTime time.Time topicSetTime time.Time
userLimit uint64 userLimit uint64
accountToUMode map[string]modes.Mode
} }
// NewChannel creates a new channel from a `Server` and a `name` // NewChannel creates a new channel from a `Server` and a `name`
// string, which must be unique on the server. // string, which must be unique on the server.
func NewChannel(s *Server, name string, addDefaultModes bool, regInfo *RegisteredChannel) *Channel { func NewChannel(s *Server, name string, regInfo *RegisteredChannel) *Channel {
casefoldedName, err := CasefoldChannel(name) casefoldedName, err := CasefoldChannel(name)
if err != nil { if err != nil {
s.logger.Error("internal", fmt.Sprintf("Bad channel name %s: %v", name, err)) s.logger.Error("internal", fmt.Sprintf("Bad channel name %s: %v", name, err))
@ -59,16 +61,15 @@ func NewChannel(s *Server, name string, addDefaultModes bool, regInfo *Registere
name: name, name: name,
nameCasefolded: casefoldedName, nameCasefolded: casefoldedName,
server: s, server: s,
} accountToUMode: make(map[string]modes.Mode),
if addDefaultModes {
for _, mode := range s.DefaultChannelModes() {
channel.flags[mode] = true
}
} }
if regInfo != nil { if regInfo != nil {
channel.applyRegInfo(regInfo) channel.applyRegInfo(regInfo)
} else {
for _, mode := range s.DefaultChannelModes() {
channel.flags[mode] = true
}
} }
return channel return channel
@ -83,6 +84,11 @@ func (channel *Channel) applyRegInfo(chanReg *RegisteredChannel) {
channel.topicSetTime = chanReg.TopicSetTime channel.topicSetTime = chanReg.TopicSetTime
channel.name = chanReg.Name channel.name = chanReg.Name
channel.createdTime = chanReg.RegisteredAt channel.createdTime = chanReg.RegisteredAt
channel.key = chanReg.Key
for _, mode := range chanReg.Modes {
channel.flags[mode] = true
}
for _, mask := range chanReg.Banlist { for _, mask := range chanReg.Banlist {
channel.lists[modes.BanMask].Add(mask) channel.lists[modes.BanMask].Add(mask)
} }
@ -92,21 +98,34 @@ func (channel *Channel) applyRegInfo(chanReg *RegisteredChannel) {
for _, mask := range chanReg.Invitelist { for _, mask := range chanReg.Invitelist {
channel.lists[modes.InviteMask].Add(mask) channel.lists[modes.InviteMask].Add(mask)
} }
for account, mode := range chanReg.AccountToUMode {
channel.accountToUMode[account] = mode
}
} }
// obtain a consistent snapshot of the channel state that can be persisted to the DB // obtain a consistent snapshot of the channel state that can be persisted to the DB
func (channel *Channel) ExportRegistration(includeLists bool) (info RegisteredChannel) { func (channel *Channel) ExportRegistration(includeFlags uint) (info RegisteredChannel) {
channel.stateMutex.RLock() channel.stateMutex.RLock()
defer channel.stateMutex.RUnlock() defer channel.stateMutex.RUnlock()
info.Name = channel.name info.Name = channel.name
info.Topic = channel.topic
info.TopicSetBy = channel.topicSetBy
info.TopicSetTime = channel.topicSetTime
info.Founder = channel.registeredFounder info.Founder = channel.registeredFounder
info.RegisteredAt = channel.registeredTime info.RegisteredAt = channel.registeredTime
if includeLists { if includeFlags&IncludeTopic != 0 {
info.Topic = channel.topic
info.TopicSetBy = channel.topicSetBy
info.TopicSetTime = channel.topicSetTime
}
if includeFlags&IncludeModes != 0 {
info.Key = channel.key
for mode := range channel.flags {
info.Modes = append(info.Modes, mode)
}
}
if includeFlags&IncludeLists != 0 {
for mask := range channel.lists[modes.BanMask].masks { for mask := range channel.lists[modes.BanMask].masks {
info.Banlist = append(info.Banlist, mask) info.Banlist = append(info.Banlist, mask)
} }
@ -116,6 +135,10 @@ func (channel *Channel) ExportRegistration(includeLists bool) (info RegisteredCh
for mask := range channel.lists[modes.InviteMask].masks { for mask := range channel.lists[modes.InviteMask].masks {
info.Invitelist = append(info.Invitelist, mask) info.Invitelist = append(info.Invitelist, mask)
} }
info.AccountToUMode = make(map[string]modes.Mode)
for account, mode := range channel.accountToUMode {
info.AccountToUMode[account] = mode
}
} }
return return
@ -131,6 +154,7 @@ func (channel *Channel) SetRegistered(founder string) error {
} }
channel.registeredFounder = founder channel.registeredFounder = founder
channel.registeredTime = time.Now() channel.registeredTime = time.Now()
channel.accountToUMode[founder] = modes.ChannelFounder
return nil return nil
} }
@ -338,7 +362,12 @@ func (channel *Channel) IsFull() bool {
// CheckKey returns true if the key is not set or matches the given key. // CheckKey returns true if the key is not set or matches the given key.
func (channel *Channel) CheckKey(key string) bool { func (channel *Channel) CheckKey(key string) bool {
return (channel.key == "") || (channel.key == key) chkey := channel.Key()
if chkey == "" {
return true
}
return subtle.ConstantTimeCompare([]byte(key), []byte(chkey)) == 1
} }
func (channel *Channel) IsEmpty() bool { func (channel *Channel) IsEmpty() bool {
@ -404,21 +433,22 @@ func (channel *Channel) Join(client *Client, key string, rb *ResponseBuffer) {
client.addChannel(channel) client.addChannel(channel)
// give channel mode if necessary
newChannel := firstJoin && !channel.IsRegistered()
var givenMode *modes.Mode
account := client.Account() account := client.Account()
cffounder, _ := CasefoldName(channel.registeredFounder)
if account != "" && account == cffounder { // give channel mode if necessary
givenMode = &modes.ChannelFounder channel.stateMutex.Lock()
newChannel := firstJoin && channel.registeredFounder == ""
mode, persistentModeExists := channel.accountToUMode[account]
var givenMode *modes.Mode
if persistentModeExists {
givenMode = &mode
} else if newChannel { } else if newChannel {
givenMode = &modes.ChannelOperator givenMode = &modes.ChannelOperator
} }
if givenMode != nil { if givenMode != nil {
channel.stateMutex.Lock()
channel.members[client][*givenMode] = true channel.members[client][*givenMode] = true
channel.stateMutex.Unlock()
} }
channel.stateMutex.Unlock()
if client.capabilities.Has(caps.ExtendedJoin) { if client.capabilities.Has(caps.ExtendedJoin) {
rb.Add(nil, client.nickMaskString, "JOIN", channel.name, client.AccountName(), client.realname) rb.Add(nil, client.nickMaskString, "JOIN", channel.name, client.AccountName(), client.realname)
@ -513,7 +543,7 @@ func (channel *Channel) SetTopic(client *Client, topic string, rb *ResponseBuffe
} }
} }
go channel.server.channelRegistry.StoreChannel(channel, false) go channel.server.channelRegistry.StoreChannel(channel, IncludeTopic)
} }
// CanSpeak returns true if the client can speak on this channel. // CanSpeak returns true if the client can speak on this channel.

View File

@ -65,7 +65,7 @@ func (cm *ChannelManager) Join(client *Client, name string, key string, rb *Resp
entry = cm.chans[casefoldedName] entry = cm.chans[casefoldedName]
if entry == nil { if entry == nil {
entry = &channelManagerEntry{ entry = &channelManagerEntry{
channel: NewChannel(server, name, true, info), channel: NewChannel(server, name, info),
pendingJoins: 0, pendingJoins: 0,
} }
cm.chans[casefoldedName] = entry cm.chans[casefoldedName] = entry

View File

@ -6,11 +6,13 @@ package irc
import ( import (
"fmt" "fmt"
"strconv" "strconv"
"strings"
"sync" "sync"
"time" "time"
"encoding/json" "encoding/json"
"github.com/oragono/oragono/irc/modes"
"github.com/tidwall/buntdb" "github.com/tidwall/buntdb"
) )
@ -18,16 +20,19 @@ import (
// channel creation/tracking/destruction is in channelmanager.go // channel creation/tracking/destruction is in channelmanager.go
const ( const (
keyChannelExists = "channel.exists %s" keyChannelExists = "channel.exists %s"
keyChannelName = "channel.name %s" // stores the 'preferred name' of the channel, not casemapped keyChannelName = "channel.name %s" // stores the 'preferred name' of the channel, not casemapped
keyChannelRegTime = "channel.registered.time %s" keyChannelRegTime = "channel.registered.time %s"
keyChannelFounder = "channel.founder %s" keyChannelFounder = "channel.founder %s"
keyChannelTopic = "channel.topic %s" keyChannelTopic = "channel.topic %s"
keyChannelTopicSetBy = "channel.topic.setby %s" keyChannelTopicSetBy = "channel.topic.setby %s"
keyChannelTopicSetTime = "channel.topic.settime %s" keyChannelTopicSetTime = "channel.topic.settime %s"
keyChannelBanlist = "channel.banlist %s" keyChannelBanlist = "channel.banlist %s"
keyChannelExceptlist = "channel.exceptlist %s" keyChannelExceptlist = "channel.exceptlist %s"
keyChannelInvitelist = "channel.invitelist %s" keyChannelInvitelist = "channel.invitelist %s"
keyChannelPassword = "channel.key %s"
keyChannelModes = "channel.modes %s"
keyChannelAccountToUMode = "channel.accounttoumode %s"
) )
var ( var (
@ -42,9 +47,26 @@ var (
keyChannelBanlist, keyChannelBanlist,
keyChannelExceptlist, keyChannelExceptlist,
keyChannelInvitelist, keyChannelInvitelist,
keyChannelPassword,
keyChannelModes,
keyChannelAccountToUMode,
} }
) )
// these are bit flags indicating what part of the channel status is "dirty"
// and needs to be read from memory and written to the db
const (
IncludeInitial uint = 1 << iota
IncludeTopic
IncludeModes
IncludeLists
)
// this is an OR of all possible flags
const (
IncludeAllChannelAttrs = ^uint(0)
)
// RegisteredChannel holds details about a given registered channel. // RegisteredChannel holds details about a given registered channel.
type RegisteredChannel struct { type RegisteredChannel struct {
// Name of the channel. // Name of the channel.
@ -59,6 +81,12 @@ type RegisteredChannel struct {
TopicSetBy string TopicSetBy string
// TopicSetTime represents the time the topic was set. // TopicSetTime represents the time the topic was set.
TopicSetTime time.Time TopicSetTime time.Time
// Modes represents the channel modes
Modes []modes.Mode
// Key represents the channel key / password
Key string
// AccountToUMode maps user accounts to their persistent channel modes (e.g., +q, +h)
AccountToUMode map[string]modes.Mode
// Banlist represents the bans set on the channel. // Banlist represents the bans set on the channel.
Banlist []string Banlist []string
// Exceptlist represents the exceptions set on the channel. // Exceptlist represents the exceptions set on the channel.
@ -87,7 +115,7 @@ func NewChannelRegistry(server *Server) *ChannelRegistry {
} }
// StoreChannel obtains a consistent view of a channel, then persists it to the store. // StoreChannel obtains a consistent view of a channel, then persists it to the store.
func (reg *ChannelRegistry) StoreChannel(channel *Channel, includeLists bool) { func (reg *ChannelRegistry) StoreChannel(channel *Channel, includeFlags uint) {
if !reg.server.ChannelRegistrationEnabled() { if !reg.server.ChannelRegistrationEnabled() {
return return
} }
@ -96,14 +124,14 @@ func (reg *ChannelRegistry) StoreChannel(channel *Channel, includeLists bool) {
defer reg.Unlock() defer reg.Unlock()
key := channel.NameCasefolded() key := channel.NameCasefolded()
info := channel.ExportRegistration(includeLists) info := channel.ExportRegistration(includeFlags)
if info.Founder == "" { if info.Founder == "" {
// sanity check, don't try to store an unregistered channel // sanity check, don't try to store an unregistered channel
return return
} }
reg.server.store.Update(func(tx *buntdb.Tx) error { reg.server.store.Update(func(tx *buntdb.Tx) error {
reg.saveChannel(tx, key, info, includeLists) reg.saveChannel(tx, key, info, includeFlags)
return nil return nil
}) })
} }
@ -132,9 +160,17 @@ func (reg *ChannelRegistry) LoadChannel(nameCasefolded string) (info *Registered
topicSetBy, _ := tx.Get(fmt.Sprintf(keyChannelTopicSetBy, channelKey)) topicSetBy, _ := tx.Get(fmt.Sprintf(keyChannelTopicSetBy, channelKey))
topicSetTime, _ := tx.Get(fmt.Sprintf(keyChannelTopicSetTime, channelKey)) topicSetTime, _ := tx.Get(fmt.Sprintf(keyChannelTopicSetTime, channelKey))
topicSetTimeInt, _ := strconv.ParseInt(topicSetTime, 10, 64) topicSetTimeInt, _ := strconv.ParseInt(topicSetTime, 10, 64)
password, _ := tx.Get(fmt.Sprintf(keyChannelPassword, channelKey))
modeString, _ := tx.Get(fmt.Sprintf(keyChannelModes, channelKey))
banlistString, _ := tx.Get(fmt.Sprintf(keyChannelBanlist, channelKey)) banlistString, _ := tx.Get(fmt.Sprintf(keyChannelBanlist, channelKey))
exceptlistString, _ := tx.Get(fmt.Sprintf(keyChannelExceptlist, channelKey)) exceptlistString, _ := tx.Get(fmt.Sprintf(keyChannelExceptlist, channelKey))
invitelistString, _ := tx.Get(fmt.Sprintf(keyChannelInvitelist, channelKey)) invitelistString, _ := tx.Get(fmt.Sprintf(keyChannelInvitelist, channelKey))
accountToUModeString, _ := tx.Get(fmt.Sprintf(keyChannelAccountToUMode, channelKey))
modeSlice := make([]modes.Mode, len(modeString))
for i, mode := range modeString {
modeSlice[i] = modes.Mode(mode)
}
var banlist []string var banlist []string
_ = json.Unmarshal([]byte(banlistString), &banlist) _ = json.Unmarshal([]byte(banlistString), &banlist)
@ -142,17 +178,22 @@ func (reg *ChannelRegistry) LoadChannel(nameCasefolded string) (info *Registered
_ = json.Unmarshal([]byte(exceptlistString), &exceptlist) _ = json.Unmarshal([]byte(exceptlistString), &exceptlist)
var invitelist []string var invitelist []string
_ = json.Unmarshal([]byte(invitelistString), &invitelist) _ = json.Unmarshal([]byte(invitelistString), &invitelist)
accountToUMode := make(map[string]modes.Mode)
_ = json.Unmarshal([]byte(accountToUModeString), &accountToUMode)
info = &RegisteredChannel{ info = &RegisteredChannel{
Name: name, Name: name,
RegisteredAt: time.Unix(regTimeInt, 0), RegisteredAt: time.Unix(regTimeInt, 0),
Founder: founder, Founder: founder,
Topic: topic, Topic: topic,
TopicSetBy: topicSetBy, TopicSetBy: topicSetBy,
TopicSetTime: time.Unix(topicSetTimeInt, 0), TopicSetTime: time.Unix(topicSetTimeInt, 0),
Banlist: banlist, Key: password,
Exceptlist: exceptlist, Modes: modeSlice,
Invitelist: invitelist, Banlist: banlist,
Exceptlist: exceptlist,
Invitelist: invitelist,
AccountToUMode: accountToUMode,
} }
return nil return nil
}) })
@ -170,17 +211,17 @@ func (reg *ChannelRegistry) Rename(channel *Channel, casefoldedOldName string) {
reg.Lock() reg.Lock()
defer reg.Unlock() defer reg.Unlock()
includeLists := true includeFlags := IncludeAllChannelAttrs
oldKey := casefoldedOldName oldKey := casefoldedOldName
key := channel.NameCasefolded() key := channel.NameCasefolded()
info := channel.ExportRegistration(includeLists) info := channel.ExportRegistration(includeFlags)
if info.Founder == "" { if info.Founder == "" {
return return
} }
reg.server.store.Update(func(tx *buntdb.Tx) error { reg.server.store.Update(func(tx *buntdb.Tx) error {
reg.deleteChannel(tx, oldKey, info) reg.deleteChannel(tx, oldKey, info)
reg.saveChannel(tx, key, info, includeLists) reg.saveChannel(tx, key, info, includeFlags)
return nil return nil
}) })
} }
@ -204,21 +245,37 @@ func (reg *ChannelRegistry) deleteChannel(tx *buntdb.Tx, key string, info Regist
} }
// saveChannel saves a channel to the store. // saveChannel saves a channel to the store.
func (reg *ChannelRegistry) saveChannel(tx *buntdb.Tx, channelKey string, channelInfo RegisteredChannel, includeLists bool) { func (reg *ChannelRegistry) saveChannel(tx *buntdb.Tx, channelKey string, channelInfo RegisteredChannel, includeFlags uint) {
tx.Set(fmt.Sprintf(keyChannelExists, channelKey), "1", nil) if includeFlags&IncludeInitial != 0 {
tx.Set(fmt.Sprintf(keyChannelName, channelKey), channelInfo.Name, nil) tx.Set(fmt.Sprintf(keyChannelExists, channelKey), "1", nil)
tx.Set(fmt.Sprintf(keyChannelRegTime, channelKey), strconv.FormatInt(channelInfo.RegisteredAt.Unix(), 10), nil) tx.Set(fmt.Sprintf(keyChannelName, channelKey), channelInfo.Name, nil)
tx.Set(fmt.Sprintf(keyChannelFounder, channelKey), channelInfo.Founder, nil) tx.Set(fmt.Sprintf(keyChannelRegTime, channelKey), strconv.FormatInt(channelInfo.RegisteredAt.Unix(), 10), nil)
tx.Set(fmt.Sprintf(keyChannelTopic, channelKey), channelInfo.Topic, nil) tx.Set(fmt.Sprintf(keyChannelFounder, channelKey), channelInfo.Founder, nil)
tx.Set(fmt.Sprintf(keyChannelTopicSetBy, channelKey), channelInfo.TopicSetBy, nil) }
tx.Set(fmt.Sprintf(keyChannelTopicSetTime, channelKey), strconv.FormatInt(channelInfo.TopicSetTime.Unix(), 10), nil)
if includeLists { if includeFlags&IncludeTopic != 0 {
tx.Set(fmt.Sprintf(keyChannelTopic, channelKey), channelInfo.Topic, nil)
tx.Set(fmt.Sprintf(keyChannelTopicSetTime, channelKey), strconv.FormatInt(channelInfo.TopicSetTime.Unix(), 10), nil)
tx.Set(fmt.Sprintf(keyChannelTopicSetBy, channelKey), channelInfo.TopicSetBy, nil)
}
if includeFlags&IncludeModes != 0 {
tx.Set(fmt.Sprintf(keyChannelPassword, channelKey), channelInfo.Key, nil)
modeStrings := make([]string, len(channelInfo.Modes))
for i, mode := range channelInfo.Modes {
modeStrings[i] = string(mode)
}
tx.Set(fmt.Sprintf(keyChannelModes, channelKey), strings.Join(modeStrings, ""), nil)
}
if includeFlags&IncludeLists != 0 {
banlistString, _ := json.Marshal(channelInfo.Banlist) banlistString, _ := json.Marshal(channelInfo.Banlist)
tx.Set(fmt.Sprintf(keyChannelBanlist, channelKey), string(banlistString), nil) tx.Set(fmt.Sprintf(keyChannelBanlist, channelKey), string(banlistString), nil)
exceptlistString, _ := json.Marshal(channelInfo.Exceptlist) exceptlistString, _ := json.Marshal(channelInfo.Exceptlist)
tx.Set(fmt.Sprintf(keyChannelExceptlist, channelKey), string(exceptlistString), nil) tx.Set(fmt.Sprintf(keyChannelExceptlist, channelKey), string(exceptlistString), nil)
invitelistString, _ := json.Marshal(channelInfo.Invitelist) invitelistString, _ := json.Marshal(channelInfo.Invitelist)
tx.Set(fmt.Sprintf(keyChannelInvitelist, channelKey), string(invitelistString), nil) tx.Set(fmt.Sprintf(keyChannelInvitelist, channelKey), string(invitelistString), nil)
accountToUModeString, _ := json.Marshal(channelInfo.AccountToUMode)
tx.Set(fmt.Sprintf(keyChannelAccountToUMode, channelKey), string(accountToUModeString), nil)
} }
} }

View File

@ -252,7 +252,7 @@ func csRegisterHandler(server *Server, client *Client, command, params string, r
} }
// registration was successful: make the database reflect it // registration was successful: make the database reflect it
go server.channelRegistry.StoreChannel(channelInfo, true) go server.channelRegistry.StoreChannel(channelInfo, IncludeAllChannelAttrs)
csNotice(rb, fmt.Sprintf(client.t("Channel %s successfully registered"), channelName)) csNotice(rb, fmt.Sprintf(client.t("Channel %s successfully registered"), channelName))

View File

@ -6,11 +6,13 @@ package irc
import ( import (
"encoding/base64" "encoding/base64"
"encoding/json"
"fmt" "fmt"
"log" "log"
"os" "os"
"strings" "strings"
"github.com/oragono/oragono/irc/modes"
"github.com/oragono/oragono/irc/passwd" "github.com/oragono/oragono/irc/passwd"
"github.com/tidwall/buntdb" "github.com/tidwall/buntdb"
@ -20,11 +22,22 @@ const (
// 'version' of the database schema // 'version' of the database schema
keySchemaVersion = "db.version" keySchemaVersion = "db.version"
// latest schema of the db // latest schema of the db
latestDbSchema = "2" latestDbSchema = "3"
// key for the primary salt used by the ircd // key for the primary salt used by the ircd
keySalt = "crypto.salt" keySalt = "crypto.salt"
) )
type SchemaChanger func(*Config, *buntdb.Tx) error
type SchemaChange struct {
InitialVersion string // the change will take this version
TargetVersion string // and transform it into this version
Changer SchemaChanger
}
// maps an initial version to a schema change capable of upgrading it
var schemaChanges map[string]SchemaChange
// InitDB creates the database. // InitDB creates the database.
func InitDB(path string) { func InitDB(path string) {
// prepare kvstore db // prepare kvstore db
@ -46,7 +59,7 @@ func InitDB(path string) {
tx.Set(keySalt, encodedSalt, nil) tx.Set(keySalt, encodedSalt, nil)
// set schema version // set schema version
tx.Set(keySchemaVersion, "2", nil) tx.Set(keySchemaVersion, latestDbSchema, nil)
return nil return nil
}) })
@ -82,58 +95,142 @@ func OpenDatabase(path string) (*buntdb.DB, error) {
} }
// UpgradeDB upgrades the datastore to the latest schema. // UpgradeDB upgrades the datastore to the latest schema.
func UpgradeDB(path string) { func UpgradeDB(config *Config) {
store, err := buntdb.Open(path) store, err := buntdb.Open(config.Datastore.Path)
if err != nil { if err != nil {
log.Fatal(fmt.Sprintf("Failed to open datastore: %s", err.Error())) log.Fatal(fmt.Sprintf("Failed to open datastore: %s", err.Error()))
} }
defer store.Close() defer store.Close()
var version string
err = store.Update(func(tx *buntdb.Tx) error { err = store.Update(func(tx *buntdb.Tx) error {
version, _ := tx.Get(keySchemaVersion) for {
version, _ = tx.Get(keySchemaVersion)
// == version 1 -> 2 == change, schemaNeedsChange := schemaChanges[version]
// account key changes and account.verified key bugfix. if !schemaNeedsChange {
if version == "1" { break
log.Println("Updating store v1 to v2")
var keysToRemove []string
newKeys := make(map[string]string)
tx.AscendKeys("account *", func(key, value string) bool {
keysToRemove = append(keysToRemove, key)
splitkey := strings.Split(key, " ")
// work around bug
if splitkey[2] == "exists" {
// manually create new verified key
newVerifiedKey := fmt.Sprintf("%s.verified %s", splitkey[0], splitkey[1])
newKeys[newVerifiedKey] = "1"
} else if splitkey[1] == "%s" {
return true
}
newKey := fmt.Sprintf("%s.%s %s", splitkey[0], splitkey[2], splitkey[1])
newKeys[newKey] = value
return true
})
for _, key := range keysToRemove {
tx.Delete(key)
} }
for key, value := range newKeys { log.Println("attempting to update store from version " + version)
tx.Set(key, value, nil) err := change.Changer(config, tx)
if err != nil {
return err
} }
_, _, err = tx.Set(keySchemaVersion, change.TargetVersion, nil)
tx.Set(keySchemaVersion, "2", nil) if err != nil {
return err
}
log.Println("successfully updated store to version " + change.TargetVersion)
} }
return nil return nil
}) })
if err != nil { if err != nil {
log.Fatal("Could not update datastore:", err.Error()) log.Fatal("Could not update datastore:", err.Error())
} }
return return
} }
func schemaChangeV1toV2(config *Config, tx *buntdb.Tx) error {
// == version 1 -> 2 ==
// account key changes and account.verified key bugfix.
var keysToRemove []string
newKeys := make(map[string]string)
tx.AscendKeys("account *", func(key, value string) bool {
keysToRemove = append(keysToRemove, key)
splitkey := strings.Split(key, " ")
// work around bug
if splitkey[2] == "exists" {
// manually create new verified key
newVerifiedKey := fmt.Sprintf("%s.verified %s", splitkey[0], splitkey[1])
newKeys[newVerifiedKey] = "1"
} else if splitkey[1] == "%s" {
return true
}
newKey := fmt.Sprintf("%s.%s %s", splitkey[0], splitkey[2], splitkey[1])
newKeys[newKey] = value
return true
})
for _, key := range keysToRemove {
tx.Delete(key)
}
for key, value := range newKeys {
tx.Set(key, value, nil)
}
return nil
}
// 1. channel founder names should be casefolded
// 2. founder should be explicitly granted the ChannelFounder user mode
// 3. explicitly initialize stored channel modes to the server default values
func schemaChangeV2ToV3(config *Config, tx *buntdb.Tx) error {
var channels []string
prefix := "channel.exists "
tx.AscendGreaterOrEqual("", prefix, func(key, value string) bool {
if !strings.HasPrefix(key, prefix) {
return false
}
chname := strings.TrimPrefix(key, prefix)
channels = append(channels, chname)
return true
})
// founder names should be casefolded
// founder should be explicitly granted the ChannelFounder user mode
for _, channel := range channels {
founderKey := "channel.founder " + channel
founder, _ := tx.Get(founderKey)
if founder != "" {
founder, err := CasefoldName(founder)
if err == nil {
tx.Set(founderKey, founder, nil)
accountToUmode := map[string]modes.Mode{
founder: modes.ChannelFounder,
}
atustr, _ := json.Marshal(accountToUmode)
tx.Set("channel.accounttoumode "+channel, string(atustr), nil)
}
}
}
// explicitly store the channel modes
defaultModes := ParseDefaultChannelModes(config)
modeStrings := make([]string, len(defaultModes))
for i, mode := range defaultModes {
modeStrings[i] = string(mode)
}
defaultModeString := strings.Join(modeStrings, "")
for _, channel := range channels {
tx.Set("channel.modes "+channel, defaultModeString, nil)
}
return nil
}
func init() {
allChanges := []SchemaChange{
SchemaChange{
InitialVersion: "1",
TargetVersion: "2",
Changer: schemaChangeV1toV2,
},
SchemaChange{
InitialVersion: "2",
TargetVersion: "3",
Changer: schemaChangeV2ToV3,
},
}
// build the index
schemaChanges = make(map[string]SchemaChange)
for _, change := range allChanges {
schemaChanges[change.InitialVersion] = change
}
}

View File

@ -256,8 +256,8 @@ func (channel *Channel) Key() string {
func (channel *Channel) setKey(key string) { func (channel *Channel) setKey(key string) {
channel.stateMutex.Lock() channel.stateMutex.Lock()
defer channel.stateMutex.Unlock()
channel.key = key channel.key = key
channel.stateMutex.Unlock()
} }
func (channel *Channel) HasMode(mode modes.Mode) bool { func (channel *Channel) HasMode(mode modes.Mode) bool {

View File

@ -1351,20 +1351,17 @@ func cmodeHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Res
applied = channel.ApplyChannelModeChanges(client, msg.Command == "SAMODE", changes, rb) applied = channel.ApplyChannelModeChanges(client, msg.Command == "SAMODE", changes, rb)
} }
// save changes to banlist/exceptlist/invexlist // save changes
var banlistUpdated, exceptlistUpdated, invexlistUpdated bool var includeFlags uint
for _, change := range applied { for _, change := range applied {
if change.Mode == modes.BanMask { includeFlags |= IncludeModes
banlistUpdated = true if change.Mode == modes.BanMask || change.Mode == modes.ExceptMask || change.Mode == modes.InviteMask {
} else if change.Mode == modes.ExceptMask { includeFlags |= IncludeLists
exceptlistUpdated = true
} else if change.Mode == modes.InviteMask {
invexlistUpdated = true
} }
} }
if (banlistUpdated || exceptlistUpdated || invexlistUpdated) && channel.IsRegistered() { if channel.IsRegistered() && includeFlags != 0 {
go server.channelRegistry.StoreChannel(channel, true) go server.channelRegistry.StoreChannel(channel, includeFlags)
} }
// send out changes // send out changes

View File

@ -84,7 +84,7 @@ Options:
log.Println("database initialized: ", config.Datastore.Path) log.Println("database initialized: ", config.Datastore.Path)
} }
} else if arguments["upgradedb"].(bool) { } else if arguments["upgradedb"].(bool) {
irc.UpgradeDB(config.Datastore.Path) irc.UpgradeDB(config)
if !arguments["--quiet"].(bool) { if !arguments["--quiet"].(bool) {
log.Println("database upgraded: ", config.Datastore.Path) log.Println("database upgraded: ", config.Datastore.Path)
} }