diff --git a/README.md b/README.md index 8a9b0fd5..5f3fb18b 100644 --- a/README.md +++ b/README.md @@ -282,6 +282,7 @@ the client used to send the request. | `client.insecure` | Whether to skip verifying the server's certificate chain and host name. | `false` | | `client.ignore-redirect` | Whether to ignore redirects (true) or follow them (false, default). | `false` | | `client.timeout` | Duration before timing out. | `10s` | +| `client.dns-resolver` | Override the DNS resolver using the format `{proto}://{host}:{port}`. | `""` | | `client.oauth2` | OAuth2 client configuration. | `{}` | | `client.oauth2.token-url` | The token endpoint URL | required `""` | | `client.oauth2.client-id` | The client id which should be used for the `Client credentials flow` | required `""` | @@ -313,6 +314,17 @@ endpoints: - "[STATUS] == 200" ``` +This example shows how you can use a `custom DNS Resolver`: +```yaml +endpoints: + - name: website + url: "https://your.health.api/getHealth" + client: + dns-resolver: "tcp://1.1.1.1:53" + conditions: + - "[STATUS] == 200" +``` + This example shows how you can use the `client.oauth2` configuration to query a backend API with `Bearer token`: ```yaml endpoints: diff --git a/client/client_test.go b/client/client_test.go index c713a399..fef97b39 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -15,6 +15,7 @@ func TestGetHTTPClient(t *testing.T) { Insecure: false, IgnoreRedirect: false, Timeout: 0, + DNSResolver: "tcp://1.1.1.1:53", OAuth2Config: &OAuth2Config{ ClientID: "00000000-0000-0000-0000-000000000000", ClientSecret: "secretsauce", @@ -22,7 +23,10 @@ func TestGetHTTPClient(t *testing.T) { Scopes: []string{"https://application.local/.default"}, }, } - cfg.ValidateAndSetDefaults() + err := cfg.ValidateAndSetDefaults() + if err != nil { + t.Errorf("expected error to be nil, but got: `%s`", err) + } if GetHTTPClient(cfg) == nil { t.Error("expected client to not be nil") } @@ -52,6 +56,54 @@ func TestPing(t *testing.T) { } } +func TestDNSResolverConfig(t *testing.T) { + type args struct { + resolver string + } + + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "Valid resolver", + args: args{ + resolver: "tcp://1.1.1.1:53", + }, + wantErr: false, + }, + { + name: "Invalid resolver address/port", + args: args{ + resolver: "tcp://127.0.0.1:99999", + }, + wantErr: true, + }, + { + name: "Invalid resolver format", + args: args{ + resolver: "foobaz", + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &Config{ + DNSResolver: tt.args.resolver, + } + client := GetHTTPClient(cfg) + _, err := client.Get("https://www.google.com") + if (err != nil) != tt.wantErr { + t.Errorf("TestDNSResolverConfig err=%v, wantErr=%v", err, tt.wantErr) + return + } + }) + } +} + func TestCanPerformStartTLS(t *testing.T) { type args struct { address string diff --git a/client/config.go b/client/config.go index 48830a70..b8e63813 100644 --- a/client/config.go +++ b/client/config.go @@ -4,7 +4,10 @@ import ( "context" "crypto/tls" "errors" + "fmt" + "net" "net/http" + "regexp" "time" "golang.org/x/oauth2" @@ -16,6 +19,7 @@ const ( ) var ( + ErrInvalidDNSResolver = errors.New("invalid DNS resolver specified. Required format is {proto}://{ip}:{port}") ErrInvalidClientOAuth2Config = errors.New("invalid OAuth2 configuration, all fields are required") defaultConfig = Config{ @@ -42,6 +46,10 @@ type Config struct { // Timeout for the client Timeout time.Duration `yaml:"timeout"` + // DNSResolver override for the HTTPClient + // Expected format is {protocol}://{host}:{port} + DNSResolver string `yaml:"dns-resolver,omitempty"` + // OAuth2Config is the OAuth2 configuration used for the client. // // If non-nil, the http.Client returned by getHTTPClient will automatically retrieve a token if necessary. @@ -51,6 +59,13 @@ type Config struct { httpClient *http.Client } +// DNSResolverConfig is the parsed configuration from the DNSResolver config string. +type DNSResolverConfig struct { + Protocol string + Host string + Port string +} + // OAuth2Config is the configuration for the OAuth2 client credentials flow type OAuth2Config struct { TokenURL string `yaml:"token-url"` // e.g. https://dev-12345678.okta.com/token @@ -64,12 +79,44 @@ func (c *Config) ValidateAndSetDefaults() error { if c.Timeout < time.Millisecond { c.Timeout = 10 * time.Second } + if c.HasCustomDNSResolver() { + _, err := c.ParseDNSResolver() + if err != nil { + return ErrInvalidDNSResolver + } + } if c.HasOAuth2Config() && !c.OAuth2Config.isValid() { return ErrInvalidClientOAuth2Config } return nil } +// Returns true if the DNSResolver is set in the configuration +func (c *Config) HasCustomDNSResolver() bool { + return len(c.DNSResolver) > 0 +} + +// Parses the DNSResolver configuration string into the DNSResolverConfig struct +func (c *Config) ParseDNSResolver() (DNSResolverConfig, error) { + re := regexp.MustCompile(`^(?P(.*))://(?P[A-Za-z0-9\-\.]+):(?P[0-9]+)?(.*)$`) + matches := re.FindStringSubmatch(c.DNSResolver) + if len(matches) == 0 { + return DNSResolverConfig{}, errors.New("ParseError") + } + r := make(map[string]string) + for i, k := range re.SubexpNames() { + if i != 0 && k != "" { + r[k] = matches[i] + } + } + + return DNSResolverConfig{ + Protocol: r["proto"], + Host: r["host"], + Port: r["port"], + }, nil +} + // HasOAuth2Config returns true if the client has OAuth2 configuration parameters func (c *Config) HasOAuth2Config() bool { return c.OAuth2Config != nil @@ -102,6 +149,22 @@ func (c *Config) getHTTPClient() *http.Client { return nil }, } + if c.HasCustomDNSResolver() { + dnsResolver, _ := c.ParseDNSResolver() + dialer := &net.Dialer{ + Resolver: &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + d := net.Dialer{} + return d.DialContext(ctx, dnsResolver.Protocol, fmt.Sprintf("%s:%s", dnsResolver.Host, dnsResolver.Port)) + }, + }, + } + dialCtx := func(ctx context.Context, network, addr string) (net.Conn, error) { + return dialer.DialContext(ctx, network, addr) + } + c.httpClient.Transport.(*http.Transport).DialContext = dialCtx + } if c.HasOAuth2Config() { c.httpClient = configureOAuth2(c.httpClient, *c.OAuth2Config) }