3
0
mirror of https://github.com/ergochat/ergo.git synced 2025-01-24 19:24:16 +01:00

Merge pull request #16 from jlatt/user-mask

support user mask wildcards through an in-memory sqlite db
This commit is contained in:
Jeremy Latt 2014-03-08 19:23:36 -08:00
commit bc3480ebb8
11 changed files with 572 additions and 115 deletions

View File

@ -12,6 +12,7 @@ import (
func main() { func main() {
conf := flag.String("conf", "ergonomadic.conf", "ergonomadic config file") conf := flag.String("conf", "ergonomadic.conf", "ergonomadic config file")
initdb := flag.Bool("initdb", false, "initialize database") initdb := flag.Bool("initdb", false, "initialize database")
upgradedb := flag.Bool("upgradedb", false, "update database")
passwd := flag.String("genpasswd", "", "bcrypt a password") passwd := flag.String("genpasswd", "", "bcrypt a password")
flag.Parse() flag.Parse()
@ -35,7 +36,13 @@ func main() {
if *initdb { if *initdb {
irc.InitDB(config.Server.Database) irc.InitDB(config.Server.Database)
log.Println("database initialized: " + config.Server.Database) log.Println("database initialized: ", config.Server.Database)
return
}
if *upgradedb {
irc.UpgradeDB(config.Server.Database)
log.Println("database upgraded: ", config.Server.Database)
return return
} }
@ -45,5 +52,8 @@ func main() {
irc.DEBUG_CHANNEL = config.Debug.Channel irc.DEBUG_CHANNEL = config.Debug.Channel
irc.DEBUG_SERVER = config.Debug.Server irc.DEBUG_SERVER = config.Debug.Server
irc.NewServer(config).Run() server := irc.NewServer(config)
log.Println(irc.SEM_VER, "running")
defer log.Println(irc.SEM_VER, "exiting")
server.Run()
} }

View File

@ -8,7 +8,7 @@ import (
type Channel struct { type Channel struct {
flags ChannelModeSet flags ChannelModeSet
lists map[ChannelMode][]UserMask lists map[ChannelMode]*UserMaskSet
key string key string
members MemberSet members MemberSet
name string name string
@ -26,10 +26,10 @@ func IsChannel(target string) bool {
func NewChannel(s *Server, name string) *Channel { func NewChannel(s *Server, name string) *Channel {
channel := &Channel{ channel := &Channel{
flags: make(ChannelModeSet), flags: make(ChannelModeSet),
lists: map[ChannelMode][]UserMask{ lists: map[ChannelMode]*UserMaskSet{
BanMask: []UserMask{}, BanMask: NewUserMaskSet(),
ExceptMask: []UserMask{}, ExceptMask: NewUserMaskSet(),
InviteMask: []UserMask{}, InviteMask: NewUserMaskSet(),
}, },
members: make(MemberSet), members: make(MemberSet),
name: strings.ToLower(name), name: strings.ToLower(name),
@ -151,6 +151,19 @@ func (channel *Channel) Join(client *Client, key string) {
return return
} }
isInvited := channel.lists[InviteMask].Match(client.UserHost())
if channel.flags[InviteOnly] && !isInvited {
client.ErrInviteOnlyChan(channel)
return
}
if channel.lists[BanMask].Match(client.UserHost()) &&
!isInvited &&
!channel.lists[ExceptMask].Match(client.UserHost()) {
client.ErrBannedFromChan(channel)
return
}
client.channels.Add(channel) client.channels.Add(channel)
channel.members.Add(client) channel.members.Add(client)
if !channel.flags[Persistent] && (len(channel.members) == 1) { if !channel.flags[Persistent] && (len(channel.members) == 1) {
@ -213,7 +226,7 @@ func (channel *Channel) SetTopic(client *Client, topic string) {
} }
if err := channel.Persist(); err != nil { if err := channel.Persist(); err != nil {
log.Println(err) log.Println("Channel.Persist:", channel, err)
} }
} }
@ -310,17 +323,48 @@ func (channel *Channel) applyModeMember(client *Client, mode ChannelMode,
return false return false
} }
func (channel *Channel) ShowMaskList(client *Client, mode ChannelMode) {
for lmask := range channel.lists[mode].masks {
client.RplMaskList(mode, channel, lmask)
}
client.RplEndOfMaskList(mode, channel)
}
func (channel *Channel) applyModeMask(client *Client, mode ChannelMode, op ModeOp,
mask string) bool {
list := channel.lists[mode]
if list == nil {
// This should never happen, but better safe than panicky.
return false
}
if (op == List) || (mask == "") {
channel.ShowMaskList(client, mode)
return false
}
if !channel.ClientIsOperator(client) {
client.ErrChanOPrivIsNeeded(channel)
return false
}
if op == Add {
return list.Add(mask)
}
if op == Remove {
return list.Remove(mask)
}
return false
}
func (channel *Channel) applyMode(client *Client, change *ChannelModeChange) bool { func (channel *Channel) applyMode(client *Client, change *ChannelModeChange) bool {
switch change.mode { switch change.mode {
case BanMask, ExceptMask, InviteMask: case BanMask, ExceptMask, InviteMask:
// TODO add/remove return channel.applyModeMask(client, change.mode, change.op, change.arg)
for _, mask := range channel.lists[change.mode] { case InviteOnly, Moderated, NoOutside, OpOnlyTopic, Persistent, Private:
client.RplMaskList(change.mode, channel, mask)
}
client.RplEndOfMaskList(change.mode, channel)
case Moderated, NoOutside, OpOnlyTopic, Persistent, Private:
return channel.applyModeFlag(client, change.mode, change.op) return channel.applyModeFlag(client, change.mode, change.op)
case Key: case Key:
@ -390,7 +434,7 @@ func (channel *Channel) Mode(client *Client, changes ChannelModeChanges) {
} }
if err := channel.Persist(); err != nil { if err := channel.Persist(); err != nil {
log.Println(err) log.Println("Channel.Persist:", channel, err)
} }
} }
} }
@ -399,10 +443,12 @@ func (channel *Channel) Persist() (err error) {
if channel.flags[Persistent] { if channel.flags[Persistent] {
_, err = channel.server.db.Exec(` _, err = channel.server.db.Exec(`
INSERT OR REPLACE INTO channel INSERT OR REPLACE INTO channel
(name, flags, key, topic, user_limit) (name, flags, key, topic, user_limit, ban_list, except_list,
VALUES (?, ?, ?, ?, ?)`, invite_list)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
channel.name, channel.flags.String(), channel.key, channel.topic, channel.name, channel.flags.String(), channel.key, channel.topic,
channel.userLimit) channel.userLimit, channel.lists[BanMask].String(),
channel.lists[ExceptMask].String(), channel.lists[InviteMask].String())
} else { } else {
_, err = channel.server.db.Exec(` _, err = channel.server.db.Exec(`
DELETE FROM channel WHERE name = ?`, channel.name) DELETE FROM channel WHERE name = ?`, channel.name)
@ -464,6 +510,13 @@ func (channel *Channel) Invite(invitee *Client, inviter *Client) {
return return
} }
if channel.flags[InviteOnly] {
channel.lists[InviteMask].Add(invitee.UserHost())
if err := channel.Persist(); err != nil {
log.Println("Channel.Persist:", channel, err)
}
}
inviter.RplInviting(invitee, channel.name) inviter.RplInviting(invitee, channel.name)
invitee.Reply(RplInviteMsg(inviter, invitee, channel.name)) invitee.Reply(RplInviteMsg(inviter, invitee, channel.name))
if invitee.flags[Away] { if invitee.flags[Away] {

View File

@ -229,6 +229,7 @@ func (client *Client) ChangeNickname(nickname string) {
// Make reply before changing nick to capture original source id. // Make reply before changing nick to capture original source id.
reply := RplNick(client, nickname) reply := RplNick(client, nickname)
client.server.clients.Remove(client) client.server.clients.Remove(client)
client.server.whoWas.Append(client)
client.nick = nickname client.nick = nickname
client.server.clients.Add(client) client.server.clients.Add(client)
for friend := range client.Friends() { for friend := range client.Friends() {
@ -249,8 +250,8 @@ func (client *Client) Quit(message string) {
} }
client.Reply(RplError("connection closed")) client.Reply(RplError("connection closed"))
client.hasQuit = true client.hasQuit = true
client.server.whoWas.Append(client)
friends := client.Friends() friends := client.Friends()
friends.Remove(client) friends.Remove(client)
client.destroy() client.destroy()

272
irc/client_lookup_set.go Normal file
View File

@ -0,0 +1,272 @@
package irc
import (
"database/sql"
"errors"
"log"
"regexp"
"strings"
)
var (
ErrNickMissing = errors.New("nick missing")
ErrNicknameInUse = errors.New("nickname in use")
ErrNicknameMismatch = errors.New("nickname mismatch")
wildMaskExpr = regexp.MustCompile(`\*|\?`)
likeQuoter = strings.NewReplacer(
`\`, `\\`,
`%`, `\%`,
`_`, `\_`,
`*`, `%`,
`?`, `_`)
)
func HasWildcards(mask string) bool {
return wildMaskExpr.MatchString(mask)
}
func ExpandUserHost(userhost string) (expanded string) {
expanded = userhost
// fill in missing wildcards for nicks
if !strings.Contains(expanded, "!") {
expanded += "!*"
}
if !strings.Contains(expanded, "@") {
expanded += "@*"
}
return
}
func QuoteLike(userhost string) string {
return likeQuoter.Replace(userhost)
}
type ClientLookupSet struct {
byNick map[string]*Client
db *ClientDB
}
func NewClientLookupSet() *ClientLookupSet {
return &ClientLookupSet{
byNick: make(map[string]*Client),
db: NewClientDB(),
}
}
func (clients *ClientLookupSet) Get(nick string) *Client {
return clients.byNick[strings.ToLower(nick)]
}
func (clients *ClientLookupSet) Add(client *Client) error {
if !client.HasNick() {
return ErrNickMissing
}
if clients.Get(client.nick) != nil {
return ErrNicknameInUse
}
clients.byNick[strings.ToLower(client.nick)] = client
clients.db.Add(client)
return nil
}
func (clients *ClientLookupSet) Remove(client *Client) error {
if !client.HasNick() {
return ErrNickMissing
}
if clients.Get(client.nick) != client {
return ErrNicknameMismatch
}
delete(clients.byNick, strings.ToLower(client.nick))
clients.db.Remove(client)
return nil
}
func (clients *ClientLookupSet) FindAll(userhost string) (set ClientSet) {
userhost = ExpandUserHost(userhost)
set = make(ClientSet)
rows, err := clients.db.db.Query(
`SELECT nickname FROM client WHERE userhost LIKE ? ESCAPE '\'`,
QuoteLike(userhost))
if err != nil {
if DEBUG_SERVER {
log.Println("ClientLookupSet.FindAll.Query:", err)
}
return
}
for rows.Next() {
var nickname string
err := rows.Scan(&nickname)
if err != nil {
if DEBUG_SERVER {
log.Println("ClientLookupSet.FindAll.Scan:", err)
}
return
}
client := clients.Get(nickname)
if client == nil {
if DEBUG_SERVER {
log.Println("ClientLookupSet.FindAll: missing client:", nickname)
}
continue
}
set.Add(client)
}
return
}
func (clients *ClientLookupSet) Find(userhost string) *Client {
userhost = ExpandUserHost(userhost)
row := clients.db.db.QueryRow(
`SELECT nickname FROM client WHERE userhost LIKE ? ESCAPE '\' LIMIT 1`,
QuoteLike(userhost))
var nickname string
err := row.Scan(&nickname)
if err != nil {
if DEBUG_SERVER {
log.Println("ClientLookupSet.Find:", err)
}
return nil
}
return clients.Get(nickname)
}
//
// client db
//
type ClientDB struct {
db *sql.DB
}
func NewClientDB() *ClientDB {
db := &ClientDB{
db: OpenDB(":memory:"),
}
stmts := []string{
`CREATE TABLE client (
nickname TEXT NOT NULL COLLATE NOCASE UNIQUE,
userhost TEXT NOT NULL COLLATE NOCASE,
UNIQUE (nickname, userhost) ON CONFLICT REPLACE)`,
`CREATE UNIQUE INDEX idx_nick ON client (nickname COLLATE NOCASE)`,
`CREATE UNIQUE INDEX idx_uh ON client (userhost COLLATE NOCASE)`,
}
for _, stmt := range stmts {
_, err := db.db.Exec(stmt)
if err != nil {
log.Fatal("NewClientDB: ", stmt, err)
}
}
return db
}
func (db *ClientDB) Add(client *Client) {
_, err := db.db.Exec(`INSERT INTO client (nickname, userhost) VALUES (?, ?)`,
client.Nick(), client.UserHost())
if err != nil {
if DEBUG_SERVER {
log.Println("ClientDB.Add:", err)
}
}
}
func (db *ClientDB) Remove(client *Client) {
_, err := db.db.Exec(`DELETE FROM client WHERE nickname = ?`,
client.Nick())
if err != nil {
if DEBUG_SERVER {
log.Println("ClientDB.Remove:", err)
}
}
}
//
// usermask to regexp
//
type UserMaskSet struct {
masks map[string]bool
regexp *regexp.Regexp
}
func NewUserMaskSet() *UserMaskSet {
return &UserMaskSet{
masks: make(map[string]bool),
}
}
func (set *UserMaskSet) Add(mask string) bool {
if set.masks[mask] {
return false
}
set.masks[mask] = true
set.setRegexp()
return true
}
func (set *UserMaskSet) AddAll(masks []string) (added bool) {
for _, mask := range masks {
if !added && !set.masks[mask] {
added = true
}
set.masks[mask] = true
}
set.setRegexp()
return
}
func (set *UserMaskSet) Remove(mask string) bool {
if !set.masks[mask] {
return false
}
delete(set.masks, mask)
set.setRegexp()
return true
}
func (set *UserMaskSet) Match(userhost string) bool {
if set.regexp == nil {
return false
}
return set.regexp.MatchString(userhost)
}
func (set *UserMaskSet) String() string {
masks := make([]string, len(set.masks))
index := 0
for mask := range set.masks {
masks[index] = mask
index += 1
}
return strings.Join(masks, " ")
}
// Generate a regular expression from the set of user mask
// strings. Masks are split at the two types of wildcards, `*` and
// `?`. All the pieces are meta-escaped. `*` is replaced with `.*`,
// the regexp equivalent. Likewise, `?` is replaced with `.`. The
// parts are re-joined and finally all masks are joined into a big
// or-expression.
func (set *UserMaskSet) setRegexp() {
if len(set.masks) == 0 {
set.regexp = nil
return
}
maskExprs := make([]string, len(set.masks))
index := 0
for mask := range set.masks {
manyParts := strings.Split(mask, "*")
manyExprs := make([]string, len(manyParts))
for mindex, manyPart := range manyParts {
oneParts := strings.Split(manyPart, "?")
oneExprs := make([]string, len(oneParts))
for oindex, onePart := range oneParts {
oneExprs[oindex] = regexp.QuoteMeta(onePart)
}
manyExprs[mindex] = strings.Join(oneExprs, ".")
}
maskExprs[index] = strings.Join(manyExprs, ".*")
}
expr := "^" + strings.Join(maskExprs, "|") + "$"
set.regexp, _ = regexp.Compile(expr)
}

View File

@ -54,6 +54,7 @@ var (
VERSION: NewVersionCommand, VERSION: NewVersionCommand,
WHO: NewWhoCommand, WHO: NewWhoCommand,
WHOIS: NewWhoisCommand, WHOIS: NewWhoisCommand,
WHOWAS: NewWhoWasCommand,
} }
) )
@ -656,7 +657,7 @@ func (msg *WhoisCommand) String() string {
type WhoCommand struct { type WhoCommand struct {
BaseCommand BaseCommand
mask Mask mask string
operatorOnly bool operatorOnly bool
} }
@ -665,7 +666,7 @@ func NewWhoCommand(args []string) (editableCommand, error) {
cmd := &WhoCommand{} cmd := &WhoCommand{}
if len(args) > 0 { if len(args) > 0 {
cmd.mask = Mask(args[0]) cmd.mask = args[0]
} }
if (len(args) > 1) && (args[1] == "o") { if (len(args) > 1) && (args[1] == "o") {
@ -982,3 +983,26 @@ func NewKillCommand(args []string) (editableCommand, error) {
comment: args[1], comment: args[1],
}, nil }, nil
} }
type WhoWasCommand struct {
BaseCommand
nicknames []string
count int64
target string
}
func NewWhoWasCommand(args []string) (editableCommand, error) {
if len(args) < 1 {
return nil, NotEnoughArgsError
}
cmd := &WhoWasCommand{
nicknames: strings.Split(args[0], ","),
}
if len(args) > 1 {
cmd.count, _ = strconv.ParseInt(args[1], 10, 64)
}
if len(args) > 2 {
cmd.target = args[2]
}
return cmd, nil
}

View File

@ -61,6 +61,7 @@ const (
VERSION StringCode = "VERSION" VERSION StringCode = "VERSION"
WHO StringCode = "WHO" WHO StringCode = "WHO"
WHOIS StringCode = "WHOIS" WHOIS StringCode = "WHOIS"
WHOWAS StringCode = "WHOWAS"
// numeric codes // numeric codes
RPL_WELCOME NumericCode = 1 RPL_WELCOME NumericCode = 1

View File

@ -2,6 +2,7 @@ package irc
import ( import (
"database/sql" "database/sql"
"fmt"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
"log" "log"
"os" "os"
@ -14,15 +15,30 @@ func InitDB(path string) {
_, err := db.Exec(` _, err := db.Exec(`
CREATE TABLE channel ( CREATE TABLE channel (
name TEXT NOT NULL UNIQUE, name TEXT NOT NULL UNIQUE,
flags TEXT NOT NULL, flags TEXT DEFAULT '',
key TEXT NOT NULL, key TEXT DEFAULT '',
topic TEXT NOT NULL, topic TEXT DEFAULT '',
user_limit INTEGER DEFAULT 0)`) user_limit INTEGER DEFAULT 0,
ban_list TEXT DEFAULT '',
except_list TEXT DEFAULT '',
invite_list TEXT DEFAULT '')`)
if err != nil { if err != nil {
log.Fatal("initdb error: ", err) log.Fatal("initdb error: ", err)
} }
} }
func UpgradeDB(path string) {
db := OpenDB(path)
alter := `ALTER TABLE channel ADD COLUMN %s TEXT DEFAULT ''`
cols := []string{"ban_list", "except_list", "invite_list"}
for _, col := range cols {
_, err := db.Exec(fmt.Sprintf(alter, col))
if err != nil {
log.Fatal("updatedb error: ", err)
}
}
}
func OpenDB(path string) *sql.DB { func OpenDB(path string) *sql.DB {
db, err := sql.Open("sqlite3", path) db, err := sql.Open("sqlite3", path)
if err != nil { if err != nil {

View File

@ -200,6 +200,16 @@ func (target *Client) RplYoureOper() {
":You are now an IRC operator") ":You are now an IRC operator")
} }
func (target *Client) RplWhois(client *Client) {
target.RplWhoisUser(client)
if client.flags[Operator] {
target.RplWhoisOperator(client)
}
target.RplWhoisIdle(client)
target.RplWhoisChannels(client)
target.RplEndOfWhois()
}
func (target *Client) RplWhoisUser(client *Client) { func (target *Client) RplWhoisUser(client *Client) {
target.NumericReply(RPL_WHOISUSER, target.NumericReply(RPL_WHOISUSER,
"%s %s %s * :%s", client.Nick(), client.username, client.hostname, "%s %s %s * :%s", client.Nick(), client.username, client.hostname,
@ -270,7 +280,7 @@ func (target *Client) RplEndOfWho(name string) {
"%s :End of WHO list", name) "%s :End of WHO list", name)
} }
func (target *Client) RplMaskList(mode ChannelMode, channel *Channel, mask UserMask) { func (target *Client) RplMaskList(mode ChannelMode, channel *Channel, mask string) {
switch mode { switch mode {
case BanMask: case BanMask:
target.RplBanList(channel, mask) target.RplBanList(channel, mask)
@ -296,7 +306,7 @@ func (target *Client) RplEndOfMaskList(mode ChannelMode, channel *Channel) {
} }
} }
func (target *Client) RplBanList(channel *Channel, mask UserMask) { func (target *Client) RplBanList(channel *Channel, mask string) {
target.NumericReply(RPL_BANLIST, target.NumericReply(RPL_BANLIST,
"%s %s", channel, mask) "%s %s", channel, mask)
} }
@ -306,7 +316,7 @@ func (target *Client) RplEndOfBanList(channel *Channel) {
"%s :End of channel ban list", channel) "%s :End of channel ban list", channel)
} }
func (target *Client) RplExceptList(channel *Channel, mask UserMask) { func (target *Client) RplExceptList(channel *Channel, mask string) {
target.NumericReply(RPL_EXCEPTLIST, target.NumericReply(RPL_EXCEPTLIST,
"%s %s", channel, mask) "%s %s", channel, mask)
} }
@ -316,7 +326,7 @@ func (target *Client) RplEndOfExceptList(channel *Channel) {
"%s :End of channel exception list", channel) "%s :End of channel exception list", channel)
} }
func (target *Client) RplInviteList(channel *Channel, mask UserMask) { func (target *Client) RplInviteList(channel *Channel, mask string) {
target.NumericReply(RPL_INVITELIST, target.NumericReply(RPL_INVITELIST,
"%s %s", channel, mask) "%s %s", channel, mask)
} }
@ -396,6 +406,17 @@ func (target *Client) RplTime() {
"%s :%s", target.server.name, time.Now().Format(time.RFC1123)) "%s :%s", target.server.name, time.Now().Format(time.RFC1123))
} }
func (target *Client) RplWhoWasUser(whoWas *WhoWas) {
target.NumericReply(RPL_WHOWASUSER,
"%s %s %s * :%s",
whoWas.nickname, whoWas.username, whoWas.hostname, whoWas.realname)
}
func (target *Client) RplEndOfWhoWas(nickname string) {
target.NumericReply(RPL_ENDOFWHOWAS,
"%s :End of WHOWAS", nickname)
}
// //
// errors (also numeric) // errors (also numeric)
// //
@ -515,7 +536,22 @@ func (target *Client) ErrChannelIsFull(channel *Channel) {
"%s :Cannot join channel (+l)", channel) "%s :Cannot join channel (+l)", channel)
} }
func (target *Client) ErrWasNoSuchNick(nickname string) {
target.NumericReply(ERR_WASNOSUCHNICK,
"%s :There was no such nickname", nickname)
}
func (target *Client) ErrInvalidCapCmd(subCommand CapSubCommand) { func (target *Client) ErrInvalidCapCmd(subCommand CapSubCommand) {
target.NumericReply(ERR_INVALIDCAPCMD, target.NumericReply(ERR_INVALIDCAPCMD,
"%s :Invalid CAP subcommand", subCommand) "%s :Invalid CAP subcommand", subCommand)
} }
func (target *Client) ErrBannedFromChan(channel *Channel) {
target.NumericReply(ERR_BANNEDFROMCHAN,
"%s :Cannot join channel (+b)", channel)
}
func (target *Client) ErrInviteOnlyChan(channel *Channel) {
target.NumericReply(ERR_INVITEONLYCHAN,
"%s :Cannot join channel (+i)", channel)
}

View File

@ -18,7 +18,7 @@ import (
type Server struct { type Server struct {
channels ChannelNameMap channels ChannelNameMap
clients ClientNameMap clients *ClientLookupSet
commands chan Command commands chan Command
ctime time.Time ctime time.Time
db *sql.DB db *sql.DB
@ -29,12 +29,13 @@ type Server struct {
operators map[string][]byte operators map[string][]byte
password []byte password []byte
signals chan os.Signal signals chan os.Signal
whoWas *WhoWasList
} }
func NewServer(config *Config) *Server { func NewServer(config *Config) *Server {
server := &Server{ server := &Server{
channels: make(ChannelNameMap), channels: make(ChannelNameMap),
clients: make(ClientNameMap), clients: NewClientLookupSet(),
commands: make(chan Command, 16), commands: make(chan Command, 16),
ctime: time.Now(), ctime: time.Now(),
db: OpenDB(config.Server.Database), db: OpenDB(config.Server.Database),
@ -44,6 +45,7 @@ func NewServer(config *Config) *Server {
newConns: make(chan net.Conn, 16), newConns: make(chan net.Conn, 16),
operators: config.Operators(), operators: config.Operators(),
signals: make(chan os.Signal, 1), signals: make(chan os.Signal, 1),
whoWas: NewWhoWasList(100),
} }
if config.Server.Password != "" { if config.Server.Password != "" {
@ -62,9 +64,17 @@ func NewServer(config *Config) *Server {
return server return server
} }
func loadChannelList(channel *Channel, list string, maskMode ChannelMode) {
if list == "" {
return
}
channel.lists[maskMode].AddAll(strings.Split(list, " "))
}
func (server *Server) loadChannels() { func (server *Server) loadChannels() {
rows, err := server.db.Query(` rows, err := server.db.Query(`
SELECT name, flags, key, topic, user_limit SELECT name, flags, key, topic, user_limit, ban_list, except_list,
invite_list
FROM channel`) FROM channel`)
if err != nil { if err != nil {
log.Fatal("error loading channels: ", err) log.Fatal("error loading channels: ", err)
@ -72,9 +82,11 @@ func (server *Server) loadChannels() {
for rows.Next() { for rows.Next() {
var name, flags, key, topic string var name, flags, key, topic string
var userLimit uint64 var userLimit uint64
err = rows.Scan(&name, &flags, &key, &topic, &userLimit) var banList, exceptList, inviteList string
err = rows.Scan(&name, &flags, &key, &topic, &userLimit, &banList,
&exceptList, &inviteList)
if err != nil { if err != nil {
log.Println(err) log.Println("Server.loadChannels:", err)
continue continue
} }
@ -85,6 +97,9 @@ func (server *Server) loadChannels() {
channel.key = key channel.key = key
channel.topic = topic channel.topic = topic
channel.userLimit = userLimit channel.userLimit = userLimit
loadChannelList(channel, banList, BanMask)
loadChannelList(channel, exceptList, ExceptMask)
loadChannelList(channel, inviteList, InviteMask)
} }
} }
@ -126,7 +141,7 @@ func (server *Server) processCommand(cmd Command) {
func (server *Server) Shutdown() { func (server *Server) Shutdown() {
server.db.Close() server.db.Close()
for _, client := range server.clients { for _, client := range server.clients.byNick {
client.Reply(RplNotice(server, client, "shutting down")) client.Reply(RplNotice(server, client, "shutting down"))
} }
} }
@ -340,7 +355,7 @@ func (msg *RFC1459UserCommand) HandleRegServer(server *Server) {
client.Quit("bad password") client.Quit("bad password")
return return
} }
msg.HandleRegServer2(server) msg.setUserInfo(server)
} }
func (msg *RFC2812UserCommand) HandleRegServer(server *Server) { func (msg *RFC2812UserCommand) HandleRegServer(server *Server) {
@ -357,15 +372,19 @@ func (msg *RFC2812UserCommand) HandleRegServer(server *Server) {
} }
client.RplUModeIs(client) client.RplUModeIs(client)
} }
msg.HandleRegServer2(server) msg.setUserInfo(server)
} }
func (msg *UserCommand) HandleRegServer2(server *Server) { func (msg *UserCommand) setUserInfo(server *Server) {
client := msg.Client() client := msg.Client()
if client.capState == CapNegotiating { if client.capState == CapNegotiating {
client.capState = CapNegotiated client.capState = CapNegotiated
} }
server.clients.Remove(client)
client.username, client.realname = msg.username, msg.realname client.username, client.realname = msg.username, msg.realname
server.clients.Add(client)
server.tryRegister(client) server.tryRegister(client)
} }
@ -514,7 +533,7 @@ func (m *ModeCommand) HandleServer(s *Server) {
return return
} }
changes := make(ModeChanges, 0) changes := make(ModeChanges, 0, len(m.changes))
for _, change := range m.changes { for _, change := range m.changes {
switch change.mode { switch change.mode {
@ -577,19 +596,14 @@ func (m *WhoisCommand) HandleServer(server *Server) {
// TODO implement target query // TODO implement target query
for _, mask := range m.masks { for _, mask := range m.masks {
// TODO implement wildcard matching matches := server.clients.FindAll(mask)
mclient := server.clients.Get(mask) if len(matches) == 0 {
if mclient == nil {
client.ErrNoSuchNick(mask) client.ErrNoSuchNick(mask)
continue continue
} }
client.RplWhoisUser(mclient) for mclient := range matches {
if mclient.flags[Operator] { client.RplWhois(mclient)
client.RplWhoisOperator(mclient)
} }
client.RplWhoisIdle(mclient)
client.RplWhoisChannels(mclient)
client.RplEndOfWhois()
} }
} }
@ -604,9 +618,9 @@ func (msg *ChannelModeCommand) HandleServer(server *Server) {
channel.Mode(client, msg.changes) channel.Mode(client, msg.changes)
} }
func whoChannel(client *Client, channel *Channel) { func whoChannel(client *Client, channel *Channel, friends ClientSet) {
for member := range channel.members { for member := range channel.members {
if !client.flags[Invisible] { if !client.flags[Invisible] || friends[client] {
client.RplWhoReply(channel, member) client.RplWhoReply(channel, member)
} }
} }
@ -614,27 +628,21 @@ func whoChannel(client *Client, channel *Channel) {
func (msg *WhoCommand) HandleServer(server *Server) { func (msg *WhoCommand) HandleServer(server *Server) {
client := msg.Client() client := msg.Client()
friends := client.Friends()
mask := msg.mask
// TODO implement wildcard matching
mask := string(msg.mask)
if mask == "" { if mask == "" {
for _, channel := range server.channels { for _, channel := range server.channels {
for member := range channel.members { whoChannel(client, channel, friends)
if !client.flags[Invisible] {
client.RplWhoReply(channel, member)
}
}
} }
} else if IsChannel(mask) { } else if IsChannel(mask) {
// TODO implement wildcard matching
channel := server.channels.Get(mask) channel := server.channels.Get(mask)
if channel != nil { if channel != nil {
for member := range channel.members { whoChannel(client, channel, friends)
client.RplWhoReply(channel, member)
}
} }
} else { } else {
mclient := server.clients.Get(mask) for mclient := range server.clients.FindAll(mask) {
if mclient != nil {
client.RplWhoReply(nil, mclient) client.RplWhoReply(nil, mclient)
} }
} }
@ -874,3 +882,18 @@ func (msg *KillCommand) HandleServer(server *Server) {
quitMsg := fmt.Sprintf("KILLed by %s: %s", client.Nick(), msg.comment) quitMsg := fmt.Sprintf("KILLed by %s: %s", client.Nick(), msg.comment)
target.Quit(quitMsg) target.Quit(quitMsg)
} }
func (msg *WhoWasCommand) HandleServer(server *Server) {
client := msg.Client()
for _, nickname := range msg.nicknames {
results := server.whoWas.Find(nickname, msg.count)
if len(results) == 0 {
client.ErrWasNoSuchNick(nickname)
} else {
for _, whoWas := range results {
client.RplWhoWasUser(whoWas)
}
}
client.RplEndOfWhoWas(nickname)
}
}

View File

@ -1,7 +1,6 @@
package irc package irc
import ( import (
"errors"
"fmt" "fmt"
"strings" "strings"
) )
@ -48,9 +47,6 @@ func (set CapabilitySet) DisableString() string {
return strings.Join(parts, " ") return strings.Join(parts, " ")
} }
// a string with wildcards
type Mask string
// add, remove, list modes // add, remove, list modes
type ModeOp rune type ModeOp rune
@ -112,40 +108,6 @@ func (channels ChannelNameMap) Remove(channel *Channel) error {
return nil return nil
} }
type ClientNameMap map[string]*Client
var (
ErrNickMissing = errors.New("nick missing")
ErrNicknameInUse = errors.New("nickname in use")
ErrNicknameMismatch = errors.New("nickname mismatch")
)
func (clients ClientNameMap) Get(nick string) *Client {
return clients[strings.ToLower(nick)]
}
func (clients ClientNameMap) Add(client *Client) error {
if !client.HasNick() {
return ErrNickMissing
}
if clients.Get(client.nick) != nil {
return ErrNicknameInUse
}
clients[strings.ToLower(client.nick)] = client
return nil
}
func (clients ClientNameMap) Remove(client *Client) error {
if !client.HasNick() {
return ErrNickMissing
}
if clients.Get(client.nick) != client {
return ErrNicknameMismatch
}
delete(clients, strings.ToLower(client.nick))
return nil
}
type ChannelModeSet map[ChannelMode]bool type ChannelModeSet map[ChannelMode]bool
func (set ChannelModeSet) String() string { func (set ChannelModeSet) String() string {
@ -247,17 +209,3 @@ type RegServerCommand interface {
Command Command
HandleRegServer(*Server) HandleRegServer(*Server)
} }
//
// structs
//
type UserMask struct {
nickname Mask
username Mask
hostname Mask
}
func (mask *UserMask) String() string {
return fmt.Sprintf("%s!%s@%s", mask.nickname, mask.username, mask.hostname)
}

73
irc/whowas.go Normal file
View File

@ -0,0 +1,73 @@
package irc
type WhoWasList struct {
buffer []*WhoWas
start uint
end uint
}
type WhoWas struct {
nickname string
username string
hostname string
realname string
}
func NewWhoWasList(size uint) *WhoWasList {
return &WhoWasList{
buffer: make([]*WhoWas, size),
}
}
func (list *WhoWasList) Append(client *Client) {
list.buffer[list.end] = &WhoWas{
nickname: client.Nick(),
username: client.username,
hostname: client.hostname,
realname: client.realname,
}
list.end = (list.end + 1) % uint(len(list.buffer))
if list.end == list.start {
list.start = (list.end + 1) % uint(len(list.buffer))
}
}
func (list *WhoWasList) Find(nickname string, limit int64) []*WhoWas {
results := make([]*WhoWas, 0)
for whoWas := range list.Each() {
if nickname != whoWas.nickname {
continue
}
results = append(results, whoWas)
if int64(len(results)) >= limit {
break
}
}
return results
}
func (list *WhoWasList) prev(index uint) uint {
index -= 1
if index < 0 {
index += uint(len(list.buffer))
}
return index
}
// Iterate the buffer in reverse.
func (list *WhoWasList) Each() <-chan *WhoWas {
ch := make(chan *WhoWas)
go func() {
defer close(ch)
if list.start == list.end {
return
}
start := list.prev(list.end)
end := list.prev(list.start)
for start != end {
ch <- list.buffer[start]
start = list.prev(start)
}
}()
return ch
}