diff --git a/cmd/wg-mesh/main.go b/cmd/wg-mesh/main.go index bbaf274..e73bdab 100644 --- a/cmd/wg-mesh/main.go +++ b/cmd/wg-mesh/main.go @@ -22,7 +22,7 @@ func createMesh(client *ipcRpc.Client, ifName string, wgPort int) string { WgPort: wgPort, } - err := client.Call("RobinIpc.CreateMesh", &newMeshParams, &reply) + err := client.Call("IpcHandler.CreateMesh", &newMeshParams, &reply) if err != nil { return err.Error() @@ -34,7 +34,7 @@ func createMesh(client *ipcRpc.Client, ifName string, wgPort int) string { func listMeshes(client *ipcRpc.Client) { reply := new(ipc.ListMeshReply) - err := client.Call("RobinIpc.ListMeshes", "", &reply) + err := client.Call("IpcHandler.ListMeshes", "", &reply) if err != nil { logging.Log.WriteErrorf(err.Error()) @@ -56,7 +56,7 @@ func joinMesh(client *ipcRpc.Client, meshId string, ipAddress string, ifName str Port: wgPort, } - err := client.Call("RobinIpc.JoinMesh", &args, &reply) + err := client.Call("IpcHandler.JoinMesh", &args, &reply) if err != nil { return err.Error() @@ -68,7 +68,7 @@ func joinMesh(client *ipcRpc.Client, meshId string, ipAddress string, ifName str func getMesh(client *ipcRpc.Client, meshId string) { reply := new(ipc.GetMeshReply) - err := client.Call("RobinIpc.GetMesh", &meshId, &reply) + err := client.Call("IpcHandler.GetMesh", &meshId, &reply) if err != nil { log.Panic(err.Error()) @@ -92,7 +92,7 @@ func getMesh(client *ipcRpc.Client, meshId string) { func enableInterface(client *ipcRpc.Client, meshId string) { var reply string - err := client.Call("RobinIpc.EnableInterface", &meshId, &reply) + err := client.Call("IpcHandler.EnableInterface", &meshId, &reply) if err != nil { fmt.Println(err.Error()) @@ -105,7 +105,7 @@ func enableInterface(client *ipcRpc.Client, meshId string) { func getGraph(client *ipcRpc.Client, meshId string) { var reply string - err := client.Call("RobinIpc.GetDOT", &meshId, &reply) + err := client.Call("IpcHandler.GetDOT", &meshId, &reply) if err != nil { fmt.Println(err.Error()) diff --git a/cmd/wgmeshd/main.go b/cmd/wgmeshd/main.go index ba7b349..5b325c3 100644 --- a/cmd/wgmeshd/main.go +++ b/cmd/wgmeshd/main.go @@ -28,8 +28,8 @@ func main() { return } - var robinRpc robin.RobinRpc - var robinIpc robin.RobinIpc + var robinRpc robin.WgRpc + var robinIpc robin.IpcHandler var authProvider middleware.AuthRpcProvider var syncProvider sync.SyncServiceImpl diff --git a/pkg/automerge/automerge.go b/pkg/automerge/automerge.go index 76c5a1b..c99439d 100644 --- a/pkg/automerge/automerge.go +++ b/pkg/automerge/automerge.go @@ -2,21 +2,22 @@ package crdt import ( "errors" - "fmt" "net" "strings" "time" "github.com/automerge/automerge-go" "github.com/tim-beatham/wgmesh/pkg/conf" + "github.com/tim-beatham/wgmesh/pkg/lib" logging "github.com/tim-beatham/wgmesh/pkg/log" + "github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/wgmesh/pkg/wg" "golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) -// CrdtNodeManager manages nodes in the crdt mesh -type CrdtNodeManager struct { +// CrdtMeshManager manages nodes in the crdt mesh +type CrdtMeshManager struct { MeshId string IfName string NodeId string @@ -26,57 +27,63 @@ type CrdtNodeManager struct { conf *conf.WgMeshConfiguration } -const maxFails = 5 +func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) { + crdt, ok := node.(*MeshNodeCrdt) + + if !ok { + panic("node must be of type *MeshNodeCrdt") + } -func (c *CrdtNodeManager) AddNode(crdt MeshNodeCrdt) { - crdt.FailedMap = automerge.NewMap() crdt.Timestamp = time.Now().Unix() c.doc.Path("nodes").Map().Set(crdt.HostEndpoint, crdt) nodeVal, _ := c.doc.Path("nodes").Map().Get(crdt.HostEndpoint) nodeVal.Map().Set("routes", automerge.NewMap()) } -func (c *CrdtNodeManager) ApplyWg() error { - snapshot, err := c.GetCrdt() +func (c *CrdtMeshManager) ApplyWg() error { + // snapshot, err := c.GetMesh() - if err != nil { - return err - } + // if err != nil { + // return err + // } - c.updateWgConf(c.IfName, snapshot.Nodes, *c.Client) + // c.updateWgConf(c.IfName, snapshot.GetNodes(), *c.Client) return nil } -// GetCrdt(): Converts the document into a struct -func (c *CrdtNodeManager) GetCrdt() (*MeshCrdt, error) { +// GetMesh(): Converts the document into a struct +func (c *CrdtMeshManager) GetMesh() (mesh.MeshSnapshot, error) { return automerge.As[*MeshCrdt](c.doc.Root()) } +// GetMeshId returns the meshid of the mesh +func (c *CrdtMeshManager) GetMeshId() string { + return c.MeshId +} + +// Save: Save an entire mesh network +func (c *CrdtMeshManager) Save() []byte { + return c.doc.Save() +} + // Load: Load an entire mesh network -func (c *CrdtNodeManager) Load(bytes []byte) error { +func (c *CrdtMeshManager) Load(bytes []byte) error { doc, err := automerge.Load(bytes) if err != nil { return err } - c.doc = doc return nil } -// Save: Save an entire mesh network -func (c *CrdtNodeManager) Save() []byte { - return c.doc.Save() -} - // NewCrdtNodeManager: Create a new crdt node manager -func NewCrdtNodeManager(meshId, hostId, devName string, port int, conf conf.WgMeshConfiguration, client *wgctrl.Client) (*CrdtNodeManager, error) { - var manager CrdtNodeManager +func NewCrdtNodeManager(meshId, devName string, port int, conf conf.WgMeshConfiguration, client *wgctrl.Client) (*CrdtMeshManager, error) { + var manager CrdtMeshManager manager.MeshId = meshId manager.doc = automerge.New() manager.IfName = devName manager.Client = client - manager.NodeId = hostId manager.conf = &conf err := wg.CreateWgInterface(client, devName, port) @@ -88,7 +95,7 @@ func NewCrdtNodeManager(meshId, hostId, devName string, port int, conf conf.WgMe return &manager, nil } -func (m *CrdtNodeManager) convertMeshNode(node MeshNodeCrdt) (*wgtypes.PeerConfig, error) { +func (m *CrdtMeshManager) convertMeshNode(node MeshNodeCrdt) (*wgtypes.PeerConfig, error) { peerEndpoint, err := net.ResolveUDPAddr("udp", node.WgEndpoint) if err != nil { @@ -125,45 +132,7 @@ func (m *CrdtNodeManager) convertMeshNode(node MeshNodeCrdt) (*wgtypes.PeerConfi return &peerConfig, nil } -func (m1 *MeshNodeCrdt) Compare(m2 *MeshNodeCrdt) int { - return strings.Compare(m1.PublicKey, m2.PublicKey) -} - -func (c *CrdtNodeManager) changeFailedCount(meshId, endpoint string, incAmount int64) error { - node, err := c.doc.Path("nodes").Map().Get(endpoint) - - if err != nil { - return err - } - - counterMap, err := node.Map().Get("failedMap") - - if counterMap.Kind() == automerge.KindVoid { - return errors.New("Something went wrong map does not exist") - } - - counter, _ := counterMap.Map().Get(c.NodeId) - - if counter.Kind() == automerge.KindVoid { - err = counterMap.Map().Set(c.NodeId, incAmount) - } else { - if counter.Int64()+incAmount < 0 { - return nil - } - - err = counterMap.Map().Set(c.NodeId, counter.Int64()+1) - } - - return err -} - -// Increment failed count increments the number of times we have attempted -// to contact the node and it's failed -func (c *CrdtNodeManager) IncrementFailedCount(endpoint string) error { - return c.changeFailedCount(c.MeshId, endpoint, +1) -} - -func (c *CrdtNodeManager) removeNode(endpoint string) error { +func (c *CrdtMeshManager) removeNode(endpoint string) error { err := c.doc.Path("nodes").Map().Delete(endpoint) if err != nil { @@ -173,14 +142,8 @@ func (c *CrdtNodeManager) removeNode(endpoint string) error { return nil } -// Decrement failed count decrements the number of times we have attempted to -// contact the node and it's failed -func (c *CrdtNodeManager) DecrementFailedCount(endpoint string) error { - return c.changeFailedCount(c.MeshId, endpoint, -1) -} - // GetNode: returns a mesh node crdt. -func (m *CrdtNodeManager) GetNode(endpoint string) (*MeshNodeCrdt, error) { +func (m *CrdtMeshManager) GetNode(endpoint string) (*MeshNodeCrdt, error) { node, err := m.doc.Path("nodes").Map().Get(endpoint) if err != nil { @@ -196,11 +159,11 @@ func (m *CrdtNodeManager) GetNode(endpoint string) (*MeshNodeCrdt, error) { return meshNode, nil } -func (m *CrdtNodeManager) Length() int { +func (m *CrdtMeshManager) Length() int { return m.doc.Path("nodes").Map().Len() } -func (m *CrdtNodeManager) GetDevice() (*wgtypes.Device, error) { +func (m *CrdtMeshManager) GetDevice() (*wgtypes.Device, error) { dev, err := m.Client.Device(m.IfName) if err != nil { @@ -211,7 +174,7 @@ func (m *CrdtNodeManager) GetDevice() (*wgtypes.Device, error) { } // HasChanges returns true if we have changes since the last time we synced -func (m *CrdtNodeManager) HasChanges() bool { +func (m *CrdtMeshManager) HasChanges() bool { changes, err := m.doc.Changes(m.LastHash) logging.Log.WriteInfof("Changes %s", m.LastHash.String()) @@ -224,34 +187,11 @@ func (m *CrdtNodeManager) HasChanges() bool { return len(changes) > 0 } -func (m *CrdtNodeManager) HasFailed(endpoint string) bool { - node, err := m.GetNode(endpoint) - - if err != nil { - logging.Log.WriteErrorf("Cannot get node node: %s\n", endpoint) - return true - } - - values, err := node.FailedMap.Values() - - if err != nil { - return true - } - - countFailed := 0 - - for _, value := range values { - count := value.Int64() - - if count >= 1 { - countFailed++ - } - } - - return countFailed >= 4 +func (m *CrdtMeshManager) HasFailed(endpoint string) bool { + return false } -func (m *CrdtNodeManager) SaveChanges() { +func (m *CrdtMeshManager) SaveChanges() { hashes := m.doc.Heads() hash := hashes[len(hashes)-1] @@ -259,13 +199,17 @@ func (m *CrdtNodeManager) SaveChanges() { m.LastHash = hash } -func (m *CrdtNodeManager) UpdateTimeStamp() error { - node, err := m.doc.Path("nodes").Map().Get(m.NodeId) +func (m *CrdtMeshManager) UpdateTimeStamp(nodeId string) error { + node, err := m.doc.Path("nodes").Map().Get(nodeId) if err != nil { return err } + if node.Kind() != automerge.KindMap { + return errors.New("node is not a map") + } + err = node.Map().Set("timestamp", time.Now().Unix()) if err == nil { @@ -275,7 +219,32 @@ func (m *CrdtNodeManager) UpdateTimeStamp() error { return err } -func (m *CrdtNodeManager) updateWgConf(devName string, nodes map[string]MeshNodeCrdt, client wgctrl.Client) error { +// AddRoutes: adds routes to the specific nodeId +func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...string) error { + nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId) + + if err != nil { + return err + } + + routeMap, err := nodeVal.Map().Get("routes") + + if err != nil { + return err + } + + for _, route := range routes { + err = routeMap.Map().Set(route, struct{}{}) + + if err != nil { + return err + } + } + + return nil +} + +func (m *CrdtMeshManager) updateWgConf(devName string, nodes map[string]MeshNodeCrdt, client wgctrl.Client) error { peerConfigs := make([]wgtypes.PeerConfig, len(nodes)) var count int = 0 @@ -300,35 +269,58 @@ func (m *CrdtNodeManager) updateWgConf(devName string, nodes map[string]MeshNode return nil } -// AddRoutes: adds routes to the specific nodeId -func (m *CrdtNodeManager) AddRoutes(routes ...string) error { - nodeVal, err := m.doc.Path("nodes").Map().Get(m.NodeId) - - if err != nil { - return err - } - - routeMap, err := nodeVal.Map().Get("routes") - - if err != nil { - return err - } - - for _, route := range routes { - err = routeMap.Map().Set(route, struct{}{}) - - if err != nil { - return err - } - } - - return nil -} - -func (m *CrdtNodeManager) GetSyncer() *AutomergeSync { +func (m *CrdtMeshManager) GetSyncer() mesh.MeshSyncer { return NewAutomergeSync(m) } -func (n *MeshNodeCrdt) GetEscapedIP() string { - return fmt.Sprintf("\"%s\"", n.WgHost) +func (m1 *MeshNodeCrdt) Compare(m2 *MeshNodeCrdt) int { + return strings.Compare(m1.PublicKey, m2.PublicKey) +} + +func (m *MeshNodeCrdt) GetHostEndpoint() string { + return m.HostEndpoint +} + +func (m *MeshNodeCrdt) GetPublicKey() (wgtypes.Key, error) { + return wgtypes.ParseKey(m.PublicKey) +} + +func (m *MeshNodeCrdt) GetWgEndpoint() string { + return m.HostEndpoint +} + +func (m *MeshNodeCrdt) GetWgHost() *net.IPNet { + _, ipnet, err := net.ParseCIDR(m.WgHost) + + if err != nil { + logging.Log.WriteErrorf("Cannot parse WgHost %s", err.Error()) + return nil + } + + return ipnet +} + +func (m *MeshNodeCrdt) GetTimeStamp() int64 { + return m.Timestamp +} + +func (m *MeshNodeCrdt) GetRoutes() []string { + return lib.MapKeys(m.Routes) +} + +func (m *MeshCrdt) GetNodes() map[string]mesh.MeshNode { + nodes := make(map[string]mesh.MeshNode) + + for _, node := range m.Nodes { + nodes[node.HostEndpoint] = &MeshNodeCrdt{ + HostEndpoint: node.HostEndpoint, + WgEndpoint: node.WgEndpoint, + PublicKey: node.PublicKey, + WgHost: node.WgHost, + Timestamp: node.Timestamp, + Routes: node.Routes, + } + } + + return nodes } diff --git a/pkg/automerge/automerge_sync.go b/pkg/automerge/automerge_sync.go index d99a1af..1c6de90 100644 --- a/pkg/automerge/automerge_sync.go +++ b/pkg/automerge/automerge_sync.go @@ -7,7 +7,7 @@ import ( type AutomergeSync struct { state *automerge.SyncState - manager *CrdtNodeManager + manager *CrdtMeshManager } func (a *AutomergeSync) GenerateMessage() ([]byte, bool) { @@ -35,7 +35,7 @@ func (a *AutomergeSync) Complete() { a.manager.SaveChanges() } -func NewAutomergeSync(manager *CrdtNodeManager) *AutomergeSync { +func NewAutomergeSync(manager *CrdtMeshManager) *AutomergeSync { return &AutomergeSync{ state: automerge.NewSyncState(manager.doc), manager: manager, diff --git a/pkg/automerge/automergefactory.go b/pkg/automerge/automergefactory.go new file mode 100644 index 0000000..69fc12d --- /dev/null +++ b/pkg/automerge/automergefactory.go @@ -0,0 +1,10 @@ +package crdt + +import "github.com/tim-beatham/wgmesh/pkg/mesh" + +type CrdtProviderFactory struct{} + +func (f *CrdtProviderFactory) CreateMesh(params *mesh.MeshProviderFactoryParams) (mesh.MeshProvider, error) { + return NewCrdtNodeManager(params.MeshId, params.DevName, params.Port, + *params.Conf, params.Client) +} diff --git a/pkg/automerge/types.go b/pkg/automerge/types.go index 01b2b7a..81b10f4 100644 --- a/pkg/automerge/types.go +++ b/pkg/automerge/types.go @@ -1,7 +1,5 @@ package crdt -import "github.com/automerge/automerge-go" - // MeshNodeCrdt: Represents a CRDT for a mesh nodes type MeshNodeCrdt struct { HostEndpoint string `automerge:"hostEndpoint"` @@ -9,7 +7,6 @@ type MeshNodeCrdt struct { PublicKey string `automerge:"publicKey"` WgHost string `automerge:"wgHost"` Timestamp int64 `automerge:"timestamp"` - FailedMap *automerge.Map `automerge:"failedMap"` Routes map[string]interface{} `automerge:"routes"` } diff --git a/pkg/conf/conf.go b/pkg/conf/conf.go index f3f2fc5..42e330b 100644 --- a/pkg/conf/conf.go +++ b/pkg/conf/conf.go @@ -13,7 +13,11 @@ type WgMeshConfiguration struct { PrivateKeyPath string `yaml:"privateKeyPath"` SkipCertVerification bool `yaml:"skipCertVerification"` GrpcPort string `yaml:"gRPCPort"` - AdvertiseRoutes bool `yaml:"advertiseRoutes"` + // AdvertiseRoutes advertises other meshes if the node is in multiple meshes + AdvertiseRoutes bool `yaml:"advertiseRoutes"` + // PublicEndpoint is the IP in which this computer is publicly reachable. + // usecase is when the node is behind NAT. + PublicEndpoint string `yaml:"publicEndpoint"` } func ParseConfiguration(filePath string) (*WgMeshConfiguration, error) { diff --git a/pkg/ctrlserver/ctrlserver.go b/pkg/ctrlserver/ctrlserver.go index 417a081..e1698bf 100644 --- a/pkg/ctrlserver/ctrlserver.go +++ b/pkg/ctrlserver/ctrlserver.go @@ -1,6 +1,7 @@ package ctrlserver import ( + crdt "github.com/tim-beatham/wgmesh/pkg/automerge" "github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/wgmesh/pkg/conn" "github.com/tim-beatham/wgmesh/pkg/mesh" @@ -21,7 +22,8 @@ type NewCtrlServerParams struct { // operation failed func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) { ctrlServer := new(MeshCtrlServer) - ctrlServer.MeshManager = mesh.NewMeshManager(*params.Conf, params.Client) + factory := crdt.CrdtProviderFactory{} + ctrlServer.MeshManager = mesh.NewMeshManager(*params.Conf, params.Client, &factory) ctrlServer.Conf = params.Conf connManagerParams := conn.NewConnectionManageParams{ diff --git a/pkg/ctrlserver/ctrltypes.go b/pkg/ctrlserver/ctrltypes.go index c6c055b..d107139 100644 --- a/pkg/ctrlserver/ctrltypes.go +++ b/pkg/ctrlserver/ctrltypes.go @@ -16,7 +16,6 @@ type MeshNode struct { WgEndpoint string PublicKey string WgHost string - Failed bool Timestamp int64 Routes []string } @@ -32,7 +31,7 @@ type Mesh struct { */ type MeshCtrlServer struct { Client *wgctrl.Client - MeshManager *mesh.MeshManger + MeshManager *mesh.MeshManager ConnectionManager conn.ConnectionManager ConnectionServer *conn.ConnectionServer Conf *conf.WgMeshConfiguration diff --git a/pkg/ctrlserver/rpchandler.go b/pkg/ctrlserver/rpchandler.go deleted file mode 100644 index 4e3d3ee..0000000 --- a/pkg/ctrlserver/rpchandler.go +++ /dev/null @@ -1,15 +0,0 @@ -/* - * RPC component of the server - */ -package ctrlserver - -import ( - "github.com/tim-beatham/wgmesh/pkg/rpc" - "google.golang.org/grpc" -) - -func NewRpcServer(server rpc.MeshCtrlServerServer) *grpc.Server { - grpc := grpc.NewServer() - rpc.RegisterMeshCtrlServerServer(grpc, server) - return grpc -} diff --git a/pkg/lib/conv.go b/pkg/lib/conv.go index 3ecddbc..1f115aa 100644 --- a/pkg/lib/conv.go +++ b/pkg/lib/conv.go @@ -1,5 +1,9 @@ package lib +import ( + logging "github.com/tim-beatham/wgmesh/pkg/log" +) + // MapToSlice converts a map to a slice in go func MapValues[K comparable, V any](m map[K]V) []V { return MapValuesWithExclude(m, map[K]struct{}{}) @@ -19,6 +23,8 @@ func MapValuesWithExclude[K comparable, V any](m map[K]V, exclude map[K]struct{} continue } + logging.Log.WriteInfof("Key %s", k) + values[i] = v i++ } diff --git a/pkg/lib/ip.go b/pkg/lib/ip.go index b7bf683..ef1524e 100644 --- a/pkg/lib/ip.go +++ b/pkg/lib/ip.go @@ -5,14 +5,13 @@ import ( "net" ) +// GetOutboundIP: gets the oubound IP of this packet func GetOutboundIP() net.IP { conn, err := net.Dial("udp", "8.8.8.8:80") if err != nil { log.Fatal(err) } defer conn.Close() - localAddr := conn.LocalAddr().(*net.UDPAddr) - return localAddr.IP } diff --git a/pkg/mesh/graph_generator.go b/pkg/mesh/graphgenerator.go similarity index 55% rename from pkg/mesh/graph_generator.go rename to pkg/mesh/graphgenerator.go index 661671e..15a3ce9 100644 --- a/pkg/mesh/graph_generator.go +++ b/pkg/mesh/graphgenerator.go @@ -8,13 +8,14 @@ import ( "github.com/tim-beatham/wgmesh/pkg/lib" ) +// MeshGraphConverter converts a mesh to a graph type MeshGraphConverter interface { // convert the mesh to textual form Generate(meshId string) (string, error) } type MeshDOTConverter struct { - manager *MeshManger + manager *MeshManager } func (c *MeshDOTConverter) Generate(meshId string) (string, error) { @@ -26,35 +27,33 @@ func (c *MeshDOTConverter) Generate(meshId string) (string, error) { g := graph.NewGraph(meshId, graph.GRAPH) - snapshot, err := mesh.GetCrdt() + snapshot, err := mesh.GetMesh() if err != nil { return "", err } - for _, node := range snapshot.Nodes { - g.AddNode(node.GetEscapedIP()) + for _, node := range snapshot.GetNodes() { + g.AddNode(fmt.Sprintf("\"%s\"", node.GetWgHost().IP.String())) } - nodes := lib.MapValues(snapshot.Nodes) + nodes := lib.MapValues(snapshot.GetNodes()) for i, node1 := range nodes[:len(nodes)-1] { - if mesh.HasFailed(node1.HostEndpoint) { - continue - } - for _, node2 := range nodes[i+1:] { - if node1.WgEndpoint == node2.WgEndpoint || mesh.HasFailed(node2.HostEndpoint) { + if node1.GetWgEndpoint() == node2.GetWgEndpoint() { continue } - g.AddEdge(fmt.Sprintf("%s to %s", node1.GetEscapedIP(), node2.GetEscapedIP()), node1.GetEscapedIP(), node2.GetEscapedIP()) + node1Id := fmt.Sprintf("\"%s\"", node1.GetWgHost().IP.String()) + node2Id := fmt.Sprintf("\"%s\"", node2.GetWgHost().IP.String()) + g.AddEdge(fmt.Sprintf("%s to %s", node1Id, node2Id), node1Id, node2Id) } } return g.GetDOT() } -func NewMeshDotConverter(m *MeshManger) MeshGraphConverter { +func NewMeshDotConverter(m *MeshManager) MeshGraphConverter { return &MeshDOTConverter{manager: m} } diff --git a/pkg/mesh/mesh_manager.go b/pkg/mesh/mesh_manager.go deleted file mode 100644 index 3ccf481..0000000 --- a/pkg/mesh/mesh_manager.go +++ /dev/null @@ -1,159 +0,0 @@ -package mesh - -import ( - "errors" - "fmt" - - crdt "github.com/tim-beatham/wgmesh/pkg/automerge" - "github.com/tim-beatham/wgmesh/pkg/conf" - "github.com/tim-beatham/wgmesh/pkg/lib" - "github.com/tim-beatham/wgmesh/pkg/wg" - "golang.zx2c4.com/wireguard/wgctrl" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" -) - -type MeshManger struct { - Meshes map[string]*crdt.CrdtNodeManager - RouteManager RouteManager - Client *wgctrl.Client - HostEndpoint string - conf *conf.WgMeshConfiguration -} - -func (m *MeshManger) MeshExists(meshId string) bool { - _, inMesh := m.Meshes[meshId] - return inMesh -} - -// CreateMesh: Creates a new mesh, stores it and returns the mesh id -func (m *MeshManger) CreateMesh(devName string, port int) (string, error) { - key, err := wgtypes.GenerateKey() - - if err != nil { - return "", err - } - - nodeManager, err := crdt.NewCrdtNodeManager(key.String(), m.HostEndpoint, devName, port, *m.conf, m.Client) - - if err != nil { - return "", err - } - - m.Meshes[key.String()] = nodeManager - - return key.String(), err -} - -// AddMesh: Add the mesh to the list of meshes -func (m *MeshManger) AddMesh(meshId string, devName string, port int, meshBytes []byte) error { - mesh, err := crdt.NewCrdtNodeManager(meshId, m.HostEndpoint, devName, port, *m.conf, m.Client) - - if err != nil { - return err - } - - err = mesh.Load(meshBytes) - - if err != nil { - return err - } - - m.Meshes[meshId] = mesh - return nil -} - -// AddMeshNode: Add a mesh node -func (m *MeshManger) AddMeshNode(meshId string, node crdt.MeshNodeCrdt) { - m.Meshes[meshId].AddNode(node) - - if m.conf.AdvertiseRoutes { - m.RouteManager.UpdateRoutes() - } -} - -func (m *MeshManger) HasChanges(meshId string) bool { - return m.Meshes[meshId].HasChanges() -} - -func (m *MeshManger) GetMesh(meshId string) *crdt.CrdtNodeManager { - theMesh, _ := m.Meshes[meshId] - return theMesh -} - -// EnableInterface: Enables the given WireGuard interface. -func (s *MeshManger) EnableInterface(meshId string) error { - mesh, contains := s.Meshes[meshId] - - if !contains { - return errors.New("Mesh does not exist") - } - - crdt, err := mesh.GetCrdt() - - if err != nil { - return err - } - - node, contains := crdt.Nodes[s.HostEndpoint] - - if !contains { - return errors.New("Node does not exist in the mesh") - } - - err = mesh.ApplyWg() - - if err != nil { - return err - } - - err = wg.EnableInterface(mesh.IfName, node.WgHost) - - if s.conf.AdvertiseRoutes { - s.RouteManager.ApplyWg(mesh) - } - - return nil -} - -// GetPublicKey: Gets the public key of the WireGuard mesh -func (s *MeshManger) GetPublicKey(meshId string) (*wgtypes.Key, error) { - mesh, ok := s.Meshes[meshId] - - if !ok { - return nil, errors.New("mesh does not exist") - } - - dev, err := mesh.GetDevice() - - if err != nil { - return nil, err - } - - return &dev.PublicKey, nil -} - -// UpdateTimeStamp updates the timestamp of this node in all meshes -func (s *MeshManger) UpdateTimeStamp() error { - for _, mesh := range s.Meshes { - err := mesh.UpdateTimeStamp() - - if err != nil { - return err - } - } - - return nil -} - -func NewMeshManager(conf conf.WgMeshConfiguration, client *wgctrl.Client) *MeshManger { - ip := lib.GetOutboundIP() - m := &MeshManger{ - Meshes: make(map[string]*crdt.CrdtNodeManager), - HostEndpoint: fmt.Sprintf("%s:%s", ip.String(), conf.GrpcPort), - Client: client, - conf: &conf, - } - - m.RouteManager = NewRouteManager(m) - return m -} diff --git a/pkg/mesh/meshconfig.go b/pkg/mesh/meshconfig.go new file mode 100644 index 0000000..f42ab33 --- /dev/null +++ b/pkg/mesh/meshconfig.go @@ -0,0 +1,100 @@ +package mesh + +import ( + "net" + + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +// MeshConfigApplyer abstracts applying the mesh configuration +type MeshConfigApplyer interface { + ApplyConfig() error +} + +// WgMeshConfigApplyer applies WireGuard configuration +type WgMeshConfigApplyer struct { + meshManager *MeshManager +} + +func ConvertMeshNode(node MeshNode) (*wgtypes.PeerConfig, error) { + endpoint, err := net.ResolveUDPAddr("udp", node.GetWgEndpoint()) + + if err != nil { + return nil, err + } + + pubKey, err := node.GetPublicKey() + + if err != nil { + return nil, err + } + + allowedips := make([]net.IPNet, 1) + allowedips[0] = *node.GetWgHost() + + for _, route := range node.GetRoutes() { + _, ipnet, _ := net.ParseCIDR(route) + allowedips = append(allowedips, *ipnet) + } + + peerConfig := wgtypes.PeerConfig{ + PublicKey: pubKey, + Endpoint: endpoint, + AllowedIPs: allowedips, + } + + return &peerConfig, nil +} + +func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error { + snap, err := mesh.GetMesh() + + if err != nil { + return err + } + + nodes := snap.GetNodes() + peerConfigs := make([]wgtypes.PeerConfig, len(nodes)) + + var count int = 0 + + for _, n := range nodes { + peer, err := ConvertMeshNode(n) + + if err != nil { + return err + } + + peerConfigs[count] = *peer + count++ + } + + cfg := wgtypes.Config{ + Peers: peerConfigs, + ReplacePeers: true, + } + + dev, err := mesh.GetDevice() + + if err != nil { + return err + } + + return m.meshManager.Client.ConfigureDevice(dev.Name, cfg) +} + +func (m *WgMeshConfigApplyer) ApplyConfig() error { + for _, mesh := range m.meshManager.Meshes { + err := m.updateWgConf(mesh) + + if err != nil { + return err + } + } + + return nil +} + +func NewWgMeshConfigApplyer(manager *MeshManager) MeshConfigApplyer { + return &WgMeshConfigApplyer{meshManager: manager} +} diff --git a/pkg/mesh/meshinterface.go b/pkg/mesh/meshinterface.go new file mode 100644 index 0000000..a3d214b --- /dev/null +++ b/pkg/mesh/meshinterface.go @@ -0,0 +1,43 @@ +package mesh + +import ( + "errors" + + "github.com/tim-beatham/wgmesh/pkg/wg" +) + +// MeshInterfaces manipulates interfaces to do with meshes +type MeshInterface interface { + EnableInterface(meshId string) error +} + +type WgMeshInterface struct { + manager *MeshManager +} + +// EnableInterface enables the interface at the given endpoint +func (m *WgMeshInterface) EnableInterface(meshId string) error { + mesh, ok := m.manager.Meshes[meshId] + + if !ok { + return errors.New("the provided mesh does not exist") + } + + dev, err := mesh.GetDevice() + + if err != nil { + return err + } + + self, err := m.manager.GetSelf(meshId) + + if err != nil { + return err + } + + return wg.EnableInterface(dev.Name, self.GetWgHost().String()) +} + +func NewWgMeshInterface(manager *MeshManager) MeshInterface { + return &WgMeshInterface{manager: manager} +} diff --git a/pkg/mesh/meshmanager.go b/pkg/mesh/meshmanager.go new file mode 100644 index 0000000..635aa57 --- /dev/null +++ b/pkg/mesh/meshmanager.go @@ -0,0 +1,179 @@ +package mesh + +import ( + "errors" + "fmt" + + "github.com/tim-beatham/wgmesh/pkg/conf" + "github.com/tim-beatham/wgmesh/pkg/lib" + logging "github.com/tim-beatham/wgmesh/pkg/log" + "golang.zx2c4.com/wireguard/wgctrl" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +type MeshManager struct { + Meshes map[string]MeshProvider + RouteManager RouteManager + Client *wgctrl.Client + // HostParameters contains information that uniquely locates + // the node in the mesh network. + HostParameters *HostParameters + conf *conf.WgMeshConfiguration + meshProviderFactory MeshProviderFactory + configApplyer MeshConfigApplyer + interfaceEnabler MeshInterface +} + +// CreateMesh: Creates a new mesh, stores it and returns the mesh id +func (m *MeshManager) CreateMesh(devName string, port int) (string, error) { + key, err := wgtypes.GenerateKey() + + if err != nil { + return "", err + } + + nodeManager, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{ + DevName: devName, + Port: port, + Conf: m.conf, + Client: m.Client, + MeshId: key.String(), + }) + + if err != nil { + return "", err + } + + m.Meshes[key.String()] = nodeManager + + return key.String(), err +} + +// AddMesh: Add the mesh to the list of meshes +func (m *MeshManager) AddMesh(meshId string, devName string, port int, meshBytes []byte) error { + meshProvider, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{ + DevName: devName, + Port: port, + Conf: m.conf, + Client: m.Client, + MeshId: meshId, + }) + + if err != nil { + return err + } + + err = meshProvider.Load(meshBytes) + + if err != nil { + return err + } + + m.Meshes[meshId] = meshProvider + return nil +} + +// AddMeshNode: Add a mesh node +func (m *MeshManager) AddMeshNode(meshId string, node MeshNode) { + m.Meshes[meshId].AddNode(node) +} + +// HasChanges returns true if the mesh has changes +func (m *MeshManager) HasChanges(meshId string) bool { + return m.Meshes[meshId].HasChanges() +} + +// GetMesh returns the mesh with the given meshid +func (m *MeshManager) GetMesh(meshId string) MeshProvider { + theMesh, _ := m.Meshes[meshId] + return theMesh +} + +// EnableInterface: Enables the given WireGuard interface. +func (s *MeshManager) EnableInterface(meshId string) error { + err := s.configApplyer.ApplyConfig() + + if err != nil { + return err + } + + return s.interfaceEnabler.EnableInterface(meshId) +} + +// GetPublicKey: Gets the public key of the WireGuard mesh +func (s *MeshManager) GetPublicKey(meshId string) (*wgtypes.Key, error) { + mesh, ok := s.Meshes[meshId] + + if !ok { + return nil, errors.New("mesh does not exist") + } + + dev, err := mesh.GetDevice() + + if err != nil { + return nil, err + } + + return &dev.PublicKey, nil +} + +func (s *MeshManager) GetSelf(meshId string) (MeshNode, error) { + meshInstance, ok := s.Meshes[meshId] + + if !ok { + return nil, errors.New(fmt.Sprintf("mesh %s does not exist", meshId)) + } + + snapshot, err := meshInstance.GetMesh() + + if err != nil { + return nil, err + } + + node, ok := snapshot.GetNodes()[s.HostParameters.HostEndpoint] + + if !ok { + return nil, errors.New("the node doesn't exist in the mesh") + } + + return node, nil +} + +// UpdateTimeStamp updates the timestamp of this node in all meshes +func (s *MeshManager) UpdateTimeStamp() error { + for _, mesh := range s.Meshes { + err := mesh.UpdateTimeStamp(s.HostParameters.HostEndpoint) + + if err != nil { + return err + } + } + + return nil +} + +// Creates a new instance of a mesh manager with the given parameters +func NewMeshManager(conf conf.WgMeshConfiguration, client *wgctrl.Client, meshProvider MeshProviderFactory) *MeshManager { + hostParams := HostParameters{} + + switch conf.PublicEndpoint { + case "": + hostParams.HostEndpoint = fmt.Sprintf("%s:%s", lib.GetOutboundIP().String(), conf.GrpcPort) + default: + hostParams.HostEndpoint = fmt.Sprintf("%s:%s", conf.PublicEndpoint, conf.GrpcPort) + } + + logging.Log.WriteInfof("Endpoint %s", hostParams.HostEndpoint) + + m := &MeshManager{ + Meshes: make(map[string]MeshProvider), + HostParameters: &hostParams, + meshProviderFactory: meshProvider, + Client: client, + conf: &conf, + } + m.configApplyer = NewWgMeshConfigApplyer(m) + m.RouteManager = NewRouteManager(m) + m.interfaceEnabler = NewWgMeshInterface(m) + return m +} diff --git a/pkg/mesh/route_manager.go b/pkg/mesh/route_manager.go deleted file mode 100644 index f3d8c0c..0000000 --- a/pkg/mesh/route_manager.go +++ /dev/null @@ -1,78 +0,0 @@ -package mesh - -import ( - "net" - - crdt "github.com/tim-beatham/wgmesh/pkg/automerge" - "github.com/tim-beatham/wgmesh/pkg/ip" - logging "github.com/tim-beatham/wgmesh/pkg/log" - "github.com/tim-beatham/wgmesh/pkg/route" -) - -type RouteManager interface { - UpdateRoutes() error - ApplyWg(mesh *crdt.CrdtNodeManager) error -} - -type RouteManagerImpl struct { - meshManager *MeshManger - routeInstaller route.RouteInstaller -} - -func (r *RouteManagerImpl) UpdateRoutes() error { - meshes := r.meshManager.Meshes - ulaBuilder := new(ip.ULABuilder) - - for _, mesh1 := range meshes { - for _, mesh2 := range meshes { - if mesh1 == mesh2 { - continue - } - - ipNet, err := ulaBuilder.GetIPNet(mesh2.MeshId) - - if err != nil { - logging.Log.WriteErrorf(err.Error()) - return err - } - - mesh1.AddRoutes(ipNet.String()) - } - } - - return nil -} - -func (r *RouteManagerImpl) ApplyWg(mesh *crdt.CrdtNodeManager) error { - snapshot, err := mesh.GetCrdt() - - if err != nil { - return err - } - - for _, node := range snapshot.Nodes { - if node.HostEndpoint == r.meshManager.HostEndpoint { - continue - } - - for route, _ := range node.Routes { - _, netIP, err := net.ParseCIDR(route) - - if err != nil { - return err - } - - err = r.routeInstaller.InstallRoutes(mesh.IfName, netIP) - - if err != nil { - return err - } - } - } - - return nil -} - -func NewRouteManager(m *MeshManger) RouteManager { - return &RouteManagerImpl{meshManager: m, routeInstaller: route.NewRouteInstaller()} -} diff --git a/pkg/mesh/routemanager.go b/pkg/mesh/routemanager.go new file mode 100644 index 0000000..6bff5e0 --- /dev/null +++ b/pkg/mesh/routemanager.go @@ -0,0 +1,73 @@ +package mesh + +import ( + "github.com/tim-beatham/wgmesh/pkg/route" +) + +type RouteManager interface { + UpdateRoutes() error + ApplyWg() error +} + +type RouteManagerImpl struct { + meshManager *MeshManager + routeInstaller route.RouteInstaller +} + +func (r *RouteManagerImpl) UpdateRoutes() error { + // // meshes := r.meshManager.Meshes + // // ulaBuilder := new(ip.ULABuilder) + + // for _, mesh1 := range meshes { + // for _, mesh2 := range meshes { + // if mesh1 == mesh2 { + // continue + // } + + // ipNet, err := ulaBuilder.GetIPNet(mesh2.MeshId) + + // if err != nil { + // logging.Log.WriteErrorf(err.Error()) + // return err + // } + + // mesh1.AddRoutes(ipNet.String()) + // } + // } + + return nil +} + +func (r *RouteManagerImpl) ApplyWg() error { + // snapshot, err := mesh.GetMesh() + + // if err != nil { + // return err + // } + + // for _, node := range snapshot.Nodes { + // if node.HostEndpoint == r.meshManager.HostEndpoint { + // continue + // } + + // for route, _ := range node.Routes { + // _, netIP, err := net.ParseCIDR(route) + + // if err != nil { + // return err + // } + + // err = r.routeInstaller.InstallRoutes(mesh.IfName, netIP) + + // if err != nil { + // return err + // } + // } + // } + + return nil +} + +func NewRouteManager(m *MeshManager) RouteManager { + return &RouteManagerImpl{meshManager: m, routeInstaller: route.NewRouteInstaller()} +} diff --git a/pkg/mesh/types.go b/pkg/mesh/types.go new file mode 100644 index 0000000..36d8196 --- /dev/null +++ b/pkg/mesh/types.go @@ -0,0 +1,84 @@ +// mesh provides implementation agnostic logic for managing +// the mesh +package mesh + +import ( + "net" + + "github.com/tim-beatham/wgmesh/pkg/conf" + "golang.zx2c4.com/wireguard/wgctrl" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +// MeshNode represents an implementation of a node in a mesh +type MeshNode interface { + // GetHostEndpoint: gets the gRPC endpoint of the node + GetHostEndpoint() string + // GetPublicKey: gets the public key of the node + GetPublicKey() (wgtypes.Key, error) + // GetWgEndpoint(): get IP and port of the wireguard endpoint + GetWgEndpoint() string + // GetWgHost: get the IP address of the WireGuard node + GetWgHost() *net.IPNet + // GetTimestamp: get the UNIX time stamp of the ndoe + GetTimeStamp() int64 + // GetRoutes: returns the routes that the nodes provides + GetRoutes() []string +} + +type MeshSnapshot interface { + // GetNodes() returns the nodes in the mesh + GetNodes() map[string]MeshNode +} + +// MeshSyncer syncs two meshes +type MeshSyncer interface { + GenerateMessage() ([]byte, bool) + RecvMessage(mesg []byte) error + Complete() +} + +// Mesh: Represents an implementation of a mesh +type MeshProvider interface { + // AddNode() adds a node to the mesh + AddNode(node MeshNode) + // GetMesh() returns a snapshot of the mesh provided by the mesh provider + GetMesh() (MeshSnapshot, error) + // GetMeshId() returns the ID of the mesh network + GetMeshId() string + // Save() saves the mesh network + Save() []byte + // Load() loads a mesh network + Load([]byte) error + // GetDevice() get the device corresponding with the mesh + GetDevice() (*wgtypes.Device, error) + // HasChanges returns true if we have changes since last time we synced + HasChanges() bool + // Record that we have changges and save the corresponding changes + SaveChanges() + // UpdateTimeStamp: update the timestamp of the given node + UpdateTimeStamp(nodeId string) error + // AddRoutes: adds routes to the given node + AddRoutes(nodeId string, route ...string) error + GetSyncer() MeshSyncer +} + +// HostParameters contains the IDs of a node +type HostParameters struct { + HostEndpoint string + // TODO: Contain the WireGuard identifier in this +} + +// MeshProviderFactoryParams parameters required to build a mesh provider +type MeshProviderFactoryParams struct { + DevName string + MeshId string + Port int + Conf *conf.WgMeshConfiguration + Client *wgctrl.Client +} + +// MeshProviderFactory creates an instance of a mesh provider +type MeshProviderFactory interface { + CreateMesh(params *MeshProviderFactoryParams) (MeshProvider, error) +} diff --git a/pkg/robin/robin_requester.go b/pkg/robin/requester.go similarity index 71% rename from pkg/robin/robin_requester.go rename to pkg/robin/requester.go index a180b08..a4d1787 100644 --- a/pkg/robin/robin_requester.go +++ b/pkg/robin/requester.go @@ -18,12 +18,12 @@ import ( "github.com/tim-beatham/wgmesh/pkg/wg" ) -type RobinIpc struct { +type IpcHandler struct { Server *ctrlserver.MeshCtrlServer ipAllocator ip.IPAllocator } -func (n *RobinIpc) CreateMesh(args *ipc.NewMeshArgs, reply *string) error { +func (n *IpcHandler) CreateMesh(args *ipc.NewMeshArgs, reply *string) error { wg.CreateInterface(args.IfName) meshId, err := n.Server.MeshManager.CreateMesh(args.IfName, args.WgPort) @@ -54,7 +54,7 @@ func (n *RobinIpc) CreateMesh(args *ipc.NewMeshArgs, reply *string) error { Routes: map[string]interface{}{}, } - n.Server.MeshManager.AddMeshNode(meshId, meshNode) + n.Server.MeshManager.AddMeshNode(meshId, &meshNode) if err != nil { return err @@ -64,12 +64,12 @@ func (n *RobinIpc) CreateMesh(args *ipc.NewMeshArgs, reply *string) error { return nil } -func (n *RobinIpc) ListMeshes(_ string, reply *ipc.ListMeshReply) error { +func (n *IpcHandler) ListMeshes(_ string, reply *ipc.ListMeshReply) error { meshNames := make([]string, len(n.Server.MeshManager.Meshes)) i := 0 - for _, mesh := range n.Server.MeshManager.Meshes { - meshNames[i] = mesh.MeshId + for meshId, _ := range n.Server.MeshManager.Meshes { + meshNames[i] = meshId i++ } @@ -77,7 +77,7 @@ func (n *RobinIpc) ListMeshes(_ string, reply *ipc.ListMeshReply) error { return nil } -func (n *RobinIpc) JoinMesh(args ipc.JoinMeshArgs, reply *string) error { +func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error { peerConnection, err := n.Server.ConnectionManager.GetConnection(args.IpAdress) client, err := peerConnection.GetClient() @@ -130,47 +130,51 @@ func (n *RobinIpc) JoinMesh(args ipc.JoinMeshArgs, reply *string) error { WgHost: ipAddr.String() + "/128", Routes: make(map[string]interface{}), } - - n.Server.MeshManager.AddMeshNode(args.MeshId, node) + n.Server.MeshManager.AddMeshNode(args.MeshId, &node) *reply = strconv.FormatBool(true) return nil } -func (n *RobinIpc) GetMesh(meshId string, reply *ipc.GetMeshReply) error { +func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error { mesh := n.Server.MeshManager.GetMesh(meshId) - meshSnapshot, err := mesh.GetCrdt() + meshSnapshot, err := mesh.GetMesh() if err != nil { return err } - if mesh != nil { - nodes := make([]ctrlserver.MeshNode, len(meshSnapshot.Nodes)) - - i := 0 - for _, node := range meshSnapshot.Nodes { - node := ctrlserver.MeshNode{ - HostEndpoint: node.HostEndpoint, - WgEndpoint: node.WgEndpoint, - PublicKey: node.PublicKey, - WgHost: node.WgHost, - Failed: mesh.HasFailed(node.HostEndpoint), - Timestamp: node.Timestamp, - Routes: lib.MapKeys(node.Routes), - } - - nodes[i] = node - i += 1 - } - - *reply = ipc.GetMeshReply{Nodes: nodes} - } else { + if mesh == nil { return errors.New("mesh does not exist") } + nodes := make([]ctrlserver.MeshNode, len(meshSnapshot.GetNodes())) + + i := 0 + for _, node := range meshSnapshot.GetNodes() { + pubKey, _ := node.GetPublicKey() + + if err != nil { + return err + } + + node := ctrlserver.MeshNode{ + HostEndpoint: node.GetHostEndpoint(), + WgEndpoint: node.GetWgEndpoint(), + PublicKey: pubKey.String(), + WgHost: node.GetWgHost().String(), + Timestamp: node.GetTimeStamp(), + Routes: node.GetRoutes(), + } + + nodes[i] = node + i += 1 + } + + *reply = ipc.GetMeshReply{Nodes: nodes} + return nil } -func (n *RobinIpc) EnableInterface(meshId string, reply *string) error { +func (n *IpcHandler) EnableInterface(meshId string, reply *string) error { err := n.Server.MeshManager.EnableInterface(meshId) if err != nil { @@ -182,7 +186,7 @@ func (n *RobinIpc) EnableInterface(meshId string, reply *string) error { return nil } -func (n *RobinIpc) GetDOT(meshId string, reply *string) error { +func (n *IpcHandler) GetDOT(meshId string, reply *string) error { g := mesh.NewMeshDotConverter(n.Server.MeshManager) result, err := g.Generate(meshId) @@ -200,8 +204,8 @@ type RobinIpcParams struct { Allocator ip.IPAllocator } -func NewRobinIpc(ipcParams RobinIpcParams) RobinIpc { - return RobinIpc{ +func NewRobinIpc(ipcParams RobinIpcParams) IpcHandler { + return IpcHandler{ Server: ipcParams.CtrlServer, ipAllocator: ipcParams.Allocator, } diff --git a/pkg/robin/robin_responder.go b/pkg/robin/responder.go similarity index 79% rename from pkg/robin/robin_responder.go rename to pkg/robin/responder.go index 9fe1b8c..ade93ae 100644 --- a/pkg/robin/robin_responder.go +++ b/pkg/robin/responder.go @@ -8,7 +8,7 @@ import ( "github.com/tim-beatham/wgmesh/pkg/rpc" ) -type RobinRpc struct { +type WgRpc struct { rpc.UnimplementedMeshCtrlServerServer Server *ctrlserver.MeshCtrlServer } @@ -36,7 +36,7 @@ func nodesToRpcNodes(nodes map[string]ctrlserver.MeshNode) []*rpc.MeshNode { return meshNodes } -func (m *RobinRpc) GetMesh(ctx context.Context, request *rpc.GetMeshRequest) (*rpc.GetMeshReply, error) { +func (m *WgRpc) GetMesh(ctx context.Context, request *rpc.GetMeshRequest) (*rpc.GetMeshReply, error) { mesh := m.Server.MeshManager.GetMesh(request.MeshId) if mesh == nil { @@ -52,6 +52,6 @@ func (m *RobinRpc) GetMesh(ctx context.Context, request *rpc.GetMeshRequest) (*r return &reply, nil } -func (m *RobinRpc) JoinMesh(ctx context.Context, request *rpc.JoinMeshRequest) (*rpc.JoinMeshReply, error) { +func (m *WgRpc) JoinMesh(ctx context.Context, request *rpc.JoinMeshRequest) (*rpc.JoinMeshReply, error) { return &rpc.JoinMeshReply{Success: true}, nil } diff --git a/pkg/rpc/rpc.go b/pkg/rpc/rpc.go deleted file mode 100644 index cec334d..0000000 --- a/pkg/rpc/rpc.go +++ /dev/null @@ -1,9 +0,0 @@ -package rpc - -import grpc "google.golang.org/grpc" - -func NewRpcServer(rpcServer *grpc.Server, server MeshCtrlServerServer, auth AuthenticationServer) *grpc.Server { - RegisterMeshCtrlServerServer(rpcServer, server) - RegisterAuthenticationServer(rpcServer, auth) - return rpcServer -} diff --git a/pkg/sync/syncer.go b/pkg/sync/syncer.go index b8ab5a3..45bb94b 100644 --- a/pkg/sync/syncer.go +++ b/pkg/sync/syncer.go @@ -18,13 +18,12 @@ type Syncer interface { } type SyncerImpl struct { - manager *mesh.MeshManger + manager *mesh.MeshManager requester SyncRequester authenticatedNodes []crdt.MeshNodeCrdt } -const subSetLength = 5 -const maxAuthentications = 30 +const subSetLength = 3 // Sync: Sync random nodes func (s *SyncerImpl) Sync(meshId string) error { @@ -39,27 +38,23 @@ func (s *SyncerImpl) Sync(meshId string) error { return errors.New("the provided mesh does not exist") } - snapshot, err := mesh.GetCrdt() + snapshot, err := mesh.GetMesh() if err != nil { return err } - if len(snapshot.Nodes) <= 1 { + nodes := snapshot.GetNodes() + + if len(nodes) <= 1 { return nil } excludedNodes := map[string]struct{}{ - s.manager.HostEndpoint: {}, + s.manager.HostParameters.HostEndpoint: {}, } - for _, node := range snapshot.Nodes { - if mesh.HasFailed(node.HostEndpoint) { - excludedNodes[node.HostEndpoint] = struct{}{} - } - } - - meshNodes := lib.MapValuesWithExclude(snapshot.Nodes, excludedNodes) + meshNodes := lib.MapValuesWithExclude(nodes, excludedNodes) randomSubset := lib.RandomSubsetOfLength(meshNodes, subSetLength) before := time.Now() @@ -71,7 +66,7 @@ func (s *SyncerImpl) Sync(meshId string) error { syncMeshFunc := func() error { defer waitGroup.Done() - err := s.requester.SyncMesh(meshId, n.HostEndpoint) + err := s.requester.SyncMesh(meshId, n.GetHostEndpoint()) return err } @@ -86,8 +81,8 @@ func (s *SyncerImpl) Sync(meshId string) error { // SyncMeshes: Sync all meshes func (s *SyncerImpl) SyncMeshes() error { - for _, m := range s.manager.Meshes { - err := s.Sync(m.MeshId) + for meshId, _ := range s.manager.Meshes { + err := s.Sync(meshId) if err != nil { return err @@ -97,6 +92,6 @@ func (s *SyncerImpl) SyncMeshes() error { return nil } -func NewSyncer(m *mesh.MeshManger, r SyncRequester) Syncer { +func NewSyncer(m *mesh.MeshManager, r SyncRequester) Syncer { return &SyncerImpl{manager: m, requester: r} } diff --git a/pkg/sync/syncererror.go b/pkg/sync/syncererror.go index 4cbd655..489613b 100644 --- a/pkg/sync/syncererror.go +++ b/pkg/sync/syncererror.go @@ -7,12 +7,14 @@ import ( "google.golang.org/grpc/status" ) +// SyncErrorHandler: Handles errors when attempting to sync type SyncErrorHandler interface { Handle(meshId string, endpoint string, err error) bool } +// SyncErrorHandlerImpl Is an implementation of the SyncErrorHandler type SyncErrorHandlerImpl struct { - meshManager *mesh.MeshManger + meshManager *mesh.MeshManager } func (s *SyncErrorHandlerImpl) incrementFailedCount(meshId string, endpoint string) bool { @@ -22,12 +24,6 @@ func (s *SyncErrorHandlerImpl) incrementFailedCount(meshId string, endpoint stri return false } - err := mesh.IncrementFailedCount(endpoint) - - if err != nil { - return false - } - return true } @@ -44,6 +40,6 @@ func (s *SyncErrorHandlerImpl) Handle(meshId string, endpoint string, err error) return false } -func NewSyncErrorHandler(m *mesh.MeshManger) SyncErrorHandler { +func NewSyncErrorHandler(m *mesh.MeshManager) SyncErrorHandler { return &SyncErrorHandlerImpl{meshManager: m} } diff --git a/pkg/sync/syncrequester.go b/pkg/sync/syncrequester.go index f2b60f7..5e3b0c3 100644 --- a/pkg/sync/syncrequester.go +++ b/pkg/sync/syncrequester.go @@ -6,9 +6,9 @@ import ( "io" "time" - crdt "github.com/tim-beatham/wgmesh/pkg/automerge" "github.com/tim-beatham/wgmesh/pkg/ctrlserver" logging "github.com/tim-beatham/wgmesh/pkg/log" + "github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/wgmesh/pkg/rpc" ) @@ -94,11 +94,10 @@ func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error { } logging.Log.WriteInfof("Synced with node: %s meshId: %s\n", endpoint, meshId) - mesh.DecrementFailedCount(endpoint) return nil } -func syncMesh(mesh *crdt.CrdtNodeManager, ctx context.Context, client rpc.SyncServiceClient) error { +func syncMesh(mesh mesh.MeshProvider, ctx context.Context, client rpc.SyncServiceClient) error { stream, err := client.SyncMesh(ctx) syncer := mesh.GetSyncer() @@ -110,7 +109,7 @@ func syncMesh(mesh *crdt.CrdtNodeManager, ctx context.Context, client rpc.SyncSe for { msg, moreMessages := syncer.GenerateMessage() - err := stream.Send(&rpc.SyncMeshRequest{MeshId: mesh.MeshId, Changes: msg}) + err := stream.Send(&rpc.SyncMeshRequest{MeshId: mesh.GetMeshId(), Changes: msg}) if err != nil { return err diff --git a/pkg/sync/syncservice.go b/pkg/sync/syncservice.go index 82423b5..f2e9b67 100644 --- a/pkg/sync/syncservice.go +++ b/pkg/sync/syncservice.go @@ -6,8 +6,8 @@ import ( "errors" "io" - crdt "github.com/tim-beatham/wgmesh/pkg/automerge" "github.com/tim-beatham/wgmesh/pkg/ctrlserver" + "github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/wgmesh/pkg/rpc" ) @@ -37,7 +37,7 @@ func (s *SyncServiceImpl) GetConf(context context.Context, request *rpc.GetConfR // SyncMesh: syncs the two streams changes func (s *SyncServiceImpl) SyncMesh(stream rpc.SyncService_SyncMeshServer) error { var meshId = "" - var syncer *crdt.AutomergeSync = nil + var syncer mesh.MeshSyncer = nil for { in, err := stream.Recv() diff --git a/pkg/timestamp/timestamp.go b/pkg/timestamp/timestamp.go index df4503e..36218b5 100644 --- a/pkg/timestamp/timestamp.go +++ b/pkg/timestamp/timestamp.go @@ -14,7 +14,7 @@ type TimestampScheduler interface { } type TimeStampSchedulerImpl struct { - meshManager *mesh.MeshManger + meshManager *mesh.MeshManager updateRate int quit chan struct{} }