mirror of
				https://github.com/ergochat/ergo.git
				synced 2025-10-31 05:47:22 +01:00 
			
		
		
		
	
		
			
				
	
	
		
			536 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			536 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package mysql
 | |
| 
 | |
| import (
 | |
| 	"bytes"
 | |
| 	"database/sql"
 | |
| 	"fmt"
 | |
| 	"runtime/debug"
 | |
| 	"strconv"
 | |
| 	"sync"
 | |
| 	"time"
 | |
| 
 | |
| 	_ "github.com/go-sql-driver/mysql"
 | |
| 	"github.com/oragono/oragono/irc/history"
 | |
| 	"github.com/oragono/oragono/irc/logger"
 | |
| 	"github.com/oragono/oragono/irc/utils"
 | |
| )
 | |
| 
 | |
| const (
 | |
| 	// latest schema of the db
 | |
| 	latestDbSchema   = "1"
 | |
| 	keySchemaVersion = "db.version"
 | |
| 	cleanupRowLimit  = 50
 | |
| 	cleanupPauseTime = 10 * time.Minute
 | |
| )
 | |
| 
 | |
| type MySQL struct {
 | |
| 	db     *sql.DB
 | |
| 	logger *logger.Manager
 | |
| 
 | |
| 	insertHistory      *sql.Stmt
 | |
| 	insertSequence     *sql.Stmt
 | |
| 	insertConversation *sql.Stmt
 | |
| 
 | |
| 	stateMutex sync.Mutex
 | |
| 	expireTime time.Duration
 | |
| }
 | |
| 
 | |
| func (mysql *MySQL) Initialize(logger *logger.Manager, expireTime time.Duration) {
 | |
| 	mysql.logger = logger
 | |
| 	mysql.expireTime = expireTime
 | |
| }
 | |
| 
 | |
| func (mysql *MySQL) SetExpireTime(expireTime time.Duration) {
 | |
| 	mysql.stateMutex.Lock()
 | |
| 	mysql.expireTime = expireTime
 | |
| 	mysql.stateMutex.Unlock()
 | |
| }
 | |
| 
 | |
| func (mysql *MySQL) getExpireTime() (expireTime time.Duration) {
 | |
| 	mysql.stateMutex.Lock()
 | |
| 	expireTime = mysql.expireTime
 | |
| 	mysql.stateMutex.Unlock()
 | |
| 	return
 | |
| }
 | |
| 
 | |
