mirror of
https://github.com/ergochat/ergo.git
synced 2025-01-03 08:32:43 +01:00
persist and load channel mask lists
This commit is contained in:
parent
04c30c8c9b
commit
cf76d2bd77
@ -12,6 +12,7 @@ import (
|
||||
func main() {
|
||||
conf := flag.String("conf", "ergonomadic.conf", "ergonomadic config file")
|
||||
initdb := flag.Bool("initdb", false, "initialize database")
|
||||
upgradedb := flag.Bool("upgradedb", false, "update database")
|
||||
passwd := flag.String("genpasswd", "", "bcrypt a password")
|
||||
flag.Parse()
|
||||
|
||||
@ -35,7 +36,13 @@ func main() {
|
||||
|
||||
if *initdb {
|
||||
irc.InitDB(config.Server.Database)
|
||||
log.Println("database initialized: " + config.Server.Database)
|
||||
log.Println("database initialized: ", config.Server.Database)
|
||||
return
|
||||
}
|
||||
|
||||
if *upgradedb {
|
||||
irc.UpgradeDB(config.Server.Database)
|
||||
log.Println("database upgraded: ", config.Server.Database)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -443,10 +443,12 @@ func (channel *Channel) Persist() (err error) {
|
||||
if channel.flags[Persistent] {
|
||||
_, err = channel.server.db.Exec(`
|
||||
INSERT OR REPLACE INTO channel
|
||||
(name, flags, key, topic, user_limit)
|
||||
VALUES (?, ?, ?, ?, ?)`,
|
||||
(name, flags, key, topic, user_limit, ban_list, except_list,
|
||||
invite_list)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
channel.name, channel.flags.String(), channel.key, channel.topic,
|
||||
channel.userLimit)
|
||||
channel.userLimit, channel.lists[BanMask].String(),
|
||||
channel.lists[ExceptMask].String(), channel.lists[InviteMask].String())
|
||||
} else {
|
||||
_, err = channel.server.db.Exec(`
|
||||
DELETE FROM channel WHERE name = ?`, channel.name)
|
||||
|
@ -219,6 +219,16 @@ func (set *UserMaskSet) Match(userhost string) bool {
|
||||
return set.regexp.MatchString(userhost)
|
||||
}
|
||||
|
||||
func (set *UserMaskSet) String() string {
|
||||
masks := make([]string, len(set.masks))
|
||||
index := 0
|
||||
for mask := range set.masks {
|
||||
masks[index] = mask
|
||||
index += 1
|
||||
}
|
||||
return strings.Join(masks, " ")
|
||||
}
|
||||
|
||||
func (set *UserMaskSet) setRegexp() {
|
||||
if len(set.masks) == 0 {
|
||||
set.regexp = nil
|
||||
|
@ -2,6 +2,7 @@ package irc
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"log"
|
||||
"os"
|
||||
@ -14,15 +15,30 @@ func InitDB(path string) {
|
||||
_, err := db.Exec(`
|
||||
CREATE TABLE channel (
|
||||
name TEXT NOT NULL UNIQUE,
|
||||
flags TEXT NOT NULL,
|
||||
key TEXT NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
user_limit INTEGER DEFAULT 0)`)
|
||||
flags TEXT DEFAULT '',
|
||||
key TEXT DEFAULT '',
|
||||
topic TEXT DEFAULT '',
|
||||
user_limit INTEGER DEFAULT 0,
|
||||
ban_list TEXT DEFAULT '',
|
||||
except_list TEXT DEFAULT '',
|
||||
invite_list TEXT DEFAULT '')`)
|
||||
if err != nil {
|
||||
log.Fatal("initdb error: ", err)
|
||||
}
|
||||
}
|
||||
|
||||
func UpgradeDB(path string) {
|
||||
db := OpenDB(path)
|
||||
alter := `ALTER TABLE channel ADD COLUMN %s TEXT DEFAULT ''`
|
||||
cols := []string{"ban_list", "except_list", "invite_list"}
|
||||
for _, col := range cols {
|
||||
_, err := db.Exec(fmt.Sprintf(alter, col))
|
||||
if err != nil {
|
||||
log.Fatal("updatedb error: ", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func OpenDB(path string) *sql.DB {
|
||||
db, err := sql.Open("sqlite3", path)
|
||||
if err != nil {
|
||||
|
@ -64,9 +64,19 @@ func NewServer(config *Config) *Server {
|
||||
return server
|
||||
}
|
||||
|
||||
func loadChannelList(channel *Channel, list string, maskMode ChannelMode) {
|
||||
if list == "" {
|
||||
return
|
||||
}
|
||||
for _, mask := range strings.Split(list, " ") {
|
||||
channel.lists[maskMode].Add(mask)
|
||||
}
|
||||
}
|
||||
|
||||
func (server *Server) loadChannels() {
|
||||
rows, err := server.db.Query(`
|
||||
SELECT name, flags, key, topic, user_limit
|
||||
SELECT name, flags, key, topic, user_limit, ban_list, except_list,
|
||||
invite_list
|
||||
FROM channel`)
|
||||
if err != nil {
|
||||
log.Fatal("error loading channels: ", err)
|
||||
@ -74,9 +84,11 @@ func (server *Server) loadChannels() {
|
||||
for rows.Next() {
|
||||
var name, flags, key, topic string
|
||||
var userLimit uint64
|
||||
err = rows.Scan(&name, &flags, &key, &topic, &userLimit)
|
||||
var banList, exceptList, inviteList string
|
||||
err = rows.Scan(&name, &flags, &key, &topic, &userLimit, &banList,
|
||||
&exceptList, &inviteList)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
log.Println("Server.loadChannels:", err)
|
||||
continue
|
||||
}
|
||||
|
||||
@ -87,6 +99,9 @@ func (server *Server) loadChannels() {
|
||||
channel.key = key
|
||||
channel.topic = topic
|
||||
channel.userLimit = userLimit
|
||||
loadChannelList(channel, banList, BanMask)
|
||||
loadChannelList(channel, exceptList, ExceptMask)
|
||||
loadChannelList(channel, inviteList, InviteMask)
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user