Merge branch 'feature/optimize-network-map-updates' into feature/validate-group-association

# Conflicts:
#	management/server/dns_test.go
#	management/server/group.go
#	management/server/nameserver.go
#	management/server/peer.go
#	management/server/peer_test.go
#	management/server/user.go
This commit is contained in:
bcmmbaga
2024-08-13 16:30:04 +03:00
106 changed files with 3739 additions and 1374 deletions

8
.editorconfig Normal file
View File

@ -0,0 +1,8 @@
root = true
[*]
end_of_line = lf
insert_final_newline = true
[*.go]
indent_style = tab

View File

@ -31,9 +31,14 @@ Please specify whether you use NetBird Cloud or self-host NetBird's control plan
`netbird version` `netbird version`
**NetBird status -d output:** **NetBird status -dA output:**
If applicable, add the `netbird status -d' command output. If applicable, add the `netbird status -dA' command output.
**Do you face any client issues on desktop?**
Please provide the file created by `netbird debug for 1m -AS`.
We advise reviewing the anonymized files for any remaining PII.
**Screenshots** **Screenshots**

View File

@ -13,7 +13,7 @@ concurrency:
jobs: jobs:
test: test:
runs-on: ubuntu-latest runs-on: ubuntu-22.04
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Test in FreeBSD - name: Test in FreeBSD
@ -21,19 +21,26 @@ jobs:
uses: vmactions/freebsd-vm@v1 uses: vmactions/freebsd-vm@v1
with: with:
usesh: true usesh: true
copyback: false
release: "14.1"
prepare: | prepare: |
pkg install -y curl pkg install -y go
pkg install -y git
# -x - to print all executed commands
# -e - to faile on first error
run: | run: |
set -x set -e -x
curl -o go.tar.gz https://go.dev/dl/go1.21.11.freebsd-amd64.tar.gz -L time go build -o netbird client/main.go
tar zxf go.tar.gz # check all component except management, since we do not support management server on freebsd
mv go /usr/local/go time go test -timeout 1m -failfast ./base62/...
ln -s /usr/local/go/bin/go /usr/local/bin/go # NOTE: without -p1 `client/internal/dns` will fail becasue of `listen udp4 :33100: bind: address already in use`
go mod tidy time go test -timeout 8m -failfast -p 1 ./client/...
go test -timeout 5m -p 1 ./iface/... time go test -timeout 1m -failfast ./dns/...
go test -timeout 5m -p 1 ./client/... time go test -timeout 1m -failfast ./encryption/...
cd client time go test -timeout 1m -failfast ./formatter/...
go build . time go test -timeout 1m -failfast ./iface/...
cd .. time go test -timeout 1m -failfast ./route/...
time go test -timeout 1m -failfast ./sharedsock/...
time go test -timeout 1m -failfast ./signal/...
time go test -timeout 1m -failfast ./util/...
time go test -timeout 1m -failfast ./version/...

View File

@ -79,15 +79,8 @@ jobs:
- name: Install goversioninfo - name: Install goversioninfo
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
- name: Generate windows syso 386
run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-build ${{ github.run_id }} -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -product-version ${{ steps.semver_parser.outputs.fullversion }}.${{ github.run_id }} -o client/resources_windows_386.syso
- name: Generate windows syso arm
run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-build ${{ github.run_id }} -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -product-version ${{ steps.semver_parser.outputs.fullversion }}.${{ github.run_id }} -o client/resources_windows_arm.syso
- name: Generate windows syso arm64
run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-build ${{ github.run_id }} -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -product-version ${{ steps.semver_parser.outputs.fullversion }}.${{ github.run_id }} -o client/resources_windows_arm64.syso
- name: Generate windows syso amd64 - name: Generate windows syso amd64
run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-build ${{ github.run_id }} -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -product-version ${{ steps.semver_parser.outputs.fullversion }}.${{ github.run_id }} -o client/resources_windows_amd64.syso run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso
- name: Run GoReleaser - name: Run GoReleaser
uses: goreleaser/goreleaser-action@v4 uses: goreleaser/goreleaser-action@v4
with: with:
@ -170,7 +163,7 @@ jobs:
- name: Install goversioninfo - name: Install goversioninfo
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
- name: Generate windows syso amd64 - name: Generate windows syso amd64
run: goversioninfo -64 -icon client/ui/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-build ${{ github.run_id }} -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -product-version ${{ steps.semver_parser.outputs.fullversion }}.${{ github.run_id }} -o client/ui/resources_windows_amd64.syso run: goversioninfo -64 -icon client/ui/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso
- name: Run GoReleaser - name: Run GoReleaser
uses: goreleaser/goreleaser-action@v4 uses: goreleaser/goreleaser-action@v4

View File

@ -151,10 +151,10 @@ jobs:
- name: run docker compose up - name: run docker compose up
working-directory: infrastructure_files/artifacts working-directory: infrastructure_files/artifacts
run: | run: |
docker-compose up -d docker compose up -d
sleep 5 sleep 5
docker-compose ps docker compose ps
docker-compose logs --tail=20 docker compose logs --tail=20
- name: test running containers - name: test running containers
run: | run: |
@ -207,7 +207,7 @@ jobs:
- name: Postgres run cleanup - name: Postgres run cleanup
run: | run: |
docker-compose down --volumes --rmi all docker compose down --volumes --rmi all
rm -rf docker-compose.yml Caddyfile zitadel.env dashboard.env machinekey/zitadel-admin-sa.token turnserver.conf management.json zdb.env rm -rf docker-compose.yml Caddyfile zitadel.env dashboard.env machinekey/zitadel-admin-sa.token turnserver.conf management.json zdb.env
- name: run script with Zitadel CockroachDB - name: run script with Zitadel CockroachDB

View File

@ -11,8 +11,6 @@ builds:
- amd64 - amd64
ldflags: ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
tags:
- legacy_appindicator
mod_timestamp: '{{ .CommitTimestamp }}' mod_timestamp: '{{ .CommitTimestamp }}'
- id: netbird-ui-windows - id: netbird-ui-windows

View File

@ -10,10 +10,12 @@
<img width="234" src="docs/media/logo-full.png"/> <img width="234" src="docs/media/logo-full.png"/>
</p> </p>
<p> <p>
<a href="https://img.shields.io/badge/license-BSD--3-blue)">
<img src="https://sonarcloud.io/api/project_badges/measure?project=netbirdio_netbird&metric=alert_status" />
</a>
<a href="https://github.com/netbirdio/netbird/blob/main/LICENSE"> <a href="https://github.com/netbirdio/netbird/blob/main/LICENSE">
<img src="https://img.shields.io/badge/license-BSD--3-blue" /> <img src="https://img.shields.io/badge/license-BSD--3-blue" />
</a> </a>
<a href="https://www.codacy.com/gh/netbirdio/netbird/dashboard?utm_source=github.com&amp;utm_medium=referral&amp;utm_content=netbirdio/netbird&amp;utm_campaign=Badge_Grade"><img src="https://app.codacy.com/project/badge/Grade/e3013d046aec44cdb7462c8673b00976"/></a>
<br> <br>
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-vrahf41g-ik1v7fV8du6t0RwxSrJ96A"> <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-vrahf41g-ik1v7fV8du6t0RwxSrJ96A">
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/> <img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>

View File

@ -178,6 +178,21 @@ func (a *Anonymizer) AnonymizeDNSLogLine(logEntry string) string {
}) })
} }
// AnonymizeRoute anonymizes a route string by replacing IP addresses with anonymized versions and
// domain names with random strings.
func (a *Anonymizer) AnonymizeRoute(route string) string {
prefix, err := netip.ParsePrefix(route)
if err == nil {
ip := a.AnonymizeIPString(prefix.Addr().String())
return fmt.Sprintf("%s/%d", ip, prefix.Bits())
}
domains := strings.Split(route, ", ")
for i, domain := range domains {
domains[i] = a.AnonymizeDomain(domain)
}
return strings.Join(domains, ", ")
}
func isWellKnown(addr netip.Addr) bool { func isWellKnown(addr netip.Addr) bool {
wellKnown := []string{ wellKnown := []string{
"8.8.8.8", "8.8.4.4", // Google DNS IPv4 "8.8.8.8", "8.8.4.4", // Google DNS IPv4

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"time" "time"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
@ -13,6 +14,8 @@ import (
"github.com/netbirdio/netbird/client/server" "github.com/netbirdio/netbird/client/server"
) )
const errCloseConnection = "Failed to close connection: %v"
var debugCmd = &cobra.Command{ var debugCmd = &cobra.Command{
Use: "debug", Use: "debug",
Short: "Debugging commands", Short: "Debugging commands",
@ -63,12 +66,17 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
if err != nil { if err != nil {
return err return err
} }
defer conn.Close() defer func() {
if err := conn.Close(); err != nil {
log.Errorf(errCloseConnection, err)
}
}()
client := proto.NewDaemonServiceClient(conn) client := proto.NewDaemonServiceClient(conn)
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{ resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
Anonymize: anonymizeFlag, Anonymize: anonymizeFlag,
Status: getStatusOutput(cmd), Status: getStatusOutput(cmd),
SystemInfo: debugSystemInfoFlag,
}) })
if err != nil { if err != nil {
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message()) return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
@ -84,7 +92,11 @@ func setLogLevel(cmd *cobra.Command, args []string) error {
if err != nil { if err != nil {
return err return err
} }
defer conn.Close() defer func() {
if err := conn.Close(); err != nil {
log.Errorf(errCloseConnection, err)
}
}()
client := proto.NewDaemonServiceClient(conn) client := proto.NewDaemonServiceClient(conn)
level := server.ParseLogLevel(args[0]) level := server.ParseLogLevel(args[0])
@ -113,7 +125,11 @@ func runForDuration(cmd *cobra.Command, args []string) error {
if err != nil { if err != nil {
return err return err
} }
defer conn.Close() defer func() {
if err := conn.Close(); err != nil {
log.Errorf(errCloseConnection, err)
}
}()
client := proto.NewDaemonServiceClient(conn) client := proto.NewDaemonServiceClient(conn)
@ -122,17 +138,20 @@ func runForDuration(cmd *cobra.Command, args []string) error {
return fmt.Errorf("failed to get status: %v", status.Convert(err).Message()) return fmt.Errorf("failed to get status: %v", status.Convert(err).Message())
} }
restoreUp := stat.Status == string(internal.StatusConnected) || stat.Status == string(internal.StatusConnecting) stateWasDown := stat.Status != string(internal.StatusConnected) && stat.Status != string(internal.StatusConnecting)
initialLogLevel, err := client.GetLogLevel(cmd.Context(), &proto.GetLogLevelRequest{}) initialLogLevel, err := client.GetLogLevel(cmd.Context(), &proto.GetLogLevelRequest{})
if err != nil { if err != nil {
return fmt.Errorf("failed to get log level: %v", status.Convert(err).Message()) return fmt.Errorf("failed to get log level: %v", status.Convert(err).Message())
} }
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil { if stateWasDown {
return fmt.Errorf("failed to down: %v", status.Convert(err).Message()) if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
}
cmd.Println("Netbird up")
time.Sleep(time.Second * 10)
} }
cmd.Println("Netbird down")
initialLevelTrace := initialLogLevel.GetLevel() >= proto.LogLevel_TRACE initialLevelTrace := initialLogLevel.GetLevel() >= proto.LogLevel_TRACE
if !initialLevelTrace { if !initialLevelTrace {
@ -145,6 +164,11 @@ func runForDuration(cmd *cobra.Command, args []string) error {
cmd.Println("Log level set to trace.") cmd.Println("Log level set to trace.")
} }
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
}
cmd.Println("Netbird down")
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil { if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
@ -162,21 +186,25 @@ func runForDuration(cmd *cobra.Command, args []string) error {
} }
cmd.Println("\nDuration completed") cmd.Println("\nDuration completed")
cmd.Println("Creating debug bundle...")
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration) headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd)) statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd))
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil { resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
return fmt.Errorf("failed to down: %v", status.Convert(err).Message()) Anonymize: anonymizeFlag,
Status: statusOutput,
SystemInfo: debugSystemInfoFlag,
})
if err != nil {
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
} }
cmd.Println("Netbird down")
time.Sleep(1 * time.Second) if stateWasDown {
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
if restoreUp { return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
} }
cmd.Println("Netbird up") cmd.Println("Netbird down")
} }
if !initialLevelTrace { if !initialLevelTrace {
@ -186,16 +214,6 @@ func runForDuration(cmd *cobra.Command, args []string) error {
cmd.Println("Log level restored to", initialLogLevel.GetLevel()) cmd.Println("Log level restored to", initialLogLevel.GetLevel())
} }
cmd.Println("Creating debug bundle...")
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
Anonymize: anonymizeFlag,
Status: statusOutput,
})
if err != nil {
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
}
cmd.Println(resp.GetPath()) cmd.Println(resp.GetPath())
return nil return nil

View File

@ -26,7 +26,7 @@ var downCmd = &cobra.Command{
return err return err
} }
ctx, cancel := context.WithTimeout(context.Background(), time.Second*3) ctx, cancel := context.WithTimeout(context.Background(), time.Second*7)
defer cancel() defer cancel()
conn, err := DialClientGRPCServer(ctx, daemonAddr) conn, err := DialClientGRPCServer(ctx, daemonAddr)

View File

@ -39,6 +39,11 @@ var loginCmd = &cobra.Command{
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, hostName) ctx = context.WithValue(ctx, system.DeviceNameCtxKey, hostName)
} }
providedSetupKey, err := getSetupKey()
if err != nil {
return err
}
// workaround to run without service // workaround to run without service
if logFile == "console" { if logFile == "console" {
err = handleRebrand(cmd) err = handleRebrand(cmd)
@ -62,7 +67,7 @@ var loginCmd = &cobra.Command{
config, _ = internal.UpdateOldManagementURL(ctx, config, configPath) config, _ = internal.UpdateOldManagementURL(ctx, config, configPath)
err = foregroundLogin(ctx, cmd, config, setupKey) err = foregroundLogin(ctx, cmd, config, providedSetupKey)
if err != nil { if err != nil {
return fmt.Errorf("foreground login failed: %v", err) return fmt.Errorf("foreground login failed: %v", err)
} }
@ -81,7 +86,7 @@ var loginCmd = &cobra.Command{
client := proto.NewDaemonServiceClient(conn) client := proto.NewDaemonServiceClient(conn)
loginRequest := proto.LoginRequest{ loginRequest := proto.LoginRequest{
SetupKey: setupKey, SetupKey: providedSetupKey,
ManagementUrl: managementURL, ManagementUrl: managementURL,
IsLinuxDesktopClient: isLinuxRunningDesktop(), IsLinuxDesktopClient: isLinuxRunningDesktop(),
Hostname: hostName, Hostname: hostName,

View File

@ -37,6 +37,7 @@ const (
serverSSHAllowedFlag = "allow-server-ssh" serverSSHAllowedFlag = "allow-server-ssh"
extraIFaceBlackListFlag = "extra-iface-blacklist" extraIFaceBlackListFlag = "extra-iface-blacklist"
dnsRouteIntervalFlag = "dns-router-interval" dnsRouteIntervalFlag = "dns-router-interval"
systemInfoFlag = "system-info"
) )
var ( var (
@ -55,6 +56,7 @@ var (
managementURL string managementURL string
adminURL string adminURL string
setupKey string setupKey string
setupKeyPath string
hostName string hostName string
preSharedKey string preSharedKey string
natExternalIPs []string natExternalIPs []string
@ -69,6 +71,7 @@ var (
autoConnectDisabled bool autoConnectDisabled bool
extraIFaceBlackList []string extraIFaceBlackList []string
anonymizeFlag bool anonymizeFlag bool
debugSystemInfoFlag bool
dnsRouteInterval time.Duration dnsRouteInterval time.Duration
rootCmd = &cobra.Command{ rootCmd = &cobra.Command{
@ -91,12 +94,15 @@ func init() {
oldDefaultConfigPathDir = "/etc/wiretrustee/" oldDefaultConfigPathDir = "/etc/wiretrustee/"
oldDefaultLogFileDir = "/var/log/wiretrustee/" oldDefaultLogFileDir = "/var/log/wiretrustee/"
if runtime.GOOS == "windows" { switch runtime.GOOS {
case "windows":
defaultConfigPathDir = os.Getenv("PROGRAMDATA") + "\\Netbird\\" defaultConfigPathDir = os.Getenv("PROGRAMDATA") + "\\Netbird\\"
defaultLogFileDir = os.Getenv("PROGRAMDATA") + "\\Netbird\\" defaultLogFileDir = os.Getenv("PROGRAMDATA") + "\\Netbird\\"
oldDefaultConfigPathDir = os.Getenv("PROGRAMDATA") + "\\Wiretrustee\\" oldDefaultConfigPathDir = os.Getenv("PROGRAMDATA") + "\\Wiretrustee\\"
oldDefaultLogFileDir = os.Getenv("PROGRAMDATA") + "\\Wiretrustee\\" oldDefaultLogFileDir = os.Getenv("PROGRAMDATA") + "\\Wiretrustee\\"
case "freebsd":
defaultConfigPathDir = "/var/db/netbird/"
} }
defaultConfigPath = defaultConfigPathDir + "config.json" defaultConfigPath = defaultConfigPathDir + "config.json"
@ -123,6 +129,8 @@ func init() {
rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets Netbird log level") rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets Netbird log level")
rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout. If syslog is specified the log will be sent to syslog daemon.") rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout. If syslog is specified the log will be sent to syslog daemon.")
rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)") rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)")
rootCmd.PersistentFlags().StringVar(&setupKeyPath, "setup-key-file", "", "The path to a setup key obtained from the Management Service Dashboard (used to register peer) This is ignored if the setup-key flag is provided.")
rootCmd.MarkFlagsMutuallyExclusive("setup-key", "setup-key-file")
rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.") rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.")
rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device") rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device")
rootCmd.PersistentFlags().BoolVarP(&anonymizeFlag, "anonymize", "A", false, "anonymize IP addresses and non-netbird.io domains in logs and status output") rootCmd.PersistentFlags().BoolVarP(&anonymizeFlag, "anonymize", "A", false, "anonymize IP addresses and non-netbird.io domains in logs and status output")
@ -165,6 +173,8 @@ func init() {
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.") upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted") upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.") upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", false, "Adds system information to the debug bundle")
} }
// SetupCloseHandler handles SIGTERM signal and exits with success // SetupCloseHandler handles SIGTERM signal and exits with success
@ -246,6 +256,21 @@ var CLIBackOffSettings = &backoff.ExponentialBackOff{
Clock: backoff.SystemClock, Clock: backoff.SystemClock,
} }
func getSetupKey() (string, error) {
if setupKeyPath != "" && setupKey == "" {
return getSetupKeyFromFile(setupKeyPath)
}
return setupKey, nil
}
func getSetupKeyFromFile(setupKeyPath string) (string, error) {
data, err := os.ReadFile(setupKeyPath)
if err != nil {
return "", fmt.Errorf("failed to read setup key file: %v", err)
}
return strings.TrimSpace(string(data)), nil
}
func handleRebrand(cmd *cobra.Command) error { func handleRebrand(cmd *cobra.Command) error {
var err error var err error
if logFile == defaultLogFile { if logFile == defaultLogFile {

View File

@ -31,6 +31,8 @@ var installCmd = &cobra.Command{
configPath, configPath,
"--log-level", "--log-level",
logLevel, logLevel,
"--daemon-addr",
daemonAddr,
} }
if managementURL != "" { if managementURL != "" {

View File

@ -807,7 +807,7 @@ func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) {
} }
for i, route := range peer.Routes { for i, route := range peer.Routes {
peer.Routes[i] = anonymizeRoute(a, route) peer.Routes[i] = a.AnonymizeRoute(route)
} }
} }
@ -843,21 +843,8 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *statusOutputOverview)
} }
for i, route := range overview.Routes { for i, route := range overview.Routes {
overview.Routes[i] = anonymizeRoute(a, route) overview.Routes[i] = a.AnonymizeRoute(route)
} }
overview.FQDN = a.AnonymizeDomain(overview.FQDN) overview.FQDN = a.AnonymizeDomain(overview.FQDN)
} }
func anonymizeRoute(a *anonymize.Anonymizer, route string) string {
prefix, err := netip.ParsePrefix(route)
if err == nil {
ip := a.AnonymizeIPString(prefix.Addr().String())
return fmt.Sprintf("%s/%d", ip, prefix.Bits())
}
domains := strings.Split(route, ", ")
for i, domain := range domains {
domains[i] = a.AnonymizeDomain(domain)
}
return strings.Join(domains, ", ")
}

View File

