diff --git a/cmd/wg-mesh/main.go b/cmd/wg-mesh/main.go index 526beae..956c9ff 100644 --- a/cmd/wg-mesh/main.go +++ b/cmd/wg-mesh/main.go @@ -13,21 +13,21 @@ import ( const SockAddr = "/tmp/wgmesh_ipc.sock" type CreateMeshParams struct { - Client *ipcRpc.Client - WgPort int - Endpoint string - Role string + Client *ipcRpc.Client + Endpoint string + Role string + WgArgs ipc.WireGuardArgs + AdvertiseRoutes bool + AdvertiseDefault bool } -func createMesh(args *CreateMeshParams) string { +func createMesh(params *CreateMeshParams) string { var reply string newMeshParams := ipc.NewMeshArgs{ - WgPort: args.WgPort, - Endpoint: args.Endpoint, - Role: args.Role, + WgArgs: params.WgArgs, } - err := args.Client.Call("IpcHandler.CreateMesh", &newMeshParams, &reply) + err := params.Client.Call("IpcHandler.CreateMesh", &newMeshParams, &reply) if err != nil { return err.Error() @@ -52,13 +52,14 @@ func listMeshes(client *ipcRpc.Client) { } type JoinMeshParams struct { - Client *ipcRpc.Client - MeshId string - IpAddress string - IfName string - WgPort int - Endpoint string - Role string + Client *ipcRpc.Client + MeshId string + IpAddress string + Endpoint string + Role string + WgArgs ipc.WireGuardArgs + AdvertiseRoutes bool + AdvertiseDefault bool } func joinMesh(params *JoinMeshParams) string { @@ -67,8 +68,7 @@ func joinMesh(params *JoinMeshParams) string { args := ipc.JoinMeshArgs{ MeshId: params.MeshId, IpAdress: params.IpAddress, - Port: params.WgPort, - Role: params.Role, + WgArgs: params.WgArgs, } err := params.Client.Call("IpcHandler.JoinMesh", &args, &reply) @@ -93,19 +93,6 @@ func leaveMesh(client *ipcRpc.Client, meshId string) { fmt.Println(reply) } -func enableInterface(client *ipcRpc.Client, meshId string) { - var reply string - - err := client.Call("IpcHandler.EnableInterface", &meshId, &reply) - - if err != nil { - fmt.Println(err.Error()) - return - } - - fmt.Println(reply) -} - func getGraph(client *ipcRpc.Client, meshId string) { var reply string @@ -191,31 +178,13 @@ func deleteService(client *ipcRpc.Client, service string) { fmt.Println(reply) } -func getNode(client *ipcRpc.Client, nodeId, meshId string) { - var reply string - args := &ipc.GetNodeArgs{ - NodeId: nodeId, - MeshId: meshId, - } - - err := client.Call("IpcHandler.GetNode", &args, &reply) - - if err != nil { - fmt.Println(err.Error()) - return - } - - fmt.Println(reply) -} - func main() { parser := argparse.NewParser("wg-mesh", - "wg-mesh Manipulate WireGuard meshes") + "wg-mesh Manipulate WireGuard mesh networks") newMeshCmd := parser.NewCommand("new-mesh", "Create a new mesh") listMeshCmd := parser.NewCommand("list-meshes", "List meshes the node is connected to") joinMeshCmd := parser.NewCommand("join-mesh", "Join a mesh network") - enableInterfaceCmd := parser.NewCommand("enable-interface", "Enable A Specific Mesh Interface") getGraphCmd := parser.NewCommand("get-graph", "Convert a mesh into DOT format") leaveMeshCmd := parser.NewCommand("leave-mesh", "Leave a mesh network") queryMeshCmd := parser.NewCommand("query-mesh", "Query a mesh network using JMESPath") @@ -223,38 +192,115 @@ func main() { putAliasCmd := parser.NewCommand("put-alias", "Place an alias for the node") setServiceCmd := parser.NewCommand("set-service", "Place a service into your advertisements") deleteServiceCmd := parser.NewCommand("delete-service", "Remove a service from your advertisements") - getNodeCmd := parser.NewCommand("get-node", "Get a specific node from the mesh") - var newMeshPort *int = newMeshCmd.Int("p", "wgport", &argparse.Options{}) - var newMeshEndpoint *string = newMeshCmd.String("e", "endpoint", &argparse.Options{}) - var newMeshRole *string = newMeshCmd.Selector("r", "role", []string{"peer", "client"}, &argparse.Options{}) + var newMeshPort *int = newMeshCmd.Int("p", "wgport", &argparse.Options{ + Default: 0, + Help: "WireGuard port to use to the interface. A default of 0 uses an unused ephmeral port.", + }) - var joinMeshId *string = joinMeshCmd.String("m", "mesh", &argparse.Options{Required: true}) - var joinMeshIpAddress *string = joinMeshCmd.String("i", "ip", &argparse.Options{Required: true}) - var joinMeshPort *int = joinMeshCmd.Int("p", "wgport", &argparse.Options{}) - var joinMeshEndpoint *string = joinMeshCmd.String("e", "endpoint", &argparse.Options{}) - var joinMeshRole *string = joinMeshCmd.Selector("r", "role", []string{"peer", "client"}, &argparse.Options{}) + var newMeshEndpoint *string = newMeshCmd.String("e", "endpoint", &argparse.Options{ + Help: "Publicly routeable endpoint to advertise within the mesh", + }) - var enableInterfaceMeshId *string = enableInterfaceCmd.String("m", "mesh", &argparse.Options{Required: true}) + var newMeshRole *string = newMeshCmd.Selector("r", "role", []string{"peer", "client"}, &argparse.Options{ + Help: "Role in the mesh network. A value of peer means that the node is publicly routeable and thus considered" + + " in the gossip protocol. Client means that the node is not publicly routeable and is not a candidate in the gossip" + + " protocol", + }) + var newMeshKeepAliveWg *int = newMeshCmd.Int("k", "KeepAliveWg", &argparse.Options{ + Default: 0, + Help: "WireGuard KeepAlive value for NAT traversal and firewall holepunching", + }) - var getGraphMeshId *string = getGraphCmd.String("m", "mesh", &argparse.Options{Required: true}) + var newMeshAdvertiseRoutes *bool = newMeshCmd.Flag("a", "advertise", &argparse.Options{ + Help: "Advertise routes to other mesh network into the mesh", + }) - var leaveMeshMeshId *string = leaveMeshCmd.String("m", "mesh", &argparse.Options{Required: true}) + var newMeshAdvertiseDefaults *bool = newMeshCmd.Flag("d", "defaults", &argparse.Options{ + Help: "Advertise ::/0 into the mesh network", + }) - var queryMeshMeshId *string = queryMeshCmd.String("m", "mesh", &argparse.Options{Required: true}) - var queryMeshQuery *string = queryMeshCmd.String("q", "query", &argparse.Options{Required: true}) + var joinMeshId *string = joinMeshCmd.String("m", "meshid", &argparse.Options{ + Required: true, + Help: "MeshID of the mesh network to join", + }) - var description *string = putDescriptionCmd.String("d", "description", &argparse.Options{Required: true}) + var joinMeshIpAddress *string = joinMeshCmd.String("i", "ip", &argparse.Options{ + Required: true, + Help: "IP address of the bootstrapping node to join through", + }) - var alias *string = putAliasCmd.String("a", "alias", &argparse.Options{Required: true}) + var joinMeshEndpoint *string = joinMeshCmd.String("e", "endpoint", &argparse.Options{ + Help: "Publicly routeable endpoint to advertise within the mesh", + }) - var serviceKey *string = setServiceCmd.String("s", "service", &argparse.Options{Required: true}) - var serviceValue *string = setServiceCmd.String("v", "value", &argparse.Options{Required: true}) + var joinMeshRole *string = joinMeshCmd.Selector("r", "role", []string{"peer", "client"}, &argparse.Options{ + Default: "Peer", + Help: "Role in the mesh network. A value of peer means that the node is publicly routeable and thus considered" + + " in the gossip protocol. Client means that the node is not publicly routeable and is not a candidate in the gossip" + + " protocol", + }) - var deleteServiceKey *string = deleteServiceCmd.String("s", "service", &argparse.Options{Required: true}) + var joinMeshPort *int = joinMeshCmd.Int("p", "wgport", &argparse.Options{ + Default: 0, + Help: "WireGuard port to use to the interface. A default of 0 uses an unused ephmeral port.", + }) - var getNodeNodeId *string = getNodeCmd.String("n", "nodeid", &argparse.Options{Required: true}) - var getNodeMeshId *string = getNodeCmd.String("m", "meshid", &argparse.Options{Required: true}) + var joinMeshKeepAliveWg *int = joinMeshCmd.Int("k", "KeepAliveWg", &argparse.Options{ + Default: 0, + Help: "WireGuard KeepAlive value for NAT traversal and firewall ho;lepunching", + }) + + var joinMeshAdvertiseRoutes *bool = joinMeshCmd.Flag("a", "advertise", &argparse.Options{ + Help: "Advertise routes to other mesh network into the mesh", + }) + + var joinMeshAdvertiseDefaults *bool = joinMeshCmd.Flag("d", "defaults", &argparse.Options{ + Help: "Advertise ::/0 into the mesh network", + }) + + var getGraphMeshId *string = getGraphCmd.String("m", "mesh", &argparse.Options{ + Required: true, + Help: "MeshID of the graph to get", + }) + + var leaveMeshMeshId *string = leaveMeshCmd.String("m", "mesh", &argparse.Options{ + Required: true, + Help: "MeshID of the mesh to leave", + }) + + var queryMeshMeshId *string = queryMeshCmd.String("m", "mesh", &argparse.Options{ + Required: true, + Help: "MeshID of the mesh to query", + }) + var queryMeshQuery *string = queryMeshCmd.String("q", "query", &argparse.Options{ + Required: true, + Help: "JMESPath Query Of The Mesh Network To Query", + }) + + var description *string = putDescriptionCmd.String("d", "description", &argparse.Options{ + Required: true, + Help: "Description of the node in the mesh", + }) + + var alias *string = putAliasCmd.String("a", "alias", &argparse.Options{ + Required: true, + Help: "Alias of the node to set can be used in DNS to lookup an IP address", + }) + + var serviceKey *string = setServiceCmd.String("s", "service", &argparse.Options{ + Required: true, + Help: "Key of the service to advertise in the mesh network", + }) + var serviceValue *string = setServiceCmd.String("v", "value", &argparse.Options{ + Required: true, + Help: "Value of the service to advertise in the mesh network", + }) + + var deleteServiceKey *string = deleteServiceCmd.String("s", "service", &argparse.Options{ + Required: true, + Help: "Key of the service to remove", + }) err := parser.Parse(os.Args) @@ -272,9 +318,16 @@ func main() { if newMeshCmd.Happened() { fmt.Println(createMesh(&CreateMeshParams{ Client: client, - WgPort: *newMeshPort, Endpoint: *newMeshEndpoint, Role: *newMeshRole, + WgArgs: ipc.WireGuardArgs{ + Endpoint: *newMeshEndpoint, + Role: *newMeshRole, + WgPort: *newMeshPort, + KeepAliveWg: *newMeshKeepAliveWg, + AdvertiseDefaultRoute: *newMeshAdvertiseDefaults, + AdvertiseRoutes: *newMeshAdvertiseRoutes, + }, })) } @@ -285,11 +338,18 @@ func main() { if joinMeshCmd.Happened() { fmt.Println(joinMesh(&JoinMeshParams{ Client: client, - WgPort: *joinMeshPort, IpAddress: *joinMeshIpAddress, MeshId: *joinMeshId, Endpoint: *joinMeshEndpoint, Role: *joinMeshRole, + WgArgs: ipc.WireGuardArgs{ + Endpoint: *joinMeshEndpoint, + Role: *joinMeshRole, + WgPort: *joinMeshPort, + KeepAliveWg: *joinMeshKeepAliveWg, + AdvertiseDefaultRoute: *joinMeshAdvertiseDefaults, + AdvertiseRoutes: *joinMeshAdvertiseRoutes, + }, })) } @@ -297,10 +357,6 @@ func main() { getGraph(client, *getGraphMeshId) } - if enableInterfaceCmd.Happened() { - enableInterface(client, *enableInterfaceMeshId) - } - if leaveMeshCmd.Happened() { leaveMesh(client, *leaveMeshMeshId) } @@ -324,8 +380,4 @@ func main() { if deleteServiceCmd.Happened() { deleteService(client, *deleteServiceKey) } - - if getNodeCmd.Happened() { - getNode(client, *getNodeNodeId, *getNodeMeshId) - } } diff --git a/pkg/api/apiserver.go b/pkg/api/apiserver.go index 6e40760..d11fc56 100644 --- a/pkg/api/apiserver.go +++ b/pkg/api/apiserver.go @@ -99,7 +99,9 @@ func (s *SmegServer) CreateMesh(c *gin.Context) { } ipcRequest := ipc.NewMeshArgs{ - WgPort: createMesh.WgPort, + WgArgs: ipc.WireGuardArgs{ + WgPort: createMesh.WgPort, + }, } var reply string @@ -132,7 +134,9 @@ func (s *SmegServer) JoinMesh(c *gin.Context) { ipcRequest := ipc.JoinMeshArgs{ MeshId: joinMesh.MeshId, IpAdress: joinMesh.Bootstrap, - Port: joinMesh.WgPort, + WgArgs: ipc.WireGuardArgs{ + WgPort: joinMesh.WgPort, + }, } var reply string diff --git a/pkg/conf/conf.go b/pkg/conf/conf.go index 90a3701..c8821d2 100644 --- a/pkg/conf/conf.go +++ b/pkg/conf/conf.go @@ -37,10 +37,10 @@ type WgConfiguration struct { // service for IPDiscoverability IPDiscovery *IPDiscovery `yaml:"ipDiscovery" validate:"required,eq=public|eq=dns"` // AdvertiseRoutes: specifies whether the node can act as a router routing packets between meshes - AdvertiseRoutes *bool `yaml:"advertiseRoute"` + AdvertiseRoutes *bool `yaml:"advertiseRoute" validate:"required"` // AdvertiseDefaultRoute: specifies whether or not this route should advertise a default route // for all nodes to route their packets to - AdvertiseDefaultRoute *bool `yaml:"advertiseDefaults"` + AdvertiseDefaultRoute *bool `yaml:"advertiseDefaults" validate:"required"` // Endpoint contains what value should be set as the public endpoint of this node Endpoint *string `yaml:"publicEndpoint"` // Role specifies whether or not the user is globally accessible. diff --git a/pkg/ipc/ipc.go b/pkg/ipc/ipc.go index 00f2b52..8496578 100644 --- a/pkg/ipc/ipc.go +++ b/pkg/ipc/ipc.go @@ -10,13 +10,29 @@ import ( "github.com/tim-beatham/wgmesh/pkg/ctrlserver" ) -type NewMeshArgs struct { +// WireGuardArgs are provided args specific to WireGuard +type WireGuardArgs struct { // WgPort is the WireGuard port to expose WgPort int + // KeepAliveWg is the number of seconds to keep alive + // for WireGuard NAT/firewall traversal + KeepAliveWg int + // AdvertiseRoutes whether or not to advertise routes to and from the + // mesh network + AdvertiseRoutes bool + // AdvertiseDefaultRoute whether or not to advertise the default route + // into the mesh network + AdvertiseDefaultRoute bool // Endpoint is the routable alias of the machine. Can be an IP // or DNS entry Endpoint string - Role string + // Role is the role of the individual in the mesh + Role string +} + +type NewMeshArgs struct { + // WgArgs are specific WireGuard args to use + WgArgs WireGuardArgs } type JoinMeshArgs struct { @@ -24,14 +40,8 @@ type JoinMeshArgs struct { MeshId string // IpAddress is a routable IP in another mesh IpAdress string - // Port is the WireGuard port to expose - Port int - // Endpoint to use to override the default - Endpoint string - // Client specifies whether we should join as a client of the peer - // we are connecting to - Client bool - Role string + // WgArgs is the WireGuard parameters to use. + WgArgs WireGuardArgs } type PutServiceArgs struct { @@ -52,11 +62,6 @@ type QueryMesh struct { Query string } -type GetNodeArgs struct { - NodeId string - MeshId string -} - type MeshIpc interface { CreateMesh(args *NewMeshArgs, reply *string) error ListMeshes(name string, reply *ListMeshReply) error @@ -68,7 +73,6 @@ type MeshIpc interface { PutDescription(description string, reply *string) error PutAlias(alias string, reply *string) error PutService(args PutServiceArgs, reply *string) error - GetNode(args GetNodeArgs, reply *string) error DeleteService(service string, reply *string) error } diff --git a/pkg/mesh/route.go b/pkg/mesh/route.go index 9d39fb4..f7fd4c4 100644 --- a/pkg/mesh/route.go +++ b/pkg/mesh/route.go @@ -20,6 +20,10 @@ func (r *RouteManagerImpl) UpdateRoutes() error { routes := make(map[string][]Route) for _, mesh1 := range meshes { + if !*mesh1.GetConfiguration().AdvertiseRoutes { + continue + } + self, err := r.meshManager.GetSelf(mesh1.GetMeshId()) if err != nil { diff --git a/pkg/robin/requester.go b/pkg/robin/requester.go index 3a9a3e0..75ef61e 100644 --- a/pkg/robin/requester.go +++ b/pkg/robin/requester.go @@ -2,7 +2,6 @@ package robin import ( "context" - "encoding/json" "errors" "fmt" "strconv" @@ -12,7 +11,6 @@ import ( "github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/wgmesh/pkg/ipc" "github.com/tim-beatham/wgmesh/pkg/mesh" - "github.com/tim-beatham/wgmesh/pkg/query" "github.com/tim-beatham/wgmesh/pkg/rpc" ) @@ -20,8 +18,8 @@ type IpcHandler struct { Server ctrlserver.CtrlServer } -func (n *IpcHandler) CreateMesh(args *ipc.NewMeshArgs, reply *string) error { - overrideConf := &conf.WgConfiguration{} +func getOverrideConfiguration(args *ipc.WireGuardArgs) conf.WgConfiguration { + overrideConf := conf.WgConfiguration{} if args.Role != "" { role := conf.NodeType(args.Role) @@ -32,13 +30,26 @@ func (n *IpcHandler) CreateMesh(args *ipc.NewMeshArgs, reply *string) error { overrideConf.Endpoint = &args.Endpoint } + if args.KeepAliveWg != 0 { + keepAliveWg := args.KeepAliveWg + overrideConf.KeepAliveWg = &keepAliveWg + } + + overrideConf.AdvertiseRoutes = &args.AdvertiseRoutes + overrideConf.AdvertiseDefaultRoute = &args.AdvertiseDefaultRoute + return overrideConf +} + +func (n *IpcHandler) CreateMesh(args *ipc.NewMeshArgs, reply *string) error { + overrideConf := getOverrideConfiguration(&args.WgArgs) + if overrideConf.Role != nil && *overrideConf.Role == conf.CLIENT_ROLE { return fmt.Errorf("cannot create a mesh with no public endpoint") } meshId, err := n.Server.GetMeshManager().CreateMesh(&mesh.CreateMeshParams{ - Port: args.WgPort, - Conf: overrideConf, + Port: args.WgArgs.WgPort, + Conf: &overrideConf, }) if err != nil { @@ -47,8 +58,8 @@ func (n *IpcHandler) CreateMesh(args *ipc.NewMeshArgs, reply *string) error { err = n.Server.GetMeshManager().AddSelf(&mesh.AddSelfParams{ MeshId: meshId, - WgPort: args.WgPort, - Endpoint: args.Endpoint, + WgPort: args.WgArgs.WgPort, + Endpoint: args.WgArgs.Endpoint, }) if err != nil { @@ -73,16 +84,7 @@ func (n *IpcHandler) ListMeshes(_ string, reply *ipc.ListMeshReply) error { } func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error { - overrideConf := &conf.WgConfiguration{} - - if args.Role != "" { - role := conf.NodeType(args.Role) - overrideConf.Role = &role - } - - if args.Endpoint != "" { - overrideConf.Endpoint = &args.Endpoint - } + overrideConf := getOverrideConfiguration(&args.WgArgs) peerConnection, err := n.Server.GetConnectionManager().GetConnection(args.IpAdress) @@ -115,9 +117,9 @@ func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error { err = n.Server.GetMeshManager().AddMesh(&mesh.AddMeshParams{ MeshId: args.MeshId, - WgPort: args.Port, + WgPort: args.WgArgs.WgPort, MeshBytes: meshReply.Mesh, - Conf: overrideConf, + Conf: &overrideConf, }) if err != nil { @@ -126,8 +128,8 @@ func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error { err = n.Server.GetMeshManager().AddSelf(&mesh.AddSelfParams{ MeshId: args.MeshId, - WgPort: args.Port, - Endpoint: args.Endpoint, + WgPort: args.WgArgs.WgPort, + Endpoint: args.WgArgs.Endpoint, }) if err != nil { @@ -248,27 +250,6 @@ func (n *IpcHandler) DeleteService(service string, reply *string) error { return nil } -func (n *IpcHandler) GetNode(args ipc.GetNodeArgs, reply *string) error { - node := n.Server.GetMeshManager().GetNode(args.MeshId, args.NodeId) - - if node == nil { - *reply = "nil" - return nil - } - - queryNode := query.MeshNodeToQueryNode(node) - - bytes, err := json.Marshal(queryNode) - - if err != nil { - *reply = err.Error() - return nil - } - - *reply = string(bytes) - return nil -} - type RobinIpcParams struct { CtrlServer ctrlserver.CtrlServer }