From ccdf7779a5fb1d57e66f5cba1eb4a6b489235478 Mon Sep 17 00:00:00 2001 From: Jeremy Latt Date: Sun, 26 May 2013 13:28:22 -0700 Subject: [PATCH] User persistence to sqlite. --- .gitignore | 1 + sql/drop.sql | 16 ++-- sql/init.sql | 26 +++++-- src/ergonomadicdb/ergonomadicdb.go | 13 +--- src/irc/channel.go | 59 +++++++-------- src/irc/client.go | 4 +- src/irc/net.go | 4 +- src/irc/nickserv.go | 3 +- src/irc/persistence.go | 114 +++++++++++++++++++---------- src/irc/server.go | 18 ++++- src/irc/service.go | 2 +- src/irc/user.go | 27 ++++--- 12 files changed, 172 insertions(+), 115 deletions(-) diff --git a/.gitignore b/.gitignore index 0b9380a8..b90e3dfc 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ pkg bin src/code.google.com/ +src/github.com/ ergonomadic.db diff --git a/sql/drop.sql b/sql/drop.sql index b8347c95..b74ab1ba 100644 --- a/sql/drop.sql +++ b/sql/drop.sql @@ -1,10 +1,10 @@ -DROP INDEX user_id_channel_id -DROP TABLE user_channel +DROP INDEX IF EXISTS index_user_id_channel_id; +DROP TABLE IF EXISTS user_channel; -DROP INDEX channel_name -DROP INDEX channel_id -DROP TABLE channel +DROP INDEX IF EXISTS index_channel_name; +DROP INDEX IF EXISTS index_channel_id; +DROP TABLE IF EXISTS channel; -DROP INDEX user_nick -DROP INDEX user_id -DROP TABLE user +DROP INDEX IF EXISTS index_user_nick; +DROP INDEX IF EXISTS index_user_id; +DROP TABLE IF EXISTS user; diff --git a/sql/init.sql b/sql/init.sql index 7b7f2bd3..b73dbffc 100644 --- a/sql/init.sql +++ b/sql/init.sql @@ -1,10 +1,20 @@ -CREATE TABLE user (id integer not null primary key autoincrement, nick text not null, hash blob not null) -CREATE UNIQUE INDEX user_id ON user (id) -CREATE UNIQUE INDEX user_nick ON user (nick) +CREATE TABLE user ( + id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + nick TEXT NOT NULL UNIQUE, + hash BLOB NOT NULL +); +CREATE INDEX index_user_id ON user(id); +CREATE INDEX index_user_nick ON user(nick); -CREATE TABLE channel (id integer not null primary key autoincrement, name text not null) -CREATE UNIQUE INDEX channel_id ON channel (id) -CREATE UNIQUE INDEX channel_name ON channel (name) +CREATE TABLE channel ( + id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + name TEXT NOT NULL UNIQUE +); +CREATE INDEX index_channel_id ON channel(id); -CREATE_TABLE user_channel (id integer not null primary key autoincrement, user_id integer not null, channel_id integer not null) -CREATE UNIQUE INDEX user_id_channel_id ON user_channel (user_id, channel_id) +CREATE TABLE user_channel ( + id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + user_id INTEGER NOT NULL, + channel_id INTEGER NOT NULL +); +CREATE UNIQUE INDEX index_user_id_channel_id ON user_channel (user_id, channel_id); diff --git a/src/ergonomadicdb/ergonomadicdb.go b/src/ergonomadicdb/ergonomadicdb.go index bc925ed8..05ee8084 100644 --- a/src/ergonomadicdb/ergonomadicdb.go +++ b/src/ergonomadicdb/ergonomadicdb.go @@ -5,18 +5,7 @@ import ( "irc" ) -var ( - actions = map[string]func(*irc.Database){ - "init": func(db *irc.Database) { - db.InitTables() - }, - "drop": func(db *irc.Database) { - db.DropTables() - }, - } -) - func main() { flag.Parse() - actions[flag.Arg(0)](irc.NewDatabase()) + irc.NewDatabase().ExecSqlFile(flag.Arg(0) + ".sql").Close() } diff --git a/src/irc/channel.go b/src/irc/channel.go index b8237f76..6f826ab3 100644 --- a/src/irc/channel.go +++ b/src/irc/channel.go @@ -49,8 +49,8 @@ type ChannelCommand interface { // NewChannel creates a new channel from a `Server` and a `name` string, which // must be unique on the server. func NewChannel(s *Server, name string) *Channel { - replies := make(chan Reply) - commands := make(chan ChannelCommand) + commands := make(chan ChannelCommand, 1) + replies := make(chan Reply, 1) channel := &Channel{ name: name, members: make(UserSet), @@ -58,8 +58,8 @@ func NewChannel(s *Server, name string) *Channel { commands: commands, replies: replies, } - go channel.receiveReplies(replies) go channel.receiveCommands(commands) + go channel.receiveReplies(replies) return channel } @@ -81,29 +81,27 @@ func (channel *Channel) Save(q Queryable) bool { return true } -// Forward `Reply`s to all `User`s of the `Channel`. -func (channel *Channel) receiveReplies(replies <-chan Reply) { - for reply := range replies { - if DEBUG_CHANNEL { - log.Printf("%s → %s", channel, reply) - } - for user := range channel.members { - if user != reply.Source() { - user.replies <- reply - } - } - } -} - func (channel *Channel) receiveCommands(commands <-chan ChannelCommand) { for command := range commands { if DEBUG_CHANNEL { - log.Printf("%s ← %s %s", channel, command.Source(), command) + log.Printf("%s → %s : %s", command.Source(), channel, command) } command.HandleChannel(channel) } } +func (channel *Channel) receiveReplies(replies <-chan Reply) { + for reply := range replies { + if DEBUG_CHANNEL { + log.Printf("%s ← %s : %s", channel, reply.Source(), reply) + } + for user := range channel.members { + if user != reply.Source() { + user.Replies() <- reply + } + } + } +} func (channel *Channel) Nicks() []string { return channel.members.Nicks() } @@ -121,6 +119,10 @@ func (channel *Channel) GetTopic(replier Replier) { replier.Replies() <- RplTopic(channel) } +func (channel *Channel) Replies() chan<- Reply { + return channel.replies +} + func (channel *Channel) Id() string { return channel.name } @@ -133,10 +135,6 @@ func (channel *Channel) Commands() chan<- ChannelCommand { return channel.commands } -func (channel *Channel) Replies() chan<- Reply { - return channel.replies -} - func (channel *Channel) String() string { return channel.Id() } @@ -150,17 +148,17 @@ func (m *JoinCommand) HandleChannel(channel *Channel) { user := client.user if channel.key != m.channels[channel.name] { - client.user.replies <- ErrBadChannelKey(channel) + client.user.Replies() <- ErrBadChannelKey(channel) return } channel.members.Add(user) user.channels.Add(channel) - channel.replies <- RplJoin(channel, user) + channel.Replies() <- RplJoin(channel, user) channel.GetTopic(user) - user.replies <- RplNamReply(channel) - user.replies <- RplEndOfNames(channel.server) + user.Replies() <- RplNamReply(channel) + user.Replies() <- RplEndOfNames(channel.server) } func (m *PartCommand) HandleChannel(channel *Channel) { @@ -176,7 +174,7 @@ func (m *PartCommand) HandleChannel(channel *Channel) { msg = user.Nick() } - channel.replies <- RplPart(channel, user, msg) + channel.Replies() <- RplPart(channel, user, msg) channel.members.Remove(user) user.channels.Remove(channel) @@ -190,7 +188,7 @@ func (m *TopicCommand) HandleChannel(channel *Channel) { user := m.User() if !channel.members[user] { - user.replies <- ErrNotOnChannel(channel) + user.Replies() <- ErrNotOnChannel(channel) return } @@ -202,11 +200,10 @@ func (m *TopicCommand) HandleChannel(channel *Channel) { channel.topic = m.topic if channel.topic == "" { - channel.replies <- RplNoTopic(channel) + channel.Replies() <- RplNoTopic(channel) return } - - channel.replies <- RplTopic(channel) + channel.Replies() <- RplTopic(channel) } func (m *PrivMsgCommand) HandleChannel(channel *Channel) { diff --git a/src/irc/client.go b/src/irc/client.go index 7a0b0894..351d166d 100644 --- a/src/irc/client.go +++ b/src/irc/client.go @@ -31,7 +31,7 @@ type ClientSet map[*Client]bool func NewClient(server *Server, conn net.Conn) *Client { read := StringReadChan(conn) write := StringWriteChan(conn) - replies := make(chan Reply) + replies := make(chan Reply, 1) client := &Client{ conn: conn, @@ -67,7 +67,7 @@ func (c *Client) readConn(recv <-chan string) { func (c *Client) writeConn(write chan<- string, replies <-chan Reply) { for reply := range replies { if DEBUG_CLIENT { - log.Printf("%s ← %s", c, reply) + log.Printf("%s ← %s : %s", c, reply.Source(), reply) } write <- reply.Format(c) } diff --git a/src/irc/net.go b/src/irc/net.go index d28d1f9c..d241a0d0 100644 --- a/src/irc/net.go +++ b/src/irc/net.go @@ -30,7 +30,7 @@ func StringReadChan(conn net.Conn) <-chan string { break } if DEBUG_NET { - log.Printf("%s → %s", conn.RemoteAddr(), line) + log.Printf("%s → %s : %s", conn.RemoteAddr(), conn.LocalAddr(), line) } ch <- line } @@ -45,7 +45,7 @@ func StringWriteChan(conn net.Conn) chan<- string { go func() { for str := range ch { if DEBUG_NET { - log.Printf("%s ← %s", conn.RemoteAddr(), str) + log.Printf("%s ← %s : %s", conn.RemoteAddr(), conn.LocalAddr(), str) } if _, err := writer.WriteString(str + "\r\n"); err != nil { break diff --git a/src/irc/nickserv.go b/src/irc/nickserv.go index 0f3d0286..798df1c7 100644 --- a/src/irc/nickserv.go +++ b/src/irc/nickserv.go @@ -106,7 +106,8 @@ func (m *RegisterCommand) HandleNickServ(ns *NickServ) { return } - user := NewUser(client.nick, m.password, ns.server) + user := NewUser(client.nick, ns.server).SetPassword(m.password) + ns.server.db.Save(user) ns.Reply(client, "You have registered.") if !user.Login(client, client.nick, m.password) { diff --git a/src/irc/persistence.go b/src/irc/persistence.go index 8a68684a..c8805fa5 100644 --- a/src/irc/persistence.go +++ b/src/irc/persistence.go @@ -27,7 +27,9 @@ type Queryable interface { QueryRow(string, ...interface{}) *sql.Row } -type TransactionFunc func(Queryable) bool +type Savable interface { + Save(q Queryable) bool +} // // general @@ -36,7 +38,7 @@ type TransactionFunc func(Queryable) bool func NewDatabase() *Database { db, err := sql.Open("sqlite3", "ergonomadic.db") if err != nil { - panic("cannot open database") + log.Fatalln("cannot open database") } return &Database{db} } @@ -48,7 +50,7 @@ func NewTransaction(tx *sql.Tx) *Transaction { func readLines(filename string) <-chan string { file, err := os.Open(filename) if err != nil { - panic(err) + log.Fatalln(err) } reader := bufio.NewReader(file) lines := make(chan string) @@ -56,7 +58,7 @@ func readLines(filename string) <-chan string { defer file.Close() defer close(lines) for { - line, err := reader.ReadString('\n') + line, err := reader.ReadString(';') if err != nil { break } @@ -70,28 +72,24 @@ func readLines(filename string) <-chan string { return lines } -func (db *Database) execSqlFile(filename string) { +func (db *Database) ExecSqlFile(filename string) *Database { db.Transact(func(q Queryable) bool { for line := range readLines(filepath.Join("sql", filename)) { log.Println(line) - q.Exec(line) + _, err := q.Exec(line) + if err != nil { + log.Fatalln(err) + } } return true }) + return db } -func (db *Database) InitTables() { - db.execSqlFile("init.sql") -} - -func (db *Database) DropTables() { - db.execSqlFile("drop.sql") -} - -func (db *Database) Transact(txf TransactionFunc) { +func (db *Database) Transact(txf func(Queryable) bool) { tx, err := db.Begin() if err != nil { - panic(err) + log.Panicln(err) } if txf(tx) { tx.Commit() @@ -100,6 +98,28 @@ func (db *Database) Transact(txf TransactionFunc) { } } +func (db *Database) Save(s Savable) { + db.Transact(func(tx Queryable) bool { + return s.Save(tx) + }) +} + +// +// general purpose sql +// + +func FindId(q Queryable, sql string, args ...interface{}) (rowId RowId, err error) { + row := q.QueryRow(sql, args...) + err = row.Scan(&rowId) + return +} + +func Count(q Queryable, sql string, args ...interface{}) (count uint, err error) { + row := q.QueryRow(sql, args...) + err = row.Scan(&count) + return +} + // // data // @@ -117,25 +137,39 @@ type ChannelRow struct { // user -func FindUserByNick(q Queryable, nick string) (ur *UserRow) { - ur = new(UserRow) - row := q.QueryRow("SELECT * FROM user LIMIT 1 WHERE nick = ?", nick) - err := row.Scan(&ur.id, &ur.nick, &ur.hash) +func FindAllUsers(q Queryable) (urs []UserRow, err error) { + var rows *sql.Rows + rows, err = q.Query("SELECT id, nick, hash FROM user") if err != nil { - ur = nil + return + } + urs = make([]UserRow, 0) + for rows.Next() { + ur := UserRow{} + err = rows.Scan(&(ur.id), &(ur.nick), &(ur.hash)) + if err != nil { + return + } + urs = append(urs, ur) } return } -func FindUserIdByNick(q Queryable, nick string) (rowId RowId, err error) { - row := q.QueryRow("SELECT id FROM user WHERE nick = ?", nick) - err = row.Scan(&rowId) +func FindUserByNick(q Queryable, nick string) (ur *UserRow, err error) { + ur = &UserRow{} + row := q.QueryRow("SELECT id, nick, hash FROM user LIMIT 1 WHERE nick = ?", + nick) + err = row.Scan(&(ur.id), &(ur.nick), &(ur.hash)) return } +func FindUserIdByNick(q Queryable, nick string) (RowId, error) { + return FindId(q, "SELECT id FROM user WHERE nick = ?", nick) +} + func FindChannelByName(q Queryable, name string) (cr *ChannelRow) { cr = new(ChannelRow) - row := q.QueryRow("SELECT * FROM channel LIMIT 1 WHERE name = ?", name) + row := q.QueryRow("SELECT id, name FROM channel LIMIT 1 WHERE name = ?", name) err := row.Scan(&(cr.id), &(cr.name)) if err != nil { cr = nil @@ -185,25 +219,31 @@ func InsertUserChannels(q Queryable, userId RowId, channelIds []RowId) (err erro // channel -func FindChannelIdByName(q Queryable, name string) (channelId RowId, err error) { - row := q.QueryRow("SELECT id FROM channel WHERE name = ?", name) - err = row.Scan(&channelId) - return +func FindChannelIdByName(q Queryable, name string) (RowId, error) { + return FindId(q, "SELECT id FROM channel WHERE name = ?", name) } -func FindChannelsForUser(q Queryable, userId RowId) (crs []ChannelRow) { - rows, err := q.Query(`SELECT * FROM channel WHERE id IN -(SELECT channel_id from user_channel WHERE user_id = ?)`, userId) +func FindChannelsForUser(q Queryable, userId RowId) (crs []ChannelRow, err error) { + query := ` FROM channel WHERE id IN +(SELECT channel_id from user_channel WHERE user_id = ?)` + count, err := Count(q, "SELECT COUNT(id)"+query, userId) if err != nil { - panic(err) + return } - crs = make([]ChannelRow, 0) + rows, err := q.Query("SELECT id, name"+query, userId) + if err != nil { + return + } + crs = make([]ChannelRow, count) + var i = 0 for rows.Next() { cr := ChannelRow{} - if err := rows.Scan(&(cr.id), &(cr.name)); err != nil { - panic(err) + err = rows.Scan(&(cr.id), &(cr.name)) + if err != nil { + return } - crs = append(crs, cr) + crs[i] = cr + i++ } return } diff --git a/src/irc/server.go b/src/irc/server.go index b9ccfa12..a7788c7c 100644 --- a/src/irc/server.go +++ b/src/irc/server.go @@ -24,10 +24,11 @@ type Server struct { channels ChannelNameMap services ServiceNameMap commands chan<- Command + db *Database } func NewServer(name string) *Server { - commands := make(chan Command) + commands := make(chan Command, 1) server := &Server{ ctime: time.Now(), name: name, @@ -35,16 +36,27 @@ func NewServer(name string) *Server { users: make(UserNameMap), channels: make(ChannelNameMap), services: make(ServiceNameMap), + db: NewDatabase(), } go server.receiveCommands(commands) NewNickServ(server) + server.db.Transact(func(q Queryable) bool { + urs, err := FindAllUsers(server.db) + if err != nil { + return false + } + for _, ur := range urs { + NewUser(ur.nick, server).SetHash(ur.hash) + } + return false + }) return server } func (server *Server) receiveCommands(commands <-chan Command) { for command := range commands { if DEBUG_SERVER { - log.Printf("%s ← %s %s", server, command.Client(), command) + log.Printf("%s → %s : %s", command.Client(), server, command) } command.Client().atime = time.Now() command.HandleServer(server) @@ -278,7 +290,7 @@ func (m *TopicCommand) HandleServer(s *Server) { channel := s.channels[m.channel] if channel == nil { - user.Replies() <- ErrNoSuchChannel(s, m.channel) + m.Client().Replies() <- ErrNoSuchChannel(s, m.channel) return } diff --git a/src/irc/service.go b/src/irc/service.go index 162287b2..581bc13c 100644 --- a/src/irc/service.go +++ b/src/irc/service.go @@ -29,7 +29,7 @@ type BaseService struct { } func NewService(service EditableService, s *Server, name string) Service { - commands := make(chan ServiceCommand) + commands := make(chan ServiceCommand, 1) base := &BaseService{ server: s, name: name, diff --git a/src/irc/user.go b/src/irc/user.go index 1f2076e8..30d6da60 100644 --- a/src/irc/user.go +++ b/src/irc/user.go @@ -46,9 +46,9 @@ func (set UserSet) Nicks() []string { return nicks } -func NewUser(nick string, password string, server *Server) *User { - commands := make(chan UserCommand) - replies := make(chan Reply) +func NewUser(nick string, server *Server) *User { + commands := make(chan UserCommand, 1) + replies := make(chan Reply, 1) user := &User{ nick: nick, server: server, @@ -56,7 +56,6 @@ func NewUser(nick string, password string, server *Server) *User { channels: make(ChannelSet), replies: replies, } - user.SetPassword(password) go user.receiveCommands(commands) go user.receiveReplies(replies) @@ -81,34 +80,40 @@ func (user *User) Save(q Queryable) bool { } } + userId := *(user.id) channelIds := user.channels.Ids() if len(channelIds) == 0 { - if err := DeleteAllUserChannels(q, *(user.id)); err != nil { + if err := DeleteAllUserChannels(q, userId); err != nil { return false } } else { - if err := DeleteOtherUserChannels(q, *(user.id), channelIds); err != nil { + if err := DeleteOtherUserChannels(q, userId, channelIds); err != nil { return false } - if err := InsertUserChannels(q, *(user.id), channelIds); err != nil { + if err := InsertUserChannels(q, userId, channelIds); err != nil { return false } } return true } -func (user *User) SetPassword(password string) { +func (user *User) SetPassword(password string) *User { hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { panic("bcrypt failed; cannot generate password hash") } + return user.SetHash(hash) +} + +func (user *User) SetHash(hash []byte) *User { user.hash = hash + return user } func (user *User) receiveCommands(commands <-chan UserCommand) { for command := range commands { if DEBUG_USER { - log.Printf("%s ← %s %s", user, command.Client(), command) + log.Printf("%s → %s : %s", command.Client(), user, command) } command.HandleUser(user) } @@ -117,7 +122,9 @@ func (user *User) receiveCommands(commands <-chan UserCommand) { // Distribute replies to clients. func (user *User) receiveReplies(replies <-chan Reply) { for reply := range replies { - log.Printf("%s ← %s", user, reply) + if DEBUG_USER { + log.Printf("%s ← %s : %s", user, reply.Source(), reply) + } for client := range user.clients { client.Replies() <- reply }