diff --git a/internal/ratelimiter/ratelimiter.go b/internal/ratelimiter/ratelimiter.go index f9fc673..006900a 100644 --- a/internal/ratelimiter/ratelimiter.go +++ b/internal/ratelimiter/ratelimiter.go @@ -2,8 +2,7 @@ package ratelimiter /* Copyright (C) 2015-2017 Jason A. Donenfeld . All Rights Reserved. */ -/* This file contains a port of the ratelimited from the linux kernel version - */ +/* This file contains a port of the rate-limiter from the linux kernel version */ import ( "net" @@ -12,11 +11,11 @@ import ( ) const ( - RatelimiterPacketsPerSecond = 20 - RatelimiterPacketsBurstable = 5 - RatelimiterGarbageCollectTime = time.Second - RatelimiterPacketCost = 1000000000 / RatelimiterPacketsPerSecond - RatelimiterMaxTokens = RatelimiterPacketCost * RatelimiterPacketsBurstable + packetsPerSecond = 20 + packetsBurstable = 5 + garbageCollectTime = time.Second + packetCost = 1000000000 / packetsPerSecond + maxTokens = packetCost * packetsBurstable ) type RatelimiterEntry struct { @@ -45,6 +44,8 @@ func (rate *Ratelimiter) Init() { rate.mutex.Lock() defer rate.mutex.Unlock() + // stop any ongoing garbage collection routine + if rate.stop != nil { close(rate.stop) } @@ -53,6 +54,8 @@ func (rate *Ratelimiter) Init() { rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry) rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry) + // start garbage collection routine + go func() { timer := time.NewTimer(time.Second) for { @@ -60,39 +63,32 @@ func (rate *Ratelimiter) Init() { case <-rate.stop: return case <-timer.C: - rate.garbageCollectEntries() + func() { + rate.mutex.Lock() + defer rate.mutex.Unlock() + + for key, entry := range rate.tableIPv4 { + entry.mutex.Lock() + if time.Now().Sub(entry.lastTime) > garbageCollectTime { + delete(rate.tableIPv4, key) + } + entry.mutex.Unlock() + } + + for key, entry := range rate.tableIPv6 { + entry.mutex.Lock() + if time.Now().Sub(entry.lastTime) > garbageCollectTime { + delete(rate.tableIPv6, key) + } + entry.mutex.Unlock() + } + }() timer.Reset(time.Second) } } }() } -func (rate *Ratelimiter) garbageCollectEntries() { - rate.mutex.Lock() - - // remove unused IPv4 entries - - for key, entry := range rate.tableIPv4 { - entry.mutex.Lock() - if time.Now().Sub(entry.lastTime) > RatelimiterGarbageCollectTime { - delete(rate.tableIPv4, key) - } - entry.mutex.Unlock() - } - - // remove unused IPv6 entries - - for key, entry := range rate.tableIPv6 { - entry.mutex.Lock() - if time.Now().Sub(entry.lastTime) > RatelimiterGarbageCollectTime { - delete(rate.tableIPv6, key) - } - entry.mutex.Unlock() - } - - rate.mutex.Unlock() -} - func (rate *Ratelimiter) Allow(ip net.IP) bool { var entry *RatelimiterEntry var KeyIPv4 [net.IPv4len]byte @@ -120,7 +116,7 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool { if entry == nil { rate.mutex.Lock() entry = new(RatelimiterEntry) - entry.tokens = RatelimiterMaxTokens - RatelimiterPacketCost + entry.tokens = maxTokens - packetCost entry.lastTime = time.Now() if IPv4 != nil { rate.tableIPv4[KeyIPv4] = entry @@ -137,14 +133,14 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool { now := time.Now() entry.tokens += now.Sub(entry.lastTime).Nanoseconds() entry.lastTime = now - if entry.tokens > RatelimiterMaxTokens { - entry.tokens = RatelimiterMaxTokens + if entry.tokens > maxTokens { + entry.tokens = maxTokens } // subtract cost of packet - if entry.tokens > RatelimiterPacketCost { - entry.tokens -= RatelimiterPacketCost + if entry.tokens > packetCost { + entry.tokens -= packetCost entry.mutex.Unlock() return true } diff --git a/internal/ratelimiter/ratelimiter_test.go b/internal/ratelimiter/ratelimiter_test.go index a6f618b..37339ee 100644 --- a/internal/ratelimiter/ratelimiter_test.go +++ b/internal/ratelimiter/ratelimiter_test.go @@ -28,7 +28,7 @@ func TestRatelimiter(t *testing.T) { ) } - for i := 0; i < RatelimiterPacketsBurstable; i++ { + for i := 0; i < packetsBurstable; i++ { Add(RatelimiterResult{ allowed: true, text: "inital burst", @@ -42,7 +42,7 @@ func TestRatelimiter(t *testing.T) { Add(RatelimiterResult{ allowed: true, - wait: Nano(time.Second.Nanoseconds() / RatelimiterPacketsPerSecond), + wait: Nano(time.Second.Nanoseconds() / packetsPerSecond), text: "filling tokens for single packet", }) @@ -53,7 +53,7 @@ func TestRatelimiter(t *testing.T) { Add(RatelimiterResult{ allowed: true, - wait: 2 * Nano(time.Second.Nanoseconds()/RatelimiterPacketsPerSecond), + wait: 2 * (Nano(time.Second.Nanoseconds() / packetsPerSecond)), text: "filling tokens for two packet burst", })