Merge branch 'main' into feature/relay-integration

This commit is contained in:
Zoltán Papp 2024-08-20 16:44:04 +02:00
commit 2e6c6cd47d
61 changed files with 2470 additions and 481 deletions

View File

@ -31,9 +31,14 @@ Please specify whether you use NetBird Cloud or self-host NetBird's control plan
`netbird version` `netbird version`
**NetBird status -d output:** **NetBird status -dA output:**
If applicable, add the `netbird status -d' command output. If applicable, add the `netbird status -dA' command output.
**Do you face any client issues on desktop?**
Please provide the file created by `netbird debug for 1m -AS`.
We advise reviewing the anonymized files for any remaining PII.
**Screenshots** **Screenshots**

View File

@ -11,8 +11,6 @@ builds:
- amd64 - amd64
ldflags: ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
tags:
- legacy_appindicator
mod_timestamp: '{{ .CommitTimestamp }}' mod_timestamp: '{{ .CommitTimestamp }}'
- id: netbird-ui-windows - id: netbird-ui-windows

View File

@ -17,7 +17,7 @@
<img src="https://img.shields.io/badge/license-BSD--3-blue" /> <img src="https://img.shields.io/badge/license-BSD--3-blue" />
</a> </a>
<br> <br>
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-vrahf41g-ik1v7fV8du6t0RwxSrJ96A"> <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2p5zwhm4g-8fHollzrQa5y4PZF5AEpvQ">
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/> <img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
</a> </a>
</p> </p>

View File

