Merge branch 'main' into feature/relay-integration

This commit is contained in:
Zoltán Papp 2024-08-20 16:44:04 +02:00
commit 2e6c6cd47d
61 changed files with 2470 additions and 481 deletions

View File

@ -31,9 +31,14 @@ Please specify whether you use NetBird Cloud or self-host NetBird's control plan
`netbird version`
**NetBird status -d output:**
**NetBird status -dA output:**
If applicable, add the `netbird status -d' command output.
If applicable, add the `netbird status -dA' command output.
**Do you face any client issues on desktop?**
Please provide the file created by `netbird debug for 1m -AS`.
We advise reviewing the anonymized files for any remaining PII.
**Screenshots**

View File

@ -11,8 +11,6 @@ builds:
- amd64
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
tags:
- legacy_appindicator
mod_timestamp: '{{ .CommitTimestamp }}'
- id: netbird-ui-windows

View File

@ -17,7 +17,7 @@
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
</a>
<br>
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-vrahf41g-ik1v7fV8du6t0RwxSrJ96A">
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2p5zwhm4g-8fHollzrQa5y4PZF5AEpvQ">
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
</a>
</p>

View File

@ -84,7 +84,7 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
supportsSSO := true
err := a.withBackOff(a.ctx, func() (err error) {
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
s, ok := gstatus.FromError(err)

View File

@ -39,6 +39,11 @@ var loginCmd = &cobra.Command{
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, hostName)
}
providedSetupKey, err := getSetupKey()
if err != nil {
return err
}
// workaround to run without service
if logFile == "console" {
err = handleRebrand(cmd)
@ -62,7 +67,7 @@ var loginCmd = &cobra.Command{
config, _ = internal.UpdateOldManagementURL(ctx, config, configPath)
err = foregroundLogin(ctx, cmd, config, setupKey)
err = foregroundLogin(ctx, cmd, config, providedSetupKey)
if err != nil {
return fmt.Errorf("foreground login failed: %v", err)
}
@ -81,7 +86,7 @@ var loginCmd = &cobra.Command{
client := proto.NewDaemonServiceClient(conn)
loginRequest := proto.LoginRequest{
SetupKey: setupKey,
SetupKey: providedSetupKey,
ManagementUrl: managementURL,
IsLinuxDesktopClient: isLinuxRunningDesktop(),
Hostname: hostName,

View File

@ -56,6 +56,7 @@ var (
managementURL string
adminURL string
setupKey string
setupKeyPath string
hostName string
preSharedKey string
natExternalIPs []string
@ -128,6 +129,8 @@ func init() {
rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets Netbird log level")
rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout. If syslog is specified the log will be sent to syslog daemon.")
rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)")
rootCmd.PersistentFlags().StringVar(&setupKeyPath, "setup-key-file", "", "The path to a setup key obtained from the Management Service Dashboard (used to register peer) This is ignored if the setup-key flag is provided.")
rootCmd.MarkFlagsMutuallyExclusive("setup-key", "setup-key-file")
rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.")
rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device")
rootCmd.PersistentFlags().BoolVarP(&anonymizeFlag, "anonymize", "A", false, "anonymize IP addresses and non-netbird.io domains in logs and status output")
@ -253,6 +256,21 @@ var CLIBackOffSettings = &backoff.ExponentialBackOff{
Clock: backoff.SystemClock,
}
func getSetupKey() (string, error) {
if setupKeyPath != "" && setupKey == "" {
return getSetupKeyFromFile(setupKeyPath)
}
return setupKey, nil
}
func getSetupKeyFromFile(setupKeyPath string) (string, error) {
data, err := os.ReadFile(setupKeyPath)
if err != nil {
return "", fmt.Errorf("failed to read setup key file: %v", err)
}
return strings.TrimSpace(string(data)), nil
}
func handleRebrand(cmd *cobra.Command) error {
var err error
if logFile == defaultLogFile {

View File

@ -11,6 +11,7 @@ import (
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/util"
@ -71,6 +72,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Listener) {
t.Helper()
lis, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
@ -88,7 +90,11 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
return nil, nil
}
iv, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics)
if err != nil {
t.Fatal(err)
}

View File

@ -147,6 +147,11 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
ic.DNSRouteInterval = &dnsRouteInterval
}
providedSetupKey, err := getSetupKey()
if err != nil {
return err
}
config, err := internal.UpdateOrCreateConfig(ic)
if err != nil {
return fmt.Errorf("get config file: %v", err)
@ -154,7 +159,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
config, _ = internal.UpdateOldManagementURL(ctx, config, configPath)
err = foregroundLogin(ctx, cmd, config, setupKey)
err = foregroundLogin(ctx, cmd, config, providedSetupKey)
if err != nil {
return fmt.Errorf("foreground login failed: %v", err)
}
@ -202,8 +207,13 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
return nil
}
providedSetupKey, err := getSetupKey()
if err != nil {
return err
}
loginRequest := proto.LoginRequest{
SetupKey: setupKey,
SetupKey: providedSetupKey,
ManagementUrl: managementURL,
AdminURL: adminURL,
NatExternalIPs: natExternalIPs,

View File

@ -2,6 +2,7 @@ package cmd
import (
"context"
"os"
"testing"
"time"
@ -40,6 +41,36 @@ func TestUpDaemon(t *testing.T) {
return
}
// Test the setup-key-file flag.
tempFile, err := os.CreateTemp("", "setup-key")
if err != nil {
t.Errorf("could not create temp file, got error %v", err)
return
}
defer os.Remove(tempFile.Name())
if _, err := tempFile.Write([]byte("A2C8E62B-38F5-4553-B31E-DD66C696CEBB")); err != nil {
t.Errorf("could not write to temp file, got error %v", err)
return
}
if err := tempFile.Close(); err != nil {
t.Errorf("unable to close file, got error %v", err)
}
rootCmd.SetArgs([]string{
"login",
"--daemon-addr", "tcp://" + cliAddr,
"--setup-key-file", tempFile.Name(),
"--log-file", "",
})
if err := rootCmd.Execute(); err != nil {
t.Errorf("expected no error while running up command, got %v", err)
return
}
time.Sleep(time.Second * 3)
if status, err := state.Status(); err != nil && status != internal.StatusIdle {
t.Errorf("wrong status after login: %s, %v", internal.StatusIdle, err)
return
}
rootCmd.SetArgs([]string{
"up",
"--daemon-addr", "tcp://" + cliAddr,

View File

@ -3,6 +3,7 @@ package auth
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
@ -180,7 +181,7 @@ func (d *DeviceAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowIn
continue
}
return TokenInfo{}, fmt.Errorf(tokenResponse.ErrorDescription)
return TokenInfo{}, errors.New(tokenResponse.ErrorDescription)
}
tokenInfo := TokenInfo{

View File

@ -86,7 +86,7 @@ func NewOAuthFlow(ctx context.Context, config *internal.Config, isLinuxDesktopCl
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
func authenticateWithPKCEFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) {
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair)
if err != nil {
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
}

View File

@ -4,6 +4,7 @@ import (
"context"
"crypto/sha256"
"crypto/subtle"
"crypto/tls"
"encoding/base64"
"errors"
"fmt"
@ -143,6 +144,18 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (
func (p *PKCEAuthorizationFlow) startServer(server *http.Server, tokenChan chan<- *oauth2.Token, errChan chan<- error) {
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
cert := p.providerConfig.ClientCertPair
if cert != nil {
tr := &http.Transport{
TLSClientConfig: &tls.Config{
Certificates: []tls.Certificate{*cert},
},
}
sslClient := &http.Client{Transport: tr}
ctx := context.WithValue(req.Context(), oauth2.HTTPClient, sslClient)
req = req.WithContext(ctx)
}
token, err := p.handleRequest(req)
if err != nil {
renderPKCEFlowTmpl(w, err)

View File

@ -2,6 +2,7 @@ package internal
import (
"context"
"crypto/tls"
"fmt"
"net/url"
"os"
@ -57,6 +58,8 @@ type ConfigInput struct {
DisableAutoConnect *bool
ExtraIFaceBlackList []string
DNSRouteInterval *time.Duration
ClientCertPath string
ClientCertKeyPath string
}
// Config Configuration type
@ -102,6 +105,13 @@ type Config struct {
// DNSRouteInterval is the interval in which the DNS routes are updated
DNSRouteInterval time.Duration
//Path to a certificate used for mTLS authentication
ClientCertPath string
//Path to corresponding private key of ClientCertPath
ClientCertKeyPath string
ClientCertKeyPair *tls.Certificate `json:"-"`
}
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
@ -385,6 +395,26 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
}
if input.ClientCertKeyPath != "" {
config.ClientCertKeyPath = input.ClientCertKeyPath
updated = true
}
if input.ClientCertPath != "" {
config.ClientCertPath = input.ClientCertPath
updated = true
}
if config.ClientCertPath != "" && config.ClientCertKeyPath != "" {
cert, err := tls.LoadX509KeyPair(config.ClientCertPath, config.ClientCertKeyPath)
if err != nil {
log.Error("Failed to load mTLS cert/key pair: ", err)
} else {
config.ClientCertKeyPair = &cert
log.Info("Loaded client mTLS cert/key pair")
}
}
return updated, nil
}

View File

@ -37,6 +37,7 @@ import (
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
relayClient "github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/route"
signal "github.com/netbirdio/netbird/signal/client"
"github.com/netbirdio/netbird/signal/proto"
@ -1097,7 +1098,11 @@ func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error)
return nil, "", err
}
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics)
if err != nil {
return nil, "", err
}

View File

@ -2,6 +2,7 @@ package internal
import (
"context"
"crypto/tls"
"fmt"
"net/url"
@ -36,10 +37,12 @@ type PKCEAuthProviderConfig struct {
RedirectURLs []string
// UseIDToken indicates if the id token should be used for authentication
UseIDToken bool
//ClientCertPair is used for mTLS authentication to the IDP
ClientCertPair *tls.Certificate
}
// GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it
func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL) (PKCEAuthorizationFlow, error) {
func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL, clientCert *tls.Certificate) (PKCEAuthorizationFlow, error) {
// validate our peer's Wireguard PRIVATE key
myPrivateKey, err := wgtypes.ParseKey(privateKey)
if err != nil {
@ -93,6 +96,7 @@ func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL
Scope: protoPKCEAuthorizationFlow.GetProviderConfig().GetScope(),
RedirectURLs: protoPKCEAuthorizationFlow.GetProviderConfig().GetRedirectURLs(),
UseIDToken: protoPKCEAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
ClientCertPair: clientCert,
},
}

View File

@ -1,4 +1,5 @@
// go:build !android
//go:build !android
package sysctl
import (

View File

@ -270,7 +270,14 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
}
routesMap := engine.GetClientRoutesWithNetID()
routeSelector := engine.GetRouteManager().GetRouteSelector()
routeManager := engine.GetRouteManager()
if routeManager == nil {
return nil, fmt.Errorf("could not get route manager")
}
routeSelector := routeManager.GetRouteSelector()
if routeSelector == nil {
return nil, fmt.Errorf("could not get route selector")
}
var routes []*selectRoute
for id, rt := range routesMap {

View File

@ -74,7 +74,7 @@ func (a *Auth) SaveConfigIfSSOSupported() (bool, error) {
err := a.withBackOff(a.ctx, func() (err error) {
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
supportsSSO = false
err = nil

View File

@ -19,6 +19,7 @@ import (
mgmtProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/signal/proto"
signalServer "github.com/netbirdio/netbird/signal/server"
)
@ -120,7 +121,11 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
return nil, "", err
}
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics)
if err != nil {
return nil, "", err
}

View File

@ -118,9 +118,9 @@ func (srv *DefaultServer) publicKeyHandler(ctx ssh.Context, key ssh.PublicKey) b
func prepareUserEnv(user *user.User, shell string) []string {
return []string{
fmt.Sprintf("SHELL=" + shell),
fmt.Sprintf("USER=" + user.Username),
fmt.Sprintf("HOME=" + user.HomeDir),
fmt.Sprint("SHELL=" + shell),
fmt.Sprint("USER=" + user.Username),
fmt.Sprint("HOME=" + user.HomeDir),
}
}

View File

@ -22,8 +22,8 @@ import (
"fyne.io/fyne/v2/app"
"fyne.io/fyne/v2/dialog"
"fyne.io/fyne/v2/widget"
"fyne.io/systray"
"github.com/cenkalti/backoff/v4"
"github.com/getlantern/systray"
log "github.com/sirupsen/logrus"
"github.com/skratchdot/open-golang/open"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"

Binary file not shown.

Before

Width:  |  Height:  |  Size: 9.0 KiB

After

Width:  |  Height:  |  Size: 9.6 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 12 KiB

After

Width:  |  Height:  |  Size: 12 KiB

43
go.mod
View File

@ -30,15 +30,15 @@ require (
)
require (
fyne.io/fyne/v2 v2.1.4
fyne.io/fyne/v2 v2.5.0
fyne.io/systray v1.11.0
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible
github.com/c-robinson/iplib v1.0.3
github.com/cilium/ebpf v0.15.0
github.com/coreos/go-iptables v0.7.0
github.com/creack/pty v1.1.18
github.com/eko/gocache/v3 v3.1.1
github.com/fsnotify/fsnotify v1.6.0
github.com/getlantern/systray v1.2.1
github.com/fsnotify/fsnotify v1.7.0
github.com/gliderlabs/ssh v0.3.4
github.com/godbus/dbus/v5 v5.1.0
github.com/golang/mock v1.6.0
@ -83,7 +83,7 @@ require (
go.opentelemetry.io/otel/sdk/metric v1.26.0
goauthentik.io/api/v3 v3.2023051.3
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028
golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a
golang.org/x/net v0.26.0
golang.org/x/oauth2 v0.19.0
golang.org/x/sync v0.7.0
@ -102,7 +102,7 @@ require (
cloud.google.com/go/compute/metadata v0.3.0 // indirect
dario.cat/mergo v1.0.0 // indirect
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect
github.com/BurntSushi/toml v1.3.2 // indirect
github.com/BurntSushi/toml v1.4.0 // indirect
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/Microsoft/hcsshim v0.12.3 // indirect
github.com/XiaoMi/pegasus-go-client v0.0.0-20210427083443-f3b6b08bc4c2 // indirect
@ -121,27 +121,25 @@ require (
github.com/docker/go-connections v0.5.0 // indirect
github.com/docker/go-units v0.5.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/fredbi/uri v0.0.0-20181227131451-3dcfdacbaaf3 // indirect
github.com/getlantern/context v0.0.0-20190109183933-c447772a6520 // indirect
github.com/getlantern/errors v0.0.0-20190325191628-abdb3e3e36f7 // indirect
github.com/getlantern/golog v0.0.0-20190830074920-4ef2e798c2d7 // indirect
github.com/getlantern/hex v0.0.0-20190417191902-c6586a6fe0b7 // indirect
github.com/getlantern/hidden v0.0.0-20190325191715-f02dbb02be55 // indirect
github.com/getlantern/ops v0.0.0-20190325191751-d70cb0d6f85f // indirect
github.com/go-gl/gl v0.0.0-20210813123233-e4099ee2221f // indirect
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20211024062804-40e447a793be // indirect
github.com/fredbi/uri v1.1.0 // indirect
github.com/fyne-io/gl-js v0.0.0-20220119005834-d2da28d9ccfe // indirect
github.com/fyne-io/glfw-js v0.0.0-20240101223322-6e1efdc71b7a // indirect
github.com/fyne-io/image v0.0.0-20220602074514-4956b0afb3d2 // indirect
github.com/go-gl/gl v0.0.0-20211210172815-726fda9656d6 // indirect
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a // indirect
github.com/go-logr/logr v1.4.1 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-ole/go-ole v1.3.0 // indirect
github.com/go-redis/redis/v8 v8.11.5 // indirect
github.com/go-stack/stack v1.8.0 // indirect
github.com/go-text/render v0.1.0 // indirect
github.com/go-text/typesetting v0.1.0 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/goki/freetype v0.0.0-20181231101311-fa8a33aabaff // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/google/btree v1.0.1 // indirect
github.com/google/s2a-go v0.1.7 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
github.com/googleapis/gax-go/v2 v2.12.3 // indirect
github.com/gopherjs/gopherjs v1.17.2 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/hashicorp/go-uuid v1.0.2 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
@ -149,9 +147,11 @@ require (
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/pgx/v5 v5.5.5 // indirect
github.com/jackc/puddle/v2 v2.2.1 // indirect
github.com/jeandeaual/go-locale v0.0.0-20240223122105-ce5225dcaa49 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/josharian/native v1.1.0 // indirect
github.com/jsummers/gobmp v0.0.0-20151104160322-e2ba15ffa76e // indirect
github.com/kelseyhightower/envconfig v1.4.0 // indirect
github.com/klauspost/compress v1.17.8 // indirect
github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae // indirect
@ -163,11 +163,11 @@ require (
github.com/moby/sys/user v0.1.0 // indirect
github.com/moby/term v0.5.0 // indirect
github.com/morikuni/aec v1.0.0 // indirect
github.com/nicksnyder/go-i18n/v2 v2.4.0 // indirect
github.com/nxadm/tail v1.4.8 // indirect
github.com/onsi/ginkgo/v2 v2.9.5 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.0 // indirect
github.com/oxtoacart/bpool v0.0.0-20190530202638-03653db5a59c // indirect
github.com/pegasus-kv/thrift v0.13.0 // indirect
github.com/pion/dtls/v2 v2.2.10 // indirect
github.com/pion/mdns v0.0.12 // indirect
@ -178,21 +178,24 @@ require (
github.com/prometheus/client_model v0.6.1 // indirect
github.com/prometheus/common v0.53.0 // indirect
github.com/prometheus/procfs v0.15.0 // indirect
github.com/rymdport/portal v0.2.2 // indirect
github.com/shoenig/go-m1cpu v0.1.6 // indirect
github.com/spf13/cast v1.5.0 // indirect
github.com/srwiley/oksvg v0.0.0-20200311192757-870daf9aa564 // indirect
github.com/srwiley/rasterx v0.0.0-20200120212402-85cb7272f5e9 // indirect
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c // indirect
github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef // indirect
github.com/tklauser/go-sysconf v0.3.14 // indirect
github.com/tklauser/numcpus v0.8.0 // indirect
github.com/vishvananda/netns v0.0.4 // indirect
github.com/yuin/goldmark v1.4.13 // indirect
github.com/yuin/goldmark v1.7.1 // indirect
go.opencensus.io v0.24.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect
go.opentelemetry.io/otel/sdk v1.26.0 // indirect
go.opentelemetry.io/otel/trace v1.26.0 // indirect
golang.org/x/image v0.18.0 // indirect
golang.org/x/mod v0.17.0 // indirect
golang.org/x/text v0.16.0 // indirect
golang.org/x/time v0.5.0 // indirect
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240515191416-fc5f0ca64291 // indirect

534
go.sum

File diff suppressed because it is too large Load Diff

View File

@ -9,7 +9,10 @@ import (
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/client/system"
@ -71,7 +74,11 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{}
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics)
if err != nil {
t.Fatal(err)
}

View File

@ -2,6 +2,7 @@ package client
import (
"context"
"errors"
"fmt"
"io"
"sync"
@ -267,7 +268,7 @@ func (c *GrpcClient) receiveEvents(stream proto.ManagementService_SyncClient, se
// GetServerPublicKey returns server's WireGuard public key (used later for encrypting messages sent to the server)
func (c *GrpcClient) GetServerPublicKey() (*wgtypes.Key, error) {
if !c.ready() {
return nil, fmt.Errorf(errMsgNoMgmtConnection)
return nil, errors.New(errMsgNoMgmtConnection)
}
mgmCtx, cancel := context.WithTimeout(c.ctx, 5*time.Second)
@ -314,7 +315,7 @@ func (c *GrpcClient) IsHealthy() bool {
func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*proto.LoginResponse, error) {
if !c.ready() {
return nil, fmt.Errorf(errMsgNoMgmtConnection)
return nil, errors.New(errMsgNoMgmtConnection)
}
loginReq, err := encryption.EncryptMessage(serverKey, c.key, req)
@ -452,7 +453,7 @@ func (c *GrpcClient) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKC
// It should be used if there is changes on peer posture check after initial sync.
func (c *GrpcClient) SyncMeta(sysInfo *system.Info) error {
if !c.ready() {
return fmt.Errorf(errMsgNoMgmtConnection)
return errors.New(errMsgNoMgmtConnection)
}
serverPubKey, err := c.GetServerPublicKey()

View File

@ -190,7 +190,7 @@ var (
return fmt.Errorf("failed to initialize integrated peer validator: %v", err)
}
accountManager, err := server.BuildManager(ctx, store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain,
dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator)
dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator, appMetrics)
if err != nil {
return fmt.Errorf("failed to build default manager: %v", err)
}

View File

@ -18,6 +18,8 @@ import (
"github.com/eko/gocache/v3/cache"
cacheStore "github.com/eko/gocache/v3/store"
"github.com/hashicorp/go-multierror"
"github.com/miekg/dns"
gocache "github.com/patrickmn/go-cache"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
@ -37,6 +39,7 @@ import (
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/route"
)
@ -65,6 +68,7 @@ type AccountManager interface {
SaveSetupKey(ctx context.Context, accountID string, key *SetupKey, userID string) (*SetupKey, error)
CreateUser(ctx context.Context, accountID, initiatorUserID string, key *UserInfo) (*UserInfo, error)
DeleteUser(ctx context.Context, accountID, initiatorUserID string, targetUserID string) error
DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string) error
InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error
ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error)
SaveUser(ctx context.Context, accountID, initiatorUserID string, update *User) (*UserInfo, error)
@ -98,6 +102,7 @@ type AccountManager interface {
SaveGroup(ctx context.Context, accountID, userID string, group *nbgroup.Group) error
SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error
DeleteGroup(ctx context.Context, accountId, userId, groupID string) error
DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error
ListGroups(ctx context.Context, accountId string) ([]*nbgroup.Group, error)
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error
@ -170,6 +175,8 @@ type DefaultAccountManager struct {
userDeleteFromIDPEnabled bool
integratedPeerValidator integrated_validator.IntegratedValidator
metrics telemetry.AppMetrics
}
// Settings represents Account settings structure that can be modified via API and Dashboard
@ -401,8 +408,16 @@ func (a *Account) GetGroup(groupID string) *nbgroup.Group {
return a.Groups[groupID]
}
// GetPeerNetworkMap returns a group by ID if exists, nil otherwise
func (a *Account) GetPeerNetworkMap(ctx context.Context, peerID, dnsDomain string, validatedPeersMap map[string]struct{}) *NetworkMap {
// GetPeerNetworkMap returns the networkmap for the given peer ID.
func (a *Account) GetPeerNetworkMap(
ctx context.Context,
peerID string,
peersCustomZone nbdns.CustomZone,
validatedPeersMap map[string]struct{},
metrics *telemetry.AccountManagerMetrics,
) *NetworkMap {
start := time.Now()
peer := a.Peers[peerID]
if peer == nil {
return &NetworkMap{
@ -438,7 +453,7 @@ func (a *Account) GetPeerNetworkMap(ctx context.Context, peerID, dnsDomain strin
if dnsManagementStatus {
var zones []nbdns.CustomZone
peersCustomZone := getPeersCustomZone(ctx, a, dnsDomain)
if peersCustomZone.Domain != "" {
zones = append(zones, peersCustomZone)
}
@ -446,7 +461,7 @@ func (a *Account) GetPeerNetworkMap(ctx context.Context, peerID, dnsDomain strin
dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID)
}
return &NetworkMap{
nm := &NetworkMap{
Peers: peersToConnect,
Network: a.Network.Copy(),
Routes: routesUpdate,
@ -454,6 +469,60 @@ func (a *Account) GetPeerNetworkMap(ctx context.Context, peerID, dnsDomain strin
OfflinePeers: expiredPeers,
FirewallRules: firewallRules,
}
if metrics != nil {
objectCount := int64(len(peersToConnect) + len(expiredPeers) + len(routesUpdate) + len(firewallRules))
metrics.CountNetworkMapObjects(objectCount)
metrics.CountGetPeerNetworkMapDuration(time.Since(start))
}
return nm
}
func (a *Account) GetPeersCustomZone(ctx context.Context, dnsDomain string) nbdns.CustomZone {
var merr *multierror.Error
if dnsDomain == "" {
log.WithContext(ctx).Error("no dns domain is set, returning empty zone")
return nbdns.CustomZone{}
}
customZone := nbdns.CustomZone{
Domain: dns.Fqdn(dnsDomain),
Records: make([]nbdns.SimpleRecord, 0, len(a.Peers)),
}
domainSuffix := "." + dnsDomain
var sb strings.Builder
for _, peer := range a.Peers {
if peer.DNSLabel == "" {
merr = multierror.Append(merr, fmt.Errorf("peer %s has an empty DNS label", peer.Name))
continue
}
sb.Grow(len(peer.DNSLabel) + len(domainSuffix))
sb.WriteString(peer.DNSLabel)
sb.WriteString(domainSuffix)
customZone.Records = append(customZone.Records, nbdns.SimpleRecord{
Name: sb.String(),
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: defaultTTL,
RData: peer.IP.String(),
})
sb.Reset()
}
go func() {
if merr != nil {
log.WithContext(ctx).Errorf("error generating custom zone for account %s: %v", a.Id, merr)
}
}()
return customZone
}
// GetExpiredPeers returns peers that have been expired
@ -853,7 +922,7 @@ func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) {
func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) {
for _, gid := range groups {
group, ok := a.Groups[gid]
if !ok {
if !ok || group.Name == "All" {
continue
}
update := make([]string, 0, len(group.Peers))
@ -871,10 +940,18 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) {
}
// BuildManager creates a new DefaultAccountManager with a provided Store
func BuildManager(ctx context.Context, store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager,
singleAccountModeDomain string, dnsDomain string, eventStore activity.Store, geo *geolocation.Geolocation,
func BuildManager(
ctx context.Context,
store Store,
peersUpdateManager *PeersUpdateManager,
idpManager idp.Manager,
singleAccountModeDomain string,
dnsDomain string,
eventStore activity.Store,
geo *geolocation.Geolocation,
userDeleteFromIDPEnabled bool,
integratedPeerValidator integrated_validator.IntegratedValidator,
metrics telemetry.AppMetrics,
) (*DefaultAccountManager, error) {
am := &DefaultAccountManager{
Store: store,
@ -889,6 +966,7 @@ func BuildManager(ctx context.Context, store Store, peersUpdateManager *PeersUpd
peerLoginExpiry: NewDefaultScheduler(),
userDeleteFromIDPEnabled: userDeleteFromIDPEnabled,
integratedPeerValidator: integratedPeerValidator,
metrics: metrics,
}
allAccounts := store.GetAllAccounts(ctx)
// enable single account mode only if configured by user and number of existing accounts is not grater than 1
@ -1994,6 +2072,28 @@ func (am *DefaultAccountManager) GetAccountIDForPeerKey(ctx context.Context, pee
return am.Store.GetAccountIDByPeerPubKey(ctx, peerKey)
}
func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpeer.Peer, settings *Settings) (bool, error) {
user, err := am.Store.GetUserByUserID(ctx, peer.UserID)
if err != nil {
return false, err
}
err = checkIfPeerOwnerIsBlocked(peer, user)
if err != nil {
return false, err
}
if peerLoginExpired(ctx, peer, settings) {
err = am.handleExpiredPeer(ctx, user, peer)
if err != nil {
return false, err
}
return true, nil
}
return false, nil
}
// addAllGroup to account object if it doesn't exist
func addAllGroup(account *Account) error {
if len(account.Groups) == 0 {

View File

@ -24,6 +24,7 @@ import (
"github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/route"
)
@ -410,7 +411,8 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
validatedPeers[p] = struct{}{}
}
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, "netbird.io", validatedPeers)
customZone := account.GetPeersCustomZone(context.Background(), "netbird.io")
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, nil)
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
}
@ -2293,7 +2295,13 @@ func TestAccount_UserGroupsRemoveFromPeers(t *testing.T) {
})
}
func createManager(t *testing.T) (*DefaultAccountManager, error) {
type TB interface {
Cleanup(func())
Helper()
TempDir() string
}
func createManager(t TB) (*DefaultAccountManager, error) {
t.Helper()
store, err := createStore(t)
@ -2302,7 +2310,12 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) {
}
eventStore := &activity.InMemoryEventStore{}
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{})
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
if err != nil {
return nil, err
}
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
if err != nil {
return nil, err
}
@ -2310,7 +2323,7 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) {
return manager, nil
}
func createStore(t *testing.T) (Store, error) {
func createStore(t TB) (Store, error) {
t.Helper()
dataDir := t.TempDir()
store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir)

View File

@ -4,8 +4,8 @@ import (
"context"
"fmt"
"strconv"
"sync"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns"
@ -17,6 +17,50 @@ import (
const defaultTTL = 300
// DNSConfigCache is a thread-safe cache for DNS configuration components
type DNSConfigCache struct {
CustomZones sync.Map
NameServerGroups sync.Map
}
// GetCustomZone retrieves a cached custom zone
func (c *DNSConfigCache) GetCustomZone(key string) (*proto.CustomZone, bool) {
if c == nil {
return nil, false
}
if value, ok := c.CustomZones.Load(key); ok {
return value.(*proto.CustomZone), true
}
return nil, false
}
// SetCustomZone stores a custom zone in the cache
func (c *DNSConfigCache) SetCustomZone(key string, value *proto.CustomZone) {
if c == nil {
return
}
c.CustomZones.Store(key, value)
}
// GetNameServerGroup retrieves a cached name server group
func (c *DNSConfigCache) GetNameServerGroup(key string) (*proto.NameServerGroup, bool) {
if c == nil {
return nil, false
}
if value, ok := c.NameServerGroups.Load(key); ok {
return value.(*proto.NameServerGroup), true
}
return nil, false
}
// SetNameServerGroup stores a name server group in the cache
func (c *DNSConfigCache) SetNameServerGroup(key string, value *proto.NameServerGroup) {
if c == nil {
return
}
c.NameServerGroups.Store(key, value)
}
type lookupMap map[string]struct{}
// DNSSettings defines dns settings at the account level
@ -113,69 +157,73 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
return nil
}
func toProtocolDNSConfig(update nbdns.Config) *proto.DNSConfig {
protoUpdate := &proto.DNSConfig{ServiceEnable: update.ServiceEnable}
// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache
func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache) *proto.DNSConfig {
protoUpdate := &proto.DNSConfig{
ServiceEnable: update.ServiceEnable,
CustomZones: make([]*proto.CustomZone, 0, len(update.CustomZones)),
NameServerGroups: make([]*proto.NameServerGroup, 0, len(update.NameServerGroups)),
}
for _, zone := range update.CustomZones {
protoZone := &proto.CustomZone{Domain: zone.Domain}
for _, record := range zone.Records {
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{
Name: record.Name,
Type: int64(record.Type),
Class: record.Class,
TTL: int64(record.TTL),
RData: record.RData,
})
cacheKey := zone.Domain
if cachedZone, exists := cache.GetCustomZone(cacheKey); exists {
protoUpdate.CustomZones = append(protoUpdate.CustomZones, cachedZone)
} else {
protoZone := convertToProtoCustomZone(zone)
cache.SetCustomZone(cacheKey, protoZone)
protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone)
}
protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone)
}
for _, nsGroup := range update.NameServerGroups {
protoGroup := &proto.NameServerGroup{
Primary: nsGroup.Primary,
Domains: nsGroup.Domains,
SearchDomainsEnabled: nsGroup.SearchDomainsEnabled,
cacheKey := nsGroup.ID
if cachedGroup, exists := cache.GetNameServerGroup(cacheKey); exists {
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, cachedGroup)
} else {
protoGroup := convertToProtoNameServerGroup(nsGroup)
cache.SetNameServerGroup(cacheKey, protoGroup)
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup)
}
for _, ns := range nsGroup.NameServers {
protoNS := &proto.NameServer{
IP: ns.IP.String(),
Port: int64(ns.Port),
NSType: int64(ns.NSType),
}
protoGroup.NameServers = append(protoGroup.NameServers, protoNS)
}
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup)
}
return protoUpdate
}
func getPeersCustomZone(ctx context.Context, account *Account, dnsDomain string) nbdns.CustomZone {
if dnsDomain == "" {
log.WithContext(ctx).Errorf("no dns domain is set, returning empty zone")
return nbdns.CustomZone{}
// Helper function to convert nbdns.CustomZone to proto.CustomZone
func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone {
protoZone := &proto.CustomZone{
Domain: zone.Domain,
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
}
customZone := nbdns.CustomZone{
Domain: dns.Fqdn(dnsDomain),
}
for _, peer := range account.Peers {
if peer.DNSLabel == "" {
log.WithContext(ctx).Errorf("found a peer with empty dns label. It was probably caused by a invalid character in its name. Peer Name: %s", peer.Name)
continue
}
customZone.Records = append(customZone.Records, nbdns.SimpleRecord{
Name: dns.Fqdn(peer.DNSLabel + "." + dnsDomain),
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: defaultTTL,
RData: peer.IP.String(),
for _, record := range zone.Records {
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{
Name: record.Name,
Type: int64(record.Type),
Class: record.Class,
TTL: int64(record.TTL),
RData: record.RData,
})
}
return protoZone
}
return customZone
// Helper function to convert nbdns.NameServerGroup to proto.NameServerGroup
func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameServerGroup {
protoGroup := &proto.NameServerGroup{
Primary: nsGroup.Primary,
Domains: nsGroup.Domains,
SearchDomainsEnabled: nsGroup.SearchDomainsEnabled,
NameServers: make([]*proto.NameServer, 0, len(nsGroup.NameServers)),
}
for _, ns := range nsGroup.NameServers {
protoGroup.NameServers = append(protoGroup.NameServers, &proto.NameServer{
IP: ns.IP.String(),
Port: int64(ns.Port),
NSType: int64(ns.NSType),
})
}
return protoGroup
}
func getPeerNSGroups(account *Account, peerID string) []*nbdns.NameServerGroup {

View File

@ -2,9 +2,14 @@ package server
import (
"context"
"fmt"
"net/netip"
"reflect"
"testing"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/dns"
@ -195,7 +200,11 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
return nil, err
}
eventStore := &activity.InMemoryEventStore{}
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{})
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}, metrics)
}
func createDNSStore(t *testing.T) (Store, error) {
@ -320,3 +329,150 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro
return am.Store.GetAccount(context.Background(), account.Id)
}
func generateTestData(size int) nbdns.Config {
config := nbdns.Config{
ServiceEnable: true,
CustomZones: make([]nbdns.CustomZone, size),
NameServerGroups: make([]*nbdns.NameServerGroup, size),
}
for i := 0; i < size; i++ {
config.CustomZones[i] = nbdns.CustomZone{
Domain: fmt.Sprintf("domain%d.com", i),
Records: []nbdns.SimpleRecord{
{
Name: fmt.Sprintf("record%d", i),
Type: 1,
Class: "IN",
TTL: 3600,
RData: "192.168.1.1",
},
},
}
config.NameServerGroups[i] = &nbdns.NameServerGroup{
ID: fmt.Sprintf("group%d", i),
Primary: i == 0,
Domains: []string{fmt.Sprintf("domain%d.com", i)},
SearchDomainsEnabled: true,
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
Port: 53,
NSType: 1,
},
},
}
}
return config
}
func BenchmarkToProtocolDNSConfig(b *testing.B) {
sizes := []int{10, 100, 1000}
for _, size := range sizes {
testData := generateTestData(size)
b.Run(fmt.Sprintf("WithCache-Size%d", size), func(b *testing.B) {
cache := &DNSConfigCache{}
b.ResetTimer()
for i := 0; i < b.N; i++ {
toProtocolDNSConfig(testData, cache)
}
})
b.Run(fmt.Sprintf("WithoutCache-Size%d", size), func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
cache := &DNSConfigCache{}
toProtocolDNSConfig(testData, cache)
}
})
}
}
func TestToProtocolDNSConfigWithCache(t *testing.T) {
var cache DNSConfigCache
// Create two different configs
config1 := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "example.com",
Records: []nbdns.SimpleRecord{
{Name: "www", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.1"},
},
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
ID: "group1",
Name: "Group 1",
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("8.8.8.8"), Port: 53},
},
},
},
}
config2 := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "example.org",
Records: []nbdns.SimpleRecord{
{Name: "mail", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.2"},
},
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
ID: "group2",
Name: "Group 2",
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("8.8.4.4"), Port: 53},
},
},
},
}
// First run with config1
result1 := toProtocolDNSConfig(config1, &cache)
// Second run with config2
result2 := toProtocolDNSConfig(config2, &cache)
// Third run with config1 again
result3 := toProtocolDNSConfig(config1, &cache)
// Verify that result1 and result3 are identical
if !reflect.DeepEqual(result1, result3) {
t.Errorf("Results are not identical when run with the same input. Expected %v, got %v", result1, result3)
}
// Verify that result2 is different from result1 and result3
if reflect.DeepEqual(result1, result2) || reflect.DeepEqual(result2, result3) {
t.Errorf("Results should be different for different inputs")
}
// Verify that the cache contains elements from both configs
if _, exists := cache.GetCustomZone("example.com"); !exists {
t.Errorf("Cache should contain custom zone for example.com")
}
if _, exists := cache.GetCustomZone("example.org"); !exists {
t.Errorf("Cache should contain custom zone for example.org")
}
if _, exists := cache.GetNameServerGroup("group1"); !exists {
t.Errorf("Cache should contain name server group 'group1'")
}
if _, exists := cache.GetNameServerGroup("group2"); !exists {
t.Errorf("Cache should contain name server group 'group2'")
}
}

View File

@ -469,6 +469,35 @@ func (s *FileStore) GetUserByTokenID(_ context.Context, tokenID string) (*User,
return account.Users[userID].Copy(), nil
}
func (s *FileStore) GetUserByUserID(_ context.Context, userID string) (*User, error) {
accountID, ok := s.UserID2AccountID[userID]
if !ok {
return nil, status.Errorf(status.NotFound, "accountID not found: provided userID doesn't exists")
}
account, err := s.getAccount(accountID)
if err != nil {
return nil, err
}
return account.Users[userID].Copy(), nil
}
func (s *FileStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
account, err := s.getAccount(accountID)
if err != nil {
return nil, err
}
groupsSlice := make([]*nbgroup.Group, 0, len(account.Groups))
for _, group := range account.Groups {
groupsSlice = append(groupsSlice, group)
}
return groupsSlice, nil
}
// GetAllAccounts returns all accounts
func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) {
s.mux.Lock()

View File

@ -2,8 +2,12 @@ package server
import (
"context"
"errors"
"fmt"
"slices"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
@ -243,7 +247,7 @@ func difference(a, b []string) []string {
return diff
}
// DeleteGroup object of the peers
// DeleteGroup object of the peers.
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountId)
defer unlock()
@ -253,96 +257,14 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use
return err
}
g, ok := account.Groups[groupID]
group, ok := account.Groups[groupID]
if !ok {
return nil
}
// disable a deleting integration group if the initiator is not an admin service user
if g.Issued == nbgroup.GroupIssuedIntegration {
executingUser := account.Users[userId]
if executingUser == nil {
return status.Errorf(status.NotFound, "user not found")
}
if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser {
return status.Errorf(status.PermissionDenied, "only service users with admin power can delete integration group")
}
if err = validateDeleteGroup(account, group, userId); err != nil {
return err
}
// check route links
for _, r := range account.Routes {
for _, g := range r.Groups {
if g == groupID {
return &GroupLinkError{"route", string(r.NetID)}
}
}
for _, g := range r.PeerGroups {
if g == groupID {
return &GroupLinkError{"route", string(r.NetID)}
}
}
}
// check DNS links
for _, dns := range account.NameServerGroups {
for _, g := range dns.Groups {
if g == groupID {
return &GroupLinkError{"name server groups", dns.Name}
}
}
}
// check ACL links
for _, policy := range account.Policies {
for _, rule := range policy.Rules {
for _, src := range rule.Sources {
if src == groupID {
return &GroupLinkError{"policy", policy.Name}
}
}
for _, dst := range rule.Destinations {
if dst == groupID {
return &GroupLinkError{"policy", policy.Name}
}
}
}
}
// check setup key links
for _, setupKey := range account.SetupKeys {
for _, grp := range setupKey.AutoGroups {
if grp == groupID {
return &GroupLinkError{"setup key", setupKey.Name}
}
}
}
// check user links
for _, user := range account.Users {
for _, grp := range user.AutoGroups {
if grp == groupID {
return &GroupLinkError{"user", user.Id}
}
}
}
// check DisabledManagementGroups
for _, disabledMgmGrp := range account.DNSSettings.DisabledManagementGroups {
if disabledMgmGrp == groupID {
return &GroupLinkError{"disabled DNS management groups", g.Name}
}
}
// check integrated peer validator groups
if account.Settings.Extra != nil {
for _, integratedPeerValidatorGroups := range account.Settings.Extra.IntegratedValidatorGroups {
if groupID == integratedPeerValidatorGroups {
return &GroupLinkError{"integrated validator", g.Name}
}
}
}
delete(account.Groups, groupID)
account.Network.IncSerial()
@ -350,13 +272,57 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use
return err
}
am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, g.EventMeta())
am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, group.EventMeta())
am.updateAccountPeers(ctx, account)
return nil
}
// DeleteGroups deletes groups from an account.
// Note: This function does not acquire the global lock.
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
//
// If an error occurs while deleting a group, the function skips it and continues deleting other groups.
// Errors are collected and returned at the end.
func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error {
account, err := am.Store.GetAccount(ctx, accountId)
if err != nil {
return err
}
var allErrors error
deletedGroups := make([]*nbgroup.Group, 0, len(groupIDs))
for _, groupID := range groupIDs {
group, ok := account.Groups[groupID]
if !ok {
continue
}
if err := validateDeleteGroup(account, group, userId); err != nil {
allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err))
continue
}
delete(account.Groups, groupID)
deletedGroups = append(deletedGroups, group)
}
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
for _, g := range deletedGroups {
am.StoreEvent(ctx, userId, g.ID, accountId, activity.GroupDeleted, g.EventMeta())
}
am.updateAccountPeers(ctx, account)
return allErrors
}
// ListGroups objects of the peers
func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
@ -440,3 +406,102 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
return nil
}
func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string) error {
// disable a deleting integration group if the initiator is not an admin service user
if group.Issued == nbgroup.GroupIssuedIntegration {
executingUser := account.Users[userID]
if executingUser == nil {
return status.Errorf(status.NotFound, "user not found")
}
if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser {
return status.Errorf(status.PermissionDenied, "only service users with admin power can delete integration group")
}
}
if isLinked, linkedRoute := isGroupLinkedToRoute(account.Routes, group.ID); isLinked {
return &GroupLinkError{"route", string(linkedRoute.NetID)}
}
if isLinked, linkedDns := isGroupLinkedToDns(account.NameServerGroups, group.ID); isLinked {
return &GroupLinkError{"name server groups", linkedDns.Name}
}
if isLinked, linkedPolicy := isGroupLinkedToPolicy(account.Policies, group.ID); isLinked {
return &GroupLinkError{"policy", linkedPolicy.Name}
}
if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(account.SetupKeys, group.ID); isLinked {
return &GroupLinkError{"setup key", linkedSetupKey.Name}
}
if isLinked, linkedUser := isGroupLinkedToUser(account.Users, group.ID); isLinked {
return &GroupLinkError{"user", linkedUser.Id}
}
if slices.Contains(account.DNSSettings.DisabledManagementGroups, group.ID) {
return &GroupLinkError{"disabled DNS management groups", group.Name}
}
if account.Settings.Extra != nil {
if slices.Contains(account.Settings.Extra.IntegratedValidatorGroups, group.ID) {
return &GroupLinkError{"integrated validator", group.Name}
}
}
return nil
}
// isGroupLinkedToRoute checks if a group is linked to any route in the account.
func isGroupLinkedToRoute(routes map[route.ID]*route.Route, groupID string) (bool, *route.Route) {
for _, r := range routes {
if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) {
return true, r
}
}
return false, nil
}
// isGroupLinkedToPolicy checks if a group is linked to any policy in the account.
func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) {
for _, policy := range policies {
for _, rule := range policy.Rules {
if slices.Contains(rule.Sources, groupID) || slices.Contains(rule.Destinations, groupID) {
return true, policy
}
}
}
return false, nil
}
// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account.
func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, groupID string) (bool, *nbdns.NameServerGroup) {
for _, dns := range nameServerGroups {
for _, g := range dns.Groups {
if g == groupID {
return true, dns
}
}
}
return false, nil
}
// isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account.
func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bool, *SetupKey) {
for _, setupKey := range setupKeys {
if slices.Contains(setupKey.AutoGroups, groupID) {
return true, setupKey
}
}
return false, nil
}
// isGroupLinkedToUser checks if a group is linked to any user in the account.
func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) {
for _, user := range users {
if slices.Contains(user.AutoGroups, groupID) {
return true, user
}
}
return false, nil
}

View File

@ -3,12 +3,14 @@ package server
import (
"context"
"errors"
"fmt"
"testing"
nbdns "github.com/netbirdio/netbird/dns"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/route"
"github.com/stretchr/testify/assert"
)
const (
@ -21,7 +23,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) {
t.Error("failed to create account manager")
}
account, err := initTestGroupAccount(am)
_, account, err := initTestGroupAccount(am)
if err != nil {
t.Error("failed to init testing account")
}
@ -56,7 +58,7 @@ func TestDefaultAccountManager_DeleteGroup(t *testing.T) {
t.Error("failed to create account manager")
}
account, err := initTestGroupAccount(am)
_, account, err := initTestGroupAccount(am)
if err != nil {
t.Error("failed to init testing account")
}
@ -132,7 +134,136 @@ func TestDefaultAccountManager_DeleteGroup(t *testing.T) {
}
}
func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) {
func TestDefaultAccountManager_DeleteGroups(t *testing.T) {
am, err := createManager(t)
assert.NoError(t, err, "Failed to create account manager")
manager, account, err := initTestGroupAccount(am)
assert.NoError(t, err, "Failed to init testing account")
groups := make([]*nbgroup.Group, 10)
for i := 0; i < 10; i++ {
groups[i] = &nbgroup.Group{
ID: fmt.Sprintf("group-%d", i+1),
AccountID: account.Id,
Name: fmt.Sprintf("group-%d", i+1),
Issued: nbgroup.GroupIssuedAPI,
}
}
err = manager.SaveGroups(context.Background(), account.Id, groupAdminUserID, groups)
assert.NoError(t, err, "Failed to save test groups")
testCases := []struct {
name string
groupIDs []string
expectedReasons []string
expectedDeleted []string
expectedNotDeleted []string
}{
{
name: "route",
groupIDs: []string{"grp-for-route"},
expectedReasons: []string{"route"},
},
{
name: "route with peer groups",
groupIDs: []string{"grp-for-route2"},
expectedReasons: []string{"route"},
},
{
name: "name server groups",
groupIDs: []string{"grp-for-name-server-grp"},
expectedReasons: []string{"name server groups"},
},
{
name: "policy",
groupIDs: []string{"grp-for-policies"},
expectedReasons: []string{"policy"},
},
{
name: "setup keys",
groupIDs: []string{"grp-for-keys"},
expectedReasons: []string{"setup key"},
},
{
name: "users",
groupIDs: []string{"grp-for-users"},
expectedReasons: []string{"user"},
},
{
name: "integration",
groupIDs: []string{"grp-for-integration"},
expectedReasons: []string{"only service users with admin power can delete integration group"},
},
{
name: "successfully delete multiple groups",
groupIDs: []string{"group-1", "group-2"},
expectedDeleted: []string{"group-1", "group-2"},
},
{
name: "delete non-existent group",
groupIDs: []string{"non-existent-group"},
expectedDeleted: []string{"non-existent-group"},
},
{
name: "delete multiple groups with mixed results",
groupIDs: []string{"group-3", "grp-for-policies", "group-4", "grp-for-users"},
expectedReasons: []string{"policy", "user"},
expectedDeleted: []string{"group-3", "group-4"},
expectedNotDeleted: []string{"grp-for-policies", "grp-for-users"},
},
{
name: "delete groups with multiple errors",
groupIDs: []string{"grp-for-policies", "grp-for-users"},
expectedReasons: []string{"policy", "user"},
expectedNotDeleted: []string{"grp-for-policies", "grp-for-users"},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err = am.DeleteGroups(context.Background(), account.Id, groupAdminUserID, tc.groupIDs)
if len(tc.expectedReasons) > 0 {
assert.Error(t, err)
var foundExpectedErrors int
wrappedErr, ok := err.(interface{ Unwrap() []error })
assert.Equal(t, ok, true)
for _, e := range wrappedErr.Unwrap() {
var sErr *status.Error
if errors.As(e, &sErr) {
assert.Contains(t, tc.expectedReasons, sErr.Message, "unexpected error message")
foundExpectedErrors++
}
var gErr *GroupLinkError
if errors.As(e, &gErr) {
assert.Contains(t, tc.expectedReasons, gErr.Resource, "unexpected error resource")
foundExpectedErrors++
}
}
assert.Equal(t, len(tc.expectedReasons), foundExpectedErrors, "not all expected errors were found")
} else {
assert.NoError(t, err)
}
for _, groupID := range tc.expectedDeleted {
_, err := am.GetGroup(context.Background(), account.Id, groupID, groupAdminUserID)
assert.Error(t, err, "group should have been deleted: %s", groupID)
}
for _, groupID := range tc.expectedNotDeleted {
group, err := am.GetGroup(context.Background(), account.Id, groupID, groupAdminUserID)
assert.NoError(t, err, "group should not have been deleted: %s", groupID)
assert.NotNil(t, group, "group should exist: %s", groupID)
}
})
}
}
func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *Account, error) {
accountID := "testingAcc"
domain := "example.com"
@ -236,7 +367,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) {
err := am.Store.SaveAccount(context.Background(), account)
if err != nil {
return nil, err
return nil, nil, err
}
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute)
@ -247,5 +378,9 @@ func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) {
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForUsers)
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForIntegration)
return am.Store.GetAccount(context.Background(), account.Id)
acc, err := am.Store.GetAccount(context.Background(), account.Id)
if err != nil {
return nil, nil, err
}
return am, acc, nil
}

