ergo/irc/usermaskset.go

167 lines
3.4 KiB
Go

// Copyright (c) 2012-2014 Jeremy Latt
// Copyright (c) 2016-2018 Daniel Oaks
// Copyright (c) 2019-2020 Shivaram Lingamneni
// released under the MIT license
package irc
import (
"regexp"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/ergochat/ergo/irc/utils"
)
type MaskInfo struct {
TimeCreated time.Time
CreatorNickmask string
CreatorAccount string
}
// UserMaskSet holds a set of client masks and lets you match hostnames to them.
type UserMaskSet struct {
sync.RWMutex
serialCacheUpdateMutex sync.Mutex
masks map[string]MaskInfo
regexp atomic.Pointer[regexp.Regexp]
muteRegexp atomic.Pointer[regexp.Regexp]
}
func NewUserMaskSet() *UserMaskSet {
return new(UserMaskSet)
}
// Add adds the given mask to this set.
func (set *UserMaskSet) Add(mask, creatorNickmask, creatorAccount string) (maskAdded string, err error) {
casefoldedMask, err := CanonicalizeMaskWildcard(mask)
if err != nil {
return
}
set.serialCacheUpdateMutex.Lock()
defer set.serialCacheUpdateMutex.Unlock()
set.Lock()
if set.masks == nil {
set.masks = make(map[string]MaskInfo)
}
_, present := set.masks[casefoldedMask]
if !present {
maskAdded = casefoldedMask
set.masks[casefoldedMask] = MaskInfo{
TimeCreated: time.Now().UTC(),
CreatorNickmask: creatorNickmask,
CreatorAccount: creatorAccount,
}
}
set.Unlock()
if !present {
set.setRegexp()
}
return
}
// Remove removes the given mask from this set.
func (set *UserMaskSet) Remove(mask string) (maskRemoved string, err error) {
mask, err = CanonicalizeMaskWildcard(mask)
if err != nil {
return
}
set.serialCacheUpdateMutex.Lock()
defer set.serialCacheUpdateMutex.Unlock()
set.Lock()
_, removed := set.masks[mask]
if removed {
maskRemoved = mask
delete(set.masks, mask)
}
set.Unlock()
if removed {
set.setRegexp()
}
return
}
func (set *UserMaskSet) SetMasks(masks map[string]MaskInfo) {
set.Lock()
set.masks = masks
set.Unlock()
set.setRegexp()
}
func (set *UserMaskSet) Masks() (result map[string]MaskInfo) {
set.RLock()
defer set.RUnlock()
result = make(map[string]MaskInfo, len(set.masks))
for mask, info := range set.masks {
result[mask] = info
}
return
}
// Match matches the given n!u@h against the standard (non-ext) bans.
func (set *UserMaskSet) Match(userhost string) bool {
regexp := set.regexp.Load()
if regexp == nil {
return false
}
return regexp.MatchString(userhost)
}
// MatchMute matches the given NUH against the mute extbans.
func (set *UserMaskSet) MatchMute(userhost string) bool {
regexp := set.MuteRegexp()
if regexp == nil {
return false
}
return regexp.MatchString(userhost)
}
func (set *UserMaskSet) MuteRegexp() *regexp.Regexp {
return set.muteRegexp.Load()
}
func (set *UserMaskSet) Length() int {
set.RLock()
defer set.RUnlock()
return len(set.masks)
}
func (set *UserMaskSet) setRegexp() {
set.RLock()
maskExprs := make([]string, 0, len(set.masks))
var muteExprs []string
for mask := range set.masks {
if strings.HasPrefix(mask, "m:") {
muteExprs = append(muteExprs, mask[2:])
} else {
maskExprs = append(maskExprs, mask)
}
}
set.RUnlock()
compileMasks := func(masks []string) *regexp.Regexp {
if len(masks) == 0 {
return nil
}
re, _ := utils.CompileMasks(masks)
return re
}
re := compileMasks(maskExprs)
muteRe := compileMasks(muteExprs)
set.regexp.Store(re)
set.muteRegexp.Store(muteRe)
}