From 326ea1c3d1e65d2fd24fcee150e0c483ee148172 Mon Sep 17 00:00:00 2001 From: TwiN Date: Mon, 13 Jun 2022 19:16:34 -0400 Subject: [PATCH] refactor(client): Clean up client dns resolver --- client/client_test.go | 47 --------------------------------- client/config.go | 61 ++++++++++++++++++++++++++----------------- client/config_test.go | 44 +++++++++++++++++++++++++++++++ 3 files changed, 81 insertions(+), 71 deletions(-) diff --git a/client/client_test.go b/client/client_test.go index 811b6a29..5da2f55a 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -56,52 +56,6 @@ func TestPing(t *testing.T) { } } -func TestDNSResolverConfig(t *testing.T) { - type args struct { - dnsResolver string - } - tests := []struct { - name string - args args - wantErr bool - }{ - { - name: "valid resolver", - args: args{ - dnsResolver: "tcp://1.1.1.1:53", - }, - wantErr: false, - }, - { - name: "invalid resolver port", - args: args{ - dnsResolver: "tcp://127.0.0.1:99999", - }, - wantErr: true, - }, - { - name: "invalid resolver format", - args: args{ - dnsResolver: "foobar", - }, - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - cfg := &Config{ - DNSResolver: tt.args.dnsResolver, - } - client := GetHTTPClient(cfg) - _, err := client.Get("https://example.org") - if (err != nil) != tt.wantErr { - t.Errorf("TestDNSResolverConfig err=%v, wantErr=%v", err, tt.wantErr) - return - } - }) - } -} - func TestCanPerformStartTLS(t *testing.T) { type args struct { address string @@ -221,7 +175,6 @@ func TestHttpClientProvidesOAuth2BearerToken(t *testing.T) { } mockHttpClient := &http.Client{ Transport: test.MockRoundTripper(func(r *http.Request) *http.Response { - // if the mock HTTP client tries to get a token from the `token-server` // we provide the expected token response if r.Host == "token-server.local" { diff --git a/client/config.go b/client/config.go index b8e63813..6fb407de 100644 --- a/client/config.go +++ b/client/config.go @@ -4,10 +4,11 @@ import ( "context" "crypto/tls" "errors" - "fmt" + "log" "net" "net/http" "regexp" + "strconv" "time" "golang.org/x/oauth2" @@ -20,6 +21,7 @@ const ( var ( ErrInvalidDNSResolver = errors.New("invalid DNS resolver specified. Required format is {proto}://{ip}:{port}") + ErrInvalidDNSResolverPort = errors.New("invalid DNS resolver port") ErrInvalidClientOAuth2Config = errors.New("invalid OAuth2 configuration, all fields are required") defaultConfig = Config{ @@ -46,8 +48,8 @@ type Config struct { // Timeout for the client Timeout time.Duration `yaml:"timeout"` - // DNSResolver override for the HTTPClient - // Expected format is {protocol}://{host}:{port} + // DNSResolver override for the HTTP client + // Expected format is {protocol}://{host}:{port}, e.g. tcp://1.1.1.1:53 DNSResolver string `yaml:"dns-resolver,omitempty"` // OAuth2Config is the OAuth2 configuration used for the client. @@ -80,9 +82,9 @@ func (c *Config) ValidateAndSetDefaults() error { c.Timeout = 10 * time.Second } if c.HasCustomDNSResolver() { - _, err := c.ParseDNSResolver() - if err != nil { - return ErrInvalidDNSResolver + // Validate the DNS resolver now to make sure it will not return an error later. + if _, err := c.parseDNSResolver(); err != nil { + return err } } if c.HasOAuth2Config() && !c.OAuth2Config.isValid() { @@ -91,17 +93,17 @@ func (c *Config) ValidateAndSetDefaults() error { return nil } -// Returns true if the DNSResolver is set in the configuration +// HasCustomDNSResolver returns whether a custom DNSResolver is configured func (c *Config) HasCustomDNSResolver() bool { return len(c.DNSResolver) > 0 } -// Parses the DNSResolver configuration string into the DNSResolverConfig struct -func (c *Config) ParseDNSResolver() (DNSResolverConfig, error) { +// parseDNSResolver parses the DNS resolver into the DNSResolverConfig struct +func (c *Config) parseDNSResolver() (*DNSResolverConfig, error) { re := regexp.MustCompile(`^(?P(.*))://(?P[A-Za-z0-9\-\.]+):(?P[0-9]+)?(.*)$`) matches := re.FindStringSubmatch(c.DNSResolver) if len(matches) == 0 { - return DNSResolverConfig{}, errors.New("ParseError") + return nil, ErrInvalidDNSResolver } r := make(map[string]string) for i, k := range re.SubexpNames() { @@ -109,8 +111,14 @@ func (c *Config) ParseDNSResolver() (DNSResolverConfig, error) { r[k] = matches[i] } } - - return DNSResolverConfig{ + port, err := strconv.Atoi(r["port"]) + if err != nil { + return nil, err + } + if port < 1 || port > 65535 { + return nil, ErrInvalidDNSResolverPort + } + return &DNSResolverConfig{ Protocol: r["proto"], Host: r["host"], Port: r["port"], @@ -150,20 +158,25 @@ func (c *Config) getHTTPClient() *http.Client { }, } if c.HasCustomDNSResolver() { - dnsResolver, _ := c.ParseDNSResolver() - dialer := &net.Dialer{ - Resolver: &net.Resolver{ - PreferGo: true, - Dial: func(ctx context.Context, network, address string) (net.Conn, error) { - d := net.Dialer{} - return d.DialContext(ctx, dnsResolver.Protocol, fmt.Sprintf("%s:%s", dnsResolver.Host, dnsResolver.Port)) + dnsResolver, err := c.parseDNSResolver() + if err != nil { + // We're ignoring the error, because it should have been validated on startup ValidateAndSetDefaults. + // It shouldn't happen, but if it does, we'll log it... Better safe than sorry ;) + log.Println("[client][getHTTPClient] THIS SHOULD NOT HAPPEN. Silently ignoring invalid DNS resolver due to error:", err.Error()) + } else { + dialer := &net.Dialer{ + Resolver: &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + d := net.Dialer{} + return d.DialContext(ctx, dnsResolver.Protocol, dnsResolver.Host+":"+dnsResolver.Port) + }, }, - }, + } + c.httpClient.Transport.(*http.Transport).DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + return dialer.DialContext(ctx, network, addr) + } } - dialCtx := func(ctx context.Context, network, addr string) (net.Conn, error) { - return dialer.DialContext(ctx, network, addr) - } - c.httpClient.Transport.(*http.Transport).DialContext = dialCtx } if c.HasOAuth2Config() { c.httpClient = configureOAuth2(c.httpClient, *c.OAuth2Config) diff --git a/client/config_test.go b/client/config_test.go index bf493d34..b09e7e57 100644 --- a/client/config_test.go +++ b/client/config_test.go @@ -35,3 +35,47 @@ func TestConfig_getHTTPClient(t *testing.T) { t.Error("expected Config.IgnoreRedirect set to true to cause the HTTP client's CheckRedirect to return http.ErrUseLastResponse") } } + +func TestConfig_ValidateAndSetDefaults_withCustomDNSResolver(t *testing.T) { + type args struct { + dnsResolver string + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "with-valid-resolver", + args: args{ + dnsResolver: "tcp://1.1.1.1:53", + }, + wantErr: false, + }, + { + name: "with-invalid-resolver-port", + args: args{ + dnsResolver: "tcp://127.0.0.1:99999", + }, + wantErr: true, + }, + { + name: "with-invalid-resolver-format", + args: args{ + dnsResolver: "foobar", + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &Config{ + DNSResolver: tt.args.dnsResolver, + } + err := cfg.ValidateAndSetDefaults() + if (err != nil) != tt.wantErr { + t.Errorf("ValidateAndSetDefaults() error=%v, wantErr=%v", err, tt.wantErr) + } + }) + } +}