mirror of
https://github.com/ergochat/ergo.git
synced 2024-11-10 22:19:31 +01:00
Merge pull request #16 from jlatt/user-mask
support user mask wildcards through an in-memory sqlite db
This commit is contained in:
commit
bc3480ebb8
@ -12,6 +12,7 @@ import (
|
||||
func main() {
|
||||
conf := flag.String("conf", "ergonomadic.conf", "ergonomadic config file")
|
||||
initdb := flag.Bool("initdb", false, "initialize database")
|
||||
upgradedb := flag.Bool("upgradedb", false, "update database")
|
||||
passwd := flag.String("genpasswd", "", "bcrypt a password")
|
||||
flag.Parse()
|
||||
|
||||
@ -35,7 +36,13 @@ func main() {
|
||||
|
||||
if *initdb {
|
||||
irc.InitDB(config.Server.Database)
|
||||
log.Println("database initialized: " + config.Server.Database)
|
||||
log.Println("database initialized: ", config.Server.Database)
|
||||
return
|
||||
}
|
||||
|
||||
if *upgradedb {
|
||||
irc.UpgradeDB(config.Server.Database)
|
||||
log.Println("database upgraded: ", config.Server.Database)
|
||||
return
|
||||
}
|
||||
|
||||
@ -45,5 +52,8 @@ func main() {
|
||||
irc.DEBUG_CHANNEL = config.Debug.Channel
|
||||
irc.DEBUG_SERVER = config.Debug.Server
|
||||
|
||||
irc.NewServer(config).Run()
|
||||
server := irc.NewServer(config)
|
||||
log.Println(irc.SEM_VER, "running")
|
||||
defer log.Println(irc.SEM_VER, "exiting")
|
||||
server.Run()
|
||||
}
|
||||
|
@ -8,7 +8,7 @@ import (
|
||||
|
||||
type Channel struct {
|
||||
flags ChannelModeSet
|
||||
lists map[ChannelMode][]UserMask
|
||||
lists map[ChannelMode]*UserMaskSet
|
||||
key string
|
||||
members MemberSet
|
||||
name string
|
||||
@ -26,10 +26,10 @@ func IsChannel(target string) bool {
|
||||
func NewChannel(s *Server, name string) *Channel {
|
||||
channel := &Channel{
|
||||
flags: make(ChannelModeSet),
|
||||
lists: map[ChannelMode][]UserMask{
|
||||
BanMask: []UserMask{},
|
||||
ExceptMask: []UserMask{},
|
||||
InviteMask: []UserMask{},
|
||||
lists: map[ChannelMode]*UserMaskSet{
|
||||
BanMask: NewUserMaskSet(),
|
||||
ExceptMask: NewUserMaskSet(),
|
||||
InviteMask: NewUserMaskSet(),
|
||||
},
|
||||
members: make(MemberSet),
|
||||
name: strings.ToLower(name),
|
||||
@ -151,6 +151,19 @@ func (channel *Channel) Join(client *Client, key string) {
|
||||
return
|
||||
}
|
||||
|
||||
isInvited := channel.lists[InviteMask].Match(client.UserHost())
|
||||
if channel.flags[InviteOnly] && !isInvited {
|
||||
client.ErrInviteOnlyChan(channel)
|
||||
return
|
||||
}
|
||||
|
||||
if channel.lists[BanMask].Match(client.UserHost()) &&
|
||||
!isInvited &&
|
||||
!channel.lists[ExceptMask].Match(client.UserHost()) {
|
||||
client.ErrBannedFromChan(channel)
|
||||
return
|
||||
}
|
||||
|
||||
client.channels.Add(channel)
|
||||
channel.members.Add(client)
|
||||
if !channel.flags[Persistent] && (len(channel.members) == 1) {
|
||||
@ -213,7 +226,7 @@ func (channel *Channel) SetTopic(client *Client, topic string) {
|
||||
}
|
||||
|
||||
if err := channel.Persist(); err != nil {
|
||||
log.Println(err)
|
||||
log.Println("Channel.Persist:", channel, err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -310,17 +323,48 @@ func (channel *Channel) applyModeMember(client *Client, mode ChannelMode,
|
||||
return false
|
||||
}
|
||||
|
||||
func (channel *Channel) ShowMaskList(client *Client, mode ChannelMode) {
|
||||
for lmask := range channel.lists[mode].masks {
|
||||
client.RplMaskList(mode, channel, lmask)
|
||||
}
|
||||
client.RplEndOfMaskList(mode, channel)
|
||||
}
|
||||
|
||||
func (channel *Channel) applyModeMask(client *Client, mode ChannelMode, op ModeOp,
|
||||
mask string) bool {
|
||||
list := channel.lists[mode]
|
||||
if list == nil {
|
||||
// This should never happen, but better safe than panicky.
|
||||
return false
|
||||
}
|
||||
|
||||
if (op == List) || (mask == "") {
|
||||
channel.ShowMaskList(client, mode)
|
||||
return false
|
||||
}
|
||||
|
||||
if !channel.ClientIsOperator(client) {
|
||||
client.ErrChanOPrivIsNeeded(channel)
|
||||
return false
|
||||
}
|
||||
|
||||
if op == Add {
|
||||
return list.Add(mask)
|
||||
}
|
||||
|
||||
if op == Remove {
|
||||
return list.Remove(mask)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (channel *Channel) applyMode(client *Client, change *ChannelModeChange) bool {
|
||||
switch change.mode {
|
||||
case BanMask, ExceptMask, InviteMask:
|
||||
// TODO add/remove
|
||||
return channel.applyModeMask(client, change.mode, change.op, change.arg)
|
||||
|
||||
for _, mask := range channel.lists[change.mode] {
|
||||
client.RplMaskList(change.mode, channel, mask)
|
||||
}
|
||||
client.RplEndOfMaskList(change.mode, channel)
|
||||
|
||||
case Moderated, NoOutside, OpOnlyTopic, Persistent, Private:
|
||||
case InviteOnly, Moderated, NoOutside, OpOnlyTopic, Persistent, Private:
|
||||
return channel.applyModeFlag(client, change.mode, change.op)
|
||||
|
||||
case Key:
|
||||
@ -390,7 +434,7 @@ func (channel *Channel) Mode(client *Client, changes ChannelModeChanges) {
|
||||
}
|
||||
|
||||
if err := channel.Persist(); err != nil {
|
||||
log.Println(err)
|
||||
log.Println("Channel.Persist:", channel, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -399,10 +443,12 @@ func (channel *Channel) Persist() (err error) {
|
||||
if channel.flags[Persistent] {
|
||||
_, err = channel.server.db.Exec(`
|
||||
INSERT OR REPLACE INTO channel
|
||||
(name, flags, key, topic, user_limit)
|
||||
VALUES (?, ?, ?, ?, ?)`,
|
||||
(name, flags, key, topic, user_limit, ban_list, except_list,
|
||||
invite_list)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
channel.name, channel.flags.String(), channel.key, channel.topic,
|
||||
channel.userLimit)
|
||||
channel.userLimit, channel.lists[BanMask].String(),
|
||||
channel.lists[ExceptMask].String(), channel.lists[InviteMask].String())
|
||||
} else {
|
||||
_, err = channel.server.db.Exec(`
|
||||
DELETE FROM channel WHERE name = ?`, channel.name)
|
||||
@ -464,6 +510,13 @@ func (channel *Channel) Invite(invitee *Client, inviter *Client) {
|
||||
return
|
||||
}
|
||||
|
||||
if channel.flags[InviteOnly] {
|
||||
channel.lists[InviteMask].Add(invitee.UserHost())
|
||||
if err := channel.Persist(); err != nil {
|
||||
log.Println("Channel.Persist:", channel, err)
|
||||
}
|
||||
}
|
||||
|
||||
inviter.RplInviting(invitee, channel.name)
|
||||
invitee.Reply(RplInviteMsg(inviter, invitee, channel.name))
|
||||
if invitee.flags[Away] {
|
||||
|
@ -229,6 +229,7 @@ func (client *Client) ChangeNickname(nickname string) {
|
||||
// Make reply before changing nick to capture original source id.
|
||||
reply := RplNick(client, nickname)
|
||||
client.server.clients.Remove(client)
|
||||
client.server.whoWas.Append(client)
|
||||
client.nick = nickname
|
||||
client.server.clients.Add(client)
|
||||
for friend := range client.Friends() {
|
||||
@ -249,8 +250,8 @@ func (client *Client) Quit(message string) {
|
||||
}
|
||||
|
||||
client.Reply(RplError("connection closed"))
|
||||
|
||||
client.hasQuit = true
|
||||
client.server.whoWas.Append(client)
|
||||
friends := client.Friends()
|
||||
friends.Remove(client)
|
||||
client.destroy()
|
||||
|
272
irc/client_lookup_set.go
Normal file
272
irc/client_lookup_set.go
Normal file
@ -0,0 +1,272 @@
|
||||
package irc
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"log"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNickMissing = errors.New("nick missing")
|
||||
ErrNicknameInUse = errors.New("nickname in use")
|
||||
ErrNicknameMismatch = errors.New("nickname mismatch")
|
||||
wildMaskExpr = regexp.MustCompile(`\*|\?`)
|
||||
likeQuoter = strings.NewReplacer(
|
||||
`\`, `\\`,
|
||||
`%`, `\%`,
|
||||
`_`, `\_`,
|
||||
`*`, `%`,
|
||||
`?`, `_`)
|
||||
)
|
||||
|
||||
func HasWildcards(mask string) bool {
|
||||
return wildMaskExpr.MatchString(mask)
|
||||
}
|
||||
|
||||
func ExpandUserHost(userhost string) (expanded string) {
|
||||
expanded = userhost
|
||||
// fill in missing wildcards for nicks
|
||||
if !strings.Contains(expanded, "!") {
|
||||
expanded += "!*"
|
||||
}
|
||||
if !strings.Contains(expanded, "@") {
|
||||
expanded += "@*"
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func QuoteLike(userhost string) string {
|
||||
return likeQuoter.Replace(userhost)
|
||||
}
|
||||
|
||||
type ClientLookupSet struct {
|
||||
byNick map[string]*Client
|
||||
db *ClientDB
|
||||
}
|
||||
|
||||
func NewClientLookupSet() *ClientLookupSet {
|
||||
return &ClientLookupSet{
|
||||
byNick: make(map[string]*Client),
|
||||
db: NewClientDB(),
|
||||
}
|
||||
}
|
||||
|
||||
func (clients *ClientLookupSet) Get(nick string) *Client {
|
||||
return clients.byNick[strings.ToLower(nick)]
|
||||
}
|
||||
|
||||
func (clients *ClientLookupSet) Add(client *Client) error {
|
||||
if !client.HasNick() {
|
||||
return ErrNickMissing
|
||||
}
|
||||
if clients.Get(client.nick) != nil {
|
||||
return ErrNicknameInUse
|
||||
}
|
||||
clients.byNick[strings.ToLower(client.nick)] = client
|
||||
clients.db.Add(client)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (clients *ClientLookupSet) Remove(client *Client) error {
|
||||
if !client.HasNick() {
|
||||
return ErrNickMissing
|
||||
}
|
||||
if clients.Get(client.nick) != client {
|
||||
return ErrNicknameMismatch
|
||||
}
|
||||
delete(clients.byNick, strings.ToLower(client.nick))
|
||||
clients.db.Remove(client)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (clients *ClientLookupSet) FindAll(userhost string) (set ClientSet) {
|
||||
userhost = ExpandUserHost(userhost)
|
||||
set = make(ClientSet)
|
||||
rows, err := clients.db.db.Query(
|
||||
`SELECT nickname FROM client WHERE userhost LIKE ? ESCAPE '\'`,
|
||||
QuoteLike(userhost))
|
||||
if err != nil {
|
||||
if DEBUG_SERVER {
|
||||
log.Println("ClientLookupSet.FindAll.Query:", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
for rows.Next() {
|
||||
var nickname string
|
||||
err := rows.Scan(&nickname)
|
||||
if err != nil {
|
||||
if DEBUG_SERVER {
|
||||
log.Println("ClientLookupSet.FindAll.Scan:", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
client := clients.Get(nickname)
|
||||
if client == nil {
|
||||
if DEBUG_SERVER {
|
||||
log.Println("ClientLookupSet.FindAll: missing client:", nickname)
|
||||
}
|
||||
continue
|
||||
}
|
||||
set.Add(client)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (clients *ClientLookupSet) Find(userhost string) *Client {
|
||||
userhost = ExpandUserHost(userhost)
|
||||
row := clients.db.db.QueryRow(
|
||||
`SELECT nickname FROM client WHERE userhost LIKE ? ESCAPE '\' LIMIT 1`,
|
||||
QuoteLike(userhost))
|
||||
var nickname string
|
||||
err := row.Scan(&nickname)
|
||||
if err != nil {
|
||||
if DEBUG_SERVER {
|
||||
log.Println("ClientLookupSet.Find:", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return clients.Get(nickname)
|
||||
}
|
||||
|
||||
//
|
||||
// client db
|
||||
//
|
||||
|
||||
type ClientDB struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func NewClientDB() *ClientDB {
|
||||
db := &ClientDB{
|
||||
db: OpenDB(":memory:"),
|
||||
}
|
||||
stmts := []string{
|
||||
`CREATE TABLE client (
|
||||
nickname TEXT NOT NULL COLLATE NOCASE UNIQUE,
|
||||
userhost TEXT NOT NULL COLLATE NOCASE,
|
||||
UNIQUE (nickname, userhost) ON CONFLICT REPLACE)`,
|
||||
`CREATE UNIQUE INDEX idx_nick ON client (nickname COLLATE NOCASE)`,
|
||||
`CREATE UNIQUE INDEX idx_uh ON client (userhost COLLATE NOCASE)`,
|
||||
}
|
||||
for _, stmt := range stmts {
|
||||
_, err := db.db.Exec(stmt)
|
||||
if err != nil {
|
||||
log.Fatal("NewClientDB: ", stmt, err)
|
||||
}
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
func (db *ClientDB) Add(client *Client) {
|
||||
_, err := db.db.Exec(`INSERT INTO client (nickname, userhost) VALUES (?, ?)`,
|
||||
client.Nick(), client.UserHost())
|
||||
if err != nil {
|
||||
if DEBUG_SERVER {
|
||||
log.Println("ClientDB.Add:", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (db *ClientDB) Remove(client *Client) {
|
||||
_, err := db.db.Exec(`DELETE FROM client WHERE nickname = ?`,
|
||||
client.Nick())
|
||||
if err != nil {
|
||||
if DEBUG_SERVER {
|
||||
log.Println("ClientDB.Remove:", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// usermask to regexp
|
||||
//
|
||||
|
||||
type UserMaskSet struct {
|
||||
masks map[string]bool
|
||||
regexp *regexp.Regexp
|
||||
}
|
||||
|
||||
func NewUserMaskSet() *UserMaskSet {
|
||||
return &UserMaskSet{
|
||||
masks: make(map[string]bool),
|
||||
}
|
||||
}
|
||||
|
||||
func (set *UserMaskSet) Add(mask string) bool {
|
||||
if set.masks[mask] {
|
||||
return false
|
||||
}
|
||||
set.masks[mask] = true
|
||||
set.setRegexp()
|
||||
return true
|
||||
}
|
||||
|
||||
func (set *UserMaskSet) AddAll(masks []string) (added bool) {
|
||||
for _, mask := range masks {
|
||||
if !added && !set.masks[mask] {
|
||||
added = true
|
||||
}
|
||||
set.masks[mask] = true
|
||||
}
|
||||
set.setRegexp()
|
||||
return
|
||||
}
|
||||
|
||||
func (set *UserMaskSet) Remove(mask string) bool {
|
||||
if !set.masks[mask] {
|
||||
return false
|
||||
}
|
||||
delete(set.masks, mask)
|
||||
set.setRegexp()
|
||||
return true
|
||||
}
|
||||
|
||||
func (set *UserMaskSet) Match(userhost string) bool {
|
||||
if set.regexp == nil {
|
||||
return false
|
||||
}
|
||||
return set.regexp.MatchString(userhost)
|
||||
}
|
||||
|
||||
func (set *UserMaskSet) String() string {
|
||||
masks := make([]string, len(set.masks))
|
||||
index := 0
|
||||
for mask := range set.masks {
|
||||
masks[index] = mask
|
||||
index += 1
|
||||
}
|
||||
return strings.Join(masks, " ")
|
||||
}
|
||||
|
||||
// Generate a regular expression from the set of user mask
|
||||
// strings. Masks are split at the two types of wildcards, `*` and
|
||||
// `?`. All the pieces are meta-escaped. `*` is replaced with `.*`,
|
||||
// the regexp equivalent. Likewise, `?` is replaced with `.`. The
|
||||
// parts are re-joined and finally all masks are joined into a big
|
||||
// or-expression.
|
||||
func (set *UserMaskSet) setRegexp() {
|
||||
if len(set.masks) == 0 {
|
||||
set.regexp = nil
|
||||
return
|
||||
}
|
||||
|
||||
maskExprs := make([]string, len(set.masks))
|
||||
index := 0
|
||||
for mask := range set.masks {
|
||||
manyParts := strings.Split(mask, "*")
|
||||
manyExprs := make([]string, len(manyParts))
|
||||
for mindex, manyPart := range manyParts {
|
||||
oneParts := strings.Split(manyPart, "?")
|
||||
oneExprs := make([]string, len(oneParts))
|
||||
for oindex, onePart := range oneParts {
|
||||
oneExprs[oindex] = regexp.QuoteMeta(onePart)
|
||||
}
|
||||
manyExprs[mindex] = strings.Join(oneExprs, ".")
|
||||
}
|
||||
maskExprs[index] = strings.Join(manyExprs, ".*")
|
||||
}
|
||||
expr := "^" + strings.Join(maskExprs, "|") + "$"
|
||||
set.regexp, _ = regexp.Compile(expr)
|
||||
}
|
@ -54,6 +54,7 @@ var (
|
||||
VERSION: NewVersionCommand,
|
||||
WHO: NewWhoCommand,
|
||||
WHOIS: NewWhoisCommand,
|
||||
WHOWAS: NewWhoWasCommand,
|
||||
}
|
||||
)
|
||||
|
||||
@ -656,7 +657,7 @@ func (msg *WhoisCommand) String() string {
|
||||
|
||||
type WhoCommand struct {
|
||||
BaseCommand
|
||||
mask Mask
|
||||
mask string
|
||||
operatorOnly bool
|
||||
}
|
||||
|
||||
@ -665,7 +666,7 @@ func NewWhoCommand(args []string) (editableCommand, error) {
|
||||
cmd := &WhoCommand{}
|
||||
|
||||
if len(args) > 0 {
|
||||
cmd.mask = Mask(args[0])
|
||||
cmd.mask = args[0]
|
||||
}
|
||||
|
||||
if (len(args) > 1) && (args[1] == "o") {
|
||||
@ -982,3 +983,26 @@ func NewKillCommand(args []string) (editableCommand, error) {
|
||||
comment: args[1],
|
||||
}, nil
|
||||
}
|
||||
|
||||
type WhoWasCommand struct {
|
||||
BaseCommand
|
||||
nicknames []string
|
||||
count int64
|
||||
target string
|
||||
}
|
||||
|
||||
func NewWhoWasCommand(args []string) (editableCommand, error) {
|
||||
if len(args) < 1 {
|
||||
return nil, NotEnoughArgsError
|
||||
}
|
||||
cmd := &WhoWasCommand{
|
||||
nicknames: strings.Split(args[0], ","),
|
||||
}
|
||||
if len(args) > 1 {
|
||||
cmd.count, _ = strconv.ParseInt(args[1], 10, 64)
|
||||
}
|
||||
if len(args) > 2 {
|
||||
cmd.target = args[2]
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
|
@ -61,6 +61,7 @@ const (
|
||||
VERSION StringCode = "VERSION"
|
||||
WHO StringCode = "WHO"
|
||||
WHOIS StringCode = "WHOIS"
|
||||
WHOWAS StringCode = "WHOWAS"
|
||||
|
||||
// numeric codes
|
||||
RPL_WELCOME NumericCode = 1
|
||||
|
@ -2,6 +2,7 @@ package irc
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"log"
|
||||
"os"
|
||||
@ -14,15 +15,30 @@ func InitDB(path string) {
|
||||
_, err := db.Exec(`
|
||||
CREATE TABLE channel (
|
||||
name TEXT NOT NULL UNIQUE,
|
||||
flags TEXT NOT NULL,
|
||||
key TEXT NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
user_limit INTEGER DEFAULT 0)`)
|
||||
flags TEXT DEFAULT '',
|
||||
key TEXT DEFAULT '',
|
||||
topic TEXT DEFAULT '',
|
||||
user_limit INTEGER DEFAULT 0,
|
||||
ban_list TEXT DEFAULT '',
|
||||
except_list TEXT DEFAULT '',
|
||||
invite_list TEXT DEFAULT '')`)
|
||||
if err != nil {
|
||||
log.Fatal("initdb error: ", err)
|
||||
}
|
||||
}
|
||||
|
||||
func UpgradeDB(path string) {
|
||||
db := OpenDB(path)
|
||||
alter := `ALTER TABLE channel ADD COLUMN %s TEXT DEFAULT ''`
|
||||
cols := []string{"ban_list", "except_list", "invite_list"}
|
||||
for _, col := range cols {
|
||||
_, err := db.Exec(fmt.Sprintf(alter, col))
|
||||
if err != nil {
|
||||
log.Fatal("updatedb error: ", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func OpenDB(path string) *sql.DB {
|
||||
db, err := sql.Open("sqlite3", path)
|
||||
if err != nil {
|
||||
|
44
irc/reply.go
44
irc/reply.go
@ -200,6 +200,16 @@ func (target *Client) RplYoureOper() {
|
||||
":You are now an IRC operator")
|
||||
}
|
||||
|
||||
func (target *Client) RplWhois(client *Client) {
|
||||
target.RplWhoisUser(client)
|
||||
if client.flags[Operator] {
|
||||
target.RplWhoisOperator(client)
|
||||
}
|
||||
target.RplWhoisIdle(client)
|
||||
target.RplWhoisChannels(client)
|
||||
target.RplEndOfWhois()
|
||||
}
|
||||
|
||||
func (target *Client) RplWhoisUser(client *Client) {
|
||||
target.NumericReply(RPL_WHOISUSER,
|
||||
"%s %s %s * :%s", client.Nick(), client.username, client.hostname,
|
||||
@ -270,7 +280,7 @@ func (target *Client) RplEndOfWho(name string) {
|
||||
"%s :End of WHO list", name)
|
||||
}
|
||||
|
||||
func (target *Client) RplMaskList(mode ChannelMode, channel *Channel, mask UserMask) {
|
||||
func (target *Client) RplMaskList(mode ChannelMode, channel *Channel, mask string) {
|
||||
switch mode {
|
||||
case BanMask:
|
||||
target.RplBanList(channel, mask)
|
||||
@ -296,7 +306,7 @@ func (target *Client) RplEndOfMaskList(mode ChannelMode, channel *Channel) {
|
||||
}
|
||||
}
|
||||
|
||||
func (target *Client) RplBanList(channel *Channel, mask UserMask) {
|
||||
func (target *Client) RplBanList(channel *Channel, mask string) {
|
||||
target.NumericReply(RPL_BANLIST,
|
||||
"%s %s", channel, mask)
|
||||
}
|
||||
@ -306,7 +316,7 @@ func (target *Client) RplEndOfBanList(channel *Channel) {
|
||||
"%s :End of channel ban list", channel)
|
||||
}
|
||||
|
||||
func (target *Client) RplExceptList(channel *Channel, mask UserMask) {
|
||||
func (target *Client) RplExceptList(channel *Channel, mask string) {
|
||||
target.NumericReply(RPL_EXCEPTLIST,
|
||||
"%s %s", channel, mask)
|
||||
}
|
||||
@ -316,7 +326,7 @@ func (target *Client) RplEndOfExceptList(channel *Channel) {
|
||||
"%s :End of channel exception list", channel)
|
||||
}
|
||||
|
||||
func (target *Client) RplInviteList(channel *Channel, mask UserMask) {
|
||||
func (target *Client) RplInviteList(channel *Channel, mask string) {
|
||||
target.NumericReply(RPL_INVITELIST,
|
||||
"%s %s", channel, mask)
|
||||
}
|
||||
@ -396,6 +406,17 @@ func (target *Client) RplTime() {
|
||||
"%s :%s", target.server.name, time.Now().Format(time.RFC1123))
|
||||
}
|
||||
|
||||
func (target *Client) RplWhoWasUser(whoWas *WhoWas) {
|
||||
target.NumericReply(RPL_WHOWASUSER,
|
||||
"%s %s %s * :%s",
|
||||
whoWas.nickname, whoWas.username, whoWas.hostname, whoWas.realname)
|
||||
}
|
||||
|
||||
func (target *Client) RplEndOfWhoWas(nickname string) {
|
||||
target.NumericReply(RPL_ENDOFWHOWAS,
|
||||
"%s :End of WHOWAS", nickname)
|
||||
}
|
||||
|
||||
//
|
||||
// errors (also numeric)
|
||||
//
|
||||
@ -515,7 +536,22 @@ func (target *Client) ErrChannelIsFull(channel *Channel) {
|
||||
"%s :Cannot join channel (+l)", channel)
|
||||
}
|
||||
|
||||
func (target *Client) ErrWasNoSuchNick(nickname string) {
|
||||
target.NumericReply(ERR_WASNOSUCHNICK,
|
||||
"%s :There was no such nickname", nickname)
|
||||
}
|
||||
|
||||
func (target *Client) ErrInvalidCapCmd(subCommand CapSubCommand) {
|
||||
target.NumericReply(ERR_INVALIDCAPCMD,
|
||||
"%s :Invalid CAP subcommand", subCommand)
|
||||
}
|
||||
|
||||
func (target *Client) ErrBannedFromChan(channel *Channel) {
|
||||
target.NumericReply(ERR_BANNEDFROMCHAN,
|
||||
"%s :Cannot join channel (+b)", channel)
|
||||
}
|
||||
|
||||
func (target *Client) ErrInviteOnlyChan(channel *Channel) {
|
||||
target.NumericReply(ERR_INVITEONLYCHAN,
|
||||
"%s :Cannot join channel (+i)", channel)
|
||||
}
|
||||
|
@ -18,7 +18,7 @@ import (
|
||||
|
||||
type Server struct {
|
||||
channels ChannelNameMap
|
||||
clients ClientNameMap
|
||||
clients *ClientLookupSet
|
||||
commands chan Command
|
||||
ctime time.Time
|
||||
db *sql.DB
|
||||
@ -29,12 +29,13 @@ type Server struct {
|
||||
operators map[string][]byte
|
||||
password []byte
|
||||
signals chan os.Signal
|
||||
whoWas *WhoWasList
|
||||
}
|
||||
|
||||
func NewServer(config *Config) *Server {
|
||||
server := &Server{
|
||||
channels: make(ChannelNameMap),
|
||||
clients: make(ClientNameMap),
|
||||
clients: NewClientLookupSet(),
|
||||
commands: make(chan Command, 16),
|
||||
ctime: time.Now(),
|
||||
db: OpenDB(config.Server.Database),
|
||||
@ -44,6 +45,7 @@ func NewServer(config *Config) *Server {
|
||||
newConns: make(chan net.Conn, 16),
|
||||
operators: config.Operators(),
|
||||
signals: make(chan os.Signal, 1),
|
||||
whoWas: NewWhoWasList(100),
|
||||
}
|
||||
|
||||
if config.Server.Password != "" {
|
||||
@ -62,9 +64,17 @@ func NewServer(config *Config) *Server {
|
||||
return server
|
||||
}
|
||||
|
||||
func loadChannelList(channel *Channel, list string, maskMode ChannelMode) {
|
||||
if list == "" {
|
||||
return
|
||||
}
|
||||
channel.lists[maskMode].AddAll(strings.Split(list, " "))
|
||||
}
|
||||
|
||||
func (server *Server) loadChannels() {
|
||||
rows, err := server.db.Query(`
|
||||
SELECT name, flags, key, topic, user_limit
|
||||
SELECT name, flags, key, topic, user_limit, ban_list, except_list,
|
||||
invite_list
|
||||
FROM channel`)
|
||||
if err != nil {
|
||||
log.Fatal("error loading channels: ", err)
|
||||
@ -72,9 +82,11 @@ func (server *Server) loadChannels() {
|
||||
for rows.Next() {
|
||||
var name, flags, key, topic string
|
||||
var userLimit uint64
|
||||
err = rows.Scan(&name, &flags, &key, &topic, &userLimit)
|
||||
var banList, exceptList, inviteList string
|
||||
err = rows.Scan(&name, &flags, &key, &topic, &userLimit, &banList,
|
||||
&exceptList, &inviteList)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
log.Println("Server.loadChannels:", err)
|
||||
continue
|
||||
}
|
||||
|
||||
@ -85,6 +97,9 @@ func (server *Server) loadChannels() {
|
||||
channel.key = key
|
||||
channel.topic = topic
|
||||
channel.userLimit = userLimit
|
||||
loadChannelList(channel, banList, BanMask)
|
||||
loadChannelList(channel, exceptList, ExceptMask)
|
||||
loadChannelList(channel, inviteList, InviteMask)
|
||||
}
|
||||
}
|
||||
|
||||
@ -126,7 +141,7 @@ func (server *Server) processCommand(cmd Command) {
|
||||
|
||||
func (server *Server) Shutdown() {
|
||||
server.db.Close()
|
||||
for _, client := range server.clients {
|
||||
for _, client := range server.clients.byNick {
|
||||
client.Reply(RplNotice(server, client, "shutting down"))
|
||||
}
|
||||
}
|
||||
@ -340,7 +355,7 @@ func (msg *RFC1459UserCommand) HandleRegServer(server *Server) {
|
||||
client.Quit("bad password")
|
||||
return
|
||||
}
|
||||
msg.HandleRegServer2(server)
|
||||
msg.setUserInfo(server)
|
||||
}
|
||||
|
||||
func (msg *RFC2812UserCommand) HandleRegServer(server *Server) {
|
||||
@ -357,15 +372,19 @@ func (msg *RFC2812UserCommand) HandleRegServer(server *Server) {
|
||||
}
|
||||
client.RplUModeIs(client)
|
||||
}
|
||||
msg.HandleRegServer2(server)
|
||||
msg.setUserInfo(server)
|
||||
}
|
||||
|
||||
func (msg *UserCommand) HandleRegServer2(server *Server) {
|
||||
func (msg *UserCommand) setUserInfo(server *Server) {
|
||||
client := msg.Client()
|
||||
if client.capState == CapNegotiating {
|
||||
client.capState = CapNegotiated
|
||||
}
|
||||
|
||||
server.clients.Remove(client)
|
||||
client.username, client.realname = msg.username, msg.realname
|
||||
server.clients.Add(client)
|
||||
|
||||
server.tryRegister(client)
|
||||
}
|
||||
|
||||
@ -514,7 +533,7 @@ func (m *ModeCommand) HandleServer(s *Server) {
|
||||
return
|
||||
}
|
||||
|
||||
changes := make(ModeChanges, 0)
|
||||
changes := make(ModeChanges, 0, len(m.changes))
|
||||
|
||||
for _, change := range m.changes {
|
||||
switch change.mode {
|
||||
@ -577,19 +596,14 @@ func (m *WhoisCommand) HandleServer(server *Server) {
|
||||
// TODO implement target query
|
||||
|
||||
for _, mask := range m.masks {
|
||||
// TODO implement wildcard matching
|
||||
mclient := server.clients.Get(mask)
|
||||
if mclient == nil {
|
||||
matches := server.clients.FindAll(mask)
|
||||
if len(matches) == 0 {
|
||||
client.ErrNoSuchNick(mask)
|
||||
continue
|
||||
}
|
||||
client.RplWhoisUser(mclient)
|
||||
if mclient.flags[Operator] {
|
||||
client.RplWhoisOperator(mclient)
|
||||
for mclient := range matches {
|
||||
client.RplWhois(mclient)
|
||||
}
|
||||
client.RplWhoisIdle(mclient)
|
||||
client.RplWhoisChannels(mclient)
|
||||
client.RplEndOfWhois()
|
||||
}
|
||||
}
|
||||
|
||||
@ -604,9 +618,9 @@ func (msg *ChannelModeCommand) HandleServer(server *Server) {
|
||||
channel.Mode(client, msg.changes)
|
||||
}
|
||||
|
||||
func whoChannel(client *Client, channel *Channel) {
|
||||
func whoChannel(client *Client, channel *Channel, friends ClientSet) {
|
||||
for member := range channel.members {
|
||||
if !client.flags[Invisible] {
|
||||
if !client.flags[Invisible] || friends[client] {
|
||||
client.RplWhoReply(channel, member)
|
||||
}
|
||||
}
|
||||
@ -614,27 +628,21 @@ func whoChannel(client *Client, channel *Channel) {
|
||||
|
||||
func (msg *WhoCommand) HandleServer(server *Server) {
|
||||
client := msg.Client()
|
||||
friends := client.Friends()
|
||||
mask := msg.mask
|
||||
|
||||
// TODO implement wildcard matching
|
||||
mask := string(msg.mask)
|
||||
if mask == "" {
|
||||
for _, channel := range server.channels {
|
||||
for member := range channel.members {
|
||||
if !client.flags[Invisible] {
|
||||
client.RplWhoReply(channel, member)
|
||||
}
|
||||
}
|
||||
whoChannel(client, channel, friends)
|
||||
}
|
||||
} else if IsChannel(mask) {
|
||||
// TODO implement wildcard matching
|
||||
channel := server.channels.Get(mask)
|
||||
if channel != nil {
|
||||
for member := range channel.members {
|
||||
client.RplWhoReply(channel, member)
|
||||
}
|
||||
whoChannel(client, channel, friends)
|
||||
}
|
||||
} else {
|
||||
mclient := server.clients.Get(mask)
|
||||
if mclient != nil {
|
||||
for mclient := range server.clients.FindAll(mask) {
|
||||
client.RplWhoReply(nil, mclient)
|
||||
}
|
||||
}
|
||||
@ -874,3 +882,18 @@ func (msg *KillCommand) HandleServer(server *Server) {
|
||||
quitMsg := fmt.Sprintf("KILLed by %s: %s", client.Nick(), msg.comment)
|
||||
target.Quit(quitMsg)
|
||||
}
|
||||
|
||||
func (msg *WhoWasCommand) HandleServer(server *Server) {
|
||||
client := msg.Client()
|
||||
for _, nickname := range msg.nicknames {
|
||||
results := server.whoWas.Find(nickname, msg.count)
|
||||
if len(results) == 0 {
|
||||
client.ErrWasNoSuchNick(nickname)
|
||||
} else {
|
||||
for _, whoWas := range results {
|
||||
client.RplWhoWasUser(whoWas)
|
||||
}
|
||||
}
|
||||
client.RplEndOfWhoWas(nickname)
|
||||
}
|
||||
}
|
||||
|
52
irc/types.go
52
irc/types.go
@ -1,7 +1,6 @@
|
||||
package irc
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
@ -48,9 +47,6 @@ func (set CapabilitySet) DisableString() string {
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
|
||||
// a string with wildcards
|
||||
type Mask string
|
||||
|
||||
// add, remove, list modes
|
||||
type ModeOp rune
|
||||
|
||||
@ -112,40 +108,6 @@ func (channels ChannelNameMap) Remove(channel *Channel) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type ClientNameMap map[string]*Client
|
||||
|
||||
var (
|
||||
ErrNickMissing = errors.New("nick missing")
|
||||
ErrNicknameInUse = errors.New("nickname in use")
|
||||
ErrNicknameMismatch = errors.New("nickname mismatch")
|
||||
)
|
||||
|
||||
func (clients ClientNameMap) Get(nick string) *Client {
|
||||
return clients[strings.ToLower(nick)]
|
||||
}
|
||||
|
||||
func (clients ClientNameMap) Add(client *Client) error {
|
||||
if !client.HasNick() {
|
||||
return ErrNickMissing
|
||||
}
|
||||
if clients.Get(client.nick) != nil {
|
||||
return ErrNicknameInUse
|
||||
}
|
||||
clients[strings.ToLower(client.nick)] = client
|
||||
return nil
|
||||
}
|
||||
|
||||
func (clients ClientNameMap) Remove(client *Client) error {
|
||||
if !client.HasNick() {
|
||||
return ErrNickMissing
|
||||
}
|
||||
if clients.Get(client.nick) != client {
|
||||
return ErrNicknameMismatch
|
||||
}
|
||||
delete(clients, strings.ToLower(client.nick))
|
||||
return nil
|
||||
}
|
||||
|
||||
type ChannelModeSet map[ChannelMode]bool
|
||||
|
||||
func (set ChannelModeSet) String() string {
|
||||
@ -247,17 +209,3 @@ type RegServerCommand interface {
|
||||
Command
|
||||
HandleRegServer(*Server)
|
||||
}
|
||||
|
||||
//
|
||||
// structs
|
||||
//
|
||||
|
||||
type UserMask struct {
|
||||
nickname Mask
|
||||
username Mask
|
||||
hostname Mask
|
||||
}
|
||||
|
||||
func (mask *UserMask) String() string {
|
||||
return fmt.Sprintf("%s!%s@%s", mask.nickname, mask.username, mask.hostname)
|
||||
}
|
||||
|
73
irc/whowas.go
Normal file
73
irc/whowas.go
Normal file
@ -0,0 +1,73 @@
|
||||
package irc
|
||||
|
||||
type WhoWasList struct {
|
||||
buffer []*WhoWas
|
||||
start uint
|
||||
end uint
|
||||
}
|
||||
|
||||
type WhoWas struct {
|
||||
nickname string
|
||||
username string
|
||||
hostname string
|
||||
realname string
|
||||
}
|
||||
|
||||
func NewWhoWasList(size uint) *WhoWasList {
|
||||
return &WhoWasList{
|
||||
buffer: make([]*WhoWas, size),
|
||||
}
|
||||
}
|
||||
|
||||
func (list *WhoWasList) Append(client *Client) {
|
||||
list.buffer[list.end] = &WhoWas{
|
||||
nickname: client.Nick(),
|
||||
username: client.username,
|
||||
hostname: client.hostname,
|
||||
realname: client.realname,
|
||||
}
|
||||
list.end = (list.end + 1) % uint(len(list.buffer))
|
||||
if list.end == list.start {
|
||||
list.start = (list.end + 1) % uint(len(list.buffer))
|
||||
}
|
||||
}
|
||||
|
||||
func (list *WhoWasList) Find(nickname string, limit int64) []*WhoWas {
|
||||
results := make([]*WhoWas, 0)
|
||||
for whoWas := range list.Each() {
|
||||
if nickname != whoWas.nickname {
|
||||
continue
|
||||
}
|
||||
results = append(results, whoWas)
|
||||
if int64(len(results)) >= limit {
|
||||
break
|
||||
}
|
||||
}
|
||||
return results
|
||||
}
|
||||
|
||||
func (list *WhoWasList) prev(index uint) uint {
|
||||
index -= 1
|
||||
if index < 0 {
|
||||
index += uint(len(list.buffer))
|
||||
}
|
||||
return index
|
||||
}
|
||||
|
||||
// Iterate the buffer in reverse.
|
||||
func (list *WhoWasList) Each() <-chan *WhoWas {
|
||||
ch := make(chan *WhoWas)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
if list.start == list.end {
|
||||
return
|
||||
}
|
||||
start := list.prev(list.end)
|
||||
end := list.prev(list.start)
|
||||
for start != end {
|
||||
ch <- list.buffer[start]
|
||||
start = list.prev(start)
|
||||
}
|
||||
}()
|
||||
return ch
|
||||
}
|
Loading…
Reference in New Issue
Block a user