View File

@ -256,7 +256,7 @@ func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string
}
if err := s.accountManager.CheckUserAccessByJWTGroups(ctx, claims); err != nil {
return "", status.Errorf(codes.PermissionDenied, err.Error())
return "", status.Error(codes.PermissionDenied, err.Error())
}
return claims.UserId, nil
@ -267,15 +267,15 @@ func mapError(ctx context.Context, err error) error {
if e, ok := internalStatus.FromError(err); ok {
switch e.Type() {
case internalStatus.PermissionDenied:
return status.Errorf(codes.PermissionDenied, e.Message)
return status.Error(codes.PermissionDenied, e.Message)
case internalStatus.Unauthorized:
return status.Errorf(codes.PermissionDenied, e.Message)
return status.Error(codes.PermissionDenied, e.Message)
case internalStatus.Unauthenticated:
return status.Errorf(codes.PermissionDenied, e.Message)
return status.Error(codes.PermissionDenied, e.Message)
case internalStatus.PreconditionFailed:
return status.Errorf(codes.FailedPrecondition, e.Message)
return status.Error(codes.FailedPrecondition, e.Message)
case internalStatus.NotFound:
return status.Errorf(codes.NotFound, e.Message)
return status.Error(codes.NotFound, e.Message)
default:
}
}
@ -550,53 +550,46 @@ func toPeerConfig(peer *nbpeer.Peer, network *Network, dnsName string) *proto.Pe
}
}
func toRemotePeerConfig(peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig {
remotePeers := []*proto.RemotePeerConfig{}
for _, rPeer := range peers {
fqdn := rPeer.FQDN(dnsName)
remotePeers = append(remotePeers, &proto.RemotePeerConfig{
WgPubKey: rPeer.Key,
AllowedIps: []string{fmt.Sprintf(AllowedIPsFormat, rPeer.IP)},
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
Fqdn: fqdn,
})
}
return remotePeers
}
func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *TURNRelayToken, relayCredentials *TURNRelayToken, networkMap *NetworkMap, dnsName string, checks []*posture.Checks) *proto.SyncResponse {
wtConfig := toWiretrusteeConfig(config, turnCredentials, relayCredentials)
pConfig := toPeerConfig(peer, networkMap.Network, dnsName)
remotePeers := toRemotePeerConfig(networkMap.Peers, dnsName)
routesUpdate := toProtocolRoutes(networkMap.Routes)
dnsUpdate := toProtocolDNSConfig(networkMap.DNSConfig)
offlinePeers := toRemotePeerConfig(networkMap.OfflinePeers, dnsName)
firewallRules := toProtocolFirewallRules(networkMap.FirewallRules)
return &proto.SyncResponse{
WiretrusteeConfig: wtConfig,
PeerConfig: pConfig,
RemotePeers: remotePeers,
RemotePeersIsEmpty: len(remotePeers) == 0,
func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *TURNRelayToken, relayCredentials *TURNRelayToken, networkMap *NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache) *proto.SyncResponse {
response := &proto.SyncResponse{
WiretrusteeConfig: toWiretrusteeConfig(config, turnCredentials, relayCredentials),
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName),
NetworkMap: &proto.NetworkMap{
Serial: networkMap.Network.CurrentSerial(),
PeerConfig: pConfig,
RemotePeers: remotePeers,
OfflinePeers: offlinePeers,
RemotePeersIsEmpty: len(remotePeers) == 0,
Routes: routesUpdate,
DNSConfig: dnsUpdate,
FirewallRules: firewallRules,
FirewallRulesIsEmpty: len(firewallRules) == 0,
Serial: networkMap.Network.CurrentSerial(),
Routes: toProtocolRoutes(networkMap.Routes),
DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache),
},
Checks: toProtocolChecks(ctx, checks),
}
response.NetworkMap.PeerConfig = response.PeerConfig
allPeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers))
allPeers = appendRemotePeerConfig(allPeers, networkMap.Peers, dnsName)
response.RemotePeers = allPeers
response.NetworkMap.RemotePeers = allPeers
response.RemotePeersIsEmpty = len(allPeers) == 0
response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty
response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName)
firewallRules := toProtocolFirewallRules(networkMap.FirewallRules)
response.NetworkMap.FirewallRules = firewallRules
response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0
return response
}
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig {
for _, rPeer := range peers {
dst = append(dst, &proto.RemotePeerConfig{
WgPubKey: rPeer.Key,
AllowedIps: []string{rPeer.IP.String() + "/32"},
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
Fqdn: rPeer.FQDN(dnsName),
})
}
return dst
}
// IsHealthy indicates whether the service is healthy
@ -615,7 +608,7 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p
if s.config.TURNConfig.TimeBasedCredentials {
turnCredentials = trt
}
plainResp := toSyncResponse(ctx, s.config, peer, turnCredentials, trt, networkMap, s.accountManager.GetDNSDomain(), postureChecks)
plainResp := toSyncResponse(ctx, s.config, peer, turnCredentials, trt, networkMap, s.accountManager.GetDNSDomain(), postureChecks, nil)
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
if err != nil {

View File

@ -71,7 +71,8 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee
return
}
netMap := account.GetPeerNetworkMap(ctx, peerID, h.accountManager.GetDNSDomain(), validPeers)
customZone := account.GetPeersCustomZone(ctx, h.accountManager.GetDNSDomain())
netMap := account.GetPeerNetworkMap(ctx, peerID, customZone, validPeers, nil)
accessiblePeers := toAccessiblePeers(netMap, dnsDomain)
_, valid := validPeers[peer.ID]
@ -115,7 +116,9 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account,
util.WriteError(ctx, fmt.Errorf("internal error"), w)
return
}
netMap := account.GetPeerNetworkMap(ctx, peerID, h.accountManager.GetDNSDomain(), validPeers)
customZone := account.GetPeersCustomZone(ctx, h.accountManager.GetDNSDomain())
netMap := account.GetPeerNetworkMap(ctx, peerID, customZone, validPeers, nil)
accessiblePeers := toAccessiblePeers(netMap, dnsDomain)
_, valid := validPeers[peer.ID]
@ -194,9 +197,7 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
}
groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID)
accessiblePeerNumbers, _ := h.accessiblePeersNumber(r.Context(), account, peer.ID)
respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, accessiblePeerNumbers))
respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0))
}
validPeersMap, err := h.accountManager.GetValidatedPeers(account)
@ -210,16 +211,6 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, respBody)
}
func (h *PeersHandler) accessiblePeersNumber(ctx context.Context, account *server.Account, peerID string) (int, error) {
validatedPeersMap, err := h.accountManager.GetValidatedPeers(account)
if err != nil {
return 0, err
}
netMap := account.GetPeerNetworkMap(ctx, peerID, h.accountManager.GetDNSDomain(), validatedPeersMap)
return len(netMap.Peers) + len(netMap.OfflinePeers), nil
}
func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPeersMap map[string]struct{}) {
for _, peer := range respBody {
_, ok := approvedPeersMap[peer.Id]

View File

@ -46,7 +46,7 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
testPostureChecks[postureChecks.ID] = postureChecks
if err := postureChecks.Validate(); err != nil {
return status.Errorf(status.InvalidArgument, err.Error())
return status.Errorf(status.InvalidArgument, err.Error()) //nolint
}
return nil

View File

@ -3,6 +3,7 @@ package idp
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
@ -44,14 +45,14 @@ type mockJsonParser struct {
func (m *mockJsonParser) Marshal(v interface{}) ([]byte, error) {
if m.marshalErrorString != "" {
return nil, fmt.Errorf(m.marshalErrorString)
return nil, errors.New(m.marshalErrorString)
}
return m.jsonParser.Marshal(v)
}
func (m *mockJsonParser) Unmarshal(data []byte, v interface{}) error {
if m.unmarshalErrorString != "" {
return fmt.Errorf(m.unmarshalErrorString)
return errors.New(m.unmarshalErrorString)
}
return m.jsonParser.Unmarshal(data, v)
}

View File

@ -150,7 +150,7 @@ func (m *JWTValidator) ValidateAndParse(ctx context.Context, token string) (*jwt
// If we get here, the required token is missing
errorMsg := "required authorization token not found"
log.WithContext(ctx).Debugf(" Error: No credentials found (CredentialsOptional=false)")
return nil, fmt.Errorf(errorMsg)
return nil, errors.New(errorMsg)
}
// Now parse the token
@ -173,7 +173,7 @@ func (m *JWTValidator) ValidateAndParse(ctx context.Context, token string) (*jwt
// Check if the parsed token is valid...
if !parsedToken.Valid {
errorMsg := "token is invalid"
log.WithContext(ctx).Debugf(errorMsg)
log.WithContext(ctx).Debug(errorMsg)
return nil, errors.New(errorMsg)
}

View File

@ -20,6 +20,7 @@ import (
"github.com/netbirdio/netbird/formatter"
mgmtProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/util"
)
@ -419,8 +420,12 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, *DefaultAccoun
ctx := context.WithValue(context.Background(), formatter.ExecutionContextKey, formatter.SystemSource) //nolint:staticcheck
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted",
eventStore, nil, false, MocIntegratedValidator{})
eventStore, nil, false, MocIntegratedValidator{}, metrics)
if err != nil {
return nil, nil, "", err
}

View File

@ -26,6 +26,7 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/util"
)
@ -541,8 +542,13 @@ func startServer(config *server.Config) (*grpc.Server, net.Listener) {
peersUpdateManager := server.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{}
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted",
eventStore, nil, false, MocIntegratedValidator{})
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
if err != nil {
log.Fatalf("failed creating metrics: %v", err)
}
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics)
if err != nil {
log.Fatalf("failed creating a manager: %v", err)
}

