3
0
mirror of https://github.com/ergochat/ergo.git synced 2024-11-25 13:29:27 +01:00

Encapsulate SQL statements and refactor Save functions as transactionable.

This commit is contained in:
Jeremy Latt 2013-05-24 21:39:53 -07:00
parent f24bb5ee7d
commit 48ca57c43d
11 changed files with 347 additions and 12 deletions

1
.gitignore vendored
View File

@ -1,3 +1,4 @@
pkg pkg
bin bin
src/code.google.com/ src/code.google.com/
ergonomadic.db

View File

@ -1,19 +1,26 @@
# Ergonomadic # Ergonomadic
## A Go IRC Daemon Ergonomadic is an IRC daemon written from scratch in Go. It supports (or will)
multiple concurrent connections for the same nick.
Ergonomadic is an IRC daemon written from scratch in Go. ## Why?
### Why?
I wanted to learn Go. I wanted to learn Go.
### What's with the name? ## What's with the name?
"Ergonomadic" is an anagram of "Go IRC Daemon". "Ergonomadic" is an anagram of "Go IRC Daemon".
### Helpful Documentation ## Helpful Documentation
- [IRC Channel Management](http://tools.ietf.org/html/rfc2811) - [RFC 2811: IRC Channel Management](http://tools.ietf.org/html/rfc2811)
- [IRC Client Protocol](http://tools.ietf.org/html/rfc2812) - [RFC 2812: IRC Client Protocol](http://tools.ietf.org/html/rfc2812)
- [IRC Server Protocol](http://tools.ietf.org/html/rfc2813) - [RFC 2813: IRC Server Protocol](http://tools.ietf.org/html/rfc2813)
## Running the Server
```sh
$ ./build.sh
$ bin/ergonomadicdb init
$ bin/ergonomadic
```

View File

@ -1,4 +1,5 @@
#!/bin/bash #!/bin/bash
export GOPATH="$PWD" export GOPATH="$PWD"
go get "code.google.com/p/go.crypto/bcrypt" go get "code.google.com/p/go.crypto/bcrypt"
go install ergonomadic genpasswd go get "github.com/mattn/go-sqlite3"
go install ergonomadic genpasswd ergonomadicdb

10
sql/drop.sql Normal file
View File

@ -0,0 +1,10 @@
DROP INDEX user_id_channel_id
DROP TABLE user_channel
DROP INDEX channel_name
DROP INDEX channel_id
DROP TABLE channel
DROP INDEX user_nick
DROP INDEX user_id
DROP TABLE user

10
sql/init.sql Normal file
View File

@ -0,0 +1,10 @@
CREATE TABLE user (id integer not null primary key autoincrement, nick text not null, hash blob not null)
CREATE UNIQUE INDEX user_id ON user (id)
CREATE UNIQUE INDEX user_nick ON user (nick)
CREATE TABLE channel (id integer not null primary key autoincrement, name text not null)
CREATE UNIQUE INDEX channel_id ON channel (id)
CREATE UNIQUE INDEX channel_name ON channel (name)
CREATE_TABLE user_channel (id integer not null primary key autoincrement, user_id integer not null, channel_id integer not null)
CREATE UNIQUE INDEX user_id_channel_id ON user_channel (user_id, channel_id)

View File

@ -0,0 +1,22 @@
package main
import (
"flag"
"irc"
)
var (
actions = map[string]func(*irc.Database){
"init": func(db *irc.Database) {
db.InitTables()
},
"drop": func(db *irc.Database) {
db.DropTables()
},
}
)
func main() {
flag.Parse()
actions[flag.Arg(0)](irc.NewDatabase())
}

View File

@ -9,6 +9,7 @@ const (
) )
type Channel struct { type Channel struct {
id *RowId
server *Server server *Server
commands chan<- ChannelCommand commands chan<- ChannelCommand
replies chan<- Reply replies chan<- Reply
@ -30,6 +31,16 @@ func (set ChannelSet) Remove(channel *Channel) {
delete(set, channel) delete(set, channel)
} }
func (set ChannelSet) Ids() (ids []RowId) {
ids = []RowId{}
for channel := range set {
if channel.id != nil {
ids = append(ids, *channel.id)
}
}
return ids
}
type ChannelCommand interface { type ChannelCommand interface {
Command Command
HandleChannel(channel *Channel) HandleChannel(channel *Channel)
@ -52,6 +63,24 @@ func NewChannel(s *Server, name string) *Channel {
return channel return channel
} }
func (channel *Channel) Save(q Queryable) bool {
if channel.id == nil {
if err := InsertChannel(q, channel); err != nil {
return false
}
channelId, err := FindChannelIdByName(q, channel.name)
if err != nil {
return false
}
channel.id = &channelId
} else {
if err := UpdateChannel(q, channel); err != nil {
return false
}
}
return true
}
// Forward `Reply`s to all `User`s of the `Channel`. // Forward `Reply`s to all `User`s of the `Channel`.
func (channel *Channel) receiveReplies(replies <-chan Reply) { func (channel *Channel) receiveReplies(replies <-chan Reply) {
for reply := range replies { for reply := range replies {

View File

@ -53,7 +53,7 @@ func (command *BaseCommand) User() *User {
} }
func (command *BaseCommand) SetClient(c *Client) { func (command *BaseCommand) SetClient(c *Client) {
command.client = c *command = BaseCommand{c}
} }
func (command *BaseCommand) Source() Identifier { func (command *BaseCommand) Source() Identifier {

View File

@ -8,7 +8,7 @@ import (
) )
const ( const (
DEBUG_NET = false DEBUG_NET = true
) )
func readTrimmedLine(reader *bufio.Reader) (string, error) { func readTrimmedLine(reader *bufio.Reader) (string, error) {

220
src/irc/persistence.go Normal file
View File

@ -0,0 +1,220 @@
package irc
import (
"database/sql"
//"fmt"
"bufio"
_ "github.com/mattn/go-sqlite3"
"log"
"os"
"path/filepath"
"strings"
)
type Database struct {
*sql.DB
}
type Transaction struct {
*sql.Tx
}
type RowId uint64
type Queryable interface {
Exec(string, ...interface{}) (sql.Result, error)
Query(string, ...interface{}) (*sql.Rows, error)
QueryRow(string, ...interface{}) *sql.Row
}
type TransactionFunc func(Queryable) bool
//
// general
//
func NewDatabase() *Database {
db, err := sql.Open("sqlite3", "ergonomadic.db")
if err != nil {
panic("cannot open database")
}
return &Database{db}
}
func NewTransaction(tx *sql.Tx) *Transaction {
return &Transaction{tx}
}
func readLines(filename string) <-chan string {
file, err := os.Open(filename)
if err != nil {
panic(err)
}
reader := bufio.NewReader(file)
lines := make(chan string)
go func(lines chan<- string) {
defer file.Close()
defer close(lines)
for {
line, err := reader.ReadString('\n')
if err != nil {
break
}
line = strings.TrimSpace(line)
if line == "" {
continue
}
lines <- line
}
}(lines)
return lines
}
func (db *Database) execSqlFile(filename string) {
db.Transact(func(q Queryable) bool {
for line := range readLines(filepath.Join("sql", filename)) {
log.Println(line)
q.Exec(line)
}
return true
})
}
func (db *Database) InitTables() {
db.execSqlFile("init.sql")
}
func (db *Database) DropTables() {
db.execSqlFile("drop.sql")
}
func (db *Database) Transact(txf TransactionFunc) {
tx, err := db.Begin()
if err != nil {
panic(err)
}
if txf(tx) {
tx.Commit()
} else {
tx.Rollback()
}
}
//
// data
//
type UserRow struct {
id RowId
nick string
hash []byte
}
type ChannelRow struct {
id RowId
name string
}
// user
func FindUserByNick(q Queryable, nick string) (ur *UserRow) {
ur = new(UserRow)
row := q.QueryRow("SELECT * FROM user LIMIT 1 WHERE nick = ?", nick)
err := row.Scan(&ur.id, &ur.nick, &ur.hash)
if err != nil {
ur = nil
}
return
}
func FindUserIdByNick(q Queryable, nick string) (rowId RowId, err error) {
row := q.QueryRow("SELECT id FROM user WHERE nick = ?", nick)
err = row.Scan(&rowId)
return
}
func FindChannelByName(q Queryable, name string) (cr *ChannelRow) {
cr = new(ChannelRow)
row := q.QueryRow("SELECT * FROM channel LIMIT 1 WHERE name = ?", name)
err := row.Scan(&(cr.id), &(cr.name))
if err != nil {
cr = nil
}
return
}
func InsertUser(q Queryable, user *User) (err error) {
_, err = q.Exec("INSERT INTO user (nick, hash) VALUES (?, ?)",
user.nick, user.hash)
return
}
func UpdateUser(q Queryable, user *User) (err error) {
_, err = q.Exec("UPDATE user SET nick = ?, hash = ? WHERE id = ?",
user.nick, user.hash, *(user.id))
return
}
// user-channel
func DeleteAllUserChannels(q Queryable, rowId RowId) (err error) {
_, err = q.Exec("DELETE FROM user_channel WHERE user_id = ?", rowId)
return
}
func DeleteOtherUserChannels(q Queryable, userId RowId, channelIds []RowId) (err error) {
_, err = q.Exec(`DELETE FROM user_channel WHERE
user_id = ? AND channel_id NOT IN ?`, userId, channelIds)
return
}
func InsertUserChannels(q Queryable, userId RowId, channelIds []RowId) (err error) {
ins := "INSERT OR IGNORE INTO user_channel (user_id, channel_id) VALUES "
vals := strings.Repeat("(?, ?), ", len(channelIds))
vals = vals[0 : len(vals)-2]
args := make([]interface{}, 2*len(channelIds))
var i = 0
for channelId := range channelIds {
args[i] = userId
args[i+1] = channelId
i += 2
}
_, err = q.Exec(ins+vals, args)
return
}
// channel
func FindChannelIdByName(q Queryable, name string) (channelId RowId, err error) {
row := q.QueryRow("SELECT id FROM channel WHERE name = ?", name)
err = row.Scan(&channelId)
return
}
func FindChannelsForUser(q Queryable, userId RowId) (crs []ChannelRow) {
rows, err := q.Query(`SELECT * FROM channel WHERE id IN
(SELECT channel_id from user_channel WHERE user_id = ?)`, userId)
if err != nil {
panic(err)
}
crs = make([]ChannelRow, 0)
for rows.Next() {
cr := ChannelRow{}
if err := rows.Scan(&(cr.id), &(cr.name)); err != nil {
panic(err)
}
crs = append(crs, cr)
}
return
}
func InsertChannel(q Queryable, channel *Channel) (err error) {
_, err = q.Exec("INSERT INTO channel (name) VALUES (?)", channel.name)
return
}
func UpdateChannel(q Queryable, channel *Channel) (err error) {
_, err = q.Exec("UPDATE channel SET name = ? WHERE id = ?",
channel.name, *(channel.id))
return
}

View File

@ -16,6 +16,7 @@ type UserCommand interface {
} }
type User struct { type User struct {
id *RowId
nick string nick string
hash []byte hash []byte
server *Server server *Server
@ -56,12 +57,46 @@ func NewUser(nick string, password string, server *Server) *User {
replies: replies, replies: replies,
} }
user.SetPassword(password) user.SetPassword(password)
go user.receiveCommands(commands) go user.receiveCommands(commands)
go user.receiveReplies(replies) go user.receiveReplies(replies)
server.users[nick] = user server.users[nick] = user
return user return user
} }
func (user *User) Save(q Queryable) bool {
if user.id == nil {
if err := InsertUser(q, user); err != nil {
return false
}
userId, err := FindUserIdByNick(q, user.nick)
if err != nil {
return false
}
user.id = &userId
} else {
if err := UpdateUser(q, user); err != nil {
return false
}
}
channelIds := user.channels.Ids()
if len(channelIds) == 0 {
if err := DeleteAllUserChannels(q, *(user.id)); err != nil {
return false
}
} else {
if err := DeleteOtherUserChannels(q, *(user.id), channelIds); err != nil {
return false
}
if err := InsertUserChannels(q, *(user.id), channelIds); err != nil {
return false
}
}
return true
}
func (user *User) SetPassword(password string) { func (user *User) SetPassword(password string) {
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil { if err != nil {