diff --git a/lib/limits.py b/lib/limits.py new file mode 100644 index 0000000..2ff32e6 --- /dev/null +++ b/lib/limits.py @@ -0,0 +1,108 @@ +""" +Connection limitation. + +Number of connections from one IP is limited. +We have nothing against scripting and automated queries. +Even the opposite, we encourage them. But there are some +connection limits that even we can't handle. +Currently the limits are quite restrictive, but they will be relaxed +in the future. + +Usage: + + limits = Limits() + not_allowed = limits.check_ip(ip_address) + if not_allowed: + return "ERROR: %s" % not_allowed + +[Taken from github.com/chubin/cheat.sh] +""" + +import time +from globals import log + +def _time_caps(minutes, hours, days): + return { + 'min': minutes, + 'hour': hours, + 'day': days, + } + +class Limits(object): + """ + Queries limitation (by IP). + + Exports: + + check_ip(ip_address) + """ + + def __init__(self, whitelist=None, limits=None): + self.intervals = ['min', 'hour', 'day'] + + self.divisor = _time_caps(60, 3600, 86400) + self.last_update = _time_caps(0, 0, 0) + + if limits: + self.limit = _time_caps(*limits) + else: + self.limit = _time_caps(30, 600, 1000) + + if whitelist: + self.whitelist = whitelist[:] + else: + self.whitelist = [] + + self.counter = { + 'min': {}, + 'hour': {}, + 'day': {}, + } + + self._clear_counters_if_needed() + + def _log_visit(self, interval, ip_address): + if ip_address not in self.counter[interval]: + self.counter[interval][ip_address] = 0 + self.counter[interval][ip_address] += 1 + + def _limit_exceeded(self, interval, ip_address): + visits = self.counter[interval][ip_address] + limit = self._get_limit(interval) + return visits > limit + + def _get_limit(self, interval): + return self.limit[interval] + + def _report_excessive_visits(self, interval, ip_address): + log("%s LIMITED [%s for %s]" % (ip_address, self._get_limit(interval), interval)) + + def check_ip(self, ip_address): + """ + Check if `ip_address` is allowed, and if not raise an RuntimeError exception. + Return True otherwise + """ + if ip_address in self.whitelist: + return None + self._clear_counters_if_needed() + for interval in self.intervals: + self._log_visit(interval, ip_address) + if self._limit_exceeded(interval, ip_address): + self._report_excessive_visits(interval, ip_address) + return ("Not so fast! Number of queries per %s is limited to %s" + % (interval, self._get_limit(interval))) + return None + + def reset(self): + """ + Reset all counters for all IPs + """ + for interval in self.intervals: + self.counter[interval] = {} + + def _clear_counters_if_needed(self): + current_time = int(time.time()) + for interval in self.intervals: + if current_time // self.divisor[interval] != self.last_update[interval]: + self.counter[interval] = {} + self.last_update[interval] = current_time / self.divisor[interval]