refactor: Move whois to client package and implement caching

This commit is contained in:
TwiN 2022-11-15 21:35:22 -05:00
parent c172e733be
commit d24ff5bd07
3 changed files with 77 additions and 19 deletions

View File

@ -4,6 +4,7 @@ import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"errors" "errors"
"fmt"
"net" "net"
"net/http" "net/http"
"net/smtp" "net/smtp"
@ -11,12 +12,19 @@ import (
"strings" "strings"
"time" "time"
"github.com/TwiN/gocache/v2"
"github.com/TwiN/whois"
"github.com/go-ping/ping" "github.com/go-ping/ping"
"github.com/ishidawataru/sctp" "github.com/ishidawataru/sctp"
) )
// injectedHTTPClient is used for testing purposes var (
var injectedHTTPClient *http.Client // injectedHTTPClient is used for testing purposes
injectedHTTPClient *http.Client
whoisClient = whois.NewClient().WithReferralCache(true)
whoisExpirationDateCache = gocache.NewCache().WithMaxSize(10000).WithDefaultTTL(24 * time.Hour)
)
// GetHTTPClient returns the shared HTTP client, or the client from the configuration passed // GetHTTPClient returns the shared HTTP client, or the client from the configuration passed
func GetHTTPClient(config *Config) *http.Client { func GetHTTPClient(config *Config) *http.Client {
@ -29,6 +37,35 @@ func GetHTTPClient(config *Config) *http.Client {
return config.getHTTPClient() return config.getHTTPClient()
} }
// GetDomainExpiration retrieves the duration until the domain provided expires
func GetDomainExpiration(hostname string) (domainExpiration time.Duration, err error) {
var retrievedCachedValue bool
if v, exists := whoisExpirationDateCache.Get(hostname); exists {
domainExpiration = time.Until(v.(time.Time))
retrievedCachedValue = true
// If the domain OR the TTL is not going to expire in less than 24 hours
// we don't have to refresh the cache. Otherwise, we'll refresh it.
cacheEntryTTL, _ := whoisExpirationDateCache.TTL(hostname)
if cacheEntryTTL > 24*time.Hour && domainExpiration > 24*time.Hour {
// No need to refresh, so we'll just return the cached values
return domainExpiration, nil
}
}
if whoisResponse, err := whoisClient.QueryAndParse(hostname); err != nil {
if !retrievedCachedValue { // Add an error unless we already retrieved a cached value
return 0, fmt.Errorf("error querying and parsing hostname using whois client: %w", err)
}
} else {
domainExpiration = time.Until(whoisResponse.ExpirationDate)
if domainExpiration > 720*time.Hour {
whoisExpirationDateCache.SetWithTTL(hostname, whoisResponse.ExpirationDate, 240*time.Hour)
} else {
whoisExpirationDateCache.SetWithTTL(hostname, whoisResponse.ExpirationDate, 72*time.Hour)
}
}
return domainExpiration, nil
}
// CanCreateTCPConnection checks whether a connection can be established with a TCP endpoint // CanCreateTCPConnection checks whether a connection can be established with a TCP endpoint
func CanCreateTCPConnection(address string, config *Config) bool { func CanCreateTCPConnection(address string, config *Config) bool {
conn, err := net.DialTimeout("tcp", address, config.Timeout) conn, err := net.DialTimeout("tcp", address, config.Timeout)

View File

@ -35,6 +35,33 @@ func TestGetHTTPClient(t *testing.T) {
} }
} }
func TestGetDomainExpiration(t *testing.T) {
if domainExpiration, err := GetDomainExpiration("example.com"); err != nil {
t.Fatalf("expected error to be nil, but got: `%s`", err)
} else if domainExpiration <= 0 {
t.Error("expected domain expiration to be higher than 0")
}
if domainExpiration, err := GetDomainExpiration("example.com"); err != nil {
t.Errorf("expected error to be nil, but got: `%s`", err)
} else if domainExpiration <= 0 {
t.Error("expected domain expiration to be higher than 0")
}
// Hack to pretend like the domain is expiring in 1 hour, which should trigger a refresh
whoisExpirationDateCache.SetWithTTL("example.com", time.Now().Add(time.Hour), 25*time.Hour)
if domainExpiration, err := GetDomainExpiration("example.com"); err != nil {
t.Errorf("expected error to be nil, but got: `%s`", err)
} else if domainExpiration <= 0 {
t.Error("expected domain expiration to be higher than 0")
}
// Make sure the refresh works when the ttl is <24 hours
whoisExpirationDateCache.SetWithTTL("example.com", time.Now().Add(35*time.Hour), 23*time.Hour)
if domainExpiration, err := GetDomainExpiration("example.com"); err != nil {
t.Errorf("expected error to be nil, but got: `%s`", err)
} else if domainExpiration <= 0 {
t.Error("expected domain expiration to be higher than 0")
}
}
func TestPing(t *testing.T) { func TestPing(t *testing.T) {
if success, rtt := Ping("127.0.0.1", &Config{Timeout: 500 * time.Millisecond}); !success { if success, rtt := Ping("127.0.0.1", &Config{Timeout: 500 * time.Millisecond}); !success {
t.Error("expected true") t.Error("expected true")

View File

@ -16,7 +16,6 @@ import (
"github.com/TwiN/gatus/v4/client" "github.com/TwiN/gatus/v4/client"
"github.com/TwiN/gatus/v4/core/ui" "github.com/TwiN/gatus/v4/core/ui"
"github.com/TwiN/gatus/v4/util" "github.com/TwiN/gatus/v4/util"
"github.com/TwiN/whois"
) )
type EndpointType string type EndpointType string
@ -66,8 +65,6 @@ var (
// This is because the free whois service we are using should not be abused, especially considering the fact that // This is because the free whois service we are using should not be abused, especially considering the fact that
// the data takes a while to be updated. // the data takes a while to be updated.
ErrInvalidEndpointIntervalForDomainExpirationPlaceholder = errors.New("the minimum interval for an endpoint with a condition using the " + DomainExpirationPlaceholder + " placeholder is 300s (5m)") ErrInvalidEndpointIntervalForDomainExpirationPlaceholder = errors.New("the minimum interval for an endpoint with a condition using the " + DomainExpirationPlaceholder + " placeholder is 300s (5m)")
whoisClient = whois.NewClient().WithReferralCache(true)
) )
// Endpoint is the configuration of a monitored // Endpoint is the configuration of a monitored
@ -257,11 +254,20 @@ func (endpoint *Endpoint) EvaluateHealth() *Result {
if endpoint.needsToRetrieveIP() { if endpoint.needsToRetrieveIP() {
endpoint.getIP(result) endpoint.getIP(result)
} }
// Retrieve domain expiration if necessary
if endpoint.needsToRetrieveDomainExpiration() && len(result.Hostname) > 0 {
var err error
if result.DomainExpiration, err = client.GetDomainExpiration(result.Hostname); err != nil {
result.AddError(err.Error())
}
}
// Call the endpoint (if there's no errors)
if len(result.Errors) == 0 { if len(result.Errors) == 0 {
endpoint.call(result) endpoint.call(result)
} else { } else {
result.Success = false result.Success = false
} }
// Evaluate the conditions
for _, condition := range endpoint.Conditions { for _, condition := range endpoint.Conditions {
success := condition.evaluate(result, endpoint.UIConfig.DontResolveFailedConditions) success := condition.evaluate(result, endpoint.UIConfig.DontResolveFailedConditions)
if !success { if !success {
@ -269,10 +275,6 @@ func (endpoint *Endpoint) EvaluateHealth() *Result {
} }
} }
result.Timestamp = time.Now() result.Timestamp = time.Now()
// Retrieve domain expiration if necessary
if endpoint.needsToRetrieveDomainExpiration() && len(result.Hostname) > 0 {
endpoint.getDomainExpiration(result)
}
// No need to keep the body after the endpoint has been evaluated // No need to keep the body after the endpoint has been evaluated
result.body = nil result.body = nil
// Clean up parameters that we don't need to keep in the results // Clean up parameters that we don't need to keep in the results
@ -291,19 +293,11 @@ func (endpoint *Endpoint) EvaluateHealth() *Result {
} }
func (endpoint *Endpoint) getIP(result *Result) { func (endpoint *Endpoint) getIP(result *Result) {
ips, err := net.LookupIP(result.Hostname) if ips, err := net.LookupIP(result.Hostname); err != nil {
if err != nil {
result.AddError(err.Error()) result.AddError(err.Error())
return return
}
result.IP = ips[0].String()
}
func (endpoint *Endpoint) getDomainExpiration(result *Result) {
if whoisResponse, err := whoisClient.QueryAndParse(result.Hostname); err != nil {
result.AddError("error querying and parsing hostname using whois client: " + err.Error())
} else { } else {
result.DomainExpiration = time.Until(whoisResponse.ExpirationDate) result.IP = ips[0].String()
} }
} }