Remove unnecessary Database struct and add more persistence.

This commit is contained in:
Jeremy Latt 2013-06-02 16:32:22 -07:00
parent ccdf7779a5
commit be60d503be
5 changed files with 28 additions and 26 deletions

View File

@ -7,5 +7,7 @@ import (
func main() { func main() {
flag.Parse() flag.Parse()
irc.NewDatabase().ExecSqlFile(flag.Arg(0) + ".sql").Close() db := irc.NewDatabase()
defer db.Close()
irc.ExecSqlFile(db, flag.Arg(0)+".sql")
} }

View File

@ -60,6 +60,7 @@ func NewChannel(s *Server, name string) *Channel {
} }
go channel.receiveCommands(commands) go channel.receiveCommands(commands)
go channel.receiveReplies(replies) go channel.receiveReplies(replies)
Save(s.db, channel)
return channel return channel
} }

View File

@ -107,7 +107,7 @@ func (m *RegisterCommand) HandleNickServ(ns *NickServ) {
} }
user := NewUser(client.nick, ns.server).SetPassword(m.password) user := NewUser(client.nick, ns.server).SetPassword(m.password)
ns.server.db.Save(user) Save(ns.server.db, user)
ns.Reply(client, "You have registered.") ns.Reply(client, "You have registered.")
if !user.Login(client, client.nick, m.password) { if !user.Login(client, client.nick, m.password) {

View File

@ -11,14 +11,6 @@ import (
"strings" "strings"
) )
type Database struct {
*sql.DB
}
type Transaction struct {
*sql.Tx
}
type RowId uint64 type RowId uint64
type Queryable interface { type Queryable interface {
@ -35,16 +27,12 @@ type Savable interface {
// general // general
// //
func NewDatabase() *Database { func NewDatabase() (db *sql.DB) {
db, err := sql.Open("sqlite3", "ergonomadic.db") db, err := sql.Open("sqlite3", "ergonomadic.db")
if err != nil { if err != nil {
log.Fatalln("cannot open database") log.Fatalln("cannot open database")
} }
return &Database{db} return
}
func NewTransaction(tx *sql.Tx) *Transaction {
return &Transaction{tx}
} }
func readLines(filename string) <-chan string { func readLines(filename string) <-chan string {
@ -72,8 +60,8 @@ func readLines(filename string) <-chan string {
return lines return lines
} }
func (db *Database) ExecSqlFile(filename string) *Database { func ExecSqlFile(db *sql.DB, filename string) {
db.Transact(func(q Queryable) bool { Transact(db, func(q Queryable) bool {
for line := range readLines(filepath.Join("sql", filename)) { for line := range readLines(filepath.Join("sql", filename)) {
log.Println(line) log.Println(line)
_, err := q.Exec(line) _, err := q.Exec(line)
@ -83,10 +71,9 @@ func (db *Database) ExecSqlFile(filename string) *Database {
} }
return true return true
}) })
return db
} }
func (db *Database) Transact(txf func(Queryable) bool) { func Transact(db *sql.DB, txf func(Queryable) bool) {
tx, err := db.Begin() tx, err := db.Begin()
if err != nil { if err != nil {
log.Panicln(err) log.Panicln(err)
@ -98,10 +85,8 @@ func (db *Database) Transact(txf func(Queryable) bool) {
} }
} }
func (db *Database) Save(s Savable) { func Save(db *sql.DB, s Savable) {
db.Transact(func(tx Queryable) bool { Transact(db, s.Save)
return s.Save(tx)
})
} }
// //
@ -189,6 +174,11 @@ func UpdateUser(q Queryable, user *User) (err error) {
return return
} }
func DeleteUser(q Queryable, user *User) (err error) {
_, err = q.Exec("DELETE FROM user WHERE id = ?", *(user.id))
return
}
// user-channel // user-channel
func DeleteAllUserChannels(q Queryable, rowId RowId) (err error) { func DeleteAllUserChannels(q Queryable, rowId RowId) (err error) {
@ -258,3 +248,8 @@ func UpdateChannel(q Queryable, channel *Channel) (err error) {
channel.name, *(channel.id)) channel.name, *(channel.id))
return return
} }
func DeleteChannel(q Queryable, channel *Channel) (err error) {
_, err = q.Exec("DELETE FROM channel WHERE id = ?", *(channel.id))
return
}

View File

@ -2,6 +2,7 @@ package irc
import ( import (
"code.google.com/p/go.crypto/bcrypt" "code.google.com/p/go.crypto/bcrypt"
"database/sql"
"log" "log"
"net" "net"
"time" "time"
@ -24,7 +25,7 @@ type Server struct {
channels ChannelNameMap channels ChannelNameMap
services ServiceNameMap services ServiceNameMap
commands chan<- Command commands chan<- Command
db *Database db *sql.DB
} }
func NewServer(name string) *Server { func NewServer(name string) *Server {
@ -40,7 +41,7 @@ func NewServer(name string) *Server {
} }
go server.receiveCommands(commands) go server.receiveCommands(commands)
NewNickServ(server) NewNickServ(server)
server.db.Transact(func(q Queryable) bool { Transact(server.db, func(q Queryable) bool {
urs, err := FindAllUsers(server.db) urs, err := FindAllUsers(server.db)
if err != nil { if err != nil {
return false return false
@ -146,6 +147,9 @@ func (s *Server) Nick() string {
func (s *Server) DeleteChannel(channel *Channel) { func (s *Server) DeleteChannel(channel *Channel) {
delete(s.channels, channel.name) delete(s.channels, channel.name)
if err := DeleteChannel(s.db, channel); err != nil {
log.Println(err)
}
} }
// //