3
0
mirror of https://github.com/ergochat/ergo.git synced 2025-01-05 09:32:32 +01:00

Merge pull request #834 from slingamn/issue833_dm_privacy.1

fix #833
This commit is contained in:
Shivaram Lingamneni 2020-02-28 02:51:47 -08:00 committed by GitHub
commit 98efaf25e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 69 additions and 75 deletions

View File

@ -1563,15 +1563,18 @@ func (client *Client) historyStatus(config *Config) (status HistoryStatus, targe
} }
client.stateMutex.RLock() client.stateMutex.RLock()
loggedIn := client.account != "" target = client.account
historyStatus := client.accountSettings.DMHistory historyStatus := client.accountSettings.DMHistory
target = client.nickCasefolded
client.stateMutex.RUnlock() client.stateMutex.RUnlock()
if !loggedIn { if target == "" {
return HistoryEphemeral, "" return HistoryEphemeral, ""
} }
return historyEnabled(config.History.Persistent.DirectMessages, historyStatus), target status = historyEnabled(config.History.Persistent.DirectMessages, historyStatus)
if status != HistoryPersistent {
target = ""
}
return
} }
// these are bit flags indicating what part of the client status is "dirty" // these are bit flags indicating what part of the client status is "dirty"

View File

@ -2011,11 +2011,9 @@ func dispatchMessageToTarget(client *Client, tags map[string]string, histType hi
item.CfCorrespondent = details.nickCasefolded item.CfCorrespondent = details.nickCasefolded
user.history.Add(item) user.history.Add(item)
} }
cPersistent := cStatus == HistoryPersistent if cStatus == HistoryPersistent || tStatus == HistoryPersistent {
tPersistent := tStatus == HistoryPersistent targetedItem.CfCorrespondent = ""
if cPersistent || tPersistent { server.historyDB.AddDirectMessage(details.nickCasefolded, details.account, tDetails.nickCasefolded, tDetails.account, targetedItem)
item.CfCorrespondent = ""
server.historyDB.AddDirectMessage(details.nickCasefolded, user.NickCasefolded(), cPersistent, tPersistent, targetedItem)
} }
} }
} }

View File

