diff --git a/cmd/wg-mesh/main.go b/cmd/wg-mesh/main.go index 5f66268..8745439 100644 --- a/cmd/wg-mesh/main.go +++ b/cmd/wg-mesh/main.go @@ -5,7 +5,7 @@ import ( "log" ipcRpc "net/rpc" "os" - "strconv" + "time" "github.com/akamensky/argparse" "github.com/tim-beatham/wgmesh/pkg/ipc" @@ -14,9 +14,14 @@ import ( const SockAddr = "/tmp/wgmesh_ipc.sock" -func createMesh(client *ipcRpc.Client) string { +func createMesh(client *ipcRpc.Client, ifName string, wgPort int) string { var reply string - err := client.Call("RobinIpc.CreateMesh", "", &reply) + newMeshParams := ipc.NewMeshArgs{ + IfName: ifName, + WgPort: wgPort, + } + + err := client.Call("RobinIpc.CreateMesh", &newMeshParams, &reply) if err != nil { return err.Error() @@ -40,10 +45,15 @@ func listMeshes(client *ipcRpc.Client) { } } -func joinMesh(client *ipcRpc.Client, meshId string, ipAddress string) string { +func joinMesh(client *ipcRpc.Client, meshId string, ipAddress string, ifName string, wgPort int) string { var reply string - args := ipc.JoinMeshArgs{MeshId: meshId, IpAdress: ipAddress} + args := ipc.JoinMeshArgs{ + MeshId: meshId, + IpAdress: ipAddress, + IfName: ifName, + Port: wgPort, + } err := client.Call("RobinIpc.JoinMesh", &args, &reply) @@ -69,7 +79,7 @@ func getMesh(client *ipcRpc.Client, meshId string) { fmt.Println("Control Endpoint: " + node.HostEndpoint) fmt.Println("WireGuard Endpoint: " + node.WgEndpoint) fmt.Println("Wg IP: " + node.WgHost) - fmt.Println("Failed Count: " + strconv.FormatBool(node.Failed)) + fmt.Println(fmt.Sprintf("Timestamp: %s", time.Unix(node.Timestamp, 0).String())) fmt.Println("---") } } @@ -111,8 +121,14 @@ func main() { enableInterfaceCmd := parser.NewCommand("enable-interface", "Enable A Specific Mesh Interface") getGraphCmd := parser.NewCommand("get-graph", "Convert a mesh into DOT format") + var newMeshIfName *string = newMeshCmd.String("f", "ifname", &argparse.Options{Required: true}) + var newMeshPort *int = newMeshCmd.Int("p", "wgport", &argparse.Options{Required: true}) + var meshId *string = joinMeshCmd.String("m", "mesh", &argparse.Options{Required: true}) var ipAddress *string = joinMeshCmd.String("i", "ip", &argparse.Options{Required: true}) + var joinMeshIfName *string = joinMeshCmd.String("f", "ifname", &argparse.Options{Required: true}) + var joinMeshPort *int = joinMeshCmd.Int("p", "wgport", &argparse.Options{Required: true}) + var getMeshId *string = getMeshCmd.String("m", "mesh", &argparse.Options{Required: true}) var enableInterfaceMeshId *string = enableInterfaceCmd.String("m", "mesh", &argparse.Options{Required: true}) var getGraphMeshId *string = getGraphCmd.String("m", "mesh", &argparse.Options{Required: true}) @@ -131,7 +147,7 @@ func main() { } if newMeshCmd.Happened() { - fmt.Println(createMesh(client)) + fmt.Println(createMesh(client, *newMeshIfName, *newMeshPort)) } if listMeshCmd.Happened() { @@ -139,7 +155,7 @@ func main() { } if joinMeshCmd.Happened() { - fmt.Println(joinMesh(client, *meshId, *ipAddress)) + fmt.Println(joinMesh(client, *meshId, *ipAddress, *joinMeshIfName, *joinMeshPort)) } if getMeshCmd.Happened() { diff --git a/cmd/wgmeshd/main.go b/cmd/wgmeshd/main.go index 3e9ef25..ba7b349 100644 --- a/cmd/wgmeshd/main.go +++ b/cmd/wgmeshd/main.go @@ -11,7 +11,8 @@ import ( "github.com/tim-beatham/wgmesh/pkg/middleware" "github.com/tim-beatham/wgmesh/pkg/robin" "github.com/tim-beatham/wgmesh/pkg/sync" - wg "github.com/tim-beatham/wgmesh/pkg/wg" + "github.com/tim-beatham/wgmesh/pkg/timestamp" + "golang.zx2c4.com/wireguard/wgctrl" ) func main() { @@ -20,7 +21,12 @@ func main() { log.Fatalln("Could not parse configuration") } - wgClient, err := wg.CreateClient(conf.IfName, conf.WgPort) + client, err := wgctrl.New() + + if err != nil { + logging.Log.WriteErrorf("Failed to create wgctrl client") + return + } var robinRpc robin.RobinRpc var robinIpc robin.RobinIpc @@ -28,17 +34,18 @@ func main() { var syncProvider sync.SyncServiceImpl ctrlServerParams := ctrlserver.NewCtrlServerParams{ - WgClient: wgClient, Conf: conf, AuthProvider: &authProvider, CtrlProvider: &robinRpc, SyncProvider: &syncProvider, + Client: client, } ctrlServer, err := ctrlserver.NewCtrlServer(&ctrlServerParams) syncProvider.Server = ctrlServer syncRequester := sync.NewSyncRequester(ctrlServer) syncScheduler := sync.NewSyncScheduler(ctrlServer, syncRequester, 2) + timestampScheduler := timestamp.NewTimestampScheduler(ctrlServer, 60) robinIpcParams := robin.RobinIpcParams{ CtrlServer: ctrlServer, @@ -57,6 +64,7 @@ func main() { go ipc.RunIpcHandler(&robinIpc) go syncScheduler.Run() + go timestampScheduler.Run() err = ctrlServer.ConnectionServer.Listen() @@ -67,6 +75,7 @@ func main() { } defer syncScheduler.Stop() + defer timestampScheduler.Stop() defer ctrlServer.Close() - defer wgClient.Close() + defer client.Close() } diff --git a/pkg/automerge/automerge.go b/pkg/automerge/automerge.go index 0f1176d..6623d8a 100644 --- a/pkg/automerge/automerge.go +++ b/pkg/automerge/automerge.go @@ -5,26 +5,30 @@ import ( "fmt" "net" "strings" + "time" "github.com/automerge/automerge-go" logging "github.com/tim-beatham/wgmesh/pkg/log" + "github.com/tim-beatham/wgmesh/pkg/wg" "golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) // CrdtNodeManager manages nodes in the crdt mesh type CrdtNodeManager struct { - MeshId string - IfName string - NodeId string - Client *wgctrl.Client - doc *automerge.Doc + MeshId string + IfName string + NodeId string + Client *wgctrl.Client + doc *automerge.Doc + LastHash automerge.ChangeHash } const maxFails = 5 func (c *CrdtNodeManager) AddNode(crdt MeshNodeCrdt) { crdt.FailedMap = automerge.NewMap() + crdt.Timestamp = time.Now().Unix() c.doc.Path("nodes").Map().Set(crdt.HostEndpoint, crdt) } @@ -61,29 +65,22 @@ func (c *CrdtNodeManager) Save() []byte { return c.doc.Save() } -func (c *CrdtNodeManager) LoadChanges(changes []byte) error { - err := c.doc.LoadIncremental(changes) - - if err != nil { - return err - } - - return nil -} - -func (c *CrdtNodeManager) SaveChanges() []byte { - return c.doc.SaveIncremental() -} - // NewCrdtNodeManager: Create a new crdt node manager -func NewCrdtNodeManager(meshId, hostId, devName string, client *wgctrl.Client) *CrdtNodeManager { +func NewCrdtNodeManager(meshId, hostId, devName string, port int, client *wgctrl.Client) (*CrdtNodeManager, error) { var manager CrdtNodeManager manager.MeshId = meshId manager.doc = automerge.New() manager.IfName = devName manager.Client = client manager.NodeId = hostId - return &manager + + err := wg.CreateWgInterface(client, devName, port) + + if err != nil { + return nil, err + } + + return &manager, nil } func (m *CrdtNodeManager) convertMeshNode(node MeshNodeCrdt) (*wgtypes.PeerConfig, error) { @@ -193,7 +190,29 @@ func (m *CrdtNodeManager) Length() int { return m.doc.Path("nodes").Map().Len() } -const thresholdVotes = 0.1 +func (m *CrdtNodeManager) GetDevice() (*wgtypes.Device, error) { + dev, err := m.Client.Device(m.IfName) + + if err != nil { + return nil, err + } + + return dev, nil +} + +// HasChanges returns true if we have changes since the last time we synced +func (m *CrdtNodeManager) HasChanges() bool { + changes, err := m.doc.Changes(m.LastHash) + + logging.Log.WriteInfof("Changes %s", m.LastHash.String()) + + if err != nil { + return false + } + + logging.Log.WriteInfof("Changes length %d", len(changes)) + return len(changes) > 0 +} func (m *CrdtNodeManager) HasFailed(endpoint string) bool { node, err := m.GetNode(endpoint) @@ -222,6 +241,30 @@ func (m *CrdtNodeManager) HasFailed(endpoint string) bool { return countFailed >= 4 } +func (m *CrdtNodeManager) SaveChanges() { + hashes := m.doc.Heads() + hash := hashes[len(hashes)-1] + + logging.Log.WriteInfof("Saved Hash %s", hash.String()) + m.LastHash = hash +} + +func (m *CrdtNodeManager) UpdateTimeStamp() error { + node, err := m.doc.Path("nodes").Map().Get(m.NodeId) + + if err != nil { + return err + } + + err = node.Map().Set("timestamp", time.Now().Unix()) + + if err == nil { + logging.Log.WriteInfof("Timestamp Updated for %s", m.MeshId) + } + + return err +} + func (m *CrdtNodeManager) updateWgConf(devName string, nodes map[string]MeshNodeCrdt, client wgctrl.Client) error { peerConfigs := make([]wgtypes.PeerConfig, len(nodes)) diff --git a/pkg/automerge/automerge_sync.go b/pkg/automerge/automerge_sync.go index bab8c7f..d99a1af 100644 --- a/pkg/automerge/automerge_sync.go +++ b/pkg/automerge/automerge_sync.go @@ -2,10 +2,12 @@ package crdt import ( "github.com/automerge/automerge-go" + logging "github.com/tim-beatham/wgmesh/pkg/log" ) type AutomergeSync struct { - state *automerge.SyncState + state *automerge.SyncState + manager *CrdtNodeManager } func (a *AutomergeSync) GenerateMessage() ([]byte, bool) { @@ -28,6 +30,14 @@ func (a *AutomergeSync) RecvMessage(msg []byte) error { return nil } -func NewAutomergeSync(manager *CrdtNodeManager) *AutomergeSync { - return &AutomergeSync{state: automerge.NewSyncState(manager.doc)} +func (a *AutomergeSync) Complete() { + logging.Log.WriteInfof("Sync Completed") + a.manager.SaveChanges() +} + +func NewAutomergeSync(manager *CrdtNodeManager) *AutomergeSync { + return &AutomergeSync{ + state: automerge.NewSyncState(manager.doc), + manager: manager, + } } diff --git a/pkg/automerge/types.go b/pkg/automerge/types.go index c4f9e1f..8dfd33e 100644 --- a/pkg/automerge/types.go +++ b/pkg/automerge/types.go @@ -7,8 +7,8 @@ type MeshNodeCrdt struct { WgEndpoint string `automerge:"wgEndpoint"` PublicKey string `automerge:"publicKey"` WgHost string `automerge:"wgHost"` + Timestamp int64 `automerge:"timestamp"` FailedMap *automerge.Map `automerge:"failedMap"` - FailedInt int `automerge:"-"` } type MeshCrdt struct { diff --git a/pkg/conn/connection.go b/pkg/conn/connection.go index 2b84a78..57d3571 100644 --- a/pkg/conn/connection.go +++ b/pkg/conn/connection.go @@ -5,10 +5,12 @@ package conn import ( "crypto/tls" "errors" + "time" logging "github.com/tim-beatham/wgmesh/pkg/log" "google.golang.org/grpc" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/keepalive" ) // PeerConnection represents a client-side connection between two @@ -42,7 +44,10 @@ func NewWgCtrlConnection(clientConfig *tls.Config, server string) (*WgCtrlConnec func (c *WgCtrlConnection) createGrpcConn() error { conn, err := grpc.Dial(c.endpoint, grpc.WithTransportCredentials(credentials.NewTLS(c.clientConfig)), - ) + grpc.WithKeepaliveParams(keepalive.ClientParameters{ + Time: 10 * time.Minute, + Timeout: 30 * time.Minute, + })) if err != nil { logging.Log.WriteErrorf("Could not connect: %s\n", err.Error()) diff --git a/pkg/ctrlserver/ctrlserver.go b/pkg/ctrlserver/ctrlserver.go index 3bcfd9b..417a081 100644 --- a/pkg/ctrlserver/ctrlserver.go +++ b/pkg/ctrlserver/ctrlserver.go @@ -10,8 +10,8 @@ import ( // NewCtrlServerParams are the params requried to create a new ctrl server type NewCtrlServerParams struct { - WgClient *wgctrl.Client Conf *conf.WgMeshConfiguration + Client *wgctrl.Client AuthProvider rpc.AuthenticationServer CtrlProvider rpc.MeshCtrlServerServer SyncProvider rpc.SyncServiceServer @@ -21,8 +21,7 @@ type NewCtrlServerParams struct { // operation failed func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) { ctrlServer := new(MeshCtrlServer) - ctrlServer.Client = params.WgClient - ctrlServer.MeshManager = mesh.NewMeshManager(*params.WgClient, *params.Conf) + ctrlServer.MeshManager = mesh.NewMeshManager(*params.Conf, params.Client) ctrlServer.Conf = params.Conf connManagerParams := conn.NewConnectionManageParams{ diff --git a/pkg/ctrlserver/ctrltypes.go b/pkg/ctrlserver/ctrltypes.go index 6f1bb2f..a4a855f 100644 --- a/pkg/ctrlserver/ctrltypes.go +++ b/pkg/ctrlserver/ctrltypes.go @@ -17,6 +17,7 @@ type MeshNode struct { PublicKey string WgHost string Failed bool + Timestamp int64 } type Mesh struct { diff --git a/pkg/ipc/ipc.go b/pkg/ipc/ipc.go index f087077..d2b15db 100644 --- a/pkg/ipc/ipc.go +++ b/pkg/ipc/ipc.go @@ -10,9 +10,16 @@ import ( "github.com/tim-beatham/wgmesh/pkg/ctrlserver" ) +type NewMeshArgs struct { + IfName string + WgPort int +} + type JoinMeshArgs struct { MeshId string IpAdress string + IfName string + Port int } type GetMeshReply struct { @@ -24,7 +31,7 @@ type ListMeshReply struct { } type MeshIpc interface { - CreateMesh(name string, reply *string) error + CreateMesh(args *NewMeshArgs, reply *string) error ListMeshes(name string, reply *ListMeshReply) error JoinMesh(args JoinMeshArgs, reply *string) error GetMesh(meshId string, reply *GetMeshReply) error diff --git a/pkg/mesh/mesh_manager.go b/pkg/mesh/mesh_manager.go index b4f074c..090fe44 100644 --- a/pkg/mesh/mesh_manager.go +++ b/pkg/mesh/mesh_manager.go @@ -16,6 +16,7 @@ type MeshManger struct { Meshes map[string]*crdt.CrdtNodeManager Client *wgctrl.Client HostEndpoint string + conf *conf.WgMeshConfiguration } func (m *MeshManger) MeshExists(meshId string) bool { @@ -24,52 +25,32 @@ func (m *MeshManger) MeshExists(meshId string) bool { } // CreateMesh: Creates a new mesh, stores it and returns the mesh id -func (m *MeshManger) CreateMesh(devName string) (string, error) { +func (m *MeshManger) CreateMesh(devName string, port int) (string, error) { key, err := wgtypes.GenerateKey() if err != nil { return "", err } - nodeManager := crdt.NewCrdtNodeManager(key.String(), m.HostEndpoint, devName, m.Client) + nodeManager, err := crdt.NewCrdtNodeManager(key.String(), m.HostEndpoint, devName, port, m.Client) + + if err != nil { + return "", err + } + m.Meshes[key.String()] = nodeManager return key.String(), nil } -// UpdateMesh: merge the changes and save it to the device -func (m *MeshManger) UpdateMesh(meshId string, changes []byte) error { - mesh, ok := m.Meshes[meshId] - - if !ok { - return errors.New("mesh does not exist") - } - - err := mesh.LoadChanges(changes) +// AddMesh: Add the mesh to the list of meshes +func (m *MeshManger) AddMesh(meshId string, devName string, port int, meshBytes []byte) error { + mesh, err := crdt.NewCrdtNodeManager(meshId, m.HostEndpoint, devName, port, m.Client) if err != nil { return err } - return nil -} - -// ApplyWg: applies the wireguard configuration changes -func (m *MeshManger) ApplyWg() error { - for _, mesh := range m.Meshes { - err := mesh.ApplyWg() - - if err != nil { - return err - } - } - - return nil -} - -// AddMesh: Add the mesh to the list of meshes -func (m *MeshManger) AddMesh(meshId string, devName string, meshBytes []byte) error { - mesh := crdt.NewCrdtNodeManager(meshId, m.HostEndpoint, devName, m.Client) - err := mesh.Load(meshBytes) + err = mesh.Load(meshBytes) if err != nil { return err @@ -84,6 +65,10 @@ func (m *MeshManger) AddMeshNode(meshId string, node crdt.MeshNodeCrdt) { m.Meshes[meshId].AddNode(node) } +func (m *MeshManger) HasChanges(meshId string) bool { + return m.Meshes[meshId].HasChanges() +} + func (m *MeshManger) GetMesh(meshId string) *crdt.CrdtNodeManager { theMesh, _ := m.Meshes[meshId] return theMesh @@ -109,6 +94,12 @@ func (s *MeshManger) EnableInterface(meshId string) error { return errors.New("Node does not exist in the mesh") } + err = mesh.ApplyWg() + + if err != nil { + return err + } + return wg.EnableInterface(mesh.IfName, node.WgHost) } @@ -120,7 +111,7 @@ func (s *MeshManger) GetPublicKey(meshId string) (*wgtypes.Key, error) { return nil, errors.New("mesh does not exist") } - dev, err := s.Client.Device(mesh.IfName) + dev, err := mesh.GetDevice() if err != nil { return nil, err @@ -129,12 +120,26 @@ func (s *MeshManger) GetPublicKey(meshId string) (*wgtypes.Key, error) { return &dev.PublicKey, nil } -func NewMeshManager(client wgctrl.Client, conf conf.WgMeshConfiguration) *MeshManger { +// UpdateTimeStamp updates the timestamp of this node in all meshes +func (s *MeshManger) UpdateTimeStamp() error { + for _, mesh := range s.Meshes { + err := mesh.UpdateTimeStamp() + + if err != nil { + return err + } + } + + return nil +} + +func NewMeshManager(conf conf.WgMeshConfiguration, client *wgctrl.Client) *MeshManger { ip := lib.GetOutboundIP() return &MeshManger{ Meshes: make(map[string]*crdt.CrdtNodeManager), - Client: &client, HostEndpoint: fmt.Sprintf("%s:%s", ip.String(), conf.GrpcPort), + Client: client, + conf: &conf, } } diff --git a/pkg/robin/robin_requester.go b/pkg/robin/robin_requester.go index 8cc070c..b0121e2 100644 --- a/pkg/robin/robin_requester.go +++ b/pkg/robin/robin_requester.go @@ -3,6 +3,7 @@ package robin import ( "context" "errors" + "fmt" "strconv" "time" @@ -22,10 +23,10 @@ type RobinIpc struct { ipAllocator ip.IPAllocator } -func (n *RobinIpc) CreateMesh(name string, reply *string) error { - wg.CreateInterface(n.Server.Conf.IfName) +func (n *RobinIpc) CreateMesh(args *ipc.NewMeshArgs, reply *string) error { + wg.CreateInterface(args.IfName) - meshId, err := n.Server.MeshManager.CreateMesh(n.Server.Conf.IfName) + meshId, err := n.Server.MeshManager.CreateMesh(args.IfName, args.WgPort) if err != nil { return err @@ -46,9 +47,9 @@ func (n *RobinIpc) CreateMesh(name string, reply *string) error { outBoundIp := lib.GetOutboundIP() meshNode := crdt.MeshNodeCrdt{ - HostEndpoint: outBoundIp.String() + ":8080", + HostEndpoint: fmt.Sprintf("%s:%s", outBoundIp.String(), n.Server.Conf.GrpcPort), PublicKey: pubKey.String(), - WgEndpoint: outBoundIp.String() + ":51820", + WgEndpoint: fmt.Sprintf("%s:%d", outBoundIp.String(), args.WgPort), WgHost: nodeIP.String() + "/128", } @@ -99,7 +100,7 @@ func (n *RobinIpc) JoinMesh(args ipc.JoinMeshArgs, reply *string) error { return err } - err = n.Server.MeshManager.AddMesh(args.MeshId, n.Server.Conf.IfName, meshReply.Mesh) + err = n.Server.MeshManager.AddMesh(args.MeshId, args.IfName, args.Port, meshReply.Mesh) if err != nil { return err @@ -122,31 +123,14 @@ func (n *RobinIpc) JoinMesh(args ipc.JoinMeshArgs, reply *string) error { outBoundIP := lib.GetOutboundIP() node := crdt.MeshNodeCrdt{ - HostEndpoint: outBoundIP.String() + ":8080", - WgEndpoint: outBoundIP.String() + ":51820", + HostEndpoint: fmt.Sprintf("%s:%s", outBoundIP.String(), n.Server.Conf.GrpcPort), + WgEndpoint: fmt.Sprintf("%s:%d", outBoundIP.String(), args.Port), PublicKey: pubKey.String(), WgHost: ipAddr.String() + "/128", } n.Server.MeshManager.AddMeshNode(args.MeshId, node) - mesh := n.Server.MeshManager.GetMesh(args.MeshId) - - joinMeshRequest := rpc.JoinMeshRequest{ - MeshId: args.MeshId, - Changes: mesh.SaveChanges(), - } - - joinReply, err := c.JoinMesh(ctx, &joinMeshRequest) - - if err != nil { - return err - } - - if err != nil { - return err - } - - *reply = strconv.FormatBool(joinReply.GetSuccess()) + *reply = strconv.FormatBool(true) return nil } @@ -169,6 +153,7 @@ func (n *RobinIpc) GetMesh(meshId string, reply *ipc.GetMeshReply) error { PublicKey: node.PublicKey, WgHost: node.WgHost, Failed: mesh.HasFailed(node.HostEndpoint), + Timestamp: node.Timestamp, } nodes[i] = node diff --git a/pkg/robin/robin_responder.go b/pkg/robin/robin_responder.go index 2f77838..9fe1b8c 100644 --- a/pkg/robin/robin_responder.go +++ b/pkg/robin/robin_responder.go @@ -5,7 +5,6 @@ import ( "errors" "github.com/tim-beatham/wgmesh/pkg/ctrlserver" - logging "github.com/tim-beatham/wgmesh/pkg/log" "github.com/tim-beatham/wgmesh/pkg/rpc" ) @@ -54,19 +53,5 @@ func (m *RobinRpc) GetMesh(ctx context.Context, request *rpc.GetMeshRequest) (*r } func (m *RobinRpc) JoinMesh(ctx context.Context, request *rpc.JoinMeshRequest) (*rpc.JoinMeshReply, error) { - mesh := m.Server.MeshManager.GetMesh(request.MeshId) - - logging.Log.WriteInfof("[JOINING MESH]: " + request.MeshId) - - if mesh == nil { - return nil, errors.New("mesh does not exist") - } - - err := m.Server.MeshManager.UpdateMesh(request.MeshId, request.Changes) - - if err != nil { - return nil, err - } - return &rpc.JoinMeshReply{Success: true}, nil } diff --git a/pkg/sync/syncer.go b/pkg/sync/syncer.go index 2b88b49..b8ab5a3 100644 --- a/pkg/sync/syncer.go +++ b/pkg/sync/syncer.go @@ -3,9 +3,11 @@ package sync import ( "errors" "sync" + "time" crdt "github.com/tim-beatham/wgmesh/pkg/automerge" "github.com/tim-beatham/wgmesh/pkg/lib" + logging "github.com/tim-beatham/wgmesh/pkg/log" "github.com/tim-beatham/wgmesh/pkg/mesh" ) @@ -26,6 +28,11 @@ const maxAuthentications = 30 // Sync: Sync random nodes func (s *SyncerImpl) Sync(meshId string) error { + if !s.manager.HasChanges(meshId) { + logging.Log.WriteInfof("No changes for %s", meshId) + return nil + } + mesh := s.manager.GetMesh(meshId) if mesh == nil { @@ -55,6 +62,8 @@ func (s *SyncerImpl) Sync(meshId string) error { meshNodes := lib.MapValuesWithExclude(snapshot.Nodes, excludedNodes) randomSubset := lib.RandomSubsetOfLength(meshNodes, subSetLength) + before := time.Now() + var waitGroup sync.WaitGroup for _, n := range randomSubset { @@ -70,6 +79,8 @@ func (s *SyncerImpl) Sync(meshId string) error { } waitGroup.Wait() + + logging.Log.WriteInfof("SYNC TIME: %v", time.Now().Sub(before)) return nil } @@ -83,7 +94,7 @@ func (s *SyncerImpl) SyncMeshes() error { } } - return s.manager.ApplyWg() + return nil } func NewSyncer(m *mesh.MeshManger, r SyncRequester) Syncer { diff --git a/pkg/sync/syncrequester.go b/pkg/sync/syncrequester.go index b9c05be..f2b60f7 100644 --- a/pkg/sync/syncrequester.go +++ b/pkg/sync/syncrequester.go @@ -14,7 +14,7 @@ import ( // SyncRequester: coordinates the syncing of meshes type SyncRequester interface { - GetMesh(meshId string, endPoint string) error + GetMesh(meshId string, ifName string, port int, endPoint string) error SyncMesh(meshid string, endPoint string) error } @@ -24,7 +24,7 @@ type SyncRequesterImpl struct { } // GetMesh: Retrieves the local state of the mesh at the endpoint -func (s *SyncRequesterImpl) GetMesh(meshId string, endPoint string) error { +func (s *SyncRequesterImpl) GetMesh(meshId string, ifName string, port int, endPoint string) error { peerConnection, err := s.server.ConnectionManager.GetConnection(endPoint) if err != nil { @@ -48,7 +48,7 @@ func (s *SyncRequesterImpl) GetMesh(meshId string, endPoint string) error { return err } - err = s.server.MeshManager.AddMesh(meshId, s.server.Conf.IfName, reply.Mesh) + err = s.server.MeshManager.AddMesh(meshId, ifName, port, reply.Mesh) return err } @@ -137,7 +137,7 @@ func syncMesh(mesh *crdt.CrdtNodeManager, ctx context.Context, client rpc.SyncSe } } - logging.Log.WriteInfof("SYNC finished") + syncer.Complete() stream.CloseSend() return nil } diff --git a/pkg/sync/syncscheduler.go b/pkg/sync/syncscheduler.go index 9e81b78..e58a500 100644 --- a/pkg/sync/syncscheduler.go +++ b/pkg/sync/syncscheduler.go @@ -14,6 +14,7 @@ type SyncScheduler interface { Stop() error } +// SyncSchedulerImpl scheduler for sync scheduling type SyncSchedulerImpl struct { syncRate int quit chan struct{} diff --git a/pkg/sync/syncservice.go b/pkg/sync/syncservice.go index 9a4c0cb..82423b5 100644 --- a/pkg/sync/syncservice.go +++ b/pkg/sync/syncservice.go @@ -8,7 +8,6 @@ import ( crdt "github.com/tim-beatham/wgmesh/pkg/automerge" "github.com/tim-beatham/wgmesh/pkg/ctrlserver" - logging "github.com/tim-beatham/wgmesh/pkg/log" "github.com/tim-beatham/wgmesh/pkg/rpc" ) @@ -41,11 +40,12 @@ func (s *SyncServiceImpl) SyncMesh(stream rpc.SyncService_SyncMeshServer) error var syncer *crdt.AutomergeSync = nil for { - logging.Log.WriteInfof("Received Attempt") in, err := stream.Recv() - logging.Log.WriteInfof("Received Worked") if err == io.EOF { + if syncer != nil { + syncer.Complete() + } return nil } @@ -84,6 +84,9 @@ func (s *SyncServiceImpl) SyncMesh(stream rpc.SyncService_SyncMeshServer) error } if !moreMessages || err == io.EOF { + if syncer != nil { + syncer.Complete() + } return nil } } diff --git a/pkg/timestamp/timestamp.go b/pkg/timestamp/timestamp.go new file mode 100644 index 0000000..df4503e --- /dev/null +++ b/pkg/timestamp/timestamp.go @@ -0,0 +1,48 @@ +package timestamp + +import ( + "time" + + "github.com/tim-beatham/wgmesh/pkg/ctrlserver" + logging "github.com/tim-beatham/wgmesh/pkg/log" + "github.com/tim-beatham/wgmesh/pkg/mesh" +) + +type TimestampScheduler interface { + Run() error + Stop() error +} + +type TimeStampSchedulerImpl struct { + meshManager *mesh.MeshManger + updateRate int + quit chan struct{} +} + +func (s *TimeStampSchedulerImpl) Run() error { + ticker := time.NewTicker(time.Duration(s.updateRate) * time.Second) + + s.quit = make(chan struct{}) + + for { + select { + case <-ticker.C: + err := s.meshManager.UpdateTimeStamp() + + if err != nil { + logging.Log.WriteErrorf("Update Timestamp Error: %s", err.Error()) + } + case <-s.quit: + break + } + } +} + +func NewTimestampScheduler(ctrlServer *ctrlserver.MeshCtrlServer, updateRate int) TimestampScheduler { + return &TimeStampSchedulerImpl{meshManager: ctrlServer.MeshManager, updateRate: updateRate} +} + +func (s *TimeStampSchedulerImpl) Stop() error { + close(s.quit) + return nil +} diff --git a/pkg/wg/wg.go b/pkg/wg/wg.go index 4d326c6..33ea89f 100644 --- a/pkg/wg/wg.go +++ b/pkg/wg/wg.go @@ -30,23 +30,17 @@ func CreateInterface(ifName string) error { /* * Create and configure a new WireGuard client */ -func CreateClient(ifName string, port int) (*wgctrl.Client, error) { +func CreateWgInterface(client *wgctrl.Client, ifName string, port int) error { err := CreateInterface(ifName) if err != nil { - return nil, err - } - - client, err := wgctrl.New() - - if err != nil { - return nil, err + return err } privateKey, err := wgtypes.GeneratePrivateKey() if err != nil { - return nil, err + return err } var cfg wgtypes.Config = wgtypes.Config{ @@ -55,7 +49,7 @@ func CreateClient(ifName string, port int) (*wgctrl.Client, error) { } client.ConfigureDevice(ifName, cfg) - return client, nil + return nil } func EnableInterface(ifName string, ip string) error { @@ -72,7 +66,7 @@ func EnableInterface(ifName string, ip string) error { return err } - cmd = exec.Command("/usr/bin/ip", "addr", "add", hostIp.String()+"/64", "dev", "wgmesh") + cmd = exec.Command("/usr/bin/ip", "addr", "add", hostIp.String()+"/64", "dev", ifName) if err := cmd.Run(); err != nil { return err diff --git a/wgmesh.go b/wgmesh.go deleted file mode 100644 index 15eb2b1..0000000 --- a/wgmesh.go +++ /dev/null @@ -1,49 +0,0 @@ -package main - -import ( - "fmt" - - "golang.zx2c4.com/wireguard/wgctrl" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" -) - -func main() { - client, err := wgctrl.New() - - if err != nil { - return - } - - privateKey, err := wgtypes.GeneratePrivateKey() - var listenPort int = 5109 - - if err != nil { - return - } - - cfg := wgtypes.Config{ - PrivateKey: &privateKey, - ListenPort: &listenPort, - } - - err = client.ConfigureDevice("utun9", cfg) - - if err != nil { - return - } - - devices, err := client.Devices() - - if err != nil { - return - } - - fmt.Printf("Number of devices: %d\n", len(devices)) - - for _, device := range devices { - fmt.Printf("Device Name: %s\n", device.Name) - fmt.Printf("Listen Port: %d\n", device.ListenPort) - fmt.Printf("Private Key: %s\n", device.PrivateKey.String()) - fmt.Printf("Public Key: %s\n", device.PublicKey.String()) - } -}