From ca977fefa8bc0613e4f0e275eb9115960c458002 Mon Sep 17 00:00:00 2001 From: TwinProduction Date: Sat, 5 Jun 2021 16:35:52 -0400 Subject: [PATCH] Minor improvements --- client/client.go | 35 +++++++++++++++-------------------- client/client_test.go | 30 +++++++++++++----------------- 2 files changed, 28 insertions(+), 37 deletions(-) diff --git a/client/client.go b/client/client.go index a07b675a..13aa559d 100644 --- a/client/client.go +++ b/client/client.go @@ -3,7 +3,7 @@ package client import ( "crypto/tls" "crypto/x509" - "fmt" + "errors" "net" "net/http" "net/smtp" @@ -78,34 +78,29 @@ func CanCreateTCPConnection(address string) bool { return true } -func CanPerformStartTls(address string, insecure bool) (connected bool, certificate *x509.Certificate, err error) { - tokens := strings.Split(address, ":") - if len(tokens) != 2 { - err = fmt.Errorf("invalid address for starttls, must HOST:PORT") +// CanPerformStartTLS checks whether a connection can be established to an address using the STARTTLS protocol +func CanPerformStartTLS(address string, insecure bool) (connected bool, certificate *x509.Certificate, err error) { + hostAndPort := strings.Split(address, ":") + if len(hostAndPort) != 2 { + return false, nil, errors.New("invalid address for starttls, format must be host:port") + } + smtpClient, err := smtp.Dial(address) + if err != nil { return } - tlsconfig := &tls.Config{ + err = smtpClient.StartTLS(&tls.Config{ InsecureSkipVerify: insecure, - ServerName: tokens[0], - } - - c, err := smtp.Dial(address) + ServerName: hostAndPort[0], + }) if err != nil { return } - - err = c.StartTLS(tlsconfig) - if err != nil { - return - } - if state, ok := c.TLSConnectionState(); ok { + if state, ok := smtpClient.TLSConnectionState(); ok { certificate = state.PeerCertificates[0] } else { - err = fmt.Errorf("could not get TLS connection state") - return + return false, nil, errors.New("could not get TLS connection state") } - connected = true - return + return true, certificate, nil } // Ping checks if an address can be pinged and returns the round-trip time if the address can be pinged diff --git a/client/client_test.go b/client/client_test.go index ecbc7150..a800a024 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -1,7 +1,6 @@ package client import ( - "crypto/x509" "testing" "time" ) @@ -51,35 +50,32 @@ func TestPing(t *testing.T) { } } -func TestCanPerformStartTls(t *testing.T) { +func TestCanPerformStartTLS(t *testing.T) { type args struct { address string insecure bool } tests := []struct { - name string - args args - wantConnected bool - wantCertificate *x509.Certificate - wantErr bool + name string + args args + wantConnected bool + wantErr bool }{ { name: "invalid address", args: args{ address: "test", }, - wantConnected: false, - wantCertificate: nil, - wantErr: true, + wantConnected: false, + wantErr: true, }, { name: "error dial", args: args{ address: "test:1234", }, - wantConnected: false, - wantCertificate: nil, - wantErr: true, + wantConnected: false, + wantErr: true, }, { name: "valid starttls", @@ -92,13 +88,13 @@ func TestCanPerformStartTls(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotConnected, _, err := CanPerformStartTls(tt.args.address, tt.args.insecure) + connected, _, err := CanPerformStartTLS(tt.args.address, tt.args.insecure) if (err != nil) != tt.wantErr { - t.Errorf("CanPerformStartTls() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("CanPerformStartTLS() err=%v, wantErr=%v", err, tt.wantErr) return } - if gotConnected != tt.wantConnected { - t.Errorf("CanPerformStartTls() gotConnected = %v, want %v", gotConnected, tt.wantConnected) + if connected != tt.wantConnected { + t.Errorf("CanPerformStartTLS() connected=%v, wantConnected=%v", connected, tt.wantConnected) } }) }