@ -11,6 +11,7 @@ import (
"go.opentelemetry.io/otel" "go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
@ -71,6 +72,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Listener) { func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Listener) {
t.Helper() t.Helper()
lis, err := net.Listen("tcp", ":0") lis, err := net.Listen("tcp", ":0")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -88,7 +90,11 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
return nil, nil return nil, nil
} }
iv, _ := integrations.NewIntegratedValidator(context.Background(), eventStore) iv, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -147,6 +147,11 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
ic.DNSRouteInterval = &dnsRouteInterval ic.DNSRouteInterval = &dnsRouteInterval
} }
providedSetupKey, err := getSetupKey()
if err != nil {
return err
}
config, err := internal.UpdateOrCreateConfig(ic) config, err := internal.UpdateOrCreateConfig(ic)
if err != nil { if err != nil {
return fmt.Errorf("get config file: %v", err) return fmt.Errorf("get config file: %v", err)
@ -154,7 +159,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
config, _ = internal.UpdateOldManagementURL(ctx, config, configPath) config, _ = internal.UpdateOldManagementURL(ctx, config, configPath)
err = foregroundLogin(ctx, cmd, config, setupKey) err = foregroundLogin(ctx, cmd, config, providedSetupKey)
if err != nil { if err != nil {
return fmt.Errorf("foreground login failed: %v", err) return fmt.Errorf("foreground login failed: %v", err)
} }
@ -199,8 +204,13 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
return nil return nil
} }
providedSetupKey, err := getSetupKey()
if err != nil {
return err
}
loginRequest := proto.LoginRequest{ loginRequest := proto.LoginRequest{
SetupKey: setupKey, SetupKey: providedSetupKey,
ManagementUrl: managementURL, ManagementUrl: managementURL,
AdminURL: adminURL, AdminURL: adminURL,
NatExternalIPs: natExternalIPs, NatExternalIPs: natExternalIPs,

View File

@ -2,6 +2,7 @@ package cmd
import ( import (
"context" "context"
"os"
"testing" "testing"
"time" "time"
@ -40,6 +41,36 @@ func TestUpDaemon(t *testing.T) {
return return
} }
// Test the setup-key-file flag.
tempFile, err := os.CreateTemp("", "setup-key")
if err != nil {
t.Errorf("could not create temp file, got error %v", err)
return
}
defer os.Remove(tempFile.Name())
if _, err := tempFile.Write([]byte("A2C8E62B-38F5-4553-B31E-DD66C696CEBB")); err != nil {
t.Errorf("could not write to temp file, got error %v", err)
return
}
if err := tempFile.Close(); err != nil {
t.Errorf("unable to close file, got error %v", err)
}
rootCmd.SetArgs([]string{
"login",
"--daemon-addr", "tcp://" + cliAddr,
"--setup-key-file", tempFile.Name(),
"--log-file", "",
})
if err := rootCmd.Execute(); err != nil {
t.Errorf("expected no error while running up command, got %v", err)
return
}
time.Sleep(time.Second * 3)
if status, err := state.Status(); err != nil && status != internal.StatusIdle {
t.Errorf("wrong status after login: %s, %v", internal.StatusIdle, err)
return
}
rootCmd.SetArgs([]string{ rootCmd.SetArgs([]string{
"up", "up",
"--daemon-addr", "tcp://" + cliAddr, "--daemon-addr", "tcp://" + cliAddr,

View File

@ -337,7 +337,6 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decode
if rule.dPort != 0 && rule.dPort == uint16(d.udp.DstPort) { if rule.dPort != 0 && rule.dPort == uint16(d.udp.DstPort) {
return rule.drop, true return rule.drop, true
} }
return rule.drop, true
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6: case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
return rule.drop, true return rule.drop, true
} }

View File

@ -69,6 +69,11 @@ func NewOAuthFlow(ctx context.Context, config *internal.Config, isLinuxDesktopCl
return authenticateWithDeviceCodeFlow(ctx, config) return authenticateWithDeviceCodeFlow(ctx, config)
} }
// On FreeBSD we currently do not support desktop environments and offer only Device Code Flow (#2384)
if runtime.GOOS == "freebsd" {
return authenticateWithDeviceCodeFlow(ctx, config)
}
pkceFlow, err := authenticateWithPKCEFlow(ctx, config) pkceFlow, err := authenticateWithPKCEFlow(ctx, config)
if err != nil { if err != nil {
// fallback to device code flow // fallback to device code flow

View File

@ -15,6 +15,12 @@ type hostManager interface {
restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error
} }
type SystemDNSSettings struct {
Domains []string
ServerIP string
ServerPort int
}
type HostDNSConfig struct { type HostDNSConfig struct {
Domains []DomainConfig `json:"domains"` Domains []DomainConfig `json:"domains"`
RouteAll bool `json:"routeAll"` RouteAll bool `json:"routeAll"`

View File

@ -7,6 +7,7 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"io" "io"
"net"
"net/netip" "net/netip"
"os/exec" "os/exec"
"strconv" "strconv"
@ -18,7 +19,7 @@ import (
const ( const (
netbirdDNSStateKeyFormat = "State:/Network/Service/NetBird-%s/DNS" netbirdDNSStateKeyFormat = "State:/Network/Service/NetBird-%s/DNS"
globalIPv4State = "State:/Network/Global/IPv4" globalIPv4State = "State:/Network/Global/IPv4"
primaryServiceSetupKeyFormat = "Setup:/Network/Service/%s/DNS" primaryServiceStateKeyFormat = "State:/Network/Service/%s/DNS"
keySupplementalMatchDomains = "SupplementalMatchDomains" keySupplementalMatchDomains = "SupplementalMatchDomains"
keySupplementalMatchDomainsNoSearch = "SupplementalMatchDomainsNoSearch" keySupplementalMatchDomainsNoSearch = "SupplementalMatchDomainsNoSearch"
keyServerAddresses = "ServerAddresses" keyServerAddresses = "ServerAddresses"
@ -28,12 +29,12 @@ const (
scutilPath = "/usr/sbin/scutil" scutilPath = "/usr/sbin/scutil"
searchSuffix = "Search" searchSuffix = "Search"
matchSuffix = "Match" matchSuffix = "Match"
localSuffix = "Local"
) )
type systemConfigurator struct { type systemConfigurator struct {
// primaryServiceID primary interface in the system. AKA the interface with the default route createdKeys map[string]struct{}
primaryServiceID string systemDNSSettings SystemDNSSettings
createdKeys map[string]struct{}
} }
func newHostManager() (hostManager, error) { func newHostManager() (hostManager, error) {
@ -49,20 +50,6 @@ func (s *systemConfigurator) supportCustomPort() bool {
func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error { func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error {
var err error var err error
if config.RouteAll {
err = s.addDNSSetupForAll(config.ServerIP, config.ServerPort)
if err != nil {
return fmt.Errorf("add dns setup for all: %w", err)
}
} else if s.primaryServiceID != "" {
err = s.removeKeyFromSystemConfig(getKeyWithInput(primaryServiceSetupKeyFormat, s.primaryServiceID))
if err != nil {
return fmt.Errorf("remote key from system config: %w", err)
}
s.primaryServiceID = ""
log.Infof("removed %s:%d as main DNS resolver for this peer", config.ServerIP, config.ServerPort)
}
// create a file for unclean shutdown detection // create a file for unclean shutdown detection
if err := createUncleanShutdownIndicator(); err != nil { if err := createUncleanShutdownIndicator(); err != nil {
log.Errorf("failed to create unclean shutdown file: %s", err) log.Errorf("failed to create unclean shutdown file: %s", err)
@ -73,6 +60,19 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error {
matchDomains []string matchDomains []string
) )
err = s.recordSystemDNSSettings(true)
if err != nil {
log.Errorf("unable to update record of System's DNS config: %s", err.Error())
}
if config.RouteAll {
searchDomains = append(searchDomains, "\"\"")
err = s.addLocalDNS()
if err != nil {
log.Infof("failed to enable split DNS")
}
}
for _, dConf := range config.Domains { for _, dConf := range config.Domains {
if dConf.Disabled { if dConf.Disabled {
continue continue
@ -110,23 +110,17 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error {
} }
func (s *systemConfigurator) restoreHostDNS() error { func (s *systemConfigurator) restoreHostDNS() error {
lines := "" keys := s.getRemovableKeysWithDefaults()
for key := range s.createdKeys { for _, key := range keys {
lines += buildRemoveKeyOperation(key)
keyType := "search" keyType := "search"
if strings.Contains(key, matchSuffix) { if strings.Contains(key, matchSuffix) {
keyType = "match" keyType = "match"
} }
log.Infof("removing %s domains from system", keyType) log.Infof("removing %s domains from system", keyType)
} err := s.removeKeyFromSystemConfig(key)
if s.primaryServiceID != "" { if err != nil {
lines += buildRemoveKeyOperation(getKeyWithInput(primaryServiceSetupKeyFormat, s.primaryServiceID)) log.Errorf("failed to remove %s domains from system: %s", keyType, err)
log.Infof("restoring DNS resolver configuration for system") }
}
_, err := runSystemConfigCommand(wrapCommand(lines))
if err != nil {
log.Errorf("got an error while cleaning the system configuration: %s", err)
return fmt.Errorf("clean system: %w", err)
} }
if err := removeUncleanShutdownIndicator(); err != nil { if err := removeUncleanShutdownIndicator(); err != nil {
@ -136,6 +130,19 @@ func (s *systemConfigurator) restoreHostDNS() error {
return nil return nil
} }
func (s *systemConfigurator) getRemovableKeysWithDefaults() []string {
if len(s.createdKeys) == 0 {
// return defaults for startup calls
return []string{getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix), getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)}
}
keys := make([]string, 0, len(s.createdKeys))
for key := range s.createdKeys {
keys = append(keys, key)
}
return keys
}
func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error { func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
line := buildRemoveKeyOperation(key) line := buildRemoveKeyOperation(key)
_, err := runSystemConfigCommand(wrapCommand(line)) _, err := runSystemConfigCommand(wrapCommand(line))
@ -148,6 +155,97 @@ func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
return nil return nil
} }
func (s *systemConfigurator) addLocalDNS() error {
if s.systemDNSSettings.ServerIP == "" || len(s.systemDNSSettings.Domains) == 0 {
err := s.recordSystemDNSSettings(true)
log.Errorf("Unable to get system DNS configuration")
return err
}
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
if s.systemDNSSettings.ServerIP != "" && len(s.systemDNSSettings.Domains) != 0 {
err := s.addSearchDomains(localKey, strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort)
if err != nil {
return fmt.Errorf("couldn't add local network DNS conf: %w", err)
}
} else {
log.Info("Not enabling local DNS server")
}
return nil
}
func (s *systemConfigurator) recordSystemDNSSettings(force bool) error {
if s.systemDNSSettings.ServerIP != "" && len(s.systemDNSSettings.Domains) != 0 && !force {
return nil
}
systemDNSSettings, err := s.getSystemDNSSettings()
if err != nil {
return fmt.Errorf("couldn't get current DNS config: %w", err)
}
s.systemDNSSettings = systemDNSSettings
return nil
}
func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
primaryServiceKey, _, err := s.getPrimaryService()
if err != nil || primaryServiceKey == "" {
return SystemDNSSettings{}, fmt.Errorf("couldn't find the primary service key: %w", err)
}
dnsServiceKey := getKeyWithInput(primaryServiceStateKeyFormat, primaryServiceKey)
line := buildCommandLine("show", dnsServiceKey, "")
stdinCommands := wrapCommand(line)
b, err := runSystemConfigCommand(stdinCommands)
if err != nil {
return SystemDNSSettings{}, fmt.Errorf("sending the command: %w", err)
}
var dnsSettings SystemDNSSettings
inSearchDomainsArray := false
inServerAddressesArray := false
scanner := bufio.NewScanner(bytes.NewReader(b))
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
switch {
case strings.HasPrefix(line, "DomainName :"):
domainName := strings.TrimSpace(strings.Split(line, ":")[1])
dnsSettings.Domains = append(dnsSettings.Domains, domainName)
case line == "SearchDomains : <array> {":
inSearchDomainsArray = true
continue
case line == "ServerAddresses : <array> {":
inServerAddressesArray = true
continue
case line == "}":
inSearchDomainsArray = false
inServerAddressesArray = false
}
if inSearchDomainsArray {
searchDomain := strings.Split(line, " : ")[1]
dnsSettings.Domains = append(dnsSettings.Domains, searchDomain)
} else if inServerAddressesArray {
address := strings.Split(line, " : ")[1]
if ip := net.ParseIP(address); ip != nil && ip.To4() != nil {
dnsSettings.ServerIP = address
inServerAddressesArray = false // Stop reading after finding the first IPv4 address
}
}
}
if err := scanner.Err(); err != nil {
return dnsSettings, err
}
// default to 53 port
dnsSettings.ServerPort = 53
return dnsSettings, nil
}
func (s *systemConfigurator) addSearchDomains(key, domains string, ip string, port int) error { func (s *systemConfigurator) addSearchDomains(key, domains string, ip string, port int) error {
err := s.addDNSState(key, domains, ip, port, true) err := s.addDNSState(key, domains, ip, port, true)
if err != nil { if err != nil {
@ -194,23 +292,6 @@ func (s *systemConfigurator) addDNSState(state, domains, dnsServer string, port
return nil return nil
} }
func (s *systemConfigurator) addDNSSetupForAll(dnsServer string, port int) error {
primaryServiceKey, existingNameserver, err := s.getPrimaryService()
if err != nil || primaryServiceKey == "" {
return fmt.Errorf("couldn't find the primary service key: %w", err)
}
err = s.addDNSSetup(getKeyWithInput(primaryServiceSetupKeyFormat, primaryServiceKey), dnsServer, port, existingNameserver)
if err != nil {
return fmt.Errorf("add dns setup: %w", err)
}
log.Infof("configured %s:%d as main DNS resolver for this peer", dnsServer, port)
s.primaryServiceID = primaryServiceKey
return nil
}
func (s *systemConfigurator) getPrimaryService() (string, string, error) { func (s *systemConfigurator) getPrimaryService() (string, string, error) {
line := buildCommandLine("show", globalIPv4State, "") line := buildCommandLine("show", globalIPv4State, "")
stdinCommands := wrapCommand(line) stdinCommands := wrapCommand(line)
@ -239,19 +320,6 @@ func (s *systemConfigurator) getPrimaryService() (string, string, error) {
return primaryService, router, nil return primaryService, router, nil
} }
func (s *systemConfigurator) addDNSSetup(setupKey, dnsServer string, port int, existingDNSServer string) error {
lines := buildAddCommandLine(keySupplementalMatchDomainsNoSearch, digitSymbol+strconv.Itoa(0))
lines += buildAddCommandLine(keyServerAddresses, arraySymbol+dnsServer+" "+existingDNSServer)
lines += buildAddCommandLine(keyServerPort, digitSymbol+strconv.Itoa(port))
addDomainCommand := buildCreateStateWithOperation(setupKey, lines)
stdinCommands := wrapCommand(addDomainCommand)
_, err := runSystemConfigCommand(stdinCommands)
if err != nil {
return fmt.Errorf("applying dns setup, error: %w", err)
}
return nil
}
func (s *systemConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error { func (s *systemConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error {
if err := s.restoreHostDNS(); err != nil { if err := s.restoreHostDNS(); err != nil {
return fmt.Errorf("restoring dns via scutil: %w", err) return fmt.Errorf("restoring dns via scutil: %w", err)

View File

@ -94,7 +94,7 @@ func NewDefaultServer(
var dnsService service var dnsService service
if wgInterface.IsUserspaceBind() { if wgInterface.IsUserspaceBind() {
dnsService = newServiceViaMemory(wgInterface) dnsService = NewServiceViaMemory(wgInterface)
} else { } else {
dnsService = newServiceViaListener(wgInterface, addrPort) dnsService = newServiceViaListener(wgInterface, addrPort)
} }
@ -112,7 +112,7 @@ func NewDefaultServerPermanentUpstream(
statusRecorder *peer.Status, statusRecorder *peer.Status,
) *DefaultServer { ) *DefaultServer {
log.Debugf("host dns address list is: %v", hostsDnsList) log.Debugf("host dns address list is: %v", hostsDnsList)
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface), statusRecorder) ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder)
ds.hostsDNSHolder.set(hostsDnsList) ds.hostsDNSHolder.set(hostsDnsList)
ds.permanent = true ds.permanent = true
ds.addHostRootZone() ds.addHostRootZone()
@ -130,7 +130,7 @@ func NewDefaultServerIos(
iosDnsManager IosDnsManager, iosDnsManager IosDnsManager,
statusRecorder *peer.Status, statusRecorder *peer.Status,
) *DefaultServer { ) *DefaultServer {
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface), statusRecorder) ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder)
ds.iosDnsManager = iosDnsManager ds.iosDnsManager = iosDnsManager
return ds return ds
} }

View File

@ -534,7 +534,7 @@ func TestDNSServerStartStop(t *testing.T) {
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) { func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
hostManager := &mockHostConfigurator{} hostManager := &mockHostConfigurator{}
server := DefaultServer{ server := DefaultServer{
service: newServiceViaMemory(&mocWGIface{}), service: NewServiceViaMemory(&mocWGIface{}),
localResolver: &localResolver{ localResolver: &localResolver{
registeredMap: make(registrationMap), registeredMap: make(registrationMap),
}, },

View File

@ -128,6 +128,9 @@ func (s *serviceViaListener) RuntimeIP() string {
} }
func (s *serviceViaListener) setListenerStatus(running bool) { func (s *serviceViaListener) setListenerStatus(running bool) {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
s.listenerIsRunning = running s.listenerIsRunning = running
} }

View File

@ -12,7 +12,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
type serviceViaMemory struct { type ServiceViaMemory struct {
wgInterface WGIface wgInterface WGIface
dnsMux *dns.ServeMux dnsMux *dns.ServeMux
runtimeIP string runtimeIP string
@ -22,8 +22,8 @@ type serviceViaMemory struct {
listenerFlagLock sync.Mutex listenerFlagLock sync.Mutex
} }
func newServiceViaMemory(wgIface WGIface) *serviceViaMemory { func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
s := &serviceViaMemory{ s := &ServiceViaMemory{
wgInterface: wgIface, wgInterface: wgIface,
dnsMux: dns.NewServeMux(), dnsMux: dns.NewServeMux(),
@ -33,7 +33,7 @@ func newServiceViaMemory(wgIface WGIface) *serviceViaMemory {
return s return s
} }
func (s *serviceViaMemory) Listen() error { func (s *ServiceViaMemory) Listen() error {
s.listenerFlagLock.Lock() s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock() defer s.listenerFlagLock.Unlock()
@ -52,7 +52,7 @@ func (s *serviceViaMemory) Listen() error {
return nil return nil
} }
func (s *serviceViaMemory) Stop() { func (s *ServiceViaMemory) Stop() {
s.listenerFlagLock.Lock() s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock() defer s.listenerFlagLock.Unlock()
@ -67,23 +67,23 @@ func (s *serviceViaMemory) Stop() {
s.listenerIsRunning = false s.listenerIsRunning = false
} }
func (s *serviceViaMemory) RegisterMux(pattern string, handler dns.Handler) { func (s *ServiceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
s.dnsMux.Handle(pattern, handler) s.dnsMux.Handle(pattern, handler)
} }
func (s *serviceViaMemory) DeregisterMux(pattern string) { func (s *ServiceViaMemory) DeregisterMux(pattern string) {
s.dnsMux.HandleRemove(pattern) s.dnsMux.HandleRemove(pattern)
} }
func (s *serviceViaMemory) RuntimePort() int { func (s *ServiceViaMemory) RuntimePort() int {
return s.runtimePort return s.runtimePort
} }
func (s *serviceViaMemory) RuntimeIP() string { func (s *ServiceViaMemory) RuntimeIP() string {
return s.runtimeIP return s.runtimeIP
} }
func (s *serviceViaMemory) filterDNSTraffic() (string, error) { func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
filter := s.wgInterface.GetFilter() filter := s.wgInterface.GetFilter()
if filter == nil { if filter == nil {
return "", fmt.Errorf("can't set DNS filter, filter not initialized") return "", fmt.Errorf("can't set DNS filter, filter not initialized")

View File

@ -24,7 +24,7 @@ const (
probeTimeout = 2 * time.Second probeTimeout = 2 * time.Second
) )
const testRecord = "." const testRecord = "com."
type upstreamClient interface { type upstreamClient interface {
exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error) exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
@ -42,6 +42,7 @@ type upstreamResolverBase struct {
upstreamServers []string upstreamServers []string
disabled bool disabled bool
failsCount atomic.Int32 failsCount atomic.Int32
successCount atomic.Int32
failsTillDeact int32 failsTillDeact int32
mutex sync.Mutex mutex sync.Mutex
reactivatePeriod time.Duration reactivatePeriod time.Duration
@ -124,6 +125,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return return
} }
u.successCount.Add(1)
log.Tracef("took %s to query the upstream %s", t, upstream) log.Tracef("took %s to query the upstream %s", t, upstream)
err = w.WriteMsg(rm) err = w.WriteMsg(rm)
@ -172,6 +174,11 @@ func (u *upstreamResolverBase) probeAvailability() {
default: default:
} }
// avoid probe if upstreams could resolve at least one query and fails count is less than failsTillDeact
if u.successCount.Load() > 0 && u.failsCount.Load() < u.failsTillDeact {
return
}
var success bool var success bool
var mu sync.Mutex var mu sync.Mutex
var wg sync.WaitGroup var wg sync.WaitGroup
@ -183,7 +190,7 @@ func (u *upstreamResolverBase) probeAvailability() {
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
err := u.testNameserver(upstream) err := u.testNameserver(upstream, 500*time.Millisecond)
if err != nil { if err != nil {
errors = multierror.Append(errors, err) errors = multierror.Append(errors, err)
log.Warnf("probing upstream nameserver %s: %s", upstream, err) log.Warnf("probing upstream nameserver %s: %s", upstream, err)
@ -224,7 +231,7 @@ func (u *upstreamResolverBase) waitUntilResponse() {
} }
for _, upstream := range u.upstreamServers { for _, upstream := range u.upstreamServers {
if err := u.testNameserver(upstream); err != nil { if err := u.testNameserver(upstream, probeTimeout); err != nil {
log.Tracef("upstream check for %s: %s", upstream, err) log.Tracef("upstream check for %s: %s", upstream, err)
} else { } else {
// at least one upstream server is available, stop probing // at least one upstream server is available, stop probing
@ -244,6 +251,7 @@ func (u *upstreamResolverBase) waitUntilResponse() {
log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServers) log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServers)
u.failsCount.Store(0) u.failsCount.Store(0)
u.successCount.Add(1)
u.reactivate() u.reactivate()
u.disabled = false u.disabled = false
} }
@ -265,13 +273,14 @@ func (u *upstreamResolverBase) disable(err error) {
} }
log.Warnf("Upstream resolving is Disabled for %v", reactivatePeriod) log.Warnf("Upstream resolving is Disabled for %v", reactivatePeriod)
u.successCount.Store(0)
u.deactivate(err) u.deactivate(err)
u.disabled = true u.disabled = true
go u.waitUntilResponse() go u.waitUntilResponse()
} }
func (u *upstreamResolverBase) testNameserver(server string) error { func (u *upstreamResolverBase) testNameserver(server string, timeout time.Duration) error {
ctx, cancel := context.WithTimeout(u.ctx, probeTimeout) ctx, cancel := context.WithTimeout(u.ctx, timeout)
defer cancel() defer cancel()
r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA) r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA)

View File

@ -4,6 +4,7 @@ package dns
import ( import (
"context" "context"
"fmt"
"net" "net"
"syscall" "syscall"
"time" "time"
@ -17,9 +18,9 @@ import (
type upstreamResolverIOS struct { type upstreamResolverIOS struct {
*upstreamResolverBase *upstreamResolverBase
lIP net.IP lIP net.IP
lNet *net.IPNet lNet *net.IPNet
iIndex int interfaceName string
} }
func newUpstreamResolver( func newUpstreamResolver(
@ -32,17 +33,11 @@ func newUpstreamResolver(
) (*upstreamResolverIOS, error) { ) (*upstreamResolverIOS, error) {
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder) upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder)
index, err := getInterfaceIndex(interfaceName)
if err != nil {
log.Debugf("unable to get interface index for %s: %s", interfaceName, err)
return nil, err
}
ios := &upstreamResolverIOS{ ios := &upstreamResolverIOS{
upstreamResolverBase: upstreamResolverBase, upstreamResolverBase: upstreamResolverBase,
lIP: ip, lIP: ip,
lNet: net, lNet: net,
iIndex: index, interfaceName: interfaceName,
} }
ios.upstreamClient = ios ios.upstreamClient = ios
@ -53,7 +48,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
client := &dns.Client{} client := &dns.Client{}
upstreamHost, _, err := net.SplitHostPort(upstream) upstreamHost, _, err := net.SplitHostPort(upstream)
if err != nil { if err != nil {
log.Errorf("error while parsing upstream host: %s", err) return nil, 0, fmt.Errorf("error while parsing upstream host: %s", err)
} }
timeout := upstreamTimeout timeout := upstreamTimeout
@ -65,26 +60,35 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
upstreamIP := net.ParseIP(upstreamHost) upstreamIP := net.ParseIP(upstreamHost)
if u.lNet.Contains(upstreamIP) || net.IP.IsPrivate(upstreamIP) { if u.lNet.Contains(upstreamIP) || net.IP.IsPrivate(upstreamIP) {
log.Debugf("using private client to query upstream: %s", upstream) log.Debugf("using private client to query upstream: %s", upstream)
client = u.getClientPrivate(timeout) client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout)
if err != nil {
return nil, 0, fmt.Errorf("error while creating private client: %s", err)
}
} }
// Cannot use client.ExchangeContext because it overwrites our Dialer // Cannot use client.ExchangeContext because it overwrites our Dialer
return client.Exchange(r, upstream) return client.Exchange(r, upstream)
} }
// getClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface // GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
// This method is needed for iOS // This method is needed for iOS
func (u *upstreamResolverIOS) getClientPrivate(dialTimeout time.Duration) *dns.Client { func GetClientPrivate(ip net.IP, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
index, err := getInterfaceIndex(interfaceName)
if err != nil {
log.Debugf("unable to get interface index for %s: %s", interfaceName, err)
return nil, err
}
dialer := &net.Dialer{ dialer := &net.Dialer{
LocalAddr: &net.UDPAddr{ LocalAddr: &net.UDPAddr{
IP: u.lIP, IP: ip,
Port: 0, // Let the OS pick a free port Port: 0, // Let the OS pick a free port
}, },
Timeout: dialTimeout, Timeout: dialTimeout,
Control: func(network, address string, c syscall.RawConn) error { Control: func(network, address string, c syscall.RawConn) error {
var operr error var operr error
fn := func(s uintptr) { fn := func(s uintptr) {
operr = unix.SetsockoptInt(int(s), unix.IPPROTO_IP, unix.IP_BOUND_IF, u.iIndex) operr = unix.SetsockoptInt(int(s), unix.IPPROTO_IP, unix.IP_BOUND_IF, index)
} }
if err := c.Control(fn); err != nil { if err := c.Control(fn); err != nil {
@ -101,7 +105,7 @@ func (u *upstreamResolverIOS) getClientPrivate(dialTimeout time.Duration) *dns.C
client := &dns.Client{ client := &dns.Client{
Dialer: dialer, Dialer: dialer,
} }
return client return client, nil
} }
func getInterfaceIndex(interfaceName string) (int, error) { func getInterfaceIndex(interfaceName string) (int, error) {

View File

@ -266,8 +266,23 @@ func (e *Engine) Stop() error {
e.close() e.close()
e.wgConnWorker.Wait() e.wgConnWorker.Wait()
log.Infof("stopped Netbird Engine")
return nil maxWaitTime := 5 * time.Second
timeout := time.After(maxWaitTime)
for {
if !e.IsWGIfaceUp() {
log.Infof("stopped Netbird Engine")
return nil
}
select {
case <-timeout:
return fmt.Errorf("timeout when waiting for interface shutdown")
default:
time.Sleep(100 * time.Millisecond)
}
}
} }
// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services // Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
@ -1533,3 +1548,20 @@ func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
return slices.Equal(checks.Files, oChecks.Files) return slices.Equal(checks.Files, oChecks.Files)
}) })
} }
func (e *Engine) IsWGIfaceUp() bool {
if e == nil || e.wgInterface == nil {
return false
}
iface, err := net.InterfaceByName(e.wgInterface.Name())
if err != nil {
log.Debugf("failed to get interface by name %s: %v", e.wgInterface.Name(), err)
return false
}
if iface.Flags&net.FlagUp != 0 {
return true
}
return false
}

View File

@ -36,6 +36,7 @@ import (
mgmtProto "github.com/netbirdio/netbird/management/proto" mgmtProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
signal "github.com/netbirdio/netbird/signal/client" signal "github.com/netbirdio/netbird/signal/client"
"github.com/netbirdio/netbird/signal/proto" "github.com/netbirdio/netbird/signal/proto"
@ -1069,7 +1070,11 @@ func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error)
return nil, "", err return nil, "", err
} }
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore) ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }

View File

@ -4,6 +4,7 @@ package networkmonitor
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"syscall" "syscall"
"unsafe" "unsafe"
@ -21,11 +22,20 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
return fmt.Errorf("failed to open routing socket: %v", err) return fmt.Errorf("failed to open routing socket: %v", err)
} }
defer func() { defer func() {
if err := unix.Close(fd); err != nil { err := unix.Close(fd)
if err != nil && !errors.Is(err, unix.EBADF) {
log.Errorf("Network monitor: failed to close routing socket: %v", err) log.Errorf("Network monitor: failed to close routing socket: %v", err)
} }
}() }()
go func() {
<-ctx.Done()
err := unix.Close(fd)
if err != nil && !errors.Is(err, unix.EBADF) {
log.Debugf("Network monitor: closed routing socket")
}
}()
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
@ -34,7 +44,9 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
buf := make([]byte, 2048) buf := make([]byte, 2048)
n, err := unix.Read(fd, buf) n, err := unix.Read(fd, buf)
if err != nil { if err != nil {
log.Errorf("Network monitor: failed to read from routing socket: %v", err) if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) {
log.Errorf("Network monitor: failed to read from routing socket: %v", err)
}
continue continue
} }
if n < unix.SizeofRtMsghdr { if n < unix.SizeofRtMsghdr {

View File

@ -99,6 +99,11 @@ func routeChanged(nexthop systemops.Nexthop, intf *net.Interface, routes []syste
return false return false
} }
if isSoftInterface(nexthop.Intf.Name) {
log.Tracef("network monitor: ignoring default route change for soft interface %s", nexthop.Intf.Name)
return false
}
unspec := getUnspecifiedPrefix(nexthop.IP) unspec := getUnspecifiedPrefix(nexthop.IP)
defaultRoutes, foundMatchingRoute := processRoutes(nexthop, intf, routes, unspec) defaultRoutes, foundMatchingRoute := processRoutes(nexthop, intf, routes, unspec)
@ -119,7 +124,7 @@ func getUnspecifiedPrefix(ip netip.Addr) netip.Prefix {
return netip.PrefixFrom(netip.IPv4Unspecified(), 0) return netip.PrefixFrom(netip.IPv4Unspecified(), 0)
} }
func processRoutes(nexthop systemops.Nexthop, intf *net.Interface, routes []systemops.Route, unspec netip.Prefix) ([]string, bool) { func processRoutes(nexthop systemops.Nexthop, nexthopIntf *net.Interface, routes []systemops.Route, unspec netip.Prefix) ([]string, bool) {
var defaultRoutes []string var defaultRoutes []string
foundMatchingRoute := false foundMatchingRoute := false
@ -128,7 +133,7 @@ func processRoutes(nexthop systemops.Nexthop, intf *net.Interface, routes []syst
routeInfo := formatRouteInfo(r) routeInfo := formatRouteInfo(r)
defaultRoutes = append(defaultRoutes, routeInfo) defaultRoutes = append(defaultRoutes, routeInfo)
if r.Nexthop == nexthop.IP && compareIntf(r.Interface, intf) == 0 { if r.Nexthop == nexthop.IP && compareIntf(r.Interface, nexthopIntf) == 0 {
foundMatchingRoute = true foundMatchingRoute = true
log.Debugf("network monitor: found matching default route: %s", routeInfo) log.Debugf("network monitor: found matching default route: %s", routeInfo)
} }
@ -232,14 +237,18 @@ func stateFromInt(state uint8) string {
} }
func compareIntf(a, b *net.Interface) int { func compareIntf(a, b *net.Interface) int {
if a == nil && b == nil { switch {
case a == nil && b == nil:
return 0 return 0
} case a == nil:
if a == nil {
return -1 return -1
} case b == nil:
if b == nil {
return 1 return 1
default:
return a.Index - b.Index
} }
return a.Index - b.Index }
func isSoftInterface(name string) bool {
return strings.Contains(strings.ToLower(name), "isatap") || strings.Contains(strings.ToLower(name), "teredo")
} }

View File

@ -10,6 +10,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic" "github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
@ -65,7 +66,7 @@ func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration
routePeersNotifiers: make(map[string]chan struct{}), routePeersNotifiers: make(map[string]chan struct{}),
routeUpdate: make(chan routesUpdate), routeUpdate: make(chan routesUpdate),
peerStateUpdate: make(chan struct{}), peerStateUpdate: make(chan struct{}),
handler: handlerFromRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouteInterval, statusRecorder), handler: handlerFromRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouteInterval, statusRecorder, wgInterface),
} }
return client return client
} }
@ -383,9 +384,10 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
} }
} }
func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status) RouteHandler { func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status, wgInterface *iface.WGIface) RouteHandler {
if rt.IsDynamic() { if rt.IsDynamic() {
return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder) dns := nbdns.NewServiceViaMemory(wgInterface)
return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder, wgInterface, fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()))
} }
return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter) return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter)
} }

View File

@ -16,6 +16,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/util" "github.com/netbirdio/netbird/client/internal/routemanager/util"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@ -47,6 +48,8 @@ type Route struct {
currentPeerKey string currentPeerKey string
cancel context.CancelFunc cancel context.CancelFunc
statusRecorder *peer.Status statusRecorder *peer.Status
wgInterface *iface.WGIface
resolverAddr string
} }
func NewRoute( func NewRoute(
@ -55,6 +58,8 @@ func NewRoute(
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
interval time.Duration, interval time.Duration,
statusRecorder *peer.Status, statusRecorder *peer.Status,
wgInterface *iface.WGIface,
resolverAddr string,
) *Route { ) *Route {
return &Route{ return &Route{
route: rt, route: rt,
@ -63,6 +68,8 @@ func NewRoute(
interval: interval, interval: interval,
dynamicDomains: domainMap{}, dynamicDomains: domainMap{},
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
wgInterface: wgInterface,
resolverAddr: resolverAddr,
} }
} }
@ -189,9 +196,14 @@ func (r *Route) startResolver(ctx context.Context) {
} }
func (r *Route) update(ctx context.Context) error { func (r *Route) update(ctx context.Context) error {
if resolved, err := r.resolveDomains(); err != nil { resolved, err := r.resolveDomains()
return fmt.Errorf("resolve domains: %w", err) if err != nil {
} else if err := r.updateDynamicRoutes(ctx, resolved); err != nil { if len(resolved) == 0 {
return fmt.Errorf("resolve domains: %w", err)
}
log.Warnf("Failed to resolve domains: %v", err)
}
if err := r.updateDynamicRoutes(ctx, resolved); err != nil {
return fmt.Errorf("update dynamic routes: %w", err) return fmt.Errorf("update dynamic routes: %w", err)
} }
@ -223,11 +235,17 @@ func (r *Route) resolve(results chan resolveResult) {
wg.Add(1) wg.Add(1)
go func(domain domain.Domain) { go func(domain domain.Domain) {
defer wg.Done() defer wg.Done()
ips, err := net.LookupIP(string(domain))
ips, err := r.getIPsFromResolver(domain)
if err != nil { if err != nil {
results <- resolveResult{domain: domain, err: fmt.Errorf("resolve d %s: %w", domain.SafeString(), err)} log.Tracef("Failed to resolve domain %s with private resolver: %v", domain.SafeString(), err)
return ips, err = net.LookupIP(string(domain))
if err != nil {
results <- resolveResult{domain: domain, err: fmt.Errorf("resolve d %s: %w", domain.SafeString(), err)}
return
}
} }
for _, ip := range ips { for _, ip := range ips {
prefix, err := util.GetPrefixFromIP(ip) prefix, err := util.GetPrefixFromIP(ip)
if err != nil { if err != nil {

View File

@ -0,0 +1,13 @@
//go:build !ios
package dynamic
import (
"net"
"github.com/netbirdio/netbird/management/domain"
)
func (r *Route) getIPsFromResolver(domain domain.Domain) ([]net.IP, error) {
return net.LookupIP(string(domain))
}

View File

@ -0,0 +1,55 @@
//go:build ios
package dynamic
import (
"fmt"
"net"
"time"
"github.com/miekg/dns"
nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/management/domain"
)
const dialTimeout = 10 * time.Second
func (r *Route) getIPsFromResolver(domain domain.Domain) ([]net.IP, error) {
privateClient, err := nbdns.GetClientPrivate(r.wgInterface.Address().IP, r.wgInterface.Name(), dialTimeout)
if err != nil {
return nil, fmt.Errorf("error while creating private client: %s", err)
}
msg := new(dns.Msg)
msg.SetQuestion(dns.Fqdn(string(domain)), dns.TypeA)
startTime := time.Now()
response, _, err := privateClient.Exchange(msg, r.resolverAddr)
if err != nil {
return nil, fmt.Errorf("DNS query for %s failed after %s: %s ", domain.SafeString(), time.Since(startTime), err)
}
if response.Rcode != dns.RcodeSuccess {
return nil, fmt.Errorf("dns response code: %s", dns.RcodeToString[response.Rcode])
}
ips := make([]net.IP, 0)
for _, answ := range response.Answer {
if aRecord, ok := answ.(*dns.A); ok {
ips = append(ips, aRecord.A)
}
if aaaaRecord, ok := answ.(*dns.AAAA); ok {
ips = append(ips, aaaaRecord.AAAA)
}
}
if len(ips) == 0 {
return nil, fmt.Errorf("no A or AAAA records found for %s", domain.SafeString())
}
return ips, nil
}

View File

@ -22,7 +22,7 @@ type Route struct {
Interface *net.Interface Interface *net.Interface
} }
func getRoutesFromTable() ([]netip.Prefix, error) { func GetRoutesFromTable() ([]netip.Prefix, error) {
tab, err := retryFetchRIB() tab, err := retryFetchRIB()
if err != nil { if err != nil {
return nil, fmt.Errorf("fetch RIB: %v", err) return nil, fmt.Errorf("fetch RIB: %v", err)

View File

@ -50,7 +50,7 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP) (nbnet.AddHookFunc, nbn
nexthop, err := r.addRouteToNonVPNIntf(prefix, r.wgInterface, initialNexthop) nexthop, err := r.addRouteToNonVPNIntf(prefix, r.wgInterface, initialNexthop)
if errors.Is(err, vars.ErrRouteNotAllowed) || errors.Is(err, vars.ErrRouteNotFound) { if errors.Is(err, vars.ErrRouteNotAllowed) || errors.Is(err, vars.ErrRouteNotFound) {
log.Tracef("Adding for prefix %s: %v", prefix, err) log.Tracef("Adding for prefix %s: %v", prefix, err)
// These errors are not critical but also we should not track and try to remove the routes either. // These errors are not critical, but also we should not track and try to remove the routes either.
return nexthop, refcounter.ErrIgnore return nexthop, refcounter.ErrIgnore
} }
return nexthop, err return nexthop, err
@ -135,6 +135,11 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIfac
return Nexthop{}, vars.ErrRouteNotAllowed return Nexthop{}, vars.ErrRouteNotAllowed
} }
// Check if the prefix is part of any local subnets
if isLocal, subnet := r.isPrefixInLocalSubnets(prefix); isLocal {
return Nexthop{}, fmt.Errorf("prefix %s is part of local subnet %s: %w", prefix, subnet, vars.ErrRouteNotAllowed)
}
// Determine the exit interface and next hop for the prefix, so we can add a specific route // Determine the exit interface and next hop for the prefix, so we can add a specific route
nexthop, err := GetNextHop(addr) nexthop, err := GetNextHop(addr)
if err != nil { if err != nil {
@ -167,6 +172,36 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIfac
return exitNextHop, nil return exitNextHop, nil
} }
func (r *SysOps) isPrefixInLocalSubnets(prefix netip.Prefix) (bool, *net.IPNet) {
localInterfaces, err := net.Interfaces()
if err != nil {
log.Errorf("Failed to get local interfaces: %v", err)
return false, nil
}
for _, intf := range localInterfaces {
addrs, err := intf.Addrs()
if err != nil {
log.Errorf("Failed to get addresses for interface %s: %v", intf.Name, err)
continue
}
for _, addr := range addrs {
ipnet, ok := addr.(*net.IPNet)
if !ok {
log.Errorf("Failed to convert address to IPNet: %v", addr)
continue
}
if ipnet.Contains(prefix.Addr().AsSlice()) {
return true, ipnet
}
}
}
return false, nil
}
// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix // genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix
// in two /1 prefixes to avoid replacing the existing default route // in two /1 prefixes to avoid replacing the existing default route
func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) error { func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
@ -392,7 +427,7 @@ func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) {
} }
func existsInRouteTable(prefix netip.Prefix) (bool, error) { func existsInRouteTable(prefix netip.Prefix) (bool, error) {
routes, err := getRoutesFromTable() routes, err := GetRoutesFromTable()
if err != nil { if err != nil {
return false, fmt.Errorf("get routes from table: %w", err) return false, fmt.Errorf("get routes from table: %w", err)
} }
@ -405,7 +440,7 @@ func existsInRouteTable(prefix netip.Prefix) (bool, error) {
} }
func isSubRange(prefix netip.Prefix) (bool, error) { func isSubRange(prefix netip.Prefix) (bool, error) {
routes, err := getRoutesFromTable() routes, err := GetRoutesFromTable()
if err != nil { if err != nil {
return false, fmt.Errorf("get routes from table: %w", err) return false, fmt.Errorf("get routes from table: %w", err)
} }

View File

@ -206,7 +206,7 @@ func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error
return nil return nil
} }
func getRoutesFromTable() ([]netip.Prefix, error) { func GetRoutesFromTable() ([]netip.Prefix, error) {
v4Routes, err := getRoutes(syscall.RT_TABLE_MAIN, netlink.FAMILY_V4) v4Routes, err := getRoutes(syscall.RT_TABLE_MAIN, netlink.FAMILY_V4)
if err != nil { if err != nil {
return nil, fmt.Errorf("get v4 routes: %w", err) return nil, fmt.Errorf("get v4 routes: %w", err)
@ -504,7 +504,7 @@ func getAddressFamily(prefix netip.Prefix) int {
func hasSeparateRouting() ([]netip.Prefix, error) { func hasSeparateRouting() ([]netip.Prefix, error) {
if isLegacy() { if isLegacy() {
return getRoutesFromTable() return GetRoutesFromTable()
} }
return nil, ErrRoutingIsSeparate return nil, ErrRoutingIsSeparate
} }

View File

@ -24,5 +24,5 @@ func EnableIPForwarding() error {
} }
func hasSeparateRouting() ([]netip.Prefix, error) { func hasSeparateRouting() ([]netip.Prefix, error) {
return getRoutesFromTable() return GetRoutesFromTable()
} }

View File

@ -94,7 +94,7 @@ func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) erro
return nil return nil
} }
func getRoutesFromTable() ([]netip.Prefix, error) { func GetRoutesFromTable() ([]netip.Prefix, error) {
mux.Lock() mux.Lock()
defer mux.Unlock() defer mux.Unlock()

View File

@ -73,7 +73,7 @@ var testCases = []testCase{
{ {
name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence
destination: "10.0.0.2:53", destination: "10.0.0.2:53",
expectedSourceIP: "10.0.0.1", expectedSourceIP: "127.0.0.1",
expectedDestPrefix: "10.0.0.0/8", expectedDestPrefix: "10.0.0.0/8",
expectedNextHop: "0.0.0.0", expectedNextHop: "0.0.0.0",
expectedInterface: "Loopback Pseudo-Interface 1", expectedInterface: "Loopback Pseudo-Interface 1",
@ -110,7 +110,7 @@ var testCases = []testCase{
{ {
name: "To more specific route (local) without custom dialer via physical interface", name: "To more specific route (local) without custom dialer via physical interface",
destination: "127.0.10.2:53", destination: "127.0.10.2:53",
expectedSourceIP: "10.0.0.1", expectedSourceIP: "127.0.0.1",
expectedDestPrefix: "127.0.0.0/8", expectedDestPrefix: "127.0.0.0/8",
expectedNextHop: "0.0.0.0", expectedNextHop: "0.0.0.0",
expectedInterface: "Loopback Pseudo-Interface 1", expectedInterface: "Loopback Pseudo-Interface 1",
@ -181,31 +181,6 @@ func testRoute(t *testing.T, destination string, dialer dialer) *FindNetRouteOut
return combinedOutput return combinedOutput
} }
func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string {
t.Helper()
ip, ipNet, err := net.ParseCIDR(ipAddressCIDR)
require.NoError(t, err)
subnetMaskSize, _ := ipNet.Mask.Size()
script := fmt.Sprintf(`New-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -PrefixLength %d -PolicyStore ActiveStore -Confirm:$False`, interfaceName, ip.String(), subnetMaskSize)
_, err = exec.Command("powershell", "-Command", script).CombinedOutput()
require.NoError(t, err, "Failed to assign IP address to loopback adapter")
// Wait for the IP address to be applied
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
err = waitForIPAddress(ctx, interfaceName, ip.String())
require.NoError(t, err, "IP address not applied within timeout")
t.Cleanup(func() {
script = fmt.Sprintf(`Remove-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -Confirm:$False`, interfaceName, ip.String())
_, err = exec.Command("powershell", "-Command", script).CombinedOutput()
require.NoError(t, err, "Failed to remove IP address from loopback adapter")
})
return interfaceName
}
func fetchOriginalGateway() (*RouteInfo, error) { func fetchOriginalGateway() (*RouteInfo, error) {
cmd := exec.Command("powershell", "-Command", "Get-NetRoute -DestinationPrefix 0.0.0.0/0 | Select-Object Nexthop, RouteMetric, InterfaceAlias | ConvertTo-Json") cmd := exec.Command("powershell", "-Command", "Get-NetRoute -DestinationPrefix 0.0.0.0/0 | Select-Object Nexthop, RouteMetric, InterfaceAlias | ConvertTo-Json")
output, err := cmd.CombinedOutput() output, err := cmd.CombinedOutput()
@ -231,30 +206,6 @@ func verifyOutput(t *testing.T, output *FindNetRouteOutput, sourceIP, destPrefix
assert.Equal(t, intf, output.InterfaceAlias, "Interface mismatch") assert.Equal(t, intf, output.InterfaceAlias, "Interface mismatch")
} }
func waitForIPAddress(ctx context.Context, interfaceAlias, expectedIPAddress string) error {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
out, err := exec.Command("powershell", "-Command", fmt.Sprintf(`Get-NetIPAddress -InterfaceAlias "%s" | Select-Object -ExpandProperty IPAddress`, interfaceAlias)).CombinedOutput()
if err != nil {
return err
}
ipAddresses := strings.Split(strings.TrimSpace(string(out)), "\n")
for _, ip := range ipAddresses {
if strings.TrimSpace(ip) == expectedIPAddress {
return nil
}
}
}
}
}
func combineOutputs(outputs []FindNetRouteOutput) *FindNetRouteOutput { func combineOutputs(outputs []FindNetRouteOutput) *FindNetRouteOutput {
var combined FindNetRouteOutput var combined FindNetRouteOutput
@ -285,5 +236,25 @@ func combineOutputs(outputs []FindNetRouteOutput) *FindNetRouteOutput {
func setupDummyInterfacesAndRoutes(t *testing.T) { func setupDummyInterfacesAndRoutes(t *testing.T) {
t.Helper() t.Helper()
createAndSetupDummyInterface(t, "Loopback Pseudo-Interface 1", "10.0.0.1/8") addDummyRoute(t, "10.0.0.0/8")
}
func addDummyRoute(t *testing.T, dstCIDR string) {
t.Helper()
script := fmt.Sprintf(`New-NetRoute -DestinationPrefix "%s" -InterfaceIndex 1 -PolicyStore ActiveStore`, dstCIDR)
output, err := exec.Command("powershell", "-Command", script).CombinedOutput()
if err != nil {
t.Logf("Failed to add dummy route: %v\nOutput: %s", err, output)
t.FailNow()
}
t.Cleanup(func() {
script = fmt.Sprintf(`Remove-NetRoute -DestinationPrefix "%s" -InterfaceIndex 1 -Confirm:$false`, dstCIDR)
output, err := exec.Command("powershell", "-Command", script).CombinedOutput()
if err != nil {
t.Logf("Failed to remove dummy route: %v\nOutput: %s", err, output)
}
})
} }

View File

@ -271,7 +271,14 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
} }
routesMap := engine.GetClientRoutesWithNetID() routesMap := engine.GetClientRoutesWithNetID()
routeSelector := engine.GetRouteManager().GetRouteSelector() routeManager := engine.GetRouteManager()
if routeManager == nil {
return nil, fmt.Errorf("could not get route manager")
}
routeSelector := routeManager.GetRouteSelector()
if routeSelector == nil {
return nil, fmt.Errorf("could not get route selector")
}
var routes []*selectRoute var routes []*selectRoute
for id, rt := range routesMap { for id, rt := range routesMap {

View File

@ -1828,8 +1828,9 @@ type DebugBundleRequest struct {
sizeCache protoimpl.SizeCache sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields unknownFields protoimpl.UnknownFields
Anonymize bool `protobuf:"varint,1,opt,name=anonymize,proto3" json:"anonymize,omitempty"` Anonymize bool `protobuf:"varint,1,opt,name=anonymize,proto3" json:"anonymize,omitempty"`
Status string `protobuf:"bytes,2,opt,name=status,proto3" json:"status,omitempty"` Status string `protobuf:"bytes,2,opt,name=status,proto3" json:"status,omitempty"`
SystemInfo bool `protobuf:"varint,3,opt,name=systemInfo,proto3" json:"systemInfo,omitempty"`
} }
func (x *DebugBundleRequest) Reset() { func (x *DebugBundleRequest) Reset() {
@ -1878,6 +1879,13 @@ func (x *DebugBundleRequest) GetStatus() string {
return "" return ""
} }
func (x *DebugBundleRequest) GetSystemInfo() bool {
if x != nil {
return x.SystemInfo
}
return false
}
type DebugBundleResponse struct { type DebugBundleResponse struct {
state protoimpl.MessageState state protoimpl.MessageState
sizeCache protoimpl.SizeCache sizeCache protoimpl.SizeCache
@ -2370,11 +2378,13 @@ var file_daemon_proto_rawDesc = []byte{
0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x24, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x24, 0x0a, 0x05, 0x76, 0x61, 0x6c,
0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f,
0x6e, 0x2e, 0x49, 0x50, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x6e, 0x2e, 0x49, 0x50, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a,
0x02, 0x38, 0x01, 0x22, 0x4a, 0x0a, 0x12, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x02, 0x38, 0x01, 0x22, 0x6a, 0x0a, 0x12, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64,
0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1c, 0x0a, 0x09, 0x61, 0x6e, 0x6f, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1c, 0x0a, 0x09, 0x61, 0x6e, 0x6f,
0x6e, 0x79, 0x6d, 0x69, 0x7a, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x61, 0x6e, 0x6e, 0x79, 0x6d, 0x69, 0x7a, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x61, 0x6e,
0x6f, 0x6e, 0x79, 0x6d, 0x69, 0x7a, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x6f, 0x6e, 0x79, 0x6d, 0x69, 0x7a, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75,
0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x22, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12,
0x1e, 0x0a, 0x0a, 0x73, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x03, 0x20,
0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, 0x6f, 0x22,
0x29, 0x0a, 0x13, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x29, 0x0a, 0x13, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65,
0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, 0x68, 0x18, 0x01, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, 0x68, 0x18, 0x01,
0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x22, 0x14, 0x0a, 0x12, 0x47, 0x65, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x22, 0x14, 0x0a, 0x12, 0x47, 0x65,

View File

@ -263,6 +263,7 @@ message Route {
message DebugBundleRequest { message DebugBundleRequest {
bool anonymize = 1; bool anonymize = 1;
string status = 2; string status = 2;
bool systemInfo = 3;
} }
message DebugBundleResponse { message DebugBundleResponse {

View File

@ -1,3 +1,5 @@
//go:build !android && !ios
package server package server
import ( import (
@ -6,16 +8,70 @@ import (
"context" "context"
"fmt" "fmt"
"io" "io"
"net"
"net/netip"
"os" "os"
"sort"
"strings" "strings"
"time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/anonymize" "github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
) )
const readmeContent = `Netbird debug bundle
This debug bundle contains the following files:
status.txt: Anonymized status information of the NetBird client.
client.log: Most recent, anonymized log file of the NetBird client.
routes.txt: Anonymized system routes, if --system-info flag was provided.
interfaces.txt: Anonymized network interface information, if --system-info flag was provided.
config.txt: Anonymized configuration information of the NetBird client.
Anonymization Process
The files in this bundle have been anonymized to protect sensitive information. Here's how the anonymization was applied:
IP Addresses
IPv4 addresses are replaced with addresses starting from 192.51.100.0
IPv6 addresses are replaced with addresses starting from 100::
IP addresses from non public ranges and well known addresses are not anonymized (e.g. 8.8.8.8, 100.64.0.0/10, addresses starting with 192.168., 172.16., 10., etc.).
Reoccuring IP addresses are replaced with the same anonymized address.
Note: The anonymized IP addresses in the status file do not match those in the log and routes files. However, the anonymized IP addresses are consistent within the status file and across the routes and log files.
Domains
All domain names (except for the netbird domains) are replaced with randomly generated strings ending in ".domain". Anonymized domains are consistent across all files in the bundle.
Reoccuring domain names are replaced with the same anonymized domain.
Routes
For anonymized routes, the IP addresses are replaced as described above. The prefix length remains unchanged. Note that for prefixes, the anonymized IP might not be a network address, but the prefix length is still correct.
Network Interfaces
The interfaces.txt file contains information about network interfaces, including:
- Interface name
- Interface index
- MTU (Maximum Transmission Unit)
- Flags
- IP addresses associated with each interface
The IP addresses in the interfaces file are anonymized using the same process as described above. Interface names, indexes, MTUs, and flags are not anonymized.
Configuration
The config.txt file contains anonymized configuration information of the NetBird client. Sensitive information such as private keys and SSH keys are excluded. The following fields are anonymized:
- ManagementURL
- AdminURL
- NATExternalIPs
- CustomDNSAddress
Other non-sensitive configuration options are included without anonymization.
`
// DebugBundle creates a debug bundle and returns the location. // DebugBundle creates a debug bundle and returns the location.
func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (resp *proto.DebugBundleResponse, err error) { func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (resp *proto.DebugBundleResponse, err error) {
s.mutex.Lock() s.mutex.Lock()
@ -30,93 +86,211 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
return nil, fmt.Errorf("create zip file: %w", err) return nil, fmt.Errorf("create zip file: %w", err)
} }
defer func() { defer func() {
if err := bundlePath.Close(); err != nil { if closeErr := bundlePath.Close(); closeErr != nil && err == nil {
log.Errorf("failed to close zip file: %v", err) err = fmt.Errorf("close zip file: %w", closeErr)
} }
if err != nil { if err != nil {
if err2 := os.Remove(bundlePath.Name()); err2 != nil { if removeErr := os.Remove(bundlePath.Name()); removeErr != nil {
log.Errorf("Failed to remove zip file: %v", err2) log.Errorf("Failed to remove zip file: %v", removeErr)
} }
} }
}() }()
archive := zip.NewWriter(bundlePath) if err := s.createArchive(bundlePath, req); err != nil {
defer func() { return nil, err
if err := archive.Close(); err != nil {
log.Errorf("failed to close archive writer: %v", err)
}
}()
if status := req.GetStatus(); status != "" {
filename := "status.txt"
if req.GetAnonymize() {
filename = "status.anon.txt"
}
statusReader := strings.NewReader(status)
if err := addFileToZip(archive, statusReader, filename); err != nil {
return nil, fmt.Errorf("add status file to zip: %w", err)
}
}
logFile, err := os.Open(s.logFile)
if err != nil {
return nil, fmt.Errorf("open log file: %w", err)
}
defer func() {
if err := logFile.Close(); err != nil {
log.Errorf("failed to close original log file: %v", err)
}
}()
filename := "client.log.txt"
var logReader io.Reader
errChan := make(chan error, 1)
if req.GetAnonymize() {
filename = "client.anon.log.txt"
var writer io.WriteCloser
logReader, writer = io.Pipe()
go s.anonymize(logFile, writer, errChan)
} else {
logReader = logFile
}
if err := addFileToZip(archive, logReader, filename); err != nil {
return nil, fmt.Errorf("add log file to zip: %w", err)
}
select {
case err := <-errChan:
if err != nil {
return nil, err
}
default:
} }
return &proto.DebugBundleResponse{Path: bundlePath.Name()}, nil return &proto.DebugBundleResponse{Path: bundlePath.Name()}, nil
} }
func (s *Server) anonymize(reader io.Reader, writer io.WriteCloser, errChan chan<- error) { func (s *Server) createArchive(bundlePath *os.File, req *proto.DebugBundleRequest) error {
scanner := bufio.NewScanner(reader) archive := zip.NewWriter(bundlePath)
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses()) if err := s.addReadme(req, archive); err != nil {
return fmt.Errorf("add readme: %w", err)
}
if err := s.addStatus(req, archive); err != nil {
return fmt.Errorf("add status: %w", err)
}
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
status := s.statusRecorder.GetFullStatus() status := s.statusRecorder.GetFullStatus()
seedFromStatus(anonymizer, &status) seedFromStatus(anonymizer, &status)
if err := s.addConfig(req, anonymizer, archive); err != nil {
return fmt.Errorf("add config: %w", err)
}
if req.GetSystemInfo() {
if err := s.addRoutes(req, anonymizer, archive); err != nil {
return fmt.Errorf("add routes: %w", err)
}
if err := s.addInterfaces(req, anonymizer, archive); err != nil {
return fmt.Errorf("add interfaces: %w", err)
}
}
if err := s.addLogfile(req, anonymizer, archive); err != nil {
return fmt.Errorf("add log file: %w", err)
}
if err := archive.Close(); err != nil {
return fmt.Errorf("close archive writer: %w", err)
}
return nil
}
func (s *Server) addReadme(req *proto.DebugBundleRequest, archive *zip.Writer) error {
if req.GetAnonymize() {
readmeReader := strings.NewReader(readmeContent)
if err := addFileToZip(archive, readmeReader, "README.txt"); err != nil {
return fmt.Errorf("add README file to zip: %w", err)
}
}
return nil
}
func (s *Server) addStatus(req *proto.DebugBundleRequest, archive *zip.Writer) error {
if status := req.GetStatus(); status != "" {
statusReader := strings.NewReader(status)
if err := addFileToZip(archive, statusReader, "status.txt"); err != nil {
return fmt.Errorf("add status file to zip: %w", err)
}
}
return nil
}
func (s *Server) addConfig(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
var configContent strings.Builder
s.addCommonConfigFields(&configContent)
if req.GetAnonymize() {
if s.config.ManagementURL != nil {
configContent.WriteString(fmt.Sprintf("ManagementURL: %s\n", anonymizer.AnonymizeURI(s.config.ManagementURL.String())))
}
if s.config.AdminURL != nil {
configContent.WriteString(fmt.Sprintf("AdminURL: %s\n", anonymizer.AnonymizeURI(s.config.AdminURL.String())))
}
configContent.WriteString(fmt.Sprintf("NATExternalIPs: %v\n", anonymizeNATExternalIPs(s.config.NATExternalIPs, anonymizer)))
if s.config.CustomDNSAddress != "" {
configContent.WriteString(fmt.Sprintf("CustomDNSAddress: %s\n", anonymizer.AnonymizeString(s.config.CustomDNSAddress)))
}
} else {
if s.config.ManagementURL != nil {
configContent.WriteString(fmt.Sprintf("ManagementURL: %s\n", s.config.ManagementURL.String()))
}
if s.config.AdminURL != nil {
configContent.WriteString(fmt.Sprintf("AdminURL: %s\n", s.config.AdminURL.String()))
}
configContent.WriteString(fmt.Sprintf("NATExternalIPs: %v\n", s.config.NATExternalIPs))
if s.config.CustomDNSAddress != "" {
configContent.WriteString(fmt.Sprintf("CustomDNSAddress: %s\n", s.config.CustomDNSAddress))
}
}
// Add config content to zip file
configReader := strings.NewReader(configContent.String())
if err := addFileToZip(archive, configReader, "config.txt"); err != nil {
return fmt.Errorf("add config file to zip: %w", err)
}
return nil
}
func (s *Server) addCommonConfigFields(configContent *strings.Builder) {
configContent.WriteString("NetBird Client Configuration:\n\n")
// Add non-sensitive fields
configContent.WriteString(fmt.Sprintf("WgIface: %s\n", s.config.WgIface))
configContent.WriteString(fmt.Sprintf("WgPort: %d\n", s.config.WgPort))
if s.config.NetworkMonitor != nil {
configContent.WriteString(fmt.Sprintf("NetworkMonitor: %v\n", *s.config.NetworkMonitor))
}
configContent.WriteString(fmt.Sprintf("IFaceBlackList: %v\n", s.config.IFaceBlackList))
configContent.WriteString(fmt.Sprintf("DisableIPv6Discovery: %v\n", s.config.DisableIPv6Discovery))
configContent.WriteString(fmt.Sprintf("RosenpassEnabled: %v\n", s.config.RosenpassEnabled))
configContent.WriteString(fmt.Sprintf("RosenpassPermissive: %v\n", s.config.RosenpassPermissive))
if s.config.ServerSSHAllowed != nil {
configContent.WriteString(fmt.Sprintf("ServerSSHAllowed: %v\n", *s.config.ServerSSHAllowed))
}
configContent.WriteString(fmt.Sprintf("DisableAutoConnect: %v\n", s.config.DisableAutoConnect))
configContent.WriteString(fmt.Sprintf("DNSRouteInterval: %s\n", s.config.DNSRouteInterval))
}
func (s *Server) addRoutes(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
if routes, err := systemops.GetRoutesFromTable(); err != nil {
log.Errorf("Failed to get routes: %v", err)
} else {
// TODO: get routes including nexthop
routesContent := formatRoutes(routes, req.GetAnonymize(), anonymizer)
routesReader := strings.NewReader(routesContent)
if err := addFileToZip(archive, routesReader, "routes.txt"); err != nil {
return fmt.Errorf("add routes file to zip: %w", err)
}
}
return nil
}
func (s *Server) addInterfaces(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
interfaces, err := net.Interfaces()
if err != nil {
return fmt.Errorf("get interfaces: %w", err)
}
interfacesContent := formatInterfaces(interfaces, req.GetAnonymize(), anonymizer)
interfacesReader := strings.NewReader(interfacesContent)
if err := addFileToZip(archive, interfacesReader, "interfaces.txt"); err != nil {
return fmt.Errorf("add interfaces file to zip: %w", err)
}
return nil
}
func (s *Server) addLogfile(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) (err error) {
logFile, err := os.Open(s.logFile)
if err != nil {
return fmt.Errorf("open log file: %w", err)
}
defer func() { defer func() {
if err := writer.Close(); err != nil { if err := logFile.Close(); err != nil {
log.Errorf("Failed to close writer: %v", err) log.Errorf("Failed to close original log file: %v", err)
} }
}() }()
var logReader io.Reader
if req.GetAnonymize() {
var writer *io.PipeWriter
logReader, writer = io.Pipe()
go s.anonymize(logFile, writer, anonymizer)
} else {
logReader = logFile
}
if err := addFileToZip(archive, logReader, "client.log"); err != nil {
return fmt.Errorf("add log file to zip: %w", err)
}
return nil
}
func (s *Server) anonymize(reader io.Reader, writer *io.PipeWriter, anonymizer *anonymize.Anonymizer) {
defer func() {
// always nil
_ = writer.Close()
}()
scanner := bufio.NewScanner(reader)
for scanner.Scan() { for scanner.Scan() {
line := anonymizer.AnonymizeString(scanner.Text()) line := anonymizer.AnonymizeString(scanner.Text())
if _, err := writer.Write([]byte(line + "\n")); err != nil { if _, err := writer.Write([]byte(line + "\n")); err != nil {
errChan <- fmt.Errorf("write line to writer: %w", err) writer.CloseWithError(fmt.Errorf("anonymize write: %w", err))
return return
} }
} }
if err := scanner.Err(); err != nil { if err := scanner.Err(); err != nil {
errChan <- fmt.Errorf("read line from scanner: %w", err) writer.CloseWithError(fmt.Errorf("anonymize scan: %w", err))
return return
} }
} }
@ -141,8 +315,22 @@ func (s *Server) SetLogLevel(_ context.Context, req *proto.SetLogLevelRequest) (
func addFileToZip(archive *zip.Writer, reader io.Reader, filename string) error { func addFileToZip(archive *zip.Writer, reader io.Reader, filename string) error {
header := &zip.FileHeader{ header := &zip.FileHeader{
Name: filename, Name: filename,
Method: zip.Deflate, Method: zip.Deflate,
Modified: time.Now(),
CreatorVersion: 20, // Version 2.0
ReaderVersion: 20, // Version 2.0
Flags: 0x800, // UTF-8 filename
}
// If the reader is a file, we can get more accurate information
if f, ok := reader.(*os.File); ok {
if stat, err := f.Stat(); err != nil {
log.Tracef("Failed to get file stat for %s: %v", filename, err)
} else {
header.Modified = stat.ModTime()
}
} }
writer, err := archive.CreateHeader(header) writer, err := archive.CreateHeader(header)
@ -165,6 +353,13 @@ func seedFromStatus(a *anonymize.Anonymizer, status *peer.FullStatus) {
for _, peer := range status.Peers { for _, peer := range status.Peers {
a.AnonymizeDomain(peer.FQDN) a.AnonymizeDomain(peer.FQDN)
for route := range peer.GetRoutes() {
a.AnonymizeRoute(route)
}
}
for route := range status.LocalPeerState.Routes {
a.AnonymizeRoute(route)
} }
for _, nsGroup := range status.NSGroupStates { for _, nsGroup := range status.NSGroupStates {
@ -179,3 +374,113 @@ func seedFromStatus(a *anonymize.Anonymizer, status *peer.FullStatus) {
} }
} }
} }
func formatRoutes(routes []netip.Prefix, anonymize bool, anonymizer *anonymize.Anonymizer) string {
var ipv4Routes, ipv6Routes []netip.Prefix
// Separate IPv4 and IPv6 routes
for _, route := range routes {
if route.Addr().Is4() {
ipv4Routes = append(ipv4Routes, route)
} else {
ipv6Routes = append(ipv6Routes, route)
}
}
// Sort IPv4 and IPv6 routes separately
sort.Slice(ipv4Routes, func(i, j int) bool {
return ipv4Routes[i].Bits() > ipv4Routes[j].Bits()
})
sort.Slice(ipv6Routes, func(i, j int) bool {
return ipv6Routes[i].Bits() > ipv6Routes[j].Bits()
})
var builder strings.Builder
// Format IPv4 routes
builder.WriteString("IPv4 Routes:\n")
for _, route := range ipv4Routes {
formatRoute(&builder, route, anonymize, anonymizer)
}
// Format IPv6 routes
builder.WriteString("\nIPv6 Routes:\n")
for _, route := range ipv6Routes {
formatRoute(&builder, route, anonymize, anonymizer)
}
return builder.String()
}
func formatRoute(builder *strings.Builder, route netip.Prefix, anonymize bool, anonymizer *anonymize.Anonymizer) {
if anonymize {
anonymizedIP := anonymizer.AnonymizeIP(route.Addr())
builder.WriteString(fmt.Sprintf("%s/%d\n", anonymizedIP, route.Bits()))
} else {
builder.WriteString(fmt.Sprintf("%s\n", route))
}
}
func formatInterfaces(interfaces []net.Interface, anonymize bool, anonymizer *anonymize.Anonymizer) string {
sort.Slice(interfaces, func(i, j int) bool {
return interfaces[i].Name < interfaces[j].Name
})
var builder strings.Builder
builder.WriteString("Network Interfaces:\n")
for _, iface := range interfaces {
builder.WriteString(fmt.Sprintf("\nInterface: %s\n", iface.Name))
builder.WriteString(fmt.Sprintf(" Index: %d\n", iface.Index))
builder.WriteString(fmt.Sprintf(" MTU: %d\n", iface.MTU))
builder.WriteString(fmt.Sprintf(" Flags: %v\n", iface.Flags))
addrs, err := iface.Addrs()
if err != nil {
builder.WriteString(fmt.Sprintf(" Addresses: Error retrieving addresses: %v\n", err))
} else {
builder.WriteString(" Addresses:\n")
for _, addr := range addrs {
prefix, err := netip.ParsePrefix(addr.String())
if err != nil {
builder.WriteString(fmt.Sprintf(" Error parsing address: %v\n", err))
continue
}
ip := prefix.Addr()
if anonymize {
ip = anonymizer.AnonymizeIP(ip)
}
builder.WriteString(fmt.Sprintf(" %s/%d\n", ip, prefix.Bits()))
}
}
}
return builder.String()
}
func anonymizeNATExternalIPs(ips []string, anonymizer *anonymize.Anonymizer) []string {
anonymizedIPs := make([]string, len(ips))
for i, ip := range ips {
parts := strings.SplitN(ip, "/", 2)
ip1, err := netip.ParseAddr(parts[0])
if err != nil {
anonymizedIPs[i] = ip
continue
}
ip1anon := anonymizer.AnonymizeIP(ip1)
if len(parts) == 2 {
ip2, err := netip.ParseAddr(parts[1])
if err != nil {
anonymizedIPs[i] = fmt.Sprintf("%s/%s", ip1anon, parts[1])
} else {
ip2anon := anonymizer.AnonymizeIP(ip2)
anonymizedIPs[i] = fmt.Sprintf("%s/%s", ip1anon, ip2anon)
}
} else {
anonymizedIPs[i] = ip1anon.String()
}
}
return anonymizedIPs
}

View File

@ -582,7 +582,7 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes
} }
// Down engine work in the daemon. // Down engine work in the daemon.
func (s *Server) Down(_ context.Context, _ *proto.DownRequest) (*proto.DownResponse, error) { func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownResponse, error) {
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
@ -593,7 +593,25 @@ func (s *Server) Down(_ context.Context, _ *proto.DownRequest) (*proto.DownRespo
state := internal.CtxGetState(s.rootCtx) state := internal.CtxGetState(s.rootCtx)
state.Set(internal.StatusIdle) state.Set(internal.StatusIdle)
return &proto.DownResponse{}, nil maxWaitTime := 5 * time.Second
timeout := time.After(maxWaitTime)
engine := s.connectClient.Engine()
for {
if !engine.IsWGIfaceUp() {
return &proto.DownResponse{}, nil
}
select {
case <-ctx.Done():
return &proto.DownResponse{}, nil
case <-timeout:
return nil, fmt.Errorf("failed to shut down properly")
default:
time.Sleep(100 * time.Millisecond)
}
}
} }
// Status returns the daemon status // Status returns the daemon status

View File

@ -19,6 +19,7 @@ import (
mgmtProto "github.com/netbirdio/netbird/management/proto" mgmtProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/signal/proto" "github.com/netbirdio/netbird/signal/proto"
signalServer "github.com/netbirdio/netbird/signal/server" signalServer "github.com/netbirdio/netbird/signal/server"
) )
@ -120,7 +121,11 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
return nil, "", err return nil, "", err
} }
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore) ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }

View File

@ -1,11 +1,12 @@
package system package system
import ( import (
log "github.com/sirupsen/logrus"
"testing" "testing"
log "github.com/sirupsen/logrus"
) )
func Test_sysInfo(t *testing.T) { func Test_sysInfoMac(t *testing.T) {
t.Skip("skipping darwin test") t.Skip("skipping darwin test")
serialNum, prodName, manufacturer := sysInfo() serialNum, prodName, manufacturer := sysInfo()
if serialNum == "" { if serialNum == "" {

View File

@ -21,6 +21,26 @@ import (
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
) )
type SysInfoGetter interface {
GetSysInfo() SysInfo
}
type SysInfoWrapper struct {
si sysinfo.SysInfo
}
func (s SysInfoWrapper) GetSysInfo() SysInfo {
s.si.GetSysInfo()
return SysInfo{
ChassisSerial: s.si.Chassis.Serial,
ProductSerial: s.si.Product.Serial,
BoardSerial: s.si.Board.Serial,
ProductName: s.si.Product.Name,
BoardName: s.si.Board.Name,
ProductVendor: s.si.Product.Vendor,
}
}
// GetInfo retrieves and parses the system information // GetInfo retrieves and parses the system information
func GetInfo(ctx context.Context) *Info { func GetInfo(ctx context.Context) *Info {
info := _getInfo() info := _getInfo()
@ -45,7 +65,8 @@ func GetInfo(ctx context.Context) *Info {
log.Warnf("failed to discover network addresses: %s", err) log.Warnf("failed to discover network addresses: %s", err)
} }
serialNum, prodName, manufacturer := sysInfo() si := SysInfoWrapper{}
serialNum, prodName, manufacturer := sysInfo(si.GetSysInfo())
env := Environment{ env := Environment{
Cloud: detect_cloud.Detect(ctx), Cloud: detect_cloud.Detect(ctx),
@ -87,20 +108,36 @@ func _getInfo() string {
return out.String() return out.String()
} }
func sysInfo() (serialNumber string, productName string, manufacturer string) { func sysInfo(si SysInfo) (string, string, string) {
var si sysinfo.SysInfo
si.GetSysInfo()
isascii := regexp.MustCompile("^[[:ascii:]]+$") isascii := regexp.MustCompile("^[[:ascii:]]+$")
serial := si.Chassis.Serial
if (serial == "Default string" || serial == "") && si.Product.Serial != "" { serials := []string{si.ChassisSerial, si.ProductSerial}
serial = si.Product.Serial serial := ""
for _, s := range serials {
if isascii.MatchString(s) {
serial = s
if s != "Default string" {
break
}
}
} }
if (!isascii.MatchString(serial)) && si.Board.Serial != "" {
serial = si.Board.Serial if serial == "" && isascii.MatchString(si.BoardSerial) {
serial = si.BoardSerial
} }
name := si.Product.Name
if (!isascii.MatchString(name)) && si.Board.Name != "" { var name string
name = si.Board.Name for _, n := range []string{si.ProductName, si.BoardName} {
if isascii.MatchString(n) {
name = n
break
}
} }
return serial, name, si.Product.Vendor
var manufacturer string
if isascii.MatchString(si.ProductVendor) {
manufacturer = si.ProductVendor
}
return serial, name, manufacturer
} }

View File

@ -0,0 +1,12 @@
package system
// SysInfo used to moc out the sysinfo getter
type SysInfo struct {
ChassisSerial string
ProductSerial string
BoardSerial string
ProductName string
BoardName string
ProductVendor string
}

View File

@ -0,0 +1,198 @@
package system
import "testing"
func Test_sysInfo(t *testing.T) {
tests := []struct {
name string
sysInfo SysInfo
wantSerialNum string
wantProdName string
wantManufacturer string
}{
{
name: "Test Case 1",
sysInfo: SysInfo{
ChassisSerial: "Default string",
ProductSerial: "Default string",
BoardSerial: "M80-G8013200245",
ProductName: "B650M-HDV/M.2",
BoardName: "B650M-HDV/M.2",
ProductVendor: "ASRock",
},
wantSerialNum: "Default string",
wantProdName: "B650M-HDV/M.2",
wantManufacturer: "ASRock",
},
{
name: "Empty Chassis Serial",
sysInfo: SysInfo{
ChassisSerial: "",
ProductSerial: "Default string",
BoardSerial: "M80-G8013200245",
ProductName: "B650M-HDV/M.2",
BoardName: "B650M-HDV/M.2",
ProductVendor: "ASRock",
},
wantSerialNum: "Default string",
wantProdName: "B650M-HDV/M.2",
wantManufacturer: "ASRock",
},
{
name: "Empty Chassis Serial",
sysInfo: SysInfo{
ChassisSerial: "",
ProductSerial: "Default string",
BoardSerial: "M80-G8013200245",
ProductName: "B650M-HDV/M.2",
BoardName: "B650M-HDV/M.2",
ProductVendor: "ASRock",
},
wantSerialNum: "Default string",
wantProdName: "B650M-HDV/M.2",
wantManufacturer: "ASRock",
},
{
name: "Fallback to Product Serial",
sysInfo: SysInfo{
ChassisSerial: "Default string",
ProductSerial: "Product serial",
BoardSerial: "M80-G8013200245",
ProductName: "B650M-HDV/M.2",
BoardName: "B650M-HDV/M.2",
ProductVendor: "ASRock",
},
wantSerialNum: "Product serial",
wantProdName: "B650M-HDV/M.2",
wantManufacturer: "ASRock",
},
{
name: "Fallback to Product Serial with default string",
sysInfo: SysInfo{
ChassisSerial: "Default string",
ProductSerial: "Default string",
BoardSerial: "M80-G8013200245",
ProductName: "B650M-HDV/M.2",
BoardName: "B650M-HDV/M.2",
ProductVendor: "ASRock",
},
wantSerialNum: "Default string",
wantProdName: "B650M-HDV/M.2",
wantManufacturer: "ASRock",
},
{
name: "Non UTF-8 in Chassis Serial",
sysInfo: SysInfo{
ChassisSerial: "\x80",
ProductSerial: "Product serial",
BoardSerial: "M80-G8013200245",
ProductName: "B650M-HDV/M.2",
BoardName: "B650M-HDV/M.2",
ProductVendor: "ASRock",
},
wantSerialNum: "Product serial",
wantProdName: "B650M-HDV/M.2",
wantManufacturer: "ASRock",
},
{
name: "Non UTF-8 in Chassis Serial and Product Serial",
sysInfo: SysInfo{
ChassisSerial: "\x80",
ProductSerial: "\x80",
BoardSerial: "M80-G8013200245",
ProductName: "B650M-HDV/M.2",
BoardName: "B650M-HDV/M.2",
ProductVendor: "ASRock",
},
wantSerialNum: "M80-G8013200245",
wantProdName: "B650M-HDV/M.2",
wantManufacturer: "ASRock",
},
{
name: "Non UTF-8 in Chassis Serial and Product Serial and BoardSerial",
sysInfo: SysInfo{
ChassisSerial: "\x80",
ProductSerial: "\x80",
BoardSerial: "\x80",
ProductName: "B650M-HDV/M.2",
BoardName: "B650M-HDV/M.2",
ProductVendor: "ASRock",
},
wantSerialNum: "",
wantProdName: "B650M-HDV/M.2",
wantManufacturer: "ASRock",
},
{
name: "Empty Product Name",
sysInfo: SysInfo{
ChassisSerial: "Default string",
ProductSerial: "Default string",
BoardSerial: "M80-G8013200245",
ProductName: "",
BoardName: "boardname",
ProductVendor: "ASRock",
},
wantSerialNum: "Default string",
wantProdName: "boardname",
wantManufacturer: "ASRock",
},
{
name: "Invalid Product Name",
sysInfo: SysInfo{
ChassisSerial: "Default string",
ProductSerial: "Default string",
BoardSerial: "M80-G8013200245",
ProductName: "\x80",
BoardName: "boardname",
ProductVendor: "ASRock",
},
wantSerialNum: "Default string",
wantProdName: "boardname",
wantManufacturer: "ASRock",
},
{
name: "Invalid BoardName Name",
sysInfo: SysInfo{
ChassisSerial: "Default string",
ProductSerial: "Default string",
BoardSerial: "M80-G8013200245",
ProductName: "\x80",
BoardName: "\x80",
ProductVendor: "ASRock",
},
wantSerialNum: "Default string",
wantProdName: "",
wantManufacturer: "ASRock",
},
{
name: "Invalid chars",
sysInfo: SysInfo{
ChassisSerial: "\x80",
ProductSerial: "\x80",
BoardSerial: "\x80",
ProductName: "\x80",
BoardName: "\x80",
ProductVendor: "\x80",
},
wantSerialNum: "",
wantProdName: "",
wantManufacturer: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotSerialNum, gotProdName, gotManufacturer := sysInfo(tt.sysInfo)
if gotSerialNum != tt.wantSerialNum {
t.Errorf("sysInfo() gotSerialNum = %v, want %v", gotSerialNum, tt.wantSerialNum)
}
if gotProdName != tt.wantProdName {
t.Errorf("sysInfo() gotProdName = %v, want %v", gotProdName, tt.wantProdName)
}
if gotManufacturer != tt.wantManufacturer {
t.Errorf("sysInfo() gotManufacturer = %v, want %v", gotManufacturer, tt.wantManufacturer)
}
})
}
}

View File

@ -22,8 +22,8 @@ import (
"fyne.io/fyne/v2/app" "fyne.io/fyne/v2/app"
"fyne.io/fyne/v2/dialog" "fyne.io/fyne/v2/dialog"
"fyne.io/fyne/v2/widget" "fyne.io/fyne/v2/widget"
"fyne.io/systray"
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
"github.com/getlantern/systray"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/skratchdot/open-golang/open" "github.com/skratchdot/open-golang/open"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"

View File

@ -1,4 +1,4 @@
//go:build darwin || dragonfly || freebsd || netbsd || openbsd //go:build darwin
package main package main

View File

@ -10,7 +10,7 @@ import (
func EncryptMessage(remotePubKey wgtypes.Key, ourPrivateKey wgtypes.Key, message pb.Message) ([]byte, error) { func EncryptMessage(remotePubKey wgtypes.Key, ourPrivateKey wgtypes.Key, message pb.Message) ([]byte, error) {
byteResp, err := pb.Marshal(message) byteResp, err := pb.Marshal(message)
if err != nil { if err != nil {
log.Errorf("failed marshalling message %v", err) log.Errorf("failed marshalling message %v, %+v", err, message.String())
return nil, err return nil, err
} }

View File

@ -14,14 +14,29 @@ type TextFormatter struct {
levelDesc []string levelDesc []string
} }
// SyslogFormatter formats logs into text
type SyslogFormatter struct {
levelDesc []string
}
var validLevelDesc = []string{"PANC", "FATL", "ERRO", "WARN", "INFO", "DEBG", "TRAC"}
// NewTextFormatter create new MyTextFormatter instance // NewTextFormatter create new MyTextFormatter instance
func NewTextFormatter() *TextFormatter { func NewTextFormatter() *TextFormatter {
return &TextFormatter{ return &TextFormatter{
levelDesc: []string{"PANC", "FATL", "ERRO", "WARN", "INFO", "DEBG", "TRAC"}, levelDesc: validLevelDesc,
timestampFormat: time.RFC3339, // or RFC3339 timestampFormat: time.RFC3339, // or RFC3339
} }
} }
// NewSyslogFormatter create new MySyslogFormatter instance
func NewSyslogFormatter() *SyslogFormatter {
return &SyslogFormatter{
levelDesc: validLevelDesc,
}
}
// Format renders a single log entry // Format renders a single log entry
func (f *TextFormatter) Format(entry *logrus.Entry) ([]byte, error) { func (f *TextFormatter) Format(entry *logrus.Entry) ([]byte, error) {
var fields string var fields string
@ -49,3 +64,20 @@ func (f *TextFormatter) parseLevel(level logrus.Level) string {
return f.levelDesc[level] return f.levelDesc[level]
} }
// Format renders a single log entry
func (f *SyslogFormatter) Format(entry *logrus.Entry) ([]byte, error) {
var fields string
keys := make([]string, 0, len(entry.Data))
for k, v := range entry.Data {
if k == "source" {
continue
}
keys = append(keys, fmt.Sprintf("%s: %v", k, v))
}
if len(keys) > 0 {
fields = fmt.Sprintf("[%s] ", strings.Join(keys, ", "))
}
return []byte(fmt.Sprintf("%s%s\n", fields, entry.Message)), nil
}

View File

@ -8,7 +8,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestLogMessageFormat(t *testing.T) { func TestLogTextFormat(t *testing.T) {
someEntry := &logrus.Entry{ someEntry := &logrus.Entry{
Data: logrus.Fields{"att1": 1, "att2": 2, "source": "some/fancy/path.go:46"}, Data: logrus.Fields{"att1": 1, "att2": 2, "source": "some/fancy/path.go:46"},
@ -24,3 +24,20 @@ func TestLogMessageFormat(t *testing.T) {
expectedString := "^2021-02-21T01:10:30Z WARN \\[(att1: 1, att2: 2|att2: 2, att1: 1)\\] some/fancy/path.go:46: Some Message\\s+$" expectedString := "^2021-02-21T01:10:30Z WARN \\[(att1: 1, att2: 2|att2: 2, att1: 1)\\] some/fancy/path.go:46: Some Message\\s+$"
assert.Regexp(t, expectedString, parsedString) assert.Regexp(t, expectedString, parsedString)
} }
func TestLogSyslogFormat(t *testing.T) {
someEntry := &logrus.Entry{
Data: logrus.Fields{"att1": 1, "att2": 2, "source": "some/fancy/path.go:46"},
Time: time.Date(2021, time.Month(2), 21, 1, 10, 30, 0, time.UTC),
Level: 3,
Message: "Some Message",
}
formatter := NewSyslogFormatter()
result, _ := formatter.Format(someEntry)
parsedString := string(result)
expectedString := "^\\[(att1: 1, att2: 2|att2: 2, att1: 1)\\] Some Message\\s+$"
assert.Regexp(t, expectedString, parsedString)
}

View File

@ -10,6 +10,12 @@ func SetTextFormatter(logger *logrus.Logger) {
logger.ReportCaller = true logger.ReportCaller = true
logger.AddHook(NewContextHook()) logger.AddHook(NewContextHook())
} }
// SetSyslogFormatter set the text formatter for given logger.
func SetSyslogFormatter(logger *logrus.Logger) {
logger.Formatter = NewSyslogFormatter()
logger.ReportCaller = true
logger.AddHook(NewContextHook())
}
// SetJSONFormatter set the JSON formatter for given logger. // SetJSONFormatter set the JSON formatter for given logger.
func SetJSONFormatter(logger *logrus.Logger) { func SetJSONFormatter(logger *logrus.Logger) {

12
go.mod
View File

@ -31,6 +31,7 @@ require (
require ( require (
fyne.io/fyne/v2 v2.1.4 fyne.io/fyne/v2 v2.1.4
fyne.io/systray v1.11.0
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible
github.com/c-robinson/iplib v1.0.3 github.com/c-robinson/iplib v1.0.3
github.com/cilium/ebpf v0.15.0 github.com/cilium/ebpf v0.15.0
@ -38,7 +39,6 @@ require (
github.com/creack/pty v1.1.18 github.com/creack/pty v1.1.18
github.com/eko/gocache/v3 v3.1.1 github.com/eko/gocache/v3 v3.1.1
github.com/fsnotify/fsnotify v1.6.0 github.com/fsnotify/fsnotify v1.6.0
github.com/getlantern/systray v1.2.1
github.com/gliderlabs/ssh v0.3.4 github.com/gliderlabs/ssh v0.3.4
github.com/godbus/dbus/v5 v5.1.0 github.com/godbus/dbus/v5 v5.1.0
github.com/golang/mock v1.6.0 github.com/golang/mock v1.6.0
@ -116,24 +116,17 @@ require (
github.com/dgraph-io/ristretto v0.1.1 // indirect github.com/dgraph-io/ristretto v0.1.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/distribution/reference v0.6.0 // indirect github.com/distribution/reference v0.6.0 // indirect
github.com/docker/docker v26.1.3+incompatible // indirect github.com/docker/docker v26.1.4+incompatible // indirect
github.com/docker/go-connections v0.5.0 // indirect github.com/docker/go-connections v0.5.0 // indirect
github.com/docker/go-units v0.5.0 // indirect github.com/docker/go-units v0.5.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/fredbi/uri v0.0.0-20181227131451-3dcfdacbaaf3 // indirect github.com/fredbi/uri v0.0.0-20181227131451-3dcfdacbaaf3 // indirect
github.com/getlantern/context v0.0.0-20190109183933-c447772a6520 // indirect
github.com/getlantern/errors v0.0.0-20190325191628-abdb3e3e36f7 // indirect
github.com/getlantern/golog v0.0.0-20190830074920-4ef2e798c2d7 // indirect
github.com/getlantern/hex v0.0.0-20190417191902-c6586a6fe0b7 // indirect
github.com/getlantern/hidden v0.0.0-20190325191715-f02dbb02be55 // indirect
github.com/getlantern/ops v0.0.0-20190325191751-d70cb0d6f85f // indirect
github.com/go-gl/gl v0.0.0-20210813123233-e4099ee2221f // indirect github.com/go-gl/gl v0.0.0-20210813123233-e4099ee2221f // indirect
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20211024062804-40e447a793be // indirect github.com/go-gl/glfw/v3.3/glfw v0.0.0-20211024062804-40e447a793be // indirect
github.com/go-logr/logr v1.4.1 // indirect github.com/go-logr/logr v1.4.1 // indirect
github.com/go-logr/stdr v1.2.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-ole/go-ole v1.3.0 // indirect github.com/go-ole/go-ole v1.3.0 // indirect
github.com/go-redis/redis/v8 v8.11.5 // indirect github.com/go-redis/redis/v8 v8.11.5 // indirect
github.com/go-stack/stack v1.8.0 // indirect
github.com/gogo/protobuf v1.3.2 // indirect github.com/gogo/protobuf v1.3.2 // indirect
github.com/goki/freetype v0.0.0-20181231101311-fa8a33aabaff // indirect github.com/goki/freetype v0.0.0-20181231101311-fa8a33aabaff // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
@ -165,7 +158,6 @@ require (
github.com/nxadm/tail v1.4.8 // indirect github.com/nxadm/tail v1.4.8 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.0 // indirect github.com/opencontainers/image-spec v1.1.0 // indirect
github.com/oxtoacart/bpool v0.0.0-20190530202638-03653db5a59c // indirect
github.com/pegasus-kv/thrift v0.13.0 // indirect github.com/pegasus-kv/thrift v0.13.0 // indirect
github.com/pion/dtls/v2 v2.2.10 // indirect github.com/pion/dtls/v2 v2.2.10 // indirect
github.com/pion/mdns v0.0.12 // indirect github.com/pion/mdns v0.0.12 // indirect

25
go.sum
View File

@ -12,6 +12,8 @@ dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk=
dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
fyne.io/fyne/v2 v2.1.4 h1:bt1+28++kAzRzPB0GM2EuSV4cnl8rXNX4cjfd8G06Rc= fyne.io/fyne/v2 v2.1.4 h1:bt1+28++kAzRzPB0GM2EuSV4cnl8rXNX4cjfd8G06Rc=
fyne.io/fyne/v2 v2.1.4/go.mod h1:p+E/Dh+wPW8JwR2DVcsZ9iXgR9ZKde80+Y+40Is54AQ= fyne.io/fyne/v2 v2.1.4/go.mod h1:p+E/Dh+wPW8JwR2DVcsZ9iXgR9ZKde80+Y+40Is54AQ=
fyne.io/systray v1.11.0 h1:D9HISlxSkx+jHSniMBR6fCFOUjk1x/OOOJLa9lJYAKg=
fyne.io/systray v1.11.0/go.mod h1:RVwqP9nYMo7h5zViCBHri2FgjXF7H2cub7MAq4NSoLs=
github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24 h1:bvDV9vkmnHYOMsOr4WLk+Vo07yKIzd94sVoIqshQ4bU= github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24 h1:bvDV9vkmnHYOMsOr4WLk+Vo07yKIzd94sVoIqshQ4bU=
github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8= github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8=
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0= github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0=
@ -81,8 +83,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk= github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E= github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
github.com/docker/docker v26.1.3+incompatible h1:lLCzRbrVZrljpVNobJu1J2FHk8V0s4BawoZippkc+xo= github.com/docker/docker v26.1.4+incompatible h1:vuTpXDuoga+Z38m1OZHzl7NKisKWaWlhjQk7IDPSLsU=
github.com/docker/docker v26.1.3+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= github.com/docker/docker v26.1.4+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c= github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c=
github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc= github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc=
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
@ -111,18 +113,6 @@ github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMo
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY= github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY=
github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw= github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw=
github.com/getlantern/context v0.0.0-20190109183933-c447772a6520 h1:NRUJuo3v3WGC/g5YiyF790gut6oQr5f3FBI88Wv0dx4=
github.com/getlantern/context v0.0.0-20190109183933-c447772a6520/go.mod h1:L+mq6/vvYHKjCX2oez0CgEAJmbq1fbb/oNJIWQkBybY=
github.com/getlantern/errors v0.0.0-20190325191628-abdb3e3e36f7 h1:6uJ+sZ/e03gkbqZ0kUG6mfKoqDb4XMAzMIwlajq19So=
github.com/getlantern/errors v0.0.0-20190325191628-abdb3e3e36f7/go.mod h1:l+xpFBrCtDLpK9qNjxs+cHU6+BAdlBaxHqikB6Lku3A=
github.com/getlantern/golog v0.0.0-20190830074920-4ef2e798c2d7 h1:guBYzEaLz0Vfc/jv0czrr2z7qyzTOGC9hiQ0VC+hKjk=
github.com/getlantern/golog v0.0.0-20190830074920-4ef2e798c2d7/go.mod h1:zx/1xUUeYPy3Pcmet8OSXLbF47l+3y6hIPpyLWoR9oc=
github.com/getlantern/hex v0.0.0-20190417191902-c6586a6fe0b7 h1:micT5vkcr9tOVk1FiH8SWKID8ultN44Z+yzd2y/Vyb0=
github.com/getlantern/hex v0.0.0-20190417191902-c6586a6fe0b7/go.mod h1:dD3CgOrwlzca8ed61CsZouQS5h5jIzkK9ZWrTcf0s+o=
github.com/getlantern/hidden v0.0.0-20190325191715-f02dbb02be55 h1:XYzSdCbkzOC0FDNrgJqGRo8PCMFOBFL9py72DRs7bmc=
github.com/getlantern/hidden v0.0.0-20190325191715-f02dbb02be55/go.mod h1:6mmzY2kW1TOOrVy+r41Za2MxXM+hhqTtY3oBKd2AgFA=
github.com/getlantern/ops v0.0.0-20190325191751-d70cb0d6f85f h1:wrYrQttPS8FHIRSlsrcuKazukx/xqO/PpLZzZXsF+EA=
github.com/getlantern/ops v0.0.0-20190325191751-d70cb0d6f85f/go.mod h1:D5ao98qkA6pxftxoqzibIBBrLSUli+kYnJqrgBf9cIA=
github.com/ghodss/yaml v0.0.0-20150909031657-73d445a93680/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/ghodss/yaml v0.0.0-20150909031657-73d445a93680/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
github.com/gin-gonic/gin v1.5.0/go.mod h1:Nd6IXA8m5kNZdNEHMBd93KT+mdY3+bewLgRvmCsR2Do= github.com/gin-gonic/gin v1.5.0/go.mod h1:Nd6IXA8m5kNZdNEHMBd93KT+mdY3+bewLgRvmCsR2Do=
@ -151,8 +141,6 @@ github.com/go-quicktest/qt v1.101.0 h1:O1K29Txy5P2OK0dGo59b7b0LR6wKfIhttaAhHUyn7
github.com/go-quicktest/qt v1.101.0/go.mod h1:14Bz/f7NwaXPtdYEgzsx46kqSxVwTbzVZsDC26tQJow= github.com/go-quicktest/qt v1.101.0/go.mod h1:14Bz/f7NwaXPtdYEgzsx46kqSxVwTbzVZsDC26tQJow=
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk=
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk= github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk=
@ -337,8 +325,6 @@ github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-
github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y= github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y=
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g= github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g=
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 h1:xbWM9BU6mwZZLHxEjxIX/V8Hv3HurQt4mReIE4mY4DM=
github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949/go.mod h1:AecygODWIsBquJCJFop8MEQcJbWFfw/1yWbVabNgpCM=
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs= github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs=
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
@ -368,8 +354,6 @@ github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQ
github.com/opencontainers/image-spec v1.1.0/go.mod h1:W4s4sFTMaBeK1BQLXbG4AdM2szdn85PY75RI83NrTrM= github.com/opencontainers/image-spec v1.1.0/go.mod h1:W4s4sFTMaBeK1BQLXbG4AdM2szdn85PY75RI83NrTrM=
github.com/oschwald/maxminddb-golang v1.12.0 h1:9FnTOD0YOhP7DGxGsq4glzpGy5+w7pq50AS6wALUMYs= github.com/oschwald/maxminddb-golang v1.12.0 h1:9FnTOD0YOhP7DGxGsq4glzpGy5+w7pq50AS6wALUMYs=
github.com/oschwald/maxminddb-golang v1.12.0/go.mod h1:q0Nob5lTCqyQ8WT6FYgS1L7PXKVVbgiymefNwIjPzgY= github.com/oschwald/maxminddb-golang v1.12.0/go.mod h1:q0Nob5lTCqyQ8WT6FYgS1L7PXKVVbgiymefNwIjPzgY=
github.com/oxtoacart/bpool v0.0.0-20190530202638-03653db5a59c h1:rp5dCmg/yLR3mgFuSOe4oEnDDmGLROTvMragMUXpTQw=
github.com/oxtoacart/bpool v0.0.0-20190530202638-03653db5a59c/go.mod h1:X07ZCGwUbLaax7L0S3Tw4hpejzu63ZrrQiUe6W0hcy0=
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
github.com/pegasus-kv/thrift v0.13.0 h1:4ESwaNoHImfbHa9RUGJiJZ4hrxorihZHk5aarYwY8d4= github.com/pegasus-kv/thrift v0.13.0 h1:4ESwaNoHImfbHa9RUGJiJZ4hrxorihZHk5aarYwY8d4=
@ -611,7 +595,6 @@ golang.org/x/sys v0.0.0-20191010194322-b09406accb47/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200515095857-1151b9dac4a9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200728102440-3e129f6d46b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200728102440-3e129f6d46b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201015000850-e3ed0017c211/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201015000850-e3ed0017c211/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=

View File

@ -9,7 +9,10 @@ import (
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
@ -71,7 +74,11 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
peersUpdateManager := mgmt.NewPeersUpdateManager(nil) peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore) ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -2,7 +2,6 @@ package client
import ( import (
"context" "context"
"crypto/tls"
"fmt" "fmt"
"io" "io"
"sync" "sync"
@ -11,15 +10,11 @@ import (
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status" gstatus "google.golang.org/grpc/status"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/connectivity" "google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"
"github.com/cenkalti/backoff/v4"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
@ -51,26 +46,21 @@ type GrpcClient struct {
// NewClient creates a new client to Management service // NewClient creates a new client to Management service
func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) { func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) {
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials()) var conn *grpc.ClientConn
if tlsEnabled { operation := func() error {
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{})) var err error
conn, err = nbgrpc.CreateConnection(addr, tlsEnabled)
if err != nil {
log.Printf("createConnection error: %v", err)
return err
}
return nil
} }
mgmCtx, cancel := context.WithTimeout(ctx, ConnectTimeout) err := backoff.Retry(operation, nbgrpc.Backoff(ctx))
defer cancel()
conn, err := grpc.DialContext(
mgmCtx,
addr,
transportOption,
nbgrpc.WithCustomDialer(),
grpc.WithBlock(),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 30 * time.Second,
Timeout: 10 * time.Second,
}))
if err != nil { if err != nil {
log.Errorf("failed creating connection to Management Service %v", err) log.Errorf("failed creating connection to Management Service: %v", err)
return nil, err return nil, err
} }
@ -326,25 +316,44 @@ func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*pro
if !c.ready() { if !c.ready() {
return nil, fmt.Errorf(errMsgNoMgmtConnection) return nil, fmt.Errorf(errMsgNoMgmtConnection)
} }
loginReq, err := encryption.EncryptMessage(serverKey, c.key, req) loginReq, err := encryption.EncryptMessage(serverKey, c.key, req)
if err != nil { if err != nil {
log.Errorf("failed to encrypt message: %s", err) log.Errorf("failed to encrypt message: %s", err)
return nil, err return nil, err
} }
mgmCtx, cancel := context.WithTimeout(c.ctx, ConnectTimeout)
defer cancel() var resp *proto.EncryptedMessage
resp, err := c.realClient.Login(mgmCtx, &proto.EncryptedMessage{ operation := func() error {
WgPubKey: c.key.PublicKey().String(), mgmCtx, cancel := context.WithTimeout(context.Background(), ConnectTimeout)
Body: loginReq, defer cancel()
})
var err error
resp, err = c.realClient.Login(mgmCtx, &proto.EncryptedMessage{
WgPubKey: c.key.PublicKey().String(),
Body: loginReq,
})
if err != nil {
// retry only on context canceled
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.Canceled {
return err
}
return backoff.Permanent(err)
}
return nil
}
err = backoff.Retry(operation, nbgrpc.Backoff(c.ctx))
if err != nil { if err != nil {
log.Errorf("failed to login to Management Service: %v", err)
return nil, err return nil, err
} }
loginResp := &proto.LoginResponse{} loginResp := &proto.LoginResponse{}
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, loginResp) err = encryption.DecryptMessage(serverKey, c.key, resp.Body, loginResp)
if err != nil { if err != nil {
log.Errorf("failed to decrypt registration message: %s", err) log.Errorf("failed to decrypt login response: %s", err)
return nil, err return nil, err
} }

View File

@ -190,7 +190,7 @@ var (
return fmt.Errorf("failed to initialize integrated peer validator: %v", err) return fmt.Errorf("failed to initialize integrated peer validator: %v", err)
} }
accountManager, err := server.BuildManager(ctx, store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain, accountManager, err := server.BuildManager(ctx, store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain,
dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator) dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator, appMetrics)
if err != nil { if err != nil {
return fmt.Errorf("failed to build default manager: %v", err) return fmt.Errorf("failed to build default manager: %v", err)
} }

View File

@ -18,6 +18,8 @@ import (
"github.com/eko/gocache/v3/cache" "github.com/eko/gocache/v3/cache"
cacheStore "github.com/eko/gocache/v3/store" cacheStore "github.com/eko/gocache/v3/store"
"github.com/hashicorp/go-multierror"
"github.com/miekg/dns"
gocache "github.com/patrickmn/go-cache" gocache "github.com/patrickmn/go-cache"
"github.com/rs/xid" "github.com/rs/xid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -37,6 +39,7 @@ import (
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@ -65,6 +68,7 @@ type AccountManager interface {
SaveSetupKey(ctx context.Context, accountID string, key *SetupKey, userID string) (*SetupKey, error) SaveSetupKey(ctx context.Context, accountID string, key *SetupKey, userID string) (*SetupKey, error)
CreateUser(ctx context.Context, accountID, initiatorUserID string, key *UserInfo) (*UserInfo, error) CreateUser(ctx context.Context, accountID, initiatorUserID string, key *UserInfo) (*UserInfo, error)
DeleteUser(ctx context.Context, accountID, initiatorUserID string, targetUserID string) error DeleteUser(ctx context.Context, accountID, initiatorUserID string, targetUserID string) error
DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string) error
InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error
ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error)
SaveUser(ctx context.Context, accountID, initiatorUserID string, update *User) (*UserInfo, error) SaveUser(ctx context.Context, accountID, initiatorUserID string, update *User) (*UserInfo, error)
@ -97,6 +101,7 @@ type AccountManager interface {
SaveGroup(ctx context.Context, accountID, userID string, group *nbgroup.Group) error SaveGroup(ctx context.Context, accountID, userID string, group *nbgroup.Group) error
SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error
DeleteGroup(ctx context.Context, accountId, userId, groupID string) error DeleteGroup(ctx context.Context, accountId, userId, groupID string) error
DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error
ListGroups(ctx context.Context, accountId string) ([]*nbgroup.Group, error) ListGroups(ctx context.Context, accountId string) ([]*nbgroup.Group, error)
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error
@ -134,8 +139,8 @@ type AccountManager interface {
UpdateIntegratedValidatorGroups(ctx context.Context, accountID string, userID string, groups []string) error UpdateIntegratedValidatorGroups(ctx context.Context, accountID string, userID string, groups []string) error
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)
GetValidatedPeers(account *Account) (map[string]struct{}, error) GetValidatedPeers(account *Account) (map[string]struct{}, error)
SyncAndMarkPeer(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error)
CancelPeerRoutines(ctx context.Context, peer *nbpeer.Peer) error OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error) GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error)
@ -169,6 +174,8 @@ type DefaultAccountManager struct {
userDeleteFromIDPEnabled bool userDeleteFromIDPEnabled bool
integratedPeerValidator integrated_validator.IntegratedValidator integratedPeerValidator integrated_validator.IntegratedValidator
metrics telemetry.AppMetrics
} }
// Settings represents Account settings structure that can be modified via API and Dashboard // Settings represents Account settings structure that can be modified via API and Dashboard
@ -400,8 +407,16 @@ func (a *Account) GetGroup(groupID string) *nbgroup.Group {
return a.Groups[groupID] return a.Groups[groupID]
} }
// GetPeerNetworkMap returns a group by ID if exists, nil otherwise // GetPeerNetworkMap returns the networkmap for the given peer ID.
func (a *Account) GetPeerNetworkMap(ctx context.Context, peerID, dnsDomain string, validatedPeersMap map[string]struct{}) *NetworkMap { func (a *Account) GetPeerNetworkMap(
ctx context.Context,
peerID string,
peersCustomZone nbdns.CustomZone,
validatedPeersMap map[string]struct{},
metrics *telemetry.AccountManagerMetrics,
) *NetworkMap {
start := time.Now()
peer := a.Peers[peerID] peer := a.Peers[peerID]
if peer == nil { if peer == nil {
return &NetworkMap{ return &NetworkMap{
@ -437,7 +452,7 @@ func (a *Account) GetPeerNetworkMap(ctx context.Context, peerID, dnsDomain strin
if dnsManagementStatus { if dnsManagementStatus {
var zones []nbdns.CustomZone var zones []nbdns.CustomZone
peersCustomZone := getPeersCustomZone(ctx, a, dnsDomain)
if peersCustomZone.Domain != "" { if peersCustomZone.Domain != "" {
zones = append(zones, peersCustomZone) zones = append(zones, peersCustomZone)
} }
@ -445,7 +460,7 @@ func (a *Account) GetPeerNetworkMap(ctx context.Context, peerID, dnsDomain strin
dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID) dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID)
} }
return &NetworkMap{ nm := &NetworkMap{
Peers: peersToConnect, Peers: peersToConnect,
Network: a.Network.Copy(), Network: a.Network.Copy(),
Routes: routesUpdate, Routes: routesUpdate,
@ -453,6 +468,60 @@ func (a *Account) GetPeerNetworkMap(ctx context.Context, peerID, dnsDomain strin
OfflinePeers: expiredPeers, OfflinePeers: expiredPeers,
FirewallRules: firewallRules, FirewallRules: firewallRules,
} }
if metrics != nil {
objectCount := int64(len(peersToConnect) + len(expiredPeers) + len(routesUpdate) + len(firewallRules))
metrics.CountNetworkMapObjects(objectCount)
metrics.CountGetPeerNetworkMapDuration(time.Since(start))
}
return nm
}
func (a *Account) GetPeersCustomZone(ctx context.Context, dnsDomain string) nbdns.CustomZone {
var merr *multierror.Error
if dnsDomain == "" {
log.WithContext(ctx).Error("no dns domain is set, returning empty zone")
return nbdns.CustomZone{}
}
customZone := nbdns.CustomZone{
Domain: dns.Fqdn(dnsDomain),
Records: make([]nbdns.SimpleRecord, 0, len(a.Peers)),
}
domainSuffix := "." + dnsDomain
var sb strings.Builder
for _, peer := range a.Peers {
if peer.DNSLabel == "" {
merr = multierror.Append(merr, fmt.Errorf("peer %s has an empty DNS label", peer.Name))
continue
}
sb.Grow(len(peer.DNSLabel) + len(domainSuffix))
sb.WriteString(peer.DNSLabel)
sb.WriteString(domainSuffix)
customZone.Records = append(customZone.Records, nbdns.SimpleRecord{
Name: sb.String(),
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: defaultTTL,
RData: peer.IP.String(),
})
sb.Reset()
}
go func() {
if merr != nil {
log.WithContext(ctx).Errorf("error generating custom zone for account %s: %v", a.Id, merr)
}
}()
return customZone
} }
// GetExpiredPeers returns peers that have been expired // GetExpiredPeers returns peers that have been expired
@ -769,10 +838,6 @@ func (a *Account) GetPeer(peerID string) *nbpeer.Peer {
// SetJWTGroups updates the user's auto groups by synchronizing JWT groups. // SetJWTGroups updates the user's auto groups by synchronizing JWT groups.
// Returns true if there are changes in the JWT group membership. // Returns true if there are changes in the JWT group membership.
func (a *Account) SetJWTGroups(userID string, groupsNames []string) bool { func (a *Account) SetJWTGroups(userID string, groupsNames []string) bool {
if len(groupsNames) == 0 {
return false
}
user, ok := a.Users[userID] user, ok := a.Users[userID]
if !ok { if !ok {
return false return false
@ -856,7 +921,7 @@ func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) {
func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) { func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) {
for _, gid := range groups { for _, gid := range groups {
group, ok := a.Groups[gid] group, ok := a.Groups[gid]
if !ok { if !ok || group.Name == "All" {
continue continue
} }
update := make([]string, 0, len(group.Peers)) update := make([]string, 0, len(group.Peers))
@ -874,10 +939,18 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) {
} }
// BuildManager creates a new DefaultAccountManager with a provided Store // BuildManager creates a new DefaultAccountManager with a provided Store
func BuildManager(ctx context.Context, store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager, func BuildManager(
singleAccountModeDomain string, dnsDomain string, eventStore activity.Store, geo *geolocation.Geolocation, ctx context.Context,
store Store,
peersUpdateManager *PeersUpdateManager,
idpManager idp.Manager,
singleAccountModeDomain string,
dnsDomain string,
eventStore activity.Store,
geo *geolocation.Geolocation,
userDeleteFromIDPEnabled bool, userDeleteFromIDPEnabled bool,
integratedPeerValidator integrated_validator.IntegratedValidator, integratedPeerValidator integrated_validator.IntegratedValidator,
metrics telemetry.AppMetrics,
) (*DefaultAccountManager, error) { ) (*DefaultAccountManager, error) {
am := &DefaultAccountManager{ am := &DefaultAccountManager{
Store: store, Store: store,
@ -892,6 +965,7 @@ func BuildManager(ctx context.Context, store Store, peersUpdateManager *PeersUpd
peerLoginExpiry: NewDefaultScheduler(), peerLoginExpiry: NewDefaultScheduler(),
userDeleteFromIDPEnabled: userDeleteFromIDPEnabled, userDeleteFromIDPEnabled: userDeleteFromIDPEnabled,
integratedPeerValidator: integratedPeerValidator, integratedPeerValidator: integratedPeerValidator,
metrics: metrics,
} }
allAccounts := store.GetAllAccounts(ctx) allAccounts := store.GetAllAccounts(ctx)
// enable single account mode only if configured by user and number of existing accounts is not grater than 1 // enable single account mode only if configured by user and number of existing accounts is not grater than 1
@ -977,7 +1051,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour") return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour")
} }
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
@ -1028,7 +1102,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) { func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) {
return func() (time.Duration, bool) { return func() (time.Duration, bool) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
@ -1127,7 +1201,7 @@ func (am *DefaultAccountManager) warmupIDPCache(ctx context.Context) error {
// DeleteAccount deletes an account and all its users from local store and from the remote IDP if the requester is an admin and account owner // DeleteAccount deletes an account and all its users from local store and from the remote IDP if the requester is an admin and account owner
func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, userID string) error { func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, userID string) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
@ -1587,7 +1661,7 @@ func (am *DefaultAccountManager) MarkPATUsed(ctx context.Context, tokenID string
return err return err
} }
unlock := am.Store.AcquireAccountWriteLock(ctx, account.Id) unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id)
defer unlock() defer unlock()
account, err = am.Store.GetAccountByUser(ctx, user.Id) account, err = am.Store.GetAccountByUser(ctx, user.Id)
@ -1670,7 +1744,7 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
unlock := am.Store.AcquireAccountWriteLock(ctx, newAcc.Id) unlock := am.Store.AcquireWriteLockByUID(ctx, newAcc.Id)
alreadyUnlocked := false alreadyUnlocked := false
defer func() { defer func() {
if !alreadyUnlocked { if !alreadyUnlocked {
@ -1831,7 +1905,7 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C
account, err := am.Store.GetAccountByUser(ctx, claims.UserId) account, err := am.Store.GetAccountByUser(ctx, claims.UserId)
if err == nil { if err == nil {
unlockAccount := am.Store.AcquireAccountWriteLock(ctx, account.Id) unlockAccount := am.Store.AcquireWriteLockByUID(ctx, account.Id)
defer unlockAccount() defer unlockAccount()
account, err = am.Store.GetAccountByUser(ctx, claims.UserId) account, err = am.Store.GetAccountByUser(ctx, claims.UserId)
if err != nil { if err != nil {
@ -1851,7 +1925,7 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C
return account, nil return account, nil
} else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { } else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
if domainAccount != nil { if domainAccount != nil {
unlockAccount := am.Store.AcquireAccountWriteLock(ctx, domainAccount.Id) unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccount.Id)
defer unlockAccount() defer unlockAccount()
domainAccount, err = am.Store.GetAccountByPrivateDomain(ctx, claims.Domain) domainAccount, err = am.Store.GetAccountByPrivateDomain(ctx, claims.Domain)
if err != nil { if err != nil {
@ -1865,17 +1939,11 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C
} }
} }
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, peerPubKey) accountUnlock := am.Store.AcquireReadLockByUID(ctx, accountID)
if err != nil { defer accountUnlock()
if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound { peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
return nil, nil, nil, status.Errorf(status.Unauthenticated, "peer not registered") defer peerUnlock()
}
return nil, nil, nil, err
}
unlock := am.Store.AcquireAccountReadLock(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
@ -1895,26 +1963,20 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, peerPubKey
return peer, netMap, postureChecks, nil return peer, netMap, postureChecks, nil
} }
func (am *DefaultAccountManager) CancelPeerRoutines(ctx context.Context, peer *nbpeer.Peer) error { func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error {
accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, peer.Key) accountUnlock := am.Store.AcquireReadLockByUID(ctx, accountID)
if err != nil { defer accountUnlock()
if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound { peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
return status.Errorf(status.Unauthenticated, "peer not registered") defer peerUnlock()
}
return err
}
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return err return err
} }
err = am.MarkPeerConnected(ctx, peer.Key, false, nil, account) err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, account)
if err != nil { if err != nil {
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peer.Key, err) log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
} }
return nil return nil
@ -1927,7 +1989,7 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st
return err return err
} }
unlock := am.Store.AcquireAccountReadLock(ctx, accountID) unlock := am.Store.AcquireReadLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)

View File

@ -24,6 +24,7 @@ import (
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@ -410,7 +411,8 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
validatedPeers[p] = struct{}{} validatedPeers[p] = struct{}{}
} }
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, "netbird.io", validatedPeers) customZone := account.GetPeersCustomZone(context.Background(), "netbird.io")
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, nil)
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers)) assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers)) assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
} }
@ -2238,6 +2240,13 @@ func TestAccount_SetJWTGroups(t *testing.T) {
assert.Len(t, account.Users["user2"].AutoGroups, 1, "new group should be added") assert.Len(t, account.Users["user2"].AutoGroups, 1, "new group should be added")
assert.Contains(t, account.Groups, account.Users["user2"].AutoGroups[0], "groups must contain group3 from user groups") assert.Contains(t, account.Groups, account.Users["user2"].AutoGroups[0], "groups must contain group3 from user groups")
}) })
t.Run("remove all JWT groups", func(t *testing.T) {
updated := account.SetJWTGroups("user1", []string{})
assert.True(t, updated, "account should be updated")
assert.Len(t, account.Users["user1"].AutoGroups, 1, "only non-JWT groups should remain")
assert.Contains(t, account.Users["user1"].AutoGroups, "group1", " group1 should still be present")
})
} }
func TestAccount_UserGroupsAddToPeers(t *testing.T) { func TestAccount_UserGroupsAddToPeers(t *testing.T) {
@ -2305,7 +2314,13 @@ func TestAccount_UserGroupsRemoveFromPeers(t *testing.T) {
}) })
} }
func createManager(t *testing.T) (*DefaultAccountManager, error) { type TB interface {
Cleanup(func())
Helper()
TempDir() string
}
func createManager(t TB) (*DefaultAccountManager, error) {
t.Helper() t.Helper()
store, err := createStore(t) store, err := createStore(t)
@ -2314,7 +2329,12 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) {
} }
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}) metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
if err != nil {
return nil, err
}
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -2322,7 +2342,7 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) {
return manager, nil return manager, nil
} }
func createStore(t *testing.T) (Store, error) { func createStore(t TB) (Store, error) {
t.Helper() t.Helper()
dataDir := t.TempDir() dataDir := t.TempDir()
store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir) store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir)

View File

@ -56,6 +56,10 @@ type Config struct {
func (c Config) GetAuthAudiences() []string { func (c Config) GetAuthAudiences() []string {
audiences := []string{c.HttpConfig.AuthAudience} audiences := []string{c.HttpConfig.AuthAudience}
if c.HttpConfig.ExtraAuthAudience != "" {
audiences = append(audiences, c.HttpConfig.ExtraAuthAudience)
}
if c.DeviceAuthorizationFlow != nil && c.DeviceAuthorizationFlow.ProviderConfig.Audience != "" { if c.DeviceAuthorizationFlow != nil && c.DeviceAuthorizationFlow.ProviderConfig.Audience != "" {
audiences = append(audiences, c.DeviceAuthorizationFlow.ProviderConfig.Audience) audiences = append(audiences, c.DeviceAuthorizationFlow.ProviderConfig.Audience)
} }
@ -90,6 +94,8 @@ type HttpServerConfig struct {
OIDCConfigEndpoint string OIDCConfigEndpoint string
// IdpSignKeyRefreshEnabled identifies the signing key is currently being rotated or not // IdpSignKeyRefreshEnabled identifies the signing key is currently being rotated or not
IdpSignKeyRefreshEnabled bool IdpSignKeyRefreshEnabled bool
// Extra audience
ExtraAuthAudience string
} }
// Host represents a Wiretrustee host (e.g. STUN, TURN, Signal) // Host represents a Wiretrustee host (e.g. STUN, TURN, Signal)

View File

@ -4,8 +4,8 @@ import (
"context" "context"
"fmt" "fmt"
"strconv" "strconv"
"sync"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
@ -17,6 +17,50 @@ import (
const defaultTTL = 300 const defaultTTL = 300
// DNSConfigCache is a thread-safe cache for DNS configuration components
type DNSConfigCache struct {
CustomZones sync.Map
NameServerGroups sync.Map
}
// GetCustomZone retrieves a cached custom zone
func (c *DNSConfigCache) GetCustomZone(key string) (*proto.CustomZone, bool) {
if c == nil {
return nil, false
}
if value, ok := c.CustomZones.Load(key); ok {
return value.(*proto.CustomZone), true
}
return nil, false
}
// SetCustomZone stores a custom zone in the cache
func (c *DNSConfigCache) SetCustomZone(key string, value *proto.CustomZone) {
if c == nil {
return
}
c.CustomZones.Store(key, value)
}
// GetNameServerGroup retrieves a cached name server group
func (c *DNSConfigCache) GetNameServerGroup(key string) (*proto.NameServerGroup, bool) {
if c == nil {
return nil, false
}
if value, ok := c.NameServerGroups.Load(key); ok {
return value.(*proto.NameServerGroup), true
}
return nil, false
}
// SetNameServerGroup stores a name server group in the cache
func (c *DNSConfigCache) SetNameServerGroup(key string, value *proto.NameServerGroup) {
if c == nil {
return
}
c.NameServerGroups.Store(key, value)
}
type lookupMap map[string]struct{} type lookupMap map[string]struct{}
// DNSSettings defines dns settings at the account level // DNSSettings defines dns settings at the account level
@ -36,7 +80,7 @@ func (d DNSSettings) Copy() DNSSettings {
// GetDNSSettings validates a user role and returns the DNS settings for the provided account ID // GetDNSSettings validates a user role and returns the DNS settings for the provided account ID
func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error) { func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
@ -58,7 +102,7 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
// SaveDNSSettings validates a user role and updates the account's DNS settings // SaveDNSSettings validates a user role and updates the account's DNS settings
func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error { func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
@ -118,69 +162,73 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
return nil return nil
} }
func toProtocolDNSConfig(update nbdns.Config) *proto.DNSConfig { // toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache
protoUpdate := &proto.DNSConfig{ServiceEnable: update.ServiceEnable} func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache) *proto.DNSConfig {
protoUpdate := &proto.DNSConfig{
ServiceEnable: update.ServiceEnable,
CustomZones: make([]*proto.CustomZone, 0, len(update.CustomZones)),
NameServerGroups: make([]*proto.NameServerGroup, 0, len(update.NameServerGroups)),
}
for _, zone := range update.CustomZones { for _, zone := range update.CustomZones {
protoZone := &proto.CustomZone{Domain: zone.Domain} cacheKey := zone.Domain
for _, record := range zone.Records { if cachedZone, exists := cache.GetCustomZone(cacheKey); exists {
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{ protoUpdate.CustomZones = append(protoUpdate.CustomZones, cachedZone)
Name: record.Name, } else {
Type: int64(record.Type), protoZone := convertToProtoCustomZone(zone)
Class: record.Class, cache.SetCustomZone(cacheKey, protoZone)
TTL: int64(record.TTL), protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone)
RData: record.RData,
})
} }
protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone)
} }
for _, nsGroup := range update.NameServerGroups { for _, nsGroup := range update.NameServerGroups {
protoGroup := &proto.NameServerGroup{ cacheKey := nsGroup.ID
Primary: nsGroup.Primary, if cachedGroup, exists := cache.GetNameServerGroup(cacheKey); exists {
Domains: nsGroup.Domains, protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, cachedGroup)
SearchDomainsEnabled: nsGroup.SearchDomainsEnabled, } else {
protoGroup := convertToProtoNameServerGroup(nsGroup)
cache.SetNameServerGroup(cacheKey, protoGroup)
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup)
} }
for _, ns := range nsGroup.NameServers {
protoNS := &proto.NameServer{
IP: ns.IP.String(),
Port: int64(ns.Port),
NSType: int64(ns.NSType),
}
protoGroup.NameServers = append(protoGroup.NameServers, protoNS)
}
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup)
} }
return protoUpdate return protoUpdate
} }
func getPeersCustomZone(ctx context.Context, account *Account, dnsDomain string) nbdns.CustomZone { // Helper function to convert nbdns.CustomZone to proto.CustomZone
if dnsDomain == "" { func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone {
log.WithContext(ctx).Errorf("no dns domain is set, returning empty zone") protoZone := &proto.CustomZone{
return nbdns.CustomZone{} Domain: zone.Domain,
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
} }
for _, record := range zone.Records {
customZone := nbdns.CustomZone{ protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{
Domain: dns.Fqdn(dnsDomain), Name: record.Name,
} Type: int64(record.Type),
Class: record.Class,
for _, peer := range account.Peers { TTL: int64(record.TTL),
if peer.DNSLabel == "" { RData: record.RData,
log.WithContext(ctx).Errorf("found a peer with empty dns label. It was probably caused by a invalid character in its name. Peer Name: %s", peer.Name)
continue
}
customZone.Records = append(customZone.Records, nbdns.SimpleRecord{
Name: dns.Fqdn(peer.DNSLabel + "." + dnsDomain),
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: defaultTTL,
RData: peer.IP.String(),
}) })
} }
return protoZone
}
return customZone // Helper function to convert nbdns.NameServerGroup to proto.NameServerGroup
func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameServerGroup {
protoGroup := &proto.NameServerGroup{
Primary: nsGroup.Primary,
Domains: nsGroup.Domains,
SearchDomainsEnabled: nsGroup.SearchDomainsEnabled,
NameServers: make([]*proto.NameServer, 0, len(nsGroup.NameServers)),
}
for _, ns := range nsGroup.NameServers {
protoGroup.NameServers = append(protoGroup.NameServers, &proto.NameServer{
IP: ns.IP.String(),
Port: int64(ns.Port),
NSType: int64(ns.NSType),
})
}
return protoGroup
} }
func getPeerNSGroups(account *Account, peerID string) []*nbdns.NameServerGroup { func getPeerNSGroups(account *Account, peerID string) []*nbdns.NameServerGroup {

View File

@ -2,11 +2,14 @@ package server
import ( import (
"context" "context"
"fmt"
"net/netip" "net/netip"
"reflect"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/dns"
@ -197,7 +200,11 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
return nil, err return nil, err
} }
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{})
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}, metrics)
} }
func createDNSStore(t *testing.T) (Store, error) { func createDNSStore(t *testing.T) (Store, error) {
@ -323,91 +330,149 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro
return am.Store.GetAccount(context.Background(), account.Id) return am.Store.GetAccount(context.Background(), account.Id)
} }
func TestDNSAccountPeerUpdate(t *testing.T) { func generateTestData(size int) nbdns.Config {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) config := nbdns.Config{
ServiceEnable: true,
CustomZones: make([]nbdns.CustomZone, size),
NameServerGroups: make([]*nbdns.NameServerGroup, size),
}
err := manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{ for i := 0; i < size; i++ {
ID: "group-id", config.CustomZones[i] = nbdns.CustomZone{
Name: "GroupA", Domain: fmt.Sprintf("domain%d.com", i),
Peers: []string{peer1.ID, peer2.ID, peer3.ID}, Records: []nbdns.SimpleRecord{
}) {
assert.NoError(t, err) Name: fmt.Sprintf("record%d", i),
Type: 1,
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) Class: "IN",
t.Cleanup(func() { TTL: 3600,
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) RData: "192.168.1.1",
}) },
},
// Saving DNS settings with unused groups should not update account peers and not send peer update
t.Run("saving dns setting with unused groups", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{
DisabledManagementGroups: []string{"group-id"},
})
assert.NoError(t, err)
select {
case <-done:
case <-time.After(200 * time.Millisecond):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
} }
})
_, err = manager.CreateNameServerGroup( config.NameServerGroups[i] = &nbdns.NameServerGroup{
context.Background(), account.Id, "ns-group-1", "ns-group-1", []dns.NameServer{{ ID: fmt.Sprintf("group%d", i),
IP: netip.MustParseAddr(peer1.IP.String()), Primary: i == 0,
NSType: dns.UDPNameServerType, Domains: []string{fmt.Sprintf("domain%d.com", i)},
Port: dns.DefaultDNSPort, SearchDomainsEnabled: true,
}}, NameServers: []nbdns.NameServer{
[]string{"group-id"}, {
true, []string{}, true, userID, false, IP: netip.MustParseAddr("8.8.8.8"),
) Port: 53,
assert.NoError(t, err) NSType: 1,
},
// Saving DNS settings with used groups should update account peers and send peer update },
t.Run("saving dns setting with used groups", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{
DisabledManagementGroups: []string{"group-id"},
})
assert.NoError(t, err)
select {
case <-done:
case <-time.After(200 * time.Millisecond):
t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) }
// Saving unchanged DNS settings with used groups should update account peers and not send peer update
// since there is no change in the network map
t.Run("saving unchanged dns setting with used groups", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{
DisabledManagementGroups: []string{"group-id"},
})
assert.NoError(t, err)
select {
case <-done:
case <-time.After(200 * time.Millisecond):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
return config
}
func BenchmarkToProtocolDNSConfig(b *testing.B) {
sizes := []int{10, 100, 1000}
for _, size := range sizes {
testData := generateTestData(size)
b.Run(fmt.Sprintf("WithCache-Size%d", size), func(b *testing.B) {
cache := &DNSConfigCache{}
b.ResetTimer()
for i := 0; i < b.N; i++ {
toProtocolDNSConfig(testData, cache)
}
})
b.Run(fmt.Sprintf("WithoutCache-Size%d", size), func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
cache := &DNSConfigCache{}
toProtocolDNSConfig(testData, cache)
}
})
}
}
func TestToProtocolDNSConfigWithCache(t *testing.T) {
var cache DNSConfigCache
// Create two different configs
config1 := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "example.com",
Records: []nbdns.SimpleRecord{
{Name: "www", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.1"},
},
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
ID: "group1",
Name: "Group 1",
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("8.8.8.8"), Port: 53},
},
},
},
}
config2 := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "example.org",
Records: []nbdns.SimpleRecord{
{Name: "mail", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.2"},
},
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
ID: "group2",
Name: "Group 2",
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("8.8.4.4"), Port: 53},
},
},
},
}
// First run with config1
result1 := toProtocolDNSConfig(config1, &cache)
// Second run with config2
result2 := toProtocolDNSConfig(config2, &cache)
// Third run with config1 again
result3 := toProtocolDNSConfig(config1, &cache)
// Verify that result1 and result3 are identical
if !reflect.DeepEqual(result1, result3) {
t.Errorf("Results are not identical when run with the same input. Expected %v, got %v", result1, result3)
}
// Verify that result2 is different from result1 and result3
if reflect.DeepEqual(result1, result2) || reflect.DeepEqual(result2, result3) {
t.Errorf("Results should be different for different inputs")
}
// Verify that the cache contains elements from both configs
if _, exists := cache.GetCustomZone("example.com"); !exists {
t.Errorf("Cache should contain custom zone for example.com")
}
if _, exists := cache.GetCustomZone("example.org"); !exists {
t.Errorf("Cache should contain custom zone for example.org")
}
if _, exists := cache.GetNameServerGroup("group1"); !exists {
t.Errorf("Cache should contain name server group 'group1'")
}
if _, exists := cache.GetNameServerGroup("group2"); !exists {
t.Errorf("Cache should contain name server group 'group2'")
}
} }

View File

@ -13,7 +13,7 @@ import (
// GetEvents returns a list of activity events of an account // GetEvents returns a list of activity events of an account
func (am *DefaultAccountManager) GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) { func (am *DefaultAccountManager) GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)

View File

@ -39,8 +39,8 @@ type FileStore struct {
mux sync.Mutex `json:"-"` mux sync.Mutex `json:"-"`
storeFile string `json:"-"` storeFile string `json:"-"`
// sync.Mutex indexed by accountID // sync.Mutex indexed by resource ID
accountLocks sync.Map `json:"-"` resourceLocks sync.Map `json:"-"`
globalAccountLock sync.Mutex `json:"-"` globalAccountLock sync.Mutex `json:"-"`
metrics telemetry.AppMetrics `json:"-"` metrics telemetry.AppMetrics `json:"-"`
@ -281,26 +281,26 @@ func (s *FileStore) AcquireGlobalLock(ctx context.Context) (unlock func()) {
return unlock return unlock
} }
// AcquireAccountWriteLock acquires account lock for writing to a resource and returns a function that releases the lock // AcquireWriteLockByUID acquires an ID lock for writing to a resource and returns a function that releases the lock
func (s *FileStore) AcquireAccountWriteLock(ctx context.Context, accountID string) (unlock func()) { func (s *FileStore) AcquireWriteLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
log.WithContext(ctx).Debugf("acquiring lock for account %s", accountID) log.WithContext(ctx).Debugf("acquiring lock for ID %s", uniqueID)
start := time.Now() start := time.Now()
value, _ := s.accountLocks.LoadOrStore(accountID, &sync.Mutex{}) value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.Mutex{})
mtx := value.(*sync.Mutex) mtx := value.(*sync.Mutex)
mtx.Lock() mtx.Lock()
unlock = func() { unlock = func() {
mtx.Unlock() mtx.Unlock()
log.WithContext(ctx).Debugf("released lock for account %s in %v", accountID, time.Since(start)) log.WithContext(ctx).Debugf("released lock for ID %s in %v", uniqueID, time.Since(start))
} }
return unlock return unlock
} }
// AcquireAccountReadLock AcquireAccountWriteLock acquires account lock for reading a resource and returns a function that releases the lock // AcquireReadLockByUID acquires an ID lock for reading a resource and returns a function that releases the lock
// This method is still returns a write lock as file store can't handle read locks // This method is still returns a write lock as file store can't handle read locks
func (s *FileStore) AcquireAccountReadLock(ctx context.Context, accountID string) (unlock func()) { func (s *FileStore) AcquireReadLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
return s.AcquireAccountWriteLock(ctx, accountID) return s.AcquireWriteLockByUID(ctx, uniqueID)
} }
func (s *FileStore) SaveAccount(ctx context.Context, account *Account) error { func (s *FileStore) SaveAccount(ctx context.Context, account *Account) error {
@ -666,6 +666,26 @@ func (s *FileStore) SaveInstallationID(ctx context.Context, ID string) error {
return s.persist(ctx, s.storeFile) return s.persist(ctx, s.storeFile)
} }
// SavePeer saves the peer in the account
func (s *FileStore) SavePeer(_ context.Context, accountID string, peer *nbpeer.Peer) error {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.getAccount(accountID)
if err != nil {
return err
}
newPeer := peer.Copy()
account.Peers[peer.ID] = newPeer
s.PeerKeyID2AccountID[peer.Key] = accountID
s.PeerID2AccountID[peer.ID] = accountID
return nil
}
// SavePeerStatus stores the PeerStatus in memory. It doesn't attempt to persist data to speed up things. // SavePeerStatus stores the PeerStatus in memory. It doesn't attempt to persist data to speed up things.
// PeerStatus will be saved eventually when some other changes occur. // PeerStatus will be saved eventually when some other changes occur.
func (s *FileStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error { func (s *FileStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {

View File

@ -9,6 +9,7 @@ import (
"path" "path"
"strconv" "strconv"
log "github.com/sirupsen/logrus"
"gorm.io/driver/sqlite" "gorm.io/driver/sqlite"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/logger" "gorm.io/gorm/logger"
@ -30,6 +31,8 @@ func loadGeolocationDatabases(dataDir string) error {
continue continue
} }
log.Infof("geo location file %s not found , file will be downloaded", file)
switch file { switch file {
case MMDBFileName: case MMDBFileName:
extractFunc := func(src string, dst string) error { extractFunc := func(src string, dst string) error {

View File

@ -2,6 +2,7 @@ package server
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"slices" "slices"
@ -26,7 +27,7 @@ func (e *GroupLinkError) Error() string {
// GetGroup object of the peers // GetGroup object of the peers
func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupID, userID string) (*nbgroup.Group, error) { func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupID, userID string) (*nbgroup.Group, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
@ -53,7 +54,7 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI
// GetAllGroups returns all groups in an account // GetAllGroups returns all groups in an account
func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID string, userID string) ([]*nbgroup.Group, error) { func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID string, userID string) ([]*nbgroup.Group, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
@ -80,7 +81,7 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID str
// GetGroupByName filters all groups in an account by name and returns the one with the most peers // GetGroupByName filters all groups in an account by name and returns the one with the most peers
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) { func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
@ -113,7 +114,7 @@ func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName,
// SaveGroup object of the peers // SaveGroup object of the peers
func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *nbgroup.Group) error { func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *nbgroup.Group) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
return am.SaveGroups(ctx, accountID, userID, []*nbgroup.Group{newGroup}) return am.SaveGroups(ctx, accountID, userID, []*nbgroup.Group{newGroup})
} }
@ -165,19 +166,12 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
eventsToStore = append(eventsToStore, events...) eventsToStore = append(eventsToStore, events...)
} }
newGroupIDs := make([]string, 0, len(newGroups))
for _, newGroup := range newGroups {
newGroupIDs = append(newGroupIDs, newGroup.ID)
}
account.Network.IncSerial() account.Network.IncSerial()
if err = am.Store.SaveGroups(account.Id, account.Groups); err != nil { if err = am.Store.SaveAccount(ctx, account); err != nil {
return err return err
} }
if areGroupChangesAffectPeers(account, newGroupIDs) { am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, account)
}
for _, storeEvent := range eventsToStore { for _, storeEvent := range eventsToStore {
storeEvent() storeEvent()
@ -253,12 +247,12 @@ func difference(a, b []string) []string {
return diff return diff
} }
// DeleteGroup object of the peers // DeleteGroup object of the peers.
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, userID, groupID string) error { func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountId)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountId)
if err != nil { if err != nil {
return err return err
} }
@ -268,22 +262,70 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, use
return nil return nil
} }
if err := validateDeleteGroup(account, group, userID); err != nil { if err = validateDeleteGroup(account, group, userId); err != nil {
return err return err
} }
delete(account.Groups, groupID) delete(account.Groups, groupID)
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil { if err = am.Store.SaveAccount(ctx, account); err != nil {
return err return err
} }
am.StoreEvent(ctx, userID, groupID, accountID, activity.GroupDeleted, group.EventMeta())
am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, group.EventMeta())
am.updateAccountPeers(ctx, account)
return nil return nil
} }
// DeleteGroups deletes groups from an account.
// Note: This function does not acquire the global lock.
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
//
// If an error occurs while deleting a group, the function skips it and continues deleting other groups.
// Errors are collected and returned at the end.
func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error {
account, err := am.Store.GetAccount(ctx, accountId)
if err != nil {
return err
}
var allErrors error
deletedGroups := make([]*nbgroup.Group, 0, len(groupIDs))
for _, groupID := range groupIDs {
group, ok := account.Groups[groupID]
if !ok {
continue
}
if err := validateDeleteGroup(account, group, userId); err != nil {
allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err))
continue
}
delete(account.Groups, groupID)
deletedGroups = append(deletedGroups, group)
}
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
for _, g := range deletedGroups {
am.StoreEvent(ctx, userId, g.ID, accountId, activity.GroupDeleted, g.EventMeta())
}
am.updateAccountPeers(ctx, account)
return allErrors
}
// ListGroups objects of the peers // ListGroups objects of the peers
func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) { func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
@ -301,7 +343,7 @@ func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID strin
// GroupAddPeer appends peer to the group // GroupAddPeer appends peer to the group
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error { func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
@ -330,16 +372,14 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
return err return err
} }
if areGroupChangesAffectPeers(account, []string{group.ID}) { am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, account)
}
return nil return nil
} }
// GroupDeletePeer removes peer from the group // GroupDeletePeer removes peer from the group
func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error { func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
@ -362,29 +402,11 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
} }
} }
if areGroupChangesAffectPeers(account, []string{group.ID}) { am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, account)
}
return nil return nil
} }
func areGroupChangesAffectPeers(account *Account, groupIDs []string) bool {
for _, groupID := range groupIDs {
if linked, _ := isGroupLinkedToDns(account.NameServerGroups, groupID); linked {
return true
}
if linked, _ := isGroupLinkedToPolicy(account.Policies, groupID); linked {
return true
}
if linked, _ := isGroupLinkedToRoute(account.Routes, groupID); linked {
return true
}
}
return false
}
func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string) error { func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string) error {
// disable a deleting integration group if the initiator is not an admin service user // disable a deleting integration group if the initiator is not an admin service user
if group.Issued == nbgroup.GroupIssuedIntegration { if group.Issued == nbgroup.GroupIssuedIntegration {
@ -478,18 +500,8 @@ func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bo
func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) { func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) {
for _, user := range users { for _, user := range users {
if slices.Contains(user.AutoGroups, groupID) { if slices.Contains(user.AutoGroups, groupID) {
return false, user return true, user
} }
} }
return false, nil return false, nil
} }
// anyGroupHasPeers checks if any of the given groups in the account have peers.
func anyGroupHasPeers(account *Account, groupIDs []string) bool {
for _, groupID := range groupIDs {
if group, exists := account.Groups[groupID]; exists && group.HasPeers() {
return true
}
}
return false
}

View File

@ -3,6 +3,7 @@ package server
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"testing" "testing"
"time" "time"
@ -23,7 +24,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) {
t.Error("failed to create account manager") t.Error("failed to create account manager")
} }
account, err := initTestGroupAccount(am) _, account, err := initTestGroupAccount(am)
if err != nil { if err != nil {
t.Error("failed to init testing account") t.Error("failed to init testing account")
} }
@ -58,7 +59,7 @@ func TestDefaultAccountManager_DeleteGroup(t *testing.T) {
t.Error("failed to create account manager") t.Error("failed to create account manager")
} }
account, err := initTestGroupAccount(am) _, account, err := initTestGroupAccount(am)
if err != nil { if err != nil {
t.Error("failed to init testing account") t.Error("failed to init testing account")
} }
@ -134,7 +135,136 @@ func TestDefaultAccountManager_DeleteGroup(t *testing.T) {
} }
} }
func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) { func TestDefaultAccountManager_DeleteGroups(t *testing.T) {
am, err := createManager(t)
assert.NoError(t, err, "Failed to create account manager")
manager, account, err := initTestGroupAccount(am)
assert.NoError(t, err, "Failed to init testing account")
groups := make([]*nbgroup.Group, 10)
for i := 0; i < 10; i++ {
groups[i] = &nbgroup.Group{
ID: fmt.Sprintf("group-%d", i+1),
AccountID: account.Id,
Name: fmt.Sprintf("group-%d", i+1),
Issued: nbgroup.GroupIssuedAPI,
}
}
err = manager.SaveGroups(context.Background(), account.Id, groupAdminUserID, groups)
assert.NoError(t, err, "Failed to save test groups")
testCases := []struct {
name string
groupIDs []string
expectedReasons []string
expectedDeleted []string
expectedNotDeleted []string
}{
{
name: "route",
groupIDs: []string{"grp-for-route"},
expectedReasons: []string{"route"},
},
{
name: "route with peer groups",
groupIDs: []string{"grp-for-route2"},
expectedReasons: []string{"route"},
},
{
name: "name server groups",
groupIDs: []string{"grp-for-name-server-grp"},
expectedReasons: []string{"name server groups"},
},
{
name: "policy",
groupIDs: []string{"grp-for-policies"},
expectedReasons: []string{"policy"},
},
{
name: "setup keys",
groupIDs: []string{"grp-for-keys"},
expectedReasons: []string{"setup key"},
},
{
name: "users",
groupIDs: []string{"grp-for-users"},
expectedReasons: []string{"user"},
},
{
name: "integration",
groupIDs: []string{"grp-for-integration"},
expectedReasons: []string{"only service users with admin power can delete integration group"},
},
{
name: "successfully delete multiple groups",
groupIDs: []string{"group-1", "group-2"},
expectedDeleted: []string{"group-1", "group-2"},
},
{
name: "delete non-existent group",
groupIDs: []string{"non-existent-group"},
expectedDeleted: []string{"non-existent-group"},
},
{
name: "delete multiple groups with mixed results",
groupIDs: []string{"group-3", "grp-for-policies", "group-4", "grp-for-users"},
expectedReasons: []string{"policy", "user"},
expectedDeleted: []string{"group-3", "group-4"},
expectedNotDeleted: []string{"grp-for-policies", "grp-for-users"},
},
{
name: "delete groups with multiple errors",
groupIDs: []string{"grp-for-policies", "grp-for-users"},
expectedReasons: []string{"policy", "user"},
expectedNotDeleted: []string{"grp-for-policies", "grp-for-users"},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err = am.DeleteGroups(context.Background(), account.Id, groupAdminUserID, tc.groupIDs)
if len(tc.expectedReasons) > 0 {
assert.Error(t, err)
var foundExpectedErrors int
wrappedErr, ok := err.(interface{ Unwrap() []error })
assert.Equal(t, ok, true)
for _, e := range wrappedErr.Unwrap() {
var sErr *status.Error
if errors.As(e, &sErr) {
assert.Contains(t, tc.expectedReasons, sErr.Message, "unexpected error message")
foundExpectedErrors++
}
var gErr *GroupLinkError
if errors.As(e, &gErr) {
assert.Contains(t, tc.expectedReasons, gErr.Resource, "unexpected error resource")
foundExpectedErrors++
}
}
assert.Equal(t, len(tc.expectedReasons), foundExpectedErrors, "not all expected errors were found")
} else {
assert.NoError(t, err)
}
for _, groupID := range tc.expectedDeleted {
_, err := am.GetGroup(context.Background(), account.Id, groupID, groupAdminUserID)
assert.Error(t, err, "group should have been deleted: %s", groupID)
}
for _, groupID := range tc.expectedNotDeleted {
group, err := am.GetGroup(context.Background(), account.Id, groupID, groupAdminUserID)
assert.NoError(t, err, "group should not have been deleted: %s", groupID)
assert.NotNil(t, group, "group should exist: %s", groupID)
}
})
}
}
func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *Account, error) {
accountID := "testingAcc" accountID := "testingAcc"
domain := "example.com" domain := "example.com"
@ -238,7 +368,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) {
err := am.Store.SaveAccount(context.Background(), account) err := am.Store.SaveAccount(context.Background(), account)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute) _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute)
@ -249,7 +379,11 @@ func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) {
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForUsers) _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForUsers)
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForIntegration) _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForIntegration)
return am.Store.GetAccount(context.Background(), account.Id) acc, err := am.Store.GetAccount(context.Background(), account.Id)
if err != nil {
return nil, nil, err
}
return am, acc, nil
} }
func TestGroupAccountPeerUpdate(t *testing.T) { func TestGroupAccountPeerUpdate(t *testing.T) {

View File

@ -156,7 +156,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
log.WithContext(ctx).Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP) log.WithContext(ctx).Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
} }
peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, peerKey.String(), extractPeerMeta(ctx, syncReq.GetMeta()), realIP) peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), extractPeerMeta(ctx, syncReq.GetMeta()), realIP)
if err != nil { if err != nil {
return mapError(ctx, err) return mapError(ctx, err)
} }
@ -179,11 +179,11 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart)) s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart))
} }
return s.handleUpdates(ctx, peerKey, peer, updates, srv) return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
} }
// handleUpdates sends updates to the connected peer until the updates channel is closed. // handleUpdates sends updates to the connected peer until the updates channel is closed.
func (s *GRPCServer) handleUpdates(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error { func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error {
for { for {
select { select {
// condition when there are some updates // condition when there are some updates
@ -194,12 +194,12 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, peerKey wgtypes.Key, pee
if !open { if !open {
log.WithContext(ctx).Debugf("updates channel for peer %s was closed", peerKey.String()) log.WithContext(ctx).Debugf("updates channel for peer %s was closed", peerKey.String())
s.cancelPeerRoutines(ctx, peer) s.cancelPeerRoutines(ctx, accountID, peer)
return nil return nil
} }
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String()) log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
if err := s.sendUpdate(ctx, peerKey, peer, update, srv); err != nil { if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv); err != nil {
return err return err
} }
@ -207,7 +207,7 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, peerKey wgtypes.Key, pee
case <-srv.Context().Done(): case <-srv.Context().Done():
// happens when connection drops, e.g. client disconnects // happens when connection drops, e.g. client disconnects
log.WithContext(ctx).Debugf("stream of peer %s has been closed", peerKey.String()) log.WithContext(ctx).Debugf("stream of peer %s has been closed", peerKey.String())
s.cancelPeerRoutines(ctx, peer) s.cancelPeerRoutines(ctx, accountID, peer)
return srv.Context().Err() return srv.Context().Err()
} }
} }
@ -215,10 +215,10 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, peerKey wgtypes.Key, pee
// sendUpdate encrypts the update message using the peer key and the server's wireguard key, // sendUpdate encrypts the update message using the peer key and the server's wireguard key,
// then sends the encrypted message to the connected peer via the sync server. // then sends the encrypted message to the connected peer via the sync server.
func (s *GRPCServer) sendUpdate(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, update *UpdateMessage, srv proto.ManagementService_SyncServer) error { func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *UpdateMessage, srv proto.ManagementService_SyncServer) error {
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update) encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update)
if err != nil { if err != nil {
s.cancelPeerRoutines(ctx, peer) s.cancelPeerRoutines(ctx, accountID, peer)
return status.Errorf(codes.Internal, "failed processing update message") return status.Errorf(codes.Internal, "failed processing update message")
} }
err = srv.SendMsg(&proto.EncryptedMessage{ err = srv.SendMsg(&proto.EncryptedMessage{
@ -226,17 +226,17 @@ func (s *GRPCServer) sendUpdate(ctx context.Context, peerKey wgtypes.Key, peer *
Body: encryptedResp, Body: encryptedResp,
}) })
if err != nil { if err != nil {
s.cancelPeerRoutines(ctx, peer) s.cancelPeerRoutines(ctx, accountID, peer)
return status.Errorf(codes.Internal, "failed sending update message") return status.Errorf(codes.Internal, "failed sending update message")
} }
log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String()) log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String())
return nil return nil
} }
func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, peer *nbpeer.Peer) { func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
s.peersUpdateManager.CloseChannel(ctx, peer.ID) s.peersUpdateManager.CloseChannel(ctx, peer.ID)
s.turnCredentialsManager.CancelRefresh(peer.ID) s.turnCredentialsManager.CancelRefresh(peer.ID)
_ = s.accountManager.CancelPeerRoutines(ctx, peer) _ = s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
s.ephemeralManager.OnPeerDisconnected(ctx, peer) s.ephemeralManager.OnPeerDisconnected(ctx, peer)
} }
@ -533,53 +533,46 @@ func toPeerConfig(peer *nbpeer.Peer, network *Network, dnsName string) *proto.Pe
} }
} }
func toRemotePeerConfig(peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig { func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache) *proto.SyncResponse {
remotePeers := []*proto.RemotePeerConfig{} response := &proto.SyncResponse{
for _, rPeer := range peers { WiretrusteeConfig: toWiretrusteeConfig(config, turnCredentials),
fqdn := rPeer.FQDN(dnsName) PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName),
remotePeers = append(remotePeers, &proto.RemotePeerConfig{
WgPubKey: rPeer.Key,
AllowedIps: []string{fmt.Sprintf(AllowedIPsFormat, rPeer.IP)},
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
Fqdn: fqdn,
})
}
return remotePeers
}
func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string, checks []*posture.Checks) *proto.SyncResponse {
wtConfig := toWiretrusteeConfig(config, turnCredentials)
pConfig := toPeerConfig(peer, networkMap.Network, dnsName)
remotePeers := toRemotePeerConfig(networkMap.Peers, dnsName)
routesUpdate := toProtocolRoutes(networkMap.Routes)
dnsUpdate := toProtocolDNSConfig(networkMap.DNSConfig)
offlinePeers := toRemotePeerConfig(networkMap.OfflinePeers, dnsName)
firewallRules := toProtocolFirewallRules(networkMap.FirewallRules)
return &proto.SyncResponse{
WiretrusteeConfig: wtConfig,
PeerConfig: pConfig,
RemotePeers: remotePeers,
RemotePeersIsEmpty: len(remotePeers) == 0,
NetworkMap: &proto.NetworkMap{ NetworkMap: &proto.NetworkMap{
Serial: networkMap.Network.CurrentSerial(), Serial: networkMap.Network.CurrentSerial(),
PeerConfig: pConfig, Routes: toProtocolRoutes(networkMap.Routes),
RemotePeers: remotePeers, DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache),
OfflinePeers: offlinePeers,
RemotePeersIsEmpty: len(remotePeers) == 0,
Routes: routesUpdate,
DNSConfig: dnsUpdate,
FirewallRules: firewallRules,
FirewallRulesIsEmpty: len(firewallRules) == 0,
}, },
Checks: toProtocolChecks(ctx, checks), Checks: toProtocolChecks(ctx, checks),
} }
response.NetworkMap.PeerConfig = response.PeerConfig
allPeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers))
allPeers = appendRemotePeerConfig(allPeers, networkMap.Peers, dnsName)
response.RemotePeers = allPeers
response.NetworkMap.RemotePeers = allPeers
response.RemotePeersIsEmpty = len(allPeers) == 0
response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty
response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName)
firewallRules := toProtocolFirewallRules(networkMap.FirewallRules)
response.NetworkMap.FirewallRules = firewallRules
response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0
return response
}
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig {
for _, rPeer := range peers {
dst = append(dst, &proto.RemotePeerConfig{
WgPubKey: rPeer.Key,
AllowedIps: []string{rPeer.IP.String() + "/32"},
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
Fqdn: rPeer.FQDN(dnsName),
})
}
return dst
} }
// IsHealthy indicates whether the service is healthy // IsHealthy indicates whether the service is healthy
@ -597,7 +590,7 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p
} else { } else {
turnCredentials = nil turnCredentials = nil
} }
plainResp := toSyncResponse(ctx, s.config, peer, turnCredentials, networkMap, s.accountManager.GetDNSDomain(), postureChecks) plainResp := toSyncResponse(ctx, s.config, peer, turnCredentials, networkMap, s.accountManager.GetDNSDomain(), postureChecks, nil)
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp) encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
if err != nil { if err != nil {

View File

@ -526,6 +526,43 @@ components:
- revoked - revoked
- auto_groups - auto_groups
- usage_limit - usage_limit
CreateSetupKeyRequest:
type: object
properties:
name:
description: Setup Key name
type: string
example: Default key
type:
description: Setup key type, one-off for single time usage and reusable
type: string
example: reusable
expires_in:
description: Expiration time in seconds
type: integer
minimum: 86400
maximum: 31536000
example: 86400
auto_groups:
description: List of group IDs to auto-assign to peers registered with this key
type: array
items:
type: string
example: "ch8i4ug6lnn4g9hqv7m0"
usage_limit:
description: A number of times this key can be used. The value of 0 indicates the unlimited usage.
type: integer
example: 0
ephemeral:
description: Indicate that the peer will be ephemeral or not
type: boolean
example: true
required:
- name
- type
- expires_in
- auto_groups
- usage_limit
PersonalAccessToken: PersonalAccessToken:
type: object type: object
properties: properties:
@ -1806,7 +1843,7 @@ paths:
content: content:
'application/json': 'application/json':
schema: schema:
$ref: '#/components/schemas/SetupKeyRequest' $ref: '#/components/schemas/CreateSetupKeyRequest'
responses: responses:
'200': '200':
description: A Setup Keys Object description: A Setup Keys Object

View File

@ -254,6 +254,27 @@ type Country struct {
// CountryCode 2-letter ISO 3166-1 alpha-2 code that represents the country // CountryCode 2-letter ISO 3166-1 alpha-2 code that represents the country
type CountryCode = string type CountryCode = string
// CreateSetupKeyRequest defines model for CreateSetupKeyRequest.
type CreateSetupKeyRequest struct {
// AutoGroups List of group IDs to auto-assign to peers registered with this key
AutoGroups []string `json:"auto_groups"`
// Ephemeral Indicate that the peer will be ephemeral or not
Ephemeral *bool `json:"ephemeral,omitempty"`
// ExpiresIn Expiration time in seconds
ExpiresIn int `json:"expires_in"`
// Name Setup Key name
Name string `json:"name"`
// Type Setup key type, one-off for single time usage and reusable
Type string `json:"type"`
// UsageLimit A number of times this key can be used. The value of 0 indicates the unlimited usage.
UsageLimit int `json:"usage_limit"`
}
// DNSSettings defines model for DNSSettings. // DNSSettings defines model for DNSSettings.
type DNSSettings struct { type DNSSettings struct {
// DisabledManagementGroups Groups whose DNS management is disabled // DisabledManagementGroups Groups whose DNS management is disabled
@ -1241,7 +1262,7 @@ type PostApiRoutesJSONRequestBody = RouteRequest
type PutApiRoutesRouteIdJSONRequestBody = RouteRequest type PutApiRoutesRouteIdJSONRequestBody = RouteRequest
// PostApiSetupKeysJSONRequestBody defines body for PostApiSetupKeys for application/json ContentType. // PostApiSetupKeysJSONRequestBody defines body for PostApiSetupKeys for application/json ContentType.
type PostApiSetupKeysJSONRequestBody = SetupKeyRequest type PostApiSetupKeysJSONRequestBody = CreateSetupKeyRequest
// PutApiSetupKeysKeyIdJSONRequestBody defines body for PutApiSetupKeysKeyId for application/json ContentType. // PutApiSetupKeysKeyIdJSONRequestBody defines body for PutApiSetupKeysKeyId for application/json ContentType.
type PutApiSetupKeysKeyIdJSONRequestBody = SetupKeyRequest type PutApiSetupKeysKeyIdJSONRequestBody = SetupKeyRequest

View File

@ -71,7 +71,8 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee
return return
} }
netMap := account.GetPeerNetworkMap(ctx, peerID, h.accountManager.GetDNSDomain(), validPeers) customZone := account.GetPeersCustomZone(ctx, h.accountManager.GetDNSDomain())
netMap := account.GetPeerNetworkMap(ctx, peerID, customZone, validPeers, nil)
accessiblePeers := toAccessiblePeers(netMap, dnsDomain) accessiblePeers := toAccessiblePeers(netMap, dnsDomain)
_, valid := validPeers[peer.ID] _, valid := validPeers[peer.ID]
@ -115,7 +116,9 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account,
util.WriteError(ctx, fmt.Errorf("internal error"), w) util.WriteError(ctx, fmt.Errorf("internal error"), w)
return return
} }
netMap := account.GetPeerNetworkMap(ctx, peerID, h.accountManager.GetDNSDomain(), validPeers)
customZone := account.GetPeersCustomZone(ctx, h.accountManager.GetDNSDomain())
netMap := account.GetPeerNetworkMap(ctx, peerID, customZone, validPeers, nil)
accessiblePeers := toAccessiblePeers(netMap, dnsDomain) accessiblePeers := toAccessiblePeers(netMap, dnsDomain)
_, valid := validPeers[peer.ID] _, valid := validPeers[peer.ID]
@ -194,9 +197,7 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
} }
groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID) groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID)
accessiblePeerNumbers, _ := h.accessiblePeersNumber(r.Context(), account, peer.ID) respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0))
respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, accessiblePeerNumbers))
} }
validPeersMap, err := h.accountManager.GetValidatedPeers(account) validPeersMap, err := h.accountManager.GetValidatedPeers(account)
@ -210,16 +211,6 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, respBody) util.WriteJSONObject(r.Context(), w, respBody)
} }
func (h *PeersHandler) accessiblePeersNumber(ctx context.Context, account *server.Account, peerID string) (int, error) {
validatedPeersMap, err := h.accountManager.GetValidatedPeers(account)
if err != nil {
return 0, err
}
netMap := account.GetPeerNetworkMap(ctx, peerID, h.accountManager.GetDNSDomain(), validatedPeersMap)
return len(netMap.Peers) + len(netMap.OfflinePeers), nil
}
func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPeersMap map[string]struct{}) { func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPeersMap map[string]struct{}) {
for _, peer := range respBody { for _, peer := range respBody {
_, ok := approvedPeersMap[peer.Id] _, ok := approvedPeersMap[peer.Id]

View File

@ -32,7 +32,7 @@ func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(ctx context.Con
return errors.New("invalid groups") return errors.New("invalid groups")
} }
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
a, err := am.Store.GetAccountByUser(ctx, userID) a, err := am.Store.GetAccountByUser(ctx, userID)

View File

@ -2,6 +2,7 @@ package server
import ( import (
"context" "context"
"fmt"
"net" "net"
"os" "os"
"path/filepath" "path/filepath"
@ -16,8 +17,10 @@ import (
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/formatter"
mgmtProto "github.com/netbirdio/netbird/management/proto" mgmtProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
@ -83,7 +86,7 @@ func Test_SyncProtocol(t *testing.T) {
defer func() { defer func() {
os.Remove(filepath.Join(dir, "store.json")) //nolint os.Remove(filepath.Join(dir, "store.json")) //nolint
}() }()
mgmtServer, mgmtAddr, err := startManagement(t, &Config{ mgmtServer, _, mgmtAddr, err := startManagement(t, &Config{
Stuns: []*Host{{ Stuns: []*Host{{
Proto: "udp", Proto: "udp",
URI: "stun:stun.wiretrustee.com:3468", URI: "stun:stun.wiretrustee.com:3468",
@ -399,32 +402,39 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
} }
} }
func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error) { func startManagement(t *testing.T, config *Config) (*grpc.Server, *DefaultAccountManager, string, error) {
t.Helper() t.Helper()
lis, err := net.Listen("tcp", "localhost:0") lis, err := net.Listen("tcp", "localhost:0")
if err != nil { if err != nil {
return nil, "", err return nil, nil, "", err
} }
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, cleanUp, err := NewTestStoreFromJson(context.Background(), config.Datadir) store, cleanUp, err := NewTestStoreFromJson(context.Background(), config.Datadir)
if err != nil { if err != nil {
return nil, "", err return nil, nil, "", err
} }
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
peersUpdateManager := NewPeersUpdateManager(nil) peersUpdateManager := NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
accountManager, err := BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted",
eventStore, nil, false, MocIntegratedValidator{}) ctx := context.WithValue(context.Background(), formatter.ExecutionContextKey, formatter.SystemSource) //nolint:staticcheck
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted",
eventStore, nil, false, MocIntegratedValidator{}, metrics)
if err != nil { if err != nil {
return nil, "", err return nil, nil, "", err
} }
turnManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) turnManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
ephemeralMgr := NewEphemeralManager(store, accountManager) ephemeralMgr := NewEphemeralManager(store, accountManager)
mgmtServer, err := NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, ephemeralMgr) mgmtServer, err := NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, ephemeralMgr)
if err != nil { if err != nil {
return nil, "", err return nil, nil, "", err
} }
mgmtProto.RegisterManagementServiceServer(s, mgmtServer) mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
@ -434,7 +444,7 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error)
} }
}() }()
return s, lis.Addr().String(), nil return s, accountManager, lis.Addr().String(), nil
} }
func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.ClientConn, error) { func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.ClientConn, error) {
@ -454,3 +464,165 @@ func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.Clie
return mgmtProto.NewManagementServiceClient(conn), conn, nil return mgmtProto.NewManagementServiceClient(conn), conn, nil
} }
func Test_SyncStatusRace(t *testing.T) {
if os.Getenv("CI") == "true" && os.Getenv("NETBIRD_STORE_ENGINE") == "postgres" {
t.Skip("Skipping on CI and Postgres store")
}
for i := 0; i < 500; i++ {
t.Run(fmt.Sprintf("TestRun-%d", i), func(t *testing.T) {
testSyncStatusRace(t)
})
}
}
func testSyncStatusRace(t *testing.T) {
t.Helper()
dir := t.TempDir()
err := util.CopyFileContents("testdata/store_with_expired_peers.json", filepath.Join(dir, "store.json"))
if err != nil {
t.Fatal(err)
}
defer func() {
os.Remove(filepath.Join(dir, "store.json")) //nolint
}()
mgmtServer, am, mgmtAddr, err := startManagement(t, &Config{
Stuns: []*Host{{
Proto: "udp",
URI: "stun:stun.wiretrustee.com:3468",
}},
TURNConfig: &TURNConfig{
TimeBasedCredentials: false,
CredentialsTTL: util.Duration{},
Secret: "whatever",
Turns: []*Host{{
Proto: "udp",
URI: "turn:stun.wiretrustee.com:3468",
}},
},
Signal: &Host{
Proto: "http",
URI: "signal.wiretrustee.com:10000",
},
Datadir: dir,
HttpConfig: nil,
})
if err != nil {
t.Fatal(err)
return
}
defer mgmtServer.GracefulStop()
client, clientConn, err := createRawClient(mgmtAddr)
if err != nil {
t.Fatal(err)
return
}
defer clientConn.Close()
// there are two peers already in the store, add two more
peers, err := registerPeers(2, client)
if err != nil {
t.Fatal(err)
return
}
serverKey, err := getServerKey(client)
if err != nil {
t.Fatal(err)
return
}
concurrentPeerKey2 := peers[1]
t.Log("Public key of concurrent peer: ", concurrentPeerKey2.PublicKey().String())
syncReq2 := &mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}}
message2, err := encryption.EncryptMessage(*serverKey, *concurrentPeerKey2, syncReq2)
if err != nil {
t.Fatal(err)
return
}
ctx2, cancelFunc2 := context.WithCancel(context.Background())
//client.
sync2, err := client.Sync(ctx2, &mgmtProto.EncryptedMessage{
WgPubKey: concurrentPeerKey2.PublicKey().String(),
Body: message2,
})
if err != nil {
t.Fatal(err)
return
}
resp2 := &mgmtProto.EncryptedMessage{}
err = sync2.RecvMsg(resp2)
if err != nil {
t.Fatal(err)
return
}
peerWithInvalidStatus := peers[0]
t.Log("Public key of peer with invalid status: ", peerWithInvalidStatus.PublicKey().String())
syncReq := &mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}}
message, err := encryption.EncryptMessage(*serverKey, *peerWithInvalidStatus, syncReq)
if err != nil {
t.Fatal(err)
return
}
ctx, cancelFunc := context.WithCancel(context.Background())
//client.
sync, err := client.Sync(ctx, &mgmtProto.EncryptedMessage{
WgPubKey: peerWithInvalidStatus.PublicKey().String(),
Body: message,
})
if err != nil {
t.Fatal(err)
return
}
// take the first registered peer as a base for the test. Total four.
resp := &mgmtProto.EncryptedMessage{}
err = sync.RecvMsg(resp)
if err != nil {
t.Fatal(err)
return
}
cancelFunc2()
time.Sleep(1 * time.Millisecond)
cancelFunc()
time.Sleep(10 * time.Millisecond)
ctx, cancelFunc = context.WithCancel(context.Background())
defer cancelFunc()
sync, err = client.Sync(ctx, &mgmtProto.EncryptedMessage{
WgPubKey: peerWithInvalidStatus.PublicKey().String(),
Body: message,
})
if err != nil {
t.Fatal(err)
return
}
resp = &mgmtProto.EncryptedMessage{}
err = sync.RecvMsg(resp)
if err != nil {
t.Fatal(err)
return
}
time.Sleep(10 * time.Millisecond)
peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), peerWithInvalidStatus.PublicKey().String())
if err != nil {
t.Fatal(err)
return
}
if !peer.Status.Connected {
t.Fatal("Peer should be connected")
}
}

