persist and load channel mask lists

This commit is contained in:
Jeremy Latt 2014-03-07 18:14:02 -08:00
parent 04c30c8c9b
commit cf76d2bd77
5 changed files with 61 additions and 11 deletions

View File

@ -12,6 +12,7 @@ import (
func main() { func main() {
conf := flag.String("conf", "ergonomadic.conf", "ergonomadic config file") conf := flag.String("conf", "ergonomadic.conf", "ergonomadic config file")
initdb := flag.Bool("initdb", false, "initialize database") initdb := flag.Bool("initdb", false, "initialize database")
upgradedb := flag.Bool("upgradedb", false, "update database")
passwd := flag.String("genpasswd", "", "bcrypt a password") passwd := flag.String("genpasswd", "", "bcrypt a password")
flag.Parse() flag.Parse()
@ -35,7 +36,13 @@ func main() {
if *initdb { if *initdb {
irc.InitDB(config.Server.Database) irc.InitDB(config.Server.Database)
log.Println("database initialized: " + config.Server.Database) log.Println("database initialized: ", config.Server.Database)
return
}
if *upgradedb {
irc.UpgradeDB(config.Server.Database)
log.Println("database upgraded: ", config.Server.Database)
return return
} }

View File

@ -443,10 +443,12 @@ func (channel *Channel) Persist() (err error) {
if channel.flags[Persistent] { if channel.flags[Persistent] {
_, err = channel.server.db.Exec(` _, err = channel.server.db.Exec(`
INSERT OR REPLACE INTO channel INSERT OR REPLACE INTO channel
(name, flags, key, topic, user_limit) (name, flags, key, topic, user_limit, ban_list, except_list,
VALUES (?, ?, ?, ?, ?)`, invite_list)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
channel.name, channel.flags.String(), channel.key, channel.topic, channel.name, channel.flags.String(), channel.key, channel.topic,
channel.userLimit) channel.userLimit, channel.lists[BanMask].String(),
channel.lists[ExceptMask].String(), channel.lists[InviteMask].String())
} else { } else {
_, err = channel.server.db.Exec(` _, err = channel.server.db.Exec(`
DELETE FROM channel WHERE name = ?`, channel.name) DELETE FROM channel WHERE name = ?`, channel.name)

View File

@ -219,6 +219,16 @@ func (set *UserMaskSet) Match(userhost string) bool {
return set.regexp.MatchString(userhost) return set.regexp.MatchString(userhost)
} }
func (set *UserMaskSet) String() string {
masks := make([]string, len(set.masks))
index := 0
for mask := range set.masks {
masks[index] = mask
index += 1
}
return strings.Join(masks, " ")
}
func (set *UserMaskSet) setRegexp() { func (set *UserMaskSet) setRegexp() {
if len(set.masks) == 0 { if len(set.masks) == 0 {
set.regexp = nil set.regexp = nil

View File

@ -2,6 +2,7 @@ package irc
import ( import (
"database/sql" "database/sql"
"fmt"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
"log" "log"
"os" "os"
@ -14,15 +15,30 @@ func InitDB(path string) {
_, err := db.Exec(` _, err := db.Exec(`
CREATE TABLE channel ( CREATE TABLE channel (
name TEXT NOT NULL UNIQUE, name TEXT NOT NULL UNIQUE,
flags TEXT NOT NULL, flags TEXT DEFAULT '',
key TEXT NOT NULL, key TEXT DEFAULT '',
topic TEXT NOT NULL, topic TEXT DEFAULT '',
user_limit INTEGER DEFAULT 0)`) user_limit INTEGER DEFAULT 0,
ban_list TEXT DEFAULT '',
except_list TEXT DEFAULT '',
invite_list TEXT DEFAULT '')`)
if err != nil { if err != nil {
log.Fatal("initdb error: ", err) log.Fatal("initdb error: ", err)
} }
} }
func UpgradeDB(path string) {
db := OpenDB(path)
alter := `ALTER TABLE channel ADD COLUMN %s TEXT DEFAULT ''`
cols := []string{"ban_list", "except_list", "invite_list"}
for _, col := range cols {
_, err := db.Exec(fmt.Sprintf(alter, col))
if err != nil {
log.Fatal("updatedb error: ", err)
}
}
}
func OpenDB(path string) *sql.DB { func OpenDB(path string) *sql.DB {
db, err := sql.Open("sqlite3", path) db, err := sql.Open("sqlite3", path)
if err != nil { if err != nil {

View File

@ -64,9 +64,19 @@ func NewServer(config *Config) *Server {
return server return server
} }
func loadChannelList(channel *Channel, list string, maskMode ChannelMode) {
if list == "" {
return
}
for _, mask := range strings.Split(list, " ") {
channel.lists[maskMode].Add(mask)
}
}
func (server *Server) loadChannels() { func (server *Server) loadChannels() {
rows, err := server.db.Query(` rows, err := server.db.Query(`
SELECT name, flags, key, topic, user_limit SELECT name, flags, key, topic, user_limit, ban_list, except_list,
invite_list
FROM channel`) FROM channel`)
if err != nil { if err != nil {
log.Fatal("error loading channels: ", err) log.Fatal("error loading channels: ", err)
@ -74,9 +84,11 @@ func (server *Server) loadChannels() {
for rows.Next() { for rows.Next() {
var name, flags, key, topic string var name, flags, key, topic string
var userLimit uint64 var userLimit uint64
err = rows.Scan(&name, &flags, &key, &topic, &userLimit) var banList, exceptList, inviteList string
err = rows.Scan(&name, &flags, &key, &topic, &userLimit, &banList,
&exceptList, &inviteList)
if err != nil { if err != nil {
log.Println(err) log.Println("Server.loadChannels:", err)
continue continue
} }
@ -87,6 +99,9 @@ func (server *Server) loadChannels() {
channel.key = key channel.key = key
channel.topic = topic channel.topic = topic
channel.userLimit = userLimit channel.userLimit = userLimit
loadChannelList(channel, banList, BanMask)
loadChannelList(channel, exceptList, ExceptMask)
loadChannelList(channel, inviteList, InviteMask)
} }
} }