3
0
mirror of https://github.com/ergochat/ergo.git synced 2024-11-10 22:19:31 +01:00

Merge pull request #8 from jlatt/cap-protocol

basic capability negotiation
This commit is contained in:
Jeremy Latt 2014-03-06 17:39:12 -08:00
commit 4bcd42ff34
7 changed files with 222 additions and 101 deletions

View File

@ -54,18 +54,26 @@ func (channel *Channel) ClientIsOperator(client *Client) bool {
return client.flags[Operator] || channel.members.HasMode(client, ChannelOperator) return client.flags[Operator] || channel.members.HasMode(client, ChannelOperator)
} }
func (channel *Channel) Nicks() []string { func (channel *Channel) Nicks(target *Client) []string {
isMultiPrefix := (target != nil) && target.capabilities[MultiPrefix]
nicks := make([]string, len(channel.members)) nicks := make([]string, len(channel.members))
i := 0 i := 0
for client, modes := range channel.members { for client, modes := range channel.members {
switch { if isMultiPrefix {
case modes[ChannelOperator]: if modes[ChannelOperator] {
nicks[i] = "@" + client.Nick() nicks[i] += "@"
case modes[Voice]: }
nicks[i] = "+" + client.Nick() if modes[Voice] {
default: nicks[i] += "+"
nicks[i] = client.Nick() }
} else {
if modes[ChannelOperator] {
nicks[i] += "@"
} else if modes[Voice] {
nicks[i] += "+"
}
} }
nicks[i] += client.Nick()
i += 1 i += 1
} }
return nicks return nicks

View File

@ -12,36 +12,41 @@ func IsNickname(nick string) bool {
} }
type Client struct { type Client struct {
atime time.Time atime time.Time
awayMessage string authorized bool
channels ChannelSet awayMessage string
commands chan editableCommand capabilities CapabilitySet
ctime time.Time capState CapState
flags map[UserMode]bool channels ChannelSet
hasQuit bool commands chan editableCommand
hops uint ctime time.Time
hostname string flags map[UserMode]bool
idleTimer *time.Timer hasQuit bool
loginTimer *time.Timer hops uint
nick string hostname string
phase Phase idleTimer *time.Timer
quitTimer *time.Timer loginTimer *time.Timer
realname string nick string
server *Server phase Phase
socket *Socket quitTimer *time.Timer
username string realname string
server *Server
socket *Socket
username string
} }
func NewClient(server *Server, conn net.Conn) *Client { func NewClient(server *Server, conn net.Conn) *Client {
now := time.Now() now := time.Now()
client := &Client{ client := &Client{
atime: now, atime: now,
channels: make(ChannelSet), capState: CapNone,
commands: make(chan editableCommand), capabilities: make(CapabilitySet),
ctime: now, channels: make(ChannelSet),
flags: make(map[UserMode]bool), commands: make(chan editableCommand),
phase: server.InitPhase(), ctime: now,
server: server, flags: make(map[UserMode]bool),
phase: Registration,
server: server,
} }
client.socket = NewSocket(conn, client.commands) client.socket = NewSocket(conn, client.commands)
client.loginTimer = time.AfterFunc(LOGIN_TIMEOUT, client.connectionTimeout) client.loginTimer = time.AfterFunc(LOGIN_TIMEOUT, client.connectionTimeout)
@ -68,6 +73,12 @@ func (client *Client) run() {
} }
} }
func (client *Client) connectionTimeout() {
client.commands <- &QuitCommand{
message: "connection timeout",
}
}
// //
// idle timer goroutine // idle timer goroutine
// //
@ -76,14 +87,6 @@ func (client *Client) connectionIdle() {
client.server.idle <- client client.server.idle <- client
} }
//
// quit timer goroutine
//
func (client *Client) connectionTimeout() {
client.server.timeout <- client
}
// //
// server goroutine // server goroutine
// //
@ -232,7 +235,10 @@ func (client *Client) ChangeNickname(nickname string) {
} }
} }
func (client *Client) Reply(reply string) { func (client *Client) Reply(reply string, args ...interface{}) {
if len(args) > 0 {
reply = fmt.Sprintf(reply, args...)
}
client.socket.Write(reply) client.socket.Write(reply)
} }

View File

