3
0
mirror of https://github.com/ergochat/ergo.git synced 2024-11-22 03:49:27 +01:00

User persistence to sqlite.

This commit is contained in:
Jeremy Latt 2013-05-26 13:28:22 -07:00
parent 48ca57c43d
commit ccdf7779a5
12 changed files with 172 additions and 115 deletions

1
.gitignore vendored
View File

@ -1,4 +1,5 @@
pkg pkg
bin bin
src/code.google.com/ src/code.google.com/
src/github.com/
ergonomadic.db ergonomadic.db

View File

@ -1,10 +1,10 @@
DROP INDEX user_id_channel_id DROP INDEX IF EXISTS index_user_id_channel_id;
DROP TABLE user_channel DROP TABLE IF EXISTS user_channel;
DROP INDEX channel_name DROP INDEX IF EXISTS index_channel_name;
DROP INDEX channel_id DROP INDEX IF EXISTS index_channel_id;
DROP TABLE channel DROP TABLE IF EXISTS channel;
DROP INDEX user_nick DROP INDEX IF EXISTS index_user_nick;
DROP INDEX user_id DROP INDEX IF EXISTS index_user_id;
DROP TABLE user DROP TABLE IF EXISTS user;

View File

@ -1,10 +1,20 @@
CREATE TABLE user (id integer not null primary key autoincrement, nick text not null, hash blob not null) CREATE TABLE user (
CREATE UNIQUE INDEX user_id ON user (id) id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
CREATE UNIQUE INDEX user_nick ON user (nick) nick TEXT NOT NULL UNIQUE,
hash BLOB NOT NULL
);
CREATE INDEX index_user_id ON user(id);
CREATE INDEX index_user_nick ON user(nick);
CREATE TABLE channel (id integer not null primary key autoincrement, name text not null) CREATE TABLE channel (
CREATE UNIQUE INDEX channel_id ON channel (id) id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
CREATE UNIQUE INDEX channel_name ON channel (name) name TEXT NOT NULL UNIQUE
);
CREATE INDEX index_channel_id ON channel(id);
CREATE_TABLE user_channel (id integer not null primary key autoincrement, user_id integer not null, channel_id integer not null) CREATE TABLE user_channel (
CREATE UNIQUE INDEX user_id_channel_id ON user_channel (user_id, channel_id) id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
user_id INTEGER NOT NULL,
channel_id INTEGER NOT NULL
);
CREATE UNIQUE INDEX index_user_id_channel_id ON user_channel (user_id, channel_id);

View File

@ -5,18 +5,7 @@ import (
"irc" "irc"
) )
var (
actions = map[string]func(*irc.Database){
"init": func(db *irc.Database) {
db.InitTables()
},
"drop": func(db *irc.Database) {
db.DropTables()
},
}
)
func main() { func main() {
flag.Parse() flag.Parse()
actions[flag.Arg(0)](irc.NewDatabase()) irc.NewDatabase().ExecSqlFile(flag.Arg(0) + ".sql").Close()
} }

View File

