diff --git a/irc/channel.go b/irc/channel.go index 5c771fe4..7d3bd333 100644 --- a/irc/channel.go +++ b/irc/channel.go @@ -46,6 +46,23 @@ func NewChannel(s *Server, name string) *Channel { return channel } +func (channel *Channel) Destroy() error { + if channel.replies == nil { + return ErrAlreadyDestroyed + } + close(channel.replies) + channel.replies = nil + return nil +} + +func (channel *Channel) Reply(reply Reply) error { + if channel.replies == nil { + return ErrAlreadyDestroyed + } + channel.replies <- reply + return nil +} + func (channel *Channel) receiveCommands(commands <-chan ChannelCommand) { for command := range commands { if DEBUG_CHANNEL { @@ -63,7 +80,7 @@ func (channel *Channel) receiveReplies(replies <-chan Reply) { for client := range channel.members { var dest Identifier = client if reply.Source() != dest { - client.replies <- reply + client.Reply(reply) } } } @@ -75,15 +92,15 @@ func (channel *Channel) IsEmpty() bool { func (channel *Channel) GetTopic(replier Replier) { if channel.topic == "" { - replier.Replies() <- RplNoTopic(channel) + replier.Reply(RplNoTopic(channel)) return } - replier.Replies() <- RplTopic(channel) + replier.Reply(RplTopic(channel)) } func (channel *Channel) GetUsers(replier Replier) { - replier.Replies() <- NewNamesReply(channel) + replier.Reply(NewNamesReply(channel)) } func (channel *Channel) Nicks() []string { @@ -96,10 +113,6 @@ func (channel *Channel) Nicks() []string { return nicks } -func (channel *Channel) Replies() chan<- Reply { - return channel.replies -} - func (channel *Channel) Id() string { return channel.name } @@ -124,8 +137,8 @@ func (channel *Channel) Join(client *Client) { channel.members.Add(client) client.channels.Add(channel) reply := RplJoin(client, channel) - client.replies <- reply - channel.replies <- reply + client.Reply(reply) + channel.Reply(reply) channel.GetTopic(client) channel.GetUsers(client) } @@ -137,7 +150,7 @@ func (channel *Channel) Join(client *Client) { func (m *JoinCommand) HandleChannel(channel *Channel) { client := m.Client() if channel.key != m.channels[channel.name] { - client.replies <- ErrBadChannelKey(channel) + client.Reply(ErrBadChannelKey(channel)) return } @@ -148,13 +161,13 @@ func (m *PartCommand) HandleChannel(channel *Channel) { client := m.Client() if !channel.members.Has(client) { - client.replies <- ErrNotOnChannel(channel) + client.Reply(ErrNotOnChannel(channel)) return } reply := RplPart(client, channel, m.Message()) - client.replies <- reply - channel.replies <- reply + client.Reply(reply) + channel.Reply(reply) channel.members.Remove(client) client.channels.Remove(channel) @@ -169,7 +182,7 @@ func (m *TopicCommand) HandleChannel(channel *Channel) { client := m.Client() if !channel.members.Has(client) { - client.replies <- ErrNotOnChannel(channel) + client.Reply(ErrNotOnChannel(channel)) return } @@ -186,10 +199,10 @@ func (m *TopicCommand) HandleChannel(channel *Channel) { func (m *PrivMsgCommand) HandleChannel(channel *Channel) { client := m.Client() if channel.noOutside && !channel.members.Has(client) { - client.replies <- ErrCannotSendToChan(channel) + client.Reply(ErrCannotSendToChan(channel)) return } - channel.replies <- RplPrivMsg(client, channel, m.message) + channel.Reply(RplPrivMsg(client, channel, m.message)) } func (msg *ChannelModeCommand) HandleChannel(channel *Channel) { @@ -200,9 +213,9 @@ func (msg *ChannelModeCommand) HandleChannel(channel *Channel) { case BanMask: // TODO add/remove for _, banMask := range channel.banList { - client.replies <- RplBanList(channel, banMask) + client.Reply(RplBanList(channel, banMask)) } - client.replies <- RplEndOfBanList(channel) + client.Reply(RplEndOfBanList(channel)) case NoOutside: // TODO perms switch modeOp.op { @@ -214,5 +227,5 @@ func (msg *ChannelModeCommand) HandleChannel(channel *Channel) { } } - client.replies <- RplChannelModeIs(channel) + client.Reply(RplChannelModeIs(channel)) } diff --git a/irc/client.go b/irc/client.go index 33dfbac1..269474c8 100644 --- a/irc/client.go +++ b/irc/client.go @@ -29,12 +29,11 @@ func NewClient(server *Server, conn net.Conn) *Client { replies := make(chan Reply) client := &Client{ - channels: make(ChannelSet), - conn: conn, - hostname: LookupHostname(conn.RemoteAddr()), - replies: replies, - server: server, - serverPass: server.password == "", + channels: make(ChannelSet), + conn: conn, + hostname: LookupHostname(conn.RemoteAddr()), + replies: replies, + server: server, } go client.readConn(read) @@ -48,9 +47,9 @@ func (c *Client) readConn(recv <-chan string) { m, err := ParseCommand(str) if err != nil { if err == NotEnoughArgsError { - c.replies <- ErrNeedMoreParams(c.server, str) + c.Reply(ErrNeedMoreParams(c.server, str)) } else { - c.replies <- ErrUnknownCommand(c.server, str) + c.Reply(ErrUnknownCommand(c.server, str)) } continue } @@ -69,14 +68,22 @@ func (c *Client) writeConn(write chan<- string, replies <-chan Reply) { } } -func (client *Client) Destroy() *Client { - client.conn.Close() +func (client *Client) Destroy() error { + if client.replies == nil { + return ErrAlreadyDestroyed + } close(client.replies) - return client + client.replies = nil + client.conn.Close() + return nil } -func (c *Client) Replies() chan<- Reply { - return c.replies +func (client *Client) Reply(reply Reply) error { + if client.replies == nil { + return ErrAlreadyDestroyed + } + client.replies <- reply + return nil } func (client *Client) HasNick() bool { diff --git a/irc/commands.go b/irc/commands.go index 9884f717..c5785cf8 100644 --- a/irc/commands.go +++ b/irc/commands.go @@ -51,7 +51,7 @@ func (command *BaseCommand) Source() Identifier { } func (command *BaseCommand) Reply(reply Reply) { - command.client.Replies() <- reply + command.client.Reply(reply) } func ParseCommand(line string) (editableCommand, error) { diff --git a/irc/constants.go b/irc/constants.go index 6a3aeb4d..cb8dfc89 100644 --- a/irc/constants.go +++ b/irc/constants.go @@ -1,10 +1,16 @@ package irc +import ( + "errors" +) + var ( DEBUG_NET = false DEBUG_CLIENT = false DEBUG_CHANNEL = false DEBUG_SERVER = false + + ErrAlreadyDestroyed = errors.New("already destroyed") ) const ( diff --git a/irc/reply.go b/irc/reply.go index cc05b08b..45b995b1 100644 --- a/irc/reply.go +++ b/irc/reply.go @@ -18,20 +18,6 @@ func joinedLen(names []string) int { return l } -type Identifier interface { - Id() string - Nick() string -} - -type Replier interface { - Replies() chan<- Reply -} - -type Reply interface { - Format(*Client, chan<- string) - Source() Identifier -} - type BaseReply struct { source Identifier message string diff --git a/irc/server.go b/irc/server.go index 009400ac..03e72f69 100644 --- a/irc/server.go +++ b/irc/server.go @@ -22,11 +22,12 @@ type Server struct { func NewServer(config *Config) *Server { commands := make(chan Command) server := &Server{ + channels: make(ChannelNameMap), + clients: make(ClientNameMap), + commands: commands, ctime: time.Now(), name: config.Name, - commands: commands, - clients: make(ClientNameMap), - channels: make(ChannelNameMap), + password: config.Password, } go server.receiveCommands(commands) go server.listen(config.Listen) @@ -38,7 +39,18 @@ func (server *Server) receiveCommands(commands <-chan Command) { if DEBUG_SERVER { log.Printf("%s → %s : %s", command.Client(), server, command) } - command.Client().atime = time.Now() + client := command.Client() + client.atime = time.Now() + if !client.serverPass { + if server.password == "" { + client.serverPass = true + + } else if _, ok := command.(*PassCommand); !ok { + client.Reply(ErrPasswdMismatch(server)) + client.Destroy() + return + } + } command.HandleServer(server) } } @@ -97,7 +109,7 @@ func (s *Server) GenerateGuestNick() string { // server functionality func (s *Server) tryRegister(c *Client) { - if !c.registered && c.HasNick() && c.HasUsername() && c.serverPass { + if !c.registered && c.HasNick() && c.HasUsername() { c.registered = true replies := []Reply{ RplWelcome(s, c), @@ -106,7 +118,7 @@ func (s *Server) tryRegister(c *Client) { RplMyInfo(s), } for _, reply := range replies { - c.Replies() <- reply + c.Reply(reply) } } } @@ -128,11 +140,11 @@ func (s *Server) Nick() string { // func (m *UnknownCommand) HandleServer(s *Server) { - m.Client().replies <- ErrUnknownCommand(s, m.command) + m.Client().Reply(ErrUnknownCommand(s, m.command)) } func (m *PingCommand) HandleServer(s *Server) { - m.Client().replies <- RplPong(s, m.Client()) + m.Client().Reply(RplPong(s, m.Client())) } func (m *PongCommand) HandleServer(s *Server) { @@ -141,7 +153,7 @@ func (m *PongCommand) HandleServer(s *Server) { func (m *PassCommand) HandleServer(s *Server) { if s.password != m.password { - m.Client().replies <- ErrPasswdMismatch(s) + m.Client().Reply(ErrPasswdMismatch(s)) m.Client().Destroy() return } @@ -154,14 +166,17 @@ func (m *NickCommand) HandleServer(s *Server) { c := m.Client() if s.clients[m.nickname] != nil { - c.replies <- ErrNickNameInUse(s, m.nickname) + c.Reply(ErrNickNameInUse(s, m.nickname)) return } + if !c.HasNick() { + c.nick = m.nickname + } reply := RplNick(c, m.nickname) - c.replies <- reply + c.Reply(reply) for iclient := range c.InterestedClients() { - iclient.replies <- reply + iclient.Reply(reply) } s.clients.Remove(c) @@ -174,7 +189,7 @@ func (m *NickCommand) HandleServer(s *Server) { func (m *UserMsgCommand) HandleServer(s *Server) { c := m.Client() if c.registered { - c.replies <- ErrAlreadyRegistered(s) + c.Reply(ErrAlreadyRegistered(s)) return } @@ -190,12 +205,12 @@ func (m *QuitCommand) HandleServer(s *Server) { channel.members.Remove(c) } - c.replies <- RplError(s, c) + c.Reply(RplError(s, c)) c.Destroy() reply := RplQuit(c, m.message) for client := range c.InterestedClients() { - client.replies <- reply + client.Reply(reply) } } @@ -221,7 +236,7 @@ func (m *PartCommand) HandleServer(s *Server) { channel := s.channels[chname] if channel == nil { - m.Client().replies <- ErrNoSuchChannel(s, channel.name) + m.Client().Reply(ErrNoSuchChannel(s, channel.name)) continue } @@ -232,7 +247,7 @@ func (m *PartCommand) HandleServer(s *Server) { func (m *TopicCommand) HandleServer(s *Server) { channel := s.channels[m.channel] if channel == nil { - m.Client().replies <- ErrNoSuchChannel(s, m.channel) + m.Client().Reply(ErrNoSuchChannel(s, m.channel)) return } @@ -243,7 +258,7 @@ func (m *PrivMsgCommand) HandleServer(s *Server) { if m.TargetIsChannel() { channel := s.channels[m.target] if channel == nil { - m.Client().replies <- ErrNoSuchChannel(s, m.target) + m.Client().Reply(ErrNoSuchChannel(s, m.target)) return } @@ -253,10 +268,10 @@ func (m *PrivMsgCommand) HandleServer(s *Server) { target := s.clients[m.target] if target == nil { - m.Client().replies <- ErrNoSuchNick(s, m.target) + m.Client().Reply(ErrNoSuchNick(s, m.target)) return } - target.replies <- RplPrivMsg(m.Client(), target, m.message) + target.Reply(RplPrivMsg(m.Client(), target, m.message)) } func (m *ModeCommand) HandleServer(s *Server) { @@ -272,11 +287,11 @@ func (m *ModeCommand) HandleServer(s *Server) { } } } - client.replies <- RplUModeIs(s, client) + client.Reply(RplUModeIs(s, client)) return } - client.replies <- ErrUsersDontMatch(client) + client.Reply(ErrUsersDontMatch(client)) } func (m *WhoisCommand) HandleServer(server *Server) { @@ -284,7 +299,7 @@ func (m *WhoisCommand) HandleServer(server *Server) { // TODO implement target query if m.target != "" { - client.replies <- ErrNoSuchServer(server, m.target) + client.Reply(ErrNoSuchServer(server, m.target)) return } @@ -292,17 +307,17 @@ func (m *WhoisCommand) HandleServer(server *Server) { // TODO implement wildcard matching mclient := server.clients[mask] if mclient != nil { - client.replies <- RplWhoisUser(server, mclient) + client.Reply(RplWhoisUser(server, mclient)) } } - client.replies <- RplEndOfWhois(server) + client.Reply(RplEndOfWhois(server)) } func (msg *ChannelModeCommand) HandleServer(server *Server) { client := msg.Client() channel := server.channels[msg.channel] if channel == nil { - client.replies <- ErrNoSuchChannel(server, msg.channel) + client.Reply(ErrNoSuchChannel(server, msg.channel)) return } channel.commands <- msg @@ -310,7 +325,7 @@ func (msg *ChannelModeCommand) HandleServer(server *Server) { func whoChannel(client *Client, server *Server, channel *Channel) { for member := range channel.members { - client.replies <- RplWhoReply(server, channel, member) + client.Reply(RplWhoReply(server, channel, member)) } } @@ -331,9 +346,9 @@ func (msg *WhoCommand) HandleServer(server *Server) { } else { mclient := server.clients[mask] if mclient != nil { - client.replies <- RplWhoReply(server, mclient.channels.First(), mclient) + client.Reply(RplWhoReply(server, mclient.channels.First(), mclient)) } } - client.replies <- RplEndOfWho(server, mask) + client.Reply(RplEndOfWho(server, mask)) } diff --git a/irc/types.go b/irc/types.go index ec974e67..0c380493 100644 --- a/irc/types.go +++ b/irc/types.go @@ -98,6 +98,20 @@ func (channels ChannelSet) First() *Channel { // interfaces // +type Identifier interface { + Id() string + Nick() string +} + +type Replier interface { + Reply(Reply) error +} + +type Reply interface { + Format(*Client, chan<- string) + Source() Identifier +} + // commands the server understands // TODO rename ServerCommand type Command interface {