diff --git a/irc/channel.go b/irc/channel.go index 4b341e3e..b1bc1ffc 100644 --- a/irc/channel.go +++ b/irc/channel.go @@ -920,7 +920,7 @@ func (channel *Channel) autoReplayHistory(client *Client, rb *ResponseBuffer, sk _, seq, _ := channel.server.GetHistorySequence(channel, client, "") if seq != nil { zncMax := channel.server.Config().History.ZNCMax - items, _, _ = seq.Between(history.Selector{Time: start}, history.Selector{Time: end}, zncMax) + items, _ = seq.Between(history.Selector{Time: start}, history.Selector{Time: end}, zncMax) } } else if !rb.session.HasHistoryCaps() { var replayLimit int @@ -937,7 +937,7 @@ func (channel *Channel) autoReplayHistory(client *Client, rb *ResponseBuffer, sk if 0 < replayLimit { _, seq, _ := channel.server.GetHistorySequence(channel, client, "") if seq != nil { - items, _, _ = seq.Between(history.Selector{}, history.Selector{}, replayLimit) + items, _ = seq.Between(history.Selector{}, history.Selector{}, replayLimit) } } } @@ -1097,20 +1097,15 @@ func (channel *Channel) resumeAndAnnounce(session *Session) { func (channel *Channel) replayHistoryForResume(session *Session, after time.Time, before time.Time) { var items []history.Item - var complete bool afterS, beforeS := history.Selector{Time: after}, history.Selector{Time: before} _, seq, _ := channel.server.GetHistorySequence(channel, session.client, "") if seq != nil { - items, complete, _ = seq.Between(afterS, beforeS, channel.server.Config().History.ZNCMax) + items, _ = seq.Between(afterS, beforeS, channel.server.Config().History.ZNCMax) } rb := NewResponseBuffer(session) if len(items) != 0 { channel.replayHistoryItems(rb, items, false) } - if !complete && !session.resumeDetails.HistoryIncomplete { - // warn here if we didn't warn already - rb.Add(nil, histservService.prefix, "NOTICE", channel.Name(), session.client.t("Some additional message history may have been lost")) - } rb.Send(true) } diff --git a/irc/channelmanager.go b/irc/channelmanager.go index a2b75b26..72ce0183 100644 --- a/irc/channelmanager.go +++ b/irc/channelmanager.go @@ -458,3 +458,13 @@ func (cm *ChannelManager) ListPurged() (result []string) { sort.Strings(result) return } + +func (cm *ChannelManager) UnfoldName(cfname string) (result string) { + cm.RLock() + entry := cm.chans[cfname] + cm.RUnlock() + if entry != nil && entry.channel.IsLoaded() { + return entry.channel.Name() + } + return cfname +} diff --git a/irc/client.go b/irc/client.go index c7e36b5c..6e857776 100644 --- a/irc/client.go +++ b/irc/client.go @@ -990,7 +990,7 @@ func (session *Session) playResume() { } _, privmsgSeq, _ := server.GetHistorySequence(nil, client, "*") if privmsgSeq != nil { - privmsgs, _, _ := privmsgSeq.Between(history.Selector{}, history.Selector{}, config.History.ClientLength) + privmsgs, _ := privmsgSeq.Between(history.Selector{}, history.Selector{}, config.History.ClientLength) for _, item := range privmsgs { sender := server.clients.Get(NUHToNick(item.Nick)) if sender != nil { @@ -1055,10 +1055,10 @@ func (session *Session) playResume() { // replay direct PRIVSMG history if !timestamp.IsZero() && privmsgSeq != nil { after := history.Selector{Time: timestamp} - items, complete, _ := privmsgSeq.Between(after, history.Selector{}, config.History.ZNCMax) + items, _ := privmsgSeq.Between(after, history.Selector{}, config.History.ZNCMax) if len(items) != 0 { rb := NewResponseBuffer(session) - client.replayPrivmsgHistory(rb, items, "", complete) + client.replayPrivmsgHistory(rb, items, "") rb.Send(true) } } @@ -1066,7 +1066,7 @@ func (session *Session) playResume() { session.resumeDetails = nil } -func (client *Client) replayPrivmsgHistory(rb *ResponseBuffer, items []history.Item, target string, complete bool) { +func (client *Client) replayPrivmsgHistory(rb *ResponseBuffer, items []history.Item, target string) { var batchID string details := client.Details() nick := details.nick @@ -1126,9 +1126,6 @@ func (client *Client) replayPrivmsgHistory(rb *ResponseBuffer, items []history.I } rb.EndNestedBatch(batchID) - if !complete { - rb.Add(nil, histservService.prefix, "NOTICE", nick, client.t("Some additional message history may have been lost")) - } } // IdleTime returns how long this client's been idle. @@ -1934,6 +1931,43 @@ func (client *Client) addHistoryItem(target *Client, item history.Item, details, return nil } +func (client *Client) listTargets(start, end history.Selector, limit int) (results []history.TargetListing, err error) { + var base, extras []history.TargetListing + var chcfnames []string + for _, channel := range client.Channels() { + _, seq, err := client.server.GetHistorySequence(channel, client, "") + if seq == nil || err != nil { + continue + } + if seq.Ephemeral() { + items, err := seq.Between(history.Selector{}, history.Selector{}, 1) + if err == nil && len(items) != 0 { + extras = append(extras, history.TargetListing{ + Time: items[0].Message.Time, + CfName: channel.NameCasefolded(), + }) + } + } else { + chcfnames = append(chcfnames, channel.NameCasefolded()) + } + } + persistentExtras, err := client.server.historyDB.ListChannels(chcfnames) + if err == nil && len(persistentExtras) != 0 { + extras = append(extras, persistentExtras...) + } + + _, cSeq, err := client.server.GetHistorySequence(nil, client, "*") + if err == nil && cSeq != nil { + correspondents, err := cSeq.ListCorrespondents(start, end, limit) + if err == nil { + base = correspondents + } + } + + results = history.MergeTargets(base, extras, start.Time, end.Time, limit) + return results, nil +} + func (client *Client) handleRegisterTimeout() { client.Quit(fmt.Sprintf("Registration timeout: %v", RegisterTimeout), nil) client.destroy(nil) diff --git a/irc/client_lookup_set.go b/irc/client_lookup_set.go index 230b69da..6992ee6e 100644 --- a/irc/client_lookup_set.go +++ b/irc/client_lookup_set.go @@ -308,3 +308,16 @@ func (clients *ClientManager) FindAll(userhost string) (set ClientSet) { return set } + +// Determine the canonical / unfolded form of a nick, if a client matching it +// is present (or always-on). +func (clients *ClientManager) UnfoldNick(cfnick string) (nick string) { + clients.RLock() + c := clients.byNick[cfnick] + clients.RUnlock() + if c != nil { + return c.Nick() + } else { + return cfnick + } +} diff --git a/irc/handlers.go b/irc/handlers.go index 95a99a1a..80a8cf44 100644 --- a/irc/handlers.go +++ b/irc/handlers.go @@ -566,27 +566,34 @@ func capHandler(server *Server, client *Client, msg ircmsg.Message, rb *Response // e.g., CHATHISTORY #ircv3 BETWEEN timestamp=YYYY-MM-DDThh:mm:ss.sssZ timestamp=YYYY-MM-DDThh:mm:ss.sssZ + 100 func chathistoryHandler(server *Server, client *Client, msg ircmsg.Message, rb *ResponseBuffer) (exiting bool) { var items []history.Item - unknown_command := false var target string var channel *Channel var sequence history.Sequence var err error + var listTargets bool + var targets []history.TargetListing defer func() { // errors are sent either without a batch, or in a draft/labeled-response batch as usual - if unknown_command { - rb.Add(nil, server.name, "FAIL", "CHATHISTORY", "UNKNOWN_COMMAND", utils.SafeErrorParam(msg.Params[0]), client.t("Unknown command")) - } else if err == utils.ErrInvalidParams { + if err == utils.ErrInvalidParams { rb.Add(nil, server.name, "FAIL", "CHATHISTORY", "INVALID_PARAMS", msg.Params[0], client.t("Invalid parameters")) - } else if sequence == nil { + } else if !listTargets && sequence == nil { rb.Add(nil, server.name, "FAIL", "CHATHISTORY", "INVALID_TARGET", utils.SafeErrorParam(target), client.t("Messages could not be retrieved")) } else if err != nil { rb.Add(nil, server.name, "FAIL", "CHATHISTORY", "MESSAGE_ERROR", msg.Params[0], client.t("Messages could not be retrieved")) } else { // successful responses are sent as a chathistory or history batch - if channel != nil { + if listTargets { + batchID := rb.StartNestedBatch("draft/chathistory-targets") + defer rb.EndNestedBatch(batchID) + for _, target := range targets { + name := server.UnfoldName(target.CfName) + rb.Add(nil, server.name, "CHATHISTORY", "TARGETS", name, + target.Time.Format(IRCv3TimestampFormat)) + } + } else if channel != nil { channel.replayHistoryItems(rb, items, false) } else { - client.replayPrivmsgHistory(rb, items, target, true) + client.replayPrivmsgHistory(rb, items, target) } } }() @@ -598,6 +605,7 @@ func chathistoryHandler(server *Server, client *Client, msg ircmsg.Message, rb * } preposition := strings.ToLower(msg.Params[0]) target = msg.Params[1] + listTargets = (preposition == "targets") parseQueryParam := func(param string) (msgid string, timestamp time.Time, err error) { if param == "*" && (preposition == "before" || preposition == "between") { @@ -632,24 +640,25 @@ func chathistoryHandler(server *Server, client *Client, msg ircmsg.Message, rb * return } - channel, sequence, err = server.GetHistorySequence(nil, client, target) - if err != nil || sequence == nil { - return - } - roundUp := func(endpoint time.Time) (result time.Time) { return endpoint.Truncate(time.Millisecond).Add(time.Millisecond) } + paramPos := 2 var start, end history.Selector var limit int switch preposition { + case "targets": + // use the same selector parsing as BETWEEN, + // except that we have no target so we have one fewer parameter + paramPos = 1 + fallthrough case "between": - start.Msgid, start.Time, err = parseQueryParam(msg.Params[2]) + start.Msgid, start.Time, err = parseQueryParam(msg.Params[paramPos]) if err != nil { return } - end.Msgid, end.Time, err = parseQueryParam(msg.Params[3]) + end.Msgid, end.Time, err = parseQueryParam(msg.Params[paramPos+1]) if err != nil { return } @@ -662,7 +671,7 @@ func chathistoryHandler(server *Server, client *Client, msg ircmsg.Message, rb * end.Time = roundUp(end.Time) } } - limit = parseHistoryLimit(4) + limit = parseHistoryLimit(paramPos + 2) case "before", "after", "around": start.Msgid, start.Time, err = parseQueryParam(msg.Params[2]) if err != nil { @@ -689,14 +698,22 @@ func chathistoryHandler(server *Server, client *Client, msg ircmsg.Message, rb * } limit = parseHistoryLimit(3) default: - unknown_command = true + err = utils.ErrInvalidParams return } - if preposition == "around" { - items, err = sequence.Around(start, limit) + if listTargets { + targets, err = client.listTargets(start, end, limit) } else { - items, _, err = sequence.Between(start, end, limit) + channel, sequence, err = server.GetHistorySequence(nil, client, target) + if err != nil || sequence == nil { + return + } + if preposition == "around" { + items, err = sequence.Around(start, limit) + } else { + items, err = sequence.Between(start, end, limit) + } } return } @@ -1086,7 +1103,7 @@ func historyHandler(server *Server, client *Client, msg ircmsg.Message, rb *Resp if channel != nil { channel.replayHistoryItems(rb, items, false) } else { - client.replayPrivmsgHistory(rb, items, "", true) + client.replayPrivmsgHistory(rb, items, "") } } return false diff --git a/irc/history/history.go b/irc/history/history.go index 981108d8..0e10e856 100644 --- a/irc/history/history.go +++ b/irc/history/history.go @@ -44,8 +44,8 @@ type Item struct { // for a DM, this is the casefolded nickname of the other party (whether this is // an incoming or outgoing message). this lets us emulate the "query buffer" functionality // required by CHATHISTORY: - CfCorrespondent string - IsBot bool `json:"IsBot,omitempty"` + CfCorrespondent string `json:"CfCorrespondent,omitempty"` + IsBot bool `json:"IsBot,omitempty"` } // HasMsgid tests whether a message has the message id `msgid`. @@ -201,6 +201,78 @@ func (list *Buffer) betweenHelper(start, end Selector, cutoff time.Time, pred Pr return list.matchInternal(satisfies, ascending, limit), complete, nil } +// returns all correspondents, in reverse time order +func (list *Buffer) allCorrespondents() (results []TargetListing) { + seen := make(utils.StringSet) + + list.RLock() + defer list.RUnlock() + if list.start == -1 || len(list.buffer) == 0 { + return + } + + // XXX traverse in reverse order, so we get the latest timestamp + // of any message sent to/from the correspondent + pos := list.prev(list.end) + stop := list.start + + for { + if !seen.Has(list.buffer[pos].CfCorrespondent) { + seen.Add(list.buffer[pos].CfCorrespondent) + results = append(results, TargetListing{ + CfName: list.buffer[pos].CfCorrespondent, + Time: list.buffer[pos].Message.Time, + }) + } + + if pos == stop { + break + } + pos = list.prev(pos) + } + return +} + +// list DM correspondents, as one input to CHATHISTORY TARGETS +func (list *Buffer) listCorrespondents(start, end Selector, cutoff time.Time, limit int) (results []TargetListing, err error) { + after := start.Time + before := end.Time + after, before, ascending := MinMaxAsc(after, before, cutoff) + + correspondents := list.allCorrespondents() + if len(correspondents) == 0 { + return + } + + // XXX allCorrespondents returns results in reverse order, + // so if we're ascending, we actually go backwards + var i int + if ascending { + i = len(correspondents) - 1 + } else { + i = 0 + } + + for 0 <= i && i < len(correspondents) && (limit == 0 || len(results) < limit) { + if (after.IsZero() || correspondents[i].Time.After(after)) && + (before.IsZero() || correspondents[i].Time.Before(before)) { + results = append(results, correspondents[i]) + } + + if ascending { + i-- + } else { + i++ + } + } + + if !ascending { + ReverseCorrespondents(results) + } + + return +} + // implements history.Sequence, emulating a single history buffer (for a channel, // a single user's DMs, or a DM conversation) type bufferSequence struct { @@ -223,14 +295,27 @@ func (list *Buffer) MakeSequence(correspondent string, cutoff time.Time) Sequenc } } -func (seq *bufferSequence) Between(start, end Selector, limit int) (results []Item, complete bool, err error) { - return seq.list.betweenHelper(start, end, seq.cutoff, seq.pred, limit) +func (seq *bufferSequence) Between(start, end Selector, limit int) (results []Item, err error) { + results, _, err = seq.list.betweenHelper(start, end, seq.cutoff, seq.pred, limit) + return } func (seq *bufferSequence) Around(start Selector, limit int) (results []Item, err error) { return GenericAround(seq, start, limit) } +func (seq *bufferSequence) ListCorrespondents(start, end Selector, limit int) (results []TargetListing, err error) { + return seq.list.listCorrespondents(start, end, seq.cutoff, limit) +} + +func (seq *bufferSequence) Cutoff() time.Time { + return seq.cutoff +} + +func (seq *bufferSequence) Ephemeral() bool { + return true +} + // you must be holding the read lock to call this func (list *Buffer) matchInternal(predicate Predicate, ascending bool, limit int) (results []Item) { if list.start == -1 || len(list.buffer) == 0 { diff --git a/irc/history/queries.go b/irc/history/queries.go index 078e7270..2c7be322 100644 --- a/irc/history/queries.go +++ b/irc/history/queries.go @@ -17,15 +17,24 @@ type Selector struct { // it encapsulates restrictions such as registration time cutoffs, or // only looking at a single "query buffer" (DMs with a particular correspondent) type Sequence interface { - Between(start, end Selector, limit int) (results []Item, complete bool, err error) + Between(start, end Selector, limit int) (results []Item, err error) Around(start Selector, limit int) (results []Item, err error) + + ListCorrespondents(start, end Selector, limit int) (results []TargetListing, err error) + + // this are weird hacks that violate the encapsulation of Sequence to some extent; + // Cutoff() returns the cutoff time for other code to use (it returns the zero time + // if none is set), and Ephemeral() returns whether the backing store is in-memory + // or a persistent database. + Cutoff() time.Time + Ephemeral() bool } // This is a bad, slow implementation of CHATHISTORY AROUND using the BETWEEN semantics func GenericAround(seq Sequence, start Selector, limit int) (results []Item, err error) { var halfLimit int halfLimit = (limit + 1) / 2 - initialResults, _, err := seq.Between(Selector{}, start, halfLimit) + initialResults, err := seq.Between(Selector{}, start, halfLimit) if err != nil { return } else if len(initialResults) == 0 { @@ -34,7 +43,7 @@ func GenericAround(seq Sequence, start Selector, limit int) (results []Item, err return } newStart := Selector{Time: initialResults[0].Message.Time} - results, _, err = seq.Between(newStart, Selector{}, limit) + results, err = seq.Between(newStart, Selector{}, limit) return } diff --git a/irc/history/targets.go b/irc/history/targets.go new file mode 100644 index 00000000..39a0c209 --- /dev/null +++ b/irc/history/targets.go @@ -0,0 +1,83 @@ +// Copyright (c) 2021 Shivaram Lingamneni +// released under the MIT license + +package history + +import ( + "sort" + "time" +) + +type TargetListing struct { + CfName string + Time time.Time +} + +// Merge `base`, a paging window of targets, with `extras` (the target entries +// for all joined channels). +func MergeTargets(base []TargetListing, extra []TargetListing, start, end time.Time, limit int) (results []TargetListing) { + if len(extra) == 0 { + return base + } + SortCorrespondents(extra) + + start, end, ascending := MinMaxAsc(start, end, time.Time{}) + predicate := func(t time.Time) bool { + return (start.IsZero() || start.Before(t)) && (end.IsZero() || end.After(t)) + } + + prealloc := len(base) + len(extra) + if limit < prealloc { + prealloc = limit + } + results = make([]TargetListing, 0, prealloc) + + if !ascending { + ReverseCorrespondents(base) + ReverseCorrespondents(extra) + } + + for len(results) < limit { + if len(extra) != 0 { + if !predicate(extra[0].Time) { + extra = extra[1:] + continue + } + if len(base) != 0 { + if base[0].Time.Before(extra[0].Time) == ascending { + results = append(results, base[0]) + base = base[1:] + } else { + results = append(results, extra[0]) + extra = extra[1:] + } + } else { + results = append(results, extra[0]) + extra = extra[1:] + } + } else if len(base) != 0 { + results = append(results, base[0]) + base = base[1:] + } else { + break + } + } + + if !ascending { + ReverseCorrespondents(results) + } + return +} + +func ReverseCorrespondents(results []TargetListing) { + // lol, generics when? + for i, j := 0, len(results)-1; i < j; i, j = i+1, j-1 { + results[i], results[j] = results[j], results[i] + } +} + +func SortCorrespondents(list []TargetListing) { + sort.Slice(list, func(i, j int) bool { + return list[i].Time.Before(list[j].Time) + }) +} diff --git a/irc/histserv.go b/irc/histserv.go index c7c13257..3497463c 100644 --- a/irc/histserv.go +++ b/irc/histserv.go @@ -238,12 +238,12 @@ func easySelectHistory(server *Server, client *Client, params []string) (items [ } if duration == 0 { - items, _, err = sequence.Between(history.Selector{}, history.Selector{}, limit) + items, err = sequence.Between(history.Selector{}, history.Selector{}, limit) } else { now := time.Now().UTC() start := history.Selector{Time: now} end := history.Selector{Time: now.Add(-duration)} - items, _, err = sequence.Between(start, end, limit) + items, err = sequence.Between(start, end, limit) } return } diff --git a/irc/mysql/history.go b/irc/mysql/history.go index 1e055675..c744af1f 100644 --- a/irc/mysql/history.go +++ b/irc/mysql/history.go @@ -4,7 +4,6 @@ package mysql import ( - "bytes" "context" "database/sql" "encoding/json" @@ -12,6 +11,7 @@ import ( "fmt" "io" "runtime/debug" + "strings" "sync" "sync/atomic" "time" @@ -36,7 +36,7 @@ const ( keySchemaVersion = "db.version" // minor version indicates rollback-safe upgrades, i.e., // you can downgrade oragono and everything will work - latestDbMinorVersion = "1" + latestDbMinorVersion = "2" keySchemaMinorVersion = "db.minorversion" cleanupRowLimit = 50 cleanupPauseTime = 10 * time.Minute @@ -53,6 +53,7 @@ type MySQL struct { insertHistory *sql.Stmt insertSequence *sql.Stmt insertConversation *sql.Stmt + insertCorrespondent *sql.Stmt insertAccountMessage *sql.Stmt stateMutex sync.Mutex @@ -155,10 +156,24 @@ func (mysql *MySQL) fixSchemas() (err error) { if err != nil { return } + err = mysql.createCorrespondentsTable() + if err != nil { + return + } _, err = mysql.db.Exec(`insert into metadata (key_name, value) values (?, ?);`, keySchemaMinorVersion, latestDbMinorVersion) if err != nil { return } + } else if err == nil && minorVersion == "1" { + // upgrade from 2.1 to 2.2: create the correspondents table + err = mysql.createCorrespondentsTable() + if err != nil { + return + } + _, err = mysql.db.Exec(`update metadata set value = ? where key_name = ?;`, latestDbMinorVersion, keySchemaMinorVersion) + if err != nil { + return + } } else if err == nil && minorVersion != latestDbMinorVersion { // TODO: if minorVersion < latestDbMinorVersion, upgrade, // if latestDbMinorVersion < minorVersion, ignore because backwards compatible @@ -202,6 +217,11 @@ func (mysql *MySQL) createTables() (err error) { return err } + err = mysql.createCorrespondentsTable() + if err != nil { + return err + } + err = mysql.createComplianceTables() if err != nil { return err @@ -210,6 +230,19 @@ func (mysql *MySQL) createTables() (err error) { return nil } +func (mysql *MySQL) createCorrespondentsTable() (err error) { + _, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE correspondents ( + id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY, + target VARBINARY(%[1]d) NOT NULL, + correspondent VARBINARY(%[1]d) NOT NULL, + nanotime BIGINT UNSIGNED NOT NULL, + UNIQUE KEY (target, correspondent), + KEY (target, nanotime), + KEY (nanotime) + ) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength)) + return +} + func (mysql *MySQL) createComplianceTables() (err error) { _, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE account_messages ( history_id BIGINT UNSIGNED NOT NULL PRIMARY KEY, @@ -275,12 +308,16 @@ func (mysql *MySQL) doCleanup(age time.Duration) (count int, err error) { mysql.logger.Debug("mysql", fmt.Sprintf("deleting %d history rows, max age %s", len(ids), utils.NanoToTimestamp(maxNanotime))) + if maxNanotime != 0 { + mysql.deleteCorrespondents(ctx, maxNanotime) + } + return len(ids), mysql.deleteHistoryIDs(ctx, ids) } func (mysql *MySQL) deleteHistoryIDs(ctx context.Context, ids []uint64) (err error) { // can't use ? binding for a variable number of arguments, build the IN clause manually - var inBuf bytes.Buffer + var inBuf strings.Builder inBuf.WriteByte('(') for i, id := range ids { if i != 0 { @@ -289,22 +326,23 @@ func (mysql *MySQL) deleteHistoryIDs(ctx context.Context, ids []uint64) (err err fmt.Fprintf(&inBuf, "%d", id) } inBuf.WriteRune(')') + inClause := inBuf.String() - _, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM conversations WHERE history_id in %s;`, inBuf.Bytes())) + _, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM conversations WHERE history_id in %s;`, inClause)) if err != nil { return } - _, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM sequence WHERE history_id in %s;`, inBuf.Bytes())) + _, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM sequence WHERE history_id in %s;`, inClause)) if err != nil { return } if mysql.isTrackingAccountMessages() { - _, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM account_messages WHERE history_id in %s;`, inBuf.Bytes())) + _, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM account_messages WHERE history_id in %s;`, inClause)) if err != nil { return } } - _, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM history WHERE id in %s;`, inBuf.Bytes())) + _, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM history WHERE id in %s;`, inClause)) if err != nil { return } @@ -351,6 +389,18 @@ func (mysql *MySQL) selectCleanupIDs(ctx context.Context, age time.Duration) (id return } +func (mysql *MySQL) deleteCorrespondents(ctx context.Context, threshold int64) { + result, err := mysql.db.ExecContext(ctx, `DELETE FROM correspondents WHERE nanotime <= (?);`, threshold) + if err != nil { + mysql.logError("error deleting correspondents", err) + } else { + count, err := result.RowsAffected() + if err != nil { + mysql.logger.Debug(fmt.Sprintf("deleted %d correspondents entries", count)) + } + } +} + // wait for forget queue items and process them one by one func (mysql *MySQL) forgetLoop() { defer func() { @@ -470,6 +520,12 @@ func (mysql *MySQL) prepareStatements() (err error) { if err != nil { return } + mysql.insertCorrespondent, err = mysql.db.Prepare(`INSERT INTO correspondents + (target, correspondent, nanotime) VALUES (?, ?, ?) + ON DUPLICATE KEY UPDATE nanotime = GREATEST(nanotime, ?);`) + if err != nil { + return + } mysql.insertAccountMessage, err = mysql.db.Prepare(`INSERT INTO account_messages (history_id, account) VALUES (?, ?);`) if err != nil { @@ -557,6 +613,12 @@ func (mysql *MySQL) insertConversationEntry(ctx context.Context, target, corresp return } +func (mysql *MySQL) insertCorrespondentsEntry(ctx context.Context, target, correspondent string, messageTime int64, historyId int64) (err error) { + _, err = mysql.insertCorrespondent.ExecContext(ctx, target, correspondent, messageTime, messageTime) + mysql.logError("could not insert conversations entry", err) + return +} + func (mysql *MySQL) insertBase(ctx context.Context, item history.Item) (id int64, err error) { value, err := marshalItem(&item) if mysql.logError("could not marshal item", err) { @@ -621,6 +683,10 @@ func (mysql *MySQL) AddDirectMessage(sender, senderAccount, recipient, recipient if err != nil { return } + err = mysql.insertCorrespondentsEntry(ctx, senderAccount, recipient, nanotime, id) + if err != nil { + return + } } if recipientAccount != "" && sender != recipient { @@ -632,6 +698,10 @@ func (mysql *MySQL) AddDirectMessage(sender, senderAccount, recipient, recipient if err != nil { return } + err = mysql.insertCorrespondentsEntry(ctx, recipientAccount, sender, nanotime, id) + if err != nil { + return + } } err = mysql.insertAccountMessageEntry(ctx, id, senderAccount) @@ -804,7 +874,7 @@ func (mysql *MySQL) betweenTimestamps(ctx context.Context, target, correspondent direction = "DESC" } - var queryBuf bytes.Buffer + var queryBuf strings.Builder args := make([]interface{}, 0, 6) fmt.Fprintf(&queryBuf, @@ -835,6 +905,103 @@ func (mysql *MySQL) betweenTimestamps(ctx context.Context, target, correspondent return } +func (mysql *MySQL) listCorrespondentsInternal(ctx context.Context, target string, after, before, cutoff time.Time, limit int) (results []history.TargetListing, err error) { + after, before, ascending := history.MinMaxAsc(after, before, cutoff) + direction := "ASC" + if !ascending { + direction = "DESC" + } + + var queryBuf strings.Builder + args := make([]interface{}, 0, 4) + queryBuf.WriteString(`SELECT correspondents.correspondent, correspondents.nanotime from correspondents + WHERE target = ?`) + args = append(args, target) + if !after.IsZero() { + queryBuf.WriteString(" AND correspondents.nanotime > ?") + args = append(args, after.UnixNano()) + } + if !before.IsZero() { + queryBuf.WriteString(" AND correspondents.nanotime < ?") + args = append(args, before.UnixNano()) + } + fmt.Fprintf(&queryBuf, " ORDER BY correspondents.nanotime %s LIMIT ?;", direction) + args = append(args, limit) + query := queryBuf.String() + + rows, err := mysql.db.QueryContext(ctx, query, args...) + if err != nil { + return + } + defer rows.Close() + var correspondent string + var nanotime int64 + for rows.Next() { + err = rows.Scan(&correspondent, &nanotime) + if err != nil { + return + } + results = append(results, history.TargetListing{ + CfName: correspondent, + Time: time.Unix(0, nanotime), + }) + } + + if !ascending { + history.ReverseCorrespondents(results) + } + + return +} + +func (mysql *MySQL) ListChannels(cfchannels []string) (results []history.TargetListing, err error) { + if mysql.db == nil { + return + } + + if len(cfchannels) == 0 { + return + } + + ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout()) + defer cancel() + + var queryBuf strings.Builder + args := make([]interface{}, 0, len(results)) + // https://dev.mysql.com/doc/refman/8.0/en/group-by-optimization.html + // this should be a "loose index scan" + queryBuf.WriteString(`SELECT sequence.target, MAX(sequence.nanotime) FROM sequence + WHERE sequence.target IN (`) + for i, chname := range cfchannels { + if i != 0 { + queryBuf.WriteString(", ") + } + queryBuf.WriteByte('?') + args = append(args, chname) + } + queryBuf.WriteString(") GROUP BY sequence.target;") + + rows, err := mysql.db.QueryContext(ctx, queryBuf.String(), args...) + if mysql.logError("could not query channel listings", err) { + return + } + defer rows.Close() + + var target string + var nanotime int64 + for rows.Next() { + err = rows.Scan(&target, &nanotime) + if mysql.logError("could not scan channel listings", err) { + return + } + results = append(results, history.TargetListing{ + CfName: target, + Time: time.Unix(0, nanotime), + }) + } + return +} + func (mysql *MySQL) Close() { // closing the database will close our prepared statements as well if mysql.db != nil { @@ -852,7 +1019,7 @@ type mySQLHistorySequence struct { cutoff time.Time } -func (s *mySQLHistorySequence) Between(start, end history.Selector, limit int) (results []history.Item, complete bool, err error) { +func (s *mySQLHistorySequence) Between(start, end history.Selector, limit int) (results []history.Item, err error) { ctx, cancel := context.WithTimeout(context.Background(), s.mysql.getTimeout()) defer cancel() @@ -860,25 +1027,46 @@ func (s *mySQLHistorySequence) Between(start, end history.Selector, limit int) ( if start.Msgid != "" { startTime, _, _, err = s.mysql.lookupMsgid(ctx, start.Msgid, false) if err != nil { - return nil, false, err + return nil, err } } endTime := end.Time if end.Msgid != "" { endTime, _, _, err = s.mysql.lookupMsgid(ctx, end.Msgid, false) if err != nil { - return nil, false, err + return nil, err } } results, err = s.mysql.betweenTimestamps(ctx, s.target, s.correspondent, startTime, endTime, s.cutoff, limit) - return results, (err == nil), err + return results, err } func (s *mySQLHistorySequence) Around(start history.Selector, limit int) (results []history.Item, err error) { return history.GenericAround(s, start, limit) } +func (seq *mySQLHistorySequence) ListCorrespondents(start, end history.Selector, limit int) (results []history.TargetListing, err error) { + ctx, cancel := context.WithTimeout(context.Background(), seq.mysql.getTimeout()) + defer cancel() + + // TODO accept msgids here? + startTime := start.Time + endTime := end.Time + + results, err = seq.mysql.listCorrespondentsInternal(ctx, seq.target, startTime, endTime, seq.cutoff, limit) + seq.mysql.logError("could not read correspondents", err) + return +} + +func (seq *mySQLHistorySequence) Cutoff() time.Time { + return seq.cutoff +} + +func (seq *mySQLHistorySequence) Ephemeral() bool { + return false +} + func (mysql *MySQL) MakeSequence(target, correspondent string, cutoff time.Time) history.Sequence { return &mySQLHistorySequence{ target: target, diff --git a/irc/server.go b/irc/server.go index 1fece0c1..8a45ad79 100644 --- a/irc/server.go +++ b/irc/server.go @@ -1017,6 +1017,13 @@ func (server *Server) DeleteMessage(target, msgid, accountName string) (err erro return } +func (server *Server) UnfoldName(cfname string) (name string) { + if strings.HasPrefix(cfname, "#") { + return server.channels.UnfoldName(cfname) + } + return server.clients.UnfoldNick(cfname) +} + // elistMatcher takes and matches ELIST conditions type elistMatcher struct { MinClientsActive bool diff --git a/irc/znc.go b/irc/znc.go index 51749bfa..6f3ceae8 100644 --- a/irc/znc.go +++ b/irc/znc.go @@ -16,6 +16,8 @@ import ( const ( // #829, also see "Case 2" in the "three cases" below: zncPlaybackCommandExpiration = time.Second * 30 + + zncPrefix = "*playback!znc@znc.in" ) type zncCommandHandler func(client *Client, command string, params []string, rb *ResponseBuffer) @@ -192,29 +194,24 @@ func zncPlayPrivmsgs(client *Client, rb *ResponseBuffer, target string, after, b return } zncMax := client.server.Config().History.ZNCMax - items, _, err := sequence.Between(history.Selector{Time: after}, history.Selector{Time: before}, zncMax) + items, err := sequence.Between(history.Selector{Time: after}, history.Selector{Time: before}, zncMax) if err == nil && len(items) != 0 { - client.replayPrivmsgHistory(rb, items, "", true) + client.replayPrivmsgHistory(rb, items, "") } } // PRIVMSG *playback :list func zncPlaybackListHandler(client *Client, command string, params []string, rb *ResponseBuffer) { + limit := client.server.Config().History.ChathistoryMax + correspondents, err := client.listTargets(history.Selector{}, history.Selector{}, limit) + if err != nil { + client.server.logger.Error("internal", "couldn't get history for ZNC list", err.Error()) + return + } nick := client.Nick() - for _, channel := range client.Channels() { - _, sequence, err := client.server.GetHistorySequence(channel, client, "") - if sequence == nil { - continue - } else if err != nil { - client.server.logger.Error("internal", "couldn't get history sequence for ZNC list", err.Error()) - continue - } - items, _, err := sequence.Between(history.Selector{}, history.Selector{}, 1) // i.e., LATEST * 1 - if err != nil { - client.server.logger.Error("internal", "couldn't query history for ZNC list", err.Error()) - } else if len(items) != 0 { - stamp := timeToZncWireTime(items[0].Message.Time) - rb.Add(nil, "*playback!znc@znc.in", "PRIVMSG", nick, fmt.Sprintf("%s 0 %s", channel.Name(), stamp)) - } + for _, correspondent := range correspondents { + stamp := timeToZncWireTime(correspondent.Time) + unfoldedTarget := client.server.UnfoldName(correspondent.CfName) + rb.Add(nil, zncPrefix, "PRIVMSG", nick, fmt.Sprintf("%s 0 %s", unfoldedTarget, stamp)) } }