@ -84,7 +84,7 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
func (a *Auth) saveConfigIfSSOSupported() (bool, error) { func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
supportsSSO := true supportsSSO := true
err := a.withBackOff(a.ctx, func() (err error) { err := a.withBackOff(a.ctx, func() (err error) {
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL) _, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) { if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL) _, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
s, ok := gstatus.FromError(err) s, ok := gstatus.FromError(err)

View File

@ -39,6 +39,11 @@ var loginCmd = &cobra.Command{
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, hostName) ctx = context.WithValue(ctx, system.DeviceNameCtxKey, hostName)
} }
providedSetupKey, err := getSetupKey()
if err != nil {
return err
}
// workaround to run without service // workaround to run without service
if logFile == "console" { if logFile == "console" {
err = handleRebrand(cmd) err = handleRebrand(cmd)
@ -62,7 +67,7 @@ var loginCmd = &cobra.Command{
config, _ = internal.UpdateOldManagementURL(ctx, config, configPath) config, _ = internal.UpdateOldManagementURL(ctx, config, configPath)
err = foregroundLogin(ctx, cmd, config, setupKey) err = foregroundLogin(ctx, cmd, config, providedSetupKey)
if err != nil { if err != nil {
return fmt.Errorf("foreground login failed: %v", err) return fmt.Errorf("foreground login failed: %v", err)
} }
@ -81,7 +86,7 @@ var loginCmd = &cobra.Command{
client := proto.NewDaemonServiceClient(conn) client := proto.NewDaemonServiceClient(conn)
loginRequest := proto.LoginRequest{ loginRequest := proto.LoginRequest{
SetupKey: setupKey, SetupKey: providedSetupKey,
ManagementUrl: managementURL, ManagementUrl: managementURL,
IsLinuxDesktopClient: isLinuxRunningDesktop(), IsLinuxDesktopClient: isLinuxRunningDesktop(),
Hostname: hostName, Hostname: hostName,

View File

@ -56,6 +56,7 @@ var (
managementURL string managementURL string
adminURL string adminURL string
setupKey string setupKey string
setupKeyPath string
hostName string hostName string
preSharedKey string preSharedKey string
natExternalIPs []string natExternalIPs []string
@ -128,6 +129,8 @@ func init() {
rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets Netbird log level") rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets Netbird log level")
rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout. If syslog is specified the log will be sent to syslog daemon.") rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout. If syslog is specified the log will be sent to syslog daemon.")
rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)") rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)")
rootCmd.PersistentFlags().StringVar(&setupKeyPath, "setup-key-file", "", "The path to a setup key obtained from the Management Service Dashboard (used to register peer) This is ignored if the setup-key flag is provided.")
rootCmd.MarkFlagsMutuallyExclusive("setup-key", "setup-key-file")
rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.") rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.")
rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device") rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device")
rootCmd.PersistentFlags().BoolVarP(&anonymizeFlag, "anonymize", "A", false, "anonymize IP addresses and non-netbird.io domains in logs and status output") rootCmd.PersistentFlags().BoolVarP(&anonymizeFlag, "anonymize", "A", false, "anonymize IP addresses and non-netbird.io domains in logs and status output")
@ -253,6 +256,21 @@ var CLIBackOffSettings = &backoff.ExponentialBackOff{
Clock: backoff.SystemClock, Clock: backoff.SystemClock,
} }
func getSetupKey() (string, error) {
if setupKeyPath != "" && setupKey == "" {
return getSetupKeyFromFile(setupKeyPath)
}
return setupKey, nil
}
func getSetupKeyFromFile(setupKeyPath string) (string, error) {
data, err := os.ReadFile(setupKeyPath)
if err != nil {
return "", fmt.Errorf("failed to read setup key file: %v", err)
}
return strings.TrimSpace(string(data)), nil
}
func handleRebrand(cmd *cobra.Command) error { func handleRebrand(cmd *cobra.Command) error {
var err error var err error
if logFile == defaultLogFile { if logFile == defaultLogFile {

View File

@ -11,6 +11,7 @@ import (
"go.opentelemetry.io/otel" "go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
@ -71,6 +72,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Listener) { func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Listener) {
t.Helper() t.Helper()
lis, err := net.Listen("tcp", ":0") lis, err := net.Listen("tcp", ":0")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -88,7 +90,11 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
return nil, nil return nil, nil
} }
iv, _ := integrations.NewIntegratedValidator(context.Background(), eventStore) iv, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -147,6 +147,11 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
ic.DNSRouteInterval = &dnsRouteInterval ic.DNSRouteInterval = &dnsRouteInterval
} }
providedSetupKey, err := getSetupKey()
if err != nil {
return err
}
config, err := internal.UpdateOrCreateConfig(ic) config, err := internal.UpdateOrCreateConfig(ic)
if err != nil { if err != nil {
return fmt.Errorf("get config file: %v", err) return fmt.Errorf("get config file: %v", err)
@ -154,7 +159,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
config, _ = internal.UpdateOldManagementURL(ctx, config, configPath) config, _ = internal.UpdateOldManagementURL(ctx, config, configPath)
err = foregroundLogin(ctx, cmd, config, setupKey) err = foregroundLogin(ctx, cmd, config, providedSetupKey)
if err != nil { if err != nil {
return fmt.Errorf("foreground login failed: %v", err) return fmt.Errorf("foreground login failed: %v", err)
} }
@ -202,8 +207,13 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
return nil return nil
} }
providedSetupKey, err := getSetupKey()
if err != nil {
return err
}
loginRequest := proto.LoginRequest{ loginRequest := proto.LoginRequest{
SetupKey: setupKey, SetupKey: providedSetupKey,
ManagementUrl: managementURL, ManagementUrl: managementURL,
AdminURL: adminURL, AdminURL: adminURL,
NatExternalIPs: natExternalIPs, NatExternalIPs: natExternalIPs,

View File

@ -2,6 +2,7 @@ package cmd
import ( import (
"context" "context"
"os"
"testing" "testing"
"time" "time"
@ -40,6 +41,36 @@ func TestUpDaemon(t *testing.T) {
return return
} }
// Test the setup-key-file flag.
tempFile, err := os.CreateTemp("", "setup-key")
if err != nil {
t.Errorf("could not create temp file, got error %v", err)
return
}
defer os.Remove(tempFile.Name())
if _, err := tempFile.Write([]byte("A2C8E62B-38F5-4553-B31E-DD66C696CEBB")); err != nil {
t.Errorf("could not write to temp file, got error %v", err)
return
}
if err := tempFile.Close(); err != nil {
t.Errorf("unable to close file, got error %v", err)
}
rootCmd.SetArgs([]string{
"login",
"--daemon-addr", "tcp://" + cliAddr,
"--setup-key-file", tempFile.Name(),
"--log-file", "",
})
if err := rootCmd.Execute(); err != nil {
t.Errorf("expected no error while running up command, got %v", err)
return
}
time.Sleep(time.Second * 3)
if status, err := state.Status(); err != nil && status != internal.StatusIdle {
t.Errorf("wrong status after login: %s, %v", internal.StatusIdle, err)
return
}
rootCmd.SetArgs([]string{ rootCmd.SetArgs([]string{
"up", "up",
"--daemon-addr", "tcp://" + cliAddr, "--daemon-addr", "tcp://" + cliAddr,

View File

@ -3,6 +3,7 @@ package auth
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@ -180,7 +181,7 @@ func (d *DeviceAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowIn
continue continue
} }
return TokenInfo{}, fmt.Errorf(tokenResponse.ErrorDescription) return TokenInfo{}, errors.New(tokenResponse.ErrorDescription)
} }
tokenInfo := TokenInfo{ tokenInfo := TokenInfo{

View File

@ -86,7 +86,7 @@ func NewOAuthFlow(ctx context.Context, config *internal.Config, isLinuxDesktopCl
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow // authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
func authenticateWithPKCEFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) { func authenticateWithPKCEFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) {
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL) pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair)
if err != nil { if err != nil {
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err) return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
} }

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"crypto/sha256" "crypto/sha256"
"crypto/subtle" "crypto/subtle"
"crypto/tls"
"encoding/base64" "encoding/base64"
"errors" "errors"
"fmt" "fmt"
@ -143,6 +144,18 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (
func (p *PKCEAuthorizationFlow) startServer(server *http.Server, tokenChan chan<- *oauth2.Token, errChan chan<- error) { func (p *PKCEAuthorizationFlow) startServer(server *http.Server, tokenChan chan<- *oauth2.Token, errChan chan<- error) {
mux := http.NewServeMux() mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) { mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
cert := p.providerConfig.ClientCertPair
if cert != nil {
tr := &http.Transport{
TLSClientConfig: &tls.Config{
Certificates: []tls.Certificate{*cert},
},
}
sslClient := &http.Client{Transport: tr}
ctx := context.WithValue(req.Context(), oauth2.HTTPClient, sslClient)
req = req.WithContext(ctx)
}
token, err := p.handleRequest(req) token, err := p.handleRequest(req)
if err != nil { if err != nil {
renderPKCEFlowTmpl(w, err) renderPKCEFlowTmpl(w, err)

View File

@ -2,6 +2,7 @@ package internal
import ( import (
"context" "context"
"crypto/tls"
"fmt" "fmt"
"net/url" "net/url"
"os" "os"
@ -57,6 +58,8 @@ type ConfigInput struct {
DisableAutoConnect *bool DisableAutoConnect *bool
ExtraIFaceBlackList []string ExtraIFaceBlackList []string
DNSRouteInterval *time.Duration DNSRouteInterval *time.Duration
ClientCertPath string
ClientCertKeyPath string
} }
// Config Configuration type // Config Configuration type
@ -102,6 +105,13 @@ type Config struct {
// DNSRouteInterval is the interval in which the DNS routes are updated // DNSRouteInterval is the interval in which the DNS routes are updated
DNSRouteInterval time.Duration DNSRouteInterval time.Duration
//Path to a certificate used for mTLS authentication
ClientCertPath string
//Path to corresponding private key of ClientCertPath
ClientCertKeyPath string
ClientCertKeyPair *tls.Certificate `json:"-"`
} }
// ReadConfig read config file and return with Config. If it is not exists create a new with default values // ReadConfig read config file and return with Config. If it is not exists create a new with default values
@ -385,6 +395,26 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
} }
if input.ClientCertKeyPath != "" {
config.ClientCertKeyPath = input.ClientCertKeyPath
updated = true
}
if input.ClientCertPath != "" {
config.ClientCertPath = input.ClientCertPath
updated = true
}
if config.ClientCertPath != "" && config.ClientCertKeyPath != "" {
cert, err := tls.LoadX509KeyPair(config.ClientCertPath, config.ClientCertKeyPath)
if err != nil {
log.Error("Failed to load mTLS cert/key pair: ", err)
} else {
config.ClientCertKeyPair = &cert
log.Info("Loaded client mTLS cert/key pair")
}
}
return updated, nil return updated, nil
} }

View File

@ -37,6 +37,7 @@ import (
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
relayClient "github.com/netbirdio/netbird/relay/client" relayClient "github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
signal "github.com/netbirdio/netbird/signal/client" signal "github.com/netbirdio/netbird/signal/client"
"github.com/netbirdio/netbird/signal/proto" "github.com/netbirdio/netbird/signal/proto"
@ -1097,7 +1098,11 @@ func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error)
return nil, "", err return nil, "", err
} }
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore) ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }

View File

@ -2,6 +2,7 @@ package internal
import ( import (
"context" "context"
"crypto/tls"
"fmt" "fmt"
"net/url" "net/url"
@ -36,10 +37,12 @@ type PKCEAuthProviderConfig struct {
RedirectURLs []string RedirectURLs []string
// UseIDToken indicates if the id token should be used for authentication // UseIDToken indicates if the id token should be used for authentication
UseIDToken bool UseIDToken bool
//ClientCertPair is used for mTLS authentication to the IDP
ClientCertPair *tls.Certificate
} }
// GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it // GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it
func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL) (PKCEAuthorizationFlow, error) { func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL, clientCert *tls.Certificate) (PKCEAuthorizationFlow, error) {
// validate our peer's Wireguard PRIVATE key // validate our peer's Wireguard PRIVATE key
myPrivateKey, err := wgtypes.ParseKey(privateKey) myPrivateKey, err := wgtypes.ParseKey(privateKey)
if err != nil { if err != nil {
@ -93,6 +96,7 @@ func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL
Scope: protoPKCEAuthorizationFlow.GetProviderConfig().GetScope(), Scope: protoPKCEAuthorizationFlow.GetProviderConfig().GetScope(),
RedirectURLs: protoPKCEAuthorizationFlow.GetProviderConfig().GetRedirectURLs(), RedirectURLs: protoPKCEAuthorizationFlow.GetProviderConfig().GetRedirectURLs(),
UseIDToken: protoPKCEAuthorizationFlow.GetProviderConfig().GetUseIDToken(), UseIDToken: protoPKCEAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
ClientCertPair: clientCert,
}, },
} }

View File

@ -1,4 +1,5 @@
// go:build !android //go:build !android
package sysctl package sysctl
import ( import (

View File

@ -270,7 +270,14 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
} }
routesMap := engine.GetClientRoutesWithNetID() routesMap := engine.GetClientRoutesWithNetID()
routeSelector := engine.GetRouteManager().GetRouteSelector() routeManager := engine.GetRouteManager()
if routeManager == nil {
return nil, fmt.Errorf("could not get route manager")
}
routeSelector := routeManager.GetRouteSelector()
if routeSelector == nil {
return nil, fmt.Errorf("could not get route selector")
}
var routes []*selectRoute var routes []*selectRoute
for id, rt := range routesMap { for id, rt := range routesMap {

View File

@ -74,7 +74,7 @@ func (a *Auth) SaveConfigIfSSOSupported() (bool, error) {
err := a.withBackOff(a.ctx, func() (err error) { err := a.withBackOff(a.ctx, func() (err error) {
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL) _, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) { if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL) _, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) { if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
supportsSSO = false supportsSSO = false
err = nil err = nil

View File

@ -19,6 +19,7 @@ import (
mgmtProto "github.com/netbirdio/netbird/management/proto" mgmtProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/signal/proto" "github.com/netbirdio/netbird/signal/proto"
signalServer "github.com/netbirdio/netbird/signal/server" signalServer "github.com/netbirdio/netbird/signal/server"
) )
@ -120,7 +121,11 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
return nil, "", err return nil, "", err
} }
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore) ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }

View File

@ -118,9 +118,9 @@ func (srv *DefaultServer) publicKeyHandler(ctx ssh.Context, key ssh.PublicKey) b
func prepareUserEnv(user *user.User, shell string) []string { func prepareUserEnv(user *user.User, shell string) []string {
return []string{ return []string{
fmt.Sprintf("SHELL=" + shell), fmt.Sprint("SHELL=" + shell),
fmt.Sprintf("USER=" + user.Username), fmt.Sprint("USER=" + user.Username),
fmt.Sprintf("HOME=" + user.HomeDir), fmt.Sprint("HOME=" + user.HomeDir),
} }
} }

View File

@ -22,8 +22,8 @@ import (
"fyne.io/fyne/v2/app" "fyne.io/fyne/v2/app"
"fyne.io/fyne/v2/dialog" "fyne.io/fyne/v2/dialog"
"fyne.io/fyne/v2/widget" "fyne.io/fyne/v2/widget"
"fyne.io/systray"
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
"github.com/getlantern/systray"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/skratchdot/open-golang/open" "github.com/skratchdot/open-golang/open"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"

Binary file not shown.

Before

Width:  |  Height:  |  Size: 9.0 KiB

After

Width:  |  Height:  |  Size: 9.6 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 12 KiB

After

Width:  |  Height:  |  Size: 12 KiB

43
go.mod
View File

@ -30,15 +30,15 @@ require (
) )
require ( require (
fyne.io/fyne/v2 v2.1.4 fyne.io/fyne/v2 v2.5.0
fyne.io/systray v1.11.0
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible
github.com/c-robinson/iplib v1.0.3 github.com/c-robinson/iplib v1.0.3
github.com/cilium/ebpf v0.15.0 github.com/cilium/ebpf v0.15.0
github.com/coreos/go-iptables v0.7.0 github.com/coreos/go-iptables v0.7.0
github.com/creack/pty v1.1.18 github.com/creack/pty v1.1.18
github.com/eko/gocache/v3 v3.1.1 github.com/eko/gocache/v3 v3.1.1
github.com/fsnotify/fsnotify v1.6.0 github.com/fsnotify/fsnotify v1.7.0
github.com/getlantern/systray v1.2.1
github.com/gliderlabs/ssh v0.3.4 github.com/gliderlabs/ssh v0.3.4
github.com/godbus/dbus/v5 v5.1.0 github.com/godbus/dbus/v5 v5.1.0
github.com/golang/mock v1.6.0 github.com/golang/mock v1.6.0
@ -83,7 +83,7 @@ require (
go.opentelemetry.io/otel/sdk/metric v1.26.0 go.opentelemetry.io/otel/sdk/metric v1.26.0
goauthentik.io/api/v3 v3.2023051.3 goauthentik.io/api/v3 v3.2023051.3
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028 golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a
golang.org/x/net v0.26.0 golang.org/x/net v0.26.0
golang.org/x/oauth2 v0.19.0 golang.org/x/oauth2 v0.19.0
golang.org/x/sync v0.7.0 golang.org/x/sync v0.7.0
@ -102,7 +102,7 @@ require (
cloud.google.com/go/compute/metadata v0.3.0 // indirect cloud.google.com/go/compute/metadata v0.3.0 // indirect
dario.cat/mergo v1.0.0 // indirect dario.cat/mergo v1.0.0 // indirect
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect
github.com/BurntSushi/toml v1.3.2 // indirect github.com/BurntSushi/toml v1.4.0 // indirect
github.com/Microsoft/go-winio v0.6.2 // indirect github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/Microsoft/hcsshim v0.12.3 // indirect github.com/Microsoft/hcsshim v0.12.3 // indirect
github.com/XiaoMi/pegasus-go-client v0.0.0-20210427083443-f3b6b08bc4c2 // indirect github.com/XiaoMi/pegasus-go-client v0.0.0-20210427083443-f3b6b08bc4c2 // indirect
@ -121,27 +121,25 @@ require (
github.com/docker/go-connections v0.5.0 // indirect github.com/docker/go-connections v0.5.0 // indirect
github.com/docker/go-units v0.5.0 // indirect github.com/docker/go-units v0.5.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/fredbi/uri v0.0.0-20181227131451-3dcfdacbaaf3 // indirect github.com/fredbi/uri v1.1.0 // indirect
github.com/getlantern/context v0.0.0-20190109183933-c447772a6520 // indirect github.com/fyne-io/gl-js v0.0.0-20220119005834-d2da28d9ccfe // indirect
github.com/getlantern/errors v0.0.0-20190325191628-abdb3e3e36f7 // indirect github.com/fyne-io/glfw-js v0.0.0-20240101223322-6e1efdc71b7a // indirect
github.com/getlantern/golog v0.0.0-20190830074920-4ef2e798c2d7 // indirect github.com/fyne-io/image v0.0.0-20220602074514-4956b0afb3d2 // indirect
github.com/getlantern/hex v0.0.0-20190417191902-c6586a6fe0b7 // indirect github.com/go-gl/gl v0.0.0-20211210172815-726fda9656d6 // indirect
github.com/getlantern/hidden v0.0.0-20190325191715-f02dbb02be55 // indirect github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a // indirect
github.com/getlantern/ops v0.0.0-20190325191751-d70cb0d6f85f // indirect
github.com/go-gl/gl v0.0.0-20210813123233-e4099ee2221f // indirect
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20211024062804-40e447a793be // indirect
github.com/go-logr/logr v1.4.1 // indirect github.com/go-logr/logr v1.4.1 // indirect
github.com/go-logr/stdr v1.2.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-ole/go-ole v1.3.0 // indirect github.com/go-ole/go-ole v1.3.0 // indirect
github.com/go-redis/redis/v8 v8.11.5 // indirect github.com/go-redis/redis/v8 v8.11.5 // indirect
github.com/go-stack/stack v1.8.0 // indirect github.com/go-text/render v0.1.0 // indirect
github.com/go-text/typesetting v0.1.0 // indirect
github.com/gogo/protobuf v1.3.2 // indirect github.com/gogo/protobuf v1.3.2 // indirect
github.com/goki/freetype v0.0.0-20181231101311-fa8a33aabaff // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/google/btree v1.0.1 // indirect github.com/google/btree v1.0.1 // indirect
github.com/google/s2a-go v0.1.7 // indirect github.com/google/s2a-go v0.1.7 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
github.com/googleapis/gax-go/v2 v2.12.3 // indirect github.com/googleapis/gax-go/v2 v2.12.3 // indirect
github.com/gopherjs/gopherjs v1.17.2 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/hashicorp/go-uuid v1.0.2 // indirect github.com/hashicorp/go-uuid v1.0.2 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect
@ -149,9 +147,11 @@ require (
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/pgx/v5 v5.5.5 // indirect github.com/jackc/pgx/v5 v5.5.5 // indirect
github.com/jackc/puddle/v2 v2.2.1 // indirect github.com/jackc/puddle/v2 v2.2.1 // indirect
github.com/jeandeaual/go-locale v0.0.0-20240223122105-ce5225dcaa49 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect github.com/jinzhu/now v1.1.5 // indirect
github.com/josharian/native v1.1.0 // indirect github.com/josharian/native v1.1.0 // indirect
github.com/jsummers/gobmp v0.0.0-20151104160322-e2ba15ffa76e // indirect
github.com/kelseyhightower/envconfig v1.4.0 // indirect github.com/kelseyhightower/envconfig v1.4.0 // indirect
github.com/klauspost/compress v1.17.8 // indirect github.com/klauspost/compress v1.17.8 // indirect
github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae // indirect github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae // indirect
@ -163,11 +163,11 @@ require (
github.com/moby/sys/user v0.1.0 // indirect github.com/moby/sys/user v0.1.0 // indirect
github.com/moby/term v0.5.0 // indirect github.com/moby/term v0.5.0 // indirect
github.com/morikuni/aec v1.0.0 // indirect github.com/morikuni/aec v1.0.0 // indirect
github.com/nicksnyder/go-i18n/v2 v2.4.0 // indirect
github.com/nxadm/tail v1.4.8 // indirect github.com/nxadm/tail v1.4.8 // indirect
github.com/onsi/ginkgo/v2 v2.9.5 // indirect github.com/onsi/ginkgo/v2 v2.9.5 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.0 // indirect github.com/opencontainers/image-spec v1.1.0 // indirect
github.com/oxtoacart/bpool v0.0.0-20190530202638-03653db5a59c // indirect
github.com/pegasus-kv/thrift v0.13.0 // indirect github.com/pegasus-kv/thrift v0.13.0 // indirect
github.com/pion/dtls/v2 v2.2.10 // indirect github.com/pion/dtls/v2 v2.2.10 // indirect
github.com/pion/mdns v0.0.12 // indirect github.com/pion/mdns v0.0.12 // indirect
@ -178,21 +178,24 @@ require (
github.com/prometheus/client_model v0.6.1 // indirect github.com/prometheus/client_model v0.6.1 // indirect
github.com/prometheus/common v0.53.0 // indirect github.com/prometheus/common v0.53.0 // indirect
github.com/prometheus/procfs v0.15.0 // indirect github.com/prometheus/procfs v0.15.0 // indirect
github.com/rymdport/portal v0.2.2 // indirect
github.com/shoenig/go-m1cpu v0.1.6 // indirect github.com/shoenig/go-m1cpu v0.1.6 // indirect
github.com/spf13/cast v1.5.0 // indirect github.com/spf13/cast v1.5.0 // indirect
github.com/srwiley/oksvg v0.0.0-20200311192757-870daf9aa564 // indirect github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c // indirect
github.com/srwiley/rasterx v0.0.0-20200120212402-85cb7272f5e9 // indirect github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef // indirect
github.com/tklauser/go-sysconf v0.3.14 // indirect github.com/tklauser/go-sysconf v0.3.14 // indirect
github.com/tklauser/numcpus v0.8.0 // indirect github.com/tklauser/numcpus v0.8.0 // indirect
github.com/vishvananda/netns v0.0.4 // indirect github.com/vishvananda/netns v0.0.4 // indirect
github.com/yuin/goldmark v1.4.13 // indirect github.com/yuin/goldmark v1.7.1 // indirect
go.opencensus.io v0.24.0 // indirect go.opencensus.io v0.24.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect
go.opentelemetry.io/otel/sdk v1.26.0 // indirect go.opentelemetry.io/otel/sdk v1.26.0 // indirect
go.opentelemetry.io/otel/trace v1.26.0 // indirect go.opentelemetry.io/otel/trace v1.26.0 // indirect
golang.org/x/image v0.18.0 // indirect golang.org/x/image v0.18.0 // indirect
golang.org/x/mod v0.17.0 // indirect
golang.org/x/text v0.16.0 // indirect golang.org/x/text v0.16.0 // indirect
golang.org/x/time v0.5.0 // indirect golang.org/x/time v0.5.0 // indirect
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240515191416-fc5f0ca64291 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240515191416-fc5f0ca64291 // indirect

534
go.sum

File diff suppressed because it is too large Load Diff

View File

@ -9,7 +9,10 @@ import (
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
@ -71,7 +74,11 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
peersUpdateManager := mgmt.NewPeersUpdateManager(nil) peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore) ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -2,6 +2,7 @@ package client
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"io" "io"
"sync" "sync"
@ -267,7 +268,7 @@ func (c *GrpcClient) receiveEvents(stream proto.ManagementService_SyncClient, se
// GetServerPublicKey returns server's WireGuard public key (used later for encrypting messages sent to the server) // GetServerPublicKey returns server's WireGuard public key (used later for encrypting messages sent to the server)
func (c *GrpcClient) GetServerPublicKey() (*wgtypes.Key, error) { func (c *GrpcClient) GetServerPublicKey() (*wgtypes.Key, error) {
if !c.ready() { if !c.ready() {
return nil, fmt.Errorf(errMsgNoMgmtConnection) return nil, errors.New(errMsgNoMgmtConnection)
} }
mgmCtx, cancel := context.WithTimeout(c.ctx, 5*time.Second) mgmCtx, cancel := context.WithTimeout(c.ctx, 5*time.Second)
@ -314,7 +315,7 @@ func (c *GrpcClient) IsHealthy() bool {
func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*proto.LoginResponse, error) { func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*proto.LoginResponse, error) {
if !c.ready() { if !c.ready() {
return nil, fmt.Errorf(errMsgNoMgmtConnection) return nil, errors.New(errMsgNoMgmtConnection)
} }
loginReq, err := encryption.EncryptMessage(serverKey, c.key, req) loginReq, err := encryption.EncryptMessage(serverKey, c.key, req)
@ -452,7 +453,7 @@ func (c *GrpcClient) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKC
// It should be used if there is changes on peer posture check after initial sync. // It should be used if there is changes on peer posture check after initial sync.
func (c *GrpcClient) SyncMeta(sysInfo *system.Info) error { func (c *GrpcClient) SyncMeta(sysInfo *system.Info) error {
if !c.ready() { if !c.ready() {
return fmt.Errorf(errMsgNoMgmtConnection) return errors.New(errMsgNoMgmtConnection)
} }
serverPubKey, err := c.GetServerPublicKey() serverPubKey, err := c.GetServerPublicKey()

View File

@ -190,7 +190,7 @@ var (
return fmt.Errorf("failed to initialize integrated peer validator: %v", err) return fmt.Errorf("failed to initialize integrated peer validator: %v", err)
} }
accountManager, err := server.BuildManager(ctx, store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain, accountManager, err := server.BuildManager(ctx, store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain,
dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator) dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator, appMetrics)
if err != nil { if err != nil {
return fmt.Errorf("failed to build default manager: %v", err) return fmt.Errorf("failed to build default manager: %v", err)
} }

View File

@ -18,6 +18,8 @@ import (
"github.com/eko/gocache/v3/cache" "github.com/eko/gocache/v3/cache"
cacheStore "github.com/eko/gocache/v3/store" cacheStore "github.com/eko/gocache/v3/store"
"github.com/hashicorp/go-multierror"
"github.com/miekg/dns"
gocache "github.com/patrickmn/go-cache" gocache "github.com/patrickmn/go-cache"
"github.com/rs/xid" "github.com/rs/xid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -37,6 +39,7 @@ import (
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@ -65,6 +68,7 @@ type AccountManager interface {
SaveSetupKey(ctx context.Context, accountID string, key *SetupKey, userID string) (*SetupKey, error) SaveSetupKey(ctx context.Context, accountID string, key *SetupKey, userID string) (*SetupKey, error)
CreateUser(ctx context.Context, accountID, initiatorUserID string, key *UserInfo) (*UserInfo, error) CreateUser(ctx context.Context, accountID, initiatorUserID string, key *UserInfo) (*UserInfo, error)
DeleteUser(ctx context.Context, accountID, initiatorUserID string, targetUserID string) error DeleteUser(ctx context.Context, accountID, initiatorUserID string, targetUserID string) error
DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string) error
InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error
ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error)
SaveUser(ctx context.Context, accountID, initiatorUserID string, update *User) (*UserInfo, error) SaveUser(ctx context.Context, accountID, initiatorUserID string, update *User) (*UserInfo, error)
@ -98,6 +102,7 @@ type AccountManager interface {
SaveGroup(ctx context.Context, accountID, userID string, group *nbgroup.Group) error SaveGroup(ctx context.Context, accountID, userID string, group *nbgroup.Group) error
SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error
DeleteGroup(ctx context.Context, accountId, userId, groupID string) error DeleteGroup(ctx context.Context, accountId, userId, groupID string) error
DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error
ListGroups(ctx context.Context, accountId string) ([]*nbgroup.Group, error) ListGroups(ctx context.Context, accountId string) ([]*nbgroup.Group, error)
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error
@ -170,6 +175,8 @@ type DefaultAccountManager struct {
userDeleteFromIDPEnabled bool userDeleteFromIDPEnabled bool
integratedPeerValidator integrated_validator.IntegratedValidator integratedPeerValidator integrated_validator.IntegratedValidator
metrics telemetry.AppMetrics
} }
// Settings represents Account settings structure that can be modified via API and Dashboard // Settings represents Account settings structure that can be modified via API and Dashboard
@ -401,8 +408,16 @@ func (a *Account) GetGroup(groupID string) *nbgroup.Group {
return a.Groups[groupID] return a.Groups[groupID]
} }
// GetPeerNetworkMap returns a group by ID if exists, nil otherwise // GetPeerNetworkMap returns the networkmap for the given peer ID.
func (a *Account) GetPeerNetworkMap(ctx context.Context, peerID, dnsDomain string, validatedPeersMap map[string]struct{}) *NetworkMap { func (a *Account) GetPeerNetworkMap(
ctx context.Context,
peerID string,
peersCustomZone nbdns.CustomZone,
validatedPeersMap map[string]struct{},
metrics *telemetry.AccountManagerMetrics,
) *NetworkMap {
start := time.Now()
peer := a.Peers[peerID] peer := a.Peers[peerID]
if peer == nil { if peer == nil {
return &NetworkMap{ return &NetworkMap{
@ -438,7 +453,7 @@ func (a *Account) GetPeerNetworkMap(ctx context.Context, peerID, dnsDomain strin
if dnsManagementStatus { if dnsManagementStatus {
var zones []nbdns.CustomZone var zones []nbdns.CustomZone
peersCustomZone := getPeersCustomZone(ctx, a, dnsDomain)
if peersCustomZone.Domain != "" { if peersCustomZone.Domain != "" {
zones = append(zones, peersCustomZone) zones = append(zones, peersCustomZone)
} }
@ -446,7 +461,7 @@ func (a *Account) GetPeerNetworkMap(ctx context.Context, peerID, dnsDomain strin
dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID) dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID)
} }
return &NetworkMap{ nm := &NetworkMap{
Peers: peersToConnect, Peers: peersToConnect,
Network: a.Network.Copy(), Network: a.Network.Copy(),
Routes: routesUpdate, Routes: routesUpdate,
@ -454,6 +469,60 @@ func (a *Account) GetPeerNetworkMap(ctx context.Context, peerID, dnsDomain strin
OfflinePeers: expiredPeers, OfflinePeers: expiredPeers,
FirewallRules: firewallRules, FirewallRules: firewallRules,
} }
if metrics != nil {
objectCount := int64(len(peersToConnect) + len(expiredPeers) + len(routesUpdate) + len(firewallRules))
metrics.CountNetworkMapObjects(objectCount)
metrics.CountGetPeerNetworkMapDuration(time.Since(start))
}
return nm
}
func (a *Account) GetPeersCustomZone(ctx context.Context, dnsDomain string) nbdns.CustomZone {
var merr *multierror.Error
if dnsDomain == "" {
log.WithContext(ctx).Error("no dns domain is set, returning empty zone")
return nbdns.CustomZone{}
}
customZone := nbdns.CustomZone{
Domain: dns.Fqdn(dnsDomain),
Records: make([]nbdns.SimpleRecord, 0, len(a.Peers)),
}
domainSuffix := "." + dnsDomain
var sb strings.Builder
for _, peer := range a.Peers {
if peer.DNSLabel == "" {
merr = multierror.Append(merr, fmt.Errorf("peer %s has an empty DNS label", peer.Name))
continue
}
sb.Grow(len(peer.DNSLabel) + len(domainSuffix))
sb.WriteString(peer.DNSLabel)
sb.WriteString(domainSuffix)
customZone.Records = append(customZone.Records, nbdns.SimpleRecord{
Name: sb.String(),
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: defaultTTL,
RData: peer.IP.String(),
})
sb.Reset()
}
go func() {
if merr != nil {
log.WithContext(ctx).Errorf("error generating custom zone for account %s: %v", a.Id, merr)
}
}()
return customZone
} }
// GetExpiredPeers returns peers that have been expired // GetExpiredPeers returns peers that have been expired
@ -853,7 +922,7 @@ func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) {
func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) { func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) {
for _, gid := range groups { for _, gid := range groups {
group, ok := a.Groups[gid] group, ok := a.Groups[gid]
if !ok { if !ok || group.Name == "All" {
continue continue
} }
update := make([]string, 0, len(group.Peers)) update := make([]string, 0, len(group.Peers))
@ -871,10 +940,18 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) {
} }
// BuildManager creates a new DefaultAccountManager with a provided Store // BuildManager creates a new DefaultAccountManager with a provided Store
func BuildManager(ctx context.Context, store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager, func BuildManager(
singleAccountModeDomain string, dnsDomain string, eventStore activity.Store, geo *geolocation.Geolocation, ctx context.Context,
store Store,
peersUpdateManager *PeersUpdateManager,
idpManager idp.Manager,
singleAccountModeDomain string,
dnsDomain string,
eventStore activity.Store,
geo *geolocation.Geolocation,
userDeleteFromIDPEnabled bool, userDeleteFromIDPEnabled bool,
integratedPeerValidator integrated_validator.IntegratedValidator, integratedPeerValidator integrated_validator.IntegratedValidator,
metrics telemetry.AppMetrics,
) (*DefaultAccountManager, error) { ) (*DefaultAccountManager, error) {
am := &DefaultAccountManager{ am := &DefaultAccountManager{
Store: store, Store: store,
@ -889,6 +966,7 @@ func BuildManager(ctx context.Context, store Store, peersUpdateManager *PeersUpd
peerLoginExpiry: NewDefaultScheduler(), peerLoginExpiry: NewDefaultScheduler(),
userDeleteFromIDPEnabled: userDeleteFromIDPEnabled, userDeleteFromIDPEnabled: userDeleteFromIDPEnabled,
integratedPeerValidator: integratedPeerValidator, integratedPeerValidator: integratedPeerValidator,
metrics: metrics,
} }
allAccounts := store.GetAllAccounts(ctx) allAccounts := store.GetAllAccounts(ctx)
// enable single account mode only if configured by user and number of existing accounts is not grater than 1 // enable single account mode only if configured by user and number of existing accounts is not grater than 1
@ -1994,6 +2072,28 @@ func (am *DefaultAccountManager) GetAccountIDForPeerKey(ctx context.Context, pee
return am.Store.GetAccountIDByPeerPubKey(ctx, peerKey) return am.Store.GetAccountIDByPeerPubKey(ctx, peerKey)
} }
func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpeer.Peer, settings *Settings) (bool, error) {
user, err := am.Store.GetUserByUserID(ctx, peer.UserID)
if err != nil {
return false, err
}
err = checkIfPeerOwnerIsBlocked(peer, user)
if err != nil {
return false, err
}
if peerLoginExpired(ctx, peer, settings) {
err = am.handleExpiredPeer(ctx, user, peer)
if err != nil {
return false, err
}
return true, nil
}
return false, nil
}
// addAllGroup to account object if it doesn't exist // addAllGroup to account object if it doesn't exist
func addAllGroup(account *Account) error { func addAllGroup(account *Account) error {
if len(account.Groups) == 0 { if len(account.Groups) == 0 {

View File

@ -24,6 +24,7 @@ import (
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@ -410,7 +411,8 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
validatedPeers[p] = struct{}{} validatedPeers[p] = struct{}{}
} }
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, "netbird.io", validatedPeers) customZone := account.GetPeersCustomZone(context.Background(), "netbird.io")
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, nil)
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers)) assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers)) assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
} }
@ -2293,7 +2295,13 @@ func TestAccount_UserGroupsRemoveFromPeers(t *testing.T) {
}) })
} }
func createManager(t *testing.T) (*DefaultAccountManager, error) { type TB interface {
Cleanup(func())
Helper()
TempDir() string
}
func createManager(t TB) (*DefaultAccountManager, error) {
t.Helper() t.Helper()
store, err := createStore(t) store, err := createStore(t)
@ -2302,7 +2310,12 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) {
} }
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}) metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
if err != nil {
return nil, err
}
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -2310,7 +2323,7 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) {
return manager, nil return manager, nil
} }
func createStore(t *testing.T) (Store, error) { func createStore(t TB) (Store, error) {
t.Helper() t.Helper()
dataDir := t.TempDir() dataDir := t.TempDir()
store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir) store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir)

View File

@ -4,8 +4,8 @@ import (
"context" "context"
"fmt" "fmt"
"strconv" "strconv"
"sync"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
@ -17,6 +17,50 @@ import (
const defaultTTL = 300 const defaultTTL = 300
// DNSConfigCache is a thread-safe cache for DNS configuration components
type DNSConfigCache struct {
CustomZones sync.Map
NameServerGroups sync.Map
}
// GetCustomZone retrieves a cached custom zone
func (c *DNSConfigCache) GetCustomZone(key string) (*proto.CustomZone, bool) {
if c == nil {
return nil, false
}
if value, ok := c.CustomZones.Load(key); ok {
return value.(*proto.CustomZone), true
}
return nil, false
}
// SetCustomZone stores a custom zone in the cache
func (c *DNSConfigCache) SetCustomZone(key string, value *proto.CustomZone) {
if c == nil {
return
}
c.CustomZones.Store(key, value)
}
// GetNameServerGroup retrieves a cached name server group
func (c *DNSConfigCache) GetNameServerGroup(key string) (*proto.NameServerGroup, bool) {
if c == nil {
return nil, false
}
if value, ok := c.NameServerGroups.Load(key); ok {
return value.(*proto.NameServerGroup), true
}
return nil, false
}
// SetNameServerGroup stores a name server group in the cache
func (c *DNSConfigCache) SetNameServerGroup(key string, value *proto.NameServerGroup) {
if c == nil {
return
}
c.NameServerGroups.Store(key, value)
}
type lookupMap map[string]struct{} type lookupMap map[string]struct{}
// DNSSettings defines dns settings at the account level // DNSSettings defines dns settings at the account level
@ -113,11 +157,45 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
return nil return nil
} }
func toProtocolDNSConfig(update nbdns.Config) *proto.DNSConfig { // toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache
protoUpdate := &proto.DNSConfig{ServiceEnable: update.ServiceEnable} func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache) *proto.DNSConfig {
protoUpdate := &proto.DNSConfig{
ServiceEnable: update.ServiceEnable,
CustomZones: make([]*proto.CustomZone, 0, len(update.CustomZones)),
NameServerGroups: make([]*proto.NameServerGroup, 0, len(update.NameServerGroups)),
}
for _, zone := range update.CustomZones { for _, zone := range update.CustomZones {
protoZone := &proto.CustomZone{Domain: zone.Domain} cacheKey := zone.Domain
if cachedZone, exists := cache.GetCustomZone(cacheKey); exists {
protoUpdate.CustomZones = append(protoUpdate.CustomZones, cachedZone)
} else {
protoZone := convertToProtoCustomZone(zone)
cache.SetCustomZone(cacheKey, protoZone)
protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone)
}
}
for _, nsGroup := range update.NameServerGroups {
cacheKey := nsGroup.ID
if cachedGroup, exists := cache.GetNameServerGroup(cacheKey); exists {
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, cachedGroup)
} else {
protoGroup := convertToProtoNameServerGroup(nsGroup)
cache.SetNameServerGroup(cacheKey, protoGroup)
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup)
}
}
return protoUpdate
}
// Helper function to convert nbdns.CustomZone to proto.CustomZone
func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone {
protoZone := &proto.CustomZone{
Domain: zone.Domain,
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
}
for _, record := range zone.Records { for _, record := range zone.Records {
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{ protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{
Name: record.Name, Name: record.Name,
@ -127,55 +205,25 @@ func toProtocolDNSConfig(update nbdns.Config) *proto.DNSConfig {
RData: record.RData, RData: record.RData,
}) })
} }
protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone) return protoZone
} }
for _, nsGroup := range update.NameServerGroups { // Helper function to convert nbdns.NameServerGroup to proto.NameServerGroup
func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameServerGroup {
protoGroup := &proto.NameServerGroup{ protoGroup := &proto.NameServerGroup{
Primary: nsGroup.Primary, Primary: nsGroup.Primary,
Domains: nsGroup.Domains, Domains: nsGroup.Domains,
SearchDomainsEnabled: nsGroup.SearchDomainsEnabled, SearchDomainsEnabled: nsGroup.SearchDomainsEnabled,
NameServers: make([]*proto.NameServer, 0, len(nsGroup.NameServers)),
} }
for _, ns := range nsGroup.NameServers { for _, ns := range nsGroup.NameServers {
protoNS := &proto.NameServer{ protoGroup.NameServers = append(protoGroup.NameServers, &proto.NameServer{
IP: ns.IP.String(), IP: ns.IP.String(),
Port: int64(ns.Port), Port: int64(ns.Port),
NSType: int64(ns.NSType), NSType: int64(ns.NSType),
}
protoGroup.NameServers = append(protoGroup.NameServers, protoNS)
}
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup)
}
return protoUpdate
}
func getPeersCustomZone(ctx context.Context, account *Account, dnsDomain string) nbdns.CustomZone {
if dnsDomain == "" {
log.WithContext(ctx).Errorf("no dns domain is set, returning empty zone")
return nbdns.CustomZone{}
}
customZone := nbdns.CustomZone{
Domain: dns.Fqdn(dnsDomain),
}
for _, peer := range account.Peers {
if peer.DNSLabel == "" {
log.WithContext(ctx).Errorf("found a peer with empty dns label. It was probably caused by a invalid character in its name. Peer Name: %s", peer.Name)
continue
}
customZone.Records = append(customZone.Records, nbdns.SimpleRecord{
Name: dns.Fqdn(peer.DNSLabel + "." + dnsDomain),
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: defaultTTL,
RData: peer.IP.String(),
}) })
} }
return protoGroup
return customZone
} }
func getPeerNSGroups(account *Account, peerID string) []*nbdns.NameServerGroup { func getPeerNSGroups(account *Account, peerID string) []*nbdns.NameServerGroup {

View File

@ -2,9 +2,14 @@ package server
import ( import (
"context" "context"
"fmt"
"net/netip" "net/netip"
"reflect"
"testing" "testing"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/dns"
@ -195,7 +200,11 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
return nil, err return nil, err
} }
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{})
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}, metrics)
} }
func createDNSStore(t *testing.T) (Store, error) { func createDNSStore(t *testing.T) (Store, error) {
@ -320,3 +329,150 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro
return am.Store.GetAccount(context.Background(), account.Id) return am.Store.GetAccount(context.Background(), account.Id)
} }
func generateTestData(size int) nbdns.Config {
config := nbdns.Config{
ServiceEnable: true,
CustomZones: make([]nbdns.CustomZone, size),
NameServerGroups: make([]*nbdns.NameServerGroup, size),
}
for i := 0; i < size; i++ {
config.CustomZones[i] = nbdns.CustomZone{
Domain: fmt.Sprintf("domain%d.com", i),
Records: []nbdns.SimpleRecord{
{
Name: fmt.Sprintf("record%d", i),
Type: 1,
Class: "IN",
TTL: 3600,
RData: "192.168.1.1",
},
},
}
config.NameServerGroups[i] = &nbdns.NameServerGroup{
ID: fmt.Sprintf("group%d", i),
Primary: i == 0,
Domains: []string{fmt.Sprintf("domain%d.com", i)},
SearchDomainsEnabled: true,
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
Port: 53,
NSType: 1,
},
},
}
}
return config
}
func BenchmarkToProtocolDNSConfig(b *testing.B) {
sizes := []int{10, 100, 1000}
for _, size := range sizes {
testData := generateTestData(size)
b.Run(fmt.Sprintf("WithCache-Size%d", size), func(b *testing.B) {
cache := &DNSConfigCache{}
b.ResetTimer()
for i := 0; i < b.N; i++ {
toProtocolDNSConfig(testData, cache)
}
})
b.Run(fmt.Sprintf("WithoutCache-Size%d", size), func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
cache := &DNSConfigCache{}
toProtocolDNSConfig(testData, cache)
}
})
}
}
func TestToProtocolDNSConfigWithCache(t *testing.T) {
var cache DNSConfigCache
// Create two different configs
config1 := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "example.com",
Records: []nbdns.SimpleRecord{
{Name: "www", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.1"},
},
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
ID: "group1",
Name: "Group 1",
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("8.8.8.8"), Port: 53},
},
},
},
}
config2 := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "example.org",
Records: []nbdns.SimpleRecord{
{Name: "mail", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.2"},
},
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
ID: "group2",
Name: "Group 2",
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("8.8.4.4"), Port: 53},
},
},
},
}
// First run with config1
result1 := toProtocolDNSConfig(config1, &cache)
// Second run with config2
result2 := toProtocolDNSConfig(config2, &cache)
// Third run with config1 again
result3 := toProtocolDNSConfig(config1, &cache)
// Verify that result1 and result3 are identical
if !reflect.DeepEqual(result1, result3) {
t.Errorf("Results are not identical when run with the same input. Expected %v, got %v", result1, result3)
}
// Verify that result2 is different from result1 and result3
if reflect.DeepEqual(result1, result2) || reflect.DeepEqual(result2, result3) {
t.Errorf("Results should be different for different inputs")
}
// Verify that the cache contains elements from both configs
if _, exists := cache.GetCustomZone("example.com"); !exists {
t.Errorf("Cache should contain custom zone for example.com")
}
if _, exists := cache.GetCustomZone("example.org"); !exists {
t.Errorf("Cache should contain custom zone for example.org")
}
if _, exists := cache.GetNameServerGroup("group1"); !exists {
t.Errorf("Cache should contain name server group 'group1'")
}
if _, exists := cache.GetNameServerGroup("group2"); !exists {
t.Errorf("Cache should contain name server group 'group2'")
}
}

View File

@ -469,6 +469,35 @@ func (s *FileStore) GetUserByTokenID(_ context.Context, tokenID string) (*User,
return account.Users[userID].Copy(), nil return account.Users[userID].Copy(), nil
} }
func (s *FileStore) GetUserByUserID(_ context.Context, userID string) (*User, error) {
accountID, ok := s.UserID2AccountID[userID]
if !ok {
return nil, status.Errorf(status.NotFound, "accountID not found: provided userID doesn't exists")
}
account, err := s.getAccount(accountID)
if err != nil {
return nil, err
}
return account.Users[userID].Copy(), nil
}
func (s *FileStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
account, err := s.getAccount(accountID)
if err != nil {
return nil, err
}
groupsSlice := make([]*nbgroup.Group, 0, len(account.Groups))
for _, group := range account.Groups {
groupsSlice = append(groupsSlice, group)
}
return groupsSlice, nil
}
// GetAllAccounts returns all accounts // GetAllAccounts returns all accounts
func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) { func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) {
s.mux.Lock() s.mux.Lock()

View File

@ -2,8 +2,12 @@ package server
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"slices"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route"
"github.com/rs/xid" "github.com/rs/xid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -243,7 +247,7 @@ func difference(a, b []string) []string {
return diff return diff
} }
// DeleteGroup object of the peers // DeleteGroup object of the peers.
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error { func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountId) unlock := am.Store.AcquireWriteLockByUID(ctx, accountId)
defer unlock() defer unlock()
@ -253,96 +257,14 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use
return err return err
} }
g, ok := account.Groups[groupID] group, ok := account.Groups[groupID]
if !ok { if !ok {
return nil return nil
} }
// disable a deleting integration group if the initiator is not an admin service user if err = validateDeleteGroup(account, group, userId); err != nil {
if g.Issued == nbgroup.GroupIssuedIntegration { return err
executingUser := account.Users[userId]
if executingUser == nil {
return status.Errorf(status.NotFound, "user not found")
} }
if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser {
return status.Errorf(status.PermissionDenied, "only service users with admin power can delete integration group")
}
}
// check route links
for _, r := range account.Routes {
for _, g := range r.Groups {
if g == groupID {
return &GroupLinkError{"route", string(r.NetID)}
}
}
for _, g := range r.PeerGroups {
if g == groupID {
return &GroupLinkError{"route", string(r.NetID)}
}
}
}
// check DNS links
for _, dns := range account.NameServerGroups {
for _, g := range dns.Groups {
if g == groupID {
return &GroupLinkError{"name server groups", dns.Name}
}
}
}
// check ACL links
for _, policy := range account.Policies {
for _, rule := range policy.Rules {
for _, src := range rule.Sources {
if src == groupID {
return &GroupLinkError{"policy", policy.Name}
}
}
for _, dst := range rule.Destinations {
if dst == groupID {
return &GroupLinkError{"policy", policy.Name}
}
}
}
}
// check setup key links
for _, setupKey := range account.SetupKeys {
for _, grp := range setupKey.AutoGroups {
if grp == groupID {
return &GroupLinkError{"setup key", setupKey.Name}
}
}
}
// check user links
for _, user := range account.Users {
for _, grp := range user.AutoGroups {
if grp == groupID {
return &GroupLinkError{"user", user.Id}
}
}
}
// check DisabledManagementGroups
for _, disabledMgmGrp := range account.DNSSettings.DisabledManagementGroups {
if disabledMgmGrp == groupID {
return &GroupLinkError{"disabled DNS management groups", g.Name}
}
}
// check integrated peer validator groups
if account.Settings.Extra != nil {
for _, integratedPeerValidatorGroups := range account.Settings.Extra.IntegratedValidatorGroups {
if groupID == integratedPeerValidatorGroups {
return &GroupLinkError{"integrated validator", g.Name}
}
}
}
delete(account.Groups, groupID) delete(account.Groups, groupID)
account.Network.IncSerial() account.Network.IncSerial()
@ -350,13 +272,57 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use
return err return err
} }
am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, g.EventMeta()) am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, group.EventMeta())
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, account)
return nil return nil
} }
// DeleteGroups deletes groups from an account.
// Note: This function does not acquire the global lock.
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
//
// If an error occurs while deleting a group, the function skips it and continues deleting other groups.
// Errors are collected and returned at the end.
func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error {
account, err := am.Store.GetAccount(ctx, accountId)
if err != nil {
return err
}
var allErrors error
deletedGroups := make([]*nbgroup.Group, 0, len(groupIDs))
for _, groupID := range groupIDs {
group, ok := account.Groups[groupID]
if !ok {
continue
}
if err := validateDeleteGroup(account, group, userId); err != nil {
allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err))
continue
}
delete(account.Groups, groupID)
deletedGroups = append(deletedGroups, group)
}
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
for _, g := range deletedGroups {
am.StoreEvent(ctx, userId, g.ID, accountId, activity.GroupDeleted, g.EventMeta())
}
am.updateAccountPeers(ctx, account)
return allErrors
}
// ListGroups objects of the peers // ListGroups objects of the peers
func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) { func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
@ -440,3 +406,102 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
return nil return nil
} }
func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string) error {
// disable a deleting integration group if the initiator is not an admin service user
if group.Issued == nbgroup.GroupIssuedIntegration {
executingUser := account.Users[userID]
if executingUser == nil {
return status.Errorf(status.NotFound, "user not found")
}
if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser {
return status.Errorf(status.PermissionDenied, "only service users with admin power can delete integration group")
}
}
if isLinked, linkedRoute := isGroupLinkedToRoute(account.Routes, group.ID); isLinked {
return &GroupLinkError{"route", string(linkedRoute.NetID)}
}
if isLinked, linkedDns := isGroupLinkedToDns(account.NameServerGroups, group.ID); isLinked {
return &GroupLinkError{"name server groups", linkedDns.Name}
}
if isLinked, linkedPolicy := isGroupLinkedToPolicy(account.Policies, group.ID); isLinked {
return &GroupLinkError{"policy", linkedPolicy.Name}
}
if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(account.SetupKeys, group.ID); isLinked {
return &GroupLinkError{"setup key", linkedSetupKey.Name}
}
if isLinked, linkedUser := isGroupLinkedToUser(account.Users, group.ID); isLinked {
return &GroupLinkError{"user", linkedUser.Id}
}
if slices.Contains(account.DNSSettings.DisabledManagementGroups, group.ID) {
return &GroupLinkError{"disabled DNS management groups", group.Name}
}
if account.Settings.Extra != nil {
if slices.Contains(account.Settings.Extra.IntegratedValidatorGroups, group.ID) {
return &GroupLinkError{"integrated validator", group.Name}
}
}
return nil
}
// isGroupLinkedToRoute checks if a group is linked to any route in the account.
func isGroupLinkedToRoute(routes map[route.ID]*route.Route, groupID string) (bool, *route.Route) {
for _, r := range routes {
if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) {
return true, r
}
}
return false, nil
}
// isGroupLinkedToPolicy checks if a group is linked to any policy in the account.
func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) {
for _, policy := range policies {
for _, rule := range policy.Rules {
if slices.Contains(rule.Sources, groupID) || slices.Contains(rule.Destinations, groupID) {
return true, policy
}
}
}
return false, nil
}
// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account.
func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, groupID string) (bool, *nbdns.NameServerGroup) {
for _, dns := range nameServerGroups {
for _, g := range dns.Groups {
if g == groupID {
return true, dns
}
}
}
return false, nil
}
// isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account.
func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bool, *SetupKey) {
for _, setupKey := range setupKeys {
if slices.Contains(setupKey.AutoGroups, groupID) {
return true, setupKey
}
}
return false, nil
}
// isGroupLinkedToUser checks if a group is linked to any user in the account.
func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) {
for _, user := range users {
if slices.Contains(user.AutoGroups, groupID) {
return true, user
}
}
return false, nil
}

View File

@ -3,12 +3,14 @@ package server
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"testing" "testing"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
nbgroup "github.com/netbirdio/netbird/management/server/group" nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/stretchr/testify/assert"
) )
const ( const (
@ -21,7 +23,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) {
t.Error("failed to create account manager") t.Error("failed to create account manager")
} }
account, err := initTestGroupAccount(am) _, account, err := initTestGroupAccount(am)
if err != nil { if err != nil {
t.Error("failed to init testing account") t.Error("failed to init testing account")
} }
@ -56,7 +58,7 @@ func TestDefaultAccountManager_DeleteGroup(t *testing.T) {
t.Error("failed to create account manager") t.Error("failed to create account manager")
} }
account, err := initTestGroupAccount(am) _, account, err := initTestGroupAccount(am)
if err != nil { if err != nil {
t.Error("failed to init testing account") t.Error("failed to init testing account")
} }
@ -132,7 +134,136 @@ func TestDefaultAccountManager_DeleteGroup(t *testing.T) {
} }
} }
func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) { func TestDefaultAccountManager_DeleteGroups(t *testing.T) {
am, err := createManager(t)
assert.NoError(t, err, "Failed to create account manager")
manager, account, err := initTestGroupAccount(am)
assert.NoError(t, err, "Failed to init testing account")
groups := make([]*nbgroup.Group, 10)
for i := 0; i < 10; i++ {
groups[i] = &nbgroup.Group{
ID: fmt.Sprintf("group-%d", i+1),
AccountID: account.Id,
Name: fmt.Sprintf("group-%d", i+1),
Issued: nbgroup.GroupIssuedAPI,
}
}
err = manager.SaveGroups(context.Background(), account.Id, groupAdminUserID, groups)
assert.NoError(t, err, "Failed to save test groups")
testCases := []struct {
name string
groupIDs []string
expectedReasons []string
expectedDeleted []string
expectedNotDeleted []string
}{
{
name: "route",
groupIDs: []string{"grp-for-route"},
expectedReasons: []string{"route"},
},
{
name: "route with peer groups",
groupIDs: []string{"grp-for-route2"},
expectedReasons: []string{"route"},
},
{
name: "name server groups",
groupIDs: []string{"grp-for-name-server-grp"},
expectedReasons: []string{"name server groups"},
},
{
name: "policy",
groupIDs: []string{"grp-for-policies"},
expectedReasons: []string{"policy"},
},
{
name: "setup keys",
groupIDs: []string{"grp-for-keys"},
expectedReasons: []string{"setup key"},
},
{
name: "users",
groupIDs: []string{"grp-for-users"},
expectedReasons: []string{"user"},
},
{
name: "integration",
groupIDs: []string{"grp-for-integration"},
expectedReasons: []string{"only service users with admin power can delete integration group"},
},
{
name: "successfully delete multiple groups",
groupIDs: []string{"group-1", "group-2"},
expectedDeleted: []string{"group-1", "group-2"},
},
{
name: "delete non-existent group",
groupIDs: []string{"non-existent-group"},
expectedDeleted: []string{"non-existent-group"},
},
{
name: "delete multiple groups with mixed results",
groupIDs: []string{"group-3", "grp-for-policies", "group-4", "grp-for-users"},
expectedReasons: []string{"policy", "user"},
expectedDeleted: []string{"group-3", "group-4"},
expectedNotDeleted: []string{"grp-for-policies", "grp-for-users"},
},
{
name: "delete groups with multiple errors",
groupIDs: []string{"grp-for-policies", "grp-for-users"},
expectedReasons: []string{"policy", "user"},
expectedNotDeleted: []string{"grp-for-policies", "grp-for-users"},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err = am.DeleteGroups(context.Background(), account.Id, groupAdminUserID, tc.groupIDs)
if len(tc.expectedReasons) > 0 {
assert.Error(t, err)
var foundExpectedErrors int
wrappedErr, ok := err.(interface{ Unwrap() []error })
assert.Equal(t, ok, true)
for _, e := range wrappedErr.Unwrap() {
var sErr *status.Error
if errors.As(e, &sErr) {
assert.Contains(t, tc.expectedReasons, sErr.Message, "unexpected error message")
foundExpectedErrors++
}
var gErr *GroupLinkError
if errors.As(e, &gErr) {
assert.Contains(t, tc.expectedReasons, gErr.Resource, "unexpected error resource")
foundExpectedErrors++
}
}
assert.Equal(t, len(tc.expectedReasons), foundExpectedErrors, "not all expected errors were found")
} else {
assert.NoError(t, err)
}
for _, groupID := range tc.expectedDeleted {
_, err := am.GetGroup(context.Background(), account.Id, groupID, groupAdminUserID)
assert.Error(t, err, "group should have been deleted: %s", groupID)
}
for _, groupID := range tc.expectedNotDeleted {
group, err := am.GetGroup(context.Background(), account.Id, groupID, groupAdminUserID)
assert.NoError(t, err, "group should not have been deleted: %s", groupID)
assert.NotNil(t, group, "group should exist: %s", groupID)
}
})
}
}
func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *Account, error) {
accountID := "testingAcc" accountID := "testingAcc"
domain := "example.com" domain := "example.com"
@ -236,7 +367,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) {
err := am.Store.SaveAccount(context.Background(), account) err := am.Store.SaveAccount(context.Background(), account)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute) _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute)
@ -247,5 +378,9 @@ func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) {
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForUsers) _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForUsers)
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForIntegration) _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForIntegration)
return am.Store.GetAccount(context.Background(), account.Id) acc, err := am.Store.GetAccount(context.Background(), account.Id)
if err != nil {
return nil, nil, err
}
return am, acc, nil
} }

View File

@ -256,7 +256,7 @@ func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string
} }
if err := s.accountManager.CheckUserAccessByJWTGroups(ctx, claims); err != nil { if err := s.accountManager.CheckUserAccessByJWTGroups(ctx, claims); err != nil {
return "", status.Errorf(codes.PermissionDenied, err.Error()) return "", status.Error(codes.PermissionDenied, err.Error())
} }
return claims.UserId, nil return claims.UserId, nil
@ -267,15 +267,15 @@ func mapError(ctx context.Context, err error) error {
if e, ok := internalStatus.FromError(err); ok { if e, ok := internalStatus.FromError(err); ok {
switch e.Type() { switch e.Type() {
case internalStatus.PermissionDenied: case internalStatus.PermissionDenied:
return status.Errorf(codes.PermissionDenied, e.Message) return status.Error(codes.PermissionDenied, e.Message)
case internalStatus.Unauthorized: case internalStatus.Unauthorized:
return status.Errorf(codes.PermissionDenied, e.Message) return status.Error(codes.PermissionDenied, e.Message)
case internalStatus.Unauthenticated: case internalStatus.Unauthenticated:
return status.Errorf(codes.PermissionDenied, e.Message) return status.Error(codes.PermissionDenied, e.Message)
case internalStatus.PreconditionFailed: case internalStatus.PreconditionFailed:
return status.Errorf(codes.FailedPrecondition, e.Message) return status.Error(codes.FailedPrecondition, e.Message)
case internalStatus.NotFound: case internalStatus.NotFound:
return status.Errorf(codes.NotFound, e.Message) return status.Error(codes.NotFound, e.Message)
default: default:
} }
} }
@ -550,53 +550,46 @@ func toPeerConfig(peer *nbpeer.Peer, network *Network, dnsName string) *proto.Pe
} }
} }
func toRemotePeerConfig(peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig { func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *TURNRelayToken, relayCredentials *TURNRelayToken, networkMap *NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache) *proto.SyncResponse {
remotePeers := []*proto.RemotePeerConfig{} response := &proto.SyncResponse{
for _, rPeer := range peers { WiretrusteeConfig: toWiretrusteeConfig(config, turnCredentials, relayCredentials),
fqdn := rPeer.FQDN(dnsName) PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName),
remotePeers = append(remotePeers, &proto.RemotePeerConfig{
WgPubKey: rPeer.Key,
AllowedIps: []string{fmt.Sprintf(AllowedIPsFormat, rPeer.IP)},
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
Fqdn: fqdn,
})
}
return remotePeers
}
func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *TURNRelayToken, relayCredentials *TURNRelayToken, networkMap *NetworkMap, dnsName string, checks []*posture.Checks) *proto.SyncResponse {
wtConfig := toWiretrusteeConfig(config, turnCredentials, relayCredentials)
pConfig := toPeerConfig(peer, networkMap.Network, dnsName)
remotePeers := toRemotePeerConfig(networkMap.Peers, dnsName)
routesUpdate := toProtocolRoutes(networkMap.Routes)
dnsUpdate := toProtocolDNSConfig(networkMap.DNSConfig)
offlinePeers := toRemotePeerConfig(networkMap.OfflinePeers, dnsName)
firewallRules := toProtocolFirewallRules(networkMap.FirewallRules)
return &proto.SyncResponse{
WiretrusteeConfig: wtConfig,
PeerConfig: pConfig,
RemotePeers: remotePeers,
RemotePeersIsEmpty: len(remotePeers) == 0,
NetworkMap: &proto.NetworkMap{ NetworkMap: &proto.NetworkMap{
Serial: networkMap.Network.CurrentSerial(), Serial: networkMap.Network.CurrentSerial(),
PeerConfig: pConfig, Routes: toProtocolRoutes(networkMap.Routes),
RemotePeers: remotePeers, DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache),
OfflinePeers: offlinePeers,
RemotePeersIsEmpty: len(remotePeers) == 0,
Routes: routesUpdate,
DNSConfig: dnsUpdate,
FirewallRules: firewallRules,
FirewallRulesIsEmpty: len(firewallRules) == 0,
}, },
Checks: toProtocolChecks(ctx, checks), Checks: toProtocolChecks(ctx, checks),
} }
response.NetworkMap.PeerConfig = response.PeerConfig
allPeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers))
allPeers = appendRemotePeerConfig(allPeers, networkMap.Peers, dnsName)
response.RemotePeers = allPeers
response.NetworkMap.RemotePeers = allPeers
response.RemotePeersIsEmpty = len(allPeers) == 0
response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty
response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName)
firewallRules := toProtocolFirewallRules(networkMap.FirewallRules)
response.NetworkMap.FirewallRules = firewallRules
response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0
return response
}
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig {
for _, rPeer := range peers {
dst = append(dst, &proto.RemotePeerConfig{
WgPubKey: rPeer.Key,
AllowedIps: []string{rPeer.IP.String() + "/32"},
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
Fqdn: rPeer.FQDN(dnsName),
})
}
return dst
} }
// IsHealthy indicates whether the service is healthy // IsHealthy indicates whether the service is healthy
@ -615,7 +608,7 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p
if s.config.TURNConfig.TimeBasedCredentials { if s.config.TURNConfig.TimeBasedCredentials {
turnCredentials = trt turnCredentials = trt
} }
plainResp := toSyncResponse(ctx, s.config, peer, turnCredentials, trt, networkMap, s.accountManager.GetDNSDomain(), postureChecks) plainResp := toSyncResponse(ctx, s.config, peer, turnCredentials, trt, networkMap, s.accountManager.GetDNSDomain(), postureChecks, nil)
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp) encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
if err != nil { if err != nil {

View File

@ -71,7 +71,8 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee
return return
} }
netMap := account.GetPeerNetworkMap(ctx, peerID, h.accountManager.GetDNSDomain(), validPeers) customZone := account.GetPeersCustomZone(ctx, h.accountManager.GetDNSDomain())
netMap := account.GetPeerNetworkMap(ctx, peerID, customZone, validPeers, nil)
accessiblePeers := toAccessiblePeers(netMap, dnsDomain) accessiblePeers := toAccessiblePeers(netMap, dnsDomain)
_, valid := validPeers[peer.ID] _, valid := validPeers[peer.ID]
@ -115,7 +116,9 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account,
util.WriteError(ctx, fmt.Errorf("internal error"), w) util.WriteError(ctx, fmt.Errorf("internal error"), w)
return return
} }
netMap := account.GetPeerNetworkMap(ctx, peerID, h.accountManager.GetDNSDomain(), validPeers)
customZone := account.GetPeersCustomZone(ctx, h.accountManager.GetDNSDomain())
netMap := account.GetPeerNetworkMap(ctx, peerID, customZone, validPeers, nil)
accessiblePeers := toAccessiblePeers(netMap, dnsDomain) accessiblePeers := toAccessiblePeers(netMap, dnsDomain)
_, valid := validPeers[peer.ID] _, valid := validPeers[peer.ID]
@ -194,9 +197,7 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
} }
groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID) groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID)
accessiblePeerNumbers, _ := h.accessiblePeersNumber(r.Context(), account, peer.ID) respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0))
respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, accessiblePeerNumbers))
} }
validPeersMap, err := h.accountManager.GetValidatedPeers(account) validPeersMap, err := h.accountManager.GetValidatedPeers(account)
@ -210,16 +211,6 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, respBody) util.WriteJSONObject(r.Context(), w, respBody)
} }
func (h *PeersHandler) accessiblePeersNumber(ctx context.Context, account *server.Account, peerID string) (int, error) {
validatedPeersMap, err := h.accountManager.GetValidatedPeers(account)
if err != nil {
return 0, err
}
netMap := account.GetPeerNetworkMap(ctx, peerID, h.accountManager.GetDNSDomain(), validatedPeersMap)
return len(netMap.Peers) + len(netMap.OfflinePeers), nil
}
func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPeersMap map[string]struct{}) { func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPeersMap map[string]struct{}) {
for _, peer := range respBody { for _, peer := range respBody {
_, ok := approvedPeersMap[peer.Id] _, ok := approvedPeersMap[peer.Id]

View File

@ -46,7 +46,7 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
testPostureChecks[postureChecks.ID] = postureChecks testPostureChecks[postureChecks.ID] = postureChecks
if err := postureChecks.Validate(); err != nil { if err := postureChecks.Validate(); err != nil {
return status.Errorf(status.InvalidArgument, err.Error()) return status.Errorf(status.InvalidArgument, err.Error()) //nolint
} }
return nil return nil

View File

@ -3,6 +3,7 @@ package idp
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@ -44,14 +45,14 @@ type mockJsonParser struct {
func (m *mockJsonParser) Marshal(v interface{}) ([]byte, error) { func (m *mockJsonParser) Marshal(v interface{}) ([]byte, error) {
if m.marshalErrorString != "" { if m.marshalErrorString != "" {
return nil, fmt.Errorf(m.marshalErrorString) return nil, errors.New(m.marshalErrorString)
} }
return m.jsonParser.Marshal(v) return m.jsonParser.Marshal(v)
} }
func (m *mockJsonParser) Unmarshal(data []byte, v interface{}) error { func (m *mockJsonParser) Unmarshal(data []byte, v interface{}) error {
if m.unmarshalErrorString != "" { if m.unmarshalErrorString != "" {
return fmt.Errorf(m.unmarshalErrorString) return errors.New(m.unmarshalErrorString)
} }
return m.jsonParser.Unmarshal(data, v) return m.jsonParser.Unmarshal(data, v)
} }

View File

@ -150,7 +150,7 @@ func (m *JWTValidator) ValidateAndParse(ctx context.Context, token string) (*jwt
// If we get here, the required token is missing // If we get here, the required token is missing
errorMsg := "required authorization token not found" errorMsg := "required authorization token not found"
log.WithContext(ctx).Debugf(" Error: No credentials found (CredentialsOptional=false)") log.WithContext(ctx).Debugf(" Error: No credentials found (CredentialsOptional=false)")
return nil, fmt.Errorf(errorMsg) return nil, errors.New(errorMsg)
} }
// Now parse the token // Now parse the token
@ -173,7 +173,7 @@ func (m *JWTValidator) ValidateAndParse(ctx context.Context, token string) (*jwt
// Check if the parsed token is valid... // Check if the parsed token is valid...
if !parsedToken.Valid { if !parsedToken.Valid {
errorMsg := "token is invalid" errorMsg := "token is invalid"
log.WithContext(ctx).Debugf(errorMsg) log.WithContext(ctx).Debug(errorMsg)
return nil, errors.New(errorMsg) return nil, errors.New(errorMsg)
} }

View File

@ -20,6 +20,7 @@ import (
"github.com/netbirdio/netbird/formatter" "github.com/netbirdio/netbird/formatter"
mgmtProto "github.com/netbirdio/netbird/management/proto" mgmtProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
@ -419,8 +420,12 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, *DefaultAccoun
ctx := context.WithValue(context.Background(), formatter.ExecutionContextKey, formatter.SystemSource) //nolint:staticcheck ctx := context.WithValue(context.Background(), formatter.ExecutionContextKey, formatter.SystemSource) //nolint:staticcheck
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted", accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted",
eventStore, nil, false, MocIntegratedValidator{}) eventStore, nil, false, MocIntegratedValidator{}, metrics)
if err != nil { if err != nil {
return nil, nil, "", err return nil, nil, "", err
} }

View File

@ -26,6 +26,7 @@ import (
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
@ -541,8 +542,13 @@ func startServer(config *server.Config) (*grpc.Server, net.Listener) {
peersUpdateManager := server.NewPeersUpdateManager(nil) peersUpdateManager := server.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted",
eventStore, nil, false, MocIntegratedValidator{}) metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
if err != nil {
log.Fatalf("failed creating metrics: %v", err)
}
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics)
if err != nil { if err != nil {
log.Fatalf("failed creating a manager: %v", err) log.Fatalf("failed creating a manager: %v", err)
} }

View File

@ -42,6 +42,7 @@ type MockAccountManager struct {
SaveGroupFunc func(ctx context.Context, accountID, userID string, group *group.Group) error SaveGroupFunc func(ctx context.Context, accountID, userID string, group *group.Group) error
SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*group.Group) error SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*group.Group) error
DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error
DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error
ListGroupsFunc func(ctx context.Context, accountID string) ([]*group.Group, error) ListGroupsFunc func(ctx context.Context, accountID string) ([]*group.Group, error)
GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
@ -67,6 +68,7 @@ type MockAccountManager struct {
SaveOrAddUserFunc func(ctx context.Context, accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error) SaveOrAddUserFunc func(ctx context.Context, accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error)
SaveOrAddUsersFunc func(ctx context.Context, accountID, initiatorUserID string, update []*server.User, addIfNotExists bool) ([]*server.UserInfo, error) SaveOrAddUsersFunc func(ctx context.Context, accountID, initiatorUserID string, update []*server.User, addIfNotExists bool) ([]*server.UserInfo, error)
DeleteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error DeleteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error
DeleteRegularUsersFunc func(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string) error
CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
DeletePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) error DeletePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) error
GetPATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error) GetPATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error)
@ -326,6 +328,14 @@ func (am *MockAccountManager) DeleteGroup(ctx context.Context, accountId, userId
return status.Errorf(codes.Unimplemented, "method DeleteGroup is not implemented") return status.Errorf(codes.Unimplemented, "method DeleteGroup is not implemented")
} }
// DeleteGroups mock implementation of DeleteGroups from server.AccountManager interface
func (am *MockAccountManager) DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error {
if am.DeleteGroupsFunc != nil {
return am.DeleteGroupsFunc(ctx, accountId, userId, groupIDs)
}
return status.Errorf(codes.Unimplemented, "method DeleteGroups is not implemented")
}
// ListGroups mock implementation of ListGroups from server.AccountManager interface // ListGroups mock implementation of ListGroups from server.AccountManager interface
func (am *MockAccountManager) ListGroups(ctx context.Context, accountID string) ([]*group.Group, error) { func (am *MockAccountManager) ListGroups(ctx context.Context, accountID string) ([]*group.Group, error) {
if am.ListGroupsFunc != nil { if am.ListGroupsFunc != nil {
@ -528,6 +538,14 @@ func (am *MockAccountManager) DeleteUser(ctx context.Context, accountID string,
return status.Errorf(codes.Unimplemented, "method DeleteUser is not implemented") return status.Errorf(codes.Unimplemented, "method DeleteUser is not implemented")
} }
// DeleteRegularUsers mocks DeleteRegularUsers of the AccountManager interface
func (am *MockAccountManager) DeleteRegularUsers(ctx context.Context, accountID string, initiatorUserID string, targetUserIDs []string) error {
if am.DeleteRegularUsersFunc != nil {
return am.DeleteRegularUsersFunc(ctx, accountID, initiatorUserID, targetUserIDs)
}
return status.Errorf(codes.Unimplemented, "method DeleteRegularUsers is not implemented")
}
func (am *MockAccountManager) InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error { func (am *MockAccountManager) InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error {
if am.InviteUserFunc != nil { if am.InviteUserFunc != nil {
return am.InviteUserFunc(ctx, accountID, initiatorUserID, targetUserID) return am.InviteUserFunc(ctx, accountID, initiatorUserID, targetUserID)

View File

@ -11,6 +11,7 @@ import (
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group" nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/telemetry"
) )
const ( const (
@ -762,7 +763,11 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
return nil, err return nil, err
} }
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{})
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics)
} }
func createNSStore(t *testing.T) (Store, error) { func createNSStore(t *testing.T) (Store, error) {

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"net" "net"
"strings" "strings"
"sync"
"time" "time"
"github.com/rs/xid" "github.com/rs/xid"
@ -65,12 +66,14 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
peers := make([]*nbpeer.Peer, 0) peers := make([]*nbpeer.Peer, 0)
peersMap := make(map[string]*nbpeer.Peer) peersMap := make(map[string]*nbpeer.Peer)
if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked { regularUser := !user.HasAdminPower() && !user.IsServiceUser
if regularUser && account.Settings.RegularUsersViewBlocked {
return peers, nil return peers, nil
} }
for _, peer := range account.Peers { for _, peer := range account.Peers {
if !(user.HasAdminPower() || user.IsServiceUser) && user.Id != peer.UserID { if regularUser && user.Id != peer.UserID {
// only display peers that belong to the current user if the current user is not an admin // only display peers that belong to the current user if the current user is not an admin
continue continue
} }
@ -79,6 +82,10 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
peersMap[peer.ID] = p peersMap[peer.ID] = p
} }
if !regularUser {
return peers, nil
}
// fetch all the peers that have access to the user's peers // fetch all the peers that have access to the user's peers
for _, peer := range peers { for _, peer := range peers {
aclPeers, _ := account.getPeerConnectionResources(ctx, peer.ID, approvedPeersMap) aclPeers, _ := account.getPeerConnectionResources(ctx, peer.ID, approvedPeersMap)
@ -316,7 +323,8 @@ func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID strin
if err != nil { if err != nil {
return nil, err return nil, err
} }
return account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, validatedPeers), nil customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
return account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, nil), nil
} }
// GetPeerNetwork returns the Network for a given peer // GetPeerNetwork returns the Network for a given peer
@ -529,7 +537,8 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
} }
postureChecks := am.getPeerPostureChecks(account, peer) postureChecks := am.getPeerPostureChecks(account, peer)
networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, am.dnsDomain, approvedPeersMap) customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
return newPeer, networkMap, postureChecks, nil return newPeer, networkMap, postureChecks, nil
} }
@ -540,16 +549,25 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
return nil, nil, nil, status.NewPeerNotRegisteredError() return nil, nil, nil, status.NewPeerNotRegisteredError()
} }
err = checkIfPeerOwnerIsBlocked(peer, account) if peer.UserID != "" {
log.Infof("Peer has no userID")
user, err := account.FindUser(peer.UserID)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
err = checkIfPeerOwnerIsBlocked(peer, user)
if err != nil {
return nil, nil, nil, err
}
}
if peerLoginExpired(ctx, peer, account.Settings) { if peerLoginExpired(ctx, peer, account.Settings) {
return nil, nil, nil, status.NewPeerLoginExpiredError() return nil, nil, nil, status.NewPeerLoginExpiredError()
} }
peer, updated := updatePeerMeta(peer, sync.Meta, account) updated := peer.UpdateMetaIfNew(sync.Meta)
if updated { if updated {
err = am.Store.SavePeer(ctx, account.Id, peer) err = am.Store.SavePeer(ctx, account.Id, peer)
if err != nil { if err != nil {
@ -585,7 +603,8 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
} }
postureChecks = am.getPeerPostureChecks(account, peer) postureChecks = am.getPeerPostureChecks(account, peer)
return peer, account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, validPeersMap), postureChecks, nil customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
} }
// LoginPeer logs in or registers a peer. // LoginPeer logs in or registers a peer.
@ -614,31 +633,28 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
// it means that the client has already checked if it needs login and had been through the SSO flow // it means that the client has already checked if it needs login and had been through the SSO flow
// so, we can skip this check and directly proceed with the login // so, we can skip this check and directly proceed with the login
if login.UserID == "" { if login.UserID == "" {
log.Info("Peer needs login")
err = am.checkIFPeerNeedsLoginWithoutLock(ctx, accountID, login) err = am.checkIFPeerNeedsLoginWithoutLock(ctx, accountID, login)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
} }
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) unlockAccount := am.Store.AcquireReadLockByUID(ctx, accountID)
defer unlockAccount()
unlockPeer := am.Store.AcquireWriteLockByUID(ctx, login.WireGuardPubKey)
defer func() { defer func() {
if unlock != nil { if unlockPeer != nil {
unlock() unlockPeer()
} }
}() }()
// fetch the account from the store once more after acquiring lock to avoid concurrent updates inconsistencies peer, err := am.Store.GetPeerByPeerPubKey(ctx, login.WireGuardPubKey)
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
peer, err := account.FindPeerByPubKey(login.WireGuardPubKey) settings, err := am.Store.GetAccountSettings(ctx, accountID)
if err != nil {
return nil, nil, nil, status.NewPeerNotRegisteredError()
}
err = checkIfPeerOwnerIsBlocked(peer, account)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
@ -646,21 +662,39 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
// this flag prevents unnecessary calls to the persistent store. // this flag prevents unnecessary calls to the persistent store.
shouldStorePeer := false shouldStorePeer := false
updateRemotePeers := false updateRemotePeers := false
if peerLoginExpired(ctx, peer, account.Settings) {
err = am.handleExpiredPeer(ctx, login, account, peer) if login.UserID != "" {
changed, err := am.handleUserPeer(ctx, peer, settings)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
updateRemotePeers = true if changed {
shouldStorePeer = true shouldStorePeer = true
updateRemotePeers = true
}
} }
isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra) groups, err := am.Store.GetAccountGroups(ctx, accountID)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
peer, updated := updatePeerMeta(peer, login.Meta, account) var grps []string
for _, group := range groups {
for _, id := range group.Peers {
if id == peer.ID {
grps = append(grps, group.ID)
break
}
}
}
isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, grps, settings.Extra)
if err != nil {
return nil, nil, nil, err
}
updated := peer.UpdateMetaIfNew(login.Meta)
if updated { if updated {
shouldStorePeer = true shouldStorePeer = true
} }
@ -677,8 +711,13 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
} }
} }
unlock() unlockPeer()
unlock = nil unlockPeer = nil
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, nil, nil, err
}
if updateRemotePeers || isStatusChanged { if updateRemotePeers || isStatusChanged {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, account)
@ -732,39 +771,34 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
} }
postureChecks = am.getPeerPostureChecks(account, peer) postureChecks = am.getPeerPostureChecks(account, peer)
return peer, account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, approvedPeersMap), postureChecks, nil customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
} }
func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, login PeerLogin, account *Account, peer *nbpeer.Peer) error { func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *User, peer *nbpeer.Peer) error {
err := checkAuth(ctx, login.UserID, peer) err := checkAuth(ctx, user.Id, peer)
if err != nil { if err != nil {
return err return err
} }
// If peer was expired before and if it reached this point, it is re-authenticated. // If peer was expired before and if it reached this point, it is re-authenticated.
// UserID is present, meaning that JWT validation passed successfully in the API layer. // UserID is present, meaning that JWT validation passed successfully in the API layer.
updatePeerLastLogin(peer, account) peer = peer.UpdateLastLogin()
err = am.Store.SavePeer(ctx, peer.AccountID, peer)
// sync user last login with peer last login
user, err := account.FindUser(login.UserID)
if err != nil {
return status.Errorf(status.Internal, "couldn't find user")
}
err = am.Store.SaveUserLastLogin(account.Id, user.Id, peer.LastLogin)
if err != nil { if err != nil {
return err return err
} }
am.StoreEvent(ctx, login.UserID, peer.ID, account.Id, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain())) err = am.Store.SaveUserLastLogin(user.AccountID, user.Id, peer.LastLogin)
if err != nil {
return err
}
am.StoreEvent(ctx, user.Id, peer.ID, user.AccountID, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain()))
return nil return nil
} }
func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, account *Account) error { func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, user *User) error {
if peer.AddedWithSSOLogin() { if peer.AddedWithSSOLogin() {
user, err := account.FindUser(peer.UserID)
if err != nil {
return status.Errorf(status.PermissionDenied, "user doesn't exist")
}
if user.IsBlocked() { if user.IsBlocked() {
return status.Errorf(status.PermissionDenied, "user is blocked") return status.Errorf(status.PermissionDenied, "user is blocked")
} }
@ -794,11 +828,6 @@ func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *Settings
return false return false
} }
func updatePeerLastLogin(peer *nbpeer.Peer, account *Account) {
peer.UpdateLastLogin()
account.UpdatePeer(peer)
}
// UpdatePeerSSHKey updates peer's public SSH key // UpdatePeerSSHKey updates peer's public SSH key
func (am *DefaultAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error { func (am *DefaultAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error {
if sshKey == "" { if sshKey == "" {
@ -897,33 +926,48 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID,
return nil, status.Errorf(status.Internal, "user %s has no access to peer %s under account %s", userID, peerID, accountID) return nil, status.Errorf(status.Internal, "user %s has no access to peer %s under account %s", userID, peerID, accountID)
} }
func updatePeerMeta(peer *nbpeer.Peer, meta nbpeer.PeerSystemMeta, account *Account) (*nbpeer.Peer, bool) {
if peer.UpdateMetaIfNew(meta) {
account.UpdatePeer(peer)
return peer, true
}
return peer, false
}
// updateAccountPeers updates all peers that belong to an account. // updateAccountPeers updates all peers that belong to an account.
// Should be called when changes have to be synced to peers. // Should be called when changes have to be synced to peers.
func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account *Account) { func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account *Account) {
start := time.Now()
defer func() {
if am.metrics != nil {
am.metrics.AccountManagerMetrics().CountUpdateAccountPeersDuration(time.Since(start))
}
}()
peers := account.GetPeers() peers := account.GetPeers()
approvedPeersMap, err := am.GetValidatedPeers(account) approvedPeersMap, err := am.GetValidatedPeers(account)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed send out updates to peers, failed to validate peer: %v", err) log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to validate peer: %v", err)
return return
} }
var wg sync.WaitGroup
semaphore := make(chan struct{}, 10)
dnsCache := &DNSConfigCache{}
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
for _, peer := range peers { for _, peer := range peers {
if !am.peersUpdateManager.HasChannel(peer.ID) { if !am.peersUpdateManager.HasChannel(peer.ID) {
log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID) log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
continue continue
} }
postureChecks := am.getPeerPostureChecks(account, peer) wg.Add(1)
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, approvedPeersMap) semaphore <- struct{}{}
update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks) go func(p *nbpeer.Peer) {
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update}) defer wg.Done()
defer func() { <-semaphore }()
postureChecks := am.getPeerPostureChecks(account, p)
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache)
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update})
}(peer)
} }
wg.Wait()
} }

