#120: Add support for StartTLS protocol

* add starttls

* remove starttls from default config

Co-authored-by: Gopher Johns <gopher.johns28@gmail.com>
This commit is contained in:
gopher-johns 2021-06-05 21:47:11 +02:00 committed by GitHub
parent 81aeb7a48e
commit 2131fa4412
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 101 additions and 2 deletions

View File

@ -2,10 +2,14 @@ package client
import ( import (
"crypto/tls" "crypto/tls"
"crypto/x509"
"fmt"
"net" "net"
"net/http" "net/http"
"net/smtp"
"os" "os"
"strconv" "strconv"
"strings"
"time" "time"
"github.com/go-ping/ping" "github.com/go-ping/ping"
@ -74,6 +78,36 @@ func CanCreateTCPConnection(address string) bool {
return true return true
} }
func CanPerformStartTls(address string, insecure bool) (connected bool, certificate *x509.Certificate, err error) {
tokens := strings.Split(address, ":")
if len(tokens) != 2 {
err = fmt.Errorf("invalid address for starttls, must HOST:PORT")
return
}
tlsconfig := &tls.Config{
InsecureSkipVerify: insecure,
ServerName: tokens[0],
}
c, err := smtp.Dial(address)
if err != nil {
return
}
err = c.StartTLS(tlsconfig)
if err != nil {
return
}
if state, ok := c.TLSConnectionState(); ok {
certificate = state.PeerCertificates[0]
} else {
err = fmt.Errorf("could not get TLS connection state")
return
}
connected = true
return
}
// Ping checks if an address can be pinged and returns the round-trip time if the address can be pinged // Ping checks if an address can be pinged and returns the round-trip time if the address can be pinged
// //
// Note that this function takes at least 100ms, even if the address is 127.0.0.1 // Note that this function takes at least 100ms, even if the address is 127.0.0.1

View File

@ -1,6 +1,7 @@
package client package client
import ( import (
"crypto/x509"
"testing" "testing"
"time" "time"
) )
@ -49,3 +50,56 @@ func TestPing(t *testing.T) {
} }
} }
} }
func TestCanPerformStartTls(t *testing.T) {
type args struct {
address string
insecure bool
}
tests := []struct {
name string
args args
wantConnected bool
wantCertificate *x509.Certificate
wantErr bool
}{
{
name: "invalid address",
args: args{
address: "test",
},
wantConnected: false,
wantCertificate: nil,
wantErr: true,
},
{
name: "error dial",
args: args{
address: "test:1234",
},
wantConnected: false,
wantCertificate: nil,
wantErr: true,
},
{
name: "valid starttls",
args: args{
address: "smtp.gmail.com:587",
},
wantConnected: true,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotConnected, _, err := CanPerformStartTls(tt.args.address, tt.args.insecure)
if (err != nil) != tt.wantErr {
t.Errorf("CanPerformStartTls() error = %v, wantErr %v", err, tt.wantErr)
return
}
if gotConnected != tt.wantConnected {
t.Errorf("CanPerformStartTls() gotConnected = %v, want %v", gotConnected, tt.wantConnected)
}
})
}
}

View File

@ -2,6 +2,7 @@ package core
import ( import (
"bytes" "bytes"
"crypto/x509"
"encoding/json" "encoding/json"
"errors" "errors"
"io/ioutil" "io/ioutil"
@ -178,10 +179,12 @@ func (service *Service) call(result *Result) {
var request *http.Request var request *http.Request
var response *http.Response var response *http.Response
var err error var err error
var certificate *x509.Certificate
isServiceDNS := service.DNS != nil isServiceDNS := service.DNS != nil
isServiceTCP := strings.HasPrefix(service.URL, "tcp://") isServiceTCP := strings.HasPrefix(service.URL, "tcp://")
isServiceICMP := strings.HasPrefix(service.URL, "icmp://") isServiceICMP := strings.HasPrefix(service.URL, "icmp://")
isServiceHTTP := !isServiceDNS && !isServiceTCP && !isServiceICMP isServiceStartTLS := strings.HasPrefix(service.URL, "starttls://")
isServiceHTTP := !isServiceDNS && !isServiceTCP && !isServiceICMP && !isServiceStartTLS
if isServiceHTTP { if isServiceHTTP {
request = service.buildHTTPRequest() request = service.buildHTTPRequest()
} }
@ -189,6 +192,14 @@ func (service *Service) call(result *Result) {
if isServiceDNS { if isServiceDNS {
service.DNS.query(service.URL, result) service.DNS.query(service.URL, result)
result.Duration = time.Since(startTime) result.Duration = time.Since(startTime)
} else if isServiceStartTLS {
result.Connected, certificate, err = client.CanPerformStartTls(strings.TrimPrefix(service.URL, "starttls://"), service.Insecure)
if err != nil {
result.Errors = append(result.Errors, err.Error())
return
}
result.Duration = time.Since(startTime)
result.CertificateExpiration = time.Until(certificate.NotAfter)
} else if isServiceTCP { } else if isServiceTCP {
result.Connected = client.CanCreateTCPConnection(strings.TrimPrefix(service.URL, "tcp://")) result.Connected = client.CanCreateTCPConnection(strings.TrimPrefix(service.URL, "tcp://"))
result.Duration = time.Since(startTime) result.Duration = time.Since(startTime)
@ -203,7 +214,7 @@ func (service *Service) call(result *Result) {
} }
defer response.Body.Close() defer response.Body.Close()
if response.TLS != nil && len(response.TLS.PeerCertificates) > 0 { if response.TLS != nil && len(response.TLS.PeerCertificates) > 0 {
certificate := response.TLS.PeerCertificates[0] certificate = response.TLS.PeerCertificates[0]
result.CertificateExpiration = time.Until(certificate.NotAfter) result.CertificateExpiration = time.Until(certificate.NotAfter)
} }
result.HTTPStatus = response.StatusCode result.HTTPStatus = response.StatusCode