View File

@ -42,6 +42,7 @@ type MockAccountManager struct {
SaveGroupFunc func(ctx context.Context, accountID, userID string, group *group.Group) error
SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*group.Group) error
DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error
DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error
ListGroupsFunc func(ctx context.Context, accountID string) ([]*group.Group, error)
GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
@ -67,6 +68,7 @@ type MockAccountManager struct {
SaveOrAddUserFunc func(ctx context.Context, accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error)
SaveOrAddUsersFunc func(ctx context.Context, accountID, initiatorUserID string, update []*server.User, addIfNotExists bool) ([]*server.UserInfo, error)
DeleteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error
DeleteRegularUsersFunc func(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string) error
CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
DeletePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) error
GetPATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error)
@ -326,6 +328,14 @@ func (am *MockAccountManager) DeleteGroup(ctx context.Context, accountId, userId
return status.Errorf(codes.Unimplemented, "method DeleteGroup is not implemented")
}
// DeleteGroups mock implementation of DeleteGroups from server.AccountManager interface
func (am *MockAccountManager) DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error {
if am.DeleteGroupsFunc != nil {
return am.DeleteGroupsFunc(ctx, accountId, userId, groupIDs)
}
return status.Errorf(codes.Unimplemented, "method DeleteGroups is not implemented")
}
// ListGroups mock implementation of ListGroups from server.AccountManager interface
func (am *MockAccountManager) ListGroups(ctx context.Context, accountID string) ([]*group.Group, error) {
if am.ListGroupsFunc != nil {
@ -528,6 +538,14 @@ func (am *MockAccountManager) DeleteUser(ctx context.Context, accountID string,
return status.Errorf(codes.Unimplemented, "method DeleteUser is not implemented")
}
// DeleteRegularUsers mocks DeleteRegularUsers of the AccountManager interface
func (am *MockAccountManager) DeleteRegularUsers(ctx context.Context, accountID string, initiatorUserID string, targetUserIDs []string) error {
if am.DeleteRegularUsersFunc != nil {
return am.DeleteRegularUsersFunc(ctx, accountID, initiatorUserID, targetUserIDs)
}
return status.Errorf(codes.Unimplemented, "method DeleteRegularUsers is not implemented")
}
func (am *MockAccountManager) InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error {
if am.InviteUserFunc != nil {
return am.InviteUserFunc(ctx, accountID, initiatorUserID, targetUserID)

View File

@ -11,6 +11,7 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/telemetry"
)
const (
@ -762,7 +763,11 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
return nil, err
}
eventStore := &activity.InMemoryEventStore{}
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{})
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics)
}
func createNSStore(t *testing.T) (Store, error) {

View File

@ -5,6 +5,7 @@ import (
"fmt"
"net"
"strings"
"sync"
"time"
"github.com/rs/xid"
@ -65,12 +66,14 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
peers := make([]*nbpeer.Peer, 0)
peersMap := make(map[string]*nbpeer.Peer)
if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked {
regularUser := !user.HasAdminPower() && !user.IsServiceUser
if regularUser && account.Settings.RegularUsersViewBlocked {
return peers, nil
}
for _, peer := range account.Peers {
if !(user.HasAdminPower() || user.IsServiceUser) && user.Id != peer.UserID {
if regularUser && user.Id != peer.UserID {
// only display peers that belong to the current user if the current user is not an admin
continue
}
@ -79,6 +82,10 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
peersMap[peer.ID] = p
}
if !regularUser {
return peers, nil
}
// fetch all the peers that have access to the user's peers
for _, peer := range peers {
aclPeers, _ := account.getPeerConnectionResources(ctx, peer.ID, approvedPeersMap)
@ -316,7 +323,8 @@ func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID strin
if err != nil {
return nil, err
}
return account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, validatedPeers), nil
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
return account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, nil), nil
}
// GetPeerNetwork returns the Network for a given peer
@ -529,7 +537,8 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
}
postureChecks := am.getPeerPostureChecks(account, peer)
networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, am.dnsDomain, approvedPeersMap)
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
return newPeer, networkMap, postureChecks, nil
}
@ -540,16 +549,25 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
return nil, nil, nil, status.NewPeerNotRegisteredError()
}
err = checkIfPeerOwnerIsBlocked(peer, account)
if err != nil {
return nil, nil, nil, err
if peer.UserID != "" {
log.Infof("Peer has no userID")
user, err := account.FindUser(peer.UserID)
if err != nil {
return nil, nil, nil, err
}
err = checkIfPeerOwnerIsBlocked(peer, user)
if err != nil {
return nil, nil, nil, err
}
}
if peerLoginExpired(ctx, peer, account.Settings) {
return nil, nil, nil, status.NewPeerLoginExpiredError()
}
peer, updated := updatePeerMeta(peer, sync.Meta, account)
updated := peer.UpdateMetaIfNew(sync.Meta)
if updated {
err = am.Store.SavePeer(ctx, account.Id, peer)
if err != nil {
@ -585,7 +603,8 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
}
postureChecks = am.getPeerPostureChecks(account, peer)
return peer, account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, validPeersMap), postureChecks, nil
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
}
// LoginPeer logs in or registers a peer.
@ -614,31 +633,28 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
// it means that the client has already checked if it needs login and had been through the SSO flow
// so, we can skip this check and directly proceed with the login
if login.UserID == "" {
log.Info("Peer needs login")
err = am.checkIFPeerNeedsLoginWithoutLock(ctx, accountID, login)
if err != nil {
return nil, nil, nil, err
}
}
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
unlockAccount := am.Store.AcquireReadLockByUID(ctx, accountID)
defer unlockAccount()
unlockPeer := am.Store.AcquireWriteLockByUID(ctx, login.WireGuardPubKey)
defer func() {
if unlock != nil {
unlock()
if unlockPeer != nil {
unlockPeer()
}
}()
// fetch the account from the store once more after acquiring lock to avoid concurrent updates inconsistencies
account, err := am.Store.GetAccount(ctx, accountID)
peer, err := am.Store.GetPeerByPeerPubKey(ctx, login.WireGuardPubKey)
if err != nil {
return nil, nil, nil, err
}
peer, err := account.FindPeerByPubKey(login.WireGuardPubKey)
if err != nil {
return nil, nil, nil, status.NewPeerNotRegisteredError()
}
err = checkIfPeerOwnerIsBlocked(peer, account)
settings, err := am.Store.GetAccountSettings(ctx, accountID)
if err != nil {
return nil, nil, nil, err
}
@ -646,21 +662,39 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
// this flag prevents unnecessary calls to the persistent store.
shouldStorePeer := false
updateRemotePeers := false
if peerLoginExpired(ctx, peer, account.Settings) {
err = am.handleExpiredPeer(ctx, login, account, peer)
if login.UserID != "" {
changed, err := am.handleUserPeer(ctx, peer, settings)
if err != nil {
return nil, nil, nil, err
}
updateRemotePeers = true
shouldStorePeer = true
if changed {
shouldStorePeer = true
updateRemotePeers = true
}
}
isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
groups, err := am.Store.GetAccountGroups(ctx, accountID)
if err != nil {
return nil, nil, nil, err
}
peer, updated := updatePeerMeta(peer, login.Meta, account)
var grps []string
for _, group := range groups {
for _, id := range group.Peers {
if id == peer.ID {
grps = append(grps, group.ID)
break
}
}
}
isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, grps, settings.Extra)
if err != nil {
return nil, nil, nil, err
}
updated := peer.UpdateMetaIfNew(login.Meta)
if updated {
shouldStorePeer = true
}
@ -677,8 +711,13 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
}
}
unlock()
unlock = nil
unlockPeer()
unlockPeer = nil
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, nil, nil, err
}
if updateRemotePeers || isStatusChanged {
am.updateAccountPeers(ctx, account)
@ -732,39 +771,34 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
}
postureChecks = am.getPeerPostureChecks(account, peer)
return peer, account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, approvedPeersMap), postureChecks, nil
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
}
func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, login PeerLogin, account *Account, peer *nbpeer.Peer) error {
err := checkAuth(ctx, login.UserID, peer)
func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *User, peer *nbpeer.Peer) error {
err := checkAuth(ctx, user.Id, peer)
if err != nil {
return err
}
// If peer was expired before and if it reached this point, it is re-authenticated.
// UserID is present, meaning that JWT validation passed successfully in the API layer.
updatePeerLastLogin(peer, account)
// sync user last login with peer last login
user, err := account.FindUser(login.UserID)
if err != nil {
return status.Errorf(status.Internal, "couldn't find user")
}
err = am.Store.SaveUserLastLogin(account.Id, user.Id, peer.LastLogin)
peer = peer.UpdateLastLogin()
err = am.Store.SavePeer(ctx, peer.AccountID, peer)
if err != nil {
return err
}
am.StoreEvent(ctx, login.UserID, peer.ID, account.Id, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain()))
err = am.Store.SaveUserLastLogin(user.AccountID, user.Id, peer.LastLogin)
if err != nil {
return err
}
am.StoreEvent(ctx, user.Id, peer.ID, user.AccountID, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain()))
return nil
}
func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, account *Account) error {
func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, user *User) error {
if peer.AddedWithSSOLogin() {
user, err := account.FindUser(peer.UserID)
if err != nil {
return status.Errorf(status.PermissionDenied, "user doesn't exist")
}
if user.IsBlocked() {
return status.Errorf(status.PermissionDenied, "user is blocked")
}
@ -794,11 +828,6 @@ func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *Settings
return false
}
func updatePeerLastLogin(peer *nbpeer.Peer, account *Account) {
peer.UpdateLastLogin()
account.UpdatePeer(peer)
}
// UpdatePeerSSHKey updates peer's public SSH key
func (am *DefaultAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error {
if sshKey == "" {
@ -897,33 +926,48 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID,
return nil, status.Errorf(status.Internal, "user %s has no access to peer %s under account %s", userID, peerID, accountID)
}
func updatePeerMeta(peer *nbpeer.Peer, meta nbpeer.PeerSystemMeta, account *Account) (*nbpeer.Peer, bool) {
if peer.UpdateMetaIfNew(meta) {
account.UpdatePeer(peer)
return peer, true
}
return peer, false
}
// updateAccountPeers updates all peers that belong to an account.
// Should be called when changes have to be synced to peers.
func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account *Account) {
start := time.Now()
defer func() {
if am.metrics != nil {
am.metrics.AccountManagerMetrics().CountUpdateAccountPeersDuration(time.Since(start))
}
}()
peers := account.GetPeers()
approvedPeersMap, err := am.GetValidatedPeers(account)
if err != nil {
log.WithContext(ctx).Errorf("failed send out updates to peers, failed to validate peer: %v", err)
log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to validate peer: %v", err)
return
}
var wg sync.WaitGroup
semaphore := make(chan struct{}, 10)
dnsCache := &DNSConfigCache{}
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
for _, peer := range peers {
if !am.peersUpdateManager.HasChannel(peer.ID) {
log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
continue
}
postureChecks := am.getPeerPostureChecks(account, peer)
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, approvedPeersMap)
update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks)
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update})
wg.Add(1)
semaphore <- struct{}{}
go func(p *nbpeer.Peer) {
defer wg.Done()
defer func() { <-semaphore }()
postureChecks := am.getPeerPostureChecks(account, p)
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache)
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update})
}(peer)
}
wg.Wait()
}

