mirror of
https://github.com/ergochat/ergo.git
synced 2024-11-10 22:19:31 +01:00
add mysql timeouts
This commit is contained in:
parent
8123e3c08f
commit
98a7b45d96
@ -504,14 +504,7 @@ type Config struct {
|
||||
Datastore struct {
|
||||
Path string
|
||||
AutoUpgrade bool
|
||||
MySQL struct {
|
||||
Enabled bool
|
||||
Host string
|
||||
Port int
|
||||
User string
|
||||
Password string
|
||||
HistoryDatabase string `yaml:"history-database"`
|
||||
}
|
||||
MySQL mysql.Config
|
||||
}
|
||||
|
||||
Accounts AccountConfig
|
||||
@ -1069,6 +1062,8 @@ func LoadConfig(filename string) (config *Config, err error) {
|
||||
config.History.ZNCMax = config.History.ChathistoryMax
|
||||
}
|
||||
|
||||
config.Datastore.MySQL.ExpireTime = time.Duration(config.History.Restrictions.ExpireTime)
|
||||
|
||||
config.Server.Cloaks.Initialize()
|
||||
if config.Server.Cloaks.Enabled {
|
||||
if config.Server.Cloaks.Secret == "" || config.Server.Cloaks.Secret == "siaELnk6Kaeo65K3RCrwJjlWaZ-Bt3WuZ2L8MXLbNb4" {
|
||||
|
22
irc/mysql/config.go
Normal file
22
irc/mysql/config.go
Normal file
@ -0,0 +1,22 @@
|
||||
// Copyright (c) 2020 Shivaram Lingamneni
|
||||
// released under the MIT license
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
// these are intended to be written directly into the config file:
|
||||
Enabled bool
|
||||
Host string
|
||||
Port int
|
||||
User string
|
||||
Password string
|
||||
HistoryDatabase string `yaml:"history-database"`
|
||||
Timeout time.Duration
|
||||
|
||||
// XXX these are copied from elsewhere in the config:
|
||||
ExpireTime time.Duration
|
||||
}
|
@ -1,11 +1,16 @@
|
||||
// Copyright (c) 2020 Shivaram Lingamneni
|
||||
// released under the MIT license
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"runtime/debug"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
@ -27,58 +32,59 @@ const (
|
||||
)
|
||||
|
||||
type MySQL struct {
|
||||
db *sql.DB
|
||||
logger *logger.Manager
|
||||
timeout int64
|
||||
db *sql.DB
|
||||
logger *logger.Manager
|
||||
|
||||
insertHistory *sql.Stmt
|
||||
insertSequence *sql.Stmt
|
||||
insertConversation *sql.Stmt
|
||||
|
||||
stateMutex sync.Mutex
|
||||
expireTime time.Duration
|
||||
config Config
|
||||
}
|
||||
|
||||
func (mysql *MySQL) Initialize(logger *logger.Manager, expireTime time.Duration) {
|
||||
func (mysql *MySQL) Initialize(logger *logger.Manager, config Config) {
|
||||
mysql.logger = logger
|
||||
mysql.expireTime = expireTime
|
||||
mysql.SetConfig(config)
|
||||
}
|
||||
|
||||
func (mysql *MySQL) SetExpireTime(expireTime time.Duration) {
|
||||
func (mysql *MySQL) SetConfig(config Config) {
|
||||
atomic.StoreInt64(&mysql.timeout, int64(config.Timeout))
|
||||
mysql.stateMutex.Lock()
|
||||
mysql.expireTime = expireTime
|
||||
mysql.config = config
|
||||
mysql.stateMutex.Unlock()
|
||||
}
|
||||
|
||||
func (mysql *MySQL) getExpireTime() (expireTime time.Duration) {
|
||||
mysql.stateMutex.Lock()
|
||||
expireTime = mysql.expireTime
|
||||
expireTime = mysql.config.ExpireTime
|
||||
mysql.stateMutex.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
func (mysql *MySQL) Open(username, password, host string, port int, database string) (err error) {
|
||||
// TODO: timeouts!
|
||||
func (m *MySQL) Open() (err error) {
|
||||
var address string
|
||||
if port != 0 {
|
||||
address = fmt.Sprintf("tcp(%s:%d)", host, port)
|
||||
if m.config.Port != 0 {
|
||||
address = fmt.Sprintf("tcp(%s:%d)", m.config.Host, m.config.Port)
|
||||
}
|
||||
|
||||
mysql.db, err = sql.Open("mysql", fmt.Sprintf("%s:%s@%s/%s", username, password, address, database))
|
||||
m.db, err = sql.Open("mysql", fmt.Sprintf("%s:%s@%s/%s", m.config.User, m.config.Password, address, m.config.HistoryDatabase))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = mysql.fixSchemas()
|
||||
err = m.fixSchemas()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = mysql.prepareStatements()
|
||||
err = m.prepareStatements()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go mysql.cleanupLoop()
|
||||
go m.cleanupLoop()
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -280,6 +286,10 @@ func (mysql *MySQL) prepareStatements() (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (mysql *MySQL) getTimeout() time.Duration {
|
||||
return time.Duration(atomic.LoadInt64(&mysql.timeout))
|
||||
}
|
||||
|
||||
func (mysql *MySQL) logError(context string, err error) (quit bool) {
|
||||
if err != nil {
|
||||
mysql.logger.Error("mysql", context, err.Error())
|
||||
@ -297,29 +307,32 @@ func (mysql *MySQL) AddChannelItem(target string, item history.Item) (err error)
|
||||
return utils.ErrInvalidParams
|
||||
}
|
||||
|
||||
id, err := mysql.insertBase(item)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
|
||||
defer cancel()
|
||||
|
||||
id, err := mysql.insertBase(ctx, item)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = mysql.insertSequenceEntry(target, item.Message.Time, id)
|
||||
err = mysql.insertSequenceEntry(ctx, target, item.Message.Time, id)
|
||||
return
|
||||
}
|
||||
|
||||
func (mysql *MySQL) insertSequenceEntry(target string, messageTime time.Time, id int64) (err error) {
|
||||
_, err = mysql.insertSequence.Exec(target, messageTime.UnixNano(), id)
|
||||
func (mysql *MySQL) insertSequenceEntry(ctx context.Context, target string, messageTime time.Time, id int64) (err error) {
|
||||
_, err = mysql.insertSequence.ExecContext(ctx, target, messageTime.UnixNano(), id)
|
||||
mysql.logError("could not insert sequence entry", err)
|
||||
return
|
||||
}
|
||||
|
||||
func (mysql *MySQL) insertConversationEntry(sender, recipient string, messageTime time.Time, id int64) (err error) {
|
||||
func (mysql *MySQL) insertConversationEntry(ctx context.Context, sender, recipient string, messageTime time.Time, id int64) (err error) {
|
||||
lower, higher := stringMinMax(sender, recipient)
|
||||
_, err = mysql.insertConversation.Exec(lower, higher, messageTime.UnixNano(), id)
|
||||
_, err = mysql.insertConversation.ExecContext(ctx, lower, higher, messageTime.UnixNano(), id)
|
||||
mysql.logError("could not insert conversations entry", err)
|
||||
return
|
||||
}
|
||||
|
||||
func (mysql *MySQL) insertBase(item history.Item) (id int64, err error) {
|
||||
func (mysql *MySQL) insertBase(ctx context.Context, item history.Item) (id int64, err error) {
|
||||
value, err := marshalItem(&item)
|
||||
if mysql.logError("could not marshal item", err) {
|
||||
return
|
||||
@ -330,7 +343,7 @@ func (mysql *MySQL) insertBase(item history.Item) (id int64, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
result, err := mysql.insertHistory.Exec(value, msgidBytes)
|
||||
result, err := mysql.insertHistory.ExecContext(ctx, value, msgidBytes)
|
||||
if mysql.logError("could not insert item", err) {
|
||||
return
|
||||
}
|
||||
@ -363,31 +376,34 @@ func (mysql *MySQL) AddDirectMessage(sender, recipient string, senderPersistent,
|
||||
return utils.ErrInvalidParams
|
||||
}
|
||||
|
||||
id, err := mysql.insertBase(item)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
|
||||
defer cancel()
|
||||
|
||||
id, err := mysql.insertBase(ctx, item)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if senderPersistent {
|
||||
mysql.insertSequenceEntry(sender, item.Message.Time, id)
|
||||
mysql.insertSequenceEntry(ctx, sender, item.Message.Time, id)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if recipientPersistent && sender != recipient {
|
||||
err = mysql.insertSequenceEntry(recipient, item.Message.Time, id)
|
||||
err = mysql.insertSequenceEntry(ctx, recipient, item.Message.Time, id)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
err = mysql.insertConversationEntry(sender, recipient, item.Message.Time, id)
|
||||
err = mysql.insertConversationEntry(ctx, sender, recipient, item.Message.Time, id)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (mysql *MySQL) msgidToTime(msgid string) (result time.Time, err error) {
|
||||
func (mysql *MySQL) msgidToTime(ctx context.Context, msgid string) (result time.Time, err error) {
|
||||
// in theory, we could optimize out a roundtrip to the database by using a subquery instead:
|
||||
// sequence.nanotime > (
|
||||
// SELECT sequence.nanotime FROM sequence, history
|
||||
@ -400,7 +416,7 @@ func (mysql *MySQL) msgidToTime(msgid string) (result time.Time, err error) {
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
row := mysql.db.QueryRow(`
|
||||
row := mysql.db.QueryRowContext(ctx, `
|
||||
SELECT sequence.nanotime FROM sequence
|
||||
INNER JOIN history ON history.id = sequence.history_id
|
||||
WHERE history.msgid = ? LIMIT 1;`, decoded)
|
||||
@ -413,8 +429,8 @@ func (mysql *MySQL) msgidToTime(msgid string) (result time.Time, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (mysql *MySQL) selectItems(query string, args ...interface{}) (results []history.Item, err error) {
|
||||
rows, err := mysql.db.Query(query, args...)
|
||||
func (mysql *MySQL) selectItems(ctx context.Context, query string, args ...interface{}) (results []history.Item, err error) {
|
||||
rows, err := mysql.db.QueryContext(ctx, query, args...)
|
||||
if mysql.logError("could not select history items", err) {
|
||||
return
|
||||
}
|
||||
@ -437,7 +453,7 @@ func (mysql *MySQL) selectItems(query string, args ...interface{}) (results []hi
|
||||
return
|
||||
}
|
||||
|
||||
func (mysql *MySQL) BetweenTimestamps(sender, recipient string, after, before, cutoff time.Time, limit int) (results []history.Item, err error) {
|
||||
func (mysql *MySQL) betweenTimestamps(ctx context.Context, sender, recipient string, after, before, cutoff time.Time, limit int) (results []history.Item, err error) {
|
||||
useSequence := true
|
||||
var lowerTarget, upperTarget string
|
||||
if sender != "" {
|
||||
@ -480,7 +496,7 @@ func (mysql *MySQL) BetweenTimestamps(sender, recipient string, after, before, c
|
||||
fmt.Fprintf(&queryBuf, " ORDER BY %[1]s.nanotime %[2]s LIMIT ?;", table, direction)
|
||||
args = append(args, limit)
|
||||
|
||||
results, err = mysql.selectItems(queryBuf.String(), args...)
|
||||
results, err = mysql.selectItems(ctx, queryBuf.String(), args...)
|
||||
if err == nil && !ascending {
|
||||
history.Reverse(results)
|
||||
}
|
||||
@ -505,22 +521,25 @@ type mySQLHistorySequence struct {
|
||||
}
|
||||
|
||||
func (s *mySQLHistorySequence) Between(start, end history.Selector, limit int) (results []history.Item, complete bool, err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), s.mysql.getTimeout())
|
||||
defer cancel()
|
||||
|
||||
startTime := start.Time
|
||||
if start.Msgid != "" {
|
||||
startTime, err = s.mysql.msgidToTime(start.Msgid)
|
||||
startTime, err = s.mysql.msgidToTime(ctx, start.Msgid)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
}
|
||||
endTime := end.Time
|
||||
if end.Msgid != "" {
|
||||
endTime, err = s.mysql.msgidToTime(end.Msgid)
|
||||
endTime, err = s.mysql.msgidToTime(ctx, end.Msgid)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
}
|
||||
|
||||
results, err = s.mysql.BetweenTimestamps(s.sender, s.recipient, startTime, endTime, s.cutoff, limit)
|
||||
results, err = s.mysql.betweenTimestamps(ctx, s.sender, s.recipient, startTime, endTime, s.cutoff, limit)
|
||||
return results, (err == nil), err
|
||||
}
|
||||
|
||||
|
@ -669,8 +669,8 @@ func (server *Server) applyConfig(config *Config) (err error) {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if config.Datastore.MySQL.Enabled {
|
||||
server.historyDB.SetExpireTime(time.Duration(config.History.Restrictions.ExpireTime))
|
||||
if config.Datastore.MySQL.Enabled && config.Datastore.MySQL != oldConfig.Datastore.MySQL {
|
||||
server.historyDB.SetConfig(config.Datastore.MySQL)
|
||||
}
|
||||
}
|
||||
|
||||
@ -793,8 +793,8 @@ func (server *Server) loadDatastore(config *Config) error {
|
||||
server.accounts.Initialize(server)
|
||||
|
||||
if config.Datastore.MySQL.Enabled {
|
||||
server.historyDB.Initialize(server.logger, time.Duration(config.History.Restrictions.ExpireTime))
|
||||
err = server.historyDB.Open(config.Datastore.MySQL.User, config.Datastore.MySQL.Password, config.Datastore.MySQL.Host, config.Datastore.MySQL.Port, config.Datastore.MySQL.HistoryDatabase)
|
||||
server.historyDB.Initialize(server.logger, config.Datastore.MySQL)
|
||||
err = server.historyDB.Open()
|
||||
if err != nil {
|
||||
server.logger.Error("internal", "could not connect to mysql", err.Error())
|
||||
return err
|
||||
|
@ -608,6 +608,7 @@ datastore:
|
||||
user: "oragono"
|
||||
password: "KOHw8WSaRwaoo-avo0qVpQ"
|
||||
history-database: "oragono_history"
|
||||
timeout: 3s
|
||||
|
||||
# languages config
|
||||
languages:
|
||||
|
Loading…
Reference in New Issue
Block a user