diff --git a/irc/connection_limits/throttler_test.go b/irc/connection_limits/throttler_test.go index b7862375..c31c4276 100644 --- a/irc/connection_limits/throttler_test.go +++ b/irc/connection_limits/throttler_test.go @@ -62,25 +62,59 @@ func TestGenericThrottleDisabled(t *testing.T) { } } -func TestConnectionThrottle(t *testing.T) { +func makeTestThrottler(v4len, v6len int) *Throttler { minute, _ := time.ParseDuration("1m") maxConnections := 3 config := ThrottlerConfig{ Enabled: true, - CidrLenIPv4: 32, - CidrLenIPv6: 64, + CidrLenIPv4: v4len, + CidrLenIPv6: v6len, ConnectionsPerCidr: maxConnections, Duration: minute, } throttler := NewThrottler() throttler.ApplyConfig(config) + return throttler +} +func TestConnectionThrottle(t *testing.T) { + throttler := makeTestThrottler(32, 64) addr := net.ParseIP("8.8.8.8") - for i := 0; i < maxConnections; i += 1 { + for i := 0; i < 3; i += 1 { err := throttler.AddClient(addr) assertEqual(err, nil, t) } err := throttler.AddClient(addr) assertEqual(err, errTooManyClients, t) } + +func TestConnectionThrottleIPv6(t *testing.T) { + throttler := makeTestThrottler(32, 64) + + var err error + err = throttler.AddClient(net.ParseIP("2001:0db8::1")) + assertEqual(err, nil, t) + err = throttler.AddClient(net.ParseIP("2001:0db8::2")) + assertEqual(err, nil, t) + err = throttler.AddClient(net.ParseIP("2001:0db8::3")) + assertEqual(err, nil, t) + + err = throttler.AddClient(net.ParseIP("2001:0db8::4")) + assertEqual(err, errTooManyClients, t) +} + +func TestConnectionThrottleIPv4(t *testing.T) { + throttler := makeTestThrottler(24, 64) + + var err error + err = throttler.AddClient(net.ParseIP("192.168.1.101")) + assertEqual(err, nil, t) + err = throttler.AddClient(net.ParseIP("192.168.1.102")) + assertEqual(err, nil, t) + err = throttler.AddClient(net.ParseIP("192.168.1.103")) + assertEqual(err, nil, t) + + err = throttler.AddClient(net.ParseIP("192.168.1.104")) + assertEqual(err, errTooManyClients, t) +}