View File

@ -1,7 +1,6 @@
package peer package peer
import ( import (
"fmt"
"net" "net"
"net/netip" "net/netip"
"slices" "slices"
@ -241,7 +240,7 @@ func (p *Peer) FQDN(dnsDomain string) string {
if dnsDomain == "" { if dnsDomain == "" {
return "" return ""
} }
return fmt.Sprintf("%s.%s", p.DNSLabel, dnsDomain) return p.DNSLabel + "." + dnsDomain
} }
// EventMeta returns activity event meta related to the peer // EventMeta returns activity event meta related to the peer

View File

@ -0,0 +1,31 @@
package peer
import (
"fmt"
"testing"
)
// FQDNOld is the original implementation for benchmarking purposes
func (p *Peer) FQDNOld(dnsDomain string) string {
if dnsDomain == "" {
return ""
}
return fmt.Sprintf("%s.%s", p.DNSLabel, dnsDomain)
}
func BenchmarkFQDN(b *testing.B) {
p := &Peer{DNSLabel: "test-peer"}
dnsDomain := "example.com"
b.Run("Old", func(b *testing.B) {
for i := 0; i < b.N; i++ {
p.FQDNOld(dnsDomain)
}
})
b.Run("New", func(b *testing.B) {
for i := 0; i < b.N; i++ {
p.FQDN(dnsDomain)
}
})
}

View File

@ -2,15 +2,26 @@ package server
import ( import (
"context" "context"
"fmt"
"io"
"net"
"net/netip"
"os"
"testing" "testing"
"time" "time"
"github.com/rs/xid" "github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/proto"
nbgroup "github.com/netbirdio/netbird/management/server/group" nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
nbroute "github.com/netbirdio/netbird/route"
) )
func TestPeer_LoginExpired(t *testing.T) { func TestPeer_LoginExpired(t *testing.T) {
@ -633,3 +644,354 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
} }
} }
func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccountManager, string, string, error) {
b.Helper()
manager, err := createManager(b)
if err != nil {
return nil, "", "", err
}
accountID := "test_account"
adminUser := "account_creator"
regularUser := "regular_user"
account := newAccountWithId(context.Background(), accountID, adminUser, "")
account.Users[regularUser] = &User{
Id: regularUser,
Role: UserRoleUser,
}
// Create peers
for i := 0; i < peers; i++ {
peerKey, _ := wgtypes.GeneratePrivateKey()
peer := &nbpeer.Peer{
ID: fmt.Sprintf("peer-%d", i),
DNSLabel: fmt.Sprintf("peer-%d", i),
Key: peerKey.PublicKey().String(),
IP: net.ParseIP(fmt.Sprintf("100.64.%d.%d", i/256, i%256)),
Status: &nbpeer.PeerStatus{},
UserID: regularUser,
}
account.Peers[peer.ID] = peer
}
// Create groups and policies
account.Policies = make([]*Policy, 0, groups)
for i := 0; i < groups; i++ {
groupID := fmt.Sprintf("group-%d", i)
group := &nbgroup.Group{
ID: groupID,
Name: fmt.Sprintf("Group %d", i),
}
for j := 0; j < peers/groups; j++ {
peerIndex := i*(peers/groups) + j
group.Peers = append(group.Peers, fmt.Sprintf("peer-%d", peerIndex))
}
account.Groups[groupID] = group
// Create a policy for this group
policy := &Policy{
ID: fmt.Sprintf("policy-%d", i),
Name: fmt.Sprintf("Policy for Group %d", i),
Enabled: true,
Rules: []*PolicyRule{
{
ID: fmt.Sprintf("rule-%d", i),
Name: fmt.Sprintf("Rule for Group %d", i),
Enabled: true,
Sources: []string{groupID},
Destinations: []string{groupID},
Bidirectional: true,
Protocol: PolicyRuleProtocolALL,
Action: PolicyTrafficActionAccept,
},
},
}
account.Policies = append(account.Policies, policy)
}
account.PostureChecks = []*posture.Checks{
{
ID: "PostureChecksAll",
Name: "All",
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.0.1",
},
},
},
}
err = manager.Store.SaveAccount(context.Background(), account)
if err != nil {
return nil, "", "", err
}
return manager, accountID, regularUser, nil
}
func BenchmarkGetPeers(b *testing.B) {
benchCases := []struct {
name string
peers int
groups int
}{
{"Small", 50, 5},
{"Medium", 500, 10},
{"Large", 5000, 20},
{"Small single", 50, 1},
{"Medium single", 500, 1},
{"Large 5", 5000, 5},
}
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) {
manager, accountID, userID, err := setupTestAccountManager(b, bc.peers, bc.groups)
if err != nil {
b.Fatalf("Failed to setup test account manager: %v", err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := manager.GetPeers(context.Background(), accountID, userID)
if err != nil {
b.Fatalf("GetPeers failed: %v", err)
}
}
})
}
}
func BenchmarkUpdateAccountPeers(b *testing.B) {
benchCases := []struct {
name string
peers int
groups int
}{
{"Small", 50, 5},
{"Medium", 500, 10},
{"Large", 5000, 20},
{"Small single", 50, 1},
{"Medium single", 500, 1},
{"Large 5", 5000, 5},
}
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) {
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
if err != nil {
b.Fatalf("Failed to setup test account manager: %v", err)
}
ctx := context.Background()
account, err := manager.Store.GetAccount(ctx, accountID)
if err != nil {
b.Fatalf("Failed to get account: %v", err)
}
peerChannels := make(map[string]chan *UpdateMessage)
for peerID := range account.Peers {
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
}
manager.peersUpdateManager.peerChannels = peerChannels
b.ResetTimer()
start := time.Now()
for i := 0; i < b.N; i++ {
manager.updateAccountPeers(ctx, account)
}
duration := time.Since(start)
b.ReportMetric(float64(duration.Nanoseconds())/float64(b.N)/1e6, "ms/op")
b.ReportMetric(0, "ns/op")
})
}
}
func TestToSyncResponse(t *testing.T) {
_, ipnet, err := net.ParseCIDR("192.168.1.0/24")
if err != nil {
t.Fatal(err)
}
domainList, err := domain.FromStringList([]string{"example.com"})
if err != nil {
t.Fatal(err)
}
config := &Config{
Signal: &Host{
Proto: "https",
URI: "signal.uri",
Username: "",
Password: "",
},
Stuns: []*Host{{URI: "stun.uri", Proto: UDP}},
TURNConfig: &TURNConfig{
Turns: []*Host{{URI: "turn.uri", Proto: UDP, Username: "turn-user", Password: "turn-pass"}},
},
}
peer := &nbpeer.Peer{
IP: net.ParseIP("192.168.1.1"),
SSHEnabled: true,
Key: "peer-key",
DNSLabel: "peer1",
SSHKey: "peer1-ssh-key",
}
turnCredentials := &TURNCredentials{
Username: "turn-user",
Password: "turn-pass",
}
networkMap := &NetworkMap{
Network: &Network{Net: *ipnet, Serial: 1000},
Peers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.2"), Key: "peer2-key", DNSLabel: "peer2", SSHEnabled: true, SSHKey: "peer2-ssh-key"}},
OfflinePeers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.3"), Key: "peer3-key", DNSLabel: "peer3", SSHEnabled: true, SSHKey: "peer3-ssh-key"}},
Routes: []*nbroute.Route{
{
ID: "route1",
Network: netip.MustParsePrefix("10.0.0.0/24"),
Domains: domainList,
KeepRoute: true,
NetID: "route1",
Peer: "peer1",
NetworkType: 1,
Masquerade: true,
Metric: 9999,
Enabled: true,
},
},
DNSConfig: nbdns.Config{
ServiceEnable: true,
NameServerGroups: []*nbdns.NameServerGroup{
{
NameServers: []nbdns.NameServer{{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
}},
Primary: true,
Domains: []string{"example.com"},
Enabled: true,
SearchDomainsEnabled: true,
},
{
ID: "ns1",
NameServers: []nbdns.NameServer{{
IP: netip.MustParseAddr("1.1.1.1"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
}},
Groups: []string{"group1"},
Primary: true,
Domains: []string{"example.com"},
Enabled: true,
SearchDomainsEnabled: true,
},
},
CustomZones: []nbdns.CustomZone{{Domain: "example.com", Records: []nbdns.SimpleRecord{{Name: "example.com", Type: 1, Class: "IN", TTL: 60, RData: "100.64.0.1"}}}},
},
FirewallRules: []*FirewallRule{
{PeerIP: "192.168.1.2", Direction: firewallRuleDirectionIN, Action: string(PolicyTrafficActionAccept), Protocol: string(PolicyRuleProtocolTCP), Port: "80"},
},
}
dnsName := "example.com"
checks := []*posture.Checks{
{
Checks: posture.ChecksDefinition{
ProcessCheck: &posture.ProcessCheck{
Processes: []posture.Process{{LinuxPath: "/usr/bin/netbird"}},
},
},
},
}
dnsCache := &DNSConfigCache{}
response := toSyncResponse(context.Background(), config, peer, turnCredentials, networkMap, dnsName, checks, dnsCache)
assert.NotNil(t, response)
// assert peer config
assert.Equal(t, "192.168.1.1/24", response.PeerConfig.Address)
assert.Equal(t, "peer1.example.com", response.PeerConfig.Fqdn)
assert.Equal(t, true, response.PeerConfig.SshConfig.SshEnabled)
// assert wiretrustee config
assert.Equal(t, "signal.uri", response.WiretrusteeConfig.Signal.Uri)
assert.Equal(t, proto.HostConfig_HTTPS, response.WiretrusteeConfig.Signal.GetProtocol())
assert.Equal(t, "stun.uri", response.WiretrusteeConfig.Stuns[0].Uri)
assert.Equal(t, "turn.uri", response.WiretrusteeConfig.Turns[0].HostConfig.GetUri())
assert.Equal(t, "turn-user", response.WiretrusteeConfig.Turns[0].User)
assert.Equal(t, "turn-pass", response.WiretrusteeConfig.Turns[0].Password)
// assert RemotePeers
assert.Equal(t, 1, len(response.RemotePeers))
assert.Equal(t, "192.168.1.2/32", response.RemotePeers[0].AllowedIps[0])
assert.Equal(t, "peer2-key", response.RemotePeers[0].WgPubKey)
assert.Equal(t, "peer2.example.com", response.RemotePeers[0].GetFqdn())
assert.Equal(t, false, response.RemotePeers[0].GetSshConfig().GetSshEnabled())
assert.Equal(t, []byte("peer2-ssh-key"), response.RemotePeers[0].GetSshConfig().GetSshPubKey())
// assert network map
assert.Equal(t, uint64(1000), response.NetworkMap.Serial)
assert.Equal(t, "192.168.1.1/24", response.NetworkMap.PeerConfig.Address)
assert.Equal(t, "peer1.example.com", response.NetworkMap.PeerConfig.Fqdn)
assert.Equal(t, true, response.NetworkMap.PeerConfig.SshConfig.SshEnabled)
// assert network map RemotePeers
assert.Equal(t, 1, len(response.NetworkMap.RemotePeers))
assert.Equal(t, "192.168.1.2/32", response.NetworkMap.RemotePeers[0].AllowedIps[0])
assert.Equal(t, "peer2-key", response.NetworkMap.RemotePeers[0].WgPubKey)
assert.Equal(t, "peer2.example.com", response.NetworkMap.RemotePeers[0].GetFqdn())
assert.Equal(t, []byte("peer2-ssh-key"), response.NetworkMap.RemotePeers[0].GetSshConfig().GetSshPubKey())
// assert network map OfflinePeers
assert.Equal(t, 1, len(response.NetworkMap.OfflinePeers))
assert.Equal(t, "192.168.1.3/32", response.NetworkMap.OfflinePeers[0].AllowedIps[0])
assert.Equal(t, "peer3-key", response.NetworkMap.OfflinePeers[0].WgPubKey)
assert.Equal(t, "peer3.example.com", response.NetworkMap.OfflinePeers[0].GetFqdn())
assert.Equal(t, []byte("peer3-ssh-key"), response.NetworkMap.OfflinePeers[0].GetSshConfig().GetSshPubKey())
// assert network map Routes
assert.Equal(t, 1, len(response.NetworkMap.Routes))
assert.Equal(t, "10.0.0.0/24", response.NetworkMap.Routes[0].Network)
assert.Equal(t, "route1", response.NetworkMap.Routes[0].ID)
assert.Equal(t, "peer1", response.NetworkMap.Routes[0].Peer)
assert.Equal(t, "example.com", response.NetworkMap.Routes[0].Domains[0])
assert.Equal(t, true, response.NetworkMap.Routes[0].KeepRoute)
assert.Equal(t, true, response.NetworkMap.Routes[0].Masquerade)
assert.Equal(t, int64(9999), response.NetworkMap.Routes[0].Metric)
assert.Equal(t, int64(1), response.NetworkMap.Routes[0].NetworkType)
assert.Equal(t, "route1", response.NetworkMap.Routes[0].NetID)
// assert network map DNSConfig
assert.Equal(t, true, response.NetworkMap.DNSConfig.ServiceEnable)
assert.Equal(t, 1, len(response.NetworkMap.DNSConfig.CustomZones))
assert.Equal(t, 2, len(response.NetworkMap.DNSConfig.NameServerGroups))
// assert network map DNSConfig.CustomZones
assert.Equal(t, "example.com", response.NetworkMap.DNSConfig.CustomZones[0].Domain)
assert.Equal(t, 1, len(response.NetworkMap.DNSConfig.CustomZones[0].Records))
assert.Equal(t, "example.com", response.NetworkMap.DNSConfig.CustomZones[0].Records[0].Name)
assert.Equal(t, int64(1), response.NetworkMap.DNSConfig.CustomZones[0].Records[0].Type)
assert.Equal(t, "IN", response.NetworkMap.DNSConfig.CustomZones[0].Records[0].Class)
assert.Equal(t, int64(60), response.NetworkMap.DNSConfig.CustomZones[0].Records[0].TTL)
assert.Equal(t, "100.64.0.1", response.NetworkMap.DNSConfig.CustomZones[0].Records[0].RData)
// assert network map DNSConfig.NameServerGroups
assert.Equal(t, true, response.NetworkMap.DNSConfig.NameServerGroups[0].Primary)
assert.Equal(t, true, response.NetworkMap.DNSConfig.NameServerGroups[0].SearchDomainsEnabled)
assert.Equal(t, "example.com", response.NetworkMap.DNSConfig.NameServerGroups[0].Domains[0])
assert.Equal(t, "8.8.8.8", response.NetworkMap.DNSConfig.NameServerGroups[0].NameServers[0].GetIP())
assert.Equal(t, int64(1), response.NetworkMap.DNSConfig.NameServerGroups[0].NameServers[0].GetNSType())
assert.Equal(t, int64(53), response.NetworkMap.DNSConfig.NameServerGroups[0].NameServers[0].GetPort())
// assert network map Firewall
assert.Equal(t, 1, len(response.NetworkMap.FirewallRules))
assert.Equal(t, "192.168.1.2", response.NetworkMap.FirewallRules[0].PeerIP)
assert.Equal(t, proto.FirewallRule_IN, response.NetworkMap.FirewallRules[0].Direction)
assert.Equal(t, proto.FirewallRule_ACCEPT, response.NetworkMap.FirewallRules[0].Action)
assert.Equal(t, proto.FirewallRule_TCP, response.NetworkMap.FirewallRules[0].Protocol)
assert.Equal(t, "80", response.NetworkMap.FirewallRules[0].Port)
// assert posture checks
assert.Equal(t, 1, len(response.Checks))
assert.Equal(t, "/usr/bin/netbird", response.Checks[0].Files[0])
}

View File

@ -213,7 +213,6 @@ type FirewallRule struct {
// //
// This function returns the list of peers and firewall rules that are applicable to a given peer. // This function returns the list of peers and firewall rules that are applicable to a given peer.
func (a *Account) getPeerConnectionResources(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) { func (a *Account) getPeerConnectionResources(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) {
generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx) generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx)
for _, policy := range a.Policies { for _, policy := range a.Policies {
if !policy.Enabled { if !policy.Enabled {
@ -225,8 +224,8 @@ func (a *Account) getPeerConnectionResources(ctx context.Context, peerID string,
continue continue
} }
sourcePeers, peerInSources := getAllPeersFromGroups(ctx, a, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap) sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap)
destinationPeers, peerInDestinations := getAllPeersFromGroups(ctx, a, rule.Destinations, peerID, nil, validatedPeersMap) destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, rule.Destinations, peerID, nil, validatedPeersMap)
if rule.Bidirectional { if rule.Bidirectional {
if peerInSources { if peerInSources {
@ -290,8 +289,8 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule,
fr.PeerIP = "0.0.0.0" fr.PeerIP = "0.0.0.0"
} }
ruleID := (rule.ID + fr.PeerIP + strconv.Itoa(direction) + ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) +
fr.Protocol + fr.Action + strings.Join(rule.Ports, ",")) fr.Protocol + fr.Action + strings.Join(rule.Ports, ",")
if _, ok := rulesExists[ruleID]; ok { if _, ok := rulesExists[ruleID]; ok {
continue continue
} }
@ -491,23 +490,23 @@ func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule {
// //
// Important: Posture checks are applicable only to source group peers, // Important: Posture checks are applicable only to source group peers,
// for destination group peers, call this method with an empty list of sourcePostureChecksIDs // for destination group peers, call this method with an empty list of sourcePostureChecksIDs
func getAllPeersFromGroups(ctx context.Context, account *Account, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) { func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) {
peerInGroups := false peerInGroups := false
filteredPeers := make([]*nbpeer.Peer, 0, len(groups)) filteredPeers := make([]*nbpeer.Peer, 0, len(groups))
for _, g := range groups { for _, g := range groups {
group, ok := account.Groups[g] group, ok := a.Groups[g]
if !ok { if !ok {
continue continue
} }
for _, p := range group.Peers { for _, p := range group.Peers {
peer, ok := account.Peers[p] peer, ok := a.Peers[p]
if !ok || peer == nil { if !ok || peer == nil {
continue continue
} }
// validate the peer based on policy posture checks applied // validate the peer based on policy posture checks applied
isValid := account.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID) isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID)
if !isValid { if !isValid {
continue continue
} }
@ -535,7 +534,7 @@ func (a *Account) validatePostureChecksOnPeer(ctx context.Context, sourcePosture
} }
for _, postureChecksID := range sourcePostureChecksID { for _, postureChecksID := range sourcePostureChecksID {
postureChecks := getPostureChecks(a, postureChecksID) postureChecks := a.getPostureChecks(postureChecksID)
if postureChecks == nil { if postureChecks == nil {
continue continue
} }
@ -553,8 +552,8 @@ func (a *Account) validatePostureChecksOnPeer(ctx context.Context, sourcePosture
return true return true
} }
func getPostureChecks(account *Account, postureChecksID string) *posture.Checks { func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks {
for _, postureChecks := range account.PostureChecks { for _, postureChecks := range a.PostureChecks {
if postureChecks.ID == postureChecksID { if postureChecks.ID == postureChecksID {
return postureChecks return postureChecks
} }

View File

@ -60,7 +60,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
} }
if err := postureChecks.Validate(); err != nil { if err := postureChecks.Validate(); err != nil {
return status.Errorf(status.InvalidArgument, err.Error()) return status.Errorf(status.InvalidArgument, err.Error()) //nolint
} }
exists, uniqName := am.savePostureChecks(account, postureChecks) exists, uniqName := am.savePostureChecks(account, postureChecks)

View File

@ -13,6 +13,7 @@ import (
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group" nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@ -1233,7 +1234,11 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) {
return nil, err return nil, err
} }
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{})
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics)
} }
func createRouterStore(t *testing.T) (Store, error) { func createRouterStore(t *testing.T) (Store, error) {

View File

@ -223,10 +223,8 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
return nil, err return nil, err
} }
for _, group := range autoGroups { if err := validateSetupKeyAutoGroups(account, autoGroups); err != nil {
if _, ok := account.Groups[group]; !ok { return nil, err
return nil, status.Errorf(status.NotFound, "group %s doesn't exist", group)
}
} }
setupKey := GenerateSetupKey(keyName, keyType, keyDuration, autoGroups, usageLimit, ephemeral) setupKey := GenerateSetupKey(keyName, keyType, keyDuration, autoGroups, usageLimit, ephemeral)
@ -279,6 +277,10 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
return nil, status.Errorf(status.NotFound, "setup key not found") return nil, status.Errorf(status.NotFound, "setup key not found")
} }
if err := validateSetupKeyAutoGroups(account, keyToSave.AutoGroups); err != nil {
return nil, err
}
// only auto groups, revoked status, and name can be updated for now // only auto groups, revoked status, and name can be updated for now
newKey := oldKey.Copy() newKey := oldKey.Copy()
newKey.Name = keyToSave.Name newKey.Name = keyToSave.Name
@ -399,3 +401,16 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use
return foundKey, nil return foundKey, nil
} }
func validateSetupKeyAutoGroups(account *Account, autoGroups []string) error {
for _, group := range autoGroups {
g, ok := account.Groups[group]
if !ok {
return status.Errorf(status.NotFound, "group %s doesn't exist", group)
}
if g.Name == "All" {
return status.Errorf(status.InvalidArgument, "can't add All group to the setup key")
}
}
return nil
}

View File

@ -26,10 +26,17 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{
{
ID: "group_1", ID: "group_1",
Name: "group_name_1", Name: "group_name_1",
Peers: []string{}, Peers: []string{},
},
{
ID: "group_2",
Name: "group_name_2",
Peers: []string{},
},
}) })
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -70,6 +77,19 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
assert.NotEmpty(t, ev.Meta["key"]) assert.NotEmpty(t, ev.Meta["key"])
assert.Equal(t, userID, ev.InitiatorID) assert.Equal(t, userID, ev.InitiatorID)
assert.Equal(t, key.Id, ev.TargetID) assert.Equal(t, key.Id, ev.TargetID)
groupAll, err := account.GetGroupAll()
assert.NoError(t, err)
// saving setup key with All group assigned to auto groups should return error
autoGroups = append(autoGroups, groupAll.ID)
_, err = manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{
Id: key.Id,
Name: newKeyName,
Revoked: revoked,
AutoGroups: autoGroups,
}, userID)
assert.Error(t, err, "should not save setup key with All group assigned in auto groups")
} }
func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
@ -102,6 +122,9 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
groupAll, err := account.GetGroupAll()
assert.NoError(t, err)
type testCase struct { type testCase struct {
name string name string
@ -134,8 +157,14 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
expectedGroups: []string{"FAKE"}, expectedGroups: []string{"FAKE"},
expectedFailure: true, expectedFailure: true,
} }
testCase3 := testCase{
name: "Create Setup Key should fail because of All group",
expectedKeyName: "my-test-key",
expectedGroups: []string{groupAll.ID},
expectedFailure: true,
}
for _, tCase := range []testCase{testCase1, testCase2} { for _, tCase := range []testCase{testCase1, testCase2, testCase3} {
t.Run(tCase.name, func(t *testing.T) { t.Run(tCase.name, func(t *testing.T) {
key, err := manager.CreateSetupKey(context.Background(), account.Id, tCase.expectedKeyName, SetupKeyReusable, expiresIn, key, err := manager.CreateSetupKey(context.Background(), account.Id, tCase.expectedKeyName, SetupKeyReusable, expiresIn,
tCase.expectedGroups, SetupKeyUnlimitedUsage, userID, false) tCase.expectedGroups, SetupKeyUnlimitedUsage, userID, false)

View File

@ -468,6 +468,34 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User,
return &user, nil return &user, nil
} }
func (s *SqlStore) GetUserByUserID(ctx context.Context, userID string) (*User, error) {
var user User
result := s.db.First(&user, idQueryCondition, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "user not found: index lookup failed")
}
log.WithContext(ctx).Errorf("error when getting user from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting user from store")
}
return &user, nil
}
func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
var groups []*nbgroup.Group
result := s.db.Find(&groups, idQueryCondition, accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
}
log.WithContext(ctx).Errorf("error when getting groups from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting groups from store")
}
return groups, nil
}
func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*Account) { func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*Account) {
var accounts []Account var accounts []Account
result := s.db.Find(&accounts) result := s.db.Find(&accounts)

View File

@ -41,6 +41,8 @@ type Store interface {
GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error)
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
GetUserByUserID(ctx context.Context, userID string) (*User, error)
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
SaveAccount(ctx context.Context, account *Account) error SaveAccount(ctx context.Context, account *Account) error
SaveUsers(accountID string, users map[string]*User) error SaveUsers(accountID string, users map[string]*User) error

View File

@ -0,0 +1,69 @@
package telemetry
import (
"context"
"time"
"go.opentelemetry.io/otel/metric"
)
// AccountManagerMetrics represents all metrics related to the AccountManager
type AccountManagerMetrics struct {
ctx context.Context
updateAccountPeersDurationMs metric.Float64Histogram
getPeerNetworkMapDurationMs metric.Float64Histogram
networkMapObjectCount metric.Int64Histogram
}
// NewAccountManagerMetrics creates an instance of AccountManagerMetrics
func NewAccountManagerMetrics(ctx context.Context, meter metric.Meter) (*AccountManagerMetrics, error) {
updateAccountPeersDurationMs, err := meter.Float64Histogram("management.account.update.account.peers.duration.ms",
metric.WithUnit("milliseconds"),
metric.WithExplicitBucketBoundaries(
0.5, 1, 2.5, 5, 10, 25, 50, 100, 250, 500, 1000, 2500, 5000, 10000, 30000,
))
if err != nil {
return nil, err
}
getPeerNetworkMapDurationMs, err := meter.Float64Histogram("management.account.get.peer.network.map.duration.ms",
metric.WithUnit("milliseconds"),
metric.WithExplicitBucketBoundaries(
0.1, 0.5, 1, 2.5, 5, 10, 25, 50, 100, 250, 500, 1000,
))
if err != nil {
return nil, err
}
networkMapObjectCount, err := meter.Int64Histogram("management.account.network.map.object.count",
metric.WithUnit("objects"),
metric.WithExplicitBucketBoundaries(
50, 100, 200, 500, 1000, 2500, 5000, 10000,
))
if err != nil {
return nil, err
}
return &AccountManagerMetrics{
ctx: ctx,
getPeerNetworkMapDurationMs: getPeerNetworkMapDurationMs,
updateAccountPeersDurationMs: updateAccountPeersDurationMs,
networkMapObjectCount: networkMapObjectCount,
}, nil
}
// CountUpdateAccountPeersDuration counts the duration of updating account peers
func (metrics *AccountManagerMetrics) CountUpdateAccountPeersDuration(duration time.Duration) {
metrics.updateAccountPeersDurationMs.Record(metrics.ctx, float64(duration.Nanoseconds())/1e6)
}
// CountGetPeerNetworkMapDuration counts the duration of getting the peer network map
func (metrics *AccountManagerMetrics) CountGetPeerNetworkMapDuration(duration time.Duration) {
metrics.getPeerNetworkMapDurationMs.Record(metrics.ctx, float64(duration.Nanoseconds())/1e6)
}
// CountNetworkMapObjects counts the number of network map objects
func (metrics *AccountManagerMetrics) CountNetworkMapObjects(count int64) {
metrics.networkMapObjectCount.Record(metrics.ctx, count)
}

View File

@ -28,6 +28,7 @@ type MockAppMetrics struct {
GRPCMetricsFunc func() *GRPCMetrics GRPCMetricsFunc func() *GRPCMetrics
StoreMetricsFunc func() *StoreMetrics StoreMetricsFunc func() *StoreMetrics
UpdateChannelMetricsFunc func() *UpdateChannelMetrics UpdateChannelMetricsFunc func() *UpdateChannelMetrics
AddAccountManagerMetricsFunc func() *AccountManagerMetrics
} }
// GetMeter mocks the GetMeter function of the AppMetrics interface // GetMeter mocks the GetMeter function of the AppMetrics interface
@ -94,6 +95,14 @@ func (mock *MockAppMetrics) UpdateChannelMetrics() *UpdateChannelMetrics {
return nil return nil
} }
// AccountManagerMetrics mocks the MockAppMetrics function of the AccountManagerMetrics interface
func (mock *MockAppMetrics) AccountManagerMetrics() *AccountManagerMetrics {
if mock.AddAccountManagerMetricsFunc != nil {
return mock.AddAccountManagerMetricsFunc()
}
return nil
}
// AppMetrics is metrics interface // AppMetrics is metrics interface
type AppMetrics interface { type AppMetrics interface {
GetMeter() metric2.Meter GetMeter() metric2.Meter
@ -104,6 +113,7 @@ type AppMetrics interface {
GRPCMetrics() *GRPCMetrics GRPCMetrics() *GRPCMetrics
StoreMetrics() *StoreMetrics StoreMetrics() *StoreMetrics
UpdateChannelMetrics() *UpdateChannelMetrics UpdateChannelMetrics() *UpdateChannelMetrics
AccountManagerMetrics() *AccountManagerMetrics
} }
// defaultAppMetrics are core application metrics based on OpenTelemetry https://opentelemetry.io/ // defaultAppMetrics are core application metrics based on OpenTelemetry https://opentelemetry.io/
@ -117,6 +127,7 @@ type defaultAppMetrics struct {
grpcMetrics *GRPCMetrics grpcMetrics *GRPCMetrics
storeMetrics *StoreMetrics storeMetrics *StoreMetrics
updateChannelMetrics *UpdateChannelMetrics updateChannelMetrics *UpdateChannelMetrics
accountManagerMetrics *AccountManagerMetrics
} }
// IDPMetrics returns metrics for the idp package // IDPMetrics returns metrics for the idp package
@ -144,6 +155,11 @@ func (appMetrics *defaultAppMetrics) UpdateChannelMetrics() *UpdateChannelMetric
return appMetrics.updateChannelMetrics return appMetrics.updateChannelMetrics
} }
// AccountManagerMetrics returns metrics for the account manager
func (appMetrics *defaultAppMetrics) AccountManagerMetrics() *AccountManagerMetrics {
return appMetrics.accountManagerMetrics
}
// Close stop application metrics HTTP handler and closes listener. // Close stop application metrics HTTP handler and closes listener.
func (appMetrics *defaultAppMetrics) Close() error { func (appMetrics *defaultAppMetrics) Close() error {
if appMetrics.listener == nil { if appMetrics.listener == nil {
@ -220,6 +236,11 @@ func NewDefaultAppMetrics(ctx context.Context) (AppMetrics, error) {
return nil, err return nil, err
} }
accountManagerMetrics, err := NewAccountManagerMetrics(ctx, meter)
if err != nil {
return nil, err
}
return &defaultAppMetrics{ return &defaultAppMetrics{
Meter: meter, Meter: meter,
ctx: ctx, ctx: ctx,
@ -228,5 +249,6 @@ func NewDefaultAppMetrics(ctx context.Context) (AppMetrics, error) {
grpcMetrics: grpcMetrics, grpcMetrics: grpcMetrics,
storeMetrics: storeMetrics, storeMetrics: storeMetrics,
updateChannelMetrics: updateChannelMetrics, updateChannelMetrics: updateChannelMetrics,
accountManagerMetrics: accountManagerMetrics,
}, nil }, nil
} }

View File

@ -2,6 +2,7 @@ package server
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"strings" "strings"
"time" "time"
@ -472,51 +473,18 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init
} }
func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account *Account, initiatorUserID, targetUserID string) error { func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account *Account, initiatorUserID, targetUserID string) error {
tuEmail, tuName, err := am.getEmailAndNameOfTargetUser(ctx, account.Id, initiatorUserID, targetUserID) meta, err := am.prepareUserDeletion(ctx, account, initiatorUserID, targetUserID)
if err != nil {
log.WithContext(ctx).Errorf("failed to resolve email address: %s", err)
return err
}
if !isNil(am.idpManager) {
// Delete if the user already exists in the IdP.Necessary in cases where a user account
// was created where a user account was provisioned but the user did not sign in
_, err = am.idpManager.GetUserDataByID(ctx, targetUserID, idp.AppMetadata{WTAccountID: account.Id})
if err == nil {
err = am.deleteUserFromIDP(ctx, targetUserID, account.Id)
if err != nil {
log.WithContext(ctx).Debugf("failed to delete user from IDP: %s", targetUserID)
return err
}
} else {
log.WithContext(ctx).Debugf("skipped deleting user %s from IDP, error: %v", targetUserID, err)
}
}
err = am.deleteUserPeers(ctx, initiatorUserID, targetUserID, account)
if err != nil { if err != nil {
return err return err
} }
u, err := account.FindUser(targetUserID)
if err != nil {
log.WithContext(ctx).Errorf("failed to find user %s for deletion, this should never happen: %s", targetUserID, err)
}
var tuCreatedAt time.Time
if u != nil {
tuCreatedAt = u.CreatedAt
}
delete(account.Users, targetUserID) delete(account.Users, targetUserID)
err = am.Store.SaveAccount(ctx, account) err = am.Store.SaveAccount(ctx, account)
if err != nil { if err != nil {
return err return err
} }
meta := map[string]any{"name": tuName, "email": tuEmail, "created_at": tuCreatedAt}
am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta) am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta)
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, account)
return nil return nil
@ -976,10 +944,14 @@ func validateUserUpdate(account *Account, initiatorUser, oldUser, update *User)
} }
for _, newGroupID := range update.AutoGroups { for _, newGroupID := range update.AutoGroups {
if _, ok := account.Groups[newGroupID]; !ok { group, ok := account.Groups[newGroupID]
if !ok {
return status.Errorf(status.InvalidArgument, "provided group ID %s in the user %s update doesn't exist", return status.Errorf(status.InvalidArgument, "provided group ID %s in the user %s update doesn't exist",
newGroupID, update.Id) newGroupID, update.Id)
} }
if group.Name == "All" {
return status.Errorf(status.InvalidArgument, "can't add All group to the user")
}
} }
return nil return nil
@ -1190,6 +1162,116 @@ func (am *DefaultAccountManager) getEmailAndNameOfTargetUser(ctx context.Context
return "", "", fmt.Errorf("user info not found for user: %s", targetId) return "", "", fmt.Errorf("user info not found for user: %s", targetId)
} }
// DeleteRegularUsers deletes regular users from an account.
// Note: This function does not acquire the global lock.
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
//
// If an error occurs while deleting the user, the function skips it and continues deleting other users.
// Errors are collected and returned at the end.
func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string) error {
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
executingUser := account.Users[initiatorUserID]
if executingUser == nil {
return status.Errorf(status.NotFound, "user not found")
}
if !executingUser.HasAdminPower() {
return status.Errorf(status.PermissionDenied, "only users with admin power can delete users")
}
var allErrors error
deletedUsersMeta := make(map[string]map[string]any)
for _, targetUserID := range targetUserIDs {
if initiatorUserID == targetUserID {
allErrors = errors.Join(allErrors, errors.New("self deletion is not allowed"))
continue
}
targetUser := account.Users[targetUserID]
if targetUser == nil {
allErrors = errors.Join(allErrors, fmt.Errorf("target user: %s not found", targetUserID))
continue
}
if targetUser.Role == UserRoleOwner {
allErrors = errors.Join(allErrors, fmt.Errorf("unable to delete a user: %s with owner role", targetUserID))
continue
}
// disable deleting integration user if the initiator is not admin service user
if targetUser.Issued == UserIssuedIntegration && !executingUser.IsServiceUser {
allErrors = errors.Join(allErrors, errors.New("only integration service user can delete this user"))
continue
}
meta, err := am.prepareUserDeletion(ctx, account, initiatorUserID, targetUserID)
if err != nil {
allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete user %s: %s", targetUserID, err))
continue
}
delete(account.Users, targetUserID)
deletedUsersMeta[targetUserID] = meta
}
err = am.Store.SaveAccount(ctx, account)
if err != nil {
return fmt.Errorf("failed to delete users: %w", err)
}
am.updateAccountPeers(ctx, account)
for targetUserID, meta := range deletedUsersMeta {
am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta)
}
return allErrors
}
func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, account *Account, initiatorUserID, targetUserID string) (map[string]any, error) {
tuEmail, tuName, err := am.getEmailAndNameOfTargetUser(ctx, account.Id, initiatorUserID, targetUserID)
if err != nil {
log.WithContext(ctx).Errorf("failed to resolve email address: %s", err)
return nil, err
}
if !isNil(am.idpManager) {
// Delete if the user already exists in the IdP. Necessary in cases where a user account
// was created where a user account was provisioned but the user did not sign in
_, err = am.idpManager.GetUserDataByID(ctx, targetUserID, idp.AppMetadata{WTAccountID: account.Id})
if err == nil {
err = am.deleteUserFromIDP(ctx, targetUserID, account.Id)
if err != nil {
log.WithContext(ctx).Debugf("failed to delete user from IDP: %s", targetUserID)
return nil, err
}
} else {
log.WithContext(ctx).Debugf("skipped deleting user %s from IDP, error: %v", targetUserID, err)
}
}
err = am.deleteUserPeers(ctx, initiatorUserID, targetUserID, account)
if err != nil {
return nil, err
}
u, err := account.FindUser(targetUserID)
if err != nil {
log.WithContext(ctx).Errorf("failed to find user %s for deletion, this should never happen: %s", targetUserID, err)
}
var tuCreatedAt time.Time
if u != nil {
tuCreatedAt = u.CreatedAt
}
return map[string]any{"name": tuName, "email": tuEmail, "created_at": tuCreatedAt}, nil
}
func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) { func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) {
for _, user := range userData { for _, user := range userData {
if user.ID == userID { if user.ID == userID {

View File

@ -662,6 +662,157 @@ func TestUser_DeleteUser_regularUser(t *testing.T) {
} }
func TestUser_DeleteUser_RegularUsers(t *testing.T) {
store := newStore(t)
defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
targetId := "user2"
account.Users[targetId] = &User{
Id: targetId,
IsServiceUser: true,
ServiceUserName: "user2username",
}
targetId = "user3"
account.Users[targetId] = &User{
Id: targetId,
IsServiceUser: false,
Issued: UserIssuedAPI,
}
targetId = "user4"
account.Users[targetId] = &User{
Id: targetId,
IsServiceUser: false,
Issued: UserIssuedIntegration,
}
targetId = "user5"
account.Users[targetId] = &User{
Id: targetId,
IsServiceUser: false,
Issued: UserIssuedAPI,
Role: UserRoleOwner,
}
account.Users["user6"] = &User{
Id: "user6",
IsServiceUser: false,
Issued: UserIssuedAPI,
}
account.Users["user7"] = &User{
Id: "user7",
IsServiceUser: false,
Issued: UserIssuedAPI,
}
account.Users["user8"] = &User{
Id: "user8",
IsServiceUser: false,
Issued: UserIssuedAPI,
Role: UserRoleAdmin,
}
account.Users["user9"] = &User{
Id: "user9",
IsServiceUser: false,
Issued: UserIssuedAPI,
Role: UserRoleAdmin,
}
err := store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
integratedPeerValidator: MocIntegratedValidator{},
}
testCases := []struct {
name string
userIDs []string
expectedReasons []string
expectedDeleted []string
expectedNotDeleted []string
}{
{
name: "Delete service user successfully ",
userIDs: []string{"user2"},
expectedDeleted: []string{"user2"},
},
{
name: "Delete regular user successfully",
userIDs: []string{"user3"},
expectedDeleted: []string{"user3"},
},
{
name: "Delete integration regular user permission denied",
userIDs: []string{"user4"},
expectedReasons: []string{"only integration service user can delete this user"},
expectedNotDeleted: []string{"user4"},
},
{
name: "Delete user with owner role should return permission denied",
userIDs: []string{"user5"},
expectedReasons: []string{"unable to delete a user: user5 with owner role"},
expectedNotDeleted: []string{"user5"},
},
{
name: "Delete multiple users with mixed results",
userIDs: []string{"user5", "user5", "user6", "user7"},
expectedReasons: []string{"only integration service user can delete this user", "unable to delete a user: user5 with owner role"},
expectedDeleted: []string{"user6", "user7"},
expectedNotDeleted: []string{"user4", "user5"},
},
{
name: "Delete non-existent user",
userIDs: []string{"non-existent-user"},
expectedReasons: []string{"target user: non-existent-user not found"},
expectedNotDeleted: []string{},
},
{
name: "Delete multiple regular users successfully",
userIDs: []string{"user8", "user9"},
expectedDeleted: []string{"user8", "user9"},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err = am.DeleteRegularUsers(context.Background(), mockAccountID, mockUserID, tc.userIDs)
if len(tc.expectedReasons) > 0 {
assert.Error(t, err)
var foundExpectedErrors int
wrappedErr, ok := err.(interface{ Unwrap() []error })
assert.Equal(t, ok, true)
for _, e := range wrappedErr.Unwrap() {
assert.Contains(t, tc.expectedReasons, e.Error(), "unexpected error message")
foundExpectedErrors++
}
assert.Equal(t, len(tc.expectedReasons), foundExpectedErrors, "not all expected errors were found")
} else {
assert.NoError(t, err)
}
acc, err := am.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "")
assert.NoError(t, err)
for _, id := range tc.expectedDeleted {
_, exists := acc.Users[id]
assert.False(t, exists, "user should have been deleted: %s", id)
}
for _, id := range tc.expectedNotDeleted {
user, exists := acc.Users[id]
assert.True(t, exists, "user should not have been deleted: %s", id)
assert.NotNil(t, user, "user should exist: %s", id)
}
})
}
}
func TestDefaultAccountManager_GetUser(t *testing.T) { func TestDefaultAccountManager_GetUser(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())

View File

@ -151,6 +151,22 @@ add_aur_repo() {
${SUDO} pacman -Rs "$REMOVE_PKGS" --noconfirm ${SUDO} pacman -Rs "$REMOVE_PKGS" --noconfirm
} }
prepare_tun_module() {
# Create the necessary file structure for /dev/net/tun
if [ ! -c /dev/net/tun ]; then
if [ ! -d /dev/net ]; then
mkdir -m 755 /dev/net
fi
mknod /dev/net/tun c 10 200
chmod 0755 /dev/net/tun
fi
# Load the tun module if not already loaded
if ! lsmod | grep -q "^tun\s"; then
insmod /lib/modules/tun.ko
fi
}
install_native_binaries() { install_native_binaries() {
# Checks for supported architecture # Checks for supported architecture
case "$ARCH" in case "$ARCH" in
@ -268,6 +284,10 @@ install_netbird() {
;; ;;
esac esac
if [ "$OS_NAME" = "synology" ]; then
prepare_tun_module
fi
# Add package manager to config # Add package manager to config
${SUDO} mkdir -p "$CONFIG_FOLDER" ${SUDO} mkdir -p "$CONFIG_FOLDER"
echo "package_manager=$PACKAGE_MANAGER" | ${SUDO} tee "$CONFIG_FILE" > /dev/null echo "package_manager=$PACKAGE_MANAGER" | ${SUDO} tee "$CONFIG_FILE" > /dev/null

View File

@ -10,5 +10,5 @@ import (
// Listen is not supported on other platforms then Linux // Listen is not supported on other platforms then Linux
func Listen(port int, filter BPFFilter) (net.PacketConn, error) { func Listen(port int, filter BPFFilter) (net.PacketConn, error) {
return nil, fmt.Errorf(fmt.Sprintf("Not supported OS %s. SharedSocket is only supported on Linux", runtime.GOOS)) return nil, fmt.Errorf("not supported OS %s. SharedSocket is only supported on Linux", runtime.GOOS)
} }