mirror of
https://github.com/ergochat/ergo.git
synced 2024-11-14 16:09:32 +01:00
Merge pull request #224 from slingamn/channelkeys.1
Updates to channel persistence
This commit is contained in:
commit
c75d2c91c5
@ -6,6 +6,7 @@
|
|||||||
package irc
|
package irc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/subtle"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
@ -36,11 +37,12 @@ type Channel struct {
|
|||||||
topicSetBy string
|
topicSetBy string
|
||||||
topicSetTime time.Time
|
topicSetTime time.Time
|
||||||
userLimit uint64
|
userLimit uint64
|
||||||
|
accountToUMode map[string]modes.Mode
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewChannel creates a new channel from a `Server` and a `name`
|
// NewChannel creates a new channel from a `Server` and a `name`
|
||||||
// string, which must be unique on the server.
|
// string, which must be unique on the server.
|
||||||
func NewChannel(s *Server, name string, addDefaultModes bool, regInfo *RegisteredChannel) *Channel {
|
func NewChannel(s *Server, name string, regInfo *RegisteredChannel) *Channel {
|
||||||
casefoldedName, err := CasefoldChannel(name)
|
casefoldedName, err := CasefoldChannel(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Error("internal", fmt.Sprintf("Bad channel name %s: %v", name, err))
|
s.logger.Error("internal", fmt.Sprintf("Bad channel name %s: %v", name, err))
|
||||||
@ -59,16 +61,15 @@ func NewChannel(s *Server, name string, addDefaultModes bool, regInfo *Registere
|
|||||||
name: name,
|
name: name,
|
||||||
nameCasefolded: casefoldedName,
|
nameCasefolded: casefoldedName,
|
||||||
server: s,
|
server: s,
|
||||||
}
|
accountToUMode: make(map[string]modes.Mode),
|
||||||
|
|
||||||
if addDefaultModes {
|
|
||||||
for _, mode := range s.DefaultChannelModes() {
|
|
||||||
channel.flags[mode] = true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if regInfo != nil {
|
if regInfo != nil {
|
||||||
channel.applyRegInfo(regInfo)
|
channel.applyRegInfo(regInfo)
|
||||||
|
} else {
|
||||||
|
for _, mode := range s.DefaultChannelModes() {
|
||||||
|
channel.flags[mode] = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return channel
|
return channel
|
||||||
@ -83,6 +84,11 @@ func (channel *Channel) applyRegInfo(chanReg *RegisteredChannel) {
|
|||||||
channel.topicSetTime = chanReg.TopicSetTime
|
channel.topicSetTime = chanReg.TopicSetTime
|
||||||
channel.name = chanReg.Name
|
channel.name = chanReg.Name
|
||||||
channel.createdTime = chanReg.RegisteredAt
|
channel.createdTime = chanReg.RegisteredAt
|
||||||
|
channel.key = chanReg.Key
|
||||||
|
|
||||||
|
for _, mode := range chanReg.Modes {
|
||||||
|
channel.flags[mode] = true
|
||||||
|
}
|
||||||
for _, mask := range chanReg.Banlist {
|
for _, mask := range chanReg.Banlist {
|
||||||
channel.lists[modes.BanMask].Add(mask)
|
channel.lists[modes.BanMask].Add(mask)
|
||||||
}
|
}
|
||||||
@ -92,21 +98,34 @@ func (channel *Channel) applyRegInfo(chanReg *RegisteredChannel) {
|
|||||||
for _, mask := range chanReg.Invitelist {
|
for _, mask := range chanReg.Invitelist {
|
||||||
channel.lists[modes.InviteMask].Add(mask)
|
channel.lists[modes.InviteMask].Add(mask)
|
||||||
}
|
}
|
||||||
|
for account, mode := range chanReg.AccountToUMode {
|
||||||
|
channel.accountToUMode[account] = mode
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// obtain a consistent snapshot of the channel state that can be persisted to the DB
|
// obtain a consistent snapshot of the channel state that can be persisted to the DB
|
||||||
func (channel *Channel) ExportRegistration(includeLists bool) (info RegisteredChannel) {
|
func (channel *Channel) ExportRegistration(includeFlags uint) (info RegisteredChannel) {
|
||||||
channel.stateMutex.RLock()
|
channel.stateMutex.RLock()
|
||||||
defer channel.stateMutex.RUnlock()
|
defer channel.stateMutex.RUnlock()
|
||||||
|
|
||||||
info.Name = channel.name
|
info.Name = channel.name
|
||||||
info.Topic = channel.topic
|
|
||||||
info.TopicSetBy = channel.topicSetBy
|
|
||||||
info.TopicSetTime = channel.topicSetTime
|
|
||||||
info.Founder = channel.registeredFounder
|
info.Founder = channel.registeredFounder
|
||||||
info.RegisteredAt = channel.registeredTime
|
info.RegisteredAt = channel.registeredTime
|
||||||
|
|
||||||
if includeLists {
|
if includeFlags&IncludeTopic != 0 {
|
||||||
|
info.Topic = channel.topic
|
||||||
|
info.TopicSetBy = channel.topicSetBy
|
||||||
|
info.TopicSetTime = channel.topicSetTime
|
||||||
|
}
|
||||||
|
|
||||||
|
if includeFlags&IncludeModes != 0 {
|
||||||
|
info.Key = channel.key
|
||||||
|
for mode := range channel.flags {
|
||||||
|
info.Modes = append(info.Modes, mode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if includeFlags&IncludeLists != 0 {
|
||||||
for mask := range channel.lists[modes.BanMask].masks {
|
for mask := range channel.lists[modes.BanMask].masks {
|
||||||
info.Banlist = append(info.Banlist, mask)
|
info.Banlist = append(info.Banlist, mask)
|
||||||
}
|
}
|
||||||
@ -116,6 +135,10 @@ func (channel *Channel) ExportRegistration(includeLists bool) (info RegisteredCh
|
|||||||
for mask := range channel.lists[modes.InviteMask].masks {
|
for mask := range channel.lists[modes.InviteMask].masks {
|
||||||
info.Invitelist = append(info.Invitelist, mask)
|
info.Invitelist = append(info.Invitelist, mask)
|
||||||
}
|
}
|
||||||
|
info.AccountToUMode = make(map[string]modes.Mode)
|
||||||
|
for account, mode := range channel.accountToUMode {
|
||||||
|
info.AccountToUMode[account] = mode
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
@ -131,6 +154,7 @@ func (channel *Channel) SetRegistered(founder string) error {
|
|||||||
}
|
}
|
||||||
channel.registeredFounder = founder
|
channel.registeredFounder = founder
|
||||||
channel.registeredTime = time.Now()
|
channel.registeredTime = time.Now()
|
||||||
|
channel.accountToUMode[founder] = modes.ChannelFounder
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -338,7 +362,12 @@ func (channel *Channel) IsFull() bool {
|
|||||||
|
|
||||||
// CheckKey returns true if the key is not set or matches the given key.
|
// CheckKey returns true if the key is not set or matches the given key.
|
||||||
func (channel *Channel) CheckKey(key string) bool {
|
func (channel *Channel) CheckKey(key string) bool {
|
||||||
return (channel.key == "") || (channel.key == key)
|
chkey := channel.Key()
|
||||||
|
if chkey == "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return subtle.ConstantTimeCompare([]byte(key), []byte(chkey)) == 1
|
||||||
}
|
}
|
||||||
|
|
||||||
func (channel *Channel) IsEmpty() bool {
|
func (channel *Channel) IsEmpty() bool {
|
||||||
@ -404,21 +433,22 @@ func (channel *Channel) Join(client *Client, key string, rb *ResponseBuffer) {
|
|||||||
|
|
||||||
client.addChannel(channel)
|
client.addChannel(channel)
|
||||||
|
|
||||||
// give channel mode if necessary
|
|
||||||
newChannel := firstJoin && !channel.IsRegistered()
|
|
||||||
var givenMode *modes.Mode
|
|
||||||
account := client.Account()
|
account := client.Account()
|
||||||
cffounder, _ := CasefoldName(channel.registeredFounder)
|
|
||||||
if account != "" && account == cffounder {
|
// give channel mode if necessary
|
||||||
givenMode = &modes.ChannelFounder
|
channel.stateMutex.Lock()
|
||||||
|
newChannel := firstJoin && channel.registeredFounder == ""
|
||||||
|
mode, persistentModeExists := channel.accountToUMode[account]
|
||||||
|
var givenMode *modes.Mode
|
||||||
|
if persistentModeExists {
|
||||||
|
givenMode = &mode
|
||||||
} else if newChannel {
|
} else if newChannel {
|
||||||
givenMode = &modes.ChannelOperator
|
givenMode = &modes.ChannelOperator
|
||||||
}
|
}
|
||||||
if givenMode != nil {
|
if givenMode != nil {
|
||||||
channel.stateMutex.Lock()
|
|
||||||
channel.members[client][*givenMode] = true
|
channel.members[client][*givenMode] = true
|
||||||
channel.stateMutex.Unlock()
|
|
||||||
}
|
}
|
||||||
|
channel.stateMutex.Unlock()
|
||||||
|
|
||||||
if client.capabilities.Has(caps.ExtendedJoin) {
|
if client.capabilities.Has(caps.ExtendedJoin) {
|
||||||
rb.Add(nil, client.nickMaskString, "JOIN", channel.name, client.AccountName(), client.realname)
|
rb.Add(nil, client.nickMaskString, "JOIN", channel.name, client.AccountName(), client.realname)
|
||||||
@ -513,7 +543,7 @@ func (channel *Channel) SetTopic(client *Client, topic string, rb *ResponseBuffe
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
go channel.server.channelRegistry.StoreChannel(channel, false)
|
go channel.server.channelRegistry.StoreChannel(channel, IncludeTopic)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CanSpeak returns true if the client can speak on this channel.
|
// CanSpeak returns true if the client can speak on this channel.
|
||||||
|
@ -65,7 +65,7 @@ func (cm *ChannelManager) Join(client *Client, name string, key string, rb *Resp
|
|||||||
entry = cm.chans[casefoldedName]
|
entry = cm.chans[casefoldedName]
|
||||||
if entry == nil {
|
if entry == nil {
|
||||||
entry = &channelManagerEntry{
|
entry = &channelManagerEntry{
|
||||||
channel: NewChannel(server, name, true, info),
|
channel: NewChannel(server, name, info),
|
||||||
pendingJoins: 0,
|
pendingJoins: 0,
|
||||||
}
|
}
|
||||||
cm.chans[casefoldedName] = entry
|
cm.chans[casefoldedName] = entry
|
||||||
|
@ -6,11 +6,13 @@ package irc
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/oragono/oragono/irc/modes"
|
||||||
"github.com/tidwall/buntdb"
|
"github.com/tidwall/buntdb"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -28,6 +30,9 @@ const (
|
|||||||
keyChannelBanlist = "channel.banlist %s"
|
keyChannelBanlist = "channel.banlist %s"
|
||||||
keyChannelExceptlist = "channel.exceptlist %s"
|
keyChannelExceptlist = "channel.exceptlist %s"
|
||||||
keyChannelInvitelist = "channel.invitelist %s"
|
keyChannelInvitelist = "channel.invitelist %s"
|
||||||
|
keyChannelPassword = "channel.key %s"
|
||||||
|
keyChannelModes = "channel.modes %s"
|
||||||
|
keyChannelAccountToUMode = "channel.accounttoumode %s"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -42,9 +47,26 @@ var (
|
|||||||
keyChannelBanlist,
|
keyChannelBanlist,
|
||||||
keyChannelExceptlist,
|
keyChannelExceptlist,
|
||||||
keyChannelInvitelist,
|
keyChannelInvitelist,
|
||||||
|
keyChannelPassword,
|
||||||
|
keyChannelModes,
|
||||||
|
keyChannelAccountToUMode,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// these are bit flags indicating what part of the channel status is "dirty"
|
||||||
|
// and needs to be read from memory and written to the db
|
||||||
|
const (
|
||||||
|
IncludeInitial uint = 1 << iota
|
||||||
|
IncludeTopic
|
||||||
|
IncludeModes
|
||||||
|
IncludeLists
|
||||||
|
)
|
||||||
|
|
||||||
|
// this is an OR of all possible flags
|
||||||
|
const (
|
||||||
|
IncludeAllChannelAttrs = ^uint(0)
|
||||||
|
)
|
||||||
|
|
||||||
// RegisteredChannel holds details about a given registered channel.
|
// RegisteredChannel holds details about a given registered channel.
|
||||||
type RegisteredChannel struct {
|
type RegisteredChannel struct {
|
||||||
// Name of the channel.
|
// Name of the channel.
|
||||||
@ -59,6 +81,12 @@ type RegisteredChannel struct {
|
|||||||
TopicSetBy string
|
TopicSetBy string
|
||||||
// TopicSetTime represents the time the topic was set.
|
// TopicSetTime represents the time the topic was set.
|
||||||
TopicSetTime time.Time
|
TopicSetTime time.Time
|
||||||
|
// Modes represents the channel modes
|
||||||
|
Modes []modes.Mode
|
||||||
|
// Key represents the channel key / password
|
||||||
|
Key string
|
||||||
|
// AccountToUMode maps user accounts to their persistent channel modes (e.g., +q, +h)
|
||||||
|
AccountToUMode map[string]modes.Mode
|
||||||
// Banlist represents the bans set on the channel.
|
// Banlist represents the bans set on the channel.
|
||||||
Banlist []string
|
Banlist []string
|
||||||
// Exceptlist represents the exceptions set on the channel.
|
// Exceptlist represents the exceptions set on the channel.
|
||||||
@ -87,7 +115,7 @@ func NewChannelRegistry(server *Server) *ChannelRegistry {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// StoreChannel obtains a consistent view of a channel, then persists it to the store.
|
// StoreChannel obtains a consistent view of a channel, then persists it to the store.
|
||||||
func (reg *ChannelRegistry) StoreChannel(channel *Channel, includeLists bool) {
|
func (reg *ChannelRegistry) StoreChannel(channel *Channel, includeFlags uint) {
|
||||||
if !reg.server.ChannelRegistrationEnabled() {
|
if !reg.server.ChannelRegistrationEnabled() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -96,14 +124,14 @@ func (reg *ChannelRegistry) StoreChannel(channel *Channel, includeLists bool) {
|
|||||||
defer reg.Unlock()
|
defer reg.Unlock()
|
||||||
|
|
||||||
key := channel.NameCasefolded()
|
key := channel.NameCasefolded()
|
||||||
info := channel.ExportRegistration(includeLists)
|
info := channel.ExportRegistration(includeFlags)
|
||||||
if info.Founder == "" {
|
if info.Founder == "" {
|
||||||
// sanity check, don't try to store an unregistered channel
|
// sanity check, don't try to store an unregistered channel
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
reg.server.store.Update(func(tx *buntdb.Tx) error {
|
reg.server.store.Update(func(tx *buntdb.Tx) error {
|
||||||
reg.saveChannel(tx, key, info, includeLists)
|
reg.saveChannel(tx, key, info, includeFlags)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -132,9 +160,17 @@ func (reg *ChannelRegistry) LoadChannel(nameCasefolded string) (info *Registered
|
|||||||
topicSetBy, _ := tx.Get(fmt.Sprintf(keyChannelTopicSetBy, channelKey))
|
topicSetBy, _ := tx.Get(fmt.Sprintf(keyChannelTopicSetBy, channelKey))
|
||||||
topicSetTime, _ := tx.Get(fmt.Sprintf(keyChannelTopicSetTime, channelKey))
|
topicSetTime, _ := tx.Get(fmt.Sprintf(keyChannelTopicSetTime, channelKey))
|
||||||
topicSetTimeInt, _ := strconv.ParseInt(topicSetTime, 10, 64)
|
topicSetTimeInt, _ := strconv.ParseInt(topicSetTime, 10, 64)
|
||||||
|
password, _ := tx.Get(fmt.Sprintf(keyChannelPassword, channelKey))
|
||||||
|
modeString, _ := tx.Get(fmt.Sprintf(keyChannelModes, channelKey))
|
||||||
banlistString, _ := tx.Get(fmt.Sprintf(keyChannelBanlist, channelKey))
|
banlistString, _ := tx.Get(fmt.Sprintf(keyChannelBanlist, channelKey))
|
||||||
exceptlistString, _ := tx.Get(fmt.Sprintf(keyChannelExceptlist, channelKey))
|
exceptlistString, _ := tx.Get(fmt.Sprintf(keyChannelExceptlist, channelKey))
|
||||||
invitelistString, _ := tx.Get(fmt.Sprintf(keyChannelInvitelist, channelKey))
|
invitelistString, _ := tx.Get(fmt.Sprintf(keyChannelInvitelist, channelKey))
|
||||||
|
accountToUModeString, _ := tx.Get(fmt.Sprintf(keyChannelAccountToUMode, channelKey))
|
||||||
|
|
||||||
|
modeSlice := make([]modes.Mode, len(modeString))
|
||||||
|
for i, mode := range modeString {
|
||||||
|
modeSlice[i] = modes.Mode(mode)
|
||||||
|
}
|
||||||
|
|
||||||
var banlist []string
|
var banlist []string
|
||||||
_ = json.Unmarshal([]byte(banlistString), &banlist)
|
_ = json.Unmarshal([]byte(banlistString), &banlist)
|
||||||
@ -142,6 +178,8 @@ func (reg *ChannelRegistry) LoadChannel(nameCasefolded string) (info *Registered
|
|||||||
_ = json.Unmarshal([]byte(exceptlistString), &exceptlist)
|
_ = json.Unmarshal([]byte(exceptlistString), &exceptlist)
|
||||||
var invitelist []string
|
var invitelist []string
|
||||||
_ = json.Unmarshal([]byte(invitelistString), &invitelist)
|
_ = json.Unmarshal([]byte(invitelistString), &invitelist)
|
||||||
|
accountToUMode := make(map[string]modes.Mode)
|
||||||
|
_ = json.Unmarshal([]byte(accountToUModeString), &accountToUMode)
|
||||||
|
|
||||||
info = &RegisteredChannel{
|
info = &RegisteredChannel{
|
||||||
Name: name,
|
Name: name,
|
||||||
@ -150,9 +188,12 @@ func (reg *ChannelRegistry) LoadChannel(nameCasefolded string) (info *Registered
|
|||||||
Topic: topic,
|
Topic: topic,
|
||||||
TopicSetBy: topicSetBy,
|
TopicSetBy: topicSetBy,
|
||||||
TopicSetTime: time.Unix(topicSetTimeInt, 0),
|
TopicSetTime: time.Unix(topicSetTimeInt, 0),
|
||||||
|
Key: password,
|
||||||
|
Modes: modeSlice,
|
||||||
Banlist: banlist,
|
Banlist: banlist,
|
||||||
Exceptlist: exceptlist,
|
Exceptlist: exceptlist,
|
||||||
Invitelist: invitelist,
|
Invitelist: invitelist,
|
||||||
|
AccountToUMode: accountToUMode,
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
@ -170,17 +211,17 @@ func (reg *ChannelRegistry) Rename(channel *Channel, casefoldedOldName string) {
|
|||||||
reg.Lock()
|
reg.Lock()
|
||||||
defer reg.Unlock()
|
defer reg.Unlock()
|
||||||
|
|
||||||
includeLists := true
|
includeFlags := IncludeAllChannelAttrs
|
||||||
oldKey := casefoldedOldName
|
oldKey := casefoldedOldName
|
||||||
key := channel.NameCasefolded()
|
key := channel.NameCasefolded()
|
||||||
info := channel.ExportRegistration(includeLists)
|
info := channel.ExportRegistration(includeFlags)
|
||||||
if info.Founder == "" {
|
if info.Founder == "" {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
reg.server.store.Update(func(tx *buntdb.Tx) error {
|
reg.server.store.Update(func(tx *buntdb.Tx) error {
|
||||||
reg.deleteChannel(tx, oldKey, info)
|
reg.deleteChannel(tx, oldKey, info)
|
||||||
reg.saveChannel(tx, key, info, includeLists)
|
reg.saveChannel(tx, key, info, includeFlags)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -204,21 +245,37 @@ func (reg *ChannelRegistry) deleteChannel(tx *buntdb.Tx, key string, info Regist
|
|||||||
}
|
}
|
||||||
|
|
||||||
// saveChannel saves a channel to the store.
|
// saveChannel saves a channel to the store.
|
||||||
func (reg *ChannelRegistry) saveChannel(tx *buntdb.Tx, channelKey string, channelInfo RegisteredChannel, includeLists bool) {
|
func (reg *ChannelRegistry) saveChannel(tx *buntdb.Tx, channelKey string, channelInfo RegisteredChannel, includeFlags uint) {
|
||||||
|
if includeFlags&IncludeInitial != 0 {
|
||||||
tx.Set(fmt.Sprintf(keyChannelExists, channelKey), "1", nil)
|
tx.Set(fmt.Sprintf(keyChannelExists, channelKey), "1", nil)
|
||||||
tx.Set(fmt.Sprintf(keyChannelName, channelKey), channelInfo.Name, nil)
|
tx.Set(fmt.Sprintf(keyChannelName, channelKey), channelInfo.Name, nil)
|
||||||
tx.Set(fmt.Sprintf(keyChannelRegTime, channelKey), strconv.FormatInt(channelInfo.RegisteredAt.Unix(), 10), nil)
|
tx.Set(fmt.Sprintf(keyChannelRegTime, channelKey), strconv.FormatInt(channelInfo.RegisteredAt.Unix(), 10), nil)
|
||||||
tx.Set(fmt.Sprintf(keyChannelFounder, channelKey), channelInfo.Founder, nil)
|
tx.Set(fmt.Sprintf(keyChannelFounder, channelKey), channelInfo.Founder, nil)
|
||||||
tx.Set(fmt.Sprintf(keyChannelTopic, channelKey), channelInfo.Topic, nil)
|
}
|
||||||
tx.Set(fmt.Sprintf(keyChannelTopicSetBy, channelKey), channelInfo.TopicSetBy, nil)
|
|
||||||
tx.Set(fmt.Sprintf(keyChannelTopicSetTime, channelKey), strconv.FormatInt(channelInfo.TopicSetTime.Unix(), 10), nil)
|
|
||||||
|
|
||||||
if includeLists {
|
if includeFlags&IncludeTopic != 0 {
|
||||||
|
tx.Set(fmt.Sprintf(keyChannelTopic, channelKey), channelInfo.Topic, nil)
|
||||||
|
tx.Set(fmt.Sprintf(keyChannelTopicSetTime, channelKey), strconv.FormatInt(channelInfo.TopicSetTime.Unix(), 10), nil)
|
||||||
|
tx.Set(fmt.Sprintf(keyChannelTopicSetBy, channelKey), channelInfo.TopicSetBy, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
if includeFlags&IncludeModes != 0 {
|
||||||
|
tx.Set(fmt.Sprintf(keyChannelPassword, channelKey), channelInfo.Key, nil)
|
||||||
|
modeStrings := make([]string, len(channelInfo.Modes))
|
||||||
|
for i, mode := range channelInfo.Modes {
|
||||||
|
modeStrings[i] = string(mode)
|
||||||
|
}
|
||||||
|
tx.Set(fmt.Sprintf(keyChannelModes, channelKey), strings.Join(modeStrings, ""), nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
if includeFlags&IncludeLists != 0 {
|
||||||
banlistString, _ := json.Marshal(channelInfo.Banlist)
|
banlistString, _ := json.Marshal(channelInfo.Banlist)
|
||||||
tx.Set(fmt.Sprintf(keyChannelBanlist, channelKey), string(banlistString), nil)
|
tx.Set(fmt.Sprintf(keyChannelBanlist, channelKey), string(banlistString), nil)
|
||||||
exceptlistString, _ := json.Marshal(channelInfo.Exceptlist)
|
exceptlistString, _ := json.Marshal(channelInfo.Exceptlist)
|
||||||
tx.Set(fmt.Sprintf(keyChannelExceptlist, channelKey), string(exceptlistString), nil)
|
tx.Set(fmt.Sprintf(keyChannelExceptlist, channelKey), string(exceptlistString), nil)
|
||||||
invitelistString, _ := json.Marshal(channelInfo.Invitelist)
|
invitelistString, _ := json.Marshal(channelInfo.Invitelist)
|
||||||
tx.Set(fmt.Sprintf(keyChannelInvitelist, channelKey), string(invitelistString), nil)
|
tx.Set(fmt.Sprintf(keyChannelInvitelist, channelKey), string(invitelistString), nil)
|
||||||
|
accountToUModeString, _ := json.Marshal(channelInfo.AccountToUMode)
|
||||||
|
tx.Set(fmt.Sprintf(keyChannelAccountToUMode, channelKey), string(accountToUModeString), nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -252,7 +252,7 @@ func csRegisterHandler(server *Server, client *Client, command, params string, r
|
|||||||
}
|
}
|
||||||
|
|
||||||
// registration was successful: make the database reflect it
|
// registration was successful: make the database reflect it
|
||||||
go server.channelRegistry.StoreChannel(channelInfo, true)
|
go server.channelRegistry.StoreChannel(channelInfo, IncludeAllChannelAttrs)
|
||||||
|
|
||||||
csNotice(rb, fmt.Sprintf(client.t("Channel %s successfully registered"), channelName))
|
csNotice(rb, fmt.Sprintf(client.t("Channel %s successfully registered"), channelName))
|
||||||
|
|
||||||
|
121
irc/database.go
121
irc/database.go
@ -6,11 +6,13 @@ package irc
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/oragono/oragono/irc/modes"
|
||||||
"github.com/oragono/oragono/irc/passwd"
|
"github.com/oragono/oragono/irc/passwd"
|
||||||
|
|
||||||
"github.com/tidwall/buntdb"
|
"github.com/tidwall/buntdb"
|
||||||
@ -20,11 +22,22 @@ const (
|
|||||||
// 'version' of the database schema
|
// 'version' of the database schema
|
||||||
keySchemaVersion = "db.version"
|
keySchemaVersion = "db.version"
|
||||||
// latest schema of the db
|
// latest schema of the db
|
||||||
latestDbSchema = "2"
|
latestDbSchema = "3"
|
||||||
// key for the primary salt used by the ircd
|
// key for the primary salt used by the ircd
|
||||||
keySalt = "crypto.salt"
|
keySalt = "crypto.salt"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type SchemaChanger func(*Config, *buntdb.Tx) error
|
||||||
|
|
||||||
|
type SchemaChange struct {
|
||||||
|
InitialVersion string // the change will take this version
|
||||||
|
TargetVersion string // and transform it into this version
|
||||||
|
Changer SchemaChanger
|
||||||
|
}
|
||||||
|
|
||||||
|
// maps an initial version to a schema change capable of upgrading it
|
||||||
|
var schemaChanges map[string]SchemaChange
|
||||||
|
|
||||||
// InitDB creates the database.
|
// InitDB creates the database.
|
||||||
func InitDB(path string) {
|
func InitDB(path string) {
|
||||||
// prepare kvstore db
|
// prepare kvstore db
|
||||||
@ -46,7 +59,7 @@ func InitDB(path string) {
|
|||||||
tx.Set(keySalt, encodedSalt, nil)
|
tx.Set(keySalt, encodedSalt, nil)
|
||||||
|
|
||||||
// set schema version
|
// set schema version
|
||||||
tx.Set(keySchemaVersion, "2", nil)
|
tx.Set(keySchemaVersion, latestDbSchema, nil)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -82,20 +95,45 @@ func OpenDatabase(path string) (*buntdb.DB, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UpgradeDB upgrades the datastore to the latest schema.
|
// UpgradeDB upgrades the datastore to the latest schema.
|
||||||
func UpgradeDB(path string) {
|
func UpgradeDB(config *Config) {
|
||||||
store, err := buntdb.Open(path)
|
store, err := buntdb.Open(config.Datastore.Path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(fmt.Sprintf("Failed to open datastore: %s", err.Error()))
|
log.Fatal(fmt.Sprintf("Failed to open datastore: %s", err.Error()))
|
||||||
}
|
}
|
||||||
defer store.Close()
|
defer store.Close()
|
||||||
|
|
||||||
|
var version string
|
||||||
err = store.Update(func(tx *buntdb.Tx) error {
|
err = store.Update(func(tx *buntdb.Tx) error {
|
||||||
version, _ := tx.Get(keySchemaVersion)
|
for {
|
||||||
|
version, _ = tx.Get(keySchemaVersion)
|
||||||
|
change, schemaNeedsChange := schemaChanges[version]
|
||||||
|
if !schemaNeedsChange {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
log.Println("attempting to update store from version " + version)
|
||||||
|
err := change.Changer(config, tx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, _, err = tx.Set(keySchemaVersion, change.TargetVersion, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
log.Println("successfully updated store to version " + change.TargetVersion)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal("Could not update datastore:", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func schemaChangeV1toV2(config *Config, tx *buntdb.Tx) error {
|
||||||
// == version 1 -> 2 ==
|
// == version 1 -> 2 ==
|
||||||
// account key changes and account.verified key bugfix.
|
// account key changes and account.verified key bugfix.
|
||||||
if version == "1" {
|
|
||||||
log.Println("Updating store v1 to v2")
|
|
||||||
|
|
||||||
var keysToRemove []string
|
var keysToRemove []string
|
||||||
newKeys := make(map[string]string)
|
newKeys := make(map[string]string)
|
||||||
@ -126,14 +164,73 @@ func UpgradeDB(path string) {
|
|||||||
tx.Set(key, value, nil)
|
tx.Set(key, value, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
tx.Set(keySchemaVersion, "2", nil)
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1. channel founder names should be casefolded
|
||||||
|
// 2. founder should be explicitly granted the ChannelFounder user mode
|
||||||
|
// 3. explicitly initialize stored channel modes to the server default values
|
||||||
|
func schemaChangeV2ToV3(config *Config, tx *buntdb.Tx) error {
|
||||||
|
var channels []string
|
||||||
|
prefix := "channel.exists "
|
||||||
|
tx.AscendGreaterOrEqual("", prefix, func(key, value string) bool {
|
||||||
|
if !strings.HasPrefix(key, prefix) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
chname := strings.TrimPrefix(key, prefix)
|
||||||
|
channels = append(channels, chname)
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
// founder names should be casefolded
|
||||||
|
// founder should be explicitly granted the ChannelFounder user mode
|
||||||
|
for _, channel := range channels {
|
||||||
|
founderKey := "channel.founder " + channel
|
||||||
|
founder, _ := tx.Get(founderKey)
|
||||||
|
if founder != "" {
|
||||||
|
founder, err := CasefoldName(founder)
|
||||||
|
if err == nil {
|
||||||
|
tx.Set(founderKey, founder, nil)
|
||||||
|
accountToUmode := map[string]modes.Mode{
|
||||||
|
founder: modes.ChannelFounder,
|
||||||
|
}
|
||||||
|
atustr, _ := json.Marshal(accountToUmode)
|
||||||
|
tx.Set("channel.accounttoumode "+channel, string(atustr), nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// explicitly store the channel modes
|
||||||
|
defaultModes := ParseDefaultChannelModes(config)
|
||||||
|
modeStrings := make([]string, len(defaultModes))
|
||||||
|
for i, mode := range defaultModes {
|
||||||
|
modeStrings[i] = string(mode)
|
||||||
|
}
|
||||||
|
defaultModeString := strings.Join(modeStrings, "")
|
||||||
|
for _, channel := range channels {
|
||||||
|
tx.Set("channel.modes "+channel, defaultModeString, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
}
|
||||||
if err != nil {
|
|
||||||
log.Fatal("Could not update datastore:", err.Error())
|
func init() {
|
||||||
|
allChanges := []SchemaChange{
|
||||||
|
SchemaChange{
|
||||||
|
InitialVersion: "1",
|
||||||
|
TargetVersion: "2",
|
||||||
|
Changer: schemaChangeV1toV2,
|
||||||
|
},
|
||||||
|
SchemaChange{
|
||||||
|
InitialVersion: "2",
|
||||||
|
TargetVersion: "3",
|
||||||
|
Changer: schemaChangeV2ToV3,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
// build the index
|
||||||
|
schemaChanges = make(map[string]SchemaChange)
|
||||||
|
for _, change := range allChanges {
|
||||||
|
schemaChanges[change.InitialVersion] = change
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -256,8 +256,8 @@ func (channel *Channel) Key() string {
|
|||||||
|
|
||||||
func (channel *Channel) setKey(key string) {
|
func (channel *Channel) setKey(key string) {
|
||||||
channel.stateMutex.Lock()
|
channel.stateMutex.Lock()
|
||||||
|
defer channel.stateMutex.Unlock()
|
||||||
channel.key = key
|
channel.key = key
|
||||||
channel.stateMutex.Unlock()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (channel *Channel) HasMode(mode modes.Mode) bool {
|
func (channel *Channel) HasMode(mode modes.Mode) bool {
|
||||||
|
@ -1351,20 +1351,17 @@ func cmodeHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Res
|
|||||||
applied = channel.ApplyChannelModeChanges(client, msg.Command == "SAMODE", changes, rb)
|
applied = channel.ApplyChannelModeChanges(client, msg.Command == "SAMODE", changes, rb)
|
||||||
}
|
}
|
||||||
|
|
||||||
// save changes to banlist/exceptlist/invexlist
|
// save changes
|
||||||
var banlistUpdated, exceptlistUpdated, invexlistUpdated bool
|
var includeFlags uint
|
||||||
for _, change := range applied {
|
for _, change := range applied {
|
||||||
if change.Mode == modes.BanMask {
|
includeFlags |= IncludeModes
|
||||||
banlistUpdated = true
|
if change.Mode == modes.BanMask || change.Mode == modes.ExceptMask || change.Mode == modes.InviteMask {
|
||||||
} else if change.Mode == modes.ExceptMask {
|
includeFlags |= IncludeLists
|
||||||
exceptlistUpdated = true
|
|
||||||
} else if change.Mode == modes.InviteMask {
|
|
||||||
invexlistUpdated = true
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (banlistUpdated || exceptlistUpdated || invexlistUpdated) && channel.IsRegistered() {
|
if channel.IsRegistered() && includeFlags != 0 {
|
||||||
go server.channelRegistry.StoreChannel(channel, true)
|
go server.channelRegistry.StoreChannel(channel, includeFlags)
|
||||||
}
|
}
|
||||||
|
|
||||||
// send out changes
|
// send out changes
|
||||||
|
@ -84,7 +84,7 @@ Options:
|
|||||||
log.Println("database initialized: ", config.Datastore.Path)
|
log.Println("database initialized: ", config.Datastore.Path)
|
||||||
}
|
}
|
||||||
} else if arguments["upgradedb"].(bool) {
|
} else if arguments["upgradedb"].(bool) {
|
||||||
irc.UpgradeDB(config.Datastore.Path)
|
irc.UpgradeDB(config)
|
||||||
if !arguments["--quiet"].(bool) {
|
if !arguments["--quiet"].(bool) {
|
||||||
log.Println("database upgraded: ", config.Datastore.Path)
|
log.Println("database upgraded: ", config.Datastore.Path)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user