@ -705,20 +705,34 @@ func NewOperCommand(args []string) (editableCommand, error) {
return cmd, nil return cmd, nil
} }
// TODO
type CapCommand struct { type CapCommand struct {
BaseCommand BaseCommand
args []string subCommand CapSubCommand
capabilities CapabilitySet
} }
func (msg *CapCommand) String() string { func (msg *CapCommand) String() string {
return fmt.Sprintf("CAP(args=%s)", msg.args) return fmt.Sprintf("CAP(subCommand=%s, capabilities=%s)",
msg.subCommand, msg.capabilities)
} }
func NewCapCommand(args []string) (editableCommand, error) { func NewCapCommand(args []string) (editableCommand, error) {
return &CapCommand{ if len(args) < 1 {
args: args, return nil, NotEnoughArgsError
}, nil }
cmd := &CapCommand{
subCommand: CapSubCommand(strings.ToUpper(args[0])),
capabilities: make(CapabilitySet),
}
if len(args) > 1 {
strs := spacesExpr.Split(args[1], -1)
for _, str := range strs {
cmd.capabilities[Capability(str)] = true
}
}
return cmd, nil
} }
// HAPROXY support // HAPROXY support

View File

@ -155,6 +155,7 @@ const (
ERR_TOOMANYTARGETS NumericCode = 407 ERR_TOOMANYTARGETS NumericCode = 407
ERR_NOSUCHSERVICE NumericCode = 408 ERR_NOSUCHSERVICE NumericCode = 408
ERR_NOORIGIN NumericCode = 409 ERR_NOORIGIN NumericCode = 409
ERR_INVALIDCAPCMD NumericCode = 410
ERR_NORECIPIENT NumericCode = 411 ERR_NORECIPIENT NumericCode = 411
ERR_NOTEXTTOSEND NumericCode = 412 ERR_NOTEXTTOSEND NumericCode = 412
ERR_NOTOPLEVEL NumericCode = 413 ERR_NOTOPLEVEL NumericCode = 413
@ -200,6 +201,14 @@ const (
ERR_UMODEUNKNOWNFLAG NumericCode = 501 ERR_UMODEUNKNOWNFLAG NumericCode = 501
ERR_USERSDONTMATCH NumericCode = 502 ERR_USERSDONTMATCH NumericCode = 502
CAP_LS CapSubCommand = "LS"
CAP_LIST CapSubCommand = "LIST"
CAP_REQ CapSubCommand = "REQ"
CAP_ACK CapSubCommand = "ACK"
CAP_NAK CapSubCommand = "NAK"
CAP_CLEAR CapSubCommand = "CLEAR"
CAP_END CapSubCommand = "END"
Add ModeOp = '+' Add ModeOp = '+'
List ModeOp = '=' List ModeOp = '='
Remove ModeOp = '-' Remove ModeOp = '-'
@ -230,10 +239,28 @@ const (
Secret ChannelMode = 's' // flag, deprecated Secret ChannelMode = 's' // flag, deprecated
UserLimit ChannelMode = 'l' // flag arg UserLimit ChannelMode = 'l' // flag arg
Voice ChannelMode = 'v' // arg Voice ChannelMode = 'v' // arg
MultiPrefix Capability = "multi-prefix"
SASL Capability = "sasl"
Disable CapModifier = '-'
Ack CapModifier = '~'
Sticky CapModifier = '='
)
var (
SupportedCapabilities = CapabilitySet{
MultiPrefix: true,
}
) )
const ( const (
Authorization Phase = iota Registration Phase = iota
Registration Phase = iota Normal Phase = iota
Normal Phase = iota )
const (
CapNone CapState = iota
CapNegotiating CapState = iota
CapNegotiated CapState = iota
) )

View File

@ -240,11 +240,19 @@ func (target *Client) RplWhoReply(channel *Channel, client *Client) {
if channel != nil { if channel != nil {
channelName = channel.name channelName = channel.name
if target.capabilities[MultiPrefix] {
if channel.members[client][ChannelOperator] { if channel.members[client][ChannelOperator] {
flags += "@" flags += "@"
} else if channel.members[client][Voice] { }
flags += "+" if channel.members[client][Voice] {
flags += "+"
}
} else {
if channel.members[client][ChannelOperator] {
flags += "@"
} else if channel.members[client][Voice] {
flags += "+"
}
} }
} }
target.NumericReply(RPL_WHOREPLY, target.NumericReply(RPL_WHOREPLY,
@ -360,7 +368,7 @@ func (target *Client) RplListEnd(server *Server) {
} }
func (target *Client) RplNamReply(channel *Channel) { func (target *Client) RplNamReply(channel *Channel) {
target.MultilineReply(channel.Nicks(), RPL_NAMREPLY, target.MultilineReply(channel.Nicks(target), RPL_NAMREPLY,
"= %s :%s", channel) "= %s :%s", channel)
} }
@ -502,3 +510,8 @@ func (target *Client) ErrChannelIsFull(channel *Channel) {
target.NumericReply(ERR_CHANNELISFULL, target.NumericReply(ERR_CHANNELISFULL,
"%s :Cannot join channel (+l)", channel) "%s :Cannot join channel (+l)", channel)
} }
func (target *Client) ErrInvalidCapCmd(subCommand CapSubCommand) {
target.NumericReply(ERR_INVALIDCAPCMD,
"%s :Invalid CAP subcommand", subCommand)
}

View File

@ -29,7 +29,6 @@ type Server struct {
operators map[string][]byte operators map[string][]byte
password []byte password []byte
signals chan os.Signal signals chan os.Signal
timeout chan *Client
} }
func NewServer(config *Config) *Server { func NewServer(config *Config) *Server {
@ -45,7 +44,6 @@ func NewServer(config *Config) *Server {
newConns: make(chan net.Conn, 16), newConns: make(chan net.Conn, 16),
operators: config.Operators(), operators: config.Operators(),
signals: make(chan os.Signal, 1), signals: make(chan os.Signal, 1),
timeout: make(chan *Client, 16),
} }
if config.Server.Password != "" { if config.Server.Password != "" {
@ -97,14 +95,6 @@ func (server *Server) processCommand(cmd Command) {
} }
switch client.phase { switch client.phase {
case Authorization:
authCmd, ok := cmd.(AuthServerCommand)
if !ok {
client.Quit("unexpected command")
return
}
authCmd.HandleAuthServer(server)
case Registration: case Registration:
regCmd, ok := cmd.(RegServerCommand) regCmd, ok := cmd.(RegServerCommand)
if !ok { if !ok {
@ -113,7 +103,7 @@ func (server *Server) processCommand(cmd Command) {
} }
regCmd.HandleRegServer(server) regCmd.HandleRegServer(server)
default: case Normal:
srvCmd, ok := cmd.(ServerCommand) srvCmd, ok := cmd.(ServerCommand)
if !ok { if !ok {
client.ErrUnknownCommand(cmd.Code()) client.ErrUnknownCommand(cmd.Code())
@ -157,20 +147,10 @@ func (server *Server) Run() {
case client := <-server.idle: case client := <-server.idle:
client.Idle() client.Idle()
case client := <-server.timeout:
client.Quit("connection timeout")
} }
} }
} }
func (server *Server) InitPhase() Phase {
if server.password == nil {
return Registration
}
return Authorization
}
// //
// listen goroutine // listen goroutine
// //
@ -206,7 +186,7 @@ func (s *Server) listen(addr string) {
// //
func (s *Server) tryRegister(c *Client) { func (s *Server) tryRegister(c *Client) {
if c.HasNick() && c.HasUsername() { if c.HasNick() && c.HasUsername() && (c.capState != CapNegotiating) {
c.Register() c.Register()
c.RplWelcome() c.RplWelcome()
c.RplYourHost() c.RplYourHost()
@ -266,18 +246,10 @@ func (s *Server) Nick() string {
} }
// //
// authorization commands // registration commands
// //
func (msg *ProxyCommand) HandleAuthServer(server *Server) { func (msg *PassCommand) HandleRegServer(server *Server) {
msg.Client().hostname = msg.hostname
}
func (msg *CapCommand) HandleAuthServer(server *Server) {
// TODO
}
func (msg *PassCommand) HandleAuthServer(server *Server) {
client := msg.Client() client := msg.Client()
if msg.err != nil { if msg.err != nil {
client.ErrPasswdMismatch() client.ErrPasswdMismatch()
@ -285,27 +257,70 @@ func (msg *PassCommand) HandleAuthServer(server *Server) {
return return
} }
client.phase = Registration client.authorized = true
} }
func (msg *QuitCommand) HandleAuthServer(server *Server) {
msg.Client().Quit(msg.message)
}
//
// registration commands
//
func (msg *ProxyCommand) HandleRegServer(server *Server) { func (msg *ProxyCommand) HandleRegServer(server *Server) {
msg.Client().hostname = msg.hostname msg.Client().hostname = msg.hostname
} }
func (msg *CapCommand) HandleRegServer(server *Server) { func (msg *CapCommand) HandleRegServer(server *Server) {
// TODO client := msg.Client()
switch msg.subCommand {
case CAP_LS:
client.capState = CapNegotiating
client.Reply("CAP LS * :%s", SupportedCapabilities)
case CAP_LIST:
client.Reply("CAP LIST * :%s", client.capabilities)
case CAP_REQ:
client.capState = CapNegotiating
for capability := range msg.capabilities {
if !SupportedCapabilities[capability] {
client.Reply("CAP NAK * :%s", msg.capabilities)
return
}
}
for capability := range msg.capabilities {
client.capabilities[capability] = true
}
client.Reply("CAP ACK * :%s", msg.capabilities)
case CAP_CLEAR:
format := strings.TrimRight(
strings.Repeat("%s%s ", len(client.capabilities)), " ")
args := make([]interface{}, len(client.capabilities))
index := 0
for capability := range client.capabilities {
args[index] = Disable
args[index+1] = capability
index += 2
delete(client.capabilities, capability)
}
client.Reply("CAP ACK * :"+format, args...)
case CAP_END:
client.capState = CapNegotiated
server.tryRegister(client)
default:
client.ErrInvalidCapCmd(msg.subCommand)
}
} }
func (m *NickCommand) HandleRegServer(s *Server) { func (m *NickCommand) HandleRegServer(s *Server) {
client := m.Client() client := m.Client()
if !client.authorized {
client.ErrPasswdMismatch()
client.Quit("bad password")
return
}
if client.capState == CapNegotiating {
client.capState = CapNegotiated
}
if m.nickname == "" { if m.nickname == "" {
client.ErrNoNicknameGiven() client.ErrNoNicknameGiven()
@ -327,11 +342,22 @@ func (m *NickCommand) HandleRegServer(s *Server) {
} }
func (msg *RFC1459UserCommand) HandleRegServer(server *Server) { func (msg *RFC1459UserCommand) HandleRegServer(server *Server) {
client := msg.Client()
if !client.authorized {
client.ErrPasswdMismatch()
client.Quit("bad password")
return
}
msg.HandleRegServer2(server) msg.HandleRegServer2(server)
} }
func (msg *RFC2812UserCommand) HandleRegServer(server *Server) { func (msg *RFC2812UserCommand) HandleRegServer(server *Server) {
client := msg.Client() client := msg.Client()
if !client.authorized {
client.ErrPasswdMismatch()
client.Quit("bad password")
return
}
flags := msg.Flags() flags := msg.Flags()
if len(flags) > 0 { if len(flags) > 0 {
for _, mode := range msg.Flags() { for _, mode := range msg.Flags() {
@ -344,6 +370,9 @@ func (msg *RFC2812UserCommand) HandleRegServer(server *Server) {
func (msg *UserCommand) HandleRegServer2(server *Server) { func (msg *UserCommand) HandleRegServer2(server *Server) {
client := msg.Client() client := msg.Client()
if client.capState == CapNegotiating {
client.capState = CapNegotiated
}
client.username, client.realname = msg.username, msg.realname client.username, client.realname = msg.username, msg.realname
server.tryRegister(client) server.tryRegister(client)
} }

View File

@ -10,6 +10,30 @@ import (
// simple types // simple types
// //
type CapSubCommand string
type Capability string
type CapModifier rune
func (mod CapModifier) String() string {
return string(mod)
}
type CapState uint
type CapabilitySet map[Capability]bool
func (set CapabilitySet) String() string {
strs := make([]string, len(set))
index := 0
for capability := range set {
strs[index] = string(capability)
index += 1
}
return strings.Join(strs, " ")
}
// a string with wildcards // a string with wildcards
type Mask string type Mask string
@ -24,7 +48,7 @@ func (op ModeOp) String() string {
type UserMode rune type UserMode rune
func (mode UserMode) String() string { func (mode UserMode) String() string {
return fmt.Sprintf("%c", mode) return string(mode)
} }
type Phase uint type Phase uint
@ -49,7 +73,7 @@ func (code NumericCode) String() string {
type ChannelMode rune type ChannelMode rune
func (mode ChannelMode) String() string { func (mode ChannelMode) String() string {
return fmt.Sprintf("%c", mode) return string(mode)
} }
type ChannelNameMap map[string]*Channel type ChannelNameMap map[string]*Channel