This commit is contained in:
Shivaram Lingamneni 2020-02-28 05:41:08 -05:00
parent c414ac383c
commit d967129446
4 changed files with 69 additions and 75 deletions

View File

@ -1563,15 +1563,18 @@ func (client *Client) historyStatus(config *Config) (status HistoryStatus, targe
}
client.stateMutex.RLock()
loggedIn := client.account != ""
target = client.account
historyStatus := client.accountSettings.DMHistory
target = client.nickCasefolded
client.stateMutex.RUnlock()
if !loggedIn {
if target == "" {
return HistoryEphemeral, ""
}
return historyEnabled(config.History.Persistent.DirectMessages, historyStatus), target
status = historyEnabled(config.History.Persistent.DirectMessages, historyStatus)
if status != HistoryPersistent {
target = ""
}
return
}
// these are bit flags indicating what part of the client status is "dirty"

View File

@ -2011,11 +2011,9 @@ func dispatchMessageToTarget(client *Client, tags map[string]string, histType hi
item.CfCorrespondent = details.nickCasefolded
user.history.Add(item)
}
cPersistent := cStatus == HistoryPersistent
tPersistent := tStatus == HistoryPersistent
if cPersistent || tPersistent {
item.CfCorrespondent = ""
server.historyDB.AddDirectMessage(details.nickCasefolded, user.NickCasefolded(), cPersistent, tPersistent, targetedItem)
if cStatus == HistoryPersistent || tStatus == HistoryPersistent {
targetedItem.CfCorrespondent = ""
server.historyDB.AddDirectMessage(details.nickCasefolded, details.account, tDetails.nickCasefolded, tDetails.account, targetedItem)
}
}
}

View File

