diff --git a/irc/utils/semaphores.go b/irc/utils/semaphores.go index 8d3d19ff..2a34c378 100644 --- a/irc/utils/semaphores.go +++ b/irc/utils/semaphores.go @@ -1,10 +1,12 @@ // Copyright (c) 2018 Shivaram Lingamneni +// released under the MIT license package utils import ( "log" "runtime/debug" + "time" ) // Semaphore is a counting semaphore. Note that a capacity of n requires O(n) storage. @@ -35,6 +37,25 @@ func (semaphore *Semaphore) TryAcquire() (acquired bool) { } } +// AcquireWithTimeout tries to acquire a semaphore, blocking for a maximum +// of approximately `d` while waiting for it. It returns whether the acquire +// was successful. +func (semaphore *Semaphore) AcquireWithTimeout(timeout time.Duration) (acquired bool) { + if timeout < 0 { + return semaphore.TryAcquire() + } + + timer := time.NewTimer(timeout) + select { + case <-(*semaphore): + acquired = true + case <-timer.C: + acquired = false + } + timer.Stop() + return +} + // Release releases a semaphore. It never blocks. (This is not a license // to program spurious releases.) func (semaphore *Semaphore) Release() { diff --git a/irc/utils/semaphores_test.go b/irc/utils/semaphores_test.go new file mode 100644 index 00000000..b047ed56 --- /dev/null +++ b/irc/utils/semaphores_test.go @@ -0,0 +1,48 @@ +// Copyright (c) 2019 Shivaram Lingamneni +// released under the MIT license + +package utils + +import ( + "testing" + "time" +) + +func TestTryAcquire(t *testing.T) { + count := 3 + var sem Semaphore + sem.Initialize(count) + + for i := 0; i < count; i++ { + assertEqual(sem.TryAcquire(), true, t) + } + // used up the capacity + assertEqual(sem.TryAcquire(), false, t) + sem.Release() + // got one slot back + assertEqual(sem.TryAcquire(), true, t) +} + +func TestAcquireWithTimeout(t *testing.T) { + var sem Semaphore + sem.Initialize(1) + + assertEqual(sem.TryAcquire(), true, t) + + // cannot acquire the held semaphore + assertEqual(sem.AcquireWithTimeout(100*time.Millisecond), false, t) + + sem.Release() + // can acquire the released semaphore + assertEqual(sem.AcquireWithTimeout(100*time.Millisecond), true, t) + sem.Release() + + // XXX this test could fail if the machine is extremely overloaded + sem.Acquire() + go func() { + time.Sleep(100 * time.Millisecond) + sem.Release() + }() + // we should acquire successfully after approximately 100 msec + assertEqual(sem.AcquireWithTimeout(1*time.Second), true, t) +}