mirror of
https://github.com/ergochat/ergo.git
synced 2025-01-08 19:22:53 +01:00
Merge pull request #8 from jlatt/cap-protocol
basic capability negotiation
This commit is contained in:
commit
4bcd42ff34
@ -54,18 +54,26 @@ func (channel *Channel) ClientIsOperator(client *Client) bool {
|
||||
return client.flags[Operator] || channel.members.HasMode(client, ChannelOperator)
|
||||
}
|
||||
|
||||
func (channel *Channel) Nicks() []string {
|
||||
func (channel *Channel) Nicks(target *Client) []string {
|
||||
isMultiPrefix := (target != nil) && target.capabilities[MultiPrefix]
|
||||
nicks := make([]string, len(channel.members))
|
||||
i := 0
|
||||
for client, modes := range channel.members {
|
||||
switch {
|
||||
case modes[ChannelOperator]:
|
||||
nicks[i] = "@" + client.Nick()
|
||||
case modes[Voice]:
|
||||
nicks[i] = "+" + client.Nick()
|
||||
default:
|
||||
nicks[i] = client.Nick()
|
||||
if isMultiPrefix {
|
||||
if modes[ChannelOperator] {
|
||||
nicks[i] += "@"
|
||||
}
|
||||
if modes[Voice] {
|
||||
nicks[i] += "+"
|
||||
}
|
||||
} else {
|
||||
if modes[ChannelOperator] {
|
||||
nicks[i] += "@"
|
||||
} else if modes[Voice] {
|
||||
nicks[i] += "+"
|
||||
}
|
||||
}
|
||||
nicks[i] += client.Nick()
|
||||
i += 1
|
||||
}
|
||||
return nicks
|
||||
|
@ -13,7 +13,10 @@ func IsNickname(nick string) bool {
|
||||
|
||||
type Client struct {
|
||||
atime time.Time
|
||||
authorized bool
|
||||
awayMessage string
|
||||
capabilities CapabilitySet
|
||||
capState CapState
|
||||
channels ChannelSet
|
||||
commands chan editableCommand
|
||||
ctime time.Time
|
||||
@ -36,11 +39,13 @@ func NewClient(server *Server, conn net.Conn) *Client {
|
||||
now := time.Now()
|
||||
client := &Client{
|
||||
atime: now,
|
||||
capState: CapNone,
|
||||
capabilities: make(CapabilitySet),
|
||||
channels: make(ChannelSet),
|
||||
commands: make(chan editableCommand),
|
||||
ctime: now,
|
||||
flags: make(map[UserMode]bool),
|
||||
phase: server.InitPhase(),
|
||||
phase: Registration,
|
||||
server: server,
|
||||
}
|
||||
client.socket = NewSocket(conn, client.commands)
|
||||
@ -68,6 +73,12 @@ func (client *Client) run() {
|
||||
}
|
||||
}
|
||||
|
||||
func (client *Client) connectionTimeout() {
|
||||
client.commands <- &QuitCommand{
|
||||
message: "connection timeout",
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// idle timer goroutine
|
||||
//
|
||||
@ -76,14 +87,6 @@ func (client *Client) connectionIdle() {
|
||||
client.server.idle <- client
|
||||
}
|
||||
|
||||
//
|
||||
// quit timer goroutine
|
||||
//
|
||||
|
||||
func (client *Client) connectionTimeout() {
|
||||
client.server.timeout <- client
|
||||
}
|
||||
|
||||
//
|
||||
// server goroutine
|
||||
//
|
||||
@ -232,7 +235,10 @@ func (client *Client) ChangeNickname(nickname string) {
|
||||
}
|
||||
}
|
||||
|
||||
func (client *Client) Reply(reply string) {
|
||||
func (client *Client) Reply(reply string, args ...interface{}) {
|
||||
if len(args) > 0 {
|
||||
reply = fmt.Sprintf(reply, args...)
|
||||
}
|
||||
client.socket.Write(reply)
|
||||
}
|
||||
|
||||
|
@ -705,20 +705,34 @@ func NewOperCommand(args []string) (editableCommand, error) {
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
// TODO
|
||||
type CapCommand struct {
|
||||
BaseCommand
|
||||
args []string
|
||||
subCommand CapSubCommand
|
||||
capabilities CapabilitySet
|
||||
}
|
||||
|
||||
func (msg *CapCommand) String() string {
|
||||
return fmt.Sprintf("CAP(args=%s)", msg.args)
|
||||
return fmt.Sprintf("CAP(subCommand=%s, capabilities=%s)",
|
||||
msg.subCommand, msg.capabilities)
|
||||
}
|
||||
|
||||
func NewCapCommand(args []string) (editableCommand, error) {
|
||||
return &CapCommand{
|
||||
args: args,
|
||||
}, nil
|
||||
if len(args) < 1 {
|
||||
return nil, NotEnoughArgsError
|
||||
}
|
||||
|
||||
cmd := &CapCommand{
|
||||
subCommand: CapSubCommand(strings.ToUpper(args[0])),
|
||||
capabilities: make(CapabilitySet),
|
||||
}
|
||||
|
||||
if len(args) > 1 {
|
||||
strs := spacesExpr.Split(args[1], -1)
|
||||
for _, str := range strs {
|
||||
cmd.capabilities[Capability(str)] = true
|
||||
}
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
// HAPROXY support
|
||||
|
@ -155,6 +155,7 @@ const (
|
||||
ERR_TOOMANYTARGETS NumericCode = 407
|
||||
ERR_NOSUCHSERVICE NumericCode = 408
|
||||
ERR_NOORIGIN NumericCode = 409
|
||||
ERR_INVALIDCAPCMD NumericCode = 410
|
||||
ERR_NORECIPIENT NumericCode = 411
|
||||
ERR_NOTEXTTOSEND NumericCode = 412
|
||||
ERR_NOTOPLEVEL NumericCode = 413
|
||||
@ -200,6 +201,14 @@ const (
|
||||
ERR_UMODEUNKNOWNFLAG NumericCode = 501
|
||||
ERR_USERSDONTMATCH NumericCode = 502
|
||||
|
||||
CAP_LS CapSubCommand = "LS"
|
||||
CAP_LIST CapSubCommand = "LIST"
|
||||
CAP_REQ CapSubCommand = "REQ"
|
||||
CAP_ACK CapSubCommand = "ACK"
|
||||
CAP_NAK CapSubCommand = "NAK"
|
||||
CAP_CLEAR CapSubCommand = "CLEAR"
|
||||
CAP_END CapSubCommand = "END"
|
||||
|
||||
Add ModeOp = '+'
|
||||
List ModeOp = '='
|
||||
Remove ModeOp = '-'
|
||||
@ -230,10 +239,28 @@ const (
|
||||
Secret ChannelMode = 's' // flag, deprecated
|
||||
UserLimit ChannelMode = 'l' // flag arg
|
||||
Voice ChannelMode = 'v' // arg
|
||||
|
||||
MultiPrefix Capability = "multi-prefix"
|
||||
SASL Capability = "sasl"
|
||||
|
||||
Disable CapModifier = '-'
|
||||
Ack CapModifier = '~'
|
||||
Sticky CapModifier = '='
|
||||
)
|
||||
|
||||
var (
|
||||
SupportedCapabilities = CapabilitySet{
|
||||
MultiPrefix: true,
|
||||
}
|
||||
)
|
||||
|
||||
const (
|
||||
Authorization Phase = iota
|
||||
Registration Phase = iota
|
||||
Normal Phase = iota
|
||||
)
|
||||
|
||||
const (
|
||||
CapNone CapState = iota
|
||||
CapNegotiating CapState = iota
|
||||
CapNegotiated CapState = iota
|
||||
)
|
||||
|
17
irc/reply.go
17
irc/reply.go
@ -240,13 +240,21 @@ func (target *Client) RplWhoReply(channel *Channel, client *Client) {
|
||||
|
||||
if channel != nil {
|
||||
channelName = channel.name
|
||||
|
||||
if target.capabilities[MultiPrefix] {
|
||||
if channel.members[client][ChannelOperator] {
|
||||
flags += "@"
|
||||
}
|
||||
if channel.members[client][Voice] {
|
||||
flags += "+"
|
||||
}
|
||||
} else {
|
||||
if channel.members[client][ChannelOperator] {
|
||||
flags += "@"
|
||||
} else if channel.members[client][Voice] {
|
||||
flags += "+"
|
||||
}
|
||||
}
|
||||
}
|
||||
target.NumericReply(RPL_WHOREPLY,
|
||||
"%s %s %s %s %s %s :%d %s", channelName, client.username, client.hostname,
|
||||
client.server.name, client.Nick(), flags, client.hops, client.realname)
|
||||
@ -360,7 +368,7 @@ func (target *Client) RplListEnd(server *Server) {
|
||||
}
|
||||
|
||||
func (target *Client) RplNamReply(channel *Channel) {
|
||||
target.MultilineReply(channel.Nicks(), RPL_NAMREPLY,
|
||||
target.MultilineReply(channel.Nicks(target), RPL_NAMREPLY,
|
||||
"= %s :%s", channel)
|
||||
}
|
||||
|
||||
@ -502,3 +510,8 @@ func (target *Client) ErrChannelIsFull(channel *Channel) {
|
||||
target.NumericReply(ERR_CHANNELISFULL,
|
||||
"%s :Cannot join channel (+l)", channel)
|
||||
}
|
||||
|
||||
func (target *Client) ErrInvalidCapCmd(subCommand CapSubCommand) {
|
||||
target.NumericReply(ERR_INVALIDCAPCMD,
|
||||
"%s :Invalid CAP subcommand", subCommand)
|
||||
}
|
||||
|
113
irc/server.go
113
irc/server.go
@ -29,7 +29,6 @@ type Server struct {
|
||||
operators map[string][]byte
|
||||
password []byte
|
||||
signals chan os.Signal
|
||||
timeout chan *Client
|
||||
}
|
||||
|
||||
func NewServer(config *Config) *Server {
|
||||
@ -45,7 +44,6 @@ func NewServer(config *Config) *Server {
|
||||
newConns: make(chan net.Conn, 16),
|
||||
operators: config.Operators(),
|
||||
signals: make(chan os.Signal, 1),
|
||||
timeout: make(chan *Client, 16),
|
||||
}
|
||||
|
||||
if config.Server.Password != "" {
|
||||
@ -97,14 +95,6 @@ func (server *Server) processCommand(cmd Command) {
|
||||
}
|
||||
|
||||
switch client.phase {
|
||||
case Authorization:
|
||||
authCmd, ok := cmd.(AuthServerCommand)
|
||||
if !ok {
|
||||
client.Quit("unexpected command")
|
||||
return
|
||||
}
|
||||
authCmd.HandleAuthServer(server)
|
||||
|
||||
case Registration:
|
||||
regCmd, ok := cmd.(RegServerCommand)
|
||||
if !ok {
|
||||
@ -113,7 +103,7 @@ func (server *Server) processCommand(cmd Command) {
|
||||
}
|
||||
regCmd.HandleRegServer(server)
|
||||
|
||||
default:
|
||||
case Normal:
|
||||
srvCmd, ok := cmd.(ServerCommand)
|
||||
if !ok {
|
||||
client.ErrUnknownCommand(cmd.Code())
|
||||
@ -157,20 +147,10 @@ func (server *Server) Run() {
|
||||
|
||||
case client := <-server.idle:
|
||||
client.Idle()
|
||||
|
||||
case client := <-server.timeout:
|
||||
client.Quit("connection timeout")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (server *Server) InitPhase() Phase {
|
||||
if server.password == nil {
|
||||
return Registration
|
||||
}
|
||||
return Authorization
|
||||
}
|
||||
|
||||
//
|
||||
// listen goroutine
|
||||
//
|
||||
@ -206,7 +186,7 @@ func (s *Server) listen(addr string) {
|
||||
//
|
||||
|
||||
func (s *Server) tryRegister(c *Client) {
|
||||
if c.HasNick() && c.HasUsername() {
|
||||
if c.HasNick() && c.HasUsername() && (c.capState != CapNegotiating) {
|
||||
c.Register()
|
||||
c.RplWelcome()
|
||||
c.RplYourHost()
|
||||
@ -266,18 +246,10 @@ func (s *Server) Nick() string {
|
||||
}
|
||||
|
||||
//
|
||||
// authorization commands
|
||||
// registration commands
|
||||
//
|
||||
|
||||
func (msg *ProxyCommand) HandleAuthServer(server *Server) {
|
||||
msg.Client().hostname = msg.hostname
|
||||
}
|
||||
|
||||
func (msg *CapCommand) HandleAuthServer(server *Server) {
|
||||
// TODO
|
||||
}
|
||||
|
||||
func (msg *PassCommand) HandleAuthServer(server *Server) {
|
||||
func (msg *PassCommand) HandleRegServer(server *Server) {
|
||||
client := msg.Client()
|
||||
if msg.err != nil {
|
||||
client.ErrPasswdMismatch()
|
||||
@ -285,27 +257,70 @@ func (msg *PassCommand) HandleAuthServer(server *Server) {
|
||||
return
|
||||
}
|
||||
|
||||
client.phase = Registration
|
||||
client.authorized = true
|
||||
}
|
||||
|
||||
func (msg *QuitCommand) HandleAuthServer(server *Server) {
|
||||
msg.Client().Quit(msg.message)
|
||||
}
|
||||
|
||||
//
|
||||
// registration commands
|
||||
//
|
||||
|
||||
func (msg *ProxyCommand) HandleRegServer(server *Server) {
|
||||
msg.Client().hostname = msg.hostname
|
||||
}
|
||||
|
||||
func (msg *CapCommand) HandleRegServer(server *Server) {
|
||||
// TODO
|
||||
client := msg.Client()
|
||||
|
||||
switch msg.subCommand {
|
||||
case CAP_LS:
|
||||
client.capState = CapNegotiating
|
||||
client.Reply("CAP LS * :%s", SupportedCapabilities)
|
||||
|
||||
case CAP_LIST:
|
||||
client.Reply("CAP LIST * :%s", client.capabilities)
|
||||
|
||||
case CAP_REQ:
|
||||
client.capState = CapNegotiating
|
||||
for capability := range msg.capabilities {
|
||||
if !SupportedCapabilities[capability] {
|
||||
client.Reply("CAP NAK * :%s", msg.capabilities)
|
||||
return
|
||||
}
|
||||
}
|
||||
for capability := range msg.capabilities {
|
||||
client.capabilities[capability] = true
|
||||
}
|
||||
client.Reply("CAP ACK * :%s", msg.capabilities)
|
||||
|
||||
case CAP_CLEAR:
|
||||
format := strings.TrimRight(
|
||||
strings.Repeat("%s%s ", len(client.capabilities)), " ")
|
||||
args := make([]interface{}, len(client.capabilities))
|
||||
index := 0
|
||||
for capability := range client.capabilities {
|
||||
args[index] = Disable
|
||||
args[index+1] = capability
|
||||
index += 2
|
||||
delete(client.capabilities, capability)
|
||||
}
|
||||
client.Reply("CAP ACK * :"+format, args...)
|
||||
|
||||
case CAP_END:
|
||||
client.capState = CapNegotiated
|
||||
server.tryRegister(client)
|
||||
|
||||
default:
|
||||
client.ErrInvalidCapCmd(msg.subCommand)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *NickCommand) HandleRegServer(s *Server) {
|
||||
client := m.Client()
|
||||
if !client.authorized {
|
||||
client.ErrPasswdMismatch()
|
||||
client.Quit("bad password")
|
||||
return
|
||||
}
|
||||
|
||||
if client.capState == CapNegotiating {
|
||||
client.capState = CapNegotiated
|
||||
}
|
||||
|
||||
if m.nickname == "" {
|
||||
client.ErrNoNicknameGiven()
|
||||
@ -327,11 +342,22 @@ func (m *NickCommand) HandleRegServer(s *Server) {
|
||||
}
|
||||
|
||||
func (msg *RFC1459UserCommand) HandleRegServer(server *Server) {
|
||||
client := msg.Client()
|
||||
if !client.authorized {
|
||||
client.ErrPasswdMismatch()
|
||||
client.Quit("bad password")
|
||||
return
|
||||
}
|
||||
msg.HandleRegServer2(server)
|
||||
}
|
||||
|
||||
func (msg *RFC2812UserCommand) HandleRegServer(server *Server) {
|
||||
client := msg.Client()
|
||||
if !client.authorized {
|
||||
client.ErrPasswdMismatch()
|
||||
client.Quit("bad password")
|
||||
return
|
||||
}
|
||||
flags := msg.Flags()
|
||||
if len(flags) > 0 {
|
||||
for _, mode := range msg.Flags() {
|
||||
@ -344,6 +370,9 @@ func (msg *RFC2812UserCommand) HandleRegServer(server *Server) {
|
||||
|
||||
func (msg *UserCommand) HandleRegServer2(server *Server) {
|
||||
client := msg.Client()
|
||||
if client.capState == CapNegotiating {
|
||||
client.capState = CapNegotiated
|
||||
}
|
||||
client.username, client.realname = msg.username, msg.realname
|
||||
server.tryRegister(client)
|
||||
}
|
||||
|
28
irc/types.go
28
irc/types.go
@ -10,6 +10,30 @@ import (
|
||||
// simple types
|
||||
//
|
||||
|
||||
type CapSubCommand string
|
||||
|
||||
type Capability string
|
||||
|
||||
type CapModifier rune
|
||||
|
||||
func (mod CapModifier) String() string {
|
||||
return string(mod)
|
||||
}
|
||||
|
||||
type CapState uint
|
||||
|
||||
type CapabilitySet map[Capability]bool
|
||||
|
||||
func (set CapabilitySet) String() string {
|
||||
strs := make([]string, len(set))
|
||||
index := 0
|
||||
for capability := range set {
|
||||
strs[index] = string(capability)
|
||||
index += 1
|
||||
}
|
||||
return strings.Join(strs, " ")
|
||||
}
|
||||
|
||||
// a string with wildcards
|
||||
type Mask string
|
||||
|
||||
@ -24,7 +48,7 @@ func (op ModeOp) String() string {
|
||||
type UserMode rune
|
||||
|
||||
func (mode UserMode) String() string {
|
||||
return fmt.Sprintf("%c", mode)
|
||||
return string(mode)
|
||||
}
|
||||
|
||||
type Phase uint
|
||||
@ -49,7 +73,7 @@ func (code NumericCode) String() string {
|
||||
type ChannelMode rune
|
||||
|
||||
func (mode ChannelMode) String() string {
|
||||
return fmt.Sprintf("%c", mode)
|
||||
return string(mode)
|
||||
}
|
||||
|
||||
type ChannelNameMap map[string]*Channel
|
||||
|
Loading…
Reference in New Issue
Block a user