| func (mysql *MySQL) Open(username, password, host string, port int, database string) (err error) {
 | |
| 	// TODO: timeouts!
 | |
| 	var address string
 | |
| 	if port != 0 {
 | |
| 		address = fmt.Sprintf("tcp(%s:%d)", host, port)
 | |
| 	}
 | |
| 
 | |
| 	mysql.db, err = sql.Open("mysql", fmt.Sprintf("%s:%s@%s/%s", username, password, address, database))
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	err = mysql.fixSchemas()
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	err = mysql.prepareStatements()
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	go mysql.cleanupLoop()
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (mysql *MySQL) fixSchemas() (err error) {
 | |
| 	_, err = mysql.db.Exec(`CREATE TABLE IF NOT EXISTS metadata (
 | |
| 		key_name VARCHAR(32) primary key,
 | |
| 		value VARCHAR(32) NOT NULL
 | |
| 	) CHARSET=ascii COLLATE=ascii_bin;`)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	var schema string
 | |
| 	err = mysql.db.QueryRow(`select value from metadata where key_name = ?;`, keySchemaVersion).Scan(&schema)
 | |
| 	if err == sql.ErrNoRows {
 | |
| 		err = mysql.createTables()
 | |
| 		if err != nil {
 | |
| 			return
 | |
| 		}
 | |
| 		_, err = mysql.db.Exec(`insert into metadata (key_name, value) values (?, ?);`, keySchemaVersion, latestDbSchema)
 | |
| 		if err != nil {
 | |
| 			return
 | |
| 		}
 | |
| 	} else if err == nil && schema != latestDbSchema {
 | |
| 		// TODO figure out what to do about schema changes
 | |
| 		return &utils.IncompatibleSchemaError{CurrentVersion: schema, RequiredVersion: latestDbSchema}
 | |
| 	} else {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (mysql *MySQL) createTables() (err error) {
 | |
| 	_, err = mysql.db.Exec(`CREATE TABLE history (
 | |
| 		id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
 | |
| 		data BLOB NOT NULL,
 | |
| 		msgid BINARY(16) NOT NULL,
 | |
| 		KEY (msgid(4))
 | |
| 	) CHARSET=ascii COLLATE=ascii_bin;`)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	_, err = mysql.db.Exec(`CREATE TABLE sequence (
 | |
| 		id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
 | |
| 		target VARBINARY(64) NOT NULL,
 | |
| 		nanotime BIGINT UNSIGNED NOT NULL,
 | |
| 		history_id BIGINT NOT NULL,
 | |
| 		KEY (target, nanotime),
 | |
| 		KEY (history_id)
 | |
| 	) CHARSET=ascii COLLATE=ascii_bin;`)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	_, err = mysql.db.Exec(`CREATE TABLE conversations (
 | |
| 		id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
 | |
| 		lower_target VARBINARY(64) NOT NULL,
 | |
| 		upper_target VARBINARY(64) NOT NULL,
 | |
| 		nanotime BIGINT UNSIGNED NOT NULL,
 | |
| 		history_id BIGINT NOT NULL,
 | |
| 		KEY (lower_target, upper_target, nanotime),
 | |
| 		KEY (history_id)
 | |
| 	) CHARSET=ascii COLLATE=ascii_bin;`)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (mysql *MySQL) cleanupLoop() {
 | |
| 	defer func() {
 | |
| 		if r := recover(); r != nil {
 | |
| 			mysql.logger.Error("mysql",
 | |
| 				fmt.Sprintf("Panic in cleanup routine: %v\n%s", r, debug.Stack()))
 | |
| 			time.Sleep(cleanupPauseTime)
 | |
| 			go mysql.cleanupLoop()
 | |
| 		}
 | |
| 	}()
 | |
| 
 | |
| 	for {
 | |
| 		expireTime := mysql.getExpireTime()
 | |
| 		if expireTime != 0 {
 | |
| 			for {
 | |
| 				startTime := time.Now()
 | |
| 				rowsDeleted, err := mysql.doCleanup(expireTime)
 | |
| 				elapsed := time.Now().Sub(startTime)
 | |
| 				mysql.logError("error during row cleanup", err)
 | |
| 				// keep going as long as we're accomplishing significant work
 | |
| 				// (don't busy-wait on small numbers of rows expiring):
 | |
| 				if rowsDeleted < (cleanupRowLimit / 10) {
 | |
| 					break
 | |
| 				}
 | |
| 				// crude backpressure mechanism: if the database is slow,
 | |
| 				// give it time to process other queries
 | |
| 				time.Sleep(elapsed)
 | |
| 			}
 | |
| 		}
 | |
| 		time.Sleep(cleanupPauseTime)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (mysql *MySQL) doCleanup(age time.Duration) (count int, err error) {
 | |
| 	ids, maxNanotime, err := mysql.selectCleanupIDs(age)
 | |
| 	if len(ids) == 0 {
 | |
| 		mysql.logger.Debug("mysql", "found no rows to clean up")
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	mysql.logger.Debug("mysql", fmt.Sprintf("deleting %d history rows, max age %s", len(ids), utils.NanoToTimestamp(maxNanotime)))
 | |
| 
 | |
| 	// can't use ? binding for a variable number of arguments, build the IN clause manually
 | |
| 	var inBuf bytes.Buffer
 | |
| 	inBuf.WriteByte('(')
 | |
| 	for i, id := range ids {
 | |
| 		if i != 0 {
 | |
| 			inBuf.WriteRune(',')
 | |
| 		}
 | |
| 		inBuf.WriteString(strconv.FormatInt(int64(id), 10))
 | |
| 	}
 | |
| 	inBuf.WriteRune(')')
 | |
| 
 | |
| 	_, err = mysql.db.Exec(fmt.Sprintf(`DELETE FROM conversations WHERE history_id in %s;`, inBuf.Bytes()))
 | |
| 	if err != nil {
 | |
| 		return
 | |
| 	}
 | |
| 	_, err = mysql.db.Exec(fmt.Sprintf(`DELETE FROM sequence WHERE history_id in %s;`, inBuf.Bytes()))
 | |
| 	if err != nil {
 | |
| 		return
 | |
| 	}
 | |
| 	_, err = mysql.db.Exec(fmt.Sprintf(`DELETE FROM history WHERE id in %s;`, inBuf.Bytes()))
 | |
| 	if err != nil {
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	count = len(ids)
 | |
| 	return
 | |
| }
 | |
| 
 | |
| func (mysql *MySQL) selectCleanupIDs(age time.Duration) (ids []uint64, maxNanotime int64, err error) {
 | |
| 	rows, err := mysql.db.Query(`
 | |
| 		SELECT history.id, sequence.nanotime
 | |
| 		FROM history
 | |
| 		LEFT JOIN sequence ON history.id = sequence.history_id
 | |
| 		ORDER BY history.id LIMIT ?;`, cleanupRowLimit)
 | |
| 	if err != nil {
 | |
| 		return
 | |
| 	}
 | |
| 	defer rows.Close()
 | |
| 
 | |
| 	// a history ID may have 0-2 rows in sequence: 1 for a channel entry,
 | |
| 	// 2 for a DM, 0 if the data is inconsistent. therefore, deduplicate
 | |
| 	// and delete anything that doesn't have a sequence entry:
 | |
| 	idset := make(map[uint64]struct{}, cleanupRowLimit)
 | |
| 	threshold := time.Now().Add(-age).UnixNano()
 | |
| 	for rows.Next() {
 | |
| 		var id uint64
 | |
| 		var nanotime sql.NullInt64
 | |
| 		err = rows.Scan(&id, &nanotime)
 | |
| 		if err != nil {
 | |
| 			return
 | |
| 		}
 | |
| 		if !nanotime.Valid || nanotime.Int64 < threshold {
 | |
| 			idset[id] = struct{}{}
 | |
| 			if nanotime.Valid && nanotime.Int64 > maxNanotime {
 | |
| 				maxNanotime = nanotime.Int64
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 	ids = make([]uint64, len(idset))
 | |
| 	i := 0
 | |
| 	for id := range idset {
 | |
| 		ids[i] = id
 | |
| 		i++
 | |
| 	}
 | |
| 	return
 | |
| }
 | |
| 
 | |
| func (mysql *MySQL) prepareStatements() (err error) {
 | |
| 	mysql.insertHistory, err = mysql.db.Prepare(`INSERT INTO history
 | |
| 		(data, msgid) VALUES (?, ?);`)
 | |
| 	if err != nil {
 | |
| 		return
 | |
| 	}
 | |
| 	mysql.insertSequence, err = mysql.db.Prepare(`INSERT INTO sequence
 | |
| 		(target, nanotime, history_id) VALUES (?, ?, ?);`)
 | |
| 	if err != nil {
 | |
| 		return
 | |
| 	}
 | |
| 	mysql.insertConversation, err = mysql.db.Prepare(`INSERT INTO conversations
 | |
| 		(lower_target, upper_target, nanotime, history_id) VALUES (?, ?, ?, ?);`)
 | |
| 	if err != nil {
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	return
 | |
| }
 | |
| 
 | |
| func (mysql *MySQL) logError(context string, err error) (quit bool) {
 | |
| 	if err != nil {
 | |
| 		mysql.logger.Error("mysql", context, err.Error())
 | |
| 		return true
 | |
| 	}
 | |
| 	return false
 | |
| }
 | |
| 
 | |
| func (mysql *MySQL) AddChannelItem(target string, item history.Item) (err error) {
 | |
| 	if mysql.db == nil {
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	if target == "" {
 | |
| 		return utils.ErrInvalidParams
 | |
| 	}
 | |
| 
 | |
| 	id, err := mysql.insertBase(item)
 | |
| 	if err != nil {
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	err = mysql.insertSequenceEntry(target, item.Message.Time, id)
 | |
| 	return
 | |
| }
 | |
| 
 | |
| func (mysql *MySQL) insertSequenceEntry(target string, messageTime time.Time, id int64) (err error) {
 | |
| 	_, err = mysql.insertSequence.Exec(target, messageTime.UnixNano(), id)
 | |
| 	mysql.logError("could not insert sequence entry", err)
 | |
| 	return
 | |
| }
 | |
| 
 | |
| func (mysql *MySQL) insertConversationEntry(sender, recipient string, messageTime time.Time, id int64) (err error) {
 | |
| 	lower, higher := stringMinMax(sender, recipient)
 | |
| 	_, err = mysql.insertConversation.Exec(lower, higher, messageTime.UnixNano(), id)
 | |
| 	mysql.logError("could not insert conversations entry", err)
 | |
| 	return
 | |
| }
 | |
| 
 | |
| func (mysql *MySQL) insertBase(item history.Item) (id int64, err error) {
 | |
| 	value, err := marshalItem(&item)
 | |
| 	if mysql.logError("could not marshal item", err) {
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	msgidBytes, err := decodeMsgid(item.Message.Msgid)
 | |
| 	if mysql.logError("could not decode msgid", err) {
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	result, err := mysql.insertHistory.Exec(value, msgidBytes)
 | |
| 	if mysql.logError("could not insert item", err) {
 | |
| 		return
 | |
| 	}
 | |
| 	id, err = result.LastInsertId()
 | |
| 	if mysql.logError("could not insert item", err) {
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	return
 | |
| }
 | |
| 
 | |
| func stringMinMax(first, second string) (min, max string) {
 | |
| 	if first < second {
 | |
| 		return first, second
 | |
| 	} else {
 | |
| 		return second, first
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (mysql *MySQL) AddDirectMessage(sender, recipient string, senderPersistent, recipientPersistent bool, item history.Item) (err error) {
 | |
| 	if mysql.db == nil {
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	if !(senderPersistent || recipientPersistent) {
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	if sender == "" || recipient == "" {
 | |
| 		return utils.ErrInvalidParams
 | |
| 	}
 | |
| 
 | |
| 	id, err := mysql.insertBase(item)
 | |
| 	if err != nil {
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	if senderPersistent {
 | |
| 		mysql.insertSequenceEntry(sender, item.Message.Time, id)
 | |
| 		if err != nil {
 | |
| 			return
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	if recipientPersistent && sender != recipient {
 | |
| 		err = mysql.insertSequenceEntry(recipient, item.Message.Time, id)
 | |
| 		if err != nil {
 | |
| 			return
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	err = mysql.insertConversationEntry(sender, recipient, item.Message.Time, id)
 | |
| 
 | |
| 	return
 | |
| }
 | |
| 
 | |
| func (mysql *MySQL) msgidToTime(msgid string) (result time.Time, err error) {
 | |
| 	// in theory, we could optimize out a roundtrip to the database by using a subquery instead:
 | |
| 	// sequence.nanotime > (
 | |
| 	//     SELECT sequence.nanotime FROM sequence, history
 | |
| 	//     WHERE sequence.history_id = history.id AND history.msgid = ?
 | |
| 	//     LIMIT 1)
 | |
| 	// however, this doesn't handle the BETWEEN case with one or two msgids, where we
 | |
| 	// don't initially know whether the interval is going forwards or backwards. to simplify
 | |
| 	// the logic,  resolve msgids to timestamps "manually" in all cases, using a separate query.
 | |
| 	decoded, err := decodeMsgid(msgid)
 | |
| 	if err != nil {
 | |
| 		return
 | |
| 	}
 | |
| 	row := mysql.db.QueryRow(`
 | |
| 		SELECT sequence.nanotime FROM sequence
 | |
| 		INNER JOIN history ON history.id = sequence.history_id
 | |
| 		WHERE history.msgid = ? LIMIT 1;`, decoded)
 | |
| 	var nanotime int64
 | |
| 	err = row.Scan(&nanotime)
 | |
| 	if mysql.logError("could not resolve msgid to time", err) {
 | |
| 		return
 | |
| 	}
 | |
| 	result = time.Unix(0, nanotime)
 | |
| 	return
 | |
| }
 | |
| 
 | |
| func (mysql *MySQL) selectItems(query string, args ...interface{}) (results []history.Item, err error) {
 | |
| 	rows, err := mysql.db.Query(query, args...)
 | |
| 	if mysql.logError("could not select history items", err) {
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	defer rows.Close()
 | |
| 
 | |
| 	for rows.Next() {
 | |
| 		var blob []byte
 | |
| 		var item history.Item
 | |
| 		err = rows.Scan(&blob)
 | |
| 		if mysql.logError("could not scan history item", err) {
 | |
| 			return
 | |
| 		}
 | |
| 		err = unmarshalItem(blob, &item)
 | |
| 		if mysql.logError("could not unmarshal history item", err) {
 | |
| 			return
 | |
| 		}
 | |
| 		results = append(results, item)
 | |
| 	}
 | |
| 	return
 | |
| }
 | |
| 
 | |
| func (mysql *MySQL) BetweenTimestamps(sender, recipient string, after, before, cutoff time.Time, limit int) (results []history.Item, err error) {
 | |
| 	useSequence := true
 | |
| 	var lowerTarget, upperTarget string
 | |
| 	if sender != "" {
 | |
| 		lowerTarget, upperTarget = stringMinMax(sender, recipient)
 | |
| 		useSequence = false
 | |
| 	}
 | |
| 
 | |
| 	table := "sequence"
 | |
| 	if !useSequence {
 | |
| 		table = "conversations"
 | |
| 	}
 | |
| 
 | |
| 	after, before, ascending := history.MinMaxAsc(after, before, cutoff)
 | |
| 	direction := "ASC"
 | |
| 	if !ascending {
 | |
| 		direction = "DESC"
 | |
| 	}
 | |
| 
 | |
| 	var queryBuf bytes.Buffer
 | |
| 
 | |
| 	args := make([]interface{}, 0, 6)
 | |
| 	fmt.Fprintf(&queryBuf,
 | |
| 		"SELECT history.data from history INNER JOIN %[1]s ON history.id = %[1]s.history_id WHERE", table)
 | |
| 	if useSequence {
 | |
| 		fmt.Fprintf(&queryBuf, " sequence.target = ?")
 | |
| 		args = append(args, recipient)
 | |
| 	} else {
 | |
| 		fmt.Fprintf(&queryBuf, " conversations.lower_target = ? AND conversations.upper_target = ?")
 | |
| 		args = append(args, lowerTarget)
 | |
| 		args = append(args, upperTarget)
 | |
| 	}
 | |
| 	if !after.IsZero() {
 | |
| 		fmt.Fprintf(&queryBuf, " AND %s.nanotime > ?", table)
 | |
| 		args = append(args, after.UnixNano())
 | |
| 	}
 | |
| 	if !before.IsZero() {
 | |
| 		fmt.Fprintf(&queryBuf, " AND %s.nanotime < ?", table)
 | |
| 		args = append(args, before.UnixNano())
 | |
| 	}
 | |
| 	fmt.Fprintf(&queryBuf, " ORDER BY %[1]s.nanotime %[2]s LIMIT ?;", table, direction)
 | |
| 	args = append(args, limit)
 | |
| 
 | |
| 	results, err = mysql.selectItems(queryBuf.String(), args...)
 | |
| 	if err == nil && !ascending {
 | |
| 		history.Reverse(results)
 | |
| 	}
 | |
| 	return
 | |
| }
 | |
| 
 | |
| func (mysql *MySQL) Close() {
 | |
| 	// closing the database will close our prepared statements as well
 | |
| 	if mysql.db != nil {
 | |
| 		mysql.db.Close()
 | |
| 	}
 | |
| 	mysql.db = nil
 | |
| }
 | |
| 
 | |
| // implements history.Sequence, emulating a single history buffer (for a channel,
 | |
| // a single user's DMs, or a DM conversation)
 | |
| type mySQLHistorySequence struct {
 | |
| 	mysql     *MySQL
 | |
| 	sender    string
 | |
| 	recipient string
 | |
| 	cutoff    time.Time
 | |
| }
 | |
| 
 | |
| func (s *mySQLHistorySequence) Between(start, end history.Selector, limit int) (results []history.Item, complete bool, err error) {
 | |
| 	startTime := start.Time
 | |
| 	if start.Msgid != "" {
 | |
| 		startTime, err = s.mysql.msgidToTime(start.Msgid)
 | |
| 		if err != nil {
 | |
| 			return nil, false, err
 | |
| 		}
 | |
| 	}
 | |
| 	endTime := end.Time
 | |
| 	if end.Msgid != "" {
 | |
| 		endTime, err = s.mysql.msgidToTime(end.Msgid)
 | |
| 		if err != nil {
 | |
| 			return nil, false, err
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	results, err = s.mysql.BetweenTimestamps(s.sender, s.recipient, startTime, endTime, s.cutoff, limit)
 | |
| 	return results, (err == nil), err
 | |
| }
 | |
| 
 | |
| func (s *mySQLHistorySequence) Around(start history.Selector, limit int) (results []history.Item, err error) {
 | |
| 	return history.GenericAround(s, start, limit)
 | |
| }
 | |
| 
 | |
| func (mysql *MySQL) MakeSequence(sender, recipient string, cutoff time.Time) history.Sequence {
 | |
| 	return &mySQLHistorySequence{
 | |
| 		sender:    sender,
 | |
| 		recipient: recipient,
 | |
| 		mysql:     mysql,
 | |
| 		cutoff:    cutoff,
 | |
| 	}
 | |
| }
 | 
