diff --git a/cmd/wg-mesh/main.go b/cmd/wg-mesh/main.go index f396e3b..86b6a86 100644 --- a/cmd/wg-mesh/main.go +++ b/cmd/wg-mesh/main.go @@ -16,7 +16,6 @@ const SockAddr = "/tmp/wgmesh_ipc.sock" type CreateMeshParams struct { Client *ipcRpc.Client - IfName string WgPort int Endpoint string } @@ -24,7 +23,6 @@ type CreateMeshParams struct { func createMesh(args *CreateMeshParams) string { var reply string newMeshParams := ipc.NewMeshArgs{ - IfName: args.IfName, WgPort: args.WgPort, Endpoint: args.Endpoint, } @@ -68,7 +66,6 @@ func joinMesh(params *JoinMeshParams) string { args := ipc.JoinMeshArgs{ MeshId: params.MeshId, IpAdress: params.IpAddress, - IfName: params.IfName, Port: params.WgPort, } @@ -251,14 +248,12 @@ func main() { deleteServiceCmd := parser.NewCommand("delete-service", "Remove a service from your advertisements") getNodeCmd := parser.NewCommand("get-node", "Get a specific node from the mesh") - var newMeshIfName *string = newMeshCmd.String("f", "ifname", &argparse.Options{Required: true}) - var newMeshPort *int = newMeshCmd.Int("p", "wgport", &argparse.Options{Required: true}) + var newMeshPort *int = newMeshCmd.Int("p", "wgport", &argparse.Options{}) var newMeshEndpoint *string = newMeshCmd.String("e", "endpoint", &argparse.Options{}) var joinMeshId *string = joinMeshCmd.String("m", "mesh", &argparse.Options{Required: true}) var joinMeshIpAddress *string = joinMeshCmd.String("i", "ip", &argparse.Options{Required: true}) - var joinMeshIfName *string = joinMeshCmd.String("f", "ifname", &argparse.Options{Required: true}) - var joinMeshPort *int = joinMeshCmd.Int("p", "wgport", &argparse.Options{Required: true}) + var joinMeshPort *int = joinMeshCmd.Int("p", "wgport", &argparse.Options{}) var joinMeshEndpoint *string = joinMeshCmd.String("e", "endpoint", &argparse.Options{}) var enableInterfaceMeshId *string = enableInterfaceCmd.String("m", "mesh", &argparse.Options{Required: true}) @@ -298,7 +293,6 @@ func main() { if newMeshCmd.Happened() { fmt.Println(createMesh(&CreateMeshParams{ Client: client, - IfName: *newMeshIfName, WgPort: *newMeshPort, Endpoint: *newMeshEndpoint, })) @@ -311,7 +305,6 @@ func main() { if joinMeshCmd.Happened() { fmt.Println(joinMesh(&JoinMeshParams{ Client: client, - IfName: *joinMeshIfName, WgPort: *joinMeshPort, IpAddress: *joinMeshIpAddress, MeshId: *joinMeshId, diff --git a/pkg/api/apiserver.go b/pkg/api/apiserver.go index f5ee232..655543f 100644 --- a/pkg/api/apiserver.go +++ b/pkg/api/apiserver.go @@ -62,11 +62,11 @@ func (s *SmegServer) CreateMesh(c *gin.Context) { c.JSON(http.StatusBadRequest, &gin.H{ "error": err.Error(), }) + return } ipcRequest := ipc.NewMeshArgs{ - IfName: createMesh.IfName, WgPort: createMesh.WgPort, } @@ -100,7 +100,6 @@ func (s *SmegServer) JoinMesh(c *gin.Context) { ipcRequest := ipc.JoinMeshArgs{ MeshId: joinMesh.MeshId, IpAdress: joinMesh.Bootstrap, - IfName: joinMesh.IfName, Port: joinMesh.WgPort, } diff --git a/pkg/api/types.go b/pkg/api/types.go index 5429fed..11d26a2 100644 --- a/pkg/api/types.go +++ b/pkg/api/types.go @@ -18,13 +18,11 @@ type SmegMesh struct { } type CreateMeshRequest struct { - IfName string `json:"ifName" binding:"required"` - WgPort int `json:"port" binding:"required,gte=1024,lt=65535"` + WgPort int `json:"port" binding:"gte=1024,lt=65535"` } type JoinMeshRequest struct { - IfName string `json:"ifName" binding:"required"` - WgPort int `json:"port" binding:"required,gte=1024,lt=65535"` + WgPort int `json:"port" binding:"gte=1024,lt=65535"` Bootstrap string `json:"bootstrap" binding:"required"` MeshId string `json:"meshid" binding:"required"` } diff --git a/pkg/ipc/ipc.go b/pkg/ipc/ipc.go index ab8fe8c..07487c6 100644 --- a/pkg/ipc/ipc.go +++ b/pkg/ipc/ipc.go @@ -11,8 +11,6 @@ import ( ) type NewMeshArgs struct { - // IfName is the interface that the mesh instance will run on - IfName string // WgPort is the WireGuard port to expose WgPort int // Endpoint is the routable alias of the machine. Can be an IP @@ -25,8 +23,6 @@ type JoinMeshArgs struct { MeshId string // IpAddress is a routable IP in another mesh IpAdress string - // IfName is the interface name of the mesh - IfName string // Port is the WireGuard port to expose Port int // Endpoint is the routable address of this machine. If not provided diff --git a/pkg/mesh/manager.go b/pkg/mesh/manager.go index ed78b9a..76eb608 100644 --- a/pkg/mesh/manager.go +++ b/pkg/mesh/manager.go @@ -14,7 +14,7 @@ import ( ) type MeshManager interface { - CreateMesh(devName string, port int) (string, error) + CreateMesh(port int) (string, error) AddMesh(params *AddMeshParams) error HasChanges(meshid string) bool GetMesh(meshId string) MeshProvider @@ -115,15 +115,25 @@ func (m *MeshManagerImpl) Prune() error { } // CreateMesh: Creates a new mesh, stores it and returns the mesh id -func (m *MeshManagerImpl) CreateMesh(devName string, port int) (string, error) { +func (m *MeshManagerImpl) CreateMesh(port int) (string, error) { meshId, err := m.idGenerator.GetId() + var ifName string = "" + if err != nil { return "", err } + if !m.conf.StubWg { + ifName, err = m.interfaceManipulator.CreateInterface(port) + + if err != nil { + return "", fmt.Errorf("error creating mesh: %w", err) + } + } + nodeManager, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{ - DevName: devName, + DevName: ifName, Port: port, Conf: m.conf, Client: m.Client, @@ -134,32 +144,31 @@ func (m *MeshManagerImpl) CreateMesh(devName string, port int) (string, error) { return "", fmt.Errorf("error creating mesh: %w", err) } - if !m.conf.StubWg { - err = m.interfaceManipulator.CreateInterface(&wg.CreateInterfaceParams{ - IfName: devName, - Port: port, - }) - - if err != nil { - return "", fmt.Errorf("error creating mesh: %w", err) - } - } - m.Meshes[meshId] = nodeManager return meshId, nil } type AddMeshParams struct { MeshId string - DevName string WgPort int MeshBytes []byte } // AddMesh: Add the mesh to the list of meshes func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error { + var ifName string + var err error + + if !m.conf.StubWg { + ifName, err = m.interfaceManipulator.CreateInterface(params.WgPort) + + if err != nil { + return err + } + } + meshProvider, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{ - DevName: params.DevName, + DevName: ifName, Port: params.WgPort, Conf: m.conf, Client: m.Client, @@ -177,14 +186,6 @@ func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error { } m.Meshes[params.MeshId] = meshProvider - - if !m.conf.StubWg { - return m.interfaceManipulator.CreateInterface(&wg.CreateInterfaceParams{ - IfName: params.DevName, - Port: params.WgPort, - }) - } - return nil } diff --git a/pkg/mesh/stub_types.go b/pkg/mesh/stub_types.go index b13e9ac..78515ba 100644 --- a/pkg/mesh/stub_types.go +++ b/pkg/mesh/stub_types.go @@ -250,7 +250,7 @@ func NewMeshManagerStub() MeshManager { return &MeshManagerStub{meshes: make(map[string]MeshProvider)} } -func (m *MeshManagerStub) CreateMesh(devName string, port int) (string, error) { +func (m *MeshManagerStub) CreateMesh(port int) (string, error) { return "tim123", nil } diff --git a/pkg/robin/requester.go b/pkg/robin/requester.go index 71c58ec..e98a228 100644 --- a/pkg/robin/requester.go +++ b/pkg/robin/requester.go @@ -20,7 +20,7 @@ type IpcHandler struct { } func (n *IpcHandler) CreateMesh(args *ipc.NewMeshArgs, reply *string) error { - meshId, err := n.Server.GetMeshManager().CreateMesh(args.IfName, args.WgPort) + meshId, err := n.Server.GetMeshManager().CreateMesh(args.WgPort) if err != nil { return err @@ -83,7 +83,6 @@ func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error { err = n.Server.GetMeshManager().AddMesh(&mesh.AddMeshParams{ MeshId: args.MeshId, - DevName: args.IfName, WgPort: args.Port, MeshBytes: meshReply.Mesh, }) diff --git a/pkg/sync/syncrequester.go b/pkg/sync/syncrequester.go index 3c40d72..472e3f0 100644 --- a/pkg/sync/syncrequester.go +++ b/pkg/sync/syncrequester.go @@ -50,7 +50,6 @@ func (s *SyncRequesterImpl) GetMesh(meshId string, ifName string, port int, endP err = s.server.MeshManager.AddMesh(&mesh.AddMeshParams{ MeshId: meshId, - DevName: ifName, WgPort: port, MeshBytes: reply.Mesh, }) diff --git a/pkg/wg/stubs.go b/pkg/wg/stubs.go index 5adcfc5..4d8e704 100644 --- a/pkg/wg/stubs.go +++ b/pkg/wg/stubs.go @@ -2,7 +2,7 @@ package wg type WgInterfaceManipulatorStub struct{} -func (i *WgInterfaceManipulatorStub) CreateInterface(params *CreateInterfaceParams) error { +func (i *WgInterfaceManipulatorStub) CreateInterface(port int) error { return nil } diff --git a/pkg/wg/types.go b/pkg/wg/types.go index a7a44b9..99f22b1 100644 --- a/pkg/wg/types.go +++ b/pkg/wg/types.go @@ -8,14 +8,9 @@ func (m *WgError) Error() string { return m.msg } -type CreateInterfaceParams struct { - IfName string - Port int -} - type WgInterfaceManipulator interface { // CreateInterface creates a WireGuard interface - CreateInterface(params *CreateInterfaceParams) error + CreateInterface(port int) (string, error) // AddAddress adds an address to the given interface name AddAddress(ifName string, addr string) error // RemoveInterface removes the specified interface diff --git a/pkg/wg/wg.go b/pkg/wg/wg.go index 691d197..70ba6af 100644 --- a/pkg/wg/wg.go +++ b/pkg/wg/wg.go @@ -1,6 +1,9 @@ package wg import ( + "crypto" + "crypto/rand" + "encoding/base64" "fmt" "github.com/tim-beatham/wgmesh/pkg/lib" @@ -13,40 +16,54 @@ type WgInterfaceManipulatorImpl struct { client *wgctrl.Client } +const hashLength = 6 + // CreateInterface creates a WireGuard interface -func (m *WgInterfaceManipulatorImpl) CreateInterface(params *CreateInterfaceParams) error { +func (m *WgInterfaceManipulatorImpl) CreateInterface(port int) (string, error) { rtnl, err := lib.NewRtNetlinkConfig() if err != nil { - return fmt.Errorf("failed to access link: %w", err) + return "", fmt.Errorf("failed to access link: %w", err) } defer rtnl.Close() - err = rtnl.CreateLink(params.IfName) + randomBuf := make([]byte, 32) + _, err = rand.Read(randomBuf) if err != nil { - return fmt.Errorf("failed to create link: %w", err) + return "", err + } + + md5 := crypto.MD5.New().Sum(randomBuf) + + md5Str := fmt.Sprintf("wg%s", base64.StdEncoding.EncodeToString(md5)[:hashLength]) + + err = rtnl.CreateLink(md5Str) + + if err != nil { + return "", fmt.Errorf("failed to create link: %w", err) } privateKey, err := wgtypes.GeneratePrivateKey() if err != nil { - return fmt.Errorf("failed to create private key: %w", err) + return "", fmt.Errorf("failed to create private key: %w", err) } var cfg wgtypes.Config = wgtypes.Config{ PrivateKey: &privateKey, - ListenPort: ¶ms.Port, + ListenPort: &port, } - err = m.client.ConfigureDevice(params.IfName, cfg) + err = m.client.ConfigureDevice(md5Str, cfg) if err != nil { - return fmt.Errorf("failed to configure dev: %w", err) + m.RemoveInterface(md5Str) + return "", fmt.Errorf("failed to configure dev: %w", err) } - logging.Log.WriteInfof("ip link set up dev %s type wireguard", params.IfName) - return nil + logging.Log.WriteInfof("ip link set up dev %s type wireguard", md5Str) + return md5Str, nil } // Add an address to the given interface