3
0
mirror of https://github.com/ergochat/ergo.git synced 2024-11-25 05:19:25 +01:00

use custime.Duration for more config fields

This commit is contained in:
Shivaram Lingamneni 2020-02-20 00:09:39 -05:00
parent 7b3caed20f
commit ef161c47ed
6 changed files with 38 additions and 27 deletions

View File

@ -381,7 +381,7 @@ func (am *AccountManager) Register(client *Client, account string, callbackNames
callbackSpec := fmt.Sprintf("%s:%s", callbackNamespace, callbackValue) callbackSpec := fmt.Sprintf("%s:%s", callbackNamespace, callbackValue)
var setOptions *buntdb.SetOptions var setOptions *buntdb.SetOptions
ttl := config.Registration.VerifyTimeout ttl := time.Duration(config.Registration.VerifyTimeout)
if ttl != 0 { if ttl != 0 {
setOptions = &buntdb.SetOptions{Expires: true, TTL: ttl} setOptions = &buntdb.SetOptions{Expires: true, TTL: ttl}
} }

View File

@ -233,9 +233,9 @@ type AccountConfig struct {
// AccountRegistrationConfig controls account registration. // AccountRegistrationConfig controls account registration.
type AccountRegistrationConfig struct { type AccountRegistrationConfig struct {
Enabled bool Enabled bool
EnabledCallbacks []string `yaml:"enabled-callbacks"` EnabledCallbacks []string `yaml:"enabled-callbacks"`
EnabledCredentialTypes []string `yaml:"-"` EnabledCredentialTypes []string `yaml:"-"`
VerifyTimeout time.Duration `yaml:"verify-timeout"` VerifyTimeout custime.Duration `yaml:"verify-timeout"`
Callbacks struct { Callbacks struct {
Mailto struct { Mailto struct {
Server string Server string
@ -263,7 +263,7 @@ type VHostConfig struct {
UserRequests struct { UserRequests struct {
Enabled bool Enabled bool
Channel string Channel string
Cooldown time.Duration Cooldown custime.Duration
} `yaml:"user-requests"` } `yaml:"user-requests"`
OfferList []string `yaml:"offer-list"` OfferList []string `yaml:"offer-list"`
} }
@ -406,18 +406,17 @@ type Limits struct {
// STSConfig controls the STS configuration/ // STSConfig controls the STS configuration/
type STSConfig struct { type STSConfig struct {
Enabled bool Enabled bool
Duration time.Duration `yaml:"duration-real"` Duration custime.Duration
DurationString string `yaml:"duration"` Port int
Port int Preload bool
Preload bool STSOnlyBanner string `yaml:"sts-only-banner"`
STSOnlyBanner string `yaml:"sts-only-banner"` bannerLines []string
bannerLines []string
} }
// Value returns the STS value to advertise in CAP // Value returns the STS value to advertise in CAP
func (sts *STSConfig) Value() string { func (sts *STSConfig) Value() string {
val := fmt.Sprintf("duration=%d", int(sts.Duration.Seconds())) val := fmt.Sprintf("duration=%d", int(time.Duration(sts.Duration).Seconds()))
if sts.Enabled && sts.Port > 0 { if sts.Enabled && sts.Port > 0 {
val += fmt.Sprintf(",port=%d", sts.Port) val += fmt.Sprintf(",port=%d", sts.Port)
} }
@ -553,9 +552,9 @@ type Config struct {
ChathistoryMax int `yaml:"chathistory-maxmessages"` ChathistoryMax int `yaml:"chathistory-maxmessages"`
ZNCMax int `yaml:"znc-maxmessages"` ZNCMax int `yaml:"znc-maxmessages"`
Restrictions struct { Restrictions struct {
ExpireTime time.Duration `yaml:"expire-time"` ExpireTime custime.Duration `yaml:"expire-time"`
EnforceRegistrationDate bool `yaml:"enforce-registration-date"` EnforceRegistrationDate bool `yaml:"enforce-registration-date"`
GracePeriod time.Duration `yaml:"grace-period"` GracePeriod custime.Duration `yaml:"grace-period"`
} }
Persistent struct { Persistent struct {
Enabled bool Enabled bool
@ -828,10 +827,6 @@ func LoadConfig(filename string) (config *Config, err error) {
} }
if config.Server.STS.Enabled { if config.Server.STS.Enabled {
config.Server.STS.Duration, err = custime.ParseDuration(config.Server.STS.DurationString)
if err != nil {
return nil, fmt.Errorf("Could not parse STS duration: %s", err.Error())
}
if config.Server.STS.Port < 0 || config.Server.STS.Port > 65535 { if config.Server.STS.Port < 0 || config.Server.STS.Port > 65535 {
return nil, fmt.Errorf("STS port is incorrect, should be 0 if disabled: %d", config.Server.STS.Port) return nil, fmt.Errorf("STS port is incorrect, should be 0 if disabled: %d", config.Server.STS.Port)
} }

View File

@ -182,3 +182,18 @@ func ParseDuration(s string) (time.Duration, error) {
} }
return time.Duration(d), nil return time.Duration(d), nil
} }
type Duration time.Duration
func (d *Duration) UnmarshalYAML(unmarshal func(interface{}) error) error {
var orig string
var err error
if err = unmarshal(&orig); err != nil {
return err
}
result, err := ParseDuration(orig)
if err == nil {
*d = Duration(result)
}
return err
}

View File

@ -7,6 +7,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"regexp" "regexp"
"time"
"github.com/oragono/oragono/irc/sno" "github.com/oragono/oragono/irc/sno"
) )
@ -214,7 +215,7 @@ func hsRequestHandler(server *Server, client *Client, command string, params []s
} }
accountName := client.Account() accountName := client.Account()
_, err := server.accounts.VHostRequest(accountName, vhost, server.Config().Accounts.VHosts.UserRequests.Cooldown) _, err := server.accounts.VHostRequest(accountName, vhost, time.Duration(server.Config().Accounts.VHosts.UserRequests.Cooldown))
if err != nil { if err != nil {
if throttled, ok := err.(*vhostThrottleExceeded); ok { if throttled, ok := err.(*vhostThrottleExceeded); ok {
hsNotice(rb, fmt.Sprintf(client.t("You must wait an additional %v before making another request"), throttled.timeRemaining)) hsNotice(rb, fmt.Sprintf(client.t("You must wait an additional %v before making another request"), throttled.timeRemaining))
@ -411,7 +412,7 @@ func hsTakeHandler(server *Server, client *Client, command string, params []stri
} }
account := client.Account() account := client.Account()
_, err := server.accounts.VHostTake(account, vhost, config.Accounts.VHosts.UserRequests.Cooldown) _, err := server.accounts.VHostTake(account, vhost, time.Duration(config.Accounts.VHosts.UserRequests.Cooldown))
if err != nil { if err != nil {
if throttled, ok := err.(*vhostThrottleExceeded); ok { if throttled, ok := err.(*vhostThrottleExceeded); ok {
hsNotice(rb, fmt.Sprintf(client.t("You must wait an additional %v before taking a vhost"), throttled.timeRemaining)) hsNotice(rb, fmt.Sprintf(client.t("You must wait an additional %v before taking a vhost"), throttled.timeRemaining))

View File

@ -670,7 +670,7 @@ func (server *Server) applyConfig(config *Config) (err error) {
} }
} else { } else {
if config.Datastore.MySQL.Enabled { if config.Datastore.MySQL.Enabled {
server.historyDB.SetExpireTime(config.History.Restrictions.ExpireTime) server.historyDB.SetExpireTime(time.Duration(config.History.Restrictions.ExpireTime))
} }
} }
@ -793,7 +793,7 @@ func (server *Server) loadDatastore(config *Config) error {
server.accounts.Initialize(server) server.accounts.Initialize(server)
if config.Datastore.MySQL.Enabled { if config.Datastore.MySQL.Enabled {
server.historyDB.Initialize(server.logger, config.History.Restrictions.ExpireTime) server.historyDB.Initialize(server.logger, time.Duration(config.History.Restrictions.ExpireTime))
err = server.historyDB.Open(config.Datastore.MySQL.User, config.Datastore.MySQL.Password, config.Datastore.MySQL.Host, config.Datastore.MySQL.Port, config.Datastore.MySQL.HistoryDatabase) err = server.historyDB.Open(config.Datastore.MySQL.User, config.Datastore.MySQL.Password, config.Datastore.MySQL.Host, config.Datastore.MySQL.Port, config.Datastore.MySQL.HistoryDatabase)
if err != nil { if err != nil {
server.logger.Error("internal", "could not connect to mysql", err.Error()) server.logger.Error("internal", "could not connect to mysql", err.Error())
@ -906,11 +906,11 @@ func (server *Server) GetHistorySequence(providedChannel *Channel, client *Clien
var cutoff time.Time var cutoff time.Time
if config.History.Restrictions.ExpireTime != 0 { if config.History.Restrictions.ExpireTime != 0 {
cutoff = time.Now().UTC().Add(-config.History.Restrictions.ExpireTime) cutoff = time.Now().UTC().Add(-time.Duration(config.History.Restrictions.ExpireTime))
} }
if config.History.Restrictions.EnforceRegistrationDate { if config.History.Restrictions.EnforceRegistrationDate {
regCutoff := client.historyCutoff() regCutoff := client.historyCutoff()
regCutoff.Add(-config.History.Restrictions.GracePeriod) regCutoff.Add(-time.Duration(config.History.Restrictions.GracePeriod))
// take the earlier of the two cutoffs // take the earlier of the two cutoffs
if regCutoff.After(cutoff) { if regCutoff.After(cutoff) {
cutoff = regCutoff cutoff = regCutoff

View File

@ -716,7 +716,7 @@ history:
restrictions: restrictions:
# if this is set, messages older than this cannot be retrieved by anyone # if this is set, messages older than this cannot be retrieved by anyone
# (and will eventually be deleted from persistent storage, if that's enabled) # (and will eventually be deleted from persistent storage, if that's enabled)
#expire-time: 168h # 7 days #expire-time: 1w
# if this is set, logged-in users cannot retrieve messages older than their # if this is set, logged-in users cannot retrieve messages older than their
# account registration date, and logged-out users cannot retrieve messages # account registration date, and logged-out users cannot retrieve messages