3
0
mirror of https://github.com/ergochat/ergo.git synced 2024-11-10 22:19:31 +01:00

add mysql timeouts

This commit is contained in:
Shivaram Lingamneni 2020-02-20 18:33:48 -05:00
parent 8123e3c08f
commit 98a7b45d96
5 changed files with 86 additions and 49 deletions

View File

@ -504,14 +504,7 @@ type Config struct {
Datastore struct {
Path string
AutoUpgrade bool
MySQL struct {
Enabled bool
Host string
Port int
User string
Password string
HistoryDatabase string `yaml:"history-database"`
}
MySQL mysql.Config
}
Accounts AccountConfig
@ -1069,6 +1062,8 @@ func LoadConfig(filename string) (config *Config, err error) {
config.History.ZNCMax = config.History.ChathistoryMax
}
config.Datastore.MySQL.ExpireTime = time.Duration(config.History.Restrictions.ExpireTime)
config.Server.Cloaks.Initialize()
if config.Server.Cloaks.Enabled {
if config.Server.Cloaks.Secret == "" || config.Server.Cloaks.Secret == "siaELnk6Kaeo65K3RCrwJjlWaZ-Bt3WuZ2L8MXLbNb4" {

22
irc/mysql/config.go Normal file
View File

@ -0,0 +1,22 @@
// Copyright (c) 2020 Shivaram Lingamneni
// released under the MIT license
package mysql
import (
"time"
)
type Config struct {
// these are intended to be written directly into the config file:
Enabled bool
Host string
Port int
User string
Password string
HistoryDatabase string `yaml:"history-database"`
Timeout time.Duration
// XXX these are copied from elsewhere in the config:
ExpireTime time.Duration
}

View File

@ -1,11 +1,16 @@
// Copyright (c) 2020 Shivaram Lingamneni
// released under the MIT license
package mysql
import (
"bytes"
"context"
"database/sql"
"fmt"
"runtime/debug"
"sync"
"sync/atomic"
"time"
_ "github.com/go-sql-driver/mysql"
@ -27,58 +32,59 @@ const (
)
type MySQL struct {
db *sql.DB
logger *logger.Manager
timeout int64
db *sql.DB
logger *logger.Manager
insertHistory *sql.Stmt
insertSequence *sql.Stmt
insertConversation *sql.Stmt
stateMutex sync.Mutex
expireTime time.Duration
config Config
}
func (mysql *MySQL) Initialize(logger *logger.Manager, expireTime time.Duration) {
func (mysql *MySQL) Initialize(logger *logger.Manager, config Config) {
mysql.logger = logger
mysql.expireTime = expireTime
mysql.SetConfig(config)
}
func (mysql *MySQL) SetExpireTime(expireTime time.Duration) {
func (mysql *MySQL) SetConfig(config Config) {
atomic.StoreInt64(&mysql.timeout, int64(config.Timeout))
mysql.stateMutex.Lock()
mysql.expireTime = expireTime
mysql.config = config
mysql.stateMutex.Unlock()
}
func (mysql *MySQL) getExpireTime() (expireTime time.Duration) {
mysql.stateMutex.Lock()
expireTime = mysql.expireTime
expireTime = mysql.config.ExpireTime
mysql.stateMutex.Unlock()
return
}
func (mysql *MySQL) Open(username, password, host string, port int, database string) (err error) {
// TODO: timeouts!
func (m *MySQL) Open() (err error) {
var address string
if port != 0 {
address = fmt.Sprintf("tcp(%s:%d)", host, port)
if m.config.Port != 0 {
address = fmt.Sprintf("tcp(%s:%d)", m.config.Host, m.config.Port)
}
mysql.db, err = sql.Open("mysql", fmt.Sprintf("%s:%s@%s/%s", username, password, address, database))
m.db, err = sql.Open("mysql", fmt.Sprintf("%s:%s@%s/%s", m.config.User, m.config.Password, address, m.config.HistoryDatabase))
if err != nil {
return err
}
err = mysql.fixSchemas()
err = m.fixSchemas()
if err != nil {
return err
}
err = mysql.prepareStatements()
err = m.prepareStatements()
if err != nil {
return err
}
go mysql.cleanupLoop()
go m.cleanupLoop()
return nil
}
@ -280,6 +286,10 @@ func (mysql *MySQL) prepareStatements() (err error) {
return
}
func (mysql *MySQL) getTimeout() time.Duration {
return time.Duration(atomic.LoadInt64(&mysql.timeout))
}
func (mysql *MySQL) logError(context string, err error) (quit bool) {
if err != nil {
mysql.logger.Error("mysql", context, err.Error())
@ -297,29 +307,32 @@ func (mysql *MySQL) AddChannelItem(target string, item history.Item) (err error)
return utils.ErrInvalidParams
}
id, err := mysql.insertBase(item)
ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
defer cancel()
id, err := mysql.insertBase(ctx, item)
if err != nil {
return
}
err = mysql.insertSequenceEntry(target, item.Message.Time, id)
err = mysql.insertSequenceEntry(ctx, target, item.Message.Time, id)
return
}
func (mysql *MySQL) insertSequenceEntry(target string, messageTime time.Time, id int64) (err error) {
_, err = mysql.insertSequence.Exec(target, messageTime.UnixNano(), id)
func (mysql *MySQL) insertSequenceEntry(ctx context.Context, target string, messageTime time.Time, id int64) (err error) {
_, err = mysql.insertSequence.ExecContext(ctx, target, messageTime.UnixNano(), id)
mysql.logError("could not insert sequence entry", err)
return
}
func (mysql *MySQL) insertConversationEntry(sender, recipient string, messageTime time.Time, id int64) (err error) {
func (mysql *MySQL) insertConversationEntry(ctx context.Context, sender, recipient string, messageTime time.Time, id int64) (err error) {
lower, higher := stringMinMax(sender, recipient)
_, err = mysql.insertConversation.Exec(lower, higher, messageTime.UnixNano(), id)
_, err = mysql.insertConversation.ExecContext(ctx, lower, higher, messageTime.UnixNano(), id)
mysql.logError("could not insert conversations entry", err)
return
}
func (mysql *MySQL) insertBase(item history.Item) (id int64, err error) {
func (mysql *MySQL) insertBase(ctx context.Context, item history.Item) (id int64, err error) {
value, err := marshalItem(&item)
if mysql.logError("could not marshal item", err) {
return
@ -330,7 +343,7 @@ func (mysql *MySQL) insertBase(item history.Item) (id int64, err error) {
return
}
result, err := mysql.insertHistory.Exec(value, msgidBytes)
result, err := mysql.insertHistory.ExecContext(ctx, value, msgidBytes)
if mysql.logError("could not insert item", err) {
return
}
@ -363,31 +376,34 @@ func (mysql *MySQL) AddDirectMessage(sender, recipient string, senderPersistent,
return utils.ErrInvalidParams
}
id, err := mysql.insertBase(item)
ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
defer cancel()
id, err := mysql.insertBase(ctx, item)
if err != nil {
return
}
if senderPersistent {
mysql.insertSequenceEntry(sender, item.Message.Time, id)
mysql.insertSequenceEntry(ctx, sender, item.Message.Time, id)
if err != nil {
return
}
}
if recipientPersistent && sender != recipient {
err = mysql.insertSequenceEntry(recipient, item.Message.Time, id)
err = mysql.insertSequenceEntry(ctx, recipient, item.Message.Time, id)
if err != nil {
return
}
}
err = mysql.insertConversationEntry(sender, recipient, item.Message.Time, id)
err = mysql.insertConversationEntry(ctx, sender, recipient, item.Message.Time, id)
return
}
func (mysql *MySQL) msgidToTime(msgid string) (result time.Time, err error) {
func (mysql *MySQL) msgidToTime(ctx context.Context, msgid string) (result time.Time, err error) {
// in theory, we could optimize out a roundtrip to the database by using a subquery instead:
// sequence.nanotime > (
// SELECT sequence.nanotime FROM sequence, history
@ -400,7 +416,7 @@ func (mysql *MySQL) msgidToTime(msgid string) (result time.Time, err error) {
if err != nil {
return
}
row := mysql.db.QueryRow(`
row := mysql.db.QueryRowContext(ctx, `
SELECT sequence.nanotime FROM sequence
INNER JOIN history ON history.id = sequence.history_id
WHERE history.msgid = ? LIMIT 1;`, decoded)
@ -413,8 +429,8 @@ func (mysql *MySQL) msgidToTime(msgid string) (result time.Time, err error) {
return
}
func (mysql *MySQL) selectItems(query string, args ...interface{}) (results []history.Item, err error) {
rows, err := mysql.db.Query(query, args...)
func (mysql *MySQL) selectItems(ctx context.Context, query string, args ...interface{}) (results []history.Item, err error) {
rows, err := mysql.db.QueryContext(ctx, query, args...)
if mysql.logError("could not select history items", err) {
return
}
@ -437,7 +453,7 @@ func (mysql *MySQL) selectItems(query string, args ...interface{}) (results []hi
return
}
func (mysql *MySQL) BetweenTimestamps(sender, recipient string, after, before, cutoff time.Time, limit int) (results []history.Item, err error) {
func (mysql *MySQL) betweenTimestamps(ctx context.Context, sender, recipient string, after, before, cutoff time.Time, limit int) (results []history.Item, err error) {
useSequence := true
var lowerTarget, upperTarget string
if sender != "" {
@ -480,7 +496,7 @@ func (mysql *MySQL) BetweenTimestamps(sender, recipient string, after, before, c
fmt.Fprintf(&queryBuf, " ORDER BY %[1]s.nanotime %[2]s LIMIT ?;", table, direction)
args = append(args, limit)
results, err = mysql.selectItems(queryBuf.String(), args...)
results, err = mysql.selectItems(ctx, queryBuf.String(), args...)
if err == nil && !ascending {
history.Reverse(results)
}
@ -505,22 +521,25 @@ type mySQLHistorySequence struct {
}
func (s *mySQLHistorySequence) Between(start, end history.Selector, limit int) (results []history.Item, complete bool, err error) {
ctx, cancel := context.WithTimeout(context.Background(), s.mysql.getTimeout())
defer cancel()
startTime := start.Time
if start.Msgid != "" {
startTime, err = s.mysql.msgidToTime(start.Msgid)
startTime, err = s.mysql.msgidToTime(ctx, start.Msgid)
if err != nil {
return nil, false, err
}
}
endTime := end.Time
if end.Msgid != "" {
endTime, err = s.mysql.msgidToTime(end.Msgid)
endTime, err = s.mysql.msgidToTime(ctx, end.Msgid)
if err != nil {
return nil, false, err
}
}
results, err = s.mysql.BetweenTimestamps(s.sender, s.recipient, startTime, endTime, s.cutoff, limit)
results, err = s.mysql.betweenTimestamps(ctx, s.sender, s.recipient, startTime, endTime, s.cutoff, limit)
return results, (err == nil), err
}

View File

@ -669,8 +669,8 @@ func (server *Server) applyConfig(config *Config) (err error) {
return err
}
} else {
if config.Datastore.MySQL.Enabled {
server.historyDB.SetExpireTime(time.Duration(config.History.Restrictions.ExpireTime))
if config.Datastore.MySQL.Enabled && config.Datastore.MySQL != oldConfig.Datastore.MySQL {
server.historyDB.SetConfig(config.Datastore.MySQL)
}
}
@ -793,8 +793,8 @@ func (server *Server) loadDatastore(config *Config) error {
server.accounts.Initialize(server)
if config.Datastore.MySQL.Enabled {
server.historyDB.Initialize(server.logger, time.Duration(config.History.Restrictions.ExpireTime))
err = server.historyDB.Open(config.Datastore.MySQL.User, config.Datastore.MySQL.Password, config.Datastore.MySQL.Host, config.Datastore.MySQL.Port, config.Datastore.MySQL.HistoryDatabase)
server.historyDB.Initialize(server.logger, config.Datastore.MySQL)
err = server.historyDB.Open()
if err != nil {
server.logger.Error("internal", "could not connect to mysql", err.Error())
return err

View File

@ -608,6 +608,7 @@ datastore:
user: "oragono"
password: "KOHw8WSaRwaoo-avo0qVpQ"
history-database: "oragono_history"
timeout: 3s
# languages config
languages: