3
0
mirror of https://github.com/ergochat/ergo.git synced 2024-11-14 07:59:31 +01:00
This commit is contained in:
Shivaram Lingamneni 2020-06-12 15:51:48 -04:00
parent c4c4ec027e
commit 218bea5a3e
6 changed files with 146 additions and 48 deletions

View File

@ -617,11 +617,12 @@ func (am *AccountManager) loadModes(account string) (uModes modes.Modes) {
return
}
func (am *AccountManager) saveLastSeen(account string, lastSeen time.Time) {
func (am *AccountManager) saveLastSeen(account string, lastSeen map[string]time.Time) {
key := fmt.Sprintf(keyAccountLastSeen, account)
var val string
if !lastSeen.IsZero() {
val = strconv.FormatInt(lastSeen.UnixNano(), 10)
if len(lastSeen) != 0 {
text, _ := json.Marshal(lastSeen)
val = string(text)
}
am.server.store.Update(func(tx *buntdb.Tx) error {
if val != "" {
@ -633,20 +634,19 @@ func (am *AccountManager) saveLastSeen(account string, lastSeen time.Time) {
})
}
func (am *AccountManager) loadLastSeen(account string) (lastSeen time.Time) {
func (am *AccountManager) loadLastSeen(account string) (lastSeen map[string]time.Time) {
key := fmt.Sprintf(keyAccountLastSeen, account)
var lsText string
am.server.store.Update(func(tx *buntdb.Tx) error {
lsText, _ = tx.Get(key)
// XXX clear this on startup, because it's not clear when it's
// going to be overwritten, and restarting the server twice in a row
// could result in a large amount of duplicated history replay
tx.Delete(key)
return nil
})
lsNum, err := strconv.ParseInt(lsText, 10, 64)
if err == nil {
return time.Unix(0, lsNum).UTC()
if lsText == "" {
return nil
}
err := json.Unmarshal([]byte(lsText), &lastSeen)
if err != nil {
return nil
}
return
}
@ -1052,6 +1052,10 @@ func (am *AccountManager) AuthenticateByPassphrase(client *Client, accountName s
}
}
if throttled, remainingTime := client.checkLoginThrottle(); throttled {
return &ThrottleError{remainingTime}
}
var account ClientAccount
defer func() {

View File

@ -30,6 +30,11 @@ const (
// IdentTimeout is how long before our ident (username) check times out.
IdentTimeout = time.Second + 500*time.Millisecond
IRCv3TimestampFormat = utils.IRCv3TimestampFormat
// limit the number of device IDs a client can use, as a DoS mitigation
maxDeviceIDsPerClient = 64
// controls how often often we write an autoreplay-missed client's
// deviceid->lastseentime mapping to the database
lastSeenWriteInterval = time.Minute * 10
)
// ResumeDetails is a place to stash data at various stages of
@ -60,8 +65,9 @@ type Client struct {
invitedTo map[string]bool
isSTSOnly bool
languages []string
lastActive time.Time // last time they sent a command that wasn't PONG or similar
lastSeen time.Time // last time they sent any kind of command
lastActive time.Time // last time they sent a command that wasn't PONG or similar
lastSeen map[string]time.Time // maps device ID (including "") to time of last received command
lastSeenLastWrite time.Time // last time `lastSeen` was written to the datastore
loginThrottle connection_limits.GenericThrottle
nick string
nickCasefolded string
@ -112,6 +118,8 @@ const (
type Session struct {
client *Client
deviceID string
ctime time.Time
lastActive time.Time
@ -299,7 +307,7 @@ func (server *Server) RunClient(conn IRCConn) {
// give them 1k of grace over the limit:
socket := NewSocket(conn, config.Server.MaxSendQBytes)
client := &Client{
lastSeen: now,
lastSeen: make(map[string]time.Time),
lastActive: now,
channels: make(ChannelSet),
ctime: now,
@ -358,11 +366,14 @@ func (server *Server) RunClient(conn IRCConn) {
client.run(session)
}
func (server *Server) AddAlwaysOnClient(account ClientAccount, chnames []string, lastSeen time.Time, uModes modes.Modes) {
func (server *Server) AddAlwaysOnClient(account ClientAccount, chnames []string, lastSeen map[string]time.Time, uModes modes.Modes) {
now := time.Now().UTC()
config := server.Config()
if lastSeen.IsZero() {
lastSeen = now
if lastSeen == nil {
lastSeen = make(map[string]time.Time)
if account.Settings.AutoreplayMissed {
lastSeen[""] = now
}
}
client := &Client{
@ -714,14 +725,39 @@ func (client *Client) playReattachMessages(session *Session) {
// at this time, modulo network latency and fakelag). `active` means not a PING or suchlike
// (i.e. the user should be sitting in front of their client).
func (client *Client) Touch(active bool, session *Session) {
var markDirty bool
now := time.Now().UTC()
client.stateMutex.Lock()
defer client.stateMutex.Unlock()
client.lastSeen = now
if active {
client.lastActive = now
session.lastActive = now
}
if client.accountSettings.AutoreplayMissed {
client.setLastSeen(now, session.deviceID)
if now.Sub(client.lastSeenLastWrite) > lastSeenWriteInterval {
markDirty = true
client.lastSeenLastWrite = now
}
}
client.stateMutex.Unlock()
if markDirty {
client.markDirty(IncludeLastSeen)
}
}
func (client *Client) setLastSeen(now time.Time, deviceID string) {
client.lastSeen[deviceID] = now
// evict the least-recently-used entry if necessary
if maxDeviceIDsPerClient < len(client.lastSeen) {
var minLastSeen time.Time
var minClientId string
for deviceID, lastSeen := range client.lastSeen {
if minLastSeen.IsZero() || lastSeen.Before(minLastSeen) {
minClientId, minLastSeen = deviceID, lastSeen
}
}
delete(client.lastSeen, minClientId)
}
}
// Ping sends the client a PING message.
@ -1604,6 +1640,12 @@ func (client *Client) attemptAutoOper(session *Session) {
}
}
func (client *Client) checkLoginThrottle() (throttled bool, remainingTime time.Duration) {
client.stateMutex.Lock()
defer client.stateMutex.Unlock()
return client.loginThrottle.Touch()
}
func (client *Client) historyStatus(config *Config) (status HistoryStatus, target string) {
if !config.History.Enabled {
return HistoryDisabled, ""
@ -1624,6 +1666,16 @@ func (client *Client) historyStatus(config *Config) (status HistoryStatus, targe
return
}
func (client *Client) copyLastSeen() (result map[string]time.Time) {
client.stateMutex.RLock()
defer client.stateMutex.RUnlock()
result = make(map[string]time.Time, len(client.lastSeen))
for id, lastSeen := range client.lastSeen {
result[id] = lastSeen
}
return
}
// these are bit flags indicating what part of the client status is "dirty"
// and needs to be read from memory and written to the db
const (
@ -1669,7 +1721,6 @@ func (client *Client) performWrite() {
dirtyBits := client.dirtyBits
client.dirtyBits = 0
account := client.account
lastSeen := client.lastSeen
client.stateMutex.Unlock()
if account == "" {
@ -1686,7 +1737,7 @@ func (client *Client) performWrite() {
client.server.accounts.saveChannels(account, channelNames)
}
if (dirtyBits & IncludeLastSeen) != 0 {
client.server.accounts.saveLastSeen(account, lastSeen)
client.server.accounts.saveLastSeen(account, client.copyLastSeen())
}
if (dirtyBits & IncludeUserModes) != 0 {
uModes := make(modes.Modes, 0, len(modes.SupportedUserModes))

View File

@ -8,6 +8,7 @@ package irc
import (
"errors"
"fmt"
"time"
"github.com/oragono/oragono/irc/utils"
)
@ -89,6 +90,14 @@ func (ck *CertKeyError) Error() string {
return fmt.Sprintf("Invalid TLS cert/key pair: %v", ck.Err)
}
type ThrottleError struct {
time.Duration
}
func (te *ThrottleError) Error() string {
return fmt.Sprintf(`Please wait at least %v and try again`, te.Duration)
}
// Config Errors
var (
ErrDatastorePathMissing = errors.New("Datastore path missing")

View File

@ -62,6 +62,7 @@ type SessionData struct {
ip net.IP
hostname string
certfp string
deviceID string
}
func (client *Client) AllSessionData(currentSession *Session) (data []SessionData, currentIndex int) {
@ -79,6 +80,7 @@ func (client *Client) AllSessionData(currentSession *Session) (data []SessionDat
ctime: session.ctime,
hostname: session.rawHostname,
certfp: session.certfp,
deviceID: session.deviceID,
}
if session.proxiedIP != nil {
data[i].ip = session.proxiedIP
@ -103,7 +105,7 @@ func (client *Client) AddSession(session *Session) (success bool, numSessions in
copy(newSessions, client.sessions)
newSessions[len(newSessions)-1] = session
if client.accountSettings.AutoreplayMissed {
lastSeen = client.lastSeen
lastSeen = client.lastSeen[session.deviceID]
}
client.sessions = newSessions
if client.autoAway {
@ -324,17 +326,23 @@ func (client *Client) AccountSettings() (result AccountSettings) {
func (client *Client) SetAccountSettings(settings AccountSettings) {
// we mark dirty if the client is transitioning to always-on
markDirty := false
var becameAlwaysOn, autoreplayMissedDisabled bool
alwaysOn := persistenceEnabled(client.server.Config().Accounts.Multiclient.AlwaysOn, settings.AlwaysOn)
client.stateMutex.Lock()
client.accountSettings = settings
if client.registered {
markDirty = !client.alwaysOn && alwaysOn
autoreplayMissedDisabled = (client.accountSettings.AutoreplayMissed && !settings.AutoreplayMissed)
becameAlwaysOn = (!client.alwaysOn && alwaysOn)
client.alwaysOn = alwaysOn
if autoreplayMissedDisabled {
client.lastSeen = make(map[string]time.Time)
}
}
client.accountSettings = settings
client.stateMutex.Unlock()
if markDirty {
if becameAlwaysOn {
client.markDirty(IncludeAllAttrs)
} else if autoreplayMissedDisabled {
client.markDirty(IncludeLastSeen)
}
}

View File

@ -236,6 +236,11 @@ func authPlainHandler(server *Server, client *Client, mechanism string, value []
return false
}
// see #843: strip the device ID for the benefit of clients that don't
// distinguish user/ident from account name
if strudelIndex := strings.IndexByte(authcid, '@'); strudelIndex != -1 {
authcid = authcid[:strudelIndex]
}
password := string(splitValue[2])
err := server.accounts.AuthenticateByPassphrase(client, authcid, password)
if err != nil {
@ -251,6 +256,10 @@ func authPlainHandler(server *Server, client *Client, mechanism string, value []
}
func authErrorToMessage(server *Server, err error) (msg string) {
if throttled, ok := err.(*ThrottleError); ok {
return throttled.Error()
}
switch err {
case errAccountDoesNotExist, errAccountUnverified, errAccountInvalidCredentials, errAuthzidAuthcidMismatch, errNickAccountMismatch:
return err.Error()
@ -280,6 +289,11 @@ func authExternalHandler(server *Server, client *Client, mechanism string, value
}
if err == nil {
// see #843: strip the device ID for the benefit of clients that don't
// distinguish user/ident from account name
if strudelIndex := strings.IndexByte(authzid, '@'); strudelIndex != -1 {
authzid = authzid[:strudelIndex]
}
err = server.accounts.AuthenticateByCertFP(client, rb.session.certfp, authzid)
}
@ -2180,8 +2194,8 @@ func passHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Resp
rb.Add(nil, server.name, ERR_ALREADYREGISTRED, client.nick, client.t("You may not reregister"))
return false
}
// only give them one try to run the PASS command (all code paths end with this
// variable being set):
// only give them one try to run the PASS command (if a server password is set,
// then all code paths end with this variable being set):
if rb.session.passStatus != serverPassUnsent {
return false
}
@ -2192,18 +2206,17 @@ func passHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Resp
if config.Accounts.LoginViaPassCommand {
colonIndex := strings.IndexByte(password, ':')
if colonIndex != -1 && client.Account() == "" {
// TODO consolidate all login throttle checks into AccountManager
throttled, _ := client.loginThrottle.Touch()
if !throttled {
account, accountPass := password[:colonIndex], password[colonIndex+1:]
err := server.accounts.AuthenticateByPassphrase(client, account, accountPass)
if err == nil {
sendSuccessfulAccountAuth(client, rb, false, true)
// login-via-pass-command entails that we do not need to check
// an actual server password (either no password or skip-server-password)
rb.session.passStatus = serverPassSuccessful
return false
}
account, accountPass := password[:colonIndex], password[colonIndex+1:]
if strudelIndex := strings.IndexByte(account, '@'); strudelIndex != -1 {
account, rb.session.deviceID = account[:strudelIndex], account[strudelIndex+1:]
}
err := server.accounts.AuthenticateByPassphrase(client, account, accountPass)
if err == nil {
sendSuccessfulAccountAuth(client, rb, false, true)
// login-via-pass-command entails that we do not need to check
// an actual server password (either no password or skip-server-password)
rb.session.passStatus = serverPassSuccessful
return false
}
}
}
@ -2521,6 +2534,22 @@ func userHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Resp
return false
}
// #843: we accept either: `USER user:pass@clientid` or `USER user@clientid`
if strudelIndex := strings.IndexByte(username, '@'); strudelIndex != -1 {
username, rb.session.deviceID = username[:strudelIndex], username[strudelIndex+1:]
if colonIndex := strings.IndexByte(username, ':'); colonIndex != -1 {
var password string
username, password = username[:colonIndex], username[colonIndex+1:]
err := server.accounts.AuthenticateByPassphrase(client, username, password)
if err == nil {
sendSuccessfulAccountAuth(client, rb, false, true)
} else {
// this is wrong, but send something for debugging that will show up in a raw transcript
rb.Add(nil, server.name, ERR_SASLFAIL, client.Nick(), client.t("SASL authentication failed"))
}
}
}
err := client.SetNames(username, realname, false)
if err == errInvalidUsername {
// if client's using a unicode nick or something weird, let's just set 'em up with a stock username instead.

View File

@ -649,14 +649,11 @@ func nsGroupHandler(server *Server, client *Client, command string, params []str
}
func nsLoginThrottleCheck(client *Client, rb *ResponseBuffer) (success bool) {
client.stateMutex.Lock()
throttled, remainingTime := client.loginThrottle.Touch()
client.stateMutex.Unlock()
throttled, remainingTime := client.checkLoginThrottle()
if throttled {
nsNotice(rb, fmt.Sprintf(client.t("Please wait at least %v and try again"), remainingTime))
return false
}
return true
return !throttled
}
func nsIdentifyHandler(server *Server, client *Client, command string, params []string, rb *ResponseBuffer) {
@ -685,9 +682,6 @@ func nsIdentifyHandler(server *Server, client *Client, command string, params []
// try passphrase
if passphrase != "" {
if !nsLoginThrottleCheck(client, rb) {
return
}
err = server.accounts.AuthenticateByPassphrase(client, username, passphrase)
loginSuccessful = (err == nil)
}
@ -1070,6 +1064,9 @@ func nsSessionsHandler(server *Server, client *Client, command string, params []
} else {
nsNotice(rb, fmt.Sprintf(client.t("Session %d:"), i+1))
}
if session.deviceID != "" {
nsNotice(rb, fmt.Sprintf(client.t("Device ID: %s"), session.deviceID))
}
nsNotice(rb, fmt.Sprintf(client.t("IP address: %s"), session.ip.String()))
nsNotice(rb, fmt.Sprintf(client.t("Hostname: %s"), session.hostname))
nsNotice(rb, fmt.Sprintf(client.t("Created at: %s"), session.ctime.Format(time.RFC1123)))