3
0
mirror of https://github.com/ergochat/ergo.git synced 2025-01-03 16:42:38 +01:00

Send Names reply in multiple messages.

This commit is contained in:
Jeremy Latt 2013-06-02 22:07:50 -07:00
parent 0001131095
commit f133f3691d
8 changed files with 235 additions and 133 deletions

View File

@ -32,11 +32,11 @@ func (set ChannelSet) Remove(channel *Channel) {
}
func (set ChannelSet) Ids() (ids []RowId) {
ids = []RowId{}
ids = make([]RowId, len(set))
var i = 0
for channel := range set {
if channel.id != nil {
ids = append(ids, *channel.id)
}
ids[i] = *(channel.id)
i++
}
return ids
}
@ -67,15 +67,18 @@ func NewChannel(s *Server, name string) *Channel {
func (channel *Channel) Save(q Queryable) bool {
if channel.id == nil {
if err := InsertChannel(q, channel); err != nil {
log.Println(err)
return false
}
channelId, err := FindChannelIdByName(q, channel.name)
if err != nil {
log.Println(err)
return false
}
channel.id = &channelId
} else {
if err := UpdateChannel(q, channel); err != nil {
log.Println(err)
return false
}
}
@ -120,6 +123,10 @@ func (channel *Channel) GetTopic(replier Replier) {
replier.Replies() <- RplTopic(channel)
}
func (channel *Channel) GetUsers(replier Replier) {
replier.Replies() <- NewNamesReply(channel)
}
func (channel *Channel) Replies() chan<- Reply {
return channel.replies
}
@ -128,6 +135,10 @@ func (channel *Channel) Id() string {
return channel.name
}
func (channel *Channel) Nick() string {
return channel.name
}
func (channel *Channel) PublicId() string {
return channel.name
}
@ -140,33 +151,33 @@ func (channel *Channel) String() string {
return channel.Id()
}
func (channel *Channel) Join(user *User) {
channel.members.Add(user)
user.channels.Add(channel)
channel.Replies() <- RplJoin(channel, user)
channel.GetTopic(user)
channel.GetUsers(user)
}
//
// commands
//
func (m *JoinCommand) HandleChannel(channel *Channel) {
client := m.Client()
user := client.user
if channel.key != m.channels[channel.name] {
client.user.Replies() <- ErrBadChannelKey(channel)
client.Replies() <- ErrBadChannelKey(channel)
return
}
channel.members.Add(user)
user.channels.Add(channel)
channel.Replies() <- RplJoin(channel, user)
channel.GetTopic(user)
user.Replies() <- RplNamReply(channel)
user.Replies() <- RplEndOfNames(channel.server)
channel.Join(client.user)
}
func (m *PartCommand) HandleChannel(channel *Channel) {
user := m.Client().user
if !channel.members[user] {
user.replies <- ErrNotOnChannel(channel)
user.Replies() <- ErrNotOnChannel(channel)
return
}

View File

@ -48,7 +48,6 @@ func NewClient(server *Server, conn net.Conn) *Client {
func (c *Client) readConn(recv <-chan string) {
for str := range recv {
m, err := ParseCommand(str)
if err != nil {
if err == NotEnoughArgsError {
@ -59,7 +58,7 @@ func (c *Client) readConn(recv <-chan string) {
continue
}
m.SetClient(c)
m.SetBase(c)
c.server.commands <- m
}
}
@ -69,7 +68,7 @@ func (c *Client) writeConn(write chan<- string, replies <-chan Reply) {
if DEBUG_CLIENT {
log.Printf("%s ← %s : %s", c, reply.Source(), reply)
}
write <- reply.Format(c)
reply.Format(c, write)
}
}

View File

@ -11,12 +11,13 @@ type Command interface {
Client() *Client
User() *User
Source() Identifier
Reply(Reply)
HandleServer(*Server)
}
type EditableCommand interface {
Command
SetClient(*Client)
SetBase(*Client)
}
var (
@ -46,25 +47,19 @@ func (command *BaseCommand) Client() *Client {
}
func (command *BaseCommand) User() *User {
if command.Client() == nil {
return nil
}
return command.User()
return command.Client().user
}
func (command *BaseCommand) SetClient(c *Client) {
func (command *BaseCommand) SetBase(c *Client) {
*command = BaseCommand{c}
}
func (command *BaseCommand) Source() Identifier {
client := command.Client()
if client == nil {
return nil
}
if client.user != nil {
return client.user
}
return client
return command.client
}
func (command *BaseCommand) Reply(reply Reply) {
command.client.Replies() <- reply
}
func ParseCommand(line string) (EditableCommand, error) {
@ -116,9 +111,8 @@ func (cmd *UnknownCommand) String() string {
func NewUnknownCommand(command string, args []string) *UnknownCommand {
return &UnknownCommand{
BaseCommand: BaseCommand{},
command: command,
args: args,
command: command,
args: args,
}
}
@ -139,8 +133,7 @@ func NewPingCommand(args []string) (EditableCommand, error) {
return nil, NotEnoughArgsError
}
msg := &PingCommand{
BaseCommand: BaseCommand{},
server: args[0],
server: args[0],
}
if len(args) > 1 {
msg.server2 = args[1]
@ -165,8 +158,7 @@ func NewPongCommand(args []string) (EditableCommand, error) {
return nil, NotEnoughArgsError
}
message := &PongCommand{
BaseCommand: BaseCommand{},
server1: args[0],
server1: args[0],
}
if len(args) > 1 {
message.server2 = args[1]
@ -190,8 +182,7 @@ func NewPassCommand(args []string) (EditableCommand, error) {
return nil, NotEnoughArgsError
}
return &PassCommand{
BaseCommand: BaseCommand{},
password: args[0],
password: args[0],
}, nil
}
@ -211,8 +202,7 @@ func NewNickCommand(args []string) (EditableCommand, error) {
return nil, NotEnoughArgsError
}
return &NickCommand{
BaseCommand: BaseCommand{},
nickname: args[0],
nickname: args[0],
}, nil
}
@ -236,10 +226,9 @@ func NewUserMsgCommand(args []string) (EditableCommand, error) {
return nil, NotEnoughArgsError
}
msg := &UserMsgCommand{
BaseCommand: BaseCommand{},
user: args[0],
unused: args[2],
realname: args[3],
user: args[0],
unused: args[2],
realname: args[3],
}
mode, err := strconv.ParseUint(args[1], 10, 8)
if err == nil {
@ -260,9 +249,7 @@ func (cmd *QuitCommand) String() string {
}
func NewQuitCommand(args []string) (EditableCommand, error) {
msg := &QuitCommand{
BaseCommand: BaseCommand{},
}
msg := &QuitCommand{}
if len(args) > 0 {
msg.message = args[0]
}
@ -283,8 +270,7 @@ func (cmd *JoinCommand) String() string {
func NewJoinCommand(args []string) (EditableCommand, error) {
msg := &JoinCommand{
BaseCommand: BaseCommand{},
channels: make(map[string]string),
channels: make(map[string]string),
}
if len(args) == 0 {
@ -327,8 +313,7 @@ func NewPartCommand(args []string) (EditableCommand, error) {
return nil, NotEnoughArgsError
}
msg := &PartCommand{
BaseCommand: BaseCommand{},
channels: strings.Split(args[0], ","),
channels: strings.Split(args[0], ","),
}
if len(args) > 1 {
msg.message = args[1]
@ -353,9 +338,8 @@ func NewPrivMsgCommand(args []string) (EditableCommand, error) {
return nil, NotEnoughArgsError
}
return &PrivMsgCommand{
BaseCommand: BaseCommand{},
target: args[0],
message: args[1],
target: args[0],
message: args[1],
}, nil
}
@ -384,8 +368,7 @@ func NewTopicCommand(args []string) (EditableCommand, error) {
return nil, NotEnoughArgsError
}
msg := &TopicCommand{
BaseCommand: BaseCommand{},
channel: args[0],
channel: args[0],
}
if len(args) > 1 {
msg.topic = args[1]
@ -409,8 +392,7 @@ func NewModeCommand(args []string) (EditableCommand, error) {
}
cmd := &ModeCommand{
BaseCommand: BaseCommand{},
nickname: args[0],
nickname: args[0],
}
if len(args) > 1 {

View File

@ -12,7 +12,7 @@ const (
type NickServCommand interface {
HandleNickServ(*NickServ)
Client() *Client
SetClient(*Client)
SetBase(*Client)
}
type NickServ struct {
@ -56,7 +56,7 @@ func (ns *NickServ) HandlePrivMsg(m *PrivMsgCommand) {
return
}
cmd.SetClient(m.Client())
cmd.SetBase(m.Client())
if ns.Debug() {
log.Printf("%s ← %s %s", ns, cmd.Client(), cmd)
}
@ -106,7 +106,8 @@ func (m *RegisterCommand) HandleNickServ(ns *NickServ) {
return
}
user := NewUser(client.nick, ns.server).SetPassword(m.password)
user := NewUser(client.nick, ns.server)
user.SetPassword(m.password)
Save(ns.server.db, user)
ns.Reply(client, "You have registered.")

View File

@ -23,6 +23,10 @@ type Savable interface {
Save(q Queryable) bool
}
type Loadable interface {
Load(q Queryable) bool
}
//
// general
//
@ -89,6 +93,10 @@ func Save(db *sql.DB, s Savable) {
Transact(db, s.Save)
}
func Load(db *sql.DB, l Loadable) {
Transact(db, l.Load)
}
//
// general purpose sql
//
@ -99,7 +107,7 @@ func findId(q Queryable, sql string, args ...interface{}) (rowId RowId, err erro
return
}
func Count(q Queryable, sql string, args ...interface{}) (count uint, err error) {
func countRows(q Queryable, sql string, args ...interface{}) (count uint, err error) {
row := q.QueryRow(sql, args...)
err = row.Scan(&count)
return
@ -162,20 +170,20 @@ func FindChannelByName(q Queryable, name string) (cr *ChannelRow) {
return
}
func InsertUser(q Queryable, user *User) (err error) {
func InsertUser(q Queryable, row *UserRow) (err error) {
_, err = q.Exec("INSERT INTO user (nick, hash) VALUES (?, ?)",
user.nick, user.hash)
row.nick, row.hash)
return
}
func UpdateUser(q Queryable, user *User) (err error) {
func UpdateUser(q Queryable, row *UserRow) (err error) {
_, err = q.Exec("UPDATE user SET nick = ?, hash = ? WHERE id = ?",
user.nick, user.hash, *(user.id))
row.nick, row.hash, row.id)
return
}
func DeleteUser(q Queryable, user *User) (err error) {
_, err = q.Exec("DELETE FROM user WHERE id = ?", *(user.id))
func DeleteUser(q Queryable, id RowId) (err error) {
_, err = q.Exec("DELETE FROM user WHERE id = ?", id)
return
}
@ -211,14 +219,12 @@ func FindChannelIdByName(q Queryable, name string) (RowId, error) {
return findId(q, "SELECT id FROM channel WHERE name = ?", name)
}
func FindChannelsForUser(q Queryable, userId RowId) (crs []*ChannelRow, err error) {
query := ` FROM channel WHERE id IN
(SELECT channel_id from user_channel WHERE user_id = ?)`
count, err := Count(q, "SELECT COUNT(id)"+query, userId)
func findChannels(q Queryable, where string, args ...interface{}) (crs []*ChannelRow, err error) {
count, err := countRows(q, "SELECT COUNT(id) FROM channel "+where, args...)
if err != nil {
return
}
rows, err := q.Query("SELECT id, name"+query, userId)
rows, err := q.Query("SELECT id, name FROM channel "+where, args...)
if err != nil {
return
}
@ -236,6 +242,17 @@ func FindChannelsForUser(q Queryable, userId RowId) (crs []*ChannelRow, err erro
return
}
func FindChannelsForUser(q Queryable, userId RowId) (crs []*ChannelRow, err error) {
crs, err = findChannels(q,
"WHERE id IN (SELECT channel_id from user_channel WHERE user_id = ?)", userId)
return
}
func FindAllChannels(q Queryable) (crs []*ChannelRow, err error) {
crs, err = findChannels(q, "")
return
}
func InsertChannel(q Queryable, channel *Channel) (err error) {
_, err = q.Exec("INSERT INTO channel (name) VALUES (?)", channel.name)
return

View File

@ -17,7 +17,7 @@ type Replier interface {
}
type Reply interface {
Format(client *Client) string
Format(*Client, chan<- string)
Source() Identifier
}
@ -31,7 +31,7 @@ func (reply *BaseReply) Source() Identifier {
}
type StringReply struct {
BaseReply
*BaseReply
code string
}
@ -40,13 +40,13 @@ func NewStringReply(source Identifier, code string,
message := fmt.Sprintf(format, args...)
fullMessage := fmt.Sprintf(":%s %s %s", source.Id(), code, message)
return &StringReply{
BaseReply: BaseReply{source, fullMessage},
BaseReply: &BaseReply{source, fullMessage},
code: code,
}
}
func (reply *StringReply) Format(client *Client) string {
return reply.message
func (reply *StringReply) Format(client *Client, write chan<- string) {
write <- reply.message
}
func (reply *StringReply) String() string {
@ -55,19 +55,23 @@ func (reply *StringReply) String() string {
}
type NumericReply struct {
BaseReply
*BaseReply
code int
}
func NewNumericReply(source Identifier, code int, format string,
args ...interface{}) *NumericReply {
return &NumericReply{
BaseReply: BaseReply{source, fmt.Sprintf(format, args...)},
BaseReply: &BaseReply{source, fmt.Sprintf(format, args...)},
code: code,
}
}
func (reply *NumericReply) Format(client *Client) string {
func (reply *NumericReply) Format(client *Client, write chan<- string) {
write <- reply.FormatString(client)
}
func (reply *NumericReply) FormatString(client *Client) string {
return fmt.Sprintf(":%s %03d %s %s", reply.Source().Id(), reply.code,
client.Nick(), reply.message)
}
@ -77,6 +81,53 @@ func (reply *NumericReply) String() string {
reply.source, reply.code, reply.message)
}
// names reply
type NamesReply struct {
*BaseReply
channel *Channel
}
func NewNamesReply(channel *Channel) Reply {
return &NamesReply{
BaseReply: &BaseReply{
source: channel,
},
}
}
const (
MAX_REPLY_LEN = 510 // 512 - CRLF
)
func joinedLen(names []string) int {
var l = len(names) - 1 // " " between names
for _, name := range names {
l += len(name)
}
return l
}
func (reply *NamesReply) Format(client *Client, write chan<- string) {
base := RplNamReply(reply.channel, []string{})
baseLen := len(base.FormatString(client))
tooLong := func(names []string) bool {
return (baseLen + joinedLen(names)) > MAX_REPLY_LEN
}
var start = 0
nicks := reply.channel.Nicks()
for i := range nicks {
if (i > start) && tooLong(nicks[start:i]) {
RplNamReply(reply.channel, nicks[start:i-1]).Format(client, write)
start = i - 1
}
}
if start < (len(nicks) - 1) {
RplNamReply(reply.channel, nicks[start:]).Format(client, write)
}
RplEndOfNames(reply.channel).Format(client, write)
}
// messaging replies
func RplPrivMsg(source Identifier, target Identifier, message string) Reply {
@ -118,7 +169,7 @@ func RplWelcome(source Identifier, client *Client) Reply {
"Welcome to the Internet Relay Network %s", client.Id())
}
func RplYourHost(server *Server, target *Client) Reply {
func RplYourHost(server *Server) Reply {
return NewNumericReply(server, RPL_YOURHOST,
"Your host is %s, running version %s", server.name, VERSION)
}
@ -152,10 +203,9 @@ func RplInvitingMsg(channel *Channel, invitee *Client) Reply {
"%s %s", channel.name, invitee.Nick())
}
func RplNamReply(channel *Channel) Reply {
// TODO multiple names and splitting based on message size
return NewNumericReply(channel.server, RPL_NAMREPLY,
"= %s :%s", channel.name, strings.Join(channel.Nicks(), " "))
func RplNamReply(channel *Channel, names []string) *NumericReply {
return NewNumericReply(channel.server, RPL_NAMREPLY, "= %s :%s",
channel.name, strings.Join(names, " "))
}
func RplEndOfNames(source Identifier) Reply {

View File

@ -41,17 +41,34 @@ func NewServer(name string) *Server {
}
go server.receiveCommands(commands)
NewNickServ(server)
Transact(server.db, func(q Queryable) bool {
urs, err := FindAllUsers(server.db)
if err != nil {
Load(server.db, server)
return server
}
func (server *Server) Load(q Queryable) bool {
crs, err := FindAllChannels(q)
if err != nil {
log.Println(err)
return false
}
for _, cr := range crs {
channel := server.GetOrMakeChannel(cr.name)
channel.id = &(cr.id)
}
urs, err := FindAllUsers(q)
if err != nil {
log.Println(err)
return false
}
for _, ur := range urs {
user := NewUser(ur.nick, server)
user.SetHash(ur.hash)
if !user.Load(q) {
return false
}
for _, ur := range urs {
NewUser(ur.nick, server).SetHash(ur.hash)
}
return false
})
return server
}
return true
}
func (server *Server) receiveCommands(commands <-chan Command) {
@ -115,7 +132,7 @@ func (s *Server) tryRegister(c *Client) {
c.registered = true
replies := []Reply{
RplWelcome(s, c),
RplYourHost(s, c),
RplYourHost(s),
RplCreated(s),
RplMyInfo(s),
}
@ -318,21 +335,21 @@ func (m *PrivMsgCommand) HandleServer(s *Server) {
if m.TargetIsChannel() {
channel := s.channels[m.target]
if channel == nil {
user.Replies() <- ErrNoSuchChannel(s, m.target)
m.Client().Replies() <- ErrNoSuchChannel(s, m.target)
return
}
channel.Commands() <- m
channel.commands <- m
return
}
target := s.users[m.target]
if target == nil {
user.Replies() <- ErrNoSuchNick(s, m.target)
m.Client().Replies() <- ErrNoSuchNick(s, m.target)
return
}
target.Commands() <- m
target.commands <- m
}
func (m *ModeCommand) HandleServer(s *Server) {

View File

@ -16,7 +16,7 @@ type UserCommand interface {
}
type User struct {
id *RowId
id RowId
nick string
hash []byte
server *Server
@ -64,50 +64,80 @@ func NewUser(nick string, server *Server) *User {
return user
}
func (user *User) Row() *UserRow {
return &UserRow{user.id, user.nick, user.hash}
}
func (user *User) Create(q Queryable) bool {
var err error
if err := InsertUser(q, user.Row()); err != nil {
log.Println(err)
return false
}
user.id, err = FindUserIdByNick(q, user.nick)
if err != nil {
log.Println(err)
return false
}
return true
}
func (user *User) Save(q Queryable) bool {
if user.id == nil {
if err := InsertUser(q, user); err != nil {
return false
}
userId, err := FindUserIdByNick(q, user.nick)
if err != nil {
return false
}
user.id = &userId
} else {
if err := UpdateUser(q, user); err != nil {
return false
}
if err := UpdateUser(q, user.Row()); err != nil {
log.Println(err)
return false
}
userId := *(user.id)
channelIds := user.channels.Ids()
if len(channelIds) == 0 {
if err := DeleteAllUserChannels(q, userId); err != nil {
if err := DeleteAllUserChannels(q, user.id); err != nil {
log.Println(err)
return false
}
} else {
if err := DeleteOtherUserChannels(q, userId, channelIds); err != nil {
if err := DeleteOtherUserChannels(q, user.id, channelIds); err != nil {
log.Println(err)
return false
}
if err := InsertUserChannels(q, userId, channelIds); err != nil {
if err := InsertUserChannels(q, user.id, channelIds); err != nil {
log.Println(err)
return false
}
}
return true
}
func (user *User) SetPassword(password string) *User {
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
func (user *User) Delete(q Queryable) bool {
err := DeleteUser(q, user.id)
if err != nil {
panic("bcrypt failed; cannot generate password hash")
log.Println(err)
return false
}
return user.SetHash(hash)
return true
}
func (user *User) SetHash(hash []byte) *User {
func (user *User) Load(q Queryable) bool {
crs, err := FindChannelsForUser(q, user.id)
if err != nil {
log.Println(err)
return false
}
for _, cr := range crs {
user.server.GetOrMakeChannel(cr.name).Join(user)
}
return true
}
func (user *User) SetPassword(password string) {
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
log.Panicln(err)
}
user.SetHash(hash)
}
func (user *User) SetHash(hash []byte) {
user.hash = hash
return user
}
func (user *User) receiveCommands(commands <-chan UserCommand) {
@ -149,10 +179,6 @@ func (user *User) String() string {
return user.Id()
}
func (user *User) Commands() chan<- UserCommand {
return user.commands
}
func (user *User) Login(c *Client, nick string, password string) bool {
if nick != c.nick {
return false
@ -172,8 +198,7 @@ func (user *User) Login(c *Client, nick string, password string) bool {
c.user = user
for channel := range user.channels {
channel.GetTopic(c)
c.Replies() <- RplNamReply(channel)
c.Replies() <- RplEndOfNames(channel.server)
channel.GetUsers(c)
}
return true
}