diff --git a/ergonomadic.go b/ergonomadic.go index dbd45ddb..3d107a7c 100644 --- a/ergonomadic.go +++ b/ergonomadic.go @@ -12,6 +12,7 @@ import ( func main() { conf := flag.String("conf", "ergonomadic.conf", "ergonomadic config file") initdb := flag.Bool("initdb", false, "initialize database") + upgradedb := flag.Bool("upgradedb", false, "update database") passwd := flag.String("genpasswd", "", "bcrypt a password") flag.Parse() @@ -35,7 +36,13 @@ func main() { if *initdb { irc.InitDB(config.Server.Database) - log.Println("database initialized: " + config.Server.Database) + log.Println("database initialized: ", config.Server.Database) + return + } + + if *upgradedb { + irc.UpgradeDB(config.Server.Database) + log.Println("database upgraded: ", config.Server.Database) return } @@ -45,5 +52,8 @@ func main() { irc.DEBUG_CHANNEL = config.Debug.Channel irc.DEBUG_SERVER = config.Debug.Server - irc.NewServer(config).Run() + server := irc.NewServer(config) + log.Println(irc.SEM_VER, "running") + defer log.Println(irc.SEM_VER, "exiting") + server.Run() } diff --git a/irc/channel.go b/irc/channel.go index d7fa7371..41f7f9a8 100644 --- a/irc/channel.go +++ b/irc/channel.go @@ -8,7 +8,7 @@ import ( type Channel struct { flags ChannelModeSet - lists map[ChannelMode][]UserMask + lists map[ChannelMode]*UserMaskSet key string members MemberSet name string @@ -26,10 +26,10 @@ func IsChannel(target string) bool { func NewChannel(s *Server, name string) *Channel { channel := &Channel{ flags: make(ChannelModeSet), - lists: map[ChannelMode][]UserMask{ - BanMask: []UserMask{}, - ExceptMask: []UserMask{}, - InviteMask: []UserMask{}, + lists: map[ChannelMode]*UserMaskSet{ + BanMask: NewUserMaskSet(), + ExceptMask: NewUserMaskSet(), + InviteMask: NewUserMaskSet(), }, members: make(MemberSet), name: strings.ToLower(name), @@ -151,6 +151,19 @@ func (channel *Channel) Join(client *Client, key string) { return } + isInvited := channel.lists[InviteMask].Match(client.UserHost()) + if channel.flags[InviteOnly] && !isInvited { + client.ErrInviteOnlyChan(channel) + return + } + + if channel.lists[BanMask].Match(client.UserHost()) && + !isInvited && + !channel.lists[ExceptMask].Match(client.UserHost()) { + client.ErrBannedFromChan(channel) + return + } + client.channels.Add(channel) channel.members.Add(client) if !channel.flags[Persistent] && (len(channel.members) == 1) { @@ -213,7 +226,7 @@ func (channel *Channel) SetTopic(client *Client, topic string) { } if err := channel.Persist(); err != nil { - log.Println(err) + log.Println("Channel.Persist:", channel, err) } } @@ -310,17 +323,48 @@ func (channel *Channel) applyModeMember(client *Client, mode ChannelMode, return false } +func (channel *Channel) ShowMaskList(client *Client, mode ChannelMode) { + for lmask := range channel.lists[mode].masks { + client.RplMaskList(mode, channel, lmask) + } + client.RplEndOfMaskList(mode, channel) +} + +func (channel *Channel) applyModeMask(client *Client, mode ChannelMode, op ModeOp, + mask string) bool { + list := channel.lists[mode] + if list == nil { + // This should never happen, but better safe than panicky. + return false + } + + if (op == List) || (mask == "") { + channel.ShowMaskList(client, mode) + return false + } + + if !channel.ClientIsOperator(client) { + client.ErrChanOPrivIsNeeded(channel) + return false + } + + if op == Add { + return list.Add(mask) + } + + if op == Remove { + return list.Remove(mask) + } + + return false +} + func (channel *Channel) applyMode(client *Client, change *ChannelModeChange) bool { switch change.mode { case BanMask, ExceptMask, InviteMask: - // TODO add/remove + return channel.applyModeMask(client, change.mode, change.op, change.arg) - for _, mask := range channel.lists[change.mode] { - client.RplMaskList(change.mode, channel, mask) - } - client.RplEndOfMaskList(change.mode, channel) - - case Moderated, NoOutside, OpOnlyTopic, Persistent, Private: + case InviteOnly, Moderated, NoOutside, OpOnlyTopic, Persistent, Private: return channel.applyModeFlag(client, change.mode, change.op) case Key: @@ -390,7 +434,7 @@ func (channel *Channel) Mode(client *Client, changes ChannelModeChanges) { } if err := channel.Persist(); err != nil { - log.Println(err) + log.Println("Channel.Persist:", channel, err) } } } @@ -399,10 +443,12 @@ func (channel *Channel) Persist() (err error) { if channel.flags[Persistent] { _, err = channel.server.db.Exec(` INSERT OR REPLACE INTO channel - (name, flags, key, topic, user_limit) - VALUES (?, ?, ?, ?, ?)`, + (name, flags, key, topic, user_limit, ban_list, except_list, + invite_list) + VALUES (?, ?, ?, ?, ?, ?, ?, ?)`, channel.name, channel.flags.String(), channel.key, channel.topic, - channel.userLimit) + channel.userLimit, channel.lists[BanMask].String(), + channel.lists[ExceptMask].String(), channel.lists[InviteMask].String()) } else { _, err = channel.server.db.Exec(` DELETE FROM channel WHERE name = ?`, channel.name) @@ -464,6 +510,13 @@ func (channel *Channel) Invite(invitee *Client, inviter *Client) { return } + if channel.flags[InviteOnly] { + channel.lists[InviteMask].Add(invitee.UserHost()) + if err := channel.Persist(); err != nil { + log.Println("Channel.Persist:", channel, err) + } + } + inviter.RplInviting(invitee, channel.name) invitee.Reply(RplInviteMsg(inviter, invitee, channel.name)) if invitee.flags[Away] { diff --git a/irc/client.go b/irc/client.go index ff306f34..2adb4309 100644 --- a/irc/client.go +++ b/irc/client.go @@ -229,6 +229,7 @@ func (client *Client) ChangeNickname(nickname string) { // Make reply before changing nick to capture original source id. reply := RplNick(client, nickname) client.server.clients.Remove(client) + client.server.whoWas.Append(client) client.nick = nickname client.server.clients.Add(client) for friend := range client.Friends() { @@ -249,8 +250,8 @@ func (client *Client) Quit(message string) { } client.Reply(RplError("connection closed")) - client.hasQuit = true + client.server.whoWas.Append(client) friends := client.Friends() friends.Remove(client) client.destroy() diff --git a/irc/client_lookup_set.go b/irc/client_lookup_set.go new file mode 100644 index 00000000..261461d9 --- /dev/null +++ b/irc/client_lookup_set.go @@ -0,0 +1,272 @@ +package irc + +import ( + "database/sql" + "errors" + "log" + "regexp" + "strings" +) + +var ( + ErrNickMissing = errors.New("nick missing") + ErrNicknameInUse = errors.New("nickname in use") + ErrNicknameMismatch = errors.New("nickname mismatch") + wildMaskExpr = regexp.MustCompile(`\*|\?`) + likeQuoter = strings.NewReplacer( + `\`, `\\`, + `%`, `\%`, + `_`, `\_`, + `*`, `%`, + `?`, `_`) +) + +func HasWildcards(mask string) bool { + return wildMaskExpr.MatchString(mask) +} + +func ExpandUserHost(userhost string) (expanded string) { + expanded = userhost + // fill in missing wildcards for nicks + if !strings.Contains(expanded, "!") { + expanded += "!*" + } + if !strings.Contains(expanded, "@") { + expanded += "@*" + } + return +} + +func QuoteLike(userhost string) string { + return likeQuoter.Replace(userhost) +} + +type ClientLookupSet struct { + byNick map[string]*Client + db *ClientDB +} + +func NewClientLookupSet() *ClientLookupSet { + return &ClientLookupSet{ + byNick: make(map[string]*Client), + db: NewClientDB(), + } +} + +func (clients *ClientLookupSet) Get(nick string) *Client { + return clients.byNick[strings.ToLower(nick)] +} + +func (clients *ClientLookupSet) Add(client *Client) error { + if !client.HasNick() { + return ErrNickMissing + } + if clients.Get(client.nick) != nil { + return ErrNicknameInUse + } + clients.byNick[strings.ToLower(client.nick)] = client + clients.db.Add(client) + return nil +} + +func (clients *ClientLookupSet) Remove(client *Client) error { + if !client.HasNick() { + return ErrNickMissing + } + if clients.Get(client.nick) != client { + return ErrNicknameMismatch + } + delete(clients.byNick, strings.ToLower(client.nick)) + clients.db.Remove(client) + return nil +} + +func (clients *ClientLookupSet) FindAll(userhost string) (set ClientSet) { + userhost = ExpandUserHost(userhost) + set = make(ClientSet) + rows, err := clients.db.db.Query( + `SELECT nickname FROM client WHERE userhost LIKE ? ESCAPE '\'`, + QuoteLike(userhost)) + if err != nil { + if DEBUG_SERVER { + log.Println("ClientLookupSet.FindAll.Query:", err) + } + return + } + for rows.Next() { + var nickname string + err := rows.Scan(&nickname) + if err != nil { + if DEBUG_SERVER { + log.Println("ClientLookupSet.FindAll.Scan:", err) + } + return + } + client := clients.Get(nickname) + if client == nil { + if DEBUG_SERVER { + log.Println("ClientLookupSet.FindAll: missing client:", nickname) + } + continue + } + set.Add(client) + } + return +} + +func (clients *ClientLookupSet) Find(userhost string) *Client { + userhost = ExpandUserHost(userhost) + row := clients.db.db.QueryRow( + `SELECT nickname FROM client WHERE userhost LIKE ? ESCAPE '\' LIMIT 1`, + QuoteLike(userhost)) + var nickname string + err := row.Scan(&nickname) + if err != nil { + if DEBUG_SERVER { + log.Println("ClientLookupSet.Find:", err) + } + return nil + } + return clients.Get(nickname) +} + +// +// client db +// + +type ClientDB struct { + db *sql.DB +} + +func NewClientDB() *ClientDB { + db := &ClientDB{ + db: OpenDB(":memory:"), + } + stmts := []string{ + `CREATE TABLE client ( + nickname TEXT NOT NULL COLLATE NOCASE UNIQUE, + userhost TEXT NOT NULL COLLATE NOCASE, + UNIQUE (nickname, userhost) ON CONFLICT REPLACE)`, + `CREATE UNIQUE INDEX idx_nick ON client (nickname COLLATE NOCASE)`, + `CREATE UNIQUE INDEX idx_uh ON client (userhost COLLATE NOCASE)`, + } + for _, stmt := range stmts { + _, err := db.db.Exec(stmt) + if err != nil { + log.Fatal("NewClientDB: ", stmt, err) + } + } + return db +} + +func (db *ClientDB) Add(client *Client) { + _, err := db.db.Exec(`INSERT INTO client (nickname, userhost) VALUES (?, ?)`, + client.Nick(), client.UserHost()) + if err != nil { + if DEBUG_SERVER { + log.Println("ClientDB.Add:", err) + } + } +} + +func (db *ClientDB) Remove(client *Client) { + _, err := db.db.Exec(`DELETE FROM client WHERE nickname = ?`, + client.Nick()) + if err != nil { + if DEBUG_SERVER { + log.Println("ClientDB.Remove:", err) + } + } +} + +// +// usermask to regexp +// + +type UserMaskSet struct { + masks map[string]bool + regexp *regexp.Regexp +} + +func NewUserMaskSet() *UserMaskSet { + return &UserMaskSet{ + masks: make(map[string]bool), + } +} + +func (set *UserMaskSet) Add(mask string) bool { + if set.masks[mask] { + return false + } + set.masks[mask] = true + set.setRegexp() + return true +} + +func (set *UserMaskSet) AddAll(masks []string) (added bool) { + for _, mask := range masks { + if !added && !set.masks[mask] { + added = true + } + set.masks[mask] = true + } + set.setRegexp() + return +} + +func (set *UserMaskSet) Remove(mask string) bool { + if !set.masks[mask] { + return false + } + delete(set.masks, mask) + set.setRegexp() + return true +} + +func (set *UserMaskSet) Match(userhost string) bool { + if set.regexp == nil { + return false + } + return set.regexp.MatchString(userhost) +} + +func (set *UserMaskSet) String() string { + masks := make([]string, len(set.masks)) + index := 0 + for mask := range set.masks { + masks[index] = mask + index += 1 + } + return strings.Join(masks, " ") +} + +// Generate a regular expression from the set of user mask +// strings. Masks are split at the two types of wildcards, `*` and +// `?`. All the pieces are meta-escaped. `*` is replaced with `.*`, +// the regexp equivalent. Likewise, `?` is replaced with `.`. The +// parts are re-joined and finally all masks are joined into a big +// or-expression. +func (set *UserMaskSet) setRegexp() { + if len(set.masks) == 0 { + set.regexp = nil + return + } + + maskExprs := make([]string, len(set.masks)) + index := 0 + for mask := range set.masks { + manyParts := strings.Split(mask, "*") + manyExprs := make([]string, len(manyParts)) + for mindex, manyPart := range manyParts { + oneParts := strings.Split(manyPart, "?") + oneExprs := make([]string, len(oneParts)) + for oindex, onePart := range oneParts { + oneExprs[oindex] = regexp.QuoteMeta(onePart) + } + manyExprs[mindex] = strings.Join(oneExprs, ".") + } + maskExprs[index] = strings.Join(manyExprs, ".*") + } + expr := "^" + strings.Join(maskExprs, "|") + "$" + set.regexp, _ = regexp.Compile(expr) +} diff --git a/irc/commands.go b/irc/commands.go index e7d36130..976d8987 100644 --- a/irc/commands.go +++ b/irc/commands.go @@ -54,6 +54,7 @@ var ( VERSION: NewVersionCommand, WHO: NewWhoCommand, WHOIS: NewWhoisCommand, + WHOWAS: NewWhoWasCommand, } ) @@ -656,7 +657,7 @@ func (msg *WhoisCommand) String() string { type WhoCommand struct { BaseCommand - mask Mask + mask string operatorOnly bool } @@ -665,7 +666,7 @@ func NewWhoCommand(args []string) (editableCommand, error) { cmd := &WhoCommand{} if len(args) > 0 { - cmd.mask = Mask(args[0]) + cmd.mask = args[0] } if (len(args) > 1) && (args[1] == "o") { @@ -982,3 +983,26 @@ func NewKillCommand(args []string) (editableCommand, error) { comment: args[1], }, nil } + +type WhoWasCommand struct { + BaseCommand + nicknames []string + count int64 + target string +} + +func NewWhoWasCommand(args []string) (editableCommand, error) { + if len(args) < 1 { + return nil, NotEnoughArgsError + } + cmd := &WhoWasCommand{ + nicknames: strings.Split(args[0], ","), + } + if len(args) > 1 { + cmd.count, _ = strconv.ParseInt(args[1], 10, 64) + } + if len(args) > 2 { + cmd.target = args[2] + } + return cmd, nil +} diff --git a/irc/constants.go b/irc/constants.go index 2a766927..63368b63 100644 --- a/irc/constants.go +++ b/irc/constants.go @@ -61,6 +61,7 @@ const ( VERSION StringCode = "VERSION" WHO StringCode = "WHO" WHOIS StringCode = "WHOIS" + WHOWAS StringCode = "WHOWAS" // numeric codes RPL_WELCOME NumericCode = 1 diff --git a/irc/database.go b/irc/database.go index c7f9264a..2a482ecf 100644 --- a/irc/database.go +++ b/irc/database.go @@ -2,6 +2,7 @@ package irc import ( "database/sql" + "fmt" _ "github.com/mattn/go-sqlite3" "log" "os" @@ -14,15 +15,30 @@ func InitDB(path string) { _, err := db.Exec(` CREATE TABLE channel ( name TEXT NOT NULL UNIQUE, - flags TEXT NOT NULL, - key TEXT NOT NULL, - topic TEXT NOT NULL, - user_limit INTEGER DEFAULT 0)`) + flags TEXT DEFAULT '', + key TEXT DEFAULT '', + topic TEXT DEFAULT '', + user_limit INTEGER DEFAULT 0, + ban_list TEXT DEFAULT '', + except_list TEXT DEFAULT '', + invite_list TEXT DEFAULT '')`) if err != nil { log.Fatal("initdb error: ", err) } } +func UpgradeDB(path string) { + db := OpenDB(path) + alter := `ALTER TABLE channel ADD COLUMN %s TEXT DEFAULT ''` + cols := []string{"ban_list", "except_list", "invite_list"} + for _, col := range cols { + _, err := db.Exec(fmt.Sprintf(alter, col)) + if err != nil { + log.Fatal("updatedb error: ", err) + } + } +} + func OpenDB(path string) *sql.DB { db, err := sql.Open("sqlite3", path) if err != nil { diff --git a/irc/reply.go b/irc/reply.go index 4510f76e..72002f93 100644 --- a/irc/reply.go +++ b/irc/reply.go @@ -200,6 +200,16 @@ func (target *Client) RplYoureOper() { ":You are now an IRC operator") } +func (target *Client) RplWhois(client *Client) { + target.RplWhoisUser(client) + if client.flags[Operator] { + target.RplWhoisOperator(client) + } + target.RplWhoisIdle(client) + target.RplWhoisChannels(client) + target.RplEndOfWhois() +} + func (target *Client) RplWhoisUser(client *Client) { target.NumericReply(RPL_WHOISUSER, "%s %s %s * :%s", client.Nick(), client.username, client.hostname, @@ -270,7 +280,7 @@ func (target *Client) RplEndOfWho(name string) { "%s :End of WHO list", name) } -func (target *Client) RplMaskList(mode ChannelMode, channel *Channel, mask UserMask) { +func (target *Client) RplMaskList(mode ChannelMode, channel *Channel, mask string) { switch mode { case BanMask: target.RplBanList(channel, mask) @@ -296,7 +306,7 @@ func (target *Client) RplEndOfMaskList(mode ChannelMode, channel *Channel) { } } -func (target *Client) RplBanList(channel *Channel, mask UserMask) { +func (target *Client) RplBanList(channel *Channel, mask string) { target.NumericReply(RPL_BANLIST, "%s %s", channel, mask) } @@ -306,7 +316,7 @@ func (target *Client) RplEndOfBanList(channel *Channel) { "%s :End of channel ban list", channel) } -func (target *Client) RplExceptList(channel *Channel, mask UserMask) { +func (target *Client) RplExceptList(channel *Channel, mask string) { target.NumericReply(RPL_EXCEPTLIST, "%s %s", channel, mask) } @@ -316,7 +326,7 @@ func (target *Client) RplEndOfExceptList(channel *Channel) { "%s :End of channel exception list", channel) } -func (target *Client) RplInviteList(channel *Channel, mask UserMask) { +func (target *Client) RplInviteList(channel *Channel, mask string) { target.NumericReply(RPL_INVITELIST, "%s %s", channel, mask) } @@ -396,6 +406,17 @@ func (target *Client) RplTime() { "%s :%s", target.server.name, time.Now().Format(time.RFC1123)) } +func (target *Client) RplWhoWasUser(whoWas *WhoWas) { + target.NumericReply(RPL_WHOWASUSER, + "%s %s %s * :%s", + whoWas.nickname, whoWas.username, whoWas.hostname, whoWas.realname) +} + +func (target *Client) RplEndOfWhoWas(nickname string) { + target.NumericReply(RPL_ENDOFWHOWAS, + "%s :End of WHOWAS", nickname) +} + // // errors (also numeric) // @@ -515,7 +536,22 @@ func (target *Client) ErrChannelIsFull(channel *Channel) { "%s :Cannot join channel (+l)", channel) } +func (target *Client) ErrWasNoSuchNick(nickname string) { + target.NumericReply(ERR_WASNOSUCHNICK, + "%s :There was no such nickname", nickname) +} + func (target *Client) ErrInvalidCapCmd(subCommand CapSubCommand) { target.NumericReply(ERR_INVALIDCAPCMD, "%s :Invalid CAP subcommand", subCommand) } + +func (target *Client) ErrBannedFromChan(channel *Channel) { + target.NumericReply(ERR_BANNEDFROMCHAN, + "%s :Cannot join channel (+b)", channel) +} + +func (target *Client) ErrInviteOnlyChan(channel *Channel) { + target.NumericReply(ERR_INVITEONLYCHAN, + "%s :Cannot join channel (+i)", channel) +} diff --git a/irc/server.go b/irc/server.go index ff083897..9c7acae9 100644 --- a/irc/server.go +++ b/irc/server.go @@ -18,7 +18,7 @@ import ( type Server struct { channels ChannelNameMap - clients ClientNameMap + clients *ClientLookupSet commands chan Command ctime time.Time db *sql.DB @@ -29,12 +29,13 @@ type Server struct { operators map[string][]byte password []byte signals chan os.Signal + whoWas *WhoWasList } func NewServer(config *Config) *Server { server := &Server{ channels: make(ChannelNameMap), - clients: make(ClientNameMap), + clients: NewClientLookupSet(), commands: make(chan Command, 16), ctime: time.Now(), db: OpenDB(config.Server.Database), @@ -44,6 +45,7 @@ func NewServer(config *Config) *Server { newConns: make(chan net.Conn, 16), operators: config.Operators(), signals: make(chan os.Signal, 1), + whoWas: NewWhoWasList(100), } if config.Server.Password != "" { @@ -62,9 +64,17 @@ func NewServer(config *Config) *Server { return server } +func loadChannelList(channel *Channel, list string, maskMode ChannelMode) { + if list == "" { + return + } + channel.lists[maskMode].AddAll(strings.Split(list, " ")) +} + func (server *Server) loadChannels() { rows, err := server.db.Query(` - SELECT name, flags, key, topic, user_limit + SELECT name, flags, key, topic, user_limit, ban_list, except_list, + invite_list FROM channel`) if err != nil { log.Fatal("error loading channels: ", err) @@ -72,9 +82,11 @@ func (server *Server) loadChannels() { for rows.Next() { var name, flags, key, topic string var userLimit uint64 - err = rows.Scan(&name, &flags, &key, &topic, &userLimit) + var banList, exceptList, inviteList string + err = rows.Scan(&name, &flags, &key, &topic, &userLimit, &banList, + &exceptList, &inviteList) if err != nil { - log.Println(err) + log.Println("Server.loadChannels:", err) continue } @@ -85,6 +97,9 @@ func (server *Server) loadChannels() { channel.key = key channel.topic = topic channel.userLimit = userLimit + loadChannelList(channel, banList, BanMask) + loadChannelList(channel, exceptList, ExceptMask) + loadChannelList(channel, inviteList, InviteMask) } } @@ -126,7 +141,7 @@ func (server *Server) processCommand(cmd Command) { func (server *Server) Shutdown() { server.db.Close() - for _, client := range server.clients { + for _, client := range server.clients.byNick { client.Reply(RplNotice(server, client, "shutting down")) } } @@ -340,7 +355,7 @@ func (msg *RFC1459UserCommand) HandleRegServer(server *Server) { client.Quit("bad password") return } - msg.HandleRegServer2(server) + msg.setUserInfo(server) } func (msg *RFC2812UserCommand) HandleRegServer(server *Server) { @@ -357,15 +372,19 @@ func (msg *RFC2812UserCommand) HandleRegServer(server *Server) { } client.RplUModeIs(client) } - msg.HandleRegServer2(server) + msg.setUserInfo(server) } -func (msg *UserCommand) HandleRegServer2(server *Server) { +func (msg *UserCommand) setUserInfo(server *Server) { client := msg.Client() if client.capState == CapNegotiating { client.capState = CapNegotiated } + + server.clients.Remove(client) client.username, client.realname = msg.username, msg.realname + server.clients.Add(client) + server.tryRegister(client) } @@ -514,7 +533,7 @@ func (m *ModeCommand) HandleServer(s *Server) { return } - changes := make(ModeChanges, 0) + changes := make(ModeChanges, 0, len(m.changes)) for _, change := range m.changes { switch change.mode { @@ -577,19 +596,14 @@ func (m *WhoisCommand) HandleServer(server *Server) { // TODO implement target query for _, mask := range m.masks { - // TODO implement wildcard matching - mclient := server.clients.Get(mask) - if mclient == nil { + matches := server.clients.FindAll(mask) + if len(matches) == 0 { client.ErrNoSuchNick(mask) continue } - client.RplWhoisUser(mclient) - if mclient.flags[Operator] { - client.RplWhoisOperator(mclient) + for mclient := range matches { + client.RplWhois(mclient) } - client.RplWhoisIdle(mclient) - client.RplWhoisChannels(mclient) - client.RplEndOfWhois() } } @@ -604,9 +618,9 @@ func (msg *ChannelModeCommand) HandleServer(server *Server) { channel.Mode(client, msg.changes) } -func whoChannel(client *Client, channel *Channel) { +func whoChannel(client *Client, channel *Channel, friends ClientSet) { for member := range channel.members { - if !client.flags[Invisible] { + if !client.flags[Invisible] || friends[client] { client.RplWhoReply(channel, member) } } @@ -614,27 +628,21 @@ func whoChannel(client *Client, channel *Channel) { func (msg *WhoCommand) HandleServer(server *Server) { client := msg.Client() + friends := client.Friends() + mask := msg.mask - // TODO implement wildcard matching - mask := string(msg.mask) if mask == "" { for _, channel := range server.channels { - for member := range channel.members { - if !client.flags[Invisible] { - client.RplWhoReply(channel, member) - } - } + whoChannel(client, channel, friends) } } else if IsChannel(mask) { + // TODO implement wildcard matching channel := server.channels.Get(mask) if channel != nil { - for member := range channel.members { - client.RplWhoReply(channel, member) - } + whoChannel(client, channel, friends) } } else { - mclient := server.clients.Get(mask) - if mclient != nil { + for mclient := range server.clients.FindAll(mask) { client.RplWhoReply(nil, mclient) } } @@ -874,3 +882,18 @@ func (msg *KillCommand) HandleServer(server *Server) { quitMsg := fmt.Sprintf("KILLed by %s: %s", client.Nick(), msg.comment) target.Quit(quitMsg) } + +func (msg *WhoWasCommand) HandleServer(server *Server) { + client := msg.Client() + for _, nickname := range msg.nicknames { + results := server.whoWas.Find(nickname, msg.count) + if len(results) == 0 { + client.ErrWasNoSuchNick(nickname) + } else { + for _, whoWas := range results { + client.RplWhoWasUser(whoWas) + } + } + client.RplEndOfWhoWas(nickname) + } +} diff --git a/irc/types.go b/irc/types.go index 198f7b65..19014012 100644 --- a/irc/types.go +++ b/irc/types.go @@ -1,7 +1,6 @@ package irc import ( - "errors" "fmt" "strings" ) @@ -48,9 +47,6 @@ func (set CapabilitySet) DisableString() string { return strings.Join(parts, " ") } -// a string with wildcards -type Mask string - // add, remove, list modes type ModeOp rune @@ -112,40 +108,6 @@ func (channels ChannelNameMap) Remove(channel *Channel) error { return nil } -type ClientNameMap map[string]*Client - -var ( - ErrNickMissing = errors.New("nick missing") - ErrNicknameInUse = errors.New("nickname in use") - ErrNicknameMismatch = errors.New("nickname mismatch") -) - -func (clients ClientNameMap) Get(nick string) *Client { - return clients[strings.ToLower(nick)] -} - -func (clients ClientNameMap) Add(client *Client) error { - if !client.HasNick() { - return ErrNickMissing - } - if clients.Get(client.nick) != nil { - return ErrNicknameInUse - } - clients[strings.ToLower(client.nick)] = client - return nil -} - -func (clients ClientNameMap) Remove(client *Client) error { - if !client.HasNick() { - return ErrNickMissing - } - if clients.Get(client.nick) != client { - return ErrNicknameMismatch - } - delete(clients, strings.ToLower(client.nick)) - return nil -} - type ChannelModeSet map[ChannelMode]bool func (set ChannelModeSet) String() string { @@ -247,17 +209,3 @@ type RegServerCommand interface { Command HandleRegServer(*Server) } - -// -// structs -// - -type UserMask struct { - nickname Mask - username Mask - hostname Mask -} - -func (mask *UserMask) String() string { - return fmt.Sprintf("%s!%s@%s", mask.nickname, mask.username, mask.hostname) -} diff --git a/irc/whowas.go b/irc/whowas.go new file mode 100644 index 00000000..008ed7f3 --- /dev/null +++ b/irc/whowas.go @@ -0,0 +1,73 @@ +package irc + +type WhoWasList struct { + buffer []*WhoWas + start uint + end uint +} + +type WhoWas struct { + nickname string + username string + hostname string + realname string +} + +func NewWhoWasList(size uint) *WhoWasList { + return &WhoWasList{ + buffer: make([]*WhoWas, size), + } +} + +func (list *WhoWasList) Append(client *Client) { + list.buffer[list.end] = &WhoWas{ + nickname: client.Nick(), + username: client.username, + hostname: client.hostname, + realname: client.realname, + } + list.end = (list.end + 1) % uint(len(list.buffer)) + if list.end == list.start { + list.start = (list.end + 1) % uint(len(list.buffer)) + } +} + +func (list *WhoWasList) Find(nickname string, limit int64) []*WhoWas { + results := make([]*WhoWas, 0) + for whoWas := range list.Each() { + if nickname != whoWas.nickname { + continue + } + results = append(results, whoWas) + if int64(len(results)) >= limit { + break + } + } + return results +} + +func (list *WhoWasList) prev(index uint) uint { + index -= 1 + if index < 0 { + index += uint(len(list.buffer)) + } + return index +} + +// Iterate the buffer in reverse. +func (list *WhoWasList) Each() <-chan *WhoWas { + ch := make(chan *WhoWas) + go func() { + defer close(ch) + if list.start == list.end { + return + } + start := list.prev(list.end) + end := list.prev(list.start) + for start != end { + ch <- list.buffer[start] + start = list.prev(start) + } + }() + return ch +}