mirror of
https://github.com/tim-beatham/smegmesh.git
synced 2024-12-12 17:41:08 +01:00
93 lines
2.2 KiB
Go
93 lines
2.2 KiB
Go
package conn
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"net"
|
|
"time"
|
|
|
|
"github.com/tim-beatham/wgmesh/pkg/auth"
|
|
logging "github.com/tim-beatham/wgmesh/pkg/log"
|
|
"github.com/tim-beatham/wgmesh/pkg/rpc"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/credentials"
|
|
)
|
|
|
|
// ConnectionServer manages the gRPC server
|
|
type ConnectionServer struct {
|
|
severConfig *tls.Config
|
|
JwtManager *auth.JwtManager
|
|
server *grpc.Server
|
|
authProvider rpc.AuthenticationServer
|
|
ctrlProvider rpc.MeshCtrlServerServer
|
|
}
|
|
|
|
type NewConnectionServerParams struct {
|
|
CertificatePath string
|
|
PrivateKey string
|
|
SkipCertVerification bool
|
|
AuthProvider rpc.AuthenticationServer
|
|
CtrlProvider rpc.MeshCtrlServerServer
|
|
}
|
|
|
|
// NewConnectionServer: create a new gRPC connection server instance
|
|
func NewConnectionServer(params *NewConnectionServerParams) (*ConnectionServer, error) {
|
|
cert, err := tls.LoadX509KeyPair(params.CertificatePath, params.PrivateKey)
|
|
|
|
if err != nil {
|
|
logging.ErrorLog.Printf("Failed to load key pair: %s\n", err.Error())
|
|
logging.ErrorLog.Printf("Certificate Path: %s\n", params.CertificatePath)
|
|
logging.ErrorLog.Printf("Private Key Path: %s\n", params.PrivateKey)
|
|
return nil, err
|
|
}
|
|
|
|
serverAuth := tls.RequireAndVerifyClientCert
|
|
|
|
if params.SkipCertVerification {
|
|
serverAuth = tls.RequireAnyClientCert
|
|
}
|
|
|
|
serverConfig := &tls.Config{
|
|
ClientAuth: serverAuth,
|
|
Certificates: []tls.Certificate{cert},
|
|
}
|
|
|
|
jwtManager := auth.NewJwtManager("tim123", 24*time.Hour)
|
|
|
|
server := grpc.NewServer(
|
|
grpc.UnaryInterceptor(jwtManager.GetAuthInterceptor()),
|
|
grpc.Creds(credentials.NewTLS(serverConfig)),
|
|
)
|
|
|
|
authProvider := params.AuthProvider
|
|
ctrlProvider := params.CtrlProvider
|
|
|
|
connServer := ConnectionServer{
|
|
serverConfig,
|
|
jwtManager,
|
|
server,
|
|
authProvider,
|
|
ctrlProvider,
|
|
}
|
|
|
|
return &connServer, nil
|
|
}
|
|
|
|
func (s *ConnectionServer) Listen() error {
|
|
rpc.RegisterMeshCtrlServerServer(s.server, s.ctrlProvider)
|
|
rpc.RegisterAuthenticationServer(s.server, s.authProvider)
|
|
|
|
lis, err := net.Listen("tcp", ":8080")
|
|
|
|
if err != nil {
|
|
logging.ErrorLog.Println(err.Error())
|
|
return err
|
|
}
|
|
|
|
if err := s.server.Serve(lis); err != nil {
|
|
logging.ErrorLog.Println(err.Error())
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|