diff --git a/cmd/wg-mesh/main.go b/cmd/wg-mesh/main.go index 95c1e35..c0d1220 100644 --- a/cmd/wg-mesh/main.go +++ b/cmd/wg-mesh/main.go @@ -87,6 +87,19 @@ func enableInterface(client *ipcRpc.Client, meshId string) { fmt.Println(reply) } +func getGraph(client *ipcRpc.Client, meshId string) { + var reply string + + err := client.Call("RobinIpc.GetDOT", &meshId, &reply) + + if err != nil { + fmt.Println(err.Error()) + return + } + + fmt.Println(reply) +} + func main() { parser := argparse.NewParser("wg-mesh", "wg-mesh Manipulate WireGuard meshes") @@ -96,12 +109,13 @@ func main() { joinMeshCmd := parser.NewCommand("join-mesh", "Join a mesh network") getMeshCmd := parser.NewCommand("get-mesh", "Get a mesh network") enableInterfaceCmd := parser.NewCommand("enable-interface", "Enable A Specific Mesh Interface") + getGraphCmd := parser.NewCommand("get-graph", "Convert a mesh into DOT format") var meshId *string = joinMeshCmd.String("m", "mesh", &argparse.Options{Required: true}) var ipAddress *string = joinMeshCmd.String("i", "ip", &argparse.Options{Required: true}) - var getMeshId *string = getMeshCmd.String("m", "mesh", &argparse.Options{Required: true}) var enableInterfaceMeshId *string = enableInterfaceCmd.String("m", "mesh", &argparse.Options{Required: true}) + var getGraphMeshId *string = getGraphCmd.String("m", "mesh", &argparse.Options{Required: true}) err := parser.Parse(os.Args) @@ -132,6 +146,10 @@ func main() { getMesh(client, *getMeshId) } + if getGraphCmd.Happened() { + getGraph(client, *getGraphMeshId) + } + if enableInterfaceCmd.Happened() { enableInterface(client, *enableInterfaceMeshId) } diff --git a/pkg/automerge/automerge.go b/pkg/automerge/automerge.go index b6cfff8..4f468d8 100644 --- a/pkg/automerge/automerge.go +++ b/pkg/automerge/automerge.go @@ -2,6 +2,7 @@ package crdt import ( "errors" + "fmt" "net" "strings" @@ -188,13 +189,13 @@ func (m *CrdtNodeManager) GetNode(endpoint string) (*MeshNodeCrdt, error) { return meshNode, nil } -const threshold = 2 -const thresholdVotes = 0.1 - func (m *CrdtNodeManager) Length() int { return m.doc.Path("nodes").Map().Len() } +const threshold = 2 +const thresholdVotes = 0.1 + func (m *CrdtNodeManager) HasFailed(endpoint string) bool { node, err := m.GetNode(endpoint) @@ -249,3 +250,7 @@ func (m *CrdtNodeManager) updateWgConf(devName string, nodes map[string]MeshNode client.ConfigureDevice(devName, cfg) return nil } + +func (n *MeshNodeCrdt) GetEscapedIP() string { + return fmt.Sprintf("\"%s\"", n.WgHost) +} diff --git a/pkg/ctrlserver/ctrlserver.go b/pkg/ctrlserver/ctrlserver.go index cba66ca..0afa349 100644 --- a/pkg/ctrlserver/ctrlserver.go +++ b/pkg/ctrlserver/ctrlserver.go @@ -8,7 +8,7 @@ package ctrlserver import ( "github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/wgmesh/pkg/conn" - "github.com/tim-beatham/wgmesh/pkg/manager" + "github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/wgmesh/pkg/rpc" "golang.zx2c4.com/wireguard/wgctrl" ) @@ -30,7 +30,7 @@ type NewCtrlServerParams struct { func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) { ctrlServer := new(MeshCtrlServer) ctrlServer.Client = params.WgClient - ctrlServer.MeshManager = manager.NewMeshManager(*params.WgClient, *params.Conf) + ctrlServer.MeshManager = mesh.NewMeshManager(*params.WgClient, *params.Conf) ctrlServer.Conf = params.Conf connManagerParams := conn.NewJwtConnectionManagerParams{ diff --git a/pkg/ctrlserver/ctrltypes.go b/pkg/ctrlserver/ctrltypes.go index 53c301e..6f1bb2f 100644 --- a/pkg/ctrlserver/ctrltypes.go +++ b/pkg/ctrlserver/ctrltypes.go @@ -3,7 +3,7 @@ package ctrlserver import ( "github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/wgmesh/pkg/conn" - "github.com/tim-beatham/wgmesh/pkg/manager" + "github.com/tim-beatham/wgmesh/pkg/mesh" "golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) @@ -30,7 +30,7 @@ type Mesh struct { */ type MeshCtrlServer struct { Client *wgctrl.Client - MeshManager *manager.MeshManger + MeshManager *mesh.MeshManger ConnectionManager conn.ConnectionManager ConnectionServer *conn.ConnectionServer Conf *conf.WgMeshConfiguration diff --git a/pkg/graph/graph.go b/pkg/graph/graph.go new file mode 100644 index 0000000..2a244db --- /dev/null +++ b/pkg/graph/graph.go @@ -0,0 +1,169 @@ +// Graph allows the definition of a DOT graph in golang +package graph + +import ( + "errors" + "fmt" + "hash/fnv" + "strings" + + "github.com/tim-beatham/wgmesh/pkg/lib" +) + +type GraphType string + +const ( + GRAPH GraphType = "graph" + DIGRAPH = "digraph" +) + +type Graph struct { + Type GraphType + Label string + nodes map[string]*Node + edges []Edge +} + +type Node struct { + Name string +} + +type Edge interface { + Dottable +} + +type DirectedEdge struct { + Label string + From *Node + To *Node +} + +type UndirectedEdge struct { + Label string + From *Node + To *Node +} + +// Dottable means an implementer can convert the struct to DOT representation +type Dottable interface { + GetDOT() (string, error) +} + +func NewGraph(label string, graphType GraphType) *Graph { + return &Graph{Type: graphType, Label: label, nodes: make(map[string]*Node), edges: make([]Edge, 0)} +} + +// AddNode: adds a node to the graph +func (g *Graph) AddNode(label string) error { + _, exists := g.nodes[label] + + if exists { + return errors.New(fmt.Sprintf("Node %s already exists", label)) + } + + g.nodes[label] = &Node{Name: label} + return nil +} + +func writeContituents[D Dottable](result *strings.Builder, elements ...D) error { + for _, node := range elements { + dot, err := node.GetDOT() + + if err != nil { + return err + } + + _, err = result.WriteString(dot) + + if err != nil { + return err + } + } + return nil +} + +func (g *Graph) GetDOT() (string, error) { + var result strings.Builder + + _, err := result.WriteString(fmt.Sprintf("%s {\n", g.Type)) + + if err != nil { + return "", err + } + + _, err = result.WriteString("node [colorscheme=set312];\n") + + if err != nil { + return "", err + } + + nodes := lib.MapValues(g.nodes) + + err = writeContituents(&result, nodes...) + + if err != nil { + return "", err + } + + err = writeContituents(&result, g.edges...) + + if err != nil { + return "", err + } + + _, err = result.WriteString("}") + + if err != nil { + return "", err + } + + return result.String(), nil +} + +func (g *Graph) constructEdge(label string, from *Node, to *Node) Edge { + switch g.Type { + case DIGRAPH: + return &DirectedEdge{Label: label, From: from, To: to} + default: + return &UndirectedEdge{Label: label, From: from, To: to} + } +} + +// AddEdge: adds an edge between two nodes in the graph +func (g *Graph) AddEdge(label string, from string, to string) error { + fromNode, exists := g.nodes[from] + + if !exists { + return errors.New(fmt.Sprintf("Node %s does not exist", from)) + } + + toNode, exists := g.nodes[to] + + if !exists { + return errors.New(fmt.Sprintf("Node %s does not exist", to)) + } + + g.edges = append(g.edges, g.constructEdge(label, fromNode, toNode)) + return nil +} + +const numColours = 12 + +func (n *Node) hash() int { + h := fnv.New32a() + h.Write([]byte(n.Name)) + return (int(h.Sum32()) % numColours) + 1 +} + +func (n *Node) GetDOT() (string, error) { + return fmt.Sprintf("node[shape=circle, style=\"filled\", fillcolor=%d] %s;\n", + n.hash(), n.Name), nil +} + +func (e *DirectedEdge) GetDOT() (string, error) { + return fmt.Sprintf("%s -> %s;\n", e.From.Name, e.To.Name), nil +} + +func (e *UndirectedEdge) GetDOT() (string, error) { + return fmt.Sprintf("%s -- %s;\n", e.From.Name, e.To.Name), nil +} diff --git a/pkg/ipc/ipc.go b/pkg/ipc/ipc.go index 76a4f9e..f087077 100644 --- a/pkg/ipc/ipc.go +++ b/pkg/ipc/ipc.go @@ -29,6 +29,7 @@ type MeshIpc interface { JoinMesh(args JoinMeshArgs, reply *string) error GetMesh(meshId string, reply *GetMeshReply) error EnableInterface(meshId string, reply *string) error + GetDOT(meshId string, reply *string) error } const SockAddr = "/tmp/wgmesh_ipc.sock" diff --git a/pkg/mesh/graph_generator.go b/pkg/mesh/graph_generator.go new file mode 100644 index 0000000..2aa16b9 --- /dev/null +++ b/pkg/mesh/graph_generator.go @@ -0,0 +1,60 @@ +package mesh + +import ( + "errors" + "fmt" + + "github.com/tim-beatham/wgmesh/pkg/graph" + "github.com/tim-beatham/wgmesh/pkg/lib" +) + +type MeshGraphConverter interface { + // convert the mesh to textual form + Generate(meshId string) (string, error) +} + +type MeshDOTConverter struct { + manager *MeshManger +} + +func (c *MeshDOTConverter) Generate(meshId string) (string, error) { + mesh := c.manager.GetMesh(meshId) + + if mesh == nil { + return "", errors.New("mesh does not exist") + } + + g := graph.NewGraph(meshId, graph.GRAPH) + + snapshot, err := mesh.GetCrdt() + + if err != nil { + return "", err + } + + for _, node := range snapshot.Nodes { + g.AddNode(node.GetEscapedIP()) + } + + nodes := lib.MapValues(snapshot.Nodes) + + for i, node1 := range nodes[:len(nodes)-1] { + if mesh.HasFailed(node1.HostEndpoint) { + continue + } + + for _, node2 := range nodes[i+1:] { + if node1 == node2 || mesh.HasFailed(node2.HostEndpoint) { + continue + } + + g.AddEdge(fmt.Sprintf("%s to %s", node1.GetEscapedIP(), node2.GetEscapedIP()), node1.GetEscapedIP(), node2.GetEscapedIP()) + } + } + + return g.GetDOT() +} + +func NewMeshDotConverter(m *MeshManger) MeshGraphConverter { + return &MeshDOTConverter{manager: m} +} diff --git a/pkg/manager/mesh_manager.go b/pkg/mesh/mesh_manager.go similarity index 99% rename from pkg/manager/mesh_manager.go rename to pkg/mesh/mesh_manager.go index 4393f61..b4f074c 100644 --- a/pkg/manager/mesh_manager.go +++ b/pkg/mesh/mesh_manager.go @@ -1,4 +1,4 @@ -package manager +package mesh import ( "errors" diff --git a/pkg/robin/robin_requester.go b/pkg/robin/robin_requester.go index f2206d8..2296468 100644 --- a/pkg/robin/robin_requester.go +++ b/pkg/robin/robin_requester.go @@ -12,6 +12,7 @@ import ( "github.com/tim-beatham/wgmesh/pkg/ipc" "github.com/tim-beatham/wgmesh/pkg/lib" logging "github.com/tim-beatham/wgmesh/pkg/log" + "github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/wgmesh/pkg/rpc" "github.com/tim-beatham/wgmesh/pkg/wg" ) @@ -261,6 +262,19 @@ func (n *RobinIpc) EnableInterface(meshId string, reply *string) error { return nil } +func (n *RobinIpc) GetDOT(meshId string, reply *string) error { + g := mesh.NewMeshDotConverter(n.Server.MeshManager) + + result, err := g.Generate(meshId) + + if err != nil { + return err + } + + *reply = result + return nil +} + type RobinIpcParams struct { CtrlServer *ctrlserver.MeshCtrlServer Allocator ip.IPAllocator diff --git a/pkg/sync/syncer.go b/pkg/sync/syncer.go index abc4bc5..3f0a1f1 100644 --- a/pkg/sync/syncer.go +++ b/pkg/sync/syncer.go @@ -5,7 +5,7 @@ import ( crdt "github.com/tim-beatham/wgmesh/pkg/automerge" "github.com/tim-beatham/wgmesh/pkg/lib" - "github.com/tim-beatham/wgmesh/pkg/manager" + "github.com/tim-beatham/wgmesh/pkg/mesh" ) // Syncer: picks random nodes from the mesh @@ -15,7 +15,7 @@ type Syncer interface { } type SyncerImpl struct { - manager *manager.MeshManger + manager *mesh.MeshManger requester SyncRequester authenticatedNodes []crdt.MeshNodeCrdt } @@ -45,6 +45,12 @@ func (s *SyncerImpl) Sync(meshId string) error { s.manager.HostEndpoint: {}, } + for _, node := range snapshot.Nodes { + if mesh.HasFailed(node.HostEndpoint) { + excludedNodes[node.HostEndpoint] = struct{}{} + } + } + meshNodes := lib.MapValuesWithExclude(snapshot.Nodes, excludedNodes) randomSubset := lib.RandomSubsetOfLength(meshNodes, subSetLength) @@ -72,6 +78,6 @@ func (s *SyncerImpl) SyncMeshes() error { return s.manager.ApplyWg() } -func NewSyncer(m *manager.MeshManger, r SyncRequester) Syncer { +func NewSyncer(m *mesh.MeshManger, r SyncRequester) Syncer { return &SyncerImpl{manager: m, requester: r} } diff --git a/pkg/sync/syncererror.go b/pkg/sync/syncererror.go index 04f9cbc..3b4cc6e 100644 --- a/pkg/sync/syncererror.go +++ b/pkg/sync/syncererror.go @@ -2,7 +2,7 @@ package sync import ( logging "github.com/tim-beatham/wgmesh/pkg/log" - "github.com/tim-beatham/wgmesh/pkg/manager" + "github.com/tim-beatham/wgmesh/pkg/mesh" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -12,7 +12,7 @@ type SyncErrorHandler interface { } type SyncErrorHandlerImpl struct { - meshManager *manager.MeshManger + meshManager *mesh.MeshManger } func (s *SyncErrorHandlerImpl) incrementFailedCount(meshId string, endpoint string) bool { @@ -44,6 +44,6 @@ func (s *SyncErrorHandlerImpl) Handle(meshId string, endpoint string, err error) return false } -func NewSyncErrorHandler(m *manager.MeshManger) SyncErrorHandler { +func NewSyncErrorHandler(m *mesh.MeshManger) SyncErrorHandler { return &SyncErrorHandlerImpl{meshManager: m} }