diff --git a/.gitignore b/.gitignore index ac4d22c4..0b9380a8 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ pkg bin src/code.google.com/ +ergonomadic.db diff --git a/README.md b/README.md index 6bef509f..03ba3a16 100644 --- a/README.md +++ b/README.md @@ -1,19 +1,26 @@ # Ergonomadic -## A Go IRC Daemon +Ergonomadic is an IRC daemon written from scratch in Go. It supports (or will) +multiple concurrent connections for the same nick. -Ergonomadic is an IRC daemon written from scratch in Go. - -### Why? +## Why? I wanted to learn Go. -### What's with the name? +## What's with the name? "Ergonomadic" is an anagram of "Go IRC Daemon". -### Helpful Documentation +## Helpful Documentation -- [IRC Channel Management](http://tools.ietf.org/html/rfc2811) -- [IRC Client Protocol](http://tools.ietf.org/html/rfc2812) -- [IRC Server Protocol](http://tools.ietf.org/html/rfc2813) +- [RFC 2811: IRC Channel Management](http://tools.ietf.org/html/rfc2811) +- [RFC 2812: IRC Client Protocol](http://tools.ietf.org/html/rfc2812) +- [RFC 2813: IRC Server Protocol](http://tools.ietf.org/html/rfc2813) + +## Running the Server + +```sh +$ ./build.sh +$ bin/ergonomadicdb init +$ bin/ergonomadic +``` diff --git a/build.sh b/build.sh index f6bb868f..f5ed68c3 100755 --- a/build.sh +++ b/build.sh @@ -1,4 +1,5 @@ #!/bin/bash export GOPATH="$PWD" go get "code.google.com/p/go.crypto/bcrypt" -go install ergonomadic genpasswd +go get "github.com/mattn/go-sqlite3" +go install ergonomadic genpasswd ergonomadicdb diff --git a/sql/drop.sql b/sql/drop.sql new file mode 100644 index 00000000..b8347c95 --- /dev/null +++ b/sql/drop.sql @@ -0,0 +1,10 @@ +DROP INDEX user_id_channel_id +DROP TABLE user_channel + +DROP INDEX channel_name +DROP INDEX channel_id +DROP TABLE channel + +DROP INDEX user_nick +DROP INDEX user_id +DROP TABLE user diff --git a/sql/init.sql b/sql/init.sql new file mode 100644 index 00000000..7b7f2bd3 --- /dev/null +++ b/sql/init.sql @@ -0,0 +1,10 @@ +CREATE TABLE user (id integer not null primary key autoincrement, nick text not null, hash blob not null) +CREATE UNIQUE INDEX user_id ON user (id) +CREATE UNIQUE INDEX user_nick ON user (nick) + +CREATE TABLE channel (id integer not null primary key autoincrement, name text not null) +CREATE UNIQUE INDEX channel_id ON channel (id) +CREATE UNIQUE INDEX channel_name ON channel (name) + +CREATE_TABLE user_channel (id integer not null primary key autoincrement, user_id integer not null, channel_id integer not null) +CREATE UNIQUE INDEX user_id_channel_id ON user_channel (user_id, channel_id) diff --git a/src/ergonomadicdb/ergonomadicdb.go b/src/ergonomadicdb/ergonomadicdb.go new file mode 100644 index 00000000..bc925ed8 --- /dev/null +++ b/src/ergonomadicdb/ergonomadicdb.go @@ -0,0 +1,22 @@ +package main + +import ( + "flag" + "irc" +) + +var ( + actions = map[string]func(*irc.Database){ + "init": func(db *irc.Database) { + db.InitTables() + }, + "drop": func(db *irc.Database) { + db.DropTables() + }, + } +) + +func main() { + flag.Parse() + actions[flag.Arg(0)](irc.NewDatabase()) +} diff --git a/src/irc/channel.go b/src/irc/channel.go index 3ba213fd..b8237f76 100644 --- a/src/irc/channel.go +++ b/src/irc/channel.go @@ -9,6 +9,7 @@ const ( ) type Channel struct { + id *RowId server *Server commands chan<- ChannelCommand replies chan<- Reply @@ -30,6 +31,16 @@ func (set ChannelSet) Remove(channel *Channel) { delete(set, channel) } +func (set ChannelSet) Ids() (ids []RowId) { + ids = []RowId{} + for channel := range set { + if channel.id != nil { + ids = append(ids, *channel.id) + } + } + return ids +} + type ChannelCommand interface { Command HandleChannel(channel *Channel) @@ -52,6 +63,24 @@ func NewChannel(s *Server, name string) *Channel { return channel } +func (channel *Channel) Save(q Queryable) bool { + if channel.id == nil { + if err := InsertChannel(q, channel); err != nil { + return false + } + channelId, err := FindChannelIdByName(q, channel.name) + if err != nil { + return false + } + channel.id = &channelId + } else { + if err := UpdateChannel(q, channel); err != nil { + return false + } + } + return true +} + // Forward `Reply`s to all `User`s of the `Channel`. func (channel *Channel) receiveReplies(replies <-chan Reply) { for reply := range replies { diff --git a/src/irc/commands.go b/src/irc/commands.go index f6499cea..8ac00e06 100644 --- a/src/irc/commands.go +++ b/src/irc/commands.go @@ -53,7 +53,7 @@ func (command *BaseCommand) User() *User { } func (command *BaseCommand) SetClient(c *Client) { - command.client = c + *command = BaseCommand{c} } func (command *BaseCommand) Source() Identifier { diff --git a/src/irc/net.go b/src/irc/net.go index ce371c55..d28d1f9c 100644 --- a/src/irc/net.go +++ b/src/irc/net.go @@ -8,7 +8,7 @@ import ( ) const ( - DEBUG_NET = false + DEBUG_NET = true ) func readTrimmedLine(reader *bufio.Reader) (string, error) { diff --git a/src/irc/persistence.go b/src/irc/persistence.go new file mode 100644 index 00000000..8a68684a --- /dev/null +++ b/src/irc/persistence.go @@ -0,0 +1,220 @@ +package irc + +import ( + "database/sql" + //"fmt" + "bufio" + _ "github.com/mattn/go-sqlite3" + "log" + "os" + "path/filepath" + "strings" +) + +type Database struct { + *sql.DB +} + +type Transaction struct { + *sql.Tx +} + +type RowId uint64 + +type Queryable interface { + Exec(string, ...interface{}) (sql.Result, error) + Query(string, ...interface{}) (*sql.Rows, error) + QueryRow(string, ...interface{}) *sql.Row +} + +type TransactionFunc func(Queryable) bool + +// +// general +// + +func NewDatabase() *Database { + db, err := sql.Open("sqlite3", "ergonomadic.db") + if err != nil { + panic("cannot open database") + } + return &Database{db} +} + +func NewTransaction(tx *sql.Tx) *Transaction { + return &Transaction{tx} +} + +func readLines(filename string) <-chan string { + file, err := os.Open(filename) + if err != nil { + panic(err) + } + reader := bufio.NewReader(file) + lines := make(chan string) + go func(lines chan<- string) { + defer file.Close() + defer close(lines) + for { + line, err := reader.ReadString('\n') + if err != nil { + break + } + line = strings.TrimSpace(line) + if line == "" { + continue + } + lines <- line + } + }(lines) + return lines +} + +func (db *Database) execSqlFile(filename string) { + db.Transact(func(q Queryable) bool { + for line := range readLines(filepath.Join("sql", filename)) { + log.Println(line) + q.Exec(line) + } + return true + }) +} + +func (db *Database) InitTables() { + db.execSqlFile("init.sql") +} + +func (db *Database) DropTables() { + db.execSqlFile("drop.sql") +} + +func (db *Database) Transact(txf TransactionFunc) { + tx, err := db.Begin() + if err != nil { + panic(err) + } + if txf(tx) { + tx.Commit() + } else { + tx.Rollback() + } +} + +// +// data +// + +type UserRow struct { + id RowId + nick string + hash []byte +} + +type ChannelRow struct { + id RowId + name string +} + +// user + +func FindUserByNick(q Queryable, nick string) (ur *UserRow) { + ur = new(UserRow) + row := q.QueryRow("SELECT * FROM user LIMIT 1 WHERE nick = ?", nick) + err := row.Scan(&ur.id, &ur.nick, &ur.hash) + if err != nil { + ur = nil + } + return +} + +func FindUserIdByNick(q Queryable, nick string) (rowId RowId, err error) { + row := q.QueryRow("SELECT id FROM user WHERE nick = ?", nick) + err = row.Scan(&rowId) + return +} + +func FindChannelByName(q Queryable, name string) (cr *ChannelRow) { + cr = new(ChannelRow) + row := q.QueryRow("SELECT * FROM channel LIMIT 1 WHERE name = ?", name) + err := row.Scan(&(cr.id), &(cr.name)) + if err != nil { + cr = nil + } + return +} + +func InsertUser(q Queryable, user *User) (err error) { + _, err = q.Exec("INSERT INTO user (nick, hash) VALUES (?, ?)", + user.nick, user.hash) + return +} + +func UpdateUser(q Queryable, user *User) (err error) { + _, err = q.Exec("UPDATE user SET nick = ?, hash = ? WHERE id = ?", + user.nick, user.hash, *(user.id)) + return +} + +// user-channel + +func DeleteAllUserChannels(q Queryable, rowId RowId) (err error) { + _, err = q.Exec("DELETE FROM user_channel WHERE user_id = ?", rowId) + return +} + +func DeleteOtherUserChannels(q Queryable, userId RowId, channelIds []RowId) (err error) { + _, err = q.Exec(`DELETE FROM user_channel WHERE +user_id = ? AND channel_id NOT IN ?`, userId, channelIds) + return +} + +func InsertUserChannels(q Queryable, userId RowId, channelIds []RowId) (err error) { + ins := "INSERT OR IGNORE INTO user_channel (user_id, channel_id) VALUES " + vals := strings.Repeat("(?, ?), ", len(channelIds)) + vals = vals[0 : len(vals)-2] + args := make([]interface{}, 2*len(channelIds)) + var i = 0 + for channelId := range channelIds { + args[i] = userId + args[i+1] = channelId + i += 2 + } + _, err = q.Exec(ins+vals, args) + return +} + +// channel + +func FindChannelIdByName(q Queryable, name string) (channelId RowId, err error) { + row := q.QueryRow("SELECT id FROM channel WHERE name = ?", name) + err = row.Scan(&channelId) + return +} + +func FindChannelsForUser(q Queryable, userId RowId) (crs []ChannelRow) { + rows, err := q.Query(`SELECT * FROM channel WHERE id IN +(SELECT channel_id from user_channel WHERE user_id = ?)`, userId) + if err != nil { + panic(err) + } + crs = make([]ChannelRow, 0) + for rows.Next() { + cr := ChannelRow{} + if err := rows.Scan(&(cr.id), &(cr.name)); err != nil { + panic(err) + } + crs = append(crs, cr) + } + return +} + +func InsertChannel(q Queryable, channel *Channel) (err error) { + _, err = q.Exec("INSERT INTO channel (name) VALUES (?)", channel.name) + return +} + +func UpdateChannel(q Queryable, channel *Channel) (err error) { + _, err = q.Exec("UPDATE channel SET name = ? WHERE id = ?", + channel.name, *(channel.id)) + return +} diff --git a/src/irc/user.go b/src/irc/user.go index 20023393..1f2076e8 100644 --- a/src/irc/user.go +++ b/src/irc/user.go @@ -16,6 +16,7 @@ type UserCommand interface { } type User struct { + id *RowId nick string hash []byte server *Server @@ -56,12 +57,46 @@ func NewUser(nick string, password string, server *Server) *User { replies: replies, } user.SetPassword(password) + go user.receiveCommands(commands) go user.receiveReplies(replies) server.users[nick] = user + return user } +func (user *User) Save(q Queryable) bool { + if user.id == nil { + if err := InsertUser(q, user); err != nil { + return false + } + userId, err := FindUserIdByNick(q, user.nick) + if err != nil { + return false + } + user.id = &userId + } else { + if err := UpdateUser(q, user); err != nil { + return false + } + } + + channelIds := user.channels.Ids() + if len(channelIds) == 0 { + if err := DeleteAllUserChannels(q, *(user.id)); err != nil { + return false + } + } else { + if err := DeleteOtherUserChannels(q, *(user.id), channelIds); err != nil { + return false + } + if err := InsertUserChannels(q, *(user.id), channelIds); err != nil { + return false + } + } + return true +} + func (user *User) SetPassword(password string) { hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil {