1
0
forked from extern/smegmesh

Tested with large number of nodes

This commit is contained in:
Tim Beatham 2023-10-24 00:12:38 +01:00
parent ef2b57047d
commit 8e89281484
20 changed files with 326 additions and 547 deletions

View File

@ -31,7 +31,7 @@ func listMeshes(client *ipcRpc.Client) {
err := client.Call("RobinIpc.ListMeshes", "", &reply) err := client.Call("RobinIpc.ListMeshes", "", &reply)
if err != nil { if err != nil {
logging.ErrorLog.Println(err.Error()) logging.Log.WriteErrorf(err.Error())
return return
} }

View File

@ -36,7 +36,6 @@ func main() {
} }
ctrlServer, err := ctrlserver.NewCtrlServer(&ctrlServerParams) ctrlServer, err := ctrlserver.NewCtrlServer(&ctrlServerParams)
authProvider.Manager = ctrlServer.ConnectionServer.JwtManager
syncProvider.Server = ctrlServer syncProvider.Server = ctrlServer
syncRequester := sync.NewSyncRequester(ctrlServer) syncRequester := sync.NewSyncRequester(ctrlServer)
syncScheduler := sync.NewSyncScheduler(ctrlServer, syncRequester, 2) syncScheduler := sync.NewSyncScheduler(ctrlServer, syncRequester, 2)
@ -50,7 +49,8 @@ func main() {
robinIpc = robin.NewRobinIpc(robinIpcParams) robinIpc = robin.NewRobinIpc(robinIpcParams)
if err != nil { if err != nil {
logging.ErrorLog.Fatalln(err.Error()) logging.Log.WriteErrorf(err.Error())
return
} }
log.Println("Running IPC Handler") log.Println("Running IPC Handler")
@ -61,9 +61,12 @@ func main() {
err = ctrlServer.ConnectionServer.Listen() err = ctrlServer.ConnectionServer.Listen()
if err != nil { if err != nil {
logging.ErrorLog.Fatalln(err.Error()) logging.Log.WriteErrorf(err.Error())
return
} }
defer wgClient.Close()
defer syncScheduler.Stop() defer syncScheduler.Stop()
defer ctrlServer.Close()
defer wgClient.Close()
} }

View File

@ -1,140 +0,0 @@
package auth
import (
"context"
"errors"
"fmt"
"strconv"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
// JwtMesh contains all the sessions with the mesh network
type JwtMesh struct {
meshId string
// nodes contains a set of nodes with the string being the jwt token
nodes map[string]interface{}
}
// JwtManager manages jwt tokens indicating a session
// between this host and another within a specific mesh
type JwtManager struct {
secretKey []byte
tokenDuration time.Duration
// meshes contains all the meshes that we have sessions with
meshes map[string]*JwtMesh
}
// JwtNode represents a jwt node in the mesh network
type JwtNode struct {
MeshId string `json:"meshId"`
Alias string `json:"alias"`
jwt.RegisteredClaims
}
func NewJwtManager(secretKey string, tokenDuration time.Duration) *JwtManager {
meshes := make(map[string]*JwtMesh)
return &JwtManager{[]byte(secretKey), tokenDuration, meshes}
}
func (m *JwtManager) CreateClaims(meshId string, alias string) (*string, error) {
logging.InfoLog.Println("MeshID: " + meshId)
logging.InfoLog.Println("Token Duration: " + strconv.Itoa(int(m.tokenDuration)))
node := JwtNode{
MeshId: meshId,
Alias: alias,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(m.tokenDuration)),
},
}
mesh, contains := m.meshes[meshId]
if !contains {
mesh = new(JwtMesh)
mesh.meshId = meshId
mesh.nodes = make(map[string]interface{})
mesh.nodes[meshId] = mesh
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, node)
signedString, err := token.SignedString(m.secretKey)
if err != nil {
fmt.Println(err.Error())
return nil, err
}
_, exists := mesh.nodes[signedString]
if exists {
return nil, errors.New("Node already exists")
}
mesh.nodes[signedString] = struct{}{}
return &signedString, nil
}
func (m *JwtManager) Verify(accessToken string) (*JwtNode, bool) {
token, err := jwt.ParseWithClaims(accessToken, &JwtNode{}, func(t *jwt.Token) (interface{}, error) {
return m.secretKey, nil
})
if err != nil {
return nil, false
}
if !token.Valid {
return nil, token.Valid
}
claims, ok := token.Claims.(*JwtNode)
return claims, ok
}
func (m *JwtManager) GetAuthInterceptor() grpc.UnaryServerInterceptor {
return func(
ctx context.Context,
req interface{},
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
if strings.Contains(info.FullMethod, "") {
return handler(ctx, req)
}
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return nil, status.Errorf(codes.Unauthenticated, "metadata is not provided")
}
values := md["authorization"]
for _, w := range values {
logging.InfoLog.Printf(w)
}
if len(values) == 0 {
return nil, status.Errorf(codes.Unauthenticated, "authorization token is not provided")
}
acessToken := values[0]
_, valid := m.Verify(acessToken)
if !valid {
return nil, status.Errorf(codes.Unauthenticated, "Invalid access token: %s", acessToken)
}
return handler(ctx, req)
}
}

