diff --git a/ergonomadic.go b/ergonomadic.go index 2c4cce76..cf91bdea 100644 --- a/ergonomadic.go +++ b/ergonomadic.go @@ -12,6 +12,7 @@ import ( func main() { conf := flag.String("conf", "ergonomadic.conf", "ergonomadic config file") initdb := flag.Bool("initdb", false, "initialize database") + upgradedb := flag.Bool("upgradedb", false, "update database") passwd := flag.String("genpasswd", "", "bcrypt a password") flag.Parse() @@ -35,7 +36,13 @@ func main() { if *initdb { irc.InitDB(config.Server.Database) - log.Println("database initialized: " + config.Server.Database) + log.Println("database initialized: ", config.Server.Database) + return + } + + if *upgradedb { + irc.UpgradeDB(config.Server.Database) + log.Println("database upgraded: ", config.Server.Database) return } diff --git a/irc/channel.go b/irc/channel.go index c44e191d..41f7f9a8 100644 --- a/irc/channel.go +++ b/irc/channel.go @@ -443,10 +443,12 @@ func (channel *Channel) Persist() (err error) { if channel.flags[Persistent] { _, err = channel.server.db.Exec(` INSERT OR REPLACE INTO channel - (name, flags, key, topic, user_limit) - VALUES (?, ?, ?, ?, ?)`, + (name, flags, key, topic, user_limit, ban_list, except_list, + invite_list) + VALUES (?, ?, ?, ?, ?, ?, ?, ?)`, channel.name, channel.flags.String(), channel.key, channel.topic, - channel.userLimit) + channel.userLimit, channel.lists[BanMask].String(), + channel.lists[ExceptMask].String(), channel.lists[InviteMask].String()) } else { _, err = channel.server.db.Exec(` DELETE FROM channel WHERE name = ?`, channel.name) diff --git a/irc/client_lookup_set.go b/irc/client_lookup_set.go index 44aa4bda..6506741e 100644 --- a/irc/client_lookup_set.go +++ b/irc/client_lookup_set.go @@ -219,6 +219,16 @@ func (set *UserMaskSet) Match(userhost string) bool { return set.regexp.MatchString(userhost) } +func (set *UserMaskSet) String() string { + masks := make([]string, len(set.masks)) + index := 0 + for mask := range set.masks { + masks[index] = mask + index += 1 + } + return strings.Join(masks, " ") +} + func (set *UserMaskSet) setRegexp() { if len(set.masks) == 0 { set.regexp = nil diff --git a/irc/database.go b/irc/database.go index c7f9264a..2a482ecf 100644 --- a/irc/database.go +++ b/irc/database.go @@ -2,6 +2,7 @@ package irc import ( "database/sql" + "fmt" _ "github.com/mattn/go-sqlite3" "log" "os" @@ -14,15 +15,30 @@ func InitDB(path string) { _, err := db.Exec(` CREATE TABLE channel ( name TEXT NOT NULL UNIQUE, - flags TEXT NOT NULL, - key TEXT NOT NULL, - topic TEXT NOT NULL, - user_limit INTEGER DEFAULT 0)`) + flags TEXT DEFAULT '', + key TEXT DEFAULT '', + topic TEXT DEFAULT '', + user_limit INTEGER DEFAULT 0, + ban_list TEXT DEFAULT '', + except_list TEXT DEFAULT '', + invite_list TEXT DEFAULT '')`) if err != nil { log.Fatal("initdb error: ", err) } } +func UpgradeDB(path string) { + db := OpenDB(path) + alter := `ALTER TABLE channel ADD COLUMN %s TEXT DEFAULT ''` + cols := []string{"ban_list", "except_list", "invite_list"} + for _, col := range cols { + _, err := db.Exec(fmt.Sprintf(alter, col)) + if err != nil { + log.Fatal("updatedb error: ", err) + } + } +} + func OpenDB(path string) *sql.DB { db, err := sql.Open("sqlite3", path) if err != nil { diff --git a/irc/server.go b/irc/server.go index 5f19f773..6d83fc50 100644 --- a/irc/server.go +++ b/irc/server.go @@ -64,9 +64,19 @@ func NewServer(config *Config) *Server { return server } +func loadChannelList(channel *Channel, list string, maskMode ChannelMode) { + if list == "" { + return + } + for _, mask := range strings.Split(list, " ") { + channel.lists[maskMode].Add(mask) + } +} + func (server *Server) loadChannels() { rows, err := server.db.Query(` - SELECT name, flags, key, topic, user_limit + SELECT name, flags, key, topic, user_limit, ban_list, except_list, + invite_list FROM channel`) if err != nil { log.Fatal("error loading channels: ", err) @@ -74,9 +84,11 @@ func (server *Server) loadChannels() { for rows.Next() { var name, flags, key, topic string var userLimit uint64 - err = rows.Scan(&name, &flags, &key, &topic, &userLimit) + var banList, exceptList, inviteList string + err = rows.Scan(&name, &flags, &key, &topic, &userLimit, &banList, + &exceptList, &inviteList) if err != nil { - log.Println(err) + log.Println("Server.loadChannels:", err) continue } @@ -87,6 +99,9 @@ func (server *Server) loadChannels() { channel.key = key channel.topic = topic channel.userLimit = userLimit + loadChannelList(channel, banList, BanMask) + loadChannelList(channel, exceptList, ExceptMask) + loadChannelList(channel, inviteList, InviteMask) } }