mirror of
https://github.com/ergochat/ergo.git
synced 2025-01-19 16:44:13 +01:00
87 lines
1.9 KiB
Go
87 lines
1.9 KiB
Go
// Copyright (c) 2018 Shivaram Lingamneni
|
|
// released under the MIT license
|
|
|
|
package connection_limits
|
|
|
|
import (
|
|
"net"
|
|
"reflect"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
func assertEqual(supplied, expected interface{}, t *testing.T) {
|
|
if !reflect.DeepEqual(supplied, expected) {
|
|
t.Errorf("expected %v but got %v", expected, supplied)
|
|
}
|
|
}
|
|
|
|
func TestGenericThrottle(t *testing.T) {
|
|
minute, _ := time.ParseDuration("1m")
|
|
second, _ := time.ParseDuration("1s")
|
|
zero, _ := time.ParseDuration("0s")
|
|
|
|
throttler := GenericThrottle{
|
|
Duration: minute,
|
|
Limit: 2,
|
|
}
|
|
|
|
now := time.Now()
|
|
throttled, remaining := throttler.touch(now)
|
|
assertEqual(throttled, false, t)
|
|
assertEqual(remaining, zero, t)
|
|
|
|
now = now.Add(second)
|
|
throttled, remaining = throttler.touch(now)
|
|
assertEqual(throttled, false, t)
|
|
assertEqual(remaining, zero, t)
|
|
|
|
now = now.Add(second)
|
|
throttled, remaining = throttler.touch(now)
|
|
assertEqual(throttled, true, t)
|
|
assertEqual(remaining, 58*second, t)
|
|
|
|
now = now.Add(minute)
|
|
throttled, remaining = throttler.touch(now)
|
|
assertEqual(throttled, false, t)
|
|
assertEqual(remaining, zero, t)
|
|
}
|
|
|
|
func TestGenericThrottleDisabled(t *testing.T) {
|
|
minute, _ := time.ParseDuration("1m")
|
|
throttler := GenericThrottle{
|
|
Duration: minute,
|
|
Limit: 0,
|
|
}
|
|
|
|
for i := 0; i < 1024; i += 1 {
|
|
throttled, _ := throttler.Touch()
|
|
if throttled {
|
|
t.Error("disabled throttler should not throttle")
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestConnectionThrottle(t *testing.T) {
|
|
minute, _ := time.ParseDuration("1m")
|
|
maxConnections := 3
|
|
config := ThrottlerConfig{
|
|
Enabled: true,
|
|
CidrLenIPv4: 32,
|
|
CidrLenIPv6: 64,
|
|
ConnectionsPerCidr: maxConnections,
|
|
Duration: minute,
|
|
}
|
|
throttler := NewThrottler()
|
|
throttler.ApplyConfig(config)
|
|
|
|
addr := net.ParseIP("8.8.8.8")
|
|
|
|
for i := 0; i < maxConnections; i += 1 {
|
|
err := throttler.AddClient(addr)
|
|
assertEqual(err, nil, t)
|
|
}
|
|
err := throttler.AddClient(addr)
|
|
assertEqual(err, errTooManyClients, t)
|
|
}
|