3
0
mirror of https://github.com/ergochat/ergo.git synced 2024-11-10 22:19:31 +01:00

update ClientLookupSet when username changes

This commit is contained in:
Jeremy Latt 2014-03-06 16:51:33 -08:00
parent 76852b0370
commit adde42a1bf
2 changed files with 72 additions and 53 deletions

View File

@ -4,9 +4,44 @@ import (
"database/sql" "database/sql"
"errors" "errors"
"log" "log"
"regexp"
"strings" "strings"
) )
var (
ErrNickMissing = errors.New("nick missing")
ErrNicknameInUse = errors.New("nickname in use")
ErrNicknameMismatch = errors.New("nickname mismatch")
wildMaskExpr = regexp.MustCompile(`\*|\?`)
likeQuoter = strings.NewReplacer(
`\`, `\\`,
`%`, `\%`,
`_`, `\_`,
`*`, `%`,
`?`, `_`)
)
func HasWildcards(mask string) bool {
return wildMaskExpr.MatchString(mask)
}
func ExpandUserHost(userhost string) (expanded string) {
expanded = userhost
// fill in missing wildcards for nicks
if !strings.Contains(expanded, "!") {
expanded += "!*"
}
if !strings.Contains(expanded, "@") {
expanded += "@*"
}
return
}
func QuoteLike(userhost string) (like string) {
like = likeQuoter.Replace(userhost)
return
}
type ClientLookupSet struct { type ClientLookupSet struct {
byNick map[string]*Client byNick map[string]*Client
db *ClientDB db *ClientDB
@ -19,12 +54,6 @@ func NewClientLookupSet() *ClientLookupSet {
} }
} }
var (
ErrNickMissing = errors.New("nick missing")
ErrNicknameInUse = errors.New("nickname in use")
ErrNicknameMismatch = errors.New("nickname mismatch")
)
func (clients *ClientLookupSet) Get(nick string) *Client { func (clients *ClientLookupSet) Get(nick string) *Client {
return clients.byNick[strings.ToLower(nick)] return clients.byNick[strings.ToLower(nick)]
} }
@ -53,38 +82,35 @@ func (clients *ClientLookupSet) Remove(client *Client) error {
return nil return nil
} }
func ExpandUserHost(userhost string) (expanded string) {
expanded = userhost
// fill in missing wildcards for nicks
if !strings.Contains(expanded, "!") {
expanded += "!*"
}
if !strings.Contains(expanded, "@") {
expanded += "@*"
}
return
}
func (clients *ClientLookupSet) FindAll(userhost string) (set ClientSet) { func (clients *ClientLookupSet) FindAll(userhost string) (set ClientSet) {
userhost = ExpandUserHost(userhost) userhost = ExpandUserHost(userhost)
set = make(ClientSet) set = make(ClientSet)
rows, err := clients.db.db.Query( rows, err := clients.db.db.Query(
`SELECT nickname FROM client `SELECT nickname FROM client WHERE userhost LIKE ? ESCAPE '\'`,
WHERE userhost LIKE ? ESCAPE '\'`,
QuoteLike(userhost)) QuoteLike(userhost))
if err != nil { if err != nil {
if DEBUG_SERVER {
log.Println("ClientLookupSet.FindAll.Query:", err)
}
return return
} }
for rows.Next() { for rows.Next() {
var nickname string var nickname string
err := rows.Scan(&nickname) err := rows.Scan(&nickname)
if err != nil { if err != nil {
if DEBUG_SERVER {
log.Println("ClientLookupSet.FindAll.Scan:", err)
}
return return
} }
client := clients.Get(nickname) client := clients.Get(nickname)
if client != nil { if client == nil {
set.Add(client) if DEBUG_SERVER {
log.Println("ClientLookupSet.FindAll: missing client:", nickname)
} }
continue
}
set.Add(client)
} }
return return
} }
@ -92,14 +118,14 @@ func (clients *ClientLookupSet) FindAll(userhost string) (set ClientSet) {
func (clients *ClientLookupSet) Find(userhost string) *Client { func (clients *ClientLookupSet) Find(userhost string) *Client {
userhost = ExpandUserHost(userhost) userhost = ExpandUserHost(userhost)
row := clients.db.db.QueryRow( row := clients.db.db.QueryRow(
`SELECT nickname FROM client `SELECT nickname FROM client WHERE userhost LIKE ? ESCAPE '\' LIMIT 1`,
WHERE userhost LIKE ? ESCAPE \
LIMIT 1`,
QuoteLike(userhost)) QuoteLike(userhost))
var nickname string var nickname string
err := row.Scan(&nickname) err := row.Scan(&nickname)
if err != nil { if err != nil {
log.Println("ClientLookupSet.Find: ", err) if DEBUG_SERVER {
log.Println("ClientLookupSet.Find:", err)
}
return nil return nil
} }
return clients.Get(nickname) return clients.Get(nickname)
@ -117,17 +143,19 @@ func NewClientDB() *ClientDB {
db := &ClientDB{ db := &ClientDB{
db: OpenDB(":memory:"), db: OpenDB(":memory:"),
} }
_, err := db.db.Exec(` stmts := []string{
CREATE TABLE client ( `CREATE TABLE client (
nickname TEXT NOT NULL UNIQUE, nickname TEXT NOT NULL COLLATE NOCASE UNIQUE,
userhost TEXT NOT NULL)`) userhost TEXT NOT NULL COLLATE NOCASE,
if err != nil { UNIQUE (nickname, userhost) ON CONFLICT REPLACE)`,
log.Fatal(err) `CREATE UNIQUE INDEX idx_nick ON client (nickname COLLATE NOCASE)`,
`CREATE UNIQUE INDEX idx_uh ON client (userhost COLLATE NOCASE)`,
} }
_, err = db.db.Exec(` for _, stmt := range stmts {
CREATE UNIQUE INDEX nickname_index ON client (nickname)`) _, err := db.db.Exec(stmt)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal("NewClientDB: ", stmt, err)
}
} }
return db return db
} }
@ -136,7 +164,9 @@ func (db *ClientDB) Add(client *Client) {
_, err := db.db.Exec(`INSERT INTO client (nickname, userhost) VALUES (?, ?)`, _, err := db.db.Exec(`INSERT INTO client (nickname, userhost) VALUES (?, ?)`,
client.Nick(), client.UserHost()) client.Nick(), client.UserHost())
if err != nil { if err != nil {
log.Println(err) if DEBUG_SERVER {
log.Println("ClientDB.Add:", err)
}
} }
} }
@ -144,21 +174,8 @@ func (db *ClientDB) Remove(client *Client) {
_, err := db.db.Exec(`DELETE FROM client WHERE nickname = ?`, _, err := db.db.Exec(`DELETE FROM client WHERE nickname = ?`,
client.Nick()) client.Nick())
if err != nil { if err != nil {
log.Println(err) if DEBUG_SERVER {
log.Println("ClientDB.Remove:", err)
}
} }
} }
func QuoteLike(userhost string) (like string) {
like = userhost
// escape escape char
like = strings.Replace(like, `\`, `\\`, -1)
// escape meta-many
like = strings.Replace(like, `%`, `\%`, -1)
// escape meta-one
like = strings.Replace(like, `_`, `\_`, -1)
// swap meta-many
like = strings.Replace(like, `*`, `%`, -1)
// swap meta-one
like = strings.Replace(like, `?`, `_`, -1)
return
}

View File

@ -346,7 +346,9 @@ func (msg *RFC2812UserCommand) HandleRegServer(server *Server) {
func (msg *UserCommand) HandleRegServer2(server *Server) { func (msg *UserCommand) HandleRegServer2(server *Server) {
client := msg.Client() client := msg.Client()
server.clients.Remove(client)
client.username, client.realname = msg.username, msg.realname client.username, client.realname = msg.username, msg.realname
server.clients.Add(client)
server.tryRegister(client) server.tryRegister(client)
} }