From 02abeeb164130e3aad760547f4def696794a11dd Mon Sep 17 00:00:00 2001 From: Jeremy Latt Date: Tue, 25 Feb 2014 11:11:34 -0800 Subject: [PATCH] persistent channels persisted to a sqlite db --- README.md | 1 + ergonomadic.go | 30 ++++++++++++++++++++++++++++++ irc/channel.go | 20 +++++++++++++++++++- irc/config.go | 10 +++++++--- irc/server.go | 46 +++++++++++++++++++++++++++++++++++++++++++++- irc/types.go | 10 ++++++++++ 6 files changed, 112 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index cc68f07b..5702e183 100644 --- a/README.md +++ b/README.md @@ -40,5 +40,6 @@ byte strings. You can generate them with e.g. `ergonomadic -genpasswd ```sh go get go install +ergonomadic -conf '/path/to/config.json' -initdb ergonomadic -conf '/path/to/config.json' ``` diff --git a/ergonomadic.go b/ergonomadic.go index 9321e5da..e4633b6e 100644 --- a/ergonomadic.go +++ b/ergonomadic.go @@ -2,11 +2,14 @@ package main import ( "code.google.com/p/go.crypto/bcrypt" + "database/sql" "encoding/base64" "flag" "fmt" "github.com/jlatt/ergonomadic/irc" + _ "github.com/mattn/go-sqlite3" "log" + "os" ) func genPasswd(passwd string) { @@ -18,8 +21,30 @@ func genPasswd(passwd string) { fmt.Println(encoded) } +func initDB(config *irc.Config) { + os.Remove(config.Database()) + + db, err := sql.Open("sqlite3", config.Database()) + if err != nil { + log.Fatal(err) + } + defer db.Close() + + _, err = db.Exec(` + CREATE TABLE channel ( + name TEXT NOT NULL UNIQUE, + flags TEXT, + key TEXT, + topic TEXT, + user_limit INTEGER)`) + if err != nil { + log.Fatal(err) + } +} + func main() { conf := flag.String("conf", "ergonomadic.json", "ergonomadic config file") + initdb := flag.Bool("initdb", false, "initialize database") passwd := flag.String("genpasswd", "", "bcrypt a password") flag.Parse() @@ -31,9 +56,14 @@ func main() { config, err := irc.LoadConfig(*conf) if err != nil { log.Fatal(err) + } + + if *initdb { + initDB(config) return } + // TODO move to data structures irc.DEBUG_NET = config.Debug["net"] irc.DEBUG_CLIENT = config.Debug["client"] irc.DEBUG_CHANNEL = config.Debug["channel"] diff --git a/irc/channel.go b/irc/channel.go index 7254eea2..fb66073b 100644 --- a/irc/channel.go +++ b/irc/channel.go @@ -33,7 +33,14 @@ func NewChannel(s *Server, name string) *Channel { name: name, server: s, } + s.channels[name] = channel + s.db.Exec(`INSERT INTO channel + (name, flags, key, topic, user_limit) + VALUES (?, ?, ?, ?, ?)`, + channel.name, channel.flags.String(), channel.key, channel.topic, + channel.userLimit) + return channel } @@ -142,7 +149,9 @@ func (channel *Channel) Join(client *Client, key string) { client.channels.Add(channel) channel.members.Add(client) if len(channel.members) == 1 { - channel.members[client][ChannelCreator] = true + if !channel.flags[Persistent] { + channel.members[client][ChannelCreator] = true + } channel.members[client][ChannelOperator] = true } @@ -198,6 +207,10 @@ func (channel *Channel) SetTopic(client *Client, topic string) { } channel.topic = topic + channel.server.db.Exec(` + UPDATE channel + SET topic = ? + WHERE name = ?`, channel.topic, channel.name) reply := RplTopicMsg(client, channel) for member := range channel.members { @@ -361,6 +374,11 @@ func (channel *Channel) Mode(client *Client, changes ChannelModeChanges) { for member := range channel.members { member.Reply(reply) } + + channel.server.db.Exec(` + UPDATE channel + SET flags = ? + WHERE name = ?`, channel.flags.String(), channel.name) } } diff --git a/irc/config.go b/irc/config.go index 5ddebb21..07518f00 100644 --- a/irc/config.go +++ b/irc/config.go @@ -26,6 +26,11 @@ type Config struct { Name string Operators []OperatorConfig Password string + directory string +} + +func (conf *Config) Database() string { + return filepath.Join(conf.directory, "ergonomadic.db") } func (conf *Config) PasswordBytes() []byte { @@ -75,9 +80,8 @@ func LoadConfig(filename string) (config *Config, err error) { return } - dir := filepath.Dir(filename) - config.MOTD = filepath.Join(dir, config.MOTD) - + config.directory = filepath.Dir(filename) + config.MOTD = filepath.Join(config.directory, config.MOTD) for _, lconf := range config.Listeners { if lconf.Net == "" { lconf.Net = "tcp" diff --git a/irc/server.go b/irc/server.go index 889bea8f..00a8ff94 100644 --- a/irc/server.go +++ b/irc/server.go @@ -4,11 +4,14 @@ import ( "bufio" "crypto/rand" "crypto/tls" + "database/sql" "encoding/binary" "fmt" + _ "github.com/mattn/go-sqlite3" "log" "net" "os" + "os/signal" "runtime" "runtime/debug" "runtime/pprof" @@ -21,30 +24,65 @@ type Server struct { clients ClientNameMap commands chan Command ctime time.Time + db *sql.DB idle chan *Client motdFile string name string newConns chan net.Conn operators map[string][]byte password []byte + signals chan os.Signal timeout chan *Client } func NewServer(config *Config) *Server { + db, err := sql.Open("sqlite3", config.Database()) + if err != nil { + log.Fatal(err) + } + server := &Server{ channels: make(ChannelNameMap), clients: make(ClientNameMap), commands: make(chan Command, 16), ctime: time.Now(), + db: db, idle: make(chan *Client, 16), motdFile: config.MOTD, name: config.Name, newConns: make(chan net.Conn, 16), operators: config.OperatorsMap(), password: config.PasswordBytes(), + signals: make(chan os.Signal, 1), timeout: make(chan *Client, 16), } + signal.Notify(server.signals, os.Interrupt, os.Kill) + + rows, err := db.Query(` + SELECT name, flags, key, topic, user_limit + FROM channel`) + if err != nil { + log.Fatal(err) + } + for rows.Next() { + var name, flags, key, topic string + var userLimit uint64 + err = rows.Scan(&name, &flags, &key, &topic, &userLimit) + if err != nil { + log.Println(err) + continue + } + + channel := NewChannel(server, name) + for flag := range flags { + channel.flags[ChannelMode(flag)] = true + } + channel.key = key + channel.topic = topic + channel.userLimit = userLimit + } + for _, listenerConf := range config.Listeners { go server.listen(listenerConf) } @@ -97,8 +135,14 @@ func (server *Server) processCommand(cmd Command) { } func (server *Server) Run() { - for { + done := false + for !done { select { + case <-server.signals: + server.db.Close() + done = true + continue + case conn := <-server.newConns: NewClient(server, conn) diff --git a/irc/types.go b/irc/types.go index 0f15d314..8e25f2f6 100644 --- a/irc/types.go +++ b/irc/types.go @@ -106,6 +106,16 @@ func (clients ClientNameMap) Remove(client *Client) error { type ChannelModeSet map[ChannelMode]bool +func (set ChannelModeSet) String() string { + strs := make([]string, len(set)) + index := 0 + for mode := range set { + strs[index] = mode.String() + index += 1 + } + return strings.Join(strs, "") +} + type ClientSet map[*Client]bool func (clients ClientSet) Add(client *Client) {