diff --git a/Makefile b/Makefile index 8659506c..e69da6c4 100644 --- a/Makefile +++ b/Makefile @@ -25,6 +25,7 @@ test: cd irc/cloaks && go test . && go vet . cd irc/connection_limits && go test . && go vet . cd irc/email && go test . && go vet . + cd irc/flatip && go test . && go vet . cd irc/history && go test . && go vet . cd irc/isupport && go test . && go vet . cd irc/migrations && go test . && go vet . diff --git a/default.yaml b/default.yaml index 4a394bcb..1ab18be0 100644 --- a/default.yaml +++ b/default.yaml @@ -247,9 +247,6 @@ server: window: 10m # maximum number of new connections per IP/CIDR within the given duration max-connections-per-window: 32 - # how long to ban offenders for. after banning them, the number of connections is - # reset, which lets you use /UNDLINE to unban people - throttle-ban-duration: 10m # how wide the CIDR should be for IPv4 (a /32 is a fully specified IPv4 address) cidr-len-ipv4: 32 diff --git a/irc/client.go b/irc/client.go index 38b7287e..68bc680b 100644 --- a/irc/client.go +++ b/irc/client.go @@ -21,6 +21,7 @@ import ( ident "github.com/oragono/go-ident" "github.com/oragono/oragono/irc/caps" "github.com/oragono/oragono/irc/connection_limits" + "github.com/oragono/oragono/irc/flatip" "github.com/oragono/oragono/irc/history" "github.com/oragono/oragono/irc/modes" "github.com/oragono/oragono/irc/sno" @@ -1477,7 +1478,7 @@ func (client *Client) destroy(session *Session) { if session.proxiedIP != nil { ip = session.proxiedIP } - client.server.connectionLimiter.RemoveClient(ip) + client.server.connectionLimiter.RemoveClient(flatip.FromNetIP(ip)) source = ip.String() } client.server.logger.Info("connect-ip", fmt.Sprintf("disconnecting session of %s from %s", details.nick, source)) diff --git a/irc/connection_limits/limiter.go b/irc/connection_limits/limiter.go index 988448ac..867cb0f8 100644 --- a/irc/connection_limits/limiter.go +++ b/irc/connection_limits/limiter.go @@ -4,12 +4,13 @@ package connection_limits import ( + "crypto/md5" "errors" "fmt" - "net" "sync" "time" + "github.com/oragono/oragono/irc/flatip" "github.com/oragono/oragono/irc/utils" ) @@ -26,10 +27,15 @@ type CustomLimitConfig struct { // tuples the key-value pair of a CIDR and its custom limit/throttle values type customLimit struct { - name string + name [16]byte maxConcurrent int maxPerWindow int - nets []net.IPNet + nets []flatip.IPNet +} + +type limiterKey struct { + maskedIP flatip.IP + prefixLen uint8 // 0 for the fake nets we generate for custom limits } // LimiterConfig controls the automated connection limits. @@ -41,8 +47,7 @@ type rawLimiterConfig struct { Throttle bool Window time.Duration - MaxPerWindow int `yaml:"max-connections-per-window"` - BanDuration time.Duration `yaml:"throttle-ban-duration"` + MaxPerWindow int `yaml:"max-connections-per-window"` CidrLenIPv4 int `yaml:"cidr-len-ipv4"` CidrLenIPv6 int `yaml:"cidr-len-ipv6"` @@ -55,9 +60,7 @@ type rawLimiterConfig struct { type LimiterConfig struct { rawLimiterConfig - ipv4Mask net.IPMask - ipv6Mask net.IPMask - exemptedNets []net.IPNet + exemptedNets []flatip.IPNet customLimits []customLimit } @@ -69,15 +72,19 @@ func (config *LimiterConfig) UnmarshalYAML(unmarshal func(interface{}) error) (e } func (config *LimiterConfig) postprocess() (err error) { - config.exemptedNets, err = utils.ParseNetList(config.Exempted) + exemptedNets, err := utils.ParseNetList(config.Exempted) if err != nil { return fmt.Errorf("Could not parse limiter exemption list: %v", err.Error()) } + config.exemptedNets = make([]flatip.IPNet, len(exemptedNets)) + for i, exempted := range exemptedNets { + config.exemptedNets[i] = flatip.FromNetIPNet(exempted) + } for identifier, customLimitConf := range config.CustomLimits { - nets := make([]net.IPNet, len(customLimitConf.Nets)) + nets := make([]flatip.IPNet, len(customLimitConf.Nets)) for i, netStr := range customLimitConf.Nets { - normalizedNet, err := utils.NormalizedNetFromString(netStr) + normalizedNet, err := flatip.ParseToNormalizedNet(netStr) if err != nil { return fmt.Errorf("Bad net %s in custom-limits block %s: %w", netStr, identifier, err) } @@ -86,23 +93,20 @@ func (config *LimiterConfig) postprocess() (err error) { if len(customLimitConf.Nets) == 0 { // see #1421: this is the legacy config format where the // dictionary key of the block is a CIDR string - normalizedNet, err := utils.NormalizedNetFromString(identifier) + normalizedNet, err := flatip.ParseToNormalizedNet(identifier) if err != nil { return fmt.Errorf("Custom limit block %s has no defined nets", identifier) } - nets = []net.IPNet{normalizedNet} + nets = []flatip.IPNet{normalizedNet} } config.customLimits = append(config.customLimits, customLimit{ maxConcurrent: customLimitConf.MaxConcurrent, maxPerWindow: customLimitConf.MaxPerWindow, - name: "*" + identifier, + name: md5.Sum([]byte(identifier)), nets: nets, }) } - config.ipv4Mask = net.CIDRMask(config.CidrLenIPv4, 32) - config.ipv6Mask = net.CIDRMask(config.CidrLenIPv6, 128) - return nil } @@ -113,53 +117,56 @@ type Limiter struct { config *LimiterConfig // IP/CIDR -> count of clients connected from there: - limiter map[string]int + limiter map[limiterKey]int // IP/CIDR -> throttle state: - throttler map[string]ThrottleDetails + throttler map[limiterKey]ThrottleDetails } // addrToKey canonicalizes `addr` to a string key, and returns // the relevant connection limit and throttle max-per-window values -func (cl *Limiter) addrToKey(addr net.IP) (key string, limit int, throttle int) { - // `key` will be a CIDR string like "8.8.8.8/32" or "2001:0db8::/32" +func (cl *Limiter) addrToKey(addr flatip.IP) (key limiterKey, limit int, throttle int) { for _, custom := range cl.config.customLimits { for _, net := range custom.nets { if net.Contains(addr) { - return custom.name, custom.maxConcurrent, custom.maxPerWindow + return limiterKey{maskedIP: custom.name, prefixLen: 0}, custom.maxConcurrent, custom.maxPerWindow } } } - var ipNet net.IPNet - addrv4 := addr.To4() - if addrv4 != nil { - ipNet = net.IPNet{ - IP: addrv4.Mask(cl.config.ipv4Mask), - Mask: cl.config.ipv4Mask, - } + var prefixLen int + if addr.IsIPv4() { + prefixLen = cl.config.CidrLenIPv4 + addr = addr.Mask(prefixLen, 32) + prefixLen += 96 } else { - ipNet = net.IPNet{ - IP: addr.Mask(cl.config.ipv6Mask), - Mask: cl.config.ipv6Mask, - } + prefixLen = cl.config.CidrLenIPv6 + addr = addr.Mask(prefixLen, 128) } - return ipNet.String(), cl.config.MaxConcurrent, cl.config.MaxPerWindow + + return limiterKey{maskedIP: addr, prefixLen: uint8(prefixLen)}, cl.config.MaxConcurrent, cl.config.MaxPerWindow } // AddClient adds a client to our population if possible. If we can't, throws an error instead. -func (cl *Limiter) AddClient(addr net.IP) error { +func (cl *Limiter) AddClient(addr flatip.IP) error { cl.Lock() defer cl.Unlock() // we don't track populations for exempted addresses or nets - this is by design - if utils.IPInNets(addr, cl.config.exemptedNets) { + if flatip.IPInNets(addr, cl.config.exemptedNets) { return nil } addrString, maxConcurrent, maxPerWindow := cl.addrToKey(addr) - // XXX check throttle first; if we checked limit first and then checked throttle, - // we'd have to decrement the limit on an unsuccessful throttle check + // check limiter + var count int + if cl.config.Count { + count = cl.limiter[addrString] + 1 + if count > maxConcurrent { + return ErrLimitExceeded + } + } + if cl.config.Throttle { details := cl.throttler[addrString] // retrieve mutable throttle state from the map // add in constant state to process the limiting operation @@ -171,16 +178,13 @@ func (cl *Limiter) AddClient(addr net.IP) error { throttled, _ := g.Touch() // actually check the limit cl.throttler[addrString] = g.ThrottleDetails // store modified mutable state if throttled { + // back out the limiter add return ErrThrottleExceeded } } - // now check limiter + // success, record in limiter if cl.config.Count { - count := cl.limiter[addrString] + 1 - if count > maxConcurrent { - return ErrLimitExceeded - } cl.limiter[addrString] = count } @@ -188,11 +192,11 @@ func (cl *Limiter) AddClient(addr net.IP) error { } // RemoveClient removes the given address from our population -func (cl *Limiter) RemoveClient(addr net.IP) { +func (cl *Limiter) RemoveClient(addr flatip.IP) { cl.Lock() defer cl.Unlock() - if !cl.config.Count || utils.IPInNets(addr, cl.config.exemptedNets) { + if !cl.config.Count || flatip.IPInNets(addr, cl.config.exemptedNets) { return } @@ -206,11 +210,11 @@ func (cl *Limiter) RemoveClient(addr net.IP) { } // ResetThrottle resets the throttle count for an IP -func (cl *Limiter) ResetThrottle(addr net.IP) { +func (cl *Limiter) ResetThrottle(addr flatip.IP) { cl.Lock() defer cl.Unlock() - if !cl.config.Throttle || utils.IPInNets(addr, cl.config.exemptedNets) { + if !cl.config.Throttle || flatip.IPInNets(addr, cl.config.exemptedNets) { return } @@ -224,10 +228,10 @@ func (cl *Limiter) ApplyConfig(config *LimiterConfig) { defer cl.Unlock() if cl.limiter == nil { - cl.limiter = make(map[string]int) + cl.limiter = make(map[limiterKey]int) } if cl.throttler == nil { - cl.throttler = make(map[string]ThrottleDetails) + cl.throttler = make(map[limiterKey]ThrottleDetails) } cl.config = config diff --git a/irc/connection_limits/limiter_test.go b/irc/connection_limits/limiter_test.go index bf852b58..3bc0b39e 100644 --- a/irc/connection_limits/limiter_test.go +++ b/irc/connection_limits/limiter_test.go @@ -4,15 +4,17 @@ package connection_limits import ( - "net" + "crypto/md5" "testing" "time" + + "github.com/oragono/oragono/irc/flatip" ) -func easyParseIP(ipstr string) (result net.IP) { - result = net.ParseIP(ipstr) - if result == nil { - panic(ipstr) +func easyParseIP(ipstr string) (result flatip.IP) { + result, err := flatip.ParseIP(ipstr) + if err != nil { + panic(err) } return } @@ -47,18 +49,23 @@ func TestKeying(t *testing.T) { var limiter Limiter limiter.ApplyConfig(&config) + // an ipv4 /32 looks like a /128 to us after applying the 4-in-6 mapping key, maxConc, maxWin := limiter.addrToKey(easyParseIP("1.1.1.1")) - assertEqual(key, "1.1.1.1/32", t) + assertEqual(key.prefixLen, uint8(128), t) + assertEqual(key.maskedIP[12:], []byte{1, 1, 1, 1}, t) assertEqual(maxConc, 4, t) assertEqual(maxWin, 8, t) - key, maxConc, maxWin = limiter.addrToKey(easyParseIP("2607:5301:201:3100::7426")) - assertEqual(key, "2607:5301:201:3100::/64", t) + testIPv6 := easyParseIP("2607:5301:201:3100::7426") + key, maxConc, maxWin = limiter.addrToKey(testIPv6) + assertEqual(key.prefixLen, uint8(64), t) + assertEqual(flatip.IP(key.maskedIP), easyParseIP("2607:5301:201:3100::"), t) assertEqual(maxConc, 4, t) assertEqual(maxWin, 8, t) key, maxConc, maxWin = limiter.addrToKey(easyParseIP("8.8.4.4")) - assertEqual(key, "*google", t) + assertEqual(key.prefixLen, uint8(0), t) + assertEqual([16]byte(key.maskedIP), md5.Sum([]byte("google")), t) assertEqual(maxConc, 128, t) assertEqual(maxWin, 256, t) } diff --git a/irc/connection_limits/throttler_test.go b/irc/connection_limits/throttler_test.go index 093c02bf..d44ec860 100644 --- a/irc/connection_limits/throttler_test.go +++ b/irc/connection_limits/throttler_test.go @@ -4,7 +4,6 @@ package connection_limits import ( - "net" "reflect" "testing" "time" @@ -83,7 +82,7 @@ func makeTestThrottler(v4len, v6len int) *Limiter { func TestConnectionThrottle(t *testing.T) { throttler := makeTestThrottler(32, 64) - addr := net.ParseIP("8.8.8.8") + addr := easyParseIP("8.8.8.8") for i := 0; i < 3; i += 1 { err := throttler.AddClient(addr) @@ -97,14 +96,14 @@ func TestConnectionThrottleIPv6(t *testing.T) { throttler := makeTestThrottler(32, 64) var err error - err = throttler.AddClient(net.ParseIP("2001:0db8::1")) + err = throttler.AddClient(easyParseIP("2001:0db8::1")) assertEqual(err, nil, t) - err = throttler.AddClient(net.ParseIP("2001:0db8::2")) + err = throttler.AddClient(easyParseIP("2001:0db8::2")) assertEqual(err, nil, t) - err = throttler.AddClient(net.ParseIP("2001:0db8::3")) + err = throttler.AddClient(easyParseIP("2001:0db8::3")) assertEqual(err, nil, t) - err = throttler.AddClient(net.ParseIP("2001:0db8::4")) + err = throttler.AddClient(easyParseIP("2001:0db8::4")) assertEqual(err, ErrThrottleExceeded, t) } @@ -112,13 +111,13 @@ func TestConnectionThrottleIPv4(t *testing.T) { throttler := makeTestThrottler(24, 64) var err error - err = throttler.AddClient(net.ParseIP("192.168.1.101")) + err = throttler.AddClient(easyParseIP("192.168.1.101")) assertEqual(err, nil, t) - err = throttler.AddClient(net.ParseIP("192.168.1.102")) + err = throttler.AddClient(easyParseIP("192.168.1.102")) assertEqual(err, nil, t) - err = throttler.AddClient(net.ParseIP("192.168.1.103")) + err = throttler.AddClient(easyParseIP("192.168.1.103")) assertEqual(err, nil, t) - err = throttler.AddClient(net.ParseIP("192.168.1.104")) + err = throttler.AddClient(easyParseIP("192.168.1.104")) assertEqual(err, ErrThrottleExceeded, t) } diff --git a/irc/dline.go b/irc/dline.go index d22d90b4..3de5a621 100644 --- a/irc/dline.go +++ b/irc/dline.go @@ -11,6 +11,7 @@ import ( "sync" "time" + "github.com/oragono/oragono/irc/flatip" "github.com/oragono/oragono/irc/utils" "github.com/tidwall/buntdb" ) @@ -54,34 +55,22 @@ func (info IPBanInfo) BanMessage(message string) string { return message } -// dLineNet contains the net itself and expiration time for a given network. -type dLineNet struct { - // Network is the network that is blocked. - // This is always an IPv6 CIDR; IPv4 CIDRs are translated with the 4-in-6 prefix, - // individual IPv4 and IPV6 addresses are translated to the relevant /128. - Network net.IPNet - // Info contains information on the ban. - Info IPBanInfo -} - // DLineManager manages and dlines. type DLineManager struct { sync.RWMutex // tier 1 persistenceMutex sync.Mutex // tier 2 // networks that are dlined: - // XXX: the keys of this map (which are also the database persistence keys) - // are the human-readable representations returned by NetToNormalizedString - networks map[string]dLineNet + networks map[flatip.IPNet]IPBanInfo // this keeps track of expiration timers for temporary bans - expirationTimers map[string]*time.Timer + expirationTimers map[flatip.IPNet]*time.Timer server *Server } // NewDLineManager returns a new DLineManager. func NewDLineManager(server *Server) *DLineManager { var dm DLineManager - dm.networks = make(map[string]dLineNet) - dm.expirationTimers = make(map[string]*time.Timer) + dm.networks = make(map[flatip.IPNet]IPBanInfo) + dm.expirationTimers = make(map[flatip.IPNet]*time.Timer) dm.server = server dm.loadFromDatastore() @@ -96,9 +85,8 @@ func (dm *DLineManager) AllBans() map[string]IPBanInfo { dm.RLock() defer dm.RUnlock() - // map keys are already the human-readable forms, just return a copy of the map for key, info := range dm.networks { - allb[key] = info.Info + allb[key.String()] = info } return allb @@ -122,9 +110,9 @@ func (dm *DLineManager) AddNetwork(network net.IPNet, duration time.Duration, re return dm.persistDline(id, info) } -func (dm *DLineManager) addNetworkInternal(network net.IPNet, info IPBanInfo) (id string) { - network = utils.NormalizeNet(network) - id = utils.NetToNormalizedString(network) +func (dm *DLineManager) addNetworkInternal(network net.IPNet, info IPBanInfo) (id flatip.IPNet) { + flatnet := flatip.FromNetIPNet(network) + id = flatnet var timeLeft time.Duration if info.Duration != 0 { @@ -137,12 +125,9 @@ func (dm *DLineManager) addNetworkInternal(network net.IPNet, info IPBanInfo) (i dm.Lock() defer dm.Unlock() - dm.networks[id] = dLineNet{ - Network: network, - Info: info, - } + dm.networks[flatnet] = info - dm.cancelTimer(id) + dm.cancelTimer(flatnet) if info.Duration == 0 { return @@ -154,29 +139,29 @@ func (dm *DLineManager) addNetworkInternal(network net.IPNet, info IPBanInfo) (i dm.Lock() defer dm.Unlock() - netBan, ok := dm.networks[id] - if ok && netBan.Info.TimeCreated.Equal(timeCreated) { - delete(dm.networks, id) + banInfo, ok := dm.networks[flatnet] + if ok && banInfo.TimeCreated.Equal(timeCreated) { + delete(dm.networks, flatnet) // TODO(slingamn) here's where we'd remove it from the radix tree - delete(dm.expirationTimers, id) + delete(dm.expirationTimers, flatnet) } } - dm.expirationTimers[id] = time.AfterFunc(timeLeft, processExpiration) + dm.expirationTimers[flatnet] = time.AfterFunc(timeLeft, processExpiration) return } -func (dm *DLineManager) cancelTimer(id string) { - oldTimer := dm.expirationTimers[id] +func (dm *DLineManager) cancelTimer(flatnet flatip.IPNet) { + oldTimer := dm.expirationTimers[flatnet] if oldTimer != nil { oldTimer.Stop() - delete(dm.expirationTimers, id) + delete(dm.expirationTimers, flatnet) } } -func (dm *DLineManager) persistDline(id string, info IPBanInfo) error { +func (dm *DLineManager) persistDline(id flatip.IPNet, info IPBanInfo) error { // save in datastore - dlineKey := fmt.Sprintf(keyDlineEntry, id) + dlineKey := fmt.Sprintf(keyDlineEntry, id.String()) // assemble json from ban info b, err := json.Marshal(info) if err != nil { @@ -199,8 +184,8 @@ func (dm *DLineManager) persistDline(id string, info IPBanInfo) error { return err } -func (dm *DLineManager) unpersistDline(id string) error { - dlineKey := fmt.Sprintf(keyDlineEntry, id) +func (dm *DLineManager) unpersistDline(id flatip.IPNet) error { + dlineKey := fmt.Sprintf(keyDlineEntry, id.String()) return dm.server.store.Update(func(tx *buntdb.Tx) error { _, err := tx.Delete(dlineKey) return err @@ -212,7 +197,7 @@ func (dm *DLineManager) RemoveNetwork(network net.IPNet) error { dm.persistenceMutex.Lock() defer dm.persistenceMutex.Unlock() - id := utils.NetToNormalizedString(utils.NormalizeNet(network)) + id := flatip.FromNetIPNet(network) present := func() bool { dm.Lock() @@ -241,8 +226,7 @@ func (dm *DLineManager) RemoveIP(addr net.IP) error { } // CheckIP returns whether or not an IP address was banned, and how long it is banned for. -func (dm *DLineManager) CheckIP(addr net.IP) (isBanned bool, info IPBanInfo) { - addr = addr.To16() // almost certainly unnecessary +func (dm *DLineManager) CheckIP(addr flatip.IP) (isBanned bool, info IPBanInfo) { if addr.IsLoopback() { return // #671 } @@ -252,13 +236,12 @@ func (dm *DLineManager) CheckIP(addr net.IP) (isBanned bool, info IPBanInfo) { // check networks // TODO(slingamn) use a radix tree as the data plane for this - for _, netBan := range dm.networks { - if netBan.Network.Contains(addr) { - return true, netBan.Info + for flatnet, info := range dm.networks { + if flatnet.Contains(addr) { + return true, info } } // no matches! - isBanned = false return } diff --git a/irc/flatip/adhoc.go b/irc/flatip/adhoc.go new file mode 100644 index 00000000..6c994c56 --- /dev/null +++ b/irc/flatip/adhoc.go @@ -0,0 +1,33 @@ +// Copyright 2020 Shivaram Lingamneni +// Released under the MIT license + +package flatip + +// begin ad-hoc utilities + +// ParseToNormalizedNet attempts to interpret a string either as an IP +// network in CIDR notation, returning an IPNet, or as an IP address, +// returning an IPNet that contains only that address. +func ParseToNormalizedNet(netstr string) (ipnet IPNet, err error) { + _, ipnet, err = ParseCIDR(netstr) + if err == nil { + return + } + ip, err := ParseIP(netstr) + if err == nil { + ipnet.IP = ip + ipnet.PrefixLen = 128 + } + return +} + +// IPInNets is a convenience function for testing whether an IP is contained +// in any member of a slice of IPNet's. +func IPInNets(addr IP, nets []IPNet) bool { + for _, net := range nets { + if net.Contains(addr) { + return true + } + } + return false +} diff --git a/irc/flatip/flatip.go b/irc/flatip/flatip.go new file mode 100644 index 00000000..7ebdbb50 --- /dev/null +++ b/irc/flatip/flatip.go @@ -0,0 +1,202 @@ +// Copyright 2020 Shivaram Lingamneni +// Copyright 2009 The Go Authors +// Released under the MIT license + +package flatip + +import ( + "bytes" + "errors" + "net" +) + +var ( + v4InV6Prefix = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff} + + IPv6loopback = IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} + IPv6zero = IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} + IPv4zero = IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 0, 0, 0, 0} + + ErrInvalidIPString = errors.New("String could not be interpreted as an IP address") +) + +// packed versions of net.IP and net.IPNet; these are pure value types, +// so they can be compared with == and used as map keys. + +// IP is a 128-bit representation of an IP address, using the 4-in-6 mapping +// to represent IPv4 addresses. +type IP [16]byte + +// IPNet is a IP network. In a valid value, all bits after PrefixLen are zeroes. +type IPNet struct { + IP + PrefixLen uint8 +} + +// NetIP converts an IP into a net.IP. +func (ip IP) NetIP() (result net.IP) { + result = make(net.IP, 16) + copy(result[:], ip[:]) + return +} + +// FromNetIP converts a net.IP into an IP. +func FromNetIP(ip net.IP) (result IP) { + if len(ip) == 16 { + copy(result[:], ip[:]) + } else { + result[10] = 0xff + result[11] = 0xff + copy(result[12:], ip[:]) + } + return +} + +// IPv4 returns the IP address representation of a.b.c.d +func IPv4(a, b, c, d byte) (result IP) { + copy(result[:12], v4InV6Prefix) + result[12] = a + result[13] = b + result[14] = c + result[15] = d + return +} + +// ParseIP parses a string representation of an IP address into an IP. +// Unlike net.ParseIP, it returns an error instead of a zero value on failure, +// since the zero value of `IP` is a representation of a valid IP (::0, the +// IPv6 "unspecified address"). +func ParseIP(ipstr string) (ip IP, err error) { + // TODO reimplement this without net.ParseIP + netip := net.ParseIP(ipstr) + if netip == nil { + err = ErrInvalidIPString + return + } + netip = netip.To16() + copy(ip[:], netip) + return +} + +// String returns the string representation of an IP +func (ip IP) String() string { + // TODO reimplement this without using (net.IP).String() + return (net.IP)(ip[:]).String() +} + +// IsIPv4 returns whether the IP is an IPv4 address. +func (ip IP) IsIPv4() bool { + return bytes.Equal(ip[:12], v4InV6Prefix) +} + +// IsLoopback returns whether the IP is a loopback address. +func (ip IP) IsLoopback() bool { + if ip.IsIPv4() { + return ip[12] == 127 + } else { + return ip == IPv6loopback + } +} + +func (ip IP) IsUnspecified() bool { + return ip == IPv4zero || ip == IPv6zero +} + +func rawCidrMask(length int) (m IP) { + n := uint(length) + for i := 0; i < 16; i++ { + if n >= 8 { + m[i] = 0xff + n -= 8 + continue + } + m[i] = ^byte(0xff >> n) + return + } + return +} + +func (ip IP) applyMask(mask IP) (result IP) { + for i := 0; i < 16; i += 1 { + result[i] = ip[i] & mask[i] + } + return +} + +func cidrMask(ones, bits int) (result IP) { + switch bits { + case 32: + return rawCidrMask(96 + ones) + case 128: + return rawCidrMask(ones) + default: + return + } +} + +// Mask returns the result of masking ip with the CIDR mask of +// length 'ones', out of a total of 'bits' (which must be either +// 32 for an IPv4 subnet or 128 for an IPv6 subnet). +func (ip IP) Mask(ones, bits int) (result IP) { + return ip.applyMask(cidrMask(ones, bits)) +} + +// ToNetIPNet converts an IPNet into a net.IPNet. +func (cidr IPNet) ToNetIPNet() (result net.IPNet) { + return net.IPNet{ + IP: cidr.IP.NetIP(), + Mask: net.CIDRMask(int(cidr.PrefixLen), 128), + } +} + +// Contains retuns whether the network contains `ip`. +func (cidr IPNet) Contains(ip IP) bool { + maskedIP := ip.Mask(int(cidr.PrefixLen), 128) + return cidr.IP == maskedIP +} + +// FromNetIPnet converts a net.IPNet into an IPNet. +func FromNetIPNet(network net.IPNet) (result IPNet) { + ones, _ := network.Mask.Size() + if len(network.IP) == 16 { + copy(result.IP[:], network.IP[:]) + } else { + result.IP[10] = 0xff + result.IP[11] = 0xff + copy(result.IP[12:], network.IP[:]) + ones += 96 + } + // perform masking so that equal CIDRs are == + result.IP = result.IP.Mask(ones, 128) + result.PrefixLen = uint8(ones) + return +} + +// String returns a string representation of an IPNet. +func (cidr IPNet) String() string { + ip := make(net.IP, 16) + copy(ip[:], cidr.IP[:]) + ipnet := net.IPNet{ + IP: ip, + Mask: net.CIDRMask(int(cidr.PrefixLen), 128), + } + return ipnet.String() +} + +// IsZero tests whether ipnet is the zero value of an IPNet, 0::0/0. +// Although this is a valid subnet, it can still be used as a sentinel +// value in some contexts. +func (ipnet IPNet) IsZero() bool { + return ipnet == IPNet{} +} + +// ParseCIDR parses a string representation of an IP network in CIDR notation, +// then returns it as an IPNet (along with the original, unmasked address). +func ParseCIDR(netstr string) (ip IP, ipnet IPNet, err error) { + // TODO reimplement this without net.ParseCIDR + nip, nipnet, err := net.ParseCIDR(netstr) + if err != nil { + return + } + return FromNetIP(nip), FromNetIPNet(*nipnet), nil +} diff --git a/irc/flatip/flatip_test.go b/irc/flatip/flatip_test.go new file mode 100644 index 00000000..c2aae9a8 --- /dev/null +++ b/irc/flatip/flatip_test.go @@ -0,0 +1,174 @@ +package flatip + +import ( + "bytes" + "math/rand" + "net" + "testing" + "time" +) + +func easyParseIP(ipstr string) (result net.IP) { + result = net.ParseIP(ipstr) + if result == nil { + panic(ipstr) + } + return +} + +func easyParseFlat(ipstr string) (result IP) { + x := easyParseIP(ipstr) + return FromNetIP(x) +} + +func easyParseIPNet(nipstr string) (result net.IPNet) { + _, nip, err := net.ParseCIDR(nipstr) + if err != nil { + panic(err) + } + return *nip +} + +func TestBasic(t *testing.T) { + nip := easyParseIP("8.8.8.8") + flatip := FromNetIP(nip) + if flatip.String() != "8.8.8.8" { + t.Errorf("conversions don't work") + } +} + +func TestLoopback(t *testing.T) { + localhost_v4 := easyParseFlat("127.0.0.1") + localhost_v4_again := easyParseFlat("127.2.3.4") + google := easyParseFlat("8.8.8.8") + loopback_v6 := easyParseFlat("::1") + google_v6 := easyParseFlat("2607:f8b0:4006:801::2004") + + if !(localhost_v4.IsLoopback() && localhost_v4_again.IsLoopback() && loopback_v6.IsLoopback()) { + t.Errorf("can't detect loopbacks") + } + + if google_v6.IsLoopback() || google.IsLoopback() { + t.Errorf("incorrectly detected loopbacks") + } +} + +func TestContains(t *testing.T) { + nipnet := easyParseIPNet("8.8.0.0/16") + flatipnet := FromNetIPNet(nipnet) + nip := easyParseIP("8.8.8.8") + flatip_ := FromNetIP(nip) + if !flatipnet.Contains(flatip_) { + t.Errorf("contains doesn't work") + } +} + +var testIPStrs = []string{ + "8.8.8.8", + "127.0.0.1", + "1.1.1.1", + "128.127.65.64", + "2001:0db8::1", + "::1", + "255.255.255.255", +} + +func doMaskingTest(ip net.IP, t *testing.T) { + flat := FromNetIP(ip) + netLen := len(ip) * 8 + for i := 0; i < netLen; i++ { + masked := flat.Mask(i, netLen) + netMask := net.CIDRMask(i, netLen) + netMasked := ip.Mask(netMask) + if !bytes.Equal(masked[:], netMasked.To16()) { + t.Errorf("Masking %s with %d/%d; expected %s, got %s", ip.String(), i, netLen, netMasked.String(), masked.String()) + } + } +} + +func TestMasking(t *testing.T) { + for _, ipstr := range testIPStrs { + doMaskingTest(easyParseIP(ipstr), t) + } +} + +func TestMaskingFuzz(t *testing.T) { + r := rand.New(rand.NewSource(time.Now().UnixNano())) + buf := make([]byte, 4) + for i := 0; i < 10000; i++ { + r.Read(buf) + doMaskingTest(net.IP(buf), t) + } + + buf = make([]byte, 16) + for i := 0; i < 10000; i++ { + r.Read(buf) + doMaskingTest(net.IP(buf), t) + } +} + +func BenchmarkMasking(b *testing.B) { + ip := easyParseIP("2001:0db8::42") + flat := FromNetIP(ip) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + flat.Mask(64, 128) + } +} + +func BenchmarkMaskingLegacy(b *testing.B) { + ip := easyParseIP("2001:0db8::42") + mask := net.CIDRMask(64, 128) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + ip.Mask(mask) + } +} + +func BenchmarkMaskingCached(b *testing.B) { + i := easyParseIP("2001:0db8::42") + flat := FromNetIP(i) + mask := cidrMask(64, 128) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + flat.applyMask(mask) + } +} + +func BenchmarkMaskingConstruct(b *testing.B) { + for i := 0; i < b.N; i++ { + cidrMask(69, 128) + } +} + +func BenchmarkContains(b *testing.B) { + ip := easyParseIP("2001:0db8::42") + flat := FromNetIP(ip) + _, ipnet, err := net.ParseCIDR("2001:0db8::/64") + if err != nil { + panic(err) + } + flatnet := FromNetIPNet(*ipnet) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + flatnet.Contains(flat) + } +} + +func BenchmarkContainsLegacy(b *testing.B) { + ip := easyParseIP("2001:0db8::42") + _, ipnetptr, err := net.ParseCIDR("2001:0db8::/64") + if err != nil { + panic(err) + } + ipnet := *ipnetptr + b.ResetTimer() + + for i := 0; i < b.N; i++ { + ipnet.Contains(ip) + } +} diff --git a/irc/gateways.go b/irc/gateways.go index 388687bc..e7dbbecf 100644 --- a/irc/gateways.go +++ b/irc/gateways.go @@ -9,6 +9,7 @@ import ( "errors" "net" + "github.com/oragono/oragono/irc/flatip" "github.com/oragono/oragono/irc/modes" "github.com/oragono/oragono/irc/utils" ) @@ -87,7 +88,7 @@ func (client *Client) ApplyProxiedIP(session *Session, proxiedIP net.IP, tls boo } // successfully added a limiter entry for the proxied IP; // remove the entry for the real IP if applicable (#197) - client.server.connectionLimiter.RemoveClient(session.realIP) + client.server.connectionLimiter.RemoveClient(flatip.FromNetIP(session.realIP)) // given IP is sane! override the client's current IP client.server.logger.Info("connect-ip", "Accepted proxy IP for client", proxiedIP.String()) diff --git a/irc/handlers.go b/irc/handlers.go index 8bf96e35..a03d2368 100644 --- a/irc/handlers.go +++ b/irc/handlers.go @@ -24,6 +24,7 @@ import ( "github.com/goshuirc/irc-go/ircmsg" "github.com/oragono/oragono/irc/caps" "github.com/oragono/oragono/irc/custime" + "github.com/oragono/oragono/irc/flatip" "github.com/oragono/oragono/irc/history" "github.com/oragono/oragono/irc/jwt" "github.com/oragono/oragono/irc/modes" @@ -2798,6 +2799,11 @@ func unDLineHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *R // get host hostString := msg.Params[0] + // TODO(#1447) consolidate this into the "unban" command + if flatip, ipErr := flatip.ParseIP(hostString); ipErr == nil { + server.connectionLimiter.ResetThrottle(flatip) + } + // check host hostNet, err := utils.NormalizedNetFromString(hostString) diff --git a/irc/server.go b/irc/server.go index 1f11c122..397ba1d1 100644 --- a/irc/server.go +++ b/irc/server.go @@ -23,6 +23,7 @@ import ( "github.com/oragono/oragono/irc/caps" "github.com/oragono/oragono/irc/connection_limits" + "github.com/oragono/oragono/irc/flatip" "github.com/oragono/oragono/irc/history" "github.com/oragono/oragono/irc/logger" "github.com/oragono/oragono/irc/modes" @@ -160,31 +161,23 @@ func (server *Server) checkBans(config *Config, ipaddr net.IP, checkScripts bool } } + flat := flatip.FromNetIP(ipaddr) + // check DLINEs - isBanned, info := server.dlines.CheckIP(ipaddr) + isBanned, info := server.dlines.CheckIP(flat) if isBanned { - server.logger.Info("connect-ip", fmt.Sprintf("Client from %v rejected by d-line", ipaddr)) + server.logger.Info("connect-ip", "Client rejected by d-line", ipaddr.String()) return true, false, info.BanMessage("You are banned from this server (%s)") } // check connection limits - err := server.connectionLimiter.AddClient(ipaddr) + err := server.connectionLimiter.AddClient(flat) if err == connection_limits.ErrLimitExceeded { // too many connections from one client, tell the client and close the connection - server.logger.Info("connect-ip", fmt.Sprintf("Client from %v rejected for connection limit", ipaddr)) + server.logger.Info("connect-ip", "Client rejected for connection limit", ipaddr.String()) return true, false, "Too many clients from your network" } else if err == connection_limits.ErrThrottleExceeded { - duration := config.Server.IPLimits.BanDuration - if duration != 0 { - server.dlines.AddIP(ipaddr, duration, throttleMessage, - "Exceeded automated connection throttle", "auto.connection.throttler") - // they're DLINE'd for 15 minutes or whatever, so we can reset the connection throttle now, - // and once their temporary DLINE is finished they can fill up the throttler again - server.connectionLimiter.ResetThrottle(ipaddr) - } - server.logger.Info( - "connect-ip", - fmt.Sprintf("Client from %v exceeded connection throttle, d-lining for %v", ipaddr, duration)) + server.logger.Info("connect-ip", "Client exceeded connection throttle", ipaddr.String()) return true, false, throttleMessage } else if err != nil { server.logger.Warning("internal", "unexpected ban result", err.Error()) @@ -211,7 +204,7 @@ func (server *Server) checkBans(config *Config, ipaddr net.IP, checkScripts bool } if output.Result == IPBanned { // XXX roll back IP connection/throttling addition for the IP - server.connectionLimiter.RemoveClient(ipaddr) + server.connectionLimiter.RemoveClient(flat) server.logger.Info("connect-ip", "Rejected client due to ip-check-script", ipaddr.String()) return true, false, output.BanMessage } else if output.Result == IPRequireSASL { diff --git a/traditional.yaml b/traditional.yaml index 227ed980..5f7c79aa 100644 --- a/traditional.yaml +++ b/traditional.yaml @@ -220,9 +220,6 @@ server: window: 10m # maximum number of new connections per IP/CIDR within the given duration max-connections-per-window: 32 - # how long to ban offenders for. after banning them, the number of connections is - # reset, which lets you use /UNDLINE to unban people - throttle-ban-duration: 10m # how wide the CIDR should be for IPv4 (a /32 is a fully specified IPv4 address) cidr-len-ipv4: 32