mirror of
https://github.com/ergochat/ergo.git
synced 2024-11-22 03:49:27 +01:00
User persistence to sqlite.
This commit is contained in:
parent
48ca57c43d
commit
ccdf7779a5
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,4 +1,5 @@
|
|||||||
pkg
|
pkg
|
||||||
bin
|
bin
|
||||||
src/code.google.com/
|
src/code.google.com/
|
||||||
|
src/github.com/
|
||||||
ergonomadic.db
|
ergonomadic.db
|
||||||
|
16
sql/drop.sql
16
sql/drop.sql
@ -1,10 +1,10 @@
|
|||||||
DROP INDEX user_id_channel_id
|
DROP INDEX IF EXISTS index_user_id_channel_id;
|
||||||
DROP TABLE user_channel
|
DROP TABLE IF EXISTS user_channel;
|
||||||
|
|
||||||
DROP INDEX channel_name
|
DROP INDEX IF EXISTS index_channel_name;
|
||||||
DROP INDEX channel_id
|
DROP INDEX IF EXISTS index_channel_id;
|
||||||
DROP TABLE channel
|
DROP TABLE IF EXISTS channel;
|
||||||
|
|
||||||
DROP INDEX user_nick
|
DROP INDEX IF EXISTS index_user_nick;
|
||||||
DROP INDEX user_id
|
DROP INDEX IF EXISTS index_user_id;
|
||||||
DROP TABLE user
|
DROP TABLE IF EXISTS user;
|
||||||
|
26
sql/init.sql
26
sql/init.sql
@ -1,10 +1,20 @@
|
|||||||
CREATE TABLE user (id integer not null primary key autoincrement, nick text not null, hash blob not null)
|
CREATE TABLE user (
|
||||||
CREATE UNIQUE INDEX user_id ON user (id)
|
id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
|
||||||
CREATE UNIQUE INDEX user_nick ON user (nick)
|
nick TEXT NOT NULL UNIQUE,
|
||||||
|
hash BLOB NOT NULL
|
||||||
|
);
|
||||||
|
CREATE INDEX index_user_id ON user(id);
|
||||||
|
CREATE INDEX index_user_nick ON user(nick);
|
||||||
|
|
||||||
CREATE TABLE channel (id integer not null primary key autoincrement, name text not null)
|
CREATE TABLE channel (
|
||||||
CREATE UNIQUE INDEX channel_id ON channel (id)
|
id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
|
||||||
CREATE UNIQUE INDEX channel_name ON channel (name)
|
name TEXT NOT NULL UNIQUE
|
||||||
|
);
|
||||||
|
CREATE INDEX index_channel_id ON channel(id);
|
||||||
|
|
||||||
CREATE_TABLE user_channel (id integer not null primary key autoincrement, user_id integer not null, channel_id integer not null)
|
CREATE TABLE user_channel (
|
||||||
CREATE UNIQUE INDEX user_id_channel_id ON user_channel (user_id, channel_id)
|
id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
|
||||||
|
user_id INTEGER NOT NULL,
|
||||||
|
channel_id INTEGER NOT NULL
|
||||||
|
);
|
||||||
|
CREATE UNIQUE INDEX index_user_id_channel_id ON user_channel (user_id, channel_id);
|
||||||
|
@ -5,18 +5,7 @@ import (
|
|||||||
"irc"
|
"irc"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
actions = map[string]func(*irc.Database){
|
|
||||||
"init": func(db *irc.Database) {
|
|
||||||
db.InitTables()
|
|
||||||
},
|
|
||||||
"drop": func(db *irc.Database) {
|
|
||||||
db.DropTables()
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
actions[flag.Arg(0)](irc.NewDatabase())
|
irc.NewDatabase().ExecSqlFile(flag.Arg(0) + ".sql").Close()
|
||||||
}
|
}
|
||||||
|
@ -49,8 +49,8 @@ type ChannelCommand interface {
|
|||||||
// NewChannel creates a new channel from a `Server` and a `name` string, which
|
// NewChannel creates a new channel from a `Server` and a `name` string, which
|
||||||
// must be unique on the server.
|
// must be unique on the server.
|
||||||
func NewChannel(s *Server, name string) *Channel {
|
func NewChannel(s *Server, name string) *Channel {
|
||||||
replies := make(chan Reply)
|
commands := make(chan ChannelCommand, 1)
|
||||||
commands := make(chan ChannelCommand)
|
replies := make(chan Reply, 1)
|
||||||
channel := &Channel{
|
channel := &Channel{
|
||||||
name: name,
|
name: name,
|
||||||
members: make(UserSet),
|
members: make(UserSet),
|
||||||
@ -58,8 +58,8 @@ func NewChannel(s *Server, name string) *Channel {
|
|||||||
commands: commands,
|
commands: commands,
|
||||||
replies: replies,
|
replies: replies,
|
||||||
}
|
}
|
||||||
go channel.receiveReplies(replies)
|
|
||||||
go channel.receiveCommands(commands)
|
go channel.receiveCommands(commands)
|
||||||
|
go channel.receiveReplies(replies)
|
||||||
return channel
|
return channel
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -81,29 +81,27 @@ func (channel *Channel) Save(q Queryable) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Forward `Reply`s to all `User`s of the `Channel`.
|
|
||||||
func (channel *Channel) receiveReplies(replies <-chan Reply) {
|
|
||||||
for reply := range replies {
|
|
||||||
if DEBUG_CHANNEL {
|
|
||||||
log.Printf("%s → %s", channel, reply)
|
|
||||||
}
|
|
||||||
for user := range channel.members {
|
|
||||||
if user != reply.Source() {
|
|
||||||
user.replies <- reply
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (channel *Channel) receiveCommands(commands <-chan ChannelCommand) {
|
func (channel *Channel) receiveCommands(commands <-chan ChannelCommand) {
|
||||||
for command := range commands {
|
for command := range commands {
|
||||||
if DEBUG_CHANNEL {
|
if DEBUG_CHANNEL {
|
||||||
log.Printf("%s ← %s %s", channel, command.Source(), command)
|
log.Printf("%s → %s : %s", command.Source(), channel, command)
|
||||||
}
|
}
|
||||||
command.HandleChannel(channel)
|
command.HandleChannel(channel)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (channel *Channel) receiveReplies(replies <-chan Reply) {
|
||||||
|
for reply := range replies {
|
||||||
|
if DEBUG_CHANNEL {
|
||||||
|
log.Printf("%s ← %s : %s", channel, reply.Source(), reply)
|
||||||
|
}
|
||||||
|
for user := range channel.members {
|
||||||
|
if user != reply.Source() {
|
||||||
|
user.Replies() <- reply
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
func (channel *Channel) Nicks() []string {
|
func (channel *Channel) Nicks() []string {
|
||||||
return channel.members.Nicks()
|
return channel.members.Nicks()
|
||||||
}
|
}
|
||||||
@ -121,6 +119,10 @@ func (channel *Channel) GetTopic(replier Replier) {
|
|||||||
replier.Replies() <- RplTopic(channel)
|
replier.Replies() <- RplTopic(channel)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (channel *Channel) Replies() chan<- Reply {
|
||||||
|
return channel.replies
|
||||||
|
}
|
||||||
|
|
||||||
func (channel *Channel) Id() string {
|
func (channel *Channel) Id() string {
|
||||||
return channel.name
|
return channel.name
|
||||||
}
|
}
|
||||||
@ -133,10 +135,6 @@ func (channel *Channel) Commands() chan<- ChannelCommand {
|
|||||||
return channel.commands
|
return channel.commands
|
||||||
}
|
}
|
||||||
|
|
||||||
func (channel *Channel) Replies() chan<- Reply {
|
|
||||||
return channel.replies
|
|
||||||
}
|
|
||||||
|
|
||||||
func (channel *Channel) String() string {
|
func (channel *Channel) String() string {
|
||||||
return channel.Id()
|
return channel.Id()
|
||||||
}
|
}
|
||||||
@ -150,17 +148,17 @@ func (m *JoinCommand) HandleChannel(channel *Channel) {
|
|||||||
user := client.user
|
user := client.user
|
||||||
|
|
||||||
if channel.key != m.channels[channel.name] {
|
if channel.key != m.channels[channel.name] {
|
||||||
client.user.replies <- ErrBadChannelKey(channel)
|
client.user.Replies() <- ErrBadChannelKey(channel)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
channel.members.Add(user)
|
channel.members.Add(user)
|
||||||
user.channels.Add(channel)
|
user.channels.Add(channel)
|
||||||
|
|
||||||
channel.replies <- RplJoin(channel, user)
|
channel.Replies() <- RplJoin(channel, user)
|
||||||
channel.GetTopic(user)
|
channel.GetTopic(user)
|
||||||
user.replies <- RplNamReply(channel)
|
user.Replies() <- RplNamReply(channel)
|
||||||
user.replies <- RplEndOfNames(channel.server)
|
user.Replies() <- RplEndOfNames(channel.server)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *PartCommand) HandleChannel(channel *Channel) {
|
func (m *PartCommand) HandleChannel(channel *Channel) {
|
||||||
@ -176,7 +174,7 @@ func (m *PartCommand) HandleChannel(channel *Channel) {
|
|||||||
msg = user.Nick()
|
msg = user.Nick()
|
||||||
}
|
}
|
||||||
|
|
||||||
channel.replies <- RplPart(channel, user, msg)
|
channel.Replies() <- RplPart(channel, user, msg)
|
||||||
|
|
||||||
channel.members.Remove(user)
|
channel.members.Remove(user)
|
||||||
user.channels.Remove(channel)
|
user.channels.Remove(channel)
|
||||||
@ -190,7 +188,7 @@ func (m *TopicCommand) HandleChannel(channel *Channel) {
|
|||||||
user := m.User()
|
user := m.User()
|
||||||
|
|
||||||
if !channel.members[user] {
|
if !channel.members[user] {
|
||||||
user.replies <- ErrNotOnChannel(channel)
|
user.Replies() <- ErrNotOnChannel(channel)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -202,11 +200,10 @@ func (m *TopicCommand) HandleChannel(channel *Channel) {
|
|||||||
channel.topic = m.topic
|
channel.topic = m.topic
|
||||||
|
|
||||||
if channel.topic == "" {
|
if channel.topic == "" {
|
||||||
channel.replies <- RplNoTopic(channel)
|
channel.Replies() <- RplNoTopic(channel)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
channel.Replies() <- RplTopic(channel)
|
||||||
channel.replies <- RplTopic(channel)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *PrivMsgCommand) HandleChannel(channel *Channel) {
|
func (m *PrivMsgCommand) HandleChannel(channel *Channel) {
|
||||||
|
@ -31,7 +31,7 @@ type ClientSet map[*Client]bool
|
|||||||
func NewClient(server *Server, conn net.Conn) *Client {
|
func NewClient(server *Server, conn net.Conn) *Client {
|
||||||
read := StringReadChan(conn)
|
read := StringReadChan(conn)
|
||||||
write := StringWriteChan(conn)
|
write := StringWriteChan(conn)
|
||||||
replies := make(chan Reply)
|
replies := make(chan Reply, 1)
|
||||||
|
|
||||||
client := &Client{
|
client := &Client{
|
||||||
conn: conn,
|
conn: conn,
|
||||||
@ -67,7 +67,7 @@ func (c *Client) readConn(recv <-chan string) {
|
|||||||
func (c *Client) writeConn(write chan<- string, replies <-chan Reply) {
|
func (c *Client) writeConn(write chan<- string, replies <-chan Reply) {
|
||||||
for reply := range replies {
|
for reply := range replies {
|
||||||
if DEBUG_CLIENT {
|
if DEBUG_CLIENT {
|
||||||
log.Printf("%s ← %s", c, reply)
|
log.Printf("%s ← %s : %s", c, reply.Source(), reply)
|
||||||
}
|
}
|
||||||
write <- reply.Format(c)
|
write <- reply.Format(c)
|
||||||
}
|
}
|
||||||
|
@ -30,7 +30,7 @@ func StringReadChan(conn net.Conn) <-chan string {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
if DEBUG_NET {
|
if DEBUG_NET {
|
||||||
log.Printf("%s → %s", conn.RemoteAddr(), line)
|
log.Printf("%s → %s : %s", conn.RemoteAddr(), conn.LocalAddr(), line)
|
||||||
}
|
}
|
||||||
ch <- line
|
ch <- line
|
||||||
}
|
}
|
||||||
@ -45,7 +45,7 @@ func StringWriteChan(conn net.Conn) chan<- string {
|
|||||||
go func() {
|
go func() {
|
||||||
for str := range ch {
|
for str := range ch {
|
||||||
if DEBUG_NET {
|
if DEBUG_NET {
|
||||||
log.Printf("%s ← %s", conn.RemoteAddr(), str)
|
log.Printf("%s ← %s : %s", conn.RemoteAddr(), conn.LocalAddr(), str)
|
||||||
}
|
}
|
||||||
if _, err := writer.WriteString(str + "\r\n"); err != nil {
|
if _, err := writer.WriteString(str + "\r\n"); err != nil {
|
||||||
break
|
break
|
||||||
|
@ -106,7 +106,8 @@ func (m *RegisterCommand) HandleNickServ(ns *NickServ) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
user := NewUser(client.nick, m.password, ns.server)
|
user := NewUser(client.nick, ns.server).SetPassword(m.password)
|
||||||
|
ns.server.db.Save(user)
|
||||||
ns.Reply(client, "You have registered.")
|
ns.Reply(client, "You have registered.")
|
||||||
|
|
||||||
if !user.Login(client, client.nick, m.password) {
|
if !user.Login(client, client.nick, m.password) {
|
||||||
|
@ -27,7 +27,9 @@ type Queryable interface {
|
|||||||
QueryRow(string, ...interface{}) *sql.Row
|
QueryRow(string, ...interface{}) *sql.Row
|
||||||
}
|
}
|
||||||
|
|
||||||
type TransactionFunc func(Queryable) bool
|
type Savable interface {
|
||||||
|
Save(q Queryable) bool
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// general
|
// general
|
||||||
@ -36,7 +38,7 @@ type TransactionFunc func(Queryable) bool
|
|||||||
func NewDatabase() *Database {
|
func NewDatabase() *Database {
|
||||||
db, err := sql.Open("sqlite3", "ergonomadic.db")
|
db, err := sql.Open("sqlite3", "ergonomadic.db")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic("cannot open database")
|
log.Fatalln("cannot open database")
|
||||||
}
|
}
|
||||||
return &Database{db}
|
return &Database{db}
|
||||||
}
|
}
|
||||||
@ -48,7 +50,7 @@ func NewTransaction(tx *sql.Tx) *Transaction {
|
|||||||
func readLines(filename string) <-chan string {
|
func readLines(filename string) <-chan string {
|
||||||
file, err := os.Open(filename)
|
file, err := os.Open(filename)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
log.Fatalln(err)
|
||||||
}
|
}
|
||||||
reader := bufio.NewReader(file)
|
reader := bufio.NewReader(file)
|
||||||
lines := make(chan string)
|
lines := make(chan string)
|
||||||
@ -56,7 +58,7 @@ func readLines(filename string) <-chan string {
|
|||||||
defer file.Close()
|
defer file.Close()
|
||||||
defer close(lines)
|
defer close(lines)
|
||||||
for {
|
for {
|
||||||
line, err := reader.ReadString('\n')
|
line, err := reader.ReadString(';')
|
||||||
if err != nil {
|
if err != nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@ -70,28 +72,24 @@ func readLines(filename string) <-chan string {
|
|||||||
return lines
|
return lines
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *Database) execSqlFile(filename string) {
|
func (db *Database) ExecSqlFile(filename string) *Database {
|
||||||
db.Transact(func(q Queryable) bool {
|
db.Transact(func(q Queryable) bool {
|
||||||
for line := range readLines(filepath.Join("sql", filename)) {
|
for line := range readLines(filepath.Join("sql", filename)) {
|
||||||
log.Println(line)
|
log.Println(line)
|
||||||
q.Exec(line)
|
_, err := q.Exec(line)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalln(err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
|
return db
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *Database) InitTables() {
|
func (db *Database) Transact(txf func(Queryable) bool) {
|
||||||
db.execSqlFile("init.sql")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (db *Database) DropTables() {
|
|
||||||
db.execSqlFile("drop.sql")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (db *Database) Transact(txf TransactionFunc) {
|
|
||||||
tx, err := db.Begin()
|
tx, err := db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
log.Panicln(err)
|
||||||
}
|
}
|
||||||
if txf(tx) {
|
if txf(tx) {
|
||||||
tx.Commit()
|
tx.Commit()
|
||||||
@ -100,6 +98,28 @@ func (db *Database) Transact(txf TransactionFunc) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (db *Database) Save(s Savable) {
|
||||||
|
db.Transact(func(tx Queryable) bool {
|
||||||
|
return s.Save(tx)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// general purpose sql
|
||||||
|
//
|
||||||
|
|
||||||
|
func FindId(q Queryable, sql string, args ...interface{}) (rowId RowId, err error) {
|
||||||
|
row := q.QueryRow(sql, args...)
|
||||||
|
err = row.Scan(&rowId)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func Count(q Queryable, sql string, args ...interface{}) (count uint, err error) {
|
||||||
|
row := q.QueryRow(sql, args...)
|
||||||
|
err = row.Scan(&count)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// data
|
// data
|
||||||
//
|
//
|
||||||
@ -117,25 +137,39 @@ type ChannelRow struct {
|
|||||||
|
|
||||||
// user
|
// user
|
||||||
|
|
||||||
func FindUserByNick(q Queryable, nick string) (ur *UserRow) {
|
func FindAllUsers(q Queryable) (urs []UserRow, err error) {
|
||||||
ur = new(UserRow)
|
var rows *sql.Rows
|
||||||
row := q.QueryRow("SELECT * FROM user LIMIT 1 WHERE nick = ?", nick)
|
rows, err = q.Query("SELECT id, nick, hash FROM user")
|
||||||
err := row.Scan(&ur.id, &ur.nick, &ur.hash)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ur = nil
|
return
|
||||||
|
}
|
||||||
|
urs = make([]UserRow, 0)
|
||||||
|
for rows.Next() {
|
||||||
|
ur := UserRow{}
|
||||||
|
err = rows.Scan(&(ur.id), &(ur.nick), &(ur.hash))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
urs = append(urs, ur)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func FindUserIdByNick(q Queryable, nick string) (rowId RowId, err error) {
|
func FindUserByNick(q Queryable, nick string) (ur *UserRow, err error) {
|
||||||
row := q.QueryRow("SELECT id FROM user WHERE nick = ?", nick)
|
ur = &UserRow{}
|
||||||
err = row.Scan(&rowId)
|
row := q.QueryRow("SELECT id, nick, hash FROM user LIMIT 1 WHERE nick = ?",
|
||||||
|
nick)
|
||||||
|
err = row.Scan(&(ur.id), &(ur.nick), &(ur.hash))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func FindUserIdByNick(q Queryable, nick string) (RowId, error) {
|
||||||
|
return FindId(q, "SELECT id FROM user WHERE nick = ?", nick)
|
||||||
|
}
|
||||||
|
|
||||||
func FindChannelByName(q Queryable, name string) (cr *ChannelRow) {
|
func FindChannelByName(q Queryable, name string) (cr *ChannelRow) {
|
||||||
cr = new(ChannelRow)
|
cr = new(ChannelRow)
|
||||||
row := q.QueryRow("SELECT * FROM channel LIMIT 1 WHERE name = ?", name)
|
row := q.QueryRow("SELECT id, name FROM channel LIMIT 1 WHERE name = ?", name)
|
||||||
err := row.Scan(&(cr.id), &(cr.name))
|
err := row.Scan(&(cr.id), &(cr.name))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cr = nil
|
cr = nil
|
||||||
@ -185,25 +219,31 @@ func InsertUserChannels(q Queryable, userId RowId, channelIds []RowId) (err erro
|
|||||||
|
|
||||||
// channel
|
// channel
|
||||||
|
|
||||||
func FindChannelIdByName(q Queryable, name string) (channelId RowId, err error) {
|
func FindChannelIdByName(q Queryable, name string) (RowId, error) {
|
||||||
row := q.QueryRow("SELECT id FROM channel WHERE name = ?", name)
|
return FindId(q, "SELECT id FROM channel WHERE name = ?", name)
|
||||||
err = row.Scan(&channelId)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func FindChannelsForUser(q Queryable, userId RowId) (crs []ChannelRow) {
|
func FindChannelsForUser(q Queryable, userId RowId) (crs []ChannelRow, err error) {
|
||||||
rows, err := q.Query(`SELECT * FROM channel WHERE id IN
|
query := ` FROM channel WHERE id IN
|
||||||
(SELECT channel_id from user_channel WHERE user_id = ?)`, userId)
|
(SELECT channel_id from user_channel WHERE user_id = ?)`
|
||||||
|
count, err := Count(q, "SELECT COUNT(id)"+query, userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
return
|
||||||
}
|
}
|
||||||
crs = make([]ChannelRow, 0)
|
rows, err := q.Query("SELECT id, name"+query, userId)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
crs = make([]ChannelRow, count)
|
||||||
|
var i = 0
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
cr := ChannelRow{}
|
cr := ChannelRow{}
|
||||||
if err := rows.Scan(&(cr.id), &(cr.name)); err != nil {
|
err = rows.Scan(&(cr.id), &(cr.name))
|
||||||
panic(err)
|
if err != nil {
|
||||||
|
return
|
||||||
}
|
}
|
||||||
crs = append(crs, cr)
|
crs[i] = cr
|
||||||
|
i++
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -24,10 +24,11 @@ type Server struct {
|
|||||||
channels ChannelNameMap
|
channels ChannelNameMap
|
||||||
services ServiceNameMap
|
services ServiceNameMap
|
||||||
commands chan<- Command
|
commands chan<- Command
|
||||||
|
db *Database
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServer(name string) *Server {
|
func NewServer(name string) *Server {
|
||||||
commands := make(chan Command)
|
commands := make(chan Command, 1)
|
||||||
server := &Server{
|
server := &Server{
|
||||||
ctime: time.Now(),
|
ctime: time.Now(),
|
||||||
name: name,
|
name: name,
|
||||||
@ -35,16 +36,27 @@ func NewServer(name string) *Server {
|
|||||||
users: make(UserNameMap),
|
users: make(UserNameMap),
|
||||||
channels: make(ChannelNameMap),
|
channels: make(ChannelNameMap),
|
||||||
services: make(ServiceNameMap),
|
services: make(ServiceNameMap),
|
||||||
|
db: NewDatabase(),
|
||||||
}
|
}
|
||||||
go server.receiveCommands(commands)
|
go server.receiveCommands(commands)
|
||||||
NewNickServ(server)
|
NewNickServ(server)
|
||||||
|
server.db.Transact(func(q Queryable) bool {
|
||||||
|
urs, err := FindAllUsers(server.db)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, ur := range urs {
|
||||||
|
NewUser(ur.nick, server).SetHash(ur.hash)
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
})
|
||||||
return server
|
return server
|
||||||
}
|
}
|
||||||
|
|
||||||
func (server *Server) receiveCommands(commands <-chan Command) {
|
func (server *Server) receiveCommands(commands <-chan Command) {
|
||||||
for command := range commands {
|
for command := range commands {
|
||||||
if DEBUG_SERVER {
|
if DEBUG_SERVER {
|
||||||
log.Printf("%s ← %s %s", server, command.Client(), command)
|
log.Printf("%s → %s : %s", command.Client(), server, command)
|
||||||
}
|
}
|
||||||
command.Client().atime = time.Now()
|
command.Client().atime = time.Now()
|
||||||
command.HandleServer(server)
|
command.HandleServer(server)
|
||||||
@ -278,7 +290,7 @@ func (m *TopicCommand) HandleServer(s *Server) {
|
|||||||
|
|
||||||
channel := s.channels[m.channel]
|
channel := s.channels[m.channel]
|
||||||
if channel == nil {
|
if channel == nil {
|
||||||
user.Replies() <- ErrNoSuchChannel(s, m.channel)
|
m.Client().Replies() <- ErrNoSuchChannel(s, m.channel)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -29,7 +29,7 @@ type BaseService struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewService(service EditableService, s *Server, name string) Service {
|
func NewService(service EditableService, s *Server, name string) Service {
|
||||||
commands := make(chan ServiceCommand)
|
commands := make(chan ServiceCommand, 1)
|
||||||
base := &BaseService{
|
base := &BaseService{
|
||||||
server: s,
|
server: s,
|
||||||
name: name,
|
name: name,
|
||||||
|
@ -46,9 +46,9 @@ func (set UserSet) Nicks() []string {
|
|||||||
return nicks
|
return nicks
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUser(nick string, password string, server *Server) *User {
|
func NewUser(nick string, server *Server) *User {
|
||||||
commands := make(chan UserCommand)
|
commands := make(chan UserCommand, 1)
|
||||||
replies := make(chan Reply)
|
replies := make(chan Reply, 1)
|
||||||
user := &User{
|
user := &User{
|
||||||
nick: nick,
|
nick: nick,
|
||||||
server: server,
|
server: server,
|
||||||
@ -56,7 +56,6 @@ func NewUser(nick string, password string, server *Server) *User {
|
|||||||
channels: make(ChannelSet),
|
channels: make(ChannelSet),
|
||||||
replies: replies,
|
replies: replies,
|
||||||
}
|
}
|
||||||
user.SetPassword(password)
|
|
||||||
|
|
||||||
go user.receiveCommands(commands)
|
go user.receiveCommands(commands)
|
||||||
go user.receiveReplies(replies)
|
go user.receiveReplies(replies)
|
||||||
@ -81,34 +80,40 @@ func (user *User) Save(q Queryable) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
userId := *(user.id)
|
||||||
channelIds := user.channels.Ids()
|
channelIds := user.channels.Ids()
|
||||||
if len(channelIds) == 0 {
|
if len(channelIds) == 0 {
|
||||||
if err := DeleteAllUserChannels(q, *(user.id)); err != nil {
|
if err := DeleteAllUserChannels(q, userId); err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if err := DeleteOtherUserChannels(q, *(user.id), channelIds); err != nil {
|
if err := DeleteOtherUserChannels(q, userId, channelIds); err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if err := InsertUserChannels(q, *(user.id), channelIds); err != nil {
|
if err := InsertUserChannels(q, userId, channelIds); err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (user *User) SetPassword(password string) {
|
func (user *User) SetPassword(password string) *User {
|
||||||
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic("bcrypt failed; cannot generate password hash")
|
panic("bcrypt failed; cannot generate password hash")
|
||||||
}
|
}
|
||||||
|
return user.SetHash(hash)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (user *User) SetHash(hash []byte) *User {
|
||||||
user.hash = hash
|
user.hash = hash
|
||||||
|
return user
|
||||||
}
|
}
|
||||||
|
|
||||||
func (user *User) receiveCommands(commands <-chan UserCommand) {
|
func (user *User) receiveCommands(commands <-chan UserCommand) {
|
||||||
for command := range commands {
|
for command := range commands {
|
||||||
if DEBUG_USER {
|
if DEBUG_USER {
|
||||||
log.Printf("%s ← %s %s", user, command.Client(), command)
|
log.Printf("%s → %s : %s", command.Client(), user, command)
|
||||||
}
|
}
|
||||||
command.HandleUser(user)
|
command.HandleUser(user)
|
||||||
}
|
}
|
||||||
@ -117,7 +122,9 @@ func (user *User) receiveCommands(commands <-chan UserCommand) {
|
|||||||
// Distribute replies to clients.
|
// Distribute replies to clients.
|
||||||
func (user *User) receiveReplies(replies <-chan Reply) {
|
func (user *User) receiveReplies(replies <-chan Reply) {
|
||||||
for reply := range replies {
|
for reply := range replies {
|
||||||
log.Printf("%s ← %s", user, reply)
|
if DEBUG_USER {
|
||||||
|
log.Printf("%s ← %s : %s", user, reply.Source(), reply)
|
||||||
|
}
|
||||||
for client := range user.clients {
|
for client := range user.clients {
|
||||||
client.Replies() <- reply
|
client.Replies() <- reply
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user