3
0
mirror of https://github.com/ergochat/ergo.git synced 2024-12-22 18:52:41 +01:00

implement CHATHISTORY TARGETS

This commit is contained in:
Shivaram Lingamneni 2021-04-07 05:40:39 -04:00
parent 4052cd12fe
commit 18b6e2f1cd
9 changed files with 248 additions and 78 deletions

View File

@ -458,3 +458,13 @@ func (cm *ChannelManager) ListPurged() (result []string) {
sort.Strings(result) sort.Strings(result)
return return
} }
func (cm *ChannelManager) UnfoldName(cfname string) (result string) {
cm.RLock()
entry := cm.chans[cfname]
cm.RUnlock()
if entry != nil && entry.channel.IsLoaded() {
return entry.channel.Name()
}
return cfname
}

View File

@ -1931,6 +1931,43 @@ func (client *Client) addHistoryItem(target *Client, item history.Item, details,
return nil return nil
} }
func (client *Client) listTargets(start, end history.Selector, limit int) (results []history.TargetListing, err error) {
var base, extras []history.TargetListing
var chcfnames []string
for _, channel := range client.Channels() {
_, seq, err := client.server.GetHistorySequence(channel, client, "")
if seq == nil || err != nil {
continue
}
if seq.Ephemeral() {
items, err := seq.Between(history.Selector{}, history.Selector{}, 1)
if err == nil && len(items) != 0 {
extras = append(extras, history.TargetListing{
Time: items[0].Message.Time,
CfName: channel.NameCasefolded(),
})
}
} else {
chcfnames = append(chcfnames, channel.NameCasefolded())
}
}
persistentExtras, err := client.server.historyDB.ListChannels(chcfnames)
if err == nil && len(persistentExtras) != 0 {
extras = append(extras, persistentExtras...)
}
_, cSeq, err := client.server.GetHistorySequence(nil, client, "*")
if err == nil && cSeq != nil {
correspondents, err := cSeq.ListCorrespondents(start, end, limit)
if err == nil {
base = correspondents
}
}
results = history.MergeTargets(base, extras, start.Time, end.Time, limit)
return results, nil
}
func (client *Client) handleRegisterTimeout() { func (client *Client) handleRegisterTimeout() {
client.Quit(fmt.Sprintf("Registration timeout: %v", RegisterTimeout), nil) client.Quit(fmt.Sprintf("Registration timeout: %v", RegisterTimeout), nil)
client.destroy(nil) client.destroy(nil)

View File

@ -570,25 +570,25 @@ func chathistoryHandler(server *Server, client *Client, msg ircmsg.Message, rb *
var channel *Channel var channel *Channel
var sequence history.Sequence var sequence history.Sequence
var err error var err error
var listCorrespondents bool var listTargets bool
var correspondents []history.CorrespondentListing var targets []history.TargetListing
defer func() { defer func() {
// errors are sent either without a batch, or in a draft/labeled-response batch as usual // errors are sent either without a batch, or in a draft/labeled-response batch as usual
if err == utils.ErrInvalidParams { if err == utils.ErrInvalidParams {
rb.Add(nil, server.name, "FAIL", "CHATHISTORY", "INVALID_PARAMS", msg.Params[0], client.t("Invalid parameters")) rb.Add(nil, server.name, "FAIL", "CHATHISTORY", "INVALID_PARAMS", msg.Params[0], client.t("Invalid parameters"))
} else if sequence == nil { } else if !listTargets && sequence == nil {
rb.Add(nil, server.name, "FAIL", "CHATHISTORY", "INVALID_TARGET", utils.SafeErrorParam(target), client.t("Messages could not be retrieved")) rb.Add(nil, server.name, "FAIL", "CHATHISTORY", "INVALID_TARGET", utils.SafeErrorParam(target), client.t("Messages could not be retrieved"))
} else if err != nil { } else if err != nil {
rb.Add(nil, server.name, "FAIL", "CHATHISTORY", "MESSAGE_ERROR", msg.Params[0], client.t("Messages could not be retrieved")) rb.Add(nil, server.name, "FAIL", "CHATHISTORY", "MESSAGE_ERROR", msg.Params[0], client.t("Messages could not be retrieved"))
} else { } else {
// successful responses are sent as a chathistory or history batch // successful responses are sent as a chathistory or history batch
if listCorrespondents { if listTargets {
batchID := rb.StartNestedBatch("draft/chathistory-listcorrespondents") batchID := rb.StartNestedBatch("draft/chathistory-targets")
defer rb.EndNestedBatch(batchID) defer rb.EndNestedBatch(batchID)
for _, correspondent := range correspondents { for _, target := range targets {
nick := server.clients.UnfoldNick(correspondent.CfCorrespondent) name := server.UnfoldName(target.CfName)
rb.Add(nil, server.name, "CHATHISTORY", "CORRESPONDENT", nick, rb.Add(nil, server.name, "CHATHISTORY", "TARGETS", name,
correspondent.Time.Format(IRCv3TimestampFormat)) target.Time.Format(IRCv3TimestampFormat))
} }
} else if channel != nil { } else if channel != nil {
channel.replayHistoryItems(rb, items, false) channel.replayHistoryItems(rb, items, false)
@ -605,9 +605,7 @@ func chathistoryHandler(server *Server, client *Client, msg ircmsg.Message, rb *
} }
preposition := strings.ToLower(msg.Params[0]) preposition := strings.ToLower(msg.Params[0])
target = msg.Params[1] target = msg.Params[1]
if preposition == "listcorrespondents" { listTargets = (preposition == "targets")
target = "*"
}
parseQueryParam := func(param string) (msgid string, timestamp time.Time, err error) { parseQueryParam := func(param string) (msgid string, timestamp time.Time, err error) {
if param == "*" && (preposition == "before" || preposition == "between") { if param == "*" && (preposition == "before" || preposition == "between") {
@ -642,11 +640,6 @@ func chathistoryHandler(server *Server, client *Client, msg ircmsg.Message, rb *
return return
} }
channel, sequence, err = server.GetHistorySequence(nil, client, target)
if err != nil || sequence == nil {
return
}
roundUp := func(endpoint time.Time) (result time.Time) { roundUp := func(endpoint time.Time) (result time.Time) {
return endpoint.Truncate(time.Millisecond).Add(time.Millisecond) return endpoint.Truncate(time.Millisecond).Add(time.Millisecond)
} }
@ -655,8 +648,7 @@ func chathistoryHandler(server *Server, client *Client, msg ircmsg.Message, rb *
var start, end history.Selector var start, end history.Selector
var limit int var limit int
switch preposition { switch preposition {
case "listcorrespondents": case "targets":
listCorrespondents = true
// use the same selector parsing as BETWEEN, // use the same selector parsing as BETWEEN,
// except that we have no target so we have one fewer parameter // except that we have no target so we have one fewer parameter
paramPos = 1 paramPos = 1
@ -710,12 +702,18 @@ func chathistoryHandler(server *Server, client *Client, msg ircmsg.Message, rb *
return return
} }
if listCorrespondents { if listTargets {
correspondents, err = sequence.ListCorrespondents(start, end, limit) targets, err = client.listTargets(start, end, limit)
} else if preposition == "around" {
items, err = sequence.Around(start, limit)
} else { } else {
items, err = sequence.Between(start, end, limit) channel, sequence, err = server.GetHistorySequence(nil, client, target)
if err != nil || sequence == nil {
return
}
if preposition == "around" {
items, err = sequence.Around(start, limit)
} else {
items, err = sequence.Between(start, end, limit)
}
} }
return return
} }

View File

@ -48,11 +48,6 @@ type Item struct {
IsBot bool `json:"IsBot,omitempty"` IsBot bool `json:"IsBot,omitempty"`
} }
type CorrespondentListing struct {
CfCorrespondent string
Time time.Time
}
// HasMsgid tests whether a message has the message id `msgid`. // HasMsgid tests whether a message has the message id `msgid`.
func (item *Item) HasMsgid(msgid string) bool { func (item *Item) HasMsgid(msgid string) bool {
return item.Message.Msgid == msgid return item.Message.Msgid == msgid
@ -66,13 +61,6 @@ func Reverse(results []Item) {
} }
} }
func ReverseCorrespondents(results []CorrespondentListing) {
// lol, generics when?
for i, j := 0, len(results)-1; i < j; i, j = i+1, j-1 {
results[i], results[j] = results[j], results[i]
}
}
// Buffer is a ring buffer holding message/event history for a channel or user // Buffer is a ring buffer holding message/event history for a channel or user
type Buffer struct { type Buffer struct {
sync.RWMutex sync.RWMutex
@ -214,7 +202,7 @@ func (list *Buffer) betweenHelper(start, end Selector, cutoff time.Time, pred Pr
} }
// returns all correspondents, in reverse time order // returns all correspondents, in reverse time order
func (list *Buffer) allCorrespondents() (results []CorrespondentListing) { func (list *Buffer) allCorrespondents() (results []TargetListing) {
seen := make(utils.StringSet) seen := make(utils.StringSet)
list.RLock() list.RLock()
@ -231,9 +219,9 @@ func (list *Buffer) allCorrespondents() (results []CorrespondentListing) {
for { for {
if !seen.Has(list.buffer[pos].CfCorrespondent) { if !seen.Has(list.buffer[pos].CfCorrespondent) {
seen.Add(list.buffer[pos].CfCorrespondent) seen.Add(list.buffer[pos].CfCorrespondent)
results = append(results, CorrespondentListing{ results = append(results, TargetListing{
CfCorrespondent: list.buffer[pos].CfCorrespondent, CfName: list.buffer[pos].CfCorrespondent,
Time: list.buffer[pos].Message.Time, Time: list.buffer[pos].Message.Time,
}) })
} }
@ -245,8 +233,8 @@ func (list *Buffer) allCorrespondents() (results []CorrespondentListing) {
return return
} }
// implement LISTCORRESPONDENTS // list DM correspondents, as one input to CHATHISTORY TARGETS
func (list *Buffer) listCorrespondents(start, end Selector, cutoff time.Time, limit int) (results []CorrespondentListing, err error) { func (list *Buffer) listCorrespondents(start, end Selector, cutoff time.Time, limit int) (results []TargetListing, err error) {
after := start.Time after := start.Time
before := end.Time before := end.Time
after, before, ascending := MinMaxAsc(after, before, cutoff) after, before, ascending := MinMaxAsc(after, before, cutoff)
@ -316,10 +304,18 @@ func (seq *bufferSequence) Around(start Selector, limit int) (results []Item, er
return GenericAround(seq, start, limit) return GenericAround(seq, start, limit)
} }
func (seq *bufferSequence) ListCorrespondents(start, end Selector, limit int) (results []CorrespondentListing, err error) { func (seq *bufferSequence) ListCorrespondents(start, end Selector, limit int) (results []TargetListing, err error) {
return seq.list.listCorrespondents(start, end, seq.cutoff, limit) return seq.list.listCorrespondents(start, end, seq.cutoff, limit)
} }
func (seq *bufferSequence) Cutoff() time.Time {
return seq.cutoff
}
func (seq *bufferSequence) Ephemeral() bool {
return true
}
// you must be holding the read lock to call this // you must be holding the read lock to call this
func (list *Buffer) matchInternal(predicate Predicate, ascending bool, limit int) (results []Item) { func (list *Buffer) matchInternal(predicate Predicate, ascending bool, limit int) (results []Item) {
if list.start == -1 || len(list.buffer) == 0 { if list.start == -1 || len(list.buffer) == 0 {

View File

@ -20,7 +20,14 @@ type Sequence interface {
Between(start, end Selector, limit int) (results []Item, err error) Between(start, end Selector, limit int) (results []Item, err error)
Around(start Selector, limit int) (results []Item, err error) Around(start Selector, limit int) (results []Item, err error)
ListCorrespondents(start, end Selector, limit int) (results []CorrespondentListing, err error) ListCorrespondents(start, end Selector, limit int) (results []TargetListing, err error)
// this are weird hacks that violate the encapsulation of Sequence to some extent;
// Cutoff() returns the cutoff time for other code to use (it returns the zero time
// if none is set), and Ephemeral() returns whether the backing store is in-memory
// or a persistent database.
Cutoff() time.Time
Ephemeral() bool
} }
// This is a bad, slow implementation of CHATHISTORY AROUND using the BETWEEN semantics // This is a bad, slow implementation of CHATHISTORY AROUND using the BETWEEN semantics

83
irc/history/targets.go Normal file
View File

@ -0,0 +1,83 @@
// Copyright (c) 2021 Shivaram Lingamneni <slingamn@cs.stanford.edu>
// released under the MIT license
package history
import (
"sort"
"time"
)
type TargetListing struct {
CfName string
Time time.Time
}
// Merge `base`, a paging window of targets, with `extras` (the target entries
// for all joined channels).
func MergeTargets(base []TargetListing, extra []TargetListing, start, end time.Time, limit int) (results []TargetListing) {
if len(extra) == 0 {
return base
}
SortCorrespondents(extra)
start, end, ascending := MinMaxAsc(start, end, time.Time{})
predicate := func(t time.Time) bool {
return (start.IsZero() || start.Before(t)) && (end.IsZero() || end.After(t))
}
prealloc := len(base) + len(extra)
if limit < prealloc {
prealloc = limit
}
results = make([]TargetListing, 0, prealloc)
if !ascending {
ReverseCorrespondents(base)
ReverseCorrespondents(extra)
}
for len(results) < limit {
if len(extra) != 0 {
if !predicate(extra[0].Time) {
extra = extra[1:]
continue
}
if len(base) != 0 {
if base[0].Time.Before(extra[0].Time) == ascending {
results = append(results, base[0])
base = base[1:]
} else {
results = append(results, extra[0])
extra = extra[1:]
}
} else {
results = append(results, extra[0])
extra = extra[1:]
}
} else if len(base) != 0 {
results = append(results, base[0])
base = base[1:]
} else {
break
}
}
if !ascending {
ReverseCorrespondents(results)
}
return
}
func ReverseCorrespondents(results []TargetListing) {
// lol, generics when?
for i, j := 0, len(results)-1; i < j; i, j = i+1, j-1 {
results[i], results[j] = results[j], results[i]
}
}
func SortCorrespondents(list []TargetListing) {
sort.Slice(list, func(i, j int) bool {
return list[i].Time.Before(list[j].Time)
})
}

View File

@ -905,7 +905,7 @@ func (mysql *MySQL) betweenTimestamps(ctx context.Context, target, correspondent
return return
} }
func (mysql *MySQL) listCorrespondentsInternal(ctx context.Context, target string, after, before, cutoff time.Time, limit int) (results []history.CorrespondentListing, err error) { func (mysql *MySQL) listCorrespondentsInternal(ctx context.Context, target string, after, before, cutoff time.Time, limit int) (results []history.TargetListing, err error) {
after, before, ascending := history.MinMaxAsc(after, before, cutoff) after, before, ascending := history.MinMaxAsc(after, before, cutoff)
direction := "ASC" direction := "ASC"
if !ascending { if !ascending {
@ -941,9 +941,9 @@ func (mysql *MySQL) listCorrespondentsInternal(ctx context.Context, target strin
if err != nil { if err != nil {
return return
} }
results = append(results, history.CorrespondentListing{ results = append(results, history.TargetListing{
CfCorrespondent: correspondent, CfName: correspondent,
Time: time.Unix(0, nanotime), Time: time.Unix(0, nanotime),
}) })
} }
@ -954,6 +954,54 @@ func (mysql *MySQL) listCorrespondentsInternal(ctx context.Context, target strin
return return
} }
func (mysql *MySQL) ListChannels(cfchannels []string) (results []history.TargetListing, err error) {
if mysql.db == nil {
return
}
if len(cfchannels) == 0 {
return
}
ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
defer cancel()
var queryBuf strings.Builder
args := make([]interface{}, 0, len(results))
// https://dev.mysql.com/doc/refman/8.0/en/group-by-optimization.html
// this should be a "loose index scan"
queryBuf.WriteString(`SELECT sequence.target, MAX(sequence.nanotime) FROM sequence
WHERE sequence.target IN (`)
for i, chname := range cfchannels {
if i != 0 {
queryBuf.WriteString(", ")
}
queryBuf.WriteByte('?')
args = append(args, chname)
}
queryBuf.WriteString(") GROUP BY sequence.target;")
rows, err := mysql.db.QueryContext(ctx, queryBuf.String(), args...)
if mysql.logError("could not query channel listings", err) {
return
}
defer rows.Close()
var target string
var nanotime int64
for rows.Next() {
err = rows.Scan(&target, &nanotime)
if mysql.logError("could not scan channel listings", err) {
return
}
results = append(results, history.TargetListing{
CfName: target,
Time: time.Unix(0, nanotime),
})
}
return
}
func (mysql *MySQL) Close() { func (mysql *MySQL) Close() {
// closing the database will close our prepared statements as well // closing the database will close our prepared statements as well
if mysql.db != nil { if mysql.db != nil {
@ -998,7 +1046,7 @@ func (s *mySQLHistorySequence) Around(start history.Selector, limit int) (result
return history.GenericAround(s, start, limit) return history.GenericAround(s, start, limit)
} }
func (seq *mySQLHistorySequence) ListCorrespondents(start, end history.Selector, limit int) (results []history.CorrespondentListing, err error) { func (seq *mySQLHistorySequence) ListCorrespondents(start, end history.Selector, limit int) (results []history.TargetListing, err error) {
ctx, cancel := context.WithTimeout(context.Background(), seq.mysql.getTimeout()) ctx, cancel := context.WithTimeout(context.Background(), seq.mysql.getTimeout())
defer cancel() defer cancel()
@ -1011,6 +1059,14 @@ func (seq *mySQLHistorySequence) ListCorrespondents(start, end history.Selector,
return return
} }
func (seq *mySQLHistorySequence) Cutoff() time.Time {
return seq.cutoff
}
func (seq *mySQLHistorySequence) Ephemeral() bool {
return false
}
func (mysql *MySQL) MakeSequence(target, correspondent string, cutoff time.Time) history.Sequence { func (mysql *MySQL) MakeSequence(target, correspondent string, cutoff time.Time) history.Sequence {
return &mySQLHistorySequence{ return &mySQLHistorySequence{
target: target, target: target,

View File

@ -1017,6 +1017,13 @@ func (server *Server) DeleteMessage(target, msgid, accountName string) (err erro
return return
} }
func (server *Server) UnfoldName(cfname string) (name string) {
if strings.HasPrefix(cfname, "#") {
return server.channels.UnfoldName(cfname)
}
return server.clients.UnfoldNick(cfname)
}
// elistMatcher takes and matches ELIST conditions // elistMatcher takes and matches ELIST conditions
type elistMatcher struct { type elistMatcher struct {
MinClientsActive bool MinClientsActive bool

View File

@ -202,40 +202,16 @@ func zncPlayPrivmsgs(client *Client, rb *ResponseBuffer, target string, after, b
// PRIVMSG *playback :list // PRIVMSG *playback :list
func zncPlaybackListHandler(client *Client, command string, params []string, rb *ResponseBuffer) { func zncPlaybackListHandler(client *Client, command string, params []string, rb *ResponseBuffer) {
nick := client.Nick()
for _, channel := range client.Channels() {
_, sequence, err := client.server.GetHistorySequence(channel, client, "")
if sequence == nil {
continue
} else if err != nil {
client.server.logger.Error("internal", "couldn't get history sequence for ZNC list", err.Error())
continue
}
items, err := sequence.Between(history.Selector{}, history.Selector{}, 1) // i.e., LATEST * 1
if err != nil {
client.server.logger.Error("internal", "couldn't query history for ZNC list", err.Error())
} else if len(items) != 0 {
stamp := timeToZncWireTime(items[0].Message.Time)
rb.Add(nil, zncPrefix, "PRIVMSG", nick, fmt.Sprintf("%s 0 %s", channel.Name(), stamp))
}
}
_, seq, err := client.server.GetHistorySequence(nil, client, "*")
if seq == nil {
return
} else if err != nil {
client.server.logger.Error("internal", "couldn't get client history sequence for ZNC list", err.Error())
return
}
limit := client.server.Config().History.ChathistoryMax limit := client.server.Config().History.ChathistoryMax
correspondents, err := seq.ListCorrespondents(history.Selector{}, history.Selector{}, limit) correspondents, err := client.listTargets(history.Selector{}, history.Selector{}, limit)
if err != nil { if err != nil {
client.server.logger.Error("internal", "couldn't get correspondents for ZNC list", err.Error()) client.server.logger.Error("internal", "couldn't get history for ZNC list", err.Error())
return return
} }
nick := client.Nick()
for _, correspondent := range correspondents { for _, correspondent := range correspondents {
stamp := timeToZncWireTime(correspondent.Time) stamp := timeToZncWireTime(correspondent.Time)
correspondentNick := client.server.clients.UnfoldNick(correspondent.CfCorrespondent) unfoldedTarget := client.server.UnfoldName(correspondent.CfName)
rb.Add(nil, zncPrefix, "PRIVMSG", nick, fmt.Sprintf("%s 0 %s", correspondentNick, stamp)) rb.Add(nil, zncPrefix, "PRIVMSG", nick, fmt.Sprintf("%s 0 %s", unfoldedTarget, stamp))
} }
} }