From c200544ceefec83f185b91b4aeb7553c1902217a Mon Sep 17 00:00:00 2001 From: Tim Beatham Date: Fri, 20 Oct 2023 12:41:06 +0100 Subject: [PATCH] Timer in go that syncs with random nodes in the mesh every given time interval. --- cmd/wgmeshd/configuration.yaml | 2 +- cmd/wgmeshd/main.go | 10 +++++- pkg/auth/jwt.go | 2 +- pkg/automerge/automerge.go | 5 +++ pkg/conf/conf.go | 5 +-- pkg/conn/conn_manager.go | 10 ++++-- pkg/conn/conn_server.go | 13 ++++++-- pkg/ctrlserver/ctrlserver.go | 2 ++ pkg/lib/conv.go | 22 ++++++++++++ pkg/lib/random.go | 25 ++++++++++++++ pkg/robin/robin_requester.go | 12 ++++--- pkg/sync/syncer.go | 61 +++++++++++++++++++++++++++++++++- pkg/sync/syncrequester.go | 33 ++++++++++++++++-- pkg/sync/syncscheduler.go | 19 ++++++++--- pkg/sync/syncservice.go | 12 ++++--- pkg/wg/wg.go | 5 ++- 16 files changed, 208 insertions(+), 30 deletions(-) create mode 100644 pkg/lib/conv.go create mode 100644 pkg/lib/random.go diff --git a/cmd/wgmeshd/configuration.yaml b/cmd/wgmeshd/configuration.yaml index 2a83469..6aa9c79 100644 --- a/cmd/wgmeshd/configuration.yaml +++ b/cmd/wgmeshd/configuration.yaml @@ -3,5 +3,5 @@ privateKeyPath: "../../cert/key.pem" skipCertVerification: true ifName: "wgmesh" wgPort: 51820 -gRPCPort: 8080 +gRPCPort: "8080" secret: "abc123" \ No newline at end of file diff --git a/cmd/wgmeshd/main.go b/cmd/wgmeshd/main.go index ff33091..82defbd 100644 --- a/cmd/wgmeshd/main.go +++ b/cmd/wgmeshd/main.go @@ -10,6 +10,7 @@ import ( logging "github.com/tim-beatham/wgmesh/pkg/log" "github.com/tim-beatham/wgmesh/pkg/middleware" "github.com/tim-beatham/wgmesh/pkg/robin" + "github.com/tim-beatham/wgmesh/pkg/sync" wg "github.com/tim-beatham/wgmesh/pkg/wg" ) @@ -19,21 +20,26 @@ func main() { log.Fatalln("Could not parse configuration") } - wgClient, err := wg.CreateClient(conf.IfName) + wgClient, err := wg.CreateClient(conf.IfName, conf.WgPort) var robinRpc robin.RobinRpc var robinIpc robin.RobinIpc var authProvider middleware.AuthRpcProvider + var syncProvider sync.SyncServiceImpl ctrlServerParams := ctrlserver.NewCtrlServerParams{ WgClient: wgClient, Conf: conf, AuthProvider: &authProvider, CtrlProvider: &robinRpc, + SyncProvider: &syncProvider, } ctrlServer, err := ctrlserver.NewCtrlServer(&ctrlServerParams) authProvider.Manager = ctrlServer.ConnectionServer.JwtManager + syncProvider.Server = ctrlServer + syncRequester := sync.NewSyncRequester(ctrlServer) + syncScheduler := sync.NewSyncScheduler(ctrlServer, syncRequester, 2) robinIpcParams := robin.RobinIpcParams{ CtrlServer: ctrlServer, @@ -50,6 +56,7 @@ func main() { log.Println("Running IPC Handler") go ipc.RunIpcHandler(&robinIpc) + go syncScheduler.Run() err = ctrlServer.ConnectionServer.Listen() @@ -58,4 +65,5 @@ func main() { } defer wgClient.Close() + defer syncScheduler.Stop() } diff --git a/pkg/auth/jwt.go b/pkg/auth/jwt.go index eb0043a..4947847 100644 --- a/pkg/auth/jwt.go +++ b/pkg/auth/jwt.go @@ -107,7 +107,7 @@ func (m *JwtManager) GetAuthInterceptor() grpc.UnaryServerInterceptor { handler grpc.UnaryHandler, ) (interface{}, error) { - if strings.Contains(info.FullMethod, "Auth") { + if strings.Contains(info.FullMethod, "") { return handler(ctx, req) } diff --git a/pkg/automerge/automerge.go b/pkg/automerge/automerge.go index 99cefce..7aea8ea 100644 --- a/pkg/automerge/automerge.go +++ b/pkg/automerge/automerge.go @@ -2,6 +2,7 @@ package crdt import ( "net" + "strings" "github.com/automerge/automerge-go" logging "github.com/tim-beatham/wgmesh/pkg/log" @@ -110,6 +111,10 @@ func convertMeshNode(node MeshNodeCrdt) (*wgtypes.PeerConfig, error) { return &peerConfig, nil } +func (m1 *MeshNodeCrdt) Compare(m2 *MeshNodeCrdt) int { + return strings.Compare(m1.PublicKey, m2.PublicKey) +} + func updateWgConf(devName string, nodes map[string]MeshNodeCrdt, client wgctrl.Client) error { peerConfigs := make([]wgtypes.PeerConfig, len(nodes)) diff --git a/pkg/conf/conf.go b/pkg/conf/conf.go index e549fce..8fd93f6 100644 --- a/pkg/conf/conf.go +++ b/pkg/conf/conf.go @@ -13,8 +13,9 @@ type WgMeshConfiguration struct { PrivateKeyPath string `yaml:"privateKeyPath"` SkipCertVerification bool `yaml:"skipCertVerification"` IfName string `yaml:"ifName"` - WgPort string `yaml:"wgPort"` - GrpcPort string `yaml:"grpcPort"` + WgPort int `yaml:"wgPort"` + GrpcPort string `yaml:"gRPCPort"` + Secret string `yaml:"secret"` } func ParseConfiguration(filePath string) (*WgMeshConfiguration, error) { diff --git a/pkg/conn/conn_manager.go b/pkg/conn/conn_manager.go index ed5c88c..8045180 100644 --- a/pkg/conn/conn_manager.go +++ b/pkg/conn/conn_manager.go @@ -10,6 +10,7 @@ import ( type ConnectionManager interface { AddConnection(endPoint string) (PeerConnection, error) GetConnection(endPoint string) (PeerConnection, error) + HasConnection(endPoint string) bool } // ConnectionManager manages connections between other peers @@ -70,10 +71,10 @@ func (m *JwtConnectionManager) GetConnection(endpoint string) (PeerConnection, e // AddToken: Adds a connection to the list of connections to manage func (m *JwtConnectionManager) AddConnection(endPoint string) (PeerConnection, error) { - _, exists := m.clientConnections[endPoint] + conn, exists := m.clientConnections[endPoint] if exists { - return nil, errors.New("token already exists in the connections") + return conn, nil } connections, err := NewWgCtrlConnection(m.clientConfig, endPoint) @@ -85,3 +86,8 @@ func (m *JwtConnectionManager) AddConnection(endPoint string) (PeerConnection, e m.clientConnections[endPoint] = connections return connections, nil } + +func (m *JwtConnectionManager) HasConnection(endPoint string) bool { + _, exists := m.clientConnections[endPoint] + return exists +} diff --git a/pkg/conn/conn_server.go b/pkg/conn/conn_server.go index 05be073..3c32dd2 100644 --- a/pkg/conn/conn_server.go +++ b/pkg/conn/conn_server.go @@ -20,6 +20,7 @@ type ConnectionServer struct { server *grpc.Server authProvider rpc.AuthenticationServer ctrlProvider rpc.MeshCtrlServerServer + syncProvider rpc.SyncServiceServer Conf *conf.WgMeshConfiguration } @@ -27,6 +28,7 @@ type NewConnectionServerParams struct { Conf *conf.WgMeshConfiguration AuthProvider rpc.AuthenticationServer CtrlProvider rpc.MeshCtrlServerServer + SyncProvider rpc.SyncServiceServer } // NewConnectionServer: create a new gRPC connection server instance @@ -51,7 +53,7 @@ func NewConnectionServer(params *NewConnectionServerParams) (*ConnectionServer, Certificates: []tls.Certificate{cert}, } - jwtManager := auth.NewJwtManager("tim123", 24*time.Hour) + jwtManager := auth.NewJwtManager(params.Conf.Secret, 24*time.Hour) server := grpc.NewServer( grpc.UnaryInterceptor(jwtManager.GetAuthInterceptor()), @@ -60,6 +62,7 @@ func NewConnectionServer(params *NewConnectionServerParams) (*ConnectionServer, authProvider := params.AuthProvider ctrlProvider := params.CtrlProvider + syncProvider := params.SyncProvider connServer := ConnectionServer{ serverConfig, @@ -67,6 +70,7 @@ func NewConnectionServer(params *NewConnectionServerParams) (*ConnectionServer, server, authProvider, ctrlProvider, + syncProvider, params.Conf, } @@ -77,7 +81,12 @@ func (s *ConnectionServer) Listen() error { rpc.RegisterMeshCtrlServerServer(s.server, s.ctrlProvider) rpc.RegisterAuthenticationServer(s.server, s.authProvider) - lis, err := net.Listen("tcp", s.Conf.GrpcPort) + logging.InfoLog.Println(s.syncProvider) + rpc.RegisterSyncServiceServer(s.server, s.syncProvider) + + lis, err := net.Listen("tcp", ":"+s.Conf.GrpcPort) + + logging.InfoLog.Printf("GRPC listening on %s\n", s.Conf.GrpcPort) if err != nil { logging.ErrorLog.Println(err.Error()) diff --git a/pkg/ctrlserver/ctrlserver.go b/pkg/ctrlserver/ctrlserver.go index 5917871..df255b8 100644 --- a/pkg/ctrlserver/ctrlserver.go +++ b/pkg/ctrlserver/ctrlserver.go @@ -18,6 +18,7 @@ type NewCtrlServerParams struct { Conf *conf.WgMeshConfiguration AuthProvider rpc.AuthenticationServer CtrlProvider rpc.MeshCtrlServerServer + SyncProvider rpc.SyncServiceServer } /* @@ -50,6 +51,7 @@ func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) { Conf: params.Conf, AuthProvider: params.AuthProvider, CtrlProvider: params.CtrlProvider, + SyncProvider: params.SyncProvider, } connServer, err := conn.NewConnectionServer(&connServerParams) diff --git a/pkg/lib/conv.go b/pkg/lib/conv.go new file mode 100644 index 0000000..f148c9c --- /dev/null +++ b/pkg/lib/conv.go @@ -0,0 +1,22 @@ +package lib + +// MapToSlice converts a map to a slice in go +func MapValues[K comparable, V any](m map[K]V) []V { + return MapValuesWithExclude(m, map[K]struct{}{}) +} + +func MapValuesWithExclude[K comparable, V any](m map[K]V, exclude map[K]struct{}) []V { + values := make([]V, len(m)-len(exclude)) + + i := 0 + for k, v := range m { + if _, excluded := exclude[k]; excluded { + continue + } + + values[i] = v + i++ + } + + return values +} diff --git a/pkg/lib/random.go b/pkg/lib/random.go new file mode 100644 index 0000000..90a8414 --- /dev/null +++ b/pkg/lib/random.go @@ -0,0 +1,25 @@ +package lib + +import "math/rand" + +// RandomSubsetOfLength: Given an array of nodes generate of random +// subset of 'num' length. +func RandomSubsetOfLength[V any](vs []V, num int) []V { + randomSubset := make([]V, 0) + selectedIndices := make(map[int]struct{}) + + for i := 0; i < num; { + if len(selectedIndices) == len(vs) { + return randomSubset + } + + randomIndex := rand.Intn(len(vs)) + + if _, ok := selectedIndices[randomIndex]; !ok { + randomSubset = append(randomSubset, vs[randomIndex]) + i++ + } + } + + return randomSubset +} diff --git a/pkg/robin/robin_requester.go b/pkg/robin/robin_requester.go index 4334be5..b421b1c 100644 --- a/pkg/robin/robin_requester.go +++ b/pkg/robin/robin_requester.go @@ -81,7 +81,7 @@ func (n *RobinIpc) Authenticate(meshId, endpoint string) error { return err } -func (n *RobinIpc) updatePeers(meshId string) error { +func (n *RobinIpc) authenticatePeers(meshId string) error { theMesh := n.Server.MeshManager.GetMesh(meshId) if theMesh == nil { @@ -101,11 +101,9 @@ func (n *RobinIpc) updatePeers(meshId string) error { continue } - var reply string - err := n.JoinMesh(ipc.JoinMeshArgs{MeshId: meshId, IpAdress: node.HostEndpoint}, &reply) + err := n.Authenticate(meshId, node.HostEndpoint) if err != nil { - logging.InfoLog.Println(err) return err } } @@ -199,7 +197,11 @@ func (n *RobinIpc) JoinMesh(args ipc.JoinMeshArgs, reply *string) error { } if joinReply.GetSuccess() { - err = n.updatePeers(args.MeshId) + err = n.authenticatePeers(args.MeshId) + } + + if err != nil { + return err } *reply = strconv.FormatBool(joinReply.GetSuccess()) diff --git a/pkg/sync/syncer.go b/pkg/sync/syncer.go index 61cd631..968219d 100644 --- a/pkg/sync/syncer.go +++ b/pkg/sync/syncer.go @@ -1,5 +1,13 @@ package sync +import ( + "errors" + + crdt "github.com/tim-beatham/wgmesh/pkg/automerge" + "github.com/tim-beatham/wgmesh/pkg/lib" + "github.com/tim-beatham/wgmesh/pkg/manager" +) + // Syncer: picks random nodes from the mesh type Syncer interface { Sync(meshId string) error @@ -7,14 +15,65 @@ type Syncer interface { } type SyncerImpl struct { + manager *manager.MeshManger + requester SyncRequester + authenticatedNodes []crdt.MeshNodeCrdt } +const subSetLength = 5 +const maxAuthentications = 30 + // Sync: Sync random nodes func (s *SyncerImpl) Sync(meshId string) error { + mesh := s.manager.GetMesh(meshId) + + if mesh == nil { + return errors.New("the provided mesh does not exist") + } + + snapshot, err := mesh.GetCrdt() + + if err != nil { + return err + } + + pubKey, err := s.manager.GetPublicKey(meshId) + + if err != nil { + return err + } + + excludedNodes := map[string]struct{}{ + pubKey.String(): {}, + } + + meshNodes := lib.MapValuesWithExclude(snapshot.Nodes, excludedNodes) + randomSubset := lib.RandomSubsetOfLength(meshNodes, subSetLength) + + for _, n := range randomSubset { + err := s.requester.SyncMesh(meshId, n.HostEndpoint) + + if err != nil { + return err + } + } + return nil } -// SyncMeshes: +// SyncMeshes: Sync all meshes func (s *SyncerImpl) SyncMeshes() error { + for _, m := range s.manager.Meshes { + err := s.Sync(m.MeshId) + + if err != nil { + return err + } + } + return nil } + +func NewSyncer(m *manager.MeshManger, r SyncRequester) Syncer { + return &SyncerImpl{manager: m, requester: r} +} diff --git a/pkg/sync/syncrequester.go b/pkg/sync/syncrequester.go index 885ef12..3bde281 100644 --- a/pkg/sync/syncrequester.go +++ b/pkg/sync/syncrequester.go @@ -6,19 +6,37 @@ import ( "time" "github.com/tim-beatham/wgmesh/pkg/ctrlserver" + logging "github.com/tim-beatham/wgmesh/pkg/log" "github.com/tim-beatham/wgmesh/pkg/rpc" ) // SyncRequester: coordinates the syncing of meshes type SyncRequester interface { - GetMesh(meshId string) error - SyncMesh(meshid string) error + GetMesh(meshId string, endPoint string) error + SyncMesh(meshid string, endPoint string) error } type SyncRequesterImpl struct { server *ctrlserver.MeshCtrlServer } +func (s *SyncRequesterImpl) Authenticate(meshId, endpoint string) error { + + peerConnection, err := s.server.ConnectionManager.AddConnection(endpoint) + + if err != nil { + return err + } + + err = peerConnection.Authenticate(meshId) + + if err != nil { + return err + } + + return err +} + // GetMesh: Retrieves the local state of the mesh at the endpoint func (s *SyncRequesterImpl) GetMesh(meshId string, endPoint string) error { peerConnection, err := s.server.ConnectionManager.GetConnection(endPoint) @@ -60,7 +78,11 @@ func (s *SyncRequesterImpl) GetMesh(meshId string, endPoint string) error { } // SyncMesh: Proactively send a sync request to the other mesh -func (s *SyncRequesterImpl) SyncMesh(meshId string, endpoint string) error { +func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error { + if !s.server.ConnectionManager.HasConnection(endpoint) { + s.Authenticate(meshId, endpoint) + } + peerConnection, err := s.server.ConnectionManager.GetConnection(endpoint) if err != nil { @@ -107,5 +129,10 @@ func (s *SyncRequesterImpl) SyncMesh(meshId string, endpoint string) error { return err } + logging.InfoLog.Printf("Synced with node: %s meshId: %s\n", endpoint, meshId) return nil } + +func NewSyncRequester(s *ctrlserver.MeshCtrlServer) SyncRequester { + return &SyncRequesterImpl{server: s} +} diff --git a/pkg/sync/syncscheduler.go b/pkg/sync/syncscheduler.go index 5547a87..df09eda 100644 --- a/pkg/sync/syncscheduler.go +++ b/pkg/sync/syncscheduler.go @@ -4,6 +4,7 @@ import ( "time" "github.com/tim-beatham/wgmesh/pkg/ctrlserver" + logging "github.com/tim-beatham/wgmesh/pkg/log" ) // SyncScheduler: Loops through all nodes in the mesh and runs a schedule to @@ -14,13 +15,15 @@ type SyncScheduler interface { } type SyncSchedulerImpl struct { - quit chan struct{} - server *ctrlserver.MeshCtrlServer + syncRate int + quit chan struct{} + server *ctrlserver.MeshCtrlServer + syncer Syncer } // Run implements SyncScheduler. func (s *SyncSchedulerImpl) Run() error { - ticker := time.NewTicker(time.Second) + ticker := time.NewTicker(time.Duration(s.syncRate) * time.Second) quit := make(chan struct{}) s.quit = quit @@ -28,6 +31,11 @@ func (s *SyncSchedulerImpl) Run() error { for { select { case <-ticker.C: + err := s.syncer.SyncMeshes() + + if err != nil { + logging.ErrorLog.Println(err.Error()) + } break case <-quit: break @@ -41,6 +49,7 @@ func (s *SyncSchedulerImpl) Stop() error { return nil } -func NewSyncScheduler(s *ctrlserver.MeshCtrlServer) SyncScheduler { - return &SyncSchedulerImpl{server: s} +func NewSyncScheduler(s *ctrlserver.MeshCtrlServer, syncRequester SyncRequester, syncRate int) SyncScheduler { + syncer := NewSyncer(s.MeshManager, syncRequester) + return &SyncSchedulerImpl{server: s, syncRate: syncRate, syncer: syncer} } diff --git a/pkg/sync/syncservice.go b/pkg/sync/syncservice.go index 95f5f1e..7f2bd5b 100644 --- a/pkg/sync/syncservice.go +++ b/pkg/sync/syncservice.go @@ -10,12 +10,13 @@ import ( ) type SyncServiceImpl struct { - server *ctrlserver.MeshCtrlServer + rpc.UnimplementedSyncServiceServer + Server *ctrlserver.MeshCtrlServer } // GetMesh: Gets a nodes local mesh configuration as a CRDT func (s *SyncServiceImpl) GetConf(context context.Context, request *rpc.GetConfRequest) (*rpc.GetConfReply, error) { - mesh := s.server.MeshManager.GetMesh(request.MeshId) + mesh := s.Server.MeshManager.GetMesh(request.MeshId) if mesh == nil { return nil, errors.New("mesh does not exist") @@ -32,13 +33,13 @@ func (s *SyncServiceImpl) GetConf(context context.Context, request *rpc.GetConfR // Sync: Pings a node and syncs the mesh configuration with the other node func (s *SyncServiceImpl) SyncMesh(conext context.Context, request *rpc.SyncMeshRequest) (*rpc.SyncMeshReply, error) { - mesh := s.server.MeshManager.GetMesh(request.MeshId) + mesh := s.Server.MeshManager.GetMesh(request.MeshId) if mesh == nil { return nil, errors.New("mesh does not exist") } - err := s.server.MeshManager.UpdateMesh(request.MeshId, request.Changes) + err := s.Server.MeshManager.UpdateMesh(request.MeshId, request.Changes) if err != nil { return nil, err @@ -47,3 +48,6 @@ func (s *SyncServiceImpl) SyncMesh(conext context.Context, request *rpc.SyncMesh return &rpc.SyncMeshReply{Success: true}, nil } +func NewSyncService(server *ctrlserver.MeshCtrlServer) *SyncServiceImpl { + return &SyncServiceImpl{Server: server} +} diff --git a/pkg/wg/wg.go b/pkg/wg/wg.go index ce79de9..51cc002 100644 --- a/pkg/wg/wg.go +++ b/pkg/wg/wg.go @@ -30,7 +30,7 @@ func CreateInterface(ifName string) error { /* * Create and configure a new WireGuard client */ -func CreateClient(ifName string) (*wgctrl.Client, error) { +func CreateClient(ifName string, port int) (*wgctrl.Client, error) { err := CreateInterface(ifName) if err != nil { @@ -43,7 +43,6 @@ func CreateClient(ifName string) (*wgctrl.Client, error) { return nil, err } - wgListenPort := 51820 privateKey, err := wgtypes.GeneratePrivateKey() if err != nil { @@ -52,7 +51,7 @@ func CreateClient(ifName string) (*wgctrl.Client, error) { var cfg wgtypes.Config = wgtypes.Config{ PrivateKey: &privateKey, - ListenPort: &wgListenPort, + ListenPort: &port, } client.ConfigureDevice(ifName, cfg)