From 9f54ea07b7b392771f2a33f6deb2b227e0be5626 Mon Sep 17 00:00:00 2001 From: Shivaram Lingamneni Date: Tue, 30 Dec 2025 23:12:30 -0500 Subject: [PATCH] prep for alternative history databases (#2316) * abstract history DB interface * make mysql error logging consistent Consistently propagate database errors to the client, making the client responsible for logging them. * move ListCorrespondents from Sequence to Database/Buffer --- irc/channel.go | 3 + irc/client.go | 39 ++++++++----- irc/handlers.go | 7 ++- irc/history/database.go | 120 ++++++++++++++++++++++++++++++++++++++++ irc/history/history.go | 10 +--- irc/history/queries.go | 2 - irc/mysql/history.go | 105 ++++++++++++++++++++--------------- irc/server.go | 26 +++++---- irc/znc.go | 2 +- 9 files changed, 232 insertions(+), 82 deletions(-) create mode 100644 irc/history/database.go diff --git a/irc/channel.go b/irc/channel.go index ade4b5b4..7f8f12a0 100644 --- a/irc/channel.go +++ b/irc/channel.go @@ -730,6 +730,9 @@ func (channel *Channel) AddHistoryItem(item history.Item, account string) (err e status, target, _ := channel.historyStatus(channel.server.Config()) if status == HistoryPersistent { err = channel.server.historyDB.AddChannelItem(target, item, account) + if err != nil { + channel.server.logger.Error("history", "could not add channel message to history", err.Error()) + } } else if status == HistoryEphemeral { channel.history.Add(item) } diff --git a/irc/client.go b/irc/client.go index 7e7b4ba4..bf150d82 100644 --- a/irc/client.go +++ b/irc/client.go @@ -1774,12 +1774,15 @@ func (client *Client) addHistoryItem(target *Client, item history.Item, details, } if cStatus == HistoryPersistent || tStatus == HistoryPersistent { targetedItem.CfCorrespondent = "" - client.server.historyDB.AddDirectMessage(details.nickCasefolded, details.account, tDetails.nickCasefolded, tDetails.account, targetedItem) + err = client.server.historyDB.AddDirectMessage(details.nickCasefolded, details.account, tDetails.nickCasefolded, tDetails.account, targetedItem) + if err != nil { + client.server.logger.Error("history", "could not add direct message to history", err.Error()) + } } return nil } -func (client *Client) listTargets(start, end history.Selector, limit int) (results []history.TargetListing, err error) { +func (client *Client) listTargets(start, end time.Time, limit int) (results []history.TargetListing, err error) { var base, extras []history.TargetListing var chcfnames []string for _, channel := range client.Channels() { @@ -1800,27 +1803,35 @@ func (client *Client) listTargets(start, end history.Selector, limit int) (resul } } persistentExtras, err := client.server.historyDB.ListChannels(chcfnames) - if err == nil && len(persistentExtras) != 0 { + if err != nil { + client.server.logger.Error("history", "could not list persistent channels", err.Error()) + } else if len(persistentExtras) != 0 { extras = append(extras, persistentExtras...) } - _, cSeq, err := client.server.GetHistorySequence(nil, client, "") - if err == nil && cSeq != nil { - correspondents, err := cSeq.ListCorrespondents(start, end, limit) - if err == nil { - base = correspondents - } + // get DM correspondents from the in-memory buffer or the database, as applicable + var cErr error + status, target := client.historyStatus(client.server.Config()) + switch status { + case HistoryEphemeral: + base, cErr = client.history.ListCorrespondents(start, end, limit) + case HistoryPersistent: + base, cErr = client.server.historyDB.ListCorrespondents(target, start, end, limit) + default: + // nothing to do + } + if cErr != nil { + base = nil + client.server.logger.Error("history", "could not list correspondents", cErr.Error()) } - results = history.MergeTargets(base, extras, start.Time, end.Time, limit) + results = history.MergeTargets(base, extras, start, end, limit) return results, nil } // latest PRIVMSG from all DM targets func (client *Client) privmsgsBetween(startTime, endTime time.Time, targetLimit, messageLimit int) (results []history.Item, err error) { - start := history.Selector{Time: startTime} - end := history.Selector{Time: endTime} - targets, err := client.listTargets(start, end, targetLimit) + targets, err := client.listTargets(startTime, endTime, targetLimit) if err != nil { return } @@ -1830,7 +1841,7 @@ func (client *Client) privmsgsBetween(startTime, endTime time.Time, targetLimit, } _, seq, err := client.server.GetHistorySequence(nil, client, target.CfName) if err == nil && seq != nil { - items, err := seq.Between(start, end, messageLimit) + items, err := seq.Between(history.Selector{Time: startTime}, history.Selector{Time: endTime}, messageLimit) if err == nil { results = append(results, items...) } else { diff --git a/irc/handlers.go b/irc/handlers.go index c1a90c96..10cb18d5 100644 --- a/irc/handlers.go +++ b/irc/handlers.go @@ -844,7 +844,12 @@ func chathistoryHandler(server *Server, client *Client, msg ircmsg.Message, rb * } if listTargets { - targets, err = client.listTargets(start, end, limit) + // TARGETS must take time= selectors + if start.Time.IsZero() || end.Time.IsZero() { + err = utils.ErrInvalidParams + return + } + targets, err = client.listTargets(start.Time, end.Time, limit) } else { channel, sequence, err = server.GetHistorySequence(nil, client, target) if err != nil || sequence == nil { diff --git a/irc/history/database.go b/irc/history/database.go new file mode 100644 index 00000000..3ec44ea4 --- /dev/null +++ b/irc/history/database.go @@ -0,0 +1,120 @@ +// Copyright (c) 2025 Shivaram Lingamneni +// released under the MIT license + +package history + +import ( + "io" + "time" +) + +// Database is an interface for persistent history storage backends. +type Database interface { + // Close closes the database connection and releases resources. + io.Closer + + // AddChannelItem adds a history item for a channel. + // target is the casefolded channel name. + // account is the sender's casefolded account name ("" for no account). + AddChannelItem(target string, item Item, account string) error + + // AddDirectMessage adds a history item for a direct message. + // All identifiers are casefolded; account identifiers are "" for no account. + AddDirectMessage(sender, senderAccount, recipient, recipientAccount string, item Item) error + + // DeleteMsgid deletes a message by its msgid. + // accountName is the unfolded account name, or "*" to skip + // account validation + DeleteMsgid(msgid, accountName string) error + + // MakeSequence creates a Sequence for querying history. + // target is the primary target (channel or account), casefolded. + // correspondent is the casefolded DM correspondent (empty for channels). + // cutoff is the earliest time to include in results. + MakeSequence(target, correspondent string, cutoff time.Time) Sequence + + // ListChannels returns the timestamp of the latest message in each + // of the given channels (specified as casefolded names). + ListChannels(cfchannels []string) (results []TargetListing, err error) + + // ListCorrespondents lists the DM correspondents associated with an account, + // in order to implement CHATHISTORY TARGETS. + ListCorrespondents(cftarget string, start, end time.Time, limit int) ([]TargetListing, error) + + // these are for theoretical GDPR compliance, not actual chat functionality, + // and are not essential: + + // Forget enqueues an account (casefolded) for message deletion. + // This is used for GDPR-style "right to be forgotten" requests. + // The actual deletion happens asynchronously. + Forget(account string) + + // Export exports all messages for an account (casefolded) to the given writer. + Export(account string, writer io.Writer) +} + +type noopDatabase struct{} + +// NewNoopDatabase returns a Database implementation that does nothing. +func NewNoopDatabase() Database { + return noopDatabase{} +} + +func (n noopDatabase) Close() error { + return nil +} + +func (n noopDatabase) AddChannelItem(target string, item Item, account string) error { + return nil +} + +func (n noopDatabase) AddDirectMessage(sender, senderAccount, recipient, recipientAccount string, item Item) error { + return nil +} + +func (n noopDatabase) DeleteMsgid(msgid, accountName string) error { + return nil +} + +func (n noopDatabase) Forget(account string) { + // no-op +} + +func (n noopDatabase) Export(account string, writer io.Writer) { + // no-op +} + +func (n noopDatabase) ListChannels(cfchannels []string) (results []TargetListing, err error) { + return nil, nil +} + +func (n noopDatabase) ListCorrespondents(target string, start, end time.Time, limit int) (results []TargetListing, err error) { + return nil, nil +} + +func (n noopDatabase) MakeSequence(target, correspondent string, cutoff time.Time) Sequence { + return noopSequence{} +} + +// noopSequence is a no-op implementation of Sequence. +// XXX: this should never be accessed, because if persistent history is disabled, +// we should always be working with a bufferSequence instead. But we might as well +// be defensive in case there's an edge case where (noopDatabase).MakeSequence ends +// up getting called. +type noopSequence struct{} + +func (n noopSequence) Between(start, end Selector, limit int) (results []Item, err error) { + return nil, nil +} + +func (n noopSequence) Around(start Selector, limit int) (results []Item, err error) { + return nil, nil +} + +func (n noopSequence) Cutoff() time.Time { + return time.Time{} +} + +func (n noopSequence) Ephemeral() bool { + return true +} diff --git a/irc/history/history.go b/irc/history/history.go index ef7e4d9f..6a2c2849 100644 --- a/irc/history/history.go +++ b/irc/history/history.go @@ -230,10 +230,8 @@ func (list *Buffer) allCorrespondents() (results []TargetListing) { } // list DM correspondents, as one input to CHATHISTORY TARGETS -func (list *Buffer) listCorrespondents(start, end Selector, cutoff time.Time, limit int) (results []TargetListing, err error) { - after := start.Time - before := end.Time - after, before, ascending := MinMaxAsc(after, before, cutoff) +func (list *Buffer) ListCorrespondents(start, end time.Time, limit int) (results []TargetListing, err error) { + after, before, ascending := MinMaxAsc(start, end, time.Time{}) correspondents := list.allCorrespondents() if len(correspondents) == 0 { @@ -300,10 +298,6 @@ func (seq *bufferSequence) Around(start Selector, limit int) (results []Item, er return GenericAround(seq, start, limit) } -func (seq *bufferSequence) ListCorrespondents(start, end Selector, limit int) (results []TargetListing, err error) { - return seq.list.listCorrespondents(start, end, seq.cutoff, limit) -} - func (seq *bufferSequence) Cutoff() time.Time { return seq.cutoff } diff --git a/irc/history/queries.go b/irc/history/queries.go index 72b1168f..12f7a0d5 100644 --- a/irc/history/queries.go +++ b/irc/history/queries.go @@ -21,8 +21,6 @@ type Sequence interface { Between(start, end Selector, limit int) (results []Item, err error) Around(start Selector, limit int) (results []Item, err error) - ListCorrespondents(start, end Selector, limit int) (results []TargetListing, err error) - // this are weird hacks that violate the encapsulation of Sequence to some extent; // Cutoff() returns the cutoff time for other code to use (it returns the zero time // if none is set), and Ephemeral() returns whether the backing store is in-memory diff --git a/irc/mysql/history.go b/irc/mysql/history.go index 286a6106..9ec74c2f 100644 --- a/irc/mysql/history.go +++ b/irc/mysql/history.go @@ -64,10 +64,16 @@ type MySQL struct { trackAccountMessages atomic.Uint32 } -func (mysql *MySQL) Initialize(logger *logger.Manager, config Config) { +var _ history.Database = (*MySQL)(nil) + +func NewMySQLDatabase(logger *logger.Manager, config Config) (*MySQL, error) { + var mysql MySQL + mysql.logger = logger mysql.wakeForgetter = make(chan e, 1) mysql.SetConfig(config) + + return &mysql, mysql.open() } func (mysql *MySQL) SetConfig(config Config) { @@ -89,7 +95,7 @@ func (mysql *MySQL) getExpireTime() (expireTime time.Duration) { return } -func (m *MySQL) Open() (err error) { +func (m *MySQL) open() (err error) { var address string if m.config.SocketPath != "" { address = fmt.Sprintf("unix(%s)", m.config.SocketPath) @@ -623,40 +629,46 @@ func (mysql *MySQL) AddChannelItem(target string, item history.Item, account str func (mysql *MySQL) insertSequenceEntry(ctx context.Context, target string, messageTime int64, id int64) (err error) { _, err = mysql.insertSequence.ExecContext(ctx, target, messageTime, id) - mysql.logError("could not insert sequence entry", err) + if err != nil { + return fmt.Errorf("could not insert sequence entry: %w", err) + } return } func (mysql *MySQL) insertConversationEntry(ctx context.Context, target, correspondent string, messageTime int64, id int64) (err error) { _, err = mysql.insertConversation.ExecContext(ctx, target, correspondent, messageTime, id) - mysql.logError("could not insert conversations entry", err) + if err != nil { + return fmt.Errorf("could not insert conversations entry: %w", err) + } return } func (mysql *MySQL) insertCorrespondentsEntry(ctx context.Context, target, correspondent string, messageTime int64, historyId int64) (err error) { _, err = mysql.insertCorrespondent.ExecContext(ctx, target, correspondent, messageTime, messageTime) - mysql.logError("could not insert conversations entry", err) + if err != nil { + return fmt.Errorf("could not insert correspondents entry: %w", err) + } return } func (mysql *MySQL) insertBase(ctx context.Context, item history.Item) (id int64, err error) { value, err := marshalItem(&item) - if mysql.logError("could not marshal item", err) { - return + if err != nil { + return 0, fmt.Errorf("could not marshal item: %w", err) } msgidBytes, err := decodeMsgid(item.Message.Msgid) - if mysql.logError("could not decode msgid", err) { - return + if err != nil { + return 0, fmt.Errorf("could not decode msgid: %w", err) } result, err := mysql.insertHistory.ExecContext(ctx, value, msgidBytes) - if mysql.logError("could not insert item", err) { - return + if err != nil { + return 0, fmt.Errorf("could not insert item: %w", err) } id, err = result.LastInsertId() - if mysql.logError("could not insert item", err) { - return + if err != nil { + return 0, fmt.Errorf("could not insert item: %w", err) } return @@ -667,7 +679,9 @@ func (mysql *MySQL) insertAccountMessageEntry(ctx context.Context, id int64, acc return } _, err = mysql.insertAccountMessage.ExecContext(ctx, id, account) - mysql.logError("could not insert account-message entry", err) + if err != nil { + return fmt.Errorf("could not insert account-message entry: %w", err) + } return } @@ -748,7 +762,9 @@ func (mysql *MySQL) DeleteMsgid(msgid, accountName string) (err error) { } err = mysql.deleteHistoryIDs(ctx, []uint64{id}) - mysql.logError("couldn't delete msgid", err) + if err != nil { + return fmt.Errorf("couldn't delete msgid: %w", err) + } return } @@ -831,10 +847,10 @@ func (mysql *MySQL) lookupMsgid(ctx context.Context, msgid string, includeData b } else { err = row.Scan(&nanoSeq, &nanoConv, &id, &data) } - if err != sql.ErrNoRows { - mysql.logError("could not resolve msgid to time", err) - } if err != nil { + if err != sql.ErrNoRows { + err = fmt.Errorf("could not resolve msgid to time: %w", err) + } return } nanotime := extractNanotime(nanoSeq, nanoConv) @@ -857,8 +873,8 @@ func extractNanotime(seq, conv sql.NullInt64) (result int64) { func (mysql *MySQL) selectItems(ctx context.Context, query string, args ...interface{}) (results []history.Item, err error) { rows, err := mysql.db.QueryContext(ctx, query, args...) - if mysql.logError("could not select history items", err) { - return + if err != nil { + return nil, fmt.Errorf("could not select history items: %w", err) } defer rows.Close() @@ -867,12 +883,12 @@ func (mysql *MySQL) selectItems(ctx context.Context, query string, args ...inter var blob []byte var item history.Item err = rows.Scan(&blob) - if mysql.logError("could not scan history item", err) { - return + if err != nil { + return nil, fmt.Errorf("could not scan history item: %w", err) } err = unmarshalItem(blob, &item) - if mysql.logError("could not unmarshal history item", err) { - return + if err != nil { + return nil, fmt.Errorf("could not unmarshal history item: %w", err) } results = append(results, item) } @@ -949,7 +965,7 @@ func (mysql *MySQL) listCorrespondentsInternal(ctx context.Context, target strin rows, err := mysql.db.QueryContext(ctx, query, args...) if err != nil { - return + return nil, fmt.Errorf("could not query correspondents: %w", err) } defer rows.Close() var correspondent string @@ -957,7 +973,7 @@ func (mysql *MySQL) listCorrespondentsInternal(ctx context.Context, target strin for rows.Next() { err = rows.Scan(&correspondent, &nanotime) if err != nil { - return + return nil, fmt.Errorf("could not scan correspondents: %w", err) } results = append(results, history.TargetListing{ CfName: correspondent, @@ -972,6 +988,19 @@ func (mysql *MySQL) listCorrespondentsInternal(ctx context.Context, target strin return } +func (mysql *MySQL) ListCorrespondents(cftarget string, start, end time.Time, limit int) (results []history.TargetListing, err error) { + ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout()) + defer cancel() + + // TODO accept msgids here? + + results, err = mysql.listCorrespondentsInternal(ctx, cftarget, start, end, time.Time{}, limit) + if err != nil { + return nil, fmt.Errorf("could not read correspondents: %w", err) + } + return +} + func (mysql *MySQL) ListChannels(cfchannels []string) (results []history.TargetListing, err error) { if mysql.db == nil { return @@ -1000,8 +1029,8 @@ func (mysql *MySQL) ListChannels(cfchannels []string) (results []history.TargetL queryBuf.WriteString(") GROUP BY sequence.target;") rows, err := mysql.db.QueryContext(ctx, queryBuf.String(), args...) - if mysql.logError("could not query channel listings", err) { - return + if err != nil { + return nil, fmt.Errorf("could not query channel listings: %w", err) } defer rows.Close() @@ -1009,8 +1038,8 @@ func (mysql *MySQL) ListChannels(cfchannels []string) (results []history.TargetL var nanotime int64 for rows.Next() { err = rows.Scan(&target, &nanotime) - if mysql.logError("could not scan channel listings", err) { - return + if err != nil { + return nil, fmt.Errorf("could not scan channel listings: %w", err) } results = append(results, history.TargetListing{ CfName: target, @@ -1020,12 +1049,13 @@ func (mysql *MySQL) ListChannels(cfchannels []string) (results []history.TargetL return } -func (mysql *MySQL) Close() { +func (mysql *MySQL) Close() error { // closing the database will close our prepared statements as well if mysql.db != nil { mysql.db.Close() } mysql.db = nil + return nil } // implements history.Sequence, emulating a single history buffer (for a channel, @@ -1072,19 +1102,6 @@ func (s *mySQLHistorySequence) Around(start history.Selector, limit int) (result return history.GenericAround(s, start, limit) } -func (seq *mySQLHistorySequence) ListCorrespondents(start, end history.Selector, limit int) (results []history.TargetListing, err error) { - ctx, cancel := context.WithTimeout(context.Background(), seq.mysql.getTimeout()) - defer cancel() - - // TODO accept msgids here? - startTime := start.Time - endTime := end.Time - - results, err = seq.mysql.listCorrespondentsInternal(ctx, seq.target, startTime, endTime, seq.cutoff, limit) - seq.mysql.logError("could not read correspondents", err) - return -} - func (seq *mySQLHistorySequence) Cutoff() time.Time { return seq.cutoff } diff --git a/irc/server.go b/irc/server.go index 641d63f2..94a1021b 100644 --- a/irc/server.go +++ b/irc/server.go @@ -90,7 +90,8 @@ type Server struct { snomasks SnoManager store *buntdb.DB dstore datastore.Datastore - historyDB mysql.MySQL + mysqlHistoryDB *mysql.MySQL + historyDB history.Database torLimiter connection_limits.TorLimiter whoWas WhoWasList stats Stats @@ -153,7 +154,6 @@ func (server *Server) Shutdown() { sdnotify.Stopping() server.logger.Info("server", "Stopping server") - //TODO(dan): Make sure we disallow new nicks for _, client := range server.clients.AllClients() { client.Notice("Server is shutting down") } @@ -162,10 +162,12 @@ func (server *Server) Shutdown() { server.performAlwaysOnMaintenance(false, true) if err := server.store.Close(); err != nil { - server.logger.Error("shutdown", fmt.Sprintln("Could not close datastore:", err)) + server.logger.Error("shutdown", "Could not close datastore", err.Error()) } - server.historyDB.Close() + if err := server.historyDB.Close(); err != nil { + server.logger.Error("shutdown", "Could not close history database", err.Error()) + } server.logger.Info("server", fmt.Sprintf("%s exiting", Ver)) } @@ -804,8 +806,10 @@ func (server *Server) applyConfig(config *Config) (err error) { return err } } else { - if config.Datastore.MySQL.Enabled && config.Datastore.MySQL != oldConfig.Datastore.MySQL { - server.historyDB.SetConfig(config.Datastore.MySQL) + if config.Datastore.MySQL.Enabled && server.mysqlHistoryDB != nil { + if config.Datastore.MySQL != oldConfig.Datastore.MySQL { + server.mysqlHistoryDB.SetConfig(config.Datastore.MySQL) + } } } @@ -1015,12 +1019,14 @@ func (server *Server) loadFromDatastore(config *Config) (err error) { server.accounts.Initialize(server) if config.Datastore.MySQL.Enabled { - server.historyDB.Initialize(server.logger, config.Datastore.MySQL) - err = server.historyDB.Open() + server.mysqlHistoryDB, err = mysql.NewMySQLDatabase(server.logger, config.Datastore.MySQL) if err != nil { server.logger.Error("internal", "could not connect to mysql", err.Error()) return err } + server.historyDB = server.mysqlHistoryDB + } else { + server.historyDB = history.NewNoopDatabase() } return nil @@ -1085,10 +1091,6 @@ func (server *Server) setupListeners(config *Config) (err error) { // we may already know the channel we're querying, or we may have // to look it up via a string query. This function is responsible for // privilege checking. -// XXX: call this with providedChannel==nil and query=="" to get a sequence -// suitable for ListCorrespondents (i.e., this function is still used to -// decide whether the ringbuf or mysql is authoritative about the client's -// message history). func (server *Server) GetHistorySequence(providedChannel *Channel, client *Client, query string) (channel *Channel, sequence history.Sequence, err error) { config := server.Config() // 4 cases: {persistent, ephemeral} x {normal, conversation} diff --git a/irc/znc.go b/irc/znc.go index a118e839..27c4df2c 100644 --- a/irc/znc.go +++ b/irc/znc.go @@ -218,7 +218,7 @@ func zncPlayPrivmsgsFromAll(client *Client, rb *ResponseBuffer, start, end time. // PRIVMSG *playback :list func zncPlaybackListHandler(client *Client, command string, params []string, rb *ResponseBuffer) { limit := client.server.Config().History.ChathistoryMax - correspondents, err := client.listTargets(history.Selector{}, history.Selector{}, limit) + correspondents, err := client.listTargets(time.Time{}, time.Time{}, limit) if err != nil { client.server.logger.Error("internal", "couldn't get history for ZNC list", err.Error()) return