Minor improvements

This commit is contained in:
TwinProduction 2021-06-05 16:35:52 -04:00
parent d07d3434a6
commit ca977fefa8
2 changed files with 28 additions and 37 deletions

View File

@ -3,7 +3,7 @@ package client
import ( import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"fmt" "errors"
"net" "net"
"net/http" "net/http"
"net/smtp" "net/smtp"
@ -78,34 +78,29 @@ func CanCreateTCPConnection(address string) bool {
return true return true
} }
func CanPerformStartTls(address string, insecure bool) (connected bool, certificate *x509.Certificate, err error) { // CanPerformStartTLS checks whether a connection can be established to an address using the STARTTLS protocol
tokens := strings.Split(address, ":") func CanPerformStartTLS(address string, insecure bool) (connected bool, certificate *x509.Certificate, err error) {
if len(tokens) != 2 { hostAndPort := strings.Split(address, ":")
err = fmt.Errorf("invalid address for starttls, must HOST:PORT") if len(hostAndPort) != 2 {
return false, nil, errors.New("invalid address for starttls, format must be host:port")
}
smtpClient, err := smtp.Dial(address)
if err != nil {
return return
} }
tlsconfig := &tls.Config{ err = smtpClient.StartTLS(&tls.Config{
InsecureSkipVerify: insecure, InsecureSkipVerify: insecure,
ServerName: tokens[0], ServerName: hostAndPort[0],
} })
c, err := smtp.Dial(address)
if err != nil { if err != nil {
return return
} }
if state, ok := smtpClient.TLSConnectionState(); ok {
err = c.StartTLS(tlsconfig)
if err != nil {
return
}
if state, ok := c.TLSConnectionState(); ok {
certificate = state.PeerCertificates[0] certificate = state.PeerCertificates[0]
} else { } else {
err = fmt.Errorf("could not get TLS connection state") return false, nil, errors.New("could not get TLS connection state")
return
} }
connected = true return true, certificate, nil
return
} }
// Ping checks if an address can be pinged and returns the round-trip time if the address can be pinged // Ping checks if an address can be pinged and returns the round-trip time if the address can be pinged

View File

@ -1,7 +1,6 @@
package client package client
import ( import (
"crypto/x509"
"testing" "testing"
"time" "time"
) )
@ -51,35 +50,32 @@ func TestPing(t *testing.T) {
} }
} }
func TestCanPerformStartTls(t *testing.T) { func TestCanPerformStartTLS(t *testing.T) {
type args struct { type args struct {
address string address string
insecure bool insecure bool
} }
tests := []struct { tests := []struct {
name string name string
args args args args
wantConnected bool wantConnected bool
wantCertificate *x509.Certificate wantErr bool
wantErr bool
}{ }{
{ {
name: "invalid address", name: "invalid address",
args: args{ args: args{
address: "test", address: "test",
}, },
wantConnected: false, wantConnected: false,
wantCertificate: nil, wantErr: true,
wantErr: true,
}, },
{ {
name: "error dial", name: "error dial",
args: args{ args: args{
address: "test:1234", address: "test:1234",
}, },
wantConnected: false, wantConnected: false,
wantCertificate: nil, wantErr: true,
wantErr: true,
}, },
{ {
name: "valid starttls", name: "valid starttls",
@ -92,13 +88,13 @@ func TestCanPerformStartTls(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
gotConnected, _, err := CanPerformStartTls(tt.args.address, tt.args.insecure) connected, _, err := CanPerformStartTLS(tt.args.address, tt.args.insecure)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("CanPerformStartTls() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("CanPerformStartTLS() err=%v, wantErr=%v", err, tt.wantErr)
return return
} }
if gotConnected != tt.wantConnected { if connected != tt.wantConnected {
t.Errorf("CanPerformStartTls() gotConnected = %v, want %v", gotConnected, tt.wantConnected) t.Errorf("CanPerformStartTLS() connected=%v, wantConnected=%v", connected, tt.wantConnected)
} }
}) })
} }