3
0
mirror of https://github.com/ergochat/ergo.git synced 2024-12-22 10:42:52 +01:00

Merge pull request #352 from slingamn/chanreglimit.1

track channel registrations per account
This commit is contained in:
Daniel Oaks 2019-02-18 07:08:57 +10:00 committed by GitHub
commit 7cf8aaccf6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 154 additions and 14 deletions

View File

@ -33,6 +33,7 @@ const (
keyAccountEnforcement = "account.customenforcement %s" keyAccountEnforcement = "account.customenforcement %s"
keyAccountVHost = "account.vhost %s" keyAccountVHost = "account.vhost %s"
keyCertToAccount = "account.creds.certfp %s" keyCertToAccount = "account.creds.certfp %s"
keyAccountChannels = "account.channels %s"
keyVHostQueueAcctToId = "vhostQueue %s" keyVHostQueueAcctToId = "vhostQueue %s"
vhostRequestIdx = "vhostQueue" vhostRequestIdx = "vhostQueue"
@ -856,9 +857,25 @@ func (am *AccountManager) Unregister(account string) error {
nicksKey := fmt.Sprintf(keyAccountAdditionalNicks, casefoldedAccount) nicksKey := fmt.Sprintf(keyAccountAdditionalNicks, casefoldedAccount)
vhostKey := fmt.Sprintf(keyAccountVHost, casefoldedAccount) vhostKey := fmt.Sprintf(keyAccountVHost, casefoldedAccount)
vhostQueueKey := fmt.Sprintf(keyVHostQueueAcctToId, casefoldedAccount) vhostQueueKey := fmt.Sprintf(keyVHostQueueAcctToId, casefoldedAccount)
channelsKey := fmt.Sprintf(keyAccountChannels, casefoldedAccount)
var clients []*Client var clients []*Client
var registeredChannels []string
// on our way out, unregister all the account's channels and delete them from the db
defer func() {
for _, channelName := range registeredChannels {
info := am.server.channelRegistry.LoadChannel(channelName)
if info != nil && info.Founder == casefoldedAccount {
am.server.channelRegistry.Delete(channelName, *info)
}
channel := am.server.channels.Get(channelName)
if channel != nil {
channel.SetUnregistered(casefoldedAccount)
}
}
}()
var credText string var credText string
var rawNicks string var rawNicks string
@ -866,6 +883,7 @@ func (am *AccountManager) Unregister(account string) error {
defer am.serialCacheUpdateMutex.Unlock() defer am.serialCacheUpdateMutex.Unlock()
var accountName string var accountName string
var channelsStr string
am.server.store.Update(func(tx *buntdb.Tx) error { am.server.store.Update(func(tx *buntdb.Tx) error {
tx.Delete(accountKey) tx.Delete(accountKey)
accountName, _ = tx.Get(accountNameKey) accountName, _ = tx.Get(accountNameKey)
@ -879,6 +897,9 @@ func (am *AccountManager) Unregister(account string) error {
credText, err = tx.Get(credentialsKey) credText, err = tx.Get(credentialsKey)
tx.Delete(credentialsKey) tx.Delete(credentialsKey)
tx.Delete(vhostKey) tx.Delete(vhostKey)
channelsStr, _ = tx.Get(channelsKey)
tx.Delete(channelsKey)
_, err := tx.Delete(vhostQueueKey) _, err := tx.Delete(vhostQueueKey)
am.decrementVHostQueueCount(casefoldedAccount, err) am.decrementVHostQueueCount(casefoldedAccount, err)
return nil return nil
@ -899,6 +920,7 @@ func (am *AccountManager) Unregister(account string) error {
skeleton, _ := Skeleton(accountName) skeleton, _ := Skeleton(accountName)
additionalNicks := unmarshalReservedNicks(rawNicks) additionalNicks := unmarshalReservedNicks(rawNicks)
registeredChannels = unmarshalRegisteredChannels(channelsStr)
am.Lock() am.Lock()
defer am.Unlock() defer am.Unlock()
@ -925,9 +947,32 @@ func (am *AccountManager) Unregister(account string) error {
if err != nil { if err != nil {
return errAccountDoesNotExist return errAccountDoesNotExist
} }
return nil return nil
} }
func unmarshalRegisteredChannels(channelsStr string) (result []string) {
if channelsStr != "" {
result = strings.Split(channelsStr, ",")
}
return
}
func (am *AccountManager) ChannelsForAccount(account string) (channels []string) {
cfaccount, err := CasefoldName(account)
if err != nil {
return
}
var channelStr string
key := fmt.Sprintf(keyAccountChannels, cfaccount)
am.server.store.View(func(tx *buntdb.Tx) error {
channelStr, _ = tx.Get(key)
return nil
})
return unmarshalRegisteredChannels(channelStr)
}
func (am *AccountManager) AuthenticateByCertFP(client *Client) error { func (am *AccountManager) AuthenticateByCertFP(client *Client) error {
if client.certfp == "" { if client.certfp == "" {
return errAccountInvalidCredentials return errAccountInvalidCredentials

View File

@ -165,10 +165,13 @@ func (channel *Channel) SetRegistered(founder string) error {
} }
// SetUnregistered deletes the channel's registration information. // SetUnregistered deletes the channel's registration information.
func (channel *Channel) SetUnregistered() { func (channel *Channel) SetUnregistered(expectedFounder string) {
channel.stateMutex.Lock() channel.stateMutex.Lock()
defer channel.stateMutex.Unlock() defer channel.stateMutex.Unlock()
if channel.registeredFounder != expectedFounder {
return
}
channel.registeredFounder = "" channel.registeredFounder = ""
var zeroTime time.Time var zeroTime time.Time
channel.registeredTime = zeroTime channel.registeredTime = zeroTime

View File

@ -254,14 +254,43 @@ func (reg *ChannelRegistry) deleteChannel(tx *buntdb.Tx, key string, info Regist
for _, keyFmt := range channelKeyStrings { for _, keyFmt := range channelKeyStrings {
tx.Delete(fmt.Sprintf(keyFmt, key)) tx.Delete(fmt.Sprintf(keyFmt, key))
} }
// remove this channel from the client's list of registered channels
channelsKey := fmt.Sprintf(keyAccountChannels, info.Founder)
channelsStr, err := tx.Get(channelsKey)
if err == buntdb.ErrNotFound {
return
}
registeredChannels := unmarshalRegisteredChannels(channelsStr)
var nowRegisteredChannels []string
for _, channel := range registeredChannels {
if channel != key {
nowRegisteredChannels = append(nowRegisteredChannels, channel)
}
}
tx.Set(channelsKey, strings.Join(nowRegisteredChannels, ","), nil)
} }
} }
} }
// saveChannel saves a channel to the store. // saveChannel saves a channel to the store.
func (reg *ChannelRegistry) saveChannel(tx *buntdb.Tx, channelKey string, channelInfo RegisteredChannel, includeFlags uint) { func (reg *ChannelRegistry) saveChannel(tx *buntdb.Tx, channelKey string, channelInfo RegisteredChannel, includeFlags uint) {
// maintain the mapping of account -> registered channels
chanExistsKey := fmt.Sprintf(keyChannelExists, channelKey)
_, existsErr := tx.Get(chanExistsKey)
if existsErr == buntdb.ErrNotFound {
// this is a new registration, need to update account-to-channels
accountChannelsKey := fmt.Sprintf(keyAccountChannels, channelInfo.Founder)
alreadyChannels, _ := tx.Get(accountChannelsKey)
newChannels := channelKey // this is the casefolded channel name
if alreadyChannels != "" {
newChannels = fmt.Sprintf("%s,%s", alreadyChannels, newChannels)
}
tx.Set(accountChannelsKey, newChannels, nil)
}
if includeFlags&IncludeInitial != 0 { if includeFlags&IncludeInitial != 0 {
tx.Set(fmt.Sprintf(keyChannelExists, channelKey), "1", nil) tx.Set(chanExistsKey, "1", nil)
tx.Set(fmt.Sprintf(keyChannelName, channelKey), channelInfo.Name, nil) tx.Set(fmt.Sprintf(keyChannelName, channelKey), channelInfo.Name, nil)
tx.Set(fmt.Sprintf(keyChannelRegTime, channelKey), strconv.FormatInt(channelInfo.RegisteredAt.Unix(), 10), nil) tx.Set(fmt.Sprintf(keyChannelRegTime, channelKey), strconv.FormatInt(channelInfo.RegisteredAt.Unix(), 10), nil)
tx.Set(fmt.Sprintf(keyChannelFounder, channelKey), channelInfo.Founder, nil) tx.Set(fmt.Sprintf(keyChannelFounder, channelKey), channelInfo.Founder, nil)

View File

@ -224,8 +224,15 @@ func csRegisterHandler(server *Server, client *Client, command string, params []
return return
} }
account := client.Account()
channelsAlreadyRegistered := server.accounts.ChannelsForAccount(account)
if server.Config().Channels.Registration.MaxChannelsPerAccount <= len(channelsAlreadyRegistered) {
csNotice(rb, client.t("You have already registered the maximum number of channels; try dropping some with /CS UNREGISTER"))
return
}
// this provides the synchronization that allows exactly one registration of the channel: // this provides the synchronization that allows exactly one registration of the channel:
err = channelInfo.SetRegistered(client.Account()) err = channelInfo.SetRegistered(account)
if err != nil { if err != nil {
csNotice(rb, err.Error()) csNotice(rb, err.Error())
return return
@ -270,11 +277,13 @@ func csUnregisterHandler(server *Server, client *Client, command string, params
return return
} }
hasPrivs := client.HasRoleCapabs("chanreg") founder := channel.Founder()
if !hasPrivs { if founder == "" {
founder := channel.Founder() csNotice(rb, client.t("That channel is not registered"))
hasPrivs = founder != "" && founder == client.Account() return
} }
hasPrivs := client.HasRoleCapabs("chanreg") || founder == client.Account()
if !hasPrivs { if !hasPrivs {
csNotice(rb, client.t("Insufficient privileges")) csNotice(rb, client.t("Insufficient privileges"))
return return
@ -288,8 +297,8 @@ func csUnregisterHandler(server *Server, client *Client, command string, params
return return
} }
channel.SetUnregistered() channel.SetUnregistered(founder)
go server.channelRegistry.Delete(channelKey, info) server.channelRegistry.Delete(channelKey, info)
csNotice(rb, fmt.Sprintf(client.t("Channel %s is now unregistered"), channelKey)) csNotice(rb, fmt.Sprintf(client.t("Channel %s is now unregistered"), channelKey))
} }

View File

@ -181,7 +181,8 @@ type NickReservationConfig struct {
// ChannelRegistrationConfig controls channel registration. // ChannelRegistrationConfig controls channel registration.
type ChannelRegistrationConfig struct { type ChannelRegistrationConfig struct {
Enabled bool Enabled bool
MaxChannelsPerAccount int `yaml:"max-channels-per-account"`
} }
// OperClassConfig defines a specific operator class. // OperClassConfig defines a specific operator class.
@ -293,9 +294,10 @@ type Config struct {
Accounts AccountConfig Accounts AccountConfig
Channels struct { Channels struct {
DefaultModes *string `yaml:"default-modes"` DefaultModes *string `yaml:"default-modes"`
defaultModes modes.Modes defaultModes modes.Modes
Registration ChannelRegistrationConfig MaxChannelsPerClient int `yaml:"max-channels-per-client"`
Registration ChannelRegistrationConfig
} }
OperClasses map[string]*OperClassConfig `yaml:"oper-classes"` OperClasses map[string]*OperClassConfig `yaml:"oper-classes"`
@ -789,6 +791,13 @@ func LoadConfig(filename string) (config *Config, err error) {
config.Accounts.Registration.BcryptCost = passwd.DefaultCost config.Accounts.Registration.BcryptCost = passwd.DefaultCost
} }
if config.Channels.MaxChannelsPerClient == 0 {
config.Channels.MaxChannelsPerClient = 100
}
if config.Channels.Registration.MaxChannelsPerAccount == 0 {
config.Channels.Registration.MaxChannelsPerAccount = 15
}
// in the current implementation, we disable history by creating a history buffer // in the current implementation, we disable history by creating a history buffer
// with zero capacity. but the `enabled` config option MUST be respected regardless // with zero capacity. but the `enabled` config option MUST be respected regardless
// of this detail // of this detail

View File

@ -22,7 +22,7 @@ const (
// 'version' of the database schema // 'version' of the database schema
keySchemaVersion = "db.version" keySchemaVersion = "db.version"
// latest schema of the db // latest schema of the db
latestDbSchema = "4" latestDbSchema = "5"
) )
type SchemaChanger func(*Config, *buntdb.Tx) error type SchemaChanger func(*Config, *buntdb.Tx) error
@ -390,6 +390,25 @@ func schemaChangeV3ToV4(config *Config, tx *buntdb.Tx) error {
return nil return nil
} }
// create new key tracking channels that belong to an account
func schemaChangeV4ToV5(config *Config, tx *buntdb.Tx) error {
founderToChannels := make(map[string][]string)
prefix := "channel.founder "
tx.AscendGreaterOrEqual("", prefix, func(key, value string) bool {
if !strings.HasPrefix(key, prefix) {
return false
}
channel := strings.TrimPrefix(key, prefix)
founderToChannels[value] = append(founderToChannels[value], channel)
return true
})
for founder, channels := range founderToChannels {
tx.Set(fmt.Sprintf("account.channels %s", founder), strings.Join(channels, ","), nil)
}
return nil
}
func init() { func init() {
allChanges := []SchemaChange{ allChanges := []SchemaChange{
{ {
@ -407,6 +426,11 @@ func init() {
TargetVersion: "4", TargetVersion: "4",
Changer: schemaChangeV3ToV4, Changer: schemaChangeV3ToV4,
}, },
{
InitialVersion: "4",
TargetVersion: "5",
Changer: schemaChangeV4ToV5,
},
} }
// build the index // build the index

View File

@ -200,6 +200,12 @@ func (client *Client) Channels() (result []*Channel) {
return return
} }
func (client *Client) NumChannels() int {
client.stateMutex.RLock()
defer client.stateMutex.RUnlock()
return len(client.channels)
}
func (client *Client) WhoWas() (result WhoWas) { func (client *Client) WhoWas() (result WhoWas) {
return client.Details().WhoWas return client.Details().WhoWas
} }

View File

@ -1159,7 +1159,13 @@ func joinHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Resp
keys = strings.Split(msg.Params[1], ",") keys = strings.Split(msg.Params[1], ",")
} }
config := server.Config()
oper := client.Oper()
for i, name := range channels { for i, name := range channels {
if config.Channels.MaxChannelsPerClient <= client.NumChannels() && oper == nil {
rb.Add(nil, server.name, ERR_TOOMANYCHANNELS, client.Nick(), name, client.t("You have joined too many channels"))
return false
}
var key string var key string
if len(keys) > i { if len(keys) > i {
key = keys[i] key = keys[i]

View File

@ -327,6 +327,9 @@ func nsInfoHandler(server *Server, client *Client, command string, params []stri
for _, nick := range account.AdditionalNicks { for _, nick := range account.AdditionalNicks {
nsNotice(rb, fmt.Sprintf(client.t("Additional grouped nick: %s"), nick)) nsNotice(rb, fmt.Sprintf(client.t("Additional grouped nick: %s"), nick))
} }
for _, channel := range server.accounts.ChannelsForAccount(accountName) {
nsNotice(rb, fmt.Sprintf(client.t("Registered channel: %s"), channel))
}
} }
func nsRegisterHandler(server *Server, client *Client, command string, params []string, rb *ResponseBuffer) { func nsRegisterHandler(server *Server, client *Client, command string, params []string, rb *ResponseBuffer) {

View File

@ -275,11 +275,17 @@ channels:
# see /QUOTE HELP cmodes for more channel modes # see /QUOTE HELP cmodes for more channel modes
default-modes: +nt default-modes: +nt
# how many channels can a client be in at once?
max-channels-per-client: 100
# channel registration - requires an account # channel registration - requires an account
registration: registration:
# can users register new channels? # can users register new channels?
enabled: true enabled: true
# how many channels can each account register?
max-channels-per-account: 15
# operator classes # operator classes
oper-classes: oper-classes:
# local operator # local operator