mirror of
https://github.com/ergochat/ergo.git
synced 2024-11-13 07:29:30 +01:00
Merge pull request #352 from slingamn/chanreglimit.1
track channel registrations per account
This commit is contained in:
commit
7cf8aaccf6
@ -33,6 +33,7 @@ const (
|
|||||||
keyAccountEnforcement = "account.customenforcement %s"
|
keyAccountEnforcement = "account.customenforcement %s"
|
||||||
keyAccountVHost = "account.vhost %s"
|
keyAccountVHost = "account.vhost %s"
|
||||||
keyCertToAccount = "account.creds.certfp %s"
|
keyCertToAccount = "account.creds.certfp %s"
|
||||||
|
keyAccountChannels = "account.channels %s"
|
||||||
|
|
||||||
keyVHostQueueAcctToId = "vhostQueue %s"
|
keyVHostQueueAcctToId = "vhostQueue %s"
|
||||||
vhostRequestIdx = "vhostQueue"
|
vhostRequestIdx = "vhostQueue"
|
||||||
@ -856,9 +857,25 @@ func (am *AccountManager) Unregister(account string) error {
|
|||||||
nicksKey := fmt.Sprintf(keyAccountAdditionalNicks, casefoldedAccount)
|
nicksKey := fmt.Sprintf(keyAccountAdditionalNicks, casefoldedAccount)
|
||||||
vhostKey := fmt.Sprintf(keyAccountVHost, casefoldedAccount)
|
vhostKey := fmt.Sprintf(keyAccountVHost, casefoldedAccount)
|
||||||
vhostQueueKey := fmt.Sprintf(keyVHostQueueAcctToId, casefoldedAccount)
|
vhostQueueKey := fmt.Sprintf(keyVHostQueueAcctToId, casefoldedAccount)
|
||||||
|
channelsKey := fmt.Sprintf(keyAccountChannels, casefoldedAccount)
|
||||||
|
|
||||||
var clients []*Client
|
var clients []*Client
|
||||||
|
|
||||||
|
var registeredChannels []string
|
||||||
|
// on our way out, unregister all the account's channels and delete them from the db
|
||||||
|
defer func() {
|
||||||
|
for _, channelName := range registeredChannels {
|
||||||
|
info := am.server.channelRegistry.LoadChannel(channelName)
|
||||||
|
if info != nil && info.Founder == casefoldedAccount {
|
||||||
|
am.server.channelRegistry.Delete(channelName, *info)
|
||||||
|
}
|
||||||
|
channel := am.server.channels.Get(channelName)
|
||||||
|
if channel != nil {
|
||||||
|
channel.SetUnregistered(casefoldedAccount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
var credText string
|
var credText string
|
||||||
var rawNicks string
|
var rawNicks string
|
||||||
|
|
||||||
@ -866,6 +883,7 @@ func (am *AccountManager) Unregister(account string) error {
|
|||||||
defer am.serialCacheUpdateMutex.Unlock()
|
defer am.serialCacheUpdateMutex.Unlock()
|
||||||
|
|
||||||
var accountName string
|
var accountName string
|
||||||
|
var channelsStr string
|
||||||
am.server.store.Update(func(tx *buntdb.Tx) error {
|
am.server.store.Update(func(tx *buntdb.Tx) error {
|
||||||
tx.Delete(accountKey)
|
tx.Delete(accountKey)
|
||||||
accountName, _ = tx.Get(accountNameKey)
|
accountName, _ = tx.Get(accountNameKey)
|
||||||
@ -879,6 +897,9 @@ func (am *AccountManager) Unregister(account string) error {
|
|||||||
credText, err = tx.Get(credentialsKey)
|
credText, err = tx.Get(credentialsKey)
|
||||||
tx.Delete(credentialsKey)
|
tx.Delete(credentialsKey)
|
||||||
tx.Delete(vhostKey)
|
tx.Delete(vhostKey)
|
||||||
|
channelsStr, _ = tx.Get(channelsKey)
|
||||||
|
tx.Delete(channelsKey)
|
||||||
|
|
||||||
_, err := tx.Delete(vhostQueueKey)
|
_, err := tx.Delete(vhostQueueKey)
|
||||||
am.decrementVHostQueueCount(casefoldedAccount, err)
|
am.decrementVHostQueueCount(casefoldedAccount, err)
|
||||||
return nil
|
return nil
|
||||||
@ -899,6 +920,7 @@ func (am *AccountManager) Unregister(account string) error {
|
|||||||
|
|
||||||
skeleton, _ := Skeleton(accountName)
|
skeleton, _ := Skeleton(accountName)
|
||||||
additionalNicks := unmarshalReservedNicks(rawNicks)
|
additionalNicks := unmarshalReservedNicks(rawNicks)
|
||||||
|
registeredChannels = unmarshalRegisteredChannels(channelsStr)
|
||||||
|
|
||||||
am.Lock()
|
am.Lock()
|
||||||
defer am.Unlock()
|
defer am.Unlock()
|
||||||
@ -925,9 +947,32 @@ func (am *AccountManager) Unregister(account string) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return errAccountDoesNotExist
|
return errAccountDoesNotExist
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func unmarshalRegisteredChannels(channelsStr string) (result []string) {
|
||||||
|
if channelsStr != "" {
|
||||||
|
result = strings.Split(channelsStr, ",")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (am *AccountManager) ChannelsForAccount(account string) (channels []string) {
|
||||||
|
cfaccount, err := CasefoldName(account)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var channelStr string
|
||||||
|
key := fmt.Sprintf(keyAccountChannels, cfaccount)
|
||||||
|
am.server.store.View(func(tx *buntdb.Tx) error {
|
||||||
|
channelStr, _ = tx.Get(key)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
return unmarshalRegisteredChannels(channelStr)
|
||||||
|
}
|
||||||
|
|
||||||
func (am *AccountManager) AuthenticateByCertFP(client *Client) error {
|
func (am *AccountManager) AuthenticateByCertFP(client *Client) error {
|
||||||
if client.certfp == "" {
|
if client.certfp == "" {
|
||||||
return errAccountInvalidCredentials
|
return errAccountInvalidCredentials
|
||||||
|
@ -165,10 +165,13 @@ func (channel *Channel) SetRegistered(founder string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SetUnregistered deletes the channel's registration information.
|
// SetUnregistered deletes the channel's registration information.
|
||||||
func (channel *Channel) SetUnregistered() {
|
func (channel *Channel) SetUnregistered(expectedFounder string) {
|
||||||
channel.stateMutex.Lock()
|
channel.stateMutex.Lock()
|
||||||
defer channel.stateMutex.Unlock()
|
defer channel.stateMutex.Unlock()
|
||||||
|
|
||||||
|
if channel.registeredFounder != expectedFounder {
|
||||||
|
return
|
||||||
|
}
|
||||||
channel.registeredFounder = ""
|
channel.registeredFounder = ""
|
||||||
var zeroTime time.Time
|
var zeroTime time.Time
|
||||||
channel.registeredTime = zeroTime
|
channel.registeredTime = zeroTime
|
||||||
|
@ -254,14 +254,43 @@ func (reg *ChannelRegistry) deleteChannel(tx *buntdb.Tx, key string, info Regist
|
|||||||
for _, keyFmt := range channelKeyStrings {
|
for _, keyFmt := range channelKeyStrings {
|
||||||
tx.Delete(fmt.Sprintf(keyFmt, key))
|
tx.Delete(fmt.Sprintf(keyFmt, key))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// remove this channel from the client's list of registered channels
|
||||||
|
channelsKey := fmt.Sprintf(keyAccountChannels, info.Founder)
|
||||||
|
channelsStr, err := tx.Get(channelsKey)
|
||||||
|
if err == buntdb.ErrNotFound {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
registeredChannels := unmarshalRegisteredChannels(channelsStr)
|
||||||
|
var nowRegisteredChannels []string
|
||||||
|
for _, channel := range registeredChannels {
|
||||||
|
if channel != key {
|
||||||
|
nowRegisteredChannels = append(nowRegisteredChannels, channel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tx.Set(channelsKey, strings.Join(nowRegisteredChannels, ","), nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// saveChannel saves a channel to the store.
|
// saveChannel saves a channel to the store.
|
||||||
func (reg *ChannelRegistry) saveChannel(tx *buntdb.Tx, channelKey string, channelInfo RegisteredChannel, includeFlags uint) {
|
func (reg *ChannelRegistry) saveChannel(tx *buntdb.Tx, channelKey string, channelInfo RegisteredChannel, includeFlags uint) {
|
||||||
|
// maintain the mapping of account -> registered channels
|
||||||
|
chanExistsKey := fmt.Sprintf(keyChannelExists, channelKey)
|
||||||
|
_, existsErr := tx.Get(chanExistsKey)
|
||||||
|
if existsErr == buntdb.ErrNotFound {
|
||||||
|
// this is a new registration, need to update account-to-channels
|
||||||
|
accountChannelsKey := fmt.Sprintf(keyAccountChannels, channelInfo.Founder)
|
||||||
|
alreadyChannels, _ := tx.Get(accountChannelsKey)
|
||||||
|
newChannels := channelKey // this is the casefolded channel name
|
||||||
|
if alreadyChannels != "" {
|
||||||
|
newChannels = fmt.Sprintf("%s,%s", alreadyChannels, newChannels)
|
||||||
|
}
|
||||||
|
tx.Set(accountChannelsKey, newChannels, nil)
|
||||||
|
}
|
||||||
|
|
||||||
if includeFlags&IncludeInitial != 0 {
|
if includeFlags&IncludeInitial != 0 {
|
||||||
tx.Set(fmt.Sprintf(keyChannelExists, channelKey), "1", nil)
|
tx.Set(chanExistsKey, "1", nil)
|
||||||
tx.Set(fmt.Sprintf(keyChannelName, channelKey), channelInfo.Name, nil)
|
tx.Set(fmt.Sprintf(keyChannelName, channelKey), channelInfo.Name, nil)
|
||||||
tx.Set(fmt.Sprintf(keyChannelRegTime, channelKey), strconv.FormatInt(channelInfo.RegisteredAt.Unix(), 10), nil)
|
tx.Set(fmt.Sprintf(keyChannelRegTime, channelKey), strconv.FormatInt(channelInfo.RegisteredAt.Unix(), 10), nil)
|
||||||
tx.Set(fmt.Sprintf(keyChannelFounder, channelKey), channelInfo.Founder, nil)
|
tx.Set(fmt.Sprintf(keyChannelFounder, channelKey), channelInfo.Founder, nil)
|
||||||
|
@ -224,8 +224,15 @@ func csRegisterHandler(server *Server, client *Client, command string, params []
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
account := client.Account()
|
||||||
|
channelsAlreadyRegistered := server.accounts.ChannelsForAccount(account)
|
||||||
|
if server.Config().Channels.Registration.MaxChannelsPerAccount <= len(channelsAlreadyRegistered) {
|
||||||
|
csNotice(rb, client.t("You have already registered the maximum number of channels; try dropping some with /CS UNREGISTER"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// this provides the synchronization that allows exactly one registration of the channel:
|
// this provides the synchronization that allows exactly one registration of the channel:
|
||||||
err = channelInfo.SetRegistered(client.Account())
|
err = channelInfo.SetRegistered(account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
csNotice(rb, err.Error())
|
csNotice(rb, err.Error())
|
||||||
return
|
return
|
||||||
@ -270,11 +277,13 @@ func csUnregisterHandler(server *Server, client *Client, command string, params
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
hasPrivs := client.HasRoleCapabs("chanreg")
|
founder := channel.Founder()
|
||||||
if !hasPrivs {
|
if founder == "" {
|
||||||
founder := channel.Founder()
|
csNotice(rb, client.t("That channel is not registered"))
|
||||||
hasPrivs = founder != "" && founder == client.Account()
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
hasPrivs := client.HasRoleCapabs("chanreg") || founder == client.Account()
|
||||||
if !hasPrivs {
|
if !hasPrivs {
|
||||||
csNotice(rb, client.t("Insufficient privileges"))
|
csNotice(rb, client.t("Insufficient privileges"))
|
||||||
return
|
return
|
||||||
@ -288,8 +297,8 @@ func csUnregisterHandler(server *Server, client *Client, command string, params
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
channel.SetUnregistered()
|
channel.SetUnregistered(founder)
|
||||||
go server.channelRegistry.Delete(channelKey, info)
|
server.channelRegistry.Delete(channelKey, info)
|
||||||
csNotice(rb, fmt.Sprintf(client.t("Channel %s is now unregistered"), channelKey))
|
csNotice(rb, fmt.Sprintf(client.t("Channel %s is now unregistered"), channelKey))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -181,7 +181,8 @@ type NickReservationConfig struct {
|
|||||||
|
|
||||||
// ChannelRegistrationConfig controls channel registration.
|
// ChannelRegistrationConfig controls channel registration.
|
||||||
type ChannelRegistrationConfig struct {
|
type ChannelRegistrationConfig struct {
|
||||||
Enabled bool
|
Enabled bool
|
||||||
|
MaxChannelsPerAccount int `yaml:"max-channels-per-account"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// OperClassConfig defines a specific operator class.
|
// OperClassConfig defines a specific operator class.
|
||||||
@ -293,9 +294,10 @@ type Config struct {
|
|||||||
Accounts AccountConfig
|
Accounts AccountConfig
|
||||||
|
|
||||||
Channels struct {
|
Channels struct {
|
||||||
DefaultModes *string `yaml:"default-modes"`
|
DefaultModes *string `yaml:"default-modes"`
|
||||||
defaultModes modes.Modes
|
defaultModes modes.Modes
|
||||||
Registration ChannelRegistrationConfig
|
MaxChannelsPerClient int `yaml:"max-channels-per-client"`
|
||||||
|
Registration ChannelRegistrationConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
OperClasses map[string]*OperClassConfig `yaml:"oper-classes"`
|
OperClasses map[string]*OperClassConfig `yaml:"oper-classes"`
|
||||||
@ -789,6 +791,13 @@ func LoadConfig(filename string) (config *Config, err error) {
|
|||||||
config.Accounts.Registration.BcryptCost = passwd.DefaultCost
|
config.Accounts.Registration.BcryptCost = passwd.DefaultCost
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if config.Channels.MaxChannelsPerClient == 0 {
|
||||||
|
config.Channels.MaxChannelsPerClient = 100
|
||||||
|
}
|
||||||
|
if config.Channels.Registration.MaxChannelsPerAccount == 0 {
|
||||||
|
config.Channels.Registration.MaxChannelsPerAccount = 15
|
||||||
|
}
|
||||||
|
|
||||||
// in the current implementation, we disable history by creating a history buffer
|
// in the current implementation, we disable history by creating a history buffer
|
||||||
// with zero capacity. but the `enabled` config option MUST be respected regardless
|
// with zero capacity. but the `enabled` config option MUST be respected regardless
|
||||||
// of this detail
|
// of this detail
|
||||||
|
@ -22,7 +22,7 @@ const (
|
|||||||
// 'version' of the database schema
|
// 'version' of the database schema
|
||||||
keySchemaVersion = "db.version"
|
keySchemaVersion = "db.version"
|
||||||
// latest schema of the db
|
// latest schema of the db
|
||||||
latestDbSchema = "4"
|
latestDbSchema = "5"
|
||||||
)
|
)
|
||||||
|
|
||||||
type SchemaChanger func(*Config, *buntdb.Tx) error
|
type SchemaChanger func(*Config, *buntdb.Tx) error
|
||||||
@ -390,6 +390,25 @@ func schemaChangeV3ToV4(config *Config, tx *buntdb.Tx) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// create new key tracking channels that belong to an account
|
||||||
|
func schemaChangeV4ToV5(config *Config, tx *buntdb.Tx) error {
|
||||||
|
founderToChannels := make(map[string][]string)
|
||||||
|
prefix := "channel.founder "
|
||||||
|
tx.AscendGreaterOrEqual("", prefix, func(key, value string) bool {
|
||||||
|
if !strings.HasPrefix(key, prefix) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
channel := strings.TrimPrefix(key, prefix)
|
||||||
|
founderToChannels[value] = append(founderToChannels[value], channel)
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
for founder, channels := range founderToChannels {
|
||||||
|
tx.Set(fmt.Sprintf("account.channels %s", founder), strings.Join(channels, ","), nil)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
allChanges := []SchemaChange{
|
allChanges := []SchemaChange{
|
||||||
{
|
{
|
||||||
@ -407,6 +426,11 @@ func init() {
|
|||||||
TargetVersion: "4",
|
TargetVersion: "4",
|
||||||
Changer: schemaChangeV3ToV4,
|
Changer: schemaChangeV3ToV4,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
InitialVersion: "4",
|
||||||
|
TargetVersion: "5",
|
||||||
|
Changer: schemaChangeV4ToV5,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// build the index
|
// build the index
|
||||||
|
@ -200,6 +200,12 @@ func (client *Client) Channels() (result []*Channel) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (client *Client) NumChannels() int {
|
||||||
|
client.stateMutex.RLock()
|
||||||
|
defer client.stateMutex.RUnlock()
|
||||||
|
return len(client.channels)
|
||||||
|
}
|
||||||
|
|
||||||
func (client *Client) WhoWas() (result WhoWas) {
|
func (client *Client) WhoWas() (result WhoWas) {
|
||||||
return client.Details().WhoWas
|
return client.Details().WhoWas
|
||||||
}
|
}
|
||||||
|
@ -1159,7 +1159,13 @@ func joinHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Resp
|
|||||||
keys = strings.Split(msg.Params[1], ",")
|
keys = strings.Split(msg.Params[1], ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
config := server.Config()
|
||||||
|
oper := client.Oper()
|
||||||
for i, name := range channels {
|
for i, name := range channels {
|
||||||
|
if config.Channels.MaxChannelsPerClient <= client.NumChannels() && oper == nil {
|
||||||
|
rb.Add(nil, server.name, ERR_TOOMANYCHANNELS, client.Nick(), name, client.t("You have joined too many channels"))
|
||||||
|
return false
|
||||||
|
}
|
||||||
var key string
|
var key string
|
||||||
if len(keys) > i {
|
if len(keys) > i {
|
||||||
key = keys[i]
|
key = keys[i]
|
||||||
|
@ -327,6 +327,9 @@ func nsInfoHandler(server *Server, client *Client, command string, params []stri
|
|||||||
for _, nick := range account.AdditionalNicks {
|
for _, nick := range account.AdditionalNicks {
|
||||||
nsNotice(rb, fmt.Sprintf(client.t("Additional grouped nick: %s"), nick))
|
nsNotice(rb, fmt.Sprintf(client.t("Additional grouped nick: %s"), nick))
|
||||||
}
|
}
|
||||||
|
for _, channel := range server.accounts.ChannelsForAccount(accountName) {
|
||||||
|
nsNotice(rb, fmt.Sprintf(client.t("Registered channel: %s"), channel))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func nsRegisterHandler(server *Server, client *Client, command string, params []string, rb *ResponseBuffer) {
|
func nsRegisterHandler(server *Server, client *Client, command string, params []string, rb *ResponseBuffer) {
|
||||||
|
@ -275,11 +275,17 @@ channels:
|
|||||||
# see /QUOTE HELP cmodes for more channel modes
|
# see /QUOTE HELP cmodes for more channel modes
|
||||||
default-modes: +nt
|
default-modes: +nt
|
||||||
|
|
||||||
|
# how many channels can a client be in at once?
|
||||||
|
max-channels-per-client: 100
|
||||||
|
|
||||||
# channel registration - requires an account
|
# channel registration - requires an account
|
||||||
registration:
|
registration:
|
||||||
# can users register new channels?
|
# can users register new channels?
|
||||||
enabled: true
|
enabled: true
|
||||||
|
|
||||||
|
# how many channels can each account register?
|
||||||
|
max-channels-per-account: 15
|
||||||
|
|
||||||
# operator classes
|
# operator classes
|
||||||
oper-classes:
|
oper-classes:
|
||||||
# local operator
|
# local operator
|
||||||
|
Loading…
Reference in New Issue
Block a user