smegmesh/pkg/conn/conn_server.go

93 lines
2.2 KiB
Go
Raw Normal View History

package conn
import (
"crypto/tls"
"net"
"time"
"github.com/tim-beatham/wgmesh/pkg/auth"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/rpc"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)
// ConnectionServer manages the gRPC server
type ConnectionServer struct {
severConfig *tls.Config
JwtManager *auth.JwtManager
server *grpc.Server
authProvider rpc.AuthenticationServer
ctrlProvider rpc.MeshCtrlServerServer
}
type NewConnectionServerParams struct {
CertificatePath string
PrivateKey string
SkipCertVerification bool
AuthProvider rpc.AuthenticationServer
CtrlProvider rpc.MeshCtrlServerServer
}
// NewConnectionServer: create a new gRPC connection server instance
func NewConnectionServer(params *NewConnectionServerParams) (*ConnectionServer, error) {
cert, err := tls.LoadX509KeyPair(params.CertificatePath, params.PrivateKey)
if err != nil {
logging.ErrorLog.Printf("Failed to load key pair: %s\n", err.Error())
logging.ErrorLog.Printf("Certificate Path: %s\n", params.CertificatePath)
logging.ErrorLog.Printf("Private Key Path: %s\n", params.PrivateKey)
return nil, err
}
serverAuth := tls.RequireAndVerifyClientCert
if params.SkipCertVerification {
serverAuth = tls.RequireAnyClientCert
}
serverConfig := &tls.Config{
ClientAuth: serverAuth,
Certificates: []tls.Certificate{cert},
}
jwtManager := auth.NewJwtManager("tim123", 24*time.Hour)
server := grpc.NewServer(
grpc.UnaryInterceptor(jwtManager.GetAuthInterceptor()),
grpc.Creds(credentials.NewTLS(serverConfig)),
)
authProvider := params.AuthProvider
ctrlProvider := params.CtrlProvider
connServer := ConnectionServer{
serverConfig,
jwtManager,
server,
authProvider,
ctrlProvider,
}
return &connServer, nil
}
func (s *ConnectionServer) Listen() error {
rpc.RegisterMeshCtrlServerServer(s.server, s.ctrlProvider)
rpc.RegisterAuthenticationServer(s.server, s.authProvider)
lis, err := net.Listen("tcp", ":8080")
if err != nil {
logging.ErrorLog.Println(err.Error())
return err
}
if err := s.server.Serve(lis); err != nil {
logging.ErrorLog.Println(err.Error())
return err
}
return nil
}