persist and load channel mask lists

This commit is contained in:
Jeremy Latt 2014-03-07 18:14:02 -08:00
parent 04c30c8c9b
commit cf76d2bd77
5 changed files with 61 additions and 11 deletions

View File

@ -12,6 +12,7 @@ import (
func main() {
conf := flag.String("conf", "ergonomadic.conf", "ergonomadic config file")
initdb := flag.Bool("initdb", false, "initialize database")
upgradedb := flag.Bool("upgradedb", false, "update database")
passwd := flag.String("genpasswd", "", "bcrypt a password")
flag.Parse()
@ -35,7 +36,13 @@ func main() {
if *initdb {
irc.InitDB(config.Server.Database)
log.Println("database initialized: " + config.Server.Database)
log.Println("database initialized: ", config.Server.Database)
return
}
if *upgradedb {
irc.UpgradeDB(config.Server.Database)
log.Println("database upgraded: ", config.Server.Database)
return
}

View File

@ -443,10 +443,12 @@ func (channel *Channel) Persist() (err error) {
if channel.flags[Persistent] {
_, err = channel.server.db.Exec(`
INSERT OR REPLACE INTO channel
(name, flags, key, topic, user_limit)
VALUES (?, ?, ?, ?, ?)`,
(name, flags, key, topic, user_limit, ban_list, except_list,
invite_list)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
channel.name, channel.flags.String(), channel.key, channel.topic,
channel.userLimit)
channel.userLimit, channel.lists[BanMask].String(),
channel.lists[ExceptMask].String(), channel.lists[InviteMask].String())
} else {
_, err = channel.server.db.Exec(`
DELETE FROM channel WHERE name = ?`, channel.name)

View File

@ -219,6 +219,16 @@ func (set *UserMaskSet) Match(userhost string) bool {
return set.regexp.MatchString(userhost)
}
func (set *UserMaskSet) String() string {
masks := make([]string, len(set.masks))
index := 0
for mask := range set.masks {
masks[index] = mask
index += 1
}
return strings.Join(masks, " ")
}
func (set *UserMaskSet) setRegexp() {
if len(set.masks) == 0 {
set.regexp = nil

View File

@ -2,6 +2,7 @@ package irc
import (
"database/sql"
"fmt"
_ "github.com/mattn/go-sqlite3"
"log"
"os"
@ -14,15 +15,30 @@ func InitDB(path string) {
_, err := db.Exec(`
CREATE TABLE channel (
name TEXT NOT NULL UNIQUE,
flags TEXT NOT NULL,
key TEXT NOT NULL,
topic TEXT NOT NULL,
user_limit INTEGER DEFAULT 0)`)
flags TEXT DEFAULT '',
key TEXT DEFAULT '',
topic TEXT DEFAULT '',
user_limit INTEGER DEFAULT 0,
ban_list TEXT DEFAULT '',
except_list TEXT DEFAULT '',
invite_list TEXT DEFAULT '')`)
if err != nil {
log.Fatal("initdb error: ", err)
}
}
func UpgradeDB(path string) {
db := OpenDB(path)
alter := `ALTER TABLE channel ADD COLUMN %s TEXT DEFAULT ''`
cols := []string{"ban_list", "except_list", "invite_list"}
for _, col := range cols {
_, err := db.Exec(fmt.Sprintf(alter, col))
if err != nil {
log.Fatal("updatedb error: ", err)
}
}
}
func OpenDB(path string) *sql.DB {
db, err := sql.Open("sqlite3", path)
if err != nil {

View File

@ -64,9 +64,19 @@ func NewServer(config *Config) *Server {
return server
}
func loadChannelList(channel *Channel, list string, maskMode ChannelMode) {
if list == "" {
return
}
for _, mask := range strings.Split(list, " ") {
channel.lists[maskMode].Add(mask)
}
}
func (server *Server) loadChannels() {
rows, err := server.db.Query(`
SELECT name, flags, key, topic, user_limit
SELECT name, flags, key, topic, user_limit, ban_list, except_list,
invite_list
FROM channel`)
if err != nil {
log.Fatal("error loading channels: ", err)
@ -74,9 +84,11 @@ func (server *Server) loadChannels() {
for rows.Next() {
var name, flags, key, topic string
var userLimit uint64
err = rows.Scan(&name, &flags, &key, &topic, &userLimit)
var banList, exceptList, inviteList string
err = rows.Scan(&name, &flags, &key, &topic, &userLimit, &banList,
&exceptList, &inviteList)
if err != nil {
log.Println(err)
log.Println("Server.loadChannels:", err)
continue
}
@ -87,6 +99,9 @@ func (server *Server) loadChannels() {
channel.key = key
channel.topic = topic
channel.userLimit = userLimit
loadChannelList(channel, banList, BanMask)
loadChannelList(channel, exceptList, ExceptMask)
loadChannelList(channel, inviteList, InviteMask)
}
}