3
0
mirror of https://github.com/ergochat/ergo.git synced 2024-11-14 07:59:31 +01:00
Implements the new `CHATHISTORY LISTCORRESPONDENTS` API.
This commit is contained in:
Shivaram Lingamneni 2021-04-06 00:46:07 -04:00
parent 2e9a0d4b2d
commit 4052cd12fe
9 changed files with 320 additions and 52 deletions

View File

@ -920,7 +920,7 @@ func (channel *Channel) autoReplayHistory(client *Client, rb *ResponseBuffer, sk
_, seq, _ := channel.server.GetHistorySequence(channel, client, "") _, seq, _ := channel.server.GetHistorySequence(channel, client, "")
if seq != nil { if seq != nil {
zncMax := channel.server.Config().History.ZNCMax zncMax := channel.server.Config().History.ZNCMax
items, _, _ = seq.Between(history.Selector{Time: start}, history.Selector{Time: end}, zncMax) items, _ = seq.Between(history.Selector{Time: start}, history.Selector{Time: end}, zncMax)
} }
} else if !rb.session.HasHistoryCaps() { } else if !rb.session.HasHistoryCaps() {
var replayLimit int var replayLimit int
@ -937,7 +937,7 @@ func (channel *Channel) autoReplayHistory(client *Client, rb *ResponseBuffer, sk
if 0 < replayLimit { if 0 < replayLimit {
_, seq, _ := channel.server.GetHistorySequence(channel, client, "") _, seq, _ := channel.server.GetHistorySequence(channel, client, "")
if seq != nil { if seq != nil {
items, _, _ = seq.Between(history.Selector{}, history.Selector{}, replayLimit) items, _ = seq.Between(history.Selector{}, history.Selector{}, replayLimit)
} }
} }
} }
@ -1097,20 +1097,15 @@ func (channel *Channel) resumeAndAnnounce(session *Session) {
func (channel *Channel) replayHistoryForResume(session *Session, after time.Time, before time.Time) { func (channel *Channel) replayHistoryForResume(session *Session, after time.Time, before time.Time) {
var items []history.Item var items []history.Item
var complete bool
afterS, beforeS := history.Selector{Time: after}, history.Selector{Time: before} afterS, beforeS := history.Selector{Time: after}, history.Selector{Time: before}
_, seq, _ := channel.server.GetHistorySequence(channel, session.client, "") _, seq, _ := channel.server.GetHistorySequence(channel, session.client, "")
if seq != nil { if seq != nil {
items, complete, _ = seq.Between(afterS, beforeS, channel.server.Config().History.ZNCMax) items, _ = seq.Between(afterS, beforeS, channel.server.Config().History.ZNCMax)
} }
rb := NewResponseBuffer(session) rb := NewResponseBuffer(session)
if len(items) != 0 { if len(items) != 0 {
channel.replayHistoryItems(rb, items, false) channel.replayHistoryItems(rb, items, false)
} }
if !complete && !session.resumeDetails.HistoryIncomplete {
// warn here if we didn't warn already
rb.Add(nil, histservService.prefix, "NOTICE", channel.Name(), session.client.t("Some additional message history may have been lost"))
}
rb.Send(true) rb.Send(true)
} }

View File

@ -990,7 +990,7 @@ func (session *Session) playResume() {
} }
_, privmsgSeq, _ := server.GetHistorySequence(nil, client, "*") _, privmsgSeq, _ := server.GetHistorySequence(nil, client, "*")
if privmsgSeq != nil { if privmsgSeq != nil {
privmsgs, _, _ := privmsgSeq.Between(history.Selector{}, history.Selector{}, config.History.ClientLength) privmsgs, _ := privmsgSeq.Between(history.Selector{}, history.Selector{}, config.History.ClientLength)
for _, item := range privmsgs { for _, item := range privmsgs {
sender := server.clients.Get(NUHToNick(item.Nick)) sender := server.clients.Get(NUHToNick(item.Nick))
if sender != nil { if sender != nil {
@ -1055,10 +1055,10 @@ func (session *Session) playResume() {
// replay direct PRIVSMG history // replay direct PRIVSMG history
if !timestamp.IsZero() && privmsgSeq != nil { if !timestamp.IsZero() && privmsgSeq != nil {
after := history.Selector{Time: timestamp} after := history.Selector{Time: timestamp}
items, complete, _ := privmsgSeq.Between(after, history.Selector{}, config.History.ZNCMax) items, _ := privmsgSeq.Between(after, history.Selector{}, config.History.ZNCMax)
if len(items) != 0 { if len(items) != 0 {
rb := NewResponseBuffer(session) rb := NewResponseBuffer(session)
client.replayPrivmsgHistory(rb, items, "", complete) client.replayPrivmsgHistory(rb, items, "")
rb.Send(true) rb.Send(true)
} }
} }
@ -1066,7 +1066,7 @@ func (session *Session) playResume() {
session.resumeDetails = nil session.resumeDetails = nil
} }
func (client *Client) replayPrivmsgHistory(rb *ResponseBuffer, items []history.Item, target string, complete bool) { func (client *Client) replayPrivmsgHistory(rb *ResponseBuffer, items []history.Item, target string) {
var batchID string var batchID string
details := client.Details() details := client.Details()
nick := details.nick nick := details.nick
@ -1126,9 +1126,6 @@ func (client *Client) replayPrivmsgHistory(rb *ResponseBuffer, items []history.I
} }
rb.EndNestedBatch(batchID) rb.EndNestedBatch(batchID)
if !complete {
rb.Add(nil, histservService.prefix, "NOTICE", nick, client.t("Some additional message history may have been lost"))
}
} }
// IdleTime returns how long this client's been idle. // IdleTime returns how long this client's been idle.

View File

@ -308,3 +308,16 @@ func (clients *ClientManager) FindAll(userhost string) (set ClientSet) {
return set return set
} }
// Determine the canonical / unfolded form of a nick, if a client matching it
// is present (or always-on).
func (clients *ClientManager) UnfoldNick(cfnick string) (nick string) {
clients.RLock()
c := clients.byNick[cfnick]
clients.RUnlock()
if c != nil {
return c.Nick()
} else {
return cfnick
}
}

View File

@ -566,16 +566,15 @@ func capHandler(server *Server, client *Client, msg ircmsg.Message, rb *Response
// e.g., CHATHISTORY #ircv3 BETWEEN timestamp=YYYY-MM-DDThh:mm:ss.sssZ timestamp=YYYY-MM-DDThh:mm:ss.sssZ + 100 // e.g., CHATHISTORY #ircv3 BETWEEN timestamp=YYYY-MM-DDThh:mm:ss.sssZ timestamp=YYYY-MM-DDThh:mm:ss.sssZ + 100
func chathistoryHandler(server *Server, client *Client, msg ircmsg.Message, rb *ResponseBuffer) (exiting bool) { func chathistoryHandler(server *Server, client *Client, msg ircmsg.Message, rb *ResponseBuffer) (exiting bool) {
var items []history.Item var items []history.Item
unknown_command := false
var target string var target string
var channel *Channel var channel *Channel
var sequence history.Sequence var sequence history.Sequence
var err error var err error
var listCorrespondents bool
var correspondents []history.CorrespondentListing
defer func() { defer func() {
// errors are sent either without a batch, or in a draft/labeled-response batch as usual // errors are sent either without a batch, or in a draft/labeled-response batch as usual
if unknown_command { if err == utils.ErrInvalidParams {
rb.Add(nil, server.name, "FAIL", "CHATHISTORY", "UNKNOWN_COMMAND", utils.SafeErrorParam(msg.Params[0]), client.t("Unknown command"))
} else if err == utils.ErrInvalidParams {
rb.Add(nil, server.name, "FAIL", "CHATHISTORY", "INVALID_PARAMS", msg.Params[0], client.t("Invalid parameters")) rb.Add(nil, server.name, "FAIL", "CHATHISTORY", "INVALID_PARAMS", msg.Params[0], client.t("Invalid parameters"))
} else if sequence == nil { } else if sequence == nil {
rb.Add(nil, server.name, "FAIL", "CHATHISTORY", "INVALID_TARGET", utils.SafeErrorParam(target), client.t("Messages could not be retrieved")) rb.Add(nil, server.name, "FAIL", "CHATHISTORY", "INVALID_TARGET", utils.SafeErrorParam(target), client.t("Messages could not be retrieved"))
@ -583,10 +582,18 @@ func chathistoryHandler(server *Server, client *Client, msg ircmsg.Message, rb *
rb.Add(nil, server.name, "FAIL", "CHATHISTORY", "MESSAGE_ERROR", msg.Params[0], client.t("Messages could not be retrieved")) rb.Add(nil, server.name, "FAIL", "CHATHISTORY", "MESSAGE_ERROR", msg.Params[0], client.t("Messages could not be retrieved"))
} else { } else {
// successful responses are sent as a chathistory or history batch // successful responses are sent as a chathistory or history batch
if channel != nil { if listCorrespondents {
batchID := rb.StartNestedBatch("draft/chathistory-listcorrespondents")
defer rb.EndNestedBatch(batchID)
for _, correspondent := range correspondents {
nick := server.clients.UnfoldNick(correspondent.CfCorrespondent)
rb.Add(nil, server.name, "CHATHISTORY", "CORRESPONDENT", nick,
correspondent.Time.Format(IRCv3TimestampFormat))
}
} else if channel != nil {
channel.replayHistoryItems(rb, items, false) channel.replayHistoryItems(rb, items, false)
} else { } else {
client.replayPrivmsgHistory(rb, items, target, true) client.replayPrivmsgHistory(rb, items, target)
} }
} }
}() }()
@ -598,6 +605,9 @@ func chathistoryHandler(server *Server, client *Client, msg ircmsg.Message, rb *
} }
preposition := strings.ToLower(msg.Params[0]) preposition := strings.ToLower(msg.Params[0])
target = msg.Params[1] target = msg.Params[1]
if preposition == "listcorrespondents" {
target = "*"
}
parseQueryParam := func(param string) (msgid string, timestamp time.Time, err error) { parseQueryParam := func(param string) (msgid string, timestamp time.Time, err error) {
if param == "*" && (preposition == "before" || preposition == "between") { if param == "*" && (preposition == "before" || preposition == "between") {
@ -641,15 +651,22 @@ func chathistoryHandler(server *Server, client *Client, msg ircmsg.Message, rb *
return endpoint.Truncate(time.Millisecond).Add(time.Millisecond) return endpoint.Truncate(time.Millisecond).Add(time.Millisecond)
} }
paramPos := 2
var start, end history.Selector var start, end history.Selector
var limit int var limit int
switch preposition { switch preposition {
case "listcorrespondents":
listCorrespondents = true
// use the same selector parsing as BETWEEN,
// except that we have no target so we have one fewer parameter
paramPos = 1
fallthrough
case "between": case "between":
start.Msgid, start.Time, err = parseQueryParam(msg.Params[2]) start.Msgid, start.Time, err = parseQueryParam(msg.Params[paramPos])
if err != nil { if err != nil {
return return
} }
end.Msgid, end.Time, err = parseQueryParam(msg.Params[3]) end.Msgid, end.Time, err = parseQueryParam(msg.Params[paramPos+1])
if err != nil { if err != nil {
return return
} }
@ -662,7 +679,7 @@ func chathistoryHandler(server *Server, client *Client, msg ircmsg.Message, rb *
end.Time = roundUp(end.Time) end.Time = roundUp(end.Time)
} }
} }
limit = parseHistoryLimit(4) limit = parseHistoryLimit(paramPos + 2)
case "before", "after", "around": case "before", "after", "around":
start.Msgid, start.Time, err = parseQueryParam(msg.Params[2]) start.Msgid, start.Time, err = parseQueryParam(msg.Params[2])
if err != nil { if err != nil {
@ -689,14 +706,16 @@ func chathistoryHandler(server *Server, client *Client, msg ircmsg.Message, rb *
} }
limit = parseHistoryLimit(3) limit = parseHistoryLimit(3)
default: default:
unknown_command = true err = utils.ErrInvalidParams
return return
} }
if preposition == "around" { if listCorrespondents {
correspondents, err = sequence.ListCorrespondents(start, end, limit)
} else if preposition == "around" {
items, err = sequence.Around(start, limit) items, err = sequence.Around(start, limit)
} else { } else {
items, _, err = sequence.Between(start, end, limit) items, err = sequence.Between(start, end, limit)
} }
return return
} }
@ -1086,7 +1105,7 @@ func historyHandler(server *Server, client *Client, msg ircmsg.Message, rb *Resp
if channel != nil { if channel != nil {
channel.replayHistoryItems(rb, items, false) channel.replayHistoryItems(rb, items, false)
} else { } else {
client.replayPrivmsgHistory(rb, items, "", true) client.replayPrivmsgHistory(rb, items, "")
} }
} }
return false return false

View File

@ -44,10 +44,15 @@ type Item struct {
// for a DM, this is the casefolded nickname of the other party (whether this is // for a DM, this is the casefolded nickname of the other party (whether this is
// an incoming or outgoing message). this lets us emulate the "query buffer" functionality // an incoming or outgoing message). this lets us emulate the "query buffer" functionality
// required by CHATHISTORY: // required by CHATHISTORY:
CfCorrespondent string CfCorrespondent string `json:"CfCorrespondent,omitempty"`
IsBot bool `json:"IsBot,omitempty"` IsBot bool `json:"IsBot,omitempty"`
} }
type CorrespondentListing struct {
CfCorrespondent string
Time time.Time
}
// HasMsgid tests whether a message has the message id `msgid`. // HasMsgid tests whether a message has the message id `msgid`.
func (item *Item) HasMsgid(msgid string) bool { func (item *Item) HasMsgid(msgid string) bool {
return item.Message.Msgid == msgid return item.Message.Msgid == msgid
@ -61,6 +66,13 @@ func Reverse(results []Item) {
} }
} }
func ReverseCorrespondents(results []CorrespondentListing) {
// lol, generics when?
for i, j := 0, len(results)-1; i < j; i, j = i+1, j-1 {
results[i], results[j] = results[j], results[i]
}
}
// Buffer is a ring buffer holding message/event history for a channel or user // Buffer is a ring buffer holding message/event history for a channel or user
type Buffer struct { type Buffer struct {
sync.RWMutex sync.RWMutex
@ -201,6 +213,78 @@ func (list *Buffer) betweenHelper(start, end Selector, cutoff time.Time, pred Pr
return list.matchInternal(satisfies, ascending, limit), complete, nil return list.matchInternal(satisfies, ascending, limit), complete, nil
} }
// returns all correspondents, in reverse time order
func (list *Buffer) allCorrespondents() (results []CorrespondentListing) {
seen := make(utils.StringSet)
list.RLock()
defer list.RUnlock()
if list.start == -1 || len(list.buffer) == 0 {
return
}
// XXX traverse in reverse order, so we get the latest timestamp
// of any message sent to/from the correspondent
pos := list.prev(list.end)
stop := list.start
for {
if !seen.Has(list.buffer[pos].CfCorrespondent) {
seen.Add(list.buffer[pos].CfCorrespondent)
results = append(results, CorrespondentListing{
CfCorrespondent: list.buffer[pos].CfCorrespondent,
Time: list.buffer[pos].Message.Time,
})
}
if pos == stop {
break
}
pos = list.prev(pos)
}
return
}
// implement LISTCORRESPONDENTS
func (list *Buffer) listCorrespondents(start, end Selector, cutoff time.Time, limit int) (results []CorrespondentListing, err error) {
after := start.Time
before := end.Time
after, before, ascending := MinMaxAsc(after, before, cutoff)
correspondents := list.allCorrespondents()
if len(correspondents) == 0 {
return
}
// XXX allCorrespondents returns results in reverse order,
// so if we're ascending, we actually go backwards
var i int
if ascending {
i = len(correspondents) - 1
} else {
i = 0
}
for 0 <= i && i < len(correspondents) && (limit == 0 || len(results) < limit) {
if (after.IsZero() || correspondents[i].Time.After(after)) &&
(before.IsZero() || correspondents[i].Time.Before(before)) {
results = append(results, correspondents[i])
}
if ascending {
i--
} else {
i++
}
}
if !ascending {
ReverseCorrespondents(results)
}
return
}
// implements history.Sequence, emulating a single history buffer (for a channel, // implements history.Sequence, emulating a single history buffer (for a channel,
// a single user's DMs, or a DM conversation) // a single user's DMs, or a DM conversation)
type bufferSequence struct { type bufferSequence struct {
@ -223,14 +307,19 @@ func (list *Buffer) MakeSequence(correspondent string, cutoff time.Time) Sequenc
} }
} }
func (seq *bufferSequence) Between(start, end Selector, limit int) (results []Item, complete bool, err error) { func (seq *bufferSequence) Between(start, end Selector, limit int) (results []Item, err error) {
return seq.list.betweenHelper(start, end, seq.cutoff, seq.pred, limit) results, _, err = seq.list.betweenHelper(start, end, seq.cutoff, seq.pred, limit)
return
} }
func (seq *bufferSequence) Around(start Selector, limit int) (results []Item, err error) { func (seq *bufferSequence) Around(start Selector, limit int) (results []Item, err error) {
return GenericAround(seq, start, limit) return GenericAround(seq, start, limit)
} }
func (seq *bufferSequence) ListCorrespondents(start, end Selector, limit int) (results []CorrespondentListing, err error) {
return seq.list.listCorrespondents(start, end, seq.cutoff, limit)
}
// you must be holding the read lock to call this // you must be holding the read lock to call this
func (list *Buffer) matchInternal(predicate Predicate, ascending bool, limit int) (results []Item) { func (list *Buffer) matchInternal(predicate Predicate, ascending bool, limit int) (results []Item) {
if list.start == -1 || len(list.buffer) == 0 { if list.start == -1 || len(list.buffer) == 0 {

View File

@ -17,15 +17,17 @@ type Selector struct {
// it encapsulates restrictions such as registration time cutoffs, or // it encapsulates restrictions such as registration time cutoffs, or
// only looking at a single "query buffer" (DMs with a particular correspondent) // only looking at a single "query buffer" (DMs with a particular correspondent)
type Sequence interface { type Sequence interface {
Between(start, end Selector, limit int) (results []Item, complete bool, err error) Between(start, end Selector, limit int) (results []Item, err error)
Around(start Selector, limit int) (results []Item, err error) Around(start Selector, limit int) (results []Item, err error)
ListCorrespondents(start, end Selector, limit int) (results []CorrespondentListing, err error)
} }
// This is a bad, slow implementation of CHATHISTORY AROUND using the BETWEEN semantics // This is a bad, slow implementation of CHATHISTORY AROUND using the BETWEEN semantics
func GenericAround(seq Sequence, start Selector, limit int) (results []Item, err error) { func GenericAround(seq Sequence, start Selector, limit int) (results []Item, err error) {
var halfLimit int var halfLimit int
halfLimit = (limit + 1) / 2 halfLimit = (limit + 1) / 2
initialResults, _, err := seq.Between(Selector{}, start, halfLimit) initialResults, err := seq.Between(Selector{}, start, halfLimit)
if err != nil { if err != nil {
return return
} else if len(initialResults) == 0 { } else if len(initialResults) == 0 {
@ -34,7 +36,7 @@ func GenericAround(seq Sequence, start Selector, limit int) (results []Item, err
return return
} }
newStart := Selector{Time: initialResults[0].Message.Time} newStart := Selector{Time: initialResults[0].Message.Time}
results, _, err = seq.Between(newStart, Selector{}, limit) results, err = seq.Between(newStart, Selector{}, limit)
return return
} }

View File

@ -238,12 +238,12 @@ func easySelectHistory(server *Server, client *Client, params []string) (items [
} }
if duration == 0 { if duration == 0 {
items, _, err = sequence.Between(history.Selector{}, history.Selector{}, limit) items, err = sequence.Between(history.Selector{}, history.Selector{}, limit)
} else { } else {
now := time.Now().UTC() now := time.Now().UTC()
start := history.Selector{Time: now} start := history.Selector{Time: now}
end := history.Selector{Time: now.Add(-duration)} end := history.Selector{Time: now.Add(-duration)}
items, _, err = sequence.Between(start, end, limit) items, err = sequence.Between(start, end, limit)
} }
return return
} }

View File

@ -4,7 +4,6 @@
package mysql package mysql
import ( import (
"bytes"
"context" "context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
@ -12,6 +11,7 @@ import (
"fmt" "fmt"
"io" "io"
"runtime/debug" "runtime/debug"
"strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -36,7 +36,7 @@ const (
keySchemaVersion = "db.version" keySchemaVersion = "db.version"
// minor version indicates rollback-safe upgrades, i.e., // minor version indicates rollback-safe upgrades, i.e.,
// you can downgrade oragono and everything will work // you can downgrade oragono and everything will work
latestDbMinorVersion = "1" latestDbMinorVersion = "2"
keySchemaMinorVersion = "db.minorversion" keySchemaMinorVersion = "db.minorversion"
cleanupRowLimit = 50 cleanupRowLimit = 50
cleanupPauseTime = 10 * time.Minute cleanupPauseTime = 10 * time.Minute
@ -53,6 +53,7 @@ type MySQL struct {
insertHistory *sql.Stmt insertHistory *sql.Stmt
insertSequence *sql.Stmt insertSequence *sql.Stmt
insertConversation *sql.Stmt insertConversation *sql.Stmt
insertCorrespondent *sql.Stmt
insertAccountMessage *sql.Stmt insertAccountMessage *sql.Stmt
stateMutex sync.Mutex stateMutex sync.Mutex
@ -155,10 +156,24 @@ func (mysql *MySQL) fixSchemas() (err error) {
if err != nil { if err != nil {
return return
} }
err = mysql.createCorrespondentsTable()
if err != nil {
return
}
_, err = mysql.db.Exec(`insert into metadata (key_name, value) values (?, ?);`, keySchemaMinorVersion, latestDbMinorVersion) _, err = mysql.db.Exec(`insert into metadata (key_name, value) values (?, ?);`, keySchemaMinorVersion, latestDbMinorVersion)
if err != nil { if err != nil {
return return
} }
} else if err == nil && minorVersion == "1" {
// upgrade from 2.1 to 2.2: create the correspondents table
err = mysql.createCorrespondentsTable()
if err != nil {
return
}
_, err = mysql.db.Exec(`update metadata set value = ? where key_name = ?;`, latestDbMinorVersion, keySchemaMinorVersion)
if err != nil {
return
}
} else if err == nil && minorVersion != latestDbMinorVersion { } else if err == nil && minorVersion != latestDbMinorVersion {
// TODO: if minorVersion < latestDbMinorVersion, upgrade, // TODO: if minorVersion < latestDbMinorVersion, upgrade,
// if latestDbMinorVersion < minorVersion, ignore because backwards compatible // if latestDbMinorVersion < minorVersion, ignore because backwards compatible
@ -202,6 +217,11 @@ func (mysql *MySQL) createTables() (err error) {
return err return err
} }
err = mysql.createCorrespondentsTable()
if err != nil {
return err
}
err = mysql.createComplianceTables() err = mysql.createComplianceTables()
if err != nil { if err != nil {
return err return err
@ -210,6 +230,19 @@ func (mysql *MySQL) createTables() (err error) {
return nil return nil
} }
func (mysql *MySQL) createCorrespondentsTable() (err error) {
_, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE correspondents (
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
target VARBINARY(%[1]d) NOT NULL,
correspondent VARBINARY(%[1]d) NOT NULL,
nanotime BIGINT UNSIGNED NOT NULL,
UNIQUE KEY (target, correspondent),
KEY (target, nanotime),
KEY (nanotime)
) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength))
return
}
func (mysql *MySQL) createComplianceTables() (err error) { func (mysql *MySQL) createComplianceTables() (err error) {
_, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE account_messages ( _, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE account_messages (
history_id BIGINT UNSIGNED NOT NULL PRIMARY KEY, history_id BIGINT UNSIGNED NOT NULL PRIMARY KEY,
@ -275,12 +308,16 @@ func (mysql *MySQL) doCleanup(age time.Duration) (count int, err error) {
mysql.logger.Debug("mysql", fmt.Sprintf("deleting %d history rows, max age %s", len(ids), utils.NanoToTimestamp(maxNanotime))) mysql.logger.Debug("mysql", fmt.Sprintf("deleting %d history rows, max age %s", len(ids), utils.NanoToTimestamp(maxNanotime)))
if maxNanotime != 0 {
mysql.deleteCorrespondents(ctx, maxNanotime)
}
return len(ids), mysql.deleteHistoryIDs(ctx, ids) return len(ids), mysql.deleteHistoryIDs(ctx, ids)
} }
func (mysql *MySQL) deleteHistoryIDs(ctx context.Context, ids []uint64) (err error) { func (mysql *MySQL) deleteHistoryIDs(ctx context.Context, ids []uint64) (err error) {
// can't use ? binding for a variable number of arguments, build the IN clause manually // can't use ? binding for a variable number of arguments, build the IN clause manually
var inBuf bytes.Buffer var inBuf strings.Builder
inBuf.WriteByte('(') inBuf.WriteByte('(')
for i, id := range ids { for i, id := range ids {
if i != 0 { if i != 0 {
@ -289,22 +326,23 @@ func (mysql *MySQL) deleteHistoryIDs(ctx context.Context, ids []uint64) (err err
fmt.Fprintf(&inBuf, "%d", id) fmt.Fprintf(&inBuf, "%d", id)
} }
inBuf.WriteRune(')') inBuf.WriteRune(')')
inClause := inBuf.String()
_, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM conversations WHERE history_id in %s;`, inBuf.Bytes())) _, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM conversations WHERE history_id in %s;`, inClause))
if err != nil { if err != nil {
return return
} }
_, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM sequence WHERE history_id in %s;`, inBuf.Bytes())) _, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM sequence WHERE history_id in %s;`, inClause))
if err != nil { if err != nil {
return return
} }
if mysql.isTrackingAccountMessages() { if mysql.isTrackingAccountMessages() {
_, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM account_messages WHERE history_id in %s;`, inBuf.Bytes())) _, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM account_messages WHERE history_id in %s;`, inClause))
if err != nil { if err != nil {
return return
} }
} }
_, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM history WHERE id in %s;`, inBuf.Bytes())) _, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM history WHERE id in %s;`, inClause))
if err != nil { if err != nil {
return return
} }
@ -351,6 +389,18 @@ func (mysql *MySQL) selectCleanupIDs(ctx context.Context, age time.Duration) (id
return return
} }
func (mysql *MySQL) deleteCorrespondents(ctx context.Context, threshold int64) {
result, err := mysql.db.ExecContext(ctx, `DELETE FROM correspondents WHERE nanotime <= (?);`, threshold)
if err != nil {
mysql.logError("error deleting correspondents", err)
} else {
count, err := result.RowsAffected()
if err != nil {
mysql.logger.Debug(fmt.Sprintf("deleted %d correspondents entries", count))
}
}
}
// wait for forget queue items and process them one by one // wait for forget queue items and process them one by one
func (mysql *MySQL) forgetLoop() { func (mysql *MySQL) forgetLoop() {
defer func() { defer func() {
@ -470,6 +520,12 @@ func (mysql *MySQL) prepareStatements() (err error) {
if err != nil { if err != nil {
return return
} }
mysql.insertCorrespondent, err = mysql.db.Prepare(`INSERT INTO correspondents
(target, correspondent, nanotime) VALUES (?, ?, ?)
ON DUPLICATE KEY UPDATE nanotime = GREATEST(nanotime, ?);`)
if err != nil {
return
}
mysql.insertAccountMessage, err = mysql.db.Prepare(`INSERT INTO account_messages mysql.insertAccountMessage, err = mysql.db.Prepare(`INSERT INTO account_messages
(history_id, account) VALUES (?, ?);`) (history_id, account) VALUES (?, ?);`)
if err != nil { if err != nil {
@ -557,6 +613,12 @@ func (mysql *MySQL) insertConversationEntry(ctx context.Context, target, corresp
return return
} }
func (mysql *MySQL) insertCorrespondentsEntry(ctx context.Context, target, correspondent string, messageTime int64, historyId int64) (err error) {
_, err = mysql.insertCorrespondent.ExecContext(ctx, target, correspondent, messageTime, messageTime)
mysql.logError("could not insert conversations entry", err)
return
}
func (mysql *MySQL) insertBase(ctx context.Context, item history.Item) (id int64, err error) { func (mysql *MySQL) insertBase(ctx context.Context, item history.Item) (id int64, err error) {
value, err := marshalItem(&item) value, err := marshalItem(&item)
if mysql.logError("could not marshal item", err) { if mysql.logError("could not marshal item", err) {
@ -621,6 +683,10 @@ func (mysql *MySQL) AddDirectMessage(sender, senderAccount, recipient, recipient
if err != nil { if err != nil {
return return
} }
err = mysql.insertCorrespondentsEntry(ctx, senderAccount, recipient, nanotime, id)
if err != nil {
return
}
} }
if recipientAccount != "" && sender != recipient { if recipientAccount != "" && sender != recipient {
@ -632,6 +698,10 @@ func (mysql *MySQL) AddDirectMessage(sender, senderAccount, recipient, recipient
if err != nil { if err != nil {
return return
} }
err = mysql.insertCorrespondentsEntry(ctx, recipientAccount, sender, nanotime, id)
if err != nil {
return
}
} }
err = mysql.insertAccountMessageEntry(ctx, id, senderAccount) err = mysql.insertAccountMessageEntry(ctx, id, senderAccount)
@ -804,7 +874,7 @@ func (mysql *MySQL) betweenTimestamps(ctx context.Context, target, correspondent
direction = "DESC" direction = "DESC"
} }
var queryBuf bytes.Buffer var queryBuf strings.Builder
args := make([]interface{}, 0, 6) args := make([]interface{}, 0, 6)
fmt.Fprintf(&queryBuf, fmt.Fprintf(&queryBuf,
@ -835,6 +905,55 @@ func (mysql *MySQL) betweenTimestamps(ctx context.Context, target, correspondent
return return
} }
func (mysql *MySQL) listCorrespondentsInternal(ctx context.Context, target string, after, before, cutoff time.Time, limit int) (results []history.CorrespondentListing, err error) {
after, before, ascending := history.MinMaxAsc(after, before, cutoff)
direction := "ASC"
if !ascending {
direction = "DESC"
}
var queryBuf strings.Builder
args := make([]interface{}, 0, 4)
queryBuf.WriteString(`SELECT correspondents.correspondent, correspondents.nanotime from correspondents
WHERE target = ?`)
args = append(args, target)
if !after.IsZero() {
queryBuf.WriteString(" AND correspondents.nanotime > ?")
args = append(args, after.UnixNano())
}
if !before.IsZero() {
queryBuf.WriteString(" AND correspondents.nanotime < ?")
args = append(args, before.UnixNano())
}
fmt.Fprintf(&queryBuf, " ORDER BY correspondents.nanotime %s LIMIT ?;", direction)
args = append(args, limit)
query := queryBuf.String()
rows, err := mysql.db.QueryContext(ctx, query, args...)
if err != nil {
return
}
defer rows.Close()
var correspondent string
var nanotime int64
for rows.Next() {
err = rows.Scan(&correspondent, &nanotime)
if err != nil {
return
}
results = append(results, history.CorrespondentListing{
CfCorrespondent: correspondent,
Time: time.Unix(0, nanotime),
})
}
if !ascending {
history.ReverseCorrespondents(results)
}
return
}
func (mysql *MySQL) Close() { func (mysql *MySQL) Close() {
// closing the database will close our prepared statements as well // closing the database will close our prepared statements as well
if mysql.db != nil { if mysql.db != nil {
@ -852,7 +971,7 @@ type mySQLHistorySequence struct {
cutoff time.Time cutoff time.Time
} }
func (s *mySQLHistorySequence) Between(start, end history.Selector, limit int) (results []history.Item, complete bool, err error) { func (s *mySQLHistorySequence) Between(start, end history.Selector, limit int) (results []history.Item, err error) {
ctx, cancel := context.WithTimeout(context.Background(), s.mysql.getTimeout()) ctx, cancel := context.WithTimeout(context.Background(), s.mysql.getTimeout())
defer cancel() defer cancel()
@ -860,25 +979,38 @@ func (s *mySQLHistorySequence) Between(start, end history.Selector, limit int) (
if start.Msgid != "" { if start.Msgid != "" {
startTime, _, _, err = s.mysql.lookupMsgid(ctx, start.Msgid, false) startTime, _, _, err = s.mysql.lookupMsgid(ctx, start.Msgid, false)
if err != nil { if err != nil {
return nil, false, err return nil, err
} }
} }
endTime := end.Time endTime := end.Time
if end.Msgid != "" { if end.Msgid != "" {
endTime, _, _, err = s.mysql.lookupMsgid(ctx, end.Msgid, false) endTime, _, _, err = s.mysql.lookupMsgid(ctx, end.Msgid, false)
if err != nil { if err != nil {
return nil, false, err return nil, err
} }
} }
results, err = s.mysql.betweenTimestamps(ctx, s.target, s.correspondent, startTime, endTime, s.cutoff, limit) results, err = s.mysql.betweenTimestamps(ctx, s.target, s.correspondent, startTime, endTime, s.cutoff, limit)
return results, (err == nil), err return results, err
} }
func (s *mySQLHistorySequence) Around(start history.Selector, limit int) (results []history.Item, err error) { func (s *mySQLHistorySequence) Around(start history.Selector, limit int) (results []history.Item, err error) {
return history.GenericAround(s, start, limit) return history.GenericAround(s, start, limit)
} }
func (seq *mySQLHistorySequence) ListCorrespondents(start, end history.Selector, limit int) (results []history.CorrespondentListing, err error) {
ctx, cancel := context.WithTimeout(context.Background(), seq.mysql.getTimeout())
defer cancel()
// TODO accept msgids here?
startTime := start.Time
endTime := end.Time
results, err = seq.mysql.listCorrespondentsInternal(ctx, seq.target, startTime, endTime, seq.cutoff, limit)
seq.mysql.logError("could not read correspondents", err)
return
}
func (mysql *MySQL) MakeSequence(target, correspondent string, cutoff time.Time) history.Sequence { func (mysql *MySQL) MakeSequence(target, correspondent string, cutoff time.Time) history.Sequence {
return &mySQLHistorySequence{ return &mySQLHistorySequence{
target: target, target: target,

View File

@ -16,6 +16,8 @@ import (
const ( const (
// #829, also see "Case 2" in the "three cases" below: // #829, also see "Case 2" in the "three cases" below:
zncPlaybackCommandExpiration = time.Second * 30 zncPlaybackCommandExpiration = time.Second * 30
zncPrefix = "*playback!znc@znc.in"
) )
type zncCommandHandler func(client *Client, command string, params []string, rb *ResponseBuffer) type zncCommandHandler func(client *Client, command string, params []string, rb *ResponseBuffer)
@ -192,9 +194,9 @@ func zncPlayPrivmsgs(client *Client, rb *ResponseBuffer, target string, after, b
return return
} }
zncMax := client.server.Config().History.ZNCMax zncMax := client.server.Config().History.ZNCMax
items, _, err := sequence.Between(history.Selector{Time: after}, history.Selector{Time: before}, zncMax) items, err := sequence.Between(history.Selector{Time: after}, history.Selector{Time: before}, zncMax)
if err == nil && len(items) != 0 { if err == nil && len(items) != 0 {
client.replayPrivmsgHistory(rb, items, "", true) client.replayPrivmsgHistory(rb, items, "")
} }
} }
@ -209,12 +211,31 @@ func zncPlaybackListHandler(client *Client, command string, params []string, rb
client.server.logger.Error("internal", "couldn't get history sequence for ZNC list", err.Error()) client.server.logger.Error("internal", "couldn't get history sequence for ZNC list", err.Error())
continue continue
} }
items, _, err := sequence.Between(history.Selector{}, history.Selector{}, 1) // i.e., LATEST * 1 items, err := sequence.Between(history.Selector{}, history.Selector{}, 1) // i.e., LATEST * 1
if err != nil { if err != nil {
client.server.logger.Error("internal", "couldn't query history for ZNC list", err.Error()) client.server.logger.Error("internal", "couldn't query history for ZNC list", err.Error())
} else if len(items) != 0 { } else if len(items) != 0 {
stamp := timeToZncWireTime(items[0].Message.Time) stamp := timeToZncWireTime(items[0].Message.Time)
rb.Add(nil, "*playback!znc@znc.in", "PRIVMSG", nick, fmt.Sprintf("%s 0 %s", channel.Name(), stamp)) rb.Add(nil, zncPrefix, "PRIVMSG", nick, fmt.Sprintf("%s 0 %s", channel.Name(), stamp))
} }
} }
_, seq, err := client.server.GetHistorySequence(nil, client, "*")
if seq == nil {
return
} else if err != nil {
client.server.logger.Error("internal", "couldn't get client history sequence for ZNC list", err.Error())
return
}
limit := client.server.Config().History.ChathistoryMax
correspondents, err := seq.ListCorrespondents(history.Selector{}, history.Selector{}, limit)
if err != nil {
client.server.logger.Error("internal", "couldn't get correspondents for ZNC list", err.Error())
return
}
for _, correspondent := range correspondents {
stamp := timeToZncWireTime(correspondent.Time)
correspondentNick := client.server.clients.UnfoldNick(correspondent.CfCorrespondent)
rb.Add(nil, zncPrefix, "PRIVMSG", nick, fmt.Sprintf("%s 0 %s", correspondentNick, stamp))
}
} }