diff --git a/src/irc/channel.go b/src/irc/channel.go index b124391e..ccb7e121 100644 --- a/src/irc/channel.go +++ b/src/irc/channel.go @@ -32,11 +32,11 @@ func (set ChannelSet) Remove(channel *Channel) { } func (set ChannelSet) Ids() (ids []RowId) { - ids = []RowId{} + ids = make([]RowId, len(set)) + var i = 0 for channel := range set { - if channel.id != nil { - ids = append(ids, *channel.id) - } + ids[i] = *(channel.id) + i++ } return ids } @@ -67,15 +67,18 @@ func NewChannel(s *Server, name string) *Channel { func (channel *Channel) Save(q Queryable) bool { if channel.id == nil { if err := InsertChannel(q, channel); err != nil { + log.Println(err) return false } channelId, err := FindChannelIdByName(q, channel.name) if err != nil { + log.Println(err) return false } channel.id = &channelId } else { if err := UpdateChannel(q, channel); err != nil { + log.Println(err) return false } } @@ -120,6 +123,10 @@ func (channel *Channel) GetTopic(replier Replier) { replier.Replies() <- RplTopic(channel) } +func (channel *Channel) GetUsers(replier Replier) { + replier.Replies() <- NewNamesReply(channel) +} + func (channel *Channel) Replies() chan<- Reply { return channel.replies } @@ -128,6 +135,10 @@ func (channel *Channel) Id() string { return channel.name } +func (channel *Channel) Nick() string { + return channel.name +} + func (channel *Channel) PublicId() string { return channel.name } @@ -140,33 +151,33 @@ func (channel *Channel) String() string { return channel.Id() } +func (channel *Channel) Join(user *User) { + channel.members.Add(user) + user.channels.Add(channel) + channel.Replies() <- RplJoin(channel, user) + channel.GetTopic(user) + channel.GetUsers(user) +} + // // commands // func (m *JoinCommand) HandleChannel(channel *Channel) { client := m.Client() - user := client.user - if channel.key != m.channels[channel.name] { - client.user.Replies() <- ErrBadChannelKey(channel) + client.Replies() <- ErrBadChannelKey(channel) return } - channel.members.Add(user) - user.channels.Add(channel) - - channel.Replies() <- RplJoin(channel, user) - channel.GetTopic(user) - user.Replies() <- RplNamReply(channel) - user.Replies() <- RplEndOfNames(channel.server) + channel.Join(client.user) } func (m *PartCommand) HandleChannel(channel *Channel) { user := m.Client().user if !channel.members[user] { - user.replies <- ErrNotOnChannel(channel) + user.Replies() <- ErrNotOnChannel(channel) return } diff --git a/src/irc/client.go b/src/irc/client.go index fbbc2cc1..a0652e80 100644 --- a/src/irc/client.go +++ b/src/irc/client.go @@ -48,7 +48,6 @@ func NewClient(server *Server, conn net.Conn) *Client { func (c *Client) readConn(recv <-chan string) { for str := range recv { - m, err := ParseCommand(str) if err != nil { if err == NotEnoughArgsError { @@ -59,7 +58,7 @@ func (c *Client) readConn(recv <-chan string) { continue } - m.SetClient(c) + m.SetBase(c) c.server.commands <- m } } @@ -69,7 +68,7 @@ func (c *Client) writeConn(write chan<- string, replies <-chan Reply) { if DEBUG_CLIENT { log.Printf("%s ← %s : %s", c, reply.Source(), reply) } - write <- reply.Format(c) + reply.Format(c, write) } } diff --git a/src/irc/commands.go b/src/irc/commands.go index 8ac00e06..1ecb5090 100644 --- a/src/irc/commands.go +++ b/src/irc/commands.go @@ -11,12 +11,13 @@ type Command interface { Client() *Client User() *User Source() Identifier + Reply(Reply) HandleServer(*Server) } type EditableCommand interface { Command - SetClient(*Client) + SetBase(*Client) } var ( @@ -46,25 +47,19 @@ func (command *BaseCommand) Client() *Client { } func (command *BaseCommand) User() *User { - if command.Client() == nil { - return nil - } - return command.User() + return command.Client().user } -func (command *BaseCommand) SetClient(c *Client) { +func (command *BaseCommand) SetBase(c *Client) { *command = BaseCommand{c} } func (command *BaseCommand) Source() Identifier { - client := command.Client() - if client == nil { - return nil - } - if client.user != nil { - return client.user - } - return client + return command.client +} + +func (command *BaseCommand) Reply(reply Reply) { + command.client.Replies() <- reply } func ParseCommand(line string) (EditableCommand, error) { @@ -116,9 +111,8 @@ func (cmd *UnknownCommand) String() string { func NewUnknownCommand(command string, args []string) *UnknownCommand { return &UnknownCommand{ - BaseCommand: BaseCommand{}, - command: command, - args: args, + command: command, + args: args, } } @@ -139,8 +133,7 @@ func NewPingCommand(args []string) (EditableCommand, error) { return nil, NotEnoughArgsError } msg := &PingCommand{ - BaseCommand: BaseCommand{}, - server: args[0], + server: args[0], } if len(args) > 1 { msg.server2 = args[1] @@ -165,8 +158,7 @@ func NewPongCommand(args []string) (EditableCommand, error) { return nil, NotEnoughArgsError } message := &PongCommand{ - BaseCommand: BaseCommand{}, - server1: args[0], + server1: args[0], } if len(args) > 1 { message.server2 = args[1] @@ -190,8 +182,7 @@ func NewPassCommand(args []string) (EditableCommand, error) { return nil, NotEnoughArgsError } return &PassCommand{ - BaseCommand: BaseCommand{}, - password: args[0], + password: args[0], }, nil } @@ -211,8 +202,7 @@ func NewNickCommand(args []string) (EditableCommand, error) { return nil, NotEnoughArgsError } return &NickCommand{ - BaseCommand: BaseCommand{}, - nickname: args[0], + nickname: args[0], }, nil } @@ -236,10 +226,9 @@ func NewUserMsgCommand(args []string) (EditableCommand, error) { return nil, NotEnoughArgsError } msg := &UserMsgCommand{ - BaseCommand: BaseCommand{}, - user: args[0], - unused: args[2], - realname: args[3], + user: args[0], + unused: args[2], + realname: args[3], } mode, err := strconv.ParseUint(args[1], 10, 8) if err == nil { @@ -260,9 +249,7 @@ func (cmd *QuitCommand) String() string { } func NewQuitCommand(args []string) (EditableCommand, error) { - msg := &QuitCommand{ - BaseCommand: BaseCommand{}, - } + msg := &QuitCommand{} if len(args) > 0 { msg.message = args[0] } @@ -283,8 +270,7 @@ func (cmd *JoinCommand) String() string { func NewJoinCommand(args []string) (EditableCommand, error) { msg := &JoinCommand{ - BaseCommand: BaseCommand{}, - channels: make(map[string]string), + channels: make(map[string]string), } if len(args) == 0 { @@ -327,8 +313,7 @@ func NewPartCommand(args []string) (EditableCommand, error) { return nil, NotEnoughArgsError } msg := &PartCommand{ - BaseCommand: BaseCommand{}, - channels: strings.Split(args[0], ","), + channels: strings.Split(args[0], ","), } if len(args) > 1 { msg.message = args[1] @@ -353,9 +338,8 @@ func NewPrivMsgCommand(args []string) (EditableCommand, error) { return nil, NotEnoughArgsError } return &PrivMsgCommand{ - BaseCommand: BaseCommand{}, - target: args[0], - message: args[1], + target: args[0], + message: args[1], }, nil } @@ -384,8 +368,7 @@ func NewTopicCommand(args []string) (EditableCommand, error) { return nil, NotEnoughArgsError } msg := &TopicCommand{ - BaseCommand: BaseCommand{}, - channel: args[0], + channel: args[0], } if len(args) > 1 { msg.topic = args[1] @@ -409,8 +392,7 @@ func NewModeCommand(args []string) (EditableCommand, error) { } cmd := &ModeCommand{ - BaseCommand: BaseCommand{}, - nickname: args[0], + nickname: args[0], } if len(args) > 1 { diff --git a/src/irc/nickserv.go b/src/irc/nickserv.go index 98fd9b43..e0fd5880 100644 --- a/src/irc/nickserv.go +++ b/src/irc/nickserv.go @@ -12,7 +12,7 @@ const ( type NickServCommand interface { HandleNickServ(*NickServ) Client() *Client - SetClient(*Client) + SetBase(*Client) } type NickServ struct { @@ -56,7 +56,7 @@ func (ns *NickServ) HandlePrivMsg(m *PrivMsgCommand) { return } - cmd.SetClient(m.Client()) + cmd.SetBase(m.Client()) if ns.Debug() { log.Printf("%s ← %s %s", ns, cmd.Client(), cmd) } @@ -106,7 +106,8 @@ func (m *RegisterCommand) HandleNickServ(ns *NickServ) { return } - user := NewUser(client.nick, ns.server).SetPassword(m.password) + user := NewUser(client.nick, ns.server) + user.SetPassword(m.password) Save(ns.server.db, user) ns.Reply(client, "You have registered.") diff --git a/src/irc/persistence.go b/src/irc/persistence.go index dc299358..8a450855 100644 --- a/src/irc/persistence.go +++ b/src/irc/persistence.go @@ -23,6 +23,10 @@ type Savable interface { Save(q Queryable) bool } +type Loadable interface { + Load(q Queryable) bool +} + // // general // @@ -89,6 +93,10 @@ func Save(db *sql.DB, s Savable) { Transact(db, s.Save) } +func Load(db *sql.DB, l Loadable) { + Transact(db, l.Load) +} + // // general purpose sql // @@ -99,7 +107,7 @@ func findId(q Queryable, sql string, args ...interface{}) (rowId RowId, err erro return } -func Count(q Queryable, sql string, args ...interface{}) (count uint, err error) { +func countRows(q Queryable, sql string, args ...interface{}) (count uint, err error) { row := q.QueryRow(sql, args...) err = row.Scan(&count) return @@ -162,20 +170,20 @@ func FindChannelByName(q Queryable, name string) (cr *ChannelRow) { return } -func InsertUser(q Queryable, user *User) (err error) { +func InsertUser(q Queryable, row *UserRow) (err error) { _, err = q.Exec("INSERT INTO user (nick, hash) VALUES (?, ?)", - user.nick, user.hash) + row.nick, row.hash) return } -func UpdateUser(q Queryable, user *User) (err error) { +func UpdateUser(q Queryable, row *UserRow) (err error) { _, err = q.Exec("UPDATE user SET nick = ?, hash = ? WHERE id = ?", - user.nick, user.hash, *(user.id)) + row.nick, row.hash, row.id) return } -func DeleteUser(q Queryable, user *User) (err error) { - _, err = q.Exec("DELETE FROM user WHERE id = ?", *(user.id)) +func DeleteUser(q Queryable, id RowId) (err error) { + _, err = q.Exec("DELETE FROM user WHERE id = ?", id) return } @@ -211,14 +219,12 @@ func FindChannelIdByName(q Queryable, name string) (RowId, error) { return findId(q, "SELECT id FROM channel WHERE name = ?", name) } -func FindChannelsForUser(q Queryable, userId RowId) (crs []*ChannelRow, err error) { - query := ` FROM channel WHERE id IN -(SELECT channel_id from user_channel WHERE user_id = ?)` - count, err := Count(q, "SELECT COUNT(id)"+query, userId) +func findChannels(q Queryable, where string, args ...interface{}) (crs []*ChannelRow, err error) { + count, err := countRows(q, "SELECT COUNT(id) FROM channel "+where, args...) if err != nil { return } - rows, err := q.Query("SELECT id, name"+query, userId) + rows, err := q.Query("SELECT id, name FROM channel "+where, args...) if err != nil { return } @@ -236,6 +242,17 @@ func FindChannelsForUser(q Queryable, userId RowId) (crs []*ChannelRow, err erro return } +func FindChannelsForUser(q Queryable, userId RowId) (crs []*ChannelRow, err error) { + crs, err = findChannels(q, + "WHERE id IN (SELECT channel_id from user_channel WHERE user_id = ?)", userId) + return +} + +func FindAllChannels(q Queryable) (crs []*ChannelRow, err error) { + crs, err = findChannels(q, "") + return +} + func InsertChannel(q Queryable, channel *Channel) (err error) { _, err = q.Exec("INSERT INTO channel (name) VALUES (?)", channel.name) return diff --git a/src/irc/reply.go b/src/irc/reply.go index ba6a036c..918ed4f8 100644 --- a/src/irc/reply.go +++ b/src/irc/reply.go @@ -17,7 +17,7 @@ type Replier interface { } type Reply interface { - Format(client *Client) string + Format(*Client, chan<- string) Source() Identifier } @@ -31,7 +31,7 @@ func (reply *BaseReply) Source() Identifier { } type StringReply struct { - BaseReply + *BaseReply code string } @@ -40,13 +40,13 @@ func NewStringReply(source Identifier, code string, message := fmt.Sprintf(format, args...) fullMessage := fmt.Sprintf(":%s %s %s", source.Id(), code, message) return &StringReply{ - BaseReply: BaseReply{source, fullMessage}, + BaseReply: &BaseReply{source, fullMessage}, code: code, } } -func (reply *StringReply) Format(client *Client) string { - return reply.message +func (reply *StringReply) Format(client *Client, write chan<- string) { + write <- reply.message } func (reply *StringReply) String() string { @@ -55,19 +55,23 @@ func (reply *StringReply) String() string { } type NumericReply struct { - BaseReply + *BaseReply code int } func NewNumericReply(source Identifier, code int, format string, args ...interface{}) *NumericReply { return &NumericReply{ - BaseReply: BaseReply{source, fmt.Sprintf(format, args...)}, + BaseReply: &BaseReply{source, fmt.Sprintf(format, args...)}, code: code, } } -func (reply *NumericReply) Format(client *Client) string { +func (reply *NumericReply) Format(client *Client, write chan<- string) { + write <- reply.FormatString(client) +} + +func (reply *NumericReply) FormatString(client *Client) string { return fmt.Sprintf(":%s %03d %s %s", reply.Source().Id(), reply.code, client.Nick(), reply.message) } @@ -77,6 +81,53 @@ func (reply *NumericReply) String() string { reply.source, reply.code, reply.message) } +// names reply + +type NamesReply struct { + *BaseReply + channel *Channel +} + +func NewNamesReply(channel *Channel) Reply { + return &NamesReply{ + BaseReply: &BaseReply{ + source: channel, + }, + } +} + +const ( + MAX_REPLY_LEN = 510 // 512 - CRLF +) + +func joinedLen(names []string) int { + var l = len(names) - 1 // " " between names + for _, name := range names { + l += len(name) + } + return l +} + +func (reply *NamesReply) Format(client *Client, write chan<- string) { + base := RplNamReply(reply.channel, []string{}) + baseLen := len(base.FormatString(client)) + tooLong := func(names []string) bool { + return (baseLen + joinedLen(names)) > MAX_REPLY_LEN + } + var start = 0 + nicks := reply.channel.Nicks() + for i := range nicks { + if (i > start) && tooLong(nicks[start:i]) { + RplNamReply(reply.channel, nicks[start:i-1]).Format(client, write) + start = i - 1 + } + } + if start < (len(nicks) - 1) { + RplNamReply(reply.channel, nicks[start:]).Format(client, write) + } + RplEndOfNames(reply.channel).Format(client, write) +} + // messaging replies func RplPrivMsg(source Identifier, target Identifier, message string) Reply { @@ -118,7 +169,7 @@ func RplWelcome(source Identifier, client *Client) Reply { "Welcome to the Internet Relay Network %s", client.Id()) } -func RplYourHost(server *Server, target *Client) Reply { +func RplYourHost(server *Server) Reply { return NewNumericReply(server, RPL_YOURHOST, "Your host is %s, running version %s", server.name, VERSION) } @@ -152,10 +203,9 @@ func RplInvitingMsg(channel *Channel, invitee *Client) Reply { "%s %s", channel.name, invitee.Nick()) } -func RplNamReply(channel *Channel) Reply { - // TODO multiple names and splitting based on message size - return NewNumericReply(channel.server, RPL_NAMREPLY, - "= %s :%s", channel.name, strings.Join(channel.Nicks(), " ")) +func RplNamReply(channel *Channel, names []string) *NumericReply { + return NewNumericReply(channel.server, RPL_NAMREPLY, "= %s :%s", + channel.name, strings.Join(names, " ")) } func RplEndOfNames(source Identifier) Reply { diff --git a/src/irc/server.go b/src/irc/server.go index 373e65b4..39e7c750 100644 --- a/src/irc/server.go +++ b/src/irc/server.go @@ -41,17 +41,34 @@ func NewServer(name string) *Server { } go server.receiveCommands(commands) NewNickServ(server) - Transact(server.db, func(q Queryable) bool { - urs, err := FindAllUsers(server.db) - if err != nil { + Load(server.db, server) + return server +} + +func (server *Server) Load(q Queryable) bool { + crs, err := FindAllChannels(q) + if err != nil { + log.Println(err) + return false + } + for _, cr := range crs { + channel := server.GetOrMakeChannel(cr.name) + channel.id = &(cr.id) + } + + urs, err := FindAllUsers(q) + if err != nil { + log.Println(err) + return false + } + for _, ur := range urs { + user := NewUser(ur.nick, server) + user.SetHash(ur.hash) + if !user.Load(q) { return false } - for _, ur := range urs { - NewUser(ur.nick, server).SetHash(ur.hash) - } - return false - }) - return server + } + return true } func (server *Server) receiveCommands(commands <-chan Command) { @@ -115,7 +132,7 @@ func (s *Server) tryRegister(c *Client) { c.registered = true replies := []Reply{ RplWelcome(s, c), - RplYourHost(s, c), + RplYourHost(s), RplCreated(s), RplMyInfo(s), } @@ -318,21 +335,21 @@ func (m *PrivMsgCommand) HandleServer(s *Server) { if m.TargetIsChannel() { channel := s.channels[m.target] if channel == nil { - user.Replies() <- ErrNoSuchChannel(s, m.target) + m.Client().Replies() <- ErrNoSuchChannel(s, m.target) return } - channel.Commands() <- m + channel.commands <- m return } target := s.users[m.target] if target == nil { - user.Replies() <- ErrNoSuchNick(s, m.target) + m.Client().Replies() <- ErrNoSuchNick(s, m.target) return } - target.Commands() <- m + target.commands <- m } func (m *ModeCommand) HandleServer(s *Server) { diff --git a/src/irc/user.go b/src/irc/user.go index 98c66567..6064c939 100644 --- a/src/irc/user.go +++ b/src/irc/user.go @@ -16,7 +16,7 @@ type UserCommand interface { } type User struct { - id *RowId + id RowId nick string hash []byte server *Server @@ -64,50 +64,80 @@ func NewUser(nick string, server *Server) *User { return user } +func (user *User) Row() *UserRow { + return &UserRow{user.id, user.nick, user.hash} +} + +func (user *User) Create(q Queryable) bool { + var err error + if err := InsertUser(q, user.Row()); err != nil { + log.Println(err) + return false + } + user.id, err = FindUserIdByNick(q, user.nick) + if err != nil { + log.Println(err) + return false + } + return true +} + func (user *User) Save(q Queryable) bool { - if user.id == nil { - if err := InsertUser(q, user); err != nil { - return false - } - userId, err := FindUserIdByNick(q, user.nick) - if err != nil { - return false - } - user.id = &userId - } else { - if err := UpdateUser(q, user); err != nil { - return false - } + if err := UpdateUser(q, user.Row()); err != nil { + log.Println(err) + return false } - userId := *(user.id) channelIds := user.channels.Ids() if len(channelIds) == 0 { - if err := DeleteAllUserChannels(q, userId); err != nil { + if err := DeleteAllUserChannels(q, user.id); err != nil { + log.Println(err) return false } } else { - if err := DeleteOtherUserChannels(q, userId, channelIds); err != nil { + if err := DeleteOtherUserChannels(q, user.id, channelIds); err != nil { + log.Println(err) return false } - if err := InsertUserChannels(q, userId, channelIds); err != nil { + if err := InsertUserChannels(q, user.id, channelIds); err != nil { + log.Println(err) return false } } return true } -func (user *User) SetPassword(password string) *User { - hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) +func (user *User) Delete(q Queryable) bool { + err := DeleteUser(q, user.id) if err != nil { - panic("bcrypt failed; cannot generate password hash") + log.Println(err) + return false } - return user.SetHash(hash) + return true } -func (user *User) SetHash(hash []byte) *User { +func (user *User) Load(q Queryable) bool { + crs, err := FindChannelsForUser(q, user.id) + if err != nil { + log.Println(err) + return false + } + for _, cr := range crs { + user.server.GetOrMakeChannel(cr.name).Join(user) + } + return true +} + +func (user *User) SetPassword(password string) { + hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + log.Panicln(err) + } + user.SetHash(hash) +} + +func (user *User) SetHash(hash []byte) { user.hash = hash - return user } func (user *User) receiveCommands(commands <-chan UserCommand) { @@ -149,10 +179,6 @@ func (user *User) String() string { return user.Id() } -func (user *User) Commands() chan<- UserCommand { - return user.commands -} - func (user *User) Login(c *Client, nick string, password string) bool { if nick != c.nick { return false @@ -172,8 +198,7 @@ func (user *User) Login(c *Client, nick string, password string) bool { c.user = user for channel := range user.channels { channel.GetTopic(c) - c.Replies() <- RplNamReply(channel) - c.Replies() <- RplEndOfNames(channel.server) + channel.GetUsers(c) } return true }