3
0
mirror of https://github.com/ergochat/ergo.git synced 2024-12-31 23:22:38 +01:00

Merge pull request #8 from jlatt/cap-protocol

basic capability negotiation
This commit is contained in:
Jeremy Latt 2014-03-06 17:39:12 -08:00
commit 4bcd42ff34
7 changed files with 222 additions and 101 deletions

View File

@ -54,18 +54,26 @@ func (channel *Channel) ClientIsOperator(client *Client) bool {
return client.flags[Operator] || channel.members.HasMode(client, ChannelOperator)
}
func (channel *Channel) Nicks() []string {
func (channel *Channel) Nicks(target *Client) []string {
isMultiPrefix := (target != nil) && target.capabilities[MultiPrefix]
nicks := make([]string, len(channel.members))
i := 0
for client, modes := range channel.members {
switch {
case modes[ChannelOperator]:
nicks[i] = "@" + client.Nick()
case modes[Voice]:
nicks[i] = "+" + client.Nick()
default:
nicks[i] = client.Nick()
if isMultiPrefix {
if modes[ChannelOperator] {
nicks[i] += "@"
}
if modes[Voice] {
nicks[i] += "+"
}
} else {
if modes[ChannelOperator] {
nicks[i] += "@"
} else if modes[Voice] {
nicks[i] += "+"
}
}
nicks[i] += client.Nick()
i += 1
}
return nicks

View File

@ -12,36 +12,41 @@ func IsNickname(nick string) bool {
}
type Client struct {
atime time.Time
awayMessage string
channels ChannelSet
commands chan editableCommand
ctime time.Time
flags map[UserMode]bool
hasQuit bool
hops uint
hostname string
idleTimer *time.Timer
loginTimer *time.Timer
nick string
phase Phase
quitTimer *time.Timer
realname string
server *Server
socket *Socket
username string
atime time.Time
authorized bool
awayMessage string
capabilities CapabilitySet
capState CapState
channels ChannelSet
commands chan editableCommand
ctime time.Time
flags map[UserMode]bool
hasQuit bool
hops uint
hostname string
idleTimer *time.Timer
loginTimer *time.Timer
nick string
phase Phase
quitTimer *time.Timer
realname string
server *Server
socket *Socket
username string
}
func NewClient(server *Server, conn net.Conn) *Client {
now := time.Now()
client := &Client{
atime: now,
channels: make(ChannelSet),
commands: make(chan editableCommand),
ctime: now,
flags: make(map[UserMode]bool),
phase: server.InitPhase(),
server: server,
atime: now,
capState: CapNone,
capabilities: make(CapabilitySet),
channels: make(ChannelSet),
commands: make(chan editableCommand),
ctime: now,
flags: make(map[UserMode]bool),
phase: Registration,
server: server,
}
client.socket = NewSocket(conn, client.commands)
client.loginTimer = time.AfterFunc(LOGIN_TIMEOUT, client.connectionTimeout)
@ -68,6 +73,12 @@ func (client *Client) run() {
}
}
func (client *Client) connectionTimeout() {
client.commands <- &QuitCommand{
message: "connection timeout",
}
}
//
// idle timer goroutine
//
@ -76,14 +87,6 @@ func (client *Client) connectionIdle() {
client.server.idle <- client
}
//
// quit timer goroutine
//
func (client *Client) connectionTimeout() {
client.server.timeout <- client
}
//
// server goroutine
//
@ -232,7 +235,10 @@ func (client *Client) ChangeNickname(nickname string) {
}
}
func (client *Client) Reply(reply string) {
func (client *Client) Reply(reply string, args ...interface{}) {
if len(args) > 0 {
reply = fmt.Sprintf(reply, args...)
}
client.socket.Write(reply)
}

View File

@ -705,20 +705,34 @@ func NewOperCommand(args []string) (editableCommand, error) {
return cmd, nil
}
// TODO
type CapCommand struct {
BaseCommand
args []string
subCommand CapSubCommand
capabilities CapabilitySet
}
func (msg *CapCommand) String() string {
return fmt.Sprintf("CAP(args=%s)", msg.args)
return fmt.Sprintf("CAP(subCommand=%s, capabilities=%s)",
msg.subCommand, msg.capabilities)
}
func NewCapCommand(args []string) (editableCommand, error) {
return &CapCommand{
args: args,
}, nil
if len(args) < 1 {
return nil, NotEnoughArgsError
}
cmd := &CapCommand{
subCommand: CapSubCommand(strings.ToUpper(args[0])),
capabilities: make(CapabilitySet),
}
if len(args) > 1 {
strs := spacesExpr.Split(args[1], -1)
for _, str := range strs {
cmd.capabilities[Capability(str)] = true
}
}
return cmd, nil
}
// HAPROXY support

View File

@ -155,6 +155,7 @@ const (
ERR_TOOMANYTARGETS NumericCode = 407
ERR_NOSUCHSERVICE NumericCode = 408
ERR_NOORIGIN NumericCode = 409
ERR_INVALIDCAPCMD NumericCode = 410
ERR_NORECIPIENT NumericCode = 411
ERR_NOTEXTTOSEND NumericCode = 412
ERR_NOTOPLEVEL NumericCode = 413
@ -200,6 +201,14 @@ const (
ERR_UMODEUNKNOWNFLAG NumericCode = 501
ERR_USERSDONTMATCH NumericCode = 502
CAP_LS CapSubCommand = "LS"
CAP_LIST CapSubCommand = "LIST"
CAP_REQ CapSubCommand = "REQ"
CAP_ACK CapSubCommand = "ACK"
CAP_NAK CapSubCommand = "NAK"
CAP_CLEAR CapSubCommand = "CLEAR"
CAP_END CapSubCommand = "END"
Add ModeOp = '+'
List ModeOp = '='
Remove ModeOp = '-'
@ -230,10 +239,28 @@ const (
Secret ChannelMode = 's' // flag, deprecated
UserLimit ChannelMode = 'l' // flag arg
Voice ChannelMode = 'v' // arg
MultiPrefix Capability = "multi-prefix"
SASL Capability = "sasl"
Disable CapModifier = '-'
Ack CapModifier = '~'
Sticky CapModifier = '='
)
var (
SupportedCapabilities = CapabilitySet{
MultiPrefix: true,
}
)
const (
Authorization Phase = iota
Registration Phase = iota
Normal Phase = iota
Registration Phase = iota
Normal Phase = iota
)
const (
CapNone CapState = iota
CapNegotiating CapState = iota
CapNegotiated CapState = iota
)

View File

@ -240,11 +240,19 @@ func (target *Client) RplWhoReply(channel *Channel, client *Client) {
if channel != nil {
channelName = channel.name
if channel.members[client][ChannelOperator] {
flags += "@"
} else if channel.members[client][Voice] {
flags += "+"
if target.capabilities[MultiPrefix] {
if channel.members[client][ChannelOperator] {
flags += "@"
}
if channel.members[client][Voice] {
flags += "+"
}
} else {
if channel.members[client][ChannelOperator] {
flags += "@"
} else if channel.members[client][Voice] {
flags += "+"
}
}
}
target.NumericReply(RPL_WHOREPLY,
@ -360,7 +368,7 @@ func (target *Client) RplListEnd(server *Server) {
}
func (target *Client) RplNamReply(channel *Channel) {
target.MultilineReply(channel.Nicks(), RPL_NAMREPLY,
target.MultilineReply(channel.Nicks(target), RPL_NAMREPLY,
"= %s :%s", channel)
}
@ -502,3 +510,8 @@ func (target *Client) ErrChannelIsFull(channel *Channel) {
target.NumericReply(ERR_CHANNELISFULL,
"%s :Cannot join channel (+l)", channel)
}
func (target *Client) ErrInvalidCapCmd(subCommand CapSubCommand) {
target.NumericReply(ERR_INVALIDCAPCMD,
"%s :Invalid CAP subcommand", subCommand)
}

View File

@ -29,7 +29,6 @@ type Server struct {
operators map[string][]byte
password []byte
signals chan os.Signal
timeout chan *Client
}
func NewServer(config *Config) *Server {
@ -45,7 +44,6 @@ func NewServer(config *Config) *Server {
newConns: make(chan net.Conn, 16),
operators: config.Operators(),
signals: make(chan os.Signal, 1),
timeout: make(chan *Client, 16),
}
if config.Server.Password != "" {
@ -97,14 +95,6 @@ func (server *Server) processCommand(cmd Command) {
}
switch client.phase {
case Authorization:
authCmd, ok := cmd.(AuthServerCommand)
if !ok {
client.Quit("unexpected command")
return
}
authCmd.HandleAuthServer(server)
case Registration:
regCmd, ok := cmd.(RegServerCommand)
if !ok {
@ -113,7 +103,7 @@ func (server *Server) processCommand(cmd Command) {
}
regCmd.HandleRegServer(server)
default:
case Normal:
srvCmd, ok := cmd.(ServerCommand)
if !ok {
client.ErrUnknownCommand(cmd.Code())
@ -157,20 +147,10 @@ func (server *Server) Run() {
case client := <-server.idle:
client.Idle()
case client := <-server.timeout:
client.Quit("connection timeout")
}
}
}
func (server *Server) InitPhase() Phase {
if server.password == nil {
return Registration
}
return Authorization
}
//
// listen goroutine
//
@ -206,7 +186,7 @@ func (s *Server) listen(addr string) {
//
func (s *Server) tryRegister(c *Client) {
if c.HasNick() && c.HasUsername() {
if c.HasNick() && c.HasUsername() && (c.capState != CapNegotiating) {
c.Register()
c.RplWelcome()
c.RplYourHost()
@ -266,18 +246,10 @@ func (s *Server) Nick() string {
}
//
// authorization commands
// registration commands
//
func (msg *ProxyCommand) HandleAuthServer(server *Server) {
msg.Client().hostname = msg.hostname
}
func (msg *CapCommand) HandleAuthServer(server *Server) {
// TODO
}
func (msg *PassCommand) HandleAuthServer(server *Server) {
func (msg *PassCommand) HandleRegServer(server *Server) {
client := msg.Client()
if msg.err != nil {
client.ErrPasswdMismatch()
@ -285,27 +257,70 @@ func (msg *PassCommand) HandleAuthServer(server *Server) {
return
}
client.phase = Registration
client.authorized = true
}
func (msg *QuitCommand) HandleAuthServer(server *Server) {
msg.Client().Quit(msg.message)
}
//
// registration commands
//
func (msg *ProxyCommand) HandleRegServer(server *Server) {
msg.Client().hostname = msg.hostname
}
func (msg *CapCommand) HandleRegServer(server *Server) {
// TODO
client := msg.Client()
switch msg.subCommand {
case CAP_LS:
client.capState = CapNegotiating
client.Reply("CAP LS * :%s", SupportedCapabilities)
case CAP_LIST:
client.Reply("CAP LIST * :%s", client.capabilities)
case CAP_REQ:
client.capState = CapNegotiating
for capability := range msg.capabilities {
if !SupportedCapabilities[capability] {
client.Reply("CAP NAK * :%s", msg.capabilities)
return
}
}
for capability := range msg.capabilities {
client.capabilities[capability] = true
}
client.Reply("CAP ACK * :%s", msg.capabilities)
case CAP_CLEAR:
format := strings.TrimRight(
strings.Repeat("%s%s ", len(client.capabilities)), " ")
args := make([]interface{}, len(client.capabilities))
index := 0
for capability := range client.capabilities {
args[index] = Disable
args[index+1] = capability
index += 2
delete(client.capabilities, capability)
}
client.Reply("CAP ACK * :"+format, args...)
case CAP_END:
client.capState = CapNegotiated
server.tryRegister(client)
default:
client.ErrInvalidCapCmd(msg.subCommand)
}
}
func (m *NickCommand) HandleRegServer(s *Server) {
client := m.Client()
if !client.authorized {
client.ErrPasswdMismatch()
client.Quit("bad password")
return
}
if client.capState == CapNegotiating {
client.capState = CapNegotiated
}
if m.nickname == "" {
client.ErrNoNicknameGiven()
@ -327,11 +342,22 @@ func (m *NickCommand) HandleRegServer(s *Server) {
}
func (msg *RFC1459UserCommand) HandleRegServer(server *Server) {
client := msg.Client()
if !client.authorized {
client.ErrPasswdMismatch()
client.Quit("bad password")
return
}
msg.HandleRegServer2(server)
}
func (msg *RFC2812UserCommand) HandleRegServer(server *Server) {
client := msg.Client()
if !client.authorized {
client.ErrPasswdMismatch()
client.Quit("bad password")
return
}
flags := msg.Flags()
if len(flags) > 0 {
for _, mode := range msg.Flags() {
@ -344,6 +370,9 @@ func (msg *RFC2812UserCommand) HandleRegServer(server *Server) {
func (msg *UserCommand) HandleRegServer2(server *Server) {
client := msg.Client()
if client.capState == CapNegotiating {
client.capState = CapNegotiated
}
client.username, client.realname = msg.username, msg.realname
server.tryRegister(client)
}

View File

@ -10,6 +10,30 @@ import (
// simple types
//
type CapSubCommand string
type Capability string
type CapModifier rune
func (mod CapModifier) String() string {
return string(mod)
}
type CapState uint
type CapabilitySet map[Capability]bool
func (set CapabilitySet) String() string {
strs := make([]string, len(set))
index := 0
for capability := range set {
strs[index] = string(capability)
index += 1
}
return strings.Join(strs, " ")
}
// a string with wildcards
type Mask string
@ -24,7 +48,7 @@ func (op ModeOp) String() string {
type UserMode rune
func (mode UserMode) String() string {
return fmt.Sprintf("%c", mode)
return string(mode)
}
type Phase uint
@ -49,7 +73,7 @@ func (code NumericCode) String() string {
type ChannelMode rune
func (mode ChannelMode) String() string {
return fmt.Sprintf("%c", mode)
return string(mode)
}
type ChannelNameMap map[string]*Channel