3
0
mirror of https://github.com/ergochat/ergo.git synced 2024-11-10 22:19:31 +01:00

persistent channels persisted to a sqlite db

This commit is contained in:
Jeremy Latt 2014-02-25 11:11:34 -08:00
parent de5538f5d5
commit 02abeeb164
6 changed files with 112 additions and 5 deletions

View File

@ -40,5 +40,6 @@ byte strings. You can generate them with e.g. `ergonomadic -genpasswd
```sh ```sh
go get go get
go install go install
ergonomadic -conf '/path/to/config.json' -initdb
ergonomadic -conf '/path/to/config.json' ergonomadic -conf '/path/to/config.json'
``` ```

View File

@ -2,11 +2,14 @@ package main
import ( import (
"code.google.com/p/go.crypto/bcrypt" "code.google.com/p/go.crypto/bcrypt"
"database/sql"
"encoding/base64" "encoding/base64"
"flag" "flag"
"fmt" "fmt"
"github.com/jlatt/ergonomadic/irc" "github.com/jlatt/ergonomadic/irc"
_ "github.com/mattn/go-sqlite3"
"log" "log"
"os"
) )
func genPasswd(passwd string) { func genPasswd(passwd string) {
@ -18,8 +21,30 @@ func genPasswd(passwd string) {
fmt.Println(encoded) fmt.Println(encoded)
} }
func initDB(config *irc.Config) {
os.Remove(config.Database())
db, err := sql.Open("sqlite3", config.Database())
if err != nil {
log.Fatal(err)
}
defer db.Close()
_, err = db.Exec(`
CREATE TABLE channel (
name TEXT NOT NULL UNIQUE,
flags TEXT,
key TEXT,
topic TEXT,
user_limit INTEGER)`)
if err != nil {
log.Fatal(err)
}
}
func main() { func main() {
conf := flag.String("conf", "ergonomadic.json", "ergonomadic config file") conf := flag.String("conf", "ergonomadic.json", "ergonomadic config file")
initdb := flag.Bool("initdb", false, "initialize database")
passwd := flag.String("genpasswd", "", "bcrypt a password") passwd := flag.String("genpasswd", "", "bcrypt a password")
flag.Parse() flag.Parse()
@ -31,9 +56,14 @@ func main() {
config, err := irc.LoadConfig(*conf) config, err := irc.LoadConfig(*conf)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
}
if *initdb {
initDB(config)
return return
} }
// TODO move to data structures
irc.DEBUG_NET = config.Debug["net"] irc.DEBUG_NET = config.Debug["net"]
irc.DEBUG_CLIENT = config.Debug["client"] irc.DEBUG_CLIENT = config.Debug["client"]
irc.DEBUG_CHANNEL = config.Debug["channel"] irc.DEBUG_CHANNEL = config.Debug["channel"]

View File

@ -33,7 +33,14 @@ func NewChannel(s *Server, name string) *Channel {
name: name, name: name,
server: s, server: s,
} }
s.channels[name] = channel s.channels[name] = channel
s.db.Exec(`INSERT INTO channel
(name, flags, key, topic, user_limit)
VALUES (?, ?, ?, ?, ?)`,
channel.name, channel.flags.String(), channel.key, channel.topic,
channel.userLimit)
return channel return channel
} }
@ -142,7 +149,9 @@ func (channel *Channel) Join(client *Client, key string) {
client.channels.Add(channel) client.channels.Add(channel)
channel.members.Add(client) channel.members.Add(client)
if len(channel.members) == 1 { if len(channel.members) == 1 {
if !channel.flags[Persistent] {
channel.members[client][ChannelCreator] = true channel.members[client][ChannelCreator] = true
}
channel.members[client][ChannelOperator] = true channel.members[client][ChannelOperator] = true
} }
@ -198,6 +207,10 @@ func (channel *Channel) SetTopic(client *Client, topic string) {
} }
channel.topic = topic channel.topic = topic
channel.server.db.Exec(`
UPDATE channel
SET topic = ?
WHERE name = ?`, channel.topic, channel.name)
reply := RplTopicMsg(client, channel) reply := RplTopicMsg(client, channel)
for member := range channel.members { for member := range channel.members {
@ -361,6 +374,11 @@ func (channel *Channel) Mode(client *Client, changes ChannelModeChanges) {
for member := range channel.members { for member := range channel.members {
member.Reply(reply) member.Reply(reply)
} }
channel.server.db.Exec(`
UPDATE channel
SET flags = ?
WHERE name = ?`, channel.flags.String(), channel.name)
} }
} }

View File

@ -26,6 +26,11 @@ type Config struct {
Name string Name string
Operators []OperatorConfig Operators []OperatorConfig
Password string Password string
directory string
}
func (conf *Config) Database() string {
return filepath.Join(conf.directory, "ergonomadic.db")
} }
func (conf *Config) PasswordBytes() []byte { func (conf *Config) PasswordBytes() []byte {
@ -75,9 +80,8 @@ func LoadConfig(filename string) (config *Config, err error) {
return return
} }
dir := filepath.Dir(filename) config.directory = filepath.Dir(filename)
config.MOTD = filepath.Join(dir, config.MOTD) config.MOTD = filepath.Join(config.directory, config.MOTD)
for _, lconf := range config.Listeners { for _, lconf := range config.Listeners {
if lconf.Net == "" { if lconf.Net == "" {
lconf.Net = "tcp" lconf.Net = "tcp"

View File

@ -4,11 +4,14 @@ import (
"bufio" "bufio"
"crypto/rand" "crypto/rand"
"crypto/tls" "crypto/tls"
"database/sql"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
_ "github.com/mattn/go-sqlite3"
"log" "log"
"net" "net"
"os" "os"
"os/signal"
"runtime" "runtime"
"runtime/debug" "runtime/debug"
"runtime/pprof" "runtime/pprof"
@ -21,30 +24,65 @@ type Server struct {
clients ClientNameMap clients ClientNameMap
commands chan Command commands chan Command
ctime time.Time ctime time.Time
db *sql.DB
idle chan *Client idle chan *Client
motdFile string motdFile string
name string name string
newConns chan net.Conn newConns chan net.Conn
operators map[string][]byte operators map[string][]byte
password []byte password []byte
signals chan os.Signal
timeout chan *Client timeout chan *Client
} }
func NewServer(config *Config) *Server { func NewServer(config *Config) *Server {
db, err := sql.Open("sqlite3", config.Database())
if err != nil {
log.Fatal(err)
}
server := &Server{ server := &Server{
channels: make(ChannelNameMap), channels: make(ChannelNameMap),
clients: make(ClientNameMap), clients: make(ClientNameMap),
commands: make(chan Command, 16), commands: make(chan Command, 16),
ctime: time.Now(), ctime: time.Now(),
db: db,
idle: make(chan *Client, 16), idle: make(chan *Client, 16),
motdFile: config.MOTD, motdFile: config.MOTD,
name: config.Name, name: config.Name,
newConns: make(chan net.Conn, 16), newConns: make(chan net.Conn, 16),
operators: config.OperatorsMap(), operators: config.OperatorsMap(),
password: config.PasswordBytes(), password: config.PasswordBytes(),
signals: make(chan os.Signal, 1),
timeout: make(chan *Client, 16), timeout: make(chan *Client, 16),
} }
signal.Notify(server.signals, os.Interrupt, os.Kill)
rows, err := db.Query(`
SELECT name, flags, key, topic, user_limit
FROM channel`)
if err != nil {
log.Fatal(err)
}
for rows.Next() {
var name, flags, key, topic string
var userLimit uint64
err = rows.Scan(&name, &flags, &key, &topic, &userLimit)
if err != nil {
log.Println(err)
continue
}
channel := NewChannel(server, name)
for flag := range flags {
channel.flags[ChannelMode(flag)] = true
}
channel.key = key
channel.topic = topic
channel.userLimit = userLimit
}
for _, listenerConf := range config.Listeners { for _, listenerConf := range config.Listeners {
go server.listen(listenerConf) go server.listen(listenerConf)
} }
@ -97,8 +135,14 @@ func (server *Server) processCommand(cmd Command) {
} }
func (server *Server) Run() { func (server *Server) Run() {
for { done := false
for !done {
select { select {
case <-server.signals:
server.db.Close()
done = true
continue
case conn := <-server.newConns: case conn := <-server.newConns:
NewClient(server, conn) NewClient(server, conn)

View File

@ -106,6 +106,16 @@ func (clients ClientNameMap) Remove(client *Client) error {
type ChannelModeSet map[ChannelMode]bool type ChannelModeSet map[ChannelMode]bool
func (set ChannelModeSet) String() string {
strs := make([]string, len(set))
index := 0
for mode := range set {
strs[index] = mode.String()
index += 1
}
return strings.Join(strs, "")
}
type ClientSet map[*Client]bool type ClientSet map[*Client]bool
func (clients ClientSet) Add(client *Client) { func (clients ClientSet) Add(client *Client) {