3
0
mirror of https://github.com/ergochat/ergo.git synced 2024-12-22 10:42:52 +01:00

persist lastSignoff in the database

This commit is contained in:
Shivaram Lingamneni 2020-02-20 02:33:49 -05:00
parent 17a89838b8
commit 4472683d58
3 changed files with 66 additions and 9 deletions

View File

@ -35,6 +35,7 @@ const (
keyCertToAccount = "account.creds.certfp %s" keyCertToAccount = "account.creds.certfp %s"
keyAccountChannels = "account.channels %s" // channels registered to the account keyAccountChannels = "account.channels %s" // channels registered to the account
keyAccountJoinedChannels = "account.joinedto %s" // channels a persistent client has joined keyAccountJoinedChannels = "account.joinedto %s" // channels a persistent client has joined
keyAccountLastSignoff = "account.lastsignoff %s"
keyVHostQueueAcctToId = "vhostQueue %s" keyVHostQueueAcctToId = "vhostQueue %s"
vhostRequestIdx = "vhostQueue" vhostRequestIdx = "vhostQueue"
@ -103,7 +104,7 @@ func (am *AccountManager) createAlwaysOnClients(config *Config) {
account, err := am.LoadAccount(accountName) account, err := am.LoadAccount(accountName)
if err == nil && account.Verified && if err == nil && account.Verified &&
persistenceEnabled(config.Accounts.Bouncer.AlwaysOn, account.Settings.AlwaysOn) { persistenceEnabled(config.Accounts.Bouncer.AlwaysOn, account.Settings.AlwaysOn) {
am.server.AddAlwaysOnClient(account, am.loadChannels(accountName)) am.server.AddAlwaysOnClient(account, am.loadChannels(accountName), am.loadLastSignoff(accountName))
} }
} }
} }
@ -534,6 +535,36 @@ func (am *AccountManager) loadChannels(account string) (channels []string) {
return return
} }
func (am *AccountManager) saveLastSignoff(account string, lastSignoff time.Time) {
key := fmt.Sprintf(keyAccountLastSignoff, account)
var val string
if !lastSignoff.IsZero() {
val = strconv.FormatInt(lastSignoff.UnixNano(), 10)
}
am.server.store.Update(func(tx *buntdb.Tx) error {
if val != "" {
tx.Set(key, val, nil)
} else {
tx.Delete(key)
}
return nil
})
}
func (am *AccountManager) loadLastSignoff(account string) (lastSignoff time.Time) {
key := fmt.Sprintf(keyAccountLastSignoff, account)
var lsText string
am.server.store.View(func(tx *buntdb.Tx) error {
lsText, _ = tx.Get(key)
return nil
})
lsNum, err := strconv.ParseInt(lsText, 10, 64)
if err != nil {
return time.Unix(0, lsNum)
}
return
}
func (am *AccountManager) addRemoveCertfp(account, certfp string, add bool, hasPrivs bool) (err error) { func (am *AccountManager) addRemoveCertfp(account, certfp string, add bool, hasPrivs bool) (err error) {
certfp, err = utils.NormalizeCertfp(certfp) certfp, err = utils.NormalizeCertfp(certfp)
if err != nil { if err != nil {
@ -1034,6 +1065,7 @@ func (am *AccountManager) Unregister(account string) error {
vhostQueueKey := fmt.Sprintf(keyVHostQueueAcctToId, casefoldedAccount) vhostQueueKey := fmt.Sprintf(keyVHostQueueAcctToId, casefoldedAccount)
channelsKey := fmt.Sprintf(keyAccountChannels, casefoldedAccount) channelsKey := fmt.Sprintf(keyAccountChannels, casefoldedAccount)
joinedChannelsKey := fmt.Sprintf(keyAccountJoinedChannels, casefoldedAccount) joinedChannelsKey := fmt.Sprintf(keyAccountJoinedChannels, casefoldedAccount)
lastSignoffKey := fmt.Sprintf(keyAccountLastSignoff, casefoldedAccount)
var clients []*Client var clients []*Client
@ -1070,6 +1102,7 @@ func (am *AccountManager) Unregister(account string) error {
channelsStr, _ = tx.Get(channelsKey) channelsStr, _ = tx.Get(channelsKey)
tx.Delete(channelsKey) tx.Delete(channelsKey)
tx.Delete(joinedChannelsKey) tx.Delete(joinedChannelsKey)
tx.Delete(lastSignoffKey)
_, err := tx.Delete(vhostQueueKey) _, err := tx.Delete(vhostQueueKey)
am.decrementVHostQueueCount(casefoldedAccount, err) am.decrementVHostQueueCount(casefoldedAccount, err)

View File

@ -306,7 +306,7 @@ func (server *Server) RunClient(conn clientConn, proxyLine string) {
client.run(session, proxyLine) client.run(session, proxyLine)
} }
func (server *Server) AddAlwaysOnClient(account ClientAccount, chnames []string) { func (server *Server) AddAlwaysOnClient(account ClientAccount, chnames []string, lastSignoff time.Time) {
now := time.Now().UTC() now := time.Now().UTC()
config := server.Config() config := server.Config()
@ -322,7 +322,8 @@ func (server *Server) AddAlwaysOnClient(account ClientAccount, chnames []string)
rawHostname: server.name, rawHostname: server.name,
realIP: utils.IPv4LoopbackAddress, realIP: utils.IPv4LoopbackAddress,
alwaysOn: true, alwaysOn: true,
lastSignoff: lastSignoff,
} }
client.SetMode(modes.TLS, true) client.SetMode(modes.TLS, true)
@ -1187,10 +1188,17 @@ func (client *Client) destroy(session *Session) {
} }
if alwaysOn && remainingSessions == 0 { if alwaysOn && remainingSessions == 0 {
client.lastSignoff = lastSignoff client.lastSignoff = lastSignoff
client.dirtyBits |= IncludeLastSignoff
} else {
lastSignoff = time.Time{}
} }
exitedSnomaskSent := client.exitedSnomaskSent exitedSnomaskSent := client.exitedSnomaskSent
client.stateMutex.Unlock() client.stateMutex.Unlock()
if !lastSignoff.IsZero() {
client.wakeWriter()
}
// destroy all applicable sessions: // destroy all applicable sessions:
var quitMessage string var quitMessage string
for _, session := range sessionsToDestroy { for _, session := range sessionsToDestroy {
@ -1573,6 +1581,7 @@ func (client *Client) historyStatus(config *Config) (persistent, ephemeral bool,
// TODO add a dirty flag for lastSignoff // TODO add a dirty flag for lastSignoff
const ( const (
IncludeChannels uint = 1 << iota IncludeChannels uint = 1 << iota
IncludeLastSignoff
) )
func (client *Client) markDirty(dirtyBits uint) { func (client *Client) markDirty(dirtyBits uint) {
@ -1609,7 +1618,7 @@ func (client *Client) writeLoop() {
func (client *Client) performWrite() { func (client *Client) performWrite() {
client.stateMutex.Lock() client.stateMutex.Lock()
// TODO actually read dirtyBits in the future dirtyBits := client.dirtyBits
client.dirtyBits = 0 client.dirtyBits = 0
account := client.account account := client.account
client.stateMutex.Unlock() client.stateMutex.Unlock()
@ -1619,10 +1628,18 @@ func (client *Client) performWrite() {
return return
} }
channels := client.Channels() if (dirtyBits & IncludeChannels) != 0 {
channelNames := make([]string, len(channels)) channels := client.Channels()
for i, channel := range channels { channelNames := make([]string, len(channels))
channelNames[i] = channel.Name() for i, channel := range channels {
channelNames[i] = channel.Name()
}
client.server.accounts.saveChannels(account, channelNames)
}
if (dirtyBits & IncludeLastSignoff) != 0 {
client.stateMutex.RLock()
lastSignoff := client.lastSignoff
client.stateMutex.RUnlock()
client.server.accounts.saveLastSignoff(account, lastSignoff)
} }
client.server.accounts.saveChannels(account, channelNames)
} }

View File

@ -94,6 +94,12 @@ func (client *Client) AllSessionData(currentSession *Session) (data []SessionDat
} }
func (client *Client) AddSession(session *Session) (success bool, numSessions int, lastSignoff time.Time) { func (client *Client) AddSession(session *Session) (success bool, numSessions int, lastSignoff time.Time) {
defer func() {
if !lastSignoff.IsZero() {
client.wakeWriter()
}
}()
client.stateMutex.Lock() client.stateMutex.Lock()
defer client.stateMutex.Unlock() defer client.stateMutex.Unlock()
@ -111,6 +117,7 @@ func (client *Client) AddSession(session *Session) (success bool, numSessions in
// on the server with no sessions: // on the server with no sessions:
lastSignoff = client.lastSignoff lastSignoff = client.lastSignoff
client.lastSignoff = time.Time{} client.lastSignoff = time.Time{}
client.dirtyBits |= IncludeLastSignoff
} }
client.sessions = newSessions client.sessions = newSessions
return true, len(client.sessions), lastSignoff return true, len(client.sessions), lastSignoff