View File

@ -26,6 +26,7 @@ import (
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
@ -541,8 +542,13 @@ func startServer(config *server.Config) (*grpc.Server, net.Listener) {
peersUpdateManager := server.NewPeersUpdateManager(nil) peersUpdateManager := server.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted",
eventStore, nil, false, MocIntegratedValidator{}) metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
if err != nil {
log.Fatalf("failed creating metrics: %v", err)
}
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics)
if err != nil { if err != nil {
log.Fatalf("failed creating a manager: %v", err) log.Fatalf("failed creating a manager: %v", err)
} }

View File

@ -31,7 +31,7 @@ type MockAccountManager struct {
ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error) ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error)
GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error
SyncAndMarkPeerFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
GetNetworkMapFunc func(ctx context.Context, peerKey string) (*server.NetworkMap, error) GetNetworkMapFunc func(ctx context.Context, peerKey string) (*server.NetworkMap, error)
GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*server.Network, error) GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*server.Network, error)
@ -42,6 +42,7 @@ type MockAccountManager struct {
SaveGroupFunc func(ctx context.Context, accountID, userID string, group *group.Group) error SaveGroupFunc func(ctx context.Context, accountID, userID string, group *group.Group) error
SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*group.Group) error SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*group.Group) error
DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error
DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error
ListGroupsFunc func(ctx context.Context, accountID string) ([]*group.Group, error) ListGroupsFunc func(ctx context.Context, accountID string) ([]*group.Group, error)
GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
@ -66,6 +67,7 @@ type MockAccountManager struct {
SaveOrAddUserFunc func(ctx context.Context, accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error) SaveOrAddUserFunc func(ctx context.Context, accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error)
SaveOrAddUsersFunc func(ctx context.Context, accountID, initiatorUserID string, update []*server.User, addIfNotExists bool) ([]*server.UserInfo, error) SaveOrAddUsersFunc func(ctx context.Context, accountID, initiatorUserID string, update []*server.User, addIfNotExists bool) ([]*server.UserInfo, error)
DeleteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error DeleteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error
DeleteRegularUsersFunc func(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string) error
CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
DeletePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) error DeletePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) error
GetPATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error) GetPATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error)
@ -104,14 +106,14 @@ type MockAccountManager struct {
GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error) GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error)
} }
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
if am.SyncAndMarkPeerFunc != nil { if am.SyncAndMarkPeerFunc != nil {
return am.SyncAndMarkPeerFunc(ctx, peerPubKey, meta, realIP) return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP)
} }
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") return nil, nil, nil, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
} }
func (am *MockAccountManager) CancelPeerRoutines(_ context.Context, peer *nbpeer.Peer) error { func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string) error {
// TODO implement me // TODO implement me
panic("implement me") panic("implement me")
} }
@ -325,6 +327,14 @@ func (am *MockAccountManager) DeleteGroup(ctx context.Context, accountId, userId
return status.Errorf(codes.Unimplemented, "method DeleteGroup is not implemented") return status.Errorf(codes.Unimplemented, "method DeleteGroup is not implemented")
} }
// DeleteGroups mock implementation of DeleteGroups from server.AccountManager interface
func (am *MockAccountManager) DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error {
if am.DeleteGroupsFunc != nil {
return am.DeleteGroupsFunc(ctx, accountId, userId, groupIDs)
}
return status.Errorf(codes.Unimplemented, "method DeleteGroups is not implemented")
}
// ListGroups mock implementation of ListGroups from server.AccountManager interface // ListGroups mock implementation of ListGroups from server.AccountManager interface
func (am *MockAccountManager) ListGroups(ctx context.Context, accountID string) ([]*group.Group, error) { func (am *MockAccountManager) ListGroups(ctx context.Context, accountID string) ([]*group.Group, error) {
if am.ListGroupsFunc != nil { if am.ListGroupsFunc != nil {
@ -519,6 +529,14 @@ func (am *MockAccountManager) DeleteUser(ctx context.Context, accountID string,
return status.Errorf(codes.Unimplemented, "method DeleteUser is not implemented") return status.Errorf(codes.Unimplemented, "method DeleteUser is not implemented")
} }
// DeleteRegularUsers mocks DeleteRegularUsers of the AccountManager interface
func (am *MockAccountManager) DeleteRegularUsers(ctx context.Context, accountID string, initiatorUserID string, targetUserIDs []string) error {
if am.DeleteRegularUsersFunc != nil {
return am.DeleteRegularUsersFunc(ctx, accountID, initiatorUserID, targetUserIDs)
}
return status.Errorf(codes.Unimplemented, "method DeleteRegularUsers is not implemented")
}
func (am *MockAccountManager) InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error { func (am *MockAccountManager) InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error {
if am.InviteUserFunc != nil { if am.InviteUserFunc != nil {
return am.InviteUserFunc(ctx, accountID, initiatorUserID, targetUserID) return am.InviteUserFunc(ctx, accountID, initiatorUserID, targetUserID)

View File

@ -20,7 +20,7 @@ const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$`
// GetNameServerGroup gets a nameserver group object from account and nameserver group IDs // GetNameServerGroup gets a nameserver group object from account and nameserver group IDs
func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) { func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
@ -48,7 +48,7 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account
// CreateNameServerGroup creates and saves a new nameserver group // CreateNameServerGroup creates and saves a new nameserver group
func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) { func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
@ -80,13 +80,13 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
account.NameServerGroups[newNSGroup.ID] = newNSGroup account.NameServerGroups[newNSGroup.ID] = newNSGroup
account.Network.IncSerial() account.Network.IncSerial()
if err := am.Store.SaveAccount(ctx, account); err != nil { err = am.Store.SaveAccount(ctx, account)
if err != nil {
return nil, err return nil, err
} }
if anyGroupHasPeers(account, newNSGroup.Groups) { am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, account)
}
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
return newNSGroup.Copy(), nil return newNSGroup.Copy(), nil
@ -94,7 +94,8 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
// SaveNameServerGroup saves nameserver group // SaveNameServerGroup saves nameserver group
func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error { func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
if nsGroupToSave == nil { if nsGroupToSave == nil {
@ -111,17 +112,16 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
return err return err
} }
oldNSGroup := account.NameServerGroups[nsGroupToSave.ID]
account.NameServerGroups[nsGroupToSave.ID] = nsGroupToSave account.NameServerGroups[nsGroupToSave.ID] = nsGroupToSave
account.Network.IncSerial() account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil { err = am.Store.SaveAccount(ctx, account)
if err != nil {
return err return err
} }
if anyGroupHasPeers(account, nsGroupToSave.Groups) || anyGroupHasPeers(account, oldNSGroup.Groups) { am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, account)
}
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
return nil return nil
@ -130,7 +130,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
// DeleteNameServerGroup deletes nameserver group with nsGroupID // DeleteNameServerGroup deletes nameserver group with nsGroupID
func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error { func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
@ -145,13 +145,13 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
delete(account.NameServerGroups, nsGroupID) delete(account.NameServerGroups, nsGroupID)
account.Network.IncSerial() account.Network.IncSerial()
if err := am.Store.SaveAccount(ctx, account); err != nil { err = am.Store.SaveAccount(ctx, account)
if err != nil {
return err return err
} }
if anyGroupHasPeers(account, nsGroup.Groups) { am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, account)
}
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
return nil return nil
@ -160,7 +160,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
// ListNameServerGroups returns a list of nameserver groups from account // ListNameServerGroups returns a list of nameserver groups from account
func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) { func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)

View File

@ -13,6 +13,7 @@ import (
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group" nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/telemetry"
) )
const ( const (
@ -764,7 +765,11 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
return nil, err return nil, err
} }
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{})
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics)
} }
func createNSStore(t *testing.T) (Store, error) { func createNSStore(t *testing.T) (Store, error) {

View File

@ -4,14 +4,15 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"slices"
"strings" "strings"
"sync"
"time" "time"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/rs/xid" "github.com/rs/xid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
@ -65,12 +66,14 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
peers := make([]*nbpeer.Peer, 0) peers := make([]*nbpeer.Peer, 0)
peersMap := make(map[string]*nbpeer.Peer) peersMap := make(map[string]*nbpeer.Peer)
if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked { regularUser := !user.HasAdminPower() && !user.IsServiceUser
if regularUser && account.Settings.RegularUsersViewBlocked {
return peers, nil return peers, nil
} }
for _, peer := range account.Peers { for _, peer := range account.Peers {
if !(user.HasAdminPower() || user.IsServiceUser) && user.Id != peer.UserID { if regularUser && user.Id != peer.UserID {
// only display peers that belong to the current user if the current user is not an admin // only display peers that belong to the current user if the current user is not an admin
continue continue
} }
@ -79,6 +82,10 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
peersMap[peer.ID] = p peersMap[peer.ID] = p
} }
if !regularUser {
return peers, nil
}
// fetch all the peers that have access to the user's peers // fetch all the peers that have access to the user's peers
for _, peer := range peers { for _, peer := range peers {
aclPeers, _ := account.getPeerConnectionResources(ctx, peer.ID, approvedPeersMap) aclPeers, _ := account.getPeerConnectionResources(ctx, peer.ID, approvedPeersMap)
@ -150,7 +157,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
// UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, and Peer.LoginExpirationEnabled can be updated. // UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, and Peer.LoginExpirationEnabled can be updated.
func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) { func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
@ -212,17 +219,13 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
} }
account.UpdatePeer(peer) account.UpdatePeer(peer)
account.Network.IncSerial()
err = am.Store.SaveAccount(ctx, account) err = am.Store.SaveAccount(ctx, account)
if err != nil { if err != nil {
return nil, err return nil, err
} }
expired, _ := peer.LoginExpired(account.Settings.PeerLoginExpiration) am.updateAccountPeers(ctx, account)
if expired && peer.LoginExpirationEnabled {
am.updateAccountPeers(ctx, account)
}
return peer, nil return peer, nil
} }
@ -278,7 +281,7 @@ func (am *DefaultAccountManager) deletePeers(ctx context.Context, account *Accou
// DeletePeer removes peer from the account by its IP // DeletePeer removes peer from the account by its IP
func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peerID, userID string) error { func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peerID, userID string) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
@ -286,7 +289,6 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
return err return err
} }
updateAccountPeers := isPeerInActiveGroup(account, peerID)
err = am.deletePeers(ctx, account, []string{peerID}, userID) err = am.deletePeers(ctx, account, []string{peerID}, userID)
if err != nil { if err != nil {
return err return err
@ -297,9 +299,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
return err return err
} }
if updateAccountPeers { am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, account)
}
return nil return nil
} }
@ -325,7 +325,8 @@ func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID strin
if err != nil { if err != nil {
return nil, err return nil, err
} }
return account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, validatedPeers), nil customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
return account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, nil), nil
} }
// GetPeerNetwork returns the Network for a given peer // GetPeerNetwork returns the Network for a given peer
@ -365,7 +366,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
return nil, nil, nil, status.Errorf(status.NotFound, "failed adding new peer: account not found") return nil, nil, nil, status.Errorf(status.NotFound, "failed adding new peer: account not found")
} }
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer func() { defer func() {
if unlock != nil { if unlock != nil {
unlock() unlock()
@ -389,7 +390,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
} }
// This is a handling for the case when the same machine (with the same WireGuard pub key) tries to register twice. // This is a handling for the case when the same machine (with the same WireGuard pub key) tries to register twice.
// Such case is possible when AddPeer function takes long time to finish after AcquireAccountWriteLock (e.g., database is slow) // Such case is possible when AddPeer function takes long time to finish after AcquireWriteLockByUID (e.g., database is slow)
// and the peer disconnects with a timeout and tries to register again. // and the peer disconnects with a timeout and tries to register again.
// We just check if this machine has been registered before and reject the second registration. // We just check if this machine has been registered before and reject the second registration.
// The connecting peer should be able to recover with a retry. // The connecting peer should be able to recover with a retry.
@ -462,6 +463,17 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
Location: peer.Location, Location: peer.Location,
} }
if am.geo != nil && newPeer.Location.ConnectionIP != nil {
location, err := am.geo.Lookup(newPeer.Location.ConnectionIP)
if err != nil {
log.WithContext(ctx).Warnf("failed to get location for new peer realip: [%s]: %v", newPeer.Location.ConnectionIP.String(), err)
} else {
newPeer.Location.CountryCode = location.Country.ISOCode
newPeer.Location.CityName = location.City.Names.En
newPeer.Location.GeoNameID = location.City.GeonameID
}
}
// add peer to 'All' group // add peer to 'All' group
group, err := account.GetGroupAll() group, err := account.GetGroupAll()
if err != nil { if err != nil {
@ -501,7 +513,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
} }
account.Peers[newPeer.ID] = newPeer account.Peers[newPeer.ID] = newPeer
account.Network.IncSerial() account.Network.IncSerial()
err = am.Store.SaveAccount(ctx, account) err = am.Store.SaveAccount(ctx, account)
if err != nil { if err != nil {
@ -520,9 +531,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta) am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
if areGroupChangesAffectPeers(account, groupsToAdd) { am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, account)
}
approvedPeersMap, err := am.GetValidatedPeers(account) approvedPeersMap, err := am.GetValidatedPeers(account)
if err != nil { if err != nil {
@ -530,7 +539,8 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
} }
postureChecks := am.getPeerPostureChecks(account, peer) postureChecks := am.getPeerPostureChecks(account, peer)
networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, am.dnsDomain, approvedPeersMap) customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
return newPeer, networkMap, postureChecks, nil return newPeer, networkMap, postureChecks, nil
} }
@ -547,7 +557,19 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
} }
if peerLoginExpired(ctx, peer, account.Settings) { if peerLoginExpired(ctx, peer, account.Settings) {
return nil, nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more") return nil, nil, nil, status.NewPeerLoginExpiredError()
}
peer, updated := updatePeerMeta(peer, sync.Meta, account)
if updated {
err = am.Store.SavePeer(ctx, account.Id, peer)
if err != nil {
return nil, nil, nil, err
}
if sync.UpdateAccountPeers {
am.updateAccountPeers(ctx, account)
}
} }
peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra) peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
@ -555,22 +577,16 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
return nil, nil, nil, err return nil, nil, nil, err
} }
var postureChecks []*posture.Checks
if peerNotValid { if peerNotValid {
emptyMap := &NetworkMap{ emptyMap := &NetworkMap{
Network: account.Network.Copy(), Network: account.Network.Copy(),
} }
return peer, emptyMap, nil, nil return peer, emptyMap, postureChecks, nil
} }
peer, peerMetaUpdated := updatePeerMeta(peer, sync.Meta, account) if isStatusChanged {
if peerMetaUpdated {
err = am.Store.SaveAccount(ctx, account)
if err != nil {
return nil, nil, nil, err
}
}
if isStatusChanged || (peerMetaUpdated && sync.UpdateAccountPeers) {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, account)
} }
@ -578,9 +594,10 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
postureChecks := am.getPeerPostureChecks(account, peer) postureChecks = am.getPeerPostureChecks(account, peer)
return peer, account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, validPeersMap), postureChecks, nil customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
} }
// LoginPeer logs in or registers a peer. // LoginPeer logs in or registers a peer.
@ -592,21 +609,10 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
// we couldn't find this peer by its public key which can mean that peer hasn't been registered yet. // we couldn't find this peer by its public key which can mean that peer hasn't been registered yet.
// Try registering it. // Try registering it.
newPeer := &nbpeer.Peer{ newPeer := &nbpeer.Peer{
Key: login.WireGuardPubKey, Key: login.WireGuardPubKey,
Meta: login.Meta, Meta: login.Meta,
SSHKey: login.SSHKey, SSHKey: login.SSHKey,
} Location: nbpeer.Location{ConnectionIP: login.ConnectionIP},
if am.geo != nil && login.ConnectionIP != nil {
location, err := am.geo.Lookup(login.ConnectionIP)
if err != nil {
log.WithContext(ctx).Warnf("failed to get location for new peer realip: [%s]: %v", login.ConnectionIP.String(), err)
} else {
newPeer.Location.ConnectionIP = login.ConnectionIP
newPeer.Location.CountryCode = location.Country.ISOCode
newPeer.Location.CityName = location.City.Names.En
newPeer.Location.GeoNameID = location.City.GeonameID
}
} }
return am.AddPeer(ctx, login.SetupKey, login.UserID, newPeer) return am.AddPeer(ctx, login.SetupKey, login.UserID, newPeer)
@ -616,44 +622,17 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
return nil, nil, nil, status.Errorf(status.Internal, "failed while logging in peer") return nil, nil, nil, status.Errorf(status.Internal, "failed while logging in peer")
} }
peer, err := am.Store.GetPeerByPeerPubKey(ctx, login.WireGuardPubKey) // when the client sends a login request with a JWT which is used to get the user ID,
if err != nil { // it means that the client has already checked if it needs login and had been through the SSO flow
return nil, nil, nil, status.NewPeerNotRegisteredError() // so, we can skip this check and directly proceed with the login
} if login.UserID == "" {
err = am.checkIFPeerNeedsLoginWithoutLock(ctx, accountID, login)
accSettings, err := am.Store.GetAccountSettings(ctx, accountID) if err != nil {
if err != nil {
return nil, nil, nil, status.Errorf(status.Internal, "failed to get account settings: %s", err)
}
var isWriteLock bool
// duplicated logic from after the lock to have an early exit
expired := peerLoginExpired(ctx, peer, accSettings)
switch {
case expired:
if err := checkAuth(ctx, login.UserID, peer); err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
isWriteLock = true
log.WithContext(ctx).Debugf("peer login expired, acquiring write lock")
case peer.UpdateMetaIfNew(login.Meta):
isWriteLock = true
log.WithContext(ctx).Debugf("peer changed meta, acquiring write lock")
default:
isWriteLock = false
log.WithContext(ctx).Debugf("peer meta is the same, acquiring read lock")
} }
var unlock func() unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
if isWriteLock {
unlock = am.Store.AcquireAccountWriteLock(ctx, accountID)
} else {
unlock = am.Store.AcquireAccountReadLock(ctx, accountID)
}
defer func() { defer func() {
if unlock != nil { if unlock != nil {
unlock() unlock()
@ -666,7 +645,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
return nil, nil, nil, err return nil, nil, nil, err
} }
peer, err = account.FindPeerByPubKey(login.WireGuardPubKey) peer, err := account.FindPeerByPubKey(login.WireGuardPubKey)
if err != nil { if err != nil {
return nil, nil, nil, status.NewPeerNotRegisteredError() return nil, nil, nil, status.NewPeerNotRegisteredError()
} }
@ -677,53 +656,39 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
} }
// this flag prevents unnecessary calls to the persistent store. // this flag prevents unnecessary calls to the persistent store.
shouldStoreAccount := false shouldStorePeer := false
updateRemotePeers := false updateRemotePeers := false
if peerLoginExpired(ctx, peer, account.Settings) { if peerLoginExpired(ctx, peer, account.Settings) {
err = checkAuth(ctx, login.UserID, peer) err = am.handleExpiredPeer(ctx, login, account, peer)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
// If peer was expired before and if it reached this point, it is re-authenticated.
// UserID is present, meaning that JWT validation passed successfully in the API layer.
updatePeerLastLogin(peer, account)
updateRemotePeers = true updateRemotePeers = true
shouldStoreAccount = true shouldStorePeer = true
// sync user last login with peer last login
user, err := account.FindUser(login.UserID)
if err != nil {
return nil, nil, nil, status.Errorf(status.Internal, "couldn't find user")
}
user.updateLastLogin(peer.LastLogin)
am.StoreEvent(ctx, login.UserID, peer.ID, account.Id, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain()))
} }
isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra) isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
peer, updated := updatePeerMeta(peer, login.Meta, account) peer, updated := updatePeerMeta(peer, login.Meta, account)
if updated { if updated {
shouldStoreAccount = true shouldStorePeer = true
} }
peer, err = am.checkAndUpdatePeerSSHKey(ctx, peer, account, login.SSHKey) if peer.SSHKey != login.SSHKey {
if err != nil { peer.SSHKey = login.SSHKey
return nil, nil, nil, err shouldStorePeer = true
} }
if shouldStoreAccount { if shouldStorePeer {
if !isWriteLock { err = am.Store.SavePeer(ctx, accountID, peer)
log.WithContext(ctx).Errorf("account %s should be stored but is not write locked", accountID)
return nil, nil, nil, status.Errorf(status.Internal, "account should be stored but is not write locked")
}
err = am.Store.SaveAccount(ctx, account)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
} }
unlock() unlock()
unlock = nil unlock = nil
@ -731,13 +696,46 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, account)
} }
return am.getValidatedPeerWithMap(ctx, isRequiresApproval, account, peer)
}
// checkIFPeerNeedsLoginWithoutLock checks if the peer needs login without acquiring the account lock. The check validate if the peer was not added via SSO
// and if the peer login is expired.
// The NetBird client doesn't have a way to check if the peer needs login besides sending a login request
// with no JWT token and usually no setup-key. As the client can send up to two login request to check if it is expired
// and before starting the engine, we do the checks without an account lock to avoid piling up requests.
func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Context, accountID string, login PeerLogin) error {
peer, err := am.Store.GetPeerByPeerPubKey(ctx, login.WireGuardPubKey)
if err != nil {
return err
}
// if the peer was not added with SSO login we can exit early because peers activated with setup-key
// doesn't expire, and we avoid extra databases calls.
if !peer.AddedWithSSOLogin() {
return nil
}
settings, err := am.Store.GetAccountSettings(ctx, accountID)
if err != nil {
return err
}
if peerLoginExpired(ctx, peer, settings) {
return status.NewPeerLoginExpiredError()
}
return nil
}
func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, account *Account, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
var postureChecks []*posture.Checks var postureChecks []*posture.Checks
if isRequiresApproval { if isRequiresApproval {
emptyMap := &NetworkMap{ emptyMap := &NetworkMap{
Network: account.Network.Copy(), Network: account.Network.Copy(),
} }
return peer, emptyMap, postureChecks, nil return peer, emptyMap, nil, nil
} }
approvedPeersMap, err := am.GetValidatedPeers(account) approvedPeersMap, err := am.GetValidatedPeers(account)
@ -746,7 +744,32 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
} }
postureChecks = am.getPeerPostureChecks(account, peer) postureChecks = am.getPeerPostureChecks(account, peer)
return peer, account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, approvedPeersMap), postureChecks, nil customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
}
func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, login PeerLogin, account *Account, peer *nbpeer.Peer) error {
err := checkAuth(ctx, login.UserID, peer)
if err != nil {
return err
}
// If peer was expired before and if it reached this point, it is re-authenticated.
// UserID is present, meaning that JWT validation passed successfully in the API layer.
updatePeerLastLogin(peer, account)
// sync user last login with peer last login
user, err := account.FindUser(login.UserID)
if err != nil {
return status.Errorf(status.Internal, "couldn't find user")
}
err = am.Store.SaveUserLastLogin(account.Id, user.Id, peer.LastLogin)
if err != nil {
return err
}
am.StoreEvent(ctx, login.UserID, peer.ID, account.Id, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain()))
return nil
} }
func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, account *Account) error { func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, account *Account) error {
@ -765,11 +788,11 @@ func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, account *Account) error {
func checkAuth(ctx context.Context, loginUserID string, peer *nbpeer.Peer) error { func checkAuth(ctx context.Context, loginUserID string, peer *nbpeer.Peer) error {
if loginUserID == "" { if loginUserID == "" {
// absence of a user ID indicates that JWT wasn't provided. // absence of a user ID indicates that JWT wasn't provided.
return status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more") return status.NewPeerLoginExpiredError()
} }
if peer.UserID != loginUserID { if peer.UserID != loginUserID {
log.WithContext(ctx).Warnf("user mismatch when logging in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, loginUserID) log.WithContext(ctx).Warnf("user mismatch when logging in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, loginUserID)
return status.Errorf(status.Unauthenticated, "can't login") return status.Errorf(status.Unauthenticated, "can't login with this credentials")
} }
return nil return nil
} }
@ -789,31 +812,54 @@ func updatePeerLastLogin(peer *nbpeer.Peer, account *Account) {
account.UpdatePeer(peer) account.UpdatePeer(peer)
} }
func (am *DefaultAccountManager) checkAndUpdatePeerSSHKey(ctx context.Context, peer *nbpeer.Peer, account *Account, newSSHKey string) (*nbpeer.Peer, error) { // UpdatePeerSSHKey updates peer's public SSH key
if len(newSSHKey) == 0 { func (am *DefaultAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error {
log.WithContext(ctx).Debugf("no new SSH key provided for peer %s, skipping update", peer.ID) if sshKey == "" {
return peer, nil log.WithContext(ctx).Debugf("empty SSH key provided for peer %s, skipping update", peerID)
return nil
} }
if peer.SSHKey == newSSHKey { account, err := am.Store.GetAccountByPeerID(ctx, peerID)
log.WithContext(ctx).Debugf("same SSH key provided for peer %s, skipping update", peer.ID) if err != nil {
return peer, nil return err
} }
peer.SSHKey = newSSHKey unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id)
defer unlock()
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
account, err = am.Store.GetAccount(ctx, account.Id)
if err != nil {
return err
}
peer := account.GetPeer(peerID)
if peer == nil {
return status.Errorf(status.NotFound, "peer with ID %s not found", peerID)
}
if peer.SSHKey == sshKey {
log.WithContext(ctx).Debugf("same SSH key provided for peer %s, skipping update", peerID)
return nil
}
peer.SSHKey = sshKey
account.UpdatePeer(peer) account.UpdatePeer(peer)
err := am.Store.SaveAccount(ctx, account) err = am.Store.SaveAccount(ctx, account)
if err != nil { if err != nil {
return nil, err return err
} }
return peer, nil // trigger network map update
am.updateAccountPeers(ctx, account)
return nil
} }
// GetPeer for a given accountID, peerID and userID error if not found. // GetPeer for a given accountID, peerID and userID error if not found.
func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) { func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
@ -875,34 +921,45 @@ func updatePeerMeta(peer *nbpeer.Peer, meta nbpeer.PeerSystemMeta, account *Acco
// updateAccountPeers updates all peers that belong to an account. // updateAccountPeers updates all peers that belong to an account.
// Should be called when changes have to be synced to peers. // Should be called when changes have to be synced to peers.
func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account *Account) { func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account *Account) {
start := time.Now()
defer func() {
if am.metrics != nil {
am.metrics.AccountManagerMetrics().CountUpdateAccountPeersDuration(time.Since(start))
}
}()
peers := account.GetPeers() peers := account.GetPeers()
approvedPeersMap, err := am.GetValidatedPeers(account) approvedPeersMap, err := am.GetValidatedPeers(account)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed send out updates to peers, failed to validate peer: %v", err) log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to validate peer: %v", err)
return return
} }
var wg sync.WaitGroup
semaphore := make(chan struct{}, 10)
dnsCache := &DNSConfigCache{}
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
for _, peer := range peers { for _, peer := range peers {
if !am.peersUpdateManager.HasChannel(peer.ID) { if !am.peersUpdateManager.HasChannel(peer.ID) {
log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID) log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
continue continue
} }
postureChecks := am.getPeerPostureChecks(account, peer) wg.Add(1)
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, approvedPeersMap) semaphore <- struct{}{}
update := toSyncResponse(ctx, nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks) go func(p *nbpeer.Peer) {
go am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap, Checks: postureChecks}) defer wg.Done()
} defer func() { <-semaphore }()
}
// IsPeerInActiveGroup checks if the given peer is part of a group that is used postureChecks := am.getPeerPostureChecks(account, p)
// in an active DNS, route, or ACL configuration. remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
func isPeerInActiveGroup(account *Account, peerID string) bool { update := toSyncResponse(ctx, nil, p, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache)
peerGroupIDs := make([]string, 0) am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap, Checks: postureChecks})
for _, group := range account.Groups { }(peer)
if slices.Contains(group.Peers, peerID) {
peerGroupIDs = append(peerGroupIDs, group.ID)
}
} }
return areGroupChangesAffectPeers(account, peerGroupIDs)
wg.Wait()
} }

View File

@ -1,7 +1,6 @@
package peer package peer
import ( import (
"fmt"
"net" "net"
"net/netip" "net/netip"
"slices" "slices"
@ -241,7 +240,7 @@ func (p *Peer) FQDN(dnsDomain string) string {
if dnsDomain == "" { if dnsDomain == "" {
return "" return ""
} }
return fmt.Sprintf("%s.%s", p.DNSLabel, dnsDomain) return p.DNSLabel + "." + dnsDomain
} }
// EventMeta returns activity event meta related to the peer // EventMeta returns activity event meta related to the peer

View File

@ -0,0 +1,31 @@
package peer
import (
"fmt"
"testing"
)
// FQDNOld is the original implementation for benchmarking purposes
func (p *Peer) FQDNOld(dnsDomain string) string {
if dnsDomain == "" {
return ""
}
return fmt.Sprintf("%s.%s", p.DNSLabel, dnsDomain)
}
func BenchmarkFQDN(b *testing.B) {
p := &Peer{DNSLabel: "test-peer"}
dnsDomain := "example.com"
b.Run("Old", func(b *testing.B) {
for i := 0; i < b.N; i++ {
p.FQDNOld(dnsDomain)
}
})
b.Run("New", func(b *testing.B) {
for i := 0; i < b.N; i++ {
p.FQDN(dnsDomain)
}
})
}

View File

@ -2,16 +2,26 @@ package server
import ( import (
"context" "context"
"fmt"
"io"
"net"
"net/netip"
"os"
"testing" "testing"
"time" "time"
"github.com/rs/xid" "github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/proto"
nbgroup "github.com/netbirdio/netbird/management/server/group" nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
nbroute "github.com/netbirdio/netbird/route"
) )
func TestPeer_LoginExpired(t *testing.T) { func TestPeer_LoginExpired(t *testing.T) {
@ -635,155 +645,353 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
} }
func TestPeerAccountPeerUpdate(t *testing.T) { func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccountManager, string, string, error) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) b.Helper()
err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID) manager, err := createManager(b)
require.NoError(t, err) if err != nil {
return nil, "", "", err
}
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ accountID := "test_account"
ID: "group-id", adminUser := "account_creator"
Name: "GroupA", regularUser := "regular_user"
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
})
require.NoError(t, err)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) account := newAccountWithId(context.Background(), accountID, adminUser, "")
t.Cleanup(func() { account.Users[regularUser] = &User{
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) Id: regularUser,
}) Role: UserRoleUser,
}
// create a user with auto groups // Create peers
_, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{ for i := 0; i < peers; i++ {
Id: "regularUser1", peerKey, _ := wgtypes.GeneratePrivateKey()
AccountID: account.Id, peer := &nbpeer.Peer{
Role: UserRoleAdmin, ID: fmt.Sprintf("peer-%d", i),
Issued: UserIssuedAPI, DNSLabel: fmt.Sprintf("peer-%d", i),
AutoGroups: []string{"group-id"}, Key: peerKey.PublicKey().String(),
}, true) IP: net.ParseIP(fmt.Sprintf("100.64.%d.%d", i/256, i%256)),
require.NoError(t, err) Status: &nbpeer.PeerStatus{},
UserID: regularUser,
var peer4 *nbpeer.Peer
// Updating not expired peer and peer expiration is enabled should not update account peers and not send peer update
t.Run("updating not expired peer and peer expiration is enabled", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
_, err := manager.UpdatePeer(context.Background(), account.Id, userID, peer2)
require.NoError(t, err)
select {
case <-done:
case <-time.After(200 * time.Millisecond):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
} }
}) account.Peers[peer.ID] = peer
}
// Adding peer with an unused group in active dns, route, acl should not update account peers and not send peer update // Create groups and policies
t.Run("adding peer with unused group", func(t *testing.T) { account.Policies = make([]*Policy, 0, groups)
done := make(chan struct{}) for i := 0; i < groups; i++ {
go func() { groupID := fmt.Sprintf("group-%d", i)
peerShouldNotReceiveUpdate(t, updMsg) group := &nbgroup.Group{
close(done) ID: groupID,
}() Name: fmt.Sprintf("Group %d", i),
key, err := wgtypes.GeneratePrivateKey()
require.NoError(t, err)
expectedPeerKey := key.PublicKey().String()
peer4, _, _, err = manager.AddPeer(context.Background(), "", "regularUser1", &nbpeer.Peer{
Key: expectedPeerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
})
require.NoError(t, err)
select {
case <-done:
case <-time.After(200 * time.Millisecond):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
} }
}) for j := 0; j < peers/groups; j++ {
peerIndex := i*(peers/groups) + j
// Deleting peer with an unused group in active dns, route, acl should not update account peers and not send peer update group.Peers = append(group.Peers, fmt.Sprintf("peer-%d", peerIndex))
t.Run("deleting peer with unused group", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
err = manager.DeletePeer(context.Background(), account.Id, peer4.ID, userID)
require.NoError(t, err)
select {
case <-done:
case <-time.After(200 * time.Millisecond):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
} }
}) account.Groups[groupID] = group
// use the group-id in policy // Create a policy for this group
err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ policy := &Policy{
ID: "policy", ID: fmt.Sprintf("policy-%d", i),
Enabled: true, Name: fmt.Sprintf("Policy for Group %d", i),
Rules: []*PolicyRule{ Enabled: true,
{ Rules: []*PolicyRule{
Enabled: true, {
Sources: []string{"group-id"}, ID: fmt.Sprintf("rule-%d", i),
Destinations: []string{"group-id"}, Name: fmt.Sprintf("Rule for Group %d", i),
Bidirectional: true, Enabled: true,
Action: PolicyTrafficActionAccept, Sources: []string{groupID},
Destinations: []string{groupID},
Bidirectional: true,
Protocol: PolicyRuleProtocolALL,
Action: PolicyTrafficActionAccept,
},
},
}
account.Policies = append(account.Policies, policy)
}
account.PostureChecks = []*posture.Checks{
{
ID: "PostureChecksAll",
Name: "All",
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.0.1",
},
}, },
}, },
}) }
require.NoError(t, err)
// Adding peer with a used group in active dns, route or policy should update account peers and send peer update err = manager.Store.SaveAccount(context.Background(), account)
t.Run("adding peer with used group", func(t *testing.T) { if err != nil {
done := make(chan struct{}) return nil, "", "", err
go func() { }
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
key, err := wgtypes.GeneratePrivateKey()
require.NoError(t, err)
expectedPeerKey := key.PublicKey().String()
peer4, _, _, err = manager.AddPeer(context.Background(), "", "regularUser1", &nbpeer.Peer{
Key: expectedPeerKey,
LoginExpirationEnabled: true,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
})
require.NoError(t, err)
select {
case <-done:
case <-time.After(200 * time.Millisecond):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
//Deleting peer with a used group in active dns, route or acl should update account peers and send peer update
t.Run("deleting peer with used group", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
err = manager.DeletePeer(context.Background(), account.Id, peer4.ID, userID)
require.NoError(t, err)
select {
case <-done:
case <-time.After(200 * time.Millisecond):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
return manager, accountID, regularUser, nil
}
func BenchmarkGetPeers(b *testing.B) {
benchCases := []struct {
name string
peers int
groups int
}{
{"Small", 50, 5},
{"Medium", 500, 10},
{"Large", 5000, 20},
{"Small single", 50, 1},
{"Medium single", 500, 1},
{"Large 5", 5000, 5},
}
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) {
manager, accountID, userID, err := setupTestAccountManager(b, bc.peers, bc.groups)
if err != nil {
b.Fatalf("Failed to setup test account manager: %v", err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := manager.GetPeers(context.Background(), accountID, userID)
if err != nil {
b.Fatalf("GetPeers failed: %v", err)
}
}
})
}
}
func BenchmarkUpdateAccountPeers(b *testing.B) {
benchCases := []struct {
name string
peers int
groups int
}{
{"Small", 50, 5},
{"Medium", 500, 10},
{"Large", 5000, 20},
{"Small single", 50, 1},
{"Medium single", 500, 1},
{"Large 5", 5000, 5},
}
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) {
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
if err != nil {
b.Fatalf("Failed to setup test account manager: %v", err)
}
ctx := context.Background()
account, err := manager.Store.GetAccount(ctx, accountID)
if err != nil {
b.Fatalf("Failed to get account: %v", err)
}
peerChannels := make(map[string]chan *UpdateMessage)
for peerID := range account.Peers {
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
}
manager.peersUpdateManager.peerChannels = peerChannels
b.ResetTimer()
start := time.Now()
for i := 0; i < b.N; i++ {
manager.updateAccountPeers(ctx, account)
}
duration := time.Since(start)
b.ReportMetric(float64(duration.Nanoseconds())/float64(b.N)/1e6, "ms/op")
b.ReportMetric(0, "ns/op")
})
}
}
func TestToSyncResponse(t *testing.T) {
_, ipnet, err := net.ParseCIDR("192.168.1.0/24")
if err != nil {
t.Fatal(err)
}
domainList, err := domain.FromStringList([]string{"example.com"})
if err != nil {
t.Fatal(err)
}
config := &Config{
Signal: &Host{
Proto: "https",
URI: "signal.uri",
Username: "",
Password: "",
},
Stuns: []*Host{{URI: "stun.uri", Proto: UDP}},
TURNConfig: &TURNConfig{
Turns: []*Host{{URI: "turn.uri", Proto: UDP, Username: "turn-user", Password: "turn-pass"}},
},
}
peer := &nbpeer.Peer{
IP: net.ParseIP("192.168.1.1"),
SSHEnabled: true,
Key: "peer-key",
DNSLabel: "peer1",
SSHKey: "peer1-ssh-key",
}
turnCredentials := &TURNCredentials{
Username: "turn-user",
Password: "turn-pass",
}
networkMap := &NetworkMap{
Network: &Network{Net: *ipnet, Serial: 1000},
Peers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.2"), Key: "peer2-key", DNSLabel: "peer2", SSHEnabled: true, SSHKey: "peer2-ssh-key"}},
OfflinePeers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.3"), Key: "peer3-key", DNSLabel: "peer3", SSHEnabled: true, SSHKey: "peer3-ssh-key"}},
Routes: []*nbroute.Route{
{
ID: "route1",
Network: netip.MustParsePrefix("10.0.0.0/24"),
Domains: domainList,
KeepRoute: true,
NetID: "route1",
Peer: "peer1",
NetworkType: 1,
Masquerade: true,
Metric: 9999,
Enabled: true,
},
},
DNSConfig: nbdns.Config{
ServiceEnable: true,
NameServerGroups: []*nbdns.NameServerGroup{
{
NameServers: []nbdns.NameServer{{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
}},
Primary: true,
Domains: []string{"example.com"},
Enabled: true,
SearchDomainsEnabled: true,
},
{
ID: "ns1",
NameServers: []nbdns.NameServer{{
IP: netip.MustParseAddr("1.1.1.1"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
}},
Groups: []string{"group1"},
Primary: true,
Domains: []string{"example.com"},
Enabled: true,
SearchDomainsEnabled: true,
},
},
CustomZones: []nbdns.CustomZone{{Domain: "example.com", Records: []nbdns.SimpleRecord{{Name: "example.com", Type: 1, Class: "IN", TTL: 60, RData: "100.64.0.1"}}}},
},
FirewallRules: []*FirewallRule{
{PeerIP: "192.168.1.2", Direction: firewallRuleDirectionIN, Action: string(PolicyTrafficActionAccept), Protocol: string(PolicyRuleProtocolTCP), Port: "80"},
},
}
dnsName := "example.com"
checks := []*posture.Checks{
{
Checks: posture.ChecksDefinition{
ProcessCheck: &posture.ProcessCheck{
Processes: []posture.Process{{LinuxPath: "/usr/bin/netbird"}},
},
},
},
}
dnsCache := &DNSConfigCache{}
response := toSyncResponse(context.Background(), config, peer, turnCredentials, networkMap, dnsName, checks, dnsCache)
assert.NotNil(t, response)
// assert peer config
assert.Equal(t, "192.168.1.1/24", response.PeerConfig.Address)
assert.Equal(t, "peer1.example.com", response.PeerConfig.Fqdn)
assert.Equal(t, true, response.PeerConfig.SshConfig.SshEnabled)
// assert wiretrustee config
assert.Equal(t, "signal.uri", response.WiretrusteeConfig.Signal.Uri)
assert.Equal(t, proto.HostConfig_HTTPS, response.WiretrusteeConfig.Signal.GetProtocol())
assert.Equal(t, "stun.uri", response.WiretrusteeConfig.Stuns[0].Uri)
assert.Equal(t, "turn.uri", response.WiretrusteeConfig.Turns[0].HostConfig.GetUri())
assert.Equal(t, "turn-user", response.WiretrusteeConfig.Turns[0].User)
assert.Equal(t, "turn-pass", response.WiretrusteeConfig.Turns[0].Password)
// assert RemotePeers
assert.Equal(t, 1, len(response.RemotePeers))
assert.Equal(t, "192.168.1.2/32", response.RemotePeers[0].AllowedIps[0])
assert.Equal(t, "peer2-key", response.RemotePeers[0].WgPubKey)
assert.Equal(t, "peer2.example.com", response.RemotePeers[0].GetFqdn())
assert.Equal(t, false, response.RemotePeers[0].GetSshConfig().GetSshEnabled())
assert.Equal(t, []byte("peer2-ssh-key"), response.RemotePeers[0].GetSshConfig().GetSshPubKey())
// assert network map
assert.Equal(t, uint64(1000), response.NetworkMap.Serial)
assert.Equal(t, "192.168.1.1/24", response.NetworkMap.PeerConfig.Address)
assert.Equal(t, "peer1.example.com", response.NetworkMap.PeerConfig.Fqdn)
assert.Equal(t, true, response.NetworkMap.PeerConfig.SshConfig.SshEnabled)
// assert network map RemotePeers
assert.Equal(t, 1, len(response.NetworkMap.RemotePeers))
assert.Equal(t, "192.168.1.2/32", response.NetworkMap.RemotePeers[0].AllowedIps[0])
assert.Equal(t, "peer2-key", response.NetworkMap.RemotePeers[0].WgPubKey)
assert.Equal(t, "peer2.example.com", response.NetworkMap.RemotePeers[0].GetFqdn())
assert.Equal(t, []byte("peer2-ssh-key"), response.NetworkMap.RemotePeers[0].GetSshConfig().GetSshPubKey())
// assert network map OfflinePeers
assert.Equal(t, 1, len(response.NetworkMap.OfflinePeers))
assert.Equal(t, "192.168.1.3/32", response.NetworkMap.OfflinePeers[0].AllowedIps[0])
assert.Equal(t, "peer3-key", response.NetworkMap.OfflinePeers[0].WgPubKey)
assert.Equal(t, "peer3.example.com", response.NetworkMap.OfflinePeers[0].GetFqdn())
assert.Equal(t, []byte("peer3-ssh-key"), response.NetworkMap.OfflinePeers[0].GetSshConfig().GetSshPubKey())
// assert network map Routes
assert.Equal(t, 1, len(response.NetworkMap.Routes))
assert.Equal(t, "10.0.0.0/24", response.NetworkMap.Routes[0].Network)
assert.Equal(t, "route1", response.NetworkMap.Routes[0].ID)
assert.Equal(t, "peer1", response.NetworkMap.Routes[0].Peer)
assert.Equal(t, "example.com", response.NetworkMap.Routes[0].Domains[0])
assert.Equal(t, true, response.NetworkMap.Routes[0].KeepRoute)
assert.Equal(t, true, response.NetworkMap.Routes[0].Masquerade)
assert.Equal(t, int64(9999), response.NetworkMap.Routes[0].Metric)
assert.Equal(t, int64(1), response.NetworkMap.Routes[0].NetworkType)
assert.Equal(t, "route1", response.NetworkMap.Routes[0].NetID)
// assert network map DNSConfig
assert.Equal(t, true, response.NetworkMap.DNSConfig.ServiceEnable)
assert.Equal(t, 1, len(response.NetworkMap.DNSConfig.CustomZones))
assert.Equal(t, 2, len(response.NetworkMap.DNSConfig.NameServerGroups))
// assert network map DNSConfig.CustomZones
assert.Equal(t, "example.com", response.NetworkMap.DNSConfig.CustomZones[0].Domain)
assert.Equal(t, 1, len(response.NetworkMap.DNSConfig.CustomZones[0].Records))
assert.Equal(t, "example.com", response.NetworkMap.DNSConfig.CustomZones[0].Records[0].Name)
assert.Equal(t, int64(1), response.NetworkMap.DNSConfig.CustomZones[0].Records[0].Type)
assert.Equal(t, "IN", response.NetworkMap.DNSConfig.CustomZones[0].Records[0].Class)
assert.Equal(t, int64(60), response.NetworkMap.DNSConfig.CustomZones[0].Records[0].TTL)
assert.Equal(t, "100.64.0.1", response.NetworkMap.DNSConfig.CustomZones[0].Records[0].RData)
// assert network map DNSConfig.NameServerGroups
assert.Equal(t, true, response.NetworkMap.DNSConfig.NameServerGroups[0].Primary)
assert.Equal(t, true, response.NetworkMap.DNSConfig.NameServerGroups[0].SearchDomainsEnabled)
assert.Equal(t, "example.com", response.NetworkMap.DNSConfig.NameServerGroups[0].Domains[0])
assert.Equal(t, "8.8.8.8", response.NetworkMap.DNSConfig.NameServerGroups[0].NameServers[0].GetIP())
assert.Equal(t, int64(1), response.NetworkMap.DNSConfig.NameServerGroups[0].NameServers[0].GetNSType())
assert.Equal(t, int64(53), response.NetworkMap.DNSConfig.NameServerGroups[0].NameServers[0].GetPort())
// assert network map Firewall
assert.Equal(t, 1, len(response.NetworkMap.FirewallRules))
assert.Equal(t, "192.168.1.2", response.NetworkMap.FirewallRules[0].PeerIP)
assert.Equal(t, proto.FirewallRule_IN, response.NetworkMap.FirewallRules[0].Direction)
assert.Equal(t, proto.FirewallRule_ACCEPT, response.NetworkMap.FirewallRules[0].Action)
assert.Equal(t, proto.FirewallRule_TCP, response.NetworkMap.FirewallRules[0].Protocol)
assert.Equal(t, "80", response.NetworkMap.FirewallRules[0].Port)
// assert posture checks
assert.Equal(t, 1, len(response.Checks))
assert.Equal(t, "/usr/bin/netbird", response.Checks[0].Files[0])
} }

View File

@ -223,7 +223,6 @@ type FirewallRule struct {
// //
// This function returns the list of peers and firewall rules that are applicable to a given peer. // This function returns the list of peers and firewall rules that are applicable to a given peer.
func (a *Account) getPeerConnectionResources(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) { func (a *Account) getPeerConnectionResources(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) {
generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx) generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx)
for _, policy := range a.Policies { for _, policy := range a.Policies {
if !policy.Enabled { if !policy.Enabled {
@ -235,8 +234,8 @@ func (a *Account) getPeerConnectionResources(ctx context.Context, peerID string,
continue continue
} }
sourcePeers, peerInSources := getAllPeersFromGroups(ctx, a, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap) sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap)
destinationPeers, peerInDestinations := getAllPeersFromGroups(ctx, a, rule.Destinations, peerID, nil, validatedPeersMap) destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, rule.Destinations, peerID, nil, validatedPeersMap)
if rule.Bidirectional { if rule.Bidirectional {
if peerInSources { if peerInSources {
@ -300,8 +299,8 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule,
fr.PeerIP = "0.0.0.0" fr.PeerIP = "0.0.0.0"
} }
ruleID := (rule.ID + fr.PeerIP + strconv.Itoa(direction) + ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) +
fr.Protocol + fr.Action + strings.Join(rule.Ports, ",")) fr.Protocol + fr.Action + strings.Join(rule.Ports, ",")
if _, ok := rulesExists[ruleID]; ok { if _, ok := rulesExists[ruleID]; ok {
continue continue
} }
@ -325,7 +324,7 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule,
// GetPolicy from the store // GetPolicy from the store
func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) { func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
@ -353,7 +352,7 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic
// SavePolicy in the store // SavePolicy in the store
func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) error { func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
@ -383,7 +382,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
// DeletePolicy from the store // DeletePolicy from the store
func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error { func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
@ -412,7 +411,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
// ListPolicies from the store // ListPolicies from the store
func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) { func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
@ -505,23 +504,23 @@ func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule {
// //
// Important: Posture checks are applicable only to source group peers, // Important: Posture checks are applicable only to source group peers,
// for destination group peers, call this method with an empty list of sourcePostureChecksIDs // for destination group peers, call this method with an empty list of sourcePostureChecksIDs
func getAllPeersFromGroups(ctx context.Context, account *Account, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) { func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) {
peerInGroups := false peerInGroups := false
filteredPeers := make([]*nbpeer.Peer, 0, len(groups)) filteredPeers := make([]*nbpeer.Peer, 0, len(groups))
for _, g := range groups { for _, g := range groups {
group, ok := account.Groups[g] group, ok := a.Groups[g]
if !ok { if !ok {
continue continue
} }
for _, p := range group.Peers { for _, p := range group.Peers {
peer, ok := account.Peers[p] peer, ok := a.Peers[p]
if !ok || peer == nil { if !ok || peer == nil {
continue continue
} }
// validate the peer based on policy posture checks applied // validate the peer based on policy posture checks applied
isValid := account.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID) isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID)
if !isValid { if !isValid {
continue continue
} }
@ -549,7 +548,7 @@ func (a *Account) validatePostureChecksOnPeer(ctx context.Context, sourcePosture
} }
for _, postureChecksID := range sourcePostureChecksID { for _, postureChecksID := range sourcePostureChecksID {
postureChecks := getPostureChecks(a, postureChecksID) postureChecks := a.getPostureChecks(postureChecksID)
if postureChecks == nil { if postureChecks == nil {
continue continue
} }
@ -567,8 +566,8 @@ func (a *Account) validatePostureChecksOnPeer(ctx context.Context, sourcePosture
return true return true
} }
func getPostureChecks(account *Account, postureChecksID string) *posture.Checks { func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks {
for _, postureChecks := range account.PostureChecks { for _, postureChecks := range a.PostureChecks {
if postureChecks.ID == postureChecksID { if postureChecks.ID == postureChecksID {
return postureChecks return postureChecks
} }

View File

@ -15,7 +15,7 @@ const (
) )
func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) { func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
@ -42,7 +42,7 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID
} }
func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error { func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
@ -91,7 +91,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
} }
func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error { func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
@ -123,7 +123,7 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
} }
func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) { func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)

View File

@ -17,7 +17,7 @@ import (
// GetRoute gets a route object from account and route IDs // GetRoute gets a route object from account and route IDs
func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) { func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
@ -126,7 +126,7 @@ func getRouteDescriptor(prefix netip.Prefix, domains domain.List) string {
// CreateRoute creates and saves a new route // CreateRoute creates and saves a new route
func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) { func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
@ -216,7 +216,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
// SaveRoute saves route // SaveRoute saves route
func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userID string, routeToSave *route.Route) error { func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userID string, routeToSave *route.Route) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
if routeToSave == nil { if routeToSave == nil {
@ -288,7 +288,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
// DeleteRoute deletes route with routeID // DeleteRoute deletes route with routeID
func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error { func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
@ -318,7 +318,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
// ListRoutes returns a list of routes from account // ListRoutes returns a list of routes from account
func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) { func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)

View File

@ -14,6 +14,7 @@ import (
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group" nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@ -1234,7 +1235,11 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) {
return nil, err return nil, err
} }
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{})
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics)
} }
func createRouterStore(t *testing.T) (Store, error) { func createRouterStore(t *testing.T) (Store, error) {

View File

@ -210,7 +210,7 @@ func Hash(s string) uint32 {
// and adds it to the specified account. A list of autoGroups IDs can be empty. // and adds it to the specified account. A list of autoGroups IDs can be empty.
func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType SetupKeyType, func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType SetupKeyType,
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error) { expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
keyDuration := DefaultSetupKeyDuration keyDuration := DefaultSetupKeyDuration
@ -223,10 +223,8 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
return nil, err return nil, err
} }
for _, group := range autoGroups { if err := validateSetupKeyAutoGroups(account, autoGroups); err != nil {
if _, ok := account.Groups[group]; !ok { return nil, err
return nil, status.Errorf(status.NotFound, "group %s doesn't exist", group)
}
} }
setupKey := GenerateSetupKey(keyName, keyType, keyDuration, autoGroups, usageLimit, ephemeral) setupKey := GenerateSetupKey(keyName, keyType, keyDuration, autoGroups, usageLimit, ephemeral)
@ -256,7 +254,7 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
// (e.g. the key itself, creation date, ID, etc). // (e.g. the key itself, creation date, ID, etc).
// These properties are overwritten: Name, AutoGroups, Revoked. The rest is copied from the existing key. // These properties are overwritten: Name, AutoGroups, Revoked. The rest is copied from the existing key.
func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID string, keyToSave *SetupKey, userID string) (*SetupKey, error) { func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID string, keyToSave *SetupKey, userID string) (*SetupKey, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
if keyToSave == nil { if keyToSave == nil {
@ -279,6 +277,10 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
return nil, status.Errorf(status.NotFound, "setup key not found") return nil, status.Errorf(status.NotFound, "setup key not found")
} }
if err := validateSetupKeyAutoGroups(account, keyToSave.AutoGroups); err != nil {
return nil, err
}
// only auto groups, revoked status, and name can be updated for now // only auto groups, revoked status, and name can be updated for now
newKey := oldKey.Copy() newKey := oldKey.Copy()
newKey.Name = keyToSave.Name newKey.Name = keyToSave.Name
@ -326,7 +328,7 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
// ListSetupKeys returns a list of all setup keys of the account // ListSetupKeys returns a list of all setup keys of the account
func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error) { func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
@ -358,7 +360,7 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u
// GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found. // GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found.
func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) { func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
@ -397,3 +399,16 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use
return foundKey, nil return foundKey, nil
} }
func validateSetupKeyAutoGroups(account *Account, autoGroups []string) error {
for _, group := range autoGroups {
g, ok := account.Groups[group]
if !ok {
return status.Errorf(status.NotFound, "group %s doesn't exist", group)
}
if g.Name == "All" {
return status.Errorf(status.InvalidArgument, "can't add All group to the setup key")
}
}
return nil
}

View File

@ -27,10 +27,17 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{
ID: "group_1", {
Name: "group_name_1", ID: "group_1",
Peers: []string{}, Name: "group_name_1",
Peers: []string{},
},
{
ID: "group_2",
Name: "group_name_2",
Peers: []string{},
},
}) })
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -71,6 +78,19 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
assert.NotEmpty(t, ev.Meta["key"]) assert.NotEmpty(t, ev.Meta["key"])
assert.Equal(t, userID, ev.InitiatorID) assert.Equal(t, userID, ev.InitiatorID)
assert.Equal(t, key.Id, ev.TargetID) assert.Equal(t, key.Id, ev.TargetID)
groupAll, err := account.GetGroupAll()
assert.NoError(t, err)
// saving setup key with All group assigned to auto groups should return error
autoGroups = append(autoGroups, groupAll.ID)
_, err = manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{
Id: key.Id,
Name: newKeyName,
Revoked: revoked,
AutoGroups: autoGroups,
}, userID)
assert.Error(t, err, "should not save setup key with All group assigned in auto groups")
} }
func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
@ -103,6 +123,9 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
groupAll, err := account.GetGroupAll()
assert.NoError(t, err)
type testCase struct { type testCase struct {
name string name string
@ -135,8 +158,14 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
expectedGroups: []string{"FAKE"}, expectedGroups: []string{"FAKE"},
expectedFailure: true, expectedFailure: true,
} }
testCase3 := testCase{
name: "Create Setup Key should fail because of All group",
expectedKeyName: "my-test-key",
expectedGroups: []string{groupAll.ID},
expectedFailure: true,
}
for _, tCase := range []testCase{testCase1, testCase2} { for _, tCase := range []testCase{testCase1, testCase2, testCase3} {
t.Run(tCase.name, func(t *testing.T) { t.Run(tCase.name, func(t *testing.T) {
key, err := manager.CreateSetupKey(context.Background(), account.Id, tCase.expectedKeyName, SetupKeyReusable, expiresIn, key, err := manager.CreateSetupKey(context.Background(), account.Id, tCase.expectedKeyName, SetupKeyReusable, expiresIn,
tCase.expectedGroups, SetupKeyUnlimitedUsage, userID, false) tCase.expectedGroups, SetupKeyUnlimitedUsage, userID, false)

View File

@ -31,14 +31,16 @@ import (
) )
const ( const (
storeSqliteFileName = "store.db" storeSqliteFileName = "store.db"
idQueryCondition = "id = ?" idQueryCondition = "id = ?"
accountAndIDQueryCondition = "account_id = ? and id = ?"
peerNotFoundFMT = "peer %s not found"
) )
// SqlStore represents an account storage backed by a Sql DB persisted to disk // SqlStore represents an account storage backed by a Sql DB persisted to disk
type SqlStore struct { type SqlStore struct {
db *gorm.DB db *gorm.DB
accountLocks sync.Map resourceLocks sync.Map
globalAccountLock sync.Mutex globalAccountLock sync.Mutex
metrics telemetry.AppMetrics metrics telemetry.AppMetrics
installationPK int installationPK int
@ -96,33 +98,35 @@ func (s *SqlStore) AcquireGlobalLock(ctx context.Context) (unlock func()) {
return unlock return unlock
} }
func (s *SqlStore) AcquireAccountWriteLock(ctx context.Context, accountID string) (unlock func()) { // AcquireWriteLockByUID acquires an ID lock for writing to a resource and returns a function that releases the lock
log.WithContext(ctx).Tracef("acquiring write lock for account %s", accountID) func (s *SqlStore) AcquireWriteLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
log.WithContext(ctx).Tracef("acquiring write lock for ID %s", uniqueID)
start := time.Now() start := time.Now()
value, _ := s.accountLocks.LoadOrStore(accountID, &sync.RWMutex{}) value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.RWMutex{})
mtx := value.(*sync.RWMutex) mtx := value.(*sync.RWMutex)
mtx.Lock() mtx.Lock()
unlock = func() { unlock = func() {
mtx.Unlock() mtx.Unlock()
log.WithContext(ctx).Tracef("released write lock for account %s in %v", accountID, time.Since(start)) log.WithContext(ctx).Tracef("released write lock for ID %s in %v", uniqueID, time.Since(start))
} }
return unlock return unlock
} }
func (s *SqlStore) AcquireAccountReadLock(ctx context.Context, accountID string) (unlock func()) { // AcquireReadLockByUID acquires an ID lock for writing to a resource and returns a function that releases the lock
log.WithContext(ctx).Tracef("acquiring read lock for account %s", accountID) func (s *SqlStore) AcquireReadLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
log.WithContext(ctx).Tracef("acquiring read lock for ID %s", uniqueID)
start := time.Now() start := time.Now()
value, _ := s.accountLocks.LoadOrStore(accountID, &sync.RWMutex{}) value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.RWMutex{})
mtx := value.(*sync.RWMutex) mtx := value.(*sync.RWMutex)
mtx.RLock() mtx.RLock()
unlock = func() { unlock = func() {
mtx.RUnlock() mtx.RUnlock()
log.WithContext(ctx).Tracef("released read lock for account %s in %v", accountID, time.Since(start)) log.WithContext(ctx).Tracef("released read lock for ID %s in %v", uniqueID, time.Since(start))
} }
return unlock return unlock
@ -271,6 +275,38 @@ func (s *SqlStore) GetInstallationID() string {
return installation.InstallationIDValue return installation.InstallationIDValue
} }
func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error {
// To maintain data integrity, we create a copy of the peer's to prevent unintended updates to other fields.
peerCopy := peer.Copy()
peerCopy.AccountID = accountID
err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// check if peer exists before saving
var peerID string
result := tx.Model(&nbpeer.Peer{}).Select("id").Find(&peerID, accountAndIDQueryCondition, accountID, peer.ID)
if result.Error != nil {
return result.Error
}
if peerID == "" {
return status.Errorf(status.NotFound, peerNotFoundFMT, peer.ID)
}
result = tx.Model(&nbpeer.Peer{}).Where(accountAndIDQueryCondition, accountID, peer.ID).Save(peerCopy)
if result.Error != nil {
return result.Error
}
return nil
})
if err != nil {
return err
}
return nil
}
func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error { func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
var peerCopy nbpeer.Peer var peerCopy nbpeer.Peer
peerCopy.Status = &peerStatus peerCopy.Status = &peerStatus
@ -281,14 +317,14 @@ func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.Pe
} }
result := s.db.Model(&nbpeer.Peer{}). result := s.db.Model(&nbpeer.Peer{}).
Select(fieldsToUpdate). Select(fieldsToUpdate).
Where("account_id = ? AND id = ?", accountID, peerID). Where(accountAndIDQueryCondition, accountID, peerID).
Updates(&peerCopy) Updates(&peerCopy)
if result.Error != nil { if result.Error != nil {
return result.Error return result.Error
} }
if result.RowsAffected == 0 { if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "peer %s not found", peerID) return status.Errorf(status.NotFound, peerNotFoundFMT, peerID)
} }
return nil return nil
@ -302,7 +338,7 @@ func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.P
peerCopy.Location = peerWithLocation.Location peerCopy.Location = peerWithLocation.Location
result := s.db.Model(&nbpeer.Peer{}). result := s.db.Model(&nbpeer.Peer{}).
Where("account_id = ? and id = ?", accountID, peerWithLocation.ID). Where(accountAndIDQueryCondition, accountID, peerWithLocation.ID).
Updates(peerCopy) Updates(peerCopy)
if result.Error != nil { if result.Error != nil {
@ -310,7 +346,7 @@ func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.P
} }
if result.RowsAffected == 0 { if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "peer %s not found", peerWithLocation.ID) return status.Errorf(status.NotFound, peerNotFoundFMT, peerWithLocation.ID)
} }
return nil return nil
@ -644,7 +680,7 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, accountID string) (*S
func (s *SqlStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error { func (s *SqlStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error {
var user User var user User
result := s.db.First(&user, "account_id = ? and id = ?", accountID, userID) result := s.db.First(&user, accountAndIDQueryCondition, accountID, userID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "user %s not found", userID) return status.Errorf(status.NotFound, "user %s not found", userID)

View File

@ -362,6 +362,54 @@ func TestSqlite_GetAccount(t *testing.T) {
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error") require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
} }
func TestSqlite_SavePeer(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStoreFromFile(t, "testdata/store.json")
account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b")
require.NoError(t, err)
// save status of non-existing peer
peer := &nbpeer.Peer{
Key: "peerkey",
ID: "testpeer",
SetupKey: "peerkeysetupkey",
IP: net.IP{127, 0, 0, 1},
Meta: nbpeer.PeerSystemMeta{Hostname: "testingpeer"},
Name: "peer name",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
}
ctx := context.Background()
err = store.SavePeer(ctx, account.Id, peer)
assert.Error(t, err)
parsedErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
// save new status of existing peer
account.Peers[peer.ID] = peer
err = store.SaveAccount(context.Background(), account)
require.NoError(t, err)
updatedPeer := peer.Copy()
updatedPeer.Status.Connected = false
updatedPeer.Meta.Hostname = "updatedpeer"
err = store.SavePeer(ctx, account.Id, updatedPeer)
require.NoError(t, err)
account, err = store.GetAccount(context.Background(), account.Id)
require.NoError(t, err)
actual := account.Peers[peer.ID]
assert.Equal(t, updatedPeer.Status, actual.Status)
assert.Equal(t, updatedPeer.Meta, actual.Meta)
}
func TestSqlite_SavePeerStatus(t *testing.T) { func TestSqlite_SavePeerStatus(t *testing.T) {
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet") t.Skip("The SQLite store is not properly supported by Windows yet")
@ -402,7 +450,19 @@ func TestSqlite_SavePeerStatus(t *testing.T) {
actual := account.Peers["testpeer"].Status actual := account.Peers["testpeer"].Status
assert.Equal(t, newStatus, *actual) assert.Equal(t, newStatus, *actual)
newStatus.Connected = true
err = store.SavePeerStatus(account.Id, "testpeer", newStatus)
require.NoError(t, err)
account, err = store.GetAccount(context.Background(), account.Id)
require.NoError(t, err)
actual = account.Peers["testpeer"].Status
assert.Equal(t, newStatus, *actual)
} }
func TestSqlite_SavePeerLocation(t *testing.T) { func TestSqlite_SavePeerLocation(t *testing.T) {
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet") t.Skip("The SQLite store is not properly supported by Windows yet")

View File

@ -95,3 +95,8 @@ func NewUserNotFoundError(userKey string) error {
func NewPeerNotRegisteredError() error { func NewPeerNotRegisteredError() error {
return Errorf(Unauthenticated, "peer is not registered") return Errorf(Unauthenticated, "peer is not registered")
} }
// NewPeerLoginExpiredError creates a new Error with PermissionDenied type for an expired peer
func NewPeerLoginExpiredError() error {
return Errorf(PermissionDenied, "peer login has expired, please log in once more")
}

View File

@ -12,10 +12,11 @@ import (
"strings" "strings"
"time" "time"
nbgroup "github.com/netbirdio/netbird/management/server/group"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"gorm.io/gorm" "gorm.io/gorm"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
@ -48,12 +49,13 @@ type Store interface {
DeleteTokenID2UserIDIndex(tokenID string) error DeleteTokenID2UserIDIndex(tokenID string) error
GetInstallationID() string GetInstallationID() string
SaveInstallationID(ctx context.Context, ID string) error SaveInstallationID(ctx context.Context, ID string) error
// AcquireAccountWriteLock should attempt to acquire account lock for write purposes and return a function that releases the lock // AcquireWriteLockByUID should attempt to acquire a lock for write purposes and return a function that releases the lock
AcquireAccountWriteLock(ctx context.Context, accountID string) func() AcquireWriteLockByUID(ctx context.Context, uniqueID string) func()
// AcquireAccountReadLock should attempt to acquire account lock for read purposes and return a function that releases the lock // AcquireReadLockByUID should attempt to acquire lock for read purposes and return a function that releases the lock
AcquireAccountReadLock(ctx context.Context, accountID string) func() AcquireReadLockByUID(ctx context.Context, uniqueID string) func()
// AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock // AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock
AcquireGlobalLock(ctx context.Context) func() AcquireGlobalLock(ctx context.Context) func()
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
SavePeerLocation(accountID string, peer *nbpeer.Peer) error SavePeerLocation(accountID string, peer *nbpeer.Peer) error
SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error

View File

@ -0,0 +1,69 @@
package telemetry
import (
"context"
"time"
"go.opentelemetry.io/otel/metric"
)
// AccountManagerMetrics represents all metrics related to the AccountManager
type AccountManagerMetrics struct {
ctx context.Context
updateAccountPeersDurationMs metric.Float64Histogram
getPeerNetworkMapDurationMs metric.Float64Histogram
networkMapObjectCount metric.Int64Histogram
}
// NewAccountManagerMetrics creates an instance of AccountManagerMetrics
func NewAccountManagerMetrics(ctx context.Context, meter metric.Meter) (*AccountManagerMetrics, error) {
updateAccountPeersDurationMs, err := meter.Float64Histogram("management.account.update.account.peers.duration.ms",
metric.WithUnit("milliseconds"),
metric.WithExplicitBucketBoundaries(
0.5, 1, 2.5, 5, 10, 25, 50, 100, 250, 500, 1000, 2500, 5000, 10000, 30000,
))
if err != nil {
return nil, err
}
getPeerNetworkMapDurationMs, err := meter.Float64Histogram("management.account.get.peer.network.map.duration.ms",
metric.WithUnit("milliseconds"),
metric.WithExplicitBucketBoundaries(
0.1, 0.5, 1, 2.5, 5, 10, 25, 50, 100, 250, 500, 1000,
))
if err != nil {
return nil, err
}
networkMapObjectCount, err := meter.Int64Histogram("management.account.network.map.object.count",
metric.WithUnit("objects"),
metric.WithExplicitBucketBoundaries(
50, 100, 200, 500, 1000, 2500, 5000, 10000,
))
if err != nil {
return nil, err
}
return &AccountManagerMetrics{
ctx: ctx,
getPeerNetworkMapDurationMs: getPeerNetworkMapDurationMs,
updateAccountPeersDurationMs: updateAccountPeersDurationMs,
networkMapObjectCount: networkMapObjectCount,
}, nil
}
// CountUpdateAccountPeersDuration counts the duration of updating account peers
func (metrics *AccountManagerMetrics) CountUpdateAccountPeersDuration(duration time.Duration) {
metrics.updateAccountPeersDurationMs.Record(metrics.ctx, float64(duration.Nanoseconds())/1e6)
}
// CountGetPeerNetworkMapDuration counts the duration of getting the peer network map
func (metrics *AccountManagerMetrics) CountGetPeerNetworkMapDuration(duration time.Duration) {
metrics.getPeerNetworkMapDurationMs.Record(metrics.ctx, float64(duration.Nanoseconds())/1e6)
}
// CountNetworkMapObjects counts the number of network map objects
func (metrics *AccountManagerMetrics) CountNetworkMapObjects(count int64) {
metrics.networkMapObjectCount.Record(metrics.ctx, count)
}

View File

@ -20,14 +20,15 @@ const defaultEndpoint = "/metrics"
// MockAppMetrics mocks the AppMetrics interface // MockAppMetrics mocks the AppMetrics interface
type MockAppMetrics struct { type MockAppMetrics struct {
GetMeterFunc func() metric2.Meter GetMeterFunc func() metric2.Meter
CloseFunc func() error CloseFunc func() error
ExposeFunc func(ctx context.Context, port int, endpoint string) error ExposeFunc func(ctx context.Context, port int, endpoint string) error
IDPMetricsFunc func() *IDPMetrics IDPMetricsFunc func() *IDPMetrics
HTTPMiddlewareFunc func() *HTTPMiddleware HTTPMiddlewareFunc func() *HTTPMiddleware
GRPCMetricsFunc func() *GRPCMetrics GRPCMetricsFunc func() *GRPCMetrics
StoreMetricsFunc func() *StoreMetrics StoreMetricsFunc func() *StoreMetrics
UpdateChannelMetricsFunc func() *UpdateChannelMetrics UpdateChannelMetricsFunc func() *UpdateChannelMetrics
AddAccountManagerMetricsFunc func() *AccountManagerMetrics
} }
// GetMeter mocks the GetMeter function of the AppMetrics interface // GetMeter mocks the GetMeter function of the AppMetrics interface
@ -94,6 +95,14 @@ func (mock *MockAppMetrics) UpdateChannelMetrics() *UpdateChannelMetrics {
return nil return nil
} }
// AccountManagerMetrics mocks the MockAppMetrics function of the AccountManagerMetrics interface
func (mock *MockAppMetrics) AccountManagerMetrics() *AccountManagerMetrics {
if mock.AddAccountManagerMetricsFunc != nil {
return mock.AddAccountManagerMetricsFunc()
}
return nil
}
// AppMetrics is metrics interface // AppMetrics is metrics interface
type AppMetrics interface { type AppMetrics interface {
GetMeter() metric2.Meter GetMeter() metric2.Meter
@ -104,19 +113,21 @@ type AppMetrics interface {
GRPCMetrics() *GRPCMetrics GRPCMetrics() *GRPCMetrics
StoreMetrics() *StoreMetrics StoreMetrics() *StoreMetrics
UpdateChannelMetrics() *UpdateChannelMetrics UpdateChannelMetrics() *UpdateChannelMetrics
AccountManagerMetrics() *AccountManagerMetrics
} }
// defaultAppMetrics are core application metrics based on OpenTelemetry https://opentelemetry.io/ // defaultAppMetrics are core application metrics based on OpenTelemetry https://opentelemetry.io/
type defaultAppMetrics struct { type defaultAppMetrics struct {
// Meter can be used by different application parts to create counters and measure things // Meter can be used by different application parts to create counters and measure things
Meter metric2.Meter Meter metric2.Meter
listener net.Listener listener net.Listener
ctx context.Context ctx context.Context
idpMetrics *IDPMetrics idpMetrics *IDPMetrics
httpMiddleware *HTTPMiddleware httpMiddleware *HTTPMiddleware
grpcMetrics *GRPCMetrics grpcMetrics *GRPCMetrics
storeMetrics *StoreMetrics storeMetrics *StoreMetrics
updateChannelMetrics *UpdateChannelMetrics updateChannelMetrics *UpdateChannelMetrics
accountManagerMetrics *AccountManagerMetrics
} }
// IDPMetrics returns metrics for the idp package // IDPMetrics returns metrics for the idp package
@ -144,6 +155,11 @@ func (appMetrics *defaultAppMetrics) UpdateChannelMetrics() *UpdateChannelMetric
return appMetrics.updateChannelMetrics return appMetrics.updateChannelMetrics
} }
// AccountManagerMetrics returns metrics for the account manager
func (appMetrics *defaultAppMetrics) AccountManagerMetrics() *AccountManagerMetrics {
return appMetrics.accountManagerMetrics
}
// Close stop application metrics HTTP handler and closes listener. // Close stop application metrics HTTP handler and closes listener.
func (appMetrics *defaultAppMetrics) Close() error { func (appMetrics *defaultAppMetrics) Close() error {
if appMetrics.listener == nil { if appMetrics.listener == nil {
@ -220,13 +236,19 @@ func NewDefaultAppMetrics(ctx context.Context) (AppMetrics, error) {
return nil, err return nil, err
} }
accountManagerMetrics, err := NewAccountManagerMetrics(ctx, meter)
if err != nil {
return nil, err
}
return &defaultAppMetrics{ return &defaultAppMetrics{
Meter: meter, Meter: meter,
ctx: ctx, ctx: ctx,
idpMetrics: idpMetrics, idpMetrics: idpMetrics,
httpMiddleware: middleware, httpMiddleware: middleware,
grpcMetrics: grpcMetrics, grpcMetrics: grpcMetrics,
storeMetrics: storeMetrics, storeMetrics: storeMetrics,
updateChannelMetrics: updateChannelMetrics, updateChannelMetrics: updateChannelMetrics,
accountManagerMetrics: accountManagerMetrics,
}, nil }, nil
} }

View File

@ -2,8 +2,8 @@ package server
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"slices"
"strings" "strings"
"time" "time"
@ -212,7 +212,7 @@ func NewOwnerUser(id string) *User {
// createServiceUser creates a new service user under the given account. // createServiceUser creates a new service user under the given account.
func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountID string, initiatorUserID string, role UserRole, serviceUserName string, nonDeletable bool, autoGroups []string) (*UserInfo, error) { func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountID string, initiatorUserID string, role UserRole, serviceUserName string, nonDeletable bool, autoGroups []string) (*UserInfo, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
@ -268,7 +268,7 @@ func (am *DefaultAccountManager) CreateUser(ctx context.Context, accountID, user
// inviteNewUser Invites a USer to a given account and creates reference in datastore // inviteNewUser Invites a USer to a given account and creates reference in datastore
func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, userID string, invite *UserInfo) (*UserInfo, error) { func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, userID string, invite *UserInfo) (*UserInfo, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
if am.idpManager == nil { if am.idpManager == nil {
@ -369,7 +369,7 @@ func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.A
return nil, fmt.Errorf("failed to get account with token claims %v", err) return nil, fmt.Errorf("failed to get account with token claims %v", err)
} }
unlock := am.Store.AcquireAccountWriteLock(ctx, account.Id) unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id)
defer unlock() defer unlock()
account, err = am.Store.GetAccount(ctx, account.Id) account, err = am.Store.GetAccount(ctx, account.Id)
@ -402,7 +402,7 @@ func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.A
// ListUsers returns lists of all users under the account. // ListUsers returns lists of all users under the account.
// It doesn't populate user information such as email or name. // It doesn't populate user information such as email or name.
func (am *DefaultAccountManager) ListUsers(ctx context.Context, accountID string) ([]*User, error) { func (am *DefaultAccountManager) ListUsers(ctx context.Context, accountID string) ([]*User, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
@ -429,7 +429,7 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init
if initiatorUserID == targetUserID { if initiatorUserID == targetUserID {
return status.Errorf(status.InvalidArgument, "self deletion is not allowed") return status.Errorf(status.InvalidArgument, "self deletion is not allowed")
} }
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
@ -473,68 +473,27 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init
} }
func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account *Account, initiatorUserID, targetUserID string) error { func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account *Account, initiatorUserID, targetUserID string) error {
tuEmail, tuName, err := am.getEmailAndNameOfTargetUser(ctx, account.Id, initiatorUserID, targetUserID) meta, err := am.prepareUserDeletion(ctx, account, initiatorUserID, targetUserID)
if err != nil {
log.WithContext(ctx).Errorf("failed to resolve email address: %s", err)
return err
}
if !isNil(am.idpManager) {
// Delete if the user already exists in the IdP.Necessary in cases where a user account
// was created where a user account was provisioned but the user did not sign in
_, err = am.idpManager.GetUserDataByID(ctx, targetUserID, idp.AppMetadata{WTAccountID: account.Id})
if err == nil {
err = am.deleteUserFromIDP(ctx, targetUserID, account.Id)
if err != nil {
log.WithContext(ctx).Debugf("failed to delete user from IDP: %s", targetUserID)
return err
}
} else {
log.WithContext(ctx).Debugf("skipped deleting user %s from IDP, error: %v", targetUserID, err)
}
}
userHasPeers, err := am.deleteUserPeers(ctx, initiatorUserID, targetUserID, account)
if err != nil { if err != nil {
return err return err
} }
u, err := account.FindUser(targetUserID)
if err != nil {
log.WithContext(ctx).Errorf("failed to find user %s for deletion, this should never happen: %s", targetUserID, err)
}
var tuCreatedAt time.Time
if u != nil {
tuCreatedAt = u.CreatedAt
}
delete(account.Users, targetUserID) delete(account.Users, targetUserID)
err = am.Store.SaveAccount(ctx, account) err = am.Store.SaveAccount(ctx, account)
if err != nil { if err != nil {
return err return err
} }
meta := map[string]any{"name": tuName, "email": tuEmail, "created_at": tuCreatedAt}
am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta) am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta)
am.updateAccountPeers(ctx, account)
if userHasPeers && account.Settings.GroupsPropagationEnabled {
am.updateAccountPeers(ctx, account)
}
return nil return nil
} }
// deleteUserPeers deletes all peers associated with the target user in the specified account. func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorUserID string, targetUserID string, account *Account) error {
func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorUserID string, targetUserID string, account *Account) (bool, error) {
peers, err := account.FindUserPeers(targetUserID) peers, err := account.FindUserPeers(targetUserID)
if err != nil { if err != nil {
return false, status.Errorf(status.Internal, "failed to find user peers") return status.Errorf(status.Internal, "failed to find user peers")
}
hadPeers := len(peers) > 0
if !hadPeers {
return false, nil
} }
peerIDs := make([]string, 0, len(peers)) peerIDs := make([]string, 0, len(peers))
@ -542,12 +501,12 @@ func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorU
peerIDs = append(peerIDs, peer.ID) peerIDs = append(peerIDs, peer.ID)
} }
return hadPeers, am.deletePeers(ctx, account, peerIDs, initiatorUserID) return am.deletePeers(ctx, account, peerIDs, initiatorUserID)
} }
// InviteUser resend invitations to users who haven't activated their accounts prior to the expiration period. // InviteUser resend invitations to users who haven't activated their accounts prior to the expiration period.
func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error { func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
if am.idpManager == nil { if am.idpManager == nil {
@ -587,7 +546,7 @@ func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID strin
// CreatePAT creates a new PAT for the given user // CreatePAT creates a new PAT for the given user
func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) { func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
if tokenName == "" { if tokenName == "" {
@ -637,7 +596,7 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string
// DeletePAT deletes a specific PAT from a user // DeletePAT deletes a specific PAT from a user
func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error { func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
@ -687,7 +646,7 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string
// GetPAT returns a specific PAT from a user // GetPAT returns a specific PAT from a user
func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) { func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
@ -719,7 +678,7 @@ func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, i
// GetAllPATs returns all PATs for a user // GetAllPATs returns all PATs for a user
func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) { func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
@ -761,7 +720,7 @@ func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, i
return nil, status.Errorf(status.InvalidArgument, "provided user update is nil") return nil, status.Errorf(status.InvalidArgument, "provided user update is nil")
} }
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
updatedUsers, err := am.SaveOrAddUsers(ctx, accountID, initiatorUserID, []*User{update}, addIfNotExists) updatedUsers, err := am.SaveOrAddUsers(ctx, accountID, initiatorUserID, []*User{update}, addIfNotExists)
@ -801,7 +760,6 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
updatedUsers := make([]*UserInfo, 0, len(updates)) updatedUsers := make([]*UserInfo, 0, len(updates))
var ( var (
expiredPeers []*nbpeer.Peer expiredPeers []*nbpeer.Peer
userIDs []string
eventsToStore []func() eventsToStore []func()
) )
@ -810,8 +768,6 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
return nil, status.Errorf(status.InvalidArgument, "provided user update is nil") return nil, status.Errorf(status.InvalidArgument, "provided user update is nil")
} }
userIDs = append(userIDs, update.Id)
oldUser := account.Users[update.Id] oldUser := account.Users[update.Id]
if oldUser == nil { if oldUser == nil {
if !addIfNotExists { if !addIfNotExists {
@ -871,11 +827,11 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
} }
account.Network.IncSerial() account.Network.IncSerial()
if err = am.Store.SaveUsers(account.Id, account.Users); err != nil { if err = am.Store.SaveAccount(ctx, account); err != nil {
return nil, err return nil, err
} }
if areUsersLinkedToPeers(account, userIDs) && account.Settings.GroupsPropagationEnabled { if account.Settings.GroupsPropagationEnabled {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, account)
} }
@ -988,10 +944,14 @@ func validateUserUpdate(account *Account, initiatorUser, oldUser, update *User)
} }
for _, newGroupID := range update.AutoGroups { for _, newGroupID := range update.AutoGroups {
if _, ok := account.Groups[newGroupID]; !ok { group, ok := account.Groups[newGroupID]
if !ok {
return status.Errorf(status.InvalidArgument, "provided group ID %s in the user %s update doesn't exist", return status.Errorf(status.InvalidArgument, "provided group ID %s in the user %s update doesn't exist",
newGroupID, update.Id) newGroupID, update.Id)
} }
if group.Name == "All" {
return status.Errorf(status.InvalidArgument, "can't add All group to the user")
}
} }
return nil return nil
@ -1202,14 +1162,114 @@ func (am *DefaultAccountManager) getEmailAndNameOfTargetUser(ctx context.Context
return "", "", fmt.Errorf("user info not found for user: %s", targetId) return "", "", fmt.Errorf("user info not found for user: %s", targetId)
} }
// areUsersLinkedToPeers checks if any of the given userIDs are linked to any of the peers in the account. // DeleteRegularUsers deletes regular users from an account.
func areUsersLinkedToPeers(account *Account, userIDs []string) bool { // Note: This function does not acquire the global lock.
for _, peer := range account.Peers { // It is the caller's responsibility to ensure proper locking is in place before invoking this method.
if slices.Contains(userIDs, peer.UserID) { //
return true // If an error occurs while deleting the user, the function skips it and continues deleting other users.
// Errors are collected and returned at the end.
func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string) error {
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
executingUser := account.Users[initiatorUserID]
if executingUser == nil {
return status.Errorf(status.NotFound, "user not found")
}
if !executingUser.HasAdminPower() {
return status.Errorf(status.PermissionDenied, "only users with admin power can delete users")
}
var allErrors error
deletedUsersMeta := make(map[string]map[string]any)
for _, targetUserID := range targetUserIDs {
if initiatorUserID == targetUserID {
allErrors = errors.Join(allErrors, errors.New("self deletion is not allowed"))
continue
}
targetUser := account.Users[targetUserID]
if targetUser == nil {
allErrors = errors.Join(allErrors, fmt.Errorf("target user: %s not found", targetUserID))
continue
}
if targetUser.Role == UserRoleOwner {
allErrors = errors.Join(allErrors, fmt.Errorf("unable to delete a user: %s with owner role", targetUserID))
continue
}
// disable deleting integration user if the initiator is not admin service user
if targetUser.Issued == UserIssuedIntegration && !executingUser.IsServiceUser {
allErrors = errors.Join(allErrors, errors.New("only integration service user can delete this user"))
continue
}
meta, err := am.prepareUserDeletion(ctx, account, initiatorUserID, targetUserID)
if err != nil {
allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete user %s: %s", targetUserID, err))
continue
}
delete(account.Users, targetUserID)
deletedUsersMeta[targetUserID] = meta
}
err = am.Store.SaveAccount(ctx, account)
if err != nil {
return fmt.Errorf("failed to delete users: %w", err)
}
am.updateAccountPeers(ctx, account)
for targetUserID, meta := range deletedUsersMeta {
am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta)
}
return allErrors
}
func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, account *Account, initiatorUserID, targetUserID string) (map[string]any, error) {
tuEmail, tuName, err := am.getEmailAndNameOfTargetUser(ctx, account.Id, initiatorUserID, targetUserID)
if err != nil {
log.WithContext(ctx).Errorf("failed to resolve email address: %s", err)
return nil, err
}
if !isNil(am.idpManager) {
// Delete if the user already exists in the IdP. Necessary in cases where a user account
// was created where a user account was provisioned but the user did not sign in
_, err = am.idpManager.GetUserDataByID(ctx, targetUserID, idp.AppMetadata{WTAccountID: account.Id})
if err == nil {
err = am.deleteUserFromIDP(ctx, targetUserID, account.Id)
if err != nil {
log.WithContext(ctx).Debugf("failed to delete user from IDP: %s", targetUserID)
return nil, err
}
} else {
log.WithContext(ctx).Debugf("skipped deleting user %s from IDP, error: %v", targetUserID, err)
} }
} }
return false
err = am.deleteUserPeers(ctx, initiatorUserID, targetUserID, account)
if err != nil {
return nil, err
}
u, err := account.FindUser(targetUserID)
if err != nil {
log.WithContext(ctx).Errorf("failed to find user %s for deletion, this should never happen: %s", targetUserID, err)
}
var tuCreatedAt time.Time
if u != nil {
tuCreatedAt = u.CreatedAt
}
return map[string]any{"name": tuName, "email": tuEmail, "created_at": tuCreatedAt}, nil
} }
func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) { func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) {

View File

@ -665,6 +665,157 @@ func TestUser_DeleteUser_regularUser(t *testing.T) {
} }
func TestUser_DeleteUser_RegularUsers(t *testing.T) {
store := newStore(t)
defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
targetId := "user2"
account.Users[targetId] = &User{
Id: targetId,
IsServiceUser: true,
ServiceUserName: "user2username",
}
targetId = "user3"
account.Users[targetId] = &User{
Id: targetId,
IsServiceUser: false,
Issued: UserIssuedAPI,
}
targetId = "user4"
account.Users[targetId] = &User{
Id: targetId,
IsServiceUser: false,
Issued: UserIssuedIntegration,
}
targetId = "user5"
account.Users[targetId] = &User{
Id: targetId,
IsServiceUser: false,
Issued: UserIssuedAPI,
Role: UserRoleOwner,
}
account.Users["user6"] = &User{
Id: "user6",
IsServiceUser: false,
Issued: UserIssuedAPI,
}
account.Users["user7"] = &User{
Id: "user7",
IsServiceUser: false,
Issued: UserIssuedAPI,
}
account.Users["user8"] = &User{
Id: "user8",
IsServiceUser: false,
Issued: UserIssuedAPI,
Role: UserRoleAdmin,
}
account.Users["user9"] = &User{
Id: "user9",
IsServiceUser: false,
Issued: UserIssuedAPI,
Role: UserRoleAdmin,
}
err := store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
integratedPeerValidator: MocIntegratedValidator{},
}
testCases := []struct {
name string
userIDs []string
expectedReasons []string
expectedDeleted []string
expectedNotDeleted []string
}{
{
name: "Delete service user successfully ",
userIDs: []string{"user2"},
expectedDeleted: []string{"user2"},
},
{
name: "Delete regular user successfully",
userIDs: []string{"user3"},
expectedDeleted: []string{"user3"},
},
{
name: "Delete integration regular user permission denied",
userIDs: []string{"user4"},
expectedReasons: []string{"only integration service user can delete this user"},
expectedNotDeleted: []string{"user4"},
},
{
name: "Delete user with owner role should return permission denied",
userIDs: []string{"user5"},
expectedReasons: []string{"unable to delete a user: user5 with owner role"},
expectedNotDeleted: []string{"user5"},
},
{
name: "Delete multiple users with mixed results",
userIDs: []string{"user5", "user5", "user6", "user7"},
expectedReasons: []string{"only integration service user can delete this user", "unable to delete a user: user5 with owner role"},
expectedDeleted: []string{"user6", "user7"},
expectedNotDeleted: []string{"user4", "user5"},
},
{
name: "Delete non-existent user",
userIDs: []string{"non-existent-user"},
expectedReasons: []string{"target user: non-existent-user not found"},
expectedNotDeleted: []string{},
},
{
name: "Delete multiple regular users successfully",
userIDs: []string{"user8", "user9"},
expectedDeleted: []string{"user8", "user9"},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err = am.DeleteRegularUsers(context.Background(), mockAccountID, mockUserID, tc.userIDs)
if len(tc.expectedReasons) > 0 {
assert.Error(t, err)
var foundExpectedErrors int
wrappedErr, ok := err.(interface{ Unwrap() []error })
assert.Equal(t, ok, true)
for _, e := range wrappedErr.Unwrap() {
assert.Contains(t, tc.expectedReasons, e.Error(), "unexpected error message")
foundExpectedErrors++
}
assert.Equal(t, len(tc.expectedReasons), foundExpectedErrors, "not all expected errors were found")
} else {
assert.NoError(t, err)
}
acc, err := am.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "")
assert.NoError(t, err)
for _, id := range tc.expectedDeleted {
_, exists := acc.Users[id]
assert.False(t, exists, "user should have been deleted: %s", id)
}
for _, id := range tc.expectedNotDeleted {
user, exists := acc.Users[id]
assert.True(t, exists, "user should not have been deleted: %s", id)
assert.NotNil(t, user, "user should exist: %s", id)
}
})
}
}
func TestDefaultAccountManager_GetUser(t *testing.T) { func TestDefaultAccountManager_GetUser(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())

Some files were not shown because too many files have changed in this diff Show More