3
0
mirror of https://github.com/ergochat/ergo.git synced 2024-11-14 07:59:31 +01:00

Merge pull request #224 from slingamn/channelkeys.1

Updates to channel persistence
This commit is contained in:
Daniel Oaks 2018-04-16 13:35:48 +10:00 committed by GitHub
commit c75d2c91c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 291 additions and 110 deletions

View File

@ -6,6 +6,7 @@
package irc
import (
"crypto/subtle"
"fmt"
"strconv"
"time"
@ -36,11 +37,12 @@ type Channel struct {
topicSetBy string
topicSetTime time.Time
userLimit uint64
accountToUMode map[string]modes.Mode
}
// NewChannel creates a new channel from a `Server` and a `name`
// string, which must be unique on the server.
func NewChannel(s *Server, name string, addDefaultModes bool, regInfo *RegisteredChannel) *Channel {
func NewChannel(s *Server, name string, regInfo *RegisteredChannel) *Channel {
casefoldedName, err := CasefoldChannel(name)
if err != nil {
s.logger.Error("internal", fmt.Sprintf("Bad channel name %s: %v", name, err))
@ -59,16 +61,15 @@ func NewChannel(s *Server, name string, addDefaultModes bool, regInfo *Registere
name: name,
nameCasefolded: casefoldedName,
server: s,
}
if addDefaultModes {
for _, mode := range s.DefaultChannelModes() {
channel.flags[mode] = true
}
accountToUMode: make(map[string]modes.Mode),
}
if regInfo != nil {
channel.applyRegInfo(regInfo)
} else {
for _, mode := range s.DefaultChannelModes() {
channel.flags[mode] = true
}
}
return channel
@ -83,6 +84,11 @@ func (channel *Channel) applyRegInfo(chanReg *RegisteredChannel) {
channel.topicSetTime = chanReg.TopicSetTime
channel.name = chanReg.Name
channel.createdTime = chanReg.RegisteredAt
channel.key = chanReg.Key
for _, mode := range chanReg.Modes {
channel.flags[mode] = true
}
for _, mask := range chanReg.Banlist {
channel.lists[modes.BanMask].Add(mask)
}
@ -92,21 +98,34 @@ func (channel *Channel) applyRegInfo(chanReg *RegisteredChannel) {
for _, mask := range chanReg.Invitelist {
channel.lists[modes.InviteMask].Add(mask)
}
for account, mode := range chanReg.AccountToUMode {
channel.accountToUMode[account] = mode
}
}
// obtain a consistent snapshot of the channel state that can be persisted to the DB
func (channel *Channel) ExportRegistration(includeLists bool) (info RegisteredChannel) {
func (channel *Channel) ExportRegistration(includeFlags uint) (info RegisteredChannel) {
channel.stateMutex.RLock()
defer channel.stateMutex.RUnlock()
info.Name = channel.name
info.Topic = channel.topic
info.TopicSetBy = channel.topicSetBy
info.TopicSetTime = channel.topicSetTime
info.Founder = channel.registeredFounder
info.RegisteredAt = channel.registeredTime
if includeLists {
if includeFlags&IncludeTopic != 0 {
info.Topic = channel.topic
info.TopicSetBy = channel.topicSetBy
info.TopicSetTime = channel.topicSetTime
}
if includeFlags&IncludeModes != 0 {
info.Key = channel.key
for mode := range channel.flags {
info.Modes = append(info.Modes, mode)
}
}
if includeFlags&IncludeLists != 0 {
for mask := range channel.lists[modes.BanMask].masks {
info.Banlist = append(info.Banlist, mask)
}
@ -116,6 +135,10 @@ func (channel *Channel) ExportRegistration(includeLists bool) (info RegisteredCh
for mask := range channel.lists[modes.InviteMask].masks {
info.Invitelist = append(info.Invitelist, mask)
}
info.AccountToUMode = make(map[string]modes.Mode)
for account, mode := range channel.accountToUMode {
info.AccountToUMode[account] = mode
}
}
return
@ -131,6 +154,7 @@ func (channel *Channel) SetRegistered(founder string) error {
}
channel.registeredFounder = founder
channel.registeredTime = time.Now()
channel.accountToUMode[founder] = modes.ChannelFounder
return nil
}
@ -338,7 +362,12 @@ func (channel *Channel) IsFull() bool {
// CheckKey returns true if the key is not set or matches the given key.
func (channel *Channel) CheckKey(key string) bool {
return (channel.key == "") || (channel.key == key)
chkey := channel.Key()
if chkey == "" {
return true
}
return subtle.ConstantTimeCompare([]byte(key), []byte(chkey)) == 1
}
func (channel *Channel) IsEmpty() bool {
@ -404,21 +433,22 @@ func (channel *Channel) Join(client *Client, key string, rb *ResponseBuffer) {
client.addChannel(channel)
// give channel mode if necessary
newChannel := firstJoin && !channel.IsRegistered()
var givenMode *modes.Mode
account := client.Account()
cffounder, _ := CasefoldName(channel.registeredFounder)
if account != "" && account == cffounder {
givenMode = &modes.ChannelFounder
// give channel mode if necessary
channel.stateMutex.Lock()
newChannel := firstJoin && channel.registeredFounder == ""
mode, persistentModeExists := channel.accountToUMode[account]
var givenMode *modes.Mode
if persistentModeExists {
givenMode = &mode
} else if newChannel {
givenMode = &modes.ChannelOperator
}
if givenMode != nil {
channel.stateMutex.Lock()
channel.members[client][*givenMode] = true
channel.stateMutex.Unlock()
}
channel.stateMutex.Unlock()
if client.capabilities.Has(caps.ExtendedJoin) {
rb.Add(nil, client.nickMaskString, "JOIN", channel.name, client.AccountName(), client.realname)
@ -513,7 +543,7 @@ func (channel *Channel) SetTopic(client *Client, topic string, rb *ResponseBuffe
}
}
go channel.server.channelRegistry.StoreChannel(channel, false)
go channel.server.channelRegistry.StoreChannel(channel, IncludeTopic)
}
// CanSpeak returns true if the client can speak on this channel.

View File

@ -65,7 +65,7 @@ func (cm *ChannelManager) Join(client *Client, name string, key string, rb *Resp
entry = cm.chans[casefoldedName]
if entry == nil {
entry = &channelManagerEntry{
channel: NewChannel(server, name, true, info),
channel: NewChannel(server, name, info),
pendingJoins: 0,
}
cm.chans[casefoldedName] = entry

View File

@ -6,11 +6,13 @@ package irc
import (
"fmt"
"strconv"
"strings"
"sync"
"time"
"encoding/json"
"github.com/oragono/oragono/irc/modes"
"github.com/tidwall/buntdb"
)
@ -28,6 +30,9 @@ const (
keyChannelBanlist = "channel.banlist %s"
keyChannelExceptlist = "channel.exceptlist %s"
keyChannelInvitelist = "channel.invitelist %s"
keyChannelPassword = "channel.key %s"
keyChannelModes = "channel.modes %s"
keyChannelAccountToUMode = "channel.accounttoumode %s"
)
var (
@ -42,9 +47,26 @@ var (
keyChannelBanlist,
keyChannelExceptlist,
keyChannelInvitelist,
keyChannelPassword,
keyChannelModes,
keyChannelAccountToUMode,
}
)
// these are bit flags indicating what part of the channel status is "dirty"
// and needs to be read from memory and written to the db
const (
IncludeInitial uint = 1 << iota
IncludeTopic
IncludeModes
IncludeLists
)
// this is an OR of all possible flags
const (
IncludeAllChannelAttrs = ^uint(0)
)
// RegisteredChannel holds details about a given registered channel.
type RegisteredChannel struct {
// Name of the channel.
@ -59,6 +81,12 @@ type RegisteredChannel struct {
TopicSetBy string
// TopicSetTime represents the time the topic was set.
TopicSetTime time.Time
// Modes represents the channel modes
Modes []modes.Mode
// Key represents the channel key / password
Key string
// AccountToUMode maps user accounts to their persistent channel modes (e.g., +q, +h)
AccountToUMode map[string]modes.Mode
// Banlist represents the bans set on the channel.
Banlist []string
// Exceptlist represents the exceptions set on the channel.
@ -87,7 +115,7 @@ func NewChannelRegistry(server *Server) *ChannelRegistry {
}
// StoreChannel obtains a consistent view of a channel, then persists it to the store.
func (reg *ChannelRegistry) StoreChannel(channel *Channel, includeLists bool) {
func (reg *ChannelRegistry) StoreChannel(channel *Channel, includeFlags uint) {
if !reg.server.ChannelRegistrationEnabled() {
return
}
@ -96,14 +124,14 @@ func (reg *ChannelRegistry) StoreChannel(channel *Channel, includeLists bool) {
defer reg.Unlock()
key := channel.NameCasefolded()
info := channel.ExportRegistration(includeLists)
info := channel.ExportRegistration(includeFlags)
if info.Founder == "" {
// sanity check, don't try to store an unregistered channel
return
}
reg.server.store.Update(func(tx *buntdb.Tx) error {
reg.saveChannel(tx, key, info, includeLists)
reg.saveChannel(tx, key, info, includeFlags)
return nil
})
}
@ -132,9 +160,17 @@ func (reg *ChannelRegistry) LoadChannel(nameCasefolded string) (info *Registered
topicSetBy, _ := tx.Get(fmt.Sprintf(keyChannelTopicSetBy, channelKey))
topicSetTime, _ := tx.Get(fmt.Sprintf(keyChannelTopicSetTime, channelKey))
topicSetTimeInt, _ := strconv.ParseInt(topicSetTime, 10, 64)
password, _ := tx.Get(fmt.Sprintf(keyChannelPassword, channelKey))
modeString, _ := tx.Get(fmt.Sprintf(keyChannelModes, channelKey))
banlistString, _ := tx.Get(fmt.Sprintf(keyChannelBanlist, channelKey))
exceptlistString, _ := tx.Get(fmt.Sprintf(keyChannelExceptlist, channelKey))
invitelistString, _ := tx.Get(fmt.Sprintf(keyChannelInvitelist, channelKey))
accountToUModeString, _ := tx.Get(fmt.Sprintf(keyChannelAccountToUMode, channelKey))
modeSlice := make([]modes.Mode, len(modeString))
for i, mode := range modeString {
modeSlice[i] = modes.Mode(mode)
}
var banlist []string
_ = json.Unmarshal([]byte(banlistString), &banlist)
@ -142,6 +178,8 @@ func (reg *ChannelRegistry) LoadChannel(nameCasefolded string) (info *Registered
_ = json.Unmarshal([]byte(exceptlistString), &exceptlist)
var invitelist []string
_ = json.Unmarshal([]byte(invitelistString), &invitelist)
accountToUMode := make(map[string]modes.Mode)
_ = json.Unmarshal([]byte(accountToUModeString), &accountToUMode)
info = &RegisteredChannel{
Name: name,
@ -150,9 +188,12 @@ func (reg *ChannelRegistry) LoadChannel(nameCasefolded string) (info *Registered
Topic: topic,
TopicSetBy: topicSetBy,
TopicSetTime: time.Unix(topicSetTimeInt, 0),
Key: password,
Modes: modeSlice,
Banlist: banlist,
Exceptlist: exceptlist,
Invitelist: invitelist,
AccountToUMode: accountToUMode,
}
return nil
})
@ -170,17 +211,17 @@ func (reg *ChannelRegistry) Rename(channel *Channel, casefoldedOldName string) {
reg.Lock()
defer reg.Unlock()
includeLists := true
includeFlags := IncludeAllChannelAttrs
oldKey := casefoldedOldName
key := channel.NameCasefolded()
info := channel.ExportRegistration(includeLists)
info := channel.ExportRegistration(includeFlags)
if info.Founder == "" {
return
}
reg.server.store.Update(func(tx *buntdb.Tx) error {
reg.deleteChannel(tx, oldKey, info)
reg.saveChannel(tx, key, info, includeLists)
reg.saveChannel(tx, key, info, includeFlags)
return nil
})
}
@ -204,21 +245,37 @@ func (reg *ChannelRegistry) deleteChannel(tx *buntdb.Tx, key string, info Regist
}
// saveChannel saves a channel to the store.
func (reg *ChannelRegistry) saveChannel(tx *buntdb.Tx, channelKey string, channelInfo RegisteredChannel, includeLists bool) {
func (reg *ChannelRegistry) saveChannel(tx *buntdb.Tx, channelKey string, channelInfo RegisteredChannel, includeFlags uint) {
if includeFlags&IncludeInitial != 0 {
tx.Set(fmt.Sprintf(keyChannelExists, channelKey), "1", nil)
tx.Set(fmt.Sprintf(keyChannelName, channelKey), channelInfo.Name, nil)
tx.Set(fmt.Sprintf(keyChannelRegTime, channelKey), strconv.FormatInt(channelInfo.RegisteredAt.Unix(), 10), nil)
tx.Set(fmt.Sprintf(keyChannelFounder, channelKey), channelInfo.Founder, nil)
tx.Set(fmt.Sprintf(keyChannelTopic, channelKey), channelInfo.Topic, nil)
tx.Set(fmt.Sprintf(keyChannelTopicSetBy, channelKey), channelInfo.TopicSetBy, nil)
tx.Set(fmt.Sprintf(keyChannelTopicSetTime, channelKey), strconv.FormatInt(channelInfo.TopicSetTime.Unix(), 10), nil)
}
if includeLists {
if includeFlags&IncludeTopic != 0 {
tx.Set(fmt.Sprintf(keyChannelTopic, channelKey), channelInfo.Topic, nil)
tx.Set(fmt.Sprintf(keyChannelTopicSetTime, channelKey), strconv.FormatInt(channelInfo.TopicSetTime.Unix(), 10), nil)
tx.Set(fmt.Sprintf(keyChannelTopicSetBy, channelKey), channelInfo.TopicSetBy, nil)
}
if includeFlags&IncludeModes != 0 {
tx.Set(fmt.Sprintf(keyChannelPassword, channelKey), channelInfo.Key, nil)
modeStrings := make([]string, len(channelInfo.Modes))
for i, mode := range channelInfo.Modes {
modeStrings[i] = string(mode)
}
tx.Set(fmt.Sprintf(keyChannelModes, channelKey), strings.Join(modeStrings, ""), nil)
}
if includeFlags&IncludeLists != 0 {
banlistString, _ := json.Marshal(channelInfo.Banlist)
tx.Set(fmt.Sprintf(keyChannelBanlist, channelKey), string(banlistString), nil)
exceptlistString, _ := json.Marshal(channelInfo.Exceptlist)
tx.Set(fmt.Sprintf(keyChannelExceptlist, channelKey), string(exceptlistString), nil)
invitelistString, _ := json.Marshal(channelInfo.Invitelist)
tx.Set(fmt.Sprintf(keyChannelInvitelist, channelKey), string(invitelistString), nil)
accountToUModeString, _ := json.Marshal(channelInfo.AccountToUMode)
tx.Set(fmt.Sprintf(keyChannelAccountToUMode, channelKey), string(accountToUModeString), nil)
}
}

View File

@ -252,7 +252,7 @@ func csRegisterHandler(server *Server, client *Client, command, params string, r
}
// registration was successful: make the database reflect it
go server.channelRegistry.StoreChannel(channelInfo, true)
go server.channelRegistry.StoreChannel(channelInfo, IncludeAllChannelAttrs)
csNotice(rb, fmt.Sprintf(client.t("Channel %s successfully registered"), channelName))

View File

@ -6,11 +6,13 @@ package irc
import (
"encoding/base64"
"encoding/json"
"fmt"
"log"
"os"
"strings"
"github.com/oragono/oragono/irc/modes"
"github.com/oragono/oragono/irc/passwd"
"github.com/tidwall/buntdb"
@ -20,11 +22,22 @@ const (
// 'version' of the database schema
keySchemaVersion = "db.version"
// latest schema of the db
latestDbSchema = "2"
latestDbSchema = "3"
// key for the primary salt used by the ircd
keySalt = "crypto.salt"
)
type SchemaChanger func(*Config, *buntdb.Tx) error
type SchemaChange struct {
InitialVersion string // the change will take this version
TargetVersion string // and transform it into this version
Changer SchemaChanger
}
// maps an initial version to a schema change capable of upgrading it
var schemaChanges map[string]SchemaChange
// InitDB creates the database.
func InitDB(path string) {
// prepare kvstore db
@ -46,7 +59,7 @@ func InitDB(path string) {
tx.Set(keySalt, encodedSalt, nil)
// set schema version
tx.Set(keySchemaVersion, "2", nil)
tx.Set(keySchemaVersion, latestDbSchema, nil)
return nil
})
@ -82,20 +95,45 @@ func OpenDatabase(path string) (*buntdb.DB, error) {
}
// UpgradeDB upgrades the datastore to the latest schema.
func UpgradeDB(path string) {
store, err := buntdb.Open(path)
func UpgradeDB(config *Config) {
store, err := buntdb.Open(config.Datastore.Path)
if err != nil {
log.Fatal(fmt.Sprintf("Failed to open datastore: %s", err.Error()))
}
defer store.Close()
var version string
err = store.Update(func(tx *buntdb.Tx) error {
version, _ := tx.Get(keySchemaVersion)
for {
version, _ = tx.Get(keySchemaVersion)
change, schemaNeedsChange := schemaChanges[version]
if !schemaNeedsChange {
break
}
log.Println("attempting to update store from version " + version)
err := change.Changer(config, tx)
if err != nil {
return err
}
_, _, err = tx.Set(keySchemaVersion, change.TargetVersion, nil)
if err != nil {
return err
}
log.Println("successfully updated store to version " + change.TargetVersion)
}
return nil
})
if err != nil {
log.Fatal("Could not update datastore:", err.Error())
}
return
}
func schemaChangeV1toV2(config *Config, tx *buntdb.Tx) error {
// == version 1 -> 2 ==
// account key changes and account.verified key bugfix.
if version == "1" {
log.Println("Updating store v1 to v2")
var keysToRemove []string
newKeys := make(map[string]string)
@ -126,14 +164,73 @@ func UpgradeDB(path string) {
tx.Set(key, value, nil)
}
tx.Set(keySchemaVersion, "2", nil)
return nil
}
// 1. channel founder names should be casefolded
// 2. founder should be explicitly granted the ChannelFounder user mode
// 3. explicitly initialize stored channel modes to the server default values
func schemaChangeV2ToV3(config *Config, tx *buntdb.Tx) error {
var channels []string
prefix := "channel.exists "
tx.AscendGreaterOrEqual("", prefix, func(key, value string) bool {
if !strings.HasPrefix(key, prefix) {
return false
}
chname := strings.TrimPrefix(key, prefix)
channels = append(channels, chname)
return true
})
// founder names should be casefolded
// founder should be explicitly granted the ChannelFounder user mode
for _, channel := range channels {
founderKey := "channel.founder " + channel
founder, _ := tx.Get(founderKey)
if founder != "" {
founder, err := CasefoldName(founder)
if err == nil {
tx.Set(founderKey, founder, nil)
accountToUmode := map[string]modes.Mode{
founder: modes.ChannelFounder,
}
atustr, _ := json.Marshal(accountToUmode)
tx.Set("channel.accounttoumode "+channel, string(atustr), nil)
}
}
}
// explicitly store the channel modes
defaultModes := ParseDefaultChannelModes(config)
modeStrings := make([]string, len(defaultModes))
for i, mode := range defaultModes {
modeStrings[i] = string(mode)
}
defaultModeString := strings.Join(modeStrings, "")
for _, channel := range channels {
tx.Set("channel.modes "+channel, defaultModeString, nil)
}
return nil
})
if err != nil {
log.Fatal("Could not update datastore:", err.Error())
}
return
func init() {
allChanges := []SchemaChange{
SchemaChange{
InitialVersion: "1",
TargetVersion: "2",
Changer: schemaChangeV1toV2,
},
SchemaChange{
InitialVersion: "2",
TargetVersion: "3",
Changer: schemaChangeV2ToV3,
},
}
// build the index
schemaChanges = make(map[string]SchemaChange)
for _, change := range allChanges {
schemaChanges[change.InitialVersion] = change
}
}

View File

@ -256,8 +256,8 @@ func (channel *Channel) Key() string {
func (channel *Channel) setKey(key string) {
channel.stateMutex.Lock()
defer channel.stateMutex.Unlock()
channel.key = key
channel.stateMutex.Unlock()
}
func (channel *Channel) HasMode(mode modes.Mode) bool {

View File

@ -1351,20 +1351,17 @@ func cmodeHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Res
applied = channel.ApplyChannelModeChanges(client, msg.Command == "SAMODE", changes, rb)
}
// save changes to banlist/exceptlist/invexlist
var banlistUpdated, exceptlistUpdated, invexlistUpdated bool
// save changes
var includeFlags uint
for _, change := range applied {
if change.Mode == modes.BanMask {
banlistUpdated = true
} else if change.Mode == modes.ExceptMask {
exceptlistUpdated = true
} else if change.Mode == modes.InviteMask {
invexlistUpdated = true
includeFlags |= IncludeModes
if change.Mode == modes.BanMask || change.Mode == modes.ExceptMask || change.Mode == modes.InviteMask {
includeFlags |= IncludeLists
}
}
if (banlistUpdated || exceptlistUpdated || invexlistUpdated) && channel.IsRegistered() {
go server.channelRegistry.StoreChannel(channel, true)
if channel.IsRegistered() && includeFlags != 0 {
go server.channelRegistry.StoreChannel(channel, includeFlags)
}
// send out changes

View File

@ -84,7 +84,7 @@ Options:
log.Println("database initialized: ", config.Datastore.Path)
}
} else if arguments["upgradedb"].(bool) {
irc.UpgradeDB(config.Datastore.Path)
irc.UpgradeDB(config)
if !arguments["--quiet"].(bool) {
log.Println("database upgraded: ", config.Datastore.Path)
}