View File

@ -193,14 +193,13 @@ func (m *CrdtNodeManager) Length() int {
return m.doc.Path("nodes").Map().Len() return m.doc.Path("nodes").Map().Len()
} }
const threshold = 2
const thresholdVotes = 0.1 const thresholdVotes = 0.1
func (m *CrdtNodeManager) HasFailed(endpoint string) bool { func (m *CrdtNodeManager) HasFailed(endpoint string) bool {
node, err := m.GetNode(endpoint) node, err := m.GetNode(endpoint)
if err != nil { if err != nil {
logging.InfoLog.Printf("Cannot get node node: %s\n", endpoint) logging.Log.WriteErrorf("Cannot get node node: %s\n", endpoint)
return true return true
} }
@ -215,14 +214,12 @@ func (m *CrdtNodeManager) HasFailed(endpoint string) bool {
for _, value := range values { for _, value := range values {
count := value.Int64() count := value.Int64()
if count >= threshold { if count >= 1 {
countFailed++ countFailed++
} }
} }
logging.InfoLog.Printf("Count Failed Value: %d\n", countFailed) return countFailed >= 4
logging.InfoLog.Printf("Threshold Value: %d\n", int(thresholdVotes*float64(m.Length())+1))
return countFailed >= int(thresholdVotes*float64(m.Length())+1)
} }
func (m *CrdtNodeManager) updateWgConf(devName string, nodes map[string]MeshNodeCrdt, client wgctrl.Client) error { func (m *CrdtNodeManager) updateWgConf(devName string, nodes map[string]MeshNodeCrdt, client wgctrl.Client) error {
@ -232,7 +229,6 @@ func (m *CrdtNodeManager) updateWgConf(devName string, nodes map[string]MeshNode
for _, n := range nodes { for _, n := range nodes {
peer, err := m.convertMeshNode(n) peer, err := m.convertMeshNode(n)
logging.InfoLog.Println(n.HostEndpoint)
if err != nil { if err != nil {
return err return err

View File

@ -24,14 +24,14 @@ func ParseConfiguration(filePath string) (*WgMeshConfiguration, error) {
yamlBytes, err := os.ReadFile(filePath) yamlBytes, err := os.ReadFile(filePath)
if err != nil { if err != nil {
logging.ErrorLog.Printf("Read file error: %s\n", err.Error()) logging.Log.WriteErrorf("Read file error: %s\n", err.Error())
return nil, err return nil, err
} }
err = yaml.Unmarshal(yamlBytes, &conf) err = yaml.Unmarshal(yamlBytes, &conf)
if err != nil { if err != nil {
logging.ErrorLog.Printf("Unmarshal error: %s\n", err.Error()) logging.Log.WriteErrorf("Unmarshal error: %s\n", err.Error())
return nil, err return nil, err
} }

View File

@ -1,116 +0,0 @@
// conn manages gRPC connections between peers.
// Includes timers.
package conn
import (
"context"
"crypto/tls"
"errors"
"time"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/rpc"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/metadata"
)
// PeerConnection interfacing for a secure connection between
// two peers.
type PeerConnection interface {
Connect() error
Close() error
Authenticate(meshId string) error
GetClient() (*grpc.ClientConn, error)
CreateAuthContext(meshId string) (context.Context, error)
}
type WgCtrlConnection struct {
clientConfig *tls.Config
conn *grpc.ClientConn
endpoint string
// tokens maps a meshID to the corresponding token
tokens map[string]string
}
func NewWgCtrlConnection(clientConfig *tls.Config, server string) (*WgCtrlConnection, error) {
var conn WgCtrlConnection
conn.tokens = make(map[string]string)
conn.clientConfig = clientConfig
conn.endpoint = server
return &conn, nil
}
func (c *WgCtrlConnection) Authenticate(meshId string) error {
conn, err := grpc.Dial(c.endpoint,
grpc.WithTransportCredentials(credentials.NewTLS(c.clientConfig)))
defer conn.Close()
if err != nil {
return err
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
client := rpc.NewAuthenticationClient(conn)
defer cancel()
authRequest := rpc.JoinAuthMeshRequest{
MeshId: meshId,
Alias: lib.GetOutboundIP().String(),
}
reply, err := client.JoinMesh(ctx, &authRequest)
if err != nil {
return err
}
c.tokens[meshId] = *reply.Token
return nil
}
// ConnectWithToken: Connects to a new gRPC peer given the address of the other server.
func (c *WgCtrlConnection) Connect() error {
conn, err := grpc.Dial(c.endpoint,
grpc.WithTransportCredentials(credentials.NewTLS(c.clientConfig)),
)
if err != nil {
logging.ErrorLog.Printf("Could not connect: %s\n", err.Error())
return err
}
c.conn = conn
return nil
}
// Close: Closes the client connections
func (c *WgCtrlConnection) Close() error {
return c.conn.Close()
}
// GetClient: Gets the client connection
func (c *WgCtrlConnection) GetClient() (*grpc.ClientConn, error) {
var err error = nil
if c.conn == nil {
err = errors.New("The client's config does not exist")
}
return c.conn, err
}
// TODO: Implement a mechanism to attach a security token
func (c *WgCtrlConnection) CreateAuthContext(meshId string) (context.Context, error) {
token, ok := c.tokens[meshId]
if !ok {
return nil, errors.New("MeshID: " + meshId + " does not exist")
}
ctx := context.Background()
return metadata.AppendToOutgoingContext(ctx, "authorization", token), nil
}

View File

@ -1,93 +0,0 @@
package conn
import (
"crypto/tls"
"errors"
logging "github.com/tim-beatham/wgmesh/pkg/log"
)
type ConnectionManager interface {
AddConnection(endPoint string) (PeerConnection, error)
GetConnection(endPoint string) (PeerConnection, error)
HasConnection(endPoint string) bool
}
// ConnectionManager manages connections between other peers
// in the control plane.
type JwtConnectionManager struct {
// clientConnections maps an endpoint to a connection
clientConnections map[string]PeerConnection
serverConfig *tls.Config
clientConfig *tls.Config
}
type NewJwtConnectionManagerParams struct {
CertificatePath string
PrivateKey string
SkipCertVerification bool
}
func NewJwtConnectionManager(params *NewJwtConnectionManagerParams) (ConnectionManager, error) {
cert, err := tls.LoadX509KeyPair(params.CertificatePath, params.PrivateKey)
if err != nil {
logging.ErrorLog.Printf("Failed to load key pair: %s\n", err.Error())
logging.ErrorLog.Printf("Certificate Path: %s\n", params.CertificatePath)
logging.ErrorLog.Printf("Private Key Path: %s\n", params.PrivateKey)
return nil, err
}
serverAuth := tls.RequireAndVerifyClientCert
if params.SkipCertVerification {
serverAuth = tls.RequireAnyClientCert
}
serverConfig := &tls.Config{
ClientAuth: serverAuth,
Certificates: []tls.Certificate{cert},
}
clientConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
InsecureSkipVerify: params.SkipCertVerification,
}
connections := make(map[string]PeerConnection)
connMgr := JwtConnectionManager{connections, serverConfig, clientConfig}
return &connMgr, nil
}
func (m *JwtConnectionManager) GetConnection(endpoint string) (PeerConnection, error) {
conn, exists := m.clientConnections[endpoint]
if !exists {
return nil, errors.New("endpoint: " + endpoint + " does not exist")
}
return conn, nil
}
// AddToken: Adds a connection to the list of connections to manage
func (m *JwtConnectionManager) AddConnection(endPoint string) (PeerConnection, error) {
conn, exists := m.clientConnections[endPoint]
if exists {
return conn, nil
}
connections, err := NewWgCtrlConnection(m.clientConfig, endPoint)
if err != nil {
return nil, err
}
m.clientConnections[endPoint] = connections
return connections, nil
}
func (m *JwtConnectionManager) HasConnection(endPoint string) bool {
_, exists := m.clientConnections[endPoint]
return exists
}

70
pkg/conn/connection.go Normal file
View File

@ -0,0 +1,70 @@
// conn manages gRPC connections between peers.
// Includes timers.
package conn
import (
"crypto/tls"
"errors"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)
// PeerConnection represents a client-side connection between two
// peers.
type PeerConnection interface {
Close() error
GetClient() (*grpc.ClientConn, error)
}
// WgCtrlConnection implements PeerConnection.
type WgCtrlConnection struct {
clientConfig *tls.Config
conn *grpc.ClientConn
endpoint string
}
// NewWgCtrlConnection creates a new instance of a WireGuard control connection
func NewWgCtrlConnection(clientConfig *tls.Config, server string) (*WgCtrlConnection, error) {
var conn WgCtrlConnection
conn.clientConfig = clientConfig
conn.endpoint = server
if err := conn.createGrpcConn(); err != nil {
return nil, err
}
return &conn, nil
}
// ConnectWithToken: Connects to a new gRPC peer given the address of the other server.
func (c *WgCtrlConnection) createGrpcConn() error {
conn, err := grpc.Dial(c.endpoint,
grpc.WithTransportCredentials(credentials.NewTLS(c.clientConfig)),
)
if err != nil {
logging.Log.WriteErrorf("Could not connect: %s\n", err.Error())
return err
}
c.conn = conn
return nil
}
// Close: Closes the client connections
func (c *WgCtrlConnection) Close() error {
return c.conn.Close()
}
// GetClient: Gets the client connection
func (c *WgCtrlConnection) GetClient() (*grpc.ClientConn, error) {
var err error = nil
if c.conn == nil {
err = errors.New("The client's config does not exist")
}
return c.conn, err
}

View File

@ -0,0 +1,129 @@
package conn
import (
"crypto/tls"
"sync"
logging "github.com/tim-beatham/wgmesh/pkg/log"
)
// ConnectionManager defines an interface for maintaining peer connections
type ConnectionManager interface {
// AddConnection adds an instance of a connection at the given endpoint
// or error if something went wrong
AddConnection(endPoint string) (PeerConnection, error)
// GetConnection returns an instance of a connection at the given endpoint.
// If the endpoint does not exist then add the connection. Returns an error
// if something went wrong
GetConnection(endPoint string) (PeerConnection, error)
// HasConnections returns true if a client has already registered at the givne
// endpoint or false otherwise.
HasConnection(endPoint string) bool
// Goes through all the connections and closes eachone
Close() error
}
// ConnectionManager manages connections between other peers
// in the control plane.
type ConnectionManagerImpl struct {
// clientConnections maps an endpoint to a connection
conLoc sync.RWMutex
clientConnections map[string]PeerConnection
serverConfig *tls.Config
clientConfig *tls.Config
}
// Create a new instance of a connection manager.
type NewConnectionManageParams struct {
// The path to the certificate
CertificatePath string
// The private key of the node
PrivateKey string
// Whether or not to skip certificate verification
SkipCertVerification bool
}
// NewConnectionManager: Creates a new instance of a ConnectionManager or an error
// if something went wrong.
func NewConnectionManager(params *NewConnectionManageParams) (ConnectionManager, error) {
cert, err := tls.LoadX509KeyPair(params.CertificatePath, params.PrivateKey)
if err != nil {
logging.Log.WriteErrorf("Failed to load key pair: %s\n", err.Error())
logging.Log.WriteErrorf("Certificate Path: %s\n", params.CertificatePath)
logging.Log.WriteErrorf("Private Key Path: %s\n", params.PrivateKey)
return nil, err
}
serverAuth := tls.RequireAndVerifyClientCert
if params.SkipCertVerification {
serverAuth = tls.RequireAnyClientCert
}
serverConfig := &tls.Config{
ClientAuth: serverAuth,
Certificates: []tls.Certificate{cert},
}
clientConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
InsecureSkipVerify: params.SkipCertVerification,
}
connections := make(map[string]PeerConnection)
connMgr := ConnectionManagerImpl{sync.RWMutex{}, connections, serverConfig, clientConfig}
return &connMgr, nil
}
// GetConnection: Returns the given connection if it exists. If it does not exist then add
// the connection. Returns an error if something went wrong
func (m *ConnectionManagerImpl) GetConnection(endpoint string) (PeerConnection, error) {
m.conLoc.Lock()
conn, exists := m.clientConnections[endpoint]
m.conLoc.Unlock()
if !exists {
return m.AddConnection(endpoint)
}
return conn, nil
}
// AddConnection: Adds a connection to the list of connections to manage.
func (m *ConnectionManagerImpl) AddConnection(endPoint string) (PeerConnection, error) {
m.conLoc.Lock()
conn, exists := m.clientConnections[endPoint]
m.conLoc.Unlock()
if exists {
return conn, nil
}
connections, err := NewWgCtrlConnection(m.clientConfig, endPoint)
if err != nil {
return nil, err
}
m.conLoc.Lock()
m.clientConnections[endPoint] = connections
m.conLoc.Unlock()
return connections, nil
}
// HasConnection Returns TRUE if the given endpoint exists
func (m *ConnectionManagerImpl) HasConnection(endPoint string) bool {
_, exists := m.clientConnections[endPoint]
return exists
}
func (m *ConnectionManagerImpl) Close() error {
for _, conn := range m.clientConnections {
if err := conn.Close(); err != nil {
return err
}
}
return nil
}

View File

@ -3,9 +3,7 @@ package conn
import ( import (
"crypto/tls" "crypto/tls"
"net" "net"
"time"
"github.com/tim-beatham/wgmesh/pkg/auth"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/wgmesh/pkg/conf"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/rpc" "github.com/tim-beatham/wgmesh/pkg/rpc"
@ -13,17 +11,23 @@ import (
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
) )
// ConnectionServer manages the gRPC server // ConnectionServer manages gRPC server peer connections
type ConnectionServer struct { type ConnectionServer struct {
severConfig *tls.Config // tlsConfiguration of the server
JwtManager *auth.JwtManager serverConfig *tls.Config
server *grpc.Server // server an instance of the grpc server
server *grpc.Server
// the authentication service to authenticate nodes
authProvider rpc.AuthenticationServer authProvider rpc.AuthenticationServer
// the ctrl service to manage node
ctrlProvider rpc.MeshCtrlServerServer ctrlProvider rpc.MeshCtrlServerServer
// the sync service to synchronise nodes
syncProvider rpc.SyncServiceServer syncProvider rpc.SyncServiceServer
Conf *conf.WgMeshConfiguration Conf *conf.WgMeshConfiguration
listener net.Listener
} }
// NewConnectionServerParams contains params for creating a new connection server
type NewConnectionServerParams struct { type NewConnectionServerParams struct {
Conf *conf.WgMeshConfiguration Conf *conf.WgMeshConfiguration
AuthProvider rpc.AuthenticationServer AuthProvider rpc.AuthenticationServer
@ -36,9 +40,7 @@ func NewConnectionServer(params *NewConnectionServerParams) (*ConnectionServer,
cert, err := tls.LoadX509KeyPair(params.Conf.CertificatePath, params.Conf.PrivateKeyPath) cert, err := tls.LoadX509KeyPair(params.Conf.CertificatePath, params.Conf.PrivateKeyPath)
if err != nil { if err != nil {
logging.ErrorLog.Printf("Failed to load key pair: %s\n", err.Error()) logging.Log.WriteErrorf("Failed to load key pair: %s\n", err.Error())
logging.ErrorLog.Printf("Certificate Path: %s\n", params.Conf.CertificatePath)
logging.ErrorLog.Printf("Private Key Path: %s\n", params.Conf.PrivateKeyPath)
return nil, err return nil, err
} }
@ -53,10 +55,7 @@ func NewConnectionServer(params *NewConnectionServerParams) (*ConnectionServer,
Certificates: []tls.Certificate{cert}, Certificates: []tls.Certificate{cert},
} }
jwtManager := auth.NewJwtManager(params.Conf.Secret, 24*time.Hour)
server := grpc.NewServer( server := grpc.NewServer(
grpc.UnaryInterceptor(jwtManager.GetAuthInterceptor()),
grpc.Creds(credentials.NewTLS(serverConfig)), grpc.Creds(credentials.NewTLS(serverConfig)),
) )
@ -65,38 +64,51 @@ func NewConnectionServer(params *NewConnectionServerParams) (*ConnectionServer,
syncProvider := params.SyncProvider syncProvider := params.SyncProvider
connServer := ConnectionServer{ connServer := ConnectionServer{
serverConfig, serverConfig: serverConfig,
jwtManager, server: server,
server, authProvider: authProvider,
authProvider, ctrlProvider: ctrlProvider,
ctrlProvider, syncProvider: syncProvider,
syncProvider, Conf: params.Conf,
params.Conf,
} }
return &connServer, nil return &connServer, nil
} }
// Listen for incoming requests. Returns an error if something went wrong.
func (s *ConnectionServer) Listen() error { func (s *ConnectionServer) Listen() error {
rpc.RegisterMeshCtrlServerServer(s.server, s.ctrlProvider) rpc.RegisterMeshCtrlServerServer(s.server, s.ctrlProvider)
rpc.RegisterAuthenticationServer(s.server, s.authProvider) rpc.RegisterAuthenticationServer(s.server, s.authProvider)
logging.InfoLog.Println(s.syncProvider)
rpc.RegisterSyncServiceServer(s.server, s.syncProvider) rpc.RegisterSyncServiceServer(s.server, s.syncProvider)
lis, err := net.Listen("tcp", ":"+s.Conf.GrpcPort) lis, err := net.Listen("tcp", ":"+s.Conf.GrpcPort)
s.listener = lis
logging.InfoLog.Printf("GRPC listening on %s\n", s.Conf.GrpcPort) logging.Log.WriteInfof("GRPC listening on %s\n", s.Conf.GrpcPort)
if err != nil { if err != nil {
logging.ErrorLog.Println(err.Error()) logging.Log.WriteErrorf(err.Error())
return err return err
} }
if err := s.server.Serve(lis); err != nil { if err := s.server.Serve(lis); err != nil {
logging.ErrorLog.Println(err.Error()) logging.Log.WriteErrorf(err.Error())
return err return err
} }
return nil return nil
} }
// Close closes the connection server. Returns an error
// if something went wrong whilst attempting to close the connection
func (c *ConnectionServer) Close() error {
var err error = nil
c.server.Stop()
if c.listener != nil {
err = c.listener.Close()
}
return err
}

View File

@ -1,8 +1,3 @@
/*
* ctrlserver controls the WireGuard mesh. Contains an IpcHandler for
* handling commands fired by wgmesh command.
* Contains an RpcHandler for handling commands fired by another server.
*/
package ctrlserver package ctrlserver
import ( import (
@ -13,6 +8,7 @@ import (
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
) )
// NewCtrlServerParams are the params requried to create a new ctrl server
type NewCtrlServerParams struct { type NewCtrlServerParams struct {
WgClient *wgctrl.Client WgClient *wgctrl.Client
Conf *conf.WgMeshConfiguration Conf *conf.WgMeshConfiguration
@ -21,32 +17,27 @@ type NewCtrlServerParams struct {
SyncProvider rpc.SyncServiceServer SyncProvider rpc.SyncServiceServer
} }
/* // Create a new instance of the MeshCtrlServer or error if the
* NewCtrlServer creates a new instance of the ctrlserver. // operation failed
* It is associated with a WireGuard client and an interface.
* wgClient: Represents the WireGuard control client.
* ifName: WireGuard interface name
*/
func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) { func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
ctrlServer := new(MeshCtrlServer) ctrlServer := new(MeshCtrlServer)
ctrlServer.Client = params.WgClient ctrlServer.Client = params.WgClient
ctrlServer.MeshManager = mesh.NewMeshManager(*params.WgClient, *params.Conf) ctrlServer.MeshManager = mesh.NewMeshManager(*params.WgClient, *params.Conf)
ctrlServer.Conf = params.Conf ctrlServer.Conf = params.Conf
connManagerParams := conn.NewJwtConnectionManagerParams{ connManagerParams := conn.NewConnectionManageParams{
CertificatePath: params.Conf.CertificatePath, CertificatePath: params.Conf.CertificatePath,
PrivateKey: params.Conf.PrivateKeyPath, PrivateKey: params.Conf.PrivateKeyPath,
SkipCertVerification: params.Conf.SkipCertVerification, SkipCertVerification: params.Conf.SkipCertVerification,
} }
connMgr, err := conn.NewJwtConnectionManager(&connManagerParams) connMgr, err := conn.NewConnectionManager(&connManagerParams)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ctrlServer.ConnectionManager = connMgr ctrlServer.ConnectionManager = connMgr
connServerParams := conn.NewConnectionServerParams{ connServerParams := conn.NewConnectionServerParams{
Conf: params.Conf, Conf: params.Conf,
AuthProvider: params.AuthProvider, AuthProvider: params.AuthProvider,
@ -63,3 +54,16 @@ func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
ctrlServer.ConnectionServer = connServer ctrlServer.ConnectionServer = connServer
return ctrlServer, nil return ctrlServer, nil
} }
// Close closes the ctrl server tearing down any connections that exist
func (s *MeshCtrlServer) Close() error {
if err := s.ConnectionManager.Close(); err != nil {
return err
}
if err := s.ConnectionServer.Close(); err != nil {
return err
}
return nil
}

View File

@ -1,22 +1,51 @@
// Provides a generic interface for logging
package logging package logging
/*
* This package creates the info, warning and error loggers.
*/
import ( import (
"log"
"os" "os"
"github.com/sirupsen/logrus"
) )
var ( var (
InfoLog *log.Logger Log Logger
WarningLog *log.Logger
ErrorLog *log.Logger
) )
func init() { type Logger interface {
InfoLog = log.New(os.Stdout, "INFO: ", log.Ldate|log.Ltime|log.Lshortfile) WriteInfof(msg string, args ...interface{})
WarningLog = log.New(os.Stdout, "WARNING: ", log.Ldate|log.Ltime|log.Lshortfile) WriteErrorf(msg string, args ...interface{})
ErrorLog = log.New(os.Stderr, "ERROR: ", log.Ldate|log.Ltime|log.Lshortfile) WriteWarnf(msg string, args ...interface{})
}
type LogrusLogger struct {
logger *logrus.Logger
}
func (l *LogrusLogger) WriteInfof(msg string, args ...interface{}) {
l.logger.Infof(msg, args...)
}
func (l *LogrusLogger) WriteErrorf(msg string, args ...interface{}) {
l.logger.Errorf(msg, args...)
}
func (l *LogrusLogger) WriteWarnf(msg string, args ...interface{}) {
l.logger.Warnf(msg, args...)
}
func NewLogrusLogger() *LogrusLogger {
logger := logrus.New()
logger.SetFormatter(&logrus.TextFormatter{FullTimestamp: true})
logger.SetOutput(os.Stdout)
logger.SetLevel(logrus.InfoLevel)
return &LogrusLogger{logger: logger}
}
func init() {
SetLogger(NewLogrusLogger())
}
func SetLogger(l Logger) {
Log = l
} }

View File

@ -4,16 +4,17 @@ import (
"context" "context"
"errors" "errors"
"github.com/tim-beatham/wgmesh/pkg/auth"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/rpc" "github.com/tim-beatham/wgmesh/pkg/rpc"
) )
// AuthRpcProvider implements the AuthRpcProvider service
type AuthRpcProvider struct { type AuthRpcProvider struct {
rpc.UnimplementedAuthenticationServer rpc.UnimplementedAuthenticationServer
Manager *auth.JwtManager
} }
// JoinMesh handles a JoinMeshRequest. Succeeds by stating the node managed to join the mesh
// or returns an error if it failed
func (a *AuthRpcProvider) JoinMesh(ctx context.Context, in *rpc.JoinAuthMeshRequest) (*rpc.JoinAuthMeshReply, error) { func (a *AuthRpcProvider) JoinMesh(ctx context.Context, in *rpc.JoinAuthMeshRequest) (*rpc.JoinAuthMeshReply, error) {
meshId := in.MeshId meshId := in.MeshId
@ -21,12 +22,8 @@ func (a *AuthRpcProvider) JoinMesh(ctx context.Context, in *rpc.JoinAuthMeshRequ
return nil, errors.New("Must specify the meshId") return nil, errors.New("Must specify the meshId")
} }
logging.InfoLog.Println("MeshID: " + in.MeshId) logging.Log.WriteInfof("MeshID: " + in.MeshId)
token, err := a.Manager.CreateClaims(in.MeshId, in.Alias)
if err != nil { var token string = ""
return nil, err return &rpc.JoinAuthMeshReply{Success: true, Token: &token}, nil
}
return &rpc.JoinAuthMeshReply{Success: true, Token: token}, nil
} }

View File

@ -75,71 +75,9 @@ func (n *RobinIpc) ListMeshes(_ string, reply *ipc.ListMeshReply) error {
return nil return nil
} }
func (n *RobinIpc) Authenticate(meshId, endpoint string) error {
peerConnection, err := n.Server.ConnectionManager.AddConnection(endpoint)
if err != nil {
return err
}
err = peerConnection.Authenticate(meshId)
if err != nil {
return err
}
return err
}
func (n *RobinIpc) authenticatePeers(meshId string) error {
theMesh := n.Server.MeshManager.GetMesh(meshId)
if theMesh == nil {
return errors.New("the mesh does not exist")
}
snapshot, _ := theMesh.GetCrdt()
publicKey, err := n.Server.MeshManager.GetPublicKey(meshId)
if err != nil {
return err
}
for nodeKey, node := range snapshot.Nodes {
logging.InfoLog.Println(nodeKey)
if nodeKey == publicKey.String() {
continue
}
err := n.Authenticate(meshId, node.HostEndpoint)
if err != nil {
return err
}
}
return nil
}
func (n *RobinIpc) JoinMesh(args ipc.JoinMeshArgs, reply *string) error { func (n *RobinIpc) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
err := n.Authenticate(args.MeshId, args.IpAdress)
if err != nil {
return err
}
peerConnection, err := n.Server.ConnectionManager.GetConnection(args.IpAdress) peerConnection, err := n.Server.ConnectionManager.GetConnection(args.IpAdress)
if err != nil {
return err
}
err = peerConnection.Connect()
if err != nil {
return err
}
client, err := peerConnection.GetClient() client, err := peerConnection.GetClient()
if err != nil { if err != nil {
@ -148,13 +86,11 @@ func (n *RobinIpc) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
c := rpc.NewMeshCtrlServerClient(client) c := rpc.NewMeshCtrlServerClient(client)
authContext, err := peerConnection.CreateAuthContext(args.MeshId)
if err != nil { if err != nil {
return err return err
} }
ctx, cancel := context.WithTimeout(authContext, time.Second) ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel() defer cancel()
meshReply, err := c.GetMesh(ctx, &rpc.GetMeshRequest{MeshId: args.MeshId}) meshReply, err := c.GetMesh(ctx, &rpc.GetMeshRequest{MeshId: args.MeshId})
@ -181,7 +117,7 @@ func (n *RobinIpc) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
return err return err
} }
logging.InfoLog.Println("WgIP: " + ipAddr.String()) logging.Log.WriteInfof("WgIP: " + ipAddr.String())
outBoundIP := lib.GetOutboundIP() outBoundIP := lib.GetOutboundIP()
@ -206,10 +142,6 @@ func (n *RobinIpc) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
return err return err
} }
if joinReply.GetSuccess() {
err = n.authenticatePeers(args.MeshId)
}
if err != nil { if err != nil {
return err return err
} }

View File

@ -56,7 +56,7 @@ func (m *RobinRpc) GetMesh(ctx context.Context, request *rpc.GetMeshRequest) (*r
func (m *RobinRpc) JoinMesh(ctx context.Context, request *rpc.JoinMeshRequest) (*rpc.JoinMeshReply, error) { func (m *RobinRpc) JoinMesh(ctx context.Context, request *rpc.JoinMeshRequest) (*rpc.JoinMeshReply, error) {
mesh := m.Server.MeshManager.GetMesh(request.MeshId) mesh := m.Server.MeshManager.GetMesh(request.MeshId)
logging.InfoLog.Println("[JOINING MESH]: " + request.MeshId) logging.Log.WriteInfof("[JOINING MESH]: " + request.MeshId)
if mesh == nil { if mesh == nil {
return nil, errors.New("mesh does not exist") return nil, errors.New("mesh does not exist")

View File

@ -34,7 +34,7 @@ func (s *SyncErrorHandlerImpl) incrementFailedCount(meshId string, endpoint stri
func (s *SyncErrorHandlerImpl) Handle(meshId string, endpoint string, err error) bool { func (s *SyncErrorHandlerImpl) Handle(meshId string, endpoint string, err error) bool {
errStatus, _ := status.FromError(err) errStatus, _ := status.FromError(err)
logging.WarningLog.Printf("Handled gRPC error: %s", errStatus.Message()) logging.Log.WriteInfof("Handled gRPC error: %s", errStatus.Message())
switch errStatus.Code() { switch errStatus.Code() {
case codes.Unavailable, codes.Unknown, codes.DeadlineExceeded, codes.Internal, codes.NotFound: case codes.Unavailable, codes.Unknown, codes.DeadlineExceeded, codes.Internal, codes.NotFound:

View File

@ -23,22 +23,6 @@ type SyncRequesterImpl struct {
errorHdlr SyncErrorHandler errorHdlr SyncErrorHandler
} }
func (s *SyncRequesterImpl) Authenticate(meshId, endpoint string) error {
peerConnection, err := s.server.ConnectionManager.AddConnection(endpoint)
if err != nil {
return err
}
err = peerConnection.Authenticate(meshId)
if err != nil {
return err
}
return err
}
// GetMesh: Retrieves the local state of the mesh at the endpoint // GetMesh: Retrieves the local state of the mesh at the endpoint
func (s *SyncRequesterImpl) GetMesh(meshId string, endPoint string) error { func (s *SyncRequesterImpl) GetMesh(meshId string, endPoint string) error {
peerConnection, err := s.server.ConnectionManager.GetConnection(endPoint) peerConnection, err := s.server.ConnectionManager.GetConnection(endPoint)
@ -47,12 +31,6 @@ func (s *SyncRequesterImpl) GetMesh(meshId string, endPoint string) error {
return err return err
} }
err = peerConnection.Connect()
if err != nil {
return err
}
client, err := peerConnection.GetClient() client, err := peerConnection.GetClient()
if err != nil { if err != nil {
@ -60,13 +38,8 @@ func (s *SyncRequesterImpl) GetMesh(meshId string, endPoint string) error {
} }
c := rpc.NewSyncServiceClient(client) c := rpc.NewSyncServiceClient(client)
authContext, err := peerConnection.CreateAuthContext(meshId)
if err != nil { ctx, cancel := context.WithTimeout(context.Background(), time.Second)
return err
}
ctx, cancel := context.WithTimeout(authContext, time.Second)
defer cancel() defer cancel()
reply, err := c.GetConf(ctx, &rpc.GetConfRequest{MeshId: meshId}) reply, err := c.GetConf(ctx, &rpc.GetConfRequest{MeshId: meshId})
@ -91,34 +64,18 @@ func (s *SyncRequesterImpl) handleErr(meshId, endpoint string, err error) error
// SyncMesh: Proactively send a sync request to the other mesh // SyncMesh: Proactively send a sync request to the other mesh
func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error { func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error {
if !s.server.ConnectionManager.HasConnection(endpoint) {
s.Authenticate(meshId, endpoint)
}
peerConnection, err := s.server.ConnectionManager.GetConnection(endpoint) peerConnection, err := s.server.ConnectionManager.GetConnection(endpoint)
if err != nil { if err != nil {
return err return err
} }
err = peerConnection.Connect()
if err != nil {
return s.handleErr(meshId, endpoint, err)
}
client, err := peerConnection.GetClient() client, err := peerConnection.GetClient()
if err != nil { if err != nil {
return err return err
} }
authContext, err := peerConnection.CreateAuthContext(meshId)
if err != nil {
return err
}
mesh := s.server.MeshManager.GetMesh(meshId) mesh := s.server.MeshManager.GetMesh(meshId)
if mesh == nil { if mesh == nil {
@ -127,7 +84,7 @@ func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error {
c := rpc.NewSyncServiceClient(client) c := rpc.NewSyncServiceClient(client)
ctx, cancel := context.WithTimeout(authContext, 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
err = syncMesh(mesh, ctx, c) err = syncMesh(mesh, ctx, c)
@ -136,7 +93,7 @@ func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error {
return s.handleErr(meshId, endpoint, err) return s.handleErr(meshId, endpoint, err)
} }
logging.InfoLog.Printf("Synced with node: %s meshId: %s\n", endpoint, meshId) logging.Log.WriteInfof("Synced with node: %s meshId: %s\n", endpoint, meshId)
mesh.DecrementFailedCount(endpoint) mesh.DecrementFailedCount(endpoint)
return nil return nil
} }
@ -162,7 +119,7 @@ func syncMesh(mesh *crdt.CrdtNodeManager, ctx context.Context, client rpc.SyncSe
in, err := stream.Recv() in, err := stream.Recv()
if err != nil && err != io.EOF { if err != nil && err != io.EOF {
logging.ErrorLog.Printf("Stream recv error: %s\n", err.Error()) logging.Log.WriteInfof("Stream recv error: %s\n", err.Error())
return err return err
} }
@ -171,7 +128,7 @@ func syncMesh(mesh *crdt.CrdtNodeManager, ctx context.Context, client rpc.SyncSe
} }
if err != nil { if err != nil {
logging.ErrorLog.Printf("Syncer recv error: %s\n", err.Error()) logging.Log.WriteInfof("Syncer recv error: %s\n", err.Error())
return err return err
} }
@ -180,7 +137,7 @@ func syncMesh(mesh *crdt.CrdtNodeManager, ctx context.Context, client rpc.SyncSe
} }
} }
logging.InfoLog.Println("SYNC finished") logging.Log.WriteInfof("SYNC finished")
stream.CloseSend() stream.CloseSend()
return nil return nil
} }

View File

@ -34,7 +34,7 @@ func (s *SyncSchedulerImpl) Run() error {
err := s.syncer.SyncMeshes() err := s.syncer.SyncMeshes()
if err != nil { if err != nil {
logging.ErrorLog.Println(err.Error()) logging.Log.WriteErrorf(err.Error())
} }
break break
case <-quit: case <-quit:

View File

@ -41,9 +41,9 @@ func (s *SyncServiceImpl) SyncMesh(stream rpc.SyncService_SyncMeshServer) error
var syncer *crdt.AutomergeSync = nil var syncer *crdt.AutomergeSync = nil
for { for {
logging.InfoLog.Println("Received Attempt") logging.Log.WriteInfof("Received Attempt")
in, err := stream.Recv() in, err := stream.Recv()
logging.InfoLog.Println("Received Worked") logging.Log.WriteInfof("Received Worked")
if err == io.EOF { if err == io.EOF {
return nil return nil
@ -84,7 +84,6 @@ func (s *SyncServiceImpl) SyncMesh(stream rpc.SyncService_SyncMeshServer) error
} }
if !moreMessages || err == io.EOF { if !moreMessages || err == io.EOF {
logging.InfoLog.Println("SYNC Completed")
return nil return nil
} }
} }

View File

@ -62,7 +62,7 @@ func EnableInterface(ifName string, ip string) error {
cmd := exec.Command("/usr/bin/ip", "link", "set", "up", "dev", ifName) cmd := exec.Command("/usr/bin/ip", "link", "set", "up", "dev", ifName)
if err := cmd.Run(); err != nil { if err := cmd.Run(); err != nil {
logging.ErrorLog.Println(err.Error()) logging.Log.WriteErrorf(err.Error())
return err return err
} }