View File

@ -1,7 +1,6 @@
package peer
import (
"fmt"
"net"
"net/netip"
"slices"
@ -241,7 +240,7 @@ func (p *Peer) FQDN(dnsDomain string) string {
if dnsDomain == "" {
return ""
}
return fmt.Sprintf("%s.%s", p.DNSLabel, dnsDomain)
return p.DNSLabel + "." + dnsDomain
}
// EventMeta returns activity event meta related to the peer

View File

@ -0,0 +1,31 @@
package peer
import (
"fmt"
"testing"
)
// FQDNOld is the original implementation for benchmarking purposes
func (p *Peer) FQDNOld(dnsDomain string) string {
if dnsDomain == "" {
return ""
}
return fmt.Sprintf("%s.%s", p.DNSLabel, dnsDomain)
}
func BenchmarkFQDN(b *testing.B) {
p := &Peer{DNSLabel: "test-peer"}
dnsDomain := "example.com"
b.Run("Old", func(b *testing.B) {
for i := 0; i < b.N; i++ {
p.FQDNOld(dnsDomain)
}
})
b.Run("New", func(b *testing.B) {
for i := 0; i < b.N; i++ {
p.FQDN(dnsDomain)
}
})
}

View File

@ -2,15 +2,26 @@ package server
import (
"context"
"fmt"
"io"
"net"
"net/netip"
"os"
"testing"
"time"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/proto"
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
nbroute "github.com/netbirdio/netbird/route"
)
func TestPeer_LoginExpired(t *testing.T) {
@ -633,3 +644,354 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
}
}
func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccountManager, string, string, error) {
b.Helper()
manager, err := createManager(b)
if err != nil {
return nil, "", "", err
}
accountID := "test_account"
adminUser := "account_creator"
regularUser := "regular_user"
account := newAccountWithId(context.Background(), accountID, adminUser, "")
account.Users[regularUser] = &User{
Id: regularUser,
Role: UserRoleUser,
}
// Create peers
for i := 0; i < peers; i++ {
peerKey, _ := wgtypes.GeneratePrivateKey()
peer := &nbpeer.Peer{
ID: fmt.Sprintf("peer-%d", i),
DNSLabel: fmt.Sprintf("peer-%d", i),
Key: peerKey.PublicKey().String(),
IP: net.ParseIP(fmt.Sprintf("100.64.%d.%d", i/256, i%256)),
Status: &nbpeer.PeerStatus{},
UserID: regularUser,
}
account.Peers[peer.ID] = peer
}
// Create groups and policies
account.Policies = make([]*Policy, 0, groups)
for i := 0; i < groups; i++ {
groupID := fmt.Sprintf("group-%d", i)
group := &nbgroup.Group{
ID: groupID,
Name: fmt.Sprintf("Group %d", i),
}
for j := 0; j < peers/groups; j++ {
peerIndex := i*(peers/groups) + j
group.Peers = append(group.Peers, fmt.Sprintf("peer-%d", peerIndex))
}
account.Groups[groupID] = group
// Create a policy for this group
policy := &Policy{
ID: fmt.Sprintf("policy-%d", i),
Name: fmt.Sprintf("Policy for Group %d", i),
Enabled: true,
Rules: []*PolicyRule{
{
ID: fmt.Sprintf("rule-%d", i),
Name: fmt.Sprintf("Rule for Group %d", i),
Enabled: true,
Sources: []string{groupID},
Destinations: []string{groupID},
Bidirectional: true,
Protocol: PolicyRuleProtocolALL,
Action: PolicyTrafficActionAccept,
},
},
}
account.Policies = append(account.Policies, policy)
}
account.PostureChecks = []*posture.Checks{
{
ID: "PostureChecksAll",
Name: "All",
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.0.1",
},
},
},
}
err = manager.Store.SaveAccount(context.Background(), account)
if err != nil {
return nil, "", "", err
}
return manager, accountID, regularUser, nil
}
func BenchmarkGetPeers(b *testing.B) {
benchCases := []struct {
name string
peers int
groups int
}{
{"Small", 50, 5},
{"Medium", 500, 10},
{"Large", 5000, 20},
{"Small single", 50, 1},
{"Medium single", 500, 1},
{"Large 5", 5000, 5},
}
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) {
manager, accountID, userID, err := setupTestAccountManager(b, bc.peers, bc.groups)
if err != nil {
b.Fatalf("Failed to setup test account manager: %v", err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := manager.GetPeers(context.Background(), accountID, userID)
if err != nil {
b.Fatalf("GetPeers failed: %v", err)
}
}
})
}
}
func BenchmarkUpdateAccountPeers(b *testing.B) {
benchCases := []struct {
name string
peers int
groups int
}{
{"Small", 50, 5},
{"Medium", 500, 10},
{"Large", 5000, 20},
{"Small single", 50, 1},
{"Medium single", 500, 1},
{"Large 5", 5000, 5},
}
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) {
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
if err != nil {
b.Fatalf("Failed to setup test account manager: %v", err)
}
ctx := context.Background()
account, err := manager.Store.GetAccount(ctx, accountID)
if err != nil {
b.Fatalf("Failed to get account: %v", err)
}
peerChannels := make(map[string]chan *UpdateMessage)
for peerID := range account.Peers {
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
}
manager.peersUpdateManager.peerChannels = peerChannels
b.ResetTimer()
start := time.Now()
for i := 0; i < b.N; i++ {
manager.updateAccountPeers(ctx, account)
}
duration := time.Since(start)
b.ReportMetric(float64(duration.Nanoseconds())/float64(b.N)/1e6, "ms/op")
b.ReportMetric(0, "ns/op")
})
}
}
func TestToSyncResponse(t *testing.T) {
_, ipnet, err := net.ParseCIDR("192.168.1.0/24")
if err != nil {
t.Fatal(err)
}
domainList, err := domain.FromStringList([]string{"example.com"})
if err != nil {
t.Fatal(err)
}
config := &Config{
Signal: &Host{
Proto: "https",
URI: "signal.uri",
Username: "",
Password: "",
},
Stuns: []*Host{{URI: "stun.uri", Proto: UDP}},
TURNConfig: &TURNConfig{
Turns: []*Host{{URI: "turn.uri", Proto: UDP, Username: "turn-user", Password: "turn-pass"}},
},
}
peer := &nbpeer.Peer{
IP: net.ParseIP("192.168.1.1"),
SSHEnabled: true,
Key: "peer-key",
DNSLabel: "peer1",
SSHKey: "peer1-ssh-key",
}
turnCredentials := &TURNCredentials{
Username: "turn-user",
Password: "turn-pass",
}
networkMap := &NetworkMap{
Network: &Network{Net: *ipnet, Serial: 1000},
Peers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.2"), Key: "peer2-key", DNSLabel: "peer2", SSHEnabled: true, SSHKey: "peer2-ssh-key"}},
OfflinePeers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.3"), Key: "peer3-key", DNSLabel: "peer3", SSHEnabled: true, SSHKey: "peer3-ssh-key"}},
Routes: []*nbroute.Route{
{
ID: "route1",
Network: netip.MustParsePrefix("10.0.0.0/24"),
Domains: domainList,
KeepRoute: true,
NetID: "route1",
Peer: "peer1",
NetworkType: 1,
Masquerade: true,
Metric: 9999,
Enabled: true,
},
},
DNSConfig: nbdns.Config{
ServiceEnable: true,
NameServerGroups: []*nbdns.NameServerGroup{
{
NameServers: []nbdns.NameServer{{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
}},
Primary: true,
Domains: []string{"example.com"},
Enabled: true,
SearchDomainsEnabled: true,
},
{
ID: "ns1",
NameServers: []nbdns.NameServer{{
IP: netip.MustParseAddr("1.1.1.1"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
}},
Groups: []string{"group1"},
Primary: true,
Domains: []string{"example.com"},
Enabled: true,
SearchDomainsEnabled: true,
},
},
CustomZones: []nbdns.CustomZone{{Domain: "example.com", Records: []nbdns.SimpleRecord{{Name: "example.com", Type: 1, Class: "IN", TTL: 60, RData: "100.64.0.1"}}}},
},
FirewallRules: []*FirewallRule{
{PeerIP: "192.168.1.2", Direction: firewallRuleDirectionIN, Action: string(PolicyTrafficActionAccept), Protocol: string(PolicyRuleProtocolTCP), Port: "80"},
},
}
dnsName := "example.com"
checks := []*posture.Checks{
{
Checks: posture.ChecksDefinition{
ProcessCheck: &posture.ProcessCheck{
Processes: []posture.Process{{LinuxPath: "/usr/bin/netbird"}},
},
},
},
}
dnsCache := &DNSConfigCache{}
response := toSyncResponse(context.Background(), config, peer, turnCredentials, networkMap, dnsName, checks, dnsCache)
assert.NotNil(t, response)
// assert peer config
assert.Equal(t, "192.168.1.1/24", response.PeerConfig.Address)
assert.Equal(t, "peer1.example.com", response.PeerConfig.Fqdn)
assert.Equal(t, true, response.PeerConfig.SshConfig.SshEnabled)
// assert wiretrustee config
assert.Equal(t, "signal.uri", response.WiretrusteeConfig.Signal.Uri)
assert.Equal(t, proto.HostConfig_HTTPS, response.WiretrusteeConfig.Signal.GetProtocol())
assert.Equal(t, "stun.uri", response.WiretrusteeConfig.Stuns[0].Uri)
assert.Equal(t, "turn.uri", response.WiretrusteeConfig.Turns[0].HostConfig.GetUri())
assert.Equal(t, "turn-user", response.WiretrusteeConfig.Turns[0].User)
assert.Equal(t, "turn-pass", response.WiretrusteeConfig.Turns[0].Password)
// assert RemotePeers
assert.Equal(t, 1, len(response.RemotePeers))
assert.Equal(t, "192.168.1.2/32", response.RemotePeers[0].AllowedIps[0])
assert.Equal(t, "peer2-key", response.RemotePeers[0].WgPubKey)
assert.Equal(t, "peer2.example.com", response.RemotePeers[0].GetFqdn())
assert.Equal(t, false, response.RemotePeers[0].GetSshConfig().GetSshEnabled())
assert.Equal(t, []byte("peer2-ssh-key"), response.RemotePeers[0].GetSshConfig().GetSshPubKey())
// assert network map
assert.Equal(t, uint64(1000), response.NetworkMap.Serial)
assert.Equal(t, "192.168.1.1/24", response.NetworkMap.PeerConfig.Address)
assert.Equal(t, "peer1.example.com", response.NetworkMap.PeerConfig.Fqdn)
assert.Equal(t, true, response.NetworkMap.PeerConfig.SshConfig.SshEnabled)
// assert network map RemotePeers
assert.Equal(t, 1, len(response.NetworkMap.RemotePeers))
assert.Equal(t, "192.168.1.2/32", response.NetworkMap.RemotePeers[0].AllowedIps[0])
assert.Equal(t, "peer2-key", response.NetworkMap.RemotePeers[0].WgPubKey)
assert.Equal(t, "peer2.example.com", response.NetworkMap.RemotePeers[0].GetFqdn())
assert.Equal(t, []byte("peer2-ssh-key"), response.NetworkMap.RemotePeers[0].GetSshConfig().GetSshPubKey())
// assert network map OfflinePeers
assert.Equal(t, 1, len(response.NetworkMap.OfflinePeers))
assert.Equal(t, "192.168.1.3/32", response.NetworkMap.OfflinePeers[0].AllowedIps[0])
assert.Equal(t, "peer3-key", response.NetworkMap.OfflinePeers[0].WgPubKey)
assert.Equal(t, "peer3.example.com", response.NetworkMap.OfflinePeers[0].GetFqdn())
assert.Equal(t, []byte("peer3-ssh-key"), response.NetworkMap.OfflinePeers[0].GetSshConfig().GetSshPubKey())
// assert network map Routes
assert.Equal(t, 1, len(response.NetworkMap.Routes))
assert.Equal(t, "10.0.0.0/24", response.NetworkMap.Routes[0].Network)
assert.Equal(t, "route1", response.NetworkMap.Routes[0].ID)
assert.Equal(t, "peer1", response.NetworkMap.Routes[0].Peer)
assert.Equal(t, "example.com", response.NetworkMap.Routes[0].Domains[0])
assert.Equal(t, true, response.NetworkMap.Routes[0].KeepRoute)
assert.Equal(t, true, response.NetworkMap.Routes[0].Masquerade)
assert.Equal(t, int64(9999), response.NetworkMap.Routes[0].Metric)
assert.Equal(t, int64(1), response.NetworkMap.Routes[0].NetworkType)
assert.Equal(t, "route1", response.NetworkMap.Routes[0].NetID)
// assert network map DNSConfig
assert.Equal(t, true, response.NetworkMap.DNSConfig.ServiceEnable)
assert.Equal(t, 1, len(response.NetworkMap.DNSConfig.CustomZones))
assert.Equal(t, 2, len(response.NetworkMap.DNSConfig.NameServerGroups))
// assert network map DNSConfig.CustomZones
assert.Equal(t, "example.com", response.NetworkMap.DNSConfig.CustomZones[0].Domain)
assert.Equal(t, 1, len(response.NetworkMap.DNSConfig.CustomZones[0].Records))
assert.Equal(t, "example.com", response.NetworkMap.DNSConfig.CustomZones[0].Records[0].Name)
assert.Equal(t, int64(1), response.NetworkMap.DNSConfig.CustomZones[0].Records[0].Type)
assert.Equal(t, "IN", response.NetworkMap.DNSConfig.CustomZones[0].Records[0].Class)
assert.Equal(t, int64(60), response.NetworkMap.DNSConfig.CustomZones[0].Records[0].TTL)
assert.Equal(t, "100.64.0.1", response.NetworkMap.DNSConfig.CustomZones[0].Records[0].RData)
// assert network map DNSConfig.NameServerGroups
assert.Equal(t, true, response.NetworkMap.DNSConfig.NameServerGroups[0].Primary)
assert.Equal(t, true, response.NetworkMap.DNSConfig.NameServerGroups[0].SearchDomainsEnabled)
assert.Equal(t, "example.com", response.NetworkMap.DNSConfig.NameServerGroups[0].Domains[0])
assert.Equal(t, "8.8.8.8", response.NetworkMap.DNSConfig.NameServerGroups[0].NameServers[0].GetIP())
assert.Equal(t, int64(1), response.NetworkMap.DNSConfig.NameServerGroups[0].NameServers[0].GetNSType())
assert.Equal(t, int64(53), response.NetworkMap.DNSConfig.NameServerGroups[0].NameServers[0].GetPort())
// assert network map Firewall
assert.Equal(t, 1, len(response.NetworkMap.FirewallRules))
assert.Equal(t, "192.168.1.2", response.NetworkMap.FirewallRules[0].PeerIP)
assert.Equal(t, proto.FirewallRule_IN, response.NetworkMap.FirewallRules[0].Direction)
assert.Equal(t, proto.FirewallRule_ACCEPT, response.NetworkMap.FirewallRules[0].Action)
assert.Equal(t, proto.FirewallRule_TCP, response.NetworkMap.FirewallRules[0].Protocol)
assert.Equal(t, "80", response.NetworkMap.FirewallRules[0].Port)
// assert posture checks
assert.Equal(t, 1, len(response.Checks))
assert.Equal(t, "/usr/bin/netbird", response.Checks[0].Files[0])
}

View File

@ -213,7 +213,6 @@ type FirewallRule struct {
//
// This function returns the list of peers and firewall rules that are applicable to a given peer.
func (a *Account) getPeerConnectionResources(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) {
generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx)
for _, policy := range a.Policies {
if !policy.Enabled {
@ -225,8 +224,8 @@ func (a *Account) getPeerConnectionResources(ctx context.Context, peerID string,
continue
}
sourcePeers, peerInSources := getAllPeersFromGroups(ctx, a, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap)
destinationPeers, peerInDestinations := getAllPeersFromGroups(ctx, a, rule.Destinations, peerID, nil, validatedPeersMap)
sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap)
destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, rule.Destinations, peerID, nil, validatedPeersMap)
if rule.Bidirectional {
if peerInSources {
@ -290,8 +289,8 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule,
fr.PeerIP = "0.0.0.0"
}
ruleID := (rule.ID + fr.PeerIP + strconv.Itoa(direction) +
fr.Protocol + fr.Action + strings.Join(rule.Ports, ","))
ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) +
fr.Protocol + fr.Action + strings.Join(rule.Ports, ",")
if _, ok := rulesExists[ruleID]; ok {
continue
}
@ -491,23 +490,23 @@ func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule {
//
// Important: Posture checks are applicable only to source group peers,
// for destination group peers, call this method with an empty list of sourcePostureChecksIDs
func getAllPeersFromGroups(ctx context.Context, account *Account, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) {
func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) {
peerInGroups := false
filteredPeers := make([]*nbpeer.Peer, 0, len(groups))
for _, g := range groups {
group, ok := account.Groups[g]
group, ok := a.Groups[g]
if !ok {
continue
}
for _, p := range group.Peers {
peer, ok := account.Peers[p]
peer, ok := a.Peers[p]
if !ok || peer == nil {
continue
}
// validate the peer based on policy posture checks applied
isValid := account.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID)
isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID)
if !isValid {
continue
}
@ -535,7 +534,7 @@ func (a *Account) validatePostureChecksOnPeer(ctx context.Context, sourcePosture
}
for _, postureChecksID := range sourcePostureChecksID {
postureChecks := getPostureChecks(a, postureChecksID)
postureChecks := a.getPostureChecks(postureChecksID)
if postureChecks == nil {
continue
}
@ -553,8 +552,8 @@ func (a *Account) validatePostureChecksOnPeer(ctx context.Context, sourcePosture
return true
}
func getPostureChecks(account *Account, postureChecksID string) *posture.Checks {
for _, postureChecks := range account.PostureChecks {
func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks {
for _, postureChecks := range a.PostureChecks {
if postureChecks.ID == postureChecksID {
return postureChecks
}

View File

@ -60,7 +60,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
}
if err := postureChecks.Validate(); err != nil {
return status.Errorf(status.InvalidArgument, err.Error())
return status.Errorf(status.InvalidArgument, err.Error()) //nolint
}
exists, uniqName := am.savePostureChecks(account, postureChecks)

View File

@ -13,6 +13,7 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/route"
)
@ -1233,7 +1234,11 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) {
return nil, err
}
eventStore := &activity.InMemoryEventStore{}
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{})
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics)
}
func createRouterStore(t *testing.T) (Store, error) {

View File

@ -223,10 +223,8 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
return nil, err
}
for _, group := range autoGroups {
if _, ok := account.Groups[group]; !ok {
return nil, status.Errorf(status.NotFound, "group %s doesn't exist", group)
}
if err := validateSetupKeyAutoGroups(account, autoGroups); err != nil {
return nil, err
}
setupKey := GenerateSetupKey(keyName, keyType, keyDuration, autoGroups, usageLimit, ephemeral)
@ -279,6 +277,10 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
return nil, status.Errorf(status.NotFound, "setup key not found")
}
if err := validateSetupKeyAutoGroups(account, keyToSave.AutoGroups); err != nil {
return nil, err
}
// only auto groups, revoked status, and name can be updated for now
newKey := oldKey.Copy()
newKey.Name = keyToSave.Name
@ -399,3 +401,16 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use
return foundKey, nil
}
func validateSetupKeyAutoGroups(account *Account, autoGroups []string) error {
for _, group := range autoGroups {
g, ok := account.Groups[group]
if !ok {
return status.Errorf(status.NotFound, "group %s doesn't exist", group)
}
if g.Name == "All" {
return status.Errorf(status.InvalidArgument, "can't add All group to the setup key")
}
}
return nil
}

View File

@ -26,10 +26,17 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
t.Fatal(err)
}
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
ID: "group_1",
Name: "group_name_1",
Peers: []string{},
err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{
{
ID: "group_1",
Name: "group_name_1",
Peers: []string{},
},
{
ID: "group_2",
Name: "group_name_2",
Peers: []string{},
},
})
if err != nil {
t.Fatal(err)
@ -70,6 +77,19 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
assert.NotEmpty(t, ev.Meta["key"])
assert.Equal(t, userID, ev.InitiatorID)
assert.Equal(t, key.Id, ev.TargetID)
groupAll, err := account.GetGroupAll()
assert.NoError(t, err)
// saving setup key with All group assigned to auto groups should return error
autoGroups = append(autoGroups, groupAll.ID)
_, err = manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{
Id: key.Id,
Name: newKeyName,
Revoked: revoked,
AutoGroups: autoGroups,
}, userID)
assert.Error(t, err, "should not save setup key with All group assigned in auto groups")
}
func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
@ -102,6 +122,9 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
t.Fatal(err)
}
groupAll, err := account.GetGroupAll()
assert.NoError(t, err)
type testCase struct {
name string
@ -134,8 +157,14 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
expectedGroups: []string{"FAKE"},
expectedFailure: true,
}
testCase3 := testCase{
name: "Create Setup Key should fail because of All group",
expectedKeyName: "my-test-key",
expectedGroups: []string{groupAll.ID},
expectedFailure: true,
}
for _, tCase := range []testCase{testCase1, testCase2} {
for _, tCase := range []testCase{testCase1, testCase2, testCase3} {
t.Run(tCase.name, func(t *testing.T) {
key, err := manager.CreateSetupKey(context.Background(), account.Id, tCase.expectedKeyName, SetupKeyReusable, expiresIn,
tCase.expectedGroups, SetupKeyUnlimitedUsage, userID, false)

View File

@ -468,6 +468,34 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User,
return &user, nil
}
func (s *SqlStore) GetUserByUserID(ctx context.Context, userID string) (*User, error) {
var user User
result := s.db.First(&user, idQueryCondition, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "user not found: index lookup failed")
}
log.WithContext(ctx).Errorf("error when getting user from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting user from store")
}
return &user, nil
}
func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
var groups []*nbgroup.Group
result := s.db.Find(&groups, idQueryCondition, accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
}
log.WithContext(ctx).Errorf("error when getting groups from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting groups from store")
}
return groups, nil
}
func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*Account) {
var accounts []Account
result := s.db.Find(&accounts)

View File

@ -41,6 +41,8 @@ type Store interface {
GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error)
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
GetUserByUserID(ctx context.Context, userID string) (*User, error)
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
SaveAccount(ctx context.Context, account *Account) error
SaveUsers(accountID string, users map[string]*User) error

View File

@ -0,0 +1,69 @@
package telemetry
import (
"context"
"time"
"go.opentelemetry.io/otel/metric"
)
// AccountManagerMetrics represents all metrics related to the AccountManager
type AccountManagerMetrics struct {
ctx context.Context
updateAccountPeersDurationMs metric.Float64Histogram
getPeerNetworkMapDurationMs metric.Float64Histogram
networkMapObjectCount metric.Int64Histogram
}
// NewAccountManagerMetrics creates an instance of AccountManagerMetrics
func NewAccountManagerMetrics(ctx context.Context, meter metric.Meter) (*AccountManagerMetrics, error) {
updateAccountPeersDurationMs, err := meter.Float64Histogram("management.account.update.account.peers.duration.ms",
metric.WithUnit("milliseconds"),
metric.WithExplicitBucketBoundaries(
0.5, 1, 2.5, 5, 10, 25, 50, 100, 250, 500, 1000, 2500, 5000, 10000, 30000,
))
if err != nil {
return nil, err
}
getPeerNetworkMapDurationMs, err := meter.Float64Histogram("management.account.get.peer.network.map.duration.ms",
metric.WithUnit("milliseconds"),
metric.WithExplicitBucketBoundaries(
0.1, 0.5, 1, 2.5, 5, 10, 25, 50, 100, 250, 500, 1000,
))
if err != nil {
return nil, err
}
networkMapObjectCount, err := meter.Int64Histogram("management.account.network.map.object.count",
metric.WithUnit("objects"),
metric.WithExplicitBucketBoundaries(
50, 100, 200, 500, 1000, 2500, 5000, 10000,
))
if err != nil {
return nil, err
}
return &AccountManagerMetrics{
ctx: ctx,
getPeerNetworkMapDurationMs: getPeerNetworkMapDurationMs,
updateAccountPeersDurationMs: updateAccountPeersDurationMs,
networkMapObjectCount: networkMapObjectCount,
}, nil
}
// CountUpdateAccountPeersDuration counts the duration of updating account peers
func (metrics *AccountManagerMetrics) CountUpdateAccountPeersDuration(duration time.Duration) {
metrics.updateAccountPeersDurationMs.Record(metrics.ctx, float64(duration.Nanoseconds())/1e6)
}
// CountGetPeerNetworkMapDuration counts the duration of getting the peer network map
func (metrics *AccountManagerMetrics) CountGetPeerNetworkMapDuration(duration time.Duration) {
metrics.getPeerNetworkMapDurationMs.Record(metrics.ctx, float64(duration.Nanoseconds())/1e6)
}
// CountNetworkMapObjects counts the number of network map objects
func (metrics *AccountManagerMetrics) CountNetworkMapObjects(count int64) {
metrics.networkMapObjectCount.Record(metrics.ctx, count)
}

View File

@ -20,14 +20,15 @@ const defaultEndpoint = "/metrics"
// MockAppMetrics mocks the AppMetrics interface
type MockAppMetrics struct {
GetMeterFunc func() metric2.Meter
CloseFunc func() error
ExposeFunc func(ctx context.Context, port int, endpoint string) error
IDPMetricsFunc func() *IDPMetrics
HTTPMiddlewareFunc func() *HTTPMiddleware
GRPCMetricsFunc func() *GRPCMetrics
StoreMetricsFunc func() *StoreMetrics
UpdateChannelMetricsFunc func() *UpdateChannelMetrics
GetMeterFunc func() metric2.Meter
CloseFunc func() error
ExposeFunc func(ctx context.Context, port int, endpoint string) error
IDPMetricsFunc func() *IDPMetrics
HTTPMiddlewareFunc func() *HTTPMiddleware
GRPCMetricsFunc func() *GRPCMetrics
StoreMetricsFunc func() *StoreMetrics
UpdateChannelMetricsFunc func() *UpdateChannelMetrics
AddAccountManagerMetricsFunc func() *AccountManagerMetrics
}
// GetMeter mocks the GetMeter function of the AppMetrics interface
@ -94,6 +95,14 @@ func (mock *MockAppMetrics) UpdateChannelMetrics() *UpdateChannelMetrics {
return nil
}
// AccountManagerMetrics mocks the MockAppMetrics function of the AccountManagerMetrics interface
func (mock *MockAppMetrics) AccountManagerMetrics() *AccountManagerMetrics {
if mock.AddAccountManagerMetricsFunc != nil {
return mock.AddAccountManagerMetricsFunc()
}
return nil
}
// AppMetrics is metrics interface
type AppMetrics interface {
GetMeter() metric2.Meter
@ -104,19 +113,21 @@ type AppMetrics interface {
GRPCMetrics() *GRPCMetrics
StoreMetrics() *StoreMetrics
UpdateChannelMetrics() *UpdateChannelMetrics
AccountManagerMetrics() *AccountManagerMetrics
}
// defaultAppMetrics are core application metrics based on OpenTelemetry https://opentelemetry.io/
type defaultAppMetrics struct {
// Meter can be used by different application parts to create counters and measure things
Meter metric2.Meter
listener net.Listener
ctx context.Context
idpMetrics *IDPMetrics
httpMiddleware *HTTPMiddleware
grpcMetrics *GRPCMetrics
storeMetrics *StoreMetrics
updateChannelMetrics *UpdateChannelMetrics
Meter metric2.Meter
listener net.Listener
ctx context.Context
idpMetrics *IDPMetrics
httpMiddleware *HTTPMiddleware
grpcMetrics *GRPCMetrics
storeMetrics *StoreMetrics
updateChannelMetrics *UpdateChannelMetrics
accountManagerMetrics *AccountManagerMetrics
}
// IDPMetrics returns metrics for the idp package
@ -144,6 +155,11 @@ func (appMetrics *defaultAppMetrics) UpdateChannelMetrics() *UpdateChannelMetric
return appMetrics.updateChannelMetrics
}
// AccountManagerMetrics returns metrics for the account manager
func (appMetrics *defaultAppMetrics) AccountManagerMetrics() *AccountManagerMetrics {
return appMetrics.accountManagerMetrics
}
// Close stop application metrics HTTP handler and closes listener.
func (appMetrics *defaultAppMetrics) Close() error {
if appMetrics.listener == nil {
@ -220,13 +236,19 @@ func NewDefaultAppMetrics(ctx context.Context) (AppMetrics, error) {
return nil, err
}
accountManagerMetrics, err := NewAccountManagerMetrics(ctx, meter)
if err != nil {
return nil, err
}
return &defaultAppMetrics{
Meter: meter,
ctx: ctx,
idpMetrics: idpMetrics,
httpMiddleware: middleware,
grpcMetrics: grpcMetrics,
storeMetrics: storeMetrics,
updateChannelMetrics: updateChannelMetrics,
Meter: meter,
ctx: ctx,
idpMetrics: idpMetrics,
httpMiddleware: middleware,
grpcMetrics: grpcMetrics,
storeMetrics: storeMetrics,
updateChannelMetrics: updateChannelMetrics,
accountManagerMetrics: accountManagerMetrics,
}, nil
}

View File

@ -2,6 +2,7 @@ package server
import (
"context"
"errors"
"fmt"
"strings"
"time"
@ -472,51 +473,18 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init
}
func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account *Account, initiatorUserID, targetUserID string) error {
tuEmail, tuName, err := am.getEmailAndNameOfTargetUser(ctx, account.Id, initiatorUserID, targetUserID)
if err != nil {
log.WithContext(ctx).Errorf("failed to resolve email address: %s", err)
return err
}
if !isNil(am.idpManager) {
// Delete if the user already exists in the IdP.Necessary in cases where a user account
// was created where a user account was provisioned but the user did not sign in
_, err = am.idpManager.GetUserDataByID(ctx, targetUserID, idp.AppMetadata{WTAccountID: account.Id})
if err == nil {
err = am.deleteUserFromIDP(ctx, targetUserID, account.Id)
if err != nil {
log.WithContext(ctx).Debugf("failed to delete user from IDP: %s", targetUserID)
return err
}
} else {
log.WithContext(ctx).Debugf("skipped deleting user %s from IDP, error: %v", targetUserID, err)
}
}
err = am.deleteUserPeers(ctx, initiatorUserID, targetUserID, account)
meta, err := am.prepareUserDeletion(ctx, account, initiatorUserID, targetUserID)
if err != nil {
return err
}
u, err := account.FindUser(targetUserID)
if err != nil {
log.WithContext(ctx).Errorf("failed to find user %s for deletion, this should never happen: %s", targetUserID, err)
}
var tuCreatedAt time.Time
if u != nil {
tuCreatedAt = u.CreatedAt
}
delete(account.Users, targetUserID)
err = am.Store.SaveAccount(ctx, account)
if err != nil {
return err
}
meta := map[string]any{"name": tuName, "email": tuEmail, "created_at": tuCreatedAt}
am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta)
am.updateAccountPeers(ctx, account)
return nil
@ -976,10 +944,14 @@ func validateUserUpdate(account *Account, initiatorUser, oldUser, update *User)
}
for _, newGroupID := range update.AutoGroups {
if _, ok := account.Groups[newGroupID]; !ok {
group, ok := account.Groups[newGroupID]
if !ok {
return status.Errorf(status.InvalidArgument, "provided group ID %s in the user %s update doesn't exist",
newGroupID, update.Id)
}
if group.Name == "All" {
return status.Errorf(status.InvalidArgument, "can't add All group to the user")
}
}
return nil
@ -1190,6 +1162,116 @@ func (am *DefaultAccountManager) getEmailAndNameOfTargetUser(ctx context.Context
return "", "", fmt.Errorf("user info not found for user: %s", targetId)
}
// DeleteRegularUsers deletes regular users from an account.
// Note: This function does not acquire the global lock.
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
//
// If an error occurs while deleting the user, the function skips it and continues deleting other users.
// Errors are collected and returned at the end.
func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string) error {
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
executingUser := account.Users[initiatorUserID]
if executingUser == nil {
return status.Errorf(status.NotFound, "user not found")
}
if !executingUser.HasAdminPower() {
return status.Errorf(status.PermissionDenied, "only users with admin power can delete users")
}
var allErrors error
deletedUsersMeta := make(map[string]map[string]any)
for _, targetUserID := range targetUserIDs {
if initiatorUserID == targetUserID {
allErrors = errors.Join(allErrors, errors.New("self deletion is not allowed"))
continue
}
targetUser := account.Users[targetUserID]
if targetUser == nil {
allErrors = errors.Join(allErrors, fmt.Errorf("target user: %s not found", targetUserID))
continue
}
if targetUser.Role == UserRoleOwner {
allErrors = errors.Join(allErrors, fmt.Errorf("unable to delete a user: %s with owner role", targetUserID))
continue
}
// disable deleting integration user if the initiator is not admin service user
if targetUser.Issued == UserIssuedIntegration && !executingUser.IsServiceUser {
allErrors = errors.Join(allErrors, errors.New("only integration service user can delete this user"))
continue
}
meta, err := am.prepareUserDeletion(ctx, account, initiatorUserID, targetUserID)
if err != nil {
allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete user %s: %s", targetUserID, err))
continue
}
delete(account.Users, targetUserID)
deletedUsersMeta[targetUserID] = meta
}
err = am.Store.SaveAccount(ctx, account)
if err != nil {
return fmt.Errorf("failed to delete users: %w", err)
}
am.updateAccountPeers(ctx, account)
for targetUserID, meta := range deletedUsersMeta {
am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta)
}
return allErrors
}
func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, account *Account, initiatorUserID, targetUserID string) (map[string]any, error) {
tuEmail, tuName, err := am.getEmailAndNameOfTargetUser(ctx, account.Id, initiatorUserID, targetUserID)
if err != nil {
log.WithContext(ctx).Errorf("failed to resolve email address: %s", err)
return nil, err
}
if !isNil(am.idpManager) {
// Delete if the user already exists in the IdP. Necessary in cases where a user account
// was created where a user account was provisioned but the user did not sign in
_, err = am.idpManager.GetUserDataByID(ctx, targetUserID, idp.AppMetadata{WTAccountID: account.Id})
if err == nil {
err = am.deleteUserFromIDP(ctx, targetUserID, account.Id)
if err != nil {
log.WithContext(ctx).Debugf("failed to delete user from IDP: %s", targetUserID)
return nil, err
}
} else {
log.WithContext(ctx).Debugf("skipped deleting user %s from IDP, error: %v", targetUserID, err)
}
}
err = am.deleteUserPeers(ctx, initiatorUserID, targetUserID, account)
if err != nil {
return nil, err
}
u, err := account.FindUser(targetUserID)
if err != nil {
log.WithContext(ctx).Errorf("failed to find user %s for deletion, this should never happen: %s", targetUserID, err)
}
var tuCreatedAt time.Time
if u != nil {
tuCreatedAt = u.CreatedAt
}
return map[string]any{"name": tuName, "email": tuEmail, "created_at": tuCreatedAt}, nil
}
func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) {
for _, user := range userData {
if user.ID == userID {

View File

@ -662,6 +662,157 @@ func TestUser_DeleteUser_regularUser(t *testing.T) {
}
func TestUser_DeleteUser_RegularUsers(t *testing.T) {
store := newStore(t)
defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
targetId := "user2"
account.Users[targetId] = &User{
Id: targetId,
IsServiceUser: true,
ServiceUserName: "user2username",
}
targetId = "user3"
account.Users[targetId] = &User{
Id: targetId,
IsServiceUser: false,
Issued: UserIssuedAPI,
}
targetId = "user4"
account.Users[targetId] = &User{
Id: targetId,
IsServiceUser: false,
Issued: UserIssuedIntegration,
}
targetId = "user5"
account.Users[targetId] = &User{
Id: targetId,
IsServiceUser: false,
Issued: UserIssuedAPI,
Role: UserRoleOwner,
}
account.Users["user6"] = &User{
Id: "user6",
IsServiceUser: false,
Issued: UserIssuedAPI,
}
account.Users["user7"] = &User{
Id: "user7",
IsServiceUser: false,
Issued: UserIssuedAPI,
}
account.Users["user8"] = &User{
Id: "user8",
IsServiceUser: false,
Issued: UserIssuedAPI,
Role: UserRoleAdmin,
}
account.Users["user9"] = &User{
Id: "user9",
IsServiceUser: false,
Issued: UserIssuedAPI,
Role: UserRoleAdmin,
}
err := store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
integratedPeerValidator: MocIntegratedValidator{},
}
testCases := []struct {
name string
userIDs []string
expectedReasons []string
expectedDeleted []string
expectedNotDeleted []string
}{
{
name: "Delete service user successfully ",
userIDs: []string{"user2"},
expectedDeleted: []string{"user2"},
},
{
name: "Delete regular user successfully",
userIDs: []string{"user3"},
expectedDeleted: []string{"user3"},
},
{
name: "Delete integration regular user permission denied",
userIDs: []string{"user4"},
expectedReasons: []string{"only integration service user can delete this user"},
expectedNotDeleted: []string{"user4"},
},
{
name: "Delete user with owner role should return permission denied",
userIDs: []string{"user5"},
expectedReasons: []string{"unable to delete a user: user5 with owner role"},
expectedNotDeleted: []string{"user5"},
},
{
name: "Delete multiple users with mixed results",
userIDs: []string{"user5", "user5", "user6", "user7"},
expectedReasons: []string{"only integration service user can delete this user", "unable to delete a user: user5 with owner role"},
expectedDeleted: []string{"user6", "user7"},
expectedNotDeleted: []string{"user4", "user5"},
},
{
name: "Delete non-existent user",
userIDs: []string{"non-existent-user"},
expectedReasons: []string{"target user: non-existent-user not found"},
expectedNotDeleted: []string{},
},
{
name: "Delete multiple regular users successfully",
userIDs: []string{"user8", "user9"},
expectedDeleted: []string{"user8", "user9"},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err = am.DeleteRegularUsers(context.Background(), mockAccountID, mockUserID, tc.userIDs)
if len(tc.expectedReasons) > 0 {
assert.Error(t, err)
var foundExpectedErrors int
wrappedErr, ok := err.(interface{ Unwrap() []error })
assert.Equal(t, ok, true)
for _, e := range wrappedErr.Unwrap() {
assert.Contains(t, tc.expectedReasons, e.Error(), "unexpected error message")
foundExpectedErrors++
}
assert.Equal(t, len(tc.expectedReasons), foundExpectedErrors, "not all expected errors were found")
} else {
assert.NoError(t, err)
}
acc, err := am.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "")
assert.NoError(t, err)
for _, id := range tc.expectedDeleted {
_, exists := acc.Users[id]
assert.False(t, exists, "user should have been deleted: %s", id)
}
for _, id := range tc.expectedNotDeleted {
user, exists := acc.Users[id]
assert.True(t, exists, "user should not have been deleted: %s", id)
assert.NotNil(t, user, "user should exist: %s", id)
}
})
}
}
func TestDefaultAccountManager_GetUser(t *testing.T) {
store := newStore(t)
defer store.Close(context.Background())

View File

@ -151,6 +151,22 @@ add_aur_repo() {
${SUDO} pacman -Rs "$REMOVE_PKGS" --noconfirm
}
prepare_tun_module() {
# Create the necessary file structure for /dev/net/tun
if [ ! -c /dev/net/tun ]; then
if [ ! -d /dev/net ]; then
mkdir -m 755 /dev/net
fi
mknod /dev/net/tun c 10 200
chmod 0755 /dev/net/tun
fi
# Load the tun module if not already loaded
if ! lsmod | grep -q "^tun\s"; then
insmod /lib/modules/tun.ko
fi
}
install_native_binaries() {
# Checks for supported architecture
case "$ARCH" in
@ -268,6 +284,10 @@ install_netbird() {
;;
esac
if [ "$OS_NAME" = "synology" ]; then
prepare_tun_module
fi
# Add package manager to config
${SUDO} mkdir -p "$CONFIG_FOLDER"
echo "package_manager=$PACKAGE_MANAGER" | ${SUDO} tee "$CONFIG_FILE" > /dev/null

View File

@ -10,5 +10,5 @@ import (
// Listen is not supported on other platforms then Linux
func Listen(port int, filter BPFFilter) (net.PacketConn, error) {
return nil, fmt.Errorf(fmt.Sprintf("Not supported OS %s. SharedSocket is only supported on Linux", runtime.GOOS))
return nil, fmt.Errorf("not supported OS %s. SharedSocket is only supported on Linux", runtime.GOOS)
}