@ -25,7 +25,7 @@ const (
MaxTargetLength = 64 MaxTargetLength = 64
// latest schema of the db // latest schema of the db
latestDbSchema = "1" latestDbSchema = "2"
keySchemaVersion = "db.version" keySchemaVersion = "db.version"
cleanupRowLimit = 50 cleanupRowLimit = 50
cleanupPauseTime = 10 * time.Minute cleanupPauseTime = 10 * time.Minute
@ -144,11 +144,11 @@ func (mysql *MySQL) createTables() (err error) {
_, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE conversations ( _, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE conversations (
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY, id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
lower_target VARBINARY(%[1]d) NOT NULL, target VARBINARY(%[1]d) NOT NULL,
upper_target VARBINARY(%[1]d) NOT NULL, correspondent VARBINARY(%[1]d) NOT NULL,
nanotime BIGINT UNSIGNED NOT NULL, nanotime BIGINT UNSIGNED NOT NULL,
history_id BIGINT NOT NULL, history_id BIGINT NOT NULL,
KEY (lower_target, upper_target, nanotime), KEY (target, correspondent, nanotime),
KEY (history_id) KEY (history_id)
) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength)) ) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength))
if err != nil { if err != nil {
@ -278,7 +278,7 @@ func (mysql *MySQL) prepareStatements() (err error) {
return return
} }
mysql.insertConversation, err = mysql.db.Prepare(`INSERT INTO conversations mysql.insertConversation, err = mysql.db.Prepare(`INSERT INTO conversations
(lower_target, upper_target, nanotime, history_id) VALUES (?, ?, ?, ?);`) (target, correspondent, nanotime, history_id) VALUES (?, ?, ?, ?);`)
if err != nil { if err != nil {
return return
} }
@ -315,19 +315,18 @@ func (mysql *MySQL) AddChannelItem(target string, item history.Item) (err error)
return return
} }
err = mysql.insertSequenceEntry(ctx, target, item.Message.Time, id) err = mysql.insertSequenceEntry(ctx, target, item.Message.Time.UnixNano(), id)
return return
} }
func (mysql *MySQL) insertSequenceEntry(ctx context.Context, target string, messageTime time.Time, id int64) (err error) { func (mysql *MySQL) insertSequenceEntry(ctx context.Context, target string, messageTime int64, id int64) (err error) {
_, err = mysql.insertSequence.ExecContext(ctx, target, messageTime.UnixNano(), id) _, err = mysql.insertSequence.ExecContext(ctx, target, messageTime, id)
mysql.logError("could not insert sequence entry", err) mysql.logError("could not insert sequence entry", err)
return return
} }
func (mysql *MySQL) insertConversationEntry(ctx context.Context, sender, recipient string, messageTime time.Time, id int64) (err error) { func (mysql *MySQL) insertConversationEntry(ctx context.Context, target, correspondent string, messageTime int64, id int64) (err error) {
lower, higher := stringMinMax(sender, recipient) _, err = mysql.insertConversation.ExecContext(ctx, target, correspondent, messageTime, id)
_, err = mysql.insertConversation.ExecContext(ctx, lower, higher, messageTime.UnixNano(), id)
mysql.logError("could not insert conversations entry", err) mysql.logError("could not insert conversations entry", err)
return return
} }
@ -355,20 +354,12 @@ func (mysql *MySQL) insertBase(ctx context.Context, item history.Item) (id int64
return return
} }
func stringMinMax(first, second string) (min, max string) { func (mysql *MySQL) AddDirectMessage(sender, senderAccount, recipient, recipientAccount string, item history.Item) (err error) {
if first < second {
return first, second
} else {
return second, first
}
}
func (mysql *MySQL) AddDirectMessage(sender, recipient string, senderPersistent, recipientPersistent bool, item history.Item) (err error) {
if mysql.db == nil { if mysql.db == nil {
return return
} }
if !(senderPersistent || recipientPersistent) { if senderAccount == "" && recipientAccount == "" {
return return
} }
@ -384,22 +375,30 @@ func (mysql *MySQL) AddDirectMessage(sender, recipient string, senderPersistent,
return return
} }
if senderPersistent { nanotime := item.Message.Time.UnixNano()
mysql.insertSequenceEntry(ctx, sender, item.Message.Time, id)
if senderAccount != "" {
err = mysql.insertSequenceEntry(ctx, senderAccount, nanotime, id)
if err != nil {
return
}
err = mysql.insertConversationEntry(ctx, senderAccount, recipient, nanotime, id)
if err != nil { if err != nil {
return return
} }
} }
if recipientPersistent && sender != recipient { if recipientAccount != "" && sender != recipient {
err = mysql.insertSequenceEntry(ctx, recipient, item.Message.Time, id) err = mysql.insertSequenceEntry(ctx, recipientAccount, nanotime, id)
if err != nil {
return
}
err = mysql.insertConversationEntry(ctx, recipientAccount, sender, nanotime, id)
if err != nil { if err != nil {
return return
} }
} }
err = mysql.insertConversationEntry(ctx, sender, recipient, item.Message.Time, id)
return return
} }
@ -453,14 +452,8 @@ func (mysql *MySQL) selectItems(ctx context.Context, query string, args ...inter
return return
} }
func (mysql *MySQL) betweenTimestamps(ctx context.Context, sender, recipient string, after, before, cutoff time.Time, limit int) (results []history.Item, err error) { func (mysql *MySQL) betweenTimestamps(ctx context.Context, target, correspondent string, after, before, cutoff time.Time, limit int) (results []history.Item, err error) {
useSequence := true useSequence := correspondent == ""
var lowerTarget, upperTarget string
if sender != "" {
lowerTarget, upperTarget = stringMinMax(sender, recipient)
useSequence = false
}
table := "sequence" table := "sequence"
if !useSequence { if !useSequence {
table = "conversations" table = "conversations"
@ -479,11 +472,11 @@ func (mysql *MySQL) betweenTimestamps(ctx context.Context, sender, recipient str
"SELECT history.data from history INNER JOIN %[1]s ON history.id = %[1]s.history_id WHERE", table) "SELECT history.data from history INNER JOIN %[1]s ON history.id = %[1]s.history_id WHERE", table)
if useSequence { if useSequence {
fmt.Fprintf(&queryBuf, " sequence.target = ?") fmt.Fprintf(&queryBuf, " sequence.target = ?")
args = append(args, recipient) args = append(args, target)
} else { } else {
fmt.Fprintf(&queryBuf, " conversations.lower_target = ? AND conversations.upper_target = ?") fmt.Fprintf(&queryBuf, " conversations.target = ? AND conversations.correspondent = ?")
args = append(args, lowerTarget) args = append(args, target)
args = append(args, upperTarget) args = append(args, correspondent)
} }
if !after.IsZero() { if !after.IsZero() {
fmt.Fprintf(&queryBuf, " AND %s.nanotime > ?", table) fmt.Fprintf(&queryBuf, " AND %s.nanotime > ?", table)
@ -515,8 +508,8 @@ func (mysql *MySQL) Close() {
// a single user's DMs, or a DM conversation) // a single user's DMs, or a DM conversation)
type mySQLHistorySequence struct { type mySQLHistorySequence struct {
mysql *MySQL mysql *MySQL
sender string target string
recipient string correspondent string
cutoff time.Time cutoff time.Time
} }
@ -539,7 +532,7 @@ func (s *mySQLHistorySequence) Between(start, end history.Selector, limit int) (
} }
} }
results, err = s.mysql.betweenTimestamps(ctx, s.sender, s.recipient, startTime, endTime, s.cutoff, limit) results, err = s.mysql.betweenTimestamps(ctx, s.target, s.correspondent, startTime, endTime, s.cutoff, limit)
return results, (err == nil), err return results, (err == nil), err
} }
@ -547,10 +540,10 @@ func (s *mySQLHistorySequence) Around(start history.Selector, limit int) (result
return history.GenericAround(s, start, limit) return history.GenericAround(s, start, limit)
} }
func (mysql *MySQL) MakeSequence(sender, recipient string, cutoff time.Time) history.Sequence { func (mysql *MySQL) MakeSequence(target, correspondent string, cutoff time.Time) history.Sequence {
return &mySQLHistorySequence{ return &mySQLHistorySequence{
sender: sender, target: target,
recipient: recipient, correspondent: correspondent,
mysql: mysql, mysql: mysql,
cutoff: cutoff, cutoff: cutoff,
} }

View File

@ -862,22 +862,22 @@ func (server *Server) setupListeners(config *Config) (err error) {
// Gets the abstract sequence from which we're going to query history; // Gets the abstract sequence from which we're going to query history;
// we may already know the channel we're querying, or we may have // we may already know the channel we're querying, or we may have
// to look it up via a string target. This function is responsible for // to look it up via a string query. This function is responsible for
// privilege checking. // privilege checking.
func (server *Server) GetHistorySequence(providedChannel *Channel, client *Client, target string) (channel *Channel, sequence history.Sequence, err error) { func (server *Server) GetHistorySequence(providedChannel *Channel, client *Client, query string) (channel *Channel, sequence history.Sequence, err error) {
config := server.Config() config := server.Config()
// 4 cases: {persistent, ephemeral} x {normal, conversation} // 4 cases: {persistent, ephemeral} x {normal, conversation}
// with ephemeral history, recipient is implicit in the choice of `hist`, // with ephemeral history, target is implicit in the choice of `hist`,
// and sender is "" if we're retrieving a channel or *, and the correspondent's name // and correspondent is "" if we're retrieving a channel or *, and the correspondent's name
// if we're retrieving a DM conversation ("query buffer"). with persistent history, // if we're retrieving a DM conversation ("query buffer"). with persistent history,
// recipient is always nonempty, and sender is either empty or nonempty as before. // target is always nonempty, and correspondent is either empty or nonempty as before.
var status HistoryStatus var status HistoryStatus
var sender, recipient string var target, correspondent string
var hist *history.Buffer var hist *history.Buffer
channel = providedChannel channel = providedChannel
if channel == nil { if channel == nil {
if strings.HasPrefix(target, "#") { if strings.HasPrefix(query, "#") {
channel = server.channels.Get(target) channel = server.channels.Get(query)
if channel == nil { if channel == nil {
return return
} }
@ -888,19 +888,19 @@ func (server *Server) GetHistorySequence(providedChannel *Channel, client *Clien
err = errInsufficientPrivs err = errInsufficientPrivs
return return
} }
status, recipient = channel.historyStatus(config) status, target = channel.historyStatus(config)
switch status { switch status {
case HistoryEphemeral: case HistoryEphemeral:
hist = &channel.history hist = &channel.history
case HistoryPersistent: case HistoryPersistent:
// already set `recipient` // already set `target`
default: default:
return return
} }
} else { } else {
status, recipient = client.historyStatus(config) status, target = client.historyStatus(config)
if target != "*" { if query != "*" {
sender, err = CasefoldName(target) correspondent, err = CasefoldName(query)
if err != nil { if err != nil {
return return
} }
@ -909,7 +909,7 @@ func (server *Server) GetHistorySequence(providedChannel *Channel, client *Clien
case HistoryEphemeral: case HistoryEphemeral:
hist = &client.history hist = &client.history
case HistoryPersistent: case HistoryPersistent:
// already set `recipient`, and `sender` if necessary // already set `target`, and `correspondent` if necessary
default: default:
return return
} }
@ -931,9 +931,9 @@ func (server *Server) GetHistorySequence(providedChannel *Channel, client *Clien
} }
if hist != nil { if hist != nil {
sequence = hist.MakeSequence(sender, cutoff) sequence = hist.MakeSequence(correspondent, cutoff)
} else if recipient != "" { } else if target != "" {
sequence = server.historyDB.MakeSequence(sender, recipient, cutoff) sequence = server.historyDB.MakeSequence(target, correspondent, cutoff)
} }
return return
} }