smegmesh/pkg/conn/connectionmanager.go

162 lines
4.1 KiB
Go
Raw Permalink Normal View History

2023-10-24 01:12:38 +02:00
package conn
import (
"crypto/tls"
2023-10-27 18:49:18 +02:00
"crypto/x509"
"errors"
"os"
2023-10-24 01:12:38 +02:00
"sync"
logging "github.com/tim-beatham/wgmesh/pkg/log"
)
// ConnectionManager defines an interface for maintaining peer connections
type ConnectionManager interface {
// AddConnection adds an instance of a connection at the given endpoint
// or error if something went wrong
AddConnection(endPoint string) (PeerConnection, error)
// GetConnection returns an instance of a connection at the given endpoint.
// If the endpoint does not exist then add the connection. Returns an error
// if something went wrong
GetConnection(endPoint string) (PeerConnection, error)
// HasConnections returns true if a client has already registered at the givne
// endpoint or false otherwise.
HasConnection(endPoint string) bool
// Goes through all the connections and closes eachone
Close() error
}
// ConnectionManager manages connections between other peers
// in the control plane.
type ConnectionManagerImpl struct {
// clientConnections maps an endpoint to a connection
conLoc sync.RWMutex
clientConnections map[string]PeerConnection
serverConfig *tls.Config
clientConfig *tls.Config
2023-11-05 19:03:58 +01:00
connFactory PeerConnectionFactory
2023-10-24 01:12:38 +02:00
}
// Create a new instance of a connection manager.
2023-11-05 19:03:58 +01:00
type NewConnectionManagerParams struct {
2023-10-24 01:12:38 +02:00
// The path to the certificate
CertificatePath string
// The private key of the node
PrivateKey string
// Whether or not to skip certificate verification
SkipCertVerification bool
2023-10-27 18:49:18 +02:00
CaCert string
2023-11-05 19:03:58 +01:00
ConnFactory PeerConnectionFactory
2023-10-24 01:12:38 +02:00
}
// NewConnectionManager: Creates a new instance of a ConnectionManager or an error
// if something went wrong.
2023-11-05 19:03:58 +01:00
func NewConnectionManager(params *NewConnectionManagerParams) (ConnectionManager, error) {
2023-10-24 01:12:38 +02:00
cert, err := tls.LoadX509KeyPair(params.CertificatePath, params.PrivateKey)
if err != nil {
logging.Log.WriteErrorf("Failed to load key pair: %s\n", err.Error())
logging.Log.WriteErrorf("Certificate Path: %s\n", params.CertificatePath)
logging.Log.WriteErrorf("Private Key Path: %s\n", params.PrivateKey)
return nil, err
}
serverAuth := tls.RequireAndVerifyClientCert
if params.SkipCertVerification {
serverAuth = tls.RequireAnyClientCert
}
2023-10-27 18:49:18 +02:00
certPool := x509.NewCertPool()
if !params.SkipCertVerification {
if params.CaCert == "" {
return nil, errors.New("CA Cert is not specified")
}
caCert, err := os.ReadFile(params.CaCert)
if err != nil {
return nil, err
}
certPool.AppendCertsFromPEM(caCert)
}
2023-10-24 01:12:38 +02:00
serverConfig := &tls.Config{
ClientAuth: serverAuth,
Certificates: []tls.Certificate{cert},
}
clientConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
InsecureSkipVerify: params.SkipCertVerification,
2023-10-27 18:49:18 +02:00
RootCAs: certPool,
2023-10-24 01:12:38 +02:00
}
connections := make(map[string]PeerConnection)
2023-11-05 19:03:58 +01:00
connMgr := ConnectionManagerImpl{
sync.RWMutex{},
2023-11-03 16:24:18 +01:00
connections,
serverConfig,
clientConfig,
2023-11-05 19:03:58 +01:00
params.ConnFactory,
2023-11-03 16:24:18 +01:00
}
2023-10-24 01:12:38 +02:00
return &connMgr, nil
}
// GetConnection: Returns the given connection if it exists. If it does not exist then add
// the connection. Returns an error if something went wrong
func (m *ConnectionManagerImpl) GetConnection(endpoint string) (PeerConnection, error) {
m.conLoc.Lock()
conn, exists := m.clientConnections[endpoint]
m.conLoc.Unlock()
if !exists {
return m.AddConnection(endpoint)
}
return conn, nil
}
// AddConnection: Adds a connection to the list of connections to manage.
func (m *ConnectionManagerImpl) AddConnection(endPoint string) (PeerConnection, error) {
m.conLoc.Lock()
conn, exists := m.clientConnections[endPoint]
m.conLoc.Unlock()
if exists {
return conn, nil
}
2023-11-05 19:03:58 +01:00
connections, err := m.connFactory(m.clientConfig, endPoint)
2023-10-24 01:12:38 +02:00
if err != nil {
return nil, err
}
m.conLoc.Lock()
m.clientConnections[endPoint] = connections
m.conLoc.Unlock()
2023-11-03 16:24:18 +01:00
2023-10-24 01:12:38 +02:00
return connections, nil
}
// HasConnection Returns TRUE if the given endpoint exists
func (m *ConnectionManagerImpl) HasConnection(endPoint string) bool {
_, exists := m.clientConnections[endPoint]
return exists
}
func (m *ConnectionManagerImpl) Close() error {
for _, conn := range m.clientConnections {
if err := conn.Close(); err != nil {
return err
}
}
return nil
}