diff --git a/common/network/Blacklist.cxx b/common/network/Blacklist.cxx index a5caeea..d1f7ff2 100644 --- a/common/network/Blacklist.cxx +++ b/common/network/Blacklist.cxx @@ -24,6 +24,8 @@ #include #include #include +#include +#include #include #include @@ -35,6 +37,9 @@ static std::map hits; static std::map blacklist; +static pthread_mutex_t hitmutex = PTHREAD_MUTEX_INITIALIZER; +static pthread_mutex_t blmutex = PTHREAD_MUTEX_INITIALIZER; + unsigned char bl_isBlacklisted(const char *addr) { const unsigned char count = blacklist.count(addr); if (!count) @@ -43,19 +48,35 @@ unsigned char bl_isBlacklisted(const char *addr) { const time_t now = time(NULL); const unsigned timeout = rfb::Blacklist::initialTimeout; + if (pthread_mutex_lock(&blmutex)) + abort(); + if (now - timeout > blacklist[addr]) { blacklist.erase(addr); + pthread_mutex_unlock(&blmutex); + + if (pthread_mutex_lock(&hitmutex)) + abort(); hits.erase(addr); + pthread_mutex_unlock(&hitmutex); return 0; } else { blacklist[addr] = now; + pthread_mutex_unlock(&blmutex); return 1; } } void bl_addFailure(const char *addr) { + if (pthread_mutex_lock(&hitmutex)) + abort(); const unsigned num = ++hits[addr]; + pthread_mutex_unlock(&hitmutex); + if (num >= (unsigned) rfb::Blacklist::threshold) { + if (pthread_mutex_lock(&blmutex)) + abort(); blacklist[addr] = time(NULL); + pthread_mutex_unlock(&blmutex); } }