@ -25,7 +25,7 @@ const (
MaxTargetLength = 64
// latest schema of the db
latestDbSchema = "1"
latestDbSchema = "2"
keySchemaVersion = "db.version"
cleanupRowLimit = 50
cleanupPauseTime = 10 * time.Minute
@ -144,11 +144,11 @@ func (mysql *MySQL) createTables() (err error) {
_, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE conversations (
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
lower_target VARBINARY(%[1]d) NOT NULL,
upper_target VARBINARY(%[1]d) NOT NULL,
target VARBINARY(%[1]d) NOT NULL,
correspondent VARBINARY(%[1]d) NOT NULL,
nanotime BIGINT UNSIGNED NOT NULL,
history_id BIGINT NOT NULL,
KEY (lower_target, upper_target, nanotime),
KEY (target, correspondent, nanotime),
KEY (history_id)
) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength))
if err != nil {
@ -278,7 +278,7 @@ func (mysql *MySQL) prepareStatements() (err error) {
return
}
mysql.insertConversation, err = mysql.db.Prepare(`INSERT INTO conversations
(lower_target, upper_target, nanotime, history_id) VALUES (?, ?, ?, ?);`)
(target, correspondent, nanotime, history_id) VALUES (?, ?, ?, ?);`)
if err != nil {
return
}
@ -315,19 +315,18 @@ func (mysql *MySQL) AddChannelItem(target string, item history.Item) (err error)
return
}
err = mysql.insertSequenceEntry(ctx, target, item.Message.Time, id)
err = mysql.insertSequenceEntry(ctx, target, item.Message.Time.UnixNano(), id)
return
}
func (mysql *MySQL) insertSequenceEntry(ctx context.Context, target string, messageTime time.Time, id int64) (err error) {
_, err = mysql.insertSequence.ExecContext(ctx, target, messageTime.UnixNano(), id)
func (mysql *MySQL) insertSequenceEntry(ctx context.Context, target string, messageTime int64, id int64) (err error) {
_, err = mysql.insertSequence.ExecContext(ctx, target, messageTime, id)
mysql.logError("could not insert sequence entry", err)
return
}
func (mysql *MySQL) insertConversationEntry(ctx context.Context, sender, recipient string, messageTime time.Time, id int64) (err error) {
lower, higher := stringMinMax(sender, recipient)
_, err = mysql.insertConversation.ExecContext(ctx, lower, higher, messageTime.UnixNano(), id)
func (mysql *MySQL) insertConversationEntry(ctx context.Context, target, correspondent string, messageTime int64, id int64) (err error) {
_, err = mysql.insertConversation.ExecContext(ctx, target, correspondent, messageTime, id)
mysql.logError("could not insert conversations entry", err)
return
}
@ -355,20 +354,12 @@ func (mysql *MySQL) insertBase(ctx context.Context, item history.Item) (id int64
return
}
func stringMinMax(first, second string) (min, max string) {
if first < second {
return first, second
} else {
return second, first
}
}
func (mysql *MySQL) AddDirectMessage(sender, recipient string, senderPersistent, recipientPersistent bool, item history.Item) (err error) {
func (mysql *MySQL) AddDirectMessage(sender, senderAccount, recipient, recipientAccount string, item history.Item) (err error) {
if mysql.db == nil {
return
}
if !(senderPersistent || recipientPersistent) {
if senderAccount == "" && recipientAccount == "" {
return
}
@ -384,22 +375,30 @@ func (mysql *MySQL) AddDirectMessage(sender, recipient string, senderPersistent,
return
}
if senderPersistent {
mysql.insertSequenceEntry(ctx, sender, item.Message.Time, id)
nanotime := item.Message.Time.UnixNano()
if senderAccount != "" {
err = mysql.insertSequenceEntry(ctx, senderAccount, nanotime, id)
if err != nil {
return
}
err = mysql.insertConversationEntry(ctx, senderAccount, recipient, nanotime, id)
if err != nil {
return
}
}
if recipientPersistent && sender != recipient {
err = mysql.insertSequenceEntry(ctx, recipient, item.Message.Time, id)
if recipientAccount != "" && sender != recipient {
err = mysql.insertSequenceEntry(ctx, recipientAccount, nanotime, id)
if err != nil {
return
}
err = mysql.insertConversationEntry(ctx, recipientAccount, sender, nanotime, id)
if err != nil {
return
}
}
err = mysql.insertConversationEntry(ctx, sender, recipient, item.Message.Time, id)
return
}
@ -453,14 +452,8 @@ func (mysql *MySQL) selectItems(ctx context.Context, query string, args ...inter
return
}
func (mysql *MySQL) betweenTimestamps(ctx context.Context, sender, recipient string, after, before, cutoff time.Time, limit int) (results []history.Item, err error) {
useSequence := true
var lowerTarget, upperTarget string
if sender != "" {
lowerTarget, upperTarget = stringMinMax(sender, recipient)
useSequence = false
}
func (mysql *MySQL) betweenTimestamps(ctx context.Context, target, correspondent string, after, before, cutoff time.Time, limit int) (results []history.Item, err error) {
useSequence := correspondent == ""
table := "sequence"
if !useSequence {
table = "conversations"
@ -479,11 +472,11 @@ func (mysql *MySQL) betweenTimestamps(ctx context.Context, sender, recipient str
"SELECT history.data from history INNER JOIN %[1]s ON history.id = %[1]s.history_id WHERE", table)
if useSequence {
fmt.Fprintf(&queryBuf, " sequence.target = ?")
args = append(args, recipient)
args = append(args, target)
} else {
fmt.Fprintf(&queryBuf, " conversations.lower_target = ? AND conversations.upper_target = ?")
args = append(args, lowerTarget)
args = append(args, upperTarget)
fmt.Fprintf(&queryBuf, " conversations.target = ? AND conversations.correspondent = ?")
args = append(args, target)
args = append(args, correspondent)
}
if !after.IsZero() {
fmt.Fprintf(&queryBuf, " AND %s.nanotime > ?", table)
@ -515,8 +508,8 @@ func (mysql *MySQL) Close() {
// a single user's DMs, or a DM conversation)
type mySQLHistorySequence struct {
mysql *MySQL
sender string
recipient string
target string
correspondent string
cutoff time.Time
}
@ -539,7 +532,7 @@ func (s *mySQLHistorySequence) Between(start, end history.Selector, limit int) (
}
}
results, err = s.mysql.betweenTimestamps(ctx, s.sender, s.recipient, startTime, endTime, s.cutoff, limit)
results, err = s.mysql.betweenTimestamps(ctx, s.target, s.correspondent, startTime, endTime, s.cutoff, limit)
return results, (err == nil), err
}
@ -547,10 +540,10 @@ func (s *mySQLHistorySequence) Around(start history.Selector, limit int) (result
return history.GenericAround(s, start, limit)
}
func (mysql *MySQL) MakeSequence(sender, recipient string, cutoff time.Time) history.Sequence {
func (mysql *MySQL) MakeSequence(target, correspondent string, cutoff time.Time) history.Sequence {
return &mySQLHistorySequence{
sender: sender,
recipient: recipient,
target: target,
correspondent: correspondent,
mysql: mysql,
cutoff: cutoff,
}

View File

@ -862,22 +862,22 @@ func (server *Server) setupListeners(config *Config) (err error) {
// Gets the abstract sequence from which we're going to query history;
// we may already know the channel we're querying, or we may have
// to look it up via a string target. This function is responsible for
// to look it up via a string query. This function is responsible for
// privilege checking.
func (server *Server) GetHistorySequence(providedChannel *Channel, client *Client, target string) (channel *Channel, sequence history.Sequence, err error) {
func (server *Server) GetHistorySequence(providedChannel *Channel, client *Client, query string) (channel *Channel, sequence history.Sequence, err error) {
config := server.Config()
// 4 cases: {persistent, ephemeral} x {normal, conversation}
// with ephemeral history, recipient is implicit in the choice of `hist`,
// and sender is "" if we're retrieving a channel or *, and the correspondent's name
// with ephemeral history, target is implicit in the choice of `hist`,
// and correspondent is "" if we're retrieving a channel or *, and the correspondent's name
// if we're retrieving a DM conversation ("query buffer"). with persistent history,
// recipient is always nonempty, and sender is either empty or nonempty as before.
// target is always nonempty, and correspondent is either empty or nonempty as before.
var status HistoryStatus
var sender, recipient string
var target, correspondent string
var hist *history.Buffer
channel = providedChannel
if channel == nil {
if strings.HasPrefix(target, "#") {
channel = server.channels.Get(target)
if strings.HasPrefix(query, "#") {
channel = server.channels.Get(query)
if channel == nil {
return
}
@ -888,19 +888,19 @@ func (server *Server) GetHistorySequence(providedChannel *Channel, client *Clien
err = errInsufficientPrivs
return
}
status, recipient = channel.historyStatus(config)
status, target = channel.historyStatus(config)
switch status {
case HistoryEphemeral:
hist = &channel.history
case HistoryPersistent:
// already set `recipient`
// already set `target`
default:
return
}
} else {
status, recipient = client.historyStatus(config)
if target != "*" {
sender, err = CasefoldName(target)
status, target = client.historyStatus(config)
if query != "*" {
correspondent, err = CasefoldName(query)
if err != nil {
return
}
@ -909,7 +909,7 @@ func (server *Server) GetHistorySequence(providedChannel *Channel, client *Clien
case HistoryEphemeral:
hist = &client.history
case HistoryPersistent:
// already set `recipient`, and `sender` if necessary
// already set `target`, and `correspondent` if necessary
default:
return
}
@ -931,9 +931,9 @@ func (server *Server) GetHistorySequence(providedChannel *Channel, client *Clien
}
if hist != nil {
sequence = hist.MakeSequence(sender, cutoff)
} else if recipient != "" {
sequence = server.historyDB.MakeSequence(sender, recipient, cutoff)
sequence = hist.MakeSequence(correspondent, cutoff)
} else if target != "" {
sequence = server.historyDB.MakeSequence(target, correspondent, cutoff)
}
return
}