mirror of
https://github.com/ergochat/ergo.git
synced 2025-01-03 08:32:43 +01:00
persist and load channel mask lists
This commit is contained in:
parent
04c30c8c9b
commit
cf76d2bd77
@ -12,6 +12,7 @@ import (
|
|||||||
func main() {
|
func main() {
|
||||||
conf := flag.String("conf", "ergonomadic.conf", "ergonomadic config file")
|
conf := flag.String("conf", "ergonomadic.conf", "ergonomadic config file")
|
||||||
initdb := flag.Bool("initdb", false, "initialize database")
|
initdb := flag.Bool("initdb", false, "initialize database")
|
||||||
|
upgradedb := flag.Bool("upgradedb", false, "update database")
|
||||||
passwd := flag.String("genpasswd", "", "bcrypt a password")
|
passwd := flag.String("genpasswd", "", "bcrypt a password")
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
@ -35,7 +36,13 @@ func main() {
|
|||||||
|
|
||||||
if *initdb {
|
if *initdb {
|
||||||
irc.InitDB(config.Server.Database)
|
irc.InitDB(config.Server.Database)
|
||||||
log.Println("database initialized: " + config.Server.Database)
|
log.Println("database initialized: ", config.Server.Database)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if *upgradedb {
|
||||||
|
irc.UpgradeDB(config.Server.Database)
|
||||||
|
log.Println("database upgraded: ", config.Server.Database)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -443,10 +443,12 @@ func (channel *Channel) Persist() (err error) {
|
|||||||
if channel.flags[Persistent] {
|
if channel.flags[Persistent] {
|
||||||
_, err = channel.server.db.Exec(`
|
_, err = channel.server.db.Exec(`
|
||||||
INSERT OR REPLACE INTO channel
|
INSERT OR REPLACE INTO channel
|
||||||
(name, flags, key, topic, user_limit)
|
(name, flags, key, topic, user_limit, ban_list, except_list,
|
||||||
VALUES (?, ?, ?, ?, ?)`,
|
invite_list)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||||
channel.name, channel.flags.String(), channel.key, channel.topic,
|
channel.name, channel.flags.String(), channel.key, channel.topic,
|
||||||
channel.userLimit)
|
channel.userLimit, channel.lists[BanMask].String(),
|
||||||
|
channel.lists[ExceptMask].String(), channel.lists[InviteMask].String())
|
||||||
} else {
|
} else {
|
||||||
_, err = channel.server.db.Exec(`
|
_, err = channel.server.db.Exec(`
|
||||||
DELETE FROM channel WHERE name = ?`, channel.name)
|
DELETE FROM channel WHERE name = ?`, channel.name)
|
||||||
|
@ -219,6 +219,16 @@ func (set *UserMaskSet) Match(userhost string) bool {
|
|||||||
return set.regexp.MatchString(userhost)
|
return set.regexp.MatchString(userhost)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (set *UserMaskSet) String() string {
|
||||||
|
masks := make([]string, len(set.masks))
|
||||||
|
index := 0
|
||||||
|
for mask := range set.masks {
|
||||||
|
masks[index] = mask
|
||||||
|
index += 1
|
||||||
|
}
|
||||||
|
return strings.Join(masks, " ")
|
||||||
|
}
|
||||||
|
|
||||||
func (set *UserMaskSet) setRegexp() {
|
func (set *UserMaskSet) setRegexp() {
|
||||||
if len(set.masks) == 0 {
|
if len(set.masks) == 0 {
|
||||||
set.regexp = nil
|
set.regexp = nil
|
||||||
|
@ -2,6 +2,7 @@ package irc
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
_ "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
@ -14,15 +15,30 @@ func InitDB(path string) {
|
|||||||
_, err := db.Exec(`
|
_, err := db.Exec(`
|
||||||
CREATE TABLE channel (
|
CREATE TABLE channel (
|
||||||
name TEXT NOT NULL UNIQUE,
|
name TEXT NOT NULL UNIQUE,
|
||||||
flags TEXT NOT NULL,
|
flags TEXT DEFAULT '',
|
||||||
key TEXT NOT NULL,
|
key TEXT DEFAULT '',
|
||||||
topic TEXT NOT NULL,
|
topic TEXT DEFAULT '',
|
||||||
user_limit INTEGER DEFAULT 0)`)
|
user_limit INTEGER DEFAULT 0,
|
||||||
|
ban_list TEXT DEFAULT '',
|
||||||
|
except_list TEXT DEFAULT '',
|
||||||
|
invite_list TEXT DEFAULT '')`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal("initdb error: ", err)
|
log.Fatal("initdb error: ", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func UpgradeDB(path string) {
|
||||||
|
db := OpenDB(path)
|
||||||
|
alter := `ALTER TABLE channel ADD COLUMN %s TEXT DEFAULT ''`
|
||||||
|
cols := []string{"ban_list", "except_list", "invite_list"}
|
||||||
|
for _, col := range cols {
|
||||||
|
_, err := db.Exec(fmt.Sprintf(alter, col))
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal("updatedb error: ", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func OpenDB(path string) *sql.DB {
|
func OpenDB(path string) *sql.DB {
|
||||||
db, err := sql.Open("sqlite3", path)
|
db, err := sql.Open("sqlite3", path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -64,9 +64,19 @@ func NewServer(config *Config) *Server {
|
|||||||
return server
|
return server
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func loadChannelList(channel *Channel, list string, maskMode ChannelMode) {
|
||||||
|
if list == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, mask := range strings.Split(list, " ") {
|
||||||
|
channel.lists[maskMode].Add(mask)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (server *Server) loadChannels() {
|
func (server *Server) loadChannels() {
|
||||||
rows, err := server.db.Query(`
|
rows, err := server.db.Query(`
|
||||||
SELECT name, flags, key, topic, user_limit
|
SELECT name, flags, key, topic, user_limit, ban_list, except_list,
|
||||||
|
invite_list
|
||||||
FROM channel`)
|
FROM channel`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal("error loading channels: ", err)
|
log.Fatal("error loading channels: ", err)
|
||||||
@ -74,9 +84,11 @@ func (server *Server) loadChannels() {
|
|||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var name, flags, key, topic string
|
var name, flags, key, topic string
|
||||||
var userLimit uint64
|
var userLimit uint64
|
||||||
err = rows.Scan(&name, &flags, &key, &topic, &userLimit)
|
var banList, exceptList, inviteList string
|
||||||
|
err = rows.Scan(&name, &flags, &key, &topic, &userLimit, &banList,
|
||||||
|
&exceptList, &inviteList)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println(err)
|
log.Println("Server.loadChannels:", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -87,6 +99,9 @@ func (server *Server) loadChannels() {
|
|||||||
channel.key = key
|
channel.key = key
|
||||||
channel.topic = topic
|
channel.topic = topic
|
||||||
channel.userLimit = userLimit
|
channel.userLimit = userLimit
|
||||||
|
loadChannelList(channel, banList, BanMask)
|
||||||
|
loadChannelList(channel, exceptList, ExceptMask)
|
||||||
|
loadChannelList(channel, inviteList, InviteMask)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user