From 1a864b7c808e5059b0d9173327e002016e5fed30 Mon Sep 17 00:00:00 2001 From: Tim Beatham Date: Tue, 7 Nov 2023 19:48:53 +0000 Subject: [PATCH] Removed interface manipulation via os.Exec into rtnetlink calls --- go.mod | 2 + pkg/automerge/automerge.go | 25 +++ pkg/ctrlserver/ctrlserver.go | 7 +- pkg/lib/rtnetlink.go | 300 +++++++++++++++++++++++++++++++++++ pkg/mesh/manager.go | 96 ++++++++--- pkg/mesh/manager_test.go | 3 +- pkg/mesh/route.go | 130 +++++++++++++++ pkg/mesh/route_stub.go | 16 ++ pkg/mesh/stub_types.go | 10 ++ pkg/mesh/types.go | 2 + pkg/robin/requester.go | 5 + pkg/sync/syncer.go | 6 +- pkg/wg/stubs.go | 6 +- pkg/wg/types.go | 7 +- pkg/wg/wg.go | 102 +++++------- 15 files changed, 618 insertions(+), 99 deletions(-) create mode 100644 pkg/lib/rtnetlink.go create mode 100644 pkg/mesh/route_stub.go diff --git a/go.mod b/go.mod index d146e93..6660fcf 100644 --- a/go.mod +++ b/go.mod @@ -18,9 +18,11 @@ require ( github.com/golang/protobuf v1.5.3 // indirect github.com/google/go-cmp v0.5.9 // indirect github.com/josharian/native v1.1.0 // indirect + github.com/jsimonetti/rtnetlink v1.3.5 // indirect github.com/mdlayher/genetlink v1.3.2 // indirect github.com/mdlayher/netlink v1.7.2 // indirect github.com/mdlayher/socket v0.5.0 // indirect + github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df // indirect golang.org/x/crypto v0.13.0 // indirect golang.org/x/net v0.15.0 // indirect golang.org/x/sync v0.3.0 // indirect diff --git a/pkg/automerge/automerge.go b/pkg/automerge/automerge.go index 3d20024..761a374 100644 --- a/pkg/automerge/automerge.go +++ b/pkg/automerge/automerge.go @@ -207,6 +207,31 @@ func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...string) error { return nil } +// DeleteRoutes deletes the specified routes +func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...string) error { + nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId) + + if err != nil { + return err + } + + if nodeVal.Kind() != automerge.KindMap { + return fmt.Errorf("node is not a map") + } + + routeMap, err := nodeVal.Map().Get("routes") + + if err != nil { + return err + } + + for _, route := range routes { + err = routeMap.Map().Delete(route) + } + + return err +} + func (m *CrdtMeshManager) GetSyncer() mesh.MeshSyncer { return NewAutomergeSync(m) } diff --git a/pkg/ctrlserver/ctrlserver.go b/pkg/ctrlserver/ctrlserver.go index a3ed638..e5acef3 100644 --- a/pkg/ctrlserver/ctrlserver.go +++ b/pkg/ctrlserver/ctrlserver.go @@ -35,7 +35,6 @@ func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) { ipAllocator := &ip.ULABuilder{} interfaceManipulator := wg.NewWgInterfaceManipulator(params.Client) - var meshManager mesh.MeshManagerImpl configApplyer := mesh.NewWgMeshConfigApplyer() meshManagerParams := &mesh.NewMeshManagerParams{ @@ -49,8 +48,8 @@ func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) { ConfigApplyer: configApplyer, } - configApplyer.SetMeshManager(&meshManager) ctrlServer.MeshManager = mesh.NewMeshManager(meshManagerParams) + configApplyer.SetMeshManager(ctrlServer.MeshManager) ctrlServer.Conf = params.Conf connManagerParams := conn.NewConnectionManagerParams{ @@ -112,6 +111,10 @@ func (s *MeshCtrlServer) Close() error { logging.Log.WriteErrorf(err.Error()) } + if err := s.MeshManager.Close(); err != nil { + logging.Log.WriteErrorf(err.Error()) + } + if err := s.ConnectionServer.Close(); err != nil { logging.Log.WriteErrorf(err.Error()) } diff --git a/pkg/lib/rtnetlink.go b/pkg/lib/rtnetlink.go new file mode 100644 index 0000000..1ac5b63 --- /dev/null +++ b/pkg/lib/rtnetlink.go @@ -0,0 +1,300 @@ +package lib + +import ( + "encoding/binary" + "fmt" + "net" + + "github.com/jsimonetti/rtnetlink" + logging "github.com/tim-beatham/wgmesh/pkg/log" + "golang.org/x/sys/unix" +) + +type RtNetlinkConfig struct { + conn *rtnetlink.Conn +} + +func NewRtNetlinkConfig() (*RtNetlinkConfig, error) { + conn, err := rtnetlink.Dial(nil) + + if err != nil { + return nil, err + } + + return &RtNetlinkConfig{conn: conn}, nil +} + +const WIREGUARD_MTU = 1420 + +// Create a netlink interface if it does not exist. ifName is the name of the netlink interface +func (c *RtNetlinkConfig) CreateLink(ifName string) error { + _, err := net.InterfaceByName(ifName) + + if err == nil { + return fmt.Errorf("interface %s already exists", ifName) + } + + err = c.conn.Link.New(&rtnetlink.LinkMessage{ + Family: unix.AF_UNSPEC, + Flags: unix.IFF_UP, + Attributes: &rtnetlink.LinkAttributes{ + Name: ifName, + Info: &rtnetlink.LinkInfo{Kind: "wireguard"}, + MTU: uint32(WIREGUARD_MTU), + }, + }) + + if err != nil { + return fmt.Errorf("failed to create wireguard interface: %w", err) + } + + return nil +} + +// Delete link delete the specified interface +func (c *RtNetlinkConfig) DeleteLink(ifName string) error { + iface, err := net.InterfaceByName(ifName) + + if err != nil { + return fmt.Errorf("failed to get interface %s %w", ifName, err) + } + + err = c.conn.Link.Delete(uint32(iface.Index)) + + if err != nil { + return fmt.Errorf("failed to delete wg interface %w", err) + } + + return nil +} + +// AddAddress adds an address to the given interface. +func (c *RtNetlinkConfig) AddAddress(ifName string, address string) error { + iface, err := net.InterfaceByName(ifName) + + if err != nil { + return fmt.Errorf("failed to get interface %s error: %w", ifName, err) + } + + addr, cidr, err := net.ParseCIDR(address) + + if err != nil { + return fmt.Errorf("failed to parse CIDR %s error: %w", addr, err) + } + + family := unix.AF_INET6 + + ipv4 := cidr.IP.To4() + + if ipv4 != nil { + family = unix.AF_INET + } + + // Calculate the prefix length + ones, _ := cidr.Mask.Size() + + // Calculate the broadcast IP + // Only used when family is AF_INET + var brd net.IP + if ipv4 != nil { + brd = make(net.IP, len(ipv4)) + binary.BigEndian.PutUint32(brd, binary.BigEndian.Uint32(ipv4)|^binary.BigEndian.Uint32(net.IP(cidr.Mask).To4())) + } + + err = c.conn.Address.New(&rtnetlink.AddressMessage{ + Family: uint8(family), + PrefixLength: uint8(ones), + Scope: unix.RT_SCOPE_UNIVERSE, + Index: uint32(iface.Index), + Attributes: &rtnetlink.AddressAttributes{ + Address: addr, + Local: addr, + Broadcast: brd, + }, + }) + + if err != nil { + err = fmt.Errorf("failed to add address to link %w", err) + } + + return err +} + +// AddRoute: adds a route to the routing table. +// ifName is the intrface to add the route to +// gateway is the IP of the gateway device to hop to +// dst is the network prefix of the advertised destination +func (c *RtNetlinkConfig) AddRoute(ifName string, route Route) error { + iface, err := net.InterfaceByName(ifName) + + if err != nil { + return fmt.Errorf("failed accessing interface %s error %w", ifName, err) + } + + gw := route.Gateway + dst := route.Destination + + var family uint8 = unix.AF_INET6 + + if dst.IP.To4() != nil { + family = unix.AF_INET + } + + attr := rtnetlink.RouteAttributes{ + Dst: dst.IP, + OutIface: uint32(iface.Index), + Gateway: gw, + } + + ones, _ := dst.Mask.Size() + + err = c.conn.Route.Replace(&rtnetlink.RouteMessage{ + Family: family, + Table: unix.RT_TABLE_MAIN, + Protocol: unix.RTPROT_BOOT, + Scope: unix.RT_SCOPE_LINK, + Type: unix.RTN_UNICAST, + DstLength: uint8(ones), + Attributes: attr, + }) + + if err != nil { + return fmt.Errorf("failed to add route %w", err) + } + + return nil +} + +// DeleteRoute deletes routes with the gateway and destination +func (c *RtNetlinkConfig) DeleteRoute(ifName string, route Route) error { + iface, err := net.InterfaceByName(ifName) + + if err != nil { + return fmt.Errorf("failed accessing interface %s error %w", ifName, err) + } + + gw := route.Gateway + dst := route.Destination + + var family uint8 = unix.AF_INET6 + + if dst.IP.To4() != nil { + family = unix.AF_INET + } + + attr := rtnetlink.RouteAttributes{ + Dst: dst.IP, + OutIface: uint32(iface.Index), + Gateway: gw, + } + + ones, _ := dst.Mask.Size() + + err = c.conn.Route.Delete(&rtnetlink.RouteMessage{ + Family: family, + Table: unix.RT_TABLE_MAIN, + Protocol: unix.RTPROT_BOOT, + Scope: unix.RT_SCOPE_LINK, + Type: unix.RTN_UNICAST, + DstLength: uint8(ones), + Attributes: attr, + }) + + if err != nil { + return fmt.Errorf("failed to delete route %w", err) + } + + return nil +} + +type Route struct { + Gateway net.IP + Destination net.IPNet +} + +func (r1 Route) equal(r2 Route) bool { + return r1.Gateway.String() == r2.Gateway.String() && + r1.Destination.String() == r2.Destination.String() +} + +// DeleteRoutes deletes all routes not in exclude +func (c *RtNetlinkConfig) DeleteRoutes(ifName string, family uint8, exclude ...Route) error { + routes := make([]rtnetlink.RouteMessage, 0) + + if len(exclude) != 0 { + lRoutes, err := c.listRoutes(ifName, family, exclude[0].Gateway) + + if err != nil { + return err + } + + routes = lRoutes + } + + ifRoutes := make([]Route, 0) + + for _, rtRoute := range routes { + logging.Log.WriteInfof("Routes: %s", rtRoute.Attributes.Dst.String()) + maskSize := 128 + + if family == unix.AF_INET { + maskSize = 32 + } + + cidr := net.CIDRMask(int(rtRoute.DstLength), maskSize) + route := Route{ + Gateway: rtRoute.Attributes.Gateway, + Destination: net.IPNet{IP: rtRoute.Attributes.Dst, Mask: cidr}, + } + + ifRoutes = append(ifRoutes, route) + } + + shouldExclude := func(r Route) bool { + for _, route := range exclude { + if route.equal(r) { + return false + } + } + return true + } + + toDelete := Filter(ifRoutes, shouldExclude) + + for _, route := range toDelete { + logging.Log.WriteInfof("Deleting route %s", route.Destination.String()) + err := c.DeleteRoute(ifName, route) + + if err != nil { + return err + } + } + + return nil +} + +// listRoutes lists all routes on the interface +func (c *RtNetlinkConfig) listRoutes(ifName string, family uint8, gateway net.IP) ([]rtnetlink.RouteMessage, error) { + iface, err := net.InterfaceByName(ifName) + + if err != nil { + return nil, fmt.Errorf("failed accessing interface %s error %w", ifName, err) + } + + routes, err := c.conn.Route.List() + + if err != nil { + return nil, fmt.Errorf("failed to get route %w", err) + } + + filterFunc := func(r rtnetlink.RouteMessage) bool { + return r.Attributes.Gateway.Equal(gateway) && r.Attributes.OutIface == uint32(iface.Index) + } + + routes = Filter(routes, filterFunc) + return routes, nil +} + +func (c *RtNetlinkConfig) Close() error { + return c.conn.Close() +} diff --git a/pkg/mesh/manager.go b/pkg/mesh/manager.go index ed3698f..bb412b4 100644 --- a/pkg/mesh/manager.go +++ b/pkg/mesh/manager.go @@ -29,6 +29,7 @@ type MeshManager interface { GetClient() *wgctrl.Client GetMeshes() map[string]MeshProvider Prune() error + Close() error } type MeshManagerImpl struct { @@ -77,7 +78,7 @@ func (m *MeshManagerImpl) CreateMesh(devName string, port int) (string, error) { }) if err != nil { - return "", err + return "", fmt.Errorf("error creating mesh: %w", err) } err = m.interfaceManipulator.CreateInterface(&wg.CreateInterfaceParams{ @@ -86,15 +87,10 @@ func (m *MeshManagerImpl) CreateMesh(devName string, port int) (string, error) { }) if err != nil { - return "", nil + return "", fmt.Errorf("error creating mesh: %w", err) } m.Meshes[meshId] = nodeManager - - if err != nil { - logging.Log.WriteErrorf(err.Error()) - } - return meshId, nil } @@ -152,25 +148,13 @@ func (s *MeshManagerImpl) EnableInterface(meshId string) error { return err } - meshNode, err := s.GetSelf(meshId) + err = s.RouteManager.InstallRoutes() if err != nil { return err } - mesh := s.GetMesh(meshId) - - if err != nil { - return err - } - - dev, err := mesh.GetDevice() - - if err != nil { - return err - } - - return s.interfaceManipulator.EnableInterface(dev.Name, meshNode.GetWgHost().String()) + return nil } // GetPublicKey: Gets the public key of the WireGuard mesh @@ -202,6 +186,12 @@ type AddSelfParams struct { // AddSelf adds this host to the mesh func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error { + mesh := s.GetMesh(params.MeshId) + + if mesh == nil { + return fmt.Errorf("addself: mesh %s does not exist", params.MeshId) + } + pubKey, err := s.GetPublicKey(params.MeshId) if err != nil { @@ -221,21 +211,45 @@ func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error { Endpoint: params.Endpoint, }) + device, err := mesh.GetDevice() + + if err != nil { + return fmt.Errorf("failed to get device %w", err) + } + + err = s.interfaceManipulator.AddAddress(device.Name, fmt.Sprintf("%s/64", nodeIP)) + + if err != nil { + return fmt.Errorf("addSelf: failed to add address to dev %w", err) + } + s.Meshes[params.MeshId].AddNode(node) return s.RouteManager.UpdateRoutes() } // LeaveMesh leaves the mesh network func (s *MeshManagerImpl) LeaveMesh(meshId string) error { - _, exists := s.Meshes[meshId] + mesh, exists := s.Meshes[meshId] if !exists { return fmt.Errorf("mesh %s does not exist", meshId) } - // For now just delete the mesh with the ID. + err := s.RouteManager.RemoveRoutes(meshId) + + if err != nil { + return err + } + + device, err := mesh.GetDevice() + + if err != nil { + return err + } + + err = s.interfaceManipulator.RemoveInterface(device.Name) delete(s.Meshes, meshId) - return nil + return err } func (s *MeshManagerImpl) GetSelf(meshId string) (MeshNode, error) { @@ -261,7 +275,13 @@ func (s *MeshManagerImpl) GetSelf(meshId string) (MeshNode, error) { } func (s *MeshManagerImpl) ApplyConfig() error { - return s.configApplyer.ApplyConfig() + err := s.configApplyer.ApplyConfig() + + if err != nil { + return err + } + + return s.RouteManager.InstallRoutes() } func (s *MeshManagerImpl) SetDescription(description string) error { @@ -307,6 +327,24 @@ func (s *MeshManagerImpl) GetMeshes() map[string]MeshProvider { return s.Meshes } +func (s *MeshManagerImpl) Close() error { + for _, mesh := range s.Meshes { + dev, err := mesh.GetDevice() + + if err != nil { + return err + } + + err = s.interfaceManipulator.RemoveInterface(dev.Name) + + if err != nil { + return err + } + } + + return nil +} + // NewMeshManagerParams params required to create an instance of a mesh manager type NewMeshManagerParams struct { Conf conf.WgMeshConfiguration @@ -317,6 +355,7 @@ type NewMeshManagerParams struct { IPAllocator ip.IPAllocator InterfaceManipulator wg.WgInterfaceManipulator ConfigApplyer MeshConfigApplyer + RouteManager RouteManager } // Creates a new instance of a mesh manager with the given parameters @@ -342,7 +381,12 @@ func NewMeshManager(params *NewMeshManagerParams) *MeshManagerImpl { } m.configApplyer = params.ConfigApplyer - m.RouteManager = NewRouteManager(m) + m.RouteManager = params.RouteManager + + if m.RouteManager == nil { + m.RouteManager = NewRouteManager(m) + } + m.idGenerator = params.IdGenerator m.ipAllocator = params.IPAllocator m.interfaceManipulator = params.InterfaceManipulator diff --git a/pkg/mesh/manager_test.go b/pkg/mesh/manager_test.go index 8393b88..65105b1 100644 --- a/pkg/mesh/manager_test.go +++ b/pkg/mesh/manager_test.go @@ -32,6 +32,7 @@ func getMeshManager() *MeshManagerImpl { IPAllocator: &ip.ULABuilder{}, InterfaceManipulator: &wg.WgInterfaceManipulatorStub{}, ConfigApplyer: &MeshConfigApplyerStub{}, + RouteManager: &RouteManagerStub{}, }) return manager @@ -186,7 +187,7 @@ func TestLeaveMeshDeletesMesh(t *testing.T) { err = manager.LeaveMesh(meshId) if err != nil { - t.Error(err) + t.Fatalf("%s", err.Error()) } _, exists := manager.Meshes[meshId] diff --git a/pkg/mesh/route.go b/pkg/mesh/route.go index 83d7bc7..e8d8f4f 100644 --- a/pkg/mesh/route.go +++ b/pkg/mesh/route.go @@ -1,13 +1,20 @@ package mesh import ( + "fmt" + "net" + "github.com/tim-beatham/wgmesh/pkg/ip" + "github.com/tim-beatham/wgmesh/pkg/lib" logging "github.com/tim-beatham/wgmesh/pkg/log" "github.com/tim-beatham/wgmesh/pkg/route" + "golang.org/x/sys/unix" ) type RouteManager interface { UpdateRoutes() error + InstallRoutes() error + RemoveRoutes(meshId string) error } type RouteManagerImpl struct { @@ -49,6 +56,129 @@ func (r *RouteManagerImpl) UpdateRoutes() error { return nil } +// removeRoutes: removes all meshes we are no longer a part of +func (r *RouteManagerImpl) RemoveRoutes(meshId string) error { + ulaBuilder := new(ip.ULABuilder) + meshes := r.meshManager.GetMeshes() + + ipNet, err := ulaBuilder.GetIPNet(meshId) + + if err != nil { + return err + } + + for _, mesh1 := range meshes { + self, err := r.meshManager.GetSelf(meshId) + + if err != nil { + return err + } + + mesh1.RemoveRoutes(self.GetHostEndpoint(), ipNet.String()) + } + return nil +} + +// AddRoute adds a route to the given interface +func (m *RouteManagerImpl) addRoute(ifName string, meshPrefix string, routes ...lib.Route) error { + rtnl, err := lib.NewRtNetlinkConfig() + + if err != nil { + return fmt.Errorf("failed to create config: %w", err) + } + defer rtnl.Close() + + // Delete any routes that may be vacant + err = rtnl.DeleteRoutes(ifName, unix.AF_INET6, routes...) + + if err != nil { + return err + } + + for _, route := range routes { + if route.Destination.String() == meshPrefix { + continue + } + + err = rtnl.AddRoute(ifName, route) + + if err != nil { + return err + } + } + + return nil +} + +func (m *RouteManagerImpl) installRoute(ifName string, meshid string, node MeshNode) error { + routeMapFunc := func(route string) lib.Route { + _, cidr, _ := net.ParseCIDR(route) + + r := lib.Route{ + Destination: *cidr, + Gateway: node.GetWgHost().IP, + } + return r + } + + ipBuilder := &ip.ULABuilder{} + ipNet, err := ipBuilder.GetIPNet(meshid) + + if err != nil { + return err + } + + routes := lib.Map(append(node.GetRoutes(), ipNet.String()), routeMapFunc) + return m.addRoute(ifName, ipNet.String(), routes...) +} + +func (m *RouteManagerImpl) installRoutes(meshProvider MeshProvider) error { + mesh, err := meshProvider.GetMesh() + + if err != nil { + return err + } + + dev, err := meshProvider.GetDevice() + + if err != nil { + return err + } + + self, err := m.meshManager.GetSelf(meshProvider.GetMeshId()) + + if err != nil { + return err + } + + for _, node := range mesh.GetNodes() { + if self.GetHostEndpoint() == node.GetHostEndpoint() { + continue + } + + err = m.installRoute(dev.Name, meshProvider.GetMeshId(), node) + + if err != nil { + return err + } + } + + return nil +} + +// InstallRoutes installs all routes to the RIB +func (r *RouteManagerImpl) InstallRoutes() error { + for _, mesh := range r.meshManager.GetMeshes() { + err := r.installRoutes(mesh) + + if err != nil { + return err + } + } + + return nil +} + func NewRouteManager(m MeshManager) RouteManager { return &RouteManagerImpl{meshManager: m, routeInstaller: route.NewRouteInstaller()} } diff --git a/pkg/mesh/route_stub.go b/pkg/mesh/route_stub.go new file mode 100644 index 0000000..4939bde --- /dev/null +++ b/pkg/mesh/route_stub.go @@ -0,0 +1,16 @@ +package mesh + +type RouteManagerStub struct { +} + +func (r *RouteManagerStub) UpdateRoutes() error { + return nil +} + +func (r *RouteManagerStub) InstallRoutes() error { + return nil +} + +func (r *RouteManagerStub) RemoveRoutes(meshId string) error { + return nil +} diff --git a/pkg/mesh/stub_types.go b/pkg/mesh/stub_types.go index aa3de11..c328c76 100644 --- a/pkg/mesh/stub_types.go +++ b/pkg/mesh/stub_types.go @@ -66,6 +66,11 @@ type MeshProviderStub struct { snapshot *MeshSnapshotStub } +// RemoveRoutes implements MeshProvider. +func (*MeshProviderStub) RemoveRoutes(nodeId string, route ...string) error { + panic("unimplemented") +} + // Prune implements MeshProvider. func (*MeshProviderStub) Prune(pruneAmount int) error { return nil @@ -166,6 +171,11 @@ type MeshManagerStub struct { meshes map[string]MeshProvider } +// Close implements MeshManager. +func (*MeshManagerStub) Close() error { + panic("unimplemented") +} + // Prune implements MeshManager. func (*MeshManagerStub) Prune() error { return nil diff --git a/pkg/mesh/types.go b/pkg/mesh/types.go index 7bffd88..5543820 100644 --- a/pkg/mesh/types.go +++ b/pkg/mesh/types.go @@ -64,6 +64,8 @@ type MeshProvider interface { UpdateTimeStamp(nodeId string) error // AddRoutes: adds routes to the given node AddRoutes(nodeId string, route ...string) error + // DeleteRoutes: deletes the routes from the node + RemoveRoutes(nodeId string, route ...string) error // GetSyncer: returns the automerge syncer for sync GetSyncer() MeshSyncer // SetDescription: sets the description of this automerge data type diff --git a/pkg/robin/requester.go b/pkg/robin/requester.go index a00bf9a..7dba7a4 100644 --- a/pkg/robin/requester.go +++ b/pkg/robin/requester.go @@ -30,6 +30,10 @@ func (n *IpcHandler) CreateMesh(args *ipc.NewMeshArgs, reply *string) error { Endpoint: args.Endpoint, }) + if err != nil { + return err + } + *reply = meshId return err } @@ -122,6 +126,7 @@ func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error { if mesh == nil { return errors.New("mesh does not exist") } + nodes := make([]ctrlserver.MeshNode, len(meshSnapshot.GetNodes())) i := 0 diff --git a/pkg/sync/syncer.go b/pkg/sync/syncer.go index c7f4d6d..834317b 100644 --- a/pkg/sync/syncer.go +++ b/pkg/sync/syncer.go @@ -31,7 +31,11 @@ type SyncerImpl struct { // Sync: Sync random nodes func (s *SyncerImpl) Sync(meshId string) error { logging.Log.WriteInfof("UPDATING WG CONF") - s.manager.ApplyConfig() + err := s.manager.ApplyConfig() + + if err != nil { + logging.Log.WriteInfof("Failed to update config %w", err) + } if !s.manager.HasChanges(meshId) && s.infectionCount == 0 { logging.Log.WriteInfof("No changes for %s", meshId) diff --git a/pkg/wg/stubs.go b/pkg/wg/stubs.go index ff9e16e..5adcfc5 100644 --- a/pkg/wg/stubs.go +++ b/pkg/wg/stubs.go @@ -6,6 +6,10 @@ func (i *WgInterfaceManipulatorStub) CreateInterface(params *CreateInterfacePara return nil } -func (i *WgInterfaceManipulatorStub) EnableInterface(ifName string, ip string) error { +func (i *WgInterfaceManipulatorStub) AddAddress(ifName string, addr string) error { + return nil +} + +func (i *WgInterfaceManipulatorStub) RemoveInterface(ifName string) error { return nil } diff --git a/pkg/wg/types.go b/pkg/wg/types.go index 12685a7..a7a44b9 100644 --- a/pkg/wg/types.go +++ b/pkg/wg/types.go @@ -16,7 +16,8 @@ type CreateInterfaceParams struct { type WgInterfaceManipulator interface { // CreateInterface creates a WireGuard interface CreateInterface(params *CreateInterfaceParams) error - // Enable interface enables the given interface with - // the IP. It overrides the IP at the interface - EnableInterface(ifName string, ip string) error + // AddAddress adds an address to the given interface name + AddAddress(ifName string, addr string) error + // RemoveInterface removes the specified interface + RemoveInterface(ifName string) error } diff --git a/pkg/wg/wg.go b/pkg/wg/wg.go index ceea489..691d197 100644 --- a/pkg/wg/wg.go +++ b/pkg/wg/wg.go @@ -1,50 +1,37 @@ package wg import ( - "errors" "fmt" - "net" - "os/exec" + "github.com/tim-beatham/wgmesh/pkg/lib" logging "github.com/tim-beatham/wgmesh/pkg/log" "golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) -// createInterface uses ip link to create an interface. If the interface exists -// it returns an error -func createInterface(ifName string) error { - _, err := net.InterfaceByName(ifName) - - if err == nil { - err = flushInterface(ifName) - return err - } - - // Check if the interface exists - cmd := exec.Command("/usr/bin/ip", "link", "add", "dev", ifName, "type", "wireguard") - - if err := cmd.Run(); err != nil { - return err - } - return nil -} - type WgInterfaceManipulatorImpl struct { client *wgctrl.Client } +// CreateInterface creates a WireGuard interface func (m *WgInterfaceManipulatorImpl) CreateInterface(params *CreateInterfaceParams) error { - err := createInterface(params.IfName) + rtnl, err := lib.NewRtNetlinkConfig() if err != nil { - return err + return fmt.Errorf("failed to access link: %w", err) + } + defer rtnl.Close() + + err = rtnl.CreateLink(params.IfName) + + if err != nil { + return fmt.Errorf("failed to create link: %w", err) } privateKey, err := wgtypes.GeneratePrivateKey() if err != nil { - return err + return fmt.Errorf("failed to create private key: %w", err) } var cfg wgtypes.Config = wgtypes.Config{ @@ -52,59 +39,44 @@ func (m *WgInterfaceManipulatorImpl) CreateInterface(params *CreateInterfacePara ListenPort: ¶ms.Port, } - m.client.ConfigureDevice(params.IfName, cfg) + err = m.client.ConfigureDevice(params.IfName, cfg) + + if err != nil { + return fmt.Errorf("failed to configure dev: %w", err) + } + + logging.Log.WriteInfof("ip link set up dev %s type wireguard", params.IfName) return nil } -// flushInterface flushes the specified interface -func flushInterface(ifName string) error { - _, err := net.InterfaceByName(ifName) +// Add an address to the given interface +func (m *WgInterfaceManipulatorImpl) AddAddress(ifName string, addr string) error { + rtnl, err := lib.NewRtNetlinkConfig() if err != nil { - return &WgError{msg: fmt.Sprintf("Interface %s does not exist cannot flush", ifName)} + return fmt.Errorf("failed to create config: %w", err) + } + defer rtnl.Close() + + err = rtnl.AddAddress(ifName, addr) + + if err != nil { + err = fmt.Errorf("failed to add address: %w", err) } - cmd := exec.Command("/usr/bin/ip", "addr", "flush", "dev", ifName) - - if err := cmd.Run(); err != nil { - logging.Log.WriteErrorf(fmt.Sprintf("%s error flushing interface %s", err.Error(), ifName)) - return &WgError{msg: fmt.Sprintf("Failed to flush interface %s", ifName)} - } - - return nil + return err } -// EnableInterface flushes the interface and sets the ip address of the -// interface -func (m *WgInterfaceManipulatorImpl) EnableInterface(ifName string, ip string) error { - if len(ifName) == 0 { - return errors.New("ifName not provided") - } - - err := flushInterface(ifName) +// RemoveInterface implements WgInterfaceManipulator. +func (*WgInterfaceManipulatorImpl) RemoveInterface(ifName string) error { + rtnl, err := lib.NewRtNetlinkConfig() if err != nil { - return err + return fmt.Errorf("failed to create config: %w", err) } + defer rtnl.Close() - cmd := exec.Command("/usr/bin/ip", "link", "set", "up", "dev", ifName) - - if err := cmd.Run(); err != nil { - return err - } - - hostIp, _, err := net.ParseCIDR(ip) - - if err != nil { - return err - } - - cmd = exec.Command("/usr/bin/ip", "addr", "add", hostIp.String()+"/64", "dev", ifName) - - if err := cmd.Run(); err != nil { - return err - } - return nil + return rtnl.DeleteLink(ifName) } func NewWgInterfaceManipulator(client *wgctrl.Client) WgInterfaceManipulator {