// Copyright (c) 2018 Shivaram Lingamneni
// released under the MIT license

package connection_limits

import (
	"reflect"
	"testing"
	"time"
)

func assertEqual(supplied, expected interface{}, t *testing.T) {
	if !reflect.DeepEqual(supplied, expected) {
		t.Errorf("expected %v but got %v", expected, supplied)
	}
}

func TestGenericThrottle(t *testing.T) {
	minute, _ := time.ParseDuration("1m")
	second, _ := time.ParseDuration("1s")
	zero, _ := time.ParseDuration("0s")

	throttler := GenericThrottle{
		Duration: minute,
		Limit:    2,
	}

	now := time.Now()
	throttled, remaining := throttler.touch(now)
	assertEqual(throttled, false, t)
	assertEqual(remaining, zero, t)

	now = now.Add(second)
	throttled, remaining = throttler.touch(now)
	assertEqual(throttled, false, t)
	assertEqual(remaining, zero, t)

	now = now.Add(second)
	throttled, remaining = throttler.touch(now)
	assertEqual(throttled, true, t)
	assertEqual(remaining, 58*second, t)

	now = now.Add(minute)
	throttled, remaining = throttler.touch(now)
	assertEqual(throttled, false, t)
	assertEqual(remaining, zero, t)
}

func TestGenericThrottleDisabled(t *testing.T) {
	minute, _ := time.ParseDuration("1m")
	throttler := GenericThrottle{
		Duration: minute,
		Limit:    0,
	}

	for i := 0; i < 1024; i += 1 {
		throttled, _ := throttler.Touch()
		if throttled {
			t.Error("disabled throttler should not throttle")
		}
	}
}

func makeTestThrottler(v4len, v6len int) *Limiter {
	minute, _ := time.ParseDuration("1m")
	maxConnections := 3
	config := LimiterConfig{
		rawLimiterConfig: rawLimiterConfig{
			Count:        false,
			Throttle:     true,
			CidrLenIPv4:  v4len,
			CidrLenIPv6:  v6len,
			MaxPerWindow: maxConnections,
			Window:       minute,
		},
	}
	config.postprocess()
	var limiter Limiter
	limiter.ApplyConfig(&config)
	return &limiter
}

func TestConnectionThrottle(t *testing.T) {
	throttler := makeTestThrottler(32, 64)
	addr := easyParseIP("8.8.8.8")

	for i := 0; i < 3; i += 1 {
		err := throttler.AddClient(addr)
		assertEqual(err, nil, t)
	}
	err := throttler.AddClient(addr)
	assertEqual(err, ErrThrottleExceeded, t)
}

func TestConnectionThrottleIPv6(t *testing.T) {
	throttler := makeTestThrottler(32, 64)

	var err error
	err = throttler.AddClient(easyParseIP("2001:0db8::1"))
	assertEqual(err, nil, t)
	err = throttler.AddClient(easyParseIP("2001:0db8::2"))
	assertEqual(err, nil, t)
	err = throttler.AddClient(easyParseIP("2001:0db8::3"))
	assertEqual(err, nil, t)

	err = throttler.AddClient(easyParseIP("2001:0db8::4"))
	assertEqual(err, ErrThrottleExceeded, t)
}

func TestConnectionThrottleIPv4(t *testing.T) {
	throttler := makeTestThrottler(24, 64)

	var err error
	err = throttler.AddClient(easyParseIP("192.168.1.101"))
	assertEqual(err, nil, t)
	err = throttler.AddClient(easyParseIP("192.168.1.102"))
	assertEqual(err, nil, t)
	err = throttler.AddClient(easyParseIP("192.168.1.103"))
	assertEqual(err, nil, t)

	err = throttler.AddClient(easyParseIP("192.168.1.104"))
	assertEqual(err, ErrThrottleExceeded, t)
}