mirror of
https://github.com/netbirdio/netbird.git
synced 2025-03-04 18:01:13 +01:00
Merge branch 'main' into feature/relay-integration
This commit is contained in:
commit
2e6c6cd47d
9
.github/ISSUE_TEMPLATE/bug-issue-report.md
vendored
9
.github/ISSUE_TEMPLATE/bug-issue-report.md
vendored
@ -31,9 +31,14 @@ Please specify whether you use NetBird Cloud or self-host NetBird's control plan
|
||||
|
||||
`netbird version`
|
||||
|
||||
**NetBird status -d output:**
|
||||
**NetBird status -dA output:**
|
||||
|
||||
If applicable, add the `netbird status -d' command output.
|
||||
If applicable, add the `netbird status -dA' command output.
|
||||
|
||||
**Do you face any client issues on desktop?**
|
||||
|
||||
Please provide the file created by `netbird debug for 1m -AS`.
|
||||
We advise reviewing the anonymized files for any remaining PII.
|
||||
|
||||
**Screenshots**
|
||||
|
||||
|
@ -11,8 +11,6 @@ builds:
|
||||
- amd64
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
tags:
|
||||
- legacy_appindicator
|
||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
||||
|
||||
- id: netbird-ui-windows
|
||||
|
@ -17,7 +17,7 @@
|
||||
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
|
||||
</a>
|
||||
<br>
|
||||
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-vrahf41g-ik1v7fV8du6t0RwxSrJ96A">
|
||||
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2p5zwhm4g-8fHollzrQa5y4PZF5AEpvQ">
|
||||
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
|
||||
</a>
|
||||
</p>
|
||||
|
@ -84,7 +84,7 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
|
||||
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
|
||||
supportsSSO := true
|
||||
err := a.withBackOff(a.ctx, func() (err error) {
|
||||
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
||||
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
||||
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
||||
s, ok := gstatus.FromError(err)
|
||||
|
@ -39,6 +39,11 @@ var loginCmd = &cobra.Command{
|
||||
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, hostName)
|
||||
}
|
||||
|
||||
providedSetupKey, err := getSetupKey()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// workaround to run without service
|
||||
if logFile == "console" {
|
||||
err = handleRebrand(cmd)
|
||||
@ -62,7 +67,7 @@ var loginCmd = &cobra.Command{
|
||||
|
||||
config, _ = internal.UpdateOldManagementURL(ctx, config, configPath)
|
||||
|
||||
err = foregroundLogin(ctx, cmd, config, setupKey)
|
||||
err = foregroundLogin(ctx, cmd, config, providedSetupKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("foreground login failed: %v", err)
|
||||
}
|
||||
@ -81,7 +86,7 @@ var loginCmd = &cobra.Command{
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
loginRequest := proto.LoginRequest{
|
||||
SetupKey: setupKey,
|
||||
SetupKey: providedSetupKey,
|
||||
ManagementUrl: managementURL,
|
||||
IsLinuxDesktopClient: isLinuxRunningDesktop(),
|
||||
Hostname: hostName,
|
||||
|
@ -56,6 +56,7 @@ var (
|
||||
managementURL string
|
||||
adminURL string
|
||||
setupKey string
|
||||
setupKeyPath string
|
||||
hostName string
|
||||
preSharedKey string
|
||||
natExternalIPs []string
|
||||
@ -128,6 +129,8 @@ func init() {
|
||||
rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets Netbird log level")
|
||||
rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout. If syslog is specified the log will be sent to syslog daemon.")
|
||||
rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)")
|
||||
rootCmd.PersistentFlags().StringVar(&setupKeyPath, "setup-key-file", "", "The path to a setup key obtained from the Management Service Dashboard (used to register peer) This is ignored if the setup-key flag is provided.")
|
||||
rootCmd.MarkFlagsMutuallyExclusive("setup-key", "setup-key-file")
|
||||
rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.")
|
||||
rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device")
|
||||
rootCmd.PersistentFlags().BoolVarP(&anonymizeFlag, "anonymize", "A", false, "anonymize IP addresses and non-netbird.io domains in logs and status output")
|
||||
@ -253,6 +256,21 @@ var CLIBackOffSettings = &backoff.ExponentialBackOff{
|
||||
Clock: backoff.SystemClock,
|
||||
}
|
||||
|
||||
func getSetupKey() (string, error) {
|
||||
if setupKeyPath != "" && setupKey == "" {
|
||||
return getSetupKeyFromFile(setupKeyPath)
|
||||
}
|
||||
return setupKey, nil
|
||||
}
|
||||
|
||||
func getSetupKeyFromFile(setupKeyPath string) (string, error) {
|
||||
data, err := os.ReadFile(setupKeyPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read setup key file: %v", err)
|
||||
}
|
||||
return strings.TrimSpace(string(data)), nil
|
||||
}
|
||||
|
||||
func handleRebrand(cmd *cobra.Command) error {
|
||||
var err error
|
||||
if logFile == defaultLogFile {
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
|
||||
@ -71,6 +72,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
|
||||
func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Listener) {
|
||||
t.Helper()
|
||||
|
||||
lis, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@ -88,7 +90,11 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
|
||||
return nil, nil
|
||||
}
|
||||
iv, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -147,6 +147,11 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
ic.DNSRouteInterval = &dnsRouteInterval
|
||||
}
|
||||
|
||||
providedSetupKey, err := getSetupKey()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config, err := internal.UpdateOrCreateConfig(ic)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get config file: %v", err)
|
||||
@ -154,7 +159,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
|
||||
config, _ = internal.UpdateOldManagementURL(ctx, config, configPath)
|
||||
|
||||
err = foregroundLogin(ctx, cmd, config, setupKey)
|
||||
err = foregroundLogin(ctx, cmd, config, providedSetupKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("foreground login failed: %v", err)
|
||||
}
|
||||
@ -202,8 +207,13 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
providedSetupKey, err := getSetupKey()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
loginRequest := proto.LoginRequest{
|
||||
SetupKey: setupKey,
|
||||
SetupKey: providedSetupKey,
|
||||
ManagementUrl: managementURL,
|
||||
AdminURL: adminURL,
|
||||
NatExternalIPs: natExternalIPs,
|
||||
|
@ -2,6 +2,7 @@ package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -40,6 +41,36 @@ func TestUpDaemon(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
// Test the setup-key-file flag.
|
||||
tempFile, err := os.CreateTemp("", "setup-key")
|
||||
if err != nil {
|
||||
t.Errorf("could not create temp file, got error %v", err)
|
||||
return
|
||||
}
|
||||
defer os.Remove(tempFile.Name())
|
||||
if _, err := tempFile.Write([]byte("A2C8E62B-38F5-4553-B31E-DD66C696CEBB")); err != nil {
|
||||
t.Errorf("could not write to temp file, got error %v", err)
|
||||
return
|
||||
}
|
||||
if err := tempFile.Close(); err != nil {
|
||||
t.Errorf("unable to close file, got error %v", err)
|
||||
}
|
||||
rootCmd.SetArgs([]string{
|
||||
"login",
|
||||
"--daemon-addr", "tcp://" + cliAddr,
|
||||
"--setup-key-file", tempFile.Name(),
|
||||
"--log-file", "",
|
||||
})
|
||||
if err := rootCmd.Execute(); err != nil {
|
||||
t.Errorf("expected no error while running up command, got %v", err)
|
||||
return
|
||||
}
|
||||
time.Sleep(time.Second * 3)
|
||||
if status, err := state.Status(); err != nil && status != internal.StatusIdle {
|
||||
t.Errorf("wrong status after login: %s, %v", internal.StatusIdle, err)
|
||||
return
|
||||
}
|
||||
|
||||
rootCmd.SetArgs([]string{
|
||||
"up",
|
||||
"--daemon-addr", "tcp://" + cliAddr,
|
||||
|
@ -3,6 +3,7 @@ package auth
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@ -180,7 +181,7 @@ func (d *DeviceAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowIn
|
||||
continue
|
||||
}
|
||||
|
||||
return TokenInfo{}, fmt.Errorf(tokenResponse.ErrorDescription)
|
||||
return TokenInfo{}, errors.New(tokenResponse.ErrorDescription)
|
||||
}
|
||||
|
||||
tokenInfo := TokenInfo{
|
||||
|
@ -86,7 +86,7 @@ func NewOAuthFlow(ctx context.Context, config *internal.Config, isLinuxDesktopCl
|
||||
|
||||
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
|
||||
func authenticateWithPKCEFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) {
|
||||
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
||||
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
|
||||
}
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
@ -143,6 +144,18 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (
|
||||
func (p *PKCEAuthorizationFlow) startServer(server *http.Server, tokenChan chan<- *oauth2.Token, errChan chan<- error) {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
|
||||
cert := p.providerConfig.ClientCertPair
|
||||
if cert != nil {
|
||||
tr := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
Certificates: []tls.Certificate{*cert},
|
||||
},
|
||||
}
|
||||
sslClient := &http.Client{Transport: tr}
|
||||
ctx := context.WithValue(req.Context(), oauth2.HTTPClient, sslClient)
|
||||
req = req.WithContext(ctx)
|
||||
}
|
||||
|
||||
token, err := p.handleRequest(req)
|
||||
if err != nil {
|
||||
renderPKCEFlowTmpl(w, err)
|
||||
|
@ -2,6 +2,7 @@ package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
@ -57,6 +58,8 @@ type ConfigInput struct {
|
||||
DisableAutoConnect *bool
|
||||
ExtraIFaceBlackList []string
|
||||
DNSRouteInterval *time.Duration
|
||||
ClientCertPath string
|
||||
ClientCertKeyPath string
|
||||
}
|
||||
|
||||
// Config Configuration type
|
||||
@ -102,6 +105,13 @@ type Config struct {
|
||||
|
||||
// DNSRouteInterval is the interval in which the DNS routes are updated
|
||||
DNSRouteInterval time.Duration
|
||||
//Path to a certificate used for mTLS authentication
|
||||
ClientCertPath string
|
||||
|
||||
//Path to corresponding private key of ClientCertPath
|
||||
ClientCertKeyPath string
|
||||
|
||||
ClientCertKeyPair *tls.Certificate `json:"-"`
|
||||
}
|
||||
|
||||
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
|
||||
@ -385,6 +395,26 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
|
||||
}
|
||||
|
||||
if input.ClientCertKeyPath != "" {
|
||||
config.ClientCertKeyPath = input.ClientCertKeyPath
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.ClientCertPath != "" {
|
||||
config.ClientCertPath = input.ClientCertPath
|
||||
updated = true
|
||||
}
|
||||
|
||||
if config.ClientCertPath != "" && config.ClientCertKeyPath != "" {
|
||||
cert, err := tls.LoadX509KeyPair(config.ClientCertPath, config.ClientCertKeyPath)
|
||||
if err != nil {
|
||||
log.Error("Failed to load mTLS cert/key pair: ", err)
|
||||
} else {
|
||||
config.ClientCertKeyPair = &cert
|
||||
log.Info("Loaded client mTLS cert/key pair")
|
||||
}
|
||||
}
|
||||
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
|
@ -37,6 +37,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
relayClient "github.com/netbirdio/netbird/relay/client"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
signal "github.com/netbirdio/netbird/signal/client"
|
||||
"github.com/netbirdio/netbird/signal/proto"
|
||||
@ -1097,7 +1098,11 @@ func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error)
|
||||
return nil, "", err
|
||||
}
|
||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
||||
@ -36,10 +37,12 @@ type PKCEAuthProviderConfig struct {
|
||||
RedirectURLs []string
|
||||
// UseIDToken indicates if the id token should be used for authentication
|
||||
UseIDToken bool
|
||||
//ClientCertPair is used for mTLS authentication to the IDP
|
||||
ClientCertPair *tls.Certificate
|
||||
}
|
||||
|
||||
// GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it
|
||||
func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL) (PKCEAuthorizationFlow, error) {
|
||||
func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL, clientCert *tls.Certificate) (PKCEAuthorizationFlow, error) {
|
||||
// validate our peer's Wireguard PRIVATE key
|
||||
myPrivateKey, err := wgtypes.ParseKey(privateKey)
|
||||
if err != nil {
|
||||
@ -93,6 +96,7 @@ func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL
|
||||
Scope: protoPKCEAuthorizationFlow.GetProviderConfig().GetScope(),
|
||||
RedirectURLs: protoPKCEAuthorizationFlow.GetProviderConfig().GetRedirectURLs(),
|
||||
UseIDToken: protoPKCEAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
|
||||
ClientCertPair: clientCert,
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
// go:build !android
|
||||
//go:build !android
|
||||
|
||||
package sysctl
|
||||
|
||||
import (
|
||||
|
@ -270,7 +270,14 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
|
||||
}
|
||||
|
||||
routesMap := engine.GetClientRoutesWithNetID()
|
||||
routeSelector := engine.GetRouteManager().GetRouteSelector()
|
||||
routeManager := engine.GetRouteManager()
|
||||
if routeManager == nil {
|
||||
return nil, fmt.Errorf("could not get route manager")
|
||||
}
|
||||
routeSelector := routeManager.GetRouteSelector()
|
||||
if routeSelector == nil {
|
||||
return nil, fmt.Errorf("could not get route selector")
|
||||
}
|
||||
|
||||
var routes []*selectRoute
|
||||
for id, rt := range routesMap {
|
||||
|
@ -74,7 +74,7 @@ func (a *Auth) SaveConfigIfSSOSupported() (bool, error) {
|
||||
err := a.withBackOff(a.ctx, func() (err error) {
|
||||
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
||||
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
||||
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
||||
supportsSSO = false
|
||||
err = nil
|
||||
|
@ -19,6 +19,7 @@ import (
|
||||
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/signal/proto"
|
||||
signalServer "github.com/netbirdio/netbird/signal/server"
|
||||
)
|
||||
@ -120,7 +121,11 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
||||
return nil, "", err
|
||||
}
|
||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
@ -118,9 +118,9 @@ func (srv *DefaultServer) publicKeyHandler(ctx ssh.Context, key ssh.PublicKey) b
|
||||
|
||||
func prepareUserEnv(user *user.User, shell string) []string {
|
||||
return []string{
|
||||
fmt.Sprintf("SHELL=" + shell),
|
||||
fmt.Sprintf("USER=" + user.Username),
|
||||
fmt.Sprintf("HOME=" + user.HomeDir),
|
||||
fmt.Sprint("SHELL=" + shell),
|
||||
fmt.Sprint("USER=" + user.Username),
|
||||
fmt.Sprint("HOME=" + user.HomeDir),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -22,8 +22,8 @@ import (
|
||||
"fyne.io/fyne/v2/app"
|
||||
"fyne.io/fyne/v2/dialog"
|
||||
"fyne.io/fyne/v2/widget"
|
||||
"fyne.io/systray"
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/getlantern/systray"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/skratchdot/open-golang/open"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
Binary file not shown.
Before Width: | Height: | Size: 9.0 KiB After Width: | Height: | Size: 9.6 KiB |
Binary file not shown.
Before Width: | Height: | Size: 12 KiB After Width: | Height: | Size: 12 KiB |
43
go.mod
43
go.mod
@ -30,15 +30,15 @@ require (
|
||||
)
|
||||
|
||||
require (
|
||||
fyne.io/fyne/v2 v2.1.4
|
||||
fyne.io/fyne/v2 v2.5.0
|
||||
fyne.io/systray v1.11.0
|
||||
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible
|
||||
github.com/c-robinson/iplib v1.0.3
|
||||
github.com/cilium/ebpf v0.15.0
|
||||
github.com/coreos/go-iptables v0.7.0
|
||||
github.com/creack/pty v1.1.18
|
||||
github.com/eko/gocache/v3 v3.1.1
|
||||
github.com/fsnotify/fsnotify v1.6.0
|
||||
github.com/getlantern/systray v1.2.1
|
||||
github.com/fsnotify/fsnotify v1.7.0
|
||||
github.com/gliderlabs/ssh v0.3.4
|
||||
github.com/godbus/dbus/v5 v5.1.0
|
||||
github.com/golang/mock v1.6.0
|
||||
@ -83,7 +83,7 @@ require (
|
||||
go.opentelemetry.io/otel/sdk/metric v1.26.0
|
||||
goauthentik.io/api/v3 v3.2023051.3
|
||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
|
||||
golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028
|
||||
golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a
|
||||
golang.org/x/net v0.26.0
|
||||
golang.org/x/oauth2 v0.19.0
|
||||
golang.org/x/sync v0.7.0
|
||||
@ -102,7 +102,7 @@ require (
|
||||
cloud.google.com/go/compute/metadata v0.3.0 // indirect
|
||||
dario.cat/mergo v1.0.0 // indirect
|
||||
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect
|
||||
github.com/BurntSushi/toml v1.3.2 // indirect
|
||||
github.com/BurntSushi/toml v1.4.0 // indirect
|
||||
github.com/Microsoft/go-winio v0.6.2 // indirect
|
||||
github.com/Microsoft/hcsshim v0.12.3 // indirect
|
||||
github.com/XiaoMi/pegasus-go-client v0.0.0-20210427083443-f3b6b08bc4c2 // indirect
|
||||
@ -121,27 +121,25 @@ require (
|
||||
github.com/docker/go-connections v0.5.0 // indirect
|
||||
github.com/docker/go-units v0.5.0 // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/fredbi/uri v0.0.0-20181227131451-3dcfdacbaaf3 // indirect
|
||||
github.com/getlantern/context v0.0.0-20190109183933-c447772a6520 // indirect
|
||||
github.com/getlantern/errors v0.0.0-20190325191628-abdb3e3e36f7 // indirect
|
||||
github.com/getlantern/golog v0.0.0-20190830074920-4ef2e798c2d7 // indirect
|
||||
github.com/getlantern/hex v0.0.0-20190417191902-c6586a6fe0b7 // indirect
|
||||
github.com/getlantern/hidden v0.0.0-20190325191715-f02dbb02be55 // indirect
|
||||
github.com/getlantern/ops v0.0.0-20190325191751-d70cb0d6f85f // indirect
|
||||
github.com/go-gl/gl v0.0.0-20210813123233-e4099ee2221f // indirect
|
||||
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20211024062804-40e447a793be // indirect
|
||||
github.com/fredbi/uri v1.1.0 // indirect
|
||||
github.com/fyne-io/gl-js v0.0.0-20220119005834-d2da28d9ccfe // indirect
|
||||
github.com/fyne-io/glfw-js v0.0.0-20240101223322-6e1efdc71b7a // indirect
|
||||
github.com/fyne-io/image v0.0.0-20220602074514-4956b0afb3d2 // indirect
|
||||
github.com/go-gl/gl v0.0.0-20211210172815-726fda9656d6 // indirect
|
||||
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a // indirect
|
||||
github.com/go-logr/logr v1.4.1 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/go-ole/go-ole v1.3.0 // indirect
|
||||
github.com/go-redis/redis/v8 v8.11.5 // indirect
|
||||
github.com/go-stack/stack v1.8.0 // indirect
|
||||
github.com/go-text/render v0.1.0 // indirect
|
||||
github.com/go-text/typesetting v0.1.0 // indirect
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/goki/freetype v0.0.0-20181231101311-fa8a33aabaff // indirect
|
||||
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
|
||||
github.com/google/btree v1.0.1 // indirect
|
||||
github.com/google/s2a-go v0.1.7 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.12.3 // indirect
|
||||
github.com/gopherjs/gopherjs v1.17.2 // indirect
|
||||
github.com/hashicorp/errwrap v1.1.0 // indirect
|
||||
github.com/hashicorp/go-uuid v1.0.2 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
@ -149,9 +147,11 @@ require (
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
|
||||
github.com/jackc/pgx/v5 v5.5.5 // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.1 // indirect
|
||||
github.com/jeandeaual/go-locale v0.0.0-20240223122105-ce5225dcaa49 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/josharian/native v1.1.0 // indirect
|
||||
github.com/jsummers/gobmp v0.0.0-20151104160322-e2ba15ffa76e // indirect
|
||||
github.com/kelseyhightower/envconfig v1.4.0 // indirect
|
||||
github.com/klauspost/compress v1.17.8 // indirect
|
||||
github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae // indirect
|
||||
@ -163,11 +163,11 @@ require (
|
||||
github.com/moby/sys/user v0.1.0 // indirect
|
||||
github.com/moby/term v0.5.0 // indirect
|
||||
github.com/morikuni/aec v1.0.0 // indirect
|
||||
github.com/nicksnyder/go-i18n/v2 v2.4.0 // indirect
|
||||
github.com/nxadm/tail v1.4.8 // indirect
|
||||
github.com/onsi/ginkgo/v2 v2.9.5 // indirect
|
||||
github.com/opencontainers/go-digest v1.0.0 // indirect
|
||||
github.com/opencontainers/image-spec v1.1.0 // indirect
|
||||
github.com/oxtoacart/bpool v0.0.0-20190530202638-03653db5a59c // indirect
|
||||
github.com/pegasus-kv/thrift v0.13.0 // indirect
|
||||
github.com/pion/dtls/v2 v2.2.10 // indirect
|
||||
github.com/pion/mdns v0.0.12 // indirect
|
||||
@ -178,21 +178,24 @@ require (
|
||||
github.com/prometheus/client_model v0.6.1 // indirect
|
||||
github.com/prometheus/common v0.53.0 // indirect
|
||||
github.com/prometheus/procfs v0.15.0 // indirect
|
||||
github.com/rymdport/portal v0.2.2 // indirect
|
||||
github.com/shoenig/go-m1cpu v0.1.6 // indirect
|
||||
github.com/spf13/cast v1.5.0 // indirect
|
||||
github.com/srwiley/oksvg v0.0.0-20200311192757-870daf9aa564 // indirect
|
||||
github.com/srwiley/rasterx v0.0.0-20200120212402-85cb7272f5e9 // indirect
|
||||
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c // indirect
|
||||
github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef // indirect
|
||||
github.com/tklauser/go-sysconf v0.3.14 // indirect
|
||||
github.com/tklauser/numcpus v0.8.0 // indirect
|
||||
github.com/vishvananda/netns v0.0.4 // indirect
|
||||
github.com/yuin/goldmark v1.4.13 // indirect
|
||||
github.com/yuin/goldmark v1.7.1 // indirect
|
||||
go.opencensus.io v0.24.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect
|
||||
go.opentelemetry.io/otel/sdk v1.26.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.26.0 // indirect
|
||||
golang.org/x/image v0.18.0 // indirect
|
||||
golang.org/x/mod v0.17.0 // indirect
|
||||
golang.org/x/text v0.16.0 // indirect
|
||||
golang.org/x/time v0.5.0 // indirect
|
||||
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240515191416-fc5f0ca64291 // indirect
|
||||
|
@ -9,7 +9,10 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
|
||||
@ -71,7 +74,11 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
@ -267,7 +268,7 @@ func (c *GrpcClient) receiveEvents(stream proto.ManagementService_SyncClient, se
|
||||
// GetServerPublicKey returns server's WireGuard public key (used later for encrypting messages sent to the server)
|
||||
func (c *GrpcClient) GetServerPublicKey() (*wgtypes.Key, error) {
|
||||
if !c.ready() {
|
||||
return nil, fmt.Errorf(errMsgNoMgmtConnection)
|
||||
return nil, errors.New(errMsgNoMgmtConnection)
|
||||
}
|
||||
|
||||
mgmCtx, cancel := context.WithTimeout(c.ctx, 5*time.Second)
|
||||
@ -314,7 +315,7 @@ func (c *GrpcClient) IsHealthy() bool {
|
||||
|
||||
func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*proto.LoginResponse, error) {
|
||||
if !c.ready() {
|
||||
return nil, fmt.Errorf(errMsgNoMgmtConnection)
|
||||
return nil, errors.New(errMsgNoMgmtConnection)
|
||||
}
|
||||
|
||||
loginReq, err := encryption.EncryptMessage(serverKey, c.key, req)
|
||||
@ -452,7 +453,7 @@ func (c *GrpcClient) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKC
|
||||
// It should be used if there is changes on peer posture check after initial sync.
|
||||
func (c *GrpcClient) SyncMeta(sysInfo *system.Info) error {
|
||||
if !c.ready() {
|
||||
return fmt.Errorf(errMsgNoMgmtConnection)
|
||||
return errors.New(errMsgNoMgmtConnection)
|
||||
}
|
||||
|
||||
serverPubKey, err := c.GetServerPublicKey()
|
||||
|
@ -190,7 +190,7 @@ var (
|
||||
return fmt.Errorf("failed to initialize integrated peer validator: %v", err)
|
||||
}
|
||||
accountManager, err := server.BuildManager(ctx, store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain,
|
||||
dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator)
|
||||
dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator, appMetrics)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to build default manager: %v", err)
|
||||
}
|
||||
|
@ -18,6 +18,8 @@ import (
|
||||
|
||||
"github.com/eko/gocache/v3/cache"
|
||||
cacheStore "github.com/eko/gocache/v3/store"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/miekg/dns"
|
||||
gocache "github.com/patrickmn/go-cache"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
@ -37,6 +39,7 @@ import (
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
@ -65,6 +68,7 @@ type AccountManager interface {
|
||||
SaveSetupKey(ctx context.Context, accountID string, key *SetupKey, userID string) (*SetupKey, error)
|
||||
CreateUser(ctx context.Context, accountID, initiatorUserID string, key *UserInfo) (*UserInfo, error)
|
||||
DeleteUser(ctx context.Context, accountID, initiatorUserID string, targetUserID string) error
|
||||
DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string) error
|
||||
InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error
|
||||
ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error)
|
||||
SaveUser(ctx context.Context, accountID, initiatorUserID string, update *User) (*UserInfo, error)
|
||||
@ -98,6 +102,7 @@ type AccountManager interface {
|
||||
SaveGroup(ctx context.Context, accountID, userID string, group *nbgroup.Group) error
|
||||
SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error
|
||||
DeleteGroup(ctx context.Context, accountId, userId, groupID string) error
|
||||
DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error
|
||||
ListGroups(ctx context.Context, accountId string) ([]*nbgroup.Group, error)
|
||||
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
|
||||
GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error
|
||||
@ -170,6 +175,8 @@ type DefaultAccountManager struct {
|
||||
userDeleteFromIDPEnabled bool
|
||||
|
||||
integratedPeerValidator integrated_validator.IntegratedValidator
|
||||
|
||||
metrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
// Settings represents Account settings structure that can be modified via API and Dashboard
|
||||
@ -401,8 +408,16 @@ func (a *Account) GetGroup(groupID string) *nbgroup.Group {
|
||||
return a.Groups[groupID]
|
||||
}
|
||||
|
||||
// GetPeerNetworkMap returns a group by ID if exists, nil otherwise
|
||||
func (a *Account) GetPeerNetworkMap(ctx context.Context, peerID, dnsDomain string, validatedPeersMap map[string]struct{}) *NetworkMap {
|
||||
// GetPeerNetworkMap returns the networkmap for the given peer ID.
|
||||
func (a *Account) GetPeerNetworkMap(
|
||||
ctx context.Context,
|
||||
peerID string,
|
||||
peersCustomZone nbdns.CustomZone,
|
||||
validatedPeersMap map[string]struct{},
|
||||
metrics *telemetry.AccountManagerMetrics,
|
||||
) *NetworkMap {
|
||||
start := time.Now()
|
||||
|
||||
peer := a.Peers[peerID]
|
||||
if peer == nil {
|
||||
return &NetworkMap{
|
||||
@ -438,7 +453,7 @@ func (a *Account) GetPeerNetworkMap(ctx context.Context, peerID, dnsDomain strin
|
||||
|
||||
if dnsManagementStatus {
|
||||
var zones []nbdns.CustomZone
|
||||
peersCustomZone := getPeersCustomZone(ctx, a, dnsDomain)
|
||||
|
||||
if peersCustomZone.Domain != "" {
|
||||
zones = append(zones, peersCustomZone)
|
||||
}
|
||||
@ -446,7 +461,7 @@ func (a *Account) GetPeerNetworkMap(ctx context.Context, peerID, dnsDomain strin
|
||||
dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID)
|
||||
}
|
||||
|
||||
return &NetworkMap{
|
||||
nm := &NetworkMap{
|
||||
Peers: peersToConnect,
|
||||
Network: a.Network.Copy(),
|
||||
Routes: routesUpdate,
|
||||
@ -454,6 +469,60 @@ func (a *Account) GetPeerNetworkMap(ctx context.Context, peerID, dnsDomain strin
|
||||
OfflinePeers: expiredPeers,
|
||||
FirewallRules: firewallRules,
|
||||
}
|
||||
|
||||
if metrics != nil {
|
||||
objectCount := int64(len(peersToConnect) + len(expiredPeers) + len(routesUpdate) + len(firewallRules))
|
||||
metrics.CountNetworkMapObjects(objectCount)
|
||||
metrics.CountGetPeerNetworkMapDuration(time.Since(start))
|
||||
}
|
||||
|
||||
return nm
|
||||
}
|
||||
|
||||
func (a *Account) GetPeersCustomZone(ctx context.Context, dnsDomain string) nbdns.CustomZone {
|
||||
var merr *multierror.Error
|
||||
|
||||
if dnsDomain == "" {
|
||||
log.WithContext(ctx).Error("no dns domain is set, returning empty zone")
|
||||
return nbdns.CustomZone{}
|
||||
}
|
||||
|
||||
customZone := nbdns.CustomZone{
|
||||
Domain: dns.Fqdn(dnsDomain),
|
||||
Records: make([]nbdns.SimpleRecord, 0, len(a.Peers)),
|
||||
}
|
||||
|
||||
domainSuffix := "." + dnsDomain
|
||||
|
||||
var sb strings.Builder
|
||||
for _, peer := range a.Peers {
|
||||
if peer.DNSLabel == "" {
|
||||
merr = multierror.Append(merr, fmt.Errorf("peer %s has an empty DNS label", peer.Name))
|
||||
continue
|
||||
}
|
||||
|
||||
sb.Grow(len(peer.DNSLabel) + len(domainSuffix))
|
||||
sb.WriteString(peer.DNSLabel)
|
||||
sb.WriteString(domainSuffix)
|
||||
|
||||
customZone.Records = append(customZone.Records, nbdns.SimpleRecord{
|
||||
Name: sb.String(),
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: defaultTTL,
|
||||
RData: peer.IP.String(),
|
||||
})
|
||||
|
||||
sb.Reset()
|
||||
}
|
||||
|
||||
go func() {
|
||||
if merr != nil {
|
||||
log.WithContext(ctx).Errorf("error generating custom zone for account %s: %v", a.Id, merr)
|
||||
}
|
||||
}()
|
||||
|
||||
return customZone
|
||||
}
|
||||
|
||||
// GetExpiredPeers returns peers that have been expired
|
||||
@ -853,7 +922,7 @@ func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) {
|
||||
func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) {
|
||||
for _, gid := range groups {
|
||||
group, ok := a.Groups[gid]
|
||||
if !ok {
|
||||
if !ok || group.Name == "All" {
|
||||
continue
|
||||
}
|
||||
update := make([]string, 0, len(group.Peers))
|
||||
@ -871,10 +940,18 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) {
|
||||
}
|
||||
|
||||
// BuildManager creates a new DefaultAccountManager with a provided Store
|
||||
func BuildManager(ctx context.Context, store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager,
|
||||
singleAccountModeDomain string, dnsDomain string, eventStore activity.Store, geo *geolocation.Geolocation,
|
||||
func BuildManager(
|
||||
ctx context.Context,
|
||||
store Store,
|
||||
peersUpdateManager *PeersUpdateManager,
|
||||
idpManager idp.Manager,
|
||||
singleAccountModeDomain string,
|
||||
dnsDomain string,
|
||||
eventStore activity.Store,
|
||||
geo *geolocation.Geolocation,
|
||||
userDeleteFromIDPEnabled bool,
|
||||
integratedPeerValidator integrated_validator.IntegratedValidator,
|
||||
metrics telemetry.AppMetrics,
|
||||
) (*DefaultAccountManager, error) {
|
||||
am := &DefaultAccountManager{
|
||||
Store: store,
|
||||
@ -889,6 +966,7 @@ func BuildManager(ctx context.Context, store Store, peersUpdateManager *PeersUpd
|
||||
peerLoginExpiry: NewDefaultScheduler(),
|
||||
userDeleteFromIDPEnabled: userDeleteFromIDPEnabled,
|
||||
integratedPeerValidator: integratedPeerValidator,
|
||||
metrics: metrics,
|
||||
}
|
||||
allAccounts := store.GetAllAccounts(ctx)
|
||||
// enable single account mode only if configured by user and number of existing accounts is not grater than 1
|
||||
@ -1994,6 +2072,28 @@ func (am *DefaultAccountManager) GetAccountIDForPeerKey(ctx context.Context, pee
|
||||
return am.Store.GetAccountIDByPeerPubKey(ctx, peerKey)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpeer.Peer, settings *Settings) (bool, error) {
|
||||
user, err := am.Store.GetUserByUserID(ctx, peer.UserID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
err = checkIfPeerOwnerIsBlocked(peer, user)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if peerLoginExpired(ctx, peer, settings) {
|
||||
err = am.handleExpiredPeer(ctx, user, peer)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// addAllGroup to account object if it doesn't exist
|
||||
func addAllGroup(account *Account) error {
|
||||
if len(account.Groups) == 0 {
|
||||
|
@ -24,6 +24,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
@ -410,7 +411,8 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
|
||||
validatedPeers[p] = struct{}{}
|
||||
}
|
||||
|
||||
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, "netbird.io", validatedPeers)
|
||||
customZone := account.GetPeersCustomZone(context.Background(), "netbird.io")
|
||||
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, nil)
|
||||
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
|
||||
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
|
||||
}
|
||||
@ -2293,7 +2295,13 @@ func TestAccount_UserGroupsRemoveFromPeers(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func createManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
type TB interface {
|
||||
Cleanup(func())
|
||||
Helper()
|
||||
TempDir() string
|
||||
}
|
||||
|
||||
func createManager(t TB) (*DefaultAccountManager, error) {
|
||||
t.Helper()
|
||||
|
||||
store, err := createStore(t)
|
||||
@ -2302,7 +2310,12 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
}
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
|
||||
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{})
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -2310,7 +2323,7 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
return manager, nil
|
||||
}
|
||||
|
||||
func createStore(t *testing.T) (Store, error) {
|
||||
func createStore(t TB) (Store, error) {
|
||||
t.Helper()
|
||||
dataDir := t.TempDir()
|
||||
store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir)
|
||||
|
@ -4,8 +4,8 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
@ -17,6 +17,50 @@ import (
|
||||
|
||||
const defaultTTL = 300
|
||||
|
||||
// DNSConfigCache is a thread-safe cache for DNS configuration components
|
||||
type DNSConfigCache struct {
|
||||
CustomZones sync.Map
|
||||
NameServerGroups sync.Map
|
||||
}
|
||||
|
||||
// GetCustomZone retrieves a cached custom zone
|
||||
func (c *DNSConfigCache) GetCustomZone(key string) (*proto.CustomZone, bool) {
|
||||
if c == nil {
|
||||
return nil, false
|
||||
}
|
||||
if value, ok := c.CustomZones.Load(key); ok {
|
||||
return value.(*proto.CustomZone), true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// SetCustomZone stores a custom zone in the cache
|
||||
func (c *DNSConfigCache) SetCustomZone(key string, value *proto.CustomZone) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.CustomZones.Store(key, value)
|
||||
}
|
||||
|
||||
// GetNameServerGroup retrieves a cached name server group
|
||||
func (c *DNSConfigCache) GetNameServerGroup(key string) (*proto.NameServerGroup, bool) {
|
||||
if c == nil {
|
||||
return nil, false
|
||||
}
|
||||
if value, ok := c.NameServerGroups.Load(key); ok {
|
||||
return value.(*proto.NameServerGroup), true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// SetNameServerGroup stores a name server group in the cache
|
||||
func (c *DNSConfigCache) SetNameServerGroup(key string, value *proto.NameServerGroup) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.NameServerGroups.Store(key, value)
|
||||
}
|
||||
|
||||
type lookupMap map[string]struct{}
|
||||
|
||||
// DNSSettings defines dns settings at the account level
|
||||
@ -113,69 +157,73 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
|
||||
return nil
|
||||
}
|
||||
|
||||
func toProtocolDNSConfig(update nbdns.Config) *proto.DNSConfig {
|
||||
protoUpdate := &proto.DNSConfig{ServiceEnable: update.ServiceEnable}
|
||||
// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache
|
||||
func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache) *proto.DNSConfig {
|
||||
protoUpdate := &proto.DNSConfig{
|
||||
ServiceEnable: update.ServiceEnable,
|
||||
CustomZones: make([]*proto.CustomZone, 0, len(update.CustomZones)),
|
||||
NameServerGroups: make([]*proto.NameServerGroup, 0, len(update.NameServerGroups)),
|
||||
}
|
||||
|
||||
for _, zone := range update.CustomZones {
|
||||
protoZone := &proto.CustomZone{Domain: zone.Domain}
|
||||
for _, record := range zone.Records {
|
||||
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{
|
||||
Name: record.Name,
|
||||
Type: int64(record.Type),
|
||||
Class: record.Class,
|
||||
TTL: int64(record.TTL),
|
||||
RData: record.RData,
|
||||
})
|
||||
cacheKey := zone.Domain
|
||||
if cachedZone, exists := cache.GetCustomZone(cacheKey); exists {
|
||||
protoUpdate.CustomZones = append(protoUpdate.CustomZones, cachedZone)
|
||||
} else {
|
||||
protoZone := convertToProtoCustomZone(zone)
|
||||
cache.SetCustomZone(cacheKey, protoZone)
|
||||
protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone)
|
||||
}
|
||||
protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone)
|
||||
}
|
||||
|
||||
for _, nsGroup := range update.NameServerGroups {
|
||||
protoGroup := &proto.NameServerGroup{
|
||||
Primary: nsGroup.Primary,
|
||||
Domains: nsGroup.Domains,
|
||||
SearchDomainsEnabled: nsGroup.SearchDomainsEnabled,
|
||||
cacheKey := nsGroup.ID
|
||||
if cachedGroup, exists := cache.GetNameServerGroup(cacheKey); exists {
|
||||
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, cachedGroup)
|
||||
} else {
|
||||
protoGroup := convertToProtoNameServerGroup(nsGroup)
|
||||
cache.SetNameServerGroup(cacheKey, protoGroup)
|
||||
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup)
|
||||
}
|
||||
for _, ns := range nsGroup.NameServers {
|
||||
protoNS := &proto.NameServer{
|
||||
IP: ns.IP.String(),
|
||||
Port: int64(ns.Port),
|
||||
NSType: int64(ns.NSType),
|
||||
}
|
||||
protoGroup.NameServers = append(protoGroup.NameServers, protoNS)
|
||||
}
|
||||
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup)
|
||||
}
|
||||
|
||||
return protoUpdate
|
||||
}
|
||||
|
||||
func getPeersCustomZone(ctx context.Context, account *Account, dnsDomain string) nbdns.CustomZone {
|
||||
if dnsDomain == "" {
|
||||
log.WithContext(ctx).Errorf("no dns domain is set, returning empty zone")
|
||||
return nbdns.CustomZone{}
|
||||
// Helper function to convert nbdns.CustomZone to proto.CustomZone
|
||||
func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone {
|
||||
protoZone := &proto.CustomZone{
|
||||
Domain: zone.Domain,
|
||||
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
|
||||
}
|
||||
|
||||
customZone := nbdns.CustomZone{
|
||||
Domain: dns.Fqdn(dnsDomain),
|
||||
}
|
||||
|
||||
for _, peer := range account.Peers {
|
||||
if peer.DNSLabel == "" {
|
||||
log.WithContext(ctx).Errorf("found a peer with empty dns label. It was probably caused by a invalid character in its name. Peer Name: %s", peer.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
customZone.Records = append(customZone.Records, nbdns.SimpleRecord{
|
||||
Name: dns.Fqdn(peer.DNSLabel + "." + dnsDomain),
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: defaultTTL,
|
||||
RData: peer.IP.String(),
|
||||
for _, record := range zone.Records {
|
||||
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{
|
||||
Name: record.Name,
|
||||
Type: int64(record.Type),
|
||||
Class: record.Class,
|
||||
TTL: int64(record.TTL),
|
||||
RData: record.RData,
|
||||
})
|
||||
}
|
||||
return protoZone
|
||||
}
|
||||
|
||||
return customZone
|
||||
// Helper function to convert nbdns.NameServerGroup to proto.NameServerGroup
|
||||
func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameServerGroup {
|
||||
protoGroup := &proto.NameServerGroup{
|
||||
Primary: nsGroup.Primary,
|
||||
Domains: nsGroup.Domains,
|
||||
SearchDomainsEnabled: nsGroup.SearchDomainsEnabled,
|
||||
NameServers: make([]*proto.NameServer, 0, len(nsGroup.NameServers)),
|
||||
}
|
||||
for _, ns := range nsGroup.NameServers {
|
||||
protoGroup.NameServers = append(protoGroup.NameServers, &proto.NameServer{
|
||||
IP: ns.IP.String(),
|
||||
Port: int64(ns.Port),
|
||||
NSType: int64(ns.NSType),
|
||||
})
|
||||
}
|
||||
return protoGroup
|
||||
}
|
||||
|
||||
func getPeerNSGroups(account *Account, peerID string) []*nbdns.NameServerGroup {
|
||||
|
@ -2,9 +2,14 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/dns"
|
||||
@ -195,7 +200,11 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
return nil, err
|
||||
}
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{})
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}, metrics)
|
||||
}
|
||||
|
||||
func createDNSStore(t *testing.T) (Store, error) {
|
||||
@ -320,3 +329,150 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro
|
||||
|
||||
return am.Store.GetAccount(context.Background(), account.Id)
|
||||
}
|
||||
|
||||
func generateTestData(size int) nbdns.Config {
|
||||
config := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: make([]nbdns.CustomZone, size),
|
||||
NameServerGroups: make([]*nbdns.NameServerGroup, size),
|
||||
}
|
||||
|
||||
for i := 0; i < size; i++ {
|
||||
config.CustomZones[i] = nbdns.CustomZone{
|
||||
Domain: fmt.Sprintf("domain%d.com", i),
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{
|
||||
Name: fmt.Sprintf("record%d", i),
|
||||
Type: 1,
|
||||
Class: "IN",
|
||||
TTL: 3600,
|
||||
RData: "192.168.1.1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
config.NameServerGroups[i] = &nbdns.NameServerGroup{
|
||||
ID: fmt.Sprintf("group%d", i),
|
||||
Primary: i == 0,
|
||||
Domains: []string{fmt.Sprintf("domain%d.com", i)},
|
||||
SearchDomainsEnabled: true,
|
||||
NameServers: []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
Port: 53,
|
||||
NSType: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
func BenchmarkToProtocolDNSConfig(b *testing.B) {
|
||||
sizes := []int{10, 100, 1000}
|
||||
|
||||
for _, size := range sizes {
|
||||
testData := generateTestData(size)
|
||||
|
||||
b.Run(fmt.Sprintf("WithCache-Size%d", size), func(b *testing.B) {
|
||||
cache := &DNSConfigCache{}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
toProtocolDNSConfig(testData, cache)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run(fmt.Sprintf("WithoutCache-Size%d", size), func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cache := &DNSConfigCache{}
|
||||
toProtocolDNSConfig(testData, cache)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToProtocolDNSConfigWithCache(t *testing.T) {
|
||||
var cache DNSConfigCache
|
||||
|
||||
// Create two different configs
|
||||
config1 := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "example.com",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "www", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
ID: "group1",
|
||||
Name: "Group 1",
|
||||
NameServers: []nbdns.NameServer{
|
||||
{IP: netip.MustParseAddr("8.8.8.8"), Port: 53},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
config2 := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "example.org",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "mail", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.2"},
|
||||
},
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
ID: "group2",
|
||||
Name: "Group 2",
|
||||
NameServers: []nbdns.NameServer{
|
||||
{IP: netip.MustParseAddr("8.8.4.4"), Port: 53},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// First run with config1
|
||||
result1 := toProtocolDNSConfig(config1, &cache)
|
||||
|
||||
// Second run with config2
|
||||
result2 := toProtocolDNSConfig(config2, &cache)
|
||||
|
||||
// Third run with config1 again
|
||||
result3 := toProtocolDNSConfig(config1, &cache)
|
||||
|
||||
// Verify that result1 and result3 are identical
|
||||
if !reflect.DeepEqual(result1, result3) {
|
||||
t.Errorf("Results are not identical when run with the same input. Expected %v, got %v", result1, result3)
|
||||
}
|
||||
|
||||
// Verify that result2 is different from result1 and result3
|
||||
if reflect.DeepEqual(result1, result2) || reflect.DeepEqual(result2, result3) {
|
||||
t.Errorf("Results should be different for different inputs")
|
||||
}
|
||||
|
||||
// Verify that the cache contains elements from both configs
|
||||
if _, exists := cache.GetCustomZone("example.com"); !exists {
|
||||
t.Errorf("Cache should contain custom zone for example.com")
|
||||
}
|
||||
|
||||
if _, exists := cache.GetCustomZone("example.org"); !exists {
|
||||
t.Errorf("Cache should contain custom zone for example.org")
|
||||
}
|
||||
|
||||
if _, exists := cache.GetNameServerGroup("group1"); !exists {
|
||||
t.Errorf("Cache should contain name server group 'group1'")
|
||||
}
|
||||
|
||||
if _, exists := cache.GetNameServerGroup("group2"); !exists {
|
||||
t.Errorf("Cache should contain name server group 'group2'")
|
||||
}
|
||||
}
|
||||
|
@ -469,6 +469,35 @@ func (s *FileStore) GetUserByTokenID(_ context.Context, tokenID string) (*User,
|
||||
return account.Users[userID].Copy(), nil
|
||||
}
|
||||
|
||||
func (s *FileStore) GetUserByUserID(_ context.Context, userID string) (*User, error) {
|
||||
accountID, ok := s.UserID2AccountID[userID]
|
||||
if !ok {
|
||||
return nil, status.Errorf(status.NotFound, "accountID not found: provided userID doesn't exists")
|
||||
}
|
||||
|
||||
account, err := s.getAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return account.Users[userID].Copy(), nil
|
||||
}
|
||||
|
||||
func (s *FileStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
|
||||
account, err := s.getAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
groupsSlice := make([]*nbgroup.Group, 0, len(account.Groups))
|
||||
|
||||
for _, group := range account.Groups {
|
||||
groupsSlice = append(groupsSlice, group)
|
||||
}
|
||||
|
||||
return groupsSlice, nil
|
||||
}
|
||||
|
||||
// GetAllAccounts returns all accounts
|
||||
func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) {
|
||||
s.mux.Lock()
|
||||
|
@ -2,8 +2,12 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@ -243,7 +247,7 @@ func difference(a, b []string) []string {
|
||||
return diff
|
||||
}
|
||||
|
||||
// DeleteGroup object of the peers
|
||||
// DeleteGroup object of the peers.
|
||||
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountId)
|
||||
defer unlock()
|
||||
@ -253,96 +257,14 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use
|
||||
return err
|
||||
}
|
||||
|
||||
g, ok := account.Groups[groupID]
|
||||
group, ok := account.Groups[groupID]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// disable a deleting integration group if the initiator is not an admin service user
|
||||
if g.Issued == nbgroup.GroupIssuedIntegration {
|
||||
executingUser := account.Users[userId]
|
||||
if executingUser == nil {
|
||||
return status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser {
|
||||
return status.Errorf(status.PermissionDenied, "only service users with admin power can delete integration group")
|
||||
}
|
||||
if err = validateDeleteGroup(account, group, userId); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// check route links
|
||||
for _, r := range account.Routes {
|
||||
for _, g := range r.Groups {
|
||||
if g == groupID {
|
||||
return &GroupLinkError{"route", string(r.NetID)}
|
||||
}
|
||||
}
|
||||
for _, g := range r.PeerGroups {
|
||||
if g == groupID {
|
||||
return &GroupLinkError{"route", string(r.NetID)}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check DNS links
|
||||
for _, dns := range account.NameServerGroups {
|
||||
for _, g := range dns.Groups {
|
||||
if g == groupID {
|
||||
return &GroupLinkError{"name server groups", dns.Name}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check ACL links
|
||||
for _, policy := range account.Policies {
|
||||
for _, rule := range policy.Rules {
|
||||
for _, src := range rule.Sources {
|
||||
if src == groupID {
|
||||
return &GroupLinkError{"policy", policy.Name}
|
||||
}
|
||||
}
|
||||
|
||||
for _, dst := range rule.Destinations {
|
||||
if dst == groupID {
|
||||
return &GroupLinkError{"policy", policy.Name}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check setup key links
|
||||
for _, setupKey := range account.SetupKeys {
|
||||
for _, grp := range setupKey.AutoGroups {
|
||||
if grp == groupID {
|
||||
return &GroupLinkError{"setup key", setupKey.Name}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check user links
|
||||
for _, user := range account.Users {
|
||||
for _, grp := range user.AutoGroups {
|
||||
if grp == groupID {
|
||||
return &GroupLinkError{"user", user.Id}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check DisabledManagementGroups
|
||||
for _, disabledMgmGrp := range account.DNSSettings.DisabledManagementGroups {
|
||||
if disabledMgmGrp == groupID {
|
||||
return &GroupLinkError{"disabled DNS management groups", g.Name}
|
||||
}
|
||||
}
|
||||
|
||||
// check integrated peer validator groups
|
||||
if account.Settings.Extra != nil {
|
||||
for _, integratedPeerValidatorGroups := range account.Settings.Extra.IntegratedValidatorGroups {
|
||||
if groupID == integratedPeerValidatorGroups {
|
||||
return &GroupLinkError{"integrated validator", g.Name}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
delete(account.Groups, groupID)
|
||||
|
||||
account.Network.IncSerial()
|
||||
@ -350,13 +272,57 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use
|
||||
return err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, g.EventMeta())
|
||||
am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, group.EventMeta())
|
||||
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteGroups deletes groups from an account.
|
||||
// Note: This function does not acquire the global lock.
|
||||
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
|
||||
//
|
||||
// If an error occurs while deleting a group, the function skips it and continues deleting other groups.
|
||||
// Errors are collected and returned at the end.
|
||||
func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error {
|
||||
account, err := am.Store.GetAccount(ctx, accountId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var allErrors error
|
||||
|
||||
deletedGroups := make([]*nbgroup.Group, 0, len(groupIDs))
|
||||
for _, groupID := range groupIDs {
|
||||
group, ok := account.Groups[groupID]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := validateDeleteGroup(account, group, userId); err != nil {
|
||||
allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err))
|
||||
continue
|
||||
}
|
||||
|
||||
delete(account.Groups, groupID)
|
||||
deletedGroups = append(deletedGroups, group)
|
||||
}
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, g := range deletedGroups {
|
||||
am.StoreEvent(ctx, userId, g.ID, accountId, activity.GroupDeleted, g.EventMeta())
|
||||
}
|
||||
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
return allErrors
|
||||
}
|
||||
|
||||
// ListGroups objects of the peers
|
||||
func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
@ -440,3 +406,102 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string) error {
|
||||
// disable a deleting integration group if the initiator is not an admin service user
|
||||
if group.Issued == nbgroup.GroupIssuedIntegration {
|
||||
executingUser := account.Users[userID]
|
||||
if executingUser == nil {
|
||||
return status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser {
|
||||
return status.Errorf(status.PermissionDenied, "only service users with admin power can delete integration group")
|
||||
}
|
||||
}
|
||||
|
||||
if isLinked, linkedRoute := isGroupLinkedToRoute(account.Routes, group.ID); isLinked {
|
||||
return &GroupLinkError{"route", string(linkedRoute.NetID)}
|
||||
}
|
||||
|
||||
if isLinked, linkedDns := isGroupLinkedToDns(account.NameServerGroups, group.ID); isLinked {
|
||||
return &GroupLinkError{"name server groups", linkedDns.Name}
|
||||
}
|
||||
|
||||
if isLinked, linkedPolicy := isGroupLinkedToPolicy(account.Policies, group.ID); isLinked {
|
||||
return &GroupLinkError{"policy", linkedPolicy.Name}
|
||||
}
|
||||
|
||||
if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(account.SetupKeys, group.ID); isLinked {
|
||||
return &GroupLinkError{"setup key", linkedSetupKey.Name}
|
||||
}
|
||||
|
||||
if isLinked, linkedUser := isGroupLinkedToUser(account.Users, group.ID); isLinked {
|
||||
return &GroupLinkError{"user", linkedUser.Id}
|
||||
}
|
||||
|
||||
if slices.Contains(account.DNSSettings.DisabledManagementGroups, group.ID) {
|
||||
return &GroupLinkError{"disabled DNS management groups", group.Name}
|
||||
}
|
||||
|
||||
if account.Settings.Extra != nil {
|
||||
if slices.Contains(account.Settings.Extra.IntegratedValidatorGroups, group.ID) {
|
||||
return &GroupLinkError{"integrated validator", group.Name}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isGroupLinkedToRoute checks if a group is linked to any route in the account.
|
||||
func isGroupLinkedToRoute(routes map[route.ID]*route.Route, groupID string) (bool, *route.Route) {
|
||||
for _, r := range routes {
|
||||
if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) {
|
||||
return true, r
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// isGroupLinkedToPolicy checks if a group is linked to any policy in the account.
|
||||
func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) {
|
||||
for _, policy := range policies {
|
||||
for _, rule := range policy.Rules {
|
||||
if slices.Contains(rule.Sources, groupID) || slices.Contains(rule.Destinations, groupID) {
|
||||
return true, policy
|
||||
}
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account.
|
||||
func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, groupID string) (bool, *nbdns.NameServerGroup) {
|
||||
for _, dns := range nameServerGroups {
|
||||
for _, g := range dns.Groups {
|
||||
if g == groupID {
|
||||
return true, dns
|
||||
}
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account.
|
||||
func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bool, *SetupKey) {
|
||||
for _, setupKey := range setupKeys {
|
||||
if slices.Contains(setupKey.AutoGroups, groupID) {
|
||||
return true, setupKey
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// isGroupLinkedToUser checks if a group is linked to any user in the account.
|
||||
func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) {
|
||||
for _, user := range users {
|
||||
if slices.Contains(user.AutoGroups, groupID) {
|
||||
return true, user
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
@ -3,12 +3,14 @@ package server
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -21,7 +23,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) {
|
||||
t.Error("failed to create account manager")
|
||||
}
|
||||
|
||||
account, err := initTestGroupAccount(am)
|
||||
_, account, err := initTestGroupAccount(am)
|
||||
if err != nil {
|
||||
t.Error("failed to init testing account")
|
||||
}
|
||||
@ -56,7 +58,7 @@ func TestDefaultAccountManager_DeleteGroup(t *testing.T) {
|
||||
t.Error("failed to create account manager")
|
||||
}
|
||||
|
||||
account, err := initTestGroupAccount(am)
|
||||
_, account, err := initTestGroupAccount(am)
|
||||
if err != nil {
|
||||
t.Error("failed to init testing account")
|
||||
}
|
||||
@ -132,7 +134,136 @@ func TestDefaultAccountManager_DeleteGroup(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) {
|
||||
func TestDefaultAccountManager_DeleteGroups(t *testing.T) {
|
||||
am, err := createManager(t)
|
||||
assert.NoError(t, err, "Failed to create account manager")
|
||||
|
||||
manager, account, err := initTestGroupAccount(am)
|
||||
assert.NoError(t, err, "Failed to init testing account")
|
||||
|
||||
groups := make([]*nbgroup.Group, 10)
|
||||
for i := 0; i < 10; i++ {
|
||||
groups[i] = &nbgroup.Group{
|
||||
ID: fmt.Sprintf("group-%d", i+1),
|
||||
AccountID: account.Id,
|
||||
Name: fmt.Sprintf("group-%d", i+1),
|
||||
Issued: nbgroup.GroupIssuedAPI,
|
||||
}
|
||||
}
|
||||
|
||||
err = manager.SaveGroups(context.Background(), account.Id, groupAdminUserID, groups)
|
||||
assert.NoError(t, err, "Failed to save test groups")
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
groupIDs []string
|
||||
expectedReasons []string
|
||||
expectedDeleted []string
|
||||
expectedNotDeleted []string
|
||||
}{
|
||||
{
|
||||
name: "route",
|
||||
groupIDs: []string{"grp-for-route"},
|
||||
expectedReasons: []string{"route"},
|
||||
},
|
||||
{
|
||||
name: "route with peer groups",
|
||||
groupIDs: []string{"grp-for-route2"},
|
||||
expectedReasons: []string{"route"},
|
||||
},
|
||||
{
|
||||
name: "name server groups",
|
||||
groupIDs: []string{"grp-for-name-server-grp"},
|
||||
expectedReasons: []string{"name server groups"},
|
||||
},
|
||||
{
|
||||
name: "policy",
|
||||
groupIDs: []string{"grp-for-policies"},
|
||||
expectedReasons: []string{"policy"},
|
||||
},
|
||||
{
|
||||
name: "setup keys",
|
||||
groupIDs: []string{"grp-for-keys"},
|
||||
expectedReasons: []string{"setup key"},
|
||||
},
|
||||
{
|
||||
name: "users",
|
||||
groupIDs: []string{"grp-for-users"},
|
||||
expectedReasons: []string{"user"},
|
||||
},
|
||||
{
|
||||
name: "integration",
|
||||
groupIDs: []string{"grp-for-integration"},
|
||||
expectedReasons: []string{"only service users with admin power can delete integration group"},
|
||||
},
|
||||
{
|
||||
name: "successfully delete multiple groups",
|
||||
groupIDs: []string{"group-1", "group-2"},
|
||||
expectedDeleted: []string{"group-1", "group-2"},
|
||||
},
|
||||
{
|
||||
name: "delete non-existent group",
|
||||
groupIDs: []string{"non-existent-group"},
|
||||
expectedDeleted: []string{"non-existent-group"},
|
||||
},
|
||||
{
|
||||
name: "delete multiple groups with mixed results",
|
||||
groupIDs: []string{"group-3", "grp-for-policies", "group-4", "grp-for-users"},
|
||||
expectedReasons: []string{"policy", "user"},
|
||||
expectedDeleted: []string{"group-3", "group-4"},
|
||||
expectedNotDeleted: []string{"grp-for-policies", "grp-for-users"},
|
||||
},
|
||||
{
|
||||
name: "delete groups with multiple errors",
|
||||
groupIDs: []string{"grp-for-policies", "grp-for-users"},
|
||||
expectedReasons: []string{"policy", "user"},
|
||||
expectedNotDeleted: []string{"grp-for-policies", "grp-for-users"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err = am.DeleteGroups(context.Background(), account.Id, groupAdminUserID, tc.groupIDs)
|
||||
if len(tc.expectedReasons) > 0 {
|
||||
assert.Error(t, err)
|
||||
var foundExpectedErrors int
|
||||
|
||||
wrappedErr, ok := err.(interface{ Unwrap() []error })
|
||||
assert.Equal(t, ok, true)
|
||||
|
||||
for _, e := range wrappedErr.Unwrap() {
|
||||
var sErr *status.Error
|
||||
if errors.As(e, &sErr) {
|
||||
assert.Contains(t, tc.expectedReasons, sErr.Message, "unexpected error message")
|
||||
foundExpectedErrors++
|
||||
}
|
||||
|
||||
var gErr *GroupLinkError
|
||||
if errors.As(e, &gErr) {
|
||||
assert.Contains(t, tc.expectedReasons, gErr.Resource, "unexpected error resource")
|
||||
foundExpectedErrors++
|
||||
}
|
||||
}
|
||||
assert.Equal(t, len(tc.expectedReasons), foundExpectedErrors, "not all expected errors were found")
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
for _, groupID := range tc.expectedDeleted {
|
||||
_, err := am.GetGroup(context.Background(), account.Id, groupID, groupAdminUserID)
|
||||
assert.Error(t, err, "group should have been deleted: %s", groupID)
|
||||
}
|
||||
|
||||
for _, groupID := range tc.expectedNotDeleted {
|
||||
group, err := am.GetGroup(context.Background(), account.Id, groupID, groupAdminUserID)
|
||||
assert.NoError(t, err, "group should not have been deleted: %s", groupID)
|
||||
assert.NotNil(t, group, "group should exist: %s", groupID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *Account, error) {
|
||||
accountID := "testingAcc"
|
||||
domain := "example.com"
|
||||
|
||||
@ -236,7 +367,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) {
|
||||
|
||||
err := am.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute)
|
||||
@ -247,5 +378,9 @@ func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) {
|
||||
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForUsers)
|
||||
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForIntegration)
|
||||
|
||||
return am.Store.GetAccount(context.Background(), account.Id)
|
||||
acc, err := am.Store.GetAccount(context.Background(), account.Id)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return am, acc, nil
|
||||
}
|
||||
|
@ -256,7 +256,7 @@ func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string
|
||||
}
|
||||
|
||||
if err := s.accountManager.CheckUserAccessByJWTGroups(ctx, claims); err != nil {
|
||||
return "", status.Errorf(codes.PermissionDenied, err.Error())
|
||||
return "", status.Error(codes.PermissionDenied, err.Error())
|
||||
}
|
||||
|
||||
return claims.UserId, nil
|
||||
@ -267,15 +267,15 @@ func mapError(ctx context.Context, err error) error {
|
||||
if e, ok := internalStatus.FromError(err); ok {
|
||||
switch e.Type() {
|
||||
case internalStatus.PermissionDenied:
|
||||
return status.Errorf(codes.PermissionDenied, e.Message)
|
||||
return status.Error(codes.PermissionDenied, e.Message)
|
||||
case internalStatus.Unauthorized:
|
||||
return status.Errorf(codes.PermissionDenied, e.Message)
|
||||
return status.Error(codes.PermissionDenied, e.Message)
|
||||
case internalStatus.Unauthenticated:
|
||||
return status.Errorf(codes.PermissionDenied, e.Message)
|
||||
return status.Error(codes.PermissionDenied, e.Message)
|
||||
case internalStatus.PreconditionFailed:
|
||||
return status.Errorf(codes.FailedPrecondition, e.Message)
|
||||
return status.Error(codes.FailedPrecondition, e.Message)
|
||||
case internalStatus.NotFound:
|
||||
return status.Errorf(codes.NotFound, e.Message)
|
||||
return status.Error(codes.NotFound, e.Message)
|
||||
default:
|
||||
}
|
||||
}
|
||||
@ -550,53 +550,46 @@ func toPeerConfig(peer *nbpeer.Peer, network *Network, dnsName string) *proto.Pe
|
||||
}
|
||||
}
|
||||
|
||||
func toRemotePeerConfig(peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig {
|
||||
remotePeers := []*proto.RemotePeerConfig{}
|
||||
for _, rPeer := range peers {
|
||||
fqdn := rPeer.FQDN(dnsName)
|
||||
remotePeers = append(remotePeers, &proto.RemotePeerConfig{
|
||||
WgPubKey: rPeer.Key,
|
||||
AllowedIps: []string{fmt.Sprintf(AllowedIPsFormat, rPeer.IP)},
|
||||
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
|
||||
Fqdn: fqdn,
|
||||
})
|
||||
}
|
||||
return remotePeers
|
||||
}
|
||||
|
||||
func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *TURNRelayToken, relayCredentials *TURNRelayToken, networkMap *NetworkMap, dnsName string, checks []*posture.Checks) *proto.SyncResponse {
|
||||
wtConfig := toWiretrusteeConfig(config, turnCredentials, relayCredentials)
|
||||
|
||||
pConfig := toPeerConfig(peer, networkMap.Network, dnsName)
|
||||
|
||||
remotePeers := toRemotePeerConfig(networkMap.Peers, dnsName)
|
||||
|
||||
routesUpdate := toProtocolRoutes(networkMap.Routes)
|
||||
|
||||
dnsUpdate := toProtocolDNSConfig(networkMap.DNSConfig)
|
||||
|
||||
offlinePeers := toRemotePeerConfig(networkMap.OfflinePeers, dnsName)
|
||||
|
||||
firewallRules := toProtocolFirewallRules(networkMap.FirewallRules)
|
||||
|
||||
return &proto.SyncResponse{
|
||||
WiretrusteeConfig: wtConfig,
|
||||
PeerConfig: pConfig,
|
||||
RemotePeers: remotePeers,
|
||||
RemotePeersIsEmpty: len(remotePeers) == 0,
|
||||
func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *TURNRelayToken, relayCredentials *TURNRelayToken, networkMap *NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache) *proto.SyncResponse {
|
||||
response := &proto.SyncResponse{
|
||||
WiretrusteeConfig: toWiretrusteeConfig(config, turnCredentials, relayCredentials),
|
||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName),
|
||||
NetworkMap: &proto.NetworkMap{
|
||||
Serial: networkMap.Network.CurrentSerial(),
|
||||
PeerConfig: pConfig,
|
||||
RemotePeers: remotePeers,
|
||||
OfflinePeers: offlinePeers,
|
||||
RemotePeersIsEmpty: len(remotePeers) == 0,
|
||||
Routes: routesUpdate,
|
||||
DNSConfig: dnsUpdate,
|
||||
FirewallRules: firewallRules,
|
||||
FirewallRulesIsEmpty: len(firewallRules) == 0,
|
||||
Serial: networkMap.Network.CurrentSerial(),
|
||||
Routes: toProtocolRoutes(networkMap.Routes),
|
||||
DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache),
|
||||
},
|
||||
Checks: toProtocolChecks(ctx, checks),
|
||||
}
|
||||
|
||||
response.NetworkMap.PeerConfig = response.PeerConfig
|
||||
|
||||
allPeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers))
|
||||
allPeers = appendRemotePeerConfig(allPeers, networkMap.Peers, dnsName)
|
||||
response.RemotePeers = allPeers
|
||||
response.NetworkMap.RemotePeers = allPeers
|
||||
response.RemotePeersIsEmpty = len(allPeers) == 0
|
||||
response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty
|
||||
|
||||
response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName)
|
||||
|
||||
firewallRules := toProtocolFirewallRules(networkMap.FirewallRules)
|
||||
response.NetworkMap.FirewallRules = firewallRules
|
||||
response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig {
|
||||
for _, rPeer := range peers {
|
||||
dst = append(dst, &proto.RemotePeerConfig{
|
||||
WgPubKey: rPeer.Key,
|
||||
AllowedIps: []string{rPeer.IP.String() + "/32"},
|
||||
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
|
||||
Fqdn: rPeer.FQDN(dnsName),
|
||||
})
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// IsHealthy indicates whether the service is healthy
|
||||
@ -615,7 +608,7 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p
|
||||
if s.config.TURNConfig.TimeBasedCredentials {
|
||||
turnCredentials = trt
|
||||
}
|
||||
plainResp := toSyncResponse(ctx, s.config, peer, turnCredentials, trt, networkMap, s.accountManager.GetDNSDomain(), postureChecks)
|
||||
plainResp := toSyncResponse(ctx, s.config, peer, turnCredentials, trt, networkMap, s.accountManager.GetDNSDomain(), postureChecks, nil)
|
||||
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
|
||||
if err != nil {
|
||||
|
@ -71,7 +71,8 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee
|
||||
return
|
||||
}
|
||||
|
||||
netMap := account.GetPeerNetworkMap(ctx, peerID, h.accountManager.GetDNSDomain(), validPeers)
|
||||
customZone := account.GetPeersCustomZone(ctx, h.accountManager.GetDNSDomain())
|
||||
netMap := account.GetPeerNetworkMap(ctx, peerID, customZone, validPeers, nil)
|
||||
accessiblePeers := toAccessiblePeers(netMap, dnsDomain)
|
||||
|
||||
_, valid := validPeers[peer.ID]
|
||||
@ -115,7 +116,9 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account,
|
||||
util.WriteError(ctx, fmt.Errorf("internal error"), w)
|
||||
return
|
||||
}
|
||||
netMap := account.GetPeerNetworkMap(ctx, peerID, h.accountManager.GetDNSDomain(), validPeers)
|
||||
|
||||
customZone := account.GetPeersCustomZone(ctx, h.accountManager.GetDNSDomain())
|
||||
netMap := account.GetPeerNetworkMap(ctx, peerID, customZone, validPeers, nil)
|
||||
accessiblePeers := toAccessiblePeers(netMap, dnsDomain)
|
||||
|
||||
_, valid := validPeers[peer.ID]
|
||||
@ -194,9 +197,7 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID)
|
||||
|
||||
accessiblePeerNumbers, _ := h.accessiblePeersNumber(r.Context(), account, peer.ID)
|
||||
|
||||
respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, accessiblePeerNumbers))
|
||||
respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0))
|
||||
}
|
||||
|
||||
validPeersMap, err := h.accountManager.GetValidatedPeers(account)
|
||||
@ -210,16 +211,6 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, respBody)
|
||||
}
|
||||
|
||||
func (h *PeersHandler) accessiblePeersNumber(ctx context.Context, account *server.Account, peerID string) (int, error) {
|
||||
validatedPeersMap, err := h.accountManager.GetValidatedPeers(account)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
netMap := account.GetPeerNetworkMap(ctx, peerID, h.accountManager.GetDNSDomain(), validatedPeersMap)
|
||||
return len(netMap.Peers) + len(netMap.OfflinePeers), nil
|
||||
}
|
||||
|
||||
func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPeersMap map[string]struct{}) {
|
||||
for _, peer := range respBody {
|
||||
_, ok := approvedPeersMap[peer.Id]
|
||||
|
@ -46,7 +46,7 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
|
||||
testPostureChecks[postureChecks.ID] = postureChecks
|
||||
|
||||
if err := postureChecks.Validate(); err != nil {
|
||||
return status.Errorf(status.InvalidArgument, err.Error())
|
||||
return status.Errorf(status.InvalidArgument, err.Error()) //nolint
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -3,6 +3,7 @@ package idp
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@ -44,14 +45,14 @@ type mockJsonParser struct {
|
||||
|
||||
func (m *mockJsonParser) Marshal(v interface{}) ([]byte, error) {
|
||||
if m.marshalErrorString != "" {
|
||||
return nil, fmt.Errorf(m.marshalErrorString)
|
||||
return nil, errors.New(m.marshalErrorString)
|
||||
}
|
||||
return m.jsonParser.Marshal(v)
|
||||
}
|
||||
|
||||
func (m *mockJsonParser) Unmarshal(data []byte, v interface{}) error {
|
||||
if m.unmarshalErrorString != "" {
|
||||
return fmt.Errorf(m.unmarshalErrorString)
|
||||
return errors.New(m.unmarshalErrorString)
|
||||
}
|
||||
return m.jsonParser.Unmarshal(data, v)
|
||||
}
|
||||
|
@ -150,7 +150,7 @@ func (m *JWTValidator) ValidateAndParse(ctx context.Context, token string) (*jwt
|
||||
// If we get here, the required token is missing
|
||||
errorMsg := "required authorization token not found"
|
||||
log.WithContext(ctx).Debugf(" Error: No credentials found (CredentialsOptional=false)")
|
||||
return nil, fmt.Errorf(errorMsg)
|
||||
return nil, errors.New(errorMsg)
|
||||
}
|
||||
|
||||
// Now parse the token
|
||||
@ -173,7 +173,7 @@ func (m *JWTValidator) ValidateAndParse(ctx context.Context, token string) (*jwt
|
||||
// Check if the parsed token is valid...
|
||||
if !parsedToken.Valid {
|
||||
errorMsg := "token is invalid"
|
||||
log.WithContext(ctx).Debugf(errorMsg)
|
||||
log.WithContext(ctx).Debug(errorMsg)
|
||||
return nil, errors.New(errorMsg)
|
||||
}
|
||||
|
||||
|
@ -20,6 +20,7 @@ import (
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@ -419,8 +420,12 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, *DefaultAccoun
|
||||
|
||||
ctx := context.WithValue(context.Background(), formatter.ExecutionContextKey, formatter.SystemSource) //nolint:staticcheck
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted",
|
||||
eventStore, nil, false, MocIntegratedValidator{})
|
||||
eventStore, nil, false, MocIntegratedValidator{}, metrics)
|
||||
|
||||
if err != nil {
|
||||
return nil, nil, "", err
|
||||
}
|
||||
|
@ -26,6 +26,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/group"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@ -541,8 +542,13 @@ func startServer(config *server.Config) (*grpc.Server, net.Listener) {
|
||||
|
||||
peersUpdateManager := server.NewPeersUpdateManager(nil)
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted",
|
||||
eventStore, nil, false, MocIntegratedValidator{})
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
if err != nil {
|
||||
log.Fatalf("failed creating metrics: %v", err)
|
||||
}
|
||||
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics)
|
||||
if err != nil {
|
||||
log.Fatalf("failed creating a manager: %v", err)
|
||||
}
|
||||
|
@ -42,6 +42,7 @@ type MockAccountManager struct {
|
||||
SaveGroupFunc func(ctx context.Context, accountID, userID string, group *group.Group) error
|
||||
SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*group.Group) error
|
||||
DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error
|
||||
DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error
|
||||
ListGroupsFunc func(ctx context.Context, accountID string) ([]*group.Group, error)
|
||||
GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
|
||||
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
|
||||
@ -67,6 +68,7 @@ type MockAccountManager struct {
|
||||
SaveOrAddUserFunc func(ctx context.Context, accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error)
|
||||
SaveOrAddUsersFunc func(ctx context.Context, accountID, initiatorUserID string, update []*server.User, addIfNotExists bool) ([]*server.UserInfo, error)
|
||||
DeleteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error
|
||||
DeleteRegularUsersFunc func(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string) error
|
||||
CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
|
||||
DeletePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) error
|
||||
GetPATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error)
|
||||
@ -326,6 +328,14 @@ func (am *MockAccountManager) DeleteGroup(ctx context.Context, accountId, userId
|
||||
return status.Errorf(codes.Unimplemented, "method DeleteGroup is not implemented")
|
||||
}
|
||||
|
||||
// DeleteGroups mock implementation of DeleteGroups from server.AccountManager interface
|
||||
func (am *MockAccountManager) DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error {
|
||||
if am.DeleteGroupsFunc != nil {
|
||||
return am.DeleteGroupsFunc(ctx, accountId, userId, groupIDs)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method DeleteGroups is not implemented")
|
||||
}
|
||||
|
||||
// ListGroups mock implementation of ListGroups from server.AccountManager interface
|
||||
func (am *MockAccountManager) ListGroups(ctx context.Context, accountID string) ([]*group.Group, error) {
|
||||
if am.ListGroupsFunc != nil {
|
||||
@ -528,6 +538,14 @@ func (am *MockAccountManager) DeleteUser(ctx context.Context, accountID string,
|
||||
return status.Errorf(codes.Unimplemented, "method DeleteUser is not implemented")
|
||||
}
|
||||
|
||||
// DeleteRegularUsers mocks DeleteRegularUsers of the AccountManager interface
|
||||
func (am *MockAccountManager) DeleteRegularUsers(ctx context.Context, accountID string, initiatorUserID string, targetUserIDs []string) error {
|
||||
if am.DeleteRegularUsersFunc != nil {
|
||||
return am.DeleteRegularUsersFunc(ctx, accountID, initiatorUserID, targetUserIDs)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method DeleteRegularUsers is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error {
|
||||
if am.InviteUserFunc != nil {
|
||||
return am.InviteUserFunc(ctx, accountID, initiatorUserID, targetUserID)
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -762,7 +763,11 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
return nil, err
|
||||
}
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{})
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics)
|
||||
}
|
||||
|
||||
func createNSStore(t *testing.T) (Store, error) {
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rs/xid"
|
||||
@ -65,12 +66,14 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
|
||||
peers := make([]*nbpeer.Peer, 0)
|
||||
peersMap := make(map[string]*nbpeer.Peer)
|
||||
|
||||
if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked {
|
||||
regularUser := !user.HasAdminPower() && !user.IsServiceUser
|
||||
|
||||
if regularUser && account.Settings.RegularUsersViewBlocked {
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
for _, peer := range account.Peers {
|
||||
if !(user.HasAdminPower() || user.IsServiceUser) && user.Id != peer.UserID {
|
||||
if regularUser && user.Id != peer.UserID {
|
||||
// only display peers that belong to the current user if the current user is not an admin
|
||||
continue
|
||||
}
|
||||
@ -79,6 +82,10 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
|
||||
peersMap[peer.ID] = p
|
||||
}
|
||||
|
||||
if !regularUser {
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
// fetch all the peers that have access to the user's peers
|
||||
for _, peer := range peers {
|
||||
aclPeers, _ := account.getPeerConnectionResources(ctx, peer.ID, approvedPeersMap)
|
||||
@ -316,7 +323,8 @@ func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID strin
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, validatedPeers), nil
|
||||
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
|
||||
return account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, nil), nil
|
||||
}
|
||||
|
||||
// GetPeerNetwork returns the Network for a given peer
|
||||
@ -529,7 +537,8 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
}
|
||||
|
||||
postureChecks := am.getPeerPostureChecks(account, peer)
|
||||
networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, am.dnsDomain, approvedPeersMap)
|
||||
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
|
||||
networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
|
||||
return newPeer, networkMap, postureChecks, nil
|
||||
}
|
||||
|
||||
@ -540,16 +549,25 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
|
||||
return nil, nil, nil, status.NewPeerNotRegisteredError()
|
||||
}
|
||||
|
||||
err = checkIfPeerOwnerIsBlocked(peer, account)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
if peer.UserID != "" {
|
||||
log.Infof("Peer has no userID")
|
||||
|
||||
user, err := account.FindUser(peer.UserID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
err = checkIfPeerOwnerIsBlocked(peer, user)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if peerLoginExpired(ctx, peer, account.Settings) {
|
||||
return nil, nil, nil, status.NewPeerLoginExpiredError()
|
||||
}
|
||||
|
||||
peer, updated := updatePeerMeta(peer, sync.Meta, account)
|
||||
updated := peer.UpdateMetaIfNew(sync.Meta)
|
||||
if updated {
|
||||
err = am.Store.SavePeer(ctx, account.Id, peer)
|
||||
if err != nil {
|
||||
@ -585,7 +603,8 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
|
||||
}
|
||||
postureChecks = am.getPeerPostureChecks(account, peer)
|
||||
|
||||
return peer, account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, validPeersMap), postureChecks, nil
|
||||
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
|
||||
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
|
||||
}
|
||||
|
||||
// LoginPeer logs in or registers a peer.
|
||||
@ -614,31 +633,28 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
|
||||
// it means that the client has already checked if it needs login and had been through the SSO flow
|
||||
// so, we can skip this check and directly proceed with the login
|
||||
if login.UserID == "" {
|
||||
log.Info("Peer needs login")
|
||||
err = am.checkIFPeerNeedsLoginWithoutLock(ctx, accountID, login)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
unlockAccount := am.Store.AcquireReadLockByUID(ctx, accountID)
|
||||
defer unlockAccount()
|
||||
unlockPeer := am.Store.AcquireWriteLockByUID(ctx, login.WireGuardPubKey)
|
||||
defer func() {
|
||||
if unlock != nil {
|
||||
unlock()
|
||||
if unlockPeer != nil {
|
||||
unlockPeer()
|
||||
}
|
||||
}()
|
||||
|
||||
// fetch the account from the store once more after acquiring lock to avoid concurrent updates inconsistencies
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(ctx, login.WireGuardPubKey)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
peer, err := account.FindPeerByPubKey(login.WireGuardPubKey)
|
||||
if err != nil {
|
||||
return nil, nil, nil, status.NewPeerNotRegisteredError()
|
||||
}
|
||||
|
||||
err = checkIfPeerOwnerIsBlocked(peer, account)
|
||||
settings, err := am.Store.GetAccountSettings(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
@ -646,21 +662,39 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
|
||||
// this flag prevents unnecessary calls to the persistent store.
|
||||
shouldStorePeer := false
|
||||
updateRemotePeers := false
|
||||
if peerLoginExpired(ctx, peer, account.Settings) {
|
||||
err = am.handleExpiredPeer(ctx, login, account, peer)
|
||||
|
||||
if login.UserID != "" {
|
||||
changed, err := am.handleUserPeer(ctx, peer, settings)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
updateRemotePeers = true
|
||||
shouldStorePeer = true
|
||||
if changed {
|
||||
shouldStorePeer = true
|
||||
updateRemotePeers = true
|
||||
}
|
||||
}
|
||||
|
||||
isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
|
||||
groups, err := am.Store.GetAccountGroups(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
peer, updated := updatePeerMeta(peer, login.Meta, account)
|
||||
var grps []string
|
||||
for _, group := range groups {
|
||||
for _, id := range group.Peers {
|
||||
if id == peer.ID {
|
||||
grps = append(grps, group.ID)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, grps, settings.Extra)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
updated := peer.UpdateMetaIfNew(login.Meta)
|
||||
if updated {
|
||||
shouldStorePeer = true
|
||||
}
|
||||
@ -677,8 +711,13 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
|
||||
}
|
||||
}
|
||||
|
||||
unlock()
|
||||
unlock = nil
|
||||
unlockPeer()
|
||||
unlockPeer = nil
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
if updateRemotePeers || isStatusChanged {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
@ -732,39 +771,34 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
|
||||
}
|
||||
postureChecks = am.getPeerPostureChecks(account, peer)
|
||||
|
||||
return peer, account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, approvedPeersMap), postureChecks, nil
|
||||
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
|
||||
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, login PeerLogin, account *Account, peer *nbpeer.Peer) error {
|
||||
err := checkAuth(ctx, login.UserID, peer)
|
||||
func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *User, peer *nbpeer.Peer) error {
|
||||
err := checkAuth(ctx, user.Id, peer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// If peer was expired before and if it reached this point, it is re-authenticated.
|
||||
// UserID is present, meaning that JWT validation passed successfully in the API layer.
|
||||
updatePeerLastLogin(peer, account)
|
||||
|
||||
// sync user last login with peer last login
|
||||
user, err := account.FindUser(login.UserID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.Internal, "couldn't find user")
|
||||
}
|
||||
|
||||
err = am.Store.SaveUserLastLogin(account.Id, user.Id, peer.LastLogin)
|
||||
peer = peer.UpdateLastLogin()
|
||||
err = am.Store.SavePeer(ctx, peer.AccountID, peer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, login.UserID, peer.ID, account.Id, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain()))
|
||||
err = am.Store.SaveUserLastLogin(user.AccountID, user.Id, peer.LastLogin)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, user.Id, peer.ID, user.AccountID, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain()))
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, account *Account) error {
|
||||
func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, user *User) error {
|
||||
if peer.AddedWithSSOLogin() {
|
||||
user, err := account.FindUser(peer.UserID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.PermissionDenied, "user doesn't exist")
|
||||
}
|
||||
if user.IsBlocked() {
|
||||
return status.Errorf(status.PermissionDenied, "user is blocked")
|
||||
}
|
||||
@ -794,11 +828,6 @@ func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *Settings
|
||||
return false
|
||||
}
|
||||
|
||||
func updatePeerLastLogin(peer *nbpeer.Peer, account *Account) {
|
||||
peer.UpdateLastLogin()
|
||||
account.UpdatePeer(peer)
|
||||
}
|
||||
|
||||
// UpdatePeerSSHKey updates peer's public SSH key
|
||||
func (am *DefaultAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error {
|
||||
if sshKey == "" {
|
||||
@ -897,33 +926,48 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID,
|
||||
return nil, status.Errorf(status.Internal, "user %s has no access to peer %s under account %s", userID, peerID, accountID)
|
||||
}
|
||||
|
||||
func updatePeerMeta(peer *nbpeer.Peer, meta nbpeer.PeerSystemMeta, account *Account) (*nbpeer.Peer, bool) {
|
||||
if peer.UpdateMetaIfNew(meta) {
|
||||
account.UpdatePeer(peer)
|
||||
return peer, true
|
||||
}
|
||||
return peer, false
|
||||
}
|
||||
|
||||
// updateAccountPeers updates all peers that belong to an account.
|
||||
// Should be called when changes have to be synced to peers.
|
||||
func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account *Account) {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
if am.metrics != nil {
|
||||
am.metrics.AccountManagerMetrics().CountUpdateAccountPeersDuration(time.Since(start))
|
||||
}
|
||||
}()
|
||||
|
||||
peers := account.GetPeers()
|
||||
|
||||
approvedPeersMap, err := am.GetValidatedPeers(account)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed send out updates to peers, failed to validate peer: %v", err)
|
||||
log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to validate peer: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
semaphore := make(chan struct{}, 10)
|
||||
|
||||
dnsCache := &DNSConfigCache{}
|
||||
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
|
||||
|
||||
for _, peer := range peers {
|
||||
if !am.peersUpdateManager.HasChannel(peer.ID) {
|
||||
log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
|
||||
continue
|
||||
}
|
||||
|
||||
postureChecks := am.getPeerPostureChecks(account, peer)
|
||||
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, approvedPeersMap)
|
||||
update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks)
|
||||
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update})
|
||||
wg.Add(1)
|
||||
semaphore <- struct{}{}
|
||||
go func(p *nbpeer.Peer) {
|
||||
defer wg.Done()
|
||||
defer func() { <-semaphore }()
|
||||
|
||||
postureChecks := am.getPeerPostureChecks(account, p)
|
||||
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
|
||||
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache)
|
||||
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update})
|
||||
}(peer)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
@ -1,7 +1,6 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
@ -241,7 +240,7 @@ func (p *Peer) FQDN(dnsDomain string) string {
|
||||
if dnsDomain == "" {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%s.%s", p.DNSLabel, dnsDomain)
|
||||
return p.DNSLabel + "." + dnsDomain
|
||||
}
|
||||
|
||||
// EventMeta returns activity event meta related to the peer
|
||||
|
31
management/server/peer/peer_test.go
Normal file
31
management/server/peer/peer_test.go
Normal file
@ -0,0 +1,31 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// FQDNOld is the original implementation for benchmarking purposes
|
||||
func (p *Peer) FQDNOld(dnsDomain string) string {
|
||||
if dnsDomain == "" {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%s.%s", p.DNSLabel, dnsDomain)
|
||||
}
|
||||
|
||||
func BenchmarkFQDN(b *testing.B) {
|
||||
p := &Peer{DNSLabel: "test-peer"}
|
||||
dnsDomain := "example.com"
|
||||
|
||||
b.Run("Old", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
p.FQDNOld(dnsDomain)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("New", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
p.FQDN(dnsDomain)
|
||||
}
|
||||
})
|
||||
}
|
@ -2,15 +2,26 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
nbroute "github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
func TestPeer_LoginExpired(t *testing.T) {
|
||||
@ -633,3 +644,354 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccountManager, string, string, error) {
|
||||
b.Helper()
|
||||
|
||||
manager, err := createManager(b)
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
}
|
||||
|
||||
accountID := "test_account"
|
||||
adminUser := "account_creator"
|
||||
regularUser := "regular_user"
|
||||
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "")
|
||||
account.Users[regularUser] = &User{
|
||||
Id: regularUser,
|
||||
Role: UserRoleUser,
|
||||
}
|
||||
|
||||
// Create peers
|
||||
for i := 0; i < peers; i++ {
|
||||
peerKey, _ := wgtypes.GeneratePrivateKey()
|
||||
peer := &nbpeer.Peer{
|
||||
ID: fmt.Sprintf("peer-%d", i),
|
||||
DNSLabel: fmt.Sprintf("peer-%d", i),
|
||||
Key: peerKey.PublicKey().String(),
|
||||
IP: net.ParseIP(fmt.Sprintf("100.64.%d.%d", i/256, i%256)),
|
||||
Status: &nbpeer.PeerStatus{},
|
||||
UserID: regularUser,
|
||||
}
|
||||
account.Peers[peer.ID] = peer
|
||||
}
|
||||
|
||||
// Create groups and policies
|
||||
account.Policies = make([]*Policy, 0, groups)
|
||||
for i := 0; i < groups; i++ {
|
||||
groupID := fmt.Sprintf("group-%d", i)
|
||||
group := &nbgroup.Group{
|
||||
ID: groupID,
|
||||
Name: fmt.Sprintf("Group %d", i),
|
||||
}
|
||||
for j := 0; j < peers/groups; j++ {
|
||||
peerIndex := i*(peers/groups) + j
|
||||
group.Peers = append(group.Peers, fmt.Sprintf("peer-%d", peerIndex))
|
||||
}
|
||||
account.Groups[groupID] = group
|
||||
|
||||
// Create a policy for this group
|
||||
policy := &Policy{
|
||||
ID: fmt.Sprintf("policy-%d", i),
|
||||
Name: fmt.Sprintf("Policy for Group %d", i),
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: fmt.Sprintf("rule-%d", i),
|
||||
Name: fmt.Sprintf("Rule for Group %d", i),
|
||||
Enabled: true,
|
||||
Sources: []string{groupID},
|
||||
Destinations: []string{groupID},
|
||||
Bidirectional: true,
|
||||
Protocol: PolicyRuleProtocolALL,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
account.Policies = append(account.Policies, policy)
|
||||
}
|
||||
|
||||
account.PostureChecks = []*posture.Checks{
|
||||
{
|
||||
ID: "PostureChecksAll",
|
||||
Name: "All",
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{
|
||||
MinVersion: "0.0.1",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err = manager.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
}
|
||||
|
||||
return manager, accountID, regularUser, nil
|
||||
}
|
||||
|
||||
func BenchmarkGetPeers(b *testing.B) {
|
||||
benchCases := []struct {
|
||||
name string
|
||||
peers int
|
||||
groups int
|
||||
}{
|
||||
{"Small", 50, 5},
|
||||
{"Medium", 500, 10},
|
||||
{"Large", 5000, 20},
|
||||
{"Small single", 50, 1},
|
||||
{"Medium single", 500, 1},
|
||||
{"Large 5", 5000, 5},
|
||||
}
|
||||
|
||||
log.SetOutput(io.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
for _, bc := range benchCases {
|
||||
b.Run(bc.name, func(b *testing.B) {
|
||||
manager, accountID, userID, err := setupTestAccountManager(b, bc.peers, bc.groups)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to setup test account manager: %v", err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := manager.GetPeers(context.Background(), accountID, userID)
|
||||
if err != nil {
|
||||
b.Fatalf("GetPeers failed: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkUpdateAccountPeers(b *testing.B) {
|
||||
benchCases := []struct {
|
||||
name string
|
||||
peers int
|
||||
groups int
|
||||
}{
|
||||
{"Small", 50, 5},
|
||||
{"Medium", 500, 10},
|
||||
{"Large", 5000, 20},
|
||||
{"Small single", 50, 1},
|
||||
{"Medium single", 500, 1},
|
||||
{"Large 5", 5000, 5},
|
||||
}
|
||||
|
||||
log.SetOutput(io.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
|
||||
for _, bc := range benchCases {
|
||||
b.Run(bc.name, func(b *testing.B) {
|
||||
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to setup test account manager: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
account, err := manager.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to get account: %v", err)
|
||||
}
|
||||
|
||||
peerChannels := make(map[string]chan *UpdateMessage)
|
||||
|
||||
for peerID := range account.Peers {
|
||||
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
|
||||
}
|
||||
|
||||
manager.peersUpdateManager.peerChannels = peerChannels
|
||||
|
||||
b.ResetTimer()
|
||||
start := time.Now()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.updateAccountPeers(ctx, account)
|
||||
}
|
||||
|
||||
duration := time.Since(start)
|
||||
b.ReportMetric(float64(duration.Nanoseconds())/float64(b.N)/1e6, "ms/op")
|
||||
b.ReportMetric(0, "ns/op")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToSyncResponse(t *testing.T) {
|
||||
_, ipnet, err := net.ParseCIDR("192.168.1.0/24")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
domainList, err := domain.FromStringList([]string{"example.com"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config := &Config{
|
||||
Signal: &Host{
|
||||
Proto: "https",
|
||||
URI: "signal.uri",
|
||||
Username: "",
|
||||
Password: "",
|
||||
},
|
||||
Stuns: []*Host{{URI: "stun.uri", Proto: UDP}},
|
||||
TURNConfig: &TURNConfig{
|
||||
Turns: []*Host{{URI: "turn.uri", Proto: UDP, Username: "turn-user", Password: "turn-pass"}},
|
||||
},
|
||||
}
|
||||
peer := &nbpeer.Peer{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
SSHEnabled: true,
|
||||
Key: "peer-key",
|
||||
DNSLabel: "peer1",
|
||||
SSHKey: "peer1-ssh-key",
|
||||
}
|
||||
turnCredentials := &TURNCredentials{
|
||||
Username: "turn-user",
|
||||
Password: "turn-pass",
|
||||
}
|
||||
networkMap := &NetworkMap{
|
||||
Network: &Network{Net: *ipnet, Serial: 1000},
|
||||
Peers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.2"), Key: "peer2-key", DNSLabel: "peer2", SSHEnabled: true, SSHKey: "peer2-ssh-key"}},
|
||||
OfflinePeers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.3"), Key: "peer3-key", DNSLabel: "peer3", SSHEnabled: true, SSHKey: "peer3-ssh-key"}},
|
||||
Routes: []*nbroute.Route{
|
||||
{
|
||||
ID: "route1",
|
||||
Network: netip.MustParsePrefix("10.0.0.0/24"),
|
||||
Domains: domainList,
|
||||
KeepRoute: true,
|
||||
NetID: "route1",
|
||||
Peer: "peer1",
|
||||
NetworkType: 1,
|
||||
Masquerade: true,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
DNSConfig: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
NameServers: []nbdns.NameServer{{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
}},
|
||||
Primary: true,
|
||||
Domains: []string{"example.com"},
|
||||
Enabled: true,
|
||||
SearchDomainsEnabled: true,
|
||||
},
|
||||
{
|
||||
ID: "ns1",
|
||||
NameServers: []nbdns.NameServer{{
|
||||
IP: netip.MustParseAddr("1.1.1.1"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
}},
|
||||
Groups: []string{"group1"},
|
||||
Primary: true,
|
||||
Domains: []string{"example.com"},
|
||||
Enabled: true,
|
||||
SearchDomainsEnabled: true,
|
||||
},
|
||||
},
|
||||
CustomZones: []nbdns.CustomZone{{Domain: "example.com", Records: []nbdns.SimpleRecord{{Name: "example.com", Type: 1, Class: "IN", TTL: 60, RData: "100.64.0.1"}}}},
|
||||
},
|
||||
FirewallRules: []*FirewallRule{
|
||||
{PeerIP: "192.168.1.2", Direction: firewallRuleDirectionIN, Action: string(PolicyTrafficActionAccept), Protocol: string(PolicyRuleProtocolTCP), Port: "80"},
|
||||
},
|
||||
}
|
||||
dnsName := "example.com"
|
||||
checks := []*posture.Checks{
|
||||
{
|
||||
Checks: posture.ChecksDefinition{
|
||||
ProcessCheck: &posture.ProcessCheck{
|
||||
Processes: []posture.Process{{LinuxPath: "/usr/bin/netbird"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
dnsCache := &DNSConfigCache{}
|
||||
|
||||
response := toSyncResponse(context.Background(), config, peer, turnCredentials, networkMap, dnsName, checks, dnsCache)
|
||||
|
||||
assert.NotNil(t, response)
|
||||
// assert peer config
|
||||
assert.Equal(t, "192.168.1.1/24", response.PeerConfig.Address)
|
||||
assert.Equal(t, "peer1.example.com", response.PeerConfig.Fqdn)
|
||||
assert.Equal(t, true, response.PeerConfig.SshConfig.SshEnabled)
|
||||
// assert wiretrustee config
|
||||
assert.Equal(t, "signal.uri", response.WiretrusteeConfig.Signal.Uri)
|
||||
assert.Equal(t, proto.HostConfig_HTTPS, response.WiretrusteeConfig.Signal.GetProtocol())
|
||||
assert.Equal(t, "stun.uri", response.WiretrusteeConfig.Stuns[0].Uri)
|
||||
assert.Equal(t, "turn.uri", response.WiretrusteeConfig.Turns[0].HostConfig.GetUri())
|
||||
assert.Equal(t, "turn-user", response.WiretrusteeConfig.Turns[0].User)
|
||||
assert.Equal(t, "turn-pass", response.WiretrusteeConfig.Turns[0].Password)
|
||||
// assert RemotePeers
|
||||
assert.Equal(t, 1, len(response.RemotePeers))
|
||||
assert.Equal(t, "192.168.1.2/32", response.RemotePeers[0].AllowedIps[0])
|
||||
assert.Equal(t, "peer2-key", response.RemotePeers[0].WgPubKey)
|
||||
assert.Equal(t, "peer2.example.com", response.RemotePeers[0].GetFqdn())
|
||||
assert.Equal(t, false, response.RemotePeers[0].GetSshConfig().GetSshEnabled())
|
||||
assert.Equal(t, []byte("peer2-ssh-key"), response.RemotePeers[0].GetSshConfig().GetSshPubKey())
|
||||
// assert network map
|
||||
assert.Equal(t, uint64(1000), response.NetworkMap.Serial)
|
||||
assert.Equal(t, "192.168.1.1/24", response.NetworkMap.PeerConfig.Address)
|
||||
assert.Equal(t, "peer1.example.com", response.NetworkMap.PeerConfig.Fqdn)
|
||||
assert.Equal(t, true, response.NetworkMap.PeerConfig.SshConfig.SshEnabled)
|
||||
// assert network map RemotePeers
|
||||
assert.Equal(t, 1, len(response.NetworkMap.RemotePeers))
|
||||
assert.Equal(t, "192.168.1.2/32", response.NetworkMap.RemotePeers[0].AllowedIps[0])
|
||||
assert.Equal(t, "peer2-key", response.NetworkMap.RemotePeers[0].WgPubKey)
|
||||
assert.Equal(t, "peer2.example.com", response.NetworkMap.RemotePeers[0].GetFqdn())
|
||||
assert.Equal(t, []byte("peer2-ssh-key"), response.NetworkMap.RemotePeers[0].GetSshConfig().GetSshPubKey())
|
||||
// assert network map OfflinePeers
|
||||
assert.Equal(t, 1, len(response.NetworkMap.OfflinePeers))
|
||||
assert.Equal(t, "192.168.1.3/32", response.NetworkMap.OfflinePeers[0].AllowedIps[0])
|
||||
assert.Equal(t, "peer3-key", response.NetworkMap.OfflinePeers[0].WgPubKey)
|
||||
assert.Equal(t, "peer3.example.com", response.NetworkMap.OfflinePeers[0].GetFqdn())
|
||||
assert.Equal(t, []byte("peer3-ssh-key"), response.NetworkMap.OfflinePeers[0].GetSshConfig().GetSshPubKey())
|
||||
// assert network map Routes
|
||||
assert.Equal(t, 1, len(response.NetworkMap.Routes))
|
||||
assert.Equal(t, "10.0.0.0/24", response.NetworkMap.Routes[0].Network)
|
||||
assert.Equal(t, "route1", response.NetworkMap.Routes[0].ID)
|
||||
assert.Equal(t, "peer1", response.NetworkMap.Routes[0].Peer)
|
||||
assert.Equal(t, "example.com", response.NetworkMap.Routes[0].Domains[0])
|
||||
assert.Equal(t, true, response.NetworkMap.Routes[0].KeepRoute)
|
||||
assert.Equal(t, true, response.NetworkMap.Routes[0].Masquerade)
|
||||
assert.Equal(t, int64(9999), response.NetworkMap.Routes[0].Metric)
|
||||
assert.Equal(t, int64(1), response.NetworkMap.Routes[0].NetworkType)
|
||||
assert.Equal(t, "route1", response.NetworkMap.Routes[0].NetID)
|
||||
// assert network map DNSConfig
|
||||
assert.Equal(t, true, response.NetworkMap.DNSConfig.ServiceEnable)
|
||||
assert.Equal(t, 1, len(response.NetworkMap.DNSConfig.CustomZones))
|
||||
assert.Equal(t, 2, len(response.NetworkMap.DNSConfig.NameServerGroups))
|
||||
// assert network map DNSConfig.CustomZones
|
||||
assert.Equal(t, "example.com", response.NetworkMap.DNSConfig.CustomZones[0].Domain)
|
||||
assert.Equal(t, 1, len(response.NetworkMap.DNSConfig.CustomZones[0].Records))
|
||||
assert.Equal(t, "example.com", response.NetworkMap.DNSConfig.CustomZones[0].Records[0].Name)
|
||||
assert.Equal(t, int64(1), response.NetworkMap.DNSConfig.CustomZones[0].Records[0].Type)
|
||||
assert.Equal(t, "IN", response.NetworkMap.DNSConfig.CustomZones[0].Records[0].Class)
|
||||
assert.Equal(t, int64(60), response.NetworkMap.DNSConfig.CustomZones[0].Records[0].TTL)
|
||||
assert.Equal(t, "100.64.0.1", response.NetworkMap.DNSConfig.CustomZones[0].Records[0].RData)
|
||||
// assert network map DNSConfig.NameServerGroups
|
||||
assert.Equal(t, true, response.NetworkMap.DNSConfig.NameServerGroups[0].Primary)
|
||||
assert.Equal(t, true, response.NetworkMap.DNSConfig.NameServerGroups[0].SearchDomainsEnabled)
|
||||
assert.Equal(t, "example.com", response.NetworkMap.DNSConfig.NameServerGroups[0].Domains[0])
|
||||
assert.Equal(t, "8.8.8.8", response.NetworkMap.DNSConfig.NameServerGroups[0].NameServers[0].GetIP())
|
||||
assert.Equal(t, int64(1), response.NetworkMap.DNSConfig.NameServerGroups[0].NameServers[0].GetNSType())
|
||||
assert.Equal(t, int64(53), response.NetworkMap.DNSConfig.NameServerGroups[0].NameServers[0].GetPort())
|
||||
// assert network map Firewall
|
||||
assert.Equal(t, 1, len(response.NetworkMap.FirewallRules))
|
||||
assert.Equal(t, "192.168.1.2", response.NetworkMap.FirewallRules[0].PeerIP)
|
||||
assert.Equal(t, proto.FirewallRule_IN, response.NetworkMap.FirewallRules[0].Direction)
|
||||
assert.Equal(t, proto.FirewallRule_ACCEPT, response.NetworkMap.FirewallRules[0].Action)
|
||||
assert.Equal(t, proto.FirewallRule_TCP, response.NetworkMap.FirewallRules[0].Protocol)
|
||||
assert.Equal(t, "80", response.NetworkMap.FirewallRules[0].Port)
|
||||
// assert posture checks
|
||||
assert.Equal(t, 1, len(response.Checks))
|
||||
assert.Equal(t, "/usr/bin/netbird", response.Checks[0].Files[0])
|
||||
}
|
||||
|
@ -213,7 +213,6 @@ type FirewallRule struct {
|
||||
//
|
||||
// This function returns the list of peers and firewall rules that are applicable to a given peer.
|
||||
func (a *Account) getPeerConnectionResources(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) {
|
||||
|
||||
generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx)
|
||||
for _, policy := range a.Policies {
|
||||
if !policy.Enabled {
|
||||
@ -225,8 +224,8 @@ func (a *Account) getPeerConnectionResources(ctx context.Context, peerID string,
|
||||
continue
|
||||
}
|
||||
|
||||
sourcePeers, peerInSources := getAllPeersFromGroups(ctx, a, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap)
|
||||
destinationPeers, peerInDestinations := getAllPeersFromGroups(ctx, a, rule.Destinations, peerID, nil, validatedPeersMap)
|
||||
sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap)
|
||||
destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, rule.Destinations, peerID, nil, validatedPeersMap)
|
||||
|
||||
if rule.Bidirectional {
|
||||
if peerInSources {
|
||||
@ -290,8 +289,8 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule,
|
||||
fr.PeerIP = "0.0.0.0"
|
||||
}
|
||||
|
||||
ruleID := (rule.ID + fr.PeerIP + strconv.Itoa(direction) +
|
||||
fr.Protocol + fr.Action + strings.Join(rule.Ports, ","))
|
||||
ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) +
|
||||
fr.Protocol + fr.Action + strings.Join(rule.Ports, ",")
|
||||
if _, ok := rulesExists[ruleID]; ok {
|
||||
continue
|
||||
}
|
||||
@ -491,23 +490,23 @@ func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule {
|
||||
//
|
||||
// Important: Posture checks are applicable only to source group peers,
|
||||
// for destination group peers, call this method with an empty list of sourcePostureChecksIDs
|
||||
func getAllPeersFromGroups(ctx context.Context, account *Account, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) {
|
||||
func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) {
|
||||
peerInGroups := false
|
||||
filteredPeers := make([]*nbpeer.Peer, 0, len(groups))
|
||||
for _, g := range groups {
|
||||
group, ok := account.Groups[g]
|
||||
group, ok := a.Groups[g]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, p := range group.Peers {
|
||||
peer, ok := account.Peers[p]
|
||||
peer, ok := a.Peers[p]
|
||||
if !ok || peer == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// validate the peer based on policy posture checks applied
|
||||
isValid := account.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID)
|
||||
isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID)
|
||||
if !isValid {
|
||||
continue
|
||||
}
|
||||
@ -535,7 +534,7 @@ func (a *Account) validatePostureChecksOnPeer(ctx context.Context, sourcePosture
|
||||
}
|
||||
|
||||
for _, postureChecksID := range sourcePostureChecksID {
|
||||
postureChecks := getPostureChecks(a, postureChecksID)
|
||||
postureChecks := a.getPostureChecks(postureChecksID)
|
||||
if postureChecks == nil {
|
||||
continue
|
||||
}
|
||||
@ -553,8 +552,8 @@ func (a *Account) validatePostureChecksOnPeer(ctx context.Context, sourcePosture
|
||||
return true
|
||||
}
|
||||
|
||||
func getPostureChecks(account *Account, postureChecksID string) *posture.Checks {
|
||||
for _, postureChecks := range account.PostureChecks {
|
||||
func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks {
|
||||
for _, postureChecks := range a.PostureChecks {
|
||||
if postureChecks.ID == postureChecksID {
|
||||
return postureChecks
|
||||
}
|
||||
|
@ -60,7 +60,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
|
||||
}
|
||||
|
||||
if err := postureChecks.Validate(); err != nil {
|
||||
return status.Errorf(status.InvalidArgument, err.Error())
|
||||
return status.Errorf(status.InvalidArgument, err.Error()) //nolint
|
||||
}
|
||||
|
||||
exists, uniqName := am.savePostureChecks(account, postureChecks)
|
||||
|
@ -13,6 +13,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
@ -1233,7 +1234,11 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
return nil, err
|
||||
}
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{})
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics)
|
||||
}
|
||||
|
||||
func createRouterStore(t *testing.T) (Store, error) {
|
||||
|
@ -223,10 +223,8 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, group := range autoGroups {
|
||||
if _, ok := account.Groups[group]; !ok {
|
||||
return nil, status.Errorf(status.NotFound, "group %s doesn't exist", group)
|
||||
}
|
||||
if err := validateSetupKeyAutoGroups(account, autoGroups); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
setupKey := GenerateSetupKey(keyName, keyType, keyDuration, autoGroups, usageLimit, ephemeral)
|
||||
@ -279,6 +277,10 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
|
||||
return nil, status.Errorf(status.NotFound, "setup key not found")
|
||||
}
|
||||
|
||||
if err := validateSetupKeyAutoGroups(account, keyToSave.AutoGroups); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// only auto groups, revoked status, and name can be updated for now
|
||||
newKey := oldKey.Copy()
|
||||
newKey.Name = keyToSave.Name
|
||||
@ -399,3 +401,16 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use
|
||||
|
||||
return foundKey, nil
|
||||
}
|
||||
|
||||
func validateSetupKeyAutoGroups(account *Account, autoGroups []string) error {
|
||||
for _, group := range autoGroups {
|
||||
g, ok := account.Groups[group]
|
||||
if !ok {
|
||||
return status.Errorf(status.NotFound, "group %s doesn't exist", group)
|
||||
}
|
||||
if g.Name == "All" {
|
||||
return status.Errorf(status.InvalidArgument, "can't add All group to the setup key")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -26,10 +26,17 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||
ID: "group_1",
|
||||
Name: "group_name_1",
|
||||
Peers: []string{},
|
||||
err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{
|
||||
{
|
||||
ID: "group_1",
|
||||
Name: "group_name_1",
|
||||
Peers: []string{},
|
||||
},
|
||||
{
|
||||
ID: "group_2",
|
||||
Name: "group_name_2",
|
||||
Peers: []string{},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@ -70,6 +77,19 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
|
||||
assert.NotEmpty(t, ev.Meta["key"])
|
||||
assert.Equal(t, userID, ev.InitiatorID)
|
||||
assert.Equal(t, key.Id, ev.TargetID)
|
||||
|
||||
groupAll, err := account.GetGroupAll()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// saving setup key with All group assigned to auto groups should return error
|
||||
autoGroups = append(autoGroups, groupAll.ID)
|
||||
_, err = manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{
|
||||
Id: key.Id,
|
||||
Name: newKeyName,
|
||||
Revoked: revoked,
|
||||
AutoGroups: autoGroups,
|
||||
}, userID)
|
||||
assert.Error(t, err, "should not save setup key with All group assigned in auto groups")
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
|
||||
@ -102,6 +122,9 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
groupAll, err := account.GetGroupAll()
|
||||
assert.NoError(t, err)
|
||||
|
||||
type testCase struct {
|
||||
name string
|
||||
|
||||
@ -134,8 +157,14 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
|
||||
expectedGroups: []string{"FAKE"},
|
||||
expectedFailure: true,
|
||||
}
|
||||
testCase3 := testCase{
|
||||
name: "Create Setup Key should fail because of All group",
|
||||
expectedKeyName: "my-test-key",
|
||||
expectedGroups: []string{groupAll.ID},
|
||||
expectedFailure: true,
|
||||
}
|
||||
|
||||
for _, tCase := range []testCase{testCase1, testCase2} {
|
||||
for _, tCase := range []testCase{testCase1, testCase2, testCase3} {
|
||||
t.Run(tCase.name, func(t *testing.T) {
|
||||
key, err := manager.CreateSetupKey(context.Background(), account.Id, tCase.expectedKeyName, SetupKeyReusable, expiresIn,
|
||||
tCase.expectedGroups, SetupKeyUnlimitedUsage, userID, false)
|
||||
|
@ -468,6 +468,34 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User,
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetUserByUserID(ctx context.Context, userID string) (*User, error) {
|
||||
var user User
|
||||
result := s.db.First(&user, idQueryCondition, userID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "user not found: index lookup failed")
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting user from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "issue getting user from store")
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
|
||||
var groups []*nbgroup.Group
|
||||
result := s.db.Find(&groups, idQueryCondition, accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting groups from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "issue getting groups from store")
|
||||
}
|
||||
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*Account) {
|
||||
var accounts []Account
|
||||
result := s.db.Find(&accounts)
|
||||
|
@ -41,6 +41,8 @@ type Store interface {
|
||||
GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error)
|
||||
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
|
||||
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
|
||||
GetUserByUserID(ctx context.Context, userID string) (*User, error)
|
||||
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
|
||||
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||
SaveAccount(ctx context.Context, account *Account) error
|
||||
SaveUsers(accountID string, users map[string]*User) error
|
||||
|
69
management/server/telemetry/accountmanager_metrics.go
Normal file
69
management/server/telemetry/accountmanager_metrics.go
Normal file
@ -0,0 +1,69 @@
|
||||
package telemetry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
)
|
||||
|
||||
// AccountManagerMetrics represents all metrics related to the AccountManager
|
||||
type AccountManagerMetrics struct {
|
||||
ctx context.Context
|
||||
updateAccountPeersDurationMs metric.Float64Histogram
|
||||
getPeerNetworkMapDurationMs metric.Float64Histogram
|
||||
networkMapObjectCount metric.Int64Histogram
|
||||
}
|
||||
|
||||
// NewAccountManagerMetrics creates an instance of AccountManagerMetrics
|
||||
func NewAccountManagerMetrics(ctx context.Context, meter metric.Meter) (*AccountManagerMetrics, error) {
|
||||
updateAccountPeersDurationMs, err := meter.Float64Histogram("management.account.update.account.peers.duration.ms",
|
||||
metric.WithUnit("milliseconds"),
|
||||
metric.WithExplicitBucketBoundaries(
|
||||
0.5, 1, 2.5, 5, 10, 25, 50, 100, 250, 500, 1000, 2500, 5000, 10000, 30000,
|
||||
))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
getPeerNetworkMapDurationMs, err := meter.Float64Histogram("management.account.get.peer.network.map.duration.ms",
|
||||
metric.WithUnit("milliseconds"),
|
||||
metric.WithExplicitBucketBoundaries(
|
||||
0.1, 0.5, 1, 2.5, 5, 10, 25, 50, 100, 250, 500, 1000,
|
||||
))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
networkMapObjectCount, err := meter.Int64Histogram("management.account.network.map.object.count",
|
||||
metric.WithUnit("objects"),
|
||||
metric.WithExplicitBucketBoundaries(
|
||||
50, 100, 200, 500, 1000, 2500, 5000, 10000,
|
||||
))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &AccountManagerMetrics{
|
||||
ctx: ctx,
|
||||
getPeerNetworkMapDurationMs: getPeerNetworkMapDurationMs,
|
||||
updateAccountPeersDurationMs: updateAccountPeersDurationMs,
|
||||
networkMapObjectCount: networkMapObjectCount,
|
||||
}, nil
|
||||
|
||||
}
|
||||
|
||||
// CountUpdateAccountPeersDuration counts the duration of updating account peers
|
||||
func (metrics *AccountManagerMetrics) CountUpdateAccountPeersDuration(duration time.Duration) {
|
||||
metrics.updateAccountPeersDurationMs.Record(metrics.ctx, float64(duration.Nanoseconds())/1e6)
|
||||
}
|
||||
|
||||
// CountGetPeerNetworkMapDuration counts the duration of getting the peer network map
|
||||
func (metrics *AccountManagerMetrics) CountGetPeerNetworkMapDuration(duration time.Duration) {
|
||||
metrics.getPeerNetworkMapDurationMs.Record(metrics.ctx, float64(duration.Nanoseconds())/1e6)
|
||||
}
|
||||
|
||||
// CountNetworkMapObjects counts the number of network map objects
|
||||
func (metrics *AccountManagerMetrics) CountNetworkMapObjects(count int64) {
|
||||
metrics.networkMapObjectCount.Record(metrics.ctx, count)
|
||||
}
|
@ -20,14 +20,15 @@ const defaultEndpoint = "/metrics"
|
||||
|
||||
// MockAppMetrics mocks the AppMetrics interface
|
||||
type MockAppMetrics struct {
|
||||
GetMeterFunc func() metric2.Meter
|
||||
CloseFunc func() error
|
||||
ExposeFunc func(ctx context.Context, port int, endpoint string) error
|
||||
IDPMetricsFunc func() *IDPMetrics
|
||||
HTTPMiddlewareFunc func() *HTTPMiddleware
|
||||
GRPCMetricsFunc func() *GRPCMetrics
|
||||
StoreMetricsFunc func() *StoreMetrics
|
||||
UpdateChannelMetricsFunc func() *UpdateChannelMetrics
|
||||
GetMeterFunc func() metric2.Meter
|
||||
CloseFunc func() error
|
||||
ExposeFunc func(ctx context.Context, port int, endpoint string) error
|
||||
IDPMetricsFunc func() *IDPMetrics
|
||||
HTTPMiddlewareFunc func() *HTTPMiddleware
|
||||
GRPCMetricsFunc func() *GRPCMetrics
|
||||
StoreMetricsFunc func() *StoreMetrics
|
||||
UpdateChannelMetricsFunc func() *UpdateChannelMetrics
|
||||
AddAccountManagerMetricsFunc func() *AccountManagerMetrics
|
||||
}
|
||||
|
||||
// GetMeter mocks the GetMeter function of the AppMetrics interface
|
||||
@ -94,6 +95,14 @@ func (mock *MockAppMetrics) UpdateChannelMetrics() *UpdateChannelMetrics {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AccountManagerMetrics mocks the MockAppMetrics function of the AccountManagerMetrics interface
|
||||
func (mock *MockAppMetrics) AccountManagerMetrics() *AccountManagerMetrics {
|
||||
if mock.AddAccountManagerMetricsFunc != nil {
|
||||
return mock.AddAccountManagerMetricsFunc()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AppMetrics is metrics interface
|
||||
type AppMetrics interface {
|
||||
GetMeter() metric2.Meter
|
||||
@ -104,19 +113,21 @@ type AppMetrics interface {
|
||||
GRPCMetrics() *GRPCMetrics
|
||||
StoreMetrics() *StoreMetrics
|
||||
UpdateChannelMetrics() *UpdateChannelMetrics
|
||||
AccountManagerMetrics() *AccountManagerMetrics
|
||||
}
|
||||
|
||||
// defaultAppMetrics are core application metrics based on OpenTelemetry https://opentelemetry.io/
|
||||
type defaultAppMetrics struct {
|
||||
// Meter can be used by different application parts to create counters and measure things
|
||||
Meter metric2.Meter
|
||||
listener net.Listener
|
||||
ctx context.Context
|
||||
idpMetrics *IDPMetrics
|
||||
httpMiddleware *HTTPMiddleware
|
||||
grpcMetrics *GRPCMetrics
|
||||
storeMetrics *StoreMetrics
|
||||
updateChannelMetrics *UpdateChannelMetrics
|
||||
Meter metric2.Meter
|
||||
listener net.Listener
|
||||
ctx context.Context
|
||||
idpMetrics *IDPMetrics
|
||||
httpMiddleware *HTTPMiddleware
|
||||
grpcMetrics *GRPCMetrics
|
||||
storeMetrics *StoreMetrics
|
||||
updateChannelMetrics *UpdateChannelMetrics
|
||||
accountManagerMetrics *AccountManagerMetrics
|
||||
}
|
||||
|
||||
// IDPMetrics returns metrics for the idp package
|
||||
@ -144,6 +155,11 @@ func (appMetrics *defaultAppMetrics) UpdateChannelMetrics() *UpdateChannelMetric
|
||||
return appMetrics.updateChannelMetrics
|
||||
}
|
||||
|
||||
// AccountManagerMetrics returns metrics for the account manager
|
||||
func (appMetrics *defaultAppMetrics) AccountManagerMetrics() *AccountManagerMetrics {
|
||||
return appMetrics.accountManagerMetrics
|
||||
}
|
||||
|
||||
// Close stop application metrics HTTP handler and closes listener.
|
||||
func (appMetrics *defaultAppMetrics) Close() error {
|
||||
if appMetrics.listener == nil {
|
||||
@ -220,13 +236,19 @@ func NewDefaultAppMetrics(ctx context.Context) (AppMetrics, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
accountManagerMetrics, err := NewAccountManagerMetrics(ctx, meter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &defaultAppMetrics{
|
||||
Meter: meter,
|
||||
ctx: ctx,
|
||||
idpMetrics: idpMetrics,
|
||||
httpMiddleware: middleware,
|
||||
grpcMetrics: grpcMetrics,
|
||||
storeMetrics: storeMetrics,
|
||||
updateChannelMetrics: updateChannelMetrics,
|
||||
Meter: meter,
|
||||
ctx: ctx,
|
||||
idpMetrics: idpMetrics,
|
||||
httpMiddleware: middleware,
|
||||
grpcMetrics: grpcMetrics,
|
||||
storeMetrics: storeMetrics,
|
||||
updateChannelMetrics: updateChannelMetrics,
|
||||
accountManagerMetrics: accountManagerMetrics,
|
||||
}, nil
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
@ -472,51 +473,18 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account *Account, initiatorUserID, targetUserID string) error {
|
||||
tuEmail, tuName, err := am.getEmailAndNameOfTargetUser(ctx, account.Id, initiatorUserID, targetUserID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to resolve email address: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if !isNil(am.idpManager) {
|
||||
// Delete if the user already exists in the IdP.Necessary in cases where a user account
|
||||
// was created where a user account was provisioned but the user did not sign in
|
||||
_, err = am.idpManager.GetUserDataByID(ctx, targetUserID, idp.AppMetadata{WTAccountID: account.Id})
|
||||
if err == nil {
|
||||
err = am.deleteUserFromIDP(ctx, targetUserID, account.Id)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to delete user from IDP: %s", targetUserID)
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
log.WithContext(ctx).Debugf("skipped deleting user %s from IDP, error: %v", targetUserID, err)
|
||||
}
|
||||
}
|
||||
|
||||
err = am.deleteUserPeers(ctx, initiatorUserID, targetUserID, account)
|
||||
meta, err := am.prepareUserDeletion(ctx, account, initiatorUserID, targetUserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
u, err := account.FindUser(targetUserID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to find user %s for deletion, this should never happen: %s", targetUserID, err)
|
||||
}
|
||||
|
||||
var tuCreatedAt time.Time
|
||||
if u != nil {
|
||||
tuCreatedAt = u.CreatedAt
|
||||
}
|
||||
|
||||
delete(account.Users, targetUserID)
|
||||
err = am.Store.SaveAccount(ctx, account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
meta := map[string]any{"name": tuName, "email": tuEmail, "created_at": tuCreatedAt}
|
||||
am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta)
|
||||
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
return nil
|
||||
@ -976,10 +944,14 @@ func validateUserUpdate(account *Account, initiatorUser, oldUser, update *User)
|
||||
}
|
||||
|
||||
for _, newGroupID := range update.AutoGroups {
|
||||
if _, ok := account.Groups[newGroupID]; !ok {
|
||||
group, ok := account.Groups[newGroupID]
|
||||
if !ok {
|
||||
return status.Errorf(status.InvalidArgument, "provided group ID %s in the user %s update doesn't exist",
|
||||
newGroupID, update.Id)
|
||||
}
|
||||
if group.Name == "All" {
|
||||
return status.Errorf(status.InvalidArgument, "can't add All group to the user")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -1190,6 +1162,116 @@ func (am *DefaultAccountManager) getEmailAndNameOfTargetUser(ctx context.Context
|
||||
return "", "", fmt.Errorf("user info not found for user: %s", targetId)
|
||||
}
|
||||
|
||||
// DeleteRegularUsers deletes regular users from an account.
|
||||
// Note: This function does not acquire the global lock.
|
||||
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
|
||||
//
|
||||
// If an error occurs while deleting the user, the function skips it and continues deleting other users.
|
||||
// Errors are collected and returned at the end.
|
||||
func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string) error {
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
executingUser := account.Users[initiatorUserID]
|
||||
if executingUser == nil {
|
||||
return status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
if !executingUser.HasAdminPower() {
|
||||
return status.Errorf(status.PermissionDenied, "only users with admin power can delete users")
|
||||
}
|
||||
|
||||
var allErrors error
|
||||
|
||||
deletedUsersMeta := make(map[string]map[string]any)
|
||||
for _, targetUserID := range targetUserIDs {
|
||||
if initiatorUserID == targetUserID {
|
||||
allErrors = errors.Join(allErrors, errors.New("self deletion is not allowed"))
|
||||
continue
|
||||
}
|
||||
|
||||
targetUser := account.Users[targetUserID]
|
||||
if targetUser == nil {
|
||||
allErrors = errors.Join(allErrors, fmt.Errorf("target user: %s not found", targetUserID))
|
||||
continue
|
||||
}
|
||||
|
||||
if targetUser.Role == UserRoleOwner {
|
||||
allErrors = errors.Join(allErrors, fmt.Errorf("unable to delete a user: %s with owner role", targetUserID))
|
||||
continue
|
||||
}
|
||||
|
||||
// disable deleting integration user if the initiator is not admin service user
|
||||
if targetUser.Issued == UserIssuedIntegration && !executingUser.IsServiceUser {
|
||||
allErrors = errors.Join(allErrors, errors.New("only integration service user can delete this user"))
|
||||
continue
|
||||
}
|
||||
|
||||
meta, err := am.prepareUserDeletion(ctx, account, initiatorUserID, targetUserID)
|
||||
if err != nil {
|
||||
allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete user %s: %s", targetUserID, err))
|
||||
continue
|
||||
}
|
||||
|
||||
delete(account.Users, targetUserID)
|
||||
deletedUsersMeta[targetUserID] = meta
|
||||
}
|
||||
|
||||
err = am.Store.SaveAccount(ctx, account)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete users: %w", err)
|
||||
}
|
||||
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
for targetUserID, meta := range deletedUsersMeta {
|
||||
am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta)
|
||||
}
|
||||
|
||||
return allErrors
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, account *Account, initiatorUserID, targetUserID string) (map[string]any, error) {
|
||||
tuEmail, tuName, err := am.getEmailAndNameOfTargetUser(ctx, account.Id, initiatorUserID, targetUserID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to resolve email address: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !isNil(am.idpManager) {
|
||||
// Delete if the user already exists in the IdP. Necessary in cases where a user account
|
||||
// was created where a user account was provisioned but the user did not sign in
|
||||
_, err = am.idpManager.GetUserDataByID(ctx, targetUserID, idp.AppMetadata{WTAccountID: account.Id})
|
||||
if err == nil {
|
||||
err = am.deleteUserFromIDP(ctx, targetUserID, account.Id)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to delete user from IDP: %s", targetUserID)
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
log.WithContext(ctx).Debugf("skipped deleting user %s from IDP, error: %v", targetUserID, err)
|
||||
}
|
||||
}
|
||||
|
||||
err = am.deleteUserPeers(ctx, initiatorUserID, targetUserID, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
u, err := account.FindUser(targetUserID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to find user %s for deletion, this should never happen: %s", targetUserID, err)
|
||||
}
|
||||
|
||||
var tuCreatedAt time.Time
|
||||
if u != nil {
|
||||
tuCreatedAt = u.CreatedAt
|
||||
}
|
||||
|
||||
return map[string]any{"name": tuName, "email": tuEmail, "created_at": tuCreatedAt}, nil
|
||||
}
|
||||
|
||||
func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) {
|
||||
for _, user := range userData {
|
||||
if user.ID == userID {
|
||||
|
@ -662,6 +662,157 @@ func TestUser_DeleteUser_regularUser(t *testing.T) {
|
||||
|
||||
}
|
||||
|
||||
func TestUser_DeleteUser_RegularUsers(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
|
||||
targetId := "user2"
|
||||
account.Users[targetId] = &User{
|
||||
Id: targetId,
|
||||
IsServiceUser: true,
|
||||
ServiceUserName: "user2username",
|
||||
}
|
||||
targetId = "user3"
|
||||
account.Users[targetId] = &User{
|
||||
Id: targetId,
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedAPI,
|
||||
}
|
||||
targetId = "user4"
|
||||
account.Users[targetId] = &User{
|
||||
Id: targetId,
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedIntegration,
|
||||
}
|
||||
|
||||
targetId = "user5"
|
||||
account.Users[targetId] = &User{
|
||||
Id: targetId,
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedAPI,
|
||||
Role: UserRoleOwner,
|
||||
}
|
||||
account.Users["user6"] = &User{
|
||||
Id: "user6",
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedAPI,
|
||||
}
|
||||
account.Users["user7"] = &User{
|
||||
Id: "user7",
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedAPI,
|
||||
}
|
||||
account.Users["user8"] = &User{
|
||||
Id: "user8",
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedAPI,
|
||||
Role: UserRoleAdmin,
|
||||
}
|
||||
account.Users["user9"] = &User{
|
||||
Id: "user9",
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedAPI,
|
||||
Role: UserRoleAdmin,
|
||||
}
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
integratedPeerValidator: MocIntegratedValidator{},
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
userIDs []string
|
||||
expectedReasons []string
|
||||
expectedDeleted []string
|
||||
expectedNotDeleted []string
|
||||
}{
|
||||
{
|
||||
name: "Delete service user successfully ",
|
||||
userIDs: []string{"user2"},
|
||||
expectedDeleted: []string{"user2"},
|
||||
},
|
||||
{
|
||||
name: "Delete regular user successfully",
|
||||
userIDs: []string{"user3"},
|
||||
expectedDeleted: []string{"user3"},
|
||||
},
|
||||
{
|
||||
name: "Delete integration regular user permission denied",
|
||||
userIDs: []string{"user4"},
|
||||
expectedReasons: []string{"only integration service user can delete this user"},
|
||||
expectedNotDeleted: []string{"user4"},
|
||||
},
|
||||
{
|
||||
name: "Delete user with owner role should return permission denied",
|
||||
userIDs: []string{"user5"},
|
||||
expectedReasons: []string{"unable to delete a user: user5 with owner role"},
|
||||
expectedNotDeleted: []string{"user5"},
|
||||
},
|
||||
{
|
||||
name: "Delete multiple users with mixed results",
|
||||
userIDs: []string{"user5", "user5", "user6", "user7"},
|
||||
expectedReasons: []string{"only integration service user can delete this user", "unable to delete a user: user5 with owner role"},
|
||||
expectedDeleted: []string{"user6", "user7"},
|
||||
expectedNotDeleted: []string{"user4", "user5"},
|
||||
},
|
||||
{
|
||||
name: "Delete non-existent user",
|
||||
userIDs: []string{"non-existent-user"},
|
||||
expectedReasons: []string{"target user: non-existent-user not found"},
|
||||
expectedNotDeleted: []string{},
|
||||
},
|
||||
{
|
||||
name: "Delete multiple regular users successfully",
|
||||
userIDs: []string{"user8", "user9"},
|
||||
expectedDeleted: []string{"user8", "user9"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err = am.DeleteRegularUsers(context.Background(), mockAccountID, mockUserID, tc.userIDs)
|
||||
if len(tc.expectedReasons) > 0 {
|
||||
assert.Error(t, err)
|
||||
var foundExpectedErrors int
|
||||
|
||||
wrappedErr, ok := err.(interface{ Unwrap() []error })
|
||||
assert.Equal(t, ok, true)
|
||||
|
||||
for _, e := range wrappedErr.Unwrap() {
|
||||
assert.Contains(t, tc.expectedReasons, e.Error(), "unexpected error message")
|
||||
foundExpectedErrors++
|
||||
}
|
||||
|
||||
assert.Equal(t, len(tc.expectedReasons), foundExpectedErrors, "not all expected errors were found")
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
acc, err := am.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "")
|
||||
assert.NoError(t, err)
|
||||
|
||||
for _, id := range tc.expectedDeleted {
|
||||
_, exists := acc.Users[id]
|
||||
assert.False(t, exists, "user should have been deleted: %s", id)
|
||||
}
|
||||
|
||||
for _, id := range tc.expectedNotDeleted {
|
||||
user, exists := acc.Users[id]
|
||||
assert.True(t, exists, "user should not have been deleted: %s", id)
|
||||
assert.NotNil(t, user, "user should exist: %s", id)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_GetUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
|
@ -151,6 +151,22 @@ add_aur_repo() {
|
||||
${SUDO} pacman -Rs "$REMOVE_PKGS" --noconfirm
|
||||
}
|
||||
|
||||
prepare_tun_module() {
|
||||
# Create the necessary file structure for /dev/net/tun
|
||||
if [ ! -c /dev/net/tun ]; then
|
||||
if [ ! -d /dev/net ]; then
|
||||
mkdir -m 755 /dev/net
|
||||
fi
|
||||
mknod /dev/net/tun c 10 200
|
||||
chmod 0755 /dev/net/tun
|
||||
fi
|
||||
|
||||
# Load the tun module if not already loaded
|
||||
if ! lsmod | grep -q "^tun\s"; then
|
||||
insmod /lib/modules/tun.ko
|
||||
fi
|
||||
}
|
||||
|
||||
install_native_binaries() {
|
||||
# Checks for supported architecture
|
||||
case "$ARCH" in
|
||||
@ -268,6 +284,10 @@ install_netbird() {
|
||||
;;
|
||||
esac
|
||||
|
||||
if [ "$OS_NAME" = "synology" ]; then
|
||||
prepare_tun_module
|
||||
fi
|
||||
|
||||
# Add package manager to config
|
||||
${SUDO} mkdir -p "$CONFIG_FOLDER"
|
||||
echo "package_manager=$PACKAGE_MANAGER" | ${SUDO} tee "$CONFIG_FILE" > /dev/null
|
||||
|
@ -10,5 +10,5 @@ import (
|
||||
|
||||
// Listen is not supported on other platforms then Linux
|
||||
func Listen(port int, filter BPFFilter) (net.PacketConn, error) {
|
||||
return nil, fmt.Errorf(fmt.Sprintf("Not supported OS %s. SharedSocket is only supported on Linux", runtime.GOOS))
|
||||
return nil, fmt.Errorf("not supported OS %s. SharedSocket is only supported on Linux", runtime.GOOS)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user