3
0
mirror of https://github.com/ergochat/ergo.git synced 2024-11-22 20:09:41 +01:00

Merge pull request #1 from jlatt/persistent-channels

persist channels to a sqlite db
This commit is contained in:
Jeremy Latt 2014-02-25 16:46:16 -08:00
commit 2f149cad1d
13 changed files with 200 additions and 83 deletions

21
LICENSE Normal file
View File

@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2014 Jeremy Latt
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -40,5 +40,6 @@ byte strings. You can generate them with e.g. `ergonomadic -genpasswd
```sh ```sh
go get go get
go install go install
ergonomadic -conf '/path/to/config.json' -initdb
ergonomadic -conf '/path/to/config.json' ergonomadic -conf '/path/to/config.json'
``` ```

View File

@ -1,13 +1,39 @@
{ "name": "irc.example.com", // Ergonomadic IRC Server Config
// -----------------------------
// Passwords are generated by `ergonomadic -genpasswd "$plaintext"`.
// Comments are not allowed in the actual config file.
{
// `name` is usually a hostname.
"name": "irc.example.com",
// The path to the MOTD is relative to this file's directory.
"motd": "motd.txt", "motd": "motd.txt",
"listeners": [
{ "address": "localhost:7777" }, // PASS command password
{ "address": "[::1]:7777" } ], "password": "JDJhJDA0JHBBenUyV3Z5UU5iWUpiYmlNMlNLZC5VRDZDM21HUzFVbmxLUUI3NTVTLkZJOERLdUFaUWNt",
"operators": [
{ "name": "root", // `listeners` are places to bind and listen for
"password": "JDJhJDEwJFRWWGUya2E3Unk5bnZlb2o3alJ0ZnVQQm9ZVW1HOE53L29nVHg5QWh5TnpaMmtOaEwya1Vl" } ], // connections. http://golang.org/pkg/net/#Dial demonstrates valid
// values for `net` and `address`. `net` is optional and defaults
// to `tcp`.
"listeners": [ {
"address": "localhost:7777"
}, {
"net": "tcp6",
"address": "[::1]:7777"
} ],
// Operators for the OPER command
"operators": [ {
"name": "root",
"password": "JDJhJDA0JHBBenUyV3Z5UU5iWUpiYmlNMlNLZC5VRDZDM21HUzFVbmxLUUI3NTVTLkZJOERLdUFaUWNt"
} ],
// Global debug flags. `net` generates a lot of output.
"debug": { "debug": {
"net": true, "net": true,
"client": false, "client": false,
"channel": false, "channel": false,
"server": false } } "server": false
}
}

View File

@ -2,11 +2,14 @@ package main
import ( import (
"code.google.com/p/go.crypto/bcrypt" "code.google.com/p/go.crypto/bcrypt"
"database/sql"
"encoding/base64" "encoding/base64"
"flag" "flag"
"fmt" "fmt"
"github.com/jlatt/ergonomadic/irc" "github.com/jlatt/ergonomadic/irc"
_ "github.com/mattn/go-sqlite3"
"log" "log"
"os"
) )
func genPasswd(passwd string) { func genPasswd(passwd string) {
@ -18,8 +21,30 @@ func genPasswd(passwd string) {
fmt.Println(encoded) fmt.Println(encoded)
} }
func initDB(config *irc.Config) {
os.Remove(config.Database())
db, err := sql.Open("sqlite3", config.Database())
if err != nil {
log.Fatal(err)
}
defer db.Close()
_, err = db.Exec(`
CREATE TABLE channel (
name TEXT NOT NULL UNIQUE,
flags TEXT NOT NULL,
key TEXT NOT NULL,
topic TEXT NOT NULL,
user_limit INTEGER DEFAULT 0)`)
if err != nil {
log.Fatal(err)
}
}
func main() { func main() {
conf := flag.String("conf", "ergonomadic.json", "ergonomadic config file") conf := flag.String("conf", "ergonomadic.json", "ergonomadic config file")
initdb := flag.Bool("initdb", false, "initialize database")
passwd := flag.String("genpasswd", "", "bcrypt a password") passwd := flag.String("genpasswd", "", "bcrypt a password")
flag.Parse() flag.Parse()
@ -31,9 +56,14 @@ func main() {
config, err := irc.LoadConfig(*conf) config, err := irc.LoadConfig(*conf)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
}
if *initdb {
initDB(config)
return return
} }
// TODO move to data structures
irc.DEBUG_NET = config.Debug["net"] irc.DEBUG_NET = config.Debug["net"]
irc.DEBUG_CLIENT = config.Debug["client"] irc.DEBUG_CLIENT = config.Debug["client"]
irc.DEBUG_CHANNEL = config.Debug["channel"] irc.DEBUG_CHANNEL = config.Debug["channel"]

View File

@ -33,7 +33,9 @@ func NewChannel(s *Server, name string) *Channel {
name: name, name: name,
server: s, server: s,
} }
s.channels[name] = channel s.channels[name] = channel
return channel return channel
} }
@ -142,7 +144,9 @@ func (channel *Channel) Join(client *Client, key string) {
client.channels.Add(channel) client.channels.Add(channel)
channel.members.Add(client) channel.members.Add(client)
if len(channel.members) == 1 { if len(channel.members) == 1 {
if !channel.flags[Persistent] {
channel.members[client][ChannelCreator] = true channel.members[client][ChannelCreator] = true
}
channel.members[client][ChannelOperator] = true channel.members[client][ChannelOperator] = true
} }
@ -166,7 +170,7 @@ func (channel *Channel) Part(client *Client, message string) {
} }
channel.Quit(client) channel.Quit(client)
if channel.IsEmpty() { if !channel.flags[Persistent] && channel.IsEmpty() {
channel.server.channels.Remove(channel) channel.server.channels.Remove(channel)
} }
} }
@ -203,6 +207,8 @@ func (channel *Channel) SetTopic(client *Client, topic string) {
for member := range channel.members { for member := range channel.members {
member.Reply(reply) member.Reply(reply)
} }
channel.Persist()
} }
func (channel *Channel) CanSpeak(client *Client) bool { func (channel *Channel) CanSpeak(client *Client) bool {
@ -296,7 +302,7 @@ func (channel *Channel) applyMode(client *Client, change *ChannelModeChange) boo
} }
client.RplEndOfMaskList(change.mode, channel) client.RplEndOfMaskList(change.mode, channel)
case Moderated, NoOutside, OpOnlyTopic, Private: case Moderated, NoOutside, OpOnlyTopic, Persistent, Private:
return channel.applyModeFlag(client, change.mode, change.op) return channel.applyModeFlag(client, change.mode, change.op)
case Key: case Key:
@ -361,6 +367,21 @@ func (channel *Channel) Mode(client *Client, changes ChannelModeChanges) {
for member := range channel.members { for member := range channel.members {
member.Reply(reply) member.Reply(reply)
} }
channel.Persist()
}
}
func (channel *Channel) Persist() {
if channel.flags[Persistent] {
channel.server.db.Exec(`
INSERT OR REPLACE INTO channel
(name, flags, key, topic)
VALUES (?, ?, ?, ?, ?)`,
channel.name, channel.flags.String(), channel.key, channel.topic,
channel.userLimit)
} else {
channel.server.db.Exec(`DELETE FROM channel WHERE name = ?`, channel.name)
} }
} }

View File

@ -96,35 +96,17 @@ var (
spacesExpr = regexp.MustCompile(` +`) spacesExpr = regexp.MustCompile(` +`)
) )
func parseArg(line string) (arg string, rest string) { func parseLine(line string) (StringCode, []string) {
if line == "" { var parts []string
return if colonIndex := strings.IndexRune(line, ':'); colonIndex >= 0 {
} lastArg := line[colonIndex+len(":"):]
line = line[:colonIndex-len(" ")]
if strings.HasPrefix(line, ":") { parts = append(spacesExpr.Split(line, -1), lastArg)
arg = line[1:]
} else { } else {
parts := spacesExpr.Split(line, 2) parts = spacesExpr.Split(line, -1)
arg = parts[0]
if len(parts) > 1 {
rest = parts[1]
}
}
return
}
func parseLine(line string) (command StringCode, args []string) {
args = make([]string, 0)
for arg, rest := parseArg(line); arg != ""; arg, rest = parseArg(rest) {
if arg == "" {
continue
} }
args = append(args, arg) return StringCode(strings.ToUpper(parts[0])), parts[1:]
}
if len(args) > 0 {
command, args = StringCode(strings.ToUpper(args[0])), args[1:]
}
return
} }
// <command> [args...] // <command> [args...]

View File

@ -26,6 +26,11 @@ type Config struct {
Name string Name string
Operators []OperatorConfig Operators []OperatorConfig
Password string Password string
directory string
}
func (conf *Config) Database() string {
return filepath.Join(conf.directory, "ergonomadic.db")
} }
func (conf *Config) PasswordBytes() []byte { func (conf *Config) PasswordBytes() []byte {
@ -75,9 +80,8 @@ func LoadConfig(filename string) (config *Config, err error) {
return return
} }
dir := filepath.Dir(filename) config.directory = filepath.Dir(filename)
config.MOTD = filepath.Join(dir, config.MOTD) config.MOTD = filepath.Join(config.directory, config.MOTD)
for _, lconf := range config.Listeners { for _, lconf := range config.Listeners {
if lconf.Net == "" { if lconf.Net == "" {
lconf.Net = "tcp" lconf.Net = "tcp"

View File

@ -23,7 +23,7 @@ var (
) )
const ( const (
SERVER_VERSION = "1.1.0" SEM_VER = "ergonomadic-1.1.0"
CRLF = "\r\n" CRLF = "\r\n"
MAX_REPLY_LEN = 512 - len(CRLF) MAX_REPLY_LEN = 512 - len(CRLF)
@ -209,7 +209,7 @@ const (
LocalOperator UserMode = 'O' LocalOperator UserMode = 'O'
Operator UserMode = 'o' Operator UserMode = 'o'
Restricted UserMode = 'r' Restricted UserMode = 'r'
ServerNotice UserMode = 's' ServerNotice UserMode = 's' // deprecated
WallOps UserMode = 'w' WallOps UserMode = 'w'
Anonymous ChannelMode = 'a' // flag Anonymous ChannelMode = 'a' // flag
@ -223,6 +223,7 @@ const (
Moderated ChannelMode = 'm' // flag Moderated ChannelMode = 'm' // flag
NoOutside ChannelMode = 'n' // flag NoOutside ChannelMode = 'n' // flag
OpOnlyTopic ChannelMode = 't' // flag OpOnlyTopic ChannelMode = 't' // flag
Persistent ChannelMode = 'P' // flag
Private ChannelMode = 'p' // flag Private ChannelMode = 'p' // flag
Quiet ChannelMode = 'q' // flag Quiet ChannelMode = 'q' // flag
ReOp ChannelMode = 'r' // flag ReOp ChannelMode = 'r' // flag

View File

@ -151,7 +151,7 @@ func (target *Client) RplWelcome() {
func (target *Client) RplYourHost() { func (target *Client) RplYourHost() {
target.NumericReply(RPL_YOURHOST, target.NumericReply(RPL_YOURHOST,
":Your host is %s, running version %s", target.server.name, SERVER_VERSION) ":Your host is %s, running version %s", target.server.name, SEM_VER)
} }
func (target *Client) RplCreated() { func (target *Client) RplCreated() {
@ -161,7 +161,7 @@ func (target *Client) RplCreated() {
func (target *Client) RplMyInfo() { func (target *Client) RplMyInfo() {
target.NumericReply(RPL_MYINFO, target.NumericReply(RPL_MYINFO,
"%s %s aiOorsw abeIikmntpqrsl", target.server.name, SERVER_VERSION) "%s %s aiOorsw abeIikmntpqrsl", target.server.name, SEM_VER)
} }
func (target *Client) RplUModeIs(client *Client) { func (target *Client) RplUModeIs(client *Client) {
@ -371,7 +371,7 @@ func (target *Client) RplWhoisChannels(client *Client) {
func (target *Client) RplVersion() { func (target *Client) RplVersion() {
target.NumericReply(RPL_VERSION, target.NumericReply(RPL_VERSION,
"ergonomadic-%s %s", SERVER_VERSION, target.server.name) "%s %s", SEM_VER, target.server.name)
} }
func (target *Client) RplInviting(invitee *Client, channel string) { func (target *Client) RplInviting(invitee *Client, channel string) {

View File

@ -4,11 +4,14 @@ import (
"bufio" "bufio"
"crypto/rand" "crypto/rand"
"crypto/tls" "crypto/tls"
"database/sql"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
_ "github.com/mattn/go-sqlite3"
"log" "log"
"net" "net"
"os" "os"
"os/signal"
"runtime" "runtime"
"runtime/debug" "runtime/debug"
"runtime/pprof" "runtime/pprof"
@ -21,30 +24,43 @@ type Server struct {
clients ClientNameMap clients ClientNameMap
commands chan Command commands chan Command
ctime time.Time ctime time.Time
db *sql.DB
idle chan *Client idle chan *Client
motdFile string motdFile string
name string name string
newConns chan net.Conn newConns chan net.Conn
operators map[string][]byte operators map[string][]byte
password []byte password []byte
signals chan os.Signal
timeout chan *Client timeout chan *Client
} }
func NewServer(config *Config) *Server { func NewServer(config *Config) *Server {
db, err := sql.Open("sqlite3", config.Database())
if err != nil {
log.Fatal(err)
}
server := &Server{ server := &Server{
channels: make(ChannelNameMap), channels: make(ChannelNameMap),
clients: make(ClientNameMap), clients: make(ClientNameMap),
commands: make(chan Command, 16), commands: make(chan Command, 16),
ctime: time.Now(), ctime: time.Now(),
db: db,
idle: make(chan *Client, 16), idle: make(chan *Client, 16),
motdFile: config.MOTD, motdFile: config.MOTD,
name: config.Name, name: config.Name,
newConns: make(chan net.Conn, 16), newConns: make(chan net.Conn, 16),
operators: config.OperatorsMap(), operators: config.OperatorsMap(),
password: config.PasswordBytes(), password: config.PasswordBytes(),
signals: make(chan os.Signal, 1),
timeout: make(chan *Client, 16), timeout: make(chan *Client, 16),
} }
signal.Notify(server.signals, os.Interrupt, os.Kill)
server.loadChannels()
for _, listenerConf := range config.Listeners { for _, listenerConf := range config.Listeners {
go server.listen(listenerConf) go server.listen(listenerConf)
} }
@ -52,6 +68,32 @@ func NewServer(config *Config) *Server {
return server return server
} }
func (server *Server) loadChannels() {
rows, err := server.db.Query(`
SELECT name, flags, key, topic, user_limit
FROM channel`)
if err != nil {
log.Fatal(err)
}
for rows.Next() {
var name, flags, key, topic string
var userLimit uint64
err = rows.Scan(&name, &flags, &key, &topic, &userLimit)
if err != nil {
log.Println(err)
continue
}
channel := NewChannel(server, name)
for _, flag := range flags {
channel.flags[ChannelMode(flag)] = true
}
channel.key = key
channel.topic = topic
channel.userLimit = userLimit
}
}
func (server *Server) processCommand(cmd Command) { func (server *Server) processCommand(cmd Command) {
client := cmd.Client() client := cmd.Client()
if DEBUG_SERVER { if DEBUG_SERVER {
@ -97,8 +139,14 @@ func (server *Server) processCommand(cmd Command) {
} }
func (server *Server) Run() { func (server *Server) Run() {
for { done := false
for !done {
select { select {
case <-server.signals:
server.db.Close()
done = true
continue
case conn := <-server.newConns: case conn := <-server.newConns:
NewClient(server, conn) NewClient(server, conn)

View File

@ -106,6 +106,19 @@ func (clients ClientNameMap) Remove(client *Client) error {
type ChannelModeSet map[ChannelMode]bool type ChannelModeSet map[ChannelMode]bool
func (set ChannelModeSet) String() string {
if len(set) == 0 {
return ""
}
strs := make([]string, len(set))
index := 0
for mode := range set {
strs[index] = mode.String()
index += 1
}
return strings.Join(strs, "")
}
type ClientSet map[*Client]bool type ClientSet map[*Client]bool
func (clients ClientSet) Add(client *Client) { func (clients ClientSet) Add(client *Client) {

View File

@ -1,10 +0,0 @@
DROP INDEX IF EXISTS index_user_id_channel_id;
DROP TABLE IF EXISTS user_channel;
DROP INDEX IF EXISTS index_channel_name;
DROP INDEX IF EXISTS index_channel_id;
DROP TABLE IF EXISTS channel;
DROP INDEX IF EXISTS index_user_nick;
DROP INDEX IF EXISTS index_user_id;
DROP TABLE IF EXISTS user;

View File

@ -1,20 +0,0 @@
CREATE TABLE user (
id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
nick TEXT NOT NULL UNIQUE,
hash BLOB NOT NULL
);
CREATE INDEX index_user_id ON user(id);
CREATE INDEX index_user_nick ON user(nick);
CREATE TABLE channel (
id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
name TEXT NOT NULL UNIQUE
);
CREATE INDEX index_channel_id ON channel(id);
CREATE TABLE user_channel (
id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
user_id INTEGER NOT NULL,
channel_id INTEGER NOT NULL
);
CREATE UNIQUE INDEX index_user_id_channel_id ON user_channel (user_id, channel_id);