Refactored client and server and interfaced out dependencies

This commit is contained in:
Tim Beatham 2023-10-02 16:03:41 +01:00
parent 52e5e3d33c
commit a069b89a9a
15 changed files with 426 additions and 203 deletions

21
cert/cert.pem Normal file
View File

@ -0,0 +1,21 @@
-----BEGIN CERTIFICATE-----
MIIDazCCAlOgAwIBAgIUT6VZnyJjB25my9JrUt/qfdX+J8QwDQYJKoZIhvcNAQEL
BQAwRTELMAkGA1UEBhMCVUsxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM
GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0yMzEwMDExNDM5MjFaFw0zMzA5
MjgxNDM5MjFaMEUxCzAJBgNVBAYTAlVLMRMwEQYDVQQIDApTb21lLVN0YXRlMSEw
HwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwggEiMA0GCSqGSIb3DQEB
AQUAA4IBDwAwggEKAoIBAQCVrc2ZbkM+ICgr9M9AahLijQOmbqhH03PtqUOprMuX
KGzKiG8v6VWCzdqrDMJTBe24/Ph9KUda8J63ra+uEfPXfTgox/NkbMVkd4qz5vIW
a6Q22g3RU2W8LpSczlcAdEvWBKxakWVnPvi1Sw/gj9Yn//HZxOvANeaTzr+wWNJa
VpTTXBPnvkpDY5GkfkSVkt1cZqCntZQAx85xBW1Bth860d0lZPibJBBtdtX3QO7r
PxeOgARB97J964M2DDvScaLiTH5+qQFzj/bS06Km+7s2rmA9ilPK/GlZb6Wc8f3Q
NdanZwF/odoLKFkW4cj0dG3vrRqJGKSO1tTk6OGrQfBTAgMBAAGjUzBRMB0GA1Ud
DgQWBBRLjaTwD74slcrdH0AWwqnCIBzDvzAfBgNVHSMEGDAWgBRLjaTwD74slcrd
H0AWwqnCIBzDvzAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQCQ
50dhW6+cdzv6vfTrhe5ABOlJ288cxrOnpqKZLK0kWgtXTBXuJdIMqKO7f1dNGGAF
fbhcIoo8YsTVYAHvK0e0nUvKKTj5Jq39YXX4jSmLZMhV9RCxHiuzn3a0Szly2FRG
oLhmz+ib0WmROmspLD+T500toayGi3gfoWALo/LtOSYqUI9JNlFXPEyOfg1dkKfE
op/8Nx4DY73mHtp25dKL3mG1FAa0MQQvDnYTv5BNMRiG2k3N4AL2nORR60PXZV+S
oW9vF+bDWo++GJjmTVgbJPX3joH2B4mg97f4L9i4KqXW38hSB890iyp02j7vXM8Y
vT7vM2Qae3Y48SeOdZIX
-----END CERTIFICATE-----

28
cert/key.pem Normal file
View File

