diff --git a/irc/accounts.go b/irc/accounts.go index f926e9c1..b75e34b6 100644 --- a/irc/accounts.go +++ b/irc/accounts.go @@ -1169,13 +1169,9 @@ func (vh *VHostInfo) checkThrottle(cooldown time.Duration) (err error) { // callback type implementing the actual business logic of vhost operations type vhostMunger func(input VHostInfo) (output VHostInfo, err error) -func (am *AccountManager) VHostSet(account string, vhost string, cooldown time.Duration) (result VHostInfo, err error) { +func (am *AccountManager) VHostSet(account string, vhost string) (result VHostInfo, err error) { munger := func(input VHostInfo) (output VHostInfo, err error) { output = input - err = output.checkThrottle(cooldown) - if err != nil { - return - } output.Enabled = true output.ApprovedVHost = vhost return @@ -1205,6 +1201,29 @@ func (am *AccountManager) VHostRequest(account string, vhost string, cooldown ti return am.performVHostChange(account, munger) } +func (am *AccountManager) VHostTake(account string, vhost string, cooldown time.Duration) (result VHostInfo, err error) { + munger := func(input VHostInfo) (output VHostInfo, err error) { + output = input + + // if you have a request pending, you can cancel it using take; + // otherwise, you're subject to the same throttling as if you were making a request + if output.RequestedVHost == "" { + err = output.checkThrottle(cooldown) + } + if err != nil { + return + } + output.ApprovedVHost = vhost + output.RequestedVHost = "" + output.RejectedVHost = "" + output.RejectionReason = "" + output.LastRequestTime = time.Now().UTC() + return + } + + return am.performVHostChange(account, munger) +} + func (am *AccountManager) VHostApprove(account string) (result VHostInfo, err error) { munger := func(input VHostInfo) (output VHostInfo, err error) { output = input diff --git a/irc/hostserv.go b/irc/hostserv.go index 8cfa9220..6424910b 100644 --- a/irc/hostserv.go +++ b/irc/hostserv.go @@ -298,7 +298,7 @@ func hsSetHandler(server *Server, client *Client, command string, params []strin } // else: command == "del", vhost == "" - _, err := server.accounts.VHostSet(user, vhost, 0) + _, err := server.accounts.VHostSet(user, vhost) if err != nil { hsNotice(rb, client.t("An error occurred")) } else if vhost != "" { @@ -404,7 +404,7 @@ func hsTakeHandler(server *Server, client *Client, command string, params []stri return } - _, err := server.accounts.VHostSet(client.Account(), vhost, config.Accounts.VHosts.UserRequests.Cooldown) + _, err := server.accounts.VHostTake(client.Account(), vhost, config.Accounts.VHosts.UserRequests.Cooldown) if err != nil { if throttled, ok := err.(*vhostThrottleExceeded); ok { hsNotice(rb, fmt.Sprintf(client.t("You must wait an additional %v before taking a vhost"), throttled.timeRemaining))