3
0
mirror of https://github.com/ergochat/ergo.git synced 2024-11-29 07:29:31 +01:00

add mysql timeouts

This commit is contained in:
Shivaram Lingamneni 2020-02-20 18:33:48 -05:00
parent 8123e3c08f
commit 98a7b45d96
5 changed files with 86 additions and 49 deletions

View File

@ -504,14 +504,7 @@ type Config struct {
Datastore struct { Datastore struct {
Path string Path string
AutoUpgrade bool AutoUpgrade bool
MySQL struct { MySQL mysql.Config
Enabled bool
Host string
Port int
User string
Password string
HistoryDatabase string `yaml:"history-database"`
}
} }
Accounts AccountConfig Accounts AccountConfig
@ -1069,6 +1062,8 @@ func LoadConfig(filename string) (config *Config, err error) {
config.History.ZNCMax = config.History.ChathistoryMax config.History.ZNCMax = config.History.ChathistoryMax
} }
config.Datastore.MySQL.ExpireTime = time.Duration(config.History.Restrictions.ExpireTime)
config.Server.Cloaks.Initialize() config.Server.Cloaks.Initialize()
if config.Server.Cloaks.Enabled { if config.Server.Cloaks.Enabled {
if config.Server.Cloaks.Secret == "" || config.Server.Cloaks.Secret == "siaELnk6Kaeo65K3RCrwJjlWaZ-Bt3WuZ2L8MXLbNb4" { if config.Server.Cloaks.Secret == "" || config.Server.Cloaks.Secret == "siaELnk6Kaeo65K3RCrwJjlWaZ-Bt3WuZ2L8MXLbNb4" {

22
irc/mysql/config.go Normal file
View File

@ -0,0 +1,22 @@
// Copyright (c) 2020 Shivaram Lingamneni
// released under the MIT license
package mysql
import (
"time"
)
type Config struct {
// these are intended to be written directly into the config file:
Enabled bool
Host string
Port int
User string
Password string
HistoryDatabase string `yaml:"history-database"`
Timeout time.Duration
// XXX these are copied from elsewhere in the config:
ExpireTime time.Duration
}

View File

@ -1,11 +1,16 @@
// Copyright (c) 2020 Shivaram Lingamneni
// released under the MIT license
package mysql package mysql
import ( import (
"bytes" "bytes"
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"runtime/debug" "runtime/debug"
"sync" "sync"
"sync/atomic"
"time" "time"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
@ -27,58 +32,59 @@ const (
) )
type MySQL struct { type MySQL struct {
db *sql.DB timeout int64
logger *logger.Manager db *sql.DB
logger *logger.Manager
insertHistory *sql.Stmt insertHistory *sql.Stmt
insertSequence *sql.Stmt insertSequence *sql.Stmt
insertConversation *sql.Stmt insertConversation *sql.Stmt
stateMutex sync.Mutex stateMutex sync.Mutex
expireTime time.Duration config Config
} }
func (mysql *MySQL) Initialize(logger *logger.Manager, expireTime time.Duration) { func (mysql *MySQL) Initialize(logger *logger.Manager, config Config) {
mysql.logger = logger mysql.logger = logger
mysql.expireTime = expireTime mysql.SetConfig(config)
} }
func (mysql *MySQL) SetExpireTime(expireTime time.Duration) { func (mysql *MySQL) SetConfig(config Config) {
atomic.StoreInt64(&mysql.timeout, int64(config.Timeout))
mysql.stateMutex.Lock() mysql.stateMutex.Lock()
mysql.expireTime = expireTime mysql.config = config
mysql.stateMutex.Unlock() mysql.stateMutex.Unlock()
} }
func (mysql *MySQL) getExpireTime() (expireTime time.Duration) { func (mysql *MySQL) getExpireTime() (expireTime time.Duration) {
mysql.stateMutex.Lock() mysql.stateMutex.Lock()
expireTime = mysql.expireTime expireTime = mysql.config.ExpireTime
mysql.stateMutex.Unlock() mysql.stateMutex.Unlock()
return return
} }
func (mysql *MySQL) Open(username, password, host string, port int, database string) (err error) { func (m *MySQL) Open() (err error) {
// TODO: timeouts!
var address string var address string
if port != 0 { if m.config.Port != 0 {
address = fmt.Sprintf("tcp(%s:%d)", host, port) address = fmt.Sprintf("tcp(%s:%d)", m.config.Host, m.config.Port)
} }
mysql.db, err = sql.Open("mysql", fmt.Sprintf("%s:%s@%s/%s", username, password, address, database)) m.db, err = sql.Open("mysql", fmt.Sprintf("%s:%s@%s/%s", m.config.User, m.config.Password, address, m.config.HistoryDatabase))
if err != nil { if err != nil {
return err return err
} }
err = mysql.fixSchemas() err = m.fixSchemas()
if err != nil { if err != nil {
return err return err
} }
err = mysql.prepareStatements() err = m.prepareStatements()
if err != nil { if err != nil {
return err return err
} }
go mysql.cleanupLoop() go m.cleanupLoop()
return nil return nil
} }
@ -280,6 +286,10 @@ func (mysql *MySQL) prepareStatements() (err error) {
return return
} }
func (mysql *MySQL) getTimeout() time.Duration {
return time.Duration(atomic.LoadInt64(&mysql.timeout))
}
func (mysql *MySQL) logError(context string, err error) (quit bool) { func (mysql *MySQL) logError(context string, err error) (quit bool) {
if err != nil { if err != nil {
mysql.logger.Error("mysql", context, err.Error()) mysql.logger.Error("mysql", context, err.Error())
@ -297,29 +307,32 @@ func (mysql *MySQL) AddChannelItem(target string, item history.Item) (err error)
return utils.ErrInvalidParams return utils.ErrInvalidParams
} }
id, err := mysql.insertBase(item) ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
defer cancel()
id, err := mysql.insertBase(ctx, item)
if err != nil { if err != nil {
return return
} }
err = mysql.insertSequenceEntry(target, item.Message.Time, id) err = mysql.insertSequenceEntry(ctx, target, item.Message.Time, id)
return return
} }
func (mysql *MySQL) insertSequenceEntry(target string, messageTime time.Time, id int64) (err error) { func (mysql *MySQL) insertSequenceEntry(ctx context.Context, target string, messageTime time.Time, id int64) (err error) {
_, err = mysql.insertSequence.Exec(target, messageTime.UnixNano(), id) _, err = mysql.insertSequence.ExecContext(ctx, target, messageTime.UnixNano(), id)
mysql.logError("could not insert sequence entry", err) mysql.logError("could not insert sequence entry", err)
return return
} }
func (mysql *MySQL) insertConversationEntry(sender, recipient string, messageTime time.Time, id int64) (err error) { func (mysql *MySQL) insertConversationEntry(ctx context.Context, sender, recipient string, messageTime time.Time, id int64) (err error) {
lower, higher := stringMinMax(sender, recipient) lower, higher := stringMinMax(sender, recipient)
_, err = mysql.insertConversation.Exec(lower, higher, messageTime.UnixNano(), id) _, err = mysql.insertConversation.ExecContext(ctx, lower, higher, messageTime.UnixNano(), id)
mysql.logError("could not insert conversations entry", err) mysql.logError("could not insert conversations entry", err)
return return
} }
func (mysql *MySQL) insertBase(item history.Item) (id int64, err error) { func (mysql *MySQL) insertBase(ctx context.Context, item history.Item) (id int64, err error) {
value, err := marshalItem(&item) value, err := marshalItem(&item)
if mysql.logError("could not marshal item", err) { if mysql.logError("could not marshal item", err) {
return return
@ -330,7 +343,7 @@ func (mysql *MySQL) insertBase(item history.Item) (id int64, err error) {
return return
} }
result, err := mysql.insertHistory.Exec(value, msgidBytes) result, err := mysql.insertHistory.ExecContext(ctx, value, msgidBytes)
if mysql.logError("could not insert item", err) { if mysql.logError("could not insert item", err) {
return return
} }
@ -363,31 +376,34 @@ func (mysql *MySQL) AddDirectMessage(sender, recipient string, senderPersistent,
return utils.ErrInvalidParams return utils.ErrInvalidParams
} }
id, err := mysql.insertBase(item) ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
defer cancel()
id, err := mysql.insertBase(ctx, item)
if err != nil { if err != nil {
return return
} }
if senderPersistent { if senderPersistent {
mysql.insertSequenceEntry(sender, item.Message.Time, id) mysql.insertSequenceEntry(ctx, sender, item.Message.Time, id)
if err != nil { if err != nil {
return return
} }
} }
if recipientPersistent && sender != recipient { if recipientPersistent && sender != recipient {
err = mysql.insertSequenceEntry(recipient, item.Message.Time, id) err = mysql.insertSequenceEntry(ctx, recipient, item.Message.Time, id)
if err != nil { if err != nil {
return return
} }
} }
err = mysql.insertConversationEntry(sender, recipient, item.Message.Time, id) err = mysql.insertConversationEntry(ctx, sender, recipient, item.Message.Time, id)
return return
} }
func (mysql *MySQL) msgidToTime(msgid string) (result time.Time, err error) { func (mysql *MySQL) msgidToTime(ctx context.Context, msgid string) (result time.Time, err error) {
// in theory, we could optimize out a roundtrip to the database by using a subquery instead: // in theory, we could optimize out a roundtrip to the database by using a subquery instead:
// sequence.nanotime > ( // sequence.nanotime > (
// SELECT sequence.nanotime FROM sequence, history // SELECT sequence.nanotime FROM sequence, history
@ -400,7 +416,7 @@ func (mysql *MySQL) msgidToTime(msgid string) (result time.Time, err error) {
if err != nil { if err != nil {
return return
} }
row := mysql.db.QueryRow(` row := mysql.db.QueryRowContext(ctx, `
SELECT sequence.nanotime FROM sequence SELECT sequence.nanotime FROM sequence
INNER JOIN history ON history.id = sequence.history_id INNER JOIN history ON history.id = sequence.history_id
WHERE history.msgid = ? LIMIT 1;`, decoded) WHERE history.msgid = ? LIMIT 1;`, decoded)
@ -413,8 +429,8 @@ func (mysql *MySQL) msgidToTime(msgid string) (result time.Time, err error) {
return return
} }
func (mysql *MySQL) selectItems(query string, args ...interface{}) (results []history.Item, err error) { func (mysql *MySQL) selectItems(ctx context.Context, query string, args ...interface{}) (results []history.Item, err error) {
rows, err := mysql.db.Query(query, args...) rows, err := mysql.db.QueryContext(ctx, query, args...)
if mysql.logError("could not select history items", err) { if mysql.logError("could not select history items", err) {
return return
} }
@ -437,7 +453,7 @@ func (mysql *MySQL) selectItems(query string, args ...interface{}) (results []hi
return return
} }
func (mysql *MySQL) BetweenTimestamps(sender, recipient string, after, before, cutoff time.Time, limit int) (results []history.Item, err error) { func (mysql *MySQL) betweenTimestamps(ctx context.Context, sender, recipient string, after, before, cutoff time.Time, limit int) (results []history.Item, err error) {
useSequence := true useSequence := true
var lowerTarget, upperTarget string var lowerTarget, upperTarget string
if sender != "" { if sender != "" {
@ -480,7 +496,7 @@ func (mysql *MySQL) BetweenTimestamps(sender, recipient string, after, before, c
fmt.Fprintf(&queryBuf, " ORDER BY %[1]s.nanotime %[2]s LIMIT ?;", table, direction) fmt.Fprintf(&queryBuf, " ORDER BY %[1]s.nanotime %[2]s LIMIT ?;", table, direction)
args = append(args, limit) args = append(args, limit)
results, err = mysql.selectItems(queryBuf.String(), args...) results, err = mysql.selectItems(ctx, queryBuf.String(), args...)
if err == nil && !ascending { if err == nil && !ascending {
history.Reverse(results) history.Reverse(results)
} }
@ -505,22 +521,25 @@ type mySQLHistorySequence struct {
} }
func (s *mySQLHistorySequence) Between(start, end history.Selector, limit int) (results []history.Item, complete bool, err error) { func (s *mySQLHistorySequence) Between(start, end history.Selector, limit int) (results []history.Item, complete bool, err error) {
ctx, cancel := context.WithTimeout(context.Background(), s.mysql.getTimeout())
defer cancel()
startTime := start.Time startTime := start.Time
if start.Msgid != "" { if start.Msgid != "" {
startTime, err = s.mysql.msgidToTime(start.Msgid) startTime, err = s.mysql.msgidToTime(ctx, start.Msgid)
if err != nil { if err != nil {
return nil, false, err return nil, false, err
} }
} }
endTime := end.Time endTime := end.Time
if end.Msgid != "" { if end.Msgid != "" {
endTime, err = s.mysql.msgidToTime(end.Msgid) endTime, err = s.mysql.msgidToTime(ctx, end.Msgid)
if err != nil { if err != nil {
return nil, false, err return nil, false, err
} }
} }
results, err = s.mysql.BetweenTimestamps(s.sender, s.recipient, startTime, endTime, s.cutoff, limit) results, err = s.mysql.betweenTimestamps(ctx, s.sender, s.recipient, startTime, endTime, s.cutoff, limit)
return results, (err == nil), err return results, (err == nil), err
} }

View File

@ -669,8 +669,8 @@ func (server *Server) applyConfig(config *Config) (err error) {
return err return err
} }
} else { } else {
if config.Datastore.MySQL.Enabled { if config.Datastore.MySQL.Enabled && config.Datastore.MySQL != oldConfig.Datastore.MySQL {
server.historyDB.SetExpireTime(time.Duration(config.History.Restrictions.ExpireTime)) server.historyDB.SetConfig(config.Datastore.MySQL)
} }
} }
@ -793,8 +793,8 @@ func (server *Server) loadDatastore(config *Config) error {
server.accounts.Initialize(server) server.accounts.Initialize(server)
if config.Datastore.MySQL.Enabled { if config.Datastore.MySQL.Enabled {
server.historyDB.Initialize(server.logger, time.Duration(config.History.Restrictions.ExpireTime)) server.historyDB.Initialize(server.logger, config.Datastore.MySQL)
err = server.historyDB.Open(config.Datastore.MySQL.User, config.Datastore.MySQL.Password, config.Datastore.MySQL.Host, config.Datastore.MySQL.Port, config.Datastore.MySQL.HistoryDatabase) err = server.historyDB.Open()
if err != nil { if err != nil {
server.logger.Error("internal", "could not connect to mysql", err.Error()) server.logger.Error("internal", "could not connect to mysql", err.Error())
return err return err

View File

@ -608,6 +608,7 @@ datastore:
user: "oragono" user: "oragono"
password: "KOHw8WSaRwaoo-avo0qVpQ" password: "KOHw8WSaRwaoo-avo0qVpQ"
history-database: "oragono_history" history-database: "oragono_history"
timeout: 3s
# languages config # languages config
languages: languages: