3
0
mirror of https://github.com/ergochat/ergo.git synced 2024-11-14 16:09:32 +01:00
This commit is contained in:
Shivaram Lingamneni 2020-06-12 15:51:48 -04:00
parent c4c4ec027e
commit 218bea5a3e
6 changed files with 146 additions and 48 deletions

View File

@ -617,11 +617,12 @@ func (am *AccountManager) loadModes(account string) (uModes modes.Modes) {
return return
} }
func (am *AccountManager) saveLastSeen(account string, lastSeen time.Time) { func (am *AccountManager) saveLastSeen(account string, lastSeen map[string]time.Time) {
key := fmt.Sprintf(keyAccountLastSeen, account) key := fmt.Sprintf(keyAccountLastSeen, account)
var val string var val string
if !lastSeen.IsZero() { if len(lastSeen) != 0 {
val = strconv.FormatInt(lastSeen.UnixNano(), 10) text, _ := json.Marshal(lastSeen)
val = string(text)
} }
am.server.store.Update(func(tx *buntdb.Tx) error { am.server.store.Update(func(tx *buntdb.Tx) error {
if val != "" { if val != "" {
@ -633,20 +634,19 @@ func (am *AccountManager) saveLastSeen(account string, lastSeen time.Time) {
}) })
} }
func (am *AccountManager) loadLastSeen(account string) (lastSeen time.Time) { func (am *AccountManager) loadLastSeen(account string) (lastSeen map[string]time.Time) {
key := fmt.Sprintf(keyAccountLastSeen, account) key := fmt.Sprintf(keyAccountLastSeen, account)
var lsText string var lsText string
am.server.store.Update(func(tx *buntdb.Tx) error { am.server.store.Update(func(tx *buntdb.Tx) error {
lsText, _ = tx.Get(key) lsText, _ = tx.Get(key)
// XXX clear this on startup, because it's not clear when it's
// going to be overwritten, and restarting the server twice in a row
// could result in a large amount of duplicated history replay
tx.Delete(key)
return nil return nil
}) })
lsNum, err := strconv.ParseInt(lsText, 10, 64) if lsText == "" {
if err == nil { return nil
return time.Unix(0, lsNum).UTC() }
err := json.Unmarshal([]byte(lsText), &lastSeen)
if err != nil {
return nil
} }
return return
} }
@ -1052,6 +1052,10 @@ func (am *AccountManager) AuthenticateByPassphrase(client *Client, accountName s
} }
} }
if throttled, remainingTime := client.checkLoginThrottle(); throttled {
return &ThrottleError{remainingTime}
}
var account ClientAccount var account ClientAccount
defer func() { defer func() {

View File

@ -30,6 +30,11 @@ const (
// IdentTimeout is how long before our ident (username) check times out. // IdentTimeout is how long before our ident (username) check times out.
IdentTimeout = time.Second + 500*time.Millisecond IdentTimeout = time.Second + 500*time.Millisecond
IRCv3TimestampFormat = utils.IRCv3TimestampFormat IRCv3TimestampFormat = utils.IRCv3TimestampFormat
// limit the number of device IDs a client can use, as a DoS mitigation
maxDeviceIDsPerClient = 64
// controls how often often we write an autoreplay-missed client's
// deviceid->lastseentime mapping to the database
lastSeenWriteInterval = time.Minute * 10
) )
// ResumeDetails is a place to stash data at various stages of // ResumeDetails is a place to stash data at various stages of
@ -61,7 +66,8 @@ type Client struct {
isSTSOnly bool isSTSOnly bool
languages []string languages []string
lastActive time.Time // last time they sent a command that wasn't PONG or similar lastActive time.Time // last time they sent a command that wasn't PONG or similar
lastSeen time.Time // last time they sent any kind of command lastSeen map[string]time.Time // maps device ID (including "") to time of last received command
lastSeenLastWrite time.Time // last time `lastSeen` was written to the datastore
loginThrottle connection_limits.GenericThrottle loginThrottle connection_limits.GenericThrottle
nick string nick string
nickCasefolded string nickCasefolded string
@ -112,6 +118,8 @@ const (
type Session struct { type Session struct {
client *Client client *Client
deviceID string
ctime time.Time ctime time.Time
lastActive time.Time lastActive time.Time
@ -299,7 +307,7 @@ func (server *Server) RunClient(conn IRCConn) {
// give them 1k of grace over the limit: // give them 1k of grace over the limit:
socket := NewSocket(conn, config.Server.MaxSendQBytes) socket := NewSocket(conn, config.Server.MaxSendQBytes)
client := &Client{ client := &Client{
lastSeen: now, lastSeen: make(map[string]time.Time),
lastActive: now, lastActive: now,
channels: make(ChannelSet), channels: make(ChannelSet),
ctime: now, ctime: now,
@ -358,11 +366,14 @@ func (server *Server) RunClient(conn IRCConn) {
client.run(session) client.run(session)
} }
func (server *Server) AddAlwaysOnClient(account ClientAccount, chnames []string, lastSeen time.Time, uModes modes.Modes) { func (server *Server) AddAlwaysOnClient(account ClientAccount, chnames []string, lastSeen map[string]time.Time, uModes modes.Modes) {
now := time.Now().UTC() now := time.Now().UTC()
config := server.Config() config := server.Config()
if lastSeen.IsZero() { if lastSeen == nil {
lastSeen = now lastSeen = make(map[string]time.Time)
if account.Settings.AutoreplayMissed {
lastSeen[""] = now
}
} }
client := &Client{ client := &Client{
@ -714,14 +725,39 @@ func (client *Client) playReattachMessages(session *Session) {
// at this time, modulo network latency and fakelag). `active` means not a PING or suchlike // at this time, modulo network latency and fakelag). `active` means not a PING or suchlike
// (i.e. the user should be sitting in front of their client). // (i.e. the user should be sitting in front of their client).
func (client *Client) Touch(active bool, session *Session) { func (client *Client) Touch(active bool, session *Session) {
var markDirty bool
now := time.Now().UTC() now := time.Now().UTC()
client.stateMutex.Lock() client.stateMutex.Lock()
defer client.stateMutex.Unlock()
client.lastSeen = now
if active { if active {
client.lastActive = now client.lastActive = now
session.lastActive = now session.lastActive = now
} }
if client.accountSettings.AutoreplayMissed {
client.setLastSeen(now, session.deviceID)
if now.Sub(client.lastSeenLastWrite) > lastSeenWriteInterval {
markDirty = true
client.lastSeenLastWrite = now
}
}
client.stateMutex.Unlock()
if markDirty {
client.markDirty(IncludeLastSeen)
}
}
func (client *Client) setLastSeen(now time.Time, deviceID string) {
client.lastSeen[deviceID] = now
// evict the least-recently-used entry if necessary
if maxDeviceIDsPerClient < len(client.lastSeen) {
var minLastSeen time.Time
var minClientId string
for deviceID, lastSeen := range client.lastSeen {
if minLastSeen.IsZero() || lastSeen.Before(minLastSeen) {
minClientId, minLastSeen = deviceID, lastSeen
}
}
delete(client.lastSeen, minClientId)
}
} }
// Ping sends the client a PING message. // Ping sends the client a PING message.
@ -1604,6 +1640,12 @@ func (client *Client) attemptAutoOper(session *Session) {
} }
} }
func (client *Client) checkLoginThrottle() (throttled bool, remainingTime time.Duration) {
client.stateMutex.Lock()
defer client.stateMutex.Unlock()
return client.loginThrottle.Touch()
}
func (client *Client) historyStatus(config *Config) (status HistoryStatus, target string) { func (client *Client) historyStatus(config *Config) (status HistoryStatus, target string) {
if !config.History.Enabled { if !config.History.Enabled {
return HistoryDisabled, "" return HistoryDisabled, ""
@ -1624,6 +1666,16 @@ func (client *Client) historyStatus(config *Config) (status HistoryStatus, targe
return return
} }
func (client *Client) copyLastSeen() (result map[string]time.Time) {
client.stateMutex.RLock()
defer client.stateMutex.RUnlock()
result = make(map[string]time.Time, len(client.lastSeen))
for id, lastSeen := range client.lastSeen {
result[id] = lastSeen
}
return
}
// these are bit flags indicating what part of the client status is "dirty" // these are bit flags indicating what part of the client status is "dirty"
// and needs to be read from memory and written to the db // and needs to be read from memory and written to the db
const ( const (
@ -1669,7 +1721,6 @@ func (client *Client) performWrite() {
dirtyBits := client.dirtyBits dirtyBits := client.dirtyBits
client.dirtyBits = 0 client.dirtyBits = 0
account := client.account account := client.account
lastSeen := client.lastSeen
client.stateMutex.Unlock() client.stateMutex.Unlock()
if account == "" { if account == "" {
@ -1686,7 +1737,7 @@ func (client *Client) performWrite() {
client.server.accounts.saveChannels(account, channelNames) client.server.accounts.saveChannels(account, channelNames)
} }
if (dirtyBits & IncludeLastSeen) != 0 { if (dirtyBits & IncludeLastSeen) != 0 {
client.server.accounts.saveLastSeen(account, lastSeen) client.server.accounts.saveLastSeen(account, client.copyLastSeen())
} }
if (dirtyBits & IncludeUserModes) != 0 { if (dirtyBits & IncludeUserModes) != 0 {
uModes := make(modes.Modes, 0, len(modes.SupportedUserModes)) uModes := make(modes.Modes, 0, len(modes.SupportedUserModes))

View File

@ -8,6 +8,7 @@ package irc
import ( import (
"errors" "errors"
"fmt" "fmt"
"time"
"github.com/oragono/oragono/irc/utils" "github.com/oragono/oragono/irc/utils"
) )
@ -89,6 +90,14 @@ func (ck *CertKeyError) Error() string {
return fmt.Sprintf("Invalid TLS cert/key pair: %v", ck.Err) return fmt.Sprintf("Invalid TLS cert/key pair: %v", ck.Err)
} }
type ThrottleError struct {
time.Duration
}
func (te *ThrottleError) Error() string {
return fmt.Sprintf(`Please wait at least %v and try again`, te.Duration)
}
// Config Errors // Config Errors
var ( var (
ErrDatastorePathMissing = errors.New("Datastore path missing") ErrDatastorePathMissing = errors.New("Datastore path missing")

View File

@ -62,6 +62,7 @@ type SessionData struct {
ip net.IP ip net.IP
hostname string hostname string
certfp string certfp string
deviceID string
} }
func (client *Client) AllSessionData(currentSession *Session) (data []SessionData, currentIndex int) { func (client *Client) AllSessionData(currentSession *Session) (data []SessionData, currentIndex int) {
@ -79,6 +80,7 @@ func (client *Client) AllSessionData(currentSession *Session) (data []SessionDat
ctime: session.ctime, ctime: session.ctime,
hostname: session.rawHostname, hostname: session.rawHostname,
certfp: session.certfp, certfp: session.certfp,
deviceID: session.deviceID,
} }
if session.proxiedIP != nil { if session.proxiedIP != nil {
data[i].ip = session.proxiedIP data[i].ip = session.proxiedIP
@ -103,7 +105,7 @@ func (client *Client) AddSession(session *Session) (success bool, numSessions in
copy(newSessions, client.sessions) copy(newSessions, client.sessions)
newSessions[len(newSessions)-1] = session newSessions[len(newSessions)-1] = session
if client.accountSettings.AutoreplayMissed { if client.accountSettings.AutoreplayMissed {
lastSeen = client.lastSeen lastSeen = client.lastSeen[session.deviceID]
} }
client.sessions = newSessions client.sessions = newSessions
if client.autoAway { if client.autoAway {
@ -324,17 +326,23 @@ func (client *Client) AccountSettings() (result AccountSettings) {
func (client *Client) SetAccountSettings(settings AccountSettings) { func (client *Client) SetAccountSettings(settings AccountSettings) {
// we mark dirty if the client is transitioning to always-on // we mark dirty if the client is transitioning to always-on
markDirty := false var becameAlwaysOn, autoreplayMissedDisabled bool
alwaysOn := persistenceEnabled(client.server.Config().Accounts.Multiclient.AlwaysOn, settings.AlwaysOn) alwaysOn := persistenceEnabled(client.server.Config().Accounts.Multiclient.AlwaysOn, settings.AlwaysOn)
client.stateMutex.Lock() client.stateMutex.Lock()
client.accountSettings = settings
if client.registered { if client.registered {
markDirty = !client.alwaysOn && alwaysOn autoreplayMissedDisabled = (client.accountSettings.AutoreplayMissed && !settings.AutoreplayMissed)
becameAlwaysOn = (!client.alwaysOn && alwaysOn)
client.alwaysOn = alwaysOn client.alwaysOn = alwaysOn
if autoreplayMissedDisabled {
client.lastSeen = make(map[string]time.Time)
} }
}
client.accountSettings = settings
client.stateMutex.Unlock() client.stateMutex.Unlock()
if markDirty { if becameAlwaysOn {
client.markDirty(IncludeAllAttrs) client.markDirty(IncludeAllAttrs)
} else if autoreplayMissedDisabled {
client.markDirty(IncludeLastSeen)
} }
} }

View File

@ -236,6 +236,11 @@ func authPlainHandler(server *Server, client *Client, mechanism string, value []
return false return false
} }
// see #843: strip the device ID for the benefit of clients that don't
// distinguish user/ident from account name
if strudelIndex := strings.IndexByte(authcid, '@'); strudelIndex != -1 {
authcid = authcid[:strudelIndex]
}
password := string(splitValue[2]) password := string(splitValue[2])
err := server.accounts.AuthenticateByPassphrase(client, authcid, password) err := server.accounts.AuthenticateByPassphrase(client, authcid, password)
if err != nil { if err != nil {
@ -251,6 +256,10 @@ func authPlainHandler(server *Server, client *Client, mechanism string, value []
} }
func authErrorToMessage(server *Server, err error) (msg string) { func authErrorToMessage(server *Server, err error) (msg string) {
if throttled, ok := err.(*ThrottleError); ok {
return throttled.Error()
}
switch err { switch err {
case errAccountDoesNotExist, errAccountUnverified, errAccountInvalidCredentials, errAuthzidAuthcidMismatch, errNickAccountMismatch: case errAccountDoesNotExist, errAccountUnverified, errAccountInvalidCredentials, errAuthzidAuthcidMismatch, errNickAccountMismatch:
return err.Error() return err.Error()
@ -280,6 +289,11 @@ func authExternalHandler(server *Server, client *Client, mechanism string, value
} }
if err == nil { if err == nil {
// see #843: strip the device ID for the benefit of clients that don't
// distinguish user/ident from account name
if strudelIndex := strings.IndexByte(authzid, '@'); strudelIndex != -1 {
authzid = authzid[:strudelIndex]
}
err = server.accounts.AuthenticateByCertFP(client, rb.session.certfp, authzid) err = server.accounts.AuthenticateByCertFP(client, rb.session.certfp, authzid)
} }
@ -2180,8 +2194,8 @@ func passHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Resp
rb.Add(nil, server.name, ERR_ALREADYREGISTRED, client.nick, client.t("You may not reregister")) rb.Add(nil, server.name, ERR_ALREADYREGISTRED, client.nick, client.t("You may not reregister"))
return false return false
} }
// only give them one try to run the PASS command (all code paths end with this // only give them one try to run the PASS command (if a server password is set,
// variable being set): // then all code paths end with this variable being set):
if rb.session.passStatus != serverPassUnsent { if rb.session.passStatus != serverPassUnsent {
return false return false
} }
@ -2192,10 +2206,10 @@ func passHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Resp
if config.Accounts.LoginViaPassCommand { if config.Accounts.LoginViaPassCommand {
colonIndex := strings.IndexByte(password, ':') colonIndex := strings.IndexByte(password, ':')
if colonIndex != -1 && client.Account() == "" { if colonIndex != -1 && client.Account() == "" {
// TODO consolidate all login throttle checks into AccountManager
throttled, _ := client.loginThrottle.Touch()
if !throttled {
account, accountPass := password[:colonIndex], password[colonIndex+1:] account, accountPass := password[:colonIndex], password[colonIndex+1:]
if strudelIndex := strings.IndexByte(account, '@'); strudelIndex != -1 {
account, rb.session.deviceID = account[:strudelIndex], account[strudelIndex+1:]
}
err := server.accounts.AuthenticateByPassphrase(client, account, accountPass) err := server.accounts.AuthenticateByPassphrase(client, account, accountPass)
if err == nil { if err == nil {
sendSuccessfulAccountAuth(client, rb, false, true) sendSuccessfulAccountAuth(client, rb, false, true)
@ -2206,7 +2220,6 @@ func passHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Resp
} }
} }
} }
}
// if login-via-PASS failed for any reason, proceed to try and interpret the // if login-via-PASS failed for any reason, proceed to try and interpret the
// provided password as the server password // provided password as the server password
@ -2521,6 +2534,22 @@ func userHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Resp
return false return false
} }
// #843: we accept either: `USER user:pass@clientid` or `USER user@clientid`
if strudelIndex := strings.IndexByte(username, '@'); strudelIndex != -1 {
username, rb.session.deviceID = username[:strudelIndex], username[strudelIndex+1:]
if colonIndex := strings.IndexByte(username, ':'); colonIndex != -1 {
var password string
username, password = username[:colonIndex], username[colonIndex+1:]
err := server.accounts.AuthenticateByPassphrase(client, username, password)
if err == nil {
sendSuccessfulAccountAuth(client, rb, false, true)
} else {
// this is wrong, but send something for debugging that will show up in a raw transcript
rb.Add(nil, server.name, ERR_SASLFAIL, client.Nick(), client.t("SASL authentication failed"))
}
}
}
err := client.SetNames(username, realname, false) err := client.SetNames(username, realname, false)
if err == errInvalidUsername { if err == errInvalidUsername {
// if client's using a unicode nick or something weird, let's just set 'em up with a stock username instead. // if client's using a unicode nick or something weird, let's just set 'em up with a stock username instead.

View File

@ -649,14 +649,11 @@ func nsGroupHandler(server *Server, client *Client, command string, params []str
} }
func nsLoginThrottleCheck(client *Client, rb *ResponseBuffer) (success bool) { func nsLoginThrottleCheck(client *Client, rb *ResponseBuffer) (success bool) {
client.stateMutex.Lock() throttled, remainingTime := client.checkLoginThrottle()
throttled, remainingTime := client.loginThrottle.Touch()
client.stateMutex.Unlock()
if throttled { if throttled {
nsNotice(rb, fmt.Sprintf(client.t("Please wait at least %v and try again"), remainingTime)) nsNotice(rb, fmt.Sprintf(client.t("Please wait at least %v and try again"), remainingTime))
return false
} }
return true return !throttled
} }
func nsIdentifyHandler(server *Server, client *Client, command string, params []string, rb *ResponseBuffer) { func nsIdentifyHandler(server *Server, client *Client, command string, params []string, rb *ResponseBuffer) {
@ -685,9 +682,6 @@ func nsIdentifyHandler(server *Server, client *Client, command string, params []
// try passphrase // try passphrase
if passphrase != "" { if passphrase != "" {
if !nsLoginThrottleCheck(client, rb) {
return
}
err = server.accounts.AuthenticateByPassphrase(client, username, passphrase) err = server.accounts.AuthenticateByPassphrase(client, username, passphrase)
loginSuccessful = (err == nil) loginSuccessful = (err == nil)
} }
@ -1070,6 +1064,9 @@ func nsSessionsHandler(server *Server, client *Client, command string, params []
} else { } else {
nsNotice(rb, fmt.Sprintf(client.t("Session %d:"), i+1)) nsNotice(rb, fmt.Sprintf(client.t("Session %d:"), i+1))
} }
if session.deviceID != "" {
nsNotice(rb, fmt.Sprintf(client.t("Device ID: %s"), session.deviceID))
}
nsNotice(rb, fmt.Sprintf(client.t("IP address: %s"), session.ip.String())) nsNotice(rb, fmt.Sprintf(client.t("IP address: %s"), session.ip.String()))
nsNotice(rb, fmt.Sprintf(client.t("Hostname: %s"), session.hostname)) nsNotice(rb, fmt.Sprintf(client.t("Hostname: %s"), session.hostname))
nsNotice(rb, fmt.Sprintf(client.t("Created at: %s"), session.ctime.Format(time.RFC1123))) nsNotice(rb, fmt.Sprintf(client.t("Created at: %s"), session.ctime.Format(time.RFC1123)))