diff --git a/client/client.go b/client/client.go index f782bd75..6d1d5a49 100644 --- a/client/client.go +++ b/client/client.go @@ -6,7 +6,6 @@ import ( "encoding/json" "errors" "fmt" - "golang.org/x/net/websocket" "net" "net/http" "net/smtp" @@ -19,6 +18,7 @@ import ( "github.com/ishidawataru/sctp" ping "github.com/prometheus-community/pro-bing" "golang.org/x/crypto/ssh" + "golang.org/x/net/websocket" ) var ( @@ -261,24 +261,25 @@ func Ping(address string, config *Config) (bool, time.Duration) { return true, 0 } -// Open a websocket connection, write `body` and return a message from the server -func QueryWebSocket(address string, config *Config, body string) (bool, []byte, error) { +// QueryWebSocket opens a websocket connection, write `body` and return a message from the server +func QueryWebSocket(address, body string, config *Config) (bool, []byte, error) { const ( Origin = "http://localhost/" MaximumMessageSize = 1024 // in bytes ) - wsConfig, err := websocket.NewConfig(address, Origin) if err != nil { return false, nil, fmt.Errorf("error configuring websocket connection: %w", err) } + if config != nil { + wsConfig.Dialer = &net.Dialer{Timeout: config.Timeout} + } // Dial URL ws, err := websocket.DialConfig(wsConfig) if err != nil { return false, nil, fmt.Errorf("error dialing websocket: %w", err) } defer ws.Close() - connected := true // Write message if _, err := ws.Write([]byte(body)); err != nil { return false, nil, fmt.Errorf("error writing websocket body: %w", err) @@ -289,7 +290,7 @@ func QueryWebSocket(address string, config *Config, body string) (bool, []byte, if n, err = ws.Read(msg); err != nil { return false, nil, fmt.Errorf("error reading websocket message: %w", err) } - return connected, msg[:n], nil + return true, msg[:n], nil } // InjectHTTPClient is used to inject a custom HTTP client for testing purposes diff --git a/client/client_test.go b/client/client_test.go index e11381f1..3e91dda8 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -253,3 +253,14 @@ func TestHttpClientProvidesOAuth2BearerToken(t *testing.T) { t.Error("exptected `secret-token` as Bearer token in the mocked response header `X-Org-Authorization`, but got", response.Header.Get("X-Org-Authorization")) } } + +func TestQueryWebSocket(t *testing.T) { + _, _, err := QueryWebSocket("", "body", &Config{Timeout: 2 * time.Second}) + if err == nil { + t.Error("expected an error due to the address being invalid") + } + _, _, err = QueryWebSocket("ws://example.org", "body", &Config{Timeout: 2 * time.Second}) + if err == nil { + t.Error("expected an error due to the target not being websocket-friendly") + } +} diff --git a/client/config.go b/client/config.go index eea948a0..752969e1 100644 --- a/client/config.go +++ b/client/config.go @@ -16,7 +16,7 @@ import ( ) const ( - defaultHTTPTimeout = 10 * time.Second + defaultTimeout = 10 * time.Second ) var ( @@ -27,7 +27,7 @@ var ( defaultConfig = Config{ Insecure: false, IgnoreRedirect: false, - Timeout: defaultHTTPTimeout, + Timeout: defaultTimeout, } ) diff --git a/client/config_test.go b/client/config_test.go index b09e7e57..9862e08e 100644 --- a/client/config_test.go +++ b/client/config_test.go @@ -13,7 +13,7 @@ func TestConfig_getHTTPClient(t *testing.T) { if !(insecureClient.Transport).(*http.Transport).TLSClientConfig.InsecureSkipVerify { t.Error("expected Config.Insecure set to true to cause the HTTP client to skip certificate verification") } - if insecureClient.Timeout != defaultHTTPTimeout { + if insecureClient.Timeout != defaultTimeout { t.Error("expected Config.Timeout to default the HTTP client to a timeout of 10s") } request, _ := http.NewRequest("GET", "", nil) diff --git a/core/endpoint.go b/core/endpoint.go index 54acba2e..1276f483 100644 --- a/core/endpoint.go +++ b/core/endpoint.go @@ -139,7 +139,7 @@ type SSH struct { Password string `yaml:"password,omitempty"` } -// Validate validates the endpoint +// ValidateAndSetDefaults validates the endpoint func (s *SSH) ValidateAndSetDefaults() error { if s.Username == "" { return ErrEndpointWithoutSSHUsername @@ -376,12 +376,12 @@ func (endpoint *Endpoint) call(result *Result) { } else if endpointType == EndpointTypeICMP { result.Connected, result.Duration = client.Ping(strings.TrimPrefix(endpoint.URL, "icmp://"), endpoint.ClientConfig) } else if endpointType == EndpointTypeWS { - result.Connected, result.Body, err = client.QueryWebSocket(endpoint.URL, endpoint.ClientConfig, endpoint.Body) - result.Duration = time.Since(startTime) + result.Connected, result.Body, err = client.QueryWebSocket(endpoint.URL, endpoint.Body, endpoint.ClientConfig) if err != nil { result.AddError(err.Error()) return } + result.Duration = time.Since(startTime) } else if endpointType == EndpointTypeSSH { var cli *ssh.Client result.Connected, cli, err = client.CanCreateSSHConnection(strings.TrimPrefix(endpoint.URL, "ssh://"), endpoint.SSH.Username, endpoint.SSH.Password, endpoint.ClientConfig) diff --git a/go.sum b/go.sum index 972920d3..ea499636 100644 --- a/go.sum +++ b/go.sum @@ -160,6 +160,7 @@ golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.3.0/go.mod h1:q750SLmJuPmVoN1blW3UFBPREJfb1KmY3vwxfr+nFDA= +golang.org/x/term v0.12.0 h1:/ZfYdc3zq+q02Rv9vGqTeSItdzZTSNDmfTi0mBAuidU= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=