@ -49,8 +49,8 @@ type ChannelCommand interface {
// NewChannel creates a new channel from a `Server` and a `name` string, which // NewChannel creates a new channel from a `Server` and a `name` string, which
// must be unique on the server. // must be unique on the server.
func NewChannel(s *Server, name string) *Channel { func NewChannel(s *Server, name string) *Channel {
replies := make(chan Reply) commands := make(chan ChannelCommand, 1)
commands := make(chan ChannelCommand) replies := make(chan Reply, 1)
channel := &Channel{ channel := &Channel{
name: name, name: name,
members: make(UserSet), members: make(UserSet),
@ -58,8 +58,8 @@ func NewChannel(s *Server, name string) *Channel {
commands: commands, commands: commands,
replies: replies, replies: replies,
} }
go channel.receiveReplies(replies)
go channel.receiveCommands(commands) go channel.receiveCommands(commands)
go channel.receiveReplies(replies)
return channel return channel
} }
@ -81,29 +81,27 @@ func (channel *Channel) Save(q Queryable) bool {
return true return true
} }
// Forward `Reply`s to all `User`s of the `Channel`.
func (channel *Channel) receiveReplies(replies <-chan Reply) {
for reply := range replies {
if DEBUG_CHANNEL {
log.Printf("%s → %s", channel, reply)
}
for user := range channel.members {
if user != reply.Source() {
user.replies <- reply
}
}
}
}
func (channel *Channel) receiveCommands(commands <-chan ChannelCommand) { func (channel *Channel) receiveCommands(commands <-chan ChannelCommand) {
for command := range commands { for command := range commands {
if DEBUG_CHANNEL { if DEBUG_CHANNEL {
log.Printf("%s ← %s %s", channel, command.Source(), command) log.Printf("%s → %s : %s", command.Source(), channel, command)
} }
command.HandleChannel(channel) command.HandleChannel(channel)
} }
} }
func (channel *Channel) receiveReplies(replies <-chan Reply) {
for reply := range replies {
if DEBUG_CHANNEL {
log.Printf("%s ← %s : %s", channel, reply.Source(), reply)
}
for user := range channel.members {
if user != reply.Source() {
user.Replies() <- reply
}
}
}
}
func (channel *Channel) Nicks() []string { func (channel *Channel) Nicks() []string {
return channel.members.Nicks() return channel.members.Nicks()
} }
@ -121,6 +119,10 @@ func (channel *Channel) GetTopic(replier Replier) {
replier.Replies() <- RplTopic(channel) replier.Replies() <- RplTopic(channel)
} }
func (channel *Channel) Replies() chan<- Reply {
return channel.replies
}
func (channel *Channel) Id() string { func (channel *Channel) Id() string {
return channel.name return channel.name
} }
@ -133,10 +135,6 @@ func (channel *Channel) Commands() chan<- ChannelCommand {
return channel.commands return channel.commands
} }
func (channel *Channel) Replies() chan<- Reply {
return channel.replies
}
func (channel *Channel) String() string { func (channel *Channel) String() string {
return channel.Id() return channel.Id()
} }
@ -150,17 +148,17 @@ func (m *JoinCommand) HandleChannel(channel *Channel) {
user := client.user user := client.user
if channel.key != m.channels[channel.name] { if channel.key != m.channels[channel.name] {
client.user.replies <- ErrBadChannelKey(channel) client.user.Replies() <- ErrBadChannelKey(channel)
return return
} }
channel.members.Add(user) channel.members.Add(user)
user.channels.Add(channel) user.channels.Add(channel)
channel.replies <- RplJoin(channel, user) channel.Replies() <- RplJoin(channel, user)
channel.GetTopic(user) channel.GetTopic(user)
user.replies <- RplNamReply(channel) user.Replies() <- RplNamReply(channel)
user.replies <- RplEndOfNames(channel.server) user.Replies() <- RplEndOfNames(channel.server)
} }
func (m *PartCommand) HandleChannel(channel *Channel) { func (m *PartCommand) HandleChannel(channel *Channel) {
@ -176,7 +174,7 @@ func (m *PartCommand) HandleChannel(channel *Channel) {
msg = user.Nick() msg = user.Nick()
} }
channel.replies <- RplPart(channel, user, msg) channel.Replies() <- RplPart(channel, user, msg)
channel.members.Remove(user) channel.members.Remove(user)
user.channels.Remove(channel) user.channels.Remove(channel)
@ -190,7 +188,7 @@ func (m *TopicCommand) HandleChannel(channel *Channel) {
user := m.User() user := m.User()
if !channel.members[user] { if !channel.members[user] {
user.replies <- ErrNotOnChannel(channel) user.Replies() <- ErrNotOnChannel(channel)
return return
} }
@ -202,11 +200,10 @@ func (m *TopicCommand) HandleChannel(channel *Channel) {
channel.topic = m.topic channel.topic = m.topic
if channel.topic == "" { if channel.topic == "" {
channel.replies <- RplNoTopic(channel) channel.Replies() <- RplNoTopic(channel)
return return
} }
channel.Replies() <- RplTopic(channel)
channel.replies <- RplTopic(channel)
} }
func (m *PrivMsgCommand) HandleChannel(channel *Channel) { func (m *PrivMsgCommand) HandleChannel(channel *Channel) {

View File

@ -31,7 +31,7 @@ type ClientSet map[*Client]bool
func NewClient(server *Server, conn net.Conn) *Client { func NewClient(server *Server, conn net.Conn) *Client {
read := StringReadChan(conn) read := StringReadChan(conn)
write := StringWriteChan(conn) write := StringWriteChan(conn)
replies := make(chan Reply) replies := make(chan Reply, 1)
client := &Client{ client := &Client{
conn: conn, conn: conn,
@ -67,7 +67,7 @@ func (c *Client) readConn(recv <-chan string) {
func (c *Client) writeConn(write chan<- string, replies <-chan Reply) { func (c *Client) writeConn(write chan<- string, replies <-chan Reply) {
for reply := range replies { for reply := range replies {
if DEBUG_CLIENT { if DEBUG_CLIENT {
log.Printf("%s ← %s", c, reply) log.Printf("%s ← %s : %s", c, reply.Source(), reply)
} }
write <- reply.Format(c) write <- reply.Format(c)
} }

View File

@ -30,7 +30,7 @@ func StringReadChan(conn net.Conn) <-chan string {
break break
} }
if DEBUG_NET { if DEBUG_NET {
log.Printf("%s → %s", conn.RemoteAddr(), line) log.Printf("%s → %s : %s", conn.RemoteAddr(), conn.LocalAddr(), line)
} }
ch <- line ch <- line
} }
@ -45,7 +45,7 @@ func StringWriteChan(conn net.Conn) chan<- string {
go func() { go func() {
for str := range ch { for str := range ch {
if DEBUG_NET { if DEBUG_NET {
log.Printf("%s ← %s", conn.RemoteAddr(), str) log.Printf("%s ← %s : %s", conn.RemoteAddr(), conn.LocalAddr(), str)
} }
if _, err := writer.WriteString(str + "\r\n"); err != nil { if _, err := writer.WriteString(str + "\r\n"); err != nil {
break break

View File

@ -106,7 +106,8 @@ func (m *RegisterCommand) HandleNickServ(ns *NickServ) {
return return
} }
user := NewUser(client.nick, m.password, ns.server) user := NewUser(client.nick, ns.server).SetPassword(m.password)
ns.server.db.Save(user)
ns.Reply(client, "You have registered.") ns.Reply(client, "You have registered.")
if !user.Login(client, client.nick, m.password) { if !user.Login(client, client.nick, m.password) {

View File

@ -27,7 +27,9 @@ type Queryable interface {
QueryRow(string, ...interface{}) *sql.Row QueryRow(string, ...interface{}) *sql.Row
} }
type TransactionFunc func(Queryable) bool type Savable interface {
Save(q Queryable) bool
}
// //
// general // general
@ -36,7 +38,7 @@ type TransactionFunc func(Queryable) bool
func NewDatabase() *Database { func NewDatabase() *Database {
db, err := sql.Open("sqlite3", "ergonomadic.db") db, err := sql.Open("sqlite3", "ergonomadic.db")
if err != nil { if err != nil {
panic("cannot open database") log.Fatalln("cannot open database")
} }
return &Database{db} return &Database{db}
} }
@ -48,7 +50,7 @@ func NewTransaction(tx *sql.Tx) *Transaction {
func readLines(filename string) <-chan string { func readLines(filename string) <-chan string {
file, err := os.Open(filename) file, err := os.Open(filename)
if err != nil { if err != nil {
panic(err) log.Fatalln(err)
} }
reader := bufio.NewReader(file) reader := bufio.NewReader(file)
lines := make(chan string) lines := make(chan string)
@ -56,7 +58,7 @@ func readLines(filename string) <-chan string {
defer file.Close() defer file.Close()
defer close(lines) defer close(lines)
for { for {
line, err := reader.ReadString('\n') line, err := reader.ReadString(';')
if err != nil { if err != nil {
break break
} }
@ -70,28 +72,24 @@ func readLines(filename string) <-chan string {
return lines return lines
} }
func (db *Database) execSqlFile(filename string) { func (db *Database) ExecSqlFile(filename string) *Database {
db.Transact(func(q Queryable) bool { db.Transact(func(q Queryable) bool {
for line := range readLines(filepath.Join("sql", filename)) { for line := range readLines(filepath.Join("sql", filename)) {
log.Println(line) log.Println(line)
q.Exec(line) _, err := q.Exec(line)
if err != nil {
log.Fatalln(err)
}
} }
return true return true
}) })
return db
} }
func (db *Database) InitTables() { func (db *Database) Transact(txf func(Queryable) bool) {
db.execSqlFile("init.sql")
}
func (db *Database) DropTables() {
db.execSqlFile("drop.sql")
}
func (db *Database) Transact(txf TransactionFunc) {
tx, err := db.Begin() tx, err := db.Begin()
if err != nil { if err != nil {
panic(err) log.Panicln(err)
} }
if txf(tx) { if txf(tx) {
tx.Commit() tx.Commit()
@ -100,6 +98,28 @@ func (db *Database) Transact(txf TransactionFunc) {
} }
} }
func (db *Database) Save(s Savable) {
db.Transact(func(tx Queryable) bool {
return s.Save(tx)
})
}
//
// general purpose sql
//
func FindId(q Queryable, sql string, args ...interface{}) (rowId RowId, err error) {
row := q.QueryRow(sql, args...)
err = row.Scan(&rowId)
return
}
func Count(q Queryable, sql string, args ...interface{}) (count uint, err error) {
row := q.QueryRow(sql, args...)
err = row.Scan(&count)
return
}
// //
// data // data
// //
@ -117,25 +137,39 @@ type ChannelRow struct {
// user // user
func FindUserByNick(q Queryable, nick string) (ur *UserRow) { func FindAllUsers(q Queryable) (urs []UserRow, err error) {
ur = new(UserRow) var rows *sql.Rows
row := q.QueryRow("SELECT * FROM user LIMIT 1 WHERE nick = ?", nick) rows, err = q.Query("SELECT id, nick, hash FROM user")
err := row.Scan(&ur.id, &ur.nick, &ur.hash)
if err != nil { if err != nil {
ur = nil return
}
urs = make([]UserRow, 0)
for rows.Next() {
ur := UserRow{}
err = rows.Scan(&(ur.id), &(ur.nick), &(ur.hash))
if err != nil {
return
}
urs = append(urs, ur)
} }
return return
} }
func FindUserIdByNick(q Queryable, nick string) (rowId RowId, err error) { func FindUserByNick(q Queryable, nick string) (ur *UserRow, err error) {
row := q.QueryRow("SELECT id FROM user WHERE nick = ?", nick) ur = &UserRow{}
err = row.Scan(&rowId) row := q.QueryRow("SELECT id, nick, hash FROM user LIMIT 1 WHERE nick = ?",
nick)
err = row.Scan(&(ur.id), &(ur.nick), &(ur.hash))
return return
} }
func FindUserIdByNick(q Queryable, nick string) (RowId, error) {
return FindId(q, "SELECT id FROM user WHERE nick = ?", nick)
}
func FindChannelByName(q Queryable, name string) (cr *ChannelRow) { func FindChannelByName(q Queryable, name string) (cr *ChannelRow) {
cr = new(ChannelRow) cr = new(ChannelRow)
row := q.QueryRow("SELECT * FROM channel LIMIT 1 WHERE name = ?", name) row := q.QueryRow("SELECT id, name FROM channel LIMIT 1 WHERE name = ?", name)
err := row.Scan(&(cr.id), &(cr.name)) err := row.Scan(&(cr.id), &(cr.name))
if err != nil { if err != nil {
cr = nil cr = nil
@ -185,25 +219,31 @@ func InsertUserChannels(q Queryable, userId RowId, channelIds []RowId) (err erro
// channel // channel
func FindChannelIdByName(q Queryable, name string) (channelId RowId, err error) { func FindChannelIdByName(q Queryable, name string) (RowId, error) {
row := q.QueryRow("SELECT id FROM channel WHERE name = ?", name) return FindId(q, "SELECT id FROM channel WHERE name = ?", name)
err = row.Scan(&channelId)
return
} }
func FindChannelsForUser(q Queryable, userId RowId) (crs []ChannelRow) { func FindChannelsForUser(q Queryable, userId RowId) (crs []ChannelRow, err error) {
rows, err := q.Query(`SELECT * FROM channel WHERE id IN query := ` FROM channel WHERE id IN
(SELECT channel_id from user_channel WHERE user_id = ?)`, userId) (SELECT channel_id from user_channel WHERE user_id = ?)`
count, err := Count(q, "SELECT COUNT(id)"+query, userId)
if err != nil { if err != nil {
panic(err) return
} }
crs = make([]ChannelRow, 0) rows, err := q.Query("SELECT id, name"+query, userId)
if err != nil {
return
}
crs = make([]ChannelRow, count)
var i = 0
for rows.Next() { for rows.Next() {
cr := ChannelRow{} cr := ChannelRow{}
if err := rows.Scan(&(cr.id), &(cr.name)); err != nil { err = rows.Scan(&(cr.id), &(cr.name))
panic(err) if err != nil {
return
} }
crs = append(crs, cr) crs[i] = cr
i++
} }
return return
} }

View File

@ -24,10 +24,11 @@ type Server struct {
channels ChannelNameMap channels ChannelNameMap
services ServiceNameMap services ServiceNameMap
commands chan<- Command commands chan<- Command
db *Database
} }
func NewServer(name string) *Server { func NewServer(name string) *Server {
commands := make(chan Command) commands := make(chan Command, 1)
server := &Server{ server := &Server{
ctime: time.Now(), ctime: time.Now(),
name: name, name: name,
@ -35,16 +36,27 @@ func NewServer(name string) *Server {
users: make(UserNameMap), users: make(UserNameMap),
channels: make(ChannelNameMap), channels: make(ChannelNameMap),
services: make(ServiceNameMap), services: make(ServiceNameMap),
db: NewDatabase(),
} }
go server.receiveCommands(commands) go server.receiveCommands(commands)
NewNickServ(server) NewNickServ(server)
server.db.Transact(func(q Queryable) bool {
urs, err := FindAllUsers(server.db)
if err != nil {
return false
}
for _, ur := range urs {
NewUser(ur.nick, server).SetHash(ur.hash)
}
return false
})
return server return server
} }
func (server *Server) receiveCommands(commands <-chan Command) { func (server *Server) receiveCommands(commands <-chan Command) {
for command := range commands { for command := range commands {
if DEBUG_SERVER { if DEBUG_SERVER {
log.Printf("%s ← %s %s", server, command.Client(), command) log.Printf("%s → %s : %s", command.Client(), server, command)
} }
command.Client().atime = time.Now() command.Client().atime = time.Now()
command.HandleServer(server) command.HandleServer(server)
@ -278,7 +290,7 @@ func (m *TopicCommand) HandleServer(s *Server) {
channel := s.channels[m.channel] channel := s.channels[m.channel]
if channel == nil { if channel == nil {
user.Replies() <- ErrNoSuchChannel(s, m.channel) m.Client().Replies() <- ErrNoSuchChannel(s, m.channel)
return return
} }

View File

@ -29,7 +29,7 @@ type BaseService struct {
} }
func NewService(service EditableService, s *Server, name string) Service { func NewService(service EditableService, s *Server, name string) Service {
commands := make(chan ServiceCommand) commands := make(chan ServiceCommand, 1)
base := &BaseService{ base := &BaseService{
server: s, server: s,
name: name, name: name,

View File

@ -46,9 +46,9 @@ func (set UserSet) Nicks() []string {
return nicks return nicks
} }
func NewUser(nick string, password string, server *Server) *User { func NewUser(nick string, server *Server) *User {
commands := make(chan UserCommand) commands := make(chan UserCommand, 1)
replies := make(chan Reply) replies := make(chan Reply, 1)
user := &User{ user := &User{
nick: nick, nick: nick,
server: server, server: server,
@ -56,7 +56,6 @@ func NewUser(nick string, password string, server *Server) *User {
channels: make(ChannelSet), channels: make(ChannelSet),
replies: replies, replies: replies,
} }
user.SetPassword(password)
go user.receiveCommands(commands) go user.receiveCommands(commands)
go user.receiveReplies(replies) go user.receiveReplies(replies)
@ -81,34 +80,40 @@ func (user *User) Save(q Queryable) bool {
} }
} }
userId := *(user.id)
channelIds := user.channels.Ids() channelIds := user.channels.Ids()
if len(channelIds) == 0 { if len(channelIds) == 0 {
if err := DeleteAllUserChannels(q, *(user.id)); err != nil { if err := DeleteAllUserChannels(q, userId); err != nil {
return false return false
} }
} else { } else {
if err := DeleteOtherUserChannels(q, *(user.id), channelIds); err != nil { if err := DeleteOtherUserChannels(q, userId, channelIds); err != nil {
return false return false
} }
if err := InsertUserChannels(q, *(user.id), channelIds); err != nil { if err := InsertUserChannels(q, userId, channelIds); err != nil {
return false return false
} }
} }
return true return true
} }
func (user *User) SetPassword(password string) { func (user *User) SetPassword(password string) *User {
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil { if err != nil {
panic("bcrypt failed; cannot generate password hash") panic("bcrypt failed; cannot generate password hash")
} }
return user.SetHash(hash)
}
func (user *User) SetHash(hash []byte) *User {
user.hash = hash user.hash = hash
return user
} }
func (user *User) receiveCommands(commands <-chan UserCommand) { func (user *User) receiveCommands(commands <-chan UserCommand) {
for command := range commands { for command := range commands {
if DEBUG_USER { if DEBUG_USER {
log.Printf("%s ← %s %s", user, command.Client(), command) log.Printf("%s → %s : %s", command.Client(), user, command)
} }
command.HandleUser(user) command.HandleUser(user)
} }
@ -117,7 +122,9 @@ func (user *User) receiveCommands(commands <-chan UserCommand) {
// Distribute replies to clients. // Distribute replies to clients.
func (user *User) receiveReplies(replies <-chan Reply) { func (user *User) receiveReplies(replies <-chan Reply) {
for reply := range replies { for reply := range replies {
log.Printf("%s ← %s", user, reply) if DEBUG_USER {
log.Printf("%s ← %s : %s", user, reply.Source(), reply)
}
for client := range user.clients { for client := range user.clients {
client.Replies() <- reply client.Replies() <- reply
} }