mirror of
https://github.com/ergochat/ergo.git
synced 2024-11-10 22:19:31 +01:00
introduce "flat ip" representations
This commit is contained in:
parent
85c39f3ea0
commit
44cc4c2092
1
Makefile
1
Makefile
@ -25,6 +25,7 @@ test:
|
|||||||
cd irc/cloaks && go test . && go vet .
|
cd irc/cloaks && go test . && go vet .
|
||||||
cd irc/connection_limits && go test . && go vet .
|
cd irc/connection_limits && go test . && go vet .
|
||||||
cd irc/email && go test . && go vet .
|
cd irc/email && go test . && go vet .
|
||||||
|
cd irc/flatip && go test . && go vet .
|
||||||
cd irc/history && go test . && go vet .
|
cd irc/history && go test . && go vet .
|
||||||
cd irc/isupport && go test . && go vet .
|
cd irc/isupport && go test . && go vet .
|
||||||
cd irc/migrations && go test . && go vet .
|
cd irc/migrations && go test . && go vet .
|
||||||
|
@ -4,12 +4,14 @@
|
|||||||
package connection_limits
|
package connection_limits
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/md5"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/oragono/oragono/irc/flatip"
|
||||||
"github.com/oragono/oragono/irc/utils"
|
"github.com/oragono/oragono/irc/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -26,10 +28,15 @@ type CustomLimitConfig struct {
|
|||||||
|
|
||||||
// tuples the key-value pair of a CIDR and its custom limit/throttle values
|
// tuples the key-value pair of a CIDR and its custom limit/throttle values
|
||||||
type customLimit struct {
|
type customLimit struct {
|
||||||
name string
|
name [16]byte
|
||||||
maxConcurrent int
|
maxConcurrent int
|
||||||
maxPerWindow int
|
maxPerWindow int
|
||||||
nets []net.IPNet
|
nets []flatip.IPNet
|
||||||
|
}
|
||||||
|
|
||||||
|
type limiterKey struct {
|
||||||
|
maskedIP flatip.IP
|
||||||
|
prefixLen uint8 // 0 for the fake nets we generate for custom limits
|
||||||
}
|
}
|
||||||
|
|
||||||
// LimiterConfig controls the automated connection limits.
|
// LimiterConfig controls the automated connection limits.
|
||||||
@ -55,9 +62,7 @@ type rawLimiterConfig struct {
|
|||||||
type LimiterConfig struct {
|
type LimiterConfig struct {
|
||||||
rawLimiterConfig
|
rawLimiterConfig
|
||||||
|
|
||||||
ipv4Mask net.IPMask
|
exemptedNets []flatip.IPNet
|
||||||
ipv6Mask net.IPMask
|
|
||||||
exemptedNets []net.IPNet
|
|
||||||
customLimits []customLimit
|
customLimits []customLimit
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -69,15 +74,19 @@ func (config *LimiterConfig) UnmarshalYAML(unmarshal func(interface{}) error) (e
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (config *LimiterConfig) postprocess() (err error) {
|
func (config *LimiterConfig) postprocess() (err error) {
|
||||||
config.exemptedNets, err = utils.ParseNetList(config.Exempted)
|
exemptedNets, err := utils.ParseNetList(config.Exempted)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Could not parse limiter exemption list: %v", err.Error())
|
return fmt.Errorf("Could not parse limiter exemption list: %v", err.Error())
|
||||||
}
|
}
|
||||||
|
config.exemptedNets = make([]flatip.IPNet, len(exemptedNets))
|
||||||
|
for i, exempted := range exemptedNets {
|
||||||
|
config.exemptedNets[i] = flatip.FromNetIPNet(exempted)
|
||||||
|
}
|
||||||
|
|
||||||
for identifier, customLimitConf := range config.CustomLimits {
|
for identifier, customLimitConf := range config.CustomLimits {
|
||||||
nets := make([]net.IPNet, len(customLimitConf.Nets))
|
nets := make([]flatip.IPNet, len(customLimitConf.Nets))
|
||||||
for i, netStr := range customLimitConf.Nets {
|
for i, netStr := range customLimitConf.Nets {
|
||||||
normalizedNet, err := utils.NormalizedNetFromString(netStr)
|
normalizedNet, err := flatip.ParseToNormalizedNet(netStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Bad net %s in custom-limits block %s: %w", netStr, identifier, err)
|
return fmt.Errorf("Bad net %s in custom-limits block %s: %w", netStr, identifier, err)
|
||||||
}
|
}
|
||||||
@ -86,23 +95,20 @@ func (config *LimiterConfig) postprocess() (err error) {
|
|||||||
if len(customLimitConf.Nets) == 0 {
|
if len(customLimitConf.Nets) == 0 {
|
||||||
// see #1421: this is the legacy config format where the
|
// see #1421: this is the legacy config format where the
|
||||||
// dictionary key of the block is a CIDR string
|
// dictionary key of the block is a CIDR string
|
||||||
normalizedNet, err := utils.NormalizedNetFromString(identifier)
|
normalizedNet, err := flatip.ParseToNormalizedNet(identifier)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Custom limit block %s has no defined nets", identifier)
|
return fmt.Errorf("Custom limit block %s has no defined nets", identifier)
|
||||||
}
|
}
|
||||||
nets = []net.IPNet{normalizedNet}
|
nets = []flatip.IPNet{normalizedNet}
|
||||||
}
|
}
|
||||||
config.customLimits = append(config.customLimits, customLimit{
|
config.customLimits = append(config.customLimits, customLimit{
|
||||||
maxConcurrent: customLimitConf.MaxConcurrent,
|
maxConcurrent: customLimitConf.MaxConcurrent,
|
||||||
maxPerWindow: customLimitConf.MaxPerWindow,
|
maxPerWindow: customLimitConf.MaxPerWindow,
|
||||||
name: "*" + identifier,
|
name: md5.Sum([]byte(identifier)),
|
||||||
nets: nets,
|
nets: nets,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
config.ipv4Mask = net.CIDRMask(config.CidrLenIPv4, 32)
|
|
||||||
config.ipv6Mask = net.CIDRMask(config.CidrLenIPv6, 128)
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -113,50 +119,48 @@ type Limiter struct {
|
|||||||
config *LimiterConfig
|
config *LimiterConfig
|
||||||
|
|
||||||
// IP/CIDR -> count of clients connected from there:
|
// IP/CIDR -> count of clients connected from there:
|
||||||
limiter map[string]int
|
limiter map[limiterKey]int
|
||||||
// IP/CIDR -> throttle state:
|
// IP/CIDR -> throttle state:
|
||||||
throttler map[string]ThrottleDetails
|
throttler map[limiterKey]ThrottleDetails
|
||||||
}
|
}
|
||||||
|
|
||||||
// addrToKey canonicalizes `addr` to a string key, and returns
|
// addrToKey canonicalizes `addr` to a string key, and returns
|
||||||
// the relevant connection limit and throttle max-per-window values
|
// the relevant connection limit and throttle max-per-window values
|
||||||
func (cl *Limiter) addrToKey(addr net.IP) (key string, limit int, throttle int) {
|
func (cl *Limiter) addrToKey(flat flatip.IP) (key limiterKey, limit int, throttle int) {
|
||||||
// `key` will be a CIDR string like "8.8.8.8/32" or "2001:0db8::/32"
|
|
||||||
for _, custom := range cl.config.customLimits {
|
for _, custom := range cl.config.customLimits {
|
||||||
for _, net := range custom.nets {
|
for _, net := range custom.nets {
|
||||||
if net.Contains(addr) {
|
if net.Contains(flat) {
|
||||||
return custom.name, custom.maxConcurrent, custom.maxPerWindow
|
return limiterKey{maskedIP: custom.name, prefixLen: 0}, custom.maxConcurrent, custom.maxPerWindow
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var ipNet net.IPNet
|
var prefixLen int
|
||||||
addrv4 := addr.To4()
|
if flat.IsIPv4() {
|
||||||
if addrv4 != nil {
|
prefixLen = cl.config.CidrLenIPv4
|
||||||
ipNet = net.IPNet{
|
flat = flat.Mask(prefixLen, 32)
|
||||||
IP: addrv4.Mask(cl.config.ipv4Mask),
|
prefixLen += 96
|
||||||
Mask: cl.config.ipv4Mask,
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
ipNet = net.IPNet{
|
prefixLen = cl.config.CidrLenIPv6
|
||||||
IP: addr.Mask(cl.config.ipv6Mask),
|
flat = flat.Mask(prefixLen, 128)
|
||||||
Mask: cl.config.ipv6Mask,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return ipNet.String(), cl.config.MaxConcurrent, cl.config.MaxPerWindow
|
|
||||||
|
return limiterKey{maskedIP: flat, prefixLen: uint8(prefixLen)}, cl.config.MaxConcurrent, cl.config.MaxPerWindow
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddClient adds a client to our population if possible. If we can't, throws an error instead.
|
// AddClient adds a client to our population if possible. If we can't, throws an error instead.
|
||||||
func (cl *Limiter) AddClient(addr net.IP) error {
|
func (cl *Limiter) AddClient(addr net.IP) error {
|
||||||
|
flat := flatip.FromNetIP(addr)
|
||||||
|
|
||||||
cl.Lock()
|
cl.Lock()
|
||||||
defer cl.Unlock()
|
defer cl.Unlock()
|
||||||
|
|
||||||
// we don't track populations for exempted addresses or nets - this is by design
|
// we don't track populations for exempted addresses or nets - this is by design
|
||||||
if utils.IPInNets(addr, cl.config.exemptedNets) {
|
if flatip.IPInNets(flat, cl.config.exemptedNets) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
addrString, maxConcurrent, maxPerWindow := cl.addrToKey(addr)
|
addrString, maxConcurrent, maxPerWindow := cl.addrToKey(flat)
|
||||||
|
|
||||||
// XXX check throttle first; if we checked limit first and then checked throttle,
|
// XXX check throttle first; if we checked limit first and then checked throttle,
|
||||||
// we'd have to decrement the limit on an unsuccessful throttle check
|
// we'd have to decrement the limit on an unsuccessful throttle check
|
||||||
@ -189,14 +193,16 @@ func (cl *Limiter) AddClient(addr net.IP) error {
|
|||||||
|
|
||||||
// RemoveClient removes the given address from our population
|
// RemoveClient removes the given address from our population
|
||||||
func (cl *Limiter) RemoveClient(addr net.IP) {
|
func (cl *Limiter) RemoveClient(addr net.IP) {
|
||||||
|
flat := flatip.FromNetIP(addr)
|
||||||
|
|
||||||
cl.Lock()
|
cl.Lock()
|
||||||
defer cl.Unlock()
|
defer cl.Unlock()
|
||||||
|
|
||||||
if !cl.config.Count || utils.IPInNets(addr, cl.config.exemptedNets) {
|
if !cl.config.Count || flatip.IPInNets(flat, cl.config.exemptedNets) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
addrString, _, _ := cl.addrToKey(addr)
|
addrString, _, _ := cl.addrToKey(flat)
|
||||||
count := cl.limiter[addrString]
|
count := cl.limiter[addrString]
|
||||||
count -= 1
|
count -= 1
|
||||||
if count < 0 {
|
if count < 0 {
|
||||||
@ -207,14 +213,16 @@ func (cl *Limiter) RemoveClient(addr net.IP) {
|
|||||||
|
|
||||||
// ResetThrottle resets the throttle count for an IP
|
// ResetThrottle resets the throttle count for an IP
|
||||||
func (cl *Limiter) ResetThrottle(addr net.IP) {
|
func (cl *Limiter) ResetThrottle(addr net.IP) {
|
||||||
|
flat := flatip.FromNetIP(addr)
|
||||||
|
|
||||||
cl.Lock()
|
cl.Lock()
|
||||||
defer cl.Unlock()
|
defer cl.Unlock()
|
||||||
|
|
||||||
if !cl.config.Throttle || utils.IPInNets(addr, cl.config.exemptedNets) {
|
if !cl.config.Throttle || flatip.IPInNets(flat, cl.config.exemptedNets) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
addrString, _, _ := cl.addrToKey(addr)
|
addrString, _, _ := cl.addrToKey(flat)
|
||||||
delete(cl.throttler, addrString)
|
delete(cl.throttler, addrString)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -224,10 +232,10 @@ func (cl *Limiter) ApplyConfig(config *LimiterConfig) {
|
|||||||
defer cl.Unlock()
|
defer cl.Unlock()
|
||||||
|
|
||||||
if cl.limiter == nil {
|
if cl.limiter == nil {
|
||||||
cl.limiter = make(map[string]int)
|
cl.limiter = make(map[limiterKey]int)
|
||||||
}
|
}
|
||||||
if cl.throttler == nil {
|
if cl.throttler == nil {
|
||||||
cl.throttler = make(map[string]ThrottleDetails)
|
cl.throttler = make(map[limiterKey]ThrottleDetails)
|
||||||
}
|
}
|
||||||
|
|
||||||
cl.config = config
|
cl.config = config
|
||||||
|
@ -4,9 +4,12 @@
|
|||||||
package connection_limits
|
package connection_limits
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/md5"
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/oragono/oragono/irc/flatip"
|
||||||
)
|
)
|
||||||
|
|
||||||
func easyParseIP(ipstr string) (result net.IP) {
|
func easyParseIP(ipstr string) (result net.IP) {
|
||||||
@ -17,6 +20,11 @@ func easyParseIP(ipstr string) (result net.IP) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func easyParseFlat(ipstr string) (result flatip.IP) {
|
||||||
|
r1 := easyParseIP(ipstr)
|
||||||
|
return flatip.FromNetIP(r1)
|
||||||
|
}
|
||||||
|
|
||||||
var baseConfig = LimiterConfig{
|
var baseConfig = LimiterConfig{
|
||||||
rawLimiterConfig: rawLimiterConfig{
|
rawLimiterConfig: rawLimiterConfig{
|
||||||
Count: true,
|
Count: true,
|
||||||
@ -47,18 +55,23 @@ func TestKeying(t *testing.T) {
|
|||||||
var limiter Limiter
|
var limiter Limiter
|
||||||
limiter.ApplyConfig(&config)
|
limiter.ApplyConfig(&config)
|
||||||
|
|
||||||
key, maxConc, maxWin := limiter.addrToKey(easyParseIP("1.1.1.1"))
|
// an ipv4 /32 looks like a /128 to us after applying the 4-in-6 mapping
|
||||||
assertEqual(key, "1.1.1.1/32", t)
|
key, maxConc, maxWin := limiter.addrToKey(easyParseFlat("1.1.1.1"))
|
||||||
|
assertEqual(key.prefixLen, uint8(128), t)
|
||||||
|
assertEqual(key.maskedIP[12:], []byte{1, 1, 1, 1}, t)
|
||||||
assertEqual(maxConc, 4, t)
|
assertEqual(maxConc, 4, t)
|
||||||
assertEqual(maxWin, 8, t)
|
assertEqual(maxWin, 8, t)
|
||||||
|
|
||||||
key, maxConc, maxWin = limiter.addrToKey(easyParseIP("2607:5301:201:3100::7426"))
|
testIPv6 := easyParseFlat("2607:5301:201:3100::7426")
|
||||||
assertEqual(key, "2607:5301:201:3100::/64", t)
|
key, maxConc, maxWin = limiter.addrToKey(testIPv6)
|
||||||
|
assertEqual(key.prefixLen, uint8(64), t)
|
||||||
|
assertEqual(key.maskedIP[:], []byte(easyParseIP("2607:5301:201:3100::")), t)
|
||||||
assertEqual(maxConc, 4, t)
|
assertEqual(maxConc, 4, t)
|
||||||
assertEqual(maxWin, 8, t)
|
assertEqual(maxWin, 8, t)
|
||||||
|
|
||||||
key, maxConc, maxWin = limiter.addrToKey(easyParseIP("8.8.4.4"))
|
key, maxConc, maxWin = limiter.addrToKey(easyParseFlat("8.8.4.4"))
|
||||||
assertEqual(key, "*google", t)
|
assertEqual(key.prefixLen, uint8(0), t)
|
||||||
|
assertEqual([16]byte(key.maskedIP), md5.Sum([]byte("google")), t)
|
||||||
assertEqual(maxConc, 128, t)
|
assertEqual(maxConc, 128, t)
|
||||||
assertEqual(maxWin, 256, t)
|
assertEqual(maxWin, 256, t)
|
||||||
}
|
}
|
||||||
|
74
irc/dline.go
74
irc/dline.go
@ -11,6 +11,7 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/oragono/oragono/irc/flatip"
|
||||||
"github.com/oragono/oragono/irc/utils"
|
"github.com/oragono/oragono/irc/utils"
|
||||||
"github.com/tidwall/buntdb"
|
"github.com/tidwall/buntdb"
|
||||||
)
|
)
|
||||||
@ -54,34 +55,22 @@ func (info IPBanInfo) BanMessage(message string) string {
|
|||||||
return message
|
return message
|
||||||
}
|
}
|
||||||
|
|
||||||
// dLineNet contains the net itself and expiration time for a given network.
|
|
||||||
type dLineNet struct {
|
|
||||||
// Network is the network that is blocked.
|
|
||||||
// This is always an IPv6 CIDR; IPv4 CIDRs are translated with the 4-in-6 prefix,
|
|
||||||
// individual IPv4 and IPV6 addresses are translated to the relevant /128.
|
|
||||||
Network net.IPNet
|
|
||||||
// Info contains information on the ban.
|
|
||||||
Info IPBanInfo
|
|
||||||
}
|
|
||||||
|
|
||||||
// DLineManager manages and dlines.
|
// DLineManager manages and dlines.
|
||||||
type DLineManager struct {
|
type DLineManager struct {
|
||||||
sync.RWMutex // tier 1
|
sync.RWMutex // tier 1
|
||||||
persistenceMutex sync.Mutex // tier 2
|
persistenceMutex sync.Mutex // tier 2
|
||||||
// networks that are dlined:
|
// networks that are dlined:
|
||||||
// XXX: the keys of this map (which are also the database persistence keys)
|
networks map[flatip.IPNet]IPBanInfo
|
||||||
// are the human-readable representations returned by NetToNormalizedString
|
|
||||||
networks map[string]dLineNet
|
|
||||||
// this keeps track of expiration timers for temporary bans
|
// this keeps track of expiration timers for temporary bans
|
||||||
expirationTimers map[string]*time.Timer
|
expirationTimers map[flatip.IPNet]*time.Timer
|
||||||
server *Server
|
server *Server
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDLineManager returns a new DLineManager.
|
// NewDLineManager returns a new DLineManager.
|
||||||
func NewDLineManager(server *Server) *DLineManager {
|
func NewDLineManager(server *Server) *DLineManager {
|
||||||
var dm DLineManager
|
var dm DLineManager
|
||||||
dm.networks = make(map[string]dLineNet)
|
dm.networks = make(map[flatip.IPNet]IPBanInfo)
|
||||||
dm.expirationTimers = make(map[string]*time.Timer)
|
dm.expirationTimers = make(map[flatip.IPNet]*time.Timer)
|
||||||
dm.server = server
|
dm.server = server
|
||||||
|
|
||||||
dm.loadFromDatastore()
|
dm.loadFromDatastore()
|
||||||
@ -96,9 +85,8 @@ func (dm *DLineManager) AllBans() map[string]IPBanInfo {
|
|||||||
dm.RLock()
|
dm.RLock()
|
||||||
defer dm.RUnlock()
|
defer dm.RUnlock()
|
||||||
|
|
||||||
// map keys are already the human-readable forms, just return a copy of the map
|
|
||||||
for key, info := range dm.networks {
|
for key, info := range dm.networks {
|
||||||
allb[key] = info.Info
|
allb[key.String()] = info
|
||||||
}
|
}
|
||||||
|
|
||||||
return allb
|
return allb
|
||||||
@ -122,9 +110,9 @@ func (dm *DLineManager) AddNetwork(network net.IPNet, duration time.Duration, re
|
|||||||
return dm.persistDline(id, info)
|
return dm.persistDline(id, info)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (dm *DLineManager) addNetworkInternal(network net.IPNet, info IPBanInfo) (id string) {
|
func (dm *DLineManager) addNetworkInternal(network net.IPNet, info IPBanInfo) (id flatip.IPNet) {
|
||||||
network = utils.NormalizeNet(network)
|
flatnet := flatip.FromNetIPNet(network)
|
||||||
id = utils.NetToNormalizedString(network)
|
id = flatnet
|
||||||
|
|
||||||
var timeLeft time.Duration
|
var timeLeft time.Duration
|
||||||
if info.Duration != 0 {
|
if info.Duration != 0 {
|
||||||
@ -137,12 +125,9 @@ func (dm *DLineManager) addNetworkInternal(network net.IPNet, info IPBanInfo) (i
|
|||||||
dm.Lock()
|
dm.Lock()
|
||||||
defer dm.Unlock()
|
defer dm.Unlock()
|
||||||
|
|
||||||
dm.networks[id] = dLineNet{
|
dm.networks[flatnet] = info
|
||||||
Network: network,
|
|
||||||
Info: info,
|
|
||||||
}
|
|
||||||
|
|
||||||
dm.cancelTimer(id)
|
dm.cancelTimer(flatnet)
|
||||||
|
|
||||||
if info.Duration == 0 {
|
if info.Duration == 0 {
|
||||||
return
|
return
|
||||||
@ -154,29 +139,29 @@ func (dm *DLineManager) addNetworkInternal(network net.IPNet, info IPBanInfo) (i
|
|||||||
dm.Lock()
|
dm.Lock()
|
||||||
defer dm.Unlock()
|
defer dm.Unlock()
|
||||||
|
|
||||||
netBan, ok := dm.networks[id]
|
banInfo, ok := dm.networks[flatnet]
|
||||||
if ok && netBan.Info.TimeCreated.Equal(timeCreated) {
|
if ok && banInfo.TimeCreated.Equal(timeCreated) {
|
||||||
delete(dm.networks, id)
|
delete(dm.networks, flatnet)
|
||||||
// TODO(slingamn) here's where we'd remove it from the radix tree
|
// TODO(slingamn) here's where we'd remove it from the radix tree
|
||||||
delete(dm.expirationTimers, id)
|
delete(dm.expirationTimers, flatnet)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
dm.expirationTimers[id] = time.AfterFunc(timeLeft, processExpiration)
|
dm.expirationTimers[flatnet] = time.AfterFunc(timeLeft, processExpiration)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (dm *DLineManager) cancelTimer(id string) {
|
func (dm *DLineManager) cancelTimer(flatnet flatip.IPNet) {
|
||||||
oldTimer := dm.expirationTimers[id]
|
oldTimer := dm.expirationTimers[flatnet]
|
||||||
if oldTimer != nil {
|
if oldTimer != nil {
|
||||||
oldTimer.Stop()
|
oldTimer.Stop()
|
||||||
delete(dm.expirationTimers, id)
|
delete(dm.expirationTimers, flatnet)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (dm *DLineManager) persistDline(id string, info IPBanInfo) error {
|
func (dm *DLineManager) persistDline(id flatip.IPNet, info IPBanInfo) error {
|
||||||
// save in datastore
|
// save in datastore
|
||||||
dlineKey := fmt.Sprintf(keyDlineEntry, id)
|
dlineKey := fmt.Sprintf(keyDlineEntry, id.String())
|
||||||
// assemble json from ban info
|
// assemble json from ban info
|
||||||
b, err := json.Marshal(info)
|
b, err := json.Marshal(info)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -199,8 +184,8 @@ func (dm *DLineManager) persistDline(id string, info IPBanInfo) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (dm *DLineManager) unpersistDline(id string) error {
|
func (dm *DLineManager) unpersistDline(id flatip.IPNet) error {
|
||||||
dlineKey := fmt.Sprintf(keyDlineEntry, id)
|
dlineKey := fmt.Sprintf(keyDlineEntry, id.String())
|
||||||
return dm.server.store.Update(func(tx *buntdb.Tx) error {
|
return dm.server.store.Update(func(tx *buntdb.Tx) error {
|
||||||
_, err := tx.Delete(dlineKey)
|
_, err := tx.Delete(dlineKey)
|
||||||
return err
|
return err
|
||||||
@ -212,7 +197,7 @@ func (dm *DLineManager) RemoveNetwork(network net.IPNet) error {
|
|||||||
dm.persistenceMutex.Lock()
|
dm.persistenceMutex.Lock()
|
||||||
defer dm.persistenceMutex.Unlock()
|
defer dm.persistenceMutex.Unlock()
|
||||||
|
|
||||||
id := utils.NetToNormalizedString(utils.NormalizeNet(network))
|
id := flatip.FromNetIPNet(network)
|
||||||
|
|
||||||
present := func() bool {
|
present := func() bool {
|
||||||
dm.Lock()
|
dm.Lock()
|
||||||
@ -241,8 +226,8 @@ func (dm *DLineManager) RemoveIP(addr net.IP) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CheckIP returns whether or not an IP address was banned, and how long it is banned for.
|
// CheckIP returns whether or not an IP address was banned, and how long it is banned for.
|
||||||
func (dm *DLineManager) CheckIP(addr net.IP) (isBanned bool, info IPBanInfo) {
|
func (dm *DLineManager) CheckIP(netAddr net.IP) (isBanned bool, info IPBanInfo) {
|
||||||
addr = addr.To16() // almost certainly unnecessary
|
addr := flatip.FromNetIP(netAddr)
|
||||||
if addr.IsLoopback() {
|
if addr.IsLoopback() {
|
||||||
return // #671
|
return // #671
|
||||||
}
|
}
|
||||||
@ -252,13 +237,12 @@ func (dm *DLineManager) CheckIP(addr net.IP) (isBanned bool, info IPBanInfo) {
|
|||||||
|
|
||||||
// check networks
|
// check networks
|
||||||
// TODO(slingamn) use a radix tree as the data plane for this
|
// TODO(slingamn) use a radix tree as the data plane for this
|
||||||
for _, netBan := range dm.networks {
|
for flatnet, info := range dm.networks {
|
||||||
if netBan.Network.Contains(addr) {
|
if flatnet.Contains(addr) {
|
||||||
return true, netBan.Info
|
return true, info
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// no matches!
|
// no matches!
|
||||||
isBanned = false
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
217
irc/flatip/flatip.go
Normal file
217
irc/flatip/flatip.go
Normal file
@ -0,0 +1,217 @@
|
|||||||
|
// Copyright 2020 Shivaram Lingamneni <slingamn@cs.stanford.edu>
|
||||||
|
// Copyright 2009 The Go Authors
|
||||||
|
|
||||||
|
package flatip
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
v4InV6Prefix = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff}
|
||||||
|
|
||||||
|
IPv6loopback = IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
|
||||||
|
|
||||||
|
ErrInvalidIPString = errors.New("String could not be interpreted as an IP address")
|
||||||
|
)
|
||||||
|
|
||||||
|
// packed versions of net.IP and net.IPNet; these are pure value types,
|
||||||
|
// so they can be compared with == and used as map keys.
|
||||||
|
|
||||||
|
// IP is the 128-bit representation of the IPv6 address, using the 4-in-6 mapping
|
||||||
|
// if necessary:
|
||||||
|
type IP [16]byte
|
||||||
|
|
||||||
|
// IPNet is a IP network. In a valid value, all bits after PrefixLen are zeroes.
|
||||||
|
type IPNet struct {
|
||||||
|
IP
|
||||||
|
PrefixLen uint8
|
||||||
|
}
|
||||||
|
|
||||||
|
// NetIP converts an IP into a net.IP.
|
||||||
|
func (ip IP) NetIP() (result net.IP) {
|
||||||
|
result = make(net.IP, 16)
|
||||||
|
copy(result[:], ip[:])
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// FromNetIP converts a net.IP into an IP.
|
||||||
|
func FromNetIP(ip net.IP) (result IP) {
|
||||||
|
if len(ip) == 16 {
|
||||||
|
copy(result[:], ip[:])
|
||||||
|
} else {
|
||||||
|
result[10] = 0xff
|
||||||
|
result[11] = 0xff
|
||||||
|
copy(result[12:], ip[:])
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// IPv4 returns the IP address representation of a.b.c.d
|
||||||
|
func IPv4(a, b, c, d byte) (result IP) {
|
||||||
|
copy(result[:12], v4InV6Prefix)
|
||||||
|
result[12] = a
|
||||||
|
result[13] = b
|
||||||
|
result[14] = c
|
||||||
|
result[15] = d
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseIP parses a string representation of an IP address into an IP.
|
||||||
|
// Unlike net.ParseIP, it returns an error instead of a zero value on failure,
|
||||||
|
// since the zero value of `IP` is a representation of a valid IP (::0, the
|
||||||
|
// IPv6 "unspecified address").
|
||||||
|
func ParseIP(ipstr string) (ip IP, err error) {
|
||||||
|
// TODO reimplement this without net.ParseIP
|
||||||
|
netip := net.ParseIP(ipstr)
|
||||||
|
if netip == nil {
|
||||||
|
err = ErrInvalidIPString
|
||||||
|
return
|
||||||
|
}
|
||||||
|
netip = netip.To16()
|
||||||
|
copy(ip[:], netip)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns the string representation of an IP
|
||||||
|
func (ip IP) String() string {
|
||||||
|
// TODO reimplement this without using (net.IP).String()
|
||||||
|
return (net.IP)(ip[:]).String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsIPv4 returns whether the IP is an IPv4 address.
|
||||||
|
func (ip IP) IsIPv4() bool {
|
||||||
|
return bytes.Equal(ip[:12], v4InV6Prefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsLoopback returns whether the IP is a loopback address.
|
||||||
|
func (ip IP) IsLoopback() bool {
|
||||||
|
if ip.IsIPv4() {
|
||||||
|
return ip[12] == 127
|
||||||
|
} else {
|
||||||
|
return ip == IPv6loopback
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func rawCidrMask(length int) (m IP) {
|
||||||
|
n := uint(length)
|
||||||
|
for i := 0; i < 16; i++ {
|
||||||
|
if n >= 8 {
|
||||||
|
m[i] = 0xff
|
||||||
|
n -= 8
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
m[i] = ^byte(0xff >> n)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ip IP) applyMask(mask IP) (result IP) {
|
||||||
|
for i := 0; i < 16; i += 1 {
|
||||||
|
result[i] = ip[i] & mask[i]
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func cidrMask(ones, bits int) (result IP) {
|
||||||
|
switch bits {
|
||||||
|
case 32:
|
||||||
|
return rawCidrMask(96 + ones)
|
||||||
|
case 128:
|
||||||
|
return rawCidrMask(ones)
|
||||||
|
default:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mask returns the result of masking ip with the CIDR mask of
|
||||||
|
// length 'ones', out of a total of 'bits' (which must be either
|
||||||
|
// 32 for an IPv4 subnet or 128 for an IPv6 subnet).
|
||||||
|
func (ip IP) Mask(ones, bits int) (result IP) {
|
||||||
|
return ip.applyMask(cidrMask(ones, bits))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToNetIPNet converts an IPNet into a net.IPNet.
|
||||||
|
func (cidr IPNet) ToNetIPNet() (result net.IPNet) {
|
||||||
|
return net.IPNet{
|
||||||
|
IP: cidr.IP.NetIP(),
|
||||||
|
Mask: net.CIDRMask(int(cidr.PrefixLen), 128),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Contains retuns whether the network contains `ip`.
|
||||||
|
func (cidr IPNet) Contains(ip IP) bool {
|
||||||
|
maskedIP := ip.Mask(int(cidr.PrefixLen), 128)
|
||||||
|
return cidr.IP == maskedIP
|
||||||
|
}
|
||||||
|
|
||||||
|
// FromNetIPnet converts a net.IPNet into an IPNet.
|
||||||
|
func FromNetIPNet(network net.IPNet) (result IPNet) {
|
||||||
|
ones, _ := network.Mask.Size()
|
||||||
|
if len(network.IP) == 16 {
|
||||||
|
copy(result.IP[:], network.IP[:])
|
||||||
|
} else {
|
||||||
|
result.IP[10] = 0xff
|
||||||
|
result.IP[11] = 0xff
|
||||||
|
copy(result.IP[12:], network.IP[:])
|
||||||
|
ones += 96
|
||||||
|
}
|
||||||
|
// perform masking so that equal CIDRs are ==
|
||||||
|
result.IP = result.IP.Mask(ones, 128)
|
||||||
|
result.PrefixLen = uint8(ones)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns a string representation of an IPNet.
|
||||||
|
func (cidr IPNet) String() string {
|
||||||
|
ip := make(net.IP, 16)
|
||||||
|
copy(ip[:], cidr.IP[:])
|
||||||
|
ipnet := net.IPNet{
|
||||||
|
IP: ip,
|
||||||
|
Mask: net.CIDRMask(int(cidr.PrefixLen), 128),
|
||||||
|
}
|
||||||
|
return ipnet.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseCIDR parses a string representation of an IP network in CIDR notation,
|
||||||
|
// then returns it as an IPNet (along with the original, unmasked address).
|
||||||
|
func ParseCIDR(netstr string) (ip IP, ipnet IPNet, err error) {
|
||||||
|
// TODO reimplement this without net.ParseCIDR
|
||||||
|
nip, nipnet, err := net.ParseCIDR(netstr)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return FromNetIP(nip), FromNetIPNet(*nipnet), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// begin ad-hoc utilities
|
||||||
|
|
||||||
|
// ParseToNormalizedNet attempts to interpret a string either as an IP
|
||||||
|
// network in CIDR notation, returning an IPNet, or as an IP address,
|
||||||
|
// returning an IPNet that contains only that address.
|
||||||
|
func ParseToNormalizedNet(netstr string) (ipnet IPNet, err error) {
|
||||||
|
_, ipnet, err = ParseCIDR(netstr)
|
||||||
|
if err == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ip, err := ParseIP(netstr)
|
||||||
|
if err == nil {
|
||||||
|
ipnet.IP = ip
|
||||||
|
ipnet.PrefixLen = 128
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// IPInNets is a convenience function for testing whether an IP is contained
|
||||||
|
// in any member of a slice of IPNet's.
|
||||||
|
func IPInNets(addr IP, nets []IPNet) bool {
|
||||||
|
for _, net := range nets {
|
||||||
|
if net.Contains(addr) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
174
irc/flatip/flatip_test.go
Normal file
174
irc/flatip/flatip_test.go
Normal file
@ -0,0 +1,174 @@
|
|||||||
|
package flatip
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"math/rand"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func easyParseIP(ipstr string) (result net.IP) {
|
||||||
|
result = net.ParseIP(ipstr)
|
||||||
|
if result == nil {
|
||||||
|
panic(ipstr)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func easyParseFlat(ipstr string) (result IP) {
|
||||||
|
x := easyParseIP(ipstr)
|
||||||
|
return FromNetIP(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
func easyParseIPNet(nipstr string) (result net.IPNet) {
|
||||||
|
_, nip, err := net.ParseCIDR(nipstr)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return *nip
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBasic(t *testing.T) {
|
||||||
|
nip := easyParseIP("8.8.8.8")
|
||||||
|
flatip := FromNetIP(nip)
|
||||||
|
if flatip.String() != "8.8.8.8" {
|
||||||
|
t.Errorf("conversions don't work")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoopback(t *testing.T) {
|
||||||
|
localhost_v4 := easyParseFlat("127.0.0.1")
|
||||||
|
localhost_v4_again := easyParseFlat("127.2.3.4")
|
||||||
|
google := easyParseFlat("8.8.8.8")
|
||||||
|
loopback_v6 := easyParseFlat("::1")
|
||||||
|
google_v6 := easyParseFlat("2607:f8b0:4006:801::2004")
|
||||||
|
|
||||||
|
if !(localhost_v4.IsLoopback() && localhost_v4_again.IsLoopback() && loopback_v6.IsLoopback()) {
|
||||||
|
t.Errorf("can't detect loopbacks")
|
||||||
|
}
|
||||||
|
|
||||||
|
if google_v6.IsLoopback() || google.IsLoopback() {
|
||||||
|
t.Errorf("incorrectly detected loopbacks")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContains(t *testing.T) {
|
||||||
|
nipnet := easyParseIPNet("8.8.0.0/16")
|
||||||
|
flatipnet := FromNetIPNet(nipnet)
|
||||||
|
nip := easyParseIP("8.8.8.8")
|
||||||
|
flatip_ := FromNetIP(nip)
|
||||||
|
if !flatipnet.Contains(flatip_) {
|
||||||
|
t.Errorf("contains doesn't work")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var testIPStrs = []string{
|
||||||
|
"8.8.8.8",
|
||||||
|
"127.0.0.1",
|
||||||
|
"1.1.1.1",
|
||||||
|
"128.127.65.64",
|
||||||
|
"2001:0db8::1",
|
||||||
|
"::1",
|
||||||
|
"255.255.255.255",
|
||||||
|
}
|
||||||
|
|
||||||
|
func doMaskingTest(ip net.IP, t *testing.T) {
|
||||||
|
flat := FromNetIP(ip)
|
||||||
|
netLen := len(ip) * 8
|
||||||
|
for i := 0; i < netLen; i++ {
|
||||||
|
masked := flat.Mask(i, netLen)
|
||||||
|
netMask := net.CIDRMask(i, netLen)
|
||||||
|
netMasked := ip.Mask(netMask)
|
||||||
|
if !bytes.Equal(masked[:], netMasked.To16()) {
|
||||||
|
t.Errorf("Masking %s with %d/%d; expected %s, got %s", ip.String(), i, netLen, netMasked.String(), masked.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMasking(t *testing.T) {
|
||||||
|
for _, ipstr := range testIPStrs {
|
||||||
|
doMaskingTest(easyParseIP(ipstr), t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMaskingFuzz(t *testing.T) {
|
||||||
|
r := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||||
|
buf := make([]byte, 4)
|
||||||
|
for i := 0; i < 10000; i++ {
|
||||||
|
r.Read(buf)
|
||||||
|
doMaskingTest(net.IP(buf), t)
|
||||||
|
}
|
||||||
|
|
||||||
|
buf = make([]byte, 16)
|
||||||
|
for i := 0; i < 10000; i++ {
|
||||||
|
r.Read(buf)
|
||||||
|
doMaskingTest(net.IP(buf), t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkMasking(b *testing.B) {
|
||||||
|
ip := easyParseIP("2001:0db8::42")
|
||||||
|
flat := FromNetIP(ip)
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
flat.Mask(64, 128)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkMaskingLegacy(b *testing.B) {
|
||||||
|
ip := easyParseIP("2001:0db8::42")
|
||||||
|
mask := net.CIDRMask(64, 128)
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
ip.Mask(mask)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkMaskingCached(b *testing.B) {
|
||||||
|
i := easyParseIP("2001:0db8::42")
|
||||||
|
flat := FromNetIP(i)
|
||||||
|
mask := cidrMask(64, 128)
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
flat.applyMask(mask)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkMaskingConstruct(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
cidrMask(69, 128)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkContains(b *testing.B) {
|
||||||
|
ip := easyParseIP("2001:0db8::42")
|
||||||
|
flat := FromNetIP(ip)
|
||||||
|
_, ipnet, err := net.ParseCIDR("2001:0db8::/64")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
flatnet := FromNetIPNet(*ipnet)
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
flatnet.Contains(flat)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkContainsLegacy(b *testing.B) {
|
||||||
|
ip := easyParseIP("2001:0db8::42")
|
||||||
|
_, ipnetptr, err := net.ParseCIDR("2001:0db8::/64")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
ipnet := *ipnetptr
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
ipnet.Contains(ip)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user