refactor(client): Clean up client dns resolver

This commit is contained in:
TwiN 2022-06-13 19:16:34 -04:00
parent fea95b8479
commit 326ea1c3d1
3 changed files with 81 additions and 71 deletions

View File

@ -56,52 +56,6 @@ func TestPing(t *testing.T) {
}
}
func TestDNSResolverConfig(t *testing.T) {
type args struct {
dnsResolver string
}
tests := []struct {
name string
args args
wantErr bool
}{
{
name: "valid resolver",
args: args{
dnsResolver: "tcp://1.1.1.1:53",
},
wantErr: false,
},
{
name: "invalid resolver port",
args: args{
dnsResolver: "tcp://127.0.0.1:99999",
},
wantErr: true,
},
{
name: "invalid resolver format",
args: args{
dnsResolver: "foobar",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &Config{
DNSResolver: tt.args.dnsResolver,
}
client := GetHTTPClient(cfg)
_, err := client.Get("https://example.org")
if (err != nil) != tt.wantErr {
t.Errorf("TestDNSResolverConfig err=%v, wantErr=%v", err, tt.wantErr)
return
}
})
}
}
func TestCanPerformStartTLS(t *testing.T) {
type args struct {
address string
@ -221,7 +175,6 @@ func TestHttpClientProvidesOAuth2BearerToken(t *testing.T) {
}
mockHttpClient := &http.Client{
Transport: test.MockRoundTripper(func(r *http.Request) *http.Response {
// if the mock HTTP client tries to get a token from the `token-server`
// we provide the expected token response
if r.Host == "token-server.local" {

View File

@ -4,10 +4,11 @@ import (
"context"
"crypto/tls"
"errors"
"fmt"
"log"
"net"
"net/http"
"regexp"
"strconv"
"time"
"golang.org/x/oauth2"
@ -20,6 +21,7 @@ const (
var (
ErrInvalidDNSResolver = errors.New("invalid DNS resolver specified. Required format is {proto}://{ip}:{port}")
ErrInvalidDNSResolverPort = errors.New("invalid DNS resolver port")
ErrInvalidClientOAuth2Config = errors.New("invalid OAuth2 configuration, all fields are required")
defaultConfig = Config{
@ -46,8 +48,8 @@ type Config struct {
// Timeout for the client
Timeout time.Duration `yaml:"timeout"`
// DNSResolver override for the HTTPClient
// Expected format is {protocol}://{host}:{port}
// DNSResolver override for the HTTP client
// Expected format is {protocol}://{host}:{port}, e.g. tcp://1.1.1.1:53
DNSResolver string `yaml:"dns-resolver,omitempty"`
// OAuth2Config is the OAuth2 configuration used for the client.
@ -80,9 +82,9 @@ func (c *Config) ValidateAndSetDefaults() error {
c.Timeout = 10 * time.Second
}
if c.HasCustomDNSResolver() {
_, err := c.ParseDNSResolver()
if err != nil {
return ErrInvalidDNSResolver
// Validate the DNS resolver now to make sure it will not return an error later.
if _, err := c.parseDNSResolver(); err != nil {
return err
}
}
if c.HasOAuth2Config() && !c.OAuth2Config.isValid() {
@ -91,17 +93,17 @@ func (c *Config) ValidateAndSetDefaults() error {
return nil
}
// Returns true if the DNSResolver is set in the configuration
// HasCustomDNSResolver returns whether a custom DNSResolver is configured
func (c *Config) HasCustomDNSResolver() bool {
return len(c.DNSResolver) > 0
}
// Parses the DNSResolver configuration string into the DNSResolverConfig struct
func (c *Config) ParseDNSResolver() (DNSResolverConfig, error) {
// parseDNSResolver parses the DNS resolver into the DNSResolverConfig struct
func (c *Config) parseDNSResolver() (*DNSResolverConfig, error) {
re := regexp.MustCompile(`^(?P<proto>(.*))://(?P<host>[A-Za-z0-9\-\.]+):(?P<port>[0-9]+)?(.*)$`)
matches := re.FindStringSubmatch(c.DNSResolver)
if len(matches) == 0 {
return DNSResolverConfig{}, errors.New("ParseError")
return nil, ErrInvalidDNSResolver
}
r := make(map[string]string)
for i, k := range re.SubexpNames() {
@ -109,8 +111,14 @@ func (c *Config) ParseDNSResolver() (DNSResolverConfig, error) {
r[k] = matches[i]
}
}
return DNSResolverConfig{
port, err := strconv.Atoi(r["port"])
if err != nil {
return nil, err
}
if port < 1 || port > 65535 {
return nil, ErrInvalidDNSResolverPort
}
return &DNSResolverConfig{
Protocol: r["proto"],
Host: r["host"],
Port: r["port"],
@ -150,20 +158,25 @@ func (c *Config) getHTTPClient() *http.Client {
},
}
if c.HasCustomDNSResolver() {
dnsResolver, _ := c.ParseDNSResolver()
dialer := &net.Dialer{
Resolver: &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
d := net.Dialer{}
return d.DialContext(ctx, dnsResolver.Protocol, fmt.Sprintf("%s:%s", dnsResolver.Host, dnsResolver.Port))
dnsResolver, err := c.parseDNSResolver()
if err != nil {
// We're ignoring the error, because it should have been validated on startup ValidateAndSetDefaults.
// It shouldn't happen, but if it does, we'll log it... Better safe than sorry ;)
log.Println("[client][getHTTPClient] THIS SHOULD NOT HAPPEN. Silently ignoring invalid DNS resolver due to error:", err.Error())
} else {
dialer := &net.Dialer{
Resolver: &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
d := net.Dialer{}
return d.DialContext(ctx, dnsResolver.Protocol, dnsResolver.Host+":"+dnsResolver.Port)
},
},
},
}
c.httpClient.Transport.(*http.Transport).DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.DialContext(ctx, network, addr)
}
}
dialCtx := func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.DialContext(ctx, network, addr)
}
c.httpClient.Transport.(*http.Transport).DialContext = dialCtx
}
if c.HasOAuth2Config() {
c.httpClient = configureOAuth2(c.httpClient, *c.OAuth2Config)

View File

@ -35,3 +35,47 @@ func TestConfig_getHTTPClient(t *testing.T) {
t.Error("expected Config.IgnoreRedirect set to true to cause the HTTP client's CheckRedirect to return http.ErrUseLastResponse")
}
}
func TestConfig_ValidateAndSetDefaults_withCustomDNSResolver(t *testing.T) {
type args struct {
dnsResolver string
}
tests := []struct {
name string
args args
wantErr bool
}{
{
name: "with-valid-resolver",
args: args{
dnsResolver: "tcp://1.1.1.1:53",
},
wantErr: false,
},
{
name: "with-invalid-resolver-port",
args: args{
dnsResolver: "tcp://127.0.0.1:99999",
},
wantErr: true,
},
{
name: "with-invalid-resolver-format",
args: args{
dnsResolver: "foobar",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &Config{
DNSResolver: tt.args.dnsResolver,
}
err := cfg.ValidateAndSetDefaults()
if (err != nil) != tt.wantErr {
t.Errorf("ValidateAndSetDefaults() error=%v, wantErr=%v", err, tt.wantErr)
}
})
}
}