mirror of
https://github.com/ergochat/ergo.git
synced 2024-11-10 22:19:31 +01:00
Encapsulate SQL statements and refactor Save functions as transactionable.
This commit is contained in:
parent
f24bb5ee7d
commit
48ca57c43d
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,3 +1,4 @@
|
|||||||
pkg
|
pkg
|
||||||
bin
|
bin
|
||||||
src/code.google.com/
|
src/code.google.com/
|
||||||
|
ergonomadic.db
|
||||||
|
25
README.md
25
README.md
@ -1,19 +1,26 @@
|
|||||||
# Ergonomadic
|
# Ergonomadic
|
||||||
|
|
||||||
## A Go IRC Daemon
|
Ergonomadic is an IRC daemon written from scratch in Go. It supports (or will)
|
||||||
|
multiple concurrent connections for the same nick.
|
||||||
|
|
||||||
Ergonomadic is an IRC daemon written from scratch in Go.
|
## Why?
|
||||||
|
|
||||||
### Why?
|
|
||||||
|
|
||||||
I wanted to learn Go.
|
I wanted to learn Go.
|
||||||
|
|
||||||
### What's with the name?
|
## What's with the name?
|
||||||
|
|
||||||
"Ergonomadic" is an anagram of "Go IRC Daemon".
|
"Ergonomadic" is an anagram of "Go IRC Daemon".
|
||||||
|
|
||||||
### Helpful Documentation
|
## Helpful Documentation
|
||||||
|
|
||||||
- [IRC Channel Management](http://tools.ietf.org/html/rfc2811)
|
- [RFC 2811: IRC Channel Management](http://tools.ietf.org/html/rfc2811)
|
||||||
- [IRC Client Protocol](http://tools.ietf.org/html/rfc2812)
|
- [RFC 2812: IRC Client Protocol](http://tools.ietf.org/html/rfc2812)
|
||||||
- [IRC Server Protocol](http://tools.ietf.org/html/rfc2813)
|
- [RFC 2813: IRC Server Protocol](http://tools.ietf.org/html/rfc2813)
|
||||||
|
|
||||||
|
## Running the Server
|
||||||
|
|
||||||
|
```sh
|
||||||
|
$ ./build.sh
|
||||||
|
$ bin/ergonomadicdb init
|
||||||
|
$ bin/ergonomadic
|
||||||
|
```
|
||||||
|
3
build.sh
3
build.sh
@ -1,4 +1,5 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
export GOPATH="$PWD"
|
export GOPATH="$PWD"
|
||||||
go get "code.google.com/p/go.crypto/bcrypt"
|
go get "code.google.com/p/go.crypto/bcrypt"
|
||||||
go install ergonomadic genpasswd
|
go get "github.com/mattn/go-sqlite3"
|
||||||
|
go install ergonomadic genpasswd ergonomadicdb
|
||||||
|
10
sql/drop.sql
Normal file
10
sql/drop.sql
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
DROP INDEX user_id_channel_id
|
||||||
|
DROP TABLE user_channel
|
||||||
|
|
||||||
|
DROP INDEX channel_name
|
||||||
|
DROP INDEX channel_id
|
||||||
|
DROP TABLE channel
|
||||||
|
|
||||||
|
DROP INDEX user_nick
|
||||||
|
DROP INDEX user_id
|
||||||
|
DROP TABLE user
|
10
sql/init.sql
Normal file
10
sql/init.sql
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
CREATE TABLE user (id integer not null primary key autoincrement, nick text not null, hash blob not null)
|
||||||
|
CREATE UNIQUE INDEX user_id ON user (id)
|
||||||
|
CREATE UNIQUE INDEX user_nick ON user (nick)
|
||||||
|
|
||||||
|
CREATE TABLE channel (id integer not null primary key autoincrement, name text not null)
|
||||||
|
CREATE UNIQUE INDEX channel_id ON channel (id)
|
||||||
|
CREATE UNIQUE INDEX channel_name ON channel (name)
|
||||||
|
|
||||||
|
CREATE_TABLE user_channel (id integer not null primary key autoincrement, user_id integer not null, channel_id integer not null)
|
||||||
|
CREATE UNIQUE INDEX user_id_channel_id ON user_channel (user_id, channel_id)
|
22
src/ergonomadicdb/ergonomadicdb.go
Normal file
22
src/ergonomadicdb/ergonomadicdb.go
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
"irc"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
actions = map[string]func(*irc.Database){
|
||||||
|
"init": func(db *irc.Database) {
|
||||||
|
db.InitTables()
|
||||||
|
},
|
||||||
|
"drop": func(db *irc.Database) {
|
||||||
|
db.DropTables()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
flag.Parse()
|
||||||
|
actions[flag.Arg(0)](irc.NewDatabase())
|
||||||
|
}
|
@ -9,6 +9,7 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Channel struct {
|
type Channel struct {
|
||||||
|
id *RowId
|
||||||
server *Server
|
server *Server
|
||||||
commands chan<- ChannelCommand
|
commands chan<- ChannelCommand
|
||||||
replies chan<- Reply
|
replies chan<- Reply
|
||||||
@ -30,6 +31,16 @@ func (set ChannelSet) Remove(channel *Channel) {
|
|||||||
delete(set, channel)
|
delete(set, channel)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (set ChannelSet) Ids() (ids []RowId) {
|
||||||
|
ids = []RowId{}
|
||||||
|
for channel := range set {
|
||||||
|
if channel.id != nil {
|
||||||
|
ids = append(ids, *channel.id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ids
|
||||||
|
}
|
||||||
|
|
||||||
type ChannelCommand interface {
|
type ChannelCommand interface {
|
||||||
Command
|
Command
|
||||||
HandleChannel(channel *Channel)
|
HandleChannel(channel *Channel)
|
||||||
@ -52,6 +63,24 @@ func NewChannel(s *Server, name string) *Channel {
|
|||||||
return channel
|
return channel
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (channel *Channel) Save(q Queryable) bool {
|
||||||
|
if channel.id == nil {
|
||||||
|
if err := InsertChannel(q, channel); err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
channelId, err := FindChannelIdByName(q, channel.name)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
channel.id = &channelId
|
||||||
|
} else {
|
||||||
|
if err := UpdateChannel(q, channel); err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
// Forward `Reply`s to all `User`s of the `Channel`.
|
// Forward `Reply`s to all `User`s of the `Channel`.
|
||||||
func (channel *Channel) receiveReplies(replies <-chan Reply) {
|
func (channel *Channel) receiveReplies(replies <-chan Reply) {
|
||||||
for reply := range replies {
|
for reply := range replies {
|
||||||
|
@ -53,7 +53,7 @@ func (command *BaseCommand) User() *User {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (command *BaseCommand) SetClient(c *Client) {
|
func (command *BaseCommand) SetClient(c *Client) {
|
||||||
command.client = c
|
*command = BaseCommand{c}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (command *BaseCommand) Source() Identifier {
|
func (command *BaseCommand) Source() Identifier {
|
||||||
|
@ -8,7 +8,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
DEBUG_NET = false
|
DEBUG_NET = true
|
||||||
)
|
)
|
||||||
|
|
||||||
func readTrimmedLine(reader *bufio.Reader) (string, error) {
|
func readTrimmedLine(reader *bufio.Reader) (string, error) {
|
||||||
|
220
src/irc/persistence.go
Normal file
220
src/irc/persistence.go
Normal file
@ -0,0 +1,220 @@
|
|||||||
|
package irc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
//"fmt"
|
||||||
|
"bufio"
|
||||||
|
_ "github.com/mattn/go-sqlite3"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Database struct {
|
||||||
|
*sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
type Transaction struct {
|
||||||
|
*sql.Tx
|
||||||
|
}
|
||||||
|
|
||||||
|
type RowId uint64
|
||||||
|
|
||||||
|
type Queryable interface {
|
||||||
|
Exec(string, ...interface{}) (sql.Result, error)
|
||||||
|
Query(string, ...interface{}) (*sql.Rows, error)
|
||||||
|
QueryRow(string, ...interface{}) *sql.Row
|
||||||
|
}
|
||||||
|
|
||||||
|
type TransactionFunc func(Queryable) bool
|
||||||
|
|
||||||
|
//
|
||||||
|
// general
|
||||||
|
//
|
||||||
|
|
||||||
|
func NewDatabase() *Database {
|
||||||
|
db, err := sql.Open("sqlite3", "ergonomadic.db")
|
||||||
|
if err != nil {
|
||||||
|
panic("cannot open database")
|
||||||
|
}
|
||||||
|
return &Database{db}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTransaction(tx *sql.Tx) *Transaction {
|
||||||
|
return &Transaction{tx}
|
||||||
|
}
|
||||||
|
|
||||||
|
func readLines(filename string) <-chan string {
|
||||||
|
file, err := os.Open(filename)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
reader := bufio.NewReader(file)
|
||||||
|
lines := make(chan string)
|
||||||
|
go func(lines chan<- string) {
|
||||||
|
defer file.Close()
|
||||||
|
defer close(lines)
|
||||||
|
for {
|
||||||
|
line, err := reader.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
line = strings.TrimSpace(line)
|
||||||
|
if line == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
lines <- line
|
||||||
|
}
|
||||||
|
}(lines)
|
||||||
|
return lines
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *Database) execSqlFile(filename string) {
|
||||||
|
db.Transact(func(q Queryable) bool {
|
||||||
|
for line := range readLines(filepath.Join("sql", filename)) {
|
||||||
|
log.Println(line)
|
||||||
|
q.Exec(line)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *Database) InitTables() {
|
||||||
|
db.execSqlFile("init.sql")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *Database) DropTables() {
|
||||||
|
db.execSqlFile("drop.sql")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *Database) Transact(txf TransactionFunc) {
|
||||||
|
tx, err := db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if txf(tx) {
|
||||||
|
tx.Commit()
|
||||||
|
} else {
|
||||||
|
tx.Rollback()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// data
|
||||||
|
//
|
||||||
|
|
||||||
|
type UserRow struct {
|
||||||
|
id RowId
|
||||||
|
nick string
|
||||||
|
hash []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChannelRow struct {
|
||||||
|
id RowId
|
||||||
|
name string
|
||||||
|
}
|
||||||
|
|
||||||
|
// user
|
||||||
|
|
||||||
|
func FindUserByNick(q Queryable, nick string) (ur *UserRow) {
|
||||||
|
ur = new(UserRow)
|
||||||
|
row := q.QueryRow("SELECT * FROM user LIMIT 1 WHERE nick = ?", nick)
|
||||||
|
err := row.Scan(&ur.id, &ur.nick, &ur.hash)
|
||||||
|
if err != nil {
|
||||||
|
ur = nil
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func FindUserIdByNick(q Queryable, nick string) (rowId RowId, err error) {
|
||||||
|
row := q.QueryRow("SELECT id FROM user WHERE nick = ?", nick)
|
||||||
|
err = row.Scan(&rowId)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func FindChannelByName(q Queryable, name string) (cr *ChannelRow) {
|
||||||
|
cr = new(ChannelRow)
|
||||||
|
row := q.QueryRow("SELECT * FROM channel LIMIT 1 WHERE name = ?", name)
|
||||||
|
err := row.Scan(&(cr.id), &(cr.name))
|
||||||
|
if err != nil {
|
||||||
|
cr = nil
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func InsertUser(q Queryable, user *User) (err error) {
|
||||||
|
_, err = q.Exec("INSERT INTO user (nick, hash) VALUES (?, ?)",
|
||||||
|
user.nick, user.hash)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func UpdateUser(q Queryable, user *User) (err error) {
|
||||||
|
_, err = q.Exec("UPDATE user SET nick = ?, hash = ? WHERE id = ?",
|
||||||
|
user.nick, user.hash, *(user.id))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// user-channel
|
||||||
|
|
||||||
|
func DeleteAllUserChannels(q Queryable, rowId RowId) (err error) {
|
||||||
|
_, err = q.Exec("DELETE FROM user_channel WHERE user_id = ?", rowId)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func DeleteOtherUserChannels(q Queryable, userId RowId, channelIds []RowId) (err error) {
|
||||||
|
_, err = q.Exec(`DELETE FROM user_channel WHERE
|
||||||
|
user_id = ? AND channel_id NOT IN ?`, userId, channelIds)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func InsertUserChannels(q Queryable, userId RowId, channelIds []RowId) (err error) {
|
||||||
|
ins := "INSERT OR IGNORE INTO user_channel (user_id, channel_id) VALUES "
|
||||||
|
vals := strings.Repeat("(?, ?), ", len(channelIds))
|
||||||
|
vals = vals[0 : len(vals)-2]
|
||||||
|
args := make([]interface{}, 2*len(channelIds))
|
||||||
|
var i = 0
|
||||||
|
for channelId := range channelIds {
|
||||||
|
args[i] = userId
|
||||||
|
args[i+1] = channelId
|
||||||
|
i += 2
|
||||||
|
}
|
||||||
|
_, err = q.Exec(ins+vals, args)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// channel
|
||||||
|
|
||||||
|
func FindChannelIdByName(q Queryable, name string) (channelId RowId, err error) {
|
||||||
|
row := q.QueryRow("SELECT id FROM channel WHERE name = ?", name)
|
||||||
|
err = row.Scan(&channelId)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func FindChannelsForUser(q Queryable, userId RowId) (crs []ChannelRow) {
|
||||||
|
rows, err := q.Query(`SELECT * FROM channel WHERE id IN
|
||||||
|
(SELECT channel_id from user_channel WHERE user_id = ?)`, userId)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
crs = make([]ChannelRow, 0)
|
||||||
|
for rows.Next() {
|
||||||
|
cr := ChannelRow{}
|
||||||
|
if err := rows.Scan(&(cr.id), &(cr.name)); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
crs = append(crs, cr)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func InsertChannel(q Queryable, channel *Channel) (err error) {
|
||||||
|
_, err = q.Exec("INSERT INTO channel (name) VALUES (?)", channel.name)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func UpdateChannel(q Queryable, channel *Channel) (err error) {
|
||||||
|
_, err = q.Exec("UPDATE channel SET name = ? WHERE id = ?",
|
||||||
|
channel.name, *(channel.id))
|
||||||
|
return
|
||||||
|
}
|
@ -16,6 +16,7 @@ type UserCommand interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type User struct {
|
type User struct {
|
||||||
|
id *RowId
|
||||||
nick string
|
nick string
|
||||||
hash []byte
|
hash []byte
|
||||||
server *Server
|
server *Server
|
||||||
@ -56,12 +57,46 @@ func NewUser(nick string, password string, server *Server) *User {
|
|||||||
replies: replies,
|
replies: replies,
|
||||||
}
|
}
|
||||||
user.SetPassword(password)
|
user.SetPassword(password)
|
||||||
|
|
||||||
go user.receiveCommands(commands)
|
go user.receiveCommands(commands)
|
||||||
go user.receiveReplies(replies)
|
go user.receiveReplies(replies)
|
||||||
server.users[nick] = user
|
server.users[nick] = user
|
||||||
|
|
||||||
return user
|
return user
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (user *User) Save(q Queryable) bool {
|
||||||
|
if user.id == nil {
|
||||||
|
if err := InsertUser(q, user); err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
userId, err := FindUserIdByNick(q, user.nick)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
user.id = &userId
|
||||||
|
} else {
|
||||||
|
if err := UpdateUser(q, user); err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
channelIds := user.channels.Ids()
|
||||||
|
if len(channelIds) == 0 {
|
||||||
|
if err := DeleteAllUserChannels(q, *(user.id)); err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := DeleteOtherUserChannels(q, *(user.id), channelIds); err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if err := InsertUserChannels(q, *(user.id), channelIds); err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
func (user *User) SetPassword(password string) {
|
func (user *User) SetPassword(password string) {
|
||||||
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
Loading…
Reference in New Issue
Block a user