From 98a7b45d96069ca7455d92d7abf54ba8ee69df2f Mon Sep 17 00:00:00 2001 From: Shivaram Lingamneni Date: Thu, 20 Feb 2020 18:33:48 -0500 Subject: [PATCH] add mysql timeouts --- irc/config.go | 11 ++---- irc/mysql/config.go | 22 +++++++++++ irc/mysql/history.go | 93 ++++++++++++++++++++++++++------------------ irc/server.go | 8 ++-- oragono.yaml | 1 + 5 files changed, 86 insertions(+), 49 deletions(-) create mode 100644 irc/mysql/config.go diff --git a/irc/config.go b/irc/config.go index b13721d5..1652c9c1 100644 --- a/irc/config.go +++ b/irc/config.go @@ -504,14 +504,7 @@ type Config struct { Datastore struct { Path string AutoUpgrade bool - MySQL struct { - Enabled bool - Host string - Port int - User string - Password string - HistoryDatabase string `yaml:"history-database"` - } + MySQL mysql.Config } Accounts AccountConfig @@ -1069,6 +1062,8 @@ func LoadConfig(filename string) (config *Config, err error) { config.History.ZNCMax = config.History.ChathistoryMax } + config.Datastore.MySQL.ExpireTime = time.Duration(config.History.Restrictions.ExpireTime) + config.Server.Cloaks.Initialize() if config.Server.Cloaks.Enabled { if config.Server.Cloaks.Secret == "" || config.Server.Cloaks.Secret == "siaELnk6Kaeo65K3RCrwJjlWaZ-Bt3WuZ2L8MXLbNb4" { diff --git a/irc/mysql/config.go b/irc/mysql/config.go new file mode 100644 index 00000000..c9e39437 --- /dev/null +++ b/irc/mysql/config.go @@ -0,0 +1,22 @@ +// Copyright (c) 2020 Shivaram Lingamneni +// released under the MIT license + +package mysql + +import ( + "time" +) + +type Config struct { + // these are intended to be written directly into the config file: + Enabled bool + Host string + Port int + User string + Password string + HistoryDatabase string `yaml:"history-database"` + Timeout time.Duration + + // XXX these are copied from elsewhere in the config: + ExpireTime time.Duration +} diff --git a/irc/mysql/history.go b/irc/mysql/history.go index 59dc0d33..c14c9d37 100644 --- a/irc/mysql/history.go +++ b/irc/mysql/history.go @@ -1,11 +1,16 @@ +// Copyright (c) 2020 Shivaram Lingamneni +// released under the MIT license + package mysql import ( "bytes" + "context" "database/sql" "fmt" "runtime/debug" "sync" + "sync/atomic" "time" _ "github.com/go-sql-driver/mysql" @@ -27,58 +32,59 @@ const ( ) type MySQL struct { - db *sql.DB - logger *logger.Manager + timeout int64 + db *sql.DB + logger *logger.Manager insertHistory *sql.Stmt insertSequence *sql.Stmt insertConversation *sql.Stmt stateMutex sync.Mutex - expireTime time.Duration + config Config } -func (mysql *MySQL) Initialize(logger *logger.Manager, expireTime time.Duration) { +func (mysql *MySQL) Initialize(logger *logger.Manager, config Config) { mysql.logger = logger - mysql.expireTime = expireTime + mysql.SetConfig(config) } -func (mysql *MySQL) SetExpireTime(expireTime time.Duration) { +func (mysql *MySQL) SetConfig(config Config) { + atomic.StoreInt64(&mysql.timeout, int64(config.Timeout)) mysql.stateMutex.Lock() - mysql.expireTime = expireTime + mysql.config = config mysql.stateMutex.Unlock() } func (mysql *MySQL) getExpireTime() (expireTime time.Duration) { mysql.stateMutex.Lock() - expireTime = mysql.expireTime + expireTime = mysql.config.ExpireTime mysql.stateMutex.Unlock() return } -func (mysql *MySQL) Open(username, password, host string, port int, database string) (err error) { - // TODO: timeouts! +func (m *MySQL) Open() (err error) { var address string - if port != 0 { - address = fmt.Sprintf("tcp(%s:%d)", host, port) + if m.config.Port != 0 { + address = fmt.Sprintf("tcp(%s:%d)", m.config.Host, m.config.Port) } - mysql.db, err = sql.Open("mysql", fmt.Sprintf("%s:%s@%s/%s", username, password, address, database)) + m.db, err = sql.Open("mysql", fmt.Sprintf("%s:%s@%s/%s", m.config.User, m.config.Password, address, m.config.HistoryDatabase)) if err != nil { return err } - err = mysql.fixSchemas() + err = m.fixSchemas() if err != nil { return err } - err = mysql.prepareStatements() + err = m.prepareStatements() if err != nil { return err } - go mysql.cleanupLoop() + go m.cleanupLoop() return nil } @@ -280,6 +286,10 @@ func (mysql *MySQL) prepareStatements() (err error) { return } +func (mysql *MySQL) getTimeout() time.Duration { + return time.Duration(atomic.LoadInt64(&mysql.timeout)) +} + func (mysql *MySQL) logError(context string, err error) (quit bool) { if err != nil { mysql.logger.Error("mysql", context, err.Error()) @@ -297,29 +307,32 @@ func (mysql *MySQL) AddChannelItem(target string, item history.Item) (err error) return utils.ErrInvalidParams } - id, err := mysql.insertBase(item) + ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout()) + defer cancel() + + id, err := mysql.insertBase(ctx, item) if err != nil { return } - err = mysql.insertSequenceEntry(target, item.Message.Time, id) + err = mysql.insertSequenceEntry(ctx, target, item.Message.Time, id) return } -func (mysql *MySQL) insertSequenceEntry(target string, messageTime time.Time, id int64) (err error) { - _, err = mysql.insertSequence.Exec(target, messageTime.UnixNano(), id) +func (mysql *MySQL) insertSequenceEntry(ctx context.Context, target string, messageTime time.Time, id int64) (err error) { + _, err = mysql.insertSequence.ExecContext(ctx, target, messageTime.UnixNano(), id) mysql.logError("could not insert sequence entry", err) return } -func (mysql *MySQL) insertConversationEntry(sender, recipient string, messageTime time.Time, id int64) (err error) { +func (mysql *MySQL) insertConversationEntry(ctx context.Context, sender, recipient string, messageTime time.Time, id int64) (err error) { lower, higher := stringMinMax(sender, recipient) - _, err = mysql.insertConversation.Exec(lower, higher, messageTime.UnixNano(), id) + _, err = mysql.insertConversation.ExecContext(ctx, lower, higher, messageTime.UnixNano(), id) mysql.logError("could not insert conversations entry", err) return } -func (mysql *MySQL) insertBase(item history.Item) (id int64, err error) { +func (mysql *MySQL) insertBase(ctx context.Context, item history.Item) (id int64, err error) { value, err := marshalItem(&item) if mysql.logError("could not marshal item", err) { return @@ -330,7 +343,7 @@ func (mysql *MySQL) insertBase(item history.Item) (id int64, err error) { return } - result, err := mysql.insertHistory.Exec(value, msgidBytes) + result, err := mysql.insertHistory.ExecContext(ctx, value, msgidBytes) if mysql.logError("could not insert item", err) { return } @@ -363,31 +376,34 @@ func (mysql *MySQL) AddDirectMessage(sender, recipient string, senderPersistent, return utils.ErrInvalidParams } - id, err := mysql.insertBase(item) + ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout()) + defer cancel() + + id, err := mysql.insertBase(ctx, item) if err != nil { return } if senderPersistent { - mysql.insertSequenceEntry(sender, item.Message.Time, id) + mysql.insertSequenceEntry(ctx, sender, item.Message.Time, id) if err != nil { return } } if recipientPersistent && sender != recipient { - err = mysql.insertSequenceEntry(recipient, item.Message.Time, id) + err = mysql.insertSequenceEntry(ctx, recipient, item.Message.Time, id) if err != nil { return } } - err = mysql.insertConversationEntry(sender, recipient, item.Message.Time, id) + err = mysql.insertConversationEntry(ctx, sender, recipient, item.Message.Time, id) return } -func (mysql *MySQL) msgidToTime(msgid string) (result time.Time, err error) { +func (mysql *MySQL) msgidToTime(ctx context.Context, msgid string) (result time.Time, err error) { // in theory, we could optimize out a roundtrip to the database by using a subquery instead: // sequence.nanotime > ( // SELECT sequence.nanotime FROM sequence, history @@ -400,7 +416,7 @@ func (mysql *MySQL) msgidToTime(msgid string) (result time.Time, err error) { if err != nil { return } - row := mysql.db.QueryRow(` + row := mysql.db.QueryRowContext(ctx, ` SELECT sequence.nanotime FROM sequence INNER JOIN history ON history.id = sequence.history_id WHERE history.msgid = ? LIMIT 1;`, decoded) @@ -413,8 +429,8 @@ func (mysql *MySQL) msgidToTime(msgid string) (result time.Time, err error) { return } -func (mysql *MySQL) selectItems(query string, args ...interface{}) (results []history.Item, err error) { - rows, err := mysql.db.Query(query, args...) +func (mysql *MySQL) selectItems(ctx context.Context, query string, args ...interface{}) (results []history.Item, err error) { + rows, err := mysql.db.QueryContext(ctx, query, args...) if mysql.logError("could not select history items", err) { return } @@ -437,7 +453,7 @@ func (mysql *MySQL) selectItems(query string, args ...interface{}) (results []hi return } -func (mysql *MySQL) BetweenTimestamps(sender, recipient string, after, before, cutoff time.Time, limit int) (results []history.Item, err error) { +func (mysql *MySQL) betweenTimestamps(ctx context.Context, sender, recipient string, after, before, cutoff time.Time, limit int) (results []history.Item, err error) { useSequence := true var lowerTarget, upperTarget string if sender != "" { @@ -480,7 +496,7 @@ func (mysql *MySQL) BetweenTimestamps(sender, recipient string, after, before, c fmt.Fprintf(&queryBuf, " ORDER BY %[1]s.nanotime %[2]s LIMIT ?;", table, direction) args = append(args, limit) - results, err = mysql.selectItems(queryBuf.String(), args...) + results, err = mysql.selectItems(ctx, queryBuf.String(), args...) if err == nil && !ascending { history.Reverse(results) } @@ -505,22 +521,25 @@ type mySQLHistorySequence struct { } func (s *mySQLHistorySequence) Between(start, end history.Selector, limit int) (results []history.Item, complete bool, err error) { + ctx, cancel := context.WithTimeout(context.Background(), s.mysql.getTimeout()) + defer cancel() + startTime := start.Time if start.Msgid != "" { - startTime, err = s.mysql.msgidToTime(start.Msgid) + startTime, err = s.mysql.msgidToTime(ctx, start.Msgid) if err != nil { return nil, false, err } } endTime := end.Time if end.Msgid != "" { - endTime, err = s.mysql.msgidToTime(end.Msgid) + endTime, err = s.mysql.msgidToTime(ctx, end.Msgid) if err != nil { return nil, false, err } } - results, err = s.mysql.BetweenTimestamps(s.sender, s.recipient, startTime, endTime, s.cutoff, limit) + results, err = s.mysql.betweenTimestamps(ctx, s.sender, s.recipient, startTime, endTime, s.cutoff, limit) return results, (err == nil), err } diff --git a/irc/server.go b/irc/server.go index 1fb20669..24b48094 100644 --- a/irc/server.go +++ b/irc/server.go @@ -669,8 +669,8 @@ func (server *Server) applyConfig(config *Config) (err error) { return err } } else { - if config.Datastore.MySQL.Enabled { - server.historyDB.SetExpireTime(time.Duration(config.History.Restrictions.ExpireTime)) + if config.Datastore.MySQL.Enabled && config.Datastore.MySQL != oldConfig.Datastore.MySQL { + server.historyDB.SetConfig(config.Datastore.MySQL) } } @@ -793,8 +793,8 @@ func (server *Server) loadDatastore(config *Config) error { server.accounts.Initialize(server) if config.Datastore.MySQL.Enabled { - server.historyDB.Initialize(server.logger, time.Duration(config.History.Restrictions.ExpireTime)) - err = server.historyDB.Open(config.Datastore.MySQL.User, config.Datastore.MySQL.Password, config.Datastore.MySQL.Host, config.Datastore.MySQL.Port, config.Datastore.MySQL.HistoryDatabase) + server.historyDB.Initialize(server.logger, config.Datastore.MySQL) + err = server.historyDB.Open() if err != nil { server.logger.Error("internal", "could not connect to mysql", err.Error()) return err diff --git a/oragono.yaml b/oragono.yaml index 126c9093..dcaa16eb 100644 --- a/oragono.yaml +++ b/oragono.yaml @@ -608,6 +608,7 @@ datastore: user: "oragono" password: "KOHw8WSaRwaoo-avo0qVpQ" history-database: "oragono_history" + timeout: 3s # languages config languages: