From d967129446121dd47ba9aca8255dac28f359a2c4 Mon Sep 17 00:00:00 2001 From: Shivaram Lingamneni Date: Fri, 28 Feb 2020 05:41:08 -0500 Subject: [PATCH] fix #833 --- irc/client.go | 11 ++++-- irc/handlers.go | 8 ++-- irc/mysql/history.go | 91 ++++++++++++++++++++------------------------ irc/server.go | 34 ++++++++--------- 4 files changed, 69 insertions(+), 75 deletions(-) diff --git a/irc/client.go b/irc/client.go index 2bda49ae..9e1e3ff3 100644 --- a/irc/client.go +++ b/irc/client.go @@ -1563,15 +1563,18 @@ func (client *Client) historyStatus(config *Config) (status HistoryStatus, targe } client.stateMutex.RLock() - loggedIn := client.account != "" + target = client.account historyStatus := client.accountSettings.DMHistory - target = client.nickCasefolded client.stateMutex.RUnlock() - if !loggedIn { + if target == "" { return HistoryEphemeral, "" } - return historyEnabled(config.History.Persistent.DirectMessages, historyStatus), target + status = historyEnabled(config.History.Persistent.DirectMessages, historyStatus) + if status != HistoryPersistent { + target = "" + } + return } // these are bit flags indicating what part of the client status is "dirty" diff --git a/irc/handlers.go b/irc/handlers.go index 4e898e11..1d5592df 100644 --- a/irc/handlers.go +++ b/irc/handlers.go @@ -2011,11 +2011,9 @@ func dispatchMessageToTarget(client *Client, tags map[string]string, histType hi item.CfCorrespondent = details.nickCasefolded user.history.Add(item) } - cPersistent := cStatus == HistoryPersistent - tPersistent := tStatus == HistoryPersistent - if cPersistent || tPersistent { - item.CfCorrespondent = "" - server.historyDB.AddDirectMessage(details.nickCasefolded, user.NickCasefolded(), cPersistent, tPersistent, targetedItem) + if cStatus == HistoryPersistent || tStatus == HistoryPersistent { + targetedItem.CfCorrespondent = "" + server.historyDB.AddDirectMessage(details.nickCasefolded, details.account, tDetails.nickCasefolded, tDetails.account, targetedItem) } } } diff --git a/irc/mysql/history.go b/irc/mysql/history.go index c14c9d37..0081ccce 100644 --- a/irc/mysql/history.go +++ b/irc/mysql/history.go @@ -25,7 +25,7 @@ const ( MaxTargetLength = 64 // latest schema of the db - latestDbSchema = "1" + latestDbSchema = "2" keySchemaVersion = "db.version" cleanupRowLimit = 50 cleanupPauseTime = 10 * time.Minute @@ -144,11 +144,11 @@ func (mysql *MySQL) createTables() (err error) { _, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE conversations ( id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY, - lower_target VARBINARY(%[1]d) NOT NULL, - upper_target VARBINARY(%[1]d) NOT NULL, + target VARBINARY(%[1]d) NOT NULL, + correspondent VARBINARY(%[1]d) NOT NULL, nanotime BIGINT UNSIGNED NOT NULL, history_id BIGINT NOT NULL, - KEY (lower_target, upper_target, nanotime), + KEY (target, correspondent, nanotime), KEY (history_id) ) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength)) if err != nil { @@ -278,7 +278,7 @@ func (mysql *MySQL) prepareStatements() (err error) { return } mysql.insertConversation, err = mysql.db.Prepare(`INSERT INTO conversations - (lower_target, upper_target, nanotime, history_id) VALUES (?, ?, ?, ?);`) + (target, correspondent, nanotime, history_id) VALUES (?, ?, ?, ?);`) if err != nil { return } @@ -315,19 +315,18 @@ func (mysql *MySQL) AddChannelItem(target string, item history.Item) (err error) return } - err = mysql.insertSequenceEntry(ctx, target, item.Message.Time, id) + err = mysql.insertSequenceEntry(ctx, target, item.Message.Time.UnixNano(), id) return } -func (mysql *MySQL) insertSequenceEntry(ctx context.Context, target string, messageTime time.Time, id int64) (err error) { - _, err = mysql.insertSequence.ExecContext(ctx, target, messageTime.UnixNano(), id) +func (mysql *MySQL) insertSequenceEntry(ctx context.Context, target string, messageTime int64, id int64) (err error) { + _, err = mysql.insertSequence.ExecContext(ctx, target, messageTime, id) mysql.logError("could not insert sequence entry", err) return } -func (mysql *MySQL) insertConversationEntry(ctx context.Context, sender, recipient string, messageTime time.Time, id int64) (err error) { - lower, higher := stringMinMax(sender, recipient) - _, err = mysql.insertConversation.ExecContext(ctx, lower, higher, messageTime.UnixNano(), id) +func (mysql *MySQL) insertConversationEntry(ctx context.Context, target, correspondent string, messageTime int64, id int64) (err error) { + _, err = mysql.insertConversation.ExecContext(ctx, target, correspondent, messageTime, id) mysql.logError("could not insert conversations entry", err) return } @@ -355,20 +354,12 @@ func (mysql *MySQL) insertBase(ctx context.Context, item history.Item) (id int64 return } -func stringMinMax(first, second string) (min, max string) { - if first < second { - return first, second - } else { - return second, first - } -} - -func (mysql *MySQL) AddDirectMessage(sender, recipient string, senderPersistent, recipientPersistent bool, item history.Item) (err error) { +func (mysql *MySQL) AddDirectMessage(sender, senderAccount, recipient, recipientAccount string, item history.Item) (err error) { if mysql.db == nil { return } - if !(senderPersistent || recipientPersistent) { + if senderAccount == "" && recipientAccount == "" { return } @@ -384,22 +375,30 @@ func (mysql *MySQL) AddDirectMessage(sender, recipient string, senderPersistent, return } - if senderPersistent { - mysql.insertSequenceEntry(ctx, sender, item.Message.Time, id) + nanotime := item.Message.Time.UnixNano() + + if senderAccount != "" { + err = mysql.insertSequenceEntry(ctx, senderAccount, nanotime, id) + if err != nil { + return + } + err = mysql.insertConversationEntry(ctx, senderAccount, recipient, nanotime, id) if err != nil { return } } - if recipientPersistent && sender != recipient { - err = mysql.insertSequenceEntry(ctx, recipient, item.Message.Time, id) + if recipientAccount != "" && sender != recipient { + err = mysql.insertSequenceEntry(ctx, recipientAccount, nanotime, id) + if err != nil { + return + } + err = mysql.insertConversationEntry(ctx, recipientAccount, sender, nanotime, id) if err != nil { return } } - err = mysql.insertConversationEntry(ctx, sender, recipient, item.Message.Time, id) - return } @@ -453,14 +452,8 @@ func (mysql *MySQL) selectItems(ctx context.Context, query string, args ...inter return } -func (mysql *MySQL) betweenTimestamps(ctx context.Context, sender, recipient string, after, before, cutoff time.Time, limit int) (results []history.Item, err error) { - useSequence := true - var lowerTarget, upperTarget string - if sender != "" { - lowerTarget, upperTarget = stringMinMax(sender, recipient) - useSequence = false - } - +func (mysql *MySQL) betweenTimestamps(ctx context.Context, target, correspondent string, after, before, cutoff time.Time, limit int) (results []history.Item, err error) { + useSequence := correspondent == "" table := "sequence" if !useSequence { table = "conversations" @@ -479,11 +472,11 @@ func (mysql *MySQL) betweenTimestamps(ctx context.Context, sender, recipient str "SELECT history.data from history INNER JOIN %[1]s ON history.id = %[1]s.history_id WHERE", table) if useSequence { fmt.Fprintf(&queryBuf, " sequence.target = ?") - args = append(args, recipient) + args = append(args, target) } else { - fmt.Fprintf(&queryBuf, " conversations.lower_target = ? AND conversations.upper_target = ?") - args = append(args, lowerTarget) - args = append(args, upperTarget) + fmt.Fprintf(&queryBuf, " conversations.target = ? AND conversations.correspondent = ?") + args = append(args, target) + args = append(args, correspondent) } if !after.IsZero() { fmt.Fprintf(&queryBuf, " AND %s.nanotime > ?", table) @@ -514,10 +507,10 @@ func (mysql *MySQL) Close() { // implements history.Sequence, emulating a single history buffer (for a channel, // a single user's DMs, or a DM conversation) type mySQLHistorySequence struct { - mysql *MySQL - sender string - recipient string - cutoff time.Time + mysql *MySQL + target string + correspondent string + cutoff time.Time } func (s *mySQLHistorySequence) Between(start, end history.Selector, limit int) (results []history.Item, complete bool, err error) { @@ -539,7 +532,7 @@ func (s *mySQLHistorySequence) Between(start, end history.Selector, limit int) ( } } - results, err = s.mysql.betweenTimestamps(ctx, s.sender, s.recipient, startTime, endTime, s.cutoff, limit) + results, err = s.mysql.betweenTimestamps(ctx, s.target, s.correspondent, startTime, endTime, s.cutoff, limit) return results, (err == nil), err } @@ -547,11 +540,11 @@ func (s *mySQLHistorySequence) Around(start history.Selector, limit int) (result return history.GenericAround(s, start, limit) } -func (mysql *MySQL) MakeSequence(sender, recipient string, cutoff time.Time) history.Sequence { +func (mysql *MySQL) MakeSequence(target, correspondent string, cutoff time.Time) history.Sequence { return &mySQLHistorySequence{ - sender: sender, - recipient: recipient, - mysql: mysql, - cutoff: cutoff, + target: target, + correspondent: correspondent, + mysql: mysql, + cutoff: cutoff, } } diff --git a/irc/server.go b/irc/server.go index ca1ea485..c7007fe6 100644 --- a/irc/server.go +++ b/irc/server.go @@ -862,22 +862,22 @@ func (server *Server) setupListeners(config *Config) (err error) { // Gets the abstract sequence from which we're going to query history; // we may already know the channel we're querying, or we may have -// to look it up via a string target. This function is responsible for +// to look it up via a string query. This function is responsible for // privilege checking. -func (server *Server) GetHistorySequence(providedChannel *Channel, client *Client, target string) (channel *Channel, sequence history.Sequence, err error) { +func (server *Server) GetHistorySequence(providedChannel *Channel, client *Client, query string) (channel *Channel, sequence history.Sequence, err error) { config := server.Config() // 4 cases: {persistent, ephemeral} x {normal, conversation} - // with ephemeral history, recipient is implicit in the choice of `hist`, - // and sender is "" if we're retrieving a channel or *, and the correspondent's name + // with ephemeral history, target is implicit in the choice of `hist`, + // and correspondent is "" if we're retrieving a channel or *, and the correspondent's name // if we're retrieving a DM conversation ("query buffer"). with persistent history, - // recipient is always nonempty, and sender is either empty or nonempty as before. + // target is always nonempty, and correspondent is either empty or nonempty as before. var status HistoryStatus - var sender, recipient string + var target, correspondent string var hist *history.Buffer channel = providedChannel if channel == nil { - if strings.HasPrefix(target, "#") { - channel = server.channels.Get(target) + if strings.HasPrefix(query, "#") { + channel = server.channels.Get(query) if channel == nil { return } @@ -888,19 +888,19 @@ func (server *Server) GetHistorySequence(providedChannel *Channel, client *Clien err = errInsufficientPrivs return } - status, recipient = channel.historyStatus(config) + status, target = channel.historyStatus(config) switch status { case HistoryEphemeral: hist = &channel.history case HistoryPersistent: - // already set `recipient` + // already set `target` default: return } } else { - status, recipient = client.historyStatus(config) - if target != "*" { - sender, err = CasefoldName(target) + status, target = client.historyStatus(config) + if query != "*" { + correspondent, err = CasefoldName(query) if err != nil { return } @@ -909,7 +909,7 @@ func (server *Server) GetHistorySequence(providedChannel *Channel, client *Clien case HistoryEphemeral: hist = &client.history case HistoryPersistent: - // already set `recipient`, and `sender` if necessary + // already set `target`, and `correspondent` if necessary default: return } @@ -931,9 +931,9 @@ func (server *Server) GetHistorySequence(providedChannel *Channel, client *Clien } if hist != nil { - sequence = hist.MakeSequence(sender, cutoff) - } else if recipient != "" { - sequence = server.historyDB.MakeSequence(sender, recipient, cutoff) + sequence = hist.MakeSequence(correspondent, cutoff) + } else if target != "" { + sequence = server.historyDB.MakeSequence(target, correspondent, cutoff) } return }