From 06125acb8d9957d9ae720baca5d4416df3a87575 Mon Sep 17 00:00:00 2001 From: Misha Bragin Date: Thu, 10 Aug 2023 21:10:12 +0200 Subject: [PATCH 01/42] Update new release banner (#1072) --- README.md | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 9b79b1856..bf83435f9 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@

- :hatching_chick: New Release! Peer expiration. - + :hatching_chick: New Release! Self-hosting in under 5 min. + Learn more

@@ -40,9 +40,13 @@ **Connect.** NetBird creates a WireGuard-based overlay network that automatically connects your machines over an encrypted tunnel, leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth. -**Secure.** NetBird isolates every machine and device by applying granular access policies, while allowing you to manage them intuitively from a single place. +**Secure.** NetBird enables secure remote access by applying granular access policies, while allowing you to manage them intuitively from a single place. Works universally on any infrastructure. -**Key features:** +### Secure peer-to-peer VPN with SSO and MFA in minutes + +https://user-images.githubusercontent.com/700848/197345890-2e2cded5-7b7a-436f-a444-94e80dd24f46.mov + +### Key features | Connectivity | Management | Automation | Platforms | |-------------------------------------------------------------------|--------------------------------------------------------------------------|----------------------------------------------------------------------------|---------------------------------------| @@ -57,10 +61,6 @@ | | | | | -### Secure peer-to-peer VPN with SSO and MFA in minutes - -https://user-images.githubusercontent.com/700848/197345890-2e2cded5-7b7a-436f-a444-94e80dd24f46.mov - ### Quickstart with NetBird Cloud - Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install) From 2dec016201d22bd17c84173d344ca34942870e9c Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Fri, 11 Aug 2023 11:51:39 +0200 Subject: [PATCH 02/42] Fix/always on boot (#1062) In case of 'always-on' feature has switched on, after the reboot the service do not start properly in all cases. If the device is in offline state (no internet connection) the auth login steps will fail and the service will stop. For the auth steps make no sense in this case because if the OS start the service we do not have option for the user interaction. --- client/android/client.go | 26 ++++++++++++++++++++++++-- client/internal/connect.go | 3 +-- client/internal/peer/notifier.go | 4 ++-- 3 files changed, 27 insertions(+), 6 deletions(-) diff --git a/client/android/client.go b/client/android/client.go index d8f561e18..bb15268eb 100644 --- a/client/android/client.go +++ b/client/android/client.go @@ -55,7 +55,6 @@ type Client struct { ctxCancelLock *sync.Mutex deviceName string routeListener routemanager.RouteListener - onHostDnsFn func([]string) } // NewClient instantiate a new Client @@ -97,7 +96,30 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead // todo do not throw error in case of cancelled context ctx = internal.CtxInitState(ctx) - c.onHostDnsFn = func([]string) {} + return internal.RunClientMobile(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.routeListener, dns.items, dnsReadyListener) +} + +// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot). +// In this case make no sense handle registration steps. +func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener) error { + cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{ + ConfigPath: c.cfgFile, + }) + if err != nil { + return err + } + c.recorder.UpdateManagementAddress(cfg.ManagementURL.String()) + + var ctx context.Context + //nolint + ctxWithValues := context.WithValue(context.Background(), system.DeviceNameCtxKey, c.deviceName) + c.ctxCancelLock.Lock() + ctx, c.ctxCancel = context.WithCancel(ctxWithValues) + defer c.ctxCancel() + c.ctxCancelLock.Unlock() + + // todo do not throw error in case of cancelled context + ctx = internal.CtxInitState(ctx) return internal.RunClientMobile(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.routeListener, dns.items, dnsReadyListener) } diff --git a/client/internal/connect.go b/client/internal/connect.go index 87aab0b54..6eecf4207 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -179,8 +179,6 @@ func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status, log.Print("Netbird engine started, my IP is: ", peerConfig.Address) state.Set(StatusConnected) - statusRecorder.ClientStart() - <-engineCtx.Done() statusRecorder.ClientTeardown() @@ -201,6 +199,7 @@ func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status, return nil } + statusRecorder.ClientStart() err = backoff.Retry(operation, backOff) if err != nil { log.Debugf("exiting client retry loop due to unrecoverable error: %s", err) diff --git a/client/internal/peer/notifier.go b/client/internal/peer/notifier.go index eb15bdeeb..f1175c2c4 100644 --- a/client/internal/peer/notifier.go +++ b/client/internal/peer/notifier.go @@ -61,7 +61,7 @@ func (n *notifier) clientStart() { n.serverStateLock.Lock() defer n.serverStateLock.Unlock() n.currentClientState = true - n.lastNotification = stateConnected + n.lastNotification = stateConnecting n.notify(n.lastNotification) } @@ -114,7 +114,7 @@ func (n *notifier) calculateState(managementConn, signalConn bool) int { return stateConnected } - if !managementConn && !signalConn { + if !managementConn && !signalConn && !n.currentClientState { return stateDisconnected } From 0f0c7ec2ed987f3a6f3d63f6d32839013b42287a Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Sat, 12 Aug 2023 11:42:36 +0200 Subject: [PATCH 03/42] Routemgr error handling (#1073) In case the route management feature is not supported then do not create unnecessary firewall and manager instances. This can happen if the nftables nor iptables is not available on the host OS. - Move the error handling to upper layer - Remove fake, useless implementations of interfaces - Update go-iptables because In Docker the old version can not determine well the path of executable file - update lib to 0.70 --- .../internal/routemanager/firewall_linux.go | 13 ++++++--- .../routemanager/firewall_nonlinux.go | 26 +++++------------ .../internal/routemanager/iptables_linux.go | 19 ++++++++---- .../routemanager/iptables_linux_test.go | 2 +- client/internal/routemanager/manager.go | 23 ++++++++++----- client/internal/routemanager/manager_test.go | 29 ++----------------- client/internal/routemanager/server.go | 9 ++++++ .../internal/routemanager/server_android.go | 15 ++-------- .../routemanager/server_nonandroid.go | 23 +++++++++------ client/internal/stdnet/filter.go | 2 +- go.mod | 2 +- go.sum | 4 +-- 12 files changed, 79 insertions(+), 88 deletions(-) create mode 100644 client/internal/routemanager/server.go diff --git a/client/internal/routemanager/firewall_linux.go b/client/internal/routemanager/firewall_linux.go index 8b27c8967..19a5a4cde 100644 --- a/client/internal/routemanager/firewall_linux.go +++ b/client/internal/routemanager/firewall_linux.go @@ -27,14 +27,19 @@ func genKey(format string, input string) string { } // NewFirewall if supported, returns an iptables manager, otherwise returns a nftables manager -func NewFirewall(parentCTX context.Context) firewallManager { +func NewFirewall(parentCTX context.Context) (firewallManager, error) { manager, err := newNFTablesManager(parentCTX) if err == nil { log.Debugf("nftables firewall manager will be used") - return manager + return manager, nil } - log.Debugf("fallback to iptables firewall manager: %s", err) - return newIptablesManager(parentCTX) + fMgr, err := newIptablesManager(parentCTX) + if err != nil { + log.Debugf("failed to initialize iptables for root mgr: %s", err) + return nil, err + } + log.Debugf("iptables firewall manager will be used") + return fMgr, nil } func getInPair(pair routerPair) routerPair { diff --git a/client/internal/routemanager/firewall_nonlinux.go b/client/internal/routemanager/firewall_nonlinux.go index 4691f15f8..1b52a1e85 100644 --- a/client/internal/routemanager/firewall_nonlinux.go +++ b/client/internal/routemanager/firewall_nonlinux.go @@ -3,24 +3,12 @@ package routemanager -import "context" +import ( + "context" + "fmt" +) -type unimplementedFirewall struct{} - -func (unimplementedFirewall) RestoreOrCreateContainers() error { - return nil -} -func (unimplementedFirewall) InsertRoutingRules(pair routerPair) error { - return nil -} -func (unimplementedFirewall) RemoveRoutingRules(pair routerPair) error { - return nil -} - -func (unimplementedFirewall) CleanRoutingRules() { -} - -// NewFirewall returns an unimplemented Firewall manager -func NewFirewall(parentCtx context.Context) firewallManager { - return unimplementedFirewall{} +// NewFirewall returns a nil manager +func NewFirewall(context.Context) (firewallManager, error) { + return nil, fmt.Errorf("firewall not supported on this OS") } diff --git a/client/internal/routemanager/iptables_linux.go b/client/internal/routemanager/iptables_linux.go index 3e3c16919..a87d4f4a3 100644 --- a/client/internal/routemanager/iptables_linux.go +++ b/client/internal/routemanager/iptables_linux.go @@ -49,14 +49,12 @@ type iptablesManager struct { mux sync.Mutex } -func newIptablesManager(parentCtx context.Context) *iptablesManager { - ctx, cancel := context.WithCancel(parentCtx) +func newIptablesManager(parentCtx context.Context) (*iptablesManager, error) { ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) if err != nil { - log.Debugf("failed to initialize iptables for ipv4: %s", err) + return nil, err } else if !isIptablesClientAvailable(ipv4Client) { - log.Infof("iptables is missing for ipv4") - ipv4Client = nil + return nil, fmt.Errorf("iptables is missing for ipv4") } ipv6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6) if err != nil { @@ -66,13 +64,14 @@ func newIptablesManager(parentCtx context.Context) *iptablesManager { ipv6Client = nil } + ctx, cancel := context.WithCancel(parentCtx) return &iptablesManager{ ctx: ctx, stop: cancel, ipv4Client: ipv4Client, ipv6Client: ipv6Client, rules: make(map[string]map[string][]string), - } + }, nil } // CleanRoutingRules cleans existing iptables resources that we created by the agent @@ -395,6 +394,10 @@ func (i *iptablesManager) insertRoutingRule(keyFormat, table, chain, jump string ipVersion = ipv6 } + if iptablesClient == nil { + return fmt.Errorf("unable to insert iptables routing rules. Iptables client is not initialized") + } + ruleKey := genKey(keyFormat, pair.ID) rule := genRuleSpec(jump, ruleKey, pair.source, pair.destination) existingRule, found := i.rules[ipVersion][ruleKey] @@ -459,6 +462,10 @@ func (i *iptablesManager) removeRoutingRule(keyFormat, table, chain string, pair ipVersion = ipv6 } + if iptablesClient == nil { + return fmt.Errorf("unable to remove iptables routing rules. Iptables client is not initialized") + } + ruleKey := genKey(keyFormat, pair.ID) existingRule, found := i.rules[ipVersion][ruleKey] if found { diff --git a/client/internal/routemanager/iptables_linux_test.go b/client/internal/routemanager/iptables_linux_test.go index c26355e56..dbe153f7b 100644 --- a/client/internal/routemanager/iptables_linux_test.go +++ b/client/internal/routemanager/iptables_linux_test.go @@ -16,7 +16,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { t.SkipNow() } - manager := newIptablesManager(context.TODO()) + manager, _ := newIptablesManager(context.TODO()) defer manager.CleanRoutingRules() diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 7324759f9..13d9d1f38 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -27,7 +27,7 @@ type DefaultManager struct { stop context.CancelFunc mux sync.Mutex clientNetworks map[string]*clientNetwork - serverRouter *serverRouter + serverRouter serverRouter statusRecorder *peer.Status wgInterface *iface.WGIface pubKey string @@ -36,13 +36,17 @@ type DefaultManager struct { // NewManager returns a new route manager func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status, initialRoutes []*route.Route) *DefaultManager { - mCTX, cancel := context.WithCancel(ctx) + serverRouter, err := newServerRouter(ctx, wgInterface) + if err != nil { + log.Errorf("server router is not supported: %s", err) + } + mCTX, cancel := context.WithCancel(ctx) dm := &DefaultManager{ ctx: mCTX, stop: cancel, clientNetworks: make(map[string]*clientNetwork), - serverRouter: newServerRouter(ctx, wgInterface), + serverRouter: serverRouter, statusRecorder: statusRecorder, wgInterface: wgInterface, pubKey: pubKey, @@ -59,7 +63,9 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, // Stop stops the manager watchers and clean firewall rules func (m *DefaultManager) Stop() { m.stop() - m.serverRouter.cleanUp() + if m.serverRouter != nil { + m.serverRouter.cleanUp() + } m.ctx = nil } @@ -77,9 +83,12 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro m.updateClientNetworks(updateSerial, newClientRoutesIDMap) m.notifier.onNewRoutes(newClientRoutesIDMap) - err := m.serverRouter.updateRoutes(newServerRoutesMap) - if err != nil { - return err + + if m.serverRouter != nil { + err := m.serverRouter.updateRoutes(newServerRoutesMap) + if err != nil { + return err + } } return nil diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index 6291b4996..6f2ac294d 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -30,7 +30,6 @@ func TestManagerUpdateRoutes(t *testing.T) { inputInitRoutes []*route.Route inputRoutes []*route.Route inputSerial uint64 - shouldCheckServerRoutes bool serverRoutesExpected int clientNetworkWatchersExpected int }{ @@ -87,7 +86,6 @@ func TestManagerUpdateRoutes(t *testing.T) { }, }, inputSerial: 1, - shouldCheckServerRoutes: runtime.GOOS == "linux", serverRoutesExpected: 2, clientNetworkWatchersExpected: 0, }, @@ -116,7 +114,6 @@ func TestManagerUpdateRoutes(t *testing.T) { }, }, inputSerial: 1, - shouldCheckServerRoutes: runtime.GOOS == "linux", serverRoutesExpected: 1, clientNetworkWatchersExpected: 1, }, @@ -174,25 +171,6 @@ func TestManagerUpdateRoutes(t *testing.T) { inputSerial: 1, clientNetworkWatchersExpected: 0, }, - { - name: "No Server Routes Should Be Added To Non Linux", - inputRoutes: []*route.Route{ - { - ID: "a", - NetID: "routeA", - Peer: localPeerKey, - Network: netip.MustParsePrefix("1.2.3.4/32"), - NetworkType: route.IPv4Network, - Metric: 9999, - Masquerade: false, - Enabled: true, - }, - }, - inputSerial: 1, - shouldCheckServerRoutes: runtime.GOOS != "linux", - serverRoutesExpected: 0, - clientNetworkWatchersExpected: 0, - }, { name: "Remove 1 Client Route", inputInitRoutes: []*route.Route{ @@ -335,7 +313,6 @@ func TestManagerUpdateRoutes(t *testing.T) { }, inputRoutes: []*route.Route{}, inputSerial: 1, - shouldCheckServerRoutes: true, serverRoutesExpected: 0, clientNetworkWatchersExpected: 0, }, @@ -384,7 +361,6 @@ func TestManagerUpdateRoutes(t *testing.T) { }, }, inputSerial: 1, - shouldCheckServerRoutes: runtime.GOOS == "linux", serverRoutesExpected: 2, clientNetworkWatchersExpected: 1, }, @@ -419,8 +395,9 @@ func TestManagerUpdateRoutes(t *testing.T) { require.Len(t, routeManager.clientNetworks, testCase.clientNetworkWatchersExpected, "client networks size should match") - if testCase.shouldCheckServerRoutes { - require.Len(t, routeManager.serverRouter.routes, testCase.serverRoutesExpected, "server networks size should match") + if runtime.GOOS == "linux" { + sr := routeManager.serverRouter.(*defaultServerRouter) + require.Len(t, sr.routes, testCase.serverRoutesExpected, "server networks size should match") } }) } diff --git a/client/internal/routemanager/server.go b/client/internal/routemanager/server.go new file mode 100644 index 000000000..c9a13a904 --- /dev/null +++ b/client/internal/routemanager/server.go @@ -0,0 +1,9 @@ +package routemanager + +import "github.com/netbirdio/netbird/route" + +type serverRouter interface { + updateRoutes(map[string]*route.Route) error + removeFromServerNetwork(*route.Route) error + cleanUp() +} diff --git a/client/internal/routemanager/server_android.go b/client/internal/routemanager/server_android.go index c5e79a1a8..d130acc00 100644 --- a/client/internal/routemanager/server_android.go +++ b/client/internal/routemanager/server_android.go @@ -2,20 +2,11 @@ package routemanager import ( "context" + "fmt" "github.com/netbirdio/netbird/iface" - "github.com/netbirdio/netbird/route" ) -type serverRouter struct { +func newServerRouter(context.Context, *iface.WGIface) (serverRouter, error) { + return nil, fmt.Errorf("server route not supported on this os") } - -func newServerRouter(ctx context.Context, wgInterface *iface.WGIface) *serverRouter { - return &serverRouter{} -} - -func (r *serverRouter) updateRoutes(routesMap map[string]*route.Route) error { - return nil -} - -func (r *serverRouter) cleanUp() {} diff --git a/client/internal/routemanager/server_nonandroid.go b/client/internal/routemanager/server_nonandroid.go index 4b85149fa..bf7a1dfd4 100644 --- a/client/internal/routemanager/server_nonandroid.go +++ b/client/internal/routemanager/server_nonandroid.go @@ -13,7 +13,7 @@ import ( "github.com/netbirdio/netbird/route" ) -type serverRouter struct { +type defaultServerRouter struct { mux sync.Mutex ctx context.Context routes map[string]*route.Route @@ -21,16 +21,21 @@ type serverRouter struct { wgInterface *iface.WGIface } -func newServerRouter(ctx context.Context, wgInterface *iface.WGIface) *serverRouter { - return &serverRouter{ +func newServerRouter(ctx context.Context, wgInterface *iface.WGIface) (serverRouter, error) { + firewall, err := NewFirewall(ctx) + if err != nil { + return nil, err + } + + return &defaultServerRouter{ ctx: ctx, routes: make(map[string]*route.Route), - firewall: NewFirewall(ctx), + firewall: firewall, wgInterface: wgInterface, - } + }, nil } -func (m *serverRouter) updateRoutes(routesMap map[string]*route.Route) error { +func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) error { serverRoutesToRemove := make([]string, 0) if len(routesMap) > 0 { @@ -81,7 +86,7 @@ func (m *serverRouter) updateRoutes(routesMap map[string]*route.Route) error { return nil } -func (m *serverRouter) removeFromServerNetwork(route *route.Route) error { +func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error { select { case <-m.ctx.Done(): log.Infof("not removing from server network because context is done") @@ -98,7 +103,7 @@ func (m *serverRouter) removeFromServerNetwork(route *route.Route) error { } } -func (m *serverRouter) addToServerNetwork(route *route.Route) error { +func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error { select { case <-m.ctx.Done(): log.Infof("not adding to server network because context is done") @@ -115,6 +120,6 @@ func (m *serverRouter) addToServerNetwork(route *route.Route) error { } } -func (m *serverRouter) cleanUp() { +func (m *defaultServerRouter) cleanUp() { m.firewall.CleanRoutingRules() } diff --git a/client/internal/stdnet/filter.go b/client/internal/stdnet/filter.go index da35a623f..8bbb93a25 100644 --- a/client/internal/stdnet/filter.go +++ b/client/internal/stdnet/filter.go @@ -20,7 +20,7 @@ func InterfaceFilter(disallowList []string) func(string) bool { for _, s := range disallowList { if strings.HasPrefix(iFace, s) { - log.Debugf("ignoring interface %s - it is not allowed", iFace) + log.Tracef("ignoring interface %s - it is not allowed", iFace) return false } } diff --git a/go.mod b/go.mod index 32e25c1a0..7ecf61584 100644 --- a/go.mod +++ b/go.mod @@ -31,7 +31,7 @@ require ( fyne.io/fyne/v2 v2.1.4 github.com/c-robinson/iplib v1.0.3 github.com/cilium/ebpf v0.10.0 - github.com/coreos/go-iptables v0.6.0 + github.com/coreos/go-iptables v0.7.0 github.com/creack/pty v1.1.18 github.com/eko/gocache/v3 v3.1.1 github.com/getlantern/systray v1.2.1 diff --git a/go.sum b/go.sum index da913db52..1eb9d243d 100644 --- a/go.sum +++ b/go.sum @@ -131,8 +131,8 @@ github.com/containerd/typeurl v1.0.2/go.mod h1:9trJWW2sRlGub4wZJRTW83VtbOLS6hwcD github.com/coocood/freecache v1.2.1 h1:/v1CqMq45NFH9mp/Pt142reundeBM0dVUD3osQBeu/U= github.com/coreos/bbolt v1.3.2/go.mod h1:iRUV2dpdMOn7Bo10OQBFzIJO9kkE559Wcmn+qkEiiKk= github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= -github.com/coreos/go-iptables v0.6.0 h1:is9qnZMPYjLd8LYqmm/qlE+wwEgJIkTYdhV3rfZo4jk= -github.com/coreos/go-iptables v0.6.0/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q= +github.com/coreos/go-iptables v0.7.0 h1:XWM3V+MPRr5/q51NuWSgU0fqMad64Zyxs8ZUoMsamr8= +github.com/coreos/go-iptables v0.7.0/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q= github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/coreos/go-systemd/v22 v22.0.0/go.mod h1:xO0FLkIi5MaZafQlIrOotqXZ90ih+1atmu1JpKERPPk= From 6c2b364966c7dd0f6eaed0cd3f8fa378d757f319 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Sat, 12 Aug 2023 16:12:09 +0200 Subject: [PATCH 04/42] Update client Dockerfile to use Alpine as base image and install necessary packages (#1078) --- client/Dockerfile | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/client/Dockerfile b/client/Dockerfile index aa4578848..63b4a3320 100644 --- a/client/Dockerfile +++ b/client/Dockerfile @@ -1,7 +1,5 @@ -FROM gcr.io/distroless/base:debug +FROM alpine:3 +RUN apk add --no-cache ca-certificates iptables ip6tables ENV NB_FOREGROUND_MODE=true -ENV PATH=/sbin:/usr/sbin:/bin:/usr/bin:/busybox -SHELL ["/busybox/sh","-c"] -RUN sed -i -E 's/(^root:.+)\/sbin\/nologin/\1\/busybox\/sh/g' /etc/passwd ENTRYPOINT [ "/go/bin/netbird","up"] COPY netbird /go/bin/netbird \ No newline at end of file From 442ba7cbc857e58be81bfdd7c30381bd92156644 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Wed, 16 Aug 2023 12:25:38 +0300 Subject: [PATCH 05/42] Add domain validation for nameserver groups (#1077) This change ensures that domain names with uppercase letters are also considered valid, providing more flexibility in domain naming. --- management/server/nameserver.go | 37 +++++++++++++--- management/server/nameserver_test.go | 66 ++++++++++++++++++++++++++++ 2 files changed, 96 insertions(+), 7 deletions(-) diff --git a/management/server/nameserver.go b/management/server/nameserver.go index 5569172c4..eb2127945 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -1,14 +1,18 @@ package server import ( + "errors" + "regexp" + "strconv" + "unicode/utf8" + "github.com/miekg/dns" + "github.com/rs/xid" + log "github.com/sirupsen/logrus" + nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/status" - "github.com/rs/xid" - log "github.com/sirupsen/logrus" - "strconv" - "unicode/utf8" ) const ( @@ -26,6 +30,8 @@ const ( UpdateNameServerGroupPrimary // UpdateNameServerGroupDomains indicates a nameserver group' domains update operation UpdateNameServerGroupDomains + + domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$` ) // NameServerGroupUpdateOperationType operation type @@ -364,9 +370,8 @@ func validateDomainInput(primary bool, domains []string) error { " you should set either primary or domain") } for _, domain := range domains { - _, valid := dns.IsDomainName(domain) - if !valid { - return status.Errorf(status.InvalidArgument, "nameserver group got an invalid domain: %s", domain) + if err := validateDomain(domain); err != nil { + return status.Errorf(status.InvalidArgument, "nameserver group got an invalid domain: %s %q", domain, err) } } return nil @@ -417,3 +422,21 @@ func validateGroups(list []string, groups map[string]*Group) error { return nil } + +func validateDomain(domain string) error { + domainMatcher := regexp.MustCompile(domainPattern) + if !domainMatcher.MatchString(domain) { + return errors.New("domain should consists of only letters, numbers, and hyphens with no leading, trailing hyphens, or spaces") + } + + labels, valid := dns.IsDomainName(domain) + if !valid { + return errors.New("invalid domain name") + } + + if labels < 2 { + return errors.New("domain should consists of a minimum of two labels") + } + + return nil +} diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 3a5c34431..9d4425056 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -1160,3 +1160,69 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error return account, nil } + +func TestValidateDomain(t *testing.T) { + testCases := []struct { + name string + domain string + errFunc require.ErrorAssertionFunc + }{ + { + name: "Valid domain name with multiple labels", + domain: "123.example.com", + errFunc: require.NoError, + }, + { + name: "Valid domain name with hyphen", + domain: "test-example.com", + errFunc: require.NoError, + }, + { + name: "Invalid domain name with double hyphen", + domain: "test--example.com", + errFunc: require.Error, + }, + { + name: "Invalid domain name with only one label", + domain: "com", + errFunc: require.Error, + }, + { + name: "Invalid domain name with a label exceeding 63 characters", + domain: "dnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdns.com", + errFunc: require.Error, + }, + { + name: "Invalid domain name starting with a hyphen", + domain: "-example.com", + errFunc: require.Error, + }, + { + name: "Invalid domain name ending with a hyphen", + domain: "example.com-", + errFunc: require.Error, + }, + { + name: "Invalid domain with unicode", + domain: "example?,.com", + errFunc: require.Error, + }, + { + name: "Invalid domain with space before top-level domain", + domain: "space .example.com", + errFunc: require.Error, + }, + { + name: "Invalid domain with trailing space", + domain: "example.com ", + errFunc: require.Error, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + testCase.errFunc(t, validateDomain(testCase.domain)) + }) + } + +} From 01f2b0ecb70be5cc8e403e00c0b570696b88c2e3 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 16 Aug 2023 15:10:57 +0200 Subject: [PATCH 06/42] Add support to force using binary install (#1082) Check if the USE_BIN_INSTALL variable is set to true and skip package manager discovery --- .github/workflows/install-script-test.yml | 35 +++++++++++ .github/workflows/install-test-darwin.yml | 60 ------------------- .github/workflows/install-test-linux.yml | 38 ------------ .../workflows/test-infrastructure-files.yml | 1 - release_files/install.sh | 30 ++++++++-- 5 files changed, 60 insertions(+), 104 deletions(-) create mode 100644 .github/workflows/install-script-test.yml delete mode 100644 .github/workflows/install-test-darwin.yml delete mode 100644 .github/workflows/install-test-linux.yml diff --git a/.github/workflows/install-script-test.yml b/.github/workflows/install-script-test.yml new file mode 100644 index 000000000..ab07899b5 --- /dev/null +++ b/.github/workflows/install-script-test.yml @@ -0,0 +1,35 @@ +name: Test installation + +on: + push: + branches: + - main + pull_request: + paths: + - "release_files/install.sh" +concurrency: + group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }} + cancel-in-progress: true +jobs: + test-install-script: + strategy: + max-parallel: 2 + matrix: + os: [ubuntu-latest, macos-latest] + skip_ui_mode: [true, false] + install_binary: [true, false] + runs-on: ${{ matrix.os }} + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: run install script + env: + SKIP_UI_APP: ${{ matrix.skip_ui_mode }} + USE_BIN_INSTALL: ${{ matrix.install_binary }} + run: | + [ "$SKIP_UI_APP" == "false" ] && export XDG_CURRENT_DESKTOP="none" + cat release_files/install.sh | sh -x + + - name: check cli binary + run: command -v netbird diff --git a/.github/workflows/install-test-darwin.yml b/.github/workflows/install-test-darwin.yml deleted file mode 100644 index 9fede3438..000000000 --- a/.github/workflows/install-test-darwin.yml +++ /dev/null @@ -1,60 +0,0 @@ -name: Test installation Darwin - -on: - push: - branches: - - main - pull_request: - paths: - - "release_files/install.sh" -concurrency: - group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }} - cancel-in-progress: true -jobs: - install-cli-only: - runs-on: macos-latest - steps: - - name: Checkout code - uses: actions/checkout@v2 - - - name: Rename brew package - if: ${{ matrix.check_bin_install }} - run: mv /opt/homebrew/bin/brew /opt/homebrew/bin/brew.bak - - - name: Run install script - run: | - sh ./release_files/install.sh - env: - SKIP_UI_APP: true - - - name: Run tests - run: | - if ! command -v netbird &> /dev/null; then - echo "Error: netbird is not installed" - exit 1 - fi - install-all: - runs-on: macos-latest - steps: - - name: Checkout code - uses: actions/checkout@v2 - - - name: Rename brew package - if: ${{ matrix.check_bin_install }} - run: mv /opt/homebrew/bin/brew /opt/homebrew/bin/brew.bak - - - name: Run install script - run: | - sh ./release_files/install.sh - - - name: Run tests - run: | - if ! command -v netbird &> /dev/null; then - echo "Error: netbird is not installed" - exit 1 - fi - - if [[ $(mdfind "kMDItemContentType == 'com.apple.application-bundle' && kMDItemFSName == '*NetBird UI.app'") ]]; then - echo "Error: NetBird UI is not installed" - exit 1 - fi diff --git a/.github/workflows/install-test-linux.yml b/.github/workflows/install-test-linux.yml deleted file mode 100644 index 4ce30a937..000000000 --- a/.github/workflows/install-test-linux.yml +++ /dev/null @@ -1,38 +0,0 @@ -name: Test installation Linux - -on: - push: - branches: - - main - pull_request: - paths: - - "release_files/install.sh" -concurrency: - group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }} - cancel-in-progress: true -jobs: - install-cli-only: - runs-on: ubuntu-latest - strategy: - matrix: - check_bin_install: [true, false] - steps: - - name: Checkout code - uses: actions/checkout@v2 - - - name: Rename apt package - if: ${{ matrix.check_bin_install }} - run: | - sudo mv /usr/bin/apt /usr/bin/apt.bak - sudo mv /usr/bin/apt-get /usr/bin/apt-get.bak - - - name: Run install script - run: | - sh ./release_files/install.sh - - - name: Run tests - run: | - if ! command -v netbird &> /dev/null; then - echo "Error: netbird is not installed" - exit 1 - fi diff --git a/.github/workflows/test-infrastructure-files.yml b/.github/workflows/test-infrastructure-files.yml index 3861487c2..fdebc882e 100644 --- a/.github/workflows/test-infrastructure-files.yml +++ b/.github/workflows/test-infrastructure-files.yml @@ -9,7 +9,6 @@ on: - 'infrastructure_files/**' - '.github/workflows/test-infrastructure-files.yml' - concurrency: group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }} cancel-in-progress: true diff --git a/release_files/install.sh b/release_files/install.sh index 63b8d81c3..971c074b6 100644 --- a/release_files/install.sh +++ b/release_files/install.sh @@ -24,13 +24,21 @@ download_release_binary() { VERSION=$(get_latest_release) BASE_URL="https://github.com/${OWNER}/${REPO}/releases/download" BINARY_BASE_NAME="${VERSION#v}_${OS_TYPE}_${ARCH}.tar.gz" - + # for Darwin, download the signed Netbird-UI if [ "$OS_TYPE" = "darwin" ] && [ "$1" = "$UI_APP" ]; then BINARY_BASE_NAME="${VERSION#v}_${OS_TYPE}_${ARCH}_signed.zip" fi - BINARY_NAME="$1_${BINARY_BASE_NAME}" + if [ "$1" = "$UI_APP" ]; then + BINARY_NAME="$1-${OS_TYPE}_${BINARY_BASE_NAME}" + if [ "$OS_TYPE" = "darwin" ]; then + BINARY_NAME="$1_${BINARY_BASE_NAME}" + fi + else + BINARY_NAME="$1_${BINARY_BASE_NAME}" + fi + DOWNLOAD_URL="${BASE_URL}/${VERSION}/${BINARY_NAME}" echo "Installing $1 from $DOWNLOAD_URL" @@ -128,6 +136,14 @@ install_native_binaries() { fi } +check_use_bin_variable() { + if [ "${USE_BIN_INSTALL}-x" = "true-x" ]; then + echo "The installation will be performed using binary files" + return 0 + fi + return 1 +} + install_netbird() { # Check if netbird CLI is installed if [ -x "$(command -v netbird)" ]; then @@ -170,8 +186,10 @@ install_netbird() { echo "Netbird UI installation will be omitted as Linux does not run desktop environment" fi - # Check the availability of a compactible package manager - if [ -x "$(command -v apt)" ]; then + # Check the availability of a compatible package manager + if check_use_bin_variable; then + PACKAGE_MANAGER="bin" + elif [ -x "$(command -v apt)" ]; then PACKAGE_MANAGER="apt" echo "The installation will be performed using apt package manager" elif [ -x "$(command -v dnf)" ]; then @@ -191,7 +209,9 @@ install_netbird() { INSTALL_DIR="/usr/local/bin" # Check the availability of a compatible package manager - if [ -x "$(command -v brew)" ]; then + if check_use_bin_variable; then + PACKAGE_MANAGER="bin" + elif [ -x "$(command -v brew)" ]; then PACKAGE_MANAGER="brew" echo "The installation will be performed using brew package manager" fi From 4572c6c1f8362c100716490bc91b52b6a8b7e659 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 16 Aug 2023 16:11:26 +0200 Subject: [PATCH 07/42] Avoid categorization on incoming claim (#1086) This prevents domain categorization on claims of invited users --- management/server/account.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/management/server/account.go b/management/server/account.go index 7efcd2f1e..785058987 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1420,7 +1420,7 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla if _, ok := accountFromID.Users[claims.UserId]; !ok { return nil, fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId) } - if accountFromID.DomainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory { + if accountFromID.DomainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || accountFromID.Domain != claims.Domain { return accountFromID, nil } } From 8e3bcd57a2b851a3caf90b9d70eaeadf4bff0065 Mon Sep 17 00:00:00 2001 From: Misha Bragin Date: Wed, 16 Aug 2023 23:05:22 +0200 Subject: [PATCH 08/42] Specify invited by email when inviting a user (#1087) --- management/server/idp/auth0.go | 7 ++++--- management/server/idp/auth0_test.go | 6 +++--- management/server/idp/authentik.go | 2 +- management/server/idp/azure.go | 2 +- management/server/idp/google_workspace.go | 2 +- management/server/idp/idp.go | 3 ++- management/server/idp/keycloak.go | 2 +- management/server/idp/okta.go | 2 +- management/server/idp/zitadel.go | 2 +- management/server/user.go | 8 +++++++- 10 files changed, 22 insertions(+), 14 deletions(-) diff --git a/management/server/idp/auth0.go b/management/server/idp/auth0.go index 517e169d0..64ec88e9f 100644 --- a/management/server/idp/auth0.go +++ b/management/server/idp/auth0.go @@ -461,7 +461,7 @@ func (am *Auth0Manager) UpdateUserAppMetadata(userID string, appMetadata AppMeta return nil } -func buildCreateUserRequestPayload(email string, name string, accountID string) (string, error) { +func buildCreateUserRequestPayload(email, name, accountID, invitedByEmail string) (string, error) { invite := true req := &createUserRequest{ Email: email, @@ -469,6 +469,7 @@ func buildCreateUserRequestPayload(email string, name string, accountID string) AppMeta: AppMetadata{ WTAccountID: accountID, WTPendingInvite: &invite, + WTInvitedBy: invitedByEmail, }, Connection: "Username-Password-Authentication", Password: GeneratePassword(8, 1, 1, 1), @@ -634,9 +635,9 @@ func (am *Auth0Manager) GetUserByEmail(email string) ([]*UserData, error) { } // CreateUser creates a new user in Auth0 Idp and sends an invite -func (am *Auth0Manager) CreateUser(email string, name string, accountID string) (*UserData, error) { +func (am *Auth0Manager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) { - payloadString, err := buildCreateUserRequestPayload(email, name, accountID) + payloadString, err := buildCreateUserRequestPayload(email, name, accountID, invitedByEmail) if err != nil { return nil, err } diff --git a/management/server/idp/auth0_test.go b/management/server/idp/auth0_test.go index fecee936b..0814b4b69 100644 --- a/management/server/idp/auth0_test.go +++ b/management/server/idp/auth0_test.go @@ -343,7 +343,7 @@ func TestAuth0_UpdateUserAppMetadata(t *testing.T) { updateUserAppMetadataTestCase2 := updateUserAppMetadataTest{ name: "Bad Status Code", inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp), - expectedReqBody: fmt.Sprintf("{\"app_metadata\":{\"wt_account_id\":\"%s\",\"wt_pending_invite\":null}}", appMetadata.WTAccountID), + expectedReqBody: fmt.Sprintf("{\"app_metadata\":{\"wt_account_id\":\"%s\",\"wt_pending_invite\":null,\"wt_invited_by_email\":\"\"}}", appMetadata.WTAccountID), appMetadata: appMetadata, statusCode: 400, helper: JsonParser{}, @@ -366,7 +366,7 @@ func TestAuth0_UpdateUserAppMetadata(t *testing.T) { updateUserAppMetadataTestCase4 := updateUserAppMetadataTest{ name: "Good request", inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp), - expectedReqBody: fmt.Sprintf("{\"app_metadata\":{\"wt_account_id\":\"%s\",\"wt_pending_invite\":null}}", appMetadata.WTAccountID), + expectedReqBody: fmt.Sprintf("{\"app_metadata\":{\"wt_account_id\":\"%s\",\"wt_pending_invite\":null,\"wt_invited_by_email\":\"\"}}", appMetadata.WTAccountID), appMetadata: appMetadata, statusCode: 200, helper: JsonParser{}, @@ -378,7 +378,7 @@ func TestAuth0_UpdateUserAppMetadata(t *testing.T) { updateUserAppMetadataTestCase5 := updateUserAppMetadataTest{ name: "Update Pending Invite", inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp), - expectedReqBody: fmt.Sprintf("{\"app_metadata\":{\"wt_account_id\":\"%s\",\"wt_pending_invite\":true}}", appMetadata.WTAccountID), + expectedReqBody: fmt.Sprintf("{\"app_metadata\":{\"wt_account_id\":\"%s\",\"wt_pending_invite\":true,\"wt_invited_by_email\":\"\"}}", appMetadata.WTAccountID), appMetadata: AppMetadata{ WTAccountID: "ok", WTPendingInvite: &invite, diff --git a/management/server/idp/authentik.go b/management/server/idp/authentik.go index 396d390e2..586348fee 100644 --- a/management/server/idp/authentik.go +++ b/management/server/idp/authentik.go @@ -362,7 +362,7 @@ func (am *AuthentikManager) GetAllAccounts() (map[string][]*UserData, error) { } // CreateUser creates a new user in authentik Idp and sends an invitation. -func (am *AuthentikManager) CreateUser(email string, name string, accountID string) (*UserData, error) { +func (am *AuthentikManager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) { ctx, err := am.authenticationContext() if err != nil { return nil, err diff --git a/management/server/idp/azure.go b/management/server/idp/azure.go index b70e87be1..7cff7d8fc 100644 --- a/management/server/idp/azure.go +++ b/management/server/idp/azure.go @@ -236,7 +236,7 @@ func (ac *AzureCredentials) Authenticate() (JWTToken, error) { } // CreateUser creates a new user in azure AD Idp. -func (am *AzureManager) CreateUser(email string, name string, accountID string) (*UserData, error) { +func (am *AzureManager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) { payload, err := buildAzureCreateUserRequestPayload(email, name, accountID, am.ClientID) if err != nil { return nil, err diff --git a/management/server/idp/google_workspace.go b/management/server/idp/google_workspace.go index 9a5d73f75..efe457fdd 100644 --- a/management/server/idp/google_workspace.go +++ b/management/server/idp/google_workspace.go @@ -185,7 +185,7 @@ func (gm *GoogleWorkspaceManager) GetAllAccounts() (map[string][]*UserData, erro } // CreateUser creates a new user in Google Workspace and sends an invitation. -func (gm *GoogleWorkspaceManager) CreateUser(email string, name string, accountID string) (*UserData, error) { +func (gm *GoogleWorkspaceManager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) { invite := true metadata := AppMetadata{ WTAccountID: accountID, diff --git a/management/server/idp/idp.go b/management/server/idp/idp.go index a4d7c9bdf..48afd5c32 100644 --- a/management/server/idp/idp.go +++ b/management/server/idp/idp.go @@ -15,7 +15,7 @@ type Manager interface { GetUserDataByID(userId string, appMetadata AppMetadata) (*UserData, error) GetAccount(accountId string) ([]*UserData, error) GetAllAccounts() (map[string][]*UserData, error) - CreateUser(email string, name string, accountID string) (*UserData, error) + CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) GetUserByEmail(email string) ([]*UserData, error) InviteUserByID(userID string) error } @@ -72,6 +72,7 @@ type AppMetadata struct { // maps to wt_account_id when json.marshal WTAccountID string `json:"wt_account_id,omitempty"` WTPendingInvite *bool `json:"wt_pending_invite"` + WTInvitedBy string `json:"wt_invited_by_email"` } // JWTToken a JWT object that holds information of a token diff --git a/management/server/idp/keycloak.go b/management/server/idp/keycloak.go index d44396571..12ed87389 100644 --- a/management/server/idp/keycloak.go +++ b/management/server/idp/keycloak.go @@ -230,7 +230,7 @@ func (kc *KeycloakCredentials) Authenticate() (JWTToken, error) { } // CreateUser creates a new user in keycloak Idp and sends an invite. -func (km *KeycloakManager) CreateUser(email string, name string, accountID string) (*UserData, error) { +func (km *KeycloakManager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) { jwtToken, err := km.credentials.Authenticate() if err != nil { return nil, err diff --git a/management/server/idp/okta.go b/management/server/idp/okta.go index 9d00bc6ee..c6b5055d4 100644 --- a/management/server/idp/okta.go +++ b/management/server/idp/okta.go @@ -103,7 +103,7 @@ func (oc *OktaCredentials) Authenticate() (JWTToken, error) { } // CreateUser creates a new user in okta Idp and sends an invitation. -func (om *OktaManager) CreateUser(email string, name string, accountID string) (*UserData, error) { +func (om *OktaManager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) { var ( sendEmail = true activate = true diff --git a/management/server/idp/zitadel.go b/management/server/idp/zitadel.go index 4de5659be..fce2c7b37 100644 --- a/management/server/idp/zitadel.go +++ b/management/server/idp/zitadel.go @@ -234,7 +234,7 @@ func (zc *ZitadelCredentials) Authenticate() (JWTToken, error) { } // CreateUser creates a new user in zitadel Idp and sends an invite. -func (zm *ZitadelManager) CreateUser(email string, name string, accountID string) (*UserData, error) { +func (zm *ZitadelManager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) { payload, err := buildZitadelCreateUserRequestPayload(email, name) if err != nil { return nil, err diff --git a/management/server/user.go b/management/server/user.go index 34a879328..0b61dd60a 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -215,6 +215,12 @@ func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite return nil, status.Errorf(status.NotFound, "account %s doesn't exist", accountID) } + // initiator is the one who is inviting the new user + initiatorUser, err := am.lookupUserInCache(userID, account) + if err != nil { + return nil, status.Errorf(status.NotFound, "user %s doesn't exist in IdP", userID) + } + // check if the user is already registered with this email => reject user, err := am.lookupUserInCacheByEmail(invite.Email, accountID) if err != nil { @@ -234,7 +240,7 @@ func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite return nil, status.Errorf(status.UserAlreadyExists, "can't invite a user with an existing NetBird account") } - idpUser, err := am.idpManager.CreateUser(invite.Email, invite.Name, accountID) + idpUser, err := am.idpManager.CreateUser(invite.Email, invite.Name, accountID, initiatorUser.Email) if err != nil { return nil, err } From da8447a67d2e6f8601e5465cd2a3a42f9cd3c10e Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Thu, 17 Aug 2023 12:27:04 +0200 Subject: [PATCH 09/42] Update the link to the doc page (#1088) --- client/cmd/login.go | 2 +- client/cmd/status.go | 2 +- client/internal/dns/server.go | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/client/cmd/login.go b/client/cmd/login.go index 566f661a3..c61c0a93a 100644 --- a/client/cmd/login.go +++ b/client/cmd/login.go @@ -202,6 +202,6 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string) { "If your browser didn't open automatically, use this URL to log in:\n\n" + " " + verificationURIComplete + " " + codeMsg + " \n\n") if err != nil { - cmd.Printf("Alternatively, you may want to use a setup key, see:\n\n https://www.netbird.io/docs/overview/setup-keys\n") + cmd.Printf("Alternatively, you may want to use a setup key, see:\n\n https://docs.netbird.io/how-to/register-machines-using-setup-keys\n") } } diff --git a/client/cmd/status.go b/client/cmd/status.go index 119944d08..5d741462b 100644 --- a/client/cmd/status.go +++ b/client/cmd/status.go @@ -120,7 +120,7 @@ func statusFunc(cmd *cobra.Command, args []string) error { " netbird up \n\n"+ "If you are running a self-hosted version and no SSO provider has been configured in your Management Server,\n"+ "you can use a setup-key:\n\n netbird up --management-url --setup-key \n\n"+ - "More info: https://www.netbird.io/docs/overview/setup-keys\n\n", + "More info: https://docs.netbird.io/how-to/register-machines-using-setup-keys\n\n", resp.GetStatus(), ) return nil diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 6dd8f1904..31946c13e 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -238,7 +238,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { hostUpdate := s.currentConfig if s.service.RuntimePort() != defaultPort && !s.hostManager.supportCustomPort() { log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " + - "Learn more at: https://netbird.io/docs/how-to-guides/nameservers#local-resolver") + "Learn more at: https://docs.netbird.io/how-to/manage-dns-in-your-network#local-resolver") hostUpdate.routeAll = false } From d4e9087f941a9e175651f391b3fc4929b3b2b123 Mon Sep 17 00:00:00 2001 From: Misha Bragin Date: Thu, 17 Aug 2023 14:04:04 +0200 Subject: [PATCH 10/42] Add peer login and expiration activity events (#1090) Track the even of a user logging in their peer. Track the event of a peer login expiration. --- management/server/account.go | 1 + management/server/activity/codes.go | 378 +++++------------------- management/server/http/api/openapi.yml | 4 +- management/server/http/api/types.gen.go | 2 + management/server/peer.go | 2 + 5 files changed, 82 insertions(+), 305 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 785058987..442066bf9 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -872,6 +872,7 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(accountID string) func() log.Errorf("failed saving peer status while expiring peer %s", peer.ID) return account.GetNextPeerExpiration() } + am.storeEvent(peer.UserID, peer.ID, account.Id, activity.PeerLoginExpired, peer.EventMeta(am.GetDNSDomain())) } log.Debugf("discovered %d peers to expire for account %s", len(peerIDs), account.Id) diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go index e571c3a0c..7c6b55218 100644 --- a/management/server/activity/codes.go +++ b/management/server/activity/codes.go @@ -1,5 +1,14 @@ package activity +// Activity that triggered an Event +type Activity int + +// Code is an activity string representation +type Code struct { + message string + code string +} + const ( // PeerAddedByUser indicates that a user added a new peer to the system PeerAddedByUser Activity = iota @@ -97,314 +106,77 @@ const ( UserUnblocked // GroupDeleted indicates that a user deleted group GroupDeleted + // UserLoggedInPeer indicates that user logged in their peer with an interactive SSO login + UserLoggedInPeer + // PeerLoginExpired indicates that the user peer login has been expired and peer disconnected + PeerLoginExpired ) -const ( - // PeerAddedByUserMessage is a human-readable text message of the PeerAddedByUser activity - PeerAddedByUserMessage string = "Peer added" - // PeerAddedWithSetupKeyMessage is a human-readable text message of the PeerAddedWithSetupKey activity - PeerAddedWithSetupKeyMessage = PeerAddedByUserMessage - // UserJoinedMessage is a human-readable text message of the UserJoined activity - UserJoinedMessage string = "User joined" - // UserInvitedMessage is a human-readable text message of the UserInvited activity - UserInvitedMessage string = "User invited" - // AccountCreatedMessage is a human-readable text message of the AccountCreated activity - AccountCreatedMessage string = "Account created" - // PeerRemovedByUserMessage is a human-readable text message of the PeerRemovedByUser activity - PeerRemovedByUserMessage string = "Peer deleted" - // RuleAddedMessage is a human-readable text message of the RuleAdded activity - RuleAddedMessage string = "Rule added" - // RuleRemovedMessage is a human-readable text message of the RuleRemoved activity - RuleRemovedMessage string = "Rule deleted" - // RuleUpdatedMessage is a human-readable text message of the RuleRemoved activity - RuleUpdatedMessage string = "Rule updated" - // PolicyAddedMessage is a human-readable text message of the PolicyAdded activity - PolicyAddedMessage string = "Policy added" - // PolicyRemovedMessage is a human-readable text message of the PolicyRemoved activity - PolicyRemovedMessage string = "Policy deleted" - // PolicyUpdatedMessage is a human-readable text message of the PolicyRemoved activity - PolicyUpdatedMessage string = "Policy updated" - // SetupKeyCreatedMessage is a human-readable text message of the SetupKeyCreated activity - SetupKeyCreatedMessage string = "Setup key created" - // SetupKeyUpdatedMessage is a human-readable text message of the SetupKeyUpdated activity - SetupKeyUpdatedMessage string = "Setup key updated" - // SetupKeyRevokedMessage is a human-readable text message of the SetupKeyRevoked activity - SetupKeyRevokedMessage string = "Setup key revoked" - // SetupKeyOverusedMessage is a human-readable text message of the SetupKeyOverused activity - SetupKeyOverusedMessage string = "Setup key overused" - // GroupCreatedMessage is a human-readable text message of the GroupCreated activity - GroupCreatedMessage string = "Group created" - // GroupUpdatedMessage is a human-readable text message of the GroupUpdated activity - GroupUpdatedMessage string = "Group updated" - // GroupAddedToPeerMessage is a human-readable text message of the GroupAddedToPeer activity - GroupAddedToPeerMessage string = "Group added to peer" - // GroupRemovedFromPeerMessage is a human-readable text message of the GroupRemovedFromPeer activity - GroupRemovedFromPeerMessage string = "Group removed from peer" - // GroupAddedToUserMessage is a human-readable text message of the GroupAddedToUser activity - GroupAddedToUserMessage string = "Group added to user" - // GroupRemovedFromUserMessage is a human-readable text message of the GroupRemovedFromUser activity - GroupRemovedFromUserMessage string = "Group removed from user" - // UserRoleUpdatedMessage is a human-readable text message of the UserRoleUpdatedMessage activity - UserRoleUpdatedMessage string = "User role updated" - // GroupAddedToSetupKeyMessage is a human-readable text message of the GroupAddedToSetupKey activity - GroupAddedToSetupKeyMessage string = "Group added to setup key" - // GroupRemovedFromSetupKeyMessage is a human-readable text message of the GroupRemovedFromSetupKey activity - GroupRemovedFromSetupKeyMessage string = "Group removed from user setup key" - // GroupAddedToDisabledManagementGroupsMessage is a human-readable text message of the GroupAddedToDisabledManagementGroups activity - GroupAddedToDisabledManagementGroupsMessage string = "Group added to disabled management DNS setting" - // GroupRemovedFromDisabledManagementGroupsMessage is a human-readable text message of the GroupRemovedFromDisabledManagementGroups activity - GroupRemovedFromDisabledManagementGroupsMessage string = "Group removed from disabled management DNS setting" - // RouteCreatedMessage is a human-readable text message of the RouteCreated activity - RouteCreatedMessage string = "Route created" - // RouteRemovedMessage is a human-readable text message of the RouteRemoved activity - RouteRemovedMessage string = "Route deleted" - // RouteUpdatedMessage is a human-readable text message of the RouteUpdated activity - RouteUpdatedMessage string = "Route updated" - // PeerSSHEnabledMessage is a human-readable text message of the PeerSSHEnabled activity - PeerSSHEnabledMessage string = "Peer SSH server enabled" - // PeerSSHDisabledMessage is a human-readable text message of the PeerSSHDisabled activity - PeerSSHDisabledMessage string = "Peer SSH server disabled" - // PeerRenamedMessage is a human-readable text message of the PeerRenamed activity - PeerRenamedMessage string = "Peer renamed" - // PeerLoginExpirationDisabledMessage is a human-readable text message of the PeerLoginExpirationDisabled activity - PeerLoginExpirationDisabledMessage string = "Peer login expiration disabled" - // PeerLoginExpirationEnabledMessage is a human-readable text message of the PeerLoginExpirationEnabled activity - PeerLoginExpirationEnabledMessage string = "Peer login expiration enabled" - // NameserverGroupCreatedMessage is a human-readable text message of the NameserverGroupCreated activity - NameserverGroupCreatedMessage string = "Nameserver group created" - // NameserverGroupDeletedMessage is a human-readable text message of the NameserverGroupDeleted activity - NameserverGroupDeletedMessage string = "Nameserver group deleted" - // NameserverGroupUpdatedMessage is a human-readable text message of the NameserverGroupUpdated activity - NameserverGroupUpdatedMessage string = "Nameserver group updated" - // AccountPeerLoginExpirationEnabledMessage is a human-readable text message of the AccountPeerLoginExpirationEnabled activity - AccountPeerLoginExpirationEnabledMessage string = "Peer login expiration enabled for the account" - // AccountPeerLoginExpirationDisabledMessage is a human-readable text message of the AccountPeerLoginExpirationDisabled activity - AccountPeerLoginExpirationDisabledMessage string = "Peer login expiration disabled for the account" - // AccountPeerLoginExpirationDurationUpdatedMessage is a human-readable text message of the AccountPeerLoginExpirationDurationUpdated activity - AccountPeerLoginExpirationDurationUpdatedMessage string = "Peer login expiration duration updated" - // PersonalAccessTokenCreatedMessage is a human-readable text message of the PersonalAccessTokenCreated activity - PersonalAccessTokenCreatedMessage string = "Personal access token created" - // PersonalAccessTokenDeletedMessage is a human-readable text message of the PersonalAccessTokenDeleted activity - PersonalAccessTokenDeletedMessage string = "Personal access token deleted" - // ServiceUserCreatedMessage is a human-readable text message of the ServiceUserCreated activity - ServiceUserCreatedMessage string = "Service user created" - // ServiceUserDeletedMessage is a human-readable text message of the ServiceUserDeleted activity - ServiceUserDeletedMessage string = "Service user deleted" - // UserBlockedMessage is a human-readable text message of the UserBlocked activity - UserBlockedMessage string = "User blocked" - // UserUnblockedMessage is a human-readable text message of the UserUnblocked activity - UserUnblockedMessage string = "User unblocked" - // GroupDeletedMessage is a human-readable text message of the GroupDeleted activity - GroupDeletedMessage string = "Group deleted" -) - -// Activity that triggered an Event -type Activity int - -// Message returns a string representation of an activity -func (a Activity) Message() string { - switch a { - case PeerAddedByUser: - return PeerAddedByUserMessage - case PeerRemovedByUser: - return PeerRemovedByUserMessage - case PeerAddedWithSetupKey: - return PeerAddedWithSetupKeyMessage - case UserJoined: - return UserJoinedMessage - case UserInvited: - return UserInvitedMessage - case AccountCreated: - return AccountCreatedMessage - case RuleAdded: - return RuleAddedMessage - case RuleRemoved: - return RuleRemovedMessage - case RuleUpdated: - return RuleUpdatedMessage - case PolicyAdded: - return PolicyAddedMessage - case PolicyRemoved: - return PolicyRemovedMessage - case PolicyUpdated: - return PolicyUpdatedMessage - case SetupKeyCreated: - return SetupKeyCreatedMessage - case SetupKeyUpdated: - return SetupKeyUpdatedMessage - case SetupKeyRevoked: - return SetupKeyRevokedMessage - case SetupKeyOverused: - return SetupKeyOverusedMessage - case GroupCreated: - return GroupCreatedMessage - case GroupUpdated: - return GroupUpdatedMessage - case GroupAddedToPeer: - return GroupAddedToPeerMessage - case GroupRemovedFromPeer: - return GroupRemovedFromPeerMessage - case GroupRemovedFromUser: - return GroupRemovedFromUserMessage - case GroupAddedToUser: - return GroupAddedToUserMessage - case UserRoleUpdated: - return UserRoleUpdatedMessage - case GroupAddedToSetupKey: - return GroupAddedToSetupKeyMessage - case GroupRemovedFromSetupKey: - return GroupRemovedFromSetupKeyMessage - case GroupAddedToDisabledManagementGroups: - return GroupAddedToDisabledManagementGroupsMessage - case GroupRemovedFromDisabledManagementGroups: - return GroupRemovedFromDisabledManagementGroupsMessage - case RouteCreated: - return RouteCreatedMessage - case RouteRemoved: - return RouteRemovedMessage - case RouteUpdated: - return RouteUpdatedMessage - case PeerSSHEnabled: - return PeerSSHEnabledMessage - case PeerSSHDisabled: - return PeerSSHDisabledMessage - case PeerLoginExpirationEnabled: - return PeerLoginExpirationEnabledMessage - case PeerLoginExpirationDisabled: - return PeerLoginExpirationDisabledMessage - case PeerRenamed: - return PeerRenamedMessage - case NameserverGroupCreated: - return NameserverGroupCreatedMessage - case NameserverGroupDeleted: - return NameserverGroupDeletedMessage - case NameserverGroupUpdated: - return NameserverGroupUpdatedMessage - case AccountPeerLoginExpirationEnabled: - return AccountPeerLoginExpirationEnabledMessage - case AccountPeerLoginExpirationDisabled: - return AccountPeerLoginExpirationDisabledMessage - case AccountPeerLoginExpirationDurationUpdated: - return AccountPeerLoginExpirationDurationUpdatedMessage - case PersonalAccessTokenCreated: - return PersonalAccessTokenCreatedMessage - case PersonalAccessTokenDeleted: - return PersonalAccessTokenDeletedMessage - case ServiceUserCreated: - return ServiceUserCreatedMessage - case ServiceUserDeleted: - return ServiceUserDeletedMessage - case UserBlocked: - return UserBlockedMessage - case UserUnblocked: - return UserUnblockedMessage - case GroupDeleted: - return GroupDeletedMessage - default: - return "UNKNOWN_ACTIVITY" - } +var activityMap = map[Activity]Code{ + PeerAddedByUser: {"Peer added", "user.peer.add"}, + PeerAddedWithSetupKey: {"Peer added", "setupkey.peer.add"}, + UserJoined: {"User joined", "user.join"}, + UserInvited: {"User invited", "user.invite"}, + AccountCreated: {"Account created", "account.create"}, + PeerRemovedByUser: {"Peer deleted", "user.peer.delete"}, + RuleAdded: {"Rule added", "rule.add"}, + RuleUpdated: {"Rule updated", "rule.update"}, + RuleRemoved: {"Rule deleted", "rule.delete"}, + PolicyAdded: {"Policy added", "policy.add"}, + PolicyUpdated: {"Policy updated", "policy.update"}, + PolicyRemoved: {"Policy deleted", "policy.delete"}, + SetupKeyCreated: {"Setup key created", "setupkey.add"}, + SetupKeyUpdated: {"Setup key updated", "setupkey.update"}, + SetupKeyRevoked: {"Setup key revoked", "setupkey.revoke"}, + SetupKeyOverused: {"Setup key overused", "setupkey.overuse"}, + GroupCreated: {"Group created", "group.add"}, + GroupUpdated: {"Group updated", "group.update"}, + GroupAddedToPeer: {"Group added to peer", "peer.group.add"}, + GroupRemovedFromPeer: {"Group removed from peer", "peer.group.delete"}, + GroupAddedToUser: {"Group added to user", "user.group.add"}, + GroupRemovedFromUser: {"Group removed from user", "user.group.delete"}, + UserRoleUpdated: {"User role updated", "user.role.update"}, + GroupAddedToSetupKey: {"Group added to setup key", "setupkey.group.add"}, + GroupRemovedFromSetupKey: {"Group removed from user setup key", "setupkey.group.delete"}, + GroupAddedToDisabledManagementGroups: {"Group added to disabled management DNS setting", "dns.setting.disabled.management.group.add"}, + GroupRemovedFromDisabledManagementGroups: {"Group removed from disabled management DNS setting", "dns.setting.disabled.management.group.delete"}, + RouteCreated: {"Route created", "route.add"}, + RouteRemoved: {"Route deleted", "route.delete"}, + RouteUpdated: {"Route updated", "route.update"}, + PeerSSHEnabled: {"Peer SSH server enabled", "peer.ssh.enable"}, + PeerSSHDisabled: {"Peer SSH server disabled", "peer.ssh.disable"}, + PeerRenamed: {"Peer renamed", "peer.rename"}, + PeerLoginExpirationEnabled: {"Peer login expiration enabled", "peer.login.expiration.enable"}, + PeerLoginExpirationDisabled: {"Peer login expiration disabled", "peer.login.expiration.disable"}, + NameserverGroupCreated: {"Nameserver group created", "nameserver.group.add"}, + NameserverGroupDeleted: {"Nameserver group deleted", "nameserver.group.delete"}, + NameserverGroupUpdated: {"Nameserver group updated", "nameserver.group.update"}, + AccountPeerLoginExpirationDurationUpdated: {"Account peer login expiration duration updated", "account.setting.peer.login.expiration.update"}, + AccountPeerLoginExpirationEnabled: {"Account peer login expiration enabled", "account.setting.peer.login.expiration.enable"}, + AccountPeerLoginExpirationDisabled: {"Account peer login expiration disabled", "account.setting.peer.login.expiration.disable"}, + PersonalAccessTokenCreated: {"Personal access token created", "personal.access.token.create"}, + PersonalAccessTokenDeleted: {"Personal access token deleted", "personal.access.token.delete"}, + ServiceUserCreated: {"Service user created", "service.user.create"}, + ServiceUserDeleted: {"Service user deleted", "service.user.delete"}, + UserBlocked: {"User blocked", "user.block"}, + UserUnblocked: {"User unblocked", "user.unblock"}, + GroupDeleted: {"Group deleted", "group.delete"}, + UserLoggedInPeer: {"User logged in peer", "user.peer.login"}, + PeerLoginExpired: {"Peer login expired", "peer.login.expire"}, } // StringCode returns a string code of the activity func (a Activity) StringCode() string { - switch a { - case PeerAddedByUser: - return "user.peer.add" - case PeerRemovedByUser: - return "user.peer.delete" - case PeerAddedWithSetupKey: - return "setupkey.peer.add" - case UserJoined: - return "user.join" - case UserInvited: - return "user.invite" - case UserBlocked: - return "user.block" - case UserUnblocked: - return "user.unblock" - case AccountCreated: - return "account.create" - case RuleAdded: - return "rule.add" - case RuleRemoved: - return "rule.delete" - case RuleUpdated: - return "rule.update" - case PolicyAdded: - return "policy.add" - case PolicyRemoved: - return "policy.delete" - case PolicyUpdated: - return "policy.update" - case SetupKeyCreated: - return "setupkey.add" - case SetupKeyRevoked: - return "setupkey.revoke" - case SetupKeyOverused: - return "setupkey.overuse" - case SetupKeyUpdated: - return "setupkey.update" - case GroupCreated: - return "group.add" - case GroupUpdated: - return "group.update" - case GroupDeleted: - return "group.delete" - case GroupRemovedFromPeer: - return "peer.group.delete" - case GroupAddedToPeer: - return "peer.group.add" - case GroupAddedToUser: - return "user.group.add" - case GroupRemovedFromUser: - return "user.group.delete" - case UserRoleUpdated: - return "user.role.update" - case GroupAddedToSetupKey: - return "setupkey.group.add" - case GroupRemovedFromSetupKey: - return "setupkey.group.delete" - case GroupAddedToDisabledManagementGroups: - return "dns.setting.disabled.management.group.add" - case GroupRemovedFromDisabledManagementGroups: - return "dns.setting.disabled.management.group.delete" - case RouteCreated: - return "route.add" - case RouteRemoved: - return "route.delete" - case RouteUpdated: - return "route.update" - case PeerRenamed: - return "peer.rename" - case PeerSSHEnabled: - return "peer.ssh.enable" - case PeerSSHDisabled: - return "peer.ssh.disable" - case PeerLoginExpirationDisabled: - return "peer.login.expiration.disable" - case PeerLoginExpirationEnabled: - return "peer.login.expiration.enable" - case NameserverGroupCreated: - return "nameserver.group.add" - case NameserverGroupDeleted: - return "nameserver.group.delete" - case NameserverGroupUpdated: - return "nameserver.group.update" - case AccountPeerLoginExpirationDurationUpdated: - return "account.setting.peer.login.expiration.update" - case AccountPeerLoginExpirationEnabled: - return "account.setting.peer.login.expiration.enable" - case AccountPeerLoginExpirationDisabled: - return "account.setting.peer.login.expiration.disable" - case PersonalAccessTokenCreated: - return "personal.access.token.create" - case PersonalAccessTokenDeleted: - return "personal.access.token.delete" - case ServiceUserCreated: - return "service.user.create" - case ServiceUserDeleted: - return "service.user.delete" - default: - return "UNKNOWN_ACTIVITY" + if code, ok := activityMap[a]; ok { + return code.code } + return "UNKNOWN_ACTIVITY" +} + +// Message returns a string representation of an activity +func (a Activity) Message() string { + if code, ok := activityMap[a]; ok { + return code.message + } + return "UNKNOWN_ACTIVITY" } diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index 2b18bc295..1fb54c4f7 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -892,7 +892,7 @@ components: description: The string code of the activity that occurred during the event type: string enum: [ "user.peer.delete", "user.join", "user.invite", "user.peer.add", "user.group.add", "user.group.delete", - "user.role.update", "user.block", "user.unblock", + "user.role.update", "user.block", "user.unblock", "user.peer.login", "setupkey.peer.add", "setupkey.add", "setupkey.update", "setupkey.revoke", "setupkey.overuse", "setupkey.group.delete", "setupkey.group.add", "rule.add", "rule.delete", "rule.update", @@ -901,7 +901,7 @@ components: "account.create", "account.setting.peer.login.expiration.update", "account.setting.peer.login.expiration.disable", "account.setting.peer.login.expiration.enable", "route.add", "route.delete", "route.update", "nameserver.group.add", "nameserver.group.delete", "nameserver.group.update", - "peer.ssh.disable", "peer.ssh.enable", "peer.rename", "peer.login.expiration.disable", "peer.login.expiration.enable", + "peer.ssh.disable", "peer.ssh.enable", "peer.rename", "peer.login.expiration.disable", "peer.login.expiration.enable", "peer.login.expire", "service.user.create", "personal.access.token.create", "service.user.delete", "personal.access.token.delete" ] example: route.add initiator_id: diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index c11ed9efa..93d371a17 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -27,6 +27,7 @@ const ( EventActivityCodeNameserverGroupUpdate EventActivityCode = "nameserver.group.update" EventActivityCodePeerLoginExpirationDisable EventActivityCode = "peer.login.expiration.disable" EventActivityCodePeerLoginExpirationEnable EventActivityCode = "peer.login.expiration.enable" + EventActivityCodePeerLoginExpire EventActivityCode = "peer.login.expire" EventActivityCodePeerRename EventActivityCode = "peer.rename" EventActivityCodePeerSshDisable EventActivityCode = "peer.ssh.disable" EventActivityCodePeerSshEnable EventActivityCode = "peer.ssh.enable" @@ -57,6 +58,7 @@ const ( EventActivityCodeUserJoin EventActivityCode = "user.join" EventActivityCodeUserPeerAdd EventActivityCode = "user.peer.add" EventActivityCodeUserPeerDelete EventActivityCode = "user.peer.delete" + EventActivityCodeUserPeerLogin EventActivityCode = "user.peer.login" EventActivityCodeUserRoleUpdate EventActivityCode = "user.role.update" EventActivityCodeUserUnblock EventActivityCode = "user.unblock" ) diff --git a/management/server/peer.go b/management/server/peer.go index b2fe0955a..b2d4e436f 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -695,6 +695,8 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*Peer, *NetworkMap, updatePeerLastLogin(peer, account) updateRemotePeers = true shouldStoreAccount = true + + am.storeEvent(login.UserID, peer.ID, account.Id, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain())) } peer, updated := updatePeerMeta(peer, login.Meta, account) From 3aa657599b7e908609387135f866705c9ae574ce Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Thu, 17 Aug 2023 15:10:03 +0300 Subject: [PATCH 11/42] Switch OAuth flow initialization order (#1089) Switches the order of initialization in the OAuth flow within the NewOAuthFlow method. Instead of initializing the Device Authorization Flow first, it now initializes the PKCE Authorization Flow first, and falls back to the Device Authorization Flow if the PKCE initialization fails. --- client/internal/auth/oauth.go | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/client/internal/auth/oauth.go b/client/internal/auth/oauth.go index d7365df60..794fe0958 100644 --- a/client/internal/auth/oauth.go +++ b/client/internal/auth/oauth.go @@ -59,19 +59,17 @@ func (t TokenInfo) GetTokenToUse() string { // NewOAuthFlow initializes and returns the appropriate OAuth flow based on the management configuration. func NewOAuthFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) { - log.Debug("getting device authorization flow info") + log.Debug("loading pkce authorization flow info") - // Try to initialize the Device Authorization Flow - deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL) + pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL) if err == nil { - return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig) + return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig) } - log.Debugf("getting device authorization flow info failed with error: %v", err) - log.Debugf("falling back to pkce authorization flow info") + log.Debugf("loading pkce authorization flow info failed with error: %v", err) + log.Debugf("falling back to device authorization flow info") - // If Device Authorization Flow failed, try the PKCE Authorization Flow - pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL) + deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL) if err != nil { s, ok := gstatus.FromError(err) if ok && s.Code() == codes.NotFound { @@ -82,9 +80,9 @@ func NewOAuthFlow(ctx context.Context, config *internal.Config) (OAuthFlow, erro return nil, fmt.Errorf("the management server, %s, does not support SSO providers, "+ "please update your server or use Setup Keys to login", config.ManagementURL) } else { - return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err) + return nil, fmt.Errorf("getting device authorization flow info failed with error: %v", err) } } - return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig) + return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig) } From 3ac32fd78abb67139a6b978637490b9cd32bdcf0 Mon Sep 17 00:00:00 2001 From: Givi Khojanashvili Date: Fri, 18 Aug 2023 17:36:05 +0400 Subject: [PATCH 12/42] Send network update when propagate user auto-groups (#1084) For peer propagation this commit triggers network map update in two cases: 1) peer login 2) user AutoGroups update Also it issues new activity message about new user group for peer login process. Previous implementation only adds JWT groups to user. This fix also removes JWT groups from user auto assign groups. Pelase note, it also happen when user works with dashboard. --- management/server/account.go | 83 ++++++++++++++++++++++++------- management/server/account_test.go | 10 ++-- management/server/user.go | 16 ++++-- 3 files changed, 83 insertions(+), 26 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 442066bf9..dc8f4d0fd 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -629,8 +629,8 @@ func (a *Account) GetPeer(peerID string) *Peer { return a.Peers[peerID] } -// AddJWTGroups to account and to user autoassigned groups -func (a *Account) AddJWTGroups(userID string, groups []string) bool { +// SetJWTGroups to account and to user autoassigned groups +func (a *Account) SetJWTGroups(userID string, groupsNames []string) bool { user, ok := a.Users[userID] if !ok { return false @@ -641,13 +641,21 @@ func (a *Account) AddJWTGroups(userID string, groups []string) bool { existedGroupsByName[group.Name] = group } - autoGroups := make(map[string]struct{}) - for _, groupID := range user.AutoGroups { - autoGroups[groupID] = struct{}{} + // remove JWT groups from the autogroups, to sync them again + removed := 0 + jwtAutoGroups := make(map[string]struct{}) + for i, id := range user.AutoGroups { + if group, ok := a.Groups[id]; ok && group.Issued == GroupIssuedJWT { + jwtAutoGroups[group.Name] = struct{}{} + user.AutoGroups = append(user.AutoGroups[:i-removed], user.AutoGroups[i-removed+1:]...) + removed++ + } } + // create JWT groups if they doesn't exist + // and all of them to the autogroups var modified bool - for _, name := range groups { + for _, name := range groupsNames { group, ok := existedGroupsByName[name] if !ok { group = &Group{ @@ -656,16 +664,22 @@ func (a *Account) AddJWTGroups(userID string, groups []string) bool { Issued: GroupIssuedJWT, } a.Groups[group.ID] = group - modified = true } - if _, ok := autoGroups[group.ID]; !ok { - if group.Issued == GroupIssuedJWT { - user.AutoGroups = append(user.AutoGroups, group.ID) + // only JWT groups will be synced + if group.Issued == GroupIssuedJWT { + user.AutoGroups = append(user.AutoGroups, group.ID) + if _, ok := jwtAutoGroups[name]; !ok { modified = true } + delete(jwtAutoGroups, name) } } + // if not empty it means we removed some groups + if len(jwtAutoGroups) > 0 { + modified = true + } + return modified } @@ -1358,23 +1372,58 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat } if claim, ok := claims.Raw[account.Settings.JWTGroupsClaimName]; ok { if slice, ok := claim.([]interface{}); ok { - var groups []string + var groupsNames []string for _, item := range slice { if g, ok := item.(string); ok { - groups = append(groups, g) + groupsNames = append(groupsNames, g) } else { log.Errorf("JWT claim %q is not a string: %v", account.Settings.JWTGroupsClaimName, item) } } + + oldGroups := make([]string, len(user.AutoGroups)) + copy(oldGroups, user.AutoGroups) // if groups were added or modified, save the account - if account.AddJWTGroups(claims.UserId, groups) { + if account.SetJWTGroups(claims.UserId, groupsNames) { if account.Settings.GroupsPropagationEnabled { if user, err := account.FindUser(claims.UserId); err == nil { - account.UserGroupsAddToPeers(claims.UserId, append(user.AutoGroups, groups...)...) + addNewGroups := difference(user.AutoGroups, oldGroups) + removeOldGroups := difference(oldGroups, user.AutoGroups) + account.UserGroupsAddToPeers(claims.UserId, addNewGroups...) + account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...) + account.Network.IncSerial() + if err := am.Store.SaveAccount(account); err != nil { + log.Errorf("failed to save account: %v", err) + } else { + if err := am.updateAccountPeers(account); err != nil { + log.Errorf("failed updating account peers while updating user %s", account.Id) + } + for _, g := range addNewGroups { + if group := account.GetGroup(g); group != nil { + am.storeEvent(user.Id, user.Id, account.Id, activity.GroupAddedToUser, + map[string]any{ + "group": group.Name, + "group_id": group.ID, + "is_service_user": user.IsServiceUser, + "user_name": user.ServiceUserName}) + } + } + for _, g := range removeOldGroups { + if group := account.GetGroup(g); group != nil { + am.storeEvent(user.Id, user.Id, account.Id, activity.GroupRemovedFromUser, + map[string]any{ + "group": group.Name, + "group_id": group.ID, + "is_service_user": user.IsServiceUser, + "user_name": user.ServiceUserName}) + } + } + } + } + } else { + if err := am.Store.SaveAccount(account); err != nil { + log.Errorf("failed to save account: %v", err) } - } - if err := am.Store.SaveAccount(account); err != nil { - log.Errorf("failed to save account: %v", err) } } } else { diff --git a/management/server/account_test.go b/management/server/account_test.go index 828fa8536..119828e20 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1930,7 +1930,7 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) { } } -func TestAccount_AddJWTGroups(t *testing.T) { +func TestAccount_SetJWTGroups(t *testing.T) { // create a new account account := &Account{ Peers: map[string]*Peer{ @@ -1951,13 +1951,13 @@ func TestAccount_AddJWTGroups(t *testing.T) { } t.Run("api group already exists", func(t *testing.T) { - updated := account.AddJWTGroups("user1", []string{"group1"}) + updated := account.SetJWTGroups("user1", []string{"group1"}) assert.False(t, updated, "account should not be updated") assert.Empty(t, account.Users["user1"].AutoGroups, "auto groups must be empty") }) t.Run("add jwt group", func(t *testing.T) { - updated := account.AddJWTGroups("user1", []string{"group1", "group2"}) + updated := account.SetJWTGroups("user1", []string{"group1", "group2"}) assert.True(t, updated, "account should be updated") assert.Len(t, account.Groups, 2, "new group should be added") assert.Len(t, account.Users["user1"].AutoGroups, 1, "new group should be added") @@ -1965,13 +1965,13 @@ func TestAccount_AddJWTGroups(t *testing.T) { }) t.Run("existed group not update", func(t *testing.T) { - updated := account.AddJWTGroups("user1", []string{"group2"}) + updated := account.SetJWTGroups("user1", []string{"group2"}) assert.False(t, updated, "account should not be updated") assert.Len(t, account.Groups, 2, "groups count should not be changed") }) t.Run("add new group", func(t *testing.T) { - updated := account.AddJWTGroups("user2", []string{"group1", "group3"}) + updated := account.SetJWTGroups("user2", []string{"group1", "group3"}) assert.True(t, updated, "account should be updated") assert.Len(t, account.Groups, 3, "new group should be added") assert.Len(t, account.Users["user2"].AutoGroups, 1, "new group should be added") diff --git a/management/server/user.go b/management/server/user.go index 0b61dd60a..19cffb840 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -610,10 +610,19 @@ func (am *DefaultAccountManager) SaveUser(accountID, initiatorUserID string, upd // need force update all auto groups in any case they will not be dublicated account.UserGroupsAddToPeers(oldUser.Id, update.AutoGroups...) account.UserGroupsRemoveFromPeers(oldUser.Id, removedGroups...) - } - if err = am.Store.SaveAccount(account); err != nil { - return nil, err + account.Network.IncSerial() + if err = am.Store.SaveAccount(account); err != nil { + return nil, err + } + + if err := am.updateAccountPeers(account); err != nil { + log.Errorf("failed updating account peers while updating user %s", accountID) + } + } else { + if err = am.Store.SaveAccount(account); err != nil { + return nil, err + } } defer func() { @@ -641,7 +650,6 @@ func (am *DefaultAccountManager) SaveUser(accountID, initiatorUserID string, upd } else { log.Errorf("group %s not found while saving user activity event of account %s", g, account.Id) } - } for _, g := range addedGroups { From da75a76d41408b263b653e6c3d1bb0721329e1f5 Mon Sep 17 00:00:00 2001 From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Fri, 18 Aug 2023 19:23:11 +0200 Subject: [PATCH 13/42] Adding dashboard login activity (#1092) For better auditing this PR adds a dashboard login event to the management service. For that the user object was extended with a field for last login that is not actively saved to the database but kept in memory until next write. The information about the last login can be extracted from the JWT claims nb_last_login. This timestamp will be stored and compared on each API request. If the value changes we generate an event to inform about a login. --- management/server/account.go | 17 ++++++------ management/server/activity/codes.go | 3 +++ management/server/file_store.go | 20 ++++++++++++++ management/server/http/api/openapi.yml | 5 ++++ management/server/http/api/types.gen.go | 3 +++ management/server/http/users_handler.go | 1 + management/server/jwtclaims/claims.go | 3 +++ management/server/jwtclaims/extractor.go | 18 +++++++++++++ management/server/jwtclaims/extractor_test.go | 11 ++++++++ management/server/store.go | 3 +++ management/server/updatechannel.go | 10 ++++--- management/server/user.go | 26 +++++++++++++++++++ management/server/user_test.go | 3 ++- 13 files changed, 110 insertions(+), 13 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index dc8f4d0fd..d9b73f713 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -189,14 +189,15 @@ type Account struct { } type UserInfo struct { - ID string `json:"id"` - Email string `json:"email"` - Name string `json:"name"` - Role string `json:"role"` - AutoGroups []string `json:"auto_groups"` - Status string `json:"-"` - IsServiceUser bool `json:"is_service_user"` - IsBlocked bool `json:"is_blocked"` + ID string `json:"id"` + Email string `json:"email"` + Name string `json:"name"` + Role string `json:"role"` + AutoGroups []string `json:"auto_groups"` + Status string `json:"-"` + IsServiceUser bool `json:"is_service_user"` + IsBlocked bool `json:"is_blocked"` + LastLogin time.Time `json:"last_login"` } // getRoutesToSync returns the enabled routes for the peer ID and the routes diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go index 7c6b55218..4de667ded 100644 --- a/management/server/activity/codes.go +++ b/management/server/activity/codes.go @@ -110,6 +110,8 @@ const ( UserLoggedInPeer // PeerLoginExpired indicates that the user peer login has been expired and peer disconnected PeerLoginExpired + // DashboardLogin indicates that the user logged in to the dashboard + DashboardLogin ) var activityMap = map[Activity]Code{ @@ -163,6 +165,7 @@ var activityMap = map[Activity]Code{ GroupDeleted: {"Group deleted", "group.delete"}, UserLoggedInPeer: {"User logged in peer", "user.peer.login"}, PeerLoginExpired: {"Peer login expired", "peer.login.expire"}, + DashboardLogin: {"Dashboard login", "dashboard.login"}, } // StringCode returns a string code of the activity diff --git a/management/server/file_store.go b/management/server/file_store.go index 0e95e3a05..ecd02ba99 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -570,6 +570,26 @@ func (s *FileStore) SavePeerStatus(accountID, peerID string, peerStatus PeerStat return nil } +// SaveUserLastLogin stores the last login time for a user in memory. It doesn't attempt to persist data to speed up things. +func (s *FileStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error { + s.mux.Lock() + defer s.mux.Unlock() + + account, err := s.getAccount(accountID) + if err != nil { + return err + } + + peer := account.Users[userID] + if peer == nil { + return status.Errorf(status.NotFound, "user %s not found", userID) + } + + peer.LastLogin = lastLogin + + return nil +} + // Close the FileStore persisting data to disk func (s *FileStore) Close() error { s.mux.Lock() diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index 1fb54c4f7..a09b9f6a6 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -100,6 +100,11 @@ components: type: string enum: [ "active","invited","blocked" ] example: active + last_login: + description: Last time this user performed a login to the dashboard + type: string + format: date-time + example: 2023-05-05T09:00:35.477782Z auto_groups: description: Groups to auto-assign to peers registered by this user type: array diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index 93d371a17..5b629cc0e 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -767,6 +767,9 @@ type User struct { // IsServiceUser Is true if this user is a service user IsServiceUser *bool `json:"is_service_user,omitempty"` + // LastLogin Last time this user performed a login to the dashboard + LastLogin *time.Time `json:"last_login,omitempty"` + // Name User's name from idp provider Name string `json:"name"` diff --git a/management/server/http/users_handler.go b/management/server/http/users_handler.go index 45b2a7618..d215e1510 100644 --- a/management/server/http/users_handler.go +++ b/management/server/http/users_handler.go @@ -270,5 +270,6 @@ func toUserResponse(user *server.UserInfo, currenUserID string) *api.User { IsCurrent: &isCurrent, IsServiceUser: &user.IsServiceUser, IsBlocked: user.IsBlocked, + LastLogin: &user.LastLogin, } } diff --git a/management/server/jwtclaims/claims.go b/management/server/jwtclaims/claims.go index 946c0b8be..1fa00b2fe 100644 --- a/management/server/jwtclaims/claims.go +++ b/management/server/jwtclaims/claims.go @@ -1,6 +1,8 @@ package jwtclaims import ( + "time" + "github.com/golang-jwt/jwt" ) @@ -10,6 +12,7 @@ type AuthorizationClaims struct { AccountId string Domain string DomainCategory string + LastLogin time.Time Raw jwt.MapClaims } diff --git a/management/server/jwtclaims/extractor.go b/management/server/jwtclaims/extractor.go index 466856d77..42a41f140 100644 --- a/management/server/jwtclaims/extractor.go +++ b/management/server/jwtclaims/extractor.go @@ -2,6 +2,7 @@ package jwtclaims import ( "net/http" + "time" "github.com/golang-jwt/jwt" ) @@ -17,6 +18,8 @@ const ( DomainCategorySuffix = "wt_account_domain_category" // UserIDClaim claim for the user id UserIDClaim = "sub" + // LastLoginSuffix claim for the last login + LastLoginSuffix = "nb_last_login" ) // ExtractClaims Extract function type @@ -93,9 +96,24 @@ func (c *ClaimsExtractor) FromToken(token *jwt.Token) AuthorizationClaims { if ok { jwtClaims.DomainCategory = domainCategoryClaim.(string) } + LastLoginClaimString, ok := claims[c.authAudience+LastLoginSuffix] + if ok { + jwtClaims.LastLogin = parseTime(LastLoginClaimString.(string)) + } return jwtClaims } +func parseTime(timeString string) time.Time { + if timeString == "" { + return time.Time{} + } + parsedTime, err := time.Parse(time.RFC3339, timeString) + if err != nil { + return time.Time{} + } + return parsedTime +} + // fromRequestContext extracts claims from the request context previously filled by the JWT token (after auth) func (c *ClaimsExtractor) fromRequestContext(r *http.Request) AuthorizationClaims { if r.Context().Value(TokenUserProperty) == nil { diff --git a/management/server/jwtclaims/extractor_test.go b/management/server/jwtclaims/extractor_test.go index 9bececac6..f7eeb82e5 100644 --- a/management/server/jwtclaims/extractor_test.go +++ b/management/server/jwtclaims/extractor_test.go @@ -4,12 +4,15 @@ import ( "context" "net/http" "testing" + "time" "github.com/golang-jwt/jwt" "github.com/stretchr/testify/require" ) func newTestRequestWithJWT(t *testing.T, claims AuthorizationClaims, audiance string) *http.Request { + const layout = "2006-01-02T15:04:05.999Z" + claimMaps := jwt.MapClaims{} if claims.UserId != "" { claimMaps[UserIDClaim] = claims.UserId @@ -23,6 +26,9 @@ func newTestRequestWithJWT(t *testing.T, claims AuthorizationClaims, audiance st if claims.DomainCategory != "" { claimMaps[audiance+DomainCategorySuffix] = claims.DomainCategory } + if claims.LastLogin != (time.Time{}) { + claimMaps[audiance+LastLoginSuffix] = claims.LastLogin.Format(layout) + } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps) r, err := http.NewRequest(http.MethodGet, "http://localhost", nil) require.NoError(t, err, "creating testing request failed") @@ -40,6 +46,9 @@ func TestExtractClaimsFromRequestContext(t *testing.T) { expectedMSG string } + const layout = "2006-01-02T15:04:05.999Z" + lastLogin, _ := time.Parse(layout, "2023-08-17T09:30:40.465Z") + testCase1 := test{ name: "All Claim Fields", inputAudiance: "https://login/", @@ -47,11 +56,13 @@ func TestExtractClaimsFromRequestContext(t *testing.T) { UserId: "test", Domain: "test.com", AccountId: "testAcc", + LastLogin: lastLogin, DomainCategory: "public", Raw: jwt.MapClaims{ "https://login/wt_account_domain": "test.com", "https://login/wt_account_domain_category": "public", "https://login/wt_account_id": "testAcc", + "https://login/nb_last_login": lastLogin.Format(layout), "sub": "test", }, }, diff --git a/management/server/store.go b/management/server/store.go index daad30eaa..9ebe41235 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -1,5 +1,7 @@ package server +import "time" + type Store interface { GetAllAccounts() []*Account GetAccount(accountID string) (*Account, error) @@ -20,6 +22,7 @@ type Store interface { // AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock AcquireGlobalLock() func() SavePeerStatus(accountID, peerID string, status PeerStatus) error + SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error // Close should close the store persisting all unsaved data. Close() error } diff --git a/management/server/updatechannel.go b/management/server/updatechannel.go index 6cc10ad24..744386547 100644 --- a/management/server/updatechannel.go +++ b/management/server/updatechannel.go @@ -1,9 +1,11 @@ package server import ( - "github.com/netbirdio/netbird/management/proto" - log "github.com/sirupsen/logrus" "sync" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/management/proto" ) const channelBufferSize = 100 @@ -33,7 +35,7 @@ func (p *PeersUpdateManager) SendUpdate(peerID string, update *UpdateMessage) er if channel, ok := p.peerChannels[peerID]; ok { select { case channel <- update: - log.Infof("update was sent to channel for peer %s", peerID) + log.Debugf("update was sent to channel for peer %s", peerID) default: log.Warnf("channel for peer %s is %d full", peerID, len(channel)) } @@ -52,7 +54,7 @@ func (p *PeersUpdateManager) CreateChannel(peerID string) chan *UpdateMessage { delete(p.peerChannels, peerID) close(channel) } - //mbragin: todo shouldn't it be more? or configurable? + // mbragin: todo shouldn't it be more? or configurable? channel := make(chan *UpdateMessage, channelBufferSize) p.peerChannels[peerID] = channel diff --git a/management/server/user.go b/management/server/user.go index 19cffb840..3d0c0313e 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -3,6 +3,7 @@ package server import ( "fmt" "strings" + "time" "github.com/google/uuid" log "github.com/sirupsen/logrus" @@ -53,6 +54,8 @@ type User struct { PATs map[string]*PersonalAccessToken // Blocked indicates whether the user is blocked. Blocked users can't use the system. Blocked bool + // LastLogin is the last time the user logged in to IdP + LastLogin time.Time } // IsBlocked returns true if the user is blocked, false otherwise @@ -60,6 +63,10 @@ func (u *User) IsBlocked() bool { return u.Blocked } +func (u *User) LastDashboardLoginChanged(LastLogin time.Time) bool { + return LastLogin.After(u.LastLogin) && !u.LastLogin.IsZero() +} + // IsAdmin returns true if the user is an admin, false otherwise func (u *User) IsAdmin() bool { return u.Role == UserRoleAdmin @@ -82,6 +89,7 @@ func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) { Status: string(UserStatusActive), IsServiceUser: u.IsServiceUser, IsBlocked: u.Blocked, + LastLogin: u.LastLogin, }, nil } if userData.ID != u.Id { @@ -102,6 +110,7 @@ func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) { Status: string(userStatus), IsServiceUser: u.IsServiceUser, IsBlocked: u.Blocked, + LastLogin: u.LastLogin, }, nil } @@ -123,6 +132,7 @@ func (u *User) Copy() *User { ServiceUserName: u.ServiceUserName, PATs: pats, Blocked: u.Blocked, + LastLogin: u.LastLogin, } } @@ -186,6 +196,7 @@ func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUs AutoGroups: newUser.AutoGroups, Status: string(UserStatusActive), IsServiceUser: true, + LastLogin: time.Time{}, }, nil } @@ -280,6 +291,21 @@ func (am *DefaultAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) ( if !ok { return nil, status.Errorf(status.NotFound, "user not found") } + + // this code should be outside of the am.GetAccountFromToken(claims) because this method is called also by the gRPC + // server when user authenticates a device. And we need to separate the Dashboard login event from the Device login event. + unlock := am.Store.AcquireAccountLock(account.Id) + newLogin := user.LastDashboardLoginChanged(claims.LastLogin) + err = am.Store.SaveUserLastLogin(account.Id, claims.UserId, claims.LastLogin) + unlock() + if newLogin { + meta := map[string]any{"timestamp": claims.LastLogin} + am.storeEvent(claims.UserId, claims.UserId, account.Id, activity.DashboardLogin, meta) + if err != nil { + log.Errorf("failed saving user last login: %v", err) + } + } + return user, nil } diff --git a/management/server/user_test.go b/management/server/user_test.go index d6226f76d..b07154663 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -266,7 +266,8 @@ func TestUser_Copy(t *testing.T) { LastUsed: time.Now(), }, }, - Blocked: false, + Blocked: false, + LastLogin: time.Now(), } err := validateStruct(user) From 892db250219de68c46dcdc6fb223f00f0e908691 Mon Sep 17 00:00:00 2001 From: Misha Bragin Date: Mon, 21 Aug 2023 09:11:52 +0200 Subject: [PATCH 14/42] docs: change get started link (#1098) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index bf83435f9..ef391e90f 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@

- Start using NetBird at app.netbird.io + Start using NetBird at netbird.io
See Documentation
From e586eca16cf4aa95b18afd4585e042cc693a5aa9 Mon Sep 17 00:00:00 2001 From: Yury Gargay Date: Tue, 22 Aug 2023 17:56:39 +0200 Subject: [PATCH 15/42] Improve account copying (#1069) With this fix, all nested slices and pointers will be copied by value. Also, this fixes tests to compare the original and copy account by their values by marshaling them to JSON strings. Before that, they were copying the pointers that also passed the simple `=` compassion (as the addresses match). --- dns/nameserver.go | 14 +++++--- management/server/account_test.go | 34 ++++++++++++++++--- management/server/group.go | 6 ++-- .../server/http/nameservers_handler_test.go | 1 + management/server/peer.go | 8 +++-- management/server/personal_access_token.go | 12 +++++++ management/server/policy.go | 17 ++++++---- management/server/rule.go | 9 +++-- management/server/setupkey.go | 4 +-- management/server/user.go | 4 +-- route/route.go | 9 +++-- 11 files changed, 88 insertions(+), 30 deletions(-) diff --git a/dns/nameserver.go b/dns/nameserver.go index 807df5907..7751f8e1c 100644 --- a/dns/nameserver.go +++ b/dns/nameserver.go @@ -130,16 +130,22 @@ func ParseNameServerURL(nsURL string) (NameServer, error) { // Copy copies a nameserver group object func (g *NameServerGroup) Copy() *NameServerGroup { - return &NameServerGroup{ + nsGroup := &NameServerGroup{ ID: g.ID, Name: g.Name, Description: g.Description, - NameServers: g.NameServers, - Groups: g.Groups, + NameServers: make([]NameServer, len(g.NameServers)), + Groups: make([]string, len(g.Groups)), Enabled: g.Enabled, Primary: g.Primary, - Domains: g.Domains, + Domains: make([]string, len(g.Domains)), } + + copy(nsGroup.NameServers, g.NameServers) + copy(nsGroup.Groups, g.Groups) + copy(nsGroup.Domains, g.Domains) + + return nsGroup } // IsEqual compares one nameserver group with the other diff --git a/management/server/account_test.go b/management/server/account_test.go index 119828e20..29af8514a 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -3,6 +3,7 @@ package server import ( "crypto/sha256" b64 "encoding/base64" + "encoding/json" "fmt" "net" "reflect" @@ -1348,6 +1349,11 @@ func TestAccount_Copy(t *testing.T) { Peers: map[string]*Peer{ "peer1": { Key: "key1", + Status: &PeerStatus{ + LastSeen: time.Now(), + Connected: true, + LoginExpired: false, + }, }, }, Users: map[string]*User{ @@ -1370,28 +1376,36 @@ func TestAccount_Copy(t *testing.T) { }, Groups: map[string]*Group{ "group1": { - ID: "group1", + ID: "group1", + Peers: []string{"peer1"}, }, }, Rules: map[string]*Rule{ "rule1": { - ID: "rule1", + ID: "rule1", + Destination: []string{}, + Source: []string{}, }, }, Policies: []*Policy{ { ID: "policy1", Enabled: true, + Rules: make([]*PolicyRule, 0), }, }, Routes: map[string]*route.Route{ "route1": { - ID: "route1", + ID: "route1", + Groups: []string{"group1"}, }, }, NameServerGroups: map[string]*nbdns.NameServerGroup{ "nsGroup1": { - ID: "nsGroup1", + ID: "nsGroup1", + Domains: []string{}, + Groups: []string{}, + NameServers: []nbdns.NameServer{}, }, }, DNSSettings: &DNSSettings{DisabledManagementGroups: []string{}}, @@ -1402,10 +1416,20 @@ func TestAccount_Copy(t *testing.T) { t.Fatal(err) } accountCopy := account.Copy() - assert.Equal(t, account, accountCopy, "account copy returned a different value than expected") + accBytes, err := json.Marshal(account) + if err != nil { + t.Fatal(err) + } + account.Peers["peer1"].Status.Connected = false // we change original object to confirm that copy wont change + accCopyBytes, err := json.Marshal(accountCopy) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, string(accBytes), string(accCopyBytes), "account copy returned a different value than expected") } // hasNilField validates pointers, maps and slices if they are nil +// TODO: make it check nested fields too func hasNilField(x interface{}) error { rv := reflect.ValueOf(x) rv = rv.Elem() diff --git a/management/server/group.go b/management/server/group.go index 53571e099..5b1d2ac9f 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -59,12 +59,14 @@ func (g *Group) EventMeta() map[string]any { } func (g *Group) Copy() *Group { - return &Group{ + group := &Group{ ID: g.ID, Name: g.Name, Issued: g.Issued, - Peers: g.Peers[:], + Peers: make([]string, len(g.Peers)), } + copy(group.Peers, g.Peers) + return group } // GetGroup object of the peers diff --git a/management/server/http/nameservers_handler_test.go b/management/server/http/nameservers_handler_test.go index 01c3cbe79..75fcb4c1c 100644 --- a/management/server/http/nameservers_handler_test.go +++ b/management/server/http/nameservers_handler_test.go @@ -54,6 +54,7 @@ var baseExistingNSGroup = &nbdns.NameServerGroup{ }, }, Groups: []string{"testing"}, + Domains: []string{"domain"}, Enabled: true, } diff --git a/management/server/peer.go b/management/server/peer.go index b2d4e436f..90377b1e8 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -108,6 +108,10 @@ func (p *Peer) AddedWithSSOLogin() bool { // Copy copies Peer object func (p *Peer) Copy() *Peer { + peerStatus := p.Status + if peerStatus != nil { + peerStatus = p.Status.Copy() + } return &Peer{ ID: p.ID, Key: p.Key, @@ -115,11 +119,11 @@ func (p *Peer) Copy() *Peer { IP: p.IP, Meta: p.Meta, Name: p.Name, - Status: p.Status, + DNSLabel: p.DNSLabel, + Status: peerStatus, UserID: p.UserID, SSHKey: p.SSHKey, SSHEnabled: p.SSHEnabled, - DNSLabel: p.DNSLabel, LoginExpirationEnabled: p.LoginExpirationEnabled, LastLogin: p.LastLogin, } diff --git a/management/server/personal_access_token.go b/management/server/personal_access_token.go index 0a55f3237..c7deca9de 100644 --- a/management/server/personal_access_token.go +++ b/management/server/personal_access_token.go @@ -36,6 +36,18 @@ type PersonalAccessToken struct { LastUsed time.Time } +func (t *PersonalAccessToken) Copy() *PersonalAccessToken { + return &PersonalAccessToken{ + ID: t.ID, + Name: t.Name, + HashedToken: t.HashedToken, + ExpirationDate: t.ExpirationDate, + CreatedBy: t.CreatedBy, + CreatedAt: t.CreatedAt, + LastUsed: t.LastUsed, + } +} + // PersonalAccessTokenGenerated holds the new PersonalAccessToken and the plain text version of it type PersonalAccessTokenGenerated struct { PlainToken string diff --git a/management/server/policy.go b/management/server/policy.go index 54158eeac..dde0b46d8 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -95,18 +95,22 @@ type PolicyRule struct { // Copy returns a copy of a policy rule func (pm *PolicyRule) Copy() *PolicyRule { - return &PolicyRule{ + rule := &PolicyRule{ ID: pm.ID, Name: pm.Name, Description: pm.Description, Enabled: pm.Enabled, Action: pm.Action, - Destinations: pm.Destinations[:], - Sources: pm.Sources[:], + Destinations: make([]string, len(pm.Destinations)), + Sources: make([]string, len(pm.Sources)), Bidirectional: pm.Bidirectional, Protocol: pm.Protocol, - Ports: pm.Ports[:], + Ports: make([]string, len(pm.Ports)), } + copy(rule.Destinations, pm.Destinations) + copy(rule.Sources, pm.Sources) + copy(rule.Ports, pm.Ports) + return rule } // ToRule converts the PolicyRule to a legacy representation of the Rule (for backwards compatibility) @@ -147,9 +151,10 @@ func (p *Policy) Copy() *Policy { Name: p.Name, Description: p.Description, Enabled: p.Enabled, + Rules: make([]*PolicyRule, len(p.Rules)), } - for _, r := range p.Rules { - c.Rules = append(c.Rules, r.Copy()) + for i, r := range p.Rules { + c.Rules[i] = r.Copy() } return c } diff --git a/management/server/rule.go b/management/server/rule.go index 68b1cc4fb..cb85d633d 100644 --- a/management/server/rule.go +++ b/management/server/rule.go @@ -45,15 +45,18 @@ type Rule struct { } func (r *Rule) Copy() *Rule { - return &Rule{ + rule := &Rule{ ID: r.ID, Name: r.Name, Description: r.Description, Disabled: r.Disabled, - Source: r.Source[:], - Destination: r.Destination[:], + Source: make([]string, len(r.Source)), + Destination: make([]string, len(r.Destination)), Flow: r.Flow, } + copy(rule.Source, r.Source) + copy(rule.Destination, r.Destination) + return rule } // EventMeta returns activity event meta related to this rule diff --git a/management/server/setupkey.go b/management/server/setupkey.go index bfba05839..ffdd822e3 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -90,8 +90,8 @@ type SetupKey struct { // Copy copies SetupKey to a new object func (key *SetupKey) Copy() *SetupKey { - autoGroups := make([]string, 0) - autoGroups = append(autoGroups, key.AutoGroups...) + autoGroups := make([]string, len(key.AutoGroups)) + copy(autoGroups, key.AutoGroups) if key.UpdatedAt.IsZero() { key.UpdatedAt = key.CreatedAt } diff --git a/management/server/user.go b/management/server/user.go index 3d0c0313e..b3556957d 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -120,9 +120,7 @@ func (u *User) Copy() *User { copy(autoGroups, u.AutoGroups) pats := make(map[string]*PersonalAccessToken, len(u.PATs)) for k, v := range u.PATs { - patCopy := new(PersonalAccessToken) - *patCopy = *v - pats[k] = patCopy + pats[k] = v.Copy() } return &User{ Id: u.Id, diff --git a/route/route.go b/route/route.go index fbd077bc2..5c45e2cf5 100644 --- a/route/route.go +++ b/route/route.go @@ -1,8 +1,9 @@ package route import ( - "github.com/netbirdio/netbird/management/server/status" "net/netip" + + "github.com/netbirdio/netbird/management/server/status" ) // Windows has some limitation regarding metric size that differ from Unix-like systems. @@ -83,7 +84,7 @@ func (r *Route) EventMeta() map[string]any { // Copy copies a route object func (r *Route) Copy() *Route { - return &Route{ + route := &Route{ ID: r.ID, Description: r.Description, NetID: r.NetID, @@ -93,8 +94,10 @@ func (r *Route) Copy() *Route { Metric: r.Metric, Masquerade: r.Masquerade, Enabled: r.Enabled, - Groups: r.Groups, + Groups: make([]string, len(r.Groups)), } + copy(route.Groups, r.Groups) + return route } // IsEqual compares one route with the other From ac0b7dc8cb85de0e3c4072ff5bab524b099a8ac2 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Wed, 23 Aug 2023 21:03:34 +0300 Subject: [PATCH 16/42] Enhance linux client authentication (#1093) The change clarifies the message usage, indicating that setup keys can alternatively be used in the authentication process. This approach adds flexibility in scenarios where automated authentication is unachievable, especially in non-desktop Linux environments. --- client/cmd/login.go | 74 ++++++++++++++++++++++++++++++++++++--------- 1 file changed, 60 insertions(+), 14 deletions(-) diff --git a/client/cmd/login.go b/client/cmd/login.go index c61c0a93a..fc0ef82e5 100644 --- a/client/cmd/login.go +++ b/client/cmd/login.go @@ -3,21 +3,21 @@ package cmd import ( "context" "fmt" - "github.com/netbirdio/netbird/client/internal/auth" + "os" + "runtime" "strings" "time" "github.com/skratchdot/open-golang/open" + "github.com/spf13/cobra" "google.golang.org/grpc/codes" gstatus "google.golang.org/grpc/status" - "github.com/netbirdio/netbird/util" - - "github.com/spf13/cobra" - "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/auth" "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/system" + "github.com/netbirdio/netbird/util" ) var loginCmd = &cobra.Command{ @@ -191,17 +191,63 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int func openURL(cmd *cobra.Command, verificationURIComplete, userCode string) { var codeMsg string - if userCode != "" { - if !strings.Contains(verificationURIComplete, userCode) { - codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode) + if userCode != "" && !strings.Contains(verificationURIComplete, userCode) { + codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode) + } + + browserAuthMsg := "Please do the SSO login in your browser. \n" + + "If your browser didn't open automatically, use this URL to log in:\n\n" + + verificationURIComplete + " " + codeMsg + + setupKeyAuthMsg := "\nAlternatively, you may want to use a setup key, see:\n\n" + + "https://docs.netbird.io/how-to/register-machines-using-setup-keys" + + authenticateUsingBrowser := func() { + cmd.Println(browserAuthMsg) + if err := open.Run(verificationURIComplete); err != nil { + cmd.Println(setupKeyAuthMsg) } } - err := open.Run(verificationURIComplete) - cmd.Printf("Please do the SSO login in your browser. \n" + - "If your browser didn't open automatically, use this URL to log in:\n\n" + - " " + verificationURIComplete + " " + codeMsg + " \n\n") - if err != nil { - cmd.Printf("Alternatively, you may want to use a setup key, see:\n\n https://docs.netbird.io/how-to/register-machines-using-setup-keys\n") + switch runtime.GOOS { + case "windows", "darwin": + authenticateUsingBrowser() + case "linux": + if isLinuxRunningDesktop() { + authenticateUsingBrowser() + } else { + // If current flow is PKCE, it implies the server is anticipating the redirect to localhost. + // Devices lacking browser support are incompatible with this flow.Therefore, + // these devices will need to resort to setup keys instead. + if isPKCEFlow(verificationURIComplete) { + cmd.Println("Please proceed with setting up this device using setup keys, see:\n\n" + + "https://docs.netbird.io/how-to/register-machines-using-setup-keys") + } else { + cmd.Println(browserAuthMsg) + } + } } } + +// isLinuxRunningDesktop checks if a Linux OS is running desktop environment. +func isLinuxRunningDesktop() bool { + for _, env := range os.Environ() { + values := strings.Split(env, "=") + if len(values) == 2 { + key, value := values[0], values[1] + if key == "XDG_CURRENT_DESKTOP" && value != "" { + return true + } + } + } + return false +} + +// isPKCEFlow determines if the PKCE flow is active or not, +// by checking the existence of redirect_uri inside the verification URL. +func isPKCEFlow(verificationURL string) bool { + if verificationURL == "" { + return false + } + return strings.Contains(verificationURL, "redirect_uri") +} From 80d9b5fca5e9207670c49ec6a700398afa88c586 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Mon, 28 Aug 2023 17:21:04 +0300 Subject: [PATCH 17/42] Add auto-update feature in netbird script for binary installation (#1106) This pull request addresses the need to enhance the installer script by introducing a new parameter --update to trigger updates. The goal is to streamline the update process for binary installations and provide a better experience for users. --- release_files/install.sh | 159 ++++++++++++++++++++++++++------------- 1 file changed, 108 insertions(+), 51 deletions(-) mode change 100644 => 100755 release_files/install.sh diff --git a/release_files/install.sh b/release_files/install.sh old mode 100644 new mode 100755 index 971c074b6..99cab9e26 --- a/release_files/install.sh +++ b/release_files/install.sh @@ -3,6 +3,9 @@ # Source: https://github.com/physk/netbird-installer set -e +CONFIG_FOLDER="/etc/netbird" +CONFIG_FILE="$CONFIG_FOLDER/install.conf" + OWNER="netbirdio" REPO="netbird" CLI_APP="netbird" @@ -12,7 +15,7 @@ UI_APP="netbird-ui" OS_NAME="" OS_TYPE="" ARCH="$(uname -m)" -PACKAGE_MANAGER="" +PACKAGE_MANAGER="bin" INSTALL_DIR="" get_latest_release() { @@ -25,7 +28,7 @@ download_release_binary() { BASE_URL="https://github.com/${OWNER}/${REPO}/releases/download" BINARY_BASE_NAME="${VERSION#v}_${OS_TYPE}_${ARCH}.tar.gz" - # for Darwin, download the signed Netbird-UI + # for Darwin, download the signed NetBird-UI if [ "$OS_TYPE" = "darwin" ] && [ "$1" = "$UI_APP" ]; then BINARY_BASE_NAME="${VERSION#v}_${OS_TYPE}_${ARCH}_signed.zip" fi @@ -42,12 +45,12 @@ download_release_binary() { DOWNLOAD_URL="${BASE_URL}/${VERSION}/${BINARY_NAME}" echo "Installing $1 from $DOWNLOAD_URL" - cd /tmp && curl -LO "$DOWNLOAD_URL" - - + cd /tmp && curl -LO "$DOWNLOAD_URL" + + if [ "$OS_TYPE" = "darwin" ] && [ "$1" = "$UI_APP" ]; then INSTALL_DIR="/Applications/NetBird UI.app" - + # Unzip the app and move to INSTALL_DIR unzip -q -o "$BINARY_NAME" mv "netbird_ui_${OS_TYPE}_${ARCH}" "$INSTALL_DIR" @@ -61,7 +64,7 @@ download_release_binary() { add_apt_repo() { sudo apt-get update sudo apt-get install ca-certificates gnupg -y - + curl -sSL https://pkgs.wiretrustee.com/debian/public.key \ | sudo gpg --dearmor --output /usr/share/keyrings/wiretrustee-archive-keyring.gpg @@ -73,15 +76,15 @@ add_apt_repo() { add_rpm_repo() { cat <<-EOF | sudo tee /etc/yum.repos.d/netbird.repo -[Netbird] -name=Netbird +[NetBird] +name=NetBird baseurl=https://pkgs.netbird.io/yum/ enabled=1 gpgcheck=0 gpgkey=https://pkgs.netbird.io/yum/repodata/repomd.xml.key repo_gpgcheck=1 EOF -} +} add_aur_repo() { INSTALL_PKGS="git base-devel go" @@ -99,10 +102,10 @@ add_aur_repo() { done # Build package from AUR - cd /tmp && git clone https://aur.archlinux.org/netbird.git + cd /tmp && git clone https://aur.archlinux.org/netbird.git cd netbird && makepkg -sri --noconfirm - if ! $SKIP_UI_APP; then + if ! $SKIP_UI_APP; then cd /tmp && git clone https://aur.archlinux.org/netbird-ui.git cd netbird-ui && makepkg -sri --noconfirm fi @@ -131,9 +134,9 @@ install_native_binaries() { # download and copy binaries to INSTALL_DIR download_release_binary "$CLI_APP" - if ! $SKIP_UI_APP; then + if ! $SKIP_UI_APP; then download_release_binary "$UI_APP" - fi + fi } check_use_bin_variable() { @@ -145,23 +148,26 @@ check_use_bin_variable() { } install_netbird() { - # Check if netbird CLI is installed if [ -x "$(command -v netbird)" ]; then - if netbird status > /dev/null 2>&1; then - echo "Netbird service is running, please stop it before proceeding" + status_output=$(netbird status) + if echo "$status_output" | grep -q 'Management: Connected' && echo "$status_output" | grep -q 'Signal: Connected'; then + echo "NetBird service is running, please stop it before proceeding" + exit 1 fi - echo "Netbird seems to be installed already, please remove it before proceeding" - exit 1 + if [ -n "$status_output" ]; then + echo "NetBird seems to be installed already, please remove it before proceeding" + exit 1 + fi fi # Checks if SKIP_UI_APP env is set if [ -z "$SKIP_UI_APP" ]; then SKIP_UI_APP=false else - if $SKIP_UI_APP; then + if $SKIP_UI_APP; then echo "SKIP_UI_APP has been set to true in the environment" - echo "Netbird UI installation will be omitted based on your preference" + echo "NetBird UI installation will be omitted based on your preference" fi fi @@ -169,21 +175,21 @@ install_netbird() { if type uname >/dev/null 2>&1; then case "$(uname)" in Linux) - OS_NAME="$(. /etc/os-release && echo "$ID")" + OS_NAME="$(. /etc/os-release && echo "$ID")" OS_TYPE="linux" INSTALL_DIR="/usr/bin" - + # Allow netbird UI installation for x64 arch only if [ "$ARCH" != "amd64" ] && [ "$ARCH" != "arm64" ] \ && [ "$ARCH" != "x86_64" ];then SKIP_UI_APP=true - echo "Netbird UI installation will be omitted as $ARCH is not a compactible architecture" + echo "NetBird UI installation will be omitted as $ARCH is not a compactible architecture" fi # Allow netbird UI installation for linux running desktop enviroment if [ -z "$XDG_CURRENT_DESKTOP" ];then - SKIP_UI_APP=true - echo "Netbird UI installation will be omitted as Linux does not run desktop environment" + SKIP_UI_APP=true + echo "NetBird UI installation will be omitted as Linux does not run desktop environment" fi # Check the availability of a compatible package manager @@ -207,7 +213,7 @@ install_netbird() { OS_NAME="macos" OS_TYPE="darwin" INSTALL_DIR="/usr/local/bin" - + # Check the availability of a compatible package manager if check_use_bin_variable; then PACKAGE_MANAGER="bin" @@ -225,15 +231,15 @@ install_netbird() { apt) add_apt_repo sudo apt-get install netbird -y - - if ! $SKIP_UI_APP; then + + if ! $SKIP_UI_APP; then sudo apt-get install netbird-ui -y fi ;; yum) add_rpm_repo sudo yum -y install netbird - if ! $SKIP_UI_APP; then + if ! $SKIP_UI_APP; then sudo yum -y install netbird-ui fi ;; @@ -243,7 +249,7 @@ install_netbird() { sudo dnf config-manager --add-repo /etc/yum.repos.d/netbird.repo sudo dnf -y install netbird - if ! $SKIP_UI_APP; then + if ! $SKIP_UI_APP; then sudo dnf -y install netbird-ui fi ;; @@ -255,46 +261,50 @@ install_netbird() { # Remove Wiretrustee if it had been installed using Homebrew before if brew ls --versions wiretrustee >/dev/null 2>&1; then echo "Removing existing wiretrustee client" - + # Stop and uninstall daemon service: wiretrustee service stop - wiretrustee service uninstall + wiretrustee service uninstall # Unlik the app brew unlink wiretrustee fi brew install netbirdio/tap/netbird - if ! $SKIP_UI_APP; then + if ! $SKIP_UI_APP; then brew install --cask netbirdio/tap/netbird-ui fi ;; *) - if [ "$OS_NAME" = "nixos" ];then - echo "Please add Netbird to your NixOS configuration.nix directly:" - echo - echo "services.netbird.enable = true;" + if [ "$OS_NAME" = "nixos" ];then + echo "Please add NetBird to your NixOS configuration.nix directly:" + echo "" + echo "services.netbird.enable = true;" - if ! $SKIP_UI_APP; then - echo "environment.systemPackages = [ pkgs.netbird-ui ];" - fi + if ! $SKIP_UI_APP; then + echo "environment.systemPackages = [ pkgs.netbird-ui ];" + fi - echo "Build and apply new configuration:" - echo - echo "sudo nixos-rebuild switch" - exit 0 - fi + echo "Build and apply new configuration:" + echo "" + echo "sudo nixos-rebuild switch" + exit 0 + fi install_native_binaries ;; esac + # Add package manager to config + sudo mkdir -p "$CONFIG_FOLDER" + echo "package_manager=$PACKAGE_MANAGER" | sudo tee "$CONFIG_FILE" > /dev/null + # Load and start netbird service - if ! sudo netbird service install 2>&1; then - echo "Netbird service has already been loaded" + if ! sudo netbird service install 2>&1; then + echo "NetBird service has already been loaded" fi - if ! sudo netbird service start 2>&1; then - echo "Netbird service has already been started" + if ! sudo netbird service start 2>&1; then + echo "NetBird service has already been started" fi @@ -303,4 +313,51 @@ install_netbird() { echo "sudo netbird up" } -install_netbird +version_greater_equal() { + printf '%s\n%s\n' "$2" "$1" | sort -V -C +} + +is_bin_package_manager() { + if sudo test -f "$1" && sudo grep -q "package_manager=bin" "$1" ; then + return 0 + else + return 1 + fi +} + +update_netbird() { + if is_bin_package_manager "$CONFIG_FILE"; then + latest_release=$(get_latest_release) + latest_version=${latest_release#v} + installed_version=$(netbird version) + + if [ "$latest_version" = "$installed_version" ]; then + echo "Installed netbird version ($installed_version) is up-to-date" + exit 0 + fi + + if version_greater_equal "$latest_version" "$installed_version"; then + echo "NetBird new version ($latest_version) available. Updating..." + echo "" + echo "Initiating NetBird update. This will stop the netbird service and restart it after the update" + + sudo netbird service stop + sudo netbird service uninstall + install_native_binaries + + sudo netbird service install + sudo netbird service start + fi + else + echo "NetBird installation was done using a package manager. Please use your system's package manager to update" + fi +} + + +case "$1" in + --update) + update_netbird + ;; + *) + install_netbird +esac \ No newline at end of file From 1a9301b6840e1401782735d840800dd0300b2a07 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Tue, 29 Aug 2023 10:13:27 +0300 Subject: [PATCH 18/42] Close PKCE Listening Port After Authorization (#1110) Addresses the issue of an open listening port persisting after the PKCE authorization flow is completed. --- client/cmd/login.go | 1 + client/internal/auth/pkce_flow.go | 21 ++++++++++++++------- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/client/cmd/login.go b/client/cmd/login.go index fc0ef82e5..794a599fd 100644 --- a/client/cmd/login.go +++ b/client/cmd/login.go @@ -204,6 +204,7 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string) { authenticateUsingBrowser := func() { cmd.Println(browserAuthMsg) + cmd.Println("") if err := open.Run(verificationURIComplete); err != nil { cmd.Println(setupKeyAuthMsg) } diff --git a/client/internal/auth/pkce_flow.go b/client/internal/auth/pkce_flow.go index eec41b2dd..d15d49373 100644 --- a/client/internal/auth/pkce_flow.go +++ b/client/internal/auth/pkce_flow.go @@ -5,12 +5,14 @@ import ( "crypto/sha256" "crypto/subtle" "encoding/base64" + "errors" "fmt" "html/template" "net" "net/http" "net/url" "strings" + "sync" "time" log "github.com/sirupsen/logrus" @@ -125,21 +127,25 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) ( } func (p *PKCEAuthorizationFlow) startServer(tokenChan chan<- *oauth2.Token, errChan chan<- error) { + var wg sync.WaitGroup + parsedURL, err := url.Parse(p.oAuthConfig.RedirectURL) if err != nil { errChan <- fmt.Errorf("failed to parse redirect URL: %v", err) return } - port := parsedURL.Port() - server := http.Server{Addr: fmt.Sprintf(":%s", port)} - defer func() { - if err := server.Shutdown(context.Background()); err != nil { - log.Errorf("error while shutting down pkce flow server: %v", err) + server := http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())} + go func() { + if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + errChan <- err } }() + wg.Add(1) http.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) { + defer wg.Done() + tokenValidatorFunc := func() (*oauth2.Token, error) { query := req.URL.Query() @@ -176,8 +182,9 @@ func (p *PKCEAuthorizationFlow) startServer(tokenChan chan<- *oauth2.Token, errC tokenChan <- token }) - if err := server.ListenAndServe(); err != nil { - errChan <- err + wg.Wait() + if err := server.Shutdown(context.Background()); err != nil { + log.Errorf("error while shutting down pkce flow server: %v", err) } } From 00dddb94588cda5052e2fd0c2facfae8f268a1ce Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Wed, 30 Aug 2023 11:42:03 +0200 Subject: [PATCH 19/42] Fix log formatter initialization in mgm cmd (#1112) The log format was mixed in the management command. In this commit put to earlier state the log preparation. --- management/cmd/management.go | 38 ++++++++++++++++-------------------- 1 file changed, 17 insertions(+), 21 deletions(-) diff --git a/management/cmd/management.go b/management/cmd/management.go index 2b13565a3..a4d67fc0b 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -19,28 +19,25 @@ import ( "github.com/google/uuid" "github.com/miekg/dns" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" "golang.org/x/crypto/acme/autocert" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" - - "github.com/netbirdio/netbird/management/server/activity/sqlite" - httpapi "github.com/netbirdio/netbird/management/server/http" - "github.com/netbirdio/netbird/management/server/jwtclaims" - "github.com/netbirdio/netbird/management/server/metrics" - "github.com/netbirdio/netbird/management/server/telemetry" - - "github.com/netbirdio/netbird/management/server" - "github.com/netbirdio/netbird/management/server/idp" - "github.com/netbirdio/netbird/util" - - log "github.com/sirupsen/logrus" - "github.com/spf13/cobra" "google.golang.org/grpc" "google.golang.org/grpc/credentials" "google.golang.org/grpc/keepalive" "github.com/netbirdio/netbird/encryption" mgmtProto "github.com/netbirdio/netbird/management/proto" + "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/activity/sqlite" + httpapi "github.com/netbirdio/netbird/management/server/http" + "github.com/netbirdio/netbird/management/server/idp" + "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/netbirdio/netbird/management/server/metrics" + "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/util" ) // ManagementLegacyPort is the port that was used before by the Management gRPC server. @@ -72,10 +69,15 @@ var ( Use: "management", Short: "start NetBird Management Server", PreRunE: func(cmd *cobra.Command, args []string) error { + flag.Parse() + err := util.InitLog(logLevel, logFile) + if err != nil { + return fmt.Errorf("failed initializing log %v", err) + } + // detect whether user specified a port userPort := cmd.Flag("port").Changed - var err error config, err = loadMgmtConfig(mgmtConfig) if err != nil { return fmt.Errorf("failed reading provided config file: %s: %v", mgmtConfig, err) @@ -104,13 +106,7 @@ var ( return nil }, RunE: func(cmd *cobra.Command, args []string) error { - flag.Parse() - err := util.InitLog(logLevel, logFile) - if err != nil { - return fmt.Errorf("failed initializing log %v", err) - } - - err = handleRebrand(cmd) + err := handleRebrand(cmd) if err != nil { return fmt.Errorf("failed to migrate files %v", err) } From d51dc4fd33f75122847e528faa34f9f67bb29779 Mon Sep 17 00:00:00 2001 From: Misha Bragin Date: Thu, 31 Aug 2023 17:01:32 +0200 Subject: [PATCH 20/42] Add sharedsock example (#1116) --- sharedsock/example/README.md | 35 ++++++++++++ sharedsock/example/main.go | 56 +++++++++++++++++++ sharedsock/filter.go | 2 - sharedsock/sock_linux.go | 6 +- sharedsock/sock_nolinux.go | 2 +- .../{filter_linux.go => stun_filter_linux.go} | 2 + ...lter_nolinux.go => stun_filter_nolinux.go} | 0 7 files changed, 97 insertions(+), 6 deletions(-) create mode 100644 sharedsock/example/README.md create mode 100644 sharedsock/example/main.go rename sharedsock/{filter_linux.go => stun_filter_linux.go} (98%) rename sharedsock/{filter_nolinux.go => stun_filter_nolinux.go} (100%) diff --git a/sharedsock/example/README.md b/sharedsock/example/README.md new file mode 100644 index 000000000..0632d6e1d --- /dev/null +++ b/sharedsock/example/README.md @@ -0,0 +1,35 @@ +### How to run + +This will only work on Linux + +1. Run netcat listening on the UDP port 51820. This is going to be our external process: +```bash +nc -kluvw 1 51820 +``` + +2. Build and run the example Go code: + +```bash + go build -o sharedsock && sudo ./sharedsock +``` + +3. Test the logic by sending a STUN binding request + +```bash +STUN_PACKET="000100002112A4425454" +echo -n $STUN_PACKET | xxd -r -p | nc -u -w 1 localhost 51820 +``` + +4. You should see a similar output of the Go program. Note that you'll see some binary output in the netcat server too. This is due to the fact that kernel copies packets to both processes. + +```bash + read a STUN packet of size 18 from ... +``` + +5. Send a non-STUN packet + +```bash +echo -n 'hello' | nc -u -w 1 localhost 51820 +``` + +6. The Go program won't print anything. diff --git a/sharedsock/example/main.go b/sharedsock/example/main.go new file mode 100644 index 000000000..7c879b4c9 --- /dev/null +++ b/sharedsock/example/main.go @@ -0,0 +1,56 @@ +package main + +import ( + "context" + "github.com/netbirdio/netbird/sharedsock" + log "github.com/sirupsen/logrus" + "os" + "os/signal" +) + +func main() { + + port := 51820 + rawSock, err := sharedsock.Listen(port, sharedsock.NewIncomingSTUNFilter()) + if err != nil { + panic(err) + } + + log.Infof("attached to to the raw socket on port %d", port) + + ctx, cancel := context.WithCancel(context.Background()) + // read packets + go func() { + buf := make([]byte, 1500) + for { + select { + case <-ctx.Done(): + log.Debugf("stopped reading from the shared socket") + return + default: + size, addr, err := rawSock.ReadFrom(buf) + if err != nil { + log.Errorf("error while reading packet from the shared socket: %s", err) + continue + } + log.Infof("read a STUN packet of size %d from %s", size, addr.String()) + } + } + }() + + // terminate the program on ^C + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt) + go func() { + for range c { + log.Infof("received ^C signal, stopping the program") + cancel() + err = rawSock.Close() + if err != nil { + log.Errorf("failed closing raw socket") + } + } + }() + + <-ctx.Done() +} diff --git a/sharedsock/filter.go b/sharedsock/filter.go index da27639fb..53339e93f 100644 --- a/sharedsock/filter.go +++ b/sharedsock/filter.go @@ -2,8 +2,6 @@ package sharedsock import "golang.org/x/net/bpf" -const magicCookie uint32 = 0x2112A442 - // BPFFilter is a generic filter that provides ipv4 and ipv6 BPF instructions type BPFFilter interface { // GetInstructions returns raw BPF instructions for ipv4 and ipv6 diff --git a/sharedsock/sock_linux.go b/sharedsock/sock_linux.go index 5d2b5a528..c9e35dfa2 100644 --- a/sharedsock/sock_linux.go +++ b/sharedsock/sock_linux.go @@ -67,17 +67,17 @@ func Listen(port int, filter BPFFilter) (net.PacketConn, error) { rawSock.router, err = netroute.New() if err != nil { - return nil, fmt.Errorf("failed to create router: %rawSock", err) + return nil, fmt.Errorf("failed to create raw socket router: %v", err) } rawSock.conn4, err = socket.Socket(unix.AF_INET, unix.SOCK_RAW, unix.IPPROTO_UDP, "raw_udp4", nil) if err != nil { - return nil, fmt.Errorf("socket.Socket for ipv4 failed with: %rawSock", err) + return nil, fmt.Errorf("failed to create ipv4 raw socket: %v", err) } rawSock.conn6, err = socket.Socket(unix.AF_INET6, unix.SOCK_RAW, unix.IPPROTO_UDP, "raw_udp6", nil) if err != nil { - log.Errorf("socket.Socket for ipv6 failed with: %rawSock", err) + log.Errorf("failed to create ipv6 raw socket: %v", err) } ipv4Instructions, ipv6Instructions, err := filter.GetInstructions(uint32(rawSock.port)) diff --git a/sharedsock/sock_nolinux.go b/sharedsock/sock_nolinux.go index 94a061b90..93ac6b96f 100644 --- a/sharedsock/sock_nolinux.go +++ b/sharedsock/sock_nolinux.go @@ -8,7 +8,7 @@ import ( "runtime" ) -// Listen is not supported on other platforms +// Listen is not supported on other platforms then Linux func Listen(port int, filter BPFFilter) (net.PacketConn, error) { return nil, fmt.Errorf(fmt.Sprintf("Not supported OS %s. SharedSocket is only supported on Linux", runtime.GOOS)) } diff --git a/sharedsock/filter_linux.go b/sharedsock/stun_filter_linux.go similarity index 98% rename from sharedsock/filter_linux.go rename to sharedsock/stun_filter_linux.go index 2dd3eaded..a9ece622d 100644 --- a/sharedsock/filter_linux.go +++ b/sharedsock/stun_filter_linux.go @@ -2,6 +2,8 @@ package sharedsock import "golang.org/x/net/bpf" +const magicCookie uint32 = 0x2112A442 + // IncomingSTUNFilter implements BPFFilter and filters out anything but incoming STUN packets to a specified destination port. // Other packets (non STUN) will be forwarded to the process that own the port (e.g., WireGuard). type IncomingSTUNFilter struct { diff --git a/sharedsock/filter_nolinux.go b/sharedsock/stun_filter_nolinux.go similarity index 100% rename from sharedsock/filter_nolinux.go rename to sharedsock/stun_filter_nolinux.go From f89c200ce900611ecace8be3b0885a623b8e1ca6 Mon Sep 17 00:00:00 2001 From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Fri, 1 Sep 2023 18:09:59 +0200 Subject: [PATCH 21/42] Fix api Auth with PAT when a custom UserIDClaim is configured in management.json (#1120) The API authentication with PATs was not considering different userIDClaim that some of the IdPs are using. In this PR we read the userIDClaim from the config file instead of using the fixed default and only keep it as a fallback if none in defined. --- management/server/http/handler.go | 3 ++- .../server/http/middleware/auth_middleware.go | 9 +++++++-- .../http/middleware/auth_middleware_test.go | 19 ++++++++++--------- 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/management/server/http/handler.go b/management/server/http/handler.go index c99b9b51f..6e9b029c7 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -36,7 +36,8 @@ func APIHandler(accountManager s.AccountManager, jwtValidator jwtclaims.JWTValid accountManager.GetAccountFromPAT, jwtValidator.ValidateAndParse, accountManager.MarkPATUsed, - authCfg.Audience) + authCfg.Audience, + authCfg.UserIDClaim) corsMiddleware := cors.AllowAll() diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index 898ad0875..710723124 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -32,6 +32,7 @@ type AuthMiddleware struct { validateAndParseToken ValidateAndParseTokenFunc markPATUsed MarkPATUsedFunc audience string + userIDClaim string } const ( @@ -39,12 +40,16 @@ const ( ) // NewAuthMiddleware instance constructor -func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc, markPATUsed MarkPATUsedFunc, audience string) *AuthMiddleware { +func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc, markPATUsed MarkPATUsedFunc, audience string, userIdClaim string) *AuthMiddleware { + if userIdClaim == "" { + userIdClaim = jwtclaims.UserIDClaim + } return &AuthMiddleware{ getAccountFromPAT: getAccountFromPAT, validateAndParseToken: validateAndParseToken, markPATUsed: markPATUsed, audience: audience, + userIDClaim: userIdClaim, } } @@ -127,7 +132,7 @@ func (m *AuthMiddleware) CheckPATFromRequest(w http.ResponseWriter, r *http.Requ } claimMaps := jwt.MapClaims{} - claimMaps[jwtclaims.UserIDClaim] = user.Id + claimMaps[m.userIDClaim] = user.Id claimMaps[m.audience+jwtclaims.AccountIDSuffix] = account.Id claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go index b041b12d5..8c8c941b0 100644 --- a/management/server/http/middleware/auth_middleware_test.go +++ b/management/server/http/middleware/auth_middleware_test.go @@ -13,14 +13,15 @@ import ( ) const ( - audience = "audience" - accountID = "accountID" - domain = "domain" - userID = "userID" - tokenID = "tokenID" - PAT = "PAT" - JWT = "JWT" - wrongToken = "wrongToken" + audience = "audience" + userIDClaim = "userIDClaim" + accountID = "accountID" + domain = "domain" + userID = "userID" + tokenID = "tokenID" + PAT = "PAT" + JWT = "JWT" + wrongToken = "wrongToken" ) var testAccount = &server.Account{ @@ -102,7 +103,7 @@ func TestAuthMiddleware_Handler(t *testing.T) { // do nothing }) - authMiddleware := NewAuthMiddleware(mockGetAccountFromPAT, mockValidateAndParseToken, mockMarkPATUsed, audience) + authMiddleware := NewAuthMiddleware(mockGetAccountFromPAT, mockValidateAndParseToken, mockMarkPATUsed, audience, userIDClaim) handlerToTest := authMiddleware.Handler(nextHandler) From 4e2d07541399cbfb85a4aa93d56138569e6614ba Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 4 Sep 2023 11:15:39 +0200 Subject: [PATCH 22/42] Add Wix file for MSI builds (#1099) This adds a basic wxs file to build MSI installer This file was created using docs from https://wixtoolset.org/docs/schema/wxs/ and examples from gsudo, qemu-shoggoth, and many others. The main difference between this and the .exe installer is that we don't use the netbird service command to install the daemon --- .github/workflows/release.yml | 2 +- client/netbird.wxs | 77 +++++++++++++++++++++++++++++++ infrastructure_files/configure.sh | 2 +- 3 files changed, 79 insertions(+), 2 deletions(-) create mode 100644 client/netbird.wxs diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index aaae51dde..f682fe274 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -19,7 +19,7 @@ on: - '**/Dockerfile.*' env: - SIGN_PIPE_VER: "v0.0.8" + SIGN_PIPE_VER: "v0.0.9" GORELEASER_VER: "v1.14.1" concurrency: diff --git a/client/netbird.wxs b/client/netbird.wxs new file mode 100644 index 000000000..f9b2449ba --- /dev/null +++ b/client/netbird.wxs @@ -0,0 +1,77 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/infrastructure_files/configure.sh b/infrastructure_files/configure.sh index 477528696..4e568b2fe 100755 --- a/infrastructure_files/configure.sh +++ b/infrastructure_files/configure.sh @@ -168,4 +168,4 @@ env | grep NETBIRD envsubst docker-compose.yml envsubst management.json -envsubst turnserver.conf +envsubst turnserver.conf \ No newline at end of file From c1f164c9cb722b25f7a1a7a48c81755011978d9c Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Mon, 4 Sep 2023 11:37:39 +0200 Subject: [PATCH 23/42] Feature/ephemeral peers (#1100) The ephemeral manager keep the inactive ephemeral peers in a linked list. The manager schedule a cleanup procedure to the head of the linked list (to the most deprecated peer). At the end of cleanup schedule the next cleanup to the new head. If a device connect back to the server the manager will remote it from the peers list. --- client/cmd/testutil.go | 2 +- client/internal/engine_test.go | 2 +- management/client/client_test.go | 2 +- management/cmd/management.go | 6 +- management/server/account.go | 2 +- management/server/account_test.go | 7 +- management/server/activity/event.go | 4 + management/server/ephemeral.go | 224 ++++++++++++++++++ management/server/ephemeral_test.go | 142 +++++++++++ management/server/grpcserver.go | 25 +- management/server/http/api/openapi.yml | 9 + management/server/http/api/types.gen.go | 6 + management/server/http/setupkeys_handler.go | 6 +- .../server/http/setupkeys_handler_test.go | 4 +- management/server/management_proto_test.go | 4 +- management/server/management_test.go | 2 +- management/server/mock_server/account_mock.go | 5 +- management/server/peer.go | 6 + management/server/peer_test.go | 6 +- management/server/setupkey.go | 12 +- management/server/setupkey_test.go | 18 +- 21 files changed, 455 insertions(+), 39 deletions(-) create mode 100644 management/server/ephemeral.go create mode 100644 management/server/ephemeral_test.go diff --git a/client/cmd/testutil.go b/client/cmd/testutil.go index 988ef8cc0..678072f0b 100644 --- a/client/cmd/testutil.go +++ b/client/cmd/testutil.go @@ -81,7 +81,7 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste t.Fatal(err) } turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) - mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil) + mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil) if err != nil { t.Fatal(err) } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 60c07c0c9..d1c46181c 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -1054,7 +1054,7 @@ func startManagement(dataDir string) (*grpc.Server, string, error) { return nil, "", err } turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) - mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil) + mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil) if err != nil { return nil, "", err } diff --git a/management/client/client_test.go b/management/client/client_test.go index d3d99dc85..deef57329 100644 --- a/management/client/client_test.go +++ b/management/client/client_test.go @@ -66,7 +66,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { t.Fatal(err) } turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) - mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil) + mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil) if err != nil { t.Fatal(err) } diff --git a/management/cmd/management.go b/management/cmd/management.go index a4d67fc0b..5c3816715 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -198,8 +198,11 @@ var ( return fmt.Errorf("failed creating HTTP API handler: %v", err) } + ephemeralManager := server.NewEphemeralManager(store, accountManager) + ephemeralManager.LoadInitialPeers() + gRPCAPIHandler := grpc.NewServer(gRPCOpts...) - srv, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, appMetrics) + srv, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, appMetrics, ephemeralManager) if err != nil { return fmt.Errorf("failed creating gRPC API handler: %v", err) } @@ -268,6 +271,7 @@ var ( SetupCloseHandler() <-stopCh + ephemeralManager.Stop() _ = appMetrics.Close() _ = listener.Close() if certManager != nil { diff --git a/management/server/account.go b/management/server/account.go index d9b73f713..26aeed3c5 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -49,7 +49,7 @@ func cacheEntryExpiration() time.Duration { type AccountManager interface { GetOrCreateAccountByUser(userId, domain string) (*Account, error) CreateSetupKey(accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration, - autoGroups []string, usageLimit int, userID string) (*SetupKey, error) + autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error) SaveSetupKey(accountID string, key *SetupKey, userID string) (*SetupKey, error) CreateUser(accountID, initiatorUserID string, key *UserInfo) (*UserInfo, error) DeleteUser(accountID, initiatorUserID string, targetUserID string) error diff --git a/management/server/account_test.go b/management/server/account_test.go index 29af8514a..6002b7a3a 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -12,6 +12,7 @@ import ( "time" "github.com/golang-jwt/jwt" + nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/route" @@ -782,7 +783,7 @@ func TestAccountManager_AddPeer(t *testing.T) { serial := account.Network.CurrentSerial() // should be 0 - setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID) + setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false) if err != nil { return } @@ -929,7 +930,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) { t.Fatal(err) } - setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID) + setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false) if err != nil { return } @@ -1113,7 +1114,7 @@ func TestAccountManager_DeletePeer(t *testing.T) { t.Fatal(err) } - setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID) + setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false) if err != nil { return } diff --git a/management/server/activity/event.go b/management/server/activity/event.go index 668449176..17ec4a0b0 100644 --- a/management/server/activity/event.go +++ b/management/server/activity/event.go @@ -4,6 +4,10 @@ import ( "time" ) +const ( + SystemInitiator = "sys" +) + // Event represents a network/system activity event. type Event struct { // Timestamp of the event diff --git a/management/server/ephemeral.go b/management/server/ephemeral.go new file mode 100644 index 000000000..a7b423983 --- /dev/null +++ b/management/server/ephemeral.go @@ -0,0 +1,224 @@ +package server + +import ( + "sync" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/management/server/activity" +) + +const ( + ephemeralLifeTime = 10 * time.Minute +) + +var ( + timeNow = time.Now +) + +type ephemeralPeer struct { + id string + account *Account + deadline time.Time + next *ephemeralPeer +} + +// todo: consider to remove peer from ephemeral list when the peer has been deleted via API. If we do not do it +// in worst case we will get invalid error message in this manager. + +// EphemeralManager keep a list of ephemeral peers. After ephemeralLifeTime inactivity the peer will be deleted +// automatically. Inactivity means the peer disconnected from the Management server. +type EphemeralManager struct { + store Store + accountManager AccountManager + + headPeer *ephemeralPeer + tailPeer *ephemeralPeer + peersLock sync.Mutex + timer *time.Timer +} + +// NewEphemeralManager instantiate new EphemeralManager +func NewEphemeralManager(store Store, accountManager AccountManager) *EphemeralManager { + return &EphemeralManager{ + store: store, + accountManager: accountManager, + } +} + +// LoadInitialPeers load from the database the ephemeral type of peers and schedule a cleanup procedure to the head +// of the linked list (to the most deprecated peer). At the end of cleanup it schedules the next cleanup to the new +// head. +func (e *EphemeralManager) LoadInitialPeers() { + e.peersLock.Lock() + defer e.peersLock.Unlock() + + e.loadEphemeralPeers() + if e.headPeer != nil { + e.timer = time.AfterFunc(ephemeralLifeTime, e.cleanup) + } +} + +// Stop timer +func (e *EphemeralManager) Stop() { + e.peersLock.Lock() + defer e.peersLock.Unlock() + + if e.timer != nil { + e.timer.Stop() + } +} + +// OnPeerConnected remove the peer from the linked list of ephemeral peers. Because it has been called when the peer +// is active the manager will not delete it while it is active. +func (e *EphemeralManager) OnPeerConnected(peer *Peer) { + if !peer.Ephemeral { + return + } + + log.Tracef("remove peer from ephemeral list: %s", peer.ID) + + e.peersLock.Lock() + defer e.peersLock.Unlock() + + e.removePeer(peer.ID) + + // stop the unnecessary timer + if e.headPeer == nil && e.timer != nil { + e.timer.Stop() + e.timer = nil + } +} + +// OnPeerDisconnected add the peer to the linked list of ephemeral peers. Because of the peer +// is inactive it will be deleted after the ephemeralLifeTime period. +func (e *EphemeralManager) OnPeerDisconnected(peer *Peer) { + if !peer.Ephemeral { + return + } + + log.Tracef("add peer to ephemeral list: %s", peer.ID) + + a, err := e.store.GetAccountByPeerID(peer.ID) + if err != nil { + log.Errorf("failed to add peer to ephemeral list: %s", err) + return + } + + e.peersLock.Lock() + defer e.peersLock.Unlock() + + if e.isPeerOnList(peer.ID) { + return + } + + e.addPeer(peer.ID, a, newDeadLine()) + if e.timer == nil { + e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), e.cleanup) + } +} + +func (e *EphemeralManager) loadEphemeralPeers() { + accounts := e.store.GetAllAccounts() + t := newDeadLine() + count := 0 + for _, a := range accounts { + for id, p := range a.Peers { + if p.Ephemeral { + count++ + e.addPeer(id, a, t) + } + } + } + log.Debugf("loaded ephemeral peer(s): %d", count) +} + +func (e *EphemeralManager) cleanup() { + log.Tracef("on ephemeral cleanup") + deletePeers := make(map[string]*ephemeralPeer) + + e.peersLock.Lock() + now := timeNow() + for p := e.headPeer; p != nil; p = p.next { + if now.Before(p.deadline) { + break + } + + deletePeers[p.id] = p + e.headPeer = p.next + if p.next == nil { + e.tailPeer = nil + } + } + + if e.headPeer != nil { + e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), e.cleanup) + } else { + e.timer = nil + } + + e.peersLock.Unlock() + + for id, p := range deletePeers { + log.Debugf("delete ephemeral peer: %s", id) + _, err := e.accountManager.DeletePeer(p.account.Id, id, activity.SystemInitiator) + if err != nil { + log.Tracef("failed to delete ephemeral peer: %s", err) + } + } +} + +func (e *EphemeralManager) addPeer(id string, account *Account, deadline time.Time) { + ep := &ephemeralPeer{ + id: id, + account: account, + deadline: deadline, + } + + if e.headPeer == nil { + e.headPeer = ep + } + if e.tailPeer != nil { + e.tailPeer.next = ep + } + e.tailPeer = ep +} + +func (e *EphemeralManager) removePeer(id string) { + if e.headPeer == nil { + return + } + + if e.headPeer.id == id { + e.headPeer = e.headPeer.next + if e.tailPeer.id == id { + e.tailPeer = nil + } + return + } + + for p := e.headPeer; p.next != nil; p = p.next { + if p.next.id == id { + // if we remove the last element from the chain then set the last-1 as tail + if e.tailPeer.id == id { + e.tailPeer = p + } + p.next = p.next.next + return + } + } +} + +func (e *EphemeralManager) isPeerOnList(id string) bool { + for p := e.headPeer; p != nil; p = p.next { + if p.id == id { + return true + } + } + return false +} + +func newDeadLine() time.Time { + return timeNow().Add(ephemeralLifeTime) +} diff --git a/management/server/ephemeral_test.go b/management/server/ephemeral_test.go new file mode 100644 index 000000000..554d2a028 --- /dev/null +++ b/management/server/ephemeral_test.go @@ -0,0 +1,142 @@ +package server + +import ( + "fmt" + "testing" + "time" +) + +type MockStore struct { + Store + account *Account +} + +func (s *MockStore) GetAllAccounts() []*Account { + return []*Account{s.account} +} + +func (s *MockStore) GetAccountByPeerID(peerId string) (*Account, error) { + _, ok := s.account.Peers[peerId] + if ok { + return s.account, nil + } + + return nil, fmt.Errorf("account not found") +} + +type MocAccountManager struct { + AccountManager + store *MockStore +} + +func (a MocAccountManager) DeletePeer(accountID, peerID, userID string) (*Peer, error) { + delete(a.store.account.Peers, peerID) + return nil, nil +} + +func TestNewManager(t *testing.T) { + startTime := time.Now() + timeNow = func() time.Time { + return startTime + } + + store := &MockStore{} + am := MocAccountManager{ + store: store, + } + + numberOfPeers := 5 + numberOfEphemeralPeers := 3 + seedPeers(store, numberOfPeers, numberOfEphemeralPeers) + + mgr := NewEphemeralManager(store, am) + mgr.loadEphemeralPeers() + startTime = startTime.Add(ephemeralLifeTime + 1) + mgr.cleanup() + + if len(store.account.Peers) != numberOfPeers { + t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", numberOfPeers, len(store.account.Peers)) + } +} + +func TestNewManagerPeerConnected(t *testing.T) { + startTime := time.Now() + timeNow = func() time.Time { + return startTime + } + + store := &MockStore{} + am := MocAccountManager{ + store: store, + } + + numberOfPeers := 5 + numberOfEphemeralPeers := 3 + seedPeers(store, numberOfPeers, numberOfEphemeralPeers) + + mgr := NewEphemeralManager(store, am) + mgr.loadEphemeralPeers() + mgr.OnPeerConnected(store.account.Peers["ephemeral_peer_0"]) + + startTime = startTime.Add(ephemeralLifeTime + 1) + mgr.cleanup() + + expected := numberOfPeers + 1 + if len(store.account.Peers) != expected { + t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", expected, len(store.account.Peers)) + } +} + +func TestNewManagerPeerDisconnected(t *testing.T) { + startTime := time.Now() + timeNow = func() time.Time { + return startTime + } + + store := &MockStore{} + am := MocAccountManager{ + store: store, + } + + numberOfPeers := 5 + numberOfEphemeralPeers := 3 + seedPeers(store, numberOfPeers, numberOfEphemeralPeers) + + mgr := NewEphemeralManager(store, am) + mgr.loadEphemeralPeers() + for _, v := range store.account.Peers { + mgr.OnPeerConnected(v) + + } + mgr.OnPeerDisconnected(store.account.Peers["ephemeral_peer_0"]) + + startTime = startTime.Add(ephemeralLifeTime + 1) + mgr.cleanup() + + expected := numberOfPeers + numberOfEphemeralPeers - 1 + if len(store.account.Peers) != expected { + t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", expected, len(store.account.Peers)) + } +} + +func seedPeers(store *MockStore, numberOfPeers int, numberOfEphemeralPeers int) { + store.account = newAccountWithId("my account", "", "") + + for i := 0; i < numberOfPeers; i++ { + peerId := fmt.Sprintf("peer_%d", i) + p := &Peer{ + ID: peerId, + Ephemeral: false, + } + store.account.Peers[p.ID] = p + } + + for i := 0; i < numberOfEphemeralPeers; i++ { + peerId := fmt.Sprintf("ephemeral_peer_%d", i) + p := &Peer{ + ID: peerId, + Ephemeral: true, + } + store.account.Peers[p.ID] = p + } +} diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 94cb1de9d..32b553f9b 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -7,11 +7,6 @@ import ( "time" pb "github.com/golang/protobuf/proto" // nolint - - "github.com/netbirdio/netbird/management/server/telemetry" - - "github.com/netbirdio/netbird/management/server/jwtclaims" - "github.com/golang/protobuf/ptypes/timestamp" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" @@ -21,7 +16,9 @@ import ( "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/management/proto" + "github.com/netbirdio/netbird/management/server/jwtclaims" internalStatus "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/telemetry" ) // GRPCServer an instance of a Management gRPC API server @@ -35,12 +32,11 @@ type GRPCServer struct { jwtValidator *jwtclaims.JWTValidator jwtClaimsExtractor *jwtclaims.ClaimsExtractor appMetrics telemetry.AppMetrics + ephemeralManager *EphemeralManager } // NewServer creates a new Management server -func NewServer(config *Config, accountManager AccountManager, peersUpdateManager *PeersUpdateManager, - turnCredentialsManager TURNCredentialsManager, appMetrics telemetry.AppMetrics, -) (*GRPCServer, error) { +func NewServer(config *Config, accountManager AccountManager, peersUpdateManager *PeersUpdateManager, turnCredentialsManager TURNCredentialsManager, appMetrics telemetry.AppMetrics, ephemeralManager *EphemeralManager) (*GRPCServer, error) { key, err := wgtypes.GeneratePrivateKey() if err != nil { return nil, err @@ -92,6 +88,7 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager jwtValidator: jwtValidator, jwtClaimsExtractor: jwtClaimsExtractor, appMetrics: appMetrics, + ephemeralManager: ephemeralManager, }, nil } @@ -141,6 +138,9 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi } updates := s.peersUpdateManager.CreateChannel(peer.ID) + + s.ephemeralManager.OnPeerConnected(peer) + err = s.accountManager.MarkPeerConnected(peerKey.String(), true) if err != nil { log.Warnf("failed marking peer as connected %s %v", peerKey, err) @@ -168,6 +168,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update) if err != nil { + s.cancelPeerRoutines(peer) return status.Errorf(codes.Internal, "failed processing update message") } @@ -176,6 +177,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi Body: encryptedResp, }) if err != nil { + s.cancelPeerRoutines(peer) return status.Errorf(codes.Internal, "failed sending update message") } log.Debugf("sent an update to peer %s", peerKey.String()) @@ -193,6 +195,7 @@ func (s *GRPCServer) cancelPeerRoutines(peer *Peer) { s.peersUpdateManager.CloseChannel(peer.ID) s.turnCredentialsManager.CancelRefresh(peer.ID) _ = s.accountManager.MarkPeerConnected(peer.Key, false) + s.ephemeralManager.OnPeerDisconnected(peer) } func (s *GRPCServer) validateToken(jwtToken string) (string, error) { @@ -318,11 +321,17 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p UserID: userID, SetupKey: loginReq.GetSetupKey(), }) + if err != nil { log.Warnf("failed logging in peer %s", peerKey) return nil, mapError(err) } + // if the login request contains setup key then it is a registration request + if loginReq.GetSetupKey() != "" { + s.ephemeralManager.OnPeerDisconnected(peer) + } + // if peer has reached this point then it has logged in loginResp := &proto.LoginResponse{ WiretrusteeConfig: toWiretrusteeConfig(s.config, nil), diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index a09b9f6a6..06da0ede3 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -350,6 +350,10 @@ components: description: A number of times this key can be used. The value of 0 indicates the unlimited usage. type: integer example: 0 + ephemeral: + description: Indicate that the peer will be ephemeral or not + type: boolean + example: true required: - id - key @@ -364,6 +368,7 @@ components: - auto_groups - updated_at - usage_limit + - ephemeral SetupKeyRequest: type: object properties: @@ -395,6 +400,10 @@ components: description: A number of times this key can be used. The value of 0 indicates the unlimited usage. type: integer example: 0 + ephemeral: + description: Indicate that the peer will be ephemeral or not + type: boolean + example: true required: - name - type diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index 5b629cc0e..402aae635 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -689,6 +689,9 @@ type SetupKey struct { // AutoGroups List of group IDs to auto-assign to peers registered with this key AutoGroups []string `json:"auto_groups"` + // Ephemeral Indicate that the peer will be ephemeral or not + Ephemeral bool `json:"ephemeral"` + // Expires Setup Key expiration date Expires time.Time `json:"expires"` @@ -731,6 +734,9 @@ type SetupKeyRequest struct { // AutoGroups List of group IDs to auto-assign to peers registered with this key AutoGroups []string `json:"auto_groups"` + // Ephemeral Indicate that the peer will be ephemeral or not + Ephemeral *bool `json:"ephemeral,omitempty"` + // ExpiresIn Expiration time in seconds ExpiresIn int `json:"expires_in"` diff --git a/management/server/http/setupkeys_handler.go b/management/server/http/setupkeys_handler.go index 58a3c1091..392cebdbd 100644 --- a/management/server/http/setupkeys_handler.go +++ b/management/server/http/setupkeys_handler.go @@ -71,8 +71,12 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request req.AutoGroups = []string{} } + var ephemeral bool + if req.Ephemeral != nil { + ephemeral = *req.Ephemeral + } setupKey, err := h.accountManager.CreateSetupKey(account.Id, req.Name, server.SetupKeyType(req.Type), expiresIn, - req.AutoGroups, req.UsageLimit, user.Id) + req.AutoGroups, req.UsageLimit, user.Id, ephemeral) if err != nil { util.WriteError(err, w) return diff --git a/management/server/http/setupkeys_handler_test.go b/management/server/http/setupkeys_handler_test.go index 4a5a9af62..afc9deb15 100644 --- a/management/server/http/setupkeys_handler_test.go +++ b/management/server/http/setupkeys_handler_test.go @@ -51,7 +51,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup }, user, nil }, CreateSetupKeyFunc: func(_ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string, - _ int, _ string, + _ int, _ string, _ bool, ) (*server.SetupKey, error) { if keyName == newKey.Name || typ != newKey.Type { return newKey, nil @@ -99,7 +99,7 @@ func TestSetupKeysHandlers(t *testing.T) { adminUser := server.NewAdminUser("test_user") newSetupKey := server.GenerateSetupKey(newSetupKeyName, server.SetupKeyReusable, 0, []string{"group-1"}, - server.SetupKeyUnlimitedUsage) + server.SetupKeyUnlimitedUsage, false) updatedDefaultSetupKey := defaultSetupKey.Copy() updatedDefaultSetupKey.AutoGroups = []string{"group-1"} updatedDefaultSetupKey.Name = updatedSetupKeyName diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index 6855c84bd..792d05187 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -422,7 +422,9 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error) return nil, "", err } turnManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) - mgmtServer, err := NewServer(config, accountManager, peersUpdateManager, turnManager, nil) + + ephemeralMgr := NewEphemeralManager(store, accountManager) + mgmtServer, err := NewServer(config, accountManager, peersUpdateManager, turnManager, nil, ephemeralMgr) if err != nil { return nil, "", err } diff --git a/management/server/management_test.go b/management/server/management_test.go index 7af2535f8..6c93765f4 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -508,7 +508,7 @@ func startServer(config *server.Config) (*grpc.Server, net.Listener) { log.Fatalf("failed creating a manager: %v", err) } turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) - mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil) + mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil) Expect(err).NotTo(HaveOccurred()) mgmtProto.RegisterManagementServiceServer(s, mgmtServer) go func() { diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index eb31d2a79..4bfa922c7 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -16,7 +16,7 @@ import ( type MockAccountManager struct { GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error) CreateSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType, - expiresIn time.Duration, autoGroups []string, usageLimit int, userID string) (*server.SetupKey, error) + expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error) GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error) GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error) GetUserFunc func(claims jwtclaims.AuthorizationClaims) (*server.User, error) @@ -122,9 +122,10 @@ func (am *MockAccountManager) CreateSetupKey( autoGroups []string, usageLimit int, userID string, + ephemeral bool, ) (*server.SetupKey, error) { if am.CreateSetupKeyFunc != nil { - return am.CreateSetupKeyFunc(accountID, keyName, keyType, expiresIn, autoGroups, usageLimit, userID) + return am.CreateSetupKeyFunc(accountID, keyName, keyType, expiresIn, autoGroups, usageLimit, userID, ephemeral) } return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented") } diff --git a/management/server/peer.go b/management/server/peer.go index 90377b1e8..f9631719f 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -99,6 +99,8 @@ type Peer struct { LoginExpirationEnabled bool // LastLogin the time when peer performed last login operation LastLogin time.Time + // Indicate ephemeral peer attribute + Ephemeral bool } // AddedWithSSOLogin indicates whether this peer has been added with an SSO login by a user. @@ -126,6 +128,7 @@ func (p *Peer) Copy() *Peer { SSHEnabled: p.SSHEnabled, LoginExpirationEnabled: p.LoginExpirationEnabled, LastLogin: p.LastLogin, + Ephemeral: p.Ephemeral, } } @@ -514,6 +517,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *Peer) (* AccountID: account.Id, } + var ephemeral bool if !addedByUser { // validate the setup key if adding with a key sk, err := account.FindSetupKey(upperKey) @@ -528,6 +532,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *Peer) (* account.SetupKeys[sk.Key] = sk.IncrementUsage() opEvent.InitiatorID = sk.Id opEvent.Activity = activity.PeerAddedWithSetupKey + ephemeral = sk.Ephemeral } else { opEvent.InitiatorID = userID opEvent.Activity = activity.PeerAddedByUser @@ -562,6 +567,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *Peer) (* SSHKey: peer.SSHKey, LastLogin: time.Now().UTC(), LoginExpirationEnabled: addedByUser, + Ephemeral: ephemeral, } // add peer to 'All' group diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 875aaeaba..822856e6a 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -78,7 +78,7 @@ func TestAccountManager_GetNetworkMap(t *testing.T) { t.Fatal(err) } - setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userId) + setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userId, false) if err != nil { return } @@ -331,7 +331,7 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) { t.Fatal(err) } - setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userId) + setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userId, false) if err != nil { return } @@ -405,7 +405,7 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { } // two peers one added by a regular user and one with a setup key - setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, adminUser) + setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, adminUser, false) if err != nil { return } diff --git a/management/server/setupkey.go b/management/server/setupkey.go index ffdd822e3..e857230a5 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -86,6 +86,8 @@ type SetupKey struct { // UsageLimit indicates the number of times this key can be used to enroll a machine. // The value of 0 indicates the unlimited usage. UsageLimit int + // Ephemeral indicate if the peers will be ephemeral or not + Ephemeral bool } // Copy copies SetupKey to a new object @@ -108,6 +110,7 @@ func (key *SetupKey) Copy() *SetupKey { LastUsed: key.LastUsed, AutoGroups: autoGroups, UsageLimit: key.UsageLimit, + Ephemeral: key.Ephemeral, } } @@ -162,7 +165,7 @@ func (key *SetupKey) IsOverUsed() bool { // GenerateSetupKey generates a new setup key func GenerateSetupKey(name string, t SetupKeyType, validFor time.Duration, autoGroups []string, - usageLimit int) *SetupKey { + usageLimit int, ephemeral bool) *SetupKey { key := strings.ToUpper(uuid.New().String()) limit := usageLimit if t == SetupKeyOneOff { @@ -180,13 +183,14 @@ func GenerateSetupKey(name string, t SetupKeyType, validFor time.Duration, autoG UsedTimes: 0, AutoGroups: autoGroups, UsageLimit: limit, + Ephemeral: ephemeral, } } // GenerateDefaultSetupKey generates a default reusable setup key with an unlimited usage and 30 days expiration func GenerateDefaultSetupKey() *SetupKey { return GenerateSetupKey(DefaultSetupKeyName, SetupKeyReusable, DefaultSetupKeyDuration, []string{}, - SetupKeyUnlimitedUsage) + SetupKeyUnlimitedUsage, false) } func Hash(s string) uint32 { @@ -201,7 +205,7 @@ func Hash(s string) uint32 { // CreateSetupKey generates a new setup key with a given name, type, list of groups IDs to auto-assign to peers registered with this key, // and adds it to the specified account. A list of autoGroups IDs can be empty. func (am *DefaultAccountManager) CreateSetupKey(accountID string, keyName string, keyType SetupKeyType, - expiresIn time.Duration, autoGroups []string, usageLimit int, userID string) (*SetupKey, error) { + expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error) { unlock := am.Store.AcquireAccountLock(accountID) defer unlock() @@ -221,7 +225,7 @@ func (am *DefaultAccountManager) CreateSetupKey(accountID string, keyName string } } - setupKey := GenerateSetupKey(keyName, keyType, keyDuration, autoGroups, usageLimit) + setupKey := GenerateSetupKey(keyName, keyType, keyDuration, autoGroups, usageLimit, ephemeral) account.SetupKeys[setupKey.Key] = setupKey err = am.Store.SaveAccount(account) if err != nil { diff --git a/management/server/setupkey_test.go b/management/server/setupkey_test.go index a6f318ab9..6da01bd82 100644 --- a/management/server/setupkey_test.go +++ b/management/server/setupkey_test.go @@ -37,7 +37,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { keyName := "my-test-key" key, err := manager.CreateSetupKey(account.Id, keyName, SetupKeyReusable, expiresIn, []string{}, - SetupKeyUnlimitedUsage, userID) + SetupKeyUnlimitedUsage, userID, false) if err != nil { t.Fatal(err) } @@ -136,7 +136,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { for _, tCase := range []testCase{testCase1, testCase2} { t.Run(tCase.name, func(t *testing.T) { key, err := manager.CreateSetupKey(account.Id, tCase.expectedKeyName, SetupKeyReusable, expiresIn, - tCase.expectedGroups, SetupKeyUnlimitedUsage, userID) + tCase.expectedGroups, SetupKeyUnlimitedUsage, userID, false) if tCase.expectedFailure { if err == nil { @@ -193,7 +193,7 @@ func TestGenerateSetupKey(t *testing.T) { expectedUpdatedAt := time.Now().UTC() var expectedAutoGroups []string - key := GenerateSetupKey(expectedName, SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage) + key := GenerateSetupKey(expectedName, SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt, expectedExpiresAt, strconv.Itoa(int(Hash(key.Key))), expectedUpdatedAt, expectedAutoGroups) @@ -201,33 +201,33 @@ func TestGenerateSetupKey(t *testing.T) { } func TestSetupKey_IsValid(t *testing.T) { - validKey := GenerateSetupKey("valid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage) + validKey := GenerateSetupKey("valid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) if !validKey.IsValid() { t.Errorf("expected key to be valid, got invalid %v", validKey) } // expired - expiredKey := GenerateSetupKey("invalid key", SetupKeyOneOff, -time.Hour, []string{}, SetupKeyUnlimitedUsage) + expiredKey := GenerateSetupKey("invalid key", SetupKeyOneOff, -time.Hour, []string{}, SetupKeyUnlimitedUsage, false) if expiredKey.IsValid() { t.Errorf("expected key to be invalid due to expiration, got valid %v", expiredKey) } // revoked - revokedKey := GenerateSetupKey("invalid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage) + revokedKey := GenerateSetupKey("invalid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) revokedKey.Revoked = true if revokedKey.IsValid() { t.Errorf("expected revoked key to be invalid, got valid %v", revokedKey) } // overused - overUsedKey := GenerateSetupKey("invalid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage) + overUsedKey := GenerateSetupKey("invalid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) overUsedKey.UsedTimes = 1 if overUsedKey.IsValid() { t.Errorf("expected overused key to be invalid, got valid %v", overUsedKey) } // overused - reusableKey := GenerateSetupKey("valid key", SetupKeyReusable, time.Hour, []string{}, SetupKeyUnlimitedUsage) + reusableKey := GenerateSetupKey("valid key", SetupKeyReusable, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) reusableKey.UsedTimes = 99 if !reusableKey.IsValid() { t.Errorf("expected reusable key to be valid when used many times, got valid %v", reusableKey) @@ -282,7 +282,7 @@ func assertKey(t *testing.T, key *SetupKey, expectedName string, expectedRevoke func TestSetupKey_Copy(t *testing.T) { - key := GenerateSetupKey("key name", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage) + key := GenerateSetupKey("key name", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) keyCopy := key.Copy() assertKey(t, keyCopy, key.Name, key.Revoked, string(key.Type), key.UsedTimes, key.CreatedAt, key.ExpiresAt, key.Id, From 8524cc75d66025e41d06b815e683f5a81c5caef6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A1bio=20C=2E=20Barrionuevo=20da=20Luz?= Date: Mon, 4 Sep 2023 10:49:07 -0300 Subject: [PATCH 24/42] Add safe security headers (#1121) This pull-request add/changes the HTTP headers to include safe defaults to Caddy and get the A+ score on the https://observatory.mozilla.org/ test --- .../getting-started-with-zitadel.sh | 41 +++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/infrastructure_files/getting-started-with-zitadel.sh b/infrastructure_files/getting-started-with-zitadel.sh index 74f9b6398..d00c2719c 100644 --- a/infrastructure_files/getting-started-with-zitadel.sh +++ b/infrastructure_files/getting-started-with-zitadel.sh @@ -487,7 +487,48 @@ renderCaddyfile() { } } +(security_headers) { + header * { + # enable HSTS + # https://cheatsheetseries.owasp.org/cheatsheets/HTTP_Headers_Cheat_Sheet.html#strict-transport-security-hsts + # NOTE: Read carefully how this header works before using it. + # If the HSTS header is misconfigured or if there is a problem with + # the SSL/TLS certificate being used, legitimate users might be unable + # to access the website. For example, if the HSTS header is set to a + # very long duration and the SSL/TLS certificate expires or is revoked, + # legitimate users might be unable to access the website until + # the HSTS header duration has expired. + # The recommended value for the max-age is 2 year (63072000 seconds). + # But we are using 1 hour (3600 seconds) for testing purposes + # and ensure that the website is working properly before setting + # to two years. + + Strict-Transport-Security "max-age=3600; includeSubDomains; preload" + + # disable clients from sniffing the media type + # https://cheatsheetseries.owasp.org/cheatsheets/HTTP_Headers_Cheat_Sheet.html#x-content-type-options + X-Content-Type-Options "nosniff" + + # clickjacking protection + # https://cheatsheetseries.owasp.org/cheatsheets/HTTP_Headers_Cheat_Sheet.html#x-frame-options + X-Frame-Options "DENY" + + # xss protection + # https://cheatsheetseries.owasp.org/cheatsheets/HTTP_Headers_Cheat_Sheet.html#x-xss-protection + X-XSS-Protection "1; mode=block" + + # Remove -Server header, which is an information leak + # Remove Caddy from Headers + -Server + + # keep referrer data off of HTTP connections + # https://cheatsheetseries.owasp.org/cheatsheets/HTTP_Headers_Cheat_Sheet.html#referrer-policy + Referrer-Policy strict-origin-when-cross-origin + } +} + :80${CADDY_SECURE_DOMAIN} { + import security_headers # Signal reverse_proxy /signalexchange.SignalExchange/* h2c://signal:10000 # Management From bb403259770d661dbbb493893523f327fa78a800 Mon Sep 17 00:00:00 2001 From: Yury Gargay Date: Mon, 4 Sep 2023 17:03:44 +0200 Subject: [PATCH 25/42] Update GitHub Actions and Enhance golangci-lint (#1075) This PR showcases the implementation of additional linter rules. I've updated the golangci-lint GitHub Actions to the latest available version. This update makes sure that the tool works the same way locally - assuming being updated regularly - and with the GitHub Actions. I've also taken care of keeping all the GitHub Actions up to date, which helps our code stay current. But there's one part, goreleaser that's a bit tricky to test on our computers. So, it's important to take a close look at that. To make it easier to understand what I've done, I've made separate changes for each thing that the new linters found. This should help the people reviewing the changes see what's going on more clearly. Some of the changes might not be obvious at first glance. Things to consider for the future CI runs on Ubuntu so the static analysis only happens for Linux. Consider running it for the rest: Darwin, Windows --- .github/workflows/golang-test-darwin.yml | 6 +- .github/workflows/golang-test-linux.yml | 12 ++-- .github/workflows/golangci-lint.yml | 9 ++- .github/workflows/release.yml | 48 +++++++++------- .../workflows/test-infrastructure-files.yml | 4 +- .golangci.yaml | 54 ++++++++++++++++++ base62/base62.go | 3 +- client/cmd/status.go | 6 +- client/internal/acl/manager.go | 3 +- client/internal/config_test.go | 3 - client/internal/dns/server_test.go | 2 +- client/internal/engine.go | 12 ++-- client/internal/engine_test.go | 2 +- client/internal/routemanager/client.go | 5 +- client/server/server.go | 5 +- iface/module_linux.go | 2 +- iface/wg_configurer_nonandroid.go | 3 - management/server/account.go | 4 +- management/server/account_test.go | 12 ---- management/server/activity/sqlite/sqlite.go | 56 ++++++++++++------- management/server/ephemeral_test.go | 2 +- management/server/http/groups_handler.go | 6 +- .../http/middleware/auth_middleware_test.go | 6 +- management/server/idp/auth0_test.go | 1 + management/server/idp/authentik.go | 19 +++++-- management/server/idp/authentik_test.go | 8 ++- management/server/idp/google_workspace.go | 10 ++-- management/server/idp/idp.go | 2 +- management/server/idp/keycloak_test.go | 1 + management/server/idp/zitadel_test.go | 1 + management/server/jwtclaims/jwtValidator.go | 4 +- management/server/management_proto_test.go | 5 -- management/server/peer_test.go | 13 +---- management/server/route.go | 13 ++--- management/server/user.go | 34 +++++------ 35 files changed, 215 insertions(+), 161 deletions(-) create mode 100644 .golangci.yaml diff --git a/.github/workflows/golang-test-darwin.yml b/.github/workflows/golang-test-darwin.yml index 6cdaa239b..97fdeabe8 100644 --- a/.github/workflows/golang-test-darwin.yml +++ b/.github/workflows/golang-test-darwin.yml @@ -15,14 +15,14 @@ jobs: runs-on: macos-latest steps: - name: Install Go - uses: actions/setup-go@v2 + uses: actions/setup-go@v4 with: go-version: "1.20.x" - name: Checkout code - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Cache Go modules - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ~/go/pkg/mod key: macos-go-${{ hashFiles('**/go.sum') }} diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index d21ecd784..13061f6eb 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -18,13 +18,13 @@ jobs: runs-on: ubuntu-latest steps: - name: Install Go - uses: actions/setup-go@v2 + uses: actions/setup-go@v4 with: go-version: "1.20.x" - name: Cache Go modules - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} @@ -32,7 +32,7 @@ jobs: ${{ runner.os }}-go- - name: Checkout code - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Install dependencies run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib @@ -47,13 +47,13 @@ jobs: runs-on: ubuntu-latest steps: - name: Install Go - uses: actions/setup-go@v2 + uses: actions/setup-go@v4 with: go-version: "1.20.x" - name: Cache Go modules - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} @@ -61,7 +61,7 @@ jobs: ${{ runner.os }}-go- - name: Checkout code - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Install dependencies run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index f5d1835ce..2a5c51c8a 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -8,14 +8,13 @@ jobs: name: lint runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - name: Checkout code + uses: actions/checkout@v3 - name: Install Go - uses: actions/setup-go@v2 + uses: actions/setup-go@v4 with: go-version: "1.20.x" - name: Install dependencies run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev - name: golangci-lint - uses: golangci/golangci-lint-action@v2 - with: - args: --timeout=6m \ No newline at end of file + uses: golangci/golangci-lint-action@v3 \ No newline at end of file diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index f682fe274..3feefdd49 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -29,20 +29,24 @@ concurrency: jobs: release: runs-on: ubuntu-latest + env: + flags: "" steps: + - if: ${{ !startsWith(github.ref, 'refs/tags/v') }} + run: echo "flags=--snapshot" >> $GITHUB_ENV - name: Checkout - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: fetch-depth: 0 # It is required for GoReleaser to work properly - name: Set up Go - uses: actions/setup-go@v2 + uses: actions/setup-go@v4 with: go-version: "1.20" - name: Cache Go modules - uses: actions/cache@v1 + uses: actions/cache@v3 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} @@ -56,10 +60,10 @@ jobs: run: git --no-pager diff --exit-code - name: Set up QEMU - uses: docker/setup-qemu-action@v1 + uses: docker/setup-qemu-action@v2 - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v1 + uses: docker/setup-buildx-action@v2 - name: Login to Docker hub if: github.event_name != 'pull_request' @@ -82,10 +86,10 @@ jobs: run: rsrc -arch 386 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_386.syso - name: Run GoReleaser - uses: goreleaser/goreleaser-action@v2 + uses: goreleaser/goreleaser-action@v4 with: version: ${{ env.GORELEASER_VER }} - args: release --rm-dist + args: release --rm-dist ${{ env.flags }} env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }} @@ -93,7 +97,7 @@ jobs: UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} - name: upload non tags for debug purposes - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: release path: dist/ @@ -102,17 +106,19 @@ jobs: release_ui: runs-on: ubuntu-latest steps: + - if: ${{ !startsWith(github.ref, 'refs/tags/v') }} + run: echo "flags=--snapshot" >> $GITHUB_ENV - name: Checkout - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: fetch-depth: 0 # It is required for GoReleaser to work properly - name: Set up Go - uses: actions/setup-go@v2 + uses: actions/setup-go@v4 with: go-version: "1.20" - name: Cache Go modules - uses: actions/cache@v1 + uses: actions/cache@v3 with: path: ~/go/pkg/mod key: ${{ runner.os }}-ui-go-${{ hashFiles('**/go.sum') }} @@ -132,17 +138,17 @@ jobs: - name: Generate windows rsrc run: rsrc -arch amd64 -ico client/ui/netbird.ico -manifest client/ui/manifest.xml -o client/ui/resources_windows_amd64.syso - name: Run GoReleaser - uses: goreleaser/goreleaser-action@v2 + uses: goreleaser/goreleaser-action@v4 with: version: ${{ env.GORELEASER_VER }} - args: release --config .goreleaser_ui.yaml --rm-dist + args: release --config .goreleaser_ui.yaml --rm-dist ${{ env.flags }} env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }} UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} - name: upload non tags for debug purposes - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: release-ui path: dist/ @@ -151,19 +157,21 @@ jobs: release_ui_darwin: runs-on: macos-11 steps: + - if: ${{ !startsWith(github.ref, 'refs/tags/v') }} + run: echo "flags=--snapshot" >> $GITHUB_ENV - name: Checkout - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: fetch-depth: 0 # It is required for GoReleaser to work properly - name: Set up Go - uses: actions/setup-go@v2 + uses: actions/setup-go@v4 with: go-version: "1.20" - name: Cache Go modules - uses: actions/cache@v1 + uses: actions/cache@v3 with: path: ~/go/pkg/mod key: ${{ runner.os }}-ui-go-${{ hashFiles('**/go.sum') }} @@ -175,15 +183,15 @@ jobs: - name: Run GoReleaser id: goreleaser - uses: goreleaser/goreleaser-action@v2 + uses: goreleaser/goreleaser-action@v4 with: version: ${{ env.GORELEASER_VER }} - args: release --config .goreleaser_ui_darwin.yaml --rm-dist + args: release --config .goreleaser_ui_darwin.yaml --rm-dist ${{ env.flags }} env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - name: upload non tags for debug purposes - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: release-ui-darwin path: dist/ diff --git a/.github/workflows/test-infrastructure-files.yml b/.github/workflows/test-infrastructure-files.yml index fdebc882e..d196ce0e0 100644 --- a/.github/workflows/test-infrastructure-files.yml +++ b/.github/workflows/test-infrastructure-files.yml @@ -24,12 +24,12 @@ jobs: run: sudo apt-get install -y curl - name: Install Go - uses: actions/setup-go@v2 + uses: actions/setup-go@v4 with: go-version: "1.20.x" - name: Cache Go modules - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} diff --git a/.golangci.yaml b/.golangci.yaml new file mode 100644 index 000000000..5034db708 --- /dev/null +++ b/.golangci.yaml @@ -0,0 +1,54 @@ +run: + # Timeout for analysis, e.g. 30s, 5m. + # Default: 1m + timeout: 6m + +# This file contains only configs which differ from defaults. +# All possible options can be found here https://github.com/golangci/golangci-lint/blob/master/.golangci.reference.yml +linters-settings: + errcheck: + # Report about not checking of errors in type assertions: `a := b.(MyStruct)`. + # Such cases aren't reported by default. + # Default: false + check-type-assertions: false + + govet: + # Enable all analyzers. + # Default: false + enable-all: false + enable: + - nilness + +linters: + disable-all: true + enable: + ## enabled by default + - errcheck # checking for unchecked errors, these unchecked errors can be critical bugs in some cases + - gosimple # specializes in simplifying a code + - govet # reports suspicious constructs, such as Printf calls whose arguments do not align with the format string + - ineffassign # detects when assignments to existing variables are not used + - staticcheck # is a go vet on steroids, applying a ton of static analysis checks + - typecheck # like the front-end of a Go compiler, parses and type-checks Go code + - unused # checks for unused constants, variables, functions and types + ## disable by default but the have interesting results so lets add them + - bodyclose # checks whether HTTP response body is closed successfully + - nilerr # finds the code that returns nil even if it checks that the error is not nil + - nilnil # checks that there is no simultaneous return of nil error and an invalid value + - sqlclosecheck # checks that sql.Rows and sql.Stmt are closed + - wastedassign # wastedassign finds wasted assignment statements +issues: + # Maximum count of issues with the same text. + # Set to 0 to disable. + # Default: 3 + max-same-issues: 5 + + exclude-rules: + - path: sharedsock/filter.go + linters: + - unused + - path: client/firewall/iptables/rule.go + linters: + - unused + - path: mock.go + linters: + - nilnil \ No newline at end of file diff --git a/base62/base62.go b/base62/base62.go index d2525f704..efafbc768 100644 --- a/base62/base62.go +++ b/base62/base62.go @@ -18,10 +18,9 @@ func Encode(num uint32) string { } var encoded strings.Builder - remainder := uint32(0) for num > 0 { - remainder = num % base + remainder := num % base encoded.WriteByte(alphabet[remainder]) num /= base } diff --git a/client/cmd/status.go b/client/cmd/status.go index 5d741462b..9dfd042f8 100644 --- a/client/cmd/status.go +++ b/client/cmd/status.go @@ -109,9 +109,9 @@ func statusFunc(cmd *cobra.Command, args []string) error { ctx := internal.CtxInitState(context.Background()) - resp, _ := getStatus(ctx, cmd) + resp, err := getStatus(ctx, cmd) if err != nil { - return nil + return err } if resp.GetStatus() == string(internal.StatusNeedsLogin) || resp.GetStatus() == string(internal.StatusLoginFailed) { @@ -133,7 +133,7 @@ func statusFunc(cmd *cobra.Command, args []string) error { outputInformationHolder := convertToStatusOutputOverview(resp) - statusOutputString := "" + var statusOutputString string switch { case detailFlag: statusOutputString = parseToFullDetailSummary(outputInformationHolder) diff --git a/client/internal/acl/manager.go b/client/internal/acl/manager.go index d4b6930a0..9a9e624d6 100644 --- a/client/internal/acl/manager.go +++ b/client/internal/acl/manager.go @@ -146,12 +146,11 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) { // if this rule is member of rule selection with more than DefaultIPsCountForSet // it's IP address can be used in the ipset for firewall manager which supports it ipset := ipsetByRuleSelectors[d.getRuleGroupingSelector(r)] - ipsetName := "" if ipset.name == "" { d.ipsetCounter++ ipset.name = fmt.Sprintf("nb%07d", d.ipsetCounter) } - ipsetName = ipset.name + ipsetName := ipset.name pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName) if err != nil { log.Errorf("failed to apply firewall rule: %+v, %v", r, err) diff --git a/client/internal/config_test.go b/client/internal/config_test.go index 25e8f7b2e..8bd8d8d61 100644 --- a/client/internal/config_test.go +++ b/client/internal/config_test.go @@ -23,9 +23,6 @@ func TestGetConfig(t *testing.T) { assert.Equal(t, config.ManagementURL.String(), DefaultManagementURL) assert.Equal(t, config.AdminURL.String(), DefaultAdminURL) - if err != nil { - return - } managementURL := "https://test.management.url:33071" adminURL := "https://app.admin.url:443" path := filepath.Join(t.TempDir(), "config.json") diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 73349de89..119ac684c 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -777,7 +777,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) { newNet, err := stdnet.NewNet(nil) if err != nil { t.Fatalf("create stdnet: %v", err) - return nil, nil + return nil, err } wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", iface.DefaultMTU, nil, newNet) diff --git a/client/internal/engine.go b/client/internal/engine.go index 038f39e5c..8a6c08642 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -992,14 +992,12 @@ func (e *Engine) parseNATExternalIPMappings() []string { log.Warnf("invalid external IP, %s, ignoring external IP mapping '%s'", external, mapping) break } - if externalIP != nil { - mappedIP := externalIP.String() - if internalIP != nil { - mappedIP = mappedIP + "/" + internalIP.String() - } - mappedIPs = append(mappedIPs, mappedIP) - log.Infof("parsed external IP mapping of '%s' as '%s'", mapping, mappedIP) + mappedIP := externalIP.String() + if internalIP != nil { + mappedIP = mappedIP + "/" + internalIP.String() } + mappedIPs = append(mappedIPs, mappedIP) + log.Infof("parsed external IP mapping of '%s' as '%s'", mapping, mappedIP) } if len(mappedIPs) != len(e.config.NATExternalIPs) { log.Warnf("one or more external IP mappings failed to parse, ignoring all mappings") diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index d1c46181c..9f17ff36b 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -1046,7 +1046,7 @@ func startManagement(dataDir string) (*grpc.Server, string, error) { peersUpdateManager := server.NewPeersUpdateManager() eventStore := &activity.InMemoryEventStore{} if err != nil { - return nil, "", nil + return nil, "", err } accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "", eventStore) diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index cf4cbe91b..62fe4dfc1 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -155,7 +155,10 @@ func (c *clientNetwork) startPeersStatusChangeWatcher() { func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error { state, err := c.statusRecorder.GetPeer(peerKey) - if err != nil || state.ConnStatus != peer.StatusConnected { + if err != nil { + return err + } + if state.ConnStatus != peer.StatusConnected { return nil } diff --git a/client/server/server.go b/client/server/server.go index b7cca947f..6748f62ab 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -3,10 +3,11 @@ package server import ( "context" "fmt" - "github.com/netbirdio/netbird/client/internal/auth" "sync" "time" + "github.com/netbirdio/netbird/client/internal/auth" + log "github.com/sirupsen/logrus" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" @@ -315,7 +316,7 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin tokenInfo, err := s.oauthAuthFlow.flow.WaitToken(waitCTX, flowInfo) if err != nil { if err == context.Canceled { - return nil, nil + return nil, nil //nolint:nilnil } s.mutex.Lock() s.oauthAuthFlow.expiresAt = time.Now() diff --git a/iface/module_linux.go b/iface/module_linux.go index 5f244d2c3..e943c0ba7 100644 --- a/iface/module_linux.go +++ b/iface/module_linux.go @@ -161,7 +161,7 @@ func getModulePath(name string) (string, error) { } if err != nil { // skip broken files - return nil + return nil //nolint:nilerr } if !info.Type().IsRegular() { diff --git a/iface/wg_configurer_nonandroid.go b/iface/wg_configurer_nonandroid.go index 70ec5dc04..6749c0966 100644 --- a/iface/wg_configurer_nonandroid.go +++ b/iface/wg_configurer_nonandroid.go @@ -146,9 +146,6 @@ func (c *wGConfigurer) removeAllowedIP(peerKey string, allowedIP string) error { } } - if err != nil { - return err - } peer := wgtypes.PeerConfig{ PublicKey: peerKeyParsed, UpdateOnly: true, diff --git a/management/server/account.go b/management/server/account.go index 26aeed3c5..4c707af3a 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1022,7 +1022,7 @@ func (am *DefaultAccountManager) lookupUserInCacheByEmail(email string, accountI } } - return nil, nil + return nil, nil //nolint:nilnil } // lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil @@ -1045,7 +1045,7 @@ func (am *DefaultAccountManager) lookupUserInCache(userID string, account *Accou } } - return nil, nil + return nil, nil //nolint:nilnil } func (am *DefaultAccountManager) refreshCache(accountID string) ([]*idp.UserData, error) { diff --git a/management/server/account_test.go b/management/server/account_test.go index 6002b7a3a..64fd90524 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -784,10 +784,6 @@ func TestAccountManager_AddPeer(t *testing.T) { serial := account.Network.CurrentSerial() // should be 0 setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false) - if err != nil { - return - } - if err != nil { t.Fatal("error creating setup key") return @@ -931,10 +927,6 @@ func TestAccountManager_NetworkUpdates(t *testing.T) { } setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false) - if err != nil { - return - } - if err != nil { t.Fatal("error creating setup key") return @@ -1115,10 +1107,6 @@ func TestAccountManager_DeletePeer(t *testing.T) { } setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false) - if err != nil { - return - } - if err != nil { t.Fatal("error creating setup key") return diff --git a/management/server/activity/sqlite/sqlite.go b/management/server/activity/sqlite/sqlite.go index 02c2143e4..a4c85cf60 100644 --- a/management/server/activity/sqlite/sqlite.go +++ b/management/server/activity/sqlite/sqlite.go @@ -3,13 +3,14 @@ package sqlite import ( "database/sql" "encoding/json" - "fmt" + "github.com/netbirdio/netbird/management/server/activity" // sqlite driver - _ "github.com/mattn/go-sqlite3" "path/filepath" "time" + + _ "github.com/mattn/go-sqlite3" ) const ( @@ -24,15 +25,20 @@ const ( "meta TEXT," + " target_id TEXT);" - selectStatement = "SELECT id, activity, timestamp, initiator_id, target_id, account_id, meta" + - " FROM events WHERE account_id = ? ORDER BY timestamp %s LIMIT ? OFFSET ?;" - insertStatement = "INSERT INTO events(activity, timestamp, initiator_id, target_id, account_id, meta) " + + selectDescQuery = "SELECT id, activity, timestamp, initiator_id, target_id, account_id, meta" + + " FROM events WHERE account_id = ? ORDER BY timestamp DESC LIMIT ? OFFSET ?;" + selectAscQuery = "SELECT id, activity, timestamp, initiator_id, target_id, account_id, meta" + + " FROM events WHERE account_id = ? ORDER BY timestamp ASC LIMIT ? OFFSET ?;" + insertQuery = "INSERT INTO events(activity, timestamp, initiator_id, target_id, account_id, meta) " + "VALUES(?, ?, ?, ?, ?, ?)" ) // Store is the implementation of the activity.Store interface backed by SQLite type Store struct { - db *sql.DB + db *sql.DB + insertStatement *sql.Stmt + selectAscStatement *sql.Stmt + selectDescStatement *sql.Stmt } // NewSQLiteStore creates a new Store with an event table if not exists. @@ -48,7 +54,27 @@ func NewSQLiteStore(dataDir string) (*Store, error) { return nil, err } - return &Store{db: db}, nil + insertStmt, err := db.Prepare(insertQuery) + if err != nil { + return nil, err + } + + selectDescStmt, err := db.Prepare(selectDescQuery) + if err != nil { + return nil, err + } + + selectAscStmt, err := db.Prepare(selectAscQuery) + if err != nil { + return nil, err + } + + return &Store{ + db: db, + insertStatement: insertStmt, + selectDescStatement: selectDescStmt, + selectAscStatement: selectAscStmt, + }, nil } func processResult(result *sql.Rows) ([]*activity.Event, error) { @@ -90,13 +116,9 @@ func processResult(result *sql.Rows) ([]*activity.Event, error) { // Get returns "limit" number of events from index ordered descending or ascending by a timestamp func (store *Store) Get(accountID string, offset, limit int, descending bool) ([]*activity.Event, error) { - order := "DESC" + stmt := store.selectDescStatement if !descending { - order = "ASC" - } - stmt, err := store.db.Prepare(fmt.Sprintf(selectStatement, order)) - if err != nil { - return nil, err + stmt = store.selectAscStatement } result, err := stmt.Query(accountID, limit, offset) @@ -110,12 +132,6 @@ func (store *Store) Get(accountID string, offset, limit int, descending bool) ([ // Save an event in the SQLite events table func (store *Store) Save(event *activity.Event) (*activity.Event, error) { - - stmt, err := store.db.Prepare(insertStatement) - if err != nil { - return nil, err - } - var jsonMeta string if event.Meta != nil { metaBytes, err := json.Marshal(event.Meta) @@ -125,7 +141,7 @@ func (store *Store) Save(event *activity.Event) (*activity.Event, error) { jsonMeta = string(metaBytes) } - result, err := stmt.Exec(event.Activity, event.Timestamp, event.InitiatorID, event.TargetID, event.AccountID, jsonMeta) + result, err := store.insertStatement.Exec(event.Activity, event.Timestamp, event.InitiatorID, event.TargetID, event.AccountID, jsonMeta) if err != nil { return nil, err } diff --git a/management/server/ephemeral_test.go b/management/server/ephemeral_test.go index 554d2a028..a763f4cef 100644 --- a/management/server/ephemeral_test.go +++ b/management/server/ephemeral_test.go @@ -31,7 +31,7 @@ type MocAccountManager struct { func (a MocAccountManager) DeletePeer(accountID, peerID, userID string) (*Peer, error) { delete(a.store.account.Peers, peerID) - return nil, nil + return nil, nil //nolint:nilnil } func TestNewManager(t *testing.T) { diff --git a/management/server/http/groups_handler.go b/management/server/http/groups_handler.go index 30c78e21b..d409623df 100644 --- a/management/server/http/groups_handler.go +++ b/management/server/http/groups_handler.go @@ -231,10 +231,8 @@ func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(w, toGroupResponse(account, group)) default: - if err != nil { - util.WriteError(status.Errorf(status.NotFound, "HTTP method not found"), w) - return - } + util.WriteError(status.Errorf(status.NotFound, "HTTP method not found"), w) + return } } diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go index 8c8c941b0..608bf42fa 100644 --- a/management/server/http/middleware/auth_middleware_test.go +++ b/management/server/http/middleware/auth_middleware_test.go @@ -115,8 +115,10 @@ func TestAuthMiddleware_Handler(t *testing.T) { handlerToTest.ServeHTTP(rec, req) - if rec.Result().StatusCode != tc.expectedStatusCode { - t.Errorf("expected status code %d, got %d", tc.expectedStatusCode, rec.Result().StatusCode) + result := rec.Result() + defer result.Body.Close() + if result.StatusCode != tc.expectedStatusCode { + t.Errorf("expected status code %d, got %d", tc.expectedStatusCode, result.StatusCode) } }) } diff --git a/management/server/idp/auth0_test.go b/management/server/idp/auth0_test.go index 0814b4b69..0be401e65 100644 --- a/management/server/idp/auth0_test.go +++ b/management/server/idp/auth0_test.go @@ -133,6 +133,7 @@ func TestAuth0_RequestJWTToken(t *testing.T) { t.Fatal(err) } } + defer res.Body.Close() body, err := io.ReadAll(res.Body) assert.NoError(t, err, "unable to read the response body") diff --git a/management/server/idp/authentik.go b/management/server/idp/authentik.go index 586348fee..0898f1c94 100644 --- a/management/server/idp/authentik.go +++ b/management/server/idp/authentik.go @@ -3,10 +3,6 @@ package idp import ( "context" "fmt" - "github.com/golang-jwt/jwt" - "github.com/netbirdio/netbird/management/server/telemetry" - log "github.com/sirupsen/logrus" - "goauthentik.io/api/v3" "io" "net/http" "net/url" @@ -14,6 +10,11 @@ import ( "strings" "sync" "time" + + "github.com/golang-jwt/jwt" + "github.com/netbirdio/netbird/management/server/telemetry" + log "github.com/sirupsen/logrus" + "goauthentik.io/api/v3" ) // AuthentikManager authentik manager client instance. @@ -236,6 +237,7 @@ func (am *AuthentikManager) UpdateUserAppMetadata(userID string, appMetadata App if err != nil { return err } + defer resp.Body.Close() if am.appMetrics != nil { am.appMetrics.IDPMetrics().CountUpdateUserAppMetadata() @@ -267,6 +269,7 @@ func (am *AuthentikManager) GetUserDataByID(userID string, appMetadata AppMetada if err != nil { return nil, err } + defer resp.Body.Close() if am.appMetrics != nil { am.appMetrics.IDPMetrics().CountGetUserDataByID() @@ -294,6 +297,7 @@ func (am *AuthentikManager) GetAccount(accountID string) ([]*UserData, error) { if err != nil { return nil, err } + defer resp.Body.Close() if am.appMetrics != nil { am.appMetrics.IDPMetrics().CountGetAccount() @@ -330,6 +334,7 @@ func (am *AuthentikManager) GetAllAccounts() (map[string][]*UserData, error) { if err != nil { return nil, err } + defer resp.Body.Close() if am.appMetrics != nil { am.appMetrics.IDPMetrics().CountGetAllAccounts() @@ -389,6 +394,7 @@ func (am *AuthentikManager) CreateUser(email, name, accountID, invitedByEmail st if err != nil { return nil, err } + defer resp.Body.Close() if am.appMetrics != nil { am.appMetrics.IDPMetrics().CountCreateUser() @@ -416,6 +422,7 @@ func (am *AuthentikManager) GetUserByEmail(email string) ([]*UserData, error) { if err != nil { return nil, err } + defer resp.Body.Close() if am.appMetrics != nil { am.appMetrics.IDPMetrics().CountGetUserByEmail() @@ -469,10 +476,11 @@ func (am *AuthentikManager) getUserGroupByName(name string) (string, error) { return "", err } - groupList, _, err := am.apiClient.CoreApi.CoreGroupsList(ctx).Name(name).Execute() + groupList, resp, err := am.apiClient.CoreApi.CoreGroupsList(ctx).Name(name).Execute() if err != nil { return "", err } + defer resp.Body.Close() if groupList != nil { if len(groupList.Results) > 0 { @@ -485,6 +493,7 @@ func (am *AuthentikManager) getUserGroupByName(name string) (string, error) { if err != nil { return "", err } + defer resp.Body.Close() if resp.StatusCode != http.StatusCreated { return "", fmt.Errorf("unable to create user group, statusCode: %d", resp.StatusCode) diff --git a/management/server/idp/authentik_test.go b/management/server/idp/authentik_test.go index 5cf8f2b2c..c70a84efd 100644 --- a/management/server/idp/authentik_test.go +++ b/management/server/idp/authentik_test.go @@ -2,13 +2,14 @@ package idp import ( "fmt" - "github.com/netbirdio/netbird/management/server/telemetry" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "io" "strings" "testing" "time" + + "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNewAuthentikManager(t *testing.T) { @@ -133,6 +134,7 @@ func TestAuthentikRequestJWTToken(t *testing.T) { t.Fatal(err) } } else { + defer resp.Body.Close() body, err := io.ReadAll(resp.Body) assert.NoError(t, err, "unable to read the response body") diff --git a/management/server/idp/google_workspace.go b/management/server/idp/google_workspace.go index efe457fdd..2e65497dc 100644 --- a/management/server/idp/google_workspace.go +++ b/management/server/idp/google_workspace.go @@ -4,15 +4,16 @@ import ( "context" "encoding/base64" "fmt" + "net/http" + "strings" + "time" + "github.com/netbirdio/netbird/management/server/telemetry" log "github.com/sirupsen/logrus" "golang.org/x/oauth2/google" admin "google.golang.org/api/admin/directory/v1" "google.golang.org/api/googleapi" "google.golang.org/api/option" - "net/http" - "strings" - "time" ) // GoogleWorkspaceManager Google Workspace manager client instance. @@ -271,7 +272,8 @@ func getGoogleCredentials(serviceAccountKey string) (*google.Credentials, error) admin.AdminDirectoryUserScope, ) if err == nil { - return creds, err + // No need to fallback to the default Google credentials path + return creds, nil } log.Debugf("failed to retrieve Google credentials from ServiceAccountKey: %v", err) diff --git a/management/server/idp/idp.go b/management/server/idp/idp.go index 48afd5c32..a1b55b183 100644 --- a/management/server/idp/idp.go +++ b/management/server/idp/idp.go @@ -92,7 +92,7 @@ func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error) switch strings.ToLower(config.ManagerType) { case "none", "": - return nil, nil + return nil, nil //nolint:nilnil case "auth0": auth0ClientConfig := config.Auth0ClientCredentials if config.ClientConfig != nil { diff --git a/management/server/idp/keycloak_test.go b/management/server/idp/keycloak_test.go index 115306d7d..0c33fc137 100644 --- a/management/server/idp/keycloak_test.go +++ b/management/server/idp/keycloak_test.go @@ -145,6 +145,7 @@ func TestKeycloakRequestJWTToken(t *testing.T) { t.Fatal(err) } } else { + defer resp.Body.Close() body, err := io.ReadAll(resp.Body) assert.NoError(t, err, "unable to read the response body") diff --git a/management/server/idp/zitadel_test.go b/management/server/idp/zitadel_test.go index 28d26aedd..b558bba73 100644 --- a/management/server/idp/zitadel_test.go +++ b/management/server/idp/zitadel_test.go @@ -124,6 +124,7 @@ func TestZitadelRequestJWTToken(t *testing.T) { t.Fatal(err) } } else { + defer resp.Body.Close() body, err := io.ReadAll(resp.Body) assert.NoError(t, err, "unable to read the response body") diff --git a/management/server/jwtclaims/jwtValidator.go b/management/server/jwtclaims/jwtValidator.go index d15327566..b564e4f4e 100644 --- a/management/server/jwtclaims/jwtValidator.go +++ b/management/server/jwtclaims/jwtValidator.go @@ -141,7 +141,7 @@ func (m *JWTValidator) ValidateAndParse(token string) (*jwt.Token, error) { if m.options.CredentialsOptional { log.Debugf("no credentials found (CredentialsOptional=true)") // No error, just no token (and that is ok given that CredentialsOptional is true) - return nil, nil + return nil, nil //nolint:nilnil } // If we get here, the required token is missing @@ -219,7 +219,7 @@ func getPemCert(token *jwt.Token, jwks *Jwks) (string, error) { return generatePemFromJWK(jwks.Keys[k]) } - return "", errors.New("unable to find appropriate key") + return cert, errors.New("unable to find appropriate key") } func generatePemFromJWK(jwk JSONWebKey) (string, error) { diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index 792d05187..66661dbf8 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -141,11 +141,6 @@ func Test_SyncProtocol(t *testing.T) { return } - if err != nil { - t.Fatal(err) - return - } - sync, err := client.Sync(context.TODO(), &mgmtProto.EncryptedMessage{ WgPubKey: key.PublicKey().String(), Body: message, diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 822856e6a..36e96df43 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -79,10 +79,6 @@ func TestAccountManager_GetNetworkMap(t *testing.T) { } setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userId, false) - if err != nil { - return - } - if err != nil { t.Fatal("error creating setup key") return @@ -332,10 +328,6 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) { } setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userId, false) - if err != nil { - return - } - if err != nil { t.Fatal("error creating setup key") return @@ -406,14 +398,11 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { // two peers one added by a regular user and one with a setup key setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, adminUser, false) - if err != nil { - return - } - if err != nil { t.Fatal("error creating setup key") return } + peerKey1, err := wgtypes.GeneratePrivateKey() if err != nil { t.Fatal(err) diff --git a/management/server/route.go b/management/server/route.go index c02729a72..f51b7c2db 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -1,15 +1,16 @@ package server import ( + "net/netip" + "strconv" + "unicode/utf8" + "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/route" "github.com/rs/xid" log "github.com/sirupsen/logrus" - "net/netip" - "strconv" - "unicode/utf8" ) const ( @@ -104,12 +105,6 @@ func (am *DefaultAccountManager) checkPrefixPeerExists(accountID, peerID string, routesWithPrefix := account.GetRoutesByPrefix(prefix) - if err != nil { - if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { - return nil - } - return status.Errorf(status.InvalidArgument, "failed to parse prefix %s", prefix.String()) - } for _, prefixRoute := range routesWithPrefix { if prefixRoute.Peer == peerID { return status.Errorf(status.AlreadyExists, "failed to add route with prefix %s - peer already has this route", prefix.String()) diff --git a/management/server/user.go b/management/server/user.go index b3556957d..8ee036df7 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -405,13 +405,13 @@ func (am *DefaultAccountManager) CreatePAT(accountID string, initiatorUserID str return nil, err } - targetUser := account.Users[targetUserID] - if targetUser == nil { - return nil, status.Errorf(status.NotFound, "targetUser not found") + targetUser, ok := account.Users[targetUserID] + if !ok { + return nil, status.Errorf(status.NotFound, "user not found") } - executingUser := account.Users[initiatorUserID] - if targetUser == nil { + executingUser, ok := account.Users[initiatorUserID] + if !ok { return nil, status.Errorf(status.NotFound, "user not found") } @@ -447,13 +447,13 @@ func (am *DefaultAccountManager) DeletePAT(accountID string, initiatorUserID str return status.Errorf(status.NotFound, "account not found: %s", err) } - targetUser := account.Users[targetUserID] - if targetUser == nil { + targetUser, ok := account.Users[targetUserID] + if !ok { return status.Errorf(status.NotFound, "user not found") } - executingUser := account.Users[initiatorUserID] - if targetUser == nil { + executingUser, ok := account.Users[initiatorUserID] + if !ok { return status.Errorf(status.NotFound, "user not found") } @@ -497,13 +497,13 @@ func (am *DefaultAccountManager) GetPAT(accountID string, initiatorUserID string return nil, status.Errorf(status.NotFound, "account not found: %s", err) } - targetUser := account.Users[targetUserID] - if targetUser == nil { + targetUser, ok := account.Users[targetUserID] + if !ok { return nil, status.Errorf(status.NotFound, "user not found") } - executingUser := account.Users[initiatorUserID] - if targetUser == nil { + executingUser, ok := account.Users[initiatorUserID] + if !ok { return nil, status.Errorf(status.NotFound, "user not found") } @@ -529,13 +529,13 @@ func (am *DefaultAccountManager) GetAllPATs(accountID string, initiatorUserID st return nil, status.Errorf(status.NotFound, "account not found: %s", err) } - targetUser := account.Users[targetUserID] - if targetUser == nil { + targetUser, ok := account.Users[targetUserID] + if !ok { return nil, status.Errorf(status.NotFound, "user not found") } - executingUser := account.Users[initiatorUserID] - if targetUser == nil { + executingUser, ok := account.Users[initiatorUserID] + if !ok { return nil, status.Errorf(status.NotFound, "user not found") } From bdb8383485e4484420bd3af9730f6ccd90597c5a Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Tue, 5 Sep 2023 14:40:40 +0200 Subject: [PATCH 26/42] Use github token to read api (#1125) prevent failing tests by using a github token to perform requests in our CI/CD --- .github/workflows/install-script-test.yml | 1 + release_files/install.sh | 16 +++++++++++++--- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/.github/workflows/install-script-test.yml b/.github/workflows/install-script-test.yml index ab07899b5..dfb8a279b 100644 --- a/.github/workflows/install-script-test.yml +++ b/.github/workflows/install-script-test.yml @@ -27,6 +27,7 @@ jobs: env: SKIP_UI_APP: ${{ matrix.skip_ui_mode }} USE_BIN_INSTALL: ${{ matrix.install_binary }} + GITHUB_TOKEN: ${{ secrets.RO_API_CALLER_TOKEN }} run: | [ "$SKIP_UI_APP" == "false" ] && export XDG_CURRENT_DESKTOP="none" cat release_files/install.sh | sh -x diff --git a/release_files/install.sh b/release_files/install.sh index 99cab9e26..3df085016 100755 --- a/release_files/install.sh +++ b/release_files/install.sh @@ -19,8 +19,14 @@ PACKAGE_MANAGER="bin" INSTALL_DIR="" get_latest_release() { - curl -s "https://api.github.com/repos/${OWNER}/${REPO}/releases/latest" \ - | grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/' + if [ -n "$GITHUB_TOKEN" ]; then + curl -H "Authorization: token ${GITHUB_TOKEN}" -s "https://api.github.com/repos/${OWNER}/${REPO}/releases/latest" \ + | grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/' + else + curl -s "https://api.github.com/repos/${OWNER}/${REPO}/releases/latest" \ + | grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/' + fi + } download_release_binary() { @@ -45,7 +51,11 @@ download_release_binary() { DOWNLOAD_URL="${BASE_URL}/${VERSION}/${BINARY_NAME}" echo "Installing $1 from $DOWNLOAD_URL" - cd /tmp && curl -LO "$DOWNLOAD_URL" + if [ -n "$GITHUB_TOKEN" ]; then + cd /tmp && curl -H "Authorization: token ${GITHUB_TOKEN}" -LO "$DOWNLOAD_URL" + else + cd /tmp && curl -LO "$DOWNLOAD_URL" + fi if [ "$OS_TYPE" = "darwin" ] && [ "$1" = "$UI_APP" ]; then From e4bc76c4de414b140885b1d6ea19c06281357ca8 Mon Sep 17 00:00:00 2001 From: Misha Bragin Date: Tue, 5 Sep 2023 15:41:50 +0300 Subject: [PATCH 27/42] Ignore empty fields in the app metadata when storing on IDP (#1122) --- management/server/idp/auth0_test.go | 6 +++--- management/server/idp/idp.go | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/management/server/idp/auth0_test.go b/management/server/idp/auth0_test.go index 0be401e65..63c634d4e 100644 --- a/management/server/idp/auth0_test.go +++ b/management/server/idp/auth0_test.go @@ -344,7 +344,7 @@ func TestAuth0_UpdateUserAppMetadata(t *testing.T) { updateUserAppMetadataTestCase2 := updateUserAppMetadataTest{ name: "Bad Status Code", inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp), - expectedReqBody: fmt.Sprintf("{\"app_metadata\":{\"wt_account_id\":\"%s\",\"wt_pending_invite\":null,\"wt_invited_by_email\":\"\"}}", appMetadata.WTAccountID), + expectedReqBody: fmt.Sprintf("{\"app_metadata\":{\"wt_account_id\":\"%s\"}}", appMetadata.WTAccountID), appMetadata: appMetadata, statusCode: 400, helper: JsonParser{}, @@ -367,7 +367,7 @@ func TestAuth0_UpdateUserAppMetadata(t *testing.T) { updateUserAppMetadataTestCase4 := updateUserAppMetadataTest{ name: "Good request", inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp), - expectedReqBody: fmt.Sprintf("{\"app_metadata\":{\"wt_account_id\":\"%s\",\"wt_pending_invite\":null,\"wt_invited_by_email\":\"\"}}", appMetadata.WTAccountID), + expectedReqBody: fmt.Sprintf("{\"app_metadata\":{\"wt_account_id\":\"%s\"}}", appMetadata.WTAccountID), appMetadata: appMetadata, statusCode: 200, helper: JsonParser{}, @@ -379,7 +379,7 @@ func TestAuth0_UpdateUserAppMetadata(t *testing.T) { updateUserAppMetadataTestCase5 := updateUserAppMetadataTest{ name: "Update Pending Invite", inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp), - expectedReqBody: fmt.Sprintf("{\"app_metadata\":{\"wt_account_id\":\"%s\",\"wt_pending_invite\":true,\"wt_invited_by_email\":\"\"}}", appMetadata.WTAccountID), + expectedReqBody: fmt.Sprintf("{\"app_metadata\":{\"wt_account_id\":\"%s\",\"wt_pending_invite\":true}}", appMetadata.WTAccountID), appMetadata: AppMetadata{ WTAccountID: "ok", WTPendingInvite: &invite, diff --git a/management/server/idp/idp.go b/management/server/idp/idp.go index a1b55b183..3c1f4c327 100644 --- a/management/server/idp/idp.go +++ b/management/server/idp/idp.go @@ -71,8 +71,8 @@ type AppMetadata struct { // WTAccountID is a NetBird (previously Wiretrustee) account id to update in the IDP // maps to wt_account_id when json.marshal WTAccountID string `json:"wt_account_id,omitempty"` - WTPendingInvite *bool `json:"wt_pending_invite"` - WTInvitedBy string `json:"wt_invited_by_email"` + WTPendingInvite *bool `json:"wt_pending_invite,omitempty"` + WTInvitedBy string `json:"wt_invited_by_email,omitempty"` } // JWTToken a JWT object that holds information of a token From 246abda46d9047b4ed6c2bd70bcb05b29fd050c3 Mon Sep 17 00:00:00 2001 From: Givi Khojanashvili Date: Tue, 5 Sep 2023 23:07:32 +0400 Subject: [PATCH 28/42] Add default firewall rule to allow netbird traffic (#1056) Add a default firewall rule to allow netbird traffic to be handled by the access control managers. Userspace manager behavior: - When running on Windows, a default rule is add on Windows firewall - For Linux, we are using one of the Kernel managers to add a single rule - This PR doesn't handle macOS Kernel manager behavior: - For NFtables, if there is a filter table, an INPUT rule is added - Iptables follows the previous flow if running on kernel mode. If running on userspace mode, it adds a single rule for INPUT and OUTPUT chains A new checkerFW package has been introduced to consolidate checks across route and access control managers. It supports a new environment variable to skip nftables and allow iptables tests --- .gitignore | 1 + client/firewall/firewall.go | 3 + client/firewall/iptables/manager_linux.go | 64 ++++++--- .../firewall/iptables/manager_linux_test.go | 8 +- client/firewall/nftables/manager_linux.go | 129 ++++++++++++++++-- client/firewall/uspfilter/allow_netbird.go | 19 +++ .../firewall/uspfilter/allow_netbird_linux.go | 21 +++ .../uspfilter/allow_netbird_windows.go | 91 ++++++++++++ client/firewall/uspfilter/uspfilter.go | 20 ++- client/firewall/uspfilter/uspfilter_test.go | 8 ++ client/internal/acl/manager_create.go | 5 + client/internal/acl/manager_create_linux.go | 62 +++++++-- client/internal/acl/manager_test.go | 40 ++++-- client/internal/checkfw/check.go | 3 + client/internal/checkfw/check_linux.go | 56 ++++++++ .../internal/routemanager/firewall_linux.go | 28 ++-- .../routemanager/firewall_nonlinux.go | 7 +- .../internal/routemanager/iptables_linux.go | 32 ++--- .../routemanager/iptables_linux_test.go | 5 +- client/internal/routemanager/manager.go | 4 +- client/internal/routemanager/manager_test.go | 39 +++++- .../internal/routemanager/nftables_linux.go | 43 +----- .../routemanager/nftables_linux_test.go | 31 +++-- .../routemanager/server_nonandroid.go | 2 +- 24 files changed, 568 insertions(+), 153 deletions(-) create mode 100644 client/firewall/uspfilter/allow_netbird.go create mode 100644 client/firewall/uspfilter/allow_netbird_linux.go create mode 100644 client/firewall/uspfilter/allow_netbird_windows.go create mode 100644 client/internal/checkfw/check.go create mode 100644 client/internal/checkfw/check_linux.go diff --git a/.gitignore b/.gitignore index 50bbbbe3f..dc62780ad 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,4 @@ client/.distfiles/ infrastructure_files/setup.env infrastructure_files/setup-*.env .vscode +.DS_Store \ No newline at end of file diff --git a/client/firewall/firewall.go b/client/firewall/firewall.go index 5d003e2f0..59e672a45 100644 --- a/client/firewall/firewall.go +++ b/client/firewall/firewall.go @@ -40,6 +40,9 @@ const ( // It declares methods which handle actions required by the // Netbird client for ACL and routing functionality type Manager interface { + // AllowNetbird allows netbird interface traffic + AllowNetbird() error + // AddFiltering rule to the firewall // // If comment argument is empty firewall manager should set diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index fa51122af..753282d87 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -44,6 +44,7 @@ type Manager struct { type iFaceMapper interface { Name() string Address() iface.WGAddress + IsUserspaceBind() bool } type ruleset struct { @@ -52,7 +53,7 @@ type ruleset struct { } // Create iptables firewall manager -func Create(wgIface iFaceMapper) (*Manager, error) { +func Create(wgIface iFaceMapper, ipv6Supported bool) (*Manager, error) { m := &Manager{ wgIface: wgIface, inputDefaultRuleSpecs: []string{ @@ -62,26 +63,26 @@ func Create(wgIface iFaceMapper) (*Manager, error) { rulesets: make(map[string]ruleset), } - if err := ipset.Init(); err != nil { + err := ipset.Init() + if err != nil { return nil, fmt.Errorf("init ipset: %w", err) } // init clients for booth ipv4 and ipv6 - ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) + m.ipv4Client, err = iptables.NewWithProtocol(iptables.ProtocolIPv4) if err != nil { return nil, fmt.Errorf("iptables is not installed in the system or not supported") } - if isIptablesClientAvailable(ipv4Client) { - m.ipv4Client = ipv4Client + + if ipv6Supported { + m.ipv6Client, err = iptables.NewWithProtocol(iptables.ProtocolIPv6) + if err != nil { + log.Warnf("ip6tables is not installed in the system or not supported: %v. Access rules for this protocol won't be applied.", err) + } } - ipv6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6) - if err != nil { - log.Errorf("ip6tables is not installed in the system or not supported: %v", err) - } else { - if isIptablesClientAvailable(ipv6Client) { - m.ipv6Client = ipv6Client - } + if m.ipv4Client == nil && m.ipv6Client == nil { + return nil, fmt.Errorf("iptables is not installed in the system or not enough permissions to use it") } if err := m.Reset(); err != nil { @@ -90,11 +91,6 @@ func Create(wgIface iFaceMapper) (*Manager, error) { return m, nil } -func isIptablesClientAvailable(client *iptables.IPTables) bool { - _, err := client.ListChains("filter") - return err == nil -} - // AddFiltering rule to the firewall // // If comment is empty rule ID is used as comment @@ -276,6 +272,38 @@ func (m *Manager) Reset() error { return nil } +// AllowNetbird allows netbird interface traffic +func (m *Manager) AllowNetbird() error { + if m.wgIface.IsUserspaceBind() { + _, err := m.AddFiltering( + net.ParseIP("0.0.0.0"), + "all", + nil, + nil, + fw.RuleDirectionIN, + fw.ActionAccept, + "", + "allow netbird interface traffic", + ) + if err != nil { + return fmt.Errorf("failed to allow netbird interface traffic: %w", err) + } + _, err = m.AddFiltering( + net.ParseIP("0.0.0.0"), + "all", + nil, + nil, + fw.RuleDirectionOUT, + fw.ActionAccept, + "", + "allow netbird interface traffic", + ) + return err + } + + return nil +} + // Flush doesn't need to be implemented for this manager func (m *Manager) Flush() error { return nil } @@ -406,7 +434,7 @@ func (m *Manager) client(ip net.IP) (*iptables.IPTables, error) { return nil, fmt.Errorf("failed to create default drop all in netbird input chain: %w", err) } - if err := client.AppendUnique("filter", "INPUT", m.inputDefaultRuleSpecs...); err != nil { + if err := client.Insert("filter", "INPUT", 1, m.inputDefaultRuleSpecs...); err != nil { return nil, fmt.Errorf("failed to create input chain jump rule: %w", err) } diff --git a/client/firewall/iptables/manager_linux_test.go b/client/firewall/iptables/manager_linux_test.go index 84e27ed14..2d2013aa2 100644 --- a/client/firewall/iptables/manager_linux_test.go +++ b/client/firewall/iptables/manager_linux_test.go @@ -33,6 +33,8 @@ func (i *iFaceMock) Address() iface.WGAddress { panic("AddressFunc is not set") } +func (i *iFaceMock) IsUserspaceBind() bool { return false } + func TestIptablesManager(t *testing.T) { ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) require.NoError(t, err) @@ -53,7 +55,7 @@ func TestIptablesManager(t *testing.T) { } // just check on the local interface - manager, err := Create(mock) + manager, err := Create(mock, true) require.NoError(t, err) time.Sleep(time.Second) @@ -141,7 +143,7 @@ func TestIptablesManagerIPSet(t *testing.T) { } // just check on the local interface - manager, err := Create(mock) + manager, err := Create(mock, true) require.NoError(t, err) time.Sleep(time.Second) @@ -229,7 +231,7 @@ func TestIptablesCreatePerformance(t *testing.T) { for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { // just check on the local interface - manager, err := Create(mock) + manager, err := Create(mock, true) require.NoError(t, err) time.Sleep(time.Second) diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index 081aee48d..2273f4edc 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -29,6 +29,8 @@ const ( // FilterOutputChainName is the name of the chain that is used for filtering outgoing packets FilterOutputChainName = "netbird-acl-output-filter" + + AllowNetbirdInputRuleID = "allow Netbird incoming traffic" ) var anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} @@ -379,7 +381,7 @@ func (m *Manager) chain( if c != nil { return c, nil } - return m.createChainIfNotExists(tf, name, hook, priority, cType) + return m.createChainIfNotExists(tf, FilterTableName, name, hook, priority, cType) } if ip.To4() != nil { @@ -399,13 +401,20 @@ func (m *Manager) chain( } // table returns the table for the given family of the IP address -func (m *Manager) table(family nftables.TableFamily) (*nftables.Table, error) { +func (m *Manager) table( + family nftables.TableFamily, tableName string, +) (*nftables.Table, error) { + // we cache access to Netbird ACL table only + if tableName != FilterTableName { + return m.createTableIfNotExists(nftables.TableFamilyIPv4, tableName) + } + if family == nftables.TableFamilyIPv4 { if m.tableIPv4 != nil { return m.tableIPv4, nil } - table, err := m.createTableIfNotExists(nftables.TableFamilyIPv4) + table, err := m.createTableIfNotExists(nftables.TableFamilyIPv4, tableName) if err != nil { return nil, err } @@ -417,7 +426,7 @@ func (m *Manager) table(family nftables.TableFamily) (*nftables.Table, error) { return m.tableIPv6, nil } - table, err := m.createTableIfNotExists(nftables.TableFamilyIPv6) + table, err := m.createTableIfNotExists(nftables.TableFamilyIPv6, tableName) if err != nil { return nil, err } @@ -425,19 +434,21 @@ func (m *Manager) table(family nftables.TableFamily) (*nftables.Table, error) { return m.tableIPv6, nil } -func (m *Manager) createTableIfNotExists(family nftables.TableFamily) (*nftables.Table, error) { +func (m *Manager) createTableIfNotExists( + family nftables.TableFamily, tableName string, +) (*nftables.Table, error) { tables, err := m.rConn.ListTablesOfFamily(family) if err != nil { return nil, fmt.Errorf("list of tables: %w", err) } for _, t := range tables { - if t.Name == FilterTableName { + if t.Name == tableName { return t, nil } } - table := m.rConn.AddTable(&nftables.Table{Name: FilterTableName, Family: nftables.TableFamilyIPv4}) + table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4}) if err := m.rConn.Flush(); err != nil { return nil, err } @@ -446,12 +457,13 @@ func (m *Manager) createTableIfNotExists(family nftables.TableFamily) (*nftables func (m *Manager) createChainIfNotExists( family nftables.TableFamily, + tableName string, name string, hooknum nftables.ChainHook, priority nftables.ChainPriority, chainType nftables.ChainType, ) (*nftables.Chain, error) { - table, err := m.table(family) + table, err := m.table(family, tableName) if err != nil { return nil, err } @@ -638,6 +650,22 @@ func (m *Manager) Reset() error { return fmt.Errorf("list of chains: %w", err) } for _, c := range chains { + // delete Netbird allow input traffic rule if it exists + if c.Table.Name == "filter" && c.Name == "INPUT" { + rules, err := m.rConn.GetRules(c.Table, c) + if err != nil { + log.Errorf("get rules for chain %q: %v", c.Name, err) + continue + } + for _, r := range rules { + if bytes.Equal(r.UserData, []byte(AllowNetbirdInputRuleID)) { + if err := m.rConn.DelRule(r); err != nil { + log.Errorf("delete rule: %v", err) + } + } + } + } + if c.Name == FilterInputChainName || c.Name == FilterOutputChainName { m.rConn.DelChain(c) } @@ -702,6 +730,53 @@ func (m *Manager) Flush() error { return nil } +// AllowNetbird allows netbird interface traffic +func (m *Manager) AllowNetbird() error { + m.mutex.Lock() + defer m.mutex.Unlock() + + tf := nftables.TableFamilyIPv4 + if m.wgIface.Address().IP.To4() == nil { + tf = nftables.TableFamilyIPv6 + } + + chains, err := m.rConn.ListChainsOfTableFamily(tf) + if err != nil { + return fmt.Errorf("list of chains: %w", err) + } + + var chain *nftables.Chain + for _, c := range chains { + if c.Table.Name == "filter" && c.Name == "INPUT" { + chain = c + break + } + } + + if chain == nil { + log.Debugf("chain INPUT not found. Skiping add allow netbird rule") + return nil + } + + rules, err := m.rConn.GetRules(chain.Table, chain) + if err != nil { + return fmt.Errorf("failed to get rules for the INPUT chain: %v", err) + } + + if rule := m.detectAllowNetbirdRule(rules); rule != nil { + log.Debugf("allow netbird rule already exists: %v", rule) + return nil + } + + m.applyAllowNetbirdRules(chain) + + err = m.rConn.Flush() + if err != nil { + return fmt.Errorf("failed to flush allow input netbird rules: %v", err) + } + return nil +} + func (m *Manager) flushWithBackoff() (err error) { backoff := 4 backoffTime := 1000 * time.Millisecond @@ -745,6 +820,44 @@ func (m *Manager) refreshRuleHandles(table *nftables.Table, chain *nftables.Chai return nil } +func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) { + rule := &nftables.Rule{ + Table: chain.Table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: ifname(m.wgIface.Name()), + }, + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + }, + UserData: []byte(AllowNetbirdInputRuleID), + } + _ = m.rConn.InsertRule(rule) +} + +func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftables.Rule { + ifName := ifname(m.wgIface.Name()) + for _, rule := range existedRules { + if rule.Table.Name == "filter" && rule.Chain.Name == "INPUT" { + if len(rule.Exprs) < 4 { + if e, ok := rule.Exprs[0].(*expr.Meta); !ok || e.Key != expr.MetaKeyIIFNAME { + continue + } + if e, ok := rule.Exprs[1].(*expr.Cmp); !ok || e.Op != expr.CmpOpEq || !bytes.Equal(e.Data, ifName) { + continue + } + return rule + } + } + } + return nil +} + func encodePort(port fw.Port) []byte { bs := make([]byte, 2) binary.BigEndian.PutUint16(bs, uint16(port.Values[0])) diff --git a/client/firewall/uspfilter/allow_netbird.go b/client/firewall/uspfilter/allow_netbird.go new file mode 100644 index 000000000..ccfef1861 --- /dev/null +++ b/client/firewall/uspfilter/allow_netbird.go @@ -0,0 +1,19 @@ +//go:build !windows && !linux + +package uspfilter + +// Reset firewall to the default state +func (m *Manager) Reset() error { + m.mutex.Lock() + defer m.mutex.Unlock() + + m.outgoingRules = make(map[string]RuleSet) + m.incomingRules = make(map[string]RuleSet) + + return nil +} + +// AllowNetbird allows netbird interface traffic +func (m *Manager) AllowNetbird() error { + return nil +} diff --git a/client/firewall/uspfilter/allow_netbird_linux.go b/client/firewall/uspfilter/allow_netbird_linux.go new file mode 100644 index 000000000..5df48c756 --- /dev/null +++ b/client/firewall/uspfilter/allow_netbird_linux.go @@ -0,0 +1,21 @@ +package uspfilter + +// AllowNetbird allows netbird interface traffic +func (m *Manager) AllowNetbird() error { + return nil +} + +// Reset firewall to the default state +func (m *Manager) Reset() error { + m.mutex.Lock() + defer m.mutex.Unlock() + + m.outgoingRules = make(map[string]RuleSet) + m.incomingRules = make(map[string]RuleSet) + + if m.resetHook != nil { + return m.resetHook() + } + + return nil +} diff --git a/client/firewall/uspfilter/allow_netbird_windows.go b/client/firewall/uspfilter/allow_netbird_windows.go new file mode 100644 index 000000000..05a6d22ae --- /dev/null +++ b/client/firewall/uspfilter/allow_netbird_windows.go @@ -0,0 +1,91 @@ +package uspfilter + +import ( + "errors" + "fmt" + "os/exec" + "strings" + "syscall" +) + +type action string + +const ( + addRule action = "add" + deleteRule action = "delete" + + firewallRuleName = "Netbird" + noRulesMatchCriteria = "No rules match the specified criteria" +) + +// Reset firewall to the default state +func (m *Manager) Reset() error { + m.mutex.Lock() + defer m.mutex.Unlock() + + m.outgoingRules = make(map[string]RuleSet) + m.incomingRules = make(map[string]RuleSet) + + if err := manageFirewallRule(firewallRuleName, deleteRule); err != nil { + return fmt.Errorf("couldn't remove windows firewall: %w", err) + } + + return nil +} + +// AllowNetbird allows netbird interface traffic +func (m *Manager) AllowNetbird() error { + return manageFirewallRule(firewallRuleName, + addRule, + "dir=in", + "enable=yes", + "action=allow", + "profile=any", + "localip="+m.wgIface.Address().IP.String(), + ) +} + +func manageFirewallRule(ruleName string, action action, args ...string) error { + active, err := isFirewallRuleActive(ruleName) + if err != nil { + return err + } + + if (action == addRule && !active) || (action == deleteRule && active) { + baseArgs := []string{"advfirewall", "firewall", string(action), "rule", "name=" + ruleName} + args := append(baseArgs, args...) + + cmd := exec.Command("netsh", args...) + cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true} + return cmd.Run() + } + + return nil +} + +func isFirewallRuleActive(ruleName string) (bool, error) { + args := []string{"advfirewall", "firewall", "show", "rule", "name=" + ruleName} + + cmd := exec.Command("netsh", args...) + cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true} + output, err := cmd.Output() + if err != nil { + var exitError *exec.ExitError + if errors.As(err, &exitError) { + // if the firewall rule is not active, we expect last exit code to be 1 + exitStatus := exitError.Sys().(syscall.WaitStatus).ExitStatus() + if exitStatus == 1 { + if strings.Contains(string(output), noRulesMatchCriteria) { + return false, nil + } + } + } + return false, err + } + + if strings.Contains(string(output), noRulesMatchCriteria) { + return false, nil + } + + return true, nil +} diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index 3dead1db4..50170b46c 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -19,6 +19,7 @@ const layerTypeAll = 0 // IFaceMapper defines subset methods of interface required for manager type IFaceMapper interface { SetFilter(iface.PacketFilter) error + Address() iface.WGAddress } // RuleSet is a set of rules grouped by a string key @@ -30,6 +31,8 @@ type Manager struct { incomingRules map[string]RuleSet wgNetwork *net.IPNet decoders sync.Pool + wgIface IFaceMapper + resetHook func() error mutex sync.RWMutex } @@ -65,6 +68,7 @@ func Create(iface IFaceMapper) (*Manager, error) { }, outgoingRules: make(map[string]RuleSet), incomingRules: make(map[string]RuleSet), + wgIface: iface, } if err := iface.SetFilter(m); err != nil { @@ -171,17 +175,6 @@ func (m *Manager) DeleteRule(rule fw.Rule) error { return nil } -// Reset firewall to the default state -func (m *Manager) Reset() error { - m.mutex.Lock() - defer m.mutex.Unlock() - - m.outgoingRules = make(map[string]RuleSet) - m.incomingRules = make(map[string]RuleSet) - - return nil -} - // Flush doesn't need to be implemented for this manager func (m *Manager) Flush() error { return nil } @@ -375,3 +368,8 @@ func (m *Manager) RemovePacketHook(hookID string) error { } return fmt.Errorf("hook with given id not found") } + +// SetResetHook which will be executed in the end of Reset method +func (m *Manager) SetResetHook(hook func() error) { + m.resetHook = hook +} diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go index bc94f59c1..6b3d334a8 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/uspfilter_test.go @@ -16,6 +16,7 @@ import ( type IFaceMock struct { SetFilterFunc func(iface.PacketFilter) error + AddressFunc func() iface.WGAddress } func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error { @@ -25,6 +26,13 @@ func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error { return i.SetFilterFunc(iface) } +func (i *IFaceMock) Address() iface.WGAddress { + if i.AddressFunc == nil { + return iface.WGAddress{} + } + return i.AddressFunc() +} + func TestManagerCreate(t *testing.T) { ifaceMock := &IFaceMock{ SetFilterFunc: func(iface.PacketFilter) error { return nil }, diff --git a/client/internal/acl/manager_create.go b/client/internal/acl/manager_create.go index 7d9e6b430..c573d2c64 100644 --- a/client/internal/acl/manager_create.go +++ b/client/internal/acl/manager_create.go @@ -6,6 +6,8 @@ import ( "fmt" "runtime" + log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/firewall/uspfilter" ) @@ -17,6 +19,9 @@ func Create(iface IFaceMapper) (manager *DefaultManager, err error) { if err != nil { return nil, err } + if err := fm.AllowNetbird(); err != nil { + log.Errorf("failed to allow netbird interface traffic: %v", err) + } return newDefaultManager(fm), nil } return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS) diff --git a/client/internal/acl/manager_create_linux.go b/client/internal/acl/manager_create_linux.go index de4e8adb9..4342463d3 100644 --- a/client/internal/acl/manager_create_linux.go +++ b/client/internal/acl/manager_create_linux.go @@ -7,26 +7,68 @@ import ( "github.com/netbirdio/netbird/client/firewall/iptables" "github.com/netbirdio/netbird/client/firewall/nftables" "github.com/netbirdio/netbird/client/firewall/uspfilter" + "github.com/netbirdio/netbird/client/internal/checkfw" ) // Create creates a firewall manager instance for the Linux -func Create(iface IFaceMapper) (manager *DefaultManager, err error) { +func Create(iface IFaceMapper) (*DefaultManager, error) { + // on the linux system we try to user nftables or iptables + // in any case, because we need to allow netbird interface traffic + // so we use AllowNetbird traffic from these firewall managers + // for the userspace packet filtering firewall var fm firewall.Manager + var err error + + checkResult := checkfw.Check() + switch checkResult { + case checkfw.IPTABLES, checkfw.IPTABLESWITHV6: + log.Debug("creating an iptables firewall manager for access control") + ipv6Supported := checkResult == checkfw.IPTABLESWITHV6 + if fm, err = iptables.Create(iface, ipv6Supported); err != nil { + log.Infof("failed to create iptables manager for access control: %s", err) + } + case checkfw.NFTABLES: + log.Debug("creating an nftables firewall manager for access control") + if fm, err = nftables.Create(iface); err != nil { + log.Debugf("failed to create nftables manager for access control: %s", err) + } + } + + var resetHookForUserspace func() error + if fm != nil && err == nil { + // err shadowing is used here, to ignore this error + if err := fm.AllowNetbird(); err != nil { + log.Errorf("failed to allow netbird interface traffic: %v", err) + } + resetHookForUserspace = fm.Reset + } + if iface.IsUserspaceBind() { // use userspace packet filtering firewall - if fm, err = uspfilter.Create(iface); err != nil { + usfm, err := uspfilter.Create(iface) + if err != nil { log.Debugf("failed to create userspace filtering firewall: %s", err) return nil, err } - } else { - if fm, err = nftables.Create(iface); err != nil { - log.Debugf("failed to create nftables manager: %s", err) - // fallback to iptables - if fm, err = iptables.Create(iface); err != nil { - log.Errorf("failed to create iptables manager: %s", err) - return nil, err - } + + // set kernel space firewall Reset as hook for userspace firewall + // manager Reset method, to clean up + if resetHookForUserspace != nil { + usfm.SetResetHook(resetHookForUserspace) } + + // to be consistent for any future extensions. + // ignore this error + if err := usfm.AllowNetbird(); err != nil { + log.Errorf("failed to allow netbird interface traffic: %v", err) + } + fm = usfm + } + + if fm == nil || err != nil { + log.Errorf("failed to create firewall manager: %s", err) + // no firewall manager found or initialized correctly + return nil, err } return newDefaultManager(fm), nil diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go index d765e5c6c..518e895cf 100644 --- a/client/internal/acl/manager_test.go +++ b/client/internal/acl/manager_test.go @@ -1,11 +1,13 @@ package acl import ( + "net" "testing" "github.com/golang/mock/gomock" "github.com/netbirdio/netbird/client/internal/acl/mocks" + "github.com/netbirdio/netbird/iface" mgmProto "github.com/netbirdio/netbird/management/proto" ) @@ -32,13 +34,22 @@ func TestDefaultManager(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - iface := mocks.NewMockIFaceMapper(ctrl) - iface.EXPECT().IsUserspaceBind().Return(true) - // iface.EXPECT().Name().Return("lo") - iface.EXPECT().SetFilter(gomock.Any()) + ifaceMock := mocks.NewMockIFaceMapper(ctrl) + ifaceMock.EXPECT().IsUserspaceBind().Return(true) + ifaceMock.EXPECT().SetFilter(gomock.Any()) + ip, network, err := net.ParseCIDR("172.0.0.1/32") + if err != nil { + t.Fatalf("failed to parse IP address: %v", err) + } + + ifaceMock.EXPECT().Name().Return("lo").AnyTimes() + ifaceMock.EXPECT().Address().Return(iface.WGAddress{ + IP: ip, + Network: network, + }).AnyTimes() // we receive one rule from the management so for testing purposes ignore it - acl, err := Create(iface) + acl, err := Create(ifaceMock) if err != nil { t.Errorf("create ACL manager: %v", err) return @@ -311,13 +322,22 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - iface := mocks.NewMockIFaceMapper(ctrl) - iface.EXPECT().IsUserspaceBind().Return(true) - // iface.EXPECT().Name().Return("lo") - iface.EXPECT().SetFilter(gomock.Any()) + ifaceMock := mocks.NewMockIFaceMapper(ctrl) + ifaceMock.EXPECT().IsUserspaceBind().Return(true) + ifaceMock.EXPECT().SetFilter(gomock.Any()) + ip, network, err := net.ParseCIDR("172.0.0.1/32") + if err != nil { + t.Fatalf("failed to parse IP address: %v", err) + } + + ifaceMock.EXPECT().Name().Return("lo").AnyTimes() + ifaceMock.EXPECT().Address().Return(iface.WGAddress{ + IP: ip, + Network: network, + }).AnyTimes() // we receive one rule from the management so for testing purposes ignore it - acl, err := Create(iface) + acl, err := Create(ifaceMock) if err != nil { t.Errorf("create ACL manager: %v", err) return diff --git a/client/internal/checkfw/check.go b/client/internal/checkfw/check.go new file mode 100644 index 000000000..59626cbc3 --- /dev/null +++ b/client/internal/checkfw/check.go @@ -0,0 +1,3 @@ +//go:build !linux + +package checkfw diff --git a/client/internal/checkfw/check_linux.go b/client/internal/checkfw/check_linux.go new file mode 100644 index 000000000..552d5698c --- /dev/null +++ b/client/internal/checkfw/check_linux.go @@ -0,0 +1,56 @@ +//go:build !android + +package checkfw + +import ( + "os" + + "github.com/coreos/go-iptables/iptables" + "github.com/google/nftables" +) + +const ( + // UNKNOWN is the default value for the firewall type for unknown firewall type + UNKNOWN FWType = iota + // IPTABLES is the value for the iptables firewall type + IPTABLES + // IPTABLESWITHV6 is the value for the iptables firewall type with ipv6 + IPTABLESWITHV6 + // NFTABLES is the value for the nftables firewall type + NFTABLES +) + +// SKIP_NFTABLES_ENV is the environment variable to skip nftables check +const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK" + +// FWType is the type for the firewall type +type FWType int + +// Check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found. +func Check() FWType { + nf := nftables.Conn{} + if _, err := nf.ListChains(); err == nil && os.Getenv(SKIP_NFTABLES_ENV) != "true" { + return NFTABLES + } + + ip, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) + if err == nil { + if isIptablesClientAvailable(ip) { + ipSupport := IPTABLES + ipv6, ip6Err := iptables.NewWithProtocol(iptables.ProtocolIPv6) + if ip6Err == nil { + if isIptablesClientAvailable(ipv6) { + ipSupport = IPTABLESWITHV6 + } + } + return ipSupport + } + } + + return UNKNOWN +} + +func isIptablesClientAvailable(client *iptables.IPTables) bool { + _, err := client.ListChains("filter") + return err == nil +} diff --git a/client/internal/routemanager/firewall_linux.go b/client/internal/routemanager/firewall_linux.go index 19a5a4cde..50d451a88 100644 --- a/client/internal/routemanager/firewall_linux.go +++ b/client/internal/routemanager/firewall_linux.go @@ -7,6 +7,8 @@ import ( "fmt" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/checkfw" ) const ( @@ -26,20 +28,20 @@ func genKey(format string, input string) string { return fmt.Sprintf(format, input) } -// NewFirewall if supported, returns an iptables manager, otherwise returns a nftables manager -func NewFirewall(parentCTX context.Context) (firewallManager, error) { - manager, err := newNFTablesManager(parentCTX) - if err == nil { - log.Debugf("nftables firewall manager will be used") - return manager, nil +// newFirewall if supported, returns an iptables manager, otherwise returns a nftables manager +func newFirewall(parentCTX context.Context) (firewallManager, error) { + checkResult := checkfw.Check() + switch checkResult { + case checkfw.IPTABLES, checkfw.IPTABLESWITHV6: + log.Debug("creating an iptables firewall manager for route rules") + ipv6Supported := checkResult == checkfw.IPTABLESWITHV6 + return newIptablesManager(parentCTX, ipv6Supported) + case checkfw.NFTABLES: + log.Info("creating an nftables firewall manager for route rules") + return newNFTablesManager(parentCTX), nil } - fMgr, err := newIptablesManager(parentCTX) - if err != nil { - log.Debugf("failed to initialize iptables for root mgr: %s", err) - return nil, err - } - log.Debugf("iptables firewall manager will be used") - return fMgr, nil + + return nil, fmt.Errorf("couldn't initialize nftables or iptables clients. Using a dummy firewall manager for route rules") } func getInPair(pair routerPair) routerPair { diff --git a/client/internal/routemanager/firewall_nonlinux.go b/client/internal/routemanager/firewall_nonlinux.go index 1b52a1e85..ae0627048 100644 --- a/client/internal/routemanager/firewall_nonlinux.go +++ b/client/internal/routemanager/firewall_nonlinux.go @@ -6,9 +6,10 @@ package routemanager import ( "context" "fmt" + "runtime" ) -// NewFirewall returns a nil manager -func NewFirewall(context.Context) (firewallManager, error) { - return nil, fmt.Errorf("firewall not supported on this OS") +// newFirewall returns a nil manager +func newFirewall(context.Context) (firewallManager, error) { + return nil, fmt.Errorf("firewall not supported on %s", runtime.GOOS) } diff --git a/client/internal/routemanager/iptables_linux.go b/client/internal/routemanager/iptables_linux.go index a87d4f4a3..9f6019305 100644 --- a/client/internal/routemanager/iptables_linux.go +++ b/client/internal/routemanager/iptables_linux.go @@ -49,29 +49,28 @@ type iptablesManager struct { mux sync.Mutex } -func newIptablesManager(parentCtx context.Context) (*iptablesManager, error) { +func newIptablesManager(parentCtx context.Context, ipv6Supported bool) (*iptablesManager, error) { ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) if err != nil { - return nil, err - } else if !isIptablesClientAvailable(ipv4Client) { - return nil, fmt.Errorf("iptables is missing for ipv4") - } - ipv6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6) - if err != nil { - log.Debugf("failed to initialize iptables for ipv6: %s", err) - } else if !isIptablesClientAvailable(ipv6Client) { - log.Infof("iptables is missing for ipv6") - ipv6Client = nil + return nil, fmt.Errorf("failed to initialize iptables for ipv4: %s", err) } ctx, cancel := context.WithCancel(parentCtx) - return &iptablesManager{ + manager := &iptablesManager{ ctx: ctx, stop: cancel, ipv4Client: ipv4Client, - ipv6Client: ipv6Client, rules: make(map[string]map[string][]string), - }, nil + } + + if ipv6Supported { + manager.ipv6Client, err = iptables.NewWithProtocol(iptables.ProtocolIPv6) + if err != nil { + log.Warnf("failed to initialize iptables for ipv6: %s. Routes for this protocol won't be applied.", err) + } + } + + return manager, nil } // CleanRoutingRules cleans existing iptables resources that we created by the agent @@ -486,8 +485,3 @@ func getIptablesRuleType(table string) string { } return ruleType } - -func isIptablesClientAvailable(client *iptables.IPTables) bool { - _, err := client.ListChains("filter") - return err == nil -} diff --git a/client/internal/routemanager/iptables_linux_test.go b/client/internal/routemanager/iptables_linux_test.go index dbe153f7b..4f733de34 100644 --- a/client/internal/routemanager/iptables_linux_test.go +++ b/client/internal/routemanager/iptables_linux_test.go @@ -16,11 +16,12 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { t.SkipNow() } - manager, _ := newIptablesManager(context.TODO()) + manager, err := newIptablesManager(context.TODO(), true) + require.NoError(t, err, "should return a valid iptables manager") defer manager.CleanRoutingRules() - err := manager.RestoreOrCreateContainers() + err = manager.RestoreOrCreateContainers() require.NoError(t, err, "shouldn't return error") require.Len(t, manager.rules, 2, "should have created maps for ipv4 and ipv6") diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 13d9d1f38..b31fe6327 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -36,7 +36,7 @@ type DefaultManager struct { // NewManager returns a new route manager func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status, initialRoutes []*route.Route) *DefaultManager { - serverRouter, err := newServerRouter(ctx, wgInterface) + srvRouter, err := newServerRouter(ctx, wgInterface) if err != nil { log.Errorf("server router is not supported: %s", err) } @@ -46,7 +46,7 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, ctx: mCTX, stop: cancel, clientNetworks: make(map[string]*clientNetwork), - serverRouter: serverRouter, + serverRouter: srvRouter, statusRecorder: statusRecorder, wgInterface: wgInterface, pubKey: pubKey, diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index 6f2ac294d..f6f5f359e 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -3,11 +3,12 @@ package routemanager import ( "context" "fmt" - "github.com/pion/transport/v2/stdnet" "net/netip" "runtime" "testing" + "github.com/pion/transport/v2/stdnet" + "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/client/internal/peer" @@ -30,6 +31,7 @@ func TestManagerUpdateRoutes(t *testing.T) { inputInitRoutes []*route.Route inputRoutes []*route.Route inputSerial uint64 + removeSrvRouter bool serverRoutesExpected int clientNetworkWatchersExpected int }{ @@ -117,6 +119,35 @@ func TestManagerUpdateRoutes(t *testing.T) { serverRoutesExpected: 1, clientNetworkWatchersExpected: 1, }, + { + name: "Should Create 1 Route For Client and Skip Server Route On Empty Server Router", + inputRoutes: []*route.Route{ + { + ID: "a", + NetID: "routeA", + Peer: localPeerKey, + Network: netip.MustParsePrefix("100.64.30.250/30"), + NetworkType: route.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, + }, + { + ID: "b", + NetID: "routeB", + Peer: remotePeerKey1, + Network: netip.MustParsePrefix("8.8.9.9/32"), + NetworkType: route.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, + }, + }, + inputSerial: 1, + removeSrvRouter: true, + serverRoutesExpected: 0, + clientNetworkWatchersExpected: 1, + }, { name: "Should Create 1 HA Route and 1 Standalone", inputRoutes: []*route.Route{ @@ -385,6 +416,10 @@ func TestManagerUpdateRoutes(t *testing.T) { routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder, nil) defer routeManager.Stop() + if testCase.removeSrvRouter { + routeManager.serverRouter = nil + } + if len(testCase.inputInitRoutes) > 0 { err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes) require.NoError(t, err, "should update routes with init routes") @@ -395,7 +430,7 @@ func TestManagerUpdateRoutes(t *testing.T) { require.Len(t, routeManager.clientNetworks, testCase.clientNetworkWatchersExpected, "client networks size should match") - if runtime.GOOS == "linux" { + if runtime.GOOS == "linux" && routeManager.serverRouter != nil { sr := routeManager.serverRouter.(*defaultServerRouter) require.Len(t, sr.routes, testCase.serverRoutesExpected, "server networks size should match") } diff --git a/client/internal/routemanager/nftables_linux.go b/client/internal/routemanager/nftables_linux.go index ca7d74f2a..25dc6e7db 100644 --- a/client/internal/routemanager/nftables_linux.go +++ b/client/internal/routemanager/nftables_linux.go @@ -86,10 +86,10 @@ type nftablesManager struct { mux sync.Mutex } -func newNFTablesManager(parentCtx context.Context) (*nftablesManager, error) { +func newNFTablesManager(parentCtx context.Context) *nftablesManager { ctx, cancel := context.WithCancel(parentCtx) - mgr := &nftablesManager{ + return &nftablesManager{ ctx: ctx, stop: cancel, conn: &nftables.Conn{}, @@ -97,18 +97,6 @@ func newNFTablesManager(parentCtx context.Context) (*nftablesManager, error) { rules: make(map[string]*nftables.Rule), defaultForwardRules: make([]*nftables.Rule, 2), } - - err := mgr.isSupported() - if err != nil { - return nil, err - } - - err = mgr.readFilterTable() - if err != nil { - return nil, err - } - - return mgr, nil } // CleanRoutingRules cleans existing nftables rules from the system @@ -147,6 +135,10 @@ func (n *nftablesManager) RestoreOrCreateContainers() error { } for _, table := range tables { + if table.Name == "filter" { + n.filterTable = table + continue + } if table.Name == nftablesTable { if table.Family == nftables.TableFamilyIPv4 { n.tableIPv4 = table @@ -259,21 +251,6 @@ func (n *nftablesManager) refreshRulesMap() error { return nil } -func (n *nftablesManager) readFilterTable() error { - tables, err := n.conn.ListTables() - if err != nil { - return err - } - - for _, t := range tables { - if t.Name == "filter" { - n.filterTable = t - return nil - } - } - return nil -} - func (n *nftablesManager) eraseDefaultForwardRule() error { if n.defaultForwardRules[0] == nil { return nil @@ -544,14 +521,6 @@ func (n *nftablesManager) removeRoutingRule(format string, pair routerPair) erro return nil } -func (n *nftablesManager) isSupported() error { - _, err := n.conn.ListChains() - if err != nil { - return fmt.Errorf("nftables is not supported: %s", err) - } - return nil -} - // getPayloadDirectives get expression directives based on ip version and direction func getPayloadDirectives(direction string, isIPv4 bool, isIPv6 bool) (uint32, uint32, []byte) { switch { diff --git a/client/internal/routemanager/nftables_linux_test.go b/client/internal/routemanager/nftables_linux_test.go index 01fc38885..dec800156 100644 --- a/client/internal/routemanager/nftables_linux_test.go +++ b/client/internal/routemanager/nftables_linux_test.go @@ -10,20 +10,23 @@ import ( "github.com/google/nftables/expr" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/internal/checkfw" ) func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) { - manager, err := newNFTablesManager(context.TODO()) - if err != nil { - t.Fatalf("failed to create nftables manager: %s", err) + if checkfw.Check() != checkfw.NFTABLES { + t.Skip("nftables not supported on this OS") } + manager := newNFTablesManager(context.TODO()) + nftablesTestingClient := &nftables.Conn{} defer manager.CleanRoutingRules() - err = manager.RestoreOrCreateContainers() + err := manager.RestoreOrCreateContainers() require.NoError(t, err, "shouldn't return error") require.Len(t, manager.chains, 2, "should have created chains for ipv4 and ipv6") @@ -126,19 +129,19 @@ func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) { } func TestNftablesManager_InsertRoutingRules(t *testing.T) { + if checkfw.Check() != checkfw.NFTABLES { + t.Skip("nftables not supported on this OS") + } for _, testCase := range insertRuleTestCases { t.Run(testCase.name, func(t *testing.T) { - manager, err := newNFTablesManager(context.TODO()) - if err != nil { - t.Fatalf("failed to create nftables manager: %s", err) - } + manager := newNFTablesManager(context.TODO()) nftablesTestingClient := &nftables.Conn{} defer manager.CleanRoutingRules() - err = manager.RestoreOrCreateContainers() + err := manager.RestoreOrCreateContainers() require.NoError(t, err, "shouldn't return error") err = manager.InsertRoutingRules(testCase.inputPair) @@ -226,19 +229,19 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) { } func TestNftablesManager_RemoveRoutingRules(t *testing.T) { + if checkfw.Check() != checkfw.NFTABLES { + t.Skip("nftables not supported on this OS") + } for _, testCase := range removeRuleTestCases { t.Run(testCase.name, func(t *testing.T) { - manager, err := newNFTablesManager(context.TODO()) - if err != nil { - t.Fatalf("failed to create nftables manager: %s", err) - } + manager := newNFTablesManager(context.TODO()) nftablesTestingClient := &nftables.Conn{} defer manager.CleanRoutingRules() - err = manager.RestoreOrCreateContainers() + err := manager.RestoreOrCreateContainers() require.NoError(t, err, "shouldn't return error") table := manager.tableIPv4 diff --git a/client/internal/routemanager/server_nonandroid.go b/client/internal/routemanager/server_nonandroid.go index bf7a1dfd4..6df632329 100644 --- a/client/internal/routemanager/server_nonandroid.go +++ b/client/internal/routemanager/server_nonandroid.go @@ -22,7 +22,7 @@ type defaultServerRouter struct { } func newServerRouter(ctx context.Context, wgInterface *iface.WGIface) (serverRouter, error) { - firewall, err := NewFirewall(ctx) + firewall, err := newFirewall(ctx) if err != nil { return nil, err } From c9b2ce08eb4ac7d8e22aa46f7f8ce90b9f841c18 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Tue, 5 Sep 2023 21:14:02 +0200 Subject: [PATCH 29/42] DNS forwarder and common ebpf loader (#1083) In case the 53 UDP port is not an option to bind then we hijack the DNS traffic with eBPF, and we forward the traffic to the listener on a custom port. With this implementation, we should be able to listen to DNS queries on any address and still set the local host system to send queries to the custom address on port 53. Because we tried to attach multiple XDP programs to the same interface, I did a refactor in the WG traffic forward code also. --- client/internal/dns/service_listener.go | 64 ++++++++-- .../{wgproxy => ebpf}/ebpf/bpf_bpfeb.go | 15 ++- client/internal/ebpf/ebpf/bpf_bpfeb.o | Bin 0 -> 15960 bytes .../{wgproxy => ebpf}/ebpf/bpf_bpfel.go | 15 ++- client/internal/ebpf/ebpf/bpf_bpfel.o | Bin 0 -> 15960 bytes client/internal/ebpf/ebpf/dns_fwd_linux.go | 51 ++++++++ client/internal/ebpf/ebpf/manager_linux.go | 116 ++++++++++++++++++ .../internal/ebpf/ebpf/manager_linux_test.go | 40 ++++++ client/internal/ebpf/ebpf/src/dns_fwd.c | 64 ++++++++++ client/internal/ebpf/ebpf/src/prog.c | 66 ++++++++++ client/internal/ebpf/ebpf/src/wg_proxy.c | 54 ++++++++ client/internal/ebpf/ebpf/wg_proxy_linux.go | 41 +++++++ client/internal/ebpf/instantiater_linux.go | 15 +++ client/internal/ebpf/instantiater_nonlinux.go | 10 ++ client/internal/ebpf/manager/manager.go | 9 ++ client/internal/wgproxy/ebpf/bpf_bpfeb.o | Bin 6264 -> 0 bytes client/internal/wgproxy/ebpf/bpf_bpfel.o | Bin 6264 -> 0 bytes client/internal/wgproxy/ebpf/loader.go | 84 ------------- client/internal/wgproxy/ebpf/loader_test.go | 18 --- .../internal/wgproxy/ebpf/src/portreplace.c | 90 -------------- client/internal/wgproxy/proxy_ebpf.go | 12 +- iface/iface.go | 2 +- 22 files changed, 553 insertions(+), 213 deletions(-) rename client/internal/{wgproxy => ebpf}/ebpf/bpf_bpfeb.go (84%) create mode 100644 client/internal/ebpf/ebpf/bpf_bpfeb.o rename client/internal/{wgproxy => ebpf}/ebpf/bpf_bpfel.go (84%) create mode 100644 client/internal/ebpf/ebpf/bpf_bpfel.o create mode 100644 client/internal/ebpf/ebpf/dns_fwd_linux.go create mode 100644 client/internal/ebpf/ebpf/manager_linux.go create mode 100644 client/internal/ebpf/ebpf/manager_linux_test.go create mode 100644 client/internal/ebpf/ebpf/src/dns_fwd.c create mode 100644 client/internal/ebpf/ebpf/src/prog.c create mode 100644 client/internal/ebpf/ebpf/src/wg_proxy.c create mode 100644 client/internal/ebpf/ebpf/wg_proxy_linux.go create mode 100644 client/internal/ebpf/instantiater_linux.go create mode 100644 client/internal/ebpf/instantiater_nonlinux.go create mode 100644 client/internal/ebpf/manager/manager.go delete mode 100644 client/internal/wgproxy/ebpf/bpf_bpfeb.o delete mode 100644 client/internal/wgproxy/ebpf/bpf_bpfel.o delete mode 100644 client/internal/wgproxy/ebpf/loader.go delete mode 100644 client/internal/wgproxy/ebpf/loader_test.go delete mode 100644 client/internal/wgproxy/ebpf/src/portreplace.c diff --git a/client/internal/dns/service_listener.go b/client/internal/dns/service_listener.go index 687ca2459..232f6ebc2 100644 --- a/client/internal/dns/service_listener.go +++ b/client/internal/dns/service_listener.go @@ -11,6 +11,9 @@ import ( "github.com/miekg/dns" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/ebpf" + ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager" ) const ( @@ -24,10 +27,11 @@ type serviceViaListener struct { dnsMux *dns.ServeMux customAddr *netip.AddrPort server *dns.Server - runtimeIP string - runtimePort int + listenIP string + listenPort int listenerIsRunning bool listenerFlagLock sync.Mutex + ebpfService ebpfMgr.Manager } func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort) *serviceViaListener { @@ -43,6 +47,7 @@ func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort) *service UDPSize: 65535, }, } + return s } @@ -55,13 +60,21 @@ func (s *serviceViaListener) Listen() error { } var err error - s.runtimeIP, s.runtimePort, err = s.evalRuntimeAddress() + s.listenIP, s.listenPort, err = s.evalListenAddress() if err != nil { log.Errorf("failed to eval runtime address: %s", err) return err } - s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort) + s.server.Addr = fmt.Sprintf("%s:%d", s.listenIP, s.listenPort) + if s.shouldApplyPortFwd() { + s.ebpfService = ebpf.GetEbpfManagerInstance() + err = s.ebpfService.LoadDNSFwd(s.listenIP, s.listenPort) + if err != nil { + log.Warnf("failed to load DNS port forwarder, custom port may not work well on some Linux operating systems: %s", err) + s.ebpfService = nil + } + } log.Debugf("starting dns on %s", s.server.Addr) go func() { s.setListenerStatus(true) @@ -69,9 +82,10 @@ func (s *serviceViaListener) Listen() error { err := s.server.ListenAndServe() if err != nil { - log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.runtimePort, err) + log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.listenPort, err) } }() + return nil } @@ -90,6 +104,13 @@ func (s *serviceViaListener) Stop() { if err != nil { log.Errorf("stopping dns server listener returned an error: %v", err) } + + if s.ebpfService != nil { + err = s.ebpfService.FreeDNSFwd() + if err != nil { + log.Errorf("stopping traffic forwarder returned an error: %v", err) + } + } } func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) { @@ -101,11 +122,18 @@ func (s *serviceViaListener) DeregisterMux(pattern string) { } func (s *serviceViaListener) RuntimePort() int { - return s.runtimePort + s.listenerFlagLock.Lock() + defer s.listenerFlagLock.Unlock() + + if s.ebpfService != nil { + return defaultPort + } else { + return s.listenPort + } } func (s *serviceViaListener) RuntimeIP() string { - return s.runtimeIP + return s.listenIP } func (s *serviceViaListener) setListenerStatus(running bool) { @@ -136,10 +164,30 @@ func (s *serviceViaListener) getFirstListenerAvailable() (string, int, error) { return "", 0, fmt.Errorf("unable to find an unused ip and port combination. IPs tested: %v and ports %v", ips, ports) } -func (s *serviceViaListener) evalRuntimeAddress() (string, int, error) { +func (s *serviceViaListener) evalListenAddress() (string, int, error) { if s.customAddr != nil { return s.customAddr.Addr().String(), int(s.customAddr.Port()), nil } return s.getFirstListenerAvailable() } + +// shouldApplyPortFwd decides whether to apply eBPF program to capture DNS traffic on port 53. +// This is needed because on some operating systems if we start a DNS server not on a default port 53, the domain name +// resolution won't work. +// So, in case we are running on Linux and picked a non-default port (53) we should fall back to the eBPF solution that will capture +// traffic on port 53 and forward it to a local DNS server running on 5053. +func (s *serviceViaListener) shouldApplyPortFwd() bool { + if runtime.GOOS != "linux" { + return false + } + + if s.customAddr != nil { + return false + } + + if s.listenPort == defaultPort { + return false + } + return true +} diff --git a/client/internal/wgproxy/ebpf/bpf_bpfeb.go b/client/internal/ebpf/ebpf/bpf_bpfeb.go similarity index 84% rename from client/internal/wgproxy/ebpf/bpf_bpfeb.go rename to client/internal/ebpf/ebpf/bpf_bpfeb.go index c4875c3ae..5d2765862 100644 --- a/client/internal/wgproxy/ebpf/bpf_bpfeb.go +++ b/client/internal/ebpf/ebpf/bpf_bpfeb.go @@ -54,13 +54,16 @@ type bpfSpecs struct { // // It can be passed ebpf.CollectionSpec.Assign. type bpfProgramSpecs struct { - NbWgProxy *ebpf.ProgramSpec `ebpf:"nb_wg_proxy"` + NbXdpProg *ebpf.ProgramSpec `ebpf:"nb_xdp_prog"` } // bpfMapSpecs contains maps before they are loaded into the kernel. // // It can be passed ebpf.CollectionSpec.Assign. type bpfMapSpecs struct { + NbFeatures *ebpf.MapSpec `ebpf:"nb_features"` + NbMapDnsIp *ebpf.MapSpec `ebpf:"nb_map_dns_ip"` + NbMapDnsPort *ebpf.MapSpec `ebpf:"nb_map_dns_port"` NbWgProxySettingsMap *ebpf.MapSpec `ebpf:"nb_wg_proxy_settings_map"` } @@ -83,11 +86,17 @@ func (o *bpfObjects) Close() error { // // It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign. type bpfMaps struct { + NbFeatures *ebpf.Map `ebpf:"nb_features"` + NbMapDnsIp *ebpf.Map `ebpf:"nb_map_dns_ip"` + NbMapDnsPort *ebpf.Map `ebpf:"nb_map_dns_port"` NbWgProxySettingsMap *ebpf.Map `ebpf:"nb_wg_proxy_settings_map"` } func (m *bpfMaps) Close() error { return _BpfClose( + m.NbFeatures, + m.NbMapDnsIp, + m.NbMapDnsPort, m.NbWgProxySettingsMap, ) } @@ -96,12 +105,12 @@ func (m *bpfMaps) Close() error { // // It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign. type bpfPrograms struct { - NbWgProxy *ebpf.Program `ebpf:"nb_wg_proxy"` + NbXdpProg *ebpf.Program `ebpf:"nb_xdp_prog"` } func (p *bpfPrograms) Close() error { return _BpfClose( - p.NbWgProxy, + p.NbXdpProg, ) } diff --git a/client/internal/ebpf/ebpf/bpf_bpfeb.o b/client/internal/ebpf/ebpf/bpf_bpfeb.o new file mode 100644 index 0000000000000000000000000000000000000000..0559da486ebeb13a7f0dbf21ed0d3bf4d40e76c3 GIT binary patch literal 15960 zcmds8e{5XGao%@7C`y(SI*DVocI+qHl8>@Pk&Ae(@+nvA{wvmdV_xVp1JuxR-Jc&n{5wqj8^%3PZl_Fs(a|5m z?m@Pw47zrgKkjOo*Nw-Id3r3~QJnnd^=oXw{_3?xi*&Ktj9>od^{b>WF7j?n^IO;d zo%Q9P<=vP~(7%5D`qhiEg}2K%d!YAIaxOw{%#Px8V82p-EG9R!KR!&3G4BrAD&D~U zZQj0~7r!2rohC&KpWkk_u-5A^em}XG|AWSQ_W6A#OT3Z2>2YTdO1D|PnD~Gk%wgZ_L|P z1lp_GPd~5NFScO6Y$EXidUjt;PdI&_SN#6a@w-vp+-SWn7o6QY3+8u7!G9YMESsyiy5Br@g~ z?XEE?#NV^T8;u|B{=ek;B)(LhU*q_LJb#6FgE$NG{6E@ub{=h>PwcBazec^AjbVd$4fCRn2khq$&WjJpI^QfWN*qo%oPLrQC7yI% zl$l0Kwk5(Tz~)sT)?8~|_UsbyDEm;nN9^f}JDjIG3%DIY@@c{+i<9h*`84usPmeJP z{`mK-d)q*wWPi(DTXIA;**ae%8AJq!dm-Ai)x9rRwhRi}DVjO!e8-b9YCx}t0|fn& z(16jelg0%*(Qmk#&#&k=g^ocpZ?m64&|gDM{SnZ=EA%MnH$n59HinA@E|~)O-(o`) zJmaE2_I?Aizwd}_GwAOMjk%7#Cwd$TeOc}3VMg2>2+L1Pv<6*yGq4NEw$O>*L;Vq4 zDA5(%NXMZ6Sqyv>^uGzszOGB(oGR$uiY@05RB`f8ZAC-2g~=4(1jKZTIf*p-NZluN zjQ;$*4T8tP&~I$*6!$jlq;{hI80uZ#FVr#cfULPa^>kt*{U^e=&)kjeWxg$XKLh)I zBiw_fJ&I1bzKFq|UqFrdVd?p0+(^fui$afpt_ocM%~+z&JZQ!Y=_Ono7v1O4_6^i` zqh6B!KSjB0FzEOGsVo-f2yJ@bQTPY&4w`{lMsF^SufW2xdk3BZ9ex#7`)ERIvEHc|zfaVBg=D1ta#Yy{~oSWKJO8 z{0))q5#6^H{+?jp_umTkeZ2>VS(rC$JfZp?Cb}IXiF~^ zh%%NPuCkah%EH)v*iw$(-J!SKj<#an$cd8rTRslHJ^w;C7Jc@yJ)m;|-5<~cLbKn! z0d4mZFzjr(JLvZj(INi<@cmf0*I2lh9<-FBcLnrNLgx?HYBQ!$8ZYLXrB2&D!1Jm>W^QxJf{G&5fHcX{FV_qsXnw47Bv}#RUt^7=>YAQuD)fhjK zuT4#vR%?dO*HYcp9cO-`G9es5fH?2 zfOE}OV|KEI2*DuZ7z?;N7M8Oc??Ub8sxp=1io5B0JRbCXgqRJidIRRCbN53jIaS=h zJFd2BPPKv6 z*{*h@o<5PcM>5%|H`}IOJ=yi#xIxHUTXunhptnZ7w;c^c{4OhU6YyW{nyz&=o1o6F za!wqw!RSYkM5AYX=x_{2gMf@*c?Bohn+K?EmV{{6DHy@mF-l1|Du!Su!nXUow{K|w z;K9L15A5IH@25ElRtXB|dQ&cSx-<}1OQ-Y2W=mYQd9<19I3TMSC?mUjD%E0X&NSxor)EpDrM!cs zZs}xe&ftNO$8o52U%gI@Qqx+Xn@z?Wt-E|5B5PRKtSr{g zjl>m9(cBrQw`=NrE7QfLFMJ&PZv4cvpM;j~$vKqA@)ca4G$uWU8t-5k{uY3D$DX%P`|~q@N8x)6+dhYj1@BZmd$-6R6kIoE zj(rRZM&D<7q#wccCa$(It9k% z0jB-7;4c_+o<6TQoO;Zd3n0y!;HEKe)84giY=;?)C(1dwsk`w+gMwu|5yuydCmIRl ze9b&ASjH30JM703`SHnkq9u_NV?0sY;nbInd56DMh`-?4!q*SZpO>fl1hj~cVS4Y3GFf6~P0KiwPP?Z7f#D~HW=E|7yx4+a?hr-ucFnmd$7YsYU!uBhI z(Z0a;7%%N-zG2J^UIWwX0fv5tI*`L3GY5aT2d)0HF>u9Xat>#)@Moc$83{0SGi8Up z{~7O}_@6l^_Qe0pQh?!q=8|CXKeHxS{Licl7XLHX0u2AN@PYnj@&5ff<~uv*FkH7K zJFn#8U$$k;@3Z|mkrT81l3?`rHf(1v1sL&=Z3me8%YrW>o)I_MHNjV~e_~+qzvm95 z23*-|PTqsq`99X8$2iRSU&Z{}-1iKaC_NW&{TcEcL9gI{Fy=Y5_`@?ycsBj!Qaisk}L`{x|?c`sUV7V6&u#aSQcF1^-0yVj;lrCshtG^I|=~%!>;F9sphpF#VD7Bzf_IlUvNZ z*cOcSVqQ%0CNcY83*_5@WxPpVgqY*8IET77!012SC-^nVi|E6~#}|$H7s}%R!yo3^ zL=Icbvv%K0UP_M!7`D=Nhkag3W4t!6A+J2e_Dh0&UOgXR_?B)99+SMfE*R}?UcD;V z=Oqls+9zgSh5un*%FH_qHCr+ZN-qBSyp&lIx#Xpc_?MNulxd5c80(Q)4lw2~vm#jX zlFy@A$x9jWKgmn*ivIe%l z@Ht1mM1RoXrxAa?M2KgvpBnXC`16+mza8+u(&00PeTn|2!{;+KDR1zDNrD`)g6gmF zu-T;ep8)=I@Vf`<{QGcKrO~nQROF{KKGMKc8&>k$~Ul)VZ$Ffq;*$Y$5+X@c9ig zETJFN_4}6g_zd=V6g2zxxB!~(S02|v<1C)wCD3;Tc&!`TWMt6Be$bzN_6Ge$ zWypj6q9x#q6`0=x*C-|D>lJm^sW z67ao&oc+o7C{Xsy9JbGlDtQ?=u$P$w4*Jh50SEo_jNGNa3>@^IK^%ttXO|SFeqhh; z^ZUV*@tJL)9|$n}3+CUf0tf4B)`5d~vS+Pee(hN+T;CDshwD2B9N4#cFJg+;#nY^f4*aN%|Bo`|9n<=Vd@9| zWZnX%zp`KL`f_|8za#n{vpxOynB(Pq3C@DF>w!lWOP2YJ^X)O$m*eyJLC}mJj|V}M z=kaGi?+EZH=m!E^0KF5J@1N_*`0)5FXs*A z@G)SV`IGi#C8vL3`vry1DSTexOA0S5ysGfJ!q>WS)T{6|g>woID?F<3afQnY&ndj1 z@HvIgD||`eWrbH2URU^9H%|2`yiMVp!ov!WDtuhwvchu;FDQIY;qwY#Qg~V6RfX3T zzSfQHyHF61_PiM4oYEgwcvRuz3YQh0Q+Pq)a|)kV_>#iQ3a=`>uJE;PY|rb#cmla-BcwOOZ-Ppc21mopzSYdw3<`f=Qcnmnm zciDo%b%o~@WE4--i zlEN1hZY#W^@S4I`yRm%@2*$suPswA22NfPscue7f!gYn`6<$<$N#P3$w-sJdcunD} z-MAMpjJMuCg=2*W6&_J|OyPpUb%o~@UQ~EV;R_146<$$zP2sEE*uLHd>%X~A$zz2F z6&_J|OyPpUb%o~@UQ~EV;R_146<+c9;ZMc)+_y8gW6$ov?mxKb`UshPnjC19X5_EO zNF6@**#Z73ng2NXKA%OCZ>hnTSNm<%xAEUX?H5Y>Ezz6h$HNC$Maa*#3YOdQTyMaMuaC-f58!{7&U%&Nq%qh?fbx6$L0M|{r8MP zhGL)dj&yh5p6ExoHt7H1c`CsCJZ|jYzbEJYgJ)&-ZSL6DN zy=1!XsH<@&k)JfK@NxR$E{BXx<>exxT@iuMchGy$ZRb#RVq(JpN_Kf4f z&dg+HoH(>>y^xi%EY;?ZvV^KCMg=LLf`U|3OACr6D&a-@he5k2K#{y+RatB#0zx9J zK;iqn_nh(diJf#=fe=@Hx%YF<{c-NS=brc8Jooj(FMeTFQ$ux=lXMNxP4bCy>~7Lz{iIPuzd23j-9Rep{C5 zcU1bNJ*;|lAdK_tA#+#TL*eQJA-D4@2J_wVrcPhM>g_yq{;lqiuYcRTGhrP4cc?9t zK6$?fe1qcZ(C%t?$hod}oImtWCOh*(TW{WC3yzmB)e@E~bkVL$9?_AdYJPy`XisO?V-P5ruzuFKUZNJy^jBkJ{yW3K#pVO z;&XNOWKy9P$LdM-d(Ib_P`)$28|Pu6aSkih@&x7Ql1lm2t@B=9lyefFk9j$N51}sS zPY=wU4*qbwX}x9EkbvmV{mYRK2Zj9n-`zWGfrIDI;tr_)tNALmX&&#gWp@%{L!uX{dy`T?&$ z-t>mQ^AvL)-OphCIkVQL7rkJm`Tzfk^NaqqG|mVAT8#78y}vEzXG`P!Kl*y~_YmiA z??1%((Epv{y!!Dciu2VSe!uP>=RX_j-7U^P5c<1ooUd(@{_Ym%w@N)M4ASTWpkl#QFzRwJu(OX#(;R+%)19<>ve9xRcPwQAVAae zpGLWgvbbCG87X7F&0mr-NGKD*jbMvSpuLWL? zI?b%}G2|P)#JTHIPQd>necp%ize|~8-3n#@M8tYbck3(Ea;vtWu%4A=>@6s(om-{M zF&>t3g0ik#7s^afZD>Wl9R>PhJ96GmcOgeoyP7(|Fm*lmR+9IPKUy8+;f}jlW$4|9hQAAjZ#Q$^-$R+}$@H<56O{i+%DpK6P|5=+ zuYsen{v@uJ@J_XH8_FE#vxYl`={wXvA&l-kKV$fau-ZLnxGJpUeNC7d`~QYvh?wrB z_#M&lshV})H~cMO9s6&Eb*zsA>$zhWOz;%v_Tt))i@x>)smF4H>z4Gh4`nvW;BZ3c zU1BGgiFMn3$@`jnEGqAg${WxYefYMIM)EC?BQ|(>B69ZmNL0>7<&LNv-u}b5dqj?7 z;q)xAp9iCHpNQISkLI#nQF&KXmREY`&s%z^Tpn}PLN1@y3iW!ibh73;>xGHB z>)cbT4(&nsPlwOuJBM6?RIH?Da%1C#S$}yQ`r;==O})s(Xs1p|SCNVb5T( zUfVNxu3o5}%~d)_(6(Brk5@}cnhsKw4wXx_dahJY_w7xx6i&#fyg7Yh<5}65Pf;|1 z#qQ!MP%4wCWp_n&4!U1*#mY#&>WU*{?o6RtE0#;HUasM)r(=bZE9TvBHFq*C4-dO~ zeT?_ldU*(YGc-~d8g*$p-s8qgwc^QAA)gG5!24gPJYY9p`#!rJ&Lh;#XT zwdQKfT;Bi7R{Qp{75;6sDwgVQJns#xmB*_?1(z?>>dqThp~kVsM84v;5)MI5BF@(8 z)$yS^JOqtKx8$J06zn>7+~La53*#l`6U)&2B@b##6<%-ASol z9;r22cF1-t*C;%sNRu(&q*&1g#o;6iV;+Bob4;FoI(h!^ONWm>`(pasOJ7X;j~qRg zw6|j-&~wFDLHt*ol>9#?-* z&z?;EBbl7kyB$-l-Z=F!y8rbBfDcf8#rY0+G}{Cy}N_@W(^lz@)Qi`TkMpYcvf`5MGt#K%OpsscmLT3mANvUu)&DZK;vX7&+?4kj(i&14ncdguT=CFYt8$QU6 zwQi&jXJ^Tr4C77y6fYB7N4^ z!<~s$U+C{YdgR!V^!RiAy4Hy8*luHeex zSL`0-T;-6mR@8X)lb?-Ij`jD%BJ@f;@A0yv<(YXu4-FSe{yy+Z7zW(#zmMOf; zMZ6I(Bm4`{H(d^MV*gL_;rt5T)kR+eXWn=2reVG)u)fYGE3EVB64v?jfaCe}MtaWY zgs{$MQdsA6LAWuWi2ZzozXp5PeF%GnVPZ&-(_YoL2+y$oEY=r!Y~KmdYx4n9zasq4 z$nXCS)<5En;ENGI3ceKa0q~WGJHgWtuLILA+S|PvK7{?h7V#GFg76XO+ZIp}IqQ=* zf-@060&a=88=Q?e3+{?|9k@sMtH@ig!ykpczhLv}`yb-^F0Ku|&P@upH)HJQupYt( zz?+X@JtA%aUl#6%e#6hh|Kj!on6L0V(BC(P`9{0}oPnIP z%z4-!aT|C$h}VHvHtK!#=aNC_A4m=DK1yX+9M3#>F|qSPHM#xQL_TxdJ)kd`9>@m~p-Z{4UBfxESZ#BK>;k zHI9A^J>$H_NsWsGQeWfZM8u4XmxML$%?YDBpRO7Gf-vI`<6?q5wy#U#j<)X+*7}U2 zTpx{#0}(SWjzrA37{^h@#W;>KF2-?``=xPImsNB#_4ZFPCtp$+knl{`ZJC){WY#l z{IyxqrzQV$!ly9 ze^%=2N@#s9KkaFgiZ}g!;QNf9m$nYiBkb~wCvj$^_r1FQo{8J)be6Lag&0`zwkIZj&eYlv&v}phr6PfQ=6#^2O?`BiDn8^Ij zHjj%**%yTUcvJ{;yxxd8-ie4g-bloLyofm-?`upH^ZOm|TTEp3Hxn`IYn-5-+#>#? z_0^xqydMoj`iH<5Bj)!y-WQm({&d8AH|2eSiF$sM%F2e4*}gAg>NPHDe-n|O<4;A* z_R|s5-r0yb{>_Nl|CGcfwTI8AOlpt%HF+DZHY$)&a={rJ&-aQ))_=K543J=b$G zVy=I@-)uh<=~;g+V*eK&X8kW4=6#rn^J9PS2{WU;y1!)3@54yX`r4ji+P@I#{r7Cy zUz)}JSRXOhyDehQZ%4$mzboP$;HM(q3GRz{J9t2tX0+p)h?xFyK4R{_-fvl->-Tn~ zr~mE7#Y8>VXg>i-`Kf>ne$Mc*i1C@zzsFFd?I$8V`@dlLlHtpS^*m^QvqrDyL-l$- zly92)85)74dObhNyidmZ9>aZxPZ-vCs{QGCQ@&v8Uow2z@U-Ds!`BQi7`|yZBk@}8 zX)~NP++$edg|v(}rgaUo*U5_@?2E#1FM!pBI(0M&Dz&&+rMu zBZemoUod>h@MXi(hGz|5GrVBSYayE*O)b|+fkC^c;j>qJR(N7wV<1y<`8T}Q*Glu63&l|pB z7>~3*aef@H#c*P{%W$vZe!~NXD~2ZxUo<>r_=@2f!*hn`4c{=VuV9V&%PX03Vz|q2 zui<{f1BNSxCkjlR74Z>!oyzlHiQ zl>S?!|6?`QX?#C5f4`u5P z%&FA>g;nw7;-^&V^{ecXRZ?Lm{w}JX8+4tubEkYRA6&!d*W11+`!&T%x)%=}+Mj-& z*}k+VP`A&!(*r^Jgh+kfo$e2~*XMiFeZ;Pi?MwGcUCH-$>xW;O*WKAob$8m`70&aI z`|k_5*~z2AlCJ*lKtO$tRHZytD*E4pD3kEpkoQWd#VF*MHtOv_iLTw-kc7fON5~U- zj^p25v#9O$_mu_6mK(oc#?||L=xwpR^7m0#Zv07p0706R33GhL4m;3M6MAaXJrRd%Fjd(!M{gjIwS2I%KrTwx5&sy!Xe+w$F+EE48=tviv=1zc(HP z<-d~g^}SNnoGX*&$qM7Iksk=Hz1lAH;@oS}URm??(*E1`sU+mKzX`NoZvW+fFZAEH z*pmCtF75tv-qt3F+%>**Bh9n(DB1D(_r=SN-$Dh_jI`(Y{H@ZCzX^k{e(@H6R-MGv iNON+2Wg8lDYwJ%zVE2@jr+0=X`d-3zv_Z#Lzy2?@yHG6v literal 0 HcmV?d00001 diff --git a/client/internal/ebpf/ebpf/dns_fwd_linux.go b/client/internal/ebpf/ebpf/dns_fwd_linux.go new file mode 100644 index 000000000..1b6493692 --- /dev/null +++ b/client/internal/ebpf/ebpf/dns_fwd_linux.go @@ -0,0 +1,51 @@ +package ebpf + +import ( + "encoding/binary" + "net" + + log "github.com/sirupsen/logrus" +) + +const ( + mapKeyDNSIP uint32 = 0 + mapKeyDNSPort uint32 = 1 +) + +func (tf *GeneralManager) LoadDNSFwd(ip string, dnsPort int) error { + log.Debugf("load ebpf DNS forwarder: address: %s:%d", ip, dnsPort) + tf.lock.Lock() + defer tf.lock.Unlock() + + err := tf.loadXdp() + if err != nil { + return err + } + + err = tf.bpfObjs.NbMapDnsIp.Put(mapKeyDNSIP, ip2int(ip)) + if err != nil { + return err + } + + err = tf.bpfObjs.NbMapDnsPort.Put(mapKeyDNSPort, uint16(dnsPort)) + if err != nil { + return err + } + + tf.setFeatureFlag(featureFlagDnsForwarder) + err = tf.bpfObjs.NbFeatures.Put(mapKeyFeatures, tf.featureFlags) + if err != nil { + return err + } + return nil +} + +func (tf *GeneralManager) FreeDNSFwd() error { + log.Debugf("free ebpf DNS forwarder") + return tf.unsetFeatureFlag(featureFlagDnsForwarder) +} + +func ip2int(ipString string) uint32 { + ip := net.ParseIP(ipString) + return binary.BigEndian.Uint32(ip.To4()) +} diff --git a/client/internal/ebpf/ebpf/manager_linux.go b/client/internal/ebpf/ebpf/manager_linux.go new file mode 100644 index 000000000..9dfdc0ad1 --- /dev/null +++ b/client/internal/ebpf/ebpf/manager_linux.go @@ -0,0 +1,116 @@ +package ebpf + +import ( + _ "embed" + "net" + "sync" + + "github.com/cilium/ebpf/link" + "github.com/cilium/ebpf/rlimit" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/ebpf/manager" +) + +const ( + mapKeyFeatures uint32 = 0 + + featureFlagWGProxy = 0b00000001 + featureFlagDnsForwarder = 0b00000010 +) + +var ( + singleton manager.Manager + singletonLock = &sync.Mutex{} +) + +// required packages libbpf-dev, libc6-dev-i386-amd64-cross + +// GeneralManager is used to load multiple eBPF programs with a custom check (if then) done in prog.c +// The manager simply adds a feature (byte) of each program to a map that is shared between the userspace and kernel. +// When packet arrives, the C code checks for each feature (if it is set) and executes each enabled program (e.g., dns_fwd.c and wg_proxy.c). +// +//go:generate go run github.com/cilium/ebpf/cmd/bpf2go -cc clang-14 bpf src/prog.c -- -I /usr/x86_64-linux-gnu/include +type GeneralManager struct { + lock sync.Mutex + link link.Link + featureFlags uint16 + bpfObjs bpfObjects +} + +// GetEbpfManagerInstance return a static eBpf Manager instance +func GetEbpfManagerInstance() manager.Manager { + singletonLock.Lock() + defer singletonLock.Unlock() + if singleton != nil { + return singleton + } + singleton = &GeneralManager{} + return singleton +} + +func (tf *GeneralManager) setFeatureFlag(feature uint16) { + tf.featureFlags = tf.featureFlags | feature +} + +func (tf *GeneralManager) loadXdp() error { + if tf.link != nil { + return nil + } + // it required for Docker + err := rlimit.RemoveMemlock() + if err != nil { + return err + } + + iFace, err := net.InterfaceByName("lo") + if err != nil { + return err + } + + // load pre-compiled programs into the kernel. + err = loadBpfObjects(&tf.bpfObjs, nil) + if err != nil { + return err + } + + tf.link, err = link.AttachXDP(link.XDPOptions{ + Program: tf.bpfObjs.NbXdpProg, + Interface: iFace.Index, + }) + + if err != nil { + _ = tf.bpfObjs.Close() + tf.link = nil + return err + } + return nil +} + +func (tf *GeneralManager) unsetFeatureFlag(feature uint16) error { + tf.lock.Lock() + defer tf.lock.Unlock() + tf.featureFlags &^= feature + + if tf.link == nil { + return nil + } + + if tf.featureFlags == 0 { + return tf.close() + } + + return tf.bpfObjs.NbFeatures.Put(mapKeyFeatures, tf.featureFlags) +} + +func (tf *GeneralManager) close() error { + log.Debugf("detach ebpf program ") + err := tf.bpfObjs.Close() + if err != nil { + log.Warnf("failed to close eBpf objects: %s", err) + } + + err = tf.link.Close() + tf.link = nil + return err +} diff --git a/client/internal/ebpf/ebpf/manager_linux_test.go b/client/internal/ebpf/ebpf/manager_linux_test.go new file mode 100644 index 000000000..956499e5b --- /dev/null +++ b/client/internal/ebpf/ebpf/manager_linux_test.go @@ -0,0 +1,40 @@ +package ebpf + +import ( + "testing" +) + +func TestManager_setFeatureFlag(t *testing.T) { + mgr := GeneralManager{} + mgr.setFeatureFlag(featureFlagWGProxy) + if mgr.featureFlags != 1 { + t.Errorf("invalid faeture state") + } + + mgr.setFeatureFlag(featureFlagDnsForwarder) + if mgr.featureFlags != 3 { + t.Errorf("invalid faeture state") + } +} + +func TestManager_unsetFeatureFlag(t *testing.T) { + mgr := GeneralManager{} + mgr.setFeatureFlag(featureFlagWGProxy) + mgr.setFeatureFlag(featureFlagDnsForwarder) + + err := mgr.unsetFeatureFlag(featureFlagWGProxy) + if err != nil { + t.Errorf("unexpected error: %s", err) + } + if mgr.featureFlags != 2 { + t.Errorf("invalid faeture state, expected: %d, got: %d", 2, mgr.featureFlags) + } + + err = mgr.unsetFeatureFlag(featureFlagDnsForwarder) + if err != nil { + t.Errorf("unexpected error: %s", err) + } + if mgr.featureFlags != 0 { + t.Errorf("invalid faeture state, expected: %d, got: %d", 0, mgr.featureFlags) + } +} diff --git a/client/internal/ebpf/ebpf/src/dns_fwd.c b/client/internal/ebpf/ebpf/src/dns_fwd.c new file mode 100644 index 000000000..5228c7e75 --- /dev/null +++ b/client/internal/ebpf/ebpf/src/dns_fwd.c @@ -0,0 +1,64 @@ +const __u32 map_key_dns_ip = 0; +const __u32 map_key_dns_port = 1; + +struct bpf_map_def SEC("maps") nb_map_dns_ip = { + .type = BPF_MAP_TYPE_ARRAY, + .key_size = sizeof(__u32), + .value_size = sizeof(__u32), + .max_entries = 10, +}; + +struct bpf_map_def SEC("maps") nb_map_dns_port = { + .type = BPF_MAP_TYPE_ARRAY, + .key_size = sizeof(__u32), + .value_size = sizeof(__u16), + .max_entries = 10, +}; + +__be32 dns_ip = 0; +__be16 dns_port = 0; + +// 13568 is 53 in big endian +__be16 GENERAL_DNS_PORT = 13568; + +bool read_settings() { + __u16 *port_value; + __u32 *ip_value; + + // read dns ip + ip_value = bpf_map_lookup_elem(&nb_map_dns_ip, &map_key_dns_ip); + if(!ip_value) { + return false; + } + dns_ip = htonl(*ip_value); + + // read dns port + port_value = bpf_map_lookup_elem(&nb_map_dns_port, &map_key_dns_port); + if (!port_value) { + return false; + } + dns_port = htons(*port_value); + return true; +} + +int xdp_dns_fwd(struct iphdr *ip, struct udphdr *udp) { + if (dns_port == 0) { + if(!read_settings()){ + return XDP_PASS; + } + bpf_printk("dns port: %d", ntohs(dns_port)); + bpf_printk("dns ip: %d", ntohl(dns_ip)); + } + + if (udp->dest == GENERAL_DNS_PORT && ip->daddr == dns_ip) { + udp->dest = dns_port; + return XDP_PASS; + } + + if (udp->source == dns_port && ip->saddr == dns_ip) { + udp->source = GENERAL_DNS_PORT; + return XDP_PASS; + } + + return XDP_PASS; +} \ No newline at end of file diff --git a/client/internal/ebpf/ebpf/src/prog.c b/client/internal/ebpf/ebpf/src/prog.c new file mode 100644 index 000000000..09b649370 --- /dev/null +++ b/client/internal/ebpf/ebpf/src/prog.c @@ -0,0 +1,66 @@ +#include +#include // ETH_P_IP +#include +#include +#include +#include +#include +#include "dns_fwd.c" +#include "wg_proxy.c" + +#define bpf_printk(fmt, ...) \ + ({ \ + char ____fmt[] = fmt; \ + bpf_trace_printk(____fmt, sizeof(____fmt), ##__VA_ARGS__); \ + }) + +const __u16 flag_feature_wg_proxy = 0b01; +const __u16 flag_feature_dns_fwd = 0b10; + +const __u32 map_key_features = 0; +struct bpf_map_def SEC("maps") nb_features = { + .type = BPF_MAP_TYPE_ARRAY, + .key_size = sizeof(__u32), + .value_size = sizeof(__u16), + .max_entries = 10, +}; + +SEC("xdp") +int nb_xdp_prog(struct xdp_md *ctx) { + __u16 *features; + features = bpf_map_lookup_elem(&nb_features, &map_key_features); + if (!features) { + return XDP_PASS; + } + + void *data = (void *)(long)ctx->data; + void *data_end = (void *)(long)ctx->data_end; + struct ethhdr *eth = data; + struct iphdr *ip = (data + sizeof(struct ethhdr)); + struct udphdr *udp = (data + sizeof(struct ethhdr) + sizeof(struct iphdr)); + + // return early if not enough data + if (data + sizeof(struct ethhdr) + sizeof(struct iphdr) + sizeof(struct udphdr) > data_end){ + return XDP_PASS; + } + + // skip non IPv4 packages + if (eth->h_proto != htons(ETH_P_IP)) { + return XDP_PASS; + } + + // skip non UPD packages + if (ip->protocol != IPPROTO_UDP) { + return XDP_PASS; + } + + if (*features & flag_feature_dns_fwd) { + xdp_dns_fwd(ip, udp); + } + + if (*features & flag_feature_wg_proxy) { + xdp_wg_proxy(ip, udp); + } + return XDP_PASS; +} +char _license[] SEC("license") = "GPL"; \ No newline at end of file diff --git a/client/internal/ebpf/ebpf/src/wg_proxy.c b/client/internal/ebpf/ebpf/src/wg_proxy.c new file mode 100644 index 000000000..ecfedc6b3 --- /dev/null +++ b/client/internal/ebpf/ebpf/src/wg_proxy.c @@ -0,0 +1,54 @@ +const __u32 map_key_proxy_port = 0; +const __u32 map_key_wg_port = 1; + +struct bpf_map_def SEC("maps") nb_wg_proxy_settings_map = { + .type = BPF_MAP_TYPE_ARRAY, + .key_size = sizeof(__u32), + .value_size = sizeof(__u16), + .max_entries = 10, +}; + +__u16 proxy_port = 0; +__u16 wg_port = 0; + +bool read_port_settings() { + __u16 *value; + value = bpf_map_lookup_elem(&nb_wg_proxy_settings_map, &map_key_proxy_port); + if (!value) { + return false; + } + + proxy_port = *value; + + value = bpf_map_lookup_elem(&nb_wg_proxy_settings_map, &map_key_wg_port); + if (!value) { + return false; + } + wg_port = htons(*value); + + return true; +} + +int xdp_wg_proxy(struct iphdr *ip, struct udphdr *udp) { + if (proxy_port == 0 || wg_port == 0) { + if (!read_port_settings()){ + return XDP_PASS; + } + bpf_printk("proxy port: %d, wg port: %d", proxy_port, wg_port); + } + + // 2130706433 = 127.0.0.1 + if (ip->daddr != htonl(2130706433)) { + return XDP_PASS; + } + + if (udp->source != wg_port){ + return XDP_PASS; + } + + __be16 new_src_port = udp->dest; + __be16 new_dst_port = htons(proxy_port); + udp->dest = new_dst_port; + udp->source = new_src_port; + return XDP_PASS; +} \ No newline at end of file diff --git a/client/internal/ebpf/ebpf/wg_proxy_linux.go b/client/internal/ebpf/ebpf/wg_proxy_linux.go new file mode 100644 index 000000000..4e0df7329 --- /dev/null +++ b/client/internal/ebpf/ebpf/wg_proxy_linux.go @@ -0,0 +1,41 @@ +package ebpf + +import log "github.com/sirupsen/logrus" + +const ( + mapKeyProxyPort uint32 = 0 + mapKeyWgPort uint32 = 1 +) + +func (tf *GeneralManager) LoadWgProxy(proxyPort, wgPort int) error { + log.Debugf("load ebpf WG proxy") + tf.lock.Lock() + defer tf.lock.Unlock() + + err := tf.loadXdp() + if err != nil { + return err + } + + err = tf.bpfObjs.NbWgProxySettingsMap.Put(mapKeyProxyPort, uint16(proxyPort)) + if err != nil { + return err + } + + err = tf.bpfObjs.NbWgProxySettingsMap.Put(mapKeyWgPort, uint16(wgPort)) + if err != nil { + return err + } + + tf.setFeatureFlag(featureFlagWGProxy) + err = tf.bpfObjs.NbFeatures.Put(mapKeyFeatures, tf.featureFlags) + if err != nil { + return err + } + return nil +} + +func (tf *GeneralManager) FreeWGProxy() error { + log.Debugf("free ebpf WG proxy") + return tf.unsetFeatureFlag(featureFlagWGProxy) +} diff --git a/client/internal/ebpf/instantiater_linux.go b/client/internal/ebpf/instantiater_linux.go new file mode 100644 index 000000000..20d8145b4 --- /dev/null +++ b/client/internal/ebpf/instantiater_linux.go @@ -0,0 +1,15 @@ +//go:build !android + +package ebpf + +import ( + "github.com/netbirdio/netbird/client/internal/ebpf/ebpf" + "github.com/netbirdio/netbird/client/internal/ebpf/manager" +) + +// GetEbpfManagerInstance is a wrapper function. This encapsulation is required because if the code import the internal +// ebpf package the Go compiler will include the object files. But it is not supported on Android. It can cause instant +// panic on older Android version. +func GetEbpfManagerInstance() manager.Manager { + return ebpf.GetEbpfManagerInstance() +} diff --git a/client/internal/ebpf/instantiater_nonlinux.go b/client/internal/ebpf/instantiater_nonlinux.go new file mode 100644 index 000000000..b7c38733a --- /dev/null +++ b/client/internal/ebpf/instantiater_nonlinux.go @@ -0,0 +1,10 @@ +//go:build !linux || android + +package ebpf + +import "github.com/netbirdio/netbird/client/internal/ebpf/manager" + +// GetEbpfManagerInstance return error because ebpf is not supported on all os +func GetEbpfManagerInstance() manager.Manager { + panic("unsupported os") +} diff --git a/client/internal/ebpf/manager/manager.go b/client/internal/ebpf/manager/manager.go new file mode 100644 index 000000000..af10142d5 --- /dev/null +++ b/client/internal/ebpf/manager/manager.go @@ -0,0 +1,9 @@ +package manager + +// Manager is used to load multiple eBPF programs. E.g., current DNS programs and WireGuard proxy +type Manager interface { + LoadDNSFwd(ip string, dnsPort int) error + FreeDNSFwd() error + LoadWgProxy(proxyPort, wgPort int) error + FreeWGProxy() error +} diff --git a/client/internal/wgproxy/ebpf/bpf_bpfeb.o b/client/internal/wgproxy/ebpf/bpf_bpfeb.o deleted file mode 100644 index 82d7fc35b4c64f55f312a2f8c963a72e40cec6b7..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6264 zcmb_gU2K$D89v{1c42>T3kA1PVonioicmW33RNu5;<8(k8oF#rSjBuyJ6}7KPG^Rh zDGV+fdf_G**95bx>4k}G!Uc&YV%7vMH1xs?61^a1HDsen;}zw@CdPw8v;mKx`i|C@1f&D(dYvbo_YvyKCAb*e&O$l$MGleqq|R9|39*wx|?4=UTLf4@90)(QsH=< z{B*b4y`DUrm&*gj`(^64=U*f0w7K(^4&!}-zk{?W<<_^G+iv-{k#D`Hpr+qjeMAoZ zSL@BMkiIypYj3=_`jO-BDeTvzA=cnObo@)YRwZ4lS6kq(kS-nOTdBLBvHd3aU0q#W zxvdS}Df50Zzt`q5|7ojtH!J#srr&p9NEylbKDC9M^I7#Q`&ldBd~%3LsndgsE3+AT z&l{}Ge%hMvQ=b#O1)KLusZZ@9?zQ$LO)CxC{sYVJ{`Tk9CUV03mW(y#m$mCV*JJ&C zppAX^+&0U^$_D{h8?HN2! zO@AT4hvAIp9J@LFb@XPv4zQ6X&f&M52=MpqcmGd3c6Hj zs_AA6PMXTg6JwI6Jzl2KtS?MAwH=P$2Q!x(WXsRbs26E2@!`*p`DuO*kKr!G`W1-*~?U@DNSjOFxa%GJtjKL&Q-sFNO1o;ss%tnOA` zXK1Xys2ip4B-I%=@1$8Q7V7;UozENFl zRx6Fn@yW4xB0e$E@3KYJQtiM|qc>eGQ}x8e#Ho{$C*#vEO@wNNe4$Wx?#h{=?C{{> z!9y<$53g~EfgL#7s4mo}6Q|qAxa`Up<9I5`9@3TMVq9o6Wf^6Zokfwe;wj^FI%N~4 zkTjY$t$x%|+of4A zB@N?XrkuyU;%)FXmxIDV$sXz)ltvMU)P5_?3{jrge!7^iL;1y1r(S$Le(l6JkH?d* zPaKb96_-oXNu`mffvE=WuX>gHKshF6)llx+s^$LMjUG?x2A+Gu;~mZ_#@(T~&ZF;n zn4c?t3#fx{oM{?wd^cPCJhP%8M#*!TzD5N9SgE&pujf2WKW7C0LaBFYzbF{?OSHc# z_&-YB;Jtm#!_>cJafI*0EsWWWRfET-fq$UXuSwGv@%u_G3*E-7HO8#GJ&&4TH;OO0S36Y~?izH8Nc@a#Fb=#xV%3jE8 zf+v;wGqH>}bw#O1qrg{1PE7yT1h1&*S$vA4TOOwTj>TQiplrhUqGbz7Cc`cy{npHy@fbxLYP z@V}56?2A;+V(f#c5P{yXVMDt*!`JyOfa5b4?FRjH8+3@@fpq+_4qN@E!;gv{{|_C0 zD(CTE#i4L<9eWu>8IzRy&aqkdsomg94cq2(9Q(m%4*f>Ihdq5aUiM$}^j&=R5ys>6 z#8+wp#5>TJQdQV%^%&dby&E*vw2gD1Y3t-?K>Pa4lhnMR>=|c$=k%%T+h>4%`4}*D zoId^g@fU%8`z2ss|0b{>pL^N!Z}zT-&0h8VrI;^b$yRo85?cy)rHSY8 zI55j$r(Xcg@d;M@L33^ne-$)s9p?J@@^b;M1$Z&Q*8+S!z_$Xt9N_x_ehBQ3@6lQu zac-VG>J2dS<;zC`JQmts z=e4;*n%~a6o7i&Z-F`hcBIG&r-FV!0!r7Bo#N);j#CL4-{eW7#IeRjA*KHJ^`;zTC z9>#L_owFzJ=XlzIKzprgLV(kEGUtbwcdxS#)}iZ+81TF6H{bph2%P`WzAQ$*K3Oi0 zq(8*t;t!G0X-R?-GUBdvCd}mp9|X0-wj}YJw4Se_bo7WAZ=Vd=i>MN7mSVPZ~y=R diff --git a/client/internal/wgproxy/ebpf/bpf_bpfel.o b/client/internal/wgproxy/ebpf/bpf_bpfel.o deleted file mode 100644 index 2fc0ff8fca8b8aa1a6bfd2b6f7079195dfd478c6..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6264 zcmbtYU2I%O6`pn6#%^QBKQ*=72;HPfSew}KPe=^G#tBX9NJiPFw&X(4^?LW(yYhN> zv%5}gk}7@xDTpEk6_6j0NTobPDiT3OiXTAn0}qzqAwmL0Rgf&60zA+s$O|&xH*;n^ zp4e5S9{KK^@0>H|oH_Gz=laIEu{UdCF=Z*H{;Rf-l`1v($`(JO*%8$Y%XZJWy=Kdf z+z;^3;0uQF$B}I|^0rv5 zW_Ht0#18MM_s3=H(5Aks?W^P8YvS17xmOM6$J^cCapTXV`ApsB&KQ69G5Fv}TBSOk zZc;-3S7(DF`9i7kwQC)pR^@;JX5s`6r~ut@n*P%K>Q*wFH9nG)3XL(}es z-;K@9&HGpf2O#3OIrYhDlVCDw>>jk5cJCggy1;h7ZYJV#%Ra5R+}fbqS?KG)U3gtX zHL<5zF24 zw$;q|5>j|u0XOwZoY+u4dV*OY<)|a(q{RODE5@YI@|H6JU zrfM7=`TX0IjnW4CLU~Dg$D%H^%pS6RBB!o~JE9E&)V-Ge4j7&~bx`;Ke3$Se;770QuX(kAS}<{1o_%@XO#!;6aUz_+gwM8C=Kp^Ap3z zSYyubZLuX!+mAf_vDj0G@vtH3)y6!znrX5sc5DJXe=@kS4Lm#j1sE~)p(sb6UNf>i zH2;ZXV|It}ZFY#^{|+1Y(|+U#(eeN3$j!-%dLyn1&lU}7kc%)ZN0mx8KU-Ek zm1wb|dXAS%spHg?qGB$Yih5Ei09{Gu7NTG%m@dwQ^T}eED-`AyieZ$C=Ho~6)8WO0=AZ% zEKX9f1n+b4L!WW@96Ar{WGkWB|&tOrYhTE&f|GoCBtXFHKt_psHTR*pHN z?{02qK8Uk(@a3SK{a#dOFa4U|^d=B-TA#E-aK%k+m(yZFMEbFr1G1P6qktYM3rpWEv&Qtx1u+;uvFh zT4Nn19hEDltZvjH?bzBndYWT?+_0zBvwO5ElB|{8_UULwRhEj8;-2R0DZM)lJD*&{ za;cQEQCT~f$t7{FI17B+;h?Zl`c7_blyU}})VwS82ysm@Z7P#2!T8L@i)Y>q-#P#7 zbK&^AiF0A7!dx~L<;#)knJ(k}Diyd6loeuBxTEmh8t~i$9hT<~4&HSh_#nQMI`HsJ zCh;KuA%Mkvhx1)P8@}U60iF|h8gMJ>70mli-F4iT1^*mv%lk@=IJgOTQt+?QHvCkn zWx=}tic&WP{}1i{J4)Sha0Bo?!L7^@fVl>ckpetF23Ef_<0AeTxNgnpEN%xLg^t*c zqbS&p14Gk#JC0=sa~w!j>+Lw!1lw_}3%28U=wOcHQ^9{je)Ui$&<2oMniOpHymtgp z9>KgF-N2nnFz-r7A1k#e_y+LMQ`FJH^}x3TZ$RIOThl!U*8{H#ei;j(0T-t=2iF6y z3+|`BH>MssxE}bC;7iaqj>Oc4gWG}mLTCQe*8|rH{yy+)<5;7De*+vOVk$29ufWZh zVro$Ee}MO+I3t4fI!MF+XV4x7_vjBA*XZ@+sjvrXeUbX3melHX<0^kIa=8~)Reu&a z^K|KLfEFj8HBsi!Hdk+Wa#w-uu$J?-Xvk)E4A6WL7`Y8|O3PoYj=ki}f57Bin(!K^EBlY^Nb*DVRlRM+E# zgBfqkvFCj}?O=|b^Gc$B=J0g_5HWow2tZaprO_5HdU)Bxw%)e?ZBKvC!)qRX=;0?0 zW_~zoG}(OXq@EVHdpPdlK@X357+Zgvf4lFj|3y!K)59wszU$#t53hUpk%ymn_^F5O z`my=h?`@0ida;=29LeHQ4^MhHVcEKo<1`*pX>J*=*@lN)3@Zq_aB}@mQz*A zZ+6PI-UvoDVj8t3t!;RwxWsQ~_nfx6LHS6Y5QdDZUm$XWt*RCE5N4 z;k*O0Q~VW)e_IS0k9)xOxAA`k8RN?`1ZOBAMYZ-;P9H9f0kFS62a_Vu|08fZ`2F94 zjxqfHIT85wex9G8FXQ*)amn^S;0nRhmSkI8nRm?p<}UGVKkWoRtBx-#WBb1J|Bn49 zlPLZ1{1)<^=5J6YXiEAK@_lFL!_MDpkoh^W%M-p#*?I-@CpUu6H2zQq+bI<)2a+>A;4iwDH@ul~Vr)+SlQB diff --git a/client/internal/wgproxy/ebpf/loader.go b/client/internal/wgproxy/ebpf/loader.go deleted file mode 100644 index e154c0d9e..000000000 --- a/client/internal/wgproxy/ebpf/loader.go +++ /dev/null @@ -1,84 +0,0 @@ -//go:build linux && !android - -package ebpf - -import ( - _ "embed" - "net" - - "github.com/cilium/ebpf/link" - "github.com/cilium/ebpf/rlimit" -) - -const ( - mapKeyProxyPort uint32 = 0 - mapKeyWgPort uint32 = 1 -) - -//go:generate go run github.com/cilium/ebpf/cmd/bpf2go -cc clang-14 bpf src/portreplace.c -- - -// EBPF is a wrapper for eBPF program -type EBPF struct { - link link.Link -} - -// NewEBPF create new EBPF instance -func NewEBPF() *EBPF { - return &EBPF{} -} - -// Load load ebpf program -func (l *EBPF) Load(proxyPort, wgPort int) error { - // it required for Docker - err := rlimit.RemoveMemlock() - if err != nil { - return err - } - - ifce, err := net.InterfaceByName("lo") - if err != nil { - return err - } - - // Load pre-compiled programs into the kernel. - objs := bpfObjects{} - err = loadBpfObjects(&objs, nil) - if err != nil { - return err - } - defer func() { - _ = objs.Close() - }() - - err = objs.NbWgProxySettingsMap.Put(mapKeyProxyPort, uint16(proxyPort)) - if err != nil { - return err - } - - err = objs.NbWgProxySettingsMap.Put(mapKeyWgPort, uint16(wgPort)) - if err != nil { - return err - } - - defer func() { - _ = objs.NbWgProxySettingsMap.Close() - }() - - l.link, err = link.AttachXDP(link.XDPOptions{ - Program: objs.NbWgProxy, - Interface: ifce.Index, - }) - if err != nil { - return err - } - - return err -} - -// Free ebpf program -func (l *EBPF) Free() error { - if l.link != nil { - return l.link.Close() - } - return nil -} diff --git a/client/internal/wgproxy/ebpf/loader_test.go b/client/internal/wgproxy/ebpf/loader_test.go deleted file mode 100644 index 6ce323e70..000000000 --- a/client/internal/wgproxy/ebpf/loader_test.go +++ /dev/null @@ -1,18 +0,0 @@ -//go:build linux - -package ebpf - -import ( - "testing" -) - -func Test_newEBPF(t *testing.T) { - ebpf := NewEBPF() - err := ebpf.Load(1234, 51892) - defer func() { - _ = ebpf.Free() - }() - if err != nil { - t.Errorf("%s", err) - } -} diff --git a/client/internal/wgproxy/ebpf/src/portreplace.c b/client/internal/wgproxy/ebpf/src/portreplace.c deleted file mode 100644 index dc95ee53f..000000000 --- a/client/internal/wgproxy/ebpf/src/portreplace.c +++ /dev/null @@ -1,90 +0,0 @@ -#include -#include // ETH_P_IP -#include -#include -#include -#include -#include - -#define bpf_printk(fmt, ...) \ - ({ \ - char ____fmt[] = fmt; \ - bpf_trace_printk(____fmt, sizeof(____fmt), ##__VA_ARGS__); \ - }) - -const __u32 map_key_proxy_port = 0; -const __u32 map_key_wg_port = 1; - -struct bpf_map_def SEC("maps") nb_wg_proxy_settings_map = { - .type = BPF_MAP_TYPE_ARRAY, - .key_size = sizeof(__u32), - .value_size = sizeof(__u16), - .max_entries = 10, -}; - -__u16 proxy_port = 0; -__u16 wg_port = 0; - -bool read_port_settings() { - __u16 *value; - value = bpf_map_lookup_elem(&nb_wg_proxy_settings_map, &map_key_proxy_port); - if(!value) { - return false; - } - - proxy_port = *value; - - value = bpf_map_lookup_elem(&nb_wg_proxy_settings_map, &map_key_wg_port); - if(!value) { - return false; - } - wg_port = *value; - - return true; -} - -SEC("xdp") -int nb_wg_proxy(struct xdp_md *ctx) { - if(proxy_port == 0 || wg_port == 0) { - if(!read_port_settings()){ - return XDP_PASS; - } - bpf_printk("proxy port: %d, wg port: %d", proxy_port, wg_port); - } - - void *data = (void *)(long)ctx->data; - void *data_end = (void *)(long)ctx->data_end; - struct ethhdr *eth = data; - struct iphdr *ip = (data + sizeof(struct ethhdr)); - struct udphdr *udp = (data + sizeof(struct ethhdr) + sizeof(struct iphdr)); - - // return early if not enough data - if (data + sizeof(struct ethhdr) + sizeof(struct iphdr) + sizeof(struct udphdr) > data_end){ - return XDP_PASS; - } - - // skip non IPv4 packages - if (eth->h_proto != htons(ETH_P_IP)) { - return XDP_PASS; - } - - if (ip->protocol != IPPROTO_UDP) { - return XDP_PASS; - } - - // 2130706433 = 127.0.0.1 - if (ip->daddr != htonl(2130706433)) { - return XDP_PASS; - } - - if (udp->source != htons(wg_port)){ - return XDP_PASS; - } - - __be16 new_src_port = udp->dest; - __be16 new_dst_port = htons(proxy_port); - udp->dest = new_dst_port; - udp->source = new_src_port; - return XDP_PASS; -} -char _license[] SEC("license") = "GPL"; \ No newline at end of file diff --git a/client/internal/wgproxy/proxy_ebpf.go b/client/internal/wgproxy/proxy_ebpf.go index 8be6b0c19..ff8ff665b 100644 --- a/client/internal/wgproxy/proxy_ebpf.go +++ b/client/internal/wgproxy/proxy_ebpf.go @@ -12,15 +12,15 @@ import ( "github.com/google/gopacket" "github.com/google/gopacket/layers" - log "github.com/sirupsen/logrus" - ebpf2 "github.com/netbirdio/netbird/client/internal/wgproxy/ebpf" + "github.com/netbirdio/netbird/client/internal/ebpf" + ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager" ) // WGEBPFProxy definition for proxy with EBPF support type WGEBPFProxy struct { - ebpf *ebpf2.EBPF + ebpfManager ebpfMgr.Manager lastUsedPort uint16 localWGListenPort int @@ -36,7 +36,7 @@ func NewWGEBPFProxy(wgPort int) *WGEBPFProxy { log.Debugf("instantiate ebpf proxy") wgProxy := &WGEBPFProxy{ localWGListenPort: wgPort, - ebpf: ebpf2.NewEBPF(), + ebpfManager: ebpf.GetEbpfManagerInstance(), lastUsedPort: 0, turnConnStore: make(map[uint16]net.Conn), } @@ -56,7 +56,7 @@ func (p *WGEBPFProxy) Listen() error { return err } - err = p.ebpf.Load(wgPorxyPort, p.localWGListenPort) + err = p.ebpfManager.LoadWgProxy(wgPorxyPort, p.localWGListenPort) if err != nil { return err } @@ -110,7 +110,7 @@ func (p *WGEBPFProxy) Free() error { err1 = p.conn.Close() } - err2 = p.ebpf.Free() + err2 = p.ebpfManager.FreeWGProxy() if p.rawConn != nil { err3 = p.rawConn.Close() } diff --git a/iface/iface.go b/iface/iface.go index 58167d1e8..55891d047 100644 --- a/iface/iface.go +++ b/iface/iface.go @@ -112,7 +112,7 @@ func (w *WGIface) Close() error { return w.tun.Close() } -// SetFilter sets packet filters for the userspace impelemntation +// SetFilter sets packet filters for the userspace implementation func (w *WGIface) SetFilter(filter PacketFilter) error { w.mu.Lock() defer w.mu.Unlock() From 7682fe2e459a368bacbf6fb5f2d57d35f1d51eaf Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Tue, 5 Sep 2023 23:04:14 +0200 Subject: [PATCH 30/42] Account ephemeral setup keys metrics (#1128) --- management/server/metrics/selfhosted.go | 52 ++++++++++++++----------- 1 file changed, 30 insertions(+), 22 deletions(-) diff --git a/management/server/metrics/selfhosted.go b/management/server/metrics/selfhosted.go index b2402118d..696df5f3c 100644 --- a/management/server/metrics/selfhosted.go +++ b/management/server/metrics/selfhosted.go @@ -157,28 +157,30 @@ func (w *Worker) generatePayload(apiKey string) pushPayload { func (w *Worker) generateProperties() properties { var ( - uptime float64 - accounts int - expirationEnabled int - users int - serviceUsers int - pats int - peers int - peersSSHEnabled int - setupKeysUsage int - activePeersLastDay int - osPeers map[string]int - userPeers int - rules int - rulesProtocol map[string]int - rulesDirection map[string]int - groups int - routes int - nameservers int - uiClient int - version string - peerActiveVersions []string - osUIClients map[string]int + uptime float64 + accounts int + expirationEnabled int + users int + serviceUsers int + pats int + peers int + peersSSHEnabled int + setupKeysUsage int + ephemeralPeersSKs int + ephemeralPeersSKUsage int + activePeersLastDay int + osPeers map[string]int + userPeers int + rules int + rulesProtocol map[string]int + rulesDirection map[string]int + groups int + routes int + nameservers int + uiClient int + version string + peerActiveVersions []string + osUIClients map[string]int ) start := time.Now() metricsProperties := make(properties) @@ -224,6 +226,10 @@ func (w *Worker) generateProperties() properties { for _, key := range account.SetupKeys { setupKeysUsage = setupKeysUsage + key.UsedTimes + if key.Ephemeral { + ephemeralPeersSKs++ + ephemeralPeersSKUsage = ephemeralPeersSKUsage + key.UsedTimes + } } for _, peer := range account.Peers { @@ -269,6 +275,8 @@ func (w *Worker) generateProperties() properties { metricsProperties["peers_ssh_enabled"] = peersSSHEnabled metricsProperties["peers_login_expiration_enabled"] = expirationEnabled metricsProperties["setup_keys_usage"] = setupKeysUsage + metricsProperties["ephemeral_peers_setup_keys"] = ephemeralPeersSKs + metricsProperties["ephemeral_peers_setup_keys_usage"] = ephemeralPeersSKUsage metricsProperties["active_peers_last_day"] = activePeersLastDay metricsProperties["user_peers"] = userPeers metricsProperties["rules"] = rules From fa4b8c1d422a51bdba74d66bdb3c373050363a3b Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 6 Sep 2023 10:40:45 +0200 Subject: [PATCH 31/42] Update ephemeral field on the API response (#1129) --- management/server/http/setupkeys_handler.go | 1 + management/server/http/setupkeys_handler_test.go | 11 +++++++---- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/management/server/http/setupkeys_handler.go b/management/server/http/setupkeys_handler.go index 392cebdbd..cddae672c 100644 --- a/management/server/http/setupkeys_handler.go +++ b/management/server/http/setupkeys_handler.go @@ -216,5 +216,6 @@ func toResponseBody(key *server.SetupKey) *api.SetupKey { AutoGroups: key.AutoGroups, UpdatedAt: key.UpdatedAt, UsageLimit: key.UsageLimit, + Ephemeral: key.Ephemeral, } } diff --git a/management/server/http/setupkeys_handler_test.go b/management/server/http/setupkeys_handler_test.go index afc9deb15..d931a5e0b 100644 --- a/management/server/http/setupkeys_handler_test.go +++ b/management/server/http/setupkeys_handler_test.go @@ -51,10 +51,12 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup }, user, nil }, CreateSetupKeyFunc: func(_ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string, - _ int, _ string, _ bool, + _ int, _ string, ephemeral bool, ) (*server.SetupKey, error) { if keyName == newKey.Name || typ != newKey.Type { - return newKey, nil + nk := newKey.Copy() + nk.Ephemeral = ephemeral + return nk, nil } return nil, fmt.Errorf("failed creating setup key") }, @@ -99,7 +101,7 @@ func TestSetupKeysHandlers(t *testing.T) { adminUser := server.NewAdminUser("test_user") newSetupKey := server.GenerateSetupKey(newSetupKeyName, server.SetupKeyReusable, 0, []string{"group-1"}, - server.SetupKeyUnlimitedUsage, false) + server.SetupKeyUnlimitedUsage, true) updatedDefaultSetupKey := defaultSetupKey.Copy() updatedDefaultSetupKey.AutoGroups = []string{"group-1"} updatedDefaultSetupKey.Name = updatedSetupKeyName @@ -143,7 +145,7 @@ func TestSetupKeysHandlers(t *testing.T) { requestType: http.MethodPost, requestPath: "/api/setup-keys", requestBody: bytes.NewBuffer( - []byte(fmt.Sprintf("{\"name\":\"%s\",\"type\":\"%s\",\"expires_in\":86400}", newSetupKey.Name, newSetupKey.Type))), + []byte(fmt.Sprintf("{\"name\":\"%s\",\"type\":\"%s\",\"expires_in\":86400, \"ephemeral\":true}", newSetupKey.Name, newSetupKey.Type))), expectedStatus: http.StatusOK, expectedBody: true, expectedSetupKey: toResponseBody(newSetupKey), @@ -229,4 +231,5 @@ func assertKeys(t *testing.T, got *api.SetupKey, expected *api.SetupKey) { assert.Equal(t, got.UsedTimes, expected.UsedTimes) assert.Equal(t, got.Revoked, expected.Revoked) assert.ElementsMatch(t, got.AutoGroups, expected.AutoGroups) + assert.Equal(t, got.Ephemeral, expected.Ephemeral) } From 5c8541ef4245b0ee0cfad621251378b41adebf3b Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Fri, 8 Sep 2023 18:24:19 +0200 Subject: [PATCH 32/42] Set not found ebpf log to Info (#1134) added an additional log event --- client/internal/wgproxy/proxy_ebpf.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/client/internal/wgproxy/proxy_ebpf.go b/client/internal/wgproxy/proxy_ebpf.go index ff8ff665b..6ca19c973 100644 --- a/client/internal/wgproxy/proxy_ebpf.go +++ b/client/internal/wgproxy/proxy_ebpf.go @@ -135,6 +135,7 @@ func (p *WGEBPFProxy) proxyToLocal(endpointPort uint16, remoteConn net.Conn) { log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err) } p.removeTurnConn(endpointPort) + log.Infof("stop forward turn packages to port: %d. error: %s", endpointPort, err) return } err = p.sendPkg(buf[:n], endpointPort) @@ -158,7 +159,7 @@ func (p *WGEBPFProxy) proxyToRemote() { conn, ok := p.turnConnStore[uint16(addr.Port)] p.turnConnMutex.Unlock() if !ok { - log.Errorf("turn conn not found by port: %d", addr.Port) + log.Infof("turn conn not found by port: %d", addr.Port) continue } From 30f1c54ed1dfe43adea2498afa7fa9db904e1518 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Fri, 8 Sep 2023 19:28:34 +0200 Subject: [PATCH 33/42] Fix: docker test for infrastructure files (#1136) * Fix: docker test for infrastructure files * Fix: docker test for infrastructure files --- .github/workflows/test-infrastructure-files.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-infrastructure-files.yml b/.github/workflows/test-infrastructure-files.yml index d196ce0e0..f7d6766fc 100644 --- a/.github/workflows/test-infrastructure-files.yml +++ b/.github/workflows/test-infrastructure-files.yml @@ -120,7 +120,7 @@ jobs: - name: test running containers run: | - count=$(docker compose ps --format json | jq '.[] | select(.Project | contains("infrastructure_files")) | .State' | grep -c running) + count=$(docker compose ps --format json | jq '. | select(.Name | contains("infrastructure_files")) | .State' | grep -c running) test $count -eq 4 working-directory: infrastructure_files From bb791d59f3cf55437d2e7749f9df2759ae019b80 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Fri, 8 Sep 2023 21:08:02 +0300 Subject: [PATCH 34/42] update check for linux running desktop (#1137) --- client/cmd/login.go | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/client/cmd/login.go b/client/cmd/login.go index 794a599fd..a5cc3215c 100644 --- a/client/cmd/login.go +++ b/client/cmd/login.go @@ -232,16 +232,7 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string) { // isLinuxRunningDesktop checks if a Linux OS is running desktop environment. func isLinuxRunningDesktop() bool { - for _, env := range os.Environ() { - values := strings.Split(env, "=") - if len(values) == 2 { - key, value := values[0], values[1] - if key == "XDG_CURRENT_DESKTOP" && value != "" { - return true - } - } - } - return false + return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != "" } // isPKCEFlow determines if the PKCE flow is active or not, From 2135533f1d631f0d9bb6ea4edcf9d9a64bad9855 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Wed, 13 Sep 2023 17:36:24 +0200 Subject: [PATCH 35/42] Fix Android build (#1142) The source code files related to the Android firewall had incorrect build tags. --- client/internal/acl/manager_create.go | 2 +- client/internal/acl/manager_create_linux.go | 2 ++ client/internal/checkfw/check.go | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/client/internal/acl/manager_create.go b/client/internal/acl/manager_create.go index c573d2c64..2fdca02ae 100644 --- a/client/internal/acl/manager_create.go +++ b/client/internal/acl/manager_create.go @@ -1,4 +1,4 @@ -//go:build !linux +//go:build !linux || android package acl diff --git a/client/internal/acl/manager_create_linux.go b/client/internal/acl/manager_create_linux.go index 4342463d3..05b042351 100644 --- a/client/internal/acl/manager_create_linux.go +++ b/client/internal/acl/manager_create_linux.go @@ -1,3 +1,5 @@ +//go:build !android + package acl import ( diff --git a/client/internal/checkfw/check.go b/client/internal/checkfw/check.go index 59626cbc3..edfd8a5b3 100644 --- a/client/internal/checkfw/check.go +++ b/client/internal/checkfw/check.go @@ -1,3 +1,3 @@ -//go:build !linux +//go:build !linux || android package checkfw From 06bec61be9b6febd47cbca7fbe0aec7949c7df4b Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Wed, 13 Sep 2023 17:58:12 +0200 Subject: [PATCH 36/42] Add Android test build (#1144) Extend the CI with gomobile build. With this step we can validate that the code can run on Android --- .../workflows/android-build-validation.yml | 41 +++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 .github/workflows/android-build-validation.yml diff --git a/.github/workflows/android-build-validation.yml b/.github/workflows/android-build-validation.yml new file mode 100644 index 000000000..57cbbacb4 --- /dev/null +++ b/.github/workflows/android-build-validation.yml @@ -0,0 +1,41 @@ +name: Android build validation + +on: + push: + branches: + - main + pull_request: + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }} + cancel-in-progress: true + +jobs: + build: + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v3 + - name: Install Go + uses: actions/setup-go@v4 + with: + go-version: "1.20.x" + - name: Setup Android SDK + uses: android-actions/setup-android@v2 + - name: NDK Cache + id: ndk-cache + uses: actions/cache@v3 + with: + path: /usr/local/lib/android/sdk/ndk + key: ndk-cache-23.1.7779620 + - name: Setup NDK + run: /usr/local/lib/android/sdk/tools/bin/sdkmanager --install "ndk;23.1.7779620" + - name: install gomobile + run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20230531173138-3c911d8e3eda + - name: gomobile init + run: gomobile init + - name: build android nebtird lib + run: PATH=$PATH:$(go env GOPATH) gomobile bind -o $GITHUB_WORKSPACE/netbird.aar -javapkg=io.netbird.gomobile -ldflags="-X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/io.netbird.client/cache/wireguard -X github.com/netbirdio/netbird/version.version=buildtest" $GITHUB_WORKSPACE/client/android + env: + CGO_ENABLED: 0 + ANDROID_NDK_HOME: /usr/local/lib/android/sdk/ndk/23.1.7779620 \ No newline at end of file From 8d18190c9412268aa0279a4e89d4e02cbd1bdb46 Mon Sep 17 00:00:00 2001 From: Fabio Fantoni Date: Thu, 14 Sep 2023 15:58:28 +0200 Subject: [PATCH 37/42] fix NETBIRD_SIGNAL_PORT not working with custom port (#1143) (#1145) Use NETBIRD_SIGNAL_PORT variable instead of the static port for signal container in the docker-compose template to make setting of custom signal port working Signed-off-by: Fabio Fantoni --- infrastructure_files/docker-compose.yml.tmpl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/infrastructure_files/docker-compose.yml.tmpl b/infrastructure_files/docker-compose.yml.tmpl index 3e7eb1df6..b70e4cb6e 100644 --- a/infrastructure_files/docker-compose.yml.tmpl +++ b/infrastructure_files/docker-compose.yml.tmpl @@ -36,7 +36,7 @@ services: volumes: - $SIGNAL_VOLUMENAME:/var/lib/netbird ports: - - 10000:80 + - $NETBIRD_SIGNAL_PORT:80 # # port and command for Let's Encrypt validation # - 443:443 # command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"] From c34e53477f7ff2429cad13f0bdf6a87ca62bece0 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Thu, 14 Sep 2023 17:01:14 +0200 Subject: [PATCH 38/42] Add signal port tests to CI workflow (#1148) --- .github/workflows/test-infrastructure-files.yml | 2 ++ infrastructure_files/tests/setup.env | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test-infrastructure-files.yml b/.github/workflows/test-infrastructure-files.yml index f7d6766fc..2987c04b4 100644 --- a/.github/workflows/test-infrastructure-files.yml +++ b/.github/workflows/test-infrastructure-files.yml @@ -80,6 +80,7 @@ jobs: CI_NETBIRD_MGMT_IDP: "none" CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret + CI_NETBIRD_SIGNAL_PORT: 12345 run: | grep AUTH_CLIENT_ID docker-compose.yml | grep $CI_NETBIRD_AUTH_CLIENT_ID @@ -91,6 +92,7 @@ jobs: grep NETBIRD_MGMT_API_ENDPOINT docker-compose.yml | grep "$CI_NETBIRD_DOMAIN:33073" grep AUTH_REDIRECT_URI docker-compose.yml | grep $CI_NETBIRD_AUTH_REDIRECT_URI grep AUTH_SILENT_REDIRECT_URI docker-compose.yml | egrep 'AUTH_SILENT_REDIRECT_URI=$' + grep $CI_NETBIRD_SIGNAL_PORT docker-compose.yml | grep ':80' grep LETSENCRYPT_DOMAIN docker-compose.yml | egrep 'LETSENCRYPT_DOMAIN=$' grep NETBIRD_TOKEN_SOURCE docker-compose.yml | grep $CI_NETBIRD_TOKEN_SOURCE grep AuthUserIDClaim management.json | grep $CI_NETBIRD_AUTH_USER_ID_CLAIM diff --git a/infrastructure_files/tests/setup.env b/infrastructure_files/tests/setup.env index 6cf1acdf4..b0999eb51 100644 --- a/infrastructure_files/tests/setup.env +++ b/infrastructure_files/tests/setup.env @@ -21,4 +21,5 @@ NETBIRD_AUTH_USER_ID_CLAIM="email" NETBIRD_AUTH_DEVICE_AUTH_SCOPE="openid email" NETBIRD_MGMT_IDP=$CI_NETBIRD_MGMT_IDP NETBIRD_IDP_MGMT_CLIENT_ID=$CI_NETBIRD_IDP_MGMT_CLIENT_ID -NETBIRD_IDP_MGMT_CLIENT_SECRET=$CI_NETBIRD_IDP_MGMT_CLIENT_SECRET \ No newline at end of file +NETBIRD_IDP_MGMT_CLIENT_SECRET=$CI_NETBIRD_IDP_MGMT_CLIENT_SECRET +NETBIRD_SIGNAL_PORT=12345 \ No newline at end of file From 0be8c726018c9edc5730defe403da406e773364b Mon Sep 17 00:00:00 2001 From: Yury Gargay Date: Mon, 18 Sep 2023 12:25:12 +0200 Subject: [PATCH 39/42] Remove unused methods from AccountManager interface (#1149) This PR removes the following unused methods from the AccountManager interface: * `UpdateGroup` * `UpdateNameServerGroup` * `UpdateRoute` --- management/server/account.go | 19 -- management/server/group.go | 71 ---- management/server/http/groups_handler_test.go | 16 - .../server/http/nameservers_handler_test.go | 25 -- management/server/http/routes_handler_test.go | 33 -- management/server/mock_server/account_mock.go | 27 -- management/server/nameserver.go | 153 +-------- management/server/nameserver_test.go | 317 ------------------ management/server/route.go | 155 --------- management/server/route_test.go | 259 -------------- 10 files changed, 1 insertion(+), 1074 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 4c707af3a..a0d4568ec 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -80,7 +80,6 @@ type AccountManager interface { GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error) GetGroup(accountId, groupID string) (*Group, error) SaveGroup(accountID, userID string, group *Group) error - UpdateGroup(accountID string, groupID string, operations []GroupUpdateOperation) (*Group, error) DeleteGroup(accountId, userId, groupID string) error ListGroups(accountId string) ([]*Group, error) GroupAddPeer(accountId, groupID, peerID string) error @@ -93,13 +92,11 @@ type AccountManager interface { GetRoute(accountID, routeID, userID string) (*route.Route, error) CreateRoute(accountID string, prefix, peerID, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) SaveRoute(accountID, userID string, route *route.Route) error - UpdateRoute(accountID, routeID string, operations []RouteUpdateOperation) (*route.Route, error) DeleteRoute(accountID, routeID, userID string) error ListRoutes(accountID, userID string) ([]*route.Route, error) GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string) (*nbdns.NameServerGroup, error) SaveNameServerGroup(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error - UpdateNameServerGroup(accountID, nsGroupID, userID string, operations []NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) DeleteNameServerGroup(accountID, nsGroupID, userID string) error ListNameServerGroups(accountID string) ([]*nbdns.NameServerGroup, error) GetDNSDomain() string @@ -1605,19 +1602,3 @@ func newAccountWithId(accountID, userID, domain string) *Account { } return acc } - -func removeFromList(inputList []string, toRemove []string) []string { - toRemoveMap := make(map[string]struct{}) - for _, item := range toRemove { - toRemoveMap[item] = struct{}{} - } - - var resultList []string - for _, item := range inputList { - _, ok := toRemoveMap[item] - if !ok { - resultList = append(resultList, item) - } - } - return resultList -} diff --git a/management/server/group.go b/management/server/group.go index 5b1d2ac9f..697fe5d70 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -33,26 +33,6 @@ type Group struct { Peers []string } -const ( - // UpdateGroupName indicates a name update operation - UpdateGroupName GroupUpdateOperationType = iota - // InsertPeersToGroup indicates insert peers to group operation - InsertPeersToGroup - // RemovePeersFromGroup indicates a remove peers from group operation - RemovePeersFromGroup - // UpdateGroupPeers indicates a replacement of group peers list - UpdateGroupPeers -) - -// GroupUpdateOperationType operation type -type GroupUpdateOperationType int - -// GroupUpdateOperation operation object with type and values to be applied -type GroupUpdateOperation struct { - Type GroupUpdateOperationType - Values []string -} - // EventMeta returns activity event meta related to the group func (g *Group) EventMeta() map[string]any { return map[string]any{"name": g.Name} @@ -165,57 +145,6 @@ func difference(a, b []string) []string { return diff } -// UpdateGroup updates a group using a list of operations -func (am *DefaultAccountManager) UpdateGroup(accountID string, - groupID string, operations []GroupUpdateOperation, -) (*Group, error) { - unlock := am.Store.AcquireAccountLock(accountID) - defer unlock() - - account, err := am.Store.GetAccount(accountID) - if err != nil { - return nil, err - } - - groupToUpdate, ok := account.Groups[groupID] - if !ok { - return nil, status.Errorf(status.NotFound, "group with ID %s no longer exists", groupID) - } - - group := groupToUpdate.Copy() - - for _, operation := range operations { - switch operation.Type { - case UpdateGroupName: - group.Name = operation.Values[0] - case UpdateGroupPeers: - group.Peers = operation.Values - case InsertPeersToGroup: - sourceList := group.Peers - resultList := removeFromList(sourceList, operation.Values) - group.Peers = append(resultList, operation.Values...) - case RemovePeersFromGroup: - sourceList := group.Peers - resultList := removeFromList(sourceList, operation.Values) - group.Peers = resultList - } - } - - account.Groups[groupID] = group - - account.Network.IncSerial() - if err = am.Store.SaveAccount(account); err != nil { - return nil, err - } - - err = am.updateAccountPeers(account) - if err != nil { - return nil, err - } - - return group, nil -} - // DeleteGroup object of the peers func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string) error { unlock := am.Store.AcquireAccountLock(accountId) diff --git a/management/server/http/groups_handler_test.go b/management/server/http/groups_handler_test.go index 44603059a..ddb1233bf 100644 --- a/management/server/http/groups_handler_test.go +++ b/management/server/http/groups_handler_test.go @@ -53,22 +53,6 @@ func initGroupTestData(user *server.User, groups ...*server.Group) *GroupsHandle Issued: server.GroupIssuedAPI, }, nil }, - UpdateGroupFunc: func(_ string, groupID string, operations []server.GroupUpdateOperation) (*server.Group, error) { - var group server.Group - group.ID = groupID - for _, operation := range operations { - switch operation.Type { - case server.UpdateGroupName: - group.Name = operation.Values[0] - case server.UpdateGroupPeers, server.InsertPeersToGroup: - group.Peers = operation.Values - case server.RemovePeersFromGroup: - default: - return nil, fmt.Errorf("no operation") - } - } - return &group, nil - }, GetPeerByIPFunc: func(_ string, peerIP string) (*server.Peer, error) { for _, peer := range TestPeers { if peer.IP.String() == peerIP { diff --git a/management/server/http/nameservers_handler_test.go b/management/server/http/nameservers_handler_test.go index 75fcb4c1c..100f4b87a 100644 --- a/management/server/http/nameservers_handler_test.go +++ b/management/server/http/nameservers_handler_test.go @@ -88,31 +88,6 @@ func initNameserversTestData() *NameserversHandler { } return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID) }, - UpdateNameServerGroupFunc: func(accountID, nsGroupID, _ string, operations []server.NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) { - nsGroupToUpdate := baseExistingNSGroup.Copy() - if nsGroupID != nsGroupToUpdate.ID { - return nil, status.Errorf(status.NotFound, "nameserver group ID %s no longer exists", nsGroupID) - } - for _, operation := range operations { - switch operation.Type { - case server.UpdateNameServerGroupName: - nsGroupToUpdate.Name = operation.Values[0] - case server.UpdateNameServerGroupDescription: - nsGroupToUpdate.Description = operation.Values[0] - case server.UpdateNameServerGroupNameServers: - var parsedNSList []nbdns.NameServer - for _, nsURL := range operation.Values { - parsed, err := nbdns.ParseNameServerURL(nsURL) - if err != nil { - return nil, err - } - parsedNSList = append(parsedNSList, parsed) - } - nsGroupToUpdate.NameServers = parsedNSList - } - } - return nsGroupToUpdate, nil - }, GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { return testingNSAccount, testingAccount.Users["test_user"], nil }, diff --git a/management/server/http/routes_handler_test.go b/management/server/http/routes_handler_test.go index c4270284c..3f2b7b910 100644 --- a/management/server/http/routes_handler_test.go +++ b/management/server/http/routes_handler_test.go @@ -8,7 +8,6 @@ import ( "net/http" "net/http/httptest" "net/netip" - "strconv" "testing" "github.com/netbirdio/netbird/management/server/http/api" @@ -108,38 +107,6 @@ func initRoutesTestData() *RoutesHandler { IP: netip.MustParseAddr(existingPeerID).AsSlice(), }, nil }, - UpdateRouteFunc: func(_ string, routeID string, operations []server.RouteUpdateOperation) (*route.Route, error) { - routeToUpdate := baseExistingRoute - if routeID != routeToUpdate.ID { - return nil, status.Errorf(status.NotFound, "route %s no longer exists", routeID) - } - for _, operation := range operations { - switch operation.Type { - case server.UpdateRouteNetwork: - routeToUpdate.NetworkType, routeToUpdate.Network, _ = route.ParseNetwork(operation.Values[0]) - case server.UpdateRouteDescription: - routeToUpdate.Description = operation.Values[0] - case server.UpdateRouteNetworkIdentifier: - routeToUpdate.NetID = operation.Values[0] - case server.UpdateRoutePeer: - routeToUpdate.Peer = operation.Values[0] - if routeToUpdate.Peer == notFoundPeerID { - return nil, status.Errorf(status.InvalidArgument, "peer with ID %s not found", routeToUpdate.Peer) - } - case server.UpdateRouteMetric: - routeToUpdate.Metric, _ = strconv.Atoi(operation.Values[0]) - case server.UpdateRouteMasquerade: - routeToUpdate.Masquerade, _ = strconv.ParseBool(operation.Values[0]) - case server.UpdateRouteEnabled: - routeToUpdate.Enabled, _ = strconv.ParseBool(operation.Values[0]) - case server.UpdateRouteGroups: - routeToUpdate.Groups = operation.Values - default: - return nil, fmt.Errorf("no operation") - } - } - return routeToUpdate, nil - }, GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { return testingAccount, testingAccount.Users["test_user"], nil }, diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 4bfa922c7..24bf9f3c9 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -31,7 +31,6 @@ type MockAccountManager struct { AddPeerFunc func(setupKey string, userId string, peer *server.Peer) (*server.Peer, *server.NetworkMap, error) GetGroupFunc func(accountID, groupID string) (*server.Group, error) SaveGroupFunc func(accountID, userID string, group *server.Group) error - UpdateGroupFunc func(accountID string, groupID string, operations []server.GroupUpdateOperation) (*server.Group, error) DeleteGroupFunc func(accountID, userId, groupID string) error ListGroupsFunc func(accountID string) ([]*server.Group, error) GroupAddPeerFunc func(accountID, groupID, peerKey string) error @@ -54,7 +53,6 @@ type MockAccountManager struct { CreateRouteFunc func(accountID string, prefix, peer, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) GetRouteFunc func(accountID, routeID, userID string) (*route.Route, error) SaveRouteFunc func(accountID, userID string, route *route.Route) error - UpdateRouteFunc func(accountID string, routeID string, operations []server.RouteUpdateOperation) (*route.Route, error) DeleteRouteFunc func(accountID, routeID, userID string) error ListRoutesFunc func(accountID, userID string) ([]*route.Route, error) SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error) @@ -68,7 +66,6 @@ type MockAccountManager struct { GetNameServerGroupFunc func(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string) (*nbdns.NameServerGroup, error) SaveNameServerGroupFunc func(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error - UpdateNameServerGroupFunc func(accountID, nsGroupID, userID string, operations []server.NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) DeleteNameServerGroupFunc func(accountID, nsGroupID, userID string) error ListNameServerGroupsFunc func(accountID string) ([]*nbdns.NameServerGroup, error) CreateUserFunc func(accountID, userID string, key *server.UserInfo) (*server.UserInfo, error) @@ -267,14 +264,6 @@ func (am *MockAccountManager) SaveGroup(accountID, userID string, group *server. return status.Errorf(codes.Unimplemented, "method SaveGroup is not implemented") } -// UpdateGroup mock implementation of UpdateGroup from server.AccountManager interface -func (am *MockAccountManager) UpdateGroup(accountID string, groupID string, operations []server.GroupUpdateOperation) (*server.Group, error) { - if am.UpdateGroupFunc != nil { - return am.UpdateGroupFunc(accountID, groupID, operations) - } - return nil, status.Errorf(codes.Unimplemented, "method UpdateGroup not implemented") -} - // DeleteGroup mock implementation of DeleteGroup from server.AccountManager interface func (am *MockAccountManager) DeleteGroup(accountId, userId, groupID string) error { if am.DeleteGroupFunc != nil { @@ -435,14 +424,6 @@ func (am *MockAccountManager) SaveRoute(accountID, userID string, route *route.R return status.Errorf(codes.Unimplemented, "method SaveRoute is not implemented") } -// UpdateRoute mock implementation of UpdateRoute from server.AccountManager interface -func (am *MockAccountManager) UpdateRoute(accountID, ruleID string, operations []server.RouteUpdateOperation) (*route.Route, error) { - if am.UpdateRouteFunc != nil { - return am.UpdateRouteFunc(accountID, ruleID, operations) - } - return nil, status.Errorf(codes.Unimplemented, "method UpdateRoute not implemented") -} - // DeleteRoute mock implementation of DeleteRoute from server.AccountManager interface func (am *MockAccountManager) DeleteRoute(accountID, routeID, userID string) error { if am.DeleteRouteFunc != nil { @@ -533,14 +514,6 @@ func (am *MockAccountManager) SaveNameServerGroup(accountID, userID string, nsGr return nil } -// UpdateNameServerGroup mocks UpdateNameServerGroup of the AccountManager interface -func (am *MockAccountManager) UpdateNameServerGroup(accountID, nsGroupID, userID string, operations []server.NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) { - if am.UpdateNameServerGroupFunc != nil { - return am.UpdateNameServerGroupFunc(accountID, nsGroupID, userID, operations) - } - return nil, nil -} - // DeleteNameServerGroup mocks DeleteNameServerGroup of the AccountManager interface func (am *MockAccountManager) DeleteNameServerGroup(accountID, nsGroupID, userID string) error { if am.DeleteNameServerGroupFunc != nil { diff --git a/management/server/nameserver.go b/management/server/nameserver.go index eb2127945..7025388ba 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -3,7 +3,6 @@ package server import ( "errors" "regexp" - "strconv" "unicode/utf8" "github.com/miekg/dns" @@ -15,54 +14,7 @@ import ( "github.com/netbirdio/netbird/management/server/status" ) -const ( - // UpdateNameServerGroupName indicates a nameserver group name update operation - UpdateNameServerGroupName NameServerGroupUpdateOperationType = iota - // UpdateNameServerGroupDescription indicates a nameserver group description update operation - UpdateNameServerGroupDescription - // UpdateNameServerGroupNameServers indicates a nameserver group nameservers list update operation - UpdateNameServerGroupNameServers - // UpdateNameServerGroupGroups indicates a nameserver group' groups update operation - UpdateNameServerGroupGroups - // UpdateNameServerGroupEnabled indicates a nameserver group status update operation - UpdateNameServerGroupEnabled - // UpdateNameServerGroupPrimary indicates a nameserver group primary status update operation - UpdateNameServerGroupPrimary - // UpdateNameServerGroupDomains indicates a nameserver group' domains update operation - UpdateNameServerGroupDomains - - domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$` -) - -// NameServerGroupUpdateOperationType operation type -type NameServerGroupUpdateOperationType int - -func (t NameServerGroupUpdateOperationType) String() string { - switch t { - case UpdateNameServerGroupDescription: - return "UpdateNameServerGroupDescription" - case UpdateNameServerGroupName: - return "UpdateNameServerGroupName" - case UpdateNameServerGroupNameServers: - return "UpdateNameServerGroupNameServers" - case UpdateNameServerGroupGroups: - return "UpdateNameServerGroupGroups" - case UpdateNameServerGroupEnabled: - return "UpdateNameServerGroupEnabled" - case UpdateNameServerGroupPrimary: - return "UpdateNameServerGroupPrimary" - case UpdateNameServerGroupDomains: - return "UpdateNameServerGroupDomains" - default: - return "InvalidOperation" - } -} - -// NameServerGroupUpdateOperation operation object with type and values to be applied -type NameServerGroupUpdateOperation struct { - Type NameServerGroupUpdateOperationType - Values []string -} +const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$` // GetNameServerGroup gets a nameserver group object from account and nameserver group IDs func (am *DefaultAccountManager) GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) { @@ -172,109 +124,6 @@ func (am *DefaultAccountManager) SaveNameServerGroup(accountID, userID string, n return nil } -// UpdateNameServerGroup updates existing nameserver group with set of operations -func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID, userID string, operations []NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) { - - unlock := am.Store.AcquireAccountLock(accountID) - defer unlock() - - account, err := am.Store.GetAccount(accountID) - if err != nil { - return nil, err - } - - if len(operations) == 0 { - return nil, status.Errorf(status.InvalidArgument, "operations shouldn't be empty") - } - - nsGroupToUpdate, ok := account.NameServerGroups[nsGroupID] - if !ok { - return nil, status.Errorf(status.NotFound, "nameserver group ID %s no longer exists", nsGroupID) - } - - newNSGroup := nsGroupToUpdate.Copy() - - for _, operation := range operations { - valuesCount := len(operation.Values) - if valuesCount < 1 { - return nil, status.Errorf(status.InvalidArgument, "operation %s contains invalid number of values, it should be at least 1", operation.Type.String()) - } - - for _, value := range operation.Values { - if value == "" { - return nil, status.Errorf(status.InvalidArgument, "operation %s contains invalid empty string value", operation.Type.String()) - } - } - switch operation.Type { - case UpdateNameServerGroupDescription: - newNSGroup.Description = operation.Values[0] - case UpdateNameServerGroupName: - if valuesCount > 1 { - return nil, status.Errorf(status.InvalidArgument, "failed to parse name values, expected 1 value got %d", valuesCount) - } - err = validateNSGroupName(operation.Values[0], nsGroupID, account.NameServerGroups) - if err != nil { - return nil, err - } - newNSGroup.Name = operation.Values[0] - case UpdateNameServerGroupNameServers: - var nsList []nbdns.NameServer - for _, url := range operation.Values { - ns, err := nbdns.ParseNameServerURL(url) - if err != nil { - return nil, err - } - nsList = append(nsList, ns) - } - err = validateNSList(nsList) - if err != nil { - return nil, err - } - newNSGroup.NameServers = nsList - case UpdateNameServerGroupGroups: - err = validateGroups(operation.Values, account.Groups) - if err != nil { - return nil, err - } - newNSGroup.Groups = operation.Values - case UpdateNameServerGroupEnabled: - enabled, err := strconv.ParseBool(operation.Values[0]) - if err != nil { - return nil, status.Errorf(status.InvalidArgument, "failed to parse enabled %s, not boolean", operation.Values[0]) - } - newNSGroup.Enabled = enabled - case UpdateNameServerGroupPrimary: - primary, err := strconv.ParseBool(operation.Values[0]) - if err != nil { - return nil, status.Errorf(status.InvalidArgument, "failed to parse primary status %s, not boolean", operation.Values[0]) - } - newNSGroup.Primary = primary - case UpdateNameServerGroupDomains: - err = validateDomainInput(false, operation.Values) - if err != nil { - return nil, err - } - newNSGroup.Domains = operation.Values - } - } - - account.NameServerGroups[nsGroupID] = newNSGroup - - account.Network.IncSerial() - err = am.Store.SaveAccount(account) - if err != nil { - return nil, err - } - - err = am.updateAccountPeers(account) - if err != nil { - log.Error(err) - return newNSGroup.Copy(), status.Errorf(status.Internal, "failed to update peers after update nameserver %s", newNSGroup.Name) - } - - return newNSGroup.Copy(), nil -} - // DeleteNameServerGroup deletes nameserver group with nsGroupID func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID, userID string) error { diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 9d4425056..ab3edaed4 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -655,323 +655,6 @@ func TestSaveNameServerGroup(t *testing.T) { } } -func TestUpdateNameServerGroup(t *testing.T) { - nsGroupID := "testingNSGroup" - - existingNSGroup := &nbdns.NameServerGroup{ - ID: nsGroupID, - Name: "super", - Description: "super", - Primary: true, - NameServers: []nbdns.NameServer{ - { - IP: netip.MustParseAddr("1.1.1.1"), - NSType: nbdns.UDPNameServerType, - Port: nbdns.DefaultDNSPort, - }, - { - IP: netip.MustParseAddr("1.1.2.2"), - NSType: nbdns.UDPNameServerType, - Port: nbdns.DefaultDNSPort, - }, - }, - Groups: []string{group1ID}, - Enabled: true, - } - - testCases := []struct { - name string - existingNSGroup *nbdns.NameServerGroup - nsGroupID string - operations []NameServerGroupUpdateOperation - shouldCreate bool - errFunc require.ErrorAssertionFunc - expectedNSGroup *nbdns.NameServerGroup - }{ - { - name: "Should Config Single Property", - existingNSGroup: existingNSGroup, - nsGroupID: existingNSGroup.ID, - operations: []NameServerGroupUpdateOperation{ - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupName, - Values: []string{"superNew"}, - }, - }, - errFunc: require.NoError, - shouldCreate: true, - expectedNSGroup: &nbdns.NameServerGroup{ - ID: nsGroupID, - Name: "superNew", - Description: "super", - Primary: true, - NameServers: []nbdns.NameServer{ - { - IP: netip.MustParseAddr("1.1.1.1"), - NSType: nbdns.UDPNameServerType, - Port: nbdns.DefaultDNSPort, - }, - { - IP: netip.MustParseAddr("1.1.2.2"), - NSType: nbdns.UDPNameServerType, - Port: nbdns.DefaultDNSPort, - }, - }, - Groups: []string{group1ID}, - Enabled: true, - }, - }, - { - name: "Should Config Multiple Properties", - existingNSGroup: existingNSGroup, - nsGroupID: existingNSGroup.ID, - operations: []NameServerGroupUpdateOperation{ - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupName, - Values: []string{"superNew"}, - }, - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupDescription, - Values: []string{"superDescription"}, - }, - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupNameServers, - Values: []string{"udp://127.0.0.1:53", "udp://8.8.8.8:53"}, - }, - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupGroups, - Values: []string{group1ID, group2ID}, - }, - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupEnabled, - Values: []string{"false"}, - }, - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupPrimary, - Values: []string{"false"}, - }, - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupDomains, - Values: []string{validDomain}, - }, - }, - errFunc: require.NoError, - shouldCreate: true, - expectedNSGroup: &nbdns.NameServerGroup{ - ID: nsGroupID, - Name: "superNew", - Description: "superDescription", - Primary: false, - Domains: []string{validDomain}, - NameServers: []nbdns.NameServer{ - { - IP: netip.MustParseAddr("127.0.0.1"), - NSType: nbdns.UDPNameServerType, - Port: nbdns.DefaultDNSPort, - }, - { - IP: netip.MustParseAddr("8.8.8.8"), - NSType: nbdns.UDPNameServerType, - Port: nbdns.DefaultDNSPort, - }, - }, - Groups: []string{group1ID, group2ID}, - Enabled: false, - }, - }, - { - name: "Should Not Config On Invalid ID", - existingNSGroup: existingNSGroup, - nsGroupID: "nonExistingNSGroup", - errFunc: require.Error, - }, - { - name: "Should Not Config On Empty Operations", - existingNSGroup: existingNSGroup, - nsGroupID: existingNSGroup.ID, - operations: []NameServerGroupUpdateOperation{}, - errFunc: require.Error, - }, - { - name: "Should Not Config On Empty Values", - existingNSGroup: existingNSGroup, - nsGroupID: existingNSGroup.ID, - operations: []NameServerGroupUpdateOperation{ - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupName, - }, - }, - errFunc: require.Error, - }, - { - name: "Should Not Config On Empty String", - existingNSGroup: existingNSGroup, - nsGroupID: existingNSGroup.ID, - operations: []NameServerGroupUpdateOperation{ - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupName, - Values: []string{""}, - }, - }, - errFunc: require.Error, - }, - { - name: "Should Not Config On Invalid Name Large String", - existingNSGroup: existingNSGroup, - nsGroupID: existingNSGroup.ID, - operations: []NameServerGroupUpdateOperation{ - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupName, - Values: []string{"12345678901234567890qwertyuiopqwertyuiop1"}, - }, - }, - errFunc: require.Error, - }, - { - name: "Should Not Config On Invalid On Existing Name", - existingNSGroup: existingNSGroup, - nsGroupID: existingNSGroup.ID, - operations: []NameServerGroupUpdateOperation{ - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupName, - Values: []string{existingNSGroupName}, - }, - }, - errFunc: require.Error, - }, - { - name: "Should Not Config On Invalid On Multiple Name Values", - existingNSGroup: existingNSGroup, - nsGroupID: existingNSGroup.ID, - operations: []NameServerGroupUpdateOperation{ - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupName, - Values: []string{"nameOne", "nameTwo"}, - }, - }, - errFunc: require.Error, - }, - { - name: "Should Not Config On Invalid Boolean", - existingNSGroup: existingNSGroup, - nsGroupID: existingNSGroup.ID, - operations: []NameServerGroupUpdateOperation{ - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupEnabled, - Values: []string{"yes"}, - }, - }, - errFunc: require.Error, - }, - { - name: "Should Not Config On Invalid Nameservers Wrong Schema", - existingNSGroup: existingNSGroup, - nsGroupID: existingNSGroup.ID, - operations: []NameServerGroupUpdateOperation{ - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupNameServers, - Values: []string{"https://127.0.0.1:53"}, - }, - }, - errFunc: require.Error, - }, - { - name: "Should Not Config On Invalid Nameservers Wrong IP", - existingNSGroup: existingNSGroup, - nsGroupID: existingNSGroup.ID, - operations: []NameServerGroupUpdateOperation{ - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupNameServers, - Values: []string{"udp://8.8.8.300:53"}, - }, - }, - errFunc: require.Error, - }, - { - name: "Should Not Config On Large Number Of Nameservers", - existingNSGroup: existingNSGroup, - nsGroupID: existingNSGroup.ID, - operations: []NameServerGroupUpdateOperation{ - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupNameServers, - Values: []string{"udp://127.0.0.1:53", "udp://8.8.8.8:53", "udp://8.8.4.4:53"}, - }, - }, - errFunc: require.Error, - }, - { - name: "Should Not Config On Invalid GroupID", - existingNSGroup: existingNSGroup, - nsGroupID: existingNSGroup.ID, - operations: []NameServerGroupUpdateOperation{ - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupGroups, - Values: []string{"nonExistingGroupID"}, - }, - }, - errFunc: require.Error, - }, - { - name: "Should Not Config On Invalid Domains", - existingNSGroup: existingNSGroup, - nsGroupID: existingNSGroup.ID, - operations: []NameServerGroupUpdateOperation{ - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupDomains, - Values: []string{invalidDomain}, - }, - }, - errFunc: require.Error, - }, - { - name: "Should Not Config On Invalid Primary Status", - existingNSGroup: existingNSGroup, - nsGroupID: existingNSGroup.ID, - operations: []NameServerGroupUpdateOperation{ - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupPrimary, - Values: []string{"yes"}, - }, - }, - errFunc: require.Error, - }, - } - for _, testCase := range testCases { - t.Run(testCase.name, func(t *testing.T) { - am, err := createNSManager(t) - if err != nil { - t.Error("failed to create account manager") - } - - account, err := initTestNSAccount(t, am) - if err != nil { - t.Error("failed to init testing account") - } - - account.NameServerGroups[testCase.existingNSGroup.ID] = testCase.existingNSGroup - - err = am.Store.SaveAccount(account) - if err != nil { - t.Error("account should be saved") - } - - updatedRoute, err := am.UpdateNameServerGroup(account.Id, testCase.nsGroupID, userID, testCase.operations) - testCase.errFunc(t, err) - - if !testCase.shouldCreate { - return - } - - testCase.expectedNSGroup.ID = updatedRoute.ID - - if !testCase.expectedNSGroup.IsEqual(updatedRoute) { - t.Errorf("new nameserver group didn't match expected group:\nGot %#v\nExpected:%#v\n", updatedRoute, testCase.expectedNSGroup) - } - - }) - } -} - func TestDeleteNameServerGroup(t *testing.T) { nsGroupID := "testingNSGroup" diff --git a/management/server/route.go b/management/server/route.go index f51b7c2db..b232c2bb6 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -2,7 +2,6 @@ package server import ( "net/netip" - "strconv" "unicode/utf8" "github.com/netbirdio/netbird/management/proto" @@ -13,57 +12,6 @@ import ( log "github.com/sirupsen/logrus" ) -const ( - // UpdateRouteDescription indicates a route description update operation - UpdateRouteDescription RouteUpdateOperationType = iota - // UpdateRouteNetwork indicates a route IP update operation - UpdateRouteNetwork - // UpdateRoutePeer indicates a route peer update operation - UpdateRoutePeer - // UpdateRouteMetric indicates a route metric update operation - UpdateRouteMetric - // UpdateRouteMasquerade indicates a route masquerade update operation - UpdateRouteMasquerade - // UpdateRouteEnabled indicates a route enabled update operation - UpdateRouteEnabled - // UpdateRouteNetworkIdentifier indicates a route net ID update operation - UpdateRouteNetworkIdentifier - // UpdateRouteGroups indicates a group list update operation - UpdateRouteGroups -) - -// RouteUpdateOperationType operation type -type RouteUpdateOperationType int - -func (t RouteUpdateOperationType) String() string { - switch t { - case UpdateRouteDescription: - return "UpdateRouteDescription" - case UpdateRouteNetwork: - return "UpdateRouteNetwork" - case UpdateRoutePeer: - return "UpdateRoutePeer" - case UpdateRouteMetric: - return "UpdateRouteMetric" - case UpdateRouteMasquerade: - return "UpdateRouteMasquerade" - case UpdateRouteEnabled: - return "UpdateRouteEnabled" - case UpdateRouteNetworkIdentifier: - return "UpdateRouteNetworkIdentifier" - case UpdateRouteGroups: - return "UpdateRouteGroups" - default: - return "InvalidOperation" - } -} - -// RouteUpdateOperation operation object with type and values to be applied -type RouteUpdateOperation struct { - Type RouteUpdateOperationType - Values []string -} - // GetRoute gets a route object from account and route IDs func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*route.Route, error) { unlock := am.Store.AcquireAccountLock(accountID) @@ -241,109 +189,6 @@ func (am *DefaultAccountManager) SaveRoute(accountID, userID string, routeToSave return nil } -// UpdateRoute updates existing route with set of operations -func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operations []RouteUpdateOperation) (*route.Route, error) { - unlock := am.Store.AcquireAccountLock(accountID) - defer unlock() - - account, err := am.Store.GetAccount(accountID) - if err != nil { - return nil, err - } - - routeToUpdate, ok := account.Routes[routeID] - if !ok { - return nil, status.Errorf(status.NotFound, "route %s no longer exists", routeID) - } - - newRoute := routeToUpdate.Copy() - - for _, operation := range operations { - - if len(operation.Values) != 1 { - return nil, status.Errorf(status.InvalidArgument, "operation %s contains invalid number of values, it should be 1", operation.Type.String()) - } - - switch operation.Type { - case UpdateRouteDescription: - newRoute.Description = operation.Values[0] - case UpdateRouteNetworkIdentifier: - if utf8.RuneCountInString(operation.Values[0]) > route.MaxNetIDChar || operation.Values[0] == "" { - return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar) - } - newRoute.NetID = operation.Values[0] - case UpdateRouteNetwork: - prefixType, prefix, err := route.ParseNetwork(operation.Values[0]) - if err != nil { - return nil, status.Errorf(status.InvalidArgument, "failed to parse IP %s", operation.Values[0]) - } - err = am.checkPrefixPeerExists(accountID, routeToUpdate.Peer, prefix) - if err != nil { - return nil, err - } - newRoute.Network = prefix - newRoute.NetworkType = prefixType - case UpdateRoutePeer: - if operation.Values[0] != "" { - peer := account.GetPeer(operation.Values[0]) - if peer == nil { - return nil, status.Errorf(status.InvalidArgument, "peer with ID %s not found", operation.Values[0]) - } - } - - err = am.checkPrefixPeerExists(accountID, operation.Values[0], routeToUpdate.Network) - if err != nil { - return nil, err - } - newRoute.Peer = operation.Values[0] - case UpdateRouteMetric: - metric, err := strconv.Atoi(operation.Values[0]) - if err != nil { - return nil, status.Errorf(status.InvalidArgument, "failed to parse metric %s, not int", operation.Values[0]) - } - if metric < route.MinMetric || metric > route.MaxMetric { - return nil, status.Errorf(status.InvalidArgument, "failed to parse metric %s, value should be %d > N < %d", - operation.Values[0], - route.MinMetric, - route.MaxMetric, - ) - } - newRoute.Metric = metric - case UpdateRouteMasquerade: - masquerade, err := strconv.ParseBool(operation.Values[0]) - if err != nil { - return nil, status.Errorf(status.InvalidArgument, "failed to parse masquerade %s, not boolean", operation.Values[0]) - } - newRoute.Masquerade = masquerade - case UpdateRouteEnabled: - enabled, err := strconv.ParseBool(operation.Values[0]) - if err != nil { - return nil, status.Errorf(status.InvalidArgument, "failed to parse enabled %s, not boolean", operation.Values[0]) - } - newRoute.Enabled = enabled - case UpdateRouteGroups: - err = validateGroups(operation.Values, account.Groups) - if err != nil { - return nil, err - } - newRoute.Groups = operation.Values - } - } - - account.Routes[routeID] = newRoute - - account.Network.IncSerial() - if err = am.Store.SaveAccount(account); err != nil { - return nil, err - } - - err = am.updateAccountPeers(account) - if err != nil { - return nil, status.Errorf(status.Internal, "failed to update account peers") - } - return newRoute, nil -} - // DeleteRoute deletes route with routeID func (am *DefaultAccountManager) DeleteRoute(accountID, routeID, userID string) error { unlock := am.Store.AcquireAccountLock(accountID) diff --git a/management/server/route_test.go b/management/server/route_test.go index c943aee0b..69241da31 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -524,265 +524,6 @@ func TestSaveRoute(t *testing.T) { } } -func TestUpdateRoute(t *testing.T) { - routeID := "testingRouteID" - - existingRoute := &route.Route{ - ID: routeID, - Network: netip.MustParsePrefix("192.168.0.0/16"), - NetID: "superRoute", - NetworkType: route.IPv4Network, - Peer: peer1ID, - Description: "super", - Masquerade: false, - Metric: 9999, - Enabled: true, - Groups: []string{routeGroup1}, - } - - testCases := []struct { - name string - existingRoute *route.Route - operations []RouteUpdateOperation - shouldCreate bool - errFunc require.ErrorAssertionFunc - expectedRoute *route.Route - }{ - { - name: "Happy Path Single OPS", - existingRoute: existingRoute, - operations: []RouteUpdateOperation{ - { - Type: UpdateRoutePeer, - Values: []string{peer2ID}, - }, - }, - errFunc: require.NoError, - shouldCreate: true, - expectedRoute: &route.Route{ - ID: routeID, - Network: netip.MustParsePrefix("192.168.0.0/16"), - NetID: "superRoute", - NetworkType: route.IPv4Network, - Peer: peer2ID, - Description: "super", - Masquerade: false, - Metric: 9999, - Enabled: true, - Groups: []string{routeGroup1}, - }, - }, - { - name: "Happy Path Multiple OPS", - existingRoute: existingRoute, - operations: []RouteUpdateOperation{ - { - Type: UpdateRouteDescription, - Values: []string{"great"}, - }, - { - Type: UpdateRouteNetwork, - Values: []string{"192.168.0.0/24"}, - }, - { - Type: UpdateRoutePeer, - Values: []string{peer2ID}, - }, - { - Type: UpdateRouteMetric, - Values: []string{"3030"}, - }, - { - Type: UpdateRouteMasquerade, - Values: []string{"true"}, - }, - { - Type: UpdateRouteEnabled, - Values: []string{"false"}, - }, - { - Type: UpdateRouteNetworkIdentifier, - Values: []string{"megaRoute"}, - }, - { - Type: UpdateRouteGroups, - Values: []string{routeGroup2}, - }, - }, - errFunc: require.NoError, - shouldCreate: true, - expectedRoute: &route.Route{ - ID: routeID, - Network: netip.MustParsePrefix("192.168.0.0/24"), - NetID: "megaRoute", - NetworkType: route.IPv4Network, - Peer: peer2ID, - Description: "great", - Masquerade: true, - Metric: 3030, - Enabled: false, - Groups: []string{routeGroup2}, - }, - }, - { - name: "Empty Values Should Fail", - existingRoute: existingRoute, - operations: []RouteUpdateOperation{ - { - Type: UpdateRoutePeer, - }, - }, - errFunc: require.Error, - }, - { - name: "Multiple Values Should Fail", - existingRoute: existingRoute, - operations: []RouteUpdateOperation{ - { - Type: UpdateRoutePeer, - Values: []string{peer2ID, peer1ID}, - }, - }, - errFunc: require.Error, - }, - { - name: "Bad Prefix Should Fail", - existingRoute: existingRoute, - operations: []RouteUpdateOperation{ - { - Type: UpdateRouteNetwork, - Values: []string{"192.168.0.0/34"}, - }, - }, - errFunc: require.Error, - }, - { - name: "Bad Peer Should Fail", - existingRoute: existingRoute, - operations: []RouteUpdateOperation{ - { - Type: UpdateRoutePeer, - Values: []string{"non existing Peer"}, - }, - }, - errFunc: require.Error, - }, - { - name: "Empty Peer", - existingRoute: existingRoute, - operations: []RouteUpdateOperation{ - { - Type: UpdateRoutePeer, - Values: []string{""}, - }, - }, - errFunc: require.NoError, - shouldCreate: true, - expectedRoute: &route.Route{ - ID: routeID, - Network: netip.MustParsePrefix("192.168.0.0/16"), - NetID: "superRoute", - NetworkType: route.IPv4Network, - Peer: "", - Description: "super", - Masquerade: false, - Metric: 9999, - Enabled: true, - Groups: []string{routeGroup1}, - }, - }, - { - name: "Large Network ID Should Fail", - existingRoute: existingRoute, - operations: []RouteUpdateOperation{ - { - Type: UpdateRouteNetworkIdentifier, - Values: []string{"12345678901234567890qwertyuiopqwertyuiop1"}, - }, - }, - errFunc: require.Error, - }, - { - name: "Empty Network ID Should Fail", - existingRoute: existingRoute, - operations: []RouteUpdateOperation{ - { - Type: UpdateRouteNetworkIdentifier, - Values: []string{""}, - }, - }, - errFunc: require.Error, - }, - { - name: "Invalid Metric Should Fail", - existingRoute: existingRoute, - operations: []RouteUpdateOperation{ - { - Type: UpdateRouteMetric, - Values: []string{"999999"}, - }, - }, - errFunc: require.Error, - }, - { - name: "Invalid Boolean Should Fail", - existingRoute: existingRoute, - operations: []RouteUpdateOperation{ - { - Type: UpdateRouteMasquerade, - Values: []string{"yes"}, - }, - }, - errFunc: require.Error, - }, - { - name: "Invalid Group Should Fail", - existingRoute: existingRoute, - operations: []RouteUpdateOperation{ - { - Type: UpdateRouteGroups, - Values: []string{routeInvalidGroup1}, - }, - }, - errFunc: require.Error, - }, - } - for _, testCase := range testCases { - t.Run(testCase.name, func(t *testing.T) { - am, err := createRouterManager(t) - if err != nil { - t.Error("failed to create account manager") - } - - account, err := initTestRouteAccount(t, am) - if err != nil { - t.Error("failed to init testing account") - } - - account.Routes[testCase.existingRoute.ID] = testCase.existingRoute - - err = am.Store.SaveAccount(account) - if err != nil { - t.Error("account should be saved") - } - - updatedRoute, err := am.UpdateRoute(account.Id, testCase.existingRoute.ID, testCase.operations) - - testCase.errFunc(t, err) - - if !testCase.shouldCreate { - return - } - - testCase.expectedRoute.ID = updatedRoute.ID - - if !testCase.expectedRoute.IsEqual(updatedRoute) { - t.Errorf("new route didn't match expected route:\nGot %#v\nExpected:%#v\n", updatedRoute, testCase.expectedRoute) - } - }) - } -} - func TestDeleteRoute(t *testing.T) { testingRoute := &route.Route{ ID: "testingRoute", From 34e2c6b9437b9432caac6c0f563c9506c9ae45b7 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Mon, 18 Sep 2023 16:04:53 +0200 Subject: [PATCH 40/42] Fix sso check (#1152) Fix SSO check - change the order of the PKCE and device auth flow check, prefer PKCE - fix error handling in PKCE check --- client/android/login.go | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/client/android/login.go b/client/android/login.go index ad334541c..8d2636c9a 100644 --- a/client/android/login.go +++ b/client/android/login.go @@ -84,10 +84,14 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) { func (a *Auth) saveConfigIfSSOSupported() (bool, error) { supportsSSO := true err := a.withBackOff(a.ctx, func() (err error) { - _, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL) - if s, ok := gstatus.FromError(err); ok && s.Code() == codes.NotFound { - _, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL) - if s, ok := gstatus.FromError(err); ok && s.Code() == codes.NotFound { + _, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL) + if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) { + _, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL) + s, ok := gstatus.FromError(err) + if !ok { + return err + } + if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented { supportsSSO = false err = nil } From 8febab40765c31caac5d40f4f3d6c57097491ec1 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Tue, 19 Sep 2023 19:06:18 +0300 Subject: [PATCH 41/42] Improve Client Authentication (#1135) * shutdown the pkce server on user cancellation * Refactor openURL to exclusively manage authentication flow instructions and browser launching * Refactor authentication flow initialization based on client OS The NewOAuthFlow method now first checks the operating system and if it is a non-desktop Linux, it opts for Device Code Flow. PKCEFlow is tried first and if it fails, then it falls back on Device Code Flow. If both unsuccessful, the authentication process halts and error messages have been updated to provide more helpful feedback for troubleshooting authentication errors * Replace log-based Linux desktop check with process check To verify if a Linux OS is running a desktop environment in the Authentication utility, the log-based method that checks the XDG_CURRENT_DESKTOP env has been replaced with a method that checks directly if either X or Wayland display server processes are running. This method is more reliable as it directly checks for the display server process rather than relying on an environment variable that may not be set in all desktop environments. * Refactor PKCE Authorization Flow to improve server handling * refactor check for linux running desktop environment * Improve server shutdown handling and encapsulate handlers with new server multiplexer The changes enhance the way the server shuts down by specifying a context with timeout of 5 seconds, adding a safeguard to ensure the server halts even on potential hanging requests. Also, the server's root handler is now encapsulated within a new ServeMux instance, to support multiple registrations of a path --- client/cmd/login.go | 53 ++------------- client/internal/auth/oauth.go | 40 ++++++++---- client/internal/auth/pkce_flow.go | 105 +++++++++++++++--------------- client/internal/auth/util.go | 6 ++ 4 files changed, 92 insertions(+), 112 deletions(-) diff --git a/client/cmd/login.go b/client/cmd/login.go index a5cc3215c..5433db522 100644 --- a/client/cmd/login.go +++ b/client/cmd/login.go @@ -3,8 +3,6 @@ package cmd import ( "context" "fmt" - "os" - "runtime" "strings" "time" @@ -195,51 +193,12 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string) { codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode) } - browserAuthMsg := "Please do the SSO login in your browser. \n" + + cmd.Println("Please do the SSO login in your browser. \n" + "If your browser didn't open automatically, use this URL to log in:\n\n" + - verificationURIComplete + " " + codeMsg - - setupKeyAuthMsg := "\nAlternatively, you may want to use a setup key, see:\n\n" + - "https://docs.netbird.io/how-to/register-machines-using-setup-keys" - - authenticateUsingBrowser := func() { - cmd.Println(browserAuthMsg) - cmd.Println("") - if err := open.Run(verificationURIComplete); err != nil { - cmd.Println(setupKeyAuthMsg) - } - } - - switch runtime.GOOS { - case "windows", "darwin": - authenticateUsingBrowser() - case "linux": - if isLinuxRunningDesktop() { - authenticateUsingBrowser() - } else { - // If current flow is PKCE, it implies the server is anticipating the redirect to localhost. - // Devices lacking browser support are incompatible with this flow.Therefore, - // these devices will need to resort to setup keys instead. - if isPKCEFlow(verificationURIComplete) { - cmd.Println("Please proceed with setting up this device using setup keys, see:\n\n" + - "https://docs.netbird.io/how-to/register-machines-using-setup-keys") - } else { - cmd.Println(browserAuthMsg) - } - } + verificationURIComplete + " " + codeMsg) + cmd.Println("") + if err := open.Run(verificationURIComplete); err != nil { + cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" + + "https://docs.netbird.io/how-to/register-machines-using-setup-keys") } } - -// isLinuxRunningDesktop checks if a Linux OS is running desktop environment. -func isLinuxRunningDesktop() bool { - return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != "" -} - -// isPKCEFlow determines if the PKCE flow is active or not, -// by checking the existence of redirect_uri inside the verification URL. -func isPKCEFlow(verificationURL string) bool { - if verificationURL == "" { - return false - } - return strings.Contains(verificationURL, "redirect_uri") -} diff --git a/client/internal/auth/oauth.go b/client/internal/auth/oauth.go index 794fe0958..8731e4f0b 100644 --- a/client/internal/auth/oauth.go +++ b/client/internal/auth/oauth.go @@ -4,8 +4,8 @@ import ( "context" "fmt" "net/http" + "runtime" - log "github.com/sirupsen/logrus" "google.golang.org/grpc/codes" gstatus "google.golang.org/grpc/status" @@ -57,25 +57,43 @@ func (t TokenInfo) GetTokenToUse() string { return t.AccessToken } -// NewOAuthFlow initializes and returns the appropriate OAuth flow based on the management configuration. +// NewOAuthFlow initializes and returns the appropriate OAuth flow based on the management configuration +// +// It starts by initializing the PKCE.If this process fails, it resorts to the Device Code Flow, +// and if that also fails, the authentication process is deemed unsuccessful +// +// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow func NewOAuthFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) { - log.Debug("loading pkce authorization flow info") - - pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL) - if err == nil { - return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig) + if runtime.GOOS == "linux" && !isLinuxRunningDesktop() { + return authenticateWithDeviceCodeFlow(ctx, config) } - log.Debugf("loading pkce authorization flow info failed with error: %v", err) - log.Debugf("falling back to device authorization flow info") + pkceFlow, err := authenticateWithPKCEFlow(ctx, config) + if err != nil { + // fallback to device code flow + return authenticateWithDeviceCodeFlow(ctx, config) + } + return pkceFlow, nil +} +// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow +func authenticateWithPKCEFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) { + pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL) + if err != nil { + return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err) + } + return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig) +} + +// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow +func authenticateWithDeviceCodeFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) { deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL) if err != nil { s, ok := gstatus.FromError(err) if ok && s.Code() == codes.NotFound { return nil, fmt.Errorf("no SSO provider returned from management. " + - "If you are using hosting Netbird see documentation at " + - "https://github.com/netbirdio/netbird/tree/main/management for details") + "Please proceed with setting up this device using setup keys " + + "https://docs.netbird.io/how-to/register-machines-using-setup-keys") } else if ok && s.Code() == codes.Unimplemented { return nil, fmt.Errorf("the management server, %s, does not support SSO providers, "+ "please update your server or use Setup Keys to login", config.ManagementURL) diff --git a/client/internal/auth/pkce_flow.go b/client/internal/auth/pkce_flow.go index d15d49373..a3d0c1309 100644 --- a/client/internal/auth/pkce_flow.go +++ b/client/internal/auth/pkce_flow.go @@ -12,7 +12,6 @@ import ( "net/http" "net/url" "strings" - "sync" "time" log "github.com/sirupsen/logrus" @@ -80,7 +79,7 @@ func (p *PKCEAuthorizationFlow) GetClientID(_ context.Context) string { } // RequestAuthInfo requests a authorization code login flow information. -func (p *PKCEAuthorizationFlow) RequestAuthInfo(_ context.Context) (AuthFlowInfo, error) { +func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error) { state, err := randomBytesInHex(24) if err != nil { return AuthFlowInfo{}, fmt.Errorf("could not generate random state: %v", err) @@ -114,64 +113,37 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) ( tokenChan := make(chan *oauth2.Token, 1) errChan := make(chan error, 1) - go p.startServer(tokenChan, errChan) + parsedURL, err := url.Parse(p.oAuthConfig.RedirectURL) + if err != nil { + return TokenInfo{}, fmt.Errorf("failed to parse redirect URL: %v", err) + } + + server := &http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())} + defer func() { + shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + if err := server.Shutdown(shutdownCtx); err != nil { + log.Errorf("failed to close the server: %v", err) + } + }() + + go p.startServer(server, tokenChan, errChan) select { case <-ctx.Done(): return TokenInfo{}, ctx.Err() case token := <-tokenChan: - return p.handleOAuthToken(token) + return p.parseOAuthToken(token) case err := <-errChan: return TokenInfo{}, err } } -func (p *PKCEAuthorizationFlow) startServer(tokenChan chan<- *oauth2.Token, errChan chan<- error) { - var wg sync.WaitGroup - - parsedURL, err := url.Parse(p.oAuthConfig.RedirectURL) - if err != nil { - errChan <- fmt.Errorf("failed to parse redirect URL: %v", err) - return - } - - server := http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())} - go func() { - if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { - errChan <- err - } - }() - - wg.Add(1) - http.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) { - defer wg.Done() - - tokenValidatorFunc := func() (*oauth2.Token, error) { - query := req.URL.Query() - - if authError := query.Get(queryError); authError != "" { - authErrorDesc := query.Get(queryErrorDesc) - return nil, fmt.Errorf("%s.%s", authError, authErrorDesc) - } - - // Prevent timing attacks on state - if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 { - return nil, fmt.Errorf("invalid state") - } - - code := query.Get(queryCode) - if code == "" { - return nil, fmt.Errorf("missing code") - } - - return p.oAuthConfig.Exchange( - req.Context(), - code, - oauth2.SetAuthURLParam("code_verifier", p.codeVerifier), - ) - } - - token, err := tokenValidatorFunc() +func (p *PKCEAuthorizationFlow) startServer(server *http.Server, tokenChan chan<- *oauth2.Token, errChan chan<- error) { + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) { + token, err := p.handleRequest(req) if err != nil { renderPKCEFlowTmpl(w, err) errChan <- fmt.Errorf("PKCE authorization flow failed: %v", err) @@ -182,13 +154,38 @@ func (p *PKCEAuthorizationFlow) startServer(tokenChan chan<- *oauth2.Token, errC tokenChan <- token }) - wg.Wait() - if err := server.Shutdown(context.Background()); err != nil { - log.Errorf("error while shutting down pkce flow server: %v", err) + server.Handler = mux + if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + errChan <- err } } -func (p *PKCEAuthorizationFlow) handleOAuthToken(token *oauth2.Token) (TokenInfo, error) { +func (p *PKCEAuthorizationFlow) handleRequest(req *http.Request) (*oauth2.Token, error) { + query := req.URL.Query() + + if authError := query.Get(queryError); authError != "" { + authErrorDesc := query.Get(queryErrorDesc) + return nil, fmt.Errorf("%s.%s", authError, authErrorDesc) + } + + // Prevent timing attacks on the state + if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 { + return nil, fmt.Errorf("invalid state") + } + + code := query.Get(queryCode) + if code == "" { + return nil, fmt.Errorf("missing code") + } + + return p.oAuthConfig.Exchange( + req.Context(), + code, + oauth2.SetAuthURLParam("code_verifier", p.codeVerifier), + ) +} + +func (p *PKCEAuthorizationFlow) parseOAuthToken(token *oauth2.Token) (TokenInfo, error) { tokenInfo := TokenInfo{ AccessToken: token.AccessToken, RefreshToken: token.RefreshToken, diff --git a/client/internal/auth/util.go b/client/internal/auth/util.go index 33a0e6e35..e61e0f175 100644 --- a/client/internal/auth/util.go +++ b/client/internal/auth/util.go @@ -7,6 +7,7 @@ import ( "encoding/json" "fmt" "io" + "os" "reflect" "strings" ) @@ -60,3 +61,8 @@ func isValidAccessToken(token string, audience string) error { return fmt.Errorf("invalid JWT token audience field") } + +// isLinuxRunningDesktop checks if a Linux OS is running desktop environment +func isLinuxRunningDesktop() bool { + return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != "" +} From d4b6d7646c0ff40193dd17d1061ff3d31e5dfce0 Mon Sep 17 00:00:00 2001 From: Givi Khojanashvili Date: Tue, 19 Sep 2023 20:08:40 +0400 Subject: [PATCH 42/42] Handle user delete (#1113) Implement user deletion across all IDP-ss. Expires all user peers when the user is deleted. Users are permanently removed from a local store, but in IDP, we remove Netbird attributes for the user untilUserDeleteFromIDPEnabled setting is not enabled. To test, an admin user should remove any additional users. Until the UI incorporates this feature, use a curl DELETE request targeting the /users/ management endpoint. Note that this request only removes user attributes and doesn't trigger a delete from the IDP. To enable user removal from the IdP, set UserDeleteFromIDPEnabled to true in account settings. Until we have a UI for this, make this change directly in the store file. Store the deleted email addresses in encrypted in activity store. --- client/cmd/testutil.go | 2 +- client/internal/engine_test.go | 2 +- management/client/client_test.go | 2 +- management/cmd/management.go | 35 ++++- management/cmd/root.go | 2 + management/server/account.go | 50 +++---- management/server/account_test.go | 2 +- management/server/activity/codes.go | 3 + management/server/activity/event.go | 18 ++- management/server/activity/sqlite/crypt.go | 81 +++++++++++ .../server/activity/sqlite/crypt_test.go | 63 +++++++++ management/server/activity/sqlite/sqlite.go | 126 +++++++++++++++--- .../server/activity/sqlite/sqlite_test.go | 3 +- management/server/config.go | 3 +- management/server/dns_test.go | 2 +- management/server/http/api/generate.sh | 0 management/server/http/api/openapi.yml | 5 + management/server/http/api/types.gen.go | 3 + management/server/http/events_handler.go | 56 ++++++-- management/server/http/events_handler_test.go | 3 + management/server/idp/auth0.go | 50 ++++++- management/server/idp/authentik.go | 35 ++++- management/server/idp/azure.go | 37 +++++ management/server/idp/google_workspace.go | 13 ++ management/server/idp/idp.go | 1 + management/server/idp/keycloak.go | 41 ++++++ management/server/idp/okta.go | 22 +++ management/server/idp/zitadel.go | 53 +++++++- management/server/management_proto_test.go | 2 +- management/server/management_test.go | 2 +- management/server/nameserver_test.go | 2 +- management/server/route_test.go | 2 +- management/server/telemetry/idp_metrics.go | 12 ++ management/server/user.go | 118 +++++++++++++--- management/server/user_test.go | 5 +- 35 files changed, 744 insertions(+), 112 deletions(-) create mode 100644 management/server/activity/sqlite/crypt.go create mode 100644 management/server/activity/sqlite/crypt_test.go mode change 100644 => 100755 management/server/http/api/generate.sh diff --git a/client/cmd/testutil.go b/client/cmd/testutil.go index 678072f0b..6d47021dd 100644 --- a/client/cmd/testutil.go +++ b/client/cmd/testutil.go @@ -76,7 +76,7 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste return nil, nil } accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "", - eventStore) + eventStore, false) if err != nil { t.Fatal(err) } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 9f17ff36b..ea4a23a8d 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -1049,7 +1049,7 @@ func startManagement(dataDir string) (*grpc.Server, string, error) { return nil, "", err } accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "", - eventStore) + eventStore, false) if err != nil { return nil, "", err } diff --git a/management/client/client_test.go b/management/client/client_test.go index deef57329..86c598adb 100644 --- a/management/client/client_test.go +++ b/management/client/client_test.go @@ -61,7 +61,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { peersUpdateManager := mgmt.NewPeersUpdateManager() eventStore := &activity.InMemoryEventStore{} accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "", - eventStore) + eventStore, false) if err != nil { t.Fatal(err) } diff --git a/management/cmd/management.go b/management/cmd/management.go index 5c3816715..ca333b931 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -31,6 +31,7 @@ import ( "github.com/netbirdio/netbird/encryption" mgmtProto "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity/sqlite" httpapi "github.com/netbirdio/netbird/management/server/http" "github.com/netbirdio/netbird/management/server/idp" @@ -142,12 +143,22 @@ var ( if disableSingleAccMode { mgmtSingleAccModeDomain = "" } - eventStore, err := sqlite.NewSQLiteStore(config.Datadir) + eventStore, key, err := initEventStore(config.Datadir, config.DataStoreEncryptionKey) if err != nil { - return err + return fmt.Errorf("failed to initialize database: %s", err) } + + if key != "" { + log.Debugf("update config with activity store key") + config.DataStoreEncryptionKey = key + err := updateMgmtConfig(mgmtConfig, config) + if err != nil { + return fmt.Errorf("failed to write out store encryption key: %s", err) + } + } + accountManager, err := server.BuildManager(store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain, - dnsDomain, eventStore) + dnsDomain, eventStore, userDeleteFromIDPEnabled) if err != nil { return fmt.Errorf("failed to build default manager: %v", err) } @@ -287,6 +298,20 @@ var ( } ) +func initEventStore(dataDir string, key string) (activity.Store, string, error) { + var err error + if key == "" { + log.Debugf("generate new activity store encryption key") + key, err = sqlite.GenerateKey() + if err != nil { + return nil, "", err + } + } + store, err := sqlite.NewSQLiteStore(dataDir, key) + return store, key, err + +} + func notifyStop(msg string) { select { case stopCh <- 1: @@ -440,6 +465,10 @@ func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) { return loadedConfig, err } +func updateMgmtConfig(path string, config *server.Config) error { + return util.WriteJson(path, config) +} + // OIDCConfigResponse used for parsing OIDC config response type OIDCConfigResponse struct { Issuer string `json:"issuer"` diff --git a/management/cmd/root.go b/management/cmd/root.go index a149841c5..2080a6b29 100644 --- a/management/cmd/root.go +++ b/management/cmd/root.go @@ -24,6 +24,7 @@ var ( disableMetrics bool disableSingleAccMode bool idpSignKeyRefreshEnabled bool + userDeleteFromIDPEnabled bool rootCmd = &cobra.Command{ Use: "netbird-mgmt", @@ -56,6 +57,7 @@ func init() { mgmtCmd.Flags().BoolVar(&disableMetrics, "disable-anonymous-metrics", false, "disables push of anonymous usage metrics to NetBird") mgmtCmd.Flags().StringVar(&dnsDomain, "dns-domain", defaultSingleAccModeDomain, fmt.Sprintf("Domain used for peer resolution. This is appended to the peer's name, e.g. pi-server. %s. Max lenght is 192 characters to allow appending to a peer name with up to 63 characters.", defaultSingleAccModeDomain)) mgmtCmd.Flags().BoolVar(&idpSignKeyRefreshEnabled, "idp-sign-key-refresh-enabled", false, "Enable cache headers evaluation to determine signing key rotation period. This will refresh the signing key upon expiry.") + mgmtCmd.Flags().BoolVar(&userDeleteFromIDPEnabled, "user-delete-from-idp", false, "Allows to delete user from IDP when user is deleted from account") rootCmd.MarkFlagRequired("config") //nolint rootCmd.PersistentFlags().StringVar(&logLevel, "log-level", "info", "") diff --git a/management/server/account.go b/management/server/account.go index a0d4568ec..2dba658ec 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -130,6 +130,9 @@ type DefaultAccountManager struct { // dnsDomain is used for peer resolution. This is appended to the peer's name dnsDomain string peerLoginExpiry Scheduler + + // userDeleteFromIDPEnabled allows to delete user from IDP when user is deleted from account + userDeleteFromIDPEnabled bool } // Settings represents Account settings structure that can be modified via API and Dashboard @@ -735,18 +738,19 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) { // BuildManager creates a new DefaultAccountManager with a provided Store func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager, - singleAccountModeDomain string, dnsDomain string, eventStore activity.Store, + singleAccountModeDomain string, dnsDomain string, eventStore activity.Store, userDeleteFromIDPEnabled bool, ) (*DefaultAccountManager, error) { am := &DefaultAccountManager{ - Store: store, - peersUpdateManager: peersUpdateManager, - idpManager: idpManager, - ctx: context.Background(), - cacheMux: sync.Mutex{}, - cacheLoading: map[string]chan struct{}{}, - dnsDomain: dnsDomain, - eventStore: eventStore, - peerLoginExpiry: NewDefaultScheduler(), + Store: store, + peersUpdateManager: peersUpdateManager, + idpManager: idpManager, + ctx: context.Background(), + cacheMux: sync.Mutex{}, + cacheLoading: map[string]chan struct{}{}, + dnsDomain: dnsDomain, + eventStore: eventStore, + peerLoginExpiry: NewDefaultScheduler(), + userDeleteFromIDPEnabled: userDeleteFromIDPEnabled, } allAccounts := store.GetAllAccounts() // enable single account mode only if configured by user and number of existing accounts is not grater than 1 @@ -871,33 +875,19 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(accountID string) func() return account.GetNextPeerExpiration() } + expiredPeers := account.GetExpiredPeers() var peerIDs []string - for _, peer := range account.GetExpiredPeers() { - if peer.Status.LoginExpired { - continue - } + for _, peer := range expiredPeers { peerIDs = append(peerIDs, peer.ID) - peer.MarkLoginExpired(true) - account.UpdatePeer(peer) - err = am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status) - if err != nil { - log.Errorf("failed saving peer status while expiring peer %s", peer.ID) - return account.GetNextPeerExpiration() - } - am.storeEvent(peer.UserID, peer.ID, account.Id, activity.PeerLoginExpired, peer.EventMeta(am.GetDNSDomain())) } log.Debugf("discovered %d peers to expire for account %s", len(peerIDs), account.Id) - if len(peerIDs) != 0 { - // this will trigger peer disconnect from the management service - am.peersUpdateManager.CloseChannels(peerIDs) - err = am.updateAccountPeers(account) - if err != nil { - log.Errorf("failed updating account peers while expiring peers for account %s", accountID) - return account.GetNextPeerExpiration() - } + if err := am.expireAndUpdatePeers(account, expiredPeers); err != nil { + log.Errorf("failed updating account peers while expiring peers for account %s", account.Id) + return account.GetNextPeerExpiration() } + return account.GetNextPeerExpiration() } } diff --git a/management/server/account_test.go b/management/server/account_test.go index 64fd90524..204e98947 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -2063,7 +2063,7 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) { return nil, err } eventStore := &activity.InMemoryEventStore{} - return BuildManager(store, NewPeersUpdateManager(), nil, "", "netbird.cloud", eventStore) + return BuildManager(store, NewPeersUpdateManager(), nil, "", "netbird.cloud", eventStore, false) } func createStore(t *testing.T) (Store, error) { diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go index 4de667ded..ce36f520f 100644 --- a/management/server/activity/codes.go +++ b/management/server/activity/codes.go @@ -104,6 +104,8 @@ const ( UserBlocked // UserUnblocked indicates that a user unblocked another user UserUnblocked + // UserDeleted indicates that a user deleted another user + UserDeleted // GroupDeleted indicates that a user deleted group GroupDeleted // UserLoggedInPeer indicates that user logged in their peer with an interactive SSO login @@ -162,6 +164,7 @@ var activityMap = map[Activity]Code{ ServiceUserDeleted: {"Service user deleted", "service.user.delete"}, UserBlocked: {"User blocked", "user.block"}, UserUnblocked: {"User unblocked", "user.unblock"}, + UserDeleted: {"User deleted", "user.delete"}, GroupDeleted: {"Group deleted", "group.delete"}, UserLoggedInPeer: {"User logged in peer", "user.peer.login"}, PeerLoginExpired: {"Peer login expired", "peer.login.expire"}, diff --git a/management/server/activity/event.go b/management/server/activity/event.go index 17ec4a0b0..1bf86ef2c 100644 --- a/management/server/activity/event.go +++ b/management/server/activity/event.go @@ -18,10 +18,13 @@ type Event struct { ID uint64 // InitiatorID is the ID of an object that initiated the event (e.g., a user) InitiatorID string + // InitiatorEmail is the email address of an object that initiated the event. This will be set on deleted users only + InitiatorEmail string // TargetID is the ID of an object that was effected by the event (e.g., a peer) TargetID string // AccountID is the ID of an account where the event happened AccountID string + // Meta of the event, e.g. deleted peer information like name, IP, etc Meta map[string]any } @@ -35,12 +38,13 @@ func (e *Event) Copy() *Event { } return &Event{ - Timestamp: e.Timestamp, - Activity: e.Activity, - ID: e.ID, - InitiatorID: e.InitiatorID, - TargetID: e.TargetID, - AccountID: e.AccountID, - Meta: meta, + Timestamp: e.Timestamp, + Activity: e.Activity, + ID: e.ID, + InitiatorID: e.InitiatorID, + InitiatorEmail: e.InitiatorEmail, + TargetID: e.TargetID, + AccountID: e.AccountID, + Meta: meta, } } diff --git a/management/server/activity/sqlite/crypt.go b/management/server/activity/sqlite/crypt.go new file mode 100644 index 000000000..8f2755604 --- /dev/null +++ b/management/server/activity/sqlite/crypt.go @@ -0,0 +1,81 @@ +package sqlite + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/base64" + "fmt" +) + +var iv = []byte{10, 22, 13, 79, 05, 8, 52, 91, 87, 98, 88, 98, 35, 25, 13, 05} + +type EmailEncrypt struct { + block cipher.Block +} + +func GenerateKey() (string, error) { + key := make([]byte, 32) + _, err := rand.Read(key) + if err != nil { + return "", err + } + readableKey := base64.StdEncoding.EncodeToString(key) + return readableKey, nil +} + +func NewEmailEncrypt(key string) (*EmailEncrypt, error) { + binKey, err := base64.StdEncoding.DecodeString(key) + if err != nil { + return nil, err + } + + block, err := aes.NewCipher(binKey) + if err != nil { + return nil, err + } + ec := &EmailEncrypt{ + block: block, + } + + return ec, nil +} + +func (ec *EmailEncrypt) Encrypt(payload string) string { + plainText := pkcs5Padding([]byte(payload)) + cipherText := make([]byte, len(plainText)) + cbc := cipher.NewCBCEncrypter(ec.block, iv) + cbc.CryptBlocks(cipherText, plainText) + return base64.StdEncoding.EncodeToString(cipherText) +} + +func (ec *EmailEncrypt) Decrypt(data string) (string, error) { + cipherText, err := base64.StdEncoding.DecodeString(data) + if err != nil { + return "", err + } + cbc := cipher.NewCBCDecrypter(ec.block, iv) + cbc.CryptBlocks(cipherText, cipherText) + payload, err := pkcs5UnPadding(cipherText) + if err != nil { + return "", err + } + + return string(payload), nil +} + +func pkcs5Padding(ciphertext []byte) []byte { + padding := aes.BlockSize - len(ciphertext)%aes.BlockSize + padText := bytes.Repeat([]byte{byte(padding)}, padding) + return append(ciphertext, padText...) +} + +func pkcs5UnPadding(src []byte) ([]byte, error) { + srcLen := len(src) + paddingLen := int(src[srcLen-1]) + if paddingLen >= srcLen || paddingLen > aes.BlockSize { + return nil, fmt.Errorf("padding size error") + } + return src[:srcLen-paddingLen], nil +} diff --git a/management/server/activity/sqlite/crypt_test.go b/management/server/activity/sqlite/crypt_test.go new file mode 100644 index 000000000..5fb59a692 --- /dev/null +++ b/management/server/activity/sqlite/crypt_test.go @@ -0,0 +1,63 @@ +package sqlite + +import ( + "testing" +) + +func TestGenerateKey(t *testing.T) { + testData := "exampl@netbird.io" + key, err := GenerateKey() + if err != nil { + t.Fatalf("failed to generate key: %s", err) + } + ee, err := NewEmailEncrypt(key) + if err != nil { + t.Fatalf("failed to init email encryption: %s", err) + } + + encrypted := ee.Encrypt(testData) + if encrypted == "" { + t.Fatalf("invalid encrypted text") + } + + decrypted, err := ee.Decrypt(encrypted) + if err != nil { + t.Fatalf("failed to decrypt data: %s", err) + } + + if decrypted != testData { + t.Fatalf("decrypted data is not match with test data: %s, %s", testData, decrypted) + } +} + +func TestCorruptKey(t *testing.T) { + testData := "exampl@netbird.io" + key, err := GenerateKey() + if err != nil { + t.Fatalf("failed to generate key: %s", err) + } + ee, err := NewEmailEncrypt(key) + if err != nil { + t.Fatalf("failed to init email encryption: %s", err) + } + + encrypted := ee.Encrypt(testData) + if encrypted == "" { + t.Fatalf("invalid encrypted text") + } + + newKey, err := GenerateKey() + if err != nil { + t.Fatalf("failed to generate key: %s", err) + } + + ee, err = NewEmailEncrypt(newKey) + if err != nil { + t.Fatalf("failed to init email encryption: %s", err) + } + + res, err := ee.Decrypt(encrypted) + if err == nil || res == testData { + t.Fatalf("incorrect decryption, the result is: %s", res) + } +} diff --git a/management/server/activity/sqlite/sqlite.go b/management/server/activity/sqlite/sqlite.go index a4c85cf60..7ff59674d 100644 --- a/management/server/activity/sqlite/sqlite.go +++ b/management/server/activity/sqlite/sqlite.go @@ -3,14 +3,14 @@ package sqlite import ( "database/sql" "encoding/json" - - "github.com/netbirdio/netbird/management/server/activity" - - // sqlite driver + "fmt" "path/filepath" "time" - _ "github.com/mattn/go-sqlite3" + _ "github.com/mattn/go-sqlite3" // sqlite driver + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/management/server/activity" ) const ( @@ -25,35 +25,62 @@ const ( "meta TEXT," + " target_id TEXT);" - selectDescQuery = "SELECT id, activity, timestamp, initiator_id, target_id, account_id, meta" + - " FROM events WHERE account_id = ? ORDER BY timestamp DESC LIMIT ? OFFSET ?;" - selectAscQuery = "SELECT id, activity, timestamp, initiator_id, target_id, account_id, meta" + - " FROM events WHERE account_id = ? ORDER BY timestamp ASC LIMIT ? OFFSET ?;" + creatTableAccountEmailQuery = `CREATE TABLE IF NOT EXISTS deleted_users (id TEXT NOT NULL, email TEXT NOT NULL);` + + selectDescQuery = `SELECT events.id, activity, timestamp, initiator_id, i.email as "initiator_email", target_id, t.email as "target_email", account_id, meta + FROM events + LEFT JOIN deleted_users i ON events.initiator_id = i.id + LEFT JOIN deleted_users t ON events.target_id = t.id + WHERE account_id = ? + ORDER BY timestamp DESC LIMIT ? OFFSET ?;` + + selectAscQuery = `SELECT events.id, activity, timestamp, initiator_id, i.email as "initiator_email", target_id, t.email as "target_email", account_id, meta + FROM events + LEFT JOIN deleted_users i ON events.initiator_id = i.id + LEFT JOIN deleted_users t ON events.target_id = t.id + WHERE account_id = ? + ORDER BY timestamp ASC LIMIT ? OFFSET ?;` + insertQuery = "INSERT INTO events(activity, timestamp, initiator_id, target_id, account_id, meta) " + "VALUES(?, ?, ?, ?, ?, ?)" + + insertDeleteUserQuery = `INSERT INTO deleted_users(id, email) VALUES(?, ?)` ) // Store is the implementation of the activity.Store interface backed by SQLite type Store struct { - db *sql.DB + db *sql.DB + emailEncrypt *EmailEncrypt + insertStatement *sql.Stmt selectAscStatement *sql.Stmt selectDescStatement *sql.Stmt + deleteUserStmt *sql.Stmt } // NewSQLiteStore creates a new Store with an event table if not exists. -func NewSQLiteStore(dataDir string) (*Store, error) { +func NewSQLiteStore(dataDir string, encryptionKey string) (*Store, error) { dbFile := filepath.Join(dataDir, eventSinkDB) db, err := sql.Open("sqlite3", dbFile) if err != nil { return nil, err } + crypt, err := NewEmailEncrypt(encryptionKey) + if err != nil { + return nil, err + } + _, err = db.Exec(createTableQuery) if err != nil { return nil, err } + _, err = db.Exec(creatTableAccountEmailQuery) + if err != nil { + return nil, err + } + insertStmt, err := db.Prepare(insertQuery) if err != nil { return nil, err @@ -69,25 +96,35 @@ func NewSQLiteStore(dataDir string) (*Store, error) { return nil, err } - return &Store{ + deleteUserStmt, err := db.Prepare(insertDeleteUserQuery) + if err != nil { + return nil, err + } + + s := &Store{ db: db, + emailEncrypt: crypt, insertStatement: insertStmt, selectDescStatement: selectDescStmt, selectAscStatement: selectAscStmt, - }, nil + deleteUserStmt: deleteUserStmt, + } + return s, nil } -func processResult(result *sql.Rows) ([]*activity.Event, error) { +func (store *Store) processResult(result *sql.Rows) ([]*activity.Event, error) { events := make([]*activity.Event, 0) for result.Next() { var id int64 var operation activity.Activity var timestamp time.Time var initiator string + var initiatorEmail *string var target string + var targetEmail *string var account string var jsonMeta string - err := result.Scan(&id, &operation, ×tamp, &initiator, &target, &account, &jsonMeta) + err := result.Scan(&id, &operation, ×tamp, &initiator, &initiatorEmail, &target, &targetEmail, &account, &jsonMeta) if err != nil { return nil, err } @@ -100,7 +137,17 @@ func processResult(result *sql.Rows) ([]*activity.Event, error) { } } - events = append(events, &activity.Event{ + if targetEmail != nil { + email, err := store.emailEncrypt.Decrypt(*targetEmail) + if err != nil { + log.Errorf("failed to decrypt email address for target id: %s", target) + meta["email"] = "" + } else { + meta["email"] = email + } + } + + event := &activity.Event{ Timestamp: timestamp, Activity: operation, ID: uint64(id), @@ -108,7 +155,18 @@ func processResult(result *sql.Rows) ([]*activity.Event, error) { TargetID: target, AccountID: account, Meta: meta, - }) + } + + if initiatorEmail != nil { + email, err := store.emailEncrypt.Decrypt(*initiatorEmail) + if err != nil { + log.Errorf("failed to decrypt email address of initiator: %s", initiator) + } else { + event.InitiatorEmail = email + } + } + + events = append(events, event) } return events, nil @@ -127,13 +185,18 @@ func (store *Store) Get(accountID string, offset, limit int, descending bool) ([ } defer result.Close() //nolint - return processResult(result) + return store.processResult(result) } -// Save an event in the SQLite events table +// Save an event in the SQLite events table end encrypt the "email" element in meta map func (store *Store) Save(event *activity.Event) (*activity.Event, error) { var jsonMeta string - if event.Meta != nil { + meta, err := store.saveDeletedUserEmailInEncrypted(event) + if err != nil { + return nil, err + } + + if meta != nil { metaBytes, err := json.Marshal(event.Meta) if err != nil { return nil, err @@ -156,6 +219,29 @@ func (store *Store) Save(event *activity.Event) (*activity.Event, error) { return eventCopy, nil } +// saveDeletedUserEmailInEncrypted if the meta contains email then store it in encrypted way and delete this item from +// meta map +func (store *Store) saveDeletedUserEmailInEncrypted(event *activity.Event) (map[string]any, error) { + email, ok := event.Meta["email"] + if !ok { + return event.Meta, nil + } + + delete(event.Meta, "email") + + encrypted := store.emailEncrypt.Encrypt(fmt.Sprintf("%s", email)) + _, err := store.deleteUserStmt.Exec(event.TargetID, encrypted) + if err != nil { + return nil, err + } + + if len(event.Meta) == 1 { + return nil, nil // nolint + } + delete(event.Meta, "email") + return event.Meta, nil +} + // Close the Store func (store *Store) Close() error { if store.db != nil { diff --git a/management/server/activity/sqlite/sqlite_test.go b/management/server/activity/sqlite/sqlite_test.go index 2ca9a1e64..f6a6f9467 100644 --- a/management/server/activity/sqlite/sqlite_test.go +++ b/management/server/activity/sqlite/sqlite_test.go @@ -12,7 +12,8 @@ import ( func TestNewSQLiteStore(t *testing.T) { dataDir := t.TempDir() - store, err := NewSQLiteStore(dataDir) + key, _ := GenerateKey() + store, err := NewSQLiteStore(dataDir, key) if err != nil { t.Fatal(err) return diff --git a/management/server/config.go b/management/server/config.go index ea0143988..31c1cf45c 100644 --- a/management/server/config.go +++ b/management/server/config.go @@ -35,7 +35,8 @@ type Config struct { TURNConfig *TURNConfig Signal *Host - Datadir string + Datadir string + DataStoreEncryptionKey string HttpConfig *HttpServerConfig diff --git a/management/server/dns_test.go b/management/server/dns_test.go index 092c52afa..b089949b2 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -191,7 +191,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) { return nil, err } eventStore := &activity.InMemoryEventStore{} - return BuildManager(store, NewPeersUpdateManager(), nil, "", "netbird.test", eventStore) + return BuildManager(store, NewPeersUpdateManager(), nil, "", "netbird.test", eventStore, false) } func createDNSStore(t *testing.T) (Store, error) { diff --git a/management/server/http/api/generate.sh b/management/server/http/api/generate.sh old mode 100644 new mode 100755 diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index 06da0ede3..f2d1e26bf 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -922,6 +922,10 @@ components: description: The ID of the initiator of the event. E.g., an ID of a user that triggered the event. type: string example: google-oauth2|123456789012345678901 + initiator_email: + description: The e-mail address of the initiator of the event. E.g., an e-mail of a user that triggered the event. + type: string + example: demo@netbird.io target_id: description: The ID of the target of the event. E.g., an ID of the peer that a user removed. type: string @@ -938,6 +942,7 @@ components: - activity - activity_code - initiator_id + - initiator_email - target_id - meta responses: diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index 402aae635..33c935a68 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -164,6 +164,9 @@ type Event struct { // Id Event unique identifier Id string `json:"id"` + // InitiatorEmail The e-mail address of the initiator of the event. E.g., an e-mail of a user that triggered the event. + InitiatorEmail string `json:"initiator_email"` + // InitiatorId The ID of the initiator of the event. E.g., an ID of a user that triggered the event. InitiatorId string `json:"initiator_id"` diff --git a/management/server/http/events_handler.go b/management/server/http/events_handler.go index 1d1c176e5..cbca44364 100644 --- a/management/server/http/events_handler.go +++ b/management/server/http/events_handler.go @@ -45,14 +45,46 @@ func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) { util.WriteError(err, w) return } - events := make([]*api.Event, 0) - for _, e := range accountEvents { - events = append(events, toEventResponse(e)) + events := make([]*api.Event, len(accountEvents)) + for i, e := range accountEvents { + events[i] = toEventResponse(e) + } + + err = h.fillEventsWithInitiatorEmail(events, account.Id, user.Id) + if err != nil { + util.WriteError(err, w) + return } util.WriteJSONObject(w, events) } +func (h *EventsHandler) fillEventsWithInitiatorEmail(events []*api.Event, accountId, userId string) error { + // build email map based on users + userInfos, err := h.accountManager.GetUsersFromAccount(accountId, userId) + if err != nil { + log.Errorf("failed to get users from account: %s", err) + return err + } + + emails := make(map[string]string) + for _, ui := range userInfos { + emails[ui.ID] = ui.Email + } + + // fill event with email of initiator + var ok bool + for _, event := range events { + if event.InitiatorEmail == "" { + event.InitiatorEmail, ok = emails[event.InitiatorId] + if !ok { + log.Warnf("failed to resolve email for initiator: %s", event.InitiatorId) + } + } + } + return nil +} + func toEventResponse(event *activity.Event) *api.Event { meta := make(map[string]string) if event.Meta != nil { @@ -60,13 +92,15 @@ func toEventResponse(event *activity.Event) *api.Event { meta[s] = fmt.Sprintf("%v", a) } } - return &api.Event{ - Id: fmt.Sprint(event.ID), - InitiatorId: event.InitiatorID, - Activity: event.Activity.Message(), - ActivityCode: api.EventActivityCode(event.Activity.StringCode()), - TargetId: event.TargetID, - Timestamp: event.Timestamp, - Meta: meta, + e := &api.Event{ + Id: fmt.Sprint(event.ID), + InitiatorId: event.InitiatorID, + InitiatorEmail: event.InitiatorEmail, + Activity: event.Activity.Message(), + ActivityCode: api.EventActivityCode(event.Activity.StringCode()), + TargetId: event.TargetID, + Timestamp: event.Timestamp, + Meta: meta, } + return e } diff --git a/management/server/http/events_handler_test.go b/management/server/http/events_handler_test.go index a77e44f45..4cfad922b 100644 --- a/management/server/http/events_handler_test.go +++ b/management/server/http/events_handler_test.go @@ -37,6 +37,9 @@ func initEventsTestData(account string, user *server.User, events ...*activity.E }, }, user, nil }, + GetUsersFromAccountFunc: func(accountID, userID string) ([]*server.UserInfo, error) { + return make([]*server.UserInfo, 0), nil + }, }, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { diff --git a/management/server/idp/auth0.go b/management/server/idp/auth0.go index 64ec88e9f..d3802d8ad 100644 --- a/management/server/idp/auth0.go +++ b/management/server/idp/auth0.go @@ -513,7 +513,9 @@ func buildUserExportRequest() (string, error) { return string(str), nil } -func (am *Auth0Manager) createPostRequest(endpoint string, payloadStr string) (*http.Request, error) { +func (am *Auth0Manager) createRequest( + method string, endpoint string, body io.Reader, +) (*http.Request, error) { jwtToken, err := am.credentials.Authenticate() if err != nil { return nil, err @@ -521,17 +523,23 @@ func (am *Auth0Manager) createPostRequest(endpoint string, payloadStr string) (* reqURL := am.authIssuer + endpoint - payload := strings.NewReader(payloadStr) - - req, err := http.NewRequest("POST", reqURL, payload) + req, err := http.NewRequest(method, reqURL, body) if err != nil { return nil, err } req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken) + + return req, nil +} + +func (am *Auth0Manager) createPostRequest(endpoint string, payloadStr string) (*http.Request, error) { + req, err := am.createRequest("POST", endpoint, strings.NewReader(payloadStr)) + if err != nil { + return nil, err + } req.Header.Add("content-type", "application/json") return req, nil - } // GetAllAccounts gets all registered accounts with corresponding user data. @@ -737,6 +745,38 @@ func (am *Auth0Manager) InviteUserByID(userID string) error { return nil } +// DeleteUser from Auth0 +func (am *Auth0Manager) DeleteUser(userID string) error { + req, err := am.createRequest(http.MethodDelete, "/api/v2/users/"+url.QueryEscape(userID), nil) + if err != nil { + return err + } + + resp, err := am.httpClient.Do(req) + if err != nil { + log.Debugf("execute delete request: %v", err) + if am.appMetrics != nil { + am.appMetrics.IDPMetrics().CountRequestError() + } + return err + } + + defer func() { + err = resp.Body.Close() + if err != nil { + log.Errorf("close delete request body: %v", err) + } + }() + if resp.StatusCode != 204 { + if am.appMetrics != nil { + am.appMetrics.IDPMetrics().CountRequestStatusError() + } + return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode) + } + + return nil +} + // checkExportJobStatus checks the status of the job created at CreateExportUsersJob. // If the status is "completed", then return the downloadLink func (am *Auth0Manager) checkExportJobStatus(jobID string) (bool, string, error) { diff --git a/management/server/idp/authentik.go b/management/server/idp/authentik.go index 0898f1c94..102222d0d 100644 --- a/management/server/idp/authentik.go +++ b/management/server/idp/authentik.go @@ -12,9 +12,10 @@ import ( "time" "github.com/golang-jwt/jwt" - "github.com/netbirdio/netbird/management/server/telemetry" log "github.com/sirupsen/logrus" "goauthentik.io/api/v3" + + "github.com/netbirdio/netbird/management/server/telemetry" ) // AuthentikManager authentik manager client instance. @@ -453,6 +454,38 @@ func (am *AuthentikManager) InviteUserByID(_ string) error { return fmt.Errorf("method InviteUserByID not implemented") } +// DeleteUser from Authentik +func (am *AuthentikManager) DeleteUser(userID string) error { + ctx, err := am.authenticationContext() + if err != nil { + return err + } + + userPk, err := strconv.ParseInt(userID, 10, 32) + if err != nil { + return err + } + + resp, err := am.apiClient.CoreApi.CoreUsersDestroy(ctx, int32(userPk)).Execute() + if err != nil { + return err + } + defer resp.Body.Close() // nolint + + if am.appMetrics != nil { + am.appMetrics.IDPMetrics().CountDeleteUser() + } + + if resp.StatusCode != http.StatusNoContent { + if am.appMetrics != nil { + am.appMetrics.IDPMetrics().CountRequestStatusError() + } + return fmt.Errorf("unable to delete user %s, statusCode %d", userID, resp.StatusCode) + } + + return nil +} + func (am *AuthentikManager) authenticationContext() (context.Context, error) { jwtToken, err := am.credentials.Authenticate() if err != nil { diff --git a/management/server/idp/azure.go b/management/server/idp/azure.go index 7cff7d8fc..22e6825ae 100644 --- a/management/server/idp/azure.go +++ b/management/server/idp/azure.go @@ -454,6 +454,43 @@ func (am *AzureManager) InviteUserByID(_ string) error { return fmt.Errorf("method InviteUserByID not implemented") } +// DeleteUser from Azure +func (am *AzureManager) DeleteUser(userID string) error { + jwtToken, err := am.credentials.Authenticate() + if err != nil { + return err + } + + reqURL := fmt.Sprintf("%s/users/%s", am.GraphAPIEndpoint, url.QueryEscape(userID)) + req, err := http.NewRequest(http.MethodDelete, reqURL, nil) + if err != nil { + return err + } + req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken) + req.Header.Add("content-type", "application/json") + + log.Debugf("delete idp user %s", userID) + + resp, err := am.httpClient.Do(req) + if err != nil { + if am.appMetrics != nil { + am.appMetrics.IDPMetrics().CountRequestError() + } + return err + } + defer resp.Body.Close() + + if am.appMetrics != nil { + am.appMetrics.IDPMetrics().CountDeleteUser() + } + + if resp.StatusCode != http.StatusNoContent { + return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode) + } + + return nil +} + func (am *AzureManager) getUserExtensions() ([]azureExtension, error) { q := url.Values{} q.Add("$select", extensionFields) diff --git a/management/server/idp/google_workspace.go b/management/server/idp/google_workspace.go index 2e65497dc..40854e598 100644 --- a/management/server/idp/google_workspace.go +++ b/management/server/idp/google_workspace.go @@ -254,6 +254,19 @@ func (gm *GoogleWorkspaceManager) InviteUserByID(_ string) error { return fmt.Errorf("method InviteUserByID not implemented") } +// DeleteUser from GoogleWorkspace. +func (gm *GoogleWorkspaceManager) DeleteUser(userID string) error { + if err := gm.usersService.Delete(userID).Do(); err != nil { + return err + } + + if gm.appMetrics != nil { + gm.appMetrics.IDPMetrics().CountDeleteUser() + } + + return nil +} + // getGoogleCredentials retrieves Google credentials based on the provided serviceAccountKey. // It decodes the base64-encoded serviceAccountKey and attempts to obtain credentials using it. // If that fails, it falls back to using the default Google credentials path. diff --git a/management/server/idp/idp.go b/management/server/idp/idp.go index 3c1f4c327..ea2231390 100644 --- a/management/server/idp/idp.go +++ b/management/server/idp/idp.go @@ -18,6 +18,7 @@ type Manager interface { CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) GetUserByEmail(email string) ([]*UserData, error) InviteUserByID(userID string) error + DeleteUser(userID string) error } // ClientConfig defines common client configuration for all IdP manager diff --git a/management/server/idp/keycloak.go b/management/server/idp/keycloak.go index 12ed87389..d65a78ae3 100644 --- a/management/server/idp/keycloak.go +++ b/management/server/idp/keycloak.go @@ -467,6 +467,47 @@ func (km *KeycloakManager) InviteUserByID(_ string) error { return fmt.Errorf("method InviteUserByID not implemented") } +// DeleteUser from Keycloack +func (km *KeycloakManager) DeleteUser(userID string) error { + jwtToken, err := km.credentials.Authenticate() + if err != nil { + return err + } + + reqURL := fmt.Sprintf("%s/users/%s", km.adminEndpoint, url.QueryEscape(userID)) + + req, err := http.NewRequest(http.MethodDelete, reqURL, nil) + if err != nil { + return err + } + req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken) + req.Header.Add("content-type", "application/json") + + if km.appMetrics != nil { + km.appMetrics.IDPMetrics().CountDeleteUser() + } + + resp, err := km.httpClient.Do(req) + if err != nil { + if km.appMetrics != nil { + km.appMetrics.IDPMetrics().CountRequestError() + } + return err + } + defer resp.Body.Close() // nolint + + // In the docs, they specified 200, but in the endpoints, they return 204 + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent { + if km.appMetrics != nil { + km.appMetrics.IDPMetrics().CountRequestStatusError() + } + + return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode) + } + + return nil +} + func buildKeycloakCreateUserRequestPayload(email string, name string, appMetadata AppMetadata) (string, error) { attrs := keycloakUserAttributes{} attrs.Set(wtAccountID, appMetadata.WTAccountID) diff --git a/management/server/idp/okta.go b/management/server/idp/okta.go index c6b5055d4..0e93c494c 100644 --- a/management/server/idp/okta.go +++ b/management/server/idp/okta.go @@ -319,6 +319,28 @@ func (om *OktaManager) InviteUserByID(_ string) error { return fmt.Errorf("method InviteUserByID not implemented") } +// DeleteUser from Okta +func (om *OktaManager) DeleteUser(userID string) error { + resp, err := om.client.User.DeactivateOrDeleteUser(context.Background(), userID, nil) + if err != nil { + fmt.Println(err.Error()) + return err + } + + if om.appMetrics != nil { + om.appMetrics.IDPMetrics().CountDeleteUser() + } + + if resp.StatusCode != http.StatusOK { + if om.appMetrics != nil { + om.appMetrics.IDPMetrics().CountRequestStatusError() + } + return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode) + } + + return nil +} + // updateUserProfileSchema updates the Okta user schema to include custom fields, // wt_account_id and wt_pending_invite. func updateUserProfileSchema(client *okta.Client) error { diff --git a/management/server/idp/zitadel.go b/management/server/idp/zitadel.go index fce2c7b37..73958a69e 100644 --- a/management/server/idp/zitadel.go +++ b/management/server/idp/zitadel.go @@ -428,7 +428,7 @@ func (zm *ZitadelManager) UpdateUserAppMetadata(userID string, appMetadata AppMe return err } - resource := fmt.Sprintf("users/%s/metadata/_bulk", userID) + resource := fmt.Sprintf("users/%s", userID) _, err = zm.post(resource, string(payload)) if err != nil { return err @@ -447,6 +447,21 @@ func (zm *ZitadelManager) InviteUserByID(_ string) error { return fmt.Errorf("method InviteUserByID not implemented") } +// DeleteUser from Zitadel +func (zm *ZitadelManager) DeleteUser(userID string) error { + resource := fmt.Sprintf("users/%s", userID) + if err := zm.delete(resource); err != nil { + return err + } + + if zm.appMetrics != nil { + zm.appMetrics.IDPMetrics().CountDeleteUser() + } + + return nil + +} + // getUserMetadata requests user metadata from zitadel via ID. func (zm *ZitadelManager) getUserMetadata(userID string) ([]zitadelMetadata, error) { resource := fmt.Sprintf("users/%s/metadata/_search", userID) @@ -500,6 +515,42 @@ func (zm *ZitadelManager) post(resource string, body string) ([]byte, error) { return io.ReadAll(resp.Body) } +// delete perform Delete requests. +func (zm *ZitadelManager) delete(resource string) error { + jwtToken, err := zm.credentials.Authenticate() + if err != nil { + return err + } + + reqURL := fmt.Sprintf("%s/%s", zm.managementEndpoint, resource) + req, err := http.NewRequest(http.MethodDelete, reqURL, nil) + if err != nil { + return err + } + req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken) + req.Header.Add("content-type", "application/json") + + resp, err := zm.httpClient.Do(req) + if err != nil { + if zm.appMetrics != nil { + zm.appMetrics.IDPMetrics().CountRequestError() + } + + return err + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + if zm.appMetrics != nil { + zm.appMetrics.IDPMetrics().CountRequestStatusError() + } + + return fmt.Errorf("unable to delete %s, statusCode %d", reqURL, resp.StatusCode) + } + + return nil +} + // get perform Get requests. func (zm *ZitadelManager) get(resource string, q url.Values) ([]byte, error) { jwtToken, err := zm.credentials.Authenticate() diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index 66661dbf8..b4a527e46 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -412,7 +412,7 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error) peersUpdateManager := NewPeersUpdateManager() eventStore := &activity.InMemoryEventStore{} accountManager, err := BuildManager(store, peersUpdateManager, nil, "", "", - eventStore) + eventStore, false) if err != nil { return nil, "", err } diff --git a/management/server/management_test.go b/management/server/management_test.go index 6c93765f4..fa35cfdef 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -503,7 +503,7 @@ func startServer(config *server.Config) (*grpc.Server, net.Listener) { peersUpdateManager := server.NewPeersUpdateManager() eventStore := &activity.InMemoryEventStore{} accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "", - eventStore) + eventStore, false) if err != nil { log.Fatalf("failed creating a manager: %v", err) } diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index ab3edaed4..26977116b 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -744,7 +744,7 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) { return nil, err } eventStore := &activity.InMemoryEventStore{} - return BuildManager(store, NewPeersUpdateManager(), nil, "", "", eventStore) + return BuildManager(store, NewPeersUpdateManager(), nil, "", "", eventStore, false) } func createNSStore(t *testing.T) (Store, error) { diff --git a/management/server/route_test.go b/management/server/route_test.go index 69241da31..81ce21a3f 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -681,7 +681,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) { return nil, err } eventStore := &activity.InMemoryEventStore{} - return BuildManager(store, NewPeersUpdateManager(), nil, "", "", eventStore) + return BuildManager(store, NewPeersUpdateManager(), nil, "", "", eventStore, false) } func createRouterStore(t *testing.T) (Store, error) { diff --git a/management/server/telemetry/idp_metrics.go b/management/server/telemetry/idp_metrics.go index 67a1d9e85..e9eee17bd 100644 --- a/management/server/telemetry/idp_metrics.go +++ b/management/server/telemetry/idp_metrics.go @@ -2,6 +2,7 @@ package telemetry import ( "context" + "go.opentelemetry.io/otel/metric" "go.opentelemetry.io/otel/metric/instrument" "go.opentelemetry.io/otel/metric/instrument/syncint64" @@ -13,6 +14,7 @@ type IDPMetrics struct { getUserByEmailCounter syncint64.Counter getAllAccountsCounter syncint64.Counter createUserCounter syncint64.Counter + deleteUserCounter syncint64.Counter getAccountCounter syncint64.Counter getUserByIDCounter syncint64.Counter authenticateRequestCounter syncint64.Counter @@ -39,6 +41,10 @@ func NewIDPMetrics(ctx context.Context, meter metric.Meter) (*IDPMetrics, error) if err != nil { return nil, err } + deleteUserCounter, err := meter.SyncInt64().Counter("management.idp.delete.user.counter", instrument.WithUnit("1")) + if err != nil { + return nil, err + } getAccountCounter, err := meter.SyncInt64().Counter("management.idp.get.account.counter", instrument.WithUnit("1")) if err != nil { return nil, err @@ -65,6 +71,7 @@ func NewIDPMetrics(ctx context.Context, meter metric.Meter) (*IDPMetrics, error) getUserByEmailCounter: getUserByEmailCounter, getAllAccountsCounter: getAllAccountsCounter, createUserCounter: createUserCounter, + deleteUserCounter: deleteUserCounter, getAccountCounter: getAccountCounter, getUserByIDCounter: getUserByIDCounter, authenticateRequestCounter: authenticateRequestCounter, @@ -88,6 +95,11 @@ func (idpMetrics *IDPMetrics) CountCreateUser() { idpMetrics.createUserCounter.Add(idpMetrics.ctx, 1) } +// CountDeleteUser ... +func (idpMetrics *IDPMetrics) CountDeleteUser() { + idpMetrics.deleteUserCounter.Add(idpMetrics.ctx, 1) +} + // CountGetAllAccounts ... func (idpMetrics *IDPMetrics) CountGetAllAccounts() { idpMetrics.getAllAccountsCounter.Add(idpMetrics.ctx, 1) diff --git a/management/server/user.go b/management/server/user.go index 8ee036df7..ebebe1e0f 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -327,15 +327,43 @@ func (am *DefaultAccountManager) DeleteUser(accountID, initiatorUserID string, t return status.Errorf(status.NotFound, "user not found") } if executingUser.Role != UserRoleAdmin { - return status.Errorf(status.PermissionDenied, "only admins can delete service users") + return status.Errorf(status.PermissionDenied, "only admins can delete users") } - if !targetUser.IsServiceUser { - return status.Errorf(status.PermissionDenied, "regular users can not be deleted") + peers, err := account.FindUserPeers(targetUserID) + if err != nil { + return status.Errorf(status.Internal, "failed to find user peers") } - meta := map[string]any{"name": targetUser.ServiceUserName} - am.storeEvent(initiatorUserID, targetUserID, accountID, activity.ServiceUserDeleted, meta) + if err := am.expireAndUpdatePeers(account, peers); err != nil { + log.Errorf("failed update deleted peers expiration: %s", err) + return err + } + + targetUserEmail, err := am.getEmailOfTargetUser(account.Id, initiatorUserID, targetUserID) + if err != nil { + log.Errorf("failed to resolve email address: %s", err) + return err + } + + var meta map[string]any + var eventAction activity.Activity + if targetUser.IsServiceUser { + meta = map[string]any{"name": targetUser.ServiceUserName} + eventAction = activity.ServiceUserDeleted + } else { + meta = map[string]any{"email": targetUserEmail} + eventAction = activity.UserDeleted + + } + am.storeEvent(initiatorUserID, targetUserID, accountID, eventAction, meta) + + if !isNil(am.idpManager) { + err := am.deleteUserFromIDP(targetUserID, accountID) + if err != nil { + return err + } + } delete(account.Users, targetUserID) @@ -609,23 +637,10 @@ func (am *DefaultAccountManager) SaveUser(accountID, initiatorUserID string, upd if err != nil { return nil, err } - var peerIDs []string - for _, peer := range blockedPeers { - peerIDs = append(peerIDs, peer.ID) - peer.MarkLoginExpired(true) - account.UpdatePeer(peer) - err = am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status) - if err != nil { - log.Errorf("failed saving peer status while expiring peer %s", peer.ID) - return nil, err - } - } - am.peersUpdateManager.CloseChannels(peerIDs) - err = am.updateAccountPeers(account) - if err != nil { - log.Errorf("failed updating account peers while expiring peers of a blocked user %s", accountID) - return nil, err + if err := am.expireAndUpdatePeers(account, blockedPeers); err != nil { + log.Errorf("failed update expired peers: %s", err) + return nil, err } } @@ -814,6 +829,67 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) ( return userInfos, nil } +// expireAndUpdatePeers expires all peers of the given user and updates them in the account +func (am *DefaultAccountManager) expireAndUpdatePeers(account *Account, peers []*Peer) error { + var peerIDs []string + for _, peer := range peers { + peerIDs = append(peerIDs, peer.ID) + peer.MarkLoginExpired(true) + account.UpdatePeer(peer) + if err := am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status); err != nil { + return err + } + am.storeEvent( + peer.UserID, peer.ID, account.Id, + activity.PeerLoginExpired, peer.EventMeta(am.GetDNSDomain()), + ) + } + + if len(peerIDs) != 0 { + // this will trigger peer disconnect from the management service + am.peersUpdateManager.CloseChannels(peerIDs) + if err := am.updateAccountPeers(account); err != nil { + return err + } + } + return nil +} + +func (am *DefaultAccountManager) deleteUserFromIDP(targetUserID, accountID string) error { + if am.userDeleteFromIDPEnabled { + log.Debugf("user %s deleted from IdP", targetUserID) + err := am.idpManager.DeleteUser(targetUserID) + if err != nil { + return fmt.Errorf("failed to delete user %s from IdP: %s", targetUserID, err) + } + } else { + err := am.idpManager.UpdateUserAppMetadata(targetUserID, idp.AppMetadata{}) + if err != nil { + return fmt.Errorf("failed to remove user %s app metadata in IdP: %s", targetUserID, err) + } + + _, err = am.refreshCache(accountID) + if err != nil { + log.Errorf("refresh account (%q) cache: %v", accountID, err) + } + } + return nil +} + +func (am *DefaultAccountManager) getEmailOfTargetUser(accountId string, initiatorId, targetId string) (string, error) { + userInfos, err := am.GetUsersFromAccount(accountId, initiatorId) + if err != nil { + return "", err + } + for _, ui := range userInfos { + if ui.ID == targetId { + return ui.Email, nil + } + } + + return "", fmt.Errorf("email not found for user: %s", targetId) +} + func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) { for _, user := range userData { if user.ID == userID { diff --git a/management/server/user_test.go b/management/server/user_test.go index b07154663..bd64074b9 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -439,8 +439,9 @@ func TestUser_DeleteUser_regularUser(t *testing.T) { } err = am.DeleteUser(mockAccountID, mockUserID, mockUserID) - - assert.Errorf(t, err, "Regular users can not be deleted (yet)") + if err != nil { + t.Errorf("unexpected error: %s", err) + } } func TestDefaultAccountManager_GetUser(t *testing.T) {