@ -0,0 +1,28 @@
-----BEGIN PRIVATE KEY-----
MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCVrc2ZbkM+ICgr
9M9AahLijQOmbqhH03PtqUOprMuXKGzKiG8v6VWCzdqrDMJTBe24/Ph9KUda8J63
ra+uEfPXfTgox/NkbMVkd4qz5vIWa6Q22g3RU2W8LpSczlcAdEvWBKxakWVnPvi1
Sw/gj9Yn//HZxOvANeaTzr+wWNJaVpTTXBPnvkpDY5GkfkSVkt1cZqCntZQAx85x
BW1Bth860d0lZPibJBBtdtX3QO7rPxeOgARB97J964M2DDvScaLiTH5+qQFzj/bS
06Km+7s2rmA9ilPK/GlZb6Wc8f3QNdanZwF/odoLKFkW4cj0dG3vrRqJGKSO1tTk
6OGrQfBTAgMBAAECggEAC4kwrmGUJyadUf31Nza1q+ZIYLxoldiTN77y6xHZQxYn
hFiNkTi/kWxCLSq3k2SClN5SXHsg975RzUBCqPzTOUl6WZJHjPbhI8Qe2Yy0HcxA
BMY7iGWQErfYVlmE1REhgyYrDnPkR9fPnVFisOEFFWIhhrIppU/CLKQjm/jMhY/G
jdVaaTUcz9ee80BB8S6RCOWZLVc2/yYeIkby7AdGf8TUMqOvw/7AnLo4KwC3t06d
ZK+bCVpTD6O4d4VcXzy5eBDwsYUfWdLSp9JEuLJRQlsq449nXF9V7xjHCD8zOoqI
9PYh6xvPNB6fr9zSaOzLH9A4v+1zQqOoLHRliG2IcQKBgQDFlsYSzKUw+ae8AtqZ
qWQiHSXhEoeosjGHLvDWyXmqdORpDparDOw1b0UlFvt3wm/QXU8q3UtNcWn7/aP8
f/QsKZshJnuk8/+piJBv0v+pHZjpx7lTGVnfTUXpaP5T1EWz9p6HUX5qLHA7XUpr
hQJLt0evvYv06GDtFLIjzHAHSQKBgQDB7UWO1n9TUofYHxG9zhoKsdCvNxeACxJJ
EA1Ue0Ri+y3FnUYY3H9JqQ4d4k3xm731rbgV6TD15xzqv+RGL0+pQ1dDKy4T0lH8
+bsjRShrq+QVxVLBWff5rike1LTk8Q2bmFlv1COft+edrMsZOpK4af9QINbGc/wF
te5d77GuuwKBgC1bQvSlzXXEmWBrN0r2u2mpTzyvSDzNStlBST/E2Azs8FG9a5Cw
UrihZjnxYKBJHemywa2RRuvsEOwreS1JIf/RPS8K6m8fI50DIETLJqzngmaH1l7g
/uRnlJjT5S3RGH8LKbDeYCp3MPwvmhm8Wp6O4AHTfQEnJrjFe28ESuMhAoGAUiAT
dvwri7PFx6bQsprXuHO5NpqUHyuRINPlcUOKoIhSx/9ksh6e4Sjwy4MNEyareaGJ
9e19SIYJXvjIyVg72iikidN9ffNxuTphH/yns4Fl5DpeY3egZmJ1E5Ns0A+tfZk1
NwCV3YvaUJHeqN5/SA3Li7l8eyqfLiPvwGRD0QUCgYEAs4VG+7f8qyTQ/9l4VzTj
1G4naIfDxOS8UTnbc3KJqk48yNuPHLUoAxXmmA+ulqsaLhW1Xn+PWTXdFVaHQ5eB
WCsgnrvi9zrznqyVi54y0lrQTt6dMsLpul/29zKR/464Uyzcdy0008Khl3dDTk0o
91xucId8s41do8dEqaHVEhE=
-----END PRIVATE KEY-----

View File

@ -1,3 +1,4 @@
certificatePath: "../../cert/cert.pem" certificatePath: "../../cert/cert.pem"
privateKeyPath: "../../cert/key.pem" privateKeyPath: "../../cert/key.pem"
skipCertVerification: true skipCertVerification: true
ifName: "wgmesh"

View File

@ -2,56 +2,52 @@ package main
import ( import (
"log" "log"
"net"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/conn"
ctrlserver "github.com/tim-beatham/wgmesh/pkg/ctrlserver" ctrlserver "github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc" "github.com/tim-beatham/wgmesh/pkg/ipc"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/middleware" "github.com/tim-beatham/wgmesh/pkg/middleware"
"github.com/tim-beatham/wgmesh/pkg/robin" "github.com/tim-beatham/wgmesh/pkg/robin"
"github.com/tim-beatham/wgmesh/pkg/rpc"
wg "github.com/tim-beatham/wgmesh/pkg/wg" wg "github.com/tim-beatham/wgmesh/pkg/wg"
) )
const ifName = "wgmesh"
func main() { func main() {
wgClient, err := wg.CreateClient(ifName)
if err != nil {
log.Fatalf("Could not create interface %s\n", ifName)
}
conf, err := conf.ParseConfiguration("./configuration.yaml") conf, err := conf.ParseConfiguration("./configuration.yaml")
if err != nil {
newConnParams := conn.NewConnectionsParams{ log.Fatalln("Could not parse configuration")
CertificatePath: conf.CertificatePath,
PrivateKey: conf.PrivateKeyPath,
SkipCertVerification: conf.SkipCertVerification,
} }
conn, err := conn.NewConnection(&newConnParams) wgClient, err := wg.CreateClient(conf.IfName)
var robinRpc robin.RobinRpc
var robinIpc robin.RobinIpc
var authProvider middleware.AuthRpcProvider
ctrlServerParams := ctrlserver.NewCtrlServerParams{
WgClient: wgClient,
Conf: conf,
AuthProvider: &authProvider,
CtrlProvider: &robinRpc,
}
ctrlServer, err := ctrlserver.NewCtrlServer(&ctrlServerParams)
authProvider.Manager = ctrlServer.ConnectionServer.JwtManager
robinRpc.Server = ctrlServer
robinIpc.Server = ctrlServer
if err != nil { if err != nil {
return logging.ErrorLog.Fatalln(err.Error())
} }
ctrlServer := ctrlserver.NewCtrlServer(wgClient, conn, "wgmesh")
log.Println("Running IPC Handler") log.Println("Running IPC Handler")
robinIpc := robin.NewRobinIpc(ctrlServer) go ipc.RunIpcHandler(&robinIpc)
robinRpc := robin.NewRobinRpc(ctrlServer)
go ipc.RunIpcHandler(robinIpc) err = ctrlServer.ConnectionServer.Listen()
grpc := conn.Listen(ctrlServer.JwtManager.GetAuthInterceptor()) if err != nil {
rpc.NewRpcServer(grpc, robinRpc, middleware.NewAuthProvider(ctrlServer)) logging.ErrorLog.Fatalln(err.Error())
lis, err := net.Listen("tcp", ":8080")
if err := grpc.Serve(lis); err != nil {
log.Fatal(err.Error())
} }
defer wgClient.Close() defer wgClient.Close()

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"strconv"
"strings" "strings"
"time" "time"
@ -44,6 +45,8 @@ func NewJwtManager(secretKey string, tokenDuration time.Duration) *JwtManager {
} }
func (m *JwtManager) CreateClaims(meshId string, alias string) (*string, error) { func (m *JwtManager) CreateClaims(meshId string, alias string) (*string, error) {
logging.InfoLog.Println("MeshID: " + meshId)
logging.InfoLog.Println("Token Duration: " + strconv.Itoa(int(m.tokenDuration)))
node := JwtNode{ node := JwtNode{
MeshId: meshId, MeshId: meshId,
Alias: alias, Alias: alias,

View File

@ -1,51 +0,0 @@
package auth
import (
"errors"
logging "github.com/tim-beatham/wgmesh/pkg/log"
)
type TokenMesh struct {
Tokens map[string]string
}
type TokenManager struct {
Meshes map[string]*TokenMesh
}
func (m *TokenManager) AddToken(meshId, endpoint, token string) error {
mesh, ok := m.Meshes[endpoint]
if !ok {
mesh = new(TokenMesh)
mesh.Tokens = make(map[string]string)
m.Meshes[endpoint] = mesh
}
mesh.Tokens[meshId] = token
return nil
}
func (m *TokenManager) GetToken(meshId, endpoint string) (string, error) {
mesh, ok := m.Meshes[endpoint]
if !ok {
logging.ErrorLog.Printf("Endpoint doesnot exist: %s\n", endpoint)
return "", errors.New("Endpoint does not exist in the token manager")
}
token, ok := mesh.Tokens[meshId]
if !ok {
return "", errors.New("MeshId does not exist")
}
return token, nil
}
func NewTokenManager() *TokenManager {
var manager *TokenManager = new(TokenManager)
manager.Meshes = make(map[string]*TokenMesh)
return manager
}

View File

@ -12,6 +12,7 @@ type WgMeshConfiguration struct {
CertificatePath string `yaml:"certificatePath"` CertificatePath string `yaml:"certificatePath"`
PrivateKeyPath string `yaml:"privateKeyPath"` PrivateKeyPath string `yaml:"privateKeyPath"`
SkipCertVerification bool `yaml:"skipCertVerification"` SkipCertVerification bool `yaml:"skipCertVerification"`
IfName string `yaml:"ifName"`
} }
func ParseConfiguration(filePath string) (*WgMeshConfiguration, error) { func ParseConfiguration(filePath string) (*WgMeshConfiguration, error) {

View File

@ -3,79 +3,122 @@
package conn package conn
import ( import (
"context"
"crypto/tls" "crypto/tls"
"errors"
"time"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/rpc"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata"
) )
// PeerConnection interfacing for a secure connection between // PeerConnection interfacing for a secure connection between
// two peers. // two peers.
type PeerConnection interface { type PeerConnection interface {
Connect() error Connect() error
Close() error
Authenticate(meshId string) error
GetClient() (*grpc.ClientConn, error)
CreateAuthContext(meshId string) (context.Context, error)
} }
type WgCtrlConnection struct { type WgCtrlConnection struct {
serverConfig *tls.Config
clientConfig *tls.Config clientConfig *tls.Config
conn *grpc.ClientConn conn *grpc.ClientConn
endpoint string
// tokens maps a meshID to the corresponding token
tokens map[string]string
} }
type NewConnectionsParams struct { var keepAliveParams = keepalive.ClientParameters{
CertificatePath string Time: 5 * time.Minute,
PrivateKey string Timeout: time.Second,
SkipCertVerification bool PermitWithoutStream: true,
} }
func NewConnection(params *NewConnectionsParams) (*WgCtrlConnection, error) { func NewWgCtrlConnection(clientConfig *tls.Config, server string) (*WgCtrlConnection, error) {
cert, err := tls.LoadX509KeyPair(params.CertificatePath, params.PrivateKey) var conn WgCtrlConnection
conn.tokens = make(map[string]string)
conn.clientConfig = clientConfig
conn.endpoint = server
return &conn, nil
}
func (c *WgCtrlConnection) Authenticate(meshId string) error {
conn, err := grpc.Dial(c.endpoint,
grpc.WithTransportCredentials(credentials.NewTLS(c.clientConfig)))
defer conn.Close()
if err != nil { if err != nil {
logging.ErrorLog.Printf("Failed to load key pair: %s\n", err.Error()) return err
logging.ErrorLog.Printf("Certificate Path: %s\n", params.CertificatePath)
logging.ErrorLog.Printf("Private Key Path: %s\n", params.PrivateKey)
return nil, err
} }
serverAuth := tls.RequireAndVerifyClientCert ctx, cancel := context.WithTimeout(context.Background(), time.Second)
if params.SkipCertVerification { client := rpc.NewAuthenticationClient(conn)
serverAuth = tls.RequireAnyClientCert defer cancel()
authRequest := rpc.JoinAuthMeshRequest{
MeshId: meshId,
Alias: lib.GetOutboundIP().String(),
} }
tlsConfig := &tls.Config{ reply, err := client.JoinMesh(ctx, &authRequest)
ClientAuth: serverAuth,
Certificates: []tls.Certificate{cert}, if err != nil {
return err
} }
clientConfig := &tls.Config{ c.tokens[meshId] = *reply.Token
Certificates: []tls.Certificate{cert}, return nil
InsecureSkipVerify: params.SkipCertVerification,
}
wgConnection := WgCtrlConnection{serverConfig: tlsConfig, clientConfig: clientConfig}
return &wgConnection, nil
} }
// Connect: Connects to a new gRPC peer given the address of the other server // ConnectWithToken: Connects to a new gRPC peer given the address of the other server.
func (c *WgCtrlConnection) Connect(server string) (*grpc.ClientConn, error) { func (c *WgCtrlConnection) Connect() error {
conn, err := grpc.Dial(server, grpc.WithTransportCredentials(credentials.NewTLS(c.clientConfig))) conn, err := grpc.Dial(c.endpoint,
grpc.WithKeepaliveParams(keepAliveParams),
grpc.WithTransportCredentials(credentials.NewTLS(c.clientConfig)),
)
if err != nil { if err != nil {
logging.ErrorLog.Printf("Could not connect: %s\n", err.Error()) logging.ErrorLog.Printf("Could not connect: %s\n", err.Error())
return nil, err return err
} }
return conn, nil c.conn = conn
return nil
} }
// Listen: listens to incoming messages // Close: Closes the client connections
func (c *WgCtrlConnection) Listen(i grpc.UnaryServerInterceptor) *grpc.Server { func (c *WgCtrlConnection) Close() error {
server := grpc.NewServer( return c.conn.Close()
grpc.UnaryInterceptor(i), }
grpc.Creds(credentials.NewTLS(c.serverConfig)),
) // GetClient: Gets the client connection
return server func (c *WgCtrlConnection) GetClient() (*grpc.ClientConn, error) {
var err error = nil
if c.conn == nil {
err = errors.New("The client's config does not exist")
}
return c.conn, err
}
// TODO: Implement a mechanism to attach a security token
func (c *WgCtrlConnection) CreateAuthContext(meshId string) (context.Context, error) {
token, ok := c.tokens[meshId]
if !ok {
return nil, errors.New("MeshID: " + meshId + " does not exist")
}
ctx := context.Background()
return metadata.AppendToOutgoingContext(ctx, "authorization", token), nil
} }

86
pkg/conn/conn_manager.go Normal file
View File

@ -0,0 +1,86 @@
package conn
import (
"crypto/tls"
"errors"
logging "github.com/tim-beatham/wgmesh/pkg/log"
)
// ConnectionManager manages connections between other peers
// in the control plane.
type ConnectionManager struct {
// clientConnections maps an endpoint to a connection
clientConnections map[string]PeerConnection
serverConfig *tls.Config
clientConfig *tls.Config
}
type NewConnectionManagerParams struct {
CertificatePath string
PrivateKey string
SkipCertVerification bool
}
func NewConnectionManager(params *NewConnectionManagerParams) (*ConnectionManager, error) {
cert, err := tls.LoadX509KeyPair(params.CertificatePath, params.PrivateKey)
if err != nil {
logging.ErrorLog.Printf("Failed to load key pair: %s\n", err.Error())
logging.ErrorLog.Printf("Certificate Path: %s\n", params.CertificatePath)
logging.ErrorLog.Printf("Private Key Path: %s\n", params.PrivateKey)
return nil, err
}
serverAuth := tls.RequireAndVerifyClientCert
if params.SkipCertVerification {
serverAuth = tls.RequireAnyClientCert
}
serverConfig := &tls.Config{
ClientAuth: serverAuth,
Certificates: []tls.Certificate{cert},
}
clientConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
InsecureSkipVerify: params.SkipCertVerification,
}
connections := make(map[string]PeerConnection)
connMgr := ConnectionManager{connections, serverConfig, clientConfig}
return &connMgr, nil
}
func (m *ConnectionManager) GetConnection(endpoint string) (PeerConnection, error) {
conn, exists := m.clientConnections[endpoint]
if !exists {
return nil, errors.New("endpoint: " + endpoint + " does not exist")
}
return conn, nil
}
type AddConnectionParams struct {
TokenId string
}
// AddToken: Adds a connection to the list of connections to manage
func (m *ConnectionManager) AddConnection(endPoint string) (PeerConnection, error) {
_, exists := m.clientConnections[endPoint]
if exists {
return nil, errors.New("token already exists in the connections")
}
connections, err := NewWgCtrlConnection(m.clientConfig, endPoint)
if err != nil {
return nil, err
}
m.clientConnections[endPoint] = connections
return connections, nil
}

92
pkg/conn/conn_server.go Normal file
View File

@ -0,0 +1,92 @@
package conn
import (
"crypto/tls"
"net"
"time"
"github.com/tim-beatham/wgmesh/pkg/auth"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/rpc"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)
// ConnectionServer manages the gRPC server
type ConnectionServer struct {
severConfig *tls.Config
JwtManager *auth.JwtManager
server *grpc.Server
authProvider rpc.AuthenticationServer
ctrlProvider rpc.MeshCtrlServerServer
}
type NewConnectionServerParams struct {
CertificatePath string
PrivateKey string
SkipCertVerification bool
AuthProvider rpc.AuthenticationServer
CtrlProvider rpc.MeshCtrlServerServer
}
// NewConnectionServer: create a new gRPC connection server instance
func NewConnectionServer(params *NewConnectionServerParams) (*ConnectionServer, error) {
cert, err := tls.LoadX509KeyPair(params.CertificatePath, params.PrivateKey)
if err != nil {
logging.ErrorLog.Printf("Failed to load key pair: %s\n", err.Error())
logging.ErrorLog.Printf("Certificate Path: %s\n", params.CertificatePath)
logging.ErrorLog.Printf("Private Key Path: %s\n", params.PrivateKey)
return nil, err
}
serverAuth := tls.RequireAndVerifyClientCert
if params.SkipCertVerification {
serverAuth = tls.RequireAnyClientCert
}
serverConfig := &tls.Config{
ClientAuth: serverAuth,
Certificates: []tls.Certificate{cert},
}
jwtManager := auth.NewJwtManager("tim123", 24*time.Hour)
server := grpc.NewServer(
grpc.UnaryInterceptor(jwtManager.GetAuthInterceptor()),
grpc.Creds(credentials.NewTLS(serverConfig)),
)
authProvider := params.AuthProvider
ctrlProvider := params.CtrlProvider
connServer := ConnectionServer{
serverConfig,
jwtManager,
server,
authProvider,
ctrlProvider,
}
return &connServer, nil
}
func (s *ConnectionServer) Listen() error {
rpc.RegisterMeshCtrlServerServer(s.server, s.ctrlProvider)
rpc.RegisterAuthenticationServer(s.server, s.authProvider)
lis, err := net.Listen("tcp", ":8080")
if err != nil {
logging.ErrorLog.Println(err.Error())
return err
}
if err := s.server.Serve(lis); err != nil {
logging.ErrorLog.Println(err.Error())
return err
}
return nil
}

View File

@ -6,35 +6,67 @@
package ctrlserver package ctrlserver
import ( import (
"context"
"errors" "errors"
"net" "net"
"time"
"github.com/tim-beatham/wgmesh/pkg/auth" "github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/conn" "github.com/tim-beatham/wgmesh/pkg/conn"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/rpc"
"github.com/tim-beatham/wgmesh/pkg/wg" "github.com/tim-beatham/wgmesh/pkg/wg"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/metadata"
) )
type NewCtrlServerParams struct {
WgClient *wgctrl.Client
Conf *conf.WgMeshConfiguration
AuthProvider rpc.AuthenticationServer
CtrlProvider rpc.MeshCtrlServerServer
}
/* /*
* NewCtrlServer creates a new instance of the ctrlserver. * NewCtrlServer creates a new instance of the ctrlserver.
* It is associated with a WireGuard client and an interface. * It is associated with a WireGuard client and an interface.
* wgClient: Represents the WireGuard control client. * wgClient: Represents the WireGuard control client.
* ifName: WireGuard interface name * ifName: WireGuard interface name
*/ */
func NewCtrlServer(wgClient *wgctrl.Client, conn *conn.WgCtrlConnection, ifName string) *MeshCtrlServer { func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
ctrlServer := new(MeshCtrlServer) ctrlServer := new(MeshCtrlServer)
ctrlServer.Meshes = make(map[string]Mesh) ctrlServer.Meshes = make(map[string]Mesh)
ctrlServer.Client = wgClient ctrlServer.Client = params.WgClient
ctrlServer.Conn = conn ctrlServer.IfName = params.Conf.IfName
ctrlServer.IfName = ifName
ctrlServer.JwtManager = auth.NewJwtManager("bob123", 24*time.Hour) connManagerParams := conn.NewConnectionManagerParams{
ctrlServer.TokenManager = auth.NewTokenManager() CertificatePath: params.Conf.CertificatePath,
return ctrlServer PrivateKey: params.Conf.PrivateKeyPath,
SkipCertVerification: params.Conf.SkipCertVerification,
}
connMgr, err := conn.NewConnectionManager(&connManagerParams)
if err != nil {
return nil, err
}
ctrlServer.ConnectionManager = connMgr
connServerParams := conn.NewConnectionServerParams{
CertificatePath: params.Conf.CertificatePath,
PrivateKey: params.Conf.PrivateKeyPath,
SkipCertVerification: params.Conf.SkipCertVerification,
AuthProvider: params.AuthProvider,
CtrlProvider: params.CtrlProvider,
}
connServer, err := conn.NewConnectionServer(&connServerParams)
if err != nil {
return nil, err
}
ctrlServer.ConnectionServer = connServer
return ctrlServer, nil
} }
/* /*
@ -195,13 +227,3 @@ func (s *MeshCtrlServer) EnableInterface(meshId string) error {
return wg.EnableInterface(s.IfName, node.WgHost) return wg.EnableInterface(s.IfName, node.WgHost)
} }
func (s *MeshCtrlServer) AddToken(ctx context.Context, endpoint, meshId string) (context.Context, error) {
token, err := s.TokenManager.GetToken(meshId, endpoint)
if err != nil {
return nil, err
}
return metadata.AppendToOutgoingContext(ctx, "authorization", token), nil
}

View File

@ -1,7 +1,6 @@
package ctrlserver package ctrlserver
import ( import (
"github.com/tim-beatham/wgmesh/pkg/auth"
"github.com/tim-beatham/wgmesh/pkg/conn" "github.com/tim-beatham/wgmesh/pkg/conn"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@ -27,10 +26,9 @@ type Mesh struct {
* is running * is running
*/ */
type MeshCtrlServer struct { type MeshCtrlServer struct {
Client *wgctrl.Client Client *wgctrl.Client
Meshes map[string]Mesh Meshes map[string]Mesh
IfName string IfName string
Conn *conn.WgCtrlConnection ConnectionManager *conn.ConnectionManager
JwtManager *auth.JwtManager ConnectionServer *conn.ConnectionServer
TokenManager *auth.TokenManager
} }

View File

@ -4,13 +4,14 @@ import (
"context" "context"
"errors" "errors"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/wgmesh/pkg/auth"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/rpc" "github.com/tim-beatham/wgmesh/pkg/rpc"
) )
type AuthRpcProvider struct { type AuthRpcProvider struct {
rpc.UnimplementedAuthenticationServer rpc.UnimplementedAuthenticationServer
server *ctrlserver.MeshCtrlServer Manager *auth.JwtManager
} }
func (a *AuthRpcProvider) JoinMesh(ctx context.Context, in *rpc.JoinAuthMeshRequest) (*rpc.JoinAuthMeshReply, error) { func (a *AuthRpcProvider) JoinMesh(ctx context.Context, in *rpc.JoinAuthMeshRequest) (*rpc.JoinAuthMeshReply, error) {
@ -20,7 +21,8 @@ func (a *AuthRpcProvider) JoinMesh(ctx context.Context, in *rpc.JoinAuthMeshRequ
return nil, errors.New("Must specify the meshId") return nil, errors.New("Must specify the meshId")
} }
token, err := a.server.JwtManager.CreateClaims(in.MeshId, "sharedSecret") logging.InfoLog.Println("MeshID: " + in.MeshId)
token, err := a.Manager.CreateClaims(in.MeshId, in.Alias)
if err != nil { if err != nil {
return nil, err return nil, err
@ -28,7 +30,3 @@ func (a *AuthRpcProvider) JoinMesh(ctx context.Context, in *rpc.JoinAuthMeshRequ
return &rpc.JoinAuthMeshReply{Success: true, Token: token}, nil return &rpc.JoinAuthMeshReply{Success: true, Token: token}, nil
} }
func NewAuthProvider(ctrlServer *ctrlserver.MeshCtrlServer) *AuthRpcProvider {
return &AuthRpcProvider{server: ctrlServer}
}

View File

@ -21,10 +21,8 @@ type RobinIpc struct {
Server *ctrlserver.MeshCtrlServer Server *ctrlserver.MeshCtrlServer
} }
const MeshIfName = "wgmesh"
func (n *RobinIpc) CreateMesh(name string, reply *string) error { func (n *RobinIpc) CreateMesh(name string, reply *string) error {
wg.CreateInterface(MeshIfName) wg.CreateInterface(n.Server.IfName)
mesh, err := n.Server.CreateMesh() mesh, err := n.Server.CreateMesh()
ula, _ := slaac.NewULA(n.Server.GetDevice().PublicKey, "0") ula, _ := slaac.NewULA(n.Server.GetDevice().PublicKey, "0")
@ -55,21 +53,23 @@ func (n *RobinIpc) ListMeshes(name string, reply *map[string]ctrlserver.Mesh) er
} }
func updateMesh(n *RobinIpc, meshId string, endPoint string) error { func updateMesh(n *RobinIpc, meshId string, endPoint string) error {
conn, err := n.Server.Conn.Connect(endPoint) peerConn, err := n.Server.ConnectionManager.GetConnection(endPoint)
if err != nil { if err != nil {
return err return err
} }
defer conn.Close() conn, err := peerConn.GetClient()
c := rpc.NewMeshCtrlServerClient(conn) c := rpc.NewMeshCtrlServerClient(conn)
ctx, err := n.Server.AddToken(context.Background(), endPoint, meshId) authContext, err := peerConn.CreateAuthContext(meshId)
if err != nil { if err != nil {
return err return err
} }
ctx, cancel := context.WithTimeout(ctx, time.Second) ctx, cancel := context.WithTimeout(authContext, time.Second)
defer cancel() defer cancel()
getMeshReq := rpc.GetMeshRequest{ getMeshReq := rpc.GetMeshRequest{
@ -108,34 +108,33 @@ func updateMesh(n *RobinIpc, meshId string, endPoint string) error {
} }
func updatePeer(n *RobinIpc, node ctrlserver.MeshNode, wgHost string, meshId string) error { func updatePeer(n *RobinIpc, node ctrlserver.MeshNode, wgHost string, meshId string) error {
token, err := n.Authenticate(meshId, node.HostEndpoint) err := n.Authenticate(meshId, node.HostEndpoint)
if err != nil { if err != nil {
return err return err
} }
err = n.Server.TokenManager.AddToken(meshId, node.HostEndpoint, token) peerConnection, err := n.Server.ConnectionManager.GetConnection(node.HostEndpoint)
if err != nil { if err != nil {
return err return err
} }
conn, err := n.Server.Conn.Connect(node.HostEndpoint) conn, err := peerConnection.GetClient()
if err != nil { if err != nil {
return err return err
} }
defer conn.Close()
c := rpc.NewMeshCtrlServerClient(conn) c := rpc.NewMeshCtrlServerClient(conn)
ctx, err := n.Server.AddToken(context.Background(), node.HostEndpoint, meshId) authContext, err := peerConnection.CreateAuthContext(meshId)
if err != nil { if err != nil {
return err return err
} }
ctx, cancel := context.WithTimeout(ctx, time.Second) ctx, cancel := context.WithTimeout(authContext, time.Second)
defer cancel() defer cancel()
dev := n.Server.GetDevice() dev := n.Server.GetDevice()
@ -178,61 +177,51 @@ func updatePeers(n *RobinIpc, meshId string, wgHost string, nodesToExclude []str
return nil return nil
} }
func (n *RobinIpc) Authenticate(meshId, endpoint string) (string, error) { func (n *RobinIpc) Authenticate(meshId, endpoint string) error {
conn, err := n.Server.Conn.Connect(endpoint) peerConnection, err := n.Server.ConnectionManager.AddConnection(endpoint)
if err != nil { if err != nil {
return "", err return err
} }
defer conn.Close() err = peerConnection.Authenticate(meshId)
c := rpc.NewAuthenticationClient(conn)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
authRequest := rpc.JoinAuthMeshRequest{
MeshId: meshId,
Alias: lib.GetOutboundIP().String(),
}
reply, err := c.JoinMesh(ctx, &authRequest)
if err != nil { if err != nil {
return "", err return err
} }
logging.InfoLog.Printf("Token: %s\n", *reply.Token) err = peerConnection.Connect()
return err
return *reply.Token, err
} }
func (n *RobinIpc) JoinMesh(args ipc.JoinMeshArgs, reply *string) error { func (n *RobinIpc) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
token, err := n.Authenticate(args.MeshId, args.IpAdress+":8080") err := n.Authenticate(args.MeshId, args.IpAdress+":8080")
if err != nil { if err != nil {
return err return err
} }
n.Server.TokenManager.AddToken(args.MeshId, args.IpAdress+":8080", token) peerConnection, err := n.Server.ConnectionManager.GetConnection(args.IpAdress + ":8080")
conn, err := n.Server.Conn.Connect(args.IpAdress + ":8080")
if err != nil { if err != nil {
return err return err
} }
defer conn.Close() client, err := peerConnection.GetClient()
c := rpc.NewMeshCtrlServerClient(conn)
ctx, err := n.Server.AddToken(context.Background(), args.IpAdress+":8080", args.MeshId)
if err != nil { if err != nil {
return err return err
} }
ctx, cancel := context.WithTimeout(ctx, time.Second) c := rpc.NewMeshCtrlServerClient(client)
authContext, err := peerConnection.CreateAuthContext(args.MeshId)
if err != nil {
return err
}
ctx, cancel := context.WithTimeout(authContext, time.Second)
defer cancel() defer cancel()
dev := n.Server.GetDevice() dev := n.Server.GetDevice()

View File

@ -13,7 +13,7 @@ import (
type RobinRpc struct { type RobinRpc struct {
rpc.UnimplementedMeshCtrlServerServer rpc.UnimplementedMeshCtrlServerServer
server *ctrlserver.MeshCtrlServer Server *ctrlserver.MeshCtrlServer
} }
func nodeToRpcNode(node ctrlserver.MeshNode) *rpc.MeshNode { func nodeToRpcNode(node ctrlserver.MeshNode) *rpc.MeshNode {
@ -40,7 +40,7 @@ func nodesToRpcNodes(nodes map[string]ctrlserver.MeshNode) []*rpc.MeshNode {
} }
func (m *RobinRpc) GetMesh(ctx context.Context, request *rpc.GetMeshRequest) (*rpc.GetMeshReply, error) { func (m *RobinRpc) GetMesh(ctx context.Context, request *rpc.GetMeshRequest) (*rpc.GetMeshReply, error) {
mesh, contains := m.server.Meshes[request.MeshId] mesh, contains := m.Server.Meshes[request.MeshId]
if !contains { if !contains {
return nil, errors.New("Element is not in the mesh") return nil, errors.New("Element is not in the mesh")
@ -77,7 +77,7 @@ func (m *RobinRpc) JoinMesh(ctx context.Context, request *rpc.JoinMeshRequest) (
WgIp: wgIp, WgIp: wgIp,
} }
err = m.server.AddHost(addHostArgs) err = m.Server.AddHost(addHostArgs)
if err != nil { if err != nil {
return nil, err return nil, err
@ -85,7 +85,3 @@ func (m *RobinRpc) JoinMesh(ctx context.Context, request *rpc.JoinMeshRequest) (
return &rpc.JoinMeshReply{Success: true, MeshIp: &wgIp}, nil return &rpc.JoinMeshReply{Success: true, MeshIp: &wgIp}, nil
} }
func NewRobinRpc(ctrlServer *ctrlserver.MeshCtrlServer) *RobinRpc {
return &RobinRpc{server: ctrlServer}
}