diff --git a/src/ergonomadicdb/ergonomadicdb.go b/src/ergonomadicdb/ergonomadicdb.go index 05ee8084..b724d7be 100644 --- a/src/ergonomadicdb/ergonomadicdb.go +++ b/src/ergonomadicdb/ergonomadicdb.go @@ -7,5 +7,7 @@ import ( func main() { flag.Parse() - irc.NewDatabase().ExecSqlFile(flag.Arg(0) + ".sql").Close() + db := irc.NewDatabase() + defer db.Close() + irc.ExecSqlFile(db, flag.Arg(0)+".sql") } diff --git a/src/irc/channel.go b/src/irc/channel.go index 6f826ab3..3a347958 100644 --- a/src/irc/channel.go +++ b/src/irc/channel.go @@ -60,6 +60,7 @@ func NewChannel(s *Server, name string) *Channel { } go channel.receiveCommands(commands) go channel.receiveReplies(replies) + Save(s.db, channel) return channel } diff --git a/src/irc/nickserv.go b/src/irc/nickserv.go index 798df1c7..98fd9b43 100644 --- a/src/irc/nickserv.go +++ b/src/irc/nickserv.go @@ -107,7 +107,7 @@ func (m *RegisterCommand) HandleNickServ(ns *NickServ) { } user := NewUser(client.nick, ns.server).SetPassword(m.password) - ns.server.db.Save(user) + Save(ns.server.db, user) ns.Reply(client, "You have registered.") if !user.Login(client, client.nick, m.password) { diff --git a/src/irc/persistence.go b/src/irc/persistence.go index c8805fa5..32811943 100644 --- a/src/irc/persistence.go +++ b/src/irc/persistence.go @@ -11,14 +11,6 @@ import ( "strings" ) -type Database struct { - *sql.DB -} - -type Transaction struct { - *sql.Tx -} - type RowId uint64 type Queryable interface { @@ -35,16 +27,12 @@ type Savable interface { // general // -func NewDatabase() *Database { +func NewDatabase() (db *sql.DB) { db, err := sql.Open("sqlite3", "ergonomadic.db") if err != nil { log.Fatalln("cannot open database") } - return &Database{db} -} - -func NewTransaction(tx *sql.Tx) *Transaction { - return &Transaction{tx} + return } func readLines(filename string) <-chan string { @@ -72,8 +60,8 @@ func readLines(filename string) <-chan string { return lines } -func (db *Database) ExecSqlFile(filename string) *Database { - db.Transact(func(q Queryable) bool { +func ExecSqlFile(db *sql.DB, filename string) { + Transact(db, func(q Queryable) bool { for line := range readLines(filepath.Join("sql", filename)) { log.Println(line) _, err := q.Exec(line) @@ -83,10 +71,9 @@ func (db *Database) ExecSqlFile(filename string) *Database { } return true }) - return db } -func (db *Database) Transact(txf func(Queryable) bool) { +func Transact(db *sql.DB, txf func(Queryable) bool) { tx, err := db.Begin() if err != nil { log.Panicln(err) @@ -98,10 +85,8 @@ func (db *Database) Transact(txf func(Queryable) bool) { } } -func (db *Database) Save(s Savable) { - db.Transact(func(tx Queryable) bool { - return s.Save(tx) - }) +func Save(db *sql.DB, s Savable) { + Transact(db, s.Save) } // @@ -189,6 +174,11 @@ func UpdateUser(q Queryable, user *User) (err error) { return } +func DeleteUser(q Queryable, user *User) (err error) { + _, err = q.Exec("DELETE FROM user WHERE id = ?", *(user.id)) + return +} + // user-channel func DeleteAllUserChannels(q Queryable, rowId RowId) (err error) { @@ -258,3 +248,8 @@ func UpdateChannel(q Queryable, channel *Channel) (err error) { channel.name, *(channel.id)) return } + +func DeleteChannel(q Queryable, channel *Channel) (err error) { + _, err = q.Exec("DELETE FROM channel WHERE id = ?", *(channel.id)) + return +} diff --git a/src/irc/server.go b/src/irc/server.go index a7788c7c..467afd67 100644 --- a/src/irc/server.go +++ b/src/irc/server.go @@ -2,6 +2,7 @@ package irc import ( "code.google.com/p/go.crypto/bcrypt" + "database/sql" "log" "net" "time" @@ -24,7 +25,7 @@ type Server struct { channels ChannelNameMap services ServiceNameMap commands chan<- Command - db *Database + db *sql.DB } func NewServer(name string) *Server { @@ -40,7 +41,7 @@ func NewServer(name string) *Server { } go server.receiveCommands(commands) NewNickServ(server) - server.db.Transact(func(q Queryable) bool { + Transact(server.db, func(q Queryable) bool { urs, err := FindAllUsers(server.db) if err != nil { return false @@ -146,6 +147,9 @@ func (s *Server) Nick() string { func (s *Server) DeleteChannel(channel *Channel) { delete(s.channels, channel.name) + if err := DeleteChannel(s.db, channel); err != nil { + log.Println(err) + } } //