use custime.Duration for more config fields

This commit is contained in:
Shivaram Lingamneni 2020-02-20 00:09:39 -05:00
parent 7b3caed20f
commit ef161c47ed
6 changed files with 38 additions and 27 deletions

View File

@ -381,7 +381,7 @@ func (am *AccountManager) Register(client *Client, account string, callbackNames
callbackSpec := fmt.Sprintf("%s:%s", callbackNamespace, callbackValue)
var setOptions *buntdb.SetOptions
ttl := config.Registration.VerifyTimeout
ttl := time.Duration(config.Registration.VerifyTimeout)
if ttl != 0 {
setOptions = &buntdb.SetOptions{Expires: true, TTL: ttl}
}

View File

@ -233,9 +233,9 @@ type AccountConfig struct {
// AccountRegistrationConfig controls account registration.
type AccountRegistrationConfig struct {
Enabled bool
EnabledCallbacks []string `yaml:"enabled-callbacks"`
EnabledCredentialTypes []string `yaml:"-"`
VerifyTimeout time.Duration `yaml:"verify-timeout"`
EnabledCallbacks []string `yaml:"enabled-callbacks"`
EnabledCredentialTypes []string `yaml:"-"`
VerifyTimeout custime.Duration `yaml:"verify-timeout"`
Callbacks struct {
Mailto struct {
Server string
@ -263,7 +263,7 @@ type VHostConfig struct {
UserRequests struct {
Enabled bool
Channel string
Cooldown time.Duration
Cooldown custime.Duration
} `yaml:"user-requests"`
OfferList []string `yaml:"offer-list"`
}
@ -406,18 +406,17 @@ type Limits struct {
// STSConfig controls the STS configuration/
type STSConfig struct {
Enabled bool
Duration time.Duration `yaml:"duration-real"`
DurationString string `yaml:"duration"`
Port int
Preload bool
STSOnlyBanner string `yaml:"sts-only-banner"`
bannerLines []string
Enabled bool
Duration custime.Duration
Port int
Preload bool
STSOnlyBanner string `yaml:"sts-only-banner"`
bannerLines []string
}
// Value returns the STS value to advertise in CAP
func (sts *STSConfig) Value() string {
val := fmt.Sprintf("duration=%d", int(sts.Duration.Seconds()))
val := fmt.Sprintf("duration=%d", int(time.Duration(sts.Duration).Seconds()))
if sts.Enabled && sts.Port > 0 {
val += fmt.Sprintf(",port=%d", sts.Port)
}
@ -553,9 +552,9 @@ type Config struct {
ChathistoryMax int `yaml:"chathistory-maxmessages"`
ZNCMax int `yaml:"znc-maxmessages"`
Restrictions struct {
ExpireTime time.Duration `yaml:"expire-time"`
EnforceRegistrationDate bool `yaml:"enforce-registration-date"`
GracePeriod time.Duration `yaml:"grace-period"`
ExpireTime custime.Duration `yaml:"expire-time"`
EnforceRegistrationDate bool `yaml:"enforce-registration-date"`
GracePeriod custime.Duration `yaml:"grace-period"`
}
Persistent struct {
Enabled bool
@ -828,10 +827,6 @@ func LoadConfig(filename string) (config *Config, err error) {
}
if config.Server.STS.Enabled {
config.Server.STS.Duration, err = custime.ParseDuration(config.Server.STS.DurationString)
if err != nil {
return nil, fmt.Errorf("Could not parse STS duration: %s", err.Error())
}
if config.Server.STS.Port < 0 || config.Server.STS.Port > 65535 {
return nil, fmt.Errorf("STS port is incorrect, should be 0 if disabled: %d", config.Server.STS.Port)
}

View File

@ -182,3 +182,18 @@ func ParseDuration(s string) (time.Duration, error) {
}
return time.Duration(d), nil
}
type Duration time.Duration
func (d *Duration) UnmarshalYAML(unmarshal func(interface{}) error) error {
var orig string
var err error
if err = unmarshal(&orig); err != nil {
return err
}
result, err := ParseDuration(orig)
if err == nil {
*d = Duration(result)
}
return err
}

View File

@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"regexp"
"time"
"github.com/oragono/oragono/irc/sno"
)
@ -214,7 +215,7 @@ func hsRequestHandler(server *Server, client *Client, command string, params []s
}
accountName := client.Account()
_, err := server.accounts.VHostRequest(accountName, vhost, server.Config().Accounts.VHosts.UserRequests.Cooldown)
_, err := server.accounts.VHostRequest(accountName, vhost, time.Duration(server.Config().Accounts.VHosts.UserRequests.Cooldown))
if err != nil {
if throttled, ok := err.(*vhostThrottleExceeded); ok {
hsNotice(rb, fmt.Sprintf(client.t("You must wait an additional %v before making another request"), throttled.timeRemaining))
@ -411,7 +412,7 @@ func hsTakeHandler(server *Server, client *Client, command string, params []stri
}
account := client.Account()
_, err := server.accounts.VHostTake(account, vhost, config.Accounts.VHosts.UserRequests.Cooldown)
_, err := server.accounts.VHostTake(account, vhost, time.Duration(config.Accounts.VHosts.UserRequests.Cooldown))
if err != nil {
if throttled, ok := err.(*vhostThrottleExceeded); ok {
hsNotice(rb, fmt.Sprintf(client.t("You must wait an additional %v before taking a vhost"), throttled.timeRemaining))

View File

@ -670,7 +670,7 @@ func (server *Server) applyConfig(config *Config) (err error) {
}
} else {
if config.Datastore.MySQL.Enabled {
server.historyDB.SetExpireTime(config.History.Restrictions.ExpireTime)
server.historyDB.SetExpireTime(time.Duration(config.History.Restrictions.ExpireTime))
}
}
@ -793,7 +793,7 @@ func (server *Server) loadDatastore(config *Config) error {
server.accounts.Initialize(server)
if config.Datastore.MySQL.Enabled {
server.historyDB.Initialize(server.logger, config.History.Restrictions.ExpireTime)
server.historyDB.Initialize(server.logger, time.Duration(config.History.Restrictions.ExpireTime))
err = server.historyDB.Open(config.Datastore.MySQL.User, config.Datastore.MySQL.Password, config.Datastore.MySQL.Host, config.Datastore.MySQL.Port, config.Datastore.MySQL.HistoryDatabase)
if err != nil {
server.logger.Error("internal", "could not connect to mysql", err.Error())
@ -906,11 +906,11 @@ func (server *Server) GetHistorySequence(providedChannel *Channel, client *Clien
var cutoff time.Time
if config.History.Restrictions.ExpireTime != 0 {
cutoff = time.Now().UTC().Add(-config.History.Restrictions.ExpireTime)
cutoff = time.Now().UTC().Add(-time.Duration(config.History.Restrictions.ExpireTime))
}
if config.History.Restrictions.EnforceRegistrationDate {
regCutoff := client.historyCutoff()
regCutoff.Add(-config.History.Restrictions.GracePeriod)
regCutoff.Add(-time.Duration(config.History.Restrictions.GracePeriod))
// take the earlier of the two cutoffs
if regCutoff.After(cutoff) {
cutoff = regCutoff

View File

@ -716,7 +716,7 @@ history:
restrictions:
# if this is set, messages older than this cannot be retrieved by anyone
# (and will eventually be deleted from persistent storage, if that's enabled)
#expire-time: 168h # 7 days
#expire-time: 1w
# if this is set, logged-in users cannot retrieve messages older than their
# account registration date, and logged-out users cannot retrieve messages