From 1012172f0444cfcee0ed6911f751010e8809118d Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 5 Sep 2022 09:06:35 +0200 Subject: [PATCH] Add routing peer support (#441) Handle routes updates from management Manage routing firewall rules Manage peer RIB table Add get peer and get notification channel from the status recorder Update interface peers allowed IPs --- README.md | 2 +- client/internal/engine.go | 36 ++ client/internal/engine_test.go | 140 ++++++ client/internal/routemanager/client.go | 285 +++++++++++++ .../routemanager/common_linux_test.go | 75 ++++ client/internal/routemanager/firewall.go | 12 + .../internal/routemanager/firewall_linux.go | 55 +++ .../routemanager/firewall_nonlinux.go | 27 ++ .../internal/routemanager/iptables_linux.go | 403 ++++++++++++++++++ .../routemanager/iptables_linux_test.go | 247 +++++++++++ client/internal/routemanager/manager.go | 181 ++++++++ client/internal/routemanager/manager_test.go | 370 ++++++++++++++++ client/internal/routemanager/mock.go | 27 ++ .../internal/routemanager/nftables_linux.go | 384 +++++++++++++++++ .../routemanager/nftables_linux_test.go | 270 ++++++++++++ client/internal/routemanager/server.go | 67 +++ client/internal/routemanager/systemops.go | 55 +++ .../internal/routemanager/systemops_linux.go | 73 ++++ .../routemanager/systemops_nonlinux.go | 41 ++ .../internal/routemanager/systemops_test.go | 68 +++ client/status/status.go | 44 +- client/status/status_test.go | 40 ++ go.mod | 4 + go.sum | 10 + iface/configuration.go | 118 +++++ iface/iface_test.go | 31 +- 26 files changed, 3030 insertions(+), 35 deletions(-) create mode 100644 client/internal/routemanager/client.go create mode 100644 client/internal/routemanager/common_linux_test.go create mode 100644 client/internal/routemanager/firewall.go create mode 100644 client/internal/routemanager/firewall_linux.go create mode 100644 client/internal/routemanager/firewall_nonlinux.go create mode 100644 client/internal/routemanager/iptables_linux.go create mode 100644 client/internal/routemanager/iptables_linux_test.go create mode 100644 client/internal/routemanager/manager.go create mode 100644 client/internal/routemanager/manager_test.go create mode 100644 client/internal/routemanager/mock.go create mode 100644 client/internal/routemanager/nftables_linux.go create mode 100644 client/internal/routemanager/nftables_linux_test.go create mode 100644 client/internal/routemanager/server.go create mode 100644 client/internal/routemanager/systemops.go create mode 100644 client/internal/routemanager/systemops_linux.go create mode 100644 client/internal/routemanager/systemops_nonlinux.go create mode 100644 client/internal/routemanager/systemops_test.go diff --git a/README.md b/README.md index 40672490a..0b5c06a88 100644 --- a/README.md +++ b/README.md @@ -56,7 +56,7 @@ NetBird creates an overlay peer-to-peer network connecting machines automaticall - \[x] Remote SSH access without managing SSH keys. **Coming soon:** -- \[ ] Router nodes +- \[ ] Network Routes. - \[ ] Private DNS. - \[ ] Mobile clients. - \[ ] Network Activity Monitoring. diff --git a/client/internal/engine.go b/client/internal/engine.go index 0a322c483..08dc4de4b 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -3,8 +3,10 @@ package internal import ( "context" "fmt" + "github.com/netbirdio/netbird/client/internal/routemanager" nbssh "github.com/netbirdio/netbird/client/ssh" nbstatus "github.com/netbirdio/netbird/client/status" + "github.com/netbirdio/netbird/route" "math/rand" "net" "reflect" @@ -99,6 +101,8 @@ type Engine struct { sshServer nbssh.Server statusRecorder *nbstatus.Status + + routeManager routemanager.Manager } // Peer is an instance of the Connection Peer @@ -182,6 +186,10 @@ func (e *Engine) Stop() error { } } + if e.routeManager != nil { + e.routeManager.Stop() + } + log.Infof("stopped Netbird Engine") return nil @@ -232,6 +240,8 @@ func (e *Engine) Start() error { return err } + e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder) + e.receiveSignalEvents() e.receiveManagementEvents() @@ -619,11 +629,37 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { } } } + protoRoutes := networkMap.GetRoutes() + if protoRoutes == nil { + protoRoutes = []*mgmProto.Route{} + } + err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes)) + if err != nil { + log.Errorf("failed to update routes, err: %v", err) + } e.networkSerial = serial return nil } +func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route { + routes := make([]*route.Route, 0) + for _, protoRoute := range protoRoutes { + _, prefix, _ := route.ParseNetwork(protoRoute.Network) + convertedRoute := &route.Route{ + ID: protoRoute.ID, + Network: prefix, + NetID: protoRoute.NetID, + NetworkType: route.NetworkType(protoRoute.NetworkType), + Peer: protoRoute.Peer, + Metric: int(protoRoute.Metric), + Masquerade: protoRoute.Masquerade, + } + routes = append(routes, convertedRoute) + } + return routes +} + // addNewPeers adds peers that were not know before but arrived from the Management service with the update func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error { for _, p := range peersUpdate { diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 3f8b269a0..e68da6fb8 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -3,11 +3,14 @@ package internal import ( "context" "fmt" + "github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/ssh" nbstatus "github.com/netbirdio/netbird/client/status" "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/route" "github.com/stretchr/testify/assert" "net" + "net/netip" "os" "path/filepath" "runtime" @@ -196,6 +199,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { WgPort: 33100, }, nbstatus.NewRecorder()) engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", iface.DefaultMTU) + engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), engine.wgInterface, engine.statusRecorder) type testCase struct { name string @@ -426,6 +430,142 @@ func TestEngine_Sync(t *testing.T) { } } +func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) { + testCases := []struct { + name string + inputErr error + networkMap *mgmtProto.NetworkMap + expectedLen int + expectedRoutes []*route.Route + expectedSerial uint64 + }{ + { + name: "Routes Update Should Be Passed To Manager", + networkMap: &mgmtProto.NetworkMap{ + Serial: 1, + PeerConfig: nil, + RemotePeersIsEmpty: false, + Routes: []*mgmtProto.Route{ + { + ID: "a", + Network: "192.168.0.0/24", + NetID: "n1", + Peer: "p1", + NetworkType: 1, + Masquerade: false, + }, + { + ID: "b", + Network: "192.168.1.0/24", + NetID: "n2", + Peer: "p1", + NetworkType: 1, + Masquerade: false, + }, + }, + }, + expectedLen: 2, + expectedRoutes: []*route.Route{ + { + ID: "a", + Network: netip.MustParsePrefix("192.168.0.0/24"), + NetID: "n1", + Peer: "p1", + NetworkType: 1, + Masquerade: false, + }, + { + ID: "b", + Network: netip.MustParsePrefix("192.168.1.0/24"), + NetID: "n2", + Peer: "p1", + NetworkType: 1, + Masquerade: false, + }, + }, + expectedSerial: 1, + }, + { + name: "Empty Routes Update Should Be Passed", + networkMap: &mgmtProto.NetworkMap{ + Serial: 1, + PeerConfig: nil, + RemotePeersIsEmpty: false, + Routes: nil, + }, + expectedLen: 0, + expectedRoutes: []*route.Route{}, + expectedSerial: 1, + }, + { + name: "Error Shouldn't Break Engine", + inputErr: fmt.Errorf("mocking error"), + networkMap: &mgmtProto.NetworkMap{ + Serial: 1, + PeerConfig: nil, + RemotePeersIsEmpty: false, + Routes: nil, + }, + expectedLen: 0, + expectedRoutes: []*route.Route{}, + expectedSerial: 1, + }, + } + + for n, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + // test setup + key, err := wgtypes.GeneratePrivateKey() + if err != nil { + t.Fatal(err) + return + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + wgIfaceName := fmt.Sprintf("utun%d", 104+n) + wgAddr := fmt.Sprintf("100.66.%d.1/24", n) + + engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, &EngineConfig{ + WgIfaceName: wgIfaceName, + WgAddr: wgAddr, + WgPrivateKey: key, + WgPort: 33100, + }, nbstatus.NewRecorder()) + engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU) + assert.NoError(t, err, "shouldn't return error") + input := struct { + inputSerial uint64 + inputRoutes []*route.Route + }{} + + mockRouteManager := &routemanager.MockManager{ + UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error { + input.inputSerial = updateSerial + input.inputRoutes = newRoutes + return testCase.inputErr + }, + } + + engine.routeManager = mockRouteManager + + defer func() { + exitErr := engine.Stop() + if exitErr != nil { + return + } + }() + + err = engine.updateNetworkMap(testCase.networkMap) + assert.NoError(t, err, "shouldn't return error") + assert.Equal(t, testCase.expectedSerial, input.inputSerial, "serial should match") + assert.Len(t, input.inputRoutes, testCase.expectedLen, "routes len should match") + assert.Equal(t, testCase.expectedRoutes, input.inputRoutes, "routes should match") + }) + } +} + func TestEngine_MultiplePeers(t *testing.T) { // log.SetLevel(log.DebugLevel) diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go new file mode 100644 index 000000000..c18b75e4d --- /dev/null +++ b/client/internal/routemanager/client.go @@ -0,0 +1,285 @@ +package routemanager + +import ( + "context" + "fmt" + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/status" + "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/route" + log "github.com/sirupsen/logrus" + "net/netip" +) + +type routerPeerStatus struct { + connected bool + relayed bool + direct bool +} + +type routesUpdate struct { + updateSerial uint64 + routes []*route.Route +} + +type clientNetwork struct { + ctx context.Context + stop context.CancelFunc + statusRecorder *status.Status + wgInterface *iface.WGIface + routes map[string]*route.Route + routeUpdate chan routesUpdate + peerStateUpdate chan struct{} + routePeersNotifiers map[string]chan struct{} + chosenRoute *route.Route + network netip.Prefix + updateSerial uint64 +} + +func newClientNetworkWatcher(ctx context.Context, wgInterface *iface.WGIface, statusRecorder *status.Status, network netip.Prefix) *clientNetwork { + ctx, cancel := context.WithCancel(ctx) + client := &clientNetwork{ + ctx: ctx, + stop: cancel, + statusRecorder: statusRecorder, + wgInterface: wgInterface, + routes: make(map[string]*route.Route), + routePeersNotifiers: make(map[string]chan struct{}), + routeUpdate: make(chan routesUpdate), + peerStateUpdate: make(chan struct{}), + network: network, + } + return client +} + +func getClientNetworkID(input *route.Route) string { + return input.NetID + "-" + input.Network.String() +} + +func (c *clientNetwork) getRouterPeerStatuses() map[string]routerPeerStatus { + routePeerStatuses := make(map[string]routerPeerStatus) + for _, r := range c.routes { + peerStatus, err := c.statusRecorder.GetPeer(r.Peer) + if err != nil { + log.Debugf("couldn't fetch peer state: %v", err) + continue + } + routePeerStatuses[r.ID] = routerPeerStatus{ + connected: peerStatus.ConnStatus == peer.StatusConnected.String(), + relayed: peerStatus.Relayed, + direct: peerStatus.Direct, + } + } + return routePeerStatuses +} + +func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]routerPeerStatus) string { + var chosen string + chosenScore := 0 + + currID := "" + if c.chosenRoute != nil { + currID = c.chosenRoute.ID + } + + for _, r := range c.routes { + tempScore := 0 + peerStatus, found := routePeerStatuses[r.ID] + if !found || !peerStatus.connected { + continue + } + if r.Metric < route.MaxMetric { + metricDiff := route.MaxMetric - r.Metric + tempScore = metricDiff * 10 + } + if !peerStatus.relayed { + tempScore++ + } + if !peerStatus.direct { + tempScore++ + } + if tempScore > chosenScore || (tempScore == chosenScore && currID == r.ID) { + chosen = r.ID + chosenScore = tempScore + } + } + + if chosen == "" { + var peers []string + for _, r := range c.routes { + peers = append(peers, r.Peer) + } + log.Warnf("no route was chosen for network %s because no peers from list %s were connected", c.network, peers) + } else if chosen != currID { + log.Infof("new chosen route is %s with peer %s with score %d", chosen, c.routes[chosen].Peer, chosenScore) + } + + return chosen +} + +func (c *clientNetwork) watchPeerStatusChanges(ctx context.Context, peerKey string, peerStateUpdate chan struct{}, closer chan struct{}) { + for { + select { + case <-ctx.Done(): + return + case <-closer: + return + case <-c.statusRecorder.GetPeerStateChangeNotifier(peerKey): + state, err := c.statusRecorder.GetPeer(peerKey) + if err != nil || state.ConnStatus == peer.StatusConnecting.String() { + continue + } + peerStateUpdate <- struct{}{} + log.Debugf("triggered route state update for Peer %s, state: %s", peerKey, state.ConnStatus) + } + } +} + +func (c *clientNetwork) startPeersStatusChangeWatcher() { + for _, r := range c.routes { + _, found := c.routePeersNotifiers[r.Peer] + if !found { + c.routePeersNotifiers[r.Peer] = make(chan struct{}) + go c.watchPeerStatusChanges(c.ctx, r.Peer, c.peerStateUpdate, c.routePeersNotifiers[r.Peer]) + } + } +} + +func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error { + state, err := c.statusRecorder.GetPeer(peerKey) + if err != nil || state.ConnStatus != peer.StatusConnected.String() { + return nil + } + + err = c.wgInterface.RemoveAllowedIP(peerKey, c.network.String()) + if err != nil { + return fmt.Errorf("couldn't remove allowed IP %s removed for peer %s, err: %v", + c.network, c.chosenRoute.Peer, err) + } + return nil +} + +func (c *clientNetwork) removeRouteFromPeerAndSystem() error { + if c.chosenRoute != nil { + err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer) + if err != nil { + return err + } + err = removeFromRouteTableIfNonSystem(c.network, c.wgInterface.GetAddress().IP.String()) + if err != nil { + return fmt.Errorf("couldn't remove route %s from system, err: %v", + c.network, err) + } + } + return nil +} + +func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { + + var err error + + routerPeerStatuses := c.getRouterPeerStatuses() + + chosen := c.getBestRouteFromStatuses(routerPeerStatuses) + if chosen == "" { + err = c.removeRouteFromPeerAndSystem() + if err != nil { + return err + } + + c.chosenRoute = nil + + return nil + } + + if c.chosenRoute != nil && c.chosenRoute.ID == chosen { + if c.chosenRoute.IsEqual(c.routes[chosen]) { + return nil + } + } + + if c.chosenRoute != nil { + err = c.removeRouteFromWireguardPeer(c.chosenRoute.Peer) + if err != nil { + return err + } + } else { + err = addToRouteTableIfNoExists(c.network, c.wgInterface.GetAddress().IP.String()) + if err != nil { + return fmt.Errorf("route %s couldn't be added for peer %s, err: %v", + c.chosenRoute.Network.String(), c.wgInterface.GetAddress().IP.String(), err) + } + } + + c.chosenRoute = c.routes[chosen] + err = c.wgInterface.AddAllowedIP(c.chosenRoute.Peer, c.network.String()) + if err != nil { + log.Errorf("couldn't add allowed IP %s added for peer %s, err: %v", + c.network, c.chosenRoute.Peer, err) + } + + return nil +} + +func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) { + go func() { + c.routeUpdate <- update + }() +} + +func (c *clientNetwork) handleUpdate(update routesUpdate) { + updateMap := make(map[string]*route.Route) + + for _, r := range update.routes { + updateMap[r.ID] = r + } + + for id, r := range c.routes { + _, found := updateMap[id] + if !found { + close(c.routePeersNotifiers[r.Peer]) + delete(c.routePeersNotifiers, r.Peer) + } + } + + c.routes = updateMap +} + +// peersStateAndUpdateWatcher is the main point of reacting on client network routing events. +// All the processing related to the client network should be done here. Thread-safe. +func (c *clientNetwork) peersStateAndUpdateWatcher() { + for { + select { + case <-c.ctx.Done(): + log.Debugf("stopping watcher for network %s", c.network) + err := c.removeRouteFromPeerAndSystem() + if err != nil { + log.Error(err) + } + return + case <-c.peerStateUpdate: + err := c.recalculateRouteAndUpdatePeerAndSystem() + if err != nil { + log.Error(err) + } + case update := <-c.routeUpdate: + if update.updateSerial < c.updateSerial { + log.Warnf("received a routes update with smaller serial number, ignoring it") + continue + } + + log.Debugf("received a new client network route update for %s", c.network) + + c.handleUpdate(update) + + c.updateSerial = update.updateSerial + + err := c.recalculateRouteAndUpdatePeerAndSystem() + if err != nil { + log.Error(err) + } + + c.startPeersStatusChangeWatcher() + } + } +} diff --git a/client/internal/routemanager/common_linux_test.go b/client/internal/routemanager/common_linux_test.go new file mode 100644 index 000000000..d27f532cd --- /dev/null +++ b/client/internal/routemanager/common_linux_test.go @@ -0,0 +1,75 @@ +package routemanager + +var insertRuleTestCases = []struct { + name string + inputPair routerPair + ipVersion string +}{ + { + name: "Insert Forwarding IPV4 Rule", + inputPair: routerPair{ + ID: "zxa", + source: "100.100.100.1/32", + destination: "100.100.200.0/24", + masquerade: false, + }, + ipVersion: ipv4, + }, + { + name: "Insert Forwarding And Nat IPV4 Rules", + inputPair: routerPair{ + ID: "zxa", + source: "100.100.100.1/32", + destination: "100.100.200.0/24", + masquerade: true, + }, + ipVersion: ipv4, + }, + { + name: "Insert Forwarding IPV6 Rule", + inputPair: routerPair{ + ID: "zxa", + source: "fc00::1/128", + destination: "fc12::/64", + masquerade: false, + }, + ipVersion: ipv6, + }, + { + name: "Insert Forwarding And Nat IPV6 Rules", + inputPair: routerPair{ + ID: "zxa", + source: "fc00::1/128", + destination: "fc12::/64", + masquerade: true, + }, + ipVersion: ipv6, + }, +} + +var removeRuleTestCases = []struct { + name string + inputPair routerPair + ipVersion string +}{ + { + name: "Remove Forwarding And Nat IPV4 Rules", + inputPair: routerPair{ + ID: "zxa", + source: "100.100.100.1/32", + destination: "100.100.200.0/24", + masquerade: true, + }, + ipVersion: ipv4, + }, + { + name: "Remove Forwarding And Nat IPV6 Rules", + inputPair: routerPair{ + ID: "zxa", + source: "fc00::1/128", + destination: "fc12::/64", + masquerade: true, + }, + ipVersion: ipv6, + }, +} diff --git a/client/internal/routemanager/firewall.go b/client/internal/routemanager/firewall.go new file mode 100644 index 000000000..fc6ff58f1 --- /dev/null +++ b/client/internal/routemanager/firewall.go @@ -0,0 +1,12 @@ +package routemanager + +type firewallManager interface { + // RestoreOrCreateContainers restores or creates a firewall container set of rules, tables and default rules + RestoreOrCreateContainers() error + // InsertRoutingRules inserts a routing firewall rule + InsertRoutingRules(pair routerPair) error + // RemoveRoutingRules removes a routing firewall rule + RemoveRoutingRules(pair routerPair) error + // CleanRoutingRules cleans a firewall set of containers + CleanRoutingRules() +} diff --git a/client/internal/routemanager/firewall_linux.go b/client/internal/routemanager/firewall_linux.go new file mode 100644 index 000000000..5673dd3fc --- /dev/null +++ b/client/internal/routemanager/firewall_linux.go @@ -0,0 +1,55 @@ +package routemanager + +import ( + "context" + "fmt" + "github.com/coreos/go-iptables/iptables" + log "github.com/sirupsen/logrus" +) +import "github.com/google/nftables" + +const ( + ipv6Forwarding = "netbird-rt-ipv6-forwarding" + ipv4Forwarding = "netbird-rt-ipv4-forwarding" + ipv6Nat = "netbird-rt-ipv6-nat" + ipv4Nat = "netbird-rt-ipv4-nat" + natFormat = "netbird-nat-%s" + forwardingFormat = "netbird-fwd-%s" + ipv6 = "ipv6" + ipv4 = "ipv4" +) + +func genKey(format string, input string) string { + return fmt.Sprintf(format, input) +} + +// NewFirewall if supported, returns an iptables manager, otherwise returns a nftables manager +func NewFirewall(parentCTX context.Context) firewallManager { + ctx, cancel := context.WithCancel(parentCTX) + + if isIptablesSupported() { + log.Debugf("iptables is supported") + ipv4Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4) + ipv6Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv6) + + return &iptablesManager{ + ctx: ctx, + stop: cancel, + ipv4Client: ipv4Client, + ipv6Client: ipv6Client, + rules: make(map[string]map[string][]string), + } + } + + log.Debugf("iptables is not supported, using nftables") + + manager := &nftablesManager{ + ctx: ctx, + stop: cancel, + conn: &nftables.Conn{}, + chains: make(map[string]map[string]*nftables.Chain), + rules: make(map[string]*nftables.Rule), + } + + return manager +} diff --git a/client/internal/routemanager/firewall_nonlinux.go b/client/internal/routemanager/firewall_nonlinux.go new file mode 100644 index 000000000..172659f26 --- /dev/null +++ b/client/internal/routemanager/firewall_nonlinux.go @@ -0,0 +1,27 @@ +//go:build !linux +// +build !linux + +package routemanager + +import "context" + +type unimplementedFirewall struct{} + +func (unimplementedFirewall) RestoreOrCreateContainers() error { + return nil +} +func (unimplementedFirewall) InsertRoutingRules(pair routerPair) error { + return nil +} +func (unimplementedFirewall) RemoveRoutingRules(pair routerPair) error { + return nil +} + +func (unimplementedFirewall) CleanRoutingRules() { + return +} + +// NewFirewall returns an unimplemented Firewall manager +func NewFirewall(parentCtx context.Context) firewallManager { + return unimplementedFirewall{} +} diff --git a/client/internal/routemanager/iptables_linux.go b/client/internal/routemanager/iptables_linux.go new file mode 100644 index 000000000..1bc56e44d --- /dev/null +++ b/client/internal/routemanager/iptables_linux.go @@ -0,0 +1,403 @@ +package routemanager + +import ( + "context" + "fmt" + "github.com/coreos/go-iptables/iptables" + log "github.com/sirupsen/logrus" + "net/netip" + "os/exec" + "strings" + "sync" +) + +func isIptablesSupported() bool { + _, err4 := exec.LookPath("iptables") + _, err6 := exec.LookPath("ip6tables") + return err4 == nil && err6 == nil +} + +// constants needed to manage and create iptable rules +const ( + iptablesFilterTable = "filter" + iptablesNatTable = "nat" + iptablesForwardChain = "FORWARD" + iptablesPostRoutingChain = "POSTROUTING" + iptablesRoutingNatChain = "NETBIRD-RT-NAT" + iptablesRoutingForwardingChain = "NETBIRD-RT-FWD" + routingFinalForwardJump = "ACCEPT" + routingFinalNatJump = "MASQUERADE" +) + +// some presets for building nftable rules +var ( + iptablesDefaultForwardingRule = []string{"-j", iptablesRoutingForwardingChain, "-m", "comment", "--comment"} + iptablesDefaultNetbirdForwardingRule = []string{"-j", "RETURN"} + iptablesDefaultNatRule = []string{"-j", iptablesRoutingNatChain, "-m", "comment", "--comment"} + iptablesDefaultNetbirdNatRule = []string{"-j", "RETURN"} +) + +type iptablesManager struct { + ctx context.Context + stop context.CancelFunc + ipv4Client *iptables.IPTables + ipv6Client *iptables.IPTables + rules map[string]map[string][]string + mux sync.Mutex +} + +// CleanRoutingRules cleans existing iptables resources that we created by the agent +func (i *iptablesManager) CleanRoutingRules() { + i.mux.Lock() + defer i.mux.Unlock() + + err := i.cleanJumpRules() + if err != nil { + log.Error(err) + } + + log.Debug("flushing tables") + errMSGFormat := "iptables: failed cleaning %s chain %s,error: %v" + err = i.ipv4Client.ClearAndDeleteChain(iptablesFilterTable, iptablesRoutingForwardingChain) + if err != nil { + log.Errorf(errMSGFormat, ipv4, iptablesRoutingForwardingChain, err) + } + + err = i.ipv4Client.ClearAndDeleteChain(iptablesNatTable, iptablesRoutingNatChain) + if err != nil { + log.Errorf(errMSGFormat, ipv4, iptablesRoutingNatChain, err) + } + + err = i.ipv6Client.ClearAndDeleteChain(iptablesFilterTable, iptablesRoutingForwardingChain) + if err != nil { + log.Errorf(errMSGFormat, ipv6, iptablesRoutingForwardingChain, err) + } + + err = i.ipv6Client.ClearAndDeleteChain(iptablesNatTable, iptablesRoutingNatChain) + if err != nil { + log.Errorf(errMSGFormat, ipv6, iptablesRoutingNatChain, err) + } + + log.Info("done cleaning up iptables rules") +} + +// RestoreOrCreateContainers restores existing iptables containers (chains and rules) +// if they don't exist, we create them +func (i *iptablesManager) RestoreOrCreateContainers() error { + i.mux.Lock() + defer i.mux.Unlock() + + if i.rules[ipv4][ipv4Forwarding] != nil && i.rules[ipv6][ipv6Forwarding] != nil { + return nil + } + + errMSGFormat := "iptables: failed creating %s chain %s,error: %v" + + err := createChain(i.ipv4Client, iptablesFilterTable, iptablesRoutingForwardingChain) + if err != nil { + return fmt.Errorf(errMSGFormat, ipv4, iptablesRoutingForwardingChain, err) + } + + err = createChain(i.ipv4Client, iptablesNatTable, iptablesRoutingNatChain) + if err != nil { + return fmt.Errorf(errMSGFormat, ipv4, iptablesRoutingNatChain, err) + } + + err = createChain(i.ipv6Client, iptablesFilterTable, iptablesRoutingForwardingChain) + if err != nil { + return fmt.Errorf(errMSGFormat, ipv6, iptablesRoutingForwardingChain, err) + } + + err = createChain(i.ipv6Client, iptablesNatTable, iptablesRoutingNatChain) + if err != nil { + return fmt.Errorf(errMSGFormat, ipv6, iptablesRoutingNatChain, err) + } + + err = i.restoreRules(i.ipv4Client) + if err != nil { + return fmt.Errorf("iptables: error while restoring ipv4 rules: %v", err) + } + + err = i.restoreRules(i.ipv6Client) + if err != nil { + return fmt.Errorf("iptables: error while restoring ipv6 rules: %v", err) + } + + err = i.addJumpRules() + if err != nil { + return fmt.Errorf("iptables: error while creating jump rules: %v", err) + } + + return nil +} + +// addJumpRules create jump rules to send packets to NetBird chains +func (i *iptablesManager) addJumpRules() error { + err := i.cleanJumpRules() + if err != nil { + return err + } + rule := append(iptablesDefaultForwardingRule, ipv4Forwarding) + err = i.ipv4Client.Insert(iptablesFilterTable, iptablesForwardChain, 1, rule...) + if err != nil { + return err + } + + i.rules[ipv4][ipv4Forwarding] = rule + + rule = append(iptablesDefaultNatRule, ipv4Nat) + err = i.ipv4Client.Insert(iptablesNatTable, iptablesPostRoutingChain, 1, rule...) + if err != nil { + return err + } + i.rules[ipv4][ipv4Nat] = rule + + rule = append(iptablesDefaultForwardingRule, ipv6Forwarding) + err = i.ipv6Client.Insert(iptablesFilterTable, iptablesForwardChain, 1, rule...) + if err != nil { + return err + } + i.rules[ipv6][ipv6Forwarding] = rule + + rule = append(iptablesDefaultNatRule, ipv6Nat) + err = i.ipv6Client.Insert(iptablesNatTable, iptablesPostRoutingChain, 1, rule...) + if err != nil { + return err + } + i.rules[ipv6][ipv6Nat] = rule + + return nil +} + +// cleanJumpRules cleans jump rules that was sending packets to NetBird chains +func (i *iptablesManager) cleanJumpRules() error { + var err error + errMSGFormat := "iptables: failed cleaning rule from %s chain %s,err: %v" + rule, found := i.rules[ipv4][ipv4Forwarding] + if found { + log.Debugf("iptables: removing %s rule: %s ", ipv4, ipv4Forwarding) + err = i.ipv4Client.DeleteIfExists(iptablesFilterTable, iptablesForwardChain, rule...) + if err != nil { + return fmt.Errorf(errMSGFormat, ipv4, iptablesForwardChain, err) + } + } + rule, found = i.rules[ipv4][ipv4Nat] + if found { + log.Debugf("iptables: removing %s rule: %s ", ipv4, ipv4Nat) + err = i.ipv4Client.DeleteIfExists(iptablesNatTable, iptablesPostRoutingChain, rule...) + if err != nil { + return fmt.Errorf(errMSGFormat, ipv4, iptablesPostRoutingChain, err) + } + } + rule, found = i.rules[ipv6][ipv6Forwarding] + if found { + log.Debugf("iptables: removing %s rule: %s ", ipv6, ipv6Forwarding) + err = i.ipv6Client.DeleteIfExists(iptablesFilterTable, iptablesForwardChain, rule...) + if err != nil { + return fmt.Errorf(errMSGFormat, ipv6, iptablesForwardChain, err) + } + } + rule, found = i.rules[ipv6][ipv6Nat] + if found { + log.Debugf("iptables: removing %s rule: %s ", ipv6, ipv6Nat) + err = i.ipv6Client.DeleteIfExists(iptablesNatTable, iptablesPostRoutingChain, rule...) + if err != nil { + return fmt.Errorf(errMSGFormat, ipv6, iptablesPostRoutingChain, err) + } + } + return nil +} + +func iptablesProtoToString(proto iptables.Protocol) string { + if proto == iptables.ProtocolIPv6 { + return ipv6 + } + return ipv4 +} + +// restoreRules restores existing NetBird rules +func (i *iptablesManager) restoreRules(iptablesClient *iptables.IPTables) error { + ipVersion := iptablesProtoToString(iptablesClient.Proto()) + + if i.rules[ipVersion] == nil { + i.rules[ipVersion] = make(map[string][]string) + } + table := iptablesFilterTable + for _, chain := range []string{iptablesForwardChain, iptablesRoutingForwardingChain} { + rules, err := iptablesClient.List(table, chain) + if err != nil { + return err + } + for _, ruleString := range rules { + rule := strings.Fields(ruleString) + id := getRuleRouteID(rule) + if id != "" { + i.rules[ipVersion][id] = rule[2:] + } + } + } + + table = iptablesNatTable + for _, chain := range []string{iptablesPostRoutingChain, iptablesRoutingNatChain} { + rules, err := iptablesClient.List(table, chain) + if err != nil { + return err + } + for _, ruleString := range rules { + rule := strings.Fields(ruleString) + id := getRuleRouteID(rule) + if id != "" { + i.rules[ipVersion][id] = rule[2:] + } + } + } + + return nil +} + +// createChain create NetBird chains +func createChain(iptables *iptables.IPTables, table, newChain string) error { + chains, err := iptables.ListChains(table) + if err != nil { + return fmt.Errorf("couldn't get %s %s table chains, error: %v", iptablesProtoToString(iptables.Proto()), table, err) + } + + shouldCreateChain := true + for _, chain := range chains { + if chain == newChain { + shouldCreateChain = false + } + } + + if shouldCreateChain { + err = iptables.NewChain(table, newChain) + if err != nil { + return fmt.Errorf("couldn't create %s chain %s in %s table, error: %v", iptablesProtoToString(iptables.Proto()), newChain, table, err) + } + + if table == iptablesNatTable { + err = iptables.Append(table, newChain, iptablesDefaultNetbirdNatRule...) + } else { + err = iptables.Append(table, newChain, iptablesDefaultNetbirdForwardingRule...) + } + if err != nil { + return fmt.Errorf("couldn't create %s chain %s default rule, error: %v", iptablesProtoToString(iptables.Proto()), newChain, err) + } + + } + return nil +} + +// genRuleSpec generates rule specification with comment identifier +func genRuleSpec(jump, id, source, destination string) []string { + return []string{"-s", source, "-d", destination, "-j", jump, "-m", "comment", "--comment", id} +} + +// getRuleRouteID returns the rule ID if matches our prefix +func getRuleRouteID(rule []string) string { + for i, flag := range rule { + if flag == "--comment" { + id := rule[i+1] + if strings.HasPrefix(id, "netbird-") { + return id + } + } + } + return "" +} + +// InsertRoutingRules inserts an iptables rule pair to the forwarding chain and if enabled, to the nat chain +func (i *iptablesManager) InsertRoutingRules(pair routerPair) error { + i.mux.Lock() + defer i.mux.Unlock() + + var err error + prefix := netip.MustParsePrefix(pair.source) + ipVersion := ipv4 + iptablesClient := i.ipv4Client + if prefix.Addr().Unmap().Is6() { + iptablesClient = i.ipv6Client + ipVersion = ipv6 + } + + forwardRuleKey := genKey(forwardingFormat, pair.ID) + forwardRule := genRuleSpec(routingFinalForwardJump, forwardRuleKey, pair.source, pair.destination) + existingRule, found := i.rules[ipVersion][forwardRuleKey] + if found { + err = iptablesClient.DeleteIfExists(iptablesFilterTable, iptablesRoutingForwardingChain, existingRule...) + if err != nil { + return fmt.Errorf("iptables: error while removing existing forwarding rule for %s: %v", pair.destination, err) + } + delete(i.rules[ipVersion], forwardRuleKey) + } + err = iptablesClient.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forwardRule...) + if err != nil { + return fmt.Errorf("iptables: error while adding new forwarding rule for %s: %v", pair.destination, err) + } + + i.rules[ipVersion][forwardRuleKey] = forwardRule + + if !pair.masquerade { + return nil + } + + natRuleKey := genKey(natFormat, pair.ID) + natRule := genRuleSpec(routingFinalNatJump, natRuleKey, pair.source, pair.destination) + existingRule, found = i.rules[ipVersion][natRuleKey] + if found { + err = iptablesClient.DeleteIfExists(iptablesNatTable, iptablesRoutingNatChain, existingRule...) + if err != nil { + return fmt.Errorf("iptables: error while removing existing nat rulefor %s: %v", pair.destination, err) + } + delete(i.rules[ipVersion], natRuleKey) + } + err = iptablesClient.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, natRule...) + if err != nil { + return fmt.Errorf("iptables: error while adding new nat rulefor %s: %v", pair.destination, err) + } + + i.rules[ipVersion][natRuleKey] = natRule + + return nil +} + +// RemoveRoutingRules removes an iptables rule pair from forwarding and nat chains +func (i *iptablesManager) RemoveRoutingRules(pair routerPair) error { + i.mux.Lock() + defer i.mux.Unlock() + + var err error + prefix := netip.MustParsePrefix(pair.source) + ipVersion := ipv4 + iptablesClient := i.ipv4Client + if prefix.Addr().Unmap().Is6() { + iptablesClient = i.ipv6Client + ipVersion = ipv6 + } + + forwardRuleKey := genKey(forwardingFormat, pair.ID) + existingRule, found := i.rules[ipVersion][forwardRuleKey] + if found { + err = iptablesClient.DeleteIfExists(iptablesFilterTable, iptablesRoutingForwardingChain, existingRule...) + if err != nil { + return fmt.Errorf("iptables: error while removing existing forwarding rule for %s: %v", pair.destination, err) + } + } + delete(i.rules[ipVersion], forwardRuleKey) + + if !pair.masquerade { + return nil + } + + natRuleKey := genKey(natFormat, pair.ID) + existingRule, found = i.rules[ipVersion][natRuleKey] + if found { + err = iptablesClient.DeleteIfExists(iptablesNatTable, iptablesRoutingNatChain, existingRule...) + if err != nil { + return fmt.Errorf("iptables: error while removing existing nat rule for %s: %v", pair.destination, err) + } + } + delete(i.rules[ipVersion], natRuleKey) + + return nil +} diff --git a/client/internal/routemanager/iptables_linux_test.go b/client/internal/routemanager/iptables_linux_test.go new file mode 100644 index 000000000..8b469b3a3 --- /dev/null +++ b/client/internal/routemanager/iptables_linux_test.go @@ -0,0 +1,247 @@ +package routemanager + +import ( + "context" + "github.com/coreos/go-iptables/iptables" + "github.com/stretchr/testify/require" + "testing" +) + +func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { + + if !isIptablesSupported() { + t.SkipNow() + } + + ctx, cancel := context.WithCancel(context.TODO()) + ipv4Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4) + ipv6Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv6) + + manager := &iptablesManager{ + ctx: ctx, + stop: cancel, + ipv4Client: ipv4Client, + ipv6Client: ipv6Client, + rules: make(map[string]map[string][]string), + } + + defer manager.CleanRoutingRules() + + err := manager.RestoreOrCreateContainers() + require.NoError(t, err, "shouldn't return error") + + require.Len(t, manager.rules, 2, "should have created maps for ipv4 and ipv6") + + require.Len(t, manager.rules[ipv4], 2, "should have created minimal rules for ipv4") + + exists, err := ipv4Client.Exists(iptablesFilterTable, iptablesForwardChain, manager.rules[ipv4][ipv4Forwarding]...) + require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv4, iptablesFilterTable, iptablesForwardChain) + require.True(t, exists, "forwarding rule should exist") + + exists, err = ipv4Client.Exists(iptablesNatTable, iptablesPostRoutingChain, manager.rules[ipv4][ipv4Nat]...) + require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv4, iptablesNatTable, iptablesPostRoutingChain) + require.True(t, exists, "postrouting rule should exist") + + require.Len(t, manager.rules[ipv6], 2, "should have created minimal rules for ipv6") + + exists, err = ipv6Client.Exists(iptablesFilterTable, iptablesForwardChain, manager.rules[ipv6][ipv6Forwarding]...) + require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv6, iptablesFilterTable, iptablesForwardChain) + require.True(t, exists, "forwarding rule should exist") + + exists, err = ipv6Client.Exists(iptablesNatTable, iptablesPostRoutingChain, manager.rules[ipv6][ipv6Nat]...) + require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv6, iptablesNatTable, iptablesPostRoutingChain) + require.True(t, exists, "postrouting rule should exist") + + pair := routerPair{ + ID: "abc", + source: "100.100.100.1/32", + destination: "100.100.100.0/24", + masquerade: true, + } + forward4RuleKey := genKey(forwardingFormat, pair.ID) + forward4Rule := genRuleSpec(routingFinalForwardJump, forward4RuleKey, pair.source, pair.destination) + + err = ipv4Client.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forward4Rule...) + require.NoError(t, err, "inserting rule should not return error") + + nat4RuleKey := genKey(natFormat, pair.ID) + nat4Rule := genRuleSpec(routingFinalNatJump, nat4RuleKey, pair.source, pair.destination) + + err = ipv4Client.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, nat4Rule...) + require.NoError(t, err, "inserting rule should not return error") + + pair = routerPair{ + ID: "abc", + source: "fc00::1/128", + destination: "fc11::/64", + masquerade: true, + } + + forward6RuleKey := genKey(forwardingFormat, pair.ID) + forward6Rule := genRuleSpec(routingFinalForwardJump, forward6RuleKey, pair.source, pair.destination) + + err = ipv6Client.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forward6Rule...) + require.NoError(t, err, "inserting rule should not return error") + + nat6RuleKey := genKey(natFormat, pair.ID) + nat6Rule := genRuleSpec(routingFinalNatJump, nat6RuleKey, pair.source, pair.destination) + + err = ipv6Client.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, nat6Rule...) + require.NoError(t, err, "inserting rule should not return error") + + delete(manager.rules, ipv4) + delete(manager.rules, ipv6) + + err = manager.RestoreOrCreateContainers() + require.NoError(t, err, "shouldn't return error") + + require.Len(t, manager.rules[ipv4], 4, "should have restored all rules for ipv4") + + foundRule, found := manager.rules[ipv4][forward4RuleKey] + require.True(t, found, "forwarding rule should exist in the map") + require.Equal(t, forward4Rule[:4], foundRule[:4], "stored forwarding rule should match") + + foundRule, found = manager.rules[ipv4][nat4RuleKey] + require.True(t, found, "nat rule should exist in the map") + require.Equal(t, nat4Rule[:4], foundRule[:4], "stored nat rule should match") + + require.Len(t, manager.rules[ipv6], 4, "should have restored all rules for ipv6") + + foundRule, found = manager.rules[ipv6][forward6RuleKey] + require.True(t, found, "forwarding rule should exist in the map") + require.Equal(t, forward6Rule[:4], foundRule[:4], "stored forward rule should match") + + foundRule, found = manager.rules[ipv6][nat6RuleKey] + require.True(t, found, "nat rule should exist in the map") + require.Equal(t, nat6Rule[:4], foundRule[:4], "stored nat rule should match") +} + +func TestIptablesManager_InsertRoutingRules(t *testing.T) { + + if !isIptablesSupported() { + t.SkipNow() + } + + for _, testCase := range insertRuleTestCases { + t.Run(testCase.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.TODO()) + ipv4Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4) + ipv6Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv6) + iptablesClient := ipv4Client + if testCase.ipVersion == ipv6 { + iptablesClient = ipv6Client + } + + manager := &iptablesManager{ + ctx: ctx, + stop: cancel, + ipv4Client: ipv4Client, + ipv6Client: ipv6Client, + rules: make(map[string]map[string][]string), + } + + defer manager.CleanRoutingRules() + + err := manager.RestoreOrCreateContainers() + require.NoError(t, err, "shouldn't return error") + + err = manager.InsertRoutingRules(testCase.inputPair) + require.NoError(t, err, "forwarding pair should be inserted") + + forwardRuleKey := genKey(forwardingFormat, testCase.inputPair.ID) + forwardRule := genRuleSpec(routingFinalForwardJump, forwardRuleKey, testCase.inputPair.source, testCase.inputPair.destination) + + exists, err := iptablesClient.Exists(iptablesFilterTable, iptablesRoutingForwardingChain, forwardRule...) + require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesFilterTable, iptablesRoutingForwardingChain) + require.True(t, exists, "forwarding rule should exist") + + foundRule, found := manager.rules[testCase.ipVersion][forwardRuleKey] + require.True(t, found, "forwarding rule should exist in the manager map") + require.Equal(t, forwardRule[:4], foundRule[:4], "stored forwarding rule should match") + + natRuleKey := genKey(natFormat, testCase.inputPair.ID) + natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.inputPair.source, testCase.inputPair.destination) + + exists, err = iptablesClient.Exists(iptablesNatTable, iptablesRoutingNatChain, natRule...) + require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesNatTable, iptablesRoutingNatChain) + if testCase.inputPair.masquerade { + require.True(t, exists, "nat rule should be created") + foundNatRule, foundNat := manager.rules[testCase.ipVersion][natRuleKey] + require.True(t, foundNat, "nat rule should exist in the map") + require.Equal(t, natRule[:4], foundNatRule[:4], "stored nat rule should match") + } else { + require.False(t, exists, "nat rule should not be created") + _, foundNat := manager.rules[testCase.ipVersion][natRuleKey] + require.False(t, foundNat, "nat rule should exist in the map") + } + }) + } +} + +func TestIptablesManager_RemoveRoutingRules(t *testing.T) { + + if !isIptablesSupported() { + t.SkipNow() + } + + for _, testCase := range removeRuleTestCases { + t.Run(testCase.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.TODO()) + ipv4Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4) + ipv6Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv6) + iptablesClient := ipv4Client + if testCase.ipVersion == ipv6 { + iptablesClient = ipv6Client + } + + manager := &iptablesManager{ + ctx: ctx, + stop: cancel, + ipv4Client: ipv4Client, + ipv6Client: ipv6Client, + rules: make(map[string]map[string][]string), + } + + defer manager.CleanRoutingRules() + + err := manager.RestoreOrCreateContainers() + require.NoError(t, err, "shouldn't return error") + + forwardRuleKey := genKey(forwardingFormat, testCase.inputPair.ID) + forwardRule := genRuleSpec(routingFinalForwardJump, forwardRuleKey, testCase.inputPair.source, testCase.inputPair.destination) + + err = iptablesClient.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forwardRule...) + require.NoError(t, err, "inserting rule should not return error") + + natRuleKey := genKey(natFormat, testCase.inputPair.ID) + natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.inputPair.source, testCase.inputPair.destination) + + err = iptablesClient.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, natRule...) + require.NoError(t, err, "inserting rule should not return error") + + delete(manager.rules, ipv4) + delete(manager.rules, ipv6) + + err = manager.RestoreOrCreateContainers() + require.NoError(t, err, "shouldn't return error") + + err = manager.RemoveRoutingRules(testCase.inputPair) + require.NoError(t, err, "shouldn't return error") + + exists, err := iptablesClient.Exists(iptablesFilterTable, iptablesRoutingForwardingChain, forwardRule...) + require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesFilterTable, iptablesRoutingForwardingChain) + require.False(t, exists, "forwarding rule should not exist") + + _, found := manager.rules[testCase.ipVersion][forwardRuleKey] + require.False(t, found, "forwarding rule should exist in the manager map") + + exists, err = iptablesClient.Exists(iptablesNatTable, iptablesRoutingNatChain, natRule...) + require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesNatTable, iptablesRoutingNatChain) + require.False(t, exists, "nat rule should not exist") + + _, found = manager.rules[testCase.ipVersion][natRuleKey] + require.False(t, found, "forwarding rule should exist in the manager map") + + }) + } +} diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go new file mode 100644 index 000000000..4527ae0cb --- /dev/null +++ b/client/internal/routemanager/manager.go @@ -0,0 +1,181 @@ +package routemanager + +import ( + "context" + "fmt" + "github.com/netbirdio/netbird/client/status" + "github.com/netbirdio/netbird/client/system" + "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/route" + log "github.com/sirupsen/logrus" + "runtime" + "sync" +) + +// Manager is a route manager interface +type Manager interface { + UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error + Stop() +} + +// DefaultManager is the default instance of a route manager +type DefaultManager struct { + ctx context.Context + stop context.CancelFunc + mux sync.Mutex + clientNetworks map[string]*clientNetwork + serverRoutes map[string]*route.Route + serverRouter *serverRouter + statusRecorder *status.Status + wgInterface *iface.WGIface + pubKey string +} + +// NewManager returns a new route manager +func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *status.Status) *DefaultManager { + mCTX, cancel := context.WithCancel(ctx) + return &DefaultManager{ + ctx: mCTX, + stop: cancel, + clientNetworks: make(map[string]*clientNetwork), + serverRoutes: make(map[string]*route.Route), + serverRouter: &serverRouter{ + routes: make(map[string]*route.Route), + netForwardHistoryEnabled: isNetForwardHistoryEnabled(), + firewall: NewFirewall(ctx), + }, + statusRecorder: statusRecorder, + wgInterface: wgInterface, + pubKey: pubKey, + } +} + +// Stop stops the manager watchers and clean firewall rules +func (m *DefaultManager) Stop() { + m.stop() + m.serverRouter.firewall.CleanRoutingRules() +} + +func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) { + // removing routes that do not exist as per the update from the Management service. + for id, client := range m.clientNetworks { + _, found := networks[id] + if !found { + log.Debugf("stopping client network watcher, %s", id) + client.stop() + delete(m.clientNetworks, id) + } + } + + for id, routes := range networks { + clientNetworkWatcher, found := m.clientNetworks[id] + if !found { + clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network) + m.clientNetworks[id] = clientNetworkWatcher + go clientNetworkWatcher.peersStateAndUpdateWatcher() + } + update := routesUpdate{ + updateSerial: updateSerial, + routes: routes, + } + + clientNetworkWatcher.sendUpdateToClientNetworkWatcher(update) + } +} + +func (m *DefaultManager) updateServerRoutes(routesMap map[string]*route.Route) error { + serverRoutesToRemove := make([]string, 0) + + if len(routesMap) > 0 { + err := m.serverRouter.firewall.RestoreOrCreateContainers() + if err != nil { + return fmt.Errorf("couldn't initialize firewall containers, got err: %v", err) + } + } + + for routeID := range m.serverRoutes { + update, found := routesMap[routeID] + if !found || !update.IsEqual(m.serverRoutes[routeID]) { + serverRoutesToRemove = append(serverRoutesToRemove, routeID) + continue + } + } + + for _, routeID := range serverRoutesToRemove { + oldRoute := m.serverRoutes[routeID] + err := m.removeFromServerNetwork(oldRoute) + if err != nil { + log.Errorf("unable to remove route id: %s, network %s, from server, got: %v", + oldRoute.ID, oldRoute.Network, err) + } + delete(m.serverRoutes, routeID) + } + + for id, newRoute := range routesMap { + _, found := m.serverRoutes[id] + if found { + continue + } + + err := m.addToServerNetwork(newRoute) + if err != nil { + log.Errorf("unable to add route %s from server, got: %v", newRoute.ID, err) + continue + } + m.serverRoutes[id] = newRoute + } + + if len(m.serverRoutes) > 0 { + err := enableIPForwarding() + if err != nil { + return err + } + } + + return nil +} + +// UpdateRoutes compares received routes with existing routes and remove, update or add them to the client and server maps +func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error { + select { + case <-m.ctx.Done(): + log.Infof("not updating routes as context is closed") + return m.ctx.Err() + default: + m.mux.Lock() + defer m.mux.Unlock() + + newClientRoutesIDMap := make(map[string][]*route.Route) + newServerRoutesMap := make(map[string]*route.Route) + + for _, newRoute := range newRoutes { + // only linux is supported for now + if newRoute.Peer == m.pubKey { + if runtime.GOOS != "linux" { + log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS) + continue + } + newServerRoutesMap[newRoute.ID] = newRoute + } else { + // if prefix is too small, lets assume is a possible default route which is not yet supported + // we skip this route management + if newRoute.Network.Bits() < 7 { + log.Errorf("this agent version: %s, doesn't support default routes, received %s, skiping this route", + system.NetbirdVersion(), newRoute.Network) + continue + } + clientNetworkID := getClientNetworkID(newRoute) + newClientRoutesIDMap[clientNetworkID] = append(newClientRoutesIDMap[clientNetworkID], newRoute) + } + } + + m.updateClientNetworks(updateSerial, newClientRoutesIDMap) + + err := m.updateServerRoutes(newServerRoutesMap) + if err != nil { + return err + } + + return nil + } +} diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go new file mode 100644 index 000000000..f88aeb53d --- /dev/null +++ b/client/internal/routemanager/manager_test.go @@ -0,0 +1,370 @@ +package routemanager + +import ( + "context" + "fmt" + "github.com/netbirdio/netbird/client/status" + "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/route" + "github.com/stretchr/testify/require" + "net/netip" + "runtime" + "testing" +) + +// send 5 routes, one for server and 4 for clients, one normal and 2 HA and one small +// if linux host, should have one for server in map +// we should have 2 client manager +// 2 ranges in our routing table + +const localPeerKey = "local" +const remotePeerKey1 = "remote1" +const remotePeerKey2 = "remote1" + +func TestManagerUpdateRoutes(t *testing.T) { + testCases := []struct { + name string + inputInitRoutes []*route.Route + inputRoutes []*route.Route + inputSerial uint64 + shouldCheckServerRoutes bool + serverRoutesExpected int + clientNetworkWatchersExpected int + }{ + { + name: "Should create 2 client networks", + inputInitRoutes: []*route.Route{}, + inputRoutes: []*route.Route{ + { + ID: "a", + NetID: "routeA", + Peer: remotePeerKey1, + Network: netip.MustParsePrefix("100.64.251.250/30"), + NetworkType: route.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, + }, + { + ID: "b", + NetID: "routeB", + Peer: remotePeerKey1, + Network: netip.MustParsePrefix("8.8.8.8/32"), + NetworkType: route.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, + }, + }, + inputSerial: 1, + clientNetworkWatchersExpected: 2, + }, + { + name: "Should Create 2 Server Routes", + inputRoutes: []*route.Route{ + { + ID: "a", + NetID: "routeA", + Peer: localPeerKey, + Network: netip.MustParsePrefix("100.64.252.250/30"), + NetworkType: route.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, + }, + { + ID: "b", + NetID: "routeB", + Peer: localPeerKey, + Network: netip.MustParsePrefix("8.8.8.9/32"), + NetworkType: route.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, + }, + }, + inputSerial: 1, + shouldCheckServerRoutes: runtime.GOOS == "linux", + serverRoutesExpected: 2, + clientNetworkWatchersExpected: 0, + }, + { + name: "Should Create 1 Route For Client And Server", + inputRoutes: []*route.Route{ + { + ID: "a", + NetID: "routeA", + Peer: localPeerKey, + Network: netip.MustParsePrefix("100.64.30.250/30"), + NetworkType: route.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, + }, + { + ID: "b", + NetID: "routeB", + Peer: remotePeerKey1, + Network: netip.MustParsePrefix("8.8.9.9/32"), + NetworkType: route.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, + }, + }, + inputSerial: 1, + shouldCheckServerRoutes: runtime.GOOS == "linux", + serverRoutesExpected: 1, + clientNetworkWatchersExpected: 1, + }, + { + name: "Should Create 1 HA Route and 1 Standalone", + inputRoutes: []*route.Route{ + { + ID: "a", + NetID: "routeA", + Peer: remotePeerKey1, + Network: netip.MustParsePrefix("8.8.20.0/24"), + NetworkType: route.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, + }, + { + ID: "b", + NetID: "routeA", + Peer: remotePeerKey2, + Network: netip.MustParsePrefix("8.8.20.0/24"), + NetworkType: route.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, + }, + { + ID: "c", + NetID: "routeB", + Peer: remotePeerKey1, + Network: netip.MustParsePrefix("8.8.9.9/32"), + NetworkType: route.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, + }, + }, + inputSerial: 1, + clientNetworkWatchersExpected: 2, + }, + { + name: "No Small Client Route Should Be Added", + inputRoutes: []*route.Route{ + { + ID: "a", + NetID: "routeA", + Peer: remotePeerKey1, + Network: netip.MustParsePrefix("0.0.0.0/0"), + NetworkType: route.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, + }, + }, + inputSerial: 1, + clientNetworkWatchersExpected: 0, + }, + { + name: "No Server Routes Should Be Added To Non Linux", + inputRoutes: []*route.Route{ + { + ID: "a", + NetID: "routeA", + Peer: localPeerKey, + Network: netip.MustParsePrefix("1.2.3.4/32"), + NetworkType: route.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, + }, + }, + inputSerial: 1, + shouldCheckServerRoutes: runtime.GOOS != "linux", + serverRoutesExpected: 0, + clientNetworkWatchersExpected: 0, + }, + { + name: "Remove 1 Client Route", + inputInitRoutes: []*route.Route{ + { + ID: "a", + NetID: "routeA", + Peer: remotePeerKey1, + Network: netip.MustParsePrefix("100.64.251.250/30"), + NetworkType: route.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, + }, + { + ID: "b", + NetID: "routeB", + Peer: remotePeerKey1, + Network: netip.MustParsePrefix("8.8.8.8/32"), + NetworkType: route.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, + }, + }, + inputRoutes: []*route.Route{ + { + ID: "a", + NetID: "routeA", + Peer: remotePeerKey1, + Network: netip.MustParsePrefix("100.64.251.250/30"), + NetworkType: route.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, + }, + }, + inputSerial: 1, + clientNetworkWatchersExpected: 1, + }, + { + name: "Update Route to HA", + inputInitRoutes: []*route.Route{ + { + ID: "a", + NetID: "routeA", + Peer: remotePeerKey1, + Network: netip.MustParsePrefix("100.64.251.250/30"), + NetworkType: route.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, + }, + { + ID: "b", + NetID: "routeB", + Peer: remotePeerKey1, + Network: netip.MustParsePrefix("8.8.8.8/32"), + NetworkType: route.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, + }, + }, + inputRoutes: []*route.Route{ + { + ID: "a", + NetID: "routeA", + Peer: remotePeerKey1, + Network: netip.MustParsePrefix("100.64.251.250/30"), + NetworkType: route.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, + }, + { + ID: "b", + NetID: "routeA", + Peer: remotePeerKey2, + Network: netip.MustParsePrefix("100.64.251.250/30"), + NetworkType: route.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, + }, + }, + inputSerial: 1, + clientNetworkWatchersExpected: 1, + }, + { + name: "Remove Client Routes", + inputInitRoutes: []*route.Route{ + { + ID: "a", + NetID: "routeA", + Peer: remotePeerKey1, + Network: netip.MustParsePrefix("100.64.251.250/30"), + NetworkType: route.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, + }, + { + ID: "b", + NetID: "routeB", + Peer: remotePeerKey1, + Network: netip.MustParsePrefix("8.8.8.8/32"), + NetworkType: route.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, + }, + }, + inputRoutes: []*route.Route{}, + inputSerial: 1, + clientNetworkWatchersExpected: 0, + }, + { + name: "Remove All Routes", + inputInitRoutes: []*route.Route{ + { + ID: "a", + NetID: "routeA", + Peer: localPeerKey, + Network: netip.MustParsePrefix("100.64.251.250/30"), + NetworkType: route.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, + }, + { + ID: "b", + NetID: "routeB", + Peer: remotePeerKey1, + Network: netip.MustParsePrefix("8.8.8.8/32"), + NetworkType: route.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, + }, + }, + inputRoutes: []*route.Route{}, + inputSerial: 1, + shouldCheckServerRoutes: true, + serverRoutesExpected: 0, + clientNetworkWatchersExpected: 0, + }, + } + + for n, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", iface.DefaultMTU) + require.NoError(t, err, "should create testing WGIface interface") + defer wgInterface.Close() + + err = wgInterface.Create() + require.NoError(t, err, "should create testing wireguard interface") + + statusRecorder := status.NewRecorder() + ctx := context.TODO() + routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder) + defer routeManager.Stop() + + if len(testCase.inputInitRoutes) > 0 { + err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes) + require.NoError(t, err, "should update routes with init routes") + } + + err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes) + require.NoError(t, err, "should update routes") + + require.Len(t, routeManager.clientNetworks, testCase.clientNetworkWatchersExpected, "client networks size should match") + + if testCase.shouldCheckServerRoutes { + require.Len(t, routeManager.serverRoutes, testCase.serverRoutesExpected, "server networks size should match") + } + }) + } +} diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go new file mode 100644 index 000000000..4d9a714d3 --- /dev/null +++ b/client/internal/routemanager/mock.go @@ -0,0 +1,27 @@ +package routemanager + +import ( + "fmt" + "github.com/netbirdio/netbird/route" +) + +// MockManager is the mock instance of a route manager +type MockManager struct { + UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) error + StopFunc func() +} + +// UpdateRoutes mock implementation of UpdateRoutes from Manager interface +func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error { + if m.UpdateRoutesFunc != nil { + return m.UpdateRoutesFunc(updateSerial, newRoutes) + } + return fmt.Errorf("method UpdateRoutes is not implemented") +} + +// Stop mock implementation of Stop from Manager interface +func (m *MockManager) Stop() { + if m.StopFunc != nil { + m.StopFunc() + } +} diff --git a/client/internal/routemanager/nftables_linux.go b/client/internal/routemanager/nftables_linux.go new file mode 100644 index 000000000..6201301fc --- /dev/null +++ b/client/internal/routemanager/nftables_linux.go @@ -0,0 +1,384 @@ +package routemanager + +import ( + "context" + "fmt" + "github.com/google/nftables/binaryutil" + "github.com/google/nftables/expr" + log "github.com/sirupsen/logrus" + "net" + "net/netip" + "sync" +) +import "github.com/google/nftables" + +// +const ( + nftablesTable = "netbird-rt" + nftablesRoutingForwardingChain = "netbird-rt-fwd" + nftablesRoutingNatChain = "netbird-rt-nat" +) + +// constants needed to create nftable rules +const ( + ipv4Len = 4 + ipv4SrcOffset = 12 + ipv4DestOffset = 16 + ipv6Len = 16 + ipv6SrcOffset = 8 + ipv6DestOffset = 24 + exprDirectionSource = "source" + exprDirectionDestination = "destination" +) + +// some presets for building nftable rules +var ( + zeroXor = binaryutil.NativeEndian.PutUint32(0) + + zeroXor6 = append(binaryutil.NativeEndian.PutUint64(0), binaryutil.NativeEndian.PutUint64(0)...) + + exprAllowRelatedEstablished = []expr.Any{ + &expr.Ct{ + Register: 1, + SourceRegister: false, + Key: 0, + }, + &expr.Bitwise{ + DestRegister: 1, + SourceRegister: 1, + Len: 4, + Mask: []uint8{0x6, 0x0, 0x0, 0x0}, + Xor: zeroXor, + }, + &expr.Cmp{ + Register: 1, + Data: binaryutil.NativeEndian.PutUint32(0), + }, + &expr.Counter{}, + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + } + + exprCounterAccept = []expr.Any{ + &expr.Counter{}, + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + } +) + +type nftablesManager struct { + ctx context.Context + stop context.CancelFunc + conn *nftables.Conn + tableIPv4 *nftables.Table + tableIPv6 *nftables.Table + chains map[string]map[string]*nftables.Chain + rules map[string]*nftables.Rule + mux sync.Mutex +} + +// CleanRoutingRules cleans existing nftables rules from the system +func (n *nftablesManager) CleanRoutingRules() { + n.mux.Lock() + defer n.mux.Unlock() + log.Debug("flushing tables") + n.conn.FlushTable(n.tableIPv6) + n.conn.FlushTable(n.tableIPv4) + log.Debugf("flushing tables result in: %v error", n.conn.Flush()) +} + +// RestoreOrCreateContainers restores existing nftables containers (tables and chains) +// if they don't exist, we create them +func (n *nftablesManager) RestoreOrCreateContainers() error { + n.mux.Lock() + defer n.mux.Unlock() + + if n.tableIPv6 != nil && n.tableIPv4 != nil { + log.Debugf("nftables: containers already restored, skipping") + return nil + } + + tables, err := n.conn.ListTables() + if err != nil { + return fmt.Errorf("nftables: unable to list tables: %v", err) + } + + for _, table := range tables { + if table.Name == nftablesTable { + if table.Family == nftables.TableFamilyIPv4 { + n.tableIPv4 = table + continue + } + n.tableIPv6 = table + } + } + + if n.tableIPv4 == nil { + n.tableIPv4 = n.conn.AddTable(&nftables.Table{ + Name: nftablesTable, + Family: nftables.TableFamilyIPv4, + }) + } + + if n.tableIPv6 == nil { + n.tableIPv6 = n.conn.AddTable(&nftables.Table{ + Name: nftablesTable, + Family: nftables.TableFamilyIPv6, + }) + } + + chains, err := n.conn.ListChains() + if err != nil { + return fmt.Errorf("nftables: unable to list chains: %v", err) + } + + n.chains[ipv4] = make(map[string]*nftables.Chain) + n.chains[ipv6] = make(map[string]*nftables.Chain) + + for _, chain := range chains { + switch { + case chain.Table.Name == nftablesTable && chain.Table.Family == nftables.TableFamilyIPv4: + n.chains[ipv4][chain.Name] = chain + case chain.Table.Name == nftablesTable && chain.Table.Family == nftables.TableFamilyIPv6: + n.chains[ipv6][chain.Name] = chain + } + } + + if _, found := n.chains[ipv4][nftablesRoutingForwardingChain]; !found { + n.chains[ipv4][nftablesRoutingForwardingChain] = n.conn.AddChain(&nftables.Chain{ + Name: nftablesRoutingForwardingChain, + Table: n.tableIPv4, + Hooknum: nftables.ChainHookForward, + Priority: nftables.ChainPriorityNATDest + 1, + Type: nftables.ChainTypeFilter, + }) + } + + if _, found := n.chains[ipv4][nftablesRoutingNatChain]; !found { + n.chains[ipv4][nftablesRoutingNatChain] = n.conn.AddChain(&nftables.Chain{ + Name: nftablesRoutingNatChain, + Table: n.tableIPv4, + Hooknum: nftables.ChainHookPostrouting, + Priority: nftables.ChainPriorityNATSource - 1, + Type: nftables.ChainTypeNAT, + }) + } + + if _, found := n.chains[ipv6][nftablesRoutingForwardingChain]; !found { + n.chains[ipv6][nftablesRoutingForwardingChain] = n.conn.AddChain(&nftables.Chain{ + Name: nftablesRoutingForwardingChain, + Table: n.tableIPv6, + Hooknum: nftables.ChainHookForward, + Priority: nftables.ChainPriorityNATDest + 1, + Type: nftables.ChainTypeFilter, + }) + } + + if _, found := n.chains[ipv6][nftablesRoutingNatChain]; !found { + n.chains[ipv6][nftablesRoutingNatChain] = n.conn.AddChain(&nftables.Chain{ + Name: nftablesRoutingNatChain, + Table: n.tableIPv6, + Hooknum: nftables.ChainHookPostrouting, + Priority: nftables.ChainPriorityNATSource - 1, + Type: nftables.ChainTypeNAT, + }) + } + + err = n.refreshRulesMap() + if err != nil { + return err + } + + n.checkOrCreateDefaultForwardingRules() + err = n.conn.Flush() + if err != nil { + return fmt.Errorf("nftables: unable to initialize table: %v", err) + } + return nil +} + +// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid +// duplicates and to get missing attributes that we don't have when adding new rules +func (n *nftablesManager) refreshRulesMap() error { + for _, registeredChains := range n.chains { + for _, chain := range registeredChains { + rules, err := n.conn.GetRules(chain.Table, chain) + if err != nil { + return fmt.Errorf("nftables: unable to list rules: %v", err) + } + for _, rule := range rules { + if len(rule.UserData) > 0 { + n.rules[string(rule.UserData)] = rule + } + } + } + } + return nil +} + +// checkOrCreateDefaultForwardingRules checks if the default forwarding rules are enabled +func (n *nftablesManager) checkOrCreateDefaultForwardingRules() { + _, foundIPv4 := n.rules[ipv4Forwarding] + if !foundIPv4 { + n.rules[ipv4Forwarding] = n.conn.AddRule(&nftables.Rule{ + Table: n.tableIPv4, + Chain: n.chains[ipv4][nftablesRoutingForwardingChain], + Exprs: exprAllowRelatedEstablished, + UserData: []byte(ipv4Forwarding), + }) + } + + _, foundIPv6 := n.rules[ipv6Forwarding] + if !foundIPv6 { + n.rules[ipv6Forwarding] = n.conn.AddRule(&nftables.Rule{ + Table: n.tableIPv6, + Chain: n.chains[ipv6][nftablesRoutingForwardingChain], + Exprs: exprAllowRelatedEstablished, + UserData: []byte(ipv6Forwarding), + }) + } +} + +// InsertRoutingRules inserts a nftable rule pair to the forwarding chain and if enabled, to the nat chain +func (n *nftablesManager) InsertRoutingRules(pair routerPair) error { + n.mux.Lock() + defer n.mux.Unlock() + + prefix := netip.MustParsePrefix(pair.source) + + sourceExp := generateCIDRMatcherExpressions("source", pair.source) + destExp := generateCIDRMatcherExpressions("destination", pair.destination) + + forwardExp := append(sourceExp, append(destExp, exprCounterAccept...)...) + fwdKey := genKey(forwardingFormat, pair.ID) + if prefix.Addr().Unmap().Is4() { + n.rules[fwdKey] = n.conn.InsertRule(&nftables.Rule{ + Table: n.tableIPv4, + Chain: n.chains[ipv4][nftablesRoutingForwardingChain], + Exprs: forwardExp, + UserData: []byte(fwdKey), + }) + } else { + n.rules[fwdKey] = n.conn.InsertRule(&nftables.Rule{ + Table: n.tableIPv6, + Chain: n.chains[ipv6][nftablesRoutingForwardingChain], + Exprs: forwardExp, + UserData: []byte(fwdKey), + }) + } + + if pair.masquerade { + natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) + natKey := genKey(natFormat, pair.ID) + + if prefix.Addr().Unmap().Is4() { + n.rules[natKey] = n.conn.InsertRule(&nftables.Rule{ + Table: n.tableIPv4, + Chain: n.chains[ipv4][nftablesRoutingNatChain], + Exprs: natExp, + UserData: []byte(natKey), + }) + } else { + n.rules[natKey] = n.conn.InsertRule(&nftables.Rule{ + Table: n.tableIPv6, + Chain: n.chains[ipv6][nftablesRoutingNatChain], + Exprs: natExp, + UserData: []byte(natKey), + }) + } + } + + err := n.conn.Flush() + if err != nil { + return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.destination, err) + } + return nil +} + +// RemoveRoutingRules removes a nftable rule pair from forwarding and nat chains +func (n *nftablesManager) RemoveRoutingRules(pair routerPair) error { + n.mux.Lock() + defer n.mux.Unlock() + + err := n.refreshRulesMap() + if err != nil { + return err + } + + fwdKey := genKey(forwardingFormat, pair.ID) + natKey := genKey(natFormat, pair.ID) + fwdRule, found := n.rules[fwdKey] + if found { + err = n.conn.DelRule(fwdRule) + if err != nil { + return fmt.Errorf("nftables: unable to remove forwarding rule for %s: %v", pair.destination, err) + } + log.Debugf("nftables: removing forwarding rule for %s", pair.destination) + delete(n.rules, fwdKey) + } + natRule, found := n.rules[natKey] + if found { + err = n.conn.DelRule(natRule) + if err != nil { + return fmt.Errorf("nftables: unable to remove nat rule for %s: %v", pair.destination, err) + } + log.Debugf("nftables: removing nat rule for %s", pair.destination) + delete(n.rules, natKey) + } + err = n.conn.Flush() + if err != nil { + return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.destination, err) + } + log.Debugf("nftables: removed rules for %s", pair.destination) + return nil +} + +// getPayloadDirectives get expression directives based on ip version and direction +func getPayloadDirectives(direction string, isIPv4 bool, isIPv6 bool) (uint32, uint32, []byte) { + switch { + case direction == exprDirectionSource && isIPv4: + return ipv4SrcOffset, ipv4Len, zeroXor + case direction == exprDirectionDestination && isIPv4: + return ipv4DestOffset, ipv4Len, zeroXor + case direction == exprDirectionSource && isIPv6: + return ipv6SrcOffset, ipv6Len, zeroXor6 + case direction == exprDirectionDestination && isIPv6: + return ipv6DestOffset, ipv6Len, zeroXor6 + default: + panic("no matched payload directive") + } +} + +// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR +func generateCIDRMatcherExpressions(direction string, cidr string) []expr.Any { + ip, network, _ := net.ParseCIDR(cidr) + ipToAdd, _ := netip.AddrFromSlice(ip) + add := ipToAdd.Unmap() + + offSet, packetLen, zeroXor := getPayloadDirectives(direction, add.Is4(), add.Is6()) + + return []expr.Any{ + // fetch src add + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseNetworkHeader, + Offset: offSet, + Len: packetLen, + }, + // net mask + &expr.Bitwise{ + DestRegister: 1, + SourceRegister: 1, + Len: packetLen, + Mask: network.Mask, + Xor: zeroXor, + }, + // net address + &expr.Cmp{ + Register: 1, + Data: add.AsSlice(), + }, + } +} diff --git a/client/internal/routemanager/nftables_linux_test.go b/client/internal/routemanager/nftables_linux_test.go new file mode 100644 index 000000000..c84df6993 --- /dev/null +++ b/client/internal/routemanager/nftables_linux_test.go @@ -0,0 +1,270 @@ +package routemanager + +import ( + "context" + "github.com/google/nftables" + "github.com/google/nftables/expr" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "testing" +) + +func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) { + + ctx, cancel := context.WithCancel(context.TODO()) + + manager := &nftablesManager{ + ctx: ctx, + stop: cancel, + conn: &nftables.Conn{}, + chains: make(map[string]map[string]*nftables.Chain), + rules: make(map[string]*nftables.Rule), + } + + nftablesTestingClient := &nftables.Conn{} + + defer manager.CleanRoutingRules() + + err := manager.RestoreOrCreateContainers() + require.NoError(t, err, "shouldn't return error") + + require.Len(t, manager.chains, 2, "should have created chains for ipv4 and ipv6") + require.Len(t, manager.chains[ipv4], 2, "should have created chains for ipv4") + require.Len(t, manager.chains[ipv4], 2, "should have created chains for ipv6") + require.Len(t, manager.rules, 2, "should have created rules for ipv4 and ipv6") + + pair := routerPair{ + ID: "abc", + source: "100.100.100.1/32", + destination: "100.100.100.0/24", + masquerade: true, + } + + sourceExp := generateCIDRMatcherExpressions("source", pair.source) + destExp := generateCIDRMatcherExpressions("destination", pair.destination) + + forward4Exp := append(sourceExp, append(destExp, exprCounterAccept...)...) + forward4RuleKey := genKey(forwardingFormat, pair.ID) + inserted4Forwarding := nftablesTestingClient.InsertRule(&nftables.Rule{ + Table: manager.tableIPv4, + Chain: manager.chains[ipv4][nftablesRoutingForwardingChain], + Exprs: forward4Exp, + UserData: []byte(forward4RuleKey), + }) + + nat4Exp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) + nat4RuleKey := genKey(natFormat, pair.ID) + + inserted4Nat := nftablesTestingClient.InsertRule(&nftables.Rule{ + Table: manager.tableIPv4, + Chain: manager.chains[ipv4][nftablesRoutingNatChain], + Exprs: nat4Exp, + UserData: []byte(nat4RuleKey), + }) + + err = nftablesTestingClient.Flush() + require.NoError(t, err, "shouldn't return error") + + pair = routerPair{ + ID: "xyz", + source: "fc00::1/128", + destination: "fc11::/64", + masquerade: true, + } + + sourceExp = generateCIDRMatcherExpressions("source", pair.source) + destExp = generateCIDRMatcherExpressions("destination", pair.destination) + + forward6Exp := append(sourceExp, append(destExp, exprCounterAccept...)...) + forward6RuleKey := genKey(forwardingFormat, pair.ID) + inserted6Forwarding := nftablesTestingClient.InsertRule(&nftables.Rule{ + Table: manager.tableIPv6, + Chain: manager.chains[ipv6][nftablesRoutingForwardingChain], + Exprs: forward6Exp, + UserData: []byte(forward6RuleKey), + }) + + nat6Exp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) + nat6RuleKey := genKey(natFormat, pair.ID) + + inserted6Nat := nftablesTestingClient.InsertRule(&nftables.Rule{ + Table: manager.tableIPv6, + Chain: manager.chains[ipv6][nftablesRoutingNatChain], + Exprs: nat6Exp, + UserData: []byte(nat6RuleKey), + }) + + err = nftablesTestingClient.Flush() + require.NoError(t, err, "shouldn't return error") + + manager.tableIPv4 = nil + manager.tableIPv6 = nil + + err = manager.RestoreOrCreateContainers() + require.NoError(t, err, "shouldn't return error") + + require.Len(t, manager.chains, 2, "should have created chains for ipv4 and ipv6") + require.Len(t, manager.chains[ipv4], 2, "should have created chains for ipv4") + require.Len(t, manager.chains[ipv4], 2, "should have created chains for ipv6") + require.Len(t, manager.rules, 6, "should have restored all rules for ipv4 and ipv6") + + foundRule, found := manager.rules[forward4RuleKey] + require.True(t, found, "forwarding rule should exist in the map") + assert.Equal(t, inserted4Forwarding.Exprs, foundRule.Exprs, "stored forwarding rule expressions should match") + + foundRule, found = manager.rules[nat4RuleKey] + require.True(t, found, "nat rule should exist in the map") + // match len of output as nftables client doesn't return expressions with masquerade expression + assert.ElementsMatch(t, inserted4Nat.Exprs[:len(foundRule.Exprs)], foundRule.Exprs, "stored nat rule expressions should match") + + foundRule, found = manager.rules[forward6RuleKey] + require.True(t, found, "forwarding rule should exist in the map") + assert.Equal(t, inserted6Forwarding.Exprs, foundRule.Exprs, "stored forward rule should match") + + foundRule, found = manager.rules[nat6RuleKey] + require.True(t, found, "nat rule should exist in the map") + // match len of output as nftables client doesn't return expressions with masquerade expression + assert.ElementsMatch(t, inserted6Nat.Exprs[:len(foundRule.Exprs)], foundRule.Exprs, "stored nat rule should match") +} + +func TestNftablesManager_InsertRoutingRules(t *testing.T) { + + for _, testCase := range insertRuleTestCases { + t.Run(testCase.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.TODO()) + + manager := &nftablesManager{ + ctx: ctx, + stop: cancel, + conn: &nftables.Conn{}, + chains: make(map[string]map[string]*nftables.Chain), + rules: make(map[string]*nftables.Rule), + } + + nftablesTestingClient := &nftables.Conn{} + + defer manager.CleanRoutingRules() + + err := manager.RestoreOrCreateContainers() + require.NoError(t, err, "shouldn't return error") + + err = manager.InsertRoutingRules(testCase.inputPair) + require.NoError(t, err, "forwarding pair should be inserted") + + sourceExp := generateCIDRMatcherExpressions("source", testCase.inputPair.source) + destExp := generateCIDRMatcherExpressions("destination", testCase.inputPair.destination) + testingExpression := append(sourceExp, destExp...) + fwdRuleKey := genKey(forwardingFormat, testCase.inputPair.ID) + + found := 0 + for _, registeredChains := range manager.chains { + for _, chain := range registeredChains { + rules, err := nftablesTestingClient.GetRules(chain.Table, chain) + require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name) + for _, rule := range rules { + if len(rule.UserData) > 0 && string(rule.UserData) == fwdRuleKey { + require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "forwarding rule elements should match") + found = 1 + } + } + } + } + + require.Equal(t, 1, found, "should find at least 1 rule to test") + + if testCase.inputPair.masquerade { + natRuleKey := genKey(natFormat, testCase.inputPair.ID) + found := 0 + for _, registeredChains := range manager.chains { + for _, chain := range registeredChains { + rules, err := nftablesTestingClient.GetRules(chain.Table, chain) + require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name) + for _, rule := range rules { + if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey { + require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "nat rule elements should match") + found = 1 + } + } + } + } + require.Equal(t, 1, found, "should find at least 1 rule to test") + } + }) + } +} + +func TestNftablesManager_RemoveRoutingRules(t *testing.T) { + + for _, testCase := range removeRuleTestCases { + t.Run(testCase.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.TODO()) + + manager := &nftablesManager{ + ctx: ctx, + stop: cancel, + conn: &nftables.Conn{}, + chains: make(map[string]map[string]*nftables.Chain), + rules: make(map[string]*nftables.Rule), + } + + nftablesTestingClient := &nftables.Conn{} + + defer manager.CleanRoutingRules() + + err := manager.RestoreOrCreateContainers() + require.NoError(t, err, "shouldn't return error") + + table := manager.tableIPv4 + if testCase.ipVersion == ipv6 { + table = manager.tableIPv6 + } + + sourceExp := generateCIDRMatcherExpressions("source", testCase.inputPair.source) + destExp := generateCIDRMatcherExpressions("destination", testCase.inputPair.destination) + + forwardExp := append(sourceExp, append(destExp, exprCounterAccept...)...) + forwardRuleKey := genKey(forwardingFormat, testCase.inputPair.ID) + insertedForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{ + Table: table, + Chain: manager.chains[testCase.ipVersion][nftablesRoutingForwardingChain], + Exprs: forwardExp, + UserData: []byte(forwardRuleKey), + }) + + natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) + natRuleKey := genKey(natFormat, testCase.inputPair.ID) + + insertedNat := nftablesTestingClient.InsertRule(&nftables.Rule{ + Table: table, + Chain: manager.chains[testCase.ipVersion][nftablesRoutingNatChain], + Exprs: natExp, + UserData: []byte(natRuleKey), + }) + + err = nftablesTestingClient.Flush() + require.NoError(t, err, "shouldn't return error") + + manager.tableIPv4 = nil + manager.tableIPv6 = nil + + err = manager.RestoreOrCreateContainers() + require.NoError(t, err, "shouldn't return error") + + err = manager.RemoveRoutingRules(testCase.inputPair) + require.NoError(t, err, "shouldn't return error") + + for _, registeredChains := range manager.chains { + for _, chain := range registeredChains { + rules, err := nftablesTestingClient.GetRules(chain.Table, chain) + require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name) + for _, rule := range rules { + if len(rule.UserData) > 0 { + require.NotEqual(t, insertedForwarding.UserData, rule.UserData, "forwarding rule should exist") + require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should exist") + } + } + } + } + }) + } +} diff --git a/client/internal/routemanager/server.go b/client/internal/routemanager/server.go new file mode 100644 index 000000000..0bfd1cec5 --- /dev/null +++ b/client/internal/routemanager/server.go @@ -0,0 +1,67 @@ +package routemanager + +import ( + "github.com/netbirdio/netbird/route" + log "github.com/sirupsen/logrus" + "net/netip" + "sync" +) + +type serverRouter struct { + routes map[string]*route.Route + // best effort to keep net forward configuration as it was + netForwardHistoryEnabled bool + mux sync.Mutex + firewall firewallManager +} + +type routerPair struct { + ID string + source string + destination string + masquerade bool +} + +func routeToRouterPair(source string, route *route.Route) routerPair { + parsed := netip.MustParsePrefix(source).Masked() + return routerPair{ + ID: route.ID, + source: parsed.String(), + destination: route.Network.Masked().String(), + masquerade: route.Masquerade, + } +} + +func (m *DefaultManager) removeFromServerNetwork(route *route.Route) error { + select { + case <-m.ctx.Done(): + log.Infof("not removing from server network because context is done") + return m.ctx.Err() + default: + m.serverRouter.mux.Lock() + defer m.serverRouter.mux.Unlock() + err := m.serverRouter.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address.String(), route)) + if err != nil { + return err + } + delete(m.serverRouter.routes, route.ID) + return nil + } +} + +func (m *DefaultManager) addToServerNetwork(route *route.Route) error { + select { + case <-m.ctx.Done(): + log.Infof("not adding to server network because context is done") + return m.ctx.Err() + default: + m.serverRouter.mux.Lock() + defer m.serverRouter.mux.Unlock() + err := m.serverRouter.firewall.InsertRoutingRules(routeToRouterPair(m.wgInterface.Address.String(), route)) + if err != nil { + return err + } + m.serverRouter.routes[route.ID] = route + return nil + } +} diff --git a/client/internal/routemanager/systemops.go b/client/internal/routemanager/systemops.go new file mode 100644 index 000000000..595425b94 --- /dev/null +++ b/client/internal/routemanager/systemops.go @@ -0,0 +1,55 @@ +package routemanager + +import ( + "fmt" + "github.com/libp2p/go-netroute" + log "github.com/sirupsen/logrus" + "net" + "net/netip" +) + +var errRouteNotFound = fmt.Errorf("route not found") + +func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error { + gateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0")) + if err != nil && err != errRouteNotFound { + return err + } + prefixGateway, err := getExistingRIBRouteGateway(prefix) + if err != nil && err != errRouteNotFound { + return err + } + + if prefixGateway != nil && !prefixGateway.Equal(gateway) { + log.Warnf("route for network %s already exist and is pointing to the gateway: %s, won't add another one", prefix, prefixGateway) + return nil + } + return addToRouteTable(prefix, addr) +} + +func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error { + addrIP := net.ParseIP(addr) + prefixGateway, err := getExistingRIBRouteGateway(prefix) + if err != nil { + return err + } + if prefixGateway != nil && !prefixGateway.Equal(addrIP) { + log.Warnf("route for network %s is pointing to a different gateway: %s, should be pointing to: %s, not removing", prefix, prefixGateway, addrIP) + return nil + } + return removeFromRouteTable(prefix) +} + +func getExistingRIBRouteGateway(prefix netip.Prefix) (net.IP, error) { + r, err := netroute.New() + if err != nil { + return nil, err + } + _, _, localGatewayAddress, err := r.Route(prefix.Addr().AsSlice()) + if err != nil { + log.Errorf("getting routes returned an error: %v", err) + return nil, errRouteNotFound + } + + return localGatewayAddress, nil +} diff --git a/client/internal/routemanager/systemops_linux.go b/client/internal/routemanager/systemops_linux.go new file mode 100644 index 000000000..f891b461f --- /dev/null +++ b/client/internal/routemanager/systemops_linux.go @@ -0,0 +1,73 @@ +package routemanager + +import ( + "github.com/vishvananda/netlink" + "io/ioutil" + "net" + "net/netip" +) + +const ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward" + +func addToRouteTable(prefix netip.Prefix, addr string) error { + _, ipNet, err := net.ParseCIDR(prefix.String()) + if err != nil { + return err + } + + addrMask := "/32" + if prefix.Addr().Unmap().Is6() { + addrMask = "/128" + } + + ip, _, err := net.ParseCIDR(addr + addrMask) + if err != nil { + return err + } + + route := &netlink.Route{ + Scope: netlink.SCOPE_UNIVERSE, + Dst: ipNet, + Gw: ip, + } + + err = netlink.RouteAdd(route) + if err != nil { + return err + } + + return nil +} + +func removeFromRouteTable(prefix netip.Prefix) error { + _, ipNet, err := net.ParseCIDR(prefix.String()) + if err != nil { + return err + } + + route := &netlink.Route{ + Scope: netlink.SCOPE_UNIVERSE, + Dst: ipNet, + } + + err = netlink.RouteDel(route) + if err != nil { + return err + } + + return nil +} + +func enableIPForwarding() error { + err := ioutil.WriteFile(ipv4ForwardingPath, []byte("1"), 0644) + return err +} + +func isNetForwardHistoryEnabled() bool { + out, err := ioutil.ReadFile(ipv4ForwardingPath) + if err != nil { + // todo + panic(err) + } + return string(out) == "1" +} diff --git a/client/internal/routemanager/systemops_nonlinux.go b/client/internal/routemanager/systemops_nonlinux.go new file mode 100644 index 000000000..aad8a1202 --- /dev/null +++ b/client/internal/routemanager/systemops_nonlinux.go @@ -0,0 +1,41 @@ +//go:build !linux +// +build !linux + +package routemanager + +import ( + log "github.com/sirupsen/logrus" + "net/netip" + "os/exec" + "runtime" +) + +func addToRouteTable(prefix netip.Prefix, addr string) error { + cmd := exec.Command("route", "add", prefix.String(), addr) + out, err := cmd.Output() + if err != nil { + return err + } + log.Debugf(string(out)) + return nil +} + +func removeFromRouteTable(prefix netip.Prefix) error { + cmd := exec.Command("route", "delete", prefix.String()) + out, err := cmd.Output() + if err != nil { + return err + } + log.Debugf(string(out)) + return nil +} + +func enableIPForwarding() error { + log.Infof("enable IP forwarding is not implemented on %s", runtime.GOOS) + return nil +} + +func isNetForwardHistoryEnabled() bool { + log.Infof("check netforwad history is not implemented on %s", runtime.GOOS) + return false +} diff --git a/client/internal/routemanager/systemops_test.go b/client/internal/routemanager/systemops_test.go new file mode 100644 index 000000000..821e9a46e --- /dev/null +++ b/client/internal/routemanager/systemops_test.go @@ -0,0 +1,68 @@ +package routemanager + +import ( + "fmt" + "github.com/netbirdio/netbird/iface" + "github.com/stretchr/testify/require" + "net/netip" + "testing" +) + +func TestAddRemoveRoutes(t *testing.T) { + testCases := []struct { + name string + prefix netip.Prefix + shouldRouteToWireguard bool + shouldBeRemoved bool + }{ + { + name: "Should Add And Remove Route", + prefix: netip.MustParsePrefix("100.66.120.0/24"), + shouldRouteToWireguard: true, + shouldBeRemoved: true, + }, + { + name: "Should Not Add Or Remove Route", + prefix: netip.MustParsePrefix("127.0.0.1/32"), + shouldRouteToWireguard: false, + shouldBeRemoved: false, + }, + } + + for n, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", iface.DefaultMTU) + require.NoError(t, err, "should create testing WGIface interface") + defer wgInterface.Close() + + err = wgInterface.Create() + require.NoError(t, err, "should create testing wireguard interface") + + err = addToRouteTableIfNoExists(testCase.prefix, wgInterface.GetAddress().IP.String()) + require.NoError(t, err, "should not return err") + + prefixGateway, err := getExistingRIBRouteGateway(testCase.prefix) + require.NoError(t, err, "should not return err") + if testCase.shouldRouteToWireguard { + require.Equal(t, wgInterface.GetAddress().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP") + } else { + require.NotEqual(t, wgInterface.GetAddress().IP.String(), prefixGateway.String(), "route should point to a different interface") + } + + err = removeFromRouteTableIfNonSystem(testCase.prefix, wgInterface.GetAddress().IP.String()) + require.NoError(t, err, "should not return err") + + prefixGateway, err = getExistingRIBRouteGateway(testCase.prefix) + require.NoError(t, err, "should not return err") + + internetGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0")) + require.NoError(t, err) + + if testCase.shouldBeRemoved { + require.Equal(t, internetGateway, prefixGateway, "route should be pointing to default internet gateway") + } else { + require.NotEqual(t, internetGateway, prefixGateway, "route should be pointing to a different gateway than the internet gateway") + } + }) + } +} diff --git a/client/status/status.go b/client/status/status.go index 3b96a8098..a337df6c0 100644 --- a/client/status/status.go +++ b/client/status/status.go @@ -47,17 +47,19 @@ type FullStatus struct { // Status holds a state of peers, signal and management connections type Status struct { - mux sync.Mutex - peers map[string]PeerState - signal SignalState - management ManagementState - localPeer LocalPeerState + mux sync.Mutex + peers map[string]PeerState + changeNotify map[string]chan struct{} + signal SignalState + management ManagementState + localPeer LocalPeerState } // NewRecorder returns a new Status instance func NewRecorder() *Status { return &Status{ - peers: make(map[string]PeerState), + peers: make(map[string]PeerState), + changeNotify: make(map[string]chan struct{}), } } @@ -74,6 +76,18 @@ func (d *Status) AddPeer(peerPubKey string) error { return nil } +// GetPeer adds peer to Daemon status map +func (d *Status) GetPeer(peerPubKey string) (PeerState, error) { + d.mux.Lock() + defer d.mux.Unlock() + + state, ok := d.peers[peerPubKey] + if !ok { + return PeerState{}, errors.New("peer not found") + } + return state, nil +} + // RemovePeer removes peer from Daemon status map func (d *Status) RemovePeer(peerPubKey string) error { d.mux.Lock() @@ -113,9 +127,27 @@ func (d *Status) UpdatePeerState(receivedState PeerState) error { d.peers[receivedState.PubKey] = peerState + ch, found := d.changeNotify[receivedState.PubKey] + if found && ch != nil { + close(ch) + d.changeNotify[receivedState.PubKey] = nil + } + return nil } +// GetPeerStateChangeNotifier returns a change notifier channel for a peer +func (d *Status) GetPeerStateChangeNotifier(peer string) <-chan struct{} { + d.mux.Lock() + defer d.mux.Unlock() + ch, found := d.changeNotify[peer] + if !found || ch == nil { + ch = make(chan struct{}) + d.changeNotify[peer] = ch + } + return ch +} + // UpdateLocalPeerState updates local peer status func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) { d.mux.Lock() diff --git a/client/status/status_test.go b/client/status/status_test.go index 02abfbfe0..00161dbd0 100644 --- a/client/status/status_test.go +++ b/client/status/status_test.go @@ -19,6 +19,21 @@ func TestAddPeer(t *testing.T) { assert.Error(t, err, "should return error on duplicate") } +func TestGetPeer(t *testing.T) { + key := "abc" + status := NewRecorder() + err := status.AddPeer(key) + assert.NoError(t, err, "shouldn't return error") + + peerStatus, err := status.GetPeer(key) + assert.NoError(t, err, "shouldn't return error on getting peer") + + assert.Equal(t, key, peerStatus.PubKey, "retrieved public key should match") + + _, err = status.GetPeer("non_existing_key") + assert.Error(t, err, "should return error when peer doesn't exist") +} + func TestUpdatePeerState(t *testing.T) { key := "abc" ip := "10.10.10.10" @@ -39,6 +54,31 @@ func TestUpdatePeerState(t *testing.T) { assert.Equal(t, ip, state.IP, "ip should be equal") } +func TestGetPeerStateChangeNotifierLogic(t *testing.T) { + key := "abc" + ip := "10.10.10.10" + status := NewRecorder() + peerState := PeerState{ + PubKey: key, + } + + status.peers[key] = peerState + + ch := status.GetPeerStateChangeNotifier(key) + assert.NotNil(t, ch, "channel shouldn't be nil") + + peerState.IP = ip + + err := status.UpdatePeerState(peerState) + assert.NoError(t, err, "shouldn't return error") + + select { + case <-ch: + default: + t.Errorf("channel wasn't closed after update") + } +} + func TestRemovePeer(t *testing.T) { key := "abc" status := NewRecorder() diff --git a/go.mod b/go.mod index f630b9795..a2538197c 100644 --- a/go.mod +++ b/go.mod @@ -30,10 +30,13 @@ require ( require ( fyne.io/fyne/v2 v2.1.4 github.com/c-robinson/iplib v1.0.3 + github.com/coreos/go-iptables v0.6.0 github.com/creack/pty v1.1.18 github.com/eko/gocache/v2 v2.3.1 github.com/getlantern/systray v1.2.1 github.com/gliderlabs/ssh v0.3.4 + github.com/google/nftables v0.0.0-20220808154552-2eca00135732 + github.com/libp2p/go-netroute v0.2.0 github.com/magiconair/properties v1.8.5 github.com/patrickmn/go-cache v2.1.0+incompatible github.com/rs/xid v1.3.0 @@ -67,6 +70,7 @@ require ( github.com/godbus/dbus/v5 v5.0.4 // indirect github.com/goki/freetype v0.0.0-20181231101311-fa8a33aabaff // indirect github.com/google/go-cmp v0.5.7 // indirect + github.com/google/gopacket v1.1.19 // indirect github.com/inconshreveable/mousetrap v1.0.0 // indirect github.com/josharian/native v0.0.0-20200817173448-b6b71def0850 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect diff --git a/go.sum b/go.sum index 282f69e88..2e36b1a4b 100644 --- a/go.sum +++ b/go.sum @@ -115,6 +115,8 @@ github.com/cncf/xds/go v0.0.0-20211001041855-01bcc9b48dfe/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211130200136-a8f946100490/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/coocood/freecache v1.2.1 h1:/v1CqMq45NFH9mp/Pt142reundeBM0dVUD3osQBeu/U= +github.com/coreos/go-iptables v0.6.0 h1:is9qnZMPYjLd8LYqmm/qlE+wwEgJIkTYdhV3rfZo4jk= +github.com/coreos/go-iptables v0.6.0/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= @@ -283,10 +285,14 @@ github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8 github.com/google/gofuzz v0.0.0-20161122191042-44d81051d367/go.mod h1:HP5RmnzzSNb993RKQDq4+1A4ia9nllfqcQFTQJedwGI= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.1.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= +github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= github.com/google/martian/v3 v3.1.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= github.com/google/martian/v3 v3.2.1/go.mod h1:oBOf6HBosgwRXnUGWUB05QECsc6uvmMiJ3+6W4l/CUk= +github.com/google/nftables v0.0.0-20220808154552-2eca00135732 h1:csc7dT82JiSLvq4aMyQMIQDL7986NH6Wxf/QrvOj55A= +github.com/google/nftables v0.0.0-20220808154552-2eca00135732/go.mod h1:b97ulCCFipUC+kSin+zygkvUVpx0vyIAwxXFdY3PlNc= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20191218002539-d4f498aebedc/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= @@ -399,6 +405,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdAPozLkw= +github.com/libp2p/go-netroute v0.2.0 h1:0FpsbsvuSnAhXFnCY0VLFbJOzaK0VnP0r1QT/o4nWRE= +github.com/libp2p/go-netroute v0.2.0/go.mod h1:Vio7LTzZ+6hoT4CMZi5/6CpY3Snzh2vgZhWgxMNwlQI= github.com/lucor/goinfo v0.0.0-20210802170112-c078a2b0f08b/go.mod h1:PRq09yoB+Q2OJReAmwzKivcYyremnibWGbK7WfftHzc= github.com/lyft/protoc-gen-star v0.5.3/go.mod h1:V0xaHgaf5oCCqmcxYcWiDfTiKsZsRc87/1qhoTACD8w= github.com/magiconair/properties v1.8.5 h1:b6kJs+EmPFMYGkow9GiUyCyOvIwYetYJ3fSaWak/Gls= @@ -748,6 +756,7 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v golang.org/x/net v0.0.0-20210316092652-d523dce5a7f4/go.mod h1:RBQZq4jEuRlivfhVLdyRGr576XBO4/greRjx4P4O3yc= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20210410081132-afb366fc7cd1/go.mod h1:9tjilg8BloeKEkVJvy7fQ90B1CfIiPueXVOjqfkSzI8= +golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= golang.org/x/net v0.0.0-20210503060351-7fd8e65b6420/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= @@ -870,6 +879,7 @@ golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210426080607-c94f62235c83/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210514084401-e8d321eab015/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210525143221-35b2ab0089ea/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/iface/configuration.go b/iface/configuration.go index 9f49cf6ee..07adc1555 100644 --- a/iface/configuration.go +++ b/iface/configuration.go @@ -9,6 +9,16 @@ import ( "time" ) +// GetName returns the interface name +func (w *WGIface) GetName() string { + return w.Name +} + +// GetAddress returns the interface address +func (w *WGIface) GetAddress() WGAddress { + return w.Address +} + // configureDevice configures the wireguard device func (w *WGIface) configureDevice(config wgtypes.Config) error { wg, err := wgctrl.New() @@ -112,6 +122,114 @@ func (w *WGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.D return nil } +// AddAllowedIP adds a prefix to the allowed IPs list of peer +func (w *WGIface) AddAllowedIP(peerKey string, allowedIP string) error { + w.mu.Lock() + defer w.mu.Unlock() + + log.Debugf("adding allowed IP to interface %s and peer %s: allowed IP %s ", w.Name, peerKey, allowedIP) + + _, ipNet, err := net.ParseCIDR(allowedIP) + if err != nil { + return err + } + + peerKeyParsed, err := wgtypes.ParseKey(peerKey) + if err != nil { + return err + } + peer := wgtypes.PeerConfig{ + PublicKey: peerKeyParsed, + UpdateOnly: true, + ReplaceAllowedIPs: false, + AllowedIPs: []net.IPNet{*ipNet}, + } + + config := wgtypes.Config{ + Peers: []wgtypes.PeerConfig{peer}, + } + err = w.configureDevice(config) + if err != nil { + return fmt.Errorf("received error \"%v\" while adding allowed Ip to peer on interface %s with settings: allowed ips %s", err, w.Name, allowedIP) + } + return nil +} + +// RemoveAllowedIP removes a prefix from the allowed IPs list of peer +func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP string) error { + w.mu.Lock() + defer w.mu.Unlock() + + log.Debugf("removing allowed IP from interface %s and peer %s: allowed IP %s ", w.Name, peerKey, allowedIP) + + _, ipNet, err := net.ParseCIDR(allowedIP) + if err != nil { + return err + } + + peerKeyParsed, err := wgtypes.ParseKey(peerKey) + if err != nil { + return err + } + + existingPeer, err := getPeer(w.Name, peerKey) + if err != nil { + return err + } + + newAllowedIPs := existingPeer.AllowedIPs + + for i, existingAllowedIP := range existingPeer.AllowedIPs { + if existingAllowedIP.String() == ipNet.String() { + newAllowedIPs = append(existingPeer.AllowedIPs[:i], existingPeer.AllowedIPs[i+1:]...) + break + } + } + + if err != nil { + return err + } + peer := wgtypes.PeerConfig{ + PublicKey: peerKeyParsed, + UpdateOnly: true, + ReplaceAllowedIPs: true, + AllowedIPs: newAllowedIPs, + } + + config := wgtypes.Config{ + Peers: []wgtypes.PeerConfig{peer}, + } + err = w.configureDevice(config) + if err != nil { + return fmt.Errorf("received error \"%v\" while removing allowed IP from peer on interface %s with settings: allowed ips %s", err, w.Name, allowedIP) + } + return nil +} + +func getPeer(ifaceName, peerPubKey string) (wgtypes.Peer, error) { + wg, err := wgctrl.New() + if err != nil { + return wgtypes.Peer{}, err + } + defer func() { + err = wg.Close() + if err != nil { + log.Errorf("got error while closing wgctl: %v", err) + } + }() + + wgDevice, err := wg.Device(ifaceName) + if err != nil { + return wgtypes.Peer{}, err + } + for _, peer := range wgDevice.Peers { + if peer.PublicKey.String() == peerPubKey { + return peer, nil + } + } + return wgtypes.Peer{}, fmt.Errorf("peer not found") +} + // RemovePeer removes a Wireguard Peer from the interface iface func (w *WGIface) RemovePeer(peerKey string) error { w.mu.Lock() diff --git a/iface/iface_test.go b/iface/iface_test.go index d4791950f..0c7aa3f3d 100644 --- a/iface/iface_test.go +++ b/iface/iface_test.go @@ -229,7 +229,7 @@ func Test_UpdatePeer(t *testing.T) { if err != nil { t.Fatal(err) } - peer, err := getPeer(ifaceName, peerPubKey, t) + peer, err := getPeer(ifaceName, peerPubKey) if err != nil { t.Fatal(err) } @@ -289,7 +289,7 @@ func Test_RemovePeer(t *testing.T) { if err != nil { t.Fatal(err) } - _, err = getPeer(ifaceName, peerPubKey, t) + _, err = getPeer(ifaceName, peerPubKey) if err.Error() != "peer not found" { t.Fatal(err) } @@ -378,7 +378,7 @@ func Test_ConnectPeers(t *testing.T) { t.Fatalf("waiting for peer handshake timeout after %s", timeout.String()) default: } - peer, gpErr := getPeer(peer1ifaceName, peer2Key.PublicKey().String(), t) + peer, gpErr := getPeer(peer1ifaceName, peer2Key.PublicKey().String()) if gpErr != nil { t.Fatal(gpErr) } @@ -389,28 +389,3 @@ func Test_ConnectPeers(t *testing.T) { } } - -func getPeer(ifaceName, peerPubKey string, t *testing.T) (wgtypes.Peer, error) { - emptyPeer := wgtypes.Peer{} - wg, err := wgctrl.New() - if err != nil { - return emptyPeer, err - } - defer func() { - err = wg.Close() - if err != nil { - t.Error(err) - } - }() - - wgDevice, err := wg.Device(ifaceName) - if err != nil { - return emptyPeer, err - } - for _, peer := range wgDevice.Peers { - if peer.PublicKey.String() == peerPubKey { - return peer, nil - } - } - return emptyPeer, fmt.Errorf("peer not found") -}