diff --git a/irc/client.go b/irc/client.go index b43af358..2409d0e1 100644 --- a/irc/client.go +++ b/irc/client.go @@ -3,7 +3,6 @@ package irc import ( "bufio" "fmt" - "io" "log" "net" "strings" @@ -11,6 +10,7 @@ import ( ) type Client struct { + atime time.Time away bool awayMessage string channels ChannelSet @@ -27,7 +27,7 @@ type Client struct { registered bool replies chan Reply server *Server - serverPass bool + authorized bool username string } @@ -48,9 +48,12 @@ func NewClient(server *Server, conn net.Conn) *Client { } func (client *Client) Touch() { + client.atime = time.Now() + if client.quitTimer != nil { client.quitTimer.Stop() } + if client.idleTimer == nil { client.idleTimer = time.AfterFunc(IDLE_TIMEOUT, client.Idle) } else { @@ -64,6 +67,7 @@ func (client *Client) Idle() { } else { client.quitTimer.Reset(QUIT_TIMEOUT) } + client.Reply(RplPing(client.server, client)) } @@ -90,11 +94,7 @@ func (c *Client) readConn() { line, err := recv.ReadString('\n') if err != nil { if DEBUG_NET { - if err == io.EOF { - log.Printf("%s → closed", c.conn.RemoteAddr()) - } else { - log.Printf("%s → error: %s", c.conn.RemoteAddr(), err) - } + log.Printf("%s → error: %s", c.conn.RemoteAddr(), err) } break } @@ -124,11 +124,7 @@ func (c *Client) readConn() { func (client *Client) maybeLogWriteError(err error) bool { if err != nil { if DEBUG_NET { - if err == io.EOF { - log.Printf("%s ← closed", client.conn.RemoteAddr()) - } else { - log.Printf("%s ← error: %s", client.conn.RemoteAddr(), err) - } + log.Printf("%s ← error: %s", client.conn.RemoteAddr(), err) } return true } @@ -181,6 +177,8 @@ func (client *Client) Destroy() { // clear channel list client.channels = make(ChannelSet) + client.server.clients.Remove(client) + client.destroyed = true } diff --git a/irc/server.go b/irc/server.go index a6e1344b..0a6475b1 100644 --- a/irc/server.go +++ b/irc/server.go @@ -55,28 +55,41 @@ func (server *Server) receiveCommands(commands <-chan Command) { log.Printf("%s → %s %s", command.Client(), server, command) } client := command.Client() - client.Touch() - if !client.serverPass { - if server.password == "" { - client.serverPass = true - - } else { - switch command.(type) { - case *PassCommand, *CapCommand, *ProxyCommand: - // no-op - default: - client.Reply(ErrPasswdMismatch(server)) - server.clients.Remove(client) - client.Destroy() - return - } - } + if !server.Authorize(client, command) { + client.Destroy() + return } + + client.Touch() command.HandleServer(server) + + if DEBUG_SERVER { + log.Printf("%s → %s %s processed", command.Client(), server, command) + } } } +func (server *Server) Authorize(client *Client, command Command) bool { + if client.authorized { + return true + } + + if server.password == "" { + client.authorized = true + return true + } + + switch command.(type) { + case *PassCommand, *CapCommand, *ProxyCommand: + // no-op + default: + return false + } + + return true +} + func newListener(config ListenerConfig) (net.Listener, error) { if config.IsTLS() { certificate, err := tls.LoadX509KeyPair(config.Certificate, config.Key) @@ -222,14 +235,20 @@ func (m *PongCommand) HandleServer(s *Server) { } func (m *PassCommand) HandleServer(s *Server) { - if s.password != m.password { - m.Client().Reply(ErrPasswdMismatch(s)) - m.Client().Destroy() + client := m.Client() + + if client.registered || client.authorized { + client.Reply(ErrAlreadyRegistered(s)) return } - m.Client().serverPass = true - // no reply? + if s.password != m.password { + client.Reply(ErrPasswdMismatch(s)) + client.Destroy() + return + } + + client.authorized = true } func (m *NickCommand) HandleServer(s *Server) { @@ -449,9 +468,7 @@ func (msg *CapCommand) HandleServer(server *Server) { } func (msg *ProxyCommand) HandleServer(server *Server) { - go func() { - msg.Client().hostname = LookupHostname(msg.sourceIP) - }() + msg.Client().hostname = LookupHostname(msg.sourceIP) } func (msg *AwayCommand) HandleServer(server *Server) {