diff --git a/.editorconfig b/.editorconfig
new file mode 100644
index 000000000..3dcb869d2
--- /dev/null
+++ b/.editorconfig
@@ -0,0 +1,8 @@
+root = true
+
+[*]
+end_of_line = lf
+insert_final_newline = true
+
+[*.go]
+indent_style = tab
diff --git a/.github/workflows/golang-test-freebsd.yml b/.github/workflows/golang-test-freebsd.yml
index 15fc6a729..4f13ee30e 100644
--- a/.github/workflows/golang-test-freebsd.yml
+++ b/.github/workflows/golang-test-freebsd.yml
@@ -13,7 +13,7 @@ concurrency:
jobs:
test:
- runs-on: ubuntu-latest
+ runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v4
- name: Test in FreeBSD
@@ -21,19 +21,26 @@ jobs:
uses: vmactions/freebsd-vm@v1
with:
usesh: true
+ copyback: false
+ release: "14.1"
prepare: |
- pkg install -y curl
- pkg install -y git
+ pkg install -y go
+ # -x - to print all executed commands
+ # -e - to faile on first error
run: |
- set -x
- curl -o go.tar.gz https://go.dev/dl/go1.21.11.freebsd-amd64.tar.gz -L
- tar zxf go.tar.gz
- mv go /usr/local/go
- ln -s /usr/local/go/bin/go /usr/local/bin/go
- go mod tidy
- go test -timeout 5m -p 1 ./iface/...
- go test -timeout 5m -p 1 ./client/...
- cd client
- go build .
- cd ..
\ No newline at end of file
+ set -e -x
+ time go build -o netbird client/main.go
+ # check all component except management, since we do not support management server on freebsd
+ time go test -timeout 1m -failfast ./base62/...
+ # NOTE: without -p1 `client/internal/dns` will fail becasue of `listen udp4 :33100: bind: address already in use`
+ time go test -timeout 8m -failfast -p 1 ./client/...
+ time go test -timeout 1m -failfast ./dns/...
+ time go test -timeout 1m -failfast ./encryption/...
+ time go test -timeout 1m -failfast ./formatter/...
+ time go test -timeout 1m -failfast ./iface/...
+ time go test -timeout 1m -failfast ./route/...
+ time go test -timeout 1m -failfast ./sharedsock/...
+ time go test -timeout 1m -failfast ./signal/...
+ time go test -timeout 1m -failfast ./util/...
+ time go test -timeout 1m -failfast ./version/...
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index 1889b58e7..30f24e92e 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -79,15 +79,8 @@ jobs:
- name: Install goversioninfo
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
- - name: Generate windows syso 386
- run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-build ${{ github.run_id }} -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -product-version ${{ steps.semver_parser.outputs.fullversion }}.${{ github.run_id }} -o client/resources_windows_386.syso
- - name: Generate windows syso arm
- run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-build ${{ github.run_id }} -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -product-version ${{ steps.semver_parser.outputs.fullversion }}.${{ github.run_id }} -o client/resources_windows_arm.syso
- - name: Generate windows syso arm64
- run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-build ${{ github.run_id }} -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -product-version ${{ steps.semver_parser.outputs.fullversion }}.${{ github.run_id }} -o client/resources_windows_arm64.syso
- name: Generate windows syso amd64
- run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-build ${{ github.run_id }} -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -product-version ${{ steps.semver_parser.outputs.fullversion }}.${{ github.run_id }} -o client/resources_windows_amd64.syso
-
+ run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso
- name: Run GoReleaser
uses: goreleaser/goreleaser-action@v4
with:
@@ -170,7 +163,7 @@ jobs:
- name: Install goversioninfo
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
- name: Generate windows syso amd64
- run: goversioninfo -64 -icon client/ui/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-build ${{ github.run_id }} -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -product-version ${{ steps.semver_parser.outputs.fullversion }}.${{ github.run_id }} -o client/ui/resources_windows_amd64.syso
+ run: goversioninfo -64 -icon client/ui/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso
- name: Run GoReleaser
uses: goreleaser/goreleaser-action@v4
diff --git a/.github/workflows/test-infrastructure-files.yml b/.github/workflows/test-infrastructure-files.yml
index abdd18ceb..52b8ee3e2 100644
--- a/.github/workflows/test-infrastructure-files.yml
+++ b/.github/workflows/test-infrastructure-files.yml
@@ -151,10 +151,10 @@ jobs:
- name: run docker compose up
working-directory: infrastructure_files/artifacts
run: |
- docker-compose up -d
+ docker compose up -d
sleep 5
- docker-compose ps
- docker-compose logs --tail=20
+ docker compose ps
+ docker compose logs --tail=20
- name: test running containers
run: |
@@ -207,7 +207,7 @@ jobs:
- name: Postgres run cleanup
run: |
- docker-compose down --volumes --rmi all
+ docker compose down --volumes --rmi all
rm -rf docker-compose.yml Caddyfile zitadel.env dashboard.env machinekey/zitadel-admin-sa.token turnserver.conf management.json zdb.env
- name: run script with Zitadel CockroachDB
diff --git a/README.md b/README.md
index 5be1826b4..370445412 100644
--- a/README.md
+++ b/README.md
@@ -10,10 +10,12 @@
+
+
+
-
diff --git a/client/anonymize/anonymize.go b/client/anonymize/anonymize.go
index acbd0441e..208e74d53 100644
--- a/client/anonymize/anonymize.go
+++ b/client/anonymize/anonymize.go
@@ -178,6 +178,21 @@ func (a *Anonymizer) AnonymizeDNSLogLine(logEntry string) string {
})
}
+// AnonymizeRoute anonymizes a route string by replacing IP addresses with anonymized versions and
+// domain names with random strings.
+func (a *Anonymizer) AnonymizeRoute(route string) string {
+ prefix, err := netip.ParsePrefix(route)
+ if err == nil {
+ ip := a.AnonymizeIPString(prefix.Addr().String())
+ return fmt.Sprintf("%s/%d", ip, prefix.Bits())
+ }
+ domains := strings.Split(route, ", ")
+ for i, domain := range domains {
+ domains[i] = a.AnonymizeDomain(domain)
+ }
+ return strings.Join(domains, ", ")
+}
+
func isWellKnown(addr netip.Addr) bool {
wellKnown := []string{
"8.8.8.8", "8.8.4.4", // Google DNS IPv4
diff --git a/client/cmd/debug.go b/client/cmd/debug.go
index da5e0945a..9abd2039d 100644
--- a/client/cmd/debug.go
+++ b/client/cmd/debug.go
@@ -5,6 +5,7 @@ import (
"fmt"
"time"
+ log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
@@ -13,6 +14,8 @@ import (
"github.com/netbirdio/netbird/client/server"
)
+const errCloseConnection = "Failed to close connection: %v"
+
var debugCmd = &cobra.Command{
Use: "debug",
Short: "Debugging commands",
@@ -63,12 +66,17 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
if err != nil {
return err
}
- defer conn.Close()
+ defer func() {
+ if err := conn.Close(); err != nil {
+ log.Errorf(errCloseConnection, err)
+ }
+ }()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
- Anonymize: anonymizeFlag,
- Status: getStatusOutput(cmd),
+ Anonymize: anonymizeFlag,
+ Status: getStatusOutput(cmd),
+ SystemInfo: debugSystemInfoFlag,
})
if err != nil {
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
@@ -84,7 +92,11 @@ func setLogLevel(cmd *cobra.Command, args []string) error {
if err != nil {
return err
}
- defer conn.Close()
+ defer func() {
+ if err := conn.Close(); err != nil {
+ log.Errorf(errCloseConnection, err)
+ }
+ }()
client := proto.NewDaemonServiceClient(conn)
level := server.ParseLogLevel(args[0])
@@ -113,7 +125,11 @@ func runForDuration(cmd *cobra.Command, args []string) error {
if err != nil {
return err
}
- defer conn.Close()
+ defer func() {
+ if err := conn.Close(); err != nil {
+ log.Errorf(errCloseConnection, err)
+ }
+ }()
client := proto.NewDaemonServiceClient(conn)
@@ -122,17 +138,20 @@ func runForDuration(cmd *cobra.Command, args []string) error {
return fmt.Errorf("failed to get status: %v", status.Convert(err).Message())
}
- restoreUp := stat.Status == string(internal.StatusConnected) || stat.Status == string(internal.StatusConnecting)
+ stateWasDown := stat.Status != string(internal.StatusConnected) && stat.Status != string(internal.StatusConnecting)
initialLogLevel, err := client.GetLogLevel(cmd.Context(), &proto.GetLogLevelRequest{})
if err != nil {
return fmt.Errorf("failed to get log level: %v", status.Convert(err).Message())
}
- if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
- return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
+ if stateWasDown {
+ if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
+ return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
+ }
+ cmd.Println("Netbird up")
+ time.Sleep(time.Second * 10)
}
- cmd.Println("Netbird down")
initialLevelTrace := initialLogLevel.GetLevel() >= proto.LogLevel_TRACE
if !initialLevelTrace {
@@ -145,6 +164,11 @@ func runForDuration(cmd *cobra.Command, args []string) error {
cmd.Println("Log level set to trace.")
}
+ if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
+ return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
+ }
+ cmd.Println("Netbird down")
+
time.Sleep(1 * time.Second)
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
@@ -162,21 +186,25 @@ func runForDuration(cmd *cobra.Command, args []string) error {
}
cmd.Println("\nDuration completed")
+ cmd.Println("Creating debug bundle...")
+
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd))
- if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
- return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
+ resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
+ Anonymize: anonymizeFlag,
+ Status: statusOutput,
+ SystemInfo: debugSystemInfoFlag,
+ })
+ if err != nil {
+ return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
}
- cmd.Println("Netbird down")
- time.Sleep(1 * time.Second)
-
- if restoreUp {
- if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
- return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
+ if stateWasDown {
+ if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
+ return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
}
- cmd.Println("Netbird up")
+ cmd.Println("Netbird down")
}
if !initialLevelTrace {
@@ -186,16 +214,6 @@ func runForDuration(cmd *cobra.Command, args []string) error {
cmd.Println("Log level restored to", initialLogLevel.GetLevel())
}
- cmd.Println("Creating debug bundle...")
-
- resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
- Anonymize: anonymizeFlag,
- Status: statusOutput,
- })
- if err != nil {
- return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
- }
-
cmd.Println(resp.GetPath())
return nil
diff --git a/client/cmd/root.go b/client/cmd/root.go
index 1e5c56366..db02ff5ea 100644
--- a/client/cmd/root.go
+++ b/client/cmd/root.go
@@ -37,6 +37,7 @@ const (
serverSSHAllowedFlag = "allow-server-ssh"
extraIFaceBlackListFlag = "extra-iface-blacklist"
dnsRouteIntervalFlag = "dns-router-interval"
+ systemInfoFlag = "system-info"
)
var (
@@ -69,6 +70,7 @@ var (
autoConnectDisabled bool
extraIFaceBlackList []string
anonymizeFlag bool
+ debugSystemInfoFlag bool
dnsRouteInterval time.Duration
rootCmd = &cobra.Command{
@@ -91,12 +93,15 @@ func init() {
oldDefaultConfigPathDir = "/etc/wiretrustee/"
oldDefaultLogFileDir = "/var/log/wiretrustee/"
- if runtime.GOOS == "windows" {
+ switch runtime.GOOS {
+ case "windows":
defaultConfigPathDir = os.Getenv("PROGRAMDATA") + "\\Netbird\\"
defaultLogFileDir = os.Getenv("PROGRAMDATA") + "\\Netbird\\"
oldDefaultConfigPathDir = os.Getenv("PROGRAMDATA") + "\\Wiretrustee\\"
oldDefaultLogFileDir = os.Getenv("PROGRAMDATA") + "\\Wiretrustee\\"
+ case "freebsd":
+ defaultConfigPathDir = "/var/db/netbird/"
}
defaultConfigPath = defaultConfigPathDir + "config.json"
@@ -165,6 +170,8 @@ func init() {
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
+
+ debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", false, "Adds system information to the debug bundle")
}
// SetupCloseHandler handles SIGTERM signal and exits with success
diff --git a/client/cmd/service_installer.go b/client/cmd/service_installer.go
index 5e147262b..99a4821b0 100644
--- a/client/cmd/service_installer.go
+++ b/client/cmd/service_installer.go
@@ -31,6 +31,8 @@ var installCmd = &cobra.Command{
configPath,
"--log-level",
logLevel,
+ "--daemon-addr",
+ daemonAddr,
}
if managementURL != "" {
diff --git a/client/cmd/status.go b/client/cmd/status.go
index 5ed2da301..1ef8b4913 100644
--- a/client/cmd/status.go
+++ b/client/cmd/status.go
@@ -810,7 +810,7 @@ func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) {
}
for i, route := range peer.Routes {
- peer.Routes[i] = anonymizeRoute(a, route)
+ peer.Routes[i] = a.AnonymizeRoute(route)
}
}
@@ -846,21 +846,8 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *statusOutputOverview)
}
for i, route := range overview.Routes {
- overview.Routes[i] = anonymizeRoute(a, route)
+ overview.Routes[i] = a.AnonymizeRoute(route)
}
overview.FQDN = a.AnonymizeDomain(overview.FQDN)
}
-
-func anonymizeRoute(a *anonymize.Anonymizer, route string) string {
- prefix, err := netip.ParsePrefix(route)
- if err == nil {
- ip := a.AnonymizeIPString(prefix.Addr().String())
- return fmt.Sprintf("%s/%d", ip, prefix.Bits())
- }
- domains := strings.Split(route, ", ")
- for i, domain := range domains {
- domains[i] = a.AnonymizeDomain(domain)
- }
- return strings.Join(domains, ", ")
-}
diff --git a/client/internal/auth/oauth.go b/client/internal/auth/oauth.go
index 7467584a3..c9f10ca86 100644
--- a/client/internal/auth/oauth.go
+++ b/client/internal/auth/oauth.go
@@ -69,6 +69,11 @@ func NewOAuthFlow(ctx context.Context, config *internal.Config, isLinuxDesktopCl
return authenticateWithDeviceCodeFlow(ctx, config)
}
+ // On FreeBSD we currently do not support desktop environments and offer only Device Code Flow (#2384)
+ if runtime.GOOS == "freebsd" {
+ return authenticateWithDeviceCodeFlow(ctx, config)
+ }
+
pkceFlow, err := authenticateWithPKCEFlow(ctx, config)
if err != nil {
// fallback to device code flow
diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go
index 267c1ed80..a4651ebb5 100644
--- a/client/internal/dns/server.go
+++ b/client/internal/dns/server.go
@@ -94,7 +94,7 @@ func NewDefaultServer(
var dnsService service
if wgInterface.IsUserspaceBind() {
- dnsService = newServiceViaMemory(wgInterface)
+ dnsService = NewServiceViaMemory(wgInterface)
} else {
dnsService = newServiceViaListener(wgInterface, addrPort)
}
@@ -112,7 +112,7 @@ func NewDefaultServerPermanentUpstream(
statusRecorder *peer.Status,
) *DefaultServer {
log.Debugf("host dns address list is: %v", hostsDnsList)
- ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface), statusRecorder)
+ ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder)
ds.hostsDNSHolder.set(hostsDnsList)
ds.permanent = true
ds.addHostRootZone()
@@ -130,7 +130,7 @@ func NewDefaultServerIos(
iosDnsManager IosDnsManager,
statusRecorder *peer.Status,
) *DefaultServer {
- ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface), statusRecorder)
+ ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder)
ds.iosDnsManager = iosDnsManager
return ds
}
diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go
index 6cbd9ea15..b9552bc17 100644
--- a/client/internal/dns/server_test.go
+++ b/client/internal/dns/server_test.go
@@ -534,7 +534,7 @@ func TestDNSServerStartStop(t *testing.T) {
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
hostManager := &mockHostConfigurator{}
server := DefaultServer{
- service: newServiceViaMemory(&mocWGIface{}),
+ service: NewServiceViaMemory(&mocWGIface{}),
localResolver: &localResolver{
registeredMap: make(registrationMap),
},
diff --git a/client/internal/dns/service_listener.go b/client/internal/dns/service_listener.go
index 89cf4daf6..e0f9da26f 100644
--- a/client/internal/dns/service_listener.go
+++ b/client/internal/dns/service_listener.go
@@ -128,6 +128,9 @@ func (s *serviceViaListener) RuntimeIP() string {
}
func (s *serviceViaListener) setListenerStatus(running bool) {
+ s.listenerFlagLock.Lock()
+ defer s.listenerFlagLock.Unlock()
+
s.listenerIsRunning = running
}
diff --git a/client/internal/dns/service_memory.go b/client/internal/dns/service_memory.go
index 757cd962a..729b90cc0 100644
--- a/client/internal/dns/service_memory.go
+++ b/client/internal/dns/service_memory.go
@@ -12,7 +12,7 @@ import (
log "github.com/sirupsen/logrus"
)
-type serviceViaMemory struct {
+type ServiceViaMemory struct {
wgInterface WGIface
dnsMux *dns.ServeMux
runtimeIP string
@@ -22,8 +22,8 @@ type serviceViaMemory struct {
listenerFlagLock sync.Mutex
}
-func newServiceViaMemory(wgIface WGIface) *serviceViaMemory {
- s := &serviceViaMemory{
+func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
+ s := &ServiceViaMemory{
wgInterface: wgIface,
dnsMux: dns.NewServeMux(),
@@ -33,7 +33,7 @@ func newServiceViaMemory(wgIface WGIface) *serviceViaMemory {
return s
}
-func (s *serviceViaMemory) Listen() error {
+func (s *ServiceViaMemory) Listen() error {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
@@ -52,7 +52,7 @@ func (s *serviceViaMemory) Listen() error {
return nil
}
-func (s *serviceViaMemory) Stop() {
+func (s *ServiceViaMemory) Stop() {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
@@ -67,23 +67,23 @@ func (s *serviceViaMemory) Stop() {
s.listenerIsRunning = false
}
-func (s *serviceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
+func (s *ServiceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
s.dnsMux.Handle(pattern, handler)
}
-func (s *serviceViaMemory) DeregisterMux(pattern string) {
+func (s *ServiceViaMemory) DeregisterMux(pattern string) {
s.dnsMux.HandleRemove(pattern)
}
-func (s *serviceViaMemory) RuntimePort() int {
+func (s *ServiceViaMemory) RuntimePort() int {
return s.runtimePort
}
-func (s *serviceViaMemory) RuntimeIP() string {
+func (s *ServiceViaMemory) RuntimeIP() string {
return s.runtimeIP
}
-func (s *serviceViaMemory) filterDNSTraffic() (string, error) {
+func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
filter := s.wgInterface.GetFilter()
if filter == nil {
return "", fmt.Errorf("can't set DNS filter, filter not initialized")
diff --git a/client/internal/dns/upstream_ios.go b/client/internal/dns/upstream_ios.go
index 0c01a013e..60ed79d87 100644
--- a/client/internal/dns/upstream_ios.go
+++ b/client/internal/dns/upstream_ios.go
@@ -4,6 +4,7 @@ package dns
import (
"context"
+ "fmt"
"net"
"syscall"
"time"
@@ -17,9 +18,9 @@ import (
type upstreamResolverIOS struct {
*upstreamResolverBase
- lIP net.IP
- lNet *net.IPNet
- iIndex int
+ lIP net.IP
+ lNet *net.IPNet
+ interfaceName string
}
func newUpstreamResolver(
@@ -32,17 +33,11 @@ func newUpstreamResolver(
) (*upstreamResolverIOS, error) {
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder)
- index, err := getInterfaceIndex(interfaceName)
- if err != nil {
- log.Debugf("unable to get interface index for %s: %s", interfaceName, err)
- return nil, err
- }
-
ios := &upstreamResolverIOS{
upstreamResolverBase: upstreamResolverBase,
lIP: ip,
lNet: net,
- iIndex: index,
+ interfaceName: interfaceName,
}
ios.upstreamClient = ios
@@ -53,7 +48,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
client := &dns.Client{}
upstreamHost, _, err := net.SplitHostPort(upstream)
if err != nil {
- log.Errorf("error while parsing upstream host: %s", err)
+ return nil, 0, fmt.Errorf("error while parsing upstream host: %s", err)
}
timeout := upstreamTimeout
@@ -65,26 +60,35 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
upstreamIP := net.ParseIP(upstreamHost)
if u.lNet.Contains(upstreamIP) || net.IP.IsPrivate(upstreamIP) {
log.Debugf("using private client to query upstream: %s", upstream)
- client = u.getClientPrivate(timeout)
+ client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout)
+ if err != nil {
+ return nil, 0, fmt.Errorf("error while creating private client: %s", err)
+ }
}
// Cannot use client.ExchangeContext because it overwrites our Dialer
return client.Exchange(r, upstream)
}
-// getClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
+// GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
// This method is needed for iOS
-func (u *upstreamResolverIOS) getClientPrivate(dialTimeout time.Duration) *dns.Client {
+func GetClientPrivate(ip net.IP, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
+ index, err := getInterfaceIndex(interfaceName)
+ if err != nil {
+ log.Debugf("unable to get interface index for %s: %s", interfaceName, err)
+ return nil, err
+ }
+
dialer := &net.Dialer{
LocalAddr: &net.UDPAddr{
- IP: u.lIP,
+ IP: ip,
Port: 0, // Let the OS pick a free port
},
Timeout: dialTimeout,
Control: func(network, address string, c syscall.RawConn) error {
var operr error
fn := func(s uintptr) {
- operr = unix.SetsockoptInt(int(s), unix.IPPROTO_IP, unix.IP_BOUND_IF, u.iIndex)
+ operr = unix.SetsockoptInt(int(s), unix.IPPROTO_IP, unix.IP_BOUND_IF, index)
}
if err := c.Control(fn); err != nil {
@@ -101,7 +105,7 @@ func (u *upstreamResolverIOS) getClientPrivate(dialTimeout time.Duration) *dns.C
client := &dns.Client{
Dialer: dialer,
}
- return client
+ return client, nil
}
func getInterfaceIndex(interfaceName string) (int, error) {
diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go
index 779a09c65..3dea7213c 100644
--- a/client/internal/routemanager/client.go
+++ b/client/internal/routemanager/client.go
@@ -10,6 +10,7 @@ import (
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
+ nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
@@ -64,7 +65,7 @@ func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration
routePeersNotifiers: make(map[string]chan struct{}),
routeUpdate: make(chan routesUpdate),
peerStateUpdate: make(chan struct{}),
- handler: handlerFromRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouteInterval, statusRecorder),
+ handler: handlerFromRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouteInterval, statusRecorder, wgInterface),
}
return client
}
@@ -377,9 +378,10 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
}
}
-func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status) RouteHandler {
+func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status, wgInterface *iface.WGIface) RouteHandler {
if rt.IsDynamic() {
- return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder)
+ dns := nbdns.NewServiceViaMemory(wgInterface)
+ return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder, wgInterface, fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()))
}
return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter)
}
diff --git a/client/internal/routemanager/dynamic/route.go b/client/internal/routemanager/dynamic/route.go
index 8429b4534..3296f3ddf 100644
--- a/client/internal/routemanager/dynamic/route.go
+++ b/client/internal/routemanager/dynamic/route.go
@@ -16,6 +16,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/util"
+ "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route"
)
@@ -47,6 +48,8 @@ type Route struct {
currentPeerKey string
cancel context.CancelFunc
statusRecorder *peer.Status
+ wgInterface *iface.WGIface
+ resolverAddr string
}
func NewRoute(
@@ -55,6 +58,8 @@ func NewRoute(
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
interval time.Duration,
statusRecorder *peer.Status,
+ wgInterface *iface.WGIface,
+ resolverAddr string,
) *Route {
return &Route{
route: rt,
@@ -63,6 +68,8 @@ func NewRoute(
interval: interval,
dynamicDomains: domainMap{},
statusRecorder: statusRecorder,
+ wgInterface: wgInterface,
+ resolverAddr: resolverAddr,
}
}
@@ -189,9 +196,14 @@ func (r *Route) startResolver(ctx context.Context) {
}
func (r *Route) update(ctx context.Context) error {
- if resolved, err := r.resolveDomains(); err != nil {
- return fmt.Errorf("resolve domains: %w", err)
- } else if err := r.updateDynamicRoutes(ctx, resolved); err != nil {
+ resolved, err := r.resolveDomains()
+ if err != nil {
+ if len(resolved) == 0 {
+ return fmt.Errorf("resolve domains: %w", err)
+ }
+ log.Warnf("Failed to resolve domains: %v", err)
+ }
+ if err := r.updateDynamicRoutes(ctx, resolved); err != nil {
return fmt.Errorf("update dynamic routes: %w", err)
}
@@ -223,11 +235,17 @@ func (r *Route) resolve(results chan resolveResult) {
wg.Add(1)
go func(domain domain.Domain) {
defer wg.Done()
- ips, err := net.LookupIP(string(domain))
+
+ ips, err := r.getIPsFromResolver(domain)
if err != nil {
- results <- resolveResult{domain: domain, err: fmt.Errorf("resolve d %s: %w", domain.SafeString(), err)}
- return
+ log.Tracef("Failed to resolve domain %s with private resolver: %v", domain.SafeString(), err)
+ ips, err = net.LookupIP(string(domain))
+ if err != nil {
+ results <- resolveResult{domain: domain, err: fmt.Errorf("resolve d %s: %w", domain.SafeString(), err)}
+ return
+ }
}
+
for _, ip := range ips {
prefix, err := util.GetPrefixFromIP(ip)
if err != nil {
diff --git a/client/internal/routemanager/dynamic/route_generic.go b/client/internal/routemanager/dynamic/route_generic.go
new file mode 100644
index 000000000..cf3d913a4
--- /dev/null
+++ b/client/internal/routemanager/dynamic/route_generic.go
@@ -0,0 +1,13 @@
+//go:build !ios
+
+package dynamic
+
+import (
+ "net"
+
+ "github.com/netbirdio/netbird/management/domain"
+)
+
+func (r *Route) getIPsFromResolver(domain domain.Domain) ([]net.IP, error) {
+ return net.LookupIP(string(domain))
+}
diff --git a/client/internal/routemanager/dynamic/route_ios.go b/client/internal/routemanager/dynamic/route_ios.go
new file mode 100644
index 000000000..67138222f
--- /dev/null
+++ b/client/internal/routemanager/dynamic/route_ios.go
@@ -0,0 +1,55 @@
+//go:build ios
+
+package dynamic
+
+import (
+ "fmt"
+ "net"
+ "time"
+
+ "github.com/miekg/dns"
+
+ nbdns "github.com/netbirdio/netbird/client/internal/dns"
+
+ "github.com/netbirdio/netbird/management/domain"
+)
+
+const dialTimeout = 10 * time.Second
+
+func (r *Route) getIPsFromResolver(domain domain.Domain) ([]net.IP, error) {
+ privateClient, err := nbdns.GetClientPrivate(r.wgInterface.Address().IP, r.wgInterface.Name(), dialTimeout)
+ if err != nil {
+ return nil, fmt.Errorf("error while creating private client: %s", err)
+ }
+
+ msg := new(dns.Msg)
+ msg.SetQuestion(dns.Fqdn(string(domain)), dns.TypeA)
+
+ startTime := time.Now()
+
+ response, _, err := privateClient.Exchange(msg, r.resolverAddr)
+ if err != nil {
+ return nil, fmt.Errorf("DNS query for %s failed after %s: %s ", domain.SafeString(), time.Since(startTime), err)
+ }
+
+ if response.Rcode != dns.RcodeSuccess {
+ return nil, fmt.Errorf("dns response code: %s", dns.RcodeToString[response.Rcode])
+ }
+
+ ips := make([]net.IP, 0)
+
+ for _, answ := range response.Answer {
+ if aRecord, ok := answ.(*dns.A); ok {
+ ips = append(ips, aRecord.A)
+ }
+ if aaaaRecord, ok := answ.(*dns.AAAA); ok {
+ ips = append(ips, aaaaRecord.AAAA)
+ }
+ }
+
+ if len(ips) == 0 {
+ return nil, fmt.Errorf("no A or AAAA records found for %s", domain.SafeString())
+ }
+
+ return ips, nil
+}
diff --git a/client/internal/routemanager/systemops/systemops_bsd.go b/client/internal/routemanager/systemops/systemops_bsd.go
index b7fb554db..5e3b20a86 100644
--- a/client/internal/routemanager/systemops/systemops_bsd.go
+++ b/client/internal/routemanager/systemops/systemops_bsd.go
@@ -22,7 +22,7 @@ type Route struct {
Interface *net.Interface
}
-func getRoutesFromTable() ([]netip.Prefix, error) {
+func GetRoutesFromTable() ([]netip.Prefix, error) {
tab, err := retryFetchRIB()
if err != nil {
return nil, fmt.Errorf("fetch RIB: %v", err)
diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go
index 4190debf9..d76824c10 100644
--- a/client/internal/routemanager/systemops/systemops_generic.go
+++ b/client/internal/routemanager/systemops/systemops_generic.go
@@ -427,7 +427,7 @@ func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) {
}
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
- routes, err := getRoutesFromTable()
+ routes, err := GetRoutesFromTable()
if err != nil {
return false, fmt.Errorf("get routes from table: %w", err)
}
@@ -440,7 +440,7 @@ func existsInRouteTable(prefix netip.Prefix) (bool, error) {
}
func isSubRange(prefix netip.Prefix) (bool, error) {
- routes, err := getRoutesFromTable()
+ routes, err := GetRoutesFromTable()
if err != nil {
return false, fmt.Errorf("get routes from table: %w", err)
}
diff --git a/client/internal/routemanager/systemops/systemops_linux.go b/client/internal/routemanager/systemops/systemops_linux.go
index c4f69fba5..2d0c57826 100644
--- a/client/internal/routemanager/systemops/systemops_linux.go
+++ b/client/internal/routemanager/systemops/systemops_linux.go
@@ -206,7 +206,7 @@ func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error
return nil
}
-func getRoutesFromTable() ([]netip.Prefix, error) {
+func GetRoutesFromTable() ([]netip.Prefix, error) {
v4Routes, err := getRoutes(syscall.RT_TABLE_MAIN, netlink.FAMILY_V4)
if err != nil {
return nil, fmt.Errorf("get v4 routes: %w", err)
@@ -504,7 +504,7 @@ func getAddressFamily(prefix netip.Prefix) int {
func hasSeparateRouting() ([]netip.Prefix, error) {
if isLegacy() {
- return getRoutesFromTable()
+ return GetRoutesFromTable()
}
return nil, ErrRoutingIsSeparate
}
diff --git a/client/internal/routemanager/systemops/systemops_nonlinux.go b/client/internal/routemanager/systemops/systemops_nonlinux.go
index 0adeb0992..3b52fc7af 100644
--- a/client/internal/routemanager/systemops/systemops_nonlinux.go
+++ b/client/internal/routemanager/systemops/systemops_nonlinux.go
@@ -24,5 +24,5 @@ func EnableIPForwarding() error {
}
func hasSeparateRouting() ([]netip.Prefix, error) {
- return getRoutesFromTable()
+ return GetRoutesFromTable()
}
diff --git a/client/internal/routemanager/systemops/systemops_windows.go b/client/internal/routemanager/systemops/systemops_windows.go
index 88bdce7c9..0d3630cb8 100644
--- a/client/internal/routemanager/systemops/systemops_windows.go
+++ b/client/internal/routemanager/systemops/systemops_windows.go
@@ -94,7 +94,7 @@ func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) erro
return nil
}
-func getRoutesFromTable() ([]netip.Prefix, error) {
+func GetRoutesFromTable() ([]netip.Prefix, error) {
mux.Lock()
defer mux.Unlock()
diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto
index 1d7388bfd..384bc0e62 100644
--- a/client/proto/daemon.proto
+++ b/client/proto/daemon.proto
@@ -263,6 +263,7 @@ message Route {
message DebugBundleRequest {
bool anonymize = 1;
string status = 2;
+ bool systemInfo = 3;
}
message DebugBundleResponse {
diff --git a/client/server/debug.go b/client/server/debug.go
index dcefb66ca..5ed43293b 100644
--- a/client/server/debug.go
+++ b/client/server/debug.go
@@ -1,3 +1,5 @@
+//go:build !android && !ios
+
package server
import (
@@ -6,16 +8,70 @@ import (
"context"
"fmt"
"io"
+ "net"
+ "net/netip"
"os"
+ "sort"
"strings"
+ "time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/internal/peer"
+ "github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/proto"
)
+const readmeContent = `Netbird debug bundle
+This debug bundle contains the following files:
+
+status.txt: Anonymized status information of the NetBird client.
+client.log: Most recent, anonymized log file of the NetBird client.
+routes.txt: Anonymized system routes, if --system-info flag was provided.
+interfaces.txt: Anonymized network interface information, if --system-info flag was provided.
+config.txt: Anonymized configuration information of the NetBird client.
+
+
+Anonymization Process
+The files in this bundle have been anonymized to protect sensitive information. Here's how the anonymization was applied:
+
+IP Addresses
+
+IPv4 addresses are replaced with addresses starting from 192.51.100.0
+IPv6 addresses are replaced with addresses starting from 100::
+
+IP addresses from non public ranges and well known addresses are not anonymized (e.g. 8.8.8.8, 100.64.0.0/10, addresses starting with 192.168., 172.16., 10., etc.).
+Reoccuring IP addresses are replaced with the same anonymized address.
+
+Note: The anonymized IP addresses in the status file do not match those in the log and routes files. However, the anonymized IP addresses are consistent within the status file and across the routes and log files.
+
+Domains
+All domain names (except for the netbird domains) are replaced with randomly generated strings ending in ".domain". Anonymized domains are consistent across all files in the bundle.
+Reoccuring domain names are replaced with the same anonymized domain.
+
+Routes
+For anonymized routes, the IP addresses are replaced as described above. The prefix length remains unchanged. Note that for prefixes, the anonymized IP might not be a network address, but the prefix length is still correct.
+Network Interfaces
+The interfaces.txt file contains information about network interfaces, including:
+- Interface name
+- Interface index
+- MTU (Maximum Transmission Unit)
+- Flags
+- IP addresses associated with each interface
+
+The IP addresses in the interfaces file are anonymized using the same process as described above. Interface names, indexes, MTUs, and flags are not anonymized.
+
+Configuration
+The config.txt file contains anonymized configuration information of the NetBird client. Sensitive information such as private keys and SSH keys are excluded. The following fields are anonymized:
+- ManagementURL
+- AdminURL
+- NATExternalIPs
+- CustomDNSAddress
+
+Other non-sensitive configuration options are included without anonymization.
+`
+
// DebugBundle creates a debug bundle and returns the location.
func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (resp *proto.DebugBundleResponse, err error) {
s.mutex.Lock()
@@ -30,93 +86,211 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
return nil, fmt.Errorf("create zip file: %w", err)
}
defer func() {
- if err := bundlePath.Close(); err != nil {
- log.Errorf("failed to close zip file: %v", err)
+ if closeErr := bundlePath.Close(); closeErr != nil && err == nil {
+ err = fmt.Errorf("close zip file: %w", closeErr)
}
if err != nil {
- if err2 := os.Remove(bundlePath.Name()); err2 != nil {
- log.Errorf("Failed to remove zip file: %v", err2)
+ if removeErr := os.Remove(bundlePath.Name()); removeErr != nil {
+ log.Errorf("Failed to remove zip file: %v", removeErr)
}
}
}()
- archive := zip.NewWriter(bundlePath)
- defer func() {
- if err := archive.Close(); err != nil {
- log.Errorf("failed to close archive writer: %v", err)
- }
- }()
-
- if status := req.GetStatus(); status != "" {
- filename := "status.txt"
- if req.GetAnonymize() {
- filename = "status.anon.txt"
- }
- statusReader := strings.NewReader(status)
- if err := addFileToZip(archive, statusReader, filename); err != nil {
- return nil, fmt.Errorf("add status file to zip: %w", err)
- }
- }
-
- logFile, err := os.Open(s.logFile)
- if err != nil {
- return nil, fmt.Errorf("open log file: %w", err)
- }
- defer func() {
- if err := logFile.Close(); err != nil {
- log.Errorf("failed to close original log file: %v", err)
- }
- }()
-
- filename := "client.log.txt"
- var logReader io.Reader
- errChan := make(chan error, 1)
- if req.GetAnonymize() {
- filename = "client.anon.log.txt"
- var writer io.WriteCloser
- logReader, writer = io.Pipe()
-
- go s.anonymize(logFile, writer, errChan)
- } else {
- logReader = logFile
- }
- if err := addFileToZip(archive, logReader, filename); err != nil {
- return nil, fmt.Errorf("add log file to zip: %w", err)
- }
-
- select {
- case err := <-errChan:
- if err != nil {
- return nil, err
- }
- default:
+ if err := s.createArchive(bundlePath, req); err != nil {
+ return nil, err
}
return &proto.DebugBundleResponse{Path: bundlePath.Name()}, nil
}
-func (s *Server) anonymize(reader io.Reader, writer io.WriteCloser, errChan chan<- error) {
- scanner := bufio.NewScanner(reader)
- anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
+func (s *Server) createArchive(bundlePath *os.File, req *proto.DebugBundleRequest) error {
+ archive := zip.NewWriter(bundlePath)
+ if err := s.addReadme(req, archive); err != nil {
+ return fmt.Errorf("add readme: %w", err)
+ }
+ if err := s.addStatus(req, archive); err != nil {
+ return fmt.Errorf("add status: %w", err)
+ }
+
+ anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
status := s.statusRecorder.GetFullStatus()
seedFromStatus(anonymizer, &status)
+ if err := s.addConfig(req, anonymizer, archive); err != nil {
+ return fmt.Errorf("add config: %w", err)
+ }
+
+ if req.GetSystemInfo() {
+ if err := s.addRoutes(req, anonymizer, archive); err != nil {
+ return fmt.Errorf("add routes: %w", err)
+ }
+
+ if err := s.addInterfaces(req, anonymizer, archive); err != nil {
+ return fmt.Errorf("add interfaces: %w", err)
+ }
+ }
+
+ if err := s.addLogfile(req, anonymizer, archive); err != nil {
+ return fmt.Errorf("add log file: %w", err)
+ }
+
+ if err := archive.Close(); err != nil {
+ return fmt.Errorf("close archive writer: %w", err)
+ }
+ return nil
+}
+
+func (s *Server) addReadme(req *proto.DebugBundleRequest, archive *zip.Writer) error {
+ if req.GetAnonymize() {
+ readmeReader := strings.NewReader(readmeContent)
+ if err := addFileToZip(archive, readmeReader, "README.txt"); err != nil {
+ return fmt.Errorf("add README file to zip: %w", err)
+ }
+ }
+ return nil
+}
+
+func (s *Server) addStatus(req *proto.DebugBundleRequest, archive *zip.Writer) error {
+ if status := req.GetStatus(); status != "" {
+ statusReader := strings.NewReader(status)
+ if err := addFileToZip(archive, statusReader, "status.txt"); err != nil {
+ return fmt.Errorf("add status file to zip: %w", err)
+ }
+ }
+ return nil
+}
+
+func (s *Server) addConfig(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
+ var configContent strings.Builder
+ s.addCommonConfigFields(&configContent)
+
+ if req.GetAnonymize() {
+ if s.config.ManagementURL != nil {
+ configContent.WriteString(fmt.Sprintf("ManagementURL: %s\n", anonymizer.AnonymizeURI(s.config.ManagementURL.String())))
+ }
+ if s.config.AdminURL != nil {
+ configContent.WriteString(fmt.Sprintf("AdminURL: %s\n", anonymizer.AnonymizeURI(s.config.AdminURL.String())))
+ }
+ configContent.WriteString(fmt.Sprintf("NATExternalIPs: %v\n", anonymizeNATExternalIPs(s.config.NATExternalIPs, anonymizer)))
+ if s.config.CustomDNSAddress != "" {
+ configContent.WriteString(fmt.Sprintf("CustomDNSAddress: %s\n", anonymizer.AnonymizeString(s.config.CustomDNSAddress)))
+ }
+ } else {
+ if s.config.ManagementURL != nil {
+ configContent.WriteString(fmt.Sprintf("ManagementURL: %s\n", s.config.ManagementURL.String()))
+ }
+ if s.config.AdminURL != nil {
+ configContent.WriteString(fmt.Sprintf("AdminURL: %s\n", s.config.AdminURL.String()))
+ }
+ configContent.WriteString(fmt.Sprintf("NATExternalIPs: %v\n", s.config.NATExternalIPs))
+ if s.config.CustomDNSAddress != "" {
+ configContent.WriteString(fmt.Sprintf("CustomDNSAddress: %s\n", s.config.CustomDNSAddress))
+ }
+ }
+
+ // Add config content to zip file
+ configReader := strings.NewReader(configContent.String())
+ if err := addFileToZip(archive, configReader, "config.txt"); err != nil {
+ return fmt.Errorf("add config file to zip: %w", err)
+ }
+
+ return nil
+}
+
+func (s *Server) addCommonConfigFields(configContent *strings.Builder) {
+ configContent.WriteString("NetBird Client Configuration:\n\n")
+
+ // Add non-sensitive fields
+ configContent.WriteString(fmt.Sprintf("WgIface: %s\n", s.config.WgIface))
+ configContent.WriteString(fmt.Sprintf("WgPort: %d\n", s.config.WgPort))
+ if s.config.NetworkMonitor != nil {
+ configContent.WriteString(fmt.Sprintf("NetworkMonitor: %v\n", *s.config.NetworkMonitor))
+ }
+ configContent.WriteString(fmt.Sprintf("IFaceBlackList: %v\n", s.config.IFaceBlackList))
+ configContent.WriteString(fmt.Sprintf("DisableIPv6Discovery: %v\n", s.config.DisableIPv6Discovery))
+ configContent.WriteString(fmt.Sprintf("RosenpassEnabled: %v\n", s.config.RosenpassEnabled))
+ configContent.WriteString(fmt.Sprintf("RosenpassPermissive: %v\n", s.config.RosenpassPermissive))
+ if s.config.ServerSSHAllowed != nil {
+ configContent.WriteString(fmt.Sprintf("ServerSSHAllowed: %v\n", *s.config.ServerSSHAllowed))
+ }
+ configContent.WriteString(fmt.Sprintf("DisableAutoConnect: %v\n", s.config.DisableAutoConnect))
+ configContent.WriteString(fmt.Sprintf("DNSRouteInterval: %s\n", s.config.DNSRouteInterval))
+}
+
+func (s *Server) addRoutes(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
+ if routes, err := systemops.GetRoutesFromTable(); err != nil {
+ log.Errorf("Failed to get routes: %v", err)
+ } else {
+ // TODO: get routes including nexthop
+ routesContent := formatRoutes(routes, req.GetAnonymize(), anonymizer)
+ routesReader := strings.NewReader(routesContent)
+ if err := addFileToZip(archive, routesReader, "routes.txt"); err != nil {
+ return fmt.Errorf("add routes file to zip: %w", err)
+ }
+ }
+ return nil
+}
+
+func (s *Server) addInterfaces(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
+ interfaces, err := net.Interfaces()
+ if err != nil {
+ return fmt.Errorf("get interfaces: %w", err)
+ }
+
+ interfacesContent := formatInterfaces(interfaces, req.GetAnonymize(), anonymizer)
+ interfacesReader := strings.NewReader(interfacesContent)
+ if err := addFileToZip(archive, interfacesReader, "interfaces.txt"); err != nil {
+ return fmt.Errorf("add interfaces file to zip: %w", err)
+ }
+
+ return nil
+}
+
+func (s *Server) addLogfile(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) (err error) {
+ logFile, err := os.Open(s.logFile)
+ if err != nil {
+ return fmt.Errorf("open log file: %w", err)
+ }
defer func() {
- if err := writer.Close(); err != nil {
- log.Errorf("Failed to close writer: %v", err)
+ if err := logFile.Close(); err != nil {
+ log.Errorf("Failed to close original log file: %v", err)
}
}()
+
+ var logReader io.Reader
+ if req.GetAnonymize() {
+ var writer *io.PipeWriter
+ logReader, writer = io.Pipe()
+
+ go s.anonymize(logFile, writer, anonymizer)
+ } else {
+ logReader = logFile
+ }
+ if err := addFileToZip(archive, logReader, "client.log"); err != nil {
+ return fmt.Errorf("add log file to zip: %w", err)
+ }
+
+ return nil
+}
+
+func (s *Server) anonymize(reader io.Reader, writer *io.PipeWriter, anonymizer *anonymize.Anonymizer) {
+ defer func() {
+ // always nil
+ _ = writer.Close()
+ }()
+
+ scanner := bufio.NewScanner(reader)
for scanner.Scan() {
line := anonymizer.AnonymizeString(scanner.Text())
if _, err := writer.Write([]byte(line + "\n")); err != nil {
- errChan <- fmt.Errorf("write line to writer: %w", err)
+ writer.CloseWithError(fmt.Errorf("anonymize write: %w", err))
return
}
}
if err := scanner.Err(); err != nil {
- errChan <- fmt.Errorf("read line from scanner: %w", err)
+ writer.CloseWithError(fmt.Errorf("anonymize scan: %w", err))
return
}
}
@@ -141,8 +315,22 @@ func (s *Server) SetLogLevel(_ context.Context, req *proto.SetLogLevelRequest) (
func addFileToZip(archive *zip.Writer, reader io.Reader, filename string) error {
header := &zip.FileHeader{
- Name: filename,
- Method: zip.Deflate,
+ Name: filename,
+ Method: zip.Deflate,
+ Modified: time.Now(),
+
+ CreatorVersion: 20, // Version 2.0
+ ReaderVersion: 20, // Version 2.0
+ Flags: 0x800, // UTF-8 filename
+ }
+
+ // If the reader is a file, we can get more accurate information
+ if f, ok := reader.(*os.File); ok {
+ if stat, err := f.Stat(); err != nil {
+ log.Tracef("Failed to get file stat for %s: %v", filename, err)
+ } else {
+ header.Modified = stat.ModTime()
+ }
}
writer, err := archive.CreateHeader(header)
@@ -165,6 +353,13 @@ func seedFromStatus(a *anonymize.Anonymizer, status *peer.FullStatus) {
for _, peer := range status.Peers {
a.AnonymizeDomain(peer.FQDN)
+ for route := range peer.GetRoutes() {
+ a.AnonymizeRoute(route)
+ }
+ }
+
+ for route := range status.LocalPeerState.Routes {
+ a.AnonymizeRoute(route)
}
for _, nsGroup := range status.NSGroupStates {
@@ -179,3 +374,113 @@ func seedFromStatus(a *anonymize.Anonymizer, status *peer.FullStatus) {
}
}
}
+
+func formatRoutes(routes []netip.Prefix, anonymize bool, anonymizer *anonymize.Anonymizer) string {
+ var ipv4Routes, ipv6Routes []netip.Prefix
+
+ // Separate IPv4 and IPv6 routes
+ for _, route := range routes {
+ if route.Addr().Is4() {
+ ipv4Routes = append(ipv4Routes, route)
+ } else {
+ ipv6Routes = append(ipv6Routes, route)
+ }
+ }
+
+ // Sort IPv4 and IPv6 routes separately
+ sort.Slice(ipv4Routes, func(i, j int) bool {
+ return ipv4Routes[i].Bits() > ipv4Routes[j].Bits()
+ })
+ sort.Slice(ipv6Routes, func(i, j int) bool {
+ return ipv6Routes[i].Bits() > ipv6Routes[j].Bits()
+ })
+
+ var builder strings.Builder
+
+ // Format IPv4 routes
+ builder.WriteString("IPv4 Routes:\n")
+ for _, route := range ipv4Routes {
+ formatRoute(&builder, route, anonymize, anonymizer)
+ }
+
+ // Format IPv6 routes
+ builder.WriteString("\nIPv6 Routes:\n")
+ for _, route := range ipv6Routes {
+ formatRoute(&builder, route, anonymize, anonymizer)
+ }
+
+ return builder.String()
+}
+
+func formatRoute(builder *strings.Builder, route netip.Prefix, anonymize bool, anonymizer *anonymize.Anonymizer) {
+ if anonymize {
+ anonymizedIP := anonymizer.AnonymizeIP(route.Addr())
+ builder.WriteString(fmt.Sprintf("%s/%d\n", anonymizedIP, route.Bits()))
+ } else {
+ builder.WriteString(fmt.Sprintf("%s\n", route))
+ }
+}
+
+func formatInterfaces(interfaces []net.Interface, anonymize bool, anonymizer *anonymize.Anonymizer) string {
+ sort.Slice(interfaces, func(i, j int) bool {
+ return interfaces[i].Name < interfaces[j].Name
+ })
+
+ var builder strings.Builder
+ builder.WriteString("Network Interfaces:\n")
+
+ for _, iface := range interfaces {
+ builder.WriteString(fmt.Sprintf("\nInterface: %s\n", iface.Name))
+ builder.WriteString(fmt.Sprintf(" Index: %d\n", iface.Index))
+ builder.WriteString(fmt.Sprintf(" MTU: %d\n", iface.MTU))
+ builder.WriteString(fmt.Sprintf(" Flags: %v\n", iface.Flags))
+
+ addrs, err := iface.Addrs()
+ if err != nil {
+ builder.WriteString(fmt.Sprintf(" Addresses: Error retrieving addresses: %v\n", err))
+ } else {
+ builder.WriteString(" Addresses:\n")
+ for _, addr := range addrs {
+ prefix, err := netip.ParsePrefix(addr.String())
+ if err != nil {
+ builder.WriteString(fmt.Sprintf(" Error parsing address: %v\n", err))
+ continue
+ }
+ ip := prefix.Addr()
+ if anonymize {
+ ip = anonymizer.AnonymizeIP(ip)
+ }
+ builder.WriteString(fmt.Sprintf(" %s/%d\n", ip, prefix.Bits()))
+ }
+ }
+ }
+
+ return builder.String()
+}
+
+func anonymizeNATExternalIPs(ips []string, anonymizer *anonymize.Anonymizer) []string {
+ anonymizedIPs := make([]string, len(ips))
+ for i, ip := range ips {
+ parts := strings.SplitN(ip, "/", 2)
+
+ ip1, err := netip.ParseAddr(parts[0])
+ if err != nil {
+ anonymizedIPs[i] = ip
+ continue
+ }
+ ip1anon := anonymizer.AnonymizeIP(ip1)
+
+ if len(parts) == 2 {
+ ip2, err := netip.ParseAddr(parts[1])
+ if err != nil {
+ anonymizedIPs[i] = fmt.Sprintf("%s/%s", ip1anon, parts[1])
+ } else {
+ ip2anon := anonymizer.AnonymizeIP(ip2)
+ anonymizedIPs[i] = fmt.Sprintf("%s/%s", ip1anon, ip2anon)
+ }
+ } else {
+ anonymizedIPs[i] = ip1anon.String()
+ }
+ }
+ return anonymizedIPs
+}
diff --git a/client/system/info_darwin_test.go b/client/system/info_darwin_test.go
index 94e0b9e5e..5608bc776 100644
--- a/client/system/info_darwin_test.go
+++ b/client/system/info_darwin_test.go
@@ -1,11 +1,12 @@
package system
import (
- log "github.com/sirupsen/logrus"
"testing"
+
+ log "github.com/sirupsen/logrus"
)
-func Test_sysInfo(t *testing.T) {
+func Test_sysInfoMac(t *testing.T) {
t.Skip("skipping darwin test")
serialNum, prodName, manufacturer := sysInfo()
if serialNum == "" {
diff --git a/client/system/info_linux.go b/client/system/info_linux.go
index db58d913f..b6a142bce 100644
--- a/client/system/info_linux.go
+++ b/client/system/info_linux.go
@@ -21,6 +21,26 @@ import (
"github.com/netbirdio/netbird/version"
)
+type SysInfoGetter interface {
+ GetSysInfo() SysInfo
+}
+
+type SysInfoWrapper struct {
+ si sysinfo.SysInfo
+}
+
+func (s SysInfoWrapper) GetSysInfo() SysInfo {
+ s.si.GetSysInfo()
+ return SysInfo{
+ ChassisSerial: s.si.Chassis.Serial,
+ ProductSerial: s.si.Product.Serial,
+ BoardSerial: s.si.Board.Serial,
+ ProductName: s.si.Product.Name,
+ BoardName: s.si.Board.Name,
+ ProductVendor: s.si.Product.Vendor,
+ }
+}
+
// GetInfo retrieves and parses the system information
func GetInfo(ctx context.Context) *Info {
info := _getInfo()
@@ -45,7 +65,8 @@ func GetInfo(ctx context.Context) *Info {
log.Warnf("failed to discover network addresses: %s", err)
}
- serialNum, prodName, manufacturer := sysInfo()
+ si := SysInfoWrapper{}
+ serialNum, prodName, manufacturer := sysInfo(si.GetSysInfo())
env := Environment{
Cloud: detect_cloud.Detect(ctx),
@@ -87,20 +108,36 @@ func _getInfo() string {
return out.String()
}
-func sysInfo() (serialNumber string, productName string, manufacturer string) {
- var si sysinfo.SysInfo
- si.GetSysInfo()
+func sysInfo(si SysInfo) (string, string, string) {
isascii := regexp.MustCompile("^[[:ascii:]]+$")
- serial := si.Chassis.Serial
- if (serial == "Default string" || serial == "") && si.Product.Serial != "" {
- serial = si.Product.Serial
+
+ serials := []string{si.ChassisSerial, si.ProductSerial}
+ serial := ""
+
+ for _, s := range serials {
+ if isascii.MatchString(s) {
+ serial = s
+ if s != "Default string" {
+ break
+ }
+ }
}
- if (!isascii.MatchString(serial)) && si.Board.Serial != "" {
- serial = si.Board.Serial
+
+ if serial == "" && isascii.MatchString(si.BoardSerial) {
+ serial = si.BoardSerial
}
- name := si.Product.Name
- if (!isascii.MatchString(name)) && si.Board.Name != "" {
- name = si.Board.Name
+
+ var name string
+ for _, n := range []string{si.ProductName, si.BoardName} {
+ if isascii.MatchString(n) {
+ name = n
+ break
+ }
}
- return serial, name, si.Product.Vendor
+
+ var manufacturer string
+ if isascii.MatchString(si.ProductVendor) {
+ manufacturer = si.ProductVendor
+ }
+ return serial, name, manufacturer
}
diff --git a/client/system/sysinfo_linux.go b/client/system/sysinfo_linux.go
new file mode 100644
index 000000000..df0f5574c
--- /dev/null
+++ b/client/system/sysinfo_linux.go
@@ -0,0 +1,12 @@
+package system
+
+// SysInfo used to moc out the sysinfo getter
+type SysInfo struct {
+ ChassisSerial string
+ ProductSerial string
+ BoardSerial string
+
+ ProductName string
+ BoardName string
+ ProductVendor string
+}
diff --git a/client/system/sysinfo_linux_test.go b/client/system/sysinfo_linux_test.go
new file mode 100644
index 000000000..f6a0b7058
--- /dev/null
+++ b/client/system/sysinfo_linux_test.go
@@ -0,0 +1,198 @@
+package system
+
+import "testing"
+
+func Test_sysInfo(t *testing.T) {
+ tests := []struct {
+ name string
+ sysInfo SysInfo
+ wantSerialNum string
+ wantProdName string
+ wantManufacturer string
+ }{
+ {
+ name: "Test Case 1",
+ sysInfo: SysInfo{
+ ChassisSerial: "Default string",
+ ProductSerial: "Default string",
+ BoardSerial: "M80-G8013200245",
+ ProductName: "B650M-HDV/M.2",
+ BoardName: "B650M-HDV/M.2",
+ ProductVendor: "ASRock",
+ },
+ wantSerialNum: "Default string",
+ wantProdName: "B650M-HDV/M.2",
+ wantManufacturer: "ASRock",
+ },
+ {
+ name: "Empty Chassis Serial",
+ sysInfo: SysInfo{
+ ChassisSerial: "",
+ ProductSerial: "Default string",
+ BoardSerial: "M80-G8013200245",
+ ProductName: "B650M-HDV/M.2",
+ BoardName: "B650M-HDV/M.2",
+ ProductVendor: "ASRock",
+ },
+ wantSerialNum: "Default string",
+ wantProdName: "B650M-HDV/M.2",
+ wantManufacturer: "ASRock",
+ },
+ {
+ name: "Empty Chassis Serial",
+ sysInfo: SysInfo{
+ ChassisSerial: "",
+ ProductSerial: "Default string",
+ BoardSerial: "M80-G8013200245",
+ ProductName: "B650M-HDV/M.2",
+ BoardName: "B650M-HDV/M.2",
+ ProductVendor: "ASRock",
+ },
+ wantSerialNum: "Default string",
+ wantProdName: "B650M-HDV/M.2",
+ wantManufacturer: "ASRock",
+ },
+ {
+ name: "Fallback to Product Serial",
+ sysInfo: SysInfo{
+ ChassisSerial: "Default string",
+ ProductSerial: "Product serial",
+ BoardSerial: "M80-G8013200245",
+ ProductName: "B650M-HDV/M.2",
+ BoardName: "B650M-HDV/M.2",
+ ProductVendor: "ASRock",
+ },
+ wantSerialNum: "Product serial",
+ wantProdName: "B650M-HDV/M.2",
+ wantManufacturer: "ASRock",
+ },
+ {
+ name: "Fallback to Product Serial with default string",
+ sysInfo: SysInfo{
+ ChassisSerial: "Default string",
+ ProductSerial: "Default string",
+ BoardSerial: "M80-G8013200245",
+ ProductName: "B650M-HDV/M.2",
+ BoardName: "B650M-HDV/M.2",
+ ProductVendor: "ASRock",
+ },
+ wantSerialNum: "Default string",
+ wantProdName: "B650M-HDV/M.2",
+ wantManufacturer: "ASRock",
+ },
+ {
+ name: "Non UTF-8 in Chassis Serial",
+ sysInfo: SysInfo{
+ ChassisSerial: "\x80",
+ ProductSerial: "Product serial",
+ BoardSerial: "M80-G8013200245",
+ ProductName: "B650M-HDV/M.2",
+ BoardName: "B650M-HDV/M.2",
+ ProductVendor: "ASRock",
+ },
+ wantSerialNum: "Product serial",
+ wantProdName: "B650M-HDV/M.2",
+ wantManufacturer: "ASRock",
+ },
+ {
+ name: "Non UTF-8 in Chassis Serial and Product Serial",
+ sysInfo: SysInfo{
+ ChassisSerial: "\x80",
+ ProductSerial: "\x80",
+ BoardSerial: "M80-G8013200245",
+ ProductName: "B650M-HDV/M.2",
+ BoardName: "B650M-HDV/M.2",
+ ProductVendor: "ASRock",
+ },
+ wantSerialNum: "M80-G8013200245",
+ wantProdName: "B650M-HDV/M.2",
+ wantManufacturer: "ASRock",
+ },
+ {
+ name: "Non UTF-8 in Chassis Serial and Product Serial and BoardSerial",
+ sysInfo: SysInfo{
+ ChassisSerial: "\x80",
+ ProductSerial: "\x80",
+ BoardSerial: "\x80",
+ ProductName: "B650M-HDV/M.2",
+ BoardName: "B650M-HDV/M.2",
+ ProductVendor: "ASRock",
+ },
+ wantSerialNum: "",
+ wantProdName: "B650M-HDV/M.2",
+ wantManufacturer: "ASRock",
+ },
+
+ {
+ name: "Empty Product Name",
+ sysInfo: SysInfo{
+ ChassisSerial: "Default string",
+ ProductSerial: "Default string",
+ BoardSerial: "M80-G8013200245",
+ ProductName: "",
+ BoardName: "boardname",
+ ProductVendor: "ASRock",
+ },
+ wantSerialNum: "Default string",
+ wantProdName: "boardname",
+ wantManufacturer: "ASRock",
+ },
+ {
+ name: "Invalid Product Name",
+ sysInfo: SysInfo{
+ ChassisSerial: "Default string",
+ ProductSerial: "Default string",
+ BoardSerial: "M80-G8013200245",
+ ProductName: "\x80",
+ BoardName: "boardname",
+ ProductVendor: "ASRock",
+ },
+ wantSerialNum: "Default string",
+ wantProdName: "boardname",
+ wantManufacturer: "ASRock",
+ },
+ {
+ name: "Invalid BoardName Name",
+ sysInfo: SysInfo{
+ ChassisSerial: "Default string",
+ ProductSerial: "Default string",
+ BoardSerial: "M80-G8013200245",
+ ProductName: "\x80",
+ BoardName: "\x80",
+ ProductVendor: "ASRock",
+ },
+ wantSerialNum: "Default string",
+ wantProdName: "",
+ wantManufacturer: "ASRock",
+ },
+ {
+ name: "Invalid chars",
+ sysInfo: SysInfo{
+ ChassisSerial: "\x80",
+ ProductSerial: "\x80",
+ BoardSerial: "\x80",
+ ProductName: "\x80",
+ BoardName: "\x80",
+ ProductVendor: "\x80",
+ },
+ wantSerialNum: "",
+ wantProdName: "",
+ wantManufacturer: "",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ gotSerialNum, gotProdName, gotManufacturer := sysInfo(tt.sysInfo)
+ if gotSerialNum != tt.wantSerialNum {
+ t.Errorf("sysInfo() gotSerialNum = %v, want %v", gotSerialNum, tt.wantSerialNum)
+ }
+ if gotProdName != tt.wantProdName {
+ t.Errorf("sysInfo() gotProdName = %v, want %v", gotProdName, tt.wantProdName)
+ }
+ if gotManufacturer != tt.wantManufacturer {
+ t.Errorf("sysInfo() gotManufacturer = %v, want %v", gotManufacturer, tt.wantManufacturer)
+ }
+ })
+ }
+}
diff --git a/client/ui/font_bsd.go b/client/ui/font_bsd.go
index 41bccceca..84cb5993d 100644
--- a/client/ui/font_bsd.go
+++ b/client/ui/font_bsd.go
@@ -1,4 +1,4 @@
-//go:build darwin || dragonfly || freebsd || netbsd || openbsd
+//go:build darwin
package main
diff --git a/encryption/message.go b/encryption/message.go
index a646fa679..6e4cd7391 100644
--- a/encryption/message.go
+++ b/encryption/message.go
@@ -10,7 +10,7 @@ import (
func EncryptMessage(remotePubKey wgtypes.Key, ourPrivateKey wgtypes.Key, message pb.Message) ([]byte, error) {
byteResp, err := pb.Marshal(message)
if err != nil {
- log.Errorf("failed marshalling message %v", err)
+ log.Errorf("failed marshalling message %v, %+v", err, message.String())
return nil, err
}
diff --git a/formatter/formatter.go b/formatter/formatter.go
index a37c67914..74de38603 100644
--- a/formatter/formatter.go
+++ b/formatter/formatter.go
@@ -14,14 +14,29 @@ type TextFormatter struct {
levelDesc []string
}
+// SyslogFormatter formats logs into text
+type SyslogFormatter struct {
+ levelDesc []string
+}
+
+var validLevelDesc = []string{"PANC", "FATL", "ERRO", "WARN", "INFO", "DEBG", "TRAC"}
+
+
// NewTextFormatter create new MyTextFormatter instance
func NewTextFormatter() *TextFormatter {
return &TextFormatter{
- levelDesc: []string{"PANC", "FATL", "ERRO", "WARN", "INFO", "DEBG", "TRAC"},
+ levelDesc: validLevelDesc,
timestampFormat: time.RFC3339, // or RFC3339
}
}
+// NewSyslogFormatter create new MySyslogFormatter instance
+func NewSyslogFormatter() *SyslogFormatter {
+ return &SyslogFormatter{
+ levelDesc: validLevelDesc,
+ }
+}
+
// Format renders a single log entry
func (f *TextFormatter) Format(entry *logrus.Entry) ([]byte, error) {
var fields string
@@ -49,3 +64,20 @@ func (f *TextFormatter) parseLevel(level logrus.Level) string {
return f.levelDesc[level]
}
+
+// Format renders a single log entry
+func (f *SyslogFormatter) Format(entry *logrus.Entry) ([]byte, error) {
+ var fields string
+ keys := make([]string, 0, len(entry.Data))
+ for k, v := range entry.Data {
+ if k == "source" {
+ continue
+ }
+ keys = append(keys, fmt.Sprintf("%s: %v", k, v))
+ }
+
+ if len(keys) > 0 {
+ fields = fmt.Sprintf("[%s] ", strings.Join(keys, ", "))
+ }
+ return []byte(fmt.Sprintf("%s%s\n", fields, entry.Message)), nil
+}
diff --git a/formatter/formatter_test.go b/formatter/formatter_test.go
index 54bc8a756..1ed207958 100644
--- a/formatter/formatter_test.go
+++ b/formatter/formatter_test.go
@@ -8,7 +8,7 @@ import (
"github.com/stretchr/testify/assert"
)
-func TestLogMessageFormat(t *testing.T) {
+func TestLogTextFormat(t *testing.T) {
someEntry := &logrus.Entry{
Data: logrus.Fields{"att1": 1, "att2": 2, "source": "some/fancy/path.go:46"},
@@ -24,3 +24,20 @@ func TestLogMessageFormat(t *testing.T) {
expectedString := "^2021-02-21T01:10:30Z WARN \\[(att1: 1, att2: 2|att2: 2, att1: 1)\\] some/fancy/path.go:46: Some Message\\s+$"
assert.Regexp(t, expectedString, parsedString)
}
+
+func TestLogSyslogFormat(t *testing.T) {
+
+ someEntry := &logrus.Entry{
+ Data: logrus.Fields{"att1": 1, "att2": 2, "source": "some/fancy/path.go:46"},
+ Time: time.Date(2021, time.Month(2), 21, 1, 10, 30, 0, time.UTC),
+ Level: 3,
+ Message: "Some Message",
+ }
+
+ formatter := NewSyslogFormatter()
+ result, _ := formatter.Format(someEntry)
+
+ parsedString := string(result)
+ expectedString := "^\\[(att1: 1, att2: 2|att2: 2, att1: 1)\\] Some Message\\s+$"
+ assert.Regexp(t, expectedString, parsedString)
+}
diff --git a/formatter/set.go b/formatter/set.go
index f9ccef601..9dfea5a7f 100644
--- a/formatter/set.go
+++ b/formatter/set.go
@@ -10,6 +10,12 @@ func SetTextFormatter(logger *logrus.Logger) {
logger.ReportCaller = true
logger.AddHook(NewContextHook())
}
+// SetSyslogFormatter set the text formatter for given logger.
+func SetSyslogFormatter(logger *logrus.Logger) {
+ logger.Formatter = NewSyslogFormatter()
+ logger.ReportCaller = true
+ logger.AddHook(NewContextHook())
+}
// SetJSONFormatter set the JSON formatter for given logger.
func SetJSONFormatter(logger *logrus.Logger) {
diff --git a/go.mod b/go.mod
index c0a0dc3c7..4b5e0ede8 100644
--- a/go.mod
+++ b/go.mod
@@ -117,7 +117,7 @@ require (
github.com/dgraph-io/ristretto v0.1.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/distribution/reference v0.6.0 // indirect
- github.com/docker/docker v26.1.3+incompatible // indirect
+ github.com/docker/docker v26.1.4+incompatible // indirect
github.com/docker/go-connections v0.5.0 // indirect
github.com/docker/go-units v0.5.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
diff --git a/go.sum b/go.sum
index 3f780eed5..dc5e7b90f 100644
--- a/go.sum
+++ b/go.sum
@@ -81,8 +81,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
-github.com/docker/docker v26.1.3+incompatible h1:lLCzRbrVZrljpVNobJu1J2FHk8V0s4BawoZippkc+xo=
-github.com/docker/docker v26.1.3+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
+github.com/docker/docker v26.1.4+incompatible h1:vuTpXDuoga+Z38m1OZHzl7NKisKWaWlhjQk7IDPSLsU=
+github.com/docker/docker v26.1.4+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c=
github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc=
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
diff --git a/management/server/account.go b/management/server/account.go
index 558de6fbb..5d3ee6dc1 100644
--- a/management/server/account.go
+++ b/management/server/account.go
@@ -135,8 +135,8 @@ type AccountManager interface {
UpdateIntegratedValidatorGroups(ctx context.Context, accountID string, userID string, groups []string) error
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)
GetValidatedPeers(account *Account) (map[string]struct{}, error)
- SyncAndMarkPeer(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error)
- CancelPeerRoutines(ctx context.Context, peer *nbpeer.Peer) error
+ SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error)
+ OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error)
@@ -770,10 +770,6 @@ func (a *Account) GetPeer(peerID string) *nbpeer.Peer {
// SetJWTGroups updates the user's auto groups by synchronizing JWT groups.
// Returns true if there are changes in the JWT group membership.
func (a *Account) SetJWTGroups(userID string, groupsNames []string) bool {
- if len(groupsNames) == 0 {
- return false
- }
-
user, ok := a.Users[userID]
if !ok {
return false
@@ -978,7 +974,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour")
}
- unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -1029,7 +1025,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) {
return func() (time.Duration, bool) {
- unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -1128,7 +1124,7 @@ func (am *DefaultAccountManager) warmupIDPCache(ctx context.Context) error {
// DeleteAccount deletes an account and all its users from local store and from the remote IDP if the requester is an admin and account owner
func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, userID string) error {
- unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
@@ -1588,7 +1584,7 @@ func (am *DefaultAccountManager) MarkPATUsed(ctx context.Context, tokenID string
return err
}
- unlock := am.Store.AcquireAccountWriteLock(ctx, account.Id)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id)
defer unlock()
account, err = am.Store.GetAccountByUser(ctx, user.Id)
@@ -1671,7 +1667,7 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims
if err != nil {
return nil, nil, err
}
- unlock := am.Store.AcquireAccountWriteLock(ctx, newAcc.Id)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, newAcc.Id)
alreadyUnlocked := false
defer func() {
if !alreadyUnlocked {
@@ -1827,7 +1823,7 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C
account, err := am.Store.GetAccountByUser(ctx, claims.UserId)
if err == nil {
- unlockAccount := am.Store.AcquireAccountWriteLock(ctx, account.Id)
+ unlockAccount := am.Store.AcquireWriteLockByUID(ctx, account.Id)
defer unlockAccount()
account, err = am.Store.GetAccountByUser(ctx, claims.UserId)
if err != nil {
@@ -1847,7 +1843,7 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C
return account, nil
} else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
if domainAccount != nil {
- unlockAccount := am.Store.AcquireAccountWriteLock(ctx, domainAccount.Id)
+ unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccount.Id)
defer unlockAccount()
domainAccount, err = am.Store.GetAccountByPrivateDomain(ctx, claims.Domain)
if err != nil {
@@ -1861,17 +1857,11 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C
}
}
-func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
- accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, peerPubKey)
- if err != nil {
- if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound {
- return nil, nil, nil, status.Errorf(status.Unauthenticated, "peer not registered")
- }
- return nil, nil, nil, err
- }
-
- unlock := am.Store.AcquireAccountReadLock(ctx, accountID)
- defer unlock()
+func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
+ accountUnlock := am.Store.AcquireReadLockByUID(ctx, accountID)
+ defer accountUnlock()
+ peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
+ defer peerUnlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
@@ -1891,26 +1881,20 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, peerPubKey
return peer, netMap, postureChecks, nil
}
-func (am *DefaultAccountManager) CancelPeerRoutines(ctx context.Context, peer *nbpeer.Peer) error {
- accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, peer.Key)
- if err != nil {
- if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound {
- return status.Errorf(status.Unauthenticated, "peer not registered")
- }
- return err
- }
-
- unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
- defer unlock()
+func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error {
+ accountUnlock := am.Store.AcquireReadLockByUID(ctx, accountID)
+ defer accountUnlock()
+ peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
+ defer peerUnlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
- err = am.MarkPeerConnected(ctx, peer.Key, false, nil, account)
+ err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, account)
if err != nil {
- log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peer.Key, err)
+ log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
}
return nil
@@ -1923,7 +1907,7 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st
return err
}
- unlock := am.Store.AcquireAccountReadLock(ctx, accountID)
+ unlock := am.Store.AcquireReadLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
diff --git a/management/server/account_test.go b/management/server/account_test.go
index 71b43bd65..45b4fbd6f 100644
--- a/management/server/account_test.go
+++ b/management/server/account_test.go
@@ -2219,6 +2219,13 @@ func TestAccount_SetJWTGroups(t *testing.T) {
assert.Len(t, account.Users["user2"].AutoGroups, 1, "new group should be added")
assert.Contains(t, account.Groups, account.Users["user2"].AutoGroups[0], "groups must contain group3 from user groups")
})
+
+ t.Run("remove all JWT groups", func(t *testing.T) {
+ updated := account.SetJWTGroups("user1", []string{})
+ assert.True(t, updated, "account should be updated")
+ assert.Len(t, account.Users["user1"].AutoGroups, 1, "only non-JWT groups should remain")
+ assert.Contains(t, account.Users["user1"].AutoGroups, "group1", " group1 should still be present")
+ })
}
func TestAccount_UserGroupsAddToPeers(t *testing.T) {
diff --git a/management/server/config.go b/management/server/config.go
index beba239e6..96f0c7ffd 100644
--- a/management/server/config.go
+++ b/management/server/config.go
@@ -57,6 +57,10 @@ type Config struct {
func (c Config) GetAuthAudiences() []string {
audiences := []string{c.HttpConfig.AuthAudience}
+ if c.HttpConfig.ExtraAuthAudience != "" {
+ audiences = append(audiences, c.HttpConfig.ExtraAuthAudience)
+ }
+
if c.DeviceAuthorizationFlow != nil && c.DeviceAuthorizationFlow.ProviderConfig.Audience != "" {
audiences = append(audiences, c.DeviceAuthorizationFlow.ProviderConfig.Audience)
}
@@ -95,6 +99,8 @@ type HttpServerConfig struct {
OIDCConfigEndpoint string
// IdpSignKeyRefreshEnabled identifies the signing key is currently being rotated or not
IdpSignKeyRefreshEnabled bool
+ // Extra audience
+ ExtraAuthAudience string
}
// Host represents a Wiretrustee host (e.g. STUN, TURN, Signal)
diff --git a/management/server/dns.go b/management/server/dns.go
index 8a889df3f..08732ad78 100644
--- a/management/server/dns.go
+++ b/management/server/dns.go
@@ -36,7 +36,7 @@ func (d DNSSettings) Copy() DNSSettings {
// GetDNSSettings validates a user role and returns the DNS settings for the provided account ID
func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error) {
- unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -58,7 +58,7 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
// SaveDNSSettings validates a user role and updates the account's DNS settings
func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error {
- unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
diff --git a/management/server/event.go b/management/server/event.go
index 616cea287..93b809226 100644
--- a/management/server/event.go
+++ b/management/server/event.go
@@ -13,7 +13,7 @@ import (
// GetEvents returns a list of activity events of an account
func (am *DefaultAccountManager) GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) {
- unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
diff --git a/management/server/file_store.go b/management/server/file_store.go
index c649602e2..6e3536bcd 100644
--- a/management/server/file_store.go
+++ b/management/server/file_store.go
@@ -39,8 +39,8 @@ type FileStore struct {
mux sync.Mutex `json:"-"`
storeFile string `json:"-"`
- // sync.Mutex indexed by accountID
- accountLocks sync.Map `json:"-"`
+ // sync.Mutex indexed by resource ID
+ resourceLocks sync.Map `json:"-"`
globalAccountLock sync.Mutex `json:"-"`
metrics telemetry.AppMetrics `json:"-"`
@@ -281,26 +281,26 @@ func (s *FileStore) AcquireGlobalLock(ctx context.Context) (unlock func()) {
return unlock
}
-// AcquireAccountWriteLock acquires account lock for writing to a resource and returns a function that releases the lock
-func (s *FileStore) AcquireAccountWriteLock(ctx context.Context, accountID string) (unlock func()) {
- log.WithContext(ctx).Debugf("acquiring lock for account %s", accountID)
+// AcquireWriteLockByUID acquires an ID lock for writing to a resource and returns a function that releases the lock
+func (s *FileStore) AcquireWriteLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
+ log.WithContext(ctx).Debugf("acquiring lock for ID %s", uniqueID)
start := time.Now()
- value, _ := s.accountLocks.LoadOrStore(accountID, &sync.Mutex{})
+ value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.Mutex{})
mtx := value.(*sync.Mutex)
mtx.Lock()
unlock = func() {
mtx.Unlock()
- log.WithContext(ctx).Debugf("released lock for account %s in %v", accountID, time.Since(start))
+ log.WithContext(ctx).Debugf("released lock for ID %s in %v", uniqueID, time.Since(start))
}
return unlock
}
-// AcquireAccountReadLock AcquireAccountWriteLock acquires account lock for reading a resource and returns a function that releases the lock
+// AcquireReadLockByUID acquires an ID lock for reading a resource and returns a function that releases the lock
// This method is still returns a write lock as file store can't handle read locks
-func (s *FileStore) AcquireAccountReadLock(ctx context.Context, accountID string) (unlock func()) {
- return s.AcquireAccountWriteLock(ctx, accountID)
+func (s *FileStore) AcquireReadLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
+ return s.AcquireWriteLockByUID(ctx, uniqueID)
}
func (s *FileStore) SaveAccount(ctx context.Context, account *Account) error {
@@ -666,6 +666,26 @@ func (s *FileStore) SaveInstallationID(ctx context.Context, ID string) error {
return s.persist(ctx, s.storeFile)
}
+// SavePeer saves the peer in the account
+func (s *FileStore) SavePeer(_ context.Context, accountID string, peer *nbpeer.Peer) error {
+ s.mux.Lock()
+ defer s.mux.Unlock()
+
+ account, err := s.getAccount(accountID)
+ if err != nil {
+ return err
+ }
+
+ newPeer := peer.Copy()
+
+ account.Peers[peer.ID] = newPeer
+
+ s.PeerKeyID2AccountID[peer.Key] = accountID
+ s.PeerID2AccountID[peer.ID] = accountID
+
+ return nil
+}
+
// SavePeerStatus stores the PeerStatus in memory. It doesn't attempt to persist data to speed up things.
// PeerStatus will be saved eventually when some other changes occur.
func (s *FileStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
diff --git a/management/server/geolocation/database.go b/management/server/geolocation/database.go
index 1bada6075..c9b2eafff 100644
--- a/management/server/geolocation/database.go
+++ b/management/server/geolocation/database.go
@@ -9,6 +9,7 @@ import (
"path"
"strconv"
+ log "github.com/sirupsen/logrus"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
@@ -30,6 +31,8 @@ func loadGeolocationDatabases(dataDir string) error {
continue
}
+ log.Infof("geo location file %s not found , file will be downloaded", file)
+
switch file {
case MMDBFileName:
extractFunc := func(src string, dst string) error {
diff --git a/management/server/group.go b/management/server/group.go
index 45c51bda2..37a6fc305 100644
--- a/management/server/group.go
+++ b/management/server/group.go
@@ -23,7 +23,7 @@ func (e *GroupLinkError) Error() string {
// GetGroup object of the peers
func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupID, userID string) (*nbgroup.Group, error) {
- unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -50,7 +50,7 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI
// GetAllGroups returns all groups in an account
func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID string, userID string) ([]*nbgroup.Group, error) {
- unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -77,7 +77,7 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID str
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
- unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -110,7 +110,7 @@ func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName,
// SaveGroup object of the peers
func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *nbgroup.Group) error {
- unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
return am.SaveGroups(ctx, accountID, userID, []*nbgroup.Group{newGroup})
}
@@ -163,7 +163,7 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
}
account.Network.IncSerial()
- if err = am.Store.SaveGroups(account.Id, account.Groups); err != nil {
+ if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
@@ -245,7 +245,7 @@ func difference(a, b []string) []string {
// DeleteGroup object of the peers
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error {
- unlock := am.Store.AcquireAccountWriteLock(ctx, accountId)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountId)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountId)
@@ -359,7 +359,7 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use
// ListGroups objects of the peers
func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
- unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -377,7 +377,7 @@ func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID strin
// GroupAddPeer appends peer to the group
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
- unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -413,7 +413,7 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
// GroupDeletePeer removes peer from the group
func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error {
- unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go
index 0e42b9f20..64931dd95 100644
--- a/management/server/grpcserver.go
+++ b/management/server/grpcserver.go
@@ -155,7 +155,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
log.WithContext(ctx).Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
}
- peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, peerKey.String(), extractPeerMeta(ctx, syncReq.GetMeta()), realIP)
+ peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), extractPeerMeta(ctx, syncReq.GetMeta()), realIP)
if err != nil {
return mapError(ctx, err)
}
@@ -178,11 +178,11 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart))
}
- return s.handleUpdates(ctx, peerKey, peer, updates, srv)
+ return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
}
// handleUpdates sends updates to the connected peer until the updates channel is closed.
-func (s *GRPCServer) handleUpdates(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error {
+func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error {
for {
select {
// condition when there are some updates
@@ -193,12 +193,12 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, peerKey wgtypes.Key, pee
if !open {
log.WithContext(ctx).Debugf("updates channel for peer %s was closed", peerKey.String())
- s.cancelPeerRoutines(ctx, peer)
+ s.cancelPeerRoutines(ctx, accountID, peer)
return nil
}
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
- if err := s.sendUpdate(ctx, peerKey, peer, update, srv); err != nil {
+ if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv); err != nil {
return err
}
@@ -206,7 +206,7 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, peerKey wgtypes.Key, pee
case <-srv.Context().Done():
// happens when connection drops, e.g. client disconnects
log.WithContext(ctx).Debugf("stream of peer %s has been closed", peerKey.String())
- s.cancelPeerRoutines(ctx, peer)
+ s.cancelPeerRoutines(ctx, accountID, peer)
return srv.Context().Err()
}
}
@@ -214,10 +214,10 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, peerKey wgtypes.Key, pee
// sendUpdate encrypts the update message using the peer key and the server's wireguard key,
// then sends the encrypted message to the connected peer via the sync server.
-func (s *GRPCServer) sendUpdate(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, update *UpdateMessage, srv proto.ManagementService_SyncServer) error {
+func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *UpdateMessage, srv proto.ManagementService_SyncServer) error {
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update)
if err != nil {
- s.cancelPeerRoutines(ctx, peer)
+ s.cancelPeerRoutines(ctx, accountID, peer)
return status.Errorf(codes.Internal, "failed processing update message")
}
err = srv.SendMsg(&proto.EncryptedMessage{
@@ -225,17 +225,17 @@ func (s *GRPCServer) sendUpdate(ctx context.Context, peerKey wgtypes.Key, peer *
Body: encryptedResp,
})
if err != nil {
- s.cancelPeerRoutines(ctx, peer)
+ s.cancelPeerRoutines(ctx, accountID, peer)
return status.Errorf(codes.Internal, "failed sending update message")
}
log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String())
return nil
}
-func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, peer *nbpeer.Peer) {
+func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
s.peersUpdateManager.CloseChannel(ctx, peer.ID)
s.turnRelayTokenManager.CancelRefresh(peer.ID)
- _ = s.accountManager.CancelPeerRoutines(ctx, peer)
+ _ = s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
}
diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml
index 30cb19c0c..45887dc2e 100644
--- a/management/server/http/api/openapi.yml
+++ b/management/server/http/api/openapi.yml
@@ -526,6 +526,43 @@ components:
- revoked
- auto_groups
- usage_limit
+ CreateSetupKeyRequest:
+ type: object
+ properties:
+ name:
+ description: Setup Key name
+ type: string
+ example: Default key
+ type:
+ description: Setup key type, one-off for single time usage and reusable
+ type: string
+ example: reusable
+ expires_in:
+ description: Expiration time in seconds
+ type: integer
+ minimum: 86400
+ maximum: 31536000
+ example: 86400
+ auto_groups:
+ description: List of group IDs to auto-assign to peers registered with this key
+ type: array
+ items:
+ type: string
+ example: "ch8i4ug6lnn4g9hqv7m0"
+ usage_limit:
+ description: A number of times this key can be used. The value of 0 indicates the unlimited usage.
+ type: integer
+ example: 0
+ ephemeral:
+ description: Indicate that the peer will be ephemeral or not
+ type: boolean
+ example: true
+ required:
+ - name
+ - type
+ - expires_in
+ - auto_groups
+ - usage_limit
PersonalAccessToken:
type: object
properties:
@@ -1806,7 +1843,7 @@ paths:
content:
'application/json':
schema:
- $ref: '#/components/schemas/SetupKeyRequest'
+ $ref: '#/components/schemas/CreateSetupKeyRequest'
responses:
'200':
description: A Setup Keys Object
diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go
index f731356ee..77a6c643d 100644
--- a/management/server/http/api/types.gen.go
+++ b/management/server/http/api/types.gen.go
@@ -254,6 +254,27 @@ type Country struct {
// CountryCode 2-letter ISO 3166-1 alpha-2 code that represents the country
type CountryCode = string
+// CreateSetupKeyRequest defines model for CreateSetupKeyRequest.
+type CreateSetupKeyRequest struct {
+ // AutoGroups List of group IDs to auto-assign to peers registered with this key
+ AutoGroups []string `json:"auto_groups"`
+
+ // Ephemeral Indicate that the peer will be ephemeral or not
+ Ephemeral *bool `json:"ephemeral,omitempty"`
+
+ // ExpiresIn Expiration time in seconds
+ ExpiresIn int `json:"expires_in"`
+
+ // Name Setup Key name
+ Name string `json:"name"`
+
+ // Type Setup key type, one-off for single time usage and reusable
+ Type string `json:"type"`
+
+ // UsageLimit A number of times this key can be used. The value of 0 indicates the unlimited usage.
+ UsageLimit int `json:"usage_limit"`
+}
+
// DNSSettings defines model for DNSSettings.
type DNSSettings struct {
// DisabledManagementGroups Groups whose DNS management is disabled
@@ -1241,7 +1262,7 @@ type PostApiRoutesJSONRequestBody = RouteRequest
type PutApiRoutesRouteIdJSONRequestBody = RouteRequest
// PostApiSetupKeysJSONRequestBody defines body for PostApiSetupKeys for application/json ContentType.
-type PostApiSetupKeysJSONRequestBody = SetupKeyRequest
+type PostApiSetupKeysJSONRequestBody = CreateSetupKeyRequest
// PutApiSetupKeysKeyIdJSONRequestBody defines body for PutApiSetupKeysKeyId for application/json ContentType.
type PutApiSetupKeysKeyIdJSONRequestBody = SetupKeyRequest
diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go
index 05537ada4..99e6b204c 100644
--- a/management/server/integrated_validator.go
+++ b/management/server/integrated_validator.go
@@ -32,7 +32,7 @@ func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(ctx context.Con
return errors.New("invalid groups")
}
- unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
a, err := am.Store.GetAccountByUser(ctx, userID)
diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go
index 5aaf95344..1748efa22 100644
--- a/management/server/management_proto_test.go
+++ b/management/server/management_proto_test.go
@@ -2,6 +2,7 @@ package server
import (
"context"
+ "fmt"
"net"
"os"
"path/filepath"
@@ -16,6 +17,7 @@ import (
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/encryption"
+ "github.com/netbirdio/netbird/formatter"
mgmtProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/util"
@@ -83,7 +85,7 @@ func Test_SyncProtocol(t *testing.T) {
defer func() {
os.Remove(filepath.Join(dir, "store.json")) //nolint
}()
- mgmtServer, mgmtAddr, err := startManagement(t, &Config{
+ mgmtServer, _, mgmtAddr, err := startManagement(t, &Config{
Stuns: []*Host{{
Proto: "udp",
URI: "stun:stun.wiretrustee.com:3468",
@@ -399,25 +401,28 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
}
}
-func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error) {
+func startManagement(t *testing.T, config *Config) (*grpc.Server, *DefaultAccountManager, string, error) {
t.Helper()
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
- return nil, "", err
+ return nil, nil, "", err
}
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, cleanUp, err := NewTestStoreFromJson(context.Background(), config.Datadir)
if err != nil {
- return nil, "", err
+ return nil, nil, "", err
}
t.Cleanup(cleanUp)
peersUpdateManager := NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{}
- accountManager, err := BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted",
+
+ ctx := context.WithValue(context.Background(), formatter.ExecutionContextKey, formatter.SystemSource) //nolint:staticcheck
+
+ accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted",
eventStore, nil, false, MocIntegratedValidator{})
if err != nil {
- return nil, "", err
+ return nil, nil, "", err
}
rc := &RelayConfig{
@@ -428,7 +433,7 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error)
ephemeralMgr := NewEphemeralManager(store, accountManager)
mgmtServer, err := NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, ephemeralMgr)
if err != nil {
- return nil, "", err
+ return nil, nil, "", err
}
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
@@ -438,7 +443,7 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error)
}
}()
- return s, lis.Addr().String(), nil
+ return s, accountManager, lis.Addr().String(), nil
}
func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.ClientConn, error) {
@@ -458,3 +463,165 @@ func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.Clie
return mgmtProto.NewManagementServiceClient(conn), conn, nil
}
+func Test_SyncStatusRace(t *testing.T) {
+ if os.Getenv("CI") == "true" && os.Getenv("NETBIRD_STORE_ENGINE") == "postgres" {
+ t.Skip("Skipping on CI and Postgres store")
+ }
+ for i := 0; i < 500; i++ {
+ t.Run(fmt.Sprintf("TestRun-%d", i), func(t *testing.T) {
+ testSyncStatusRace(t)
+ })
+ }
+}
+func testSyncStatusRace(t *testing.T) {
+ t.Helper()
+ dir := t.TempDir()
+ err := util.CopyFileContents("testdata/store_with_expired_peers.json", filepath.Join(dir, "store.json"))
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer func() {
+ os.Remove(filepath.Join(dir, "store.json")) //nolint
+ }()
+
+ mgmtServer, am, mgmtAddr, err := startManagement(t, &Config{
+ Stuns: []*Host{{
+ Proto: "udp",
+ URI: "stun:stun.wiretrustee.com:3468",
+ }},
+ TURNConfig: &TURNConfig{
+ TimeBasedCredentials: false,
+ CredentialsTTL: util.Duration{},
+ Secret: "whatever",
+ Turns: []*Host{{
+ Proto: "udp",
+ URI: "turn:stun.wiretrustee.com:3468",
+ }},
+ },
+ Signal: &Host{
+ Proto: "http",
+ URI: "signal.wiretrustee.com:10000",
+ },
+ Datadir: dir,
+ HttpConfig: nil,
+ })
+ if err != nil {
+ t.Fatal(err)
+ return
+ }
+ defer mgmtServer.GracefulStop()
+
+ client, clientConn, err := createRawClient(mgmtAddr)
+ if err != nil {
+ t.Fatal(err)
+ return
+ }
+
+ defer clientConn.Close()
+
+ // there are two peers already in the store, add two more
+ peers, err := registerPeers(2, client)
+ if err != nil {
+ t.Fatal(err)
+ return
+ }
+
+ serverKey, err := getServerKey(client)
+ if err != nil {
+ t.Fatal(err)
+ return
+ }
+
+ concurrentPeerKey2 := peers[1]
+ t.Log("Public key of concurrent peer: ", concurrentPeerKey2.PublicKey().String())
+
+ syncReq2 := &mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}}
+ message2, err := encryption.EncryptMessage(*serverKey, *concurrentPeerKey2, syncReq2)
+ if err != nil {
+ t.Fatal(err)
+ return
+ }
+
+ ctx2, cancelFunc2 := context.WithCancel(context.Background())
+
+ //client.
+ sync2, err := client.Sync(ctx2, &mgmtProto.EncryptedMessage{
+ WgPubKey: concurrentPeerKey2.PublicKey().String(),
+ Body: message2,
+ })
+ if err != nil {
+ t.Fatal(err)
+ return
+ }
+
+ resp2 := &mgmtProto.EncryptedMessage{}
+ err = sync2.RecvMsg(resp2)
+ if err != nil {
+ t.Fatal(err)
+ return
+ }
+
+ peerWithInvalidStatus := peers[0]
+ t.Log("Public key of peer with invalid status: ", peerWithInvalidStatus.PublicKey().String())
+
+ syncReq := &mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}}
+ message, err := encryption.EncryptMessage(*serverKey, *peerWithInvalidStatus, syncReq)
+ if err != nil {
+ t.Fatal(err)
+ return
+ }
+
+ ctx, cancelFunc := context.WithCancel(context.Background())
+
+ //client.
+ sync, err := client.Sync(ctx, &mgmtProto.EncryptedMessage{
+ WgPubKey: peerWithInvalidStatus.PublicKey().String(),
+ Body: message,
+ })
+ if err != nil {
+ t.Fatal(err)
+ return
+ }
+
+ // take the first registered peer as a base for the test. Total four.
+
+ resp := &mgmtProto.EncryptedMessage{}
+ err = sync.RecvMsg(resp)
+ if err != nil {
+ t.Fatal(err)
+ return
+ }
+
+ cancelFunc2()
+ time.Sleep(1 * time.Millisecond)
+ cancelFunc()
+ time.Sleep(10 * time.Millisecond)
+
+ ctx, cancelFunc = context.WithCancel(context.Background())
+ defer cancelFunc()
+ sync, err = client.Sync(ctx, &mgmtProto.EncryptedMessage{
+ WgPubKey: peerWithInvalidStatus.PublicKey().String(),
+ Body: message,
+ })
+ if err != nil {
+ t.Fatal(err)
+ return
+ }
+
+ resp = &mgmtProto.EncryptedMessage{}
+ err = sync.RecvMsg(resp)
+ if err != nil {
+ t.Fatal(err)
+ return
+ }
+
+ time.Sleep(10 * time.Millisecond)
+ peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), peerWithInvalidStatus.PublicKey().String())
+ if err != nil {
+ t.Fatal(err)
+ return
+ }
+ if !peer.Status.Connected {
+ t.Fatal("Peer should be connected")
+ }
+}
diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go
index 25bcdfcee..a66bdee2b 100644
--- a/management/server/mock_server/account_mock.go
+++ b/management/server/mock_server/account_mock.go
@@ -31,7 +31,7 @@ type MockAccountManager struct {
ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error)
GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error
- SyncAndMarkPeerFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
+ SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
GetNetworkMapFunc func(ctx context.Context, peerKey string) (*server.NetworkMap, error)
GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*server.Network, error)
@@ -105,14 +105,14 @@ type MockAccountManager struct {
GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error)
}
-func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
+func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
if am.SyncAndMarkPeerFunc != nil {
- return am.SyncAndMarkPeerFunc(ctx, peerPubKey, meta, realIP)
+ return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP)
}
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
}
-func (am *MockAccountManager) CancelPeerRoutines(_ context.Context, peer *nbpeer.Peer) error {
+func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string) error {
// TODO implement me
panic("implement me")
}
diff --git a/management/server/nameserver.go b/management/server/nameserver.go
index f8d644ded..636f7cfee 100644
--- a/management/server/nameserver.go
+++ b/management/server/nameserver.go
@@ -20,7 +20,7 @@ const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$`
// GetNameServerGroup gets a nameserver group object from account and nameserver group IDs
func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
- unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -48,7 +48,7 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account
// CreateNameServerGroup creates and saves a new nameserver group
func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) {
- unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -95,7 +95,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
// SaveNameServerGroup saves nameserver group
func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error {
- unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
if nsGroupToSave == nil {
@@ -130,7 +130,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
// DeleteNameServerGroup deletes nameserver group with nsGroupID
func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error {
- unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -160,7 +160,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
// ListNameServerGroups returns a list of nameserver groups from account
func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) {
- unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
diff --git a/management/server/peer.go b/management/server/peer.go
index 9b48276ce..964da3c53 100644
--- a/management/server/peer.go
+++ b/management/server/peer.go
@@ -7,10 +7,11 @@ import (
"strings"
"time"
- "github.com/netbirdio/netbird/management/server/posture"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
+ "github.com/netbirdio/netbird/management/server/posture"
+
"github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
@@ -149,7 +150,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
// UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, and Peer.LoginExpirationEnabled can be updated.
func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
- unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -271,7 +272,7 @@ func (am *DefaultAccountManager) deletePeers(ctx context.Context, account *Accou
// DeletePeer removes peer from the account by its IP
func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peerID, userID string) error {
- unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -355,7 +356,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
return nil, nil, nil, status.Errorf(status.NotFound, "failed adding new peer: account not found")
}
- unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer func() {
if unlock != nil {
unlock()
@@ -379,7 +380,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
}
// This is a handling for the case when the same machine (with the same WireGuard pub key) tries to register twice.
- // Such case is possible when AddPeer function takes long time to finish after AcquireAccountWriteLock (e.g., database is slow)
+ // Such case is possible when AddPeer function takes long time to finish after AcquireWriteLockByUID (e.g., database is slow)
// and the peer disconnects with a timeout and tries to register again.
// We just check if this machine has been registered before and reject the second registration.
// The connecting peer should be able to recover with a retry.
@@ -452,6 +453,17 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
Location: peer.Location,
}
+ if am.geo != nil && newPeer.Location.ConnectionIP != nil {
+ location, err := am.geo.Lookup(newPeer.Location.ConnectionIP)
+ if err != nil {
+ log.WithContext(ctx).Warnf("failed to get location for new peer realip: [%s]: %v", newPeer.Location.ConnectionIP.String(), err)
+ } else {
+ newPeer.Location.CountryCode = location.Country.ISOCode
+ newPeer.Location.CityName = location.City.Names.En
+ newPeer.Location.GeoNameID = location.City.GeonameID
+ }
+ }
+
// add peer to 'All' group
group, err := account.GetGroupAll()
if err != nil {
@@ -534,12 +546,12 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
}
if peerLoginExpired(ctx, peer, account.Settings) {
- return nil, nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more")
+ return nil, nil, nil, status.NewPeerLoginExpiredError()
}
peer, updated := updatePeerMeta(peer, sync.Meta, account)
if updated {
- err = am.Store.SaveAccount(ctx, account)
+ err = am.Store.SavePeer(ctx, account.Id, peer)
if err != nil {
return nil, nil, nil, err
}
@@ -585,21 +597,10 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
// we couldn't find this peer by its public key which can mean that peer hasn't been registered yet.
// Try registering it.
newPeer := &nbpeer.Peer{
- Key: login.WireGuardPubKey,
- Meta: login.Meta,
- SSHKey: login.SSHKey,
- }
- if am.geo != nil && login.ConnectionIP != nil {
- location, err := am.geo.Lookup(login.ConnectionIP)
- if err != nil {
- log.WithContext(ctx).Warnf("failed to get location for new peer realip: [%s]: %v", login.ConnectionIP.String(), err)
- } else {
- newPeer.Location.ConnectionIP = login.ConnectionIP
- newPeer.Location.CountryCode = location.Country.ISOCode
- newPeer.Location.CityName = location.City.Names.En
- newPeer.Location.GeoNameID = location.City.GeonameID
-
- }
+ Key: login.WireGuardPubKey,
+ Meta: login.Meta,
+ SSHKey: login.SSHKey,
+ Location: nbpeer.Location{ConnectionIP: login.ConnectionIP},
}
return am.AddPeer(ctx, login.SetupKey, login.UserID, newPeer)
@@ -609,44 +610,17 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
return nil, nil, nil, status.Errorf(status.Internal, "failed while logging in peer")
}
- peer, err := am.Store.GetPeerByPeerPubKey(ctx, login.WireGuardPubKey)
- if err != nil {
- return nil, nil, nil, status.NewPeerNotRegisteredError()
- }
-
- accSettings, err := am.Store.GetAccountSettings(ctx, accountID)
- if err != nil {
- return nil, nil, nil, status.Errorf(status.Internal, "failed to get account settings: %s", err)
- }
-
- var isWriteLock bool
-
- // duplicated logic from after the lock to have an early exit
- expired := peerLoginExpired(ctx, peer, accSettings)
- switch {
- case expired:
- if err := checkAuth(ctx, login.UserID, peer); err != nil {
+ // when the client sends a login request with a JWT which is used to get the user ID,
+ // it means that the client has already checked if it needs login and had been through the SSO flow
+ // so, we can skip this check and directly proceed with the login
+ if login.UserID == "" {
+ err = am.checkIFPeerNeedsLoginWithoutLock(ctx, accountID, login)
+ if err != nil {
return nil, nil, nil, err
}
- isWriteLock = true
- log.WithContext(ctx).Debugf("peer login expired, acquiring write lock")
-
- case peer.UpdateMetaIfNew(login.Meta):
- isWriteLock = true
- log.WithContext(ctx).Debugf("peer changed meta, acquiring write lock")
-
- default:
- isWriteLock = false
- log.WithContext(ctx).Debugf("peer meta is the same, acquiring read lock")
}
- var unlock func()
-
- if isWriteLock {
- unlock = am.Store.AcquireAccountWriteLock(ctx, accountID)
- } else {
- unlock = am.Store.AcquireAccountReadLock(ctx, accountID)
- }
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer func() {
if unlock != nil {
unlock()
@@ -659,7 +633,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
return nil, nil, nil, err
}
- peer, err = account.FindPeerByPubKey(login.WireGuardPubKey)
+ peer, err := account.FindPeerByPubKey(login.WireGuardPubKey)
if err != nil {
return nil, nil, nil, status.NewPeerNotRegisteredError()
}
@@ -670,53 +644,39 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
}
// this flag prevents unnecessary calls to the persistent store.
- shouldStoreAccount := false
+ shouldStorePeer := false
updateRemotePeers := false
if peerLoginExpired(ctx, peer, account.Settings) {
- err = checkAuth(ctx, login.UserID, peer)
+ err = am.handleExpiredPeer(ctx, login, account, peer)
if err != nil {
return nil, nil, nil, err
}
- // If peer was expired before and if it reached this point, it is re-authenticated.
- // UserID is present, meaning that JWT validation passed successfully in the API layer.
- updatePeerLastLogin(peer, account)
updateRemotePeers = true
- shouldStoreAccount = true
-
- // sync user last login with peer last login
- user, err := account.FindUser(login.UserID)
- if err != nil {
- return nil, nil, nil, status.Errorf(status.Internal, "couldn't find user")
- }
- user.updateLastLogin(peer.LastLogin)
-
- am.StoreEvent(ctx, login.UserID, peer.ID, account.Id, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain()))
+ shouldStorePeer = true
}
isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
if err != nil {
return nil, nil, nil, err
}
+
peer, updated := updatePeerMeta(peer, login.Meta, account)
if updated {
- shouldStoreAccount = true
+ shouldStorePeer = true
}
- peer, err = am.checkAndUpdatePeerSSHKey(ctx, peer, account, login.SSHKey)
- if err != nil {
- return nil, nil, nil, err
+ if peer.SSHKey != login.SSHKey {
+ peer.SSHKey = login.SSHKey
+ shouldStorePeer = true
}
- if shouldStoreAccount {
- if !isWriteLock {
- log.WithContext(ctx).Errorf("account %s should be stored but is not write locked", accountID)
- return nil, nil, nil, status.Errorf(status.Internal, "account should be stored but is not write locked")
- }
- err = am.Store.SaveAccount(ctx, account)
+ if shouldStorePeer {
+ err = am.Store.SavePeer(ctx, accountID, peer)
if err != nil {
return nil, nil, nil, err
}
}
+
unlock()
unlock = nil
@@ -724,13 +684,46 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
am.updateAccountPeers(ctx, account)
}
+ return am.getValidatedPeerWithMap(ctx, isRequiresApproval, account, peer)
+}
+
+// checkIFPeerNeedsLoginWithoutLock checks if the peer needs login without acquiring the account lock. The check validate if the peer was not added via SSO
+// and if the peer login is expired.
+// The NetBird client doesn't have a way to check if the peer needs login besides sending a login request
+// with no JWT token and usually no setup-key. As the client can send up to two login request to check if it is expired
+// and before starting the engine, we do the checks without an account lock to avoid piling up requests.
+func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Context, accountID string, login PeerLogin) error {
+ peer, err := am.Store.GetPeerByPeerPubKey(ctx, login.WireGuardPubKey)
+ if err != nil {
+ return err
+ }
+
+ // if the peer was not added with SSO login we can exit early because peers activated with setup-key
+ // doesn't expire, and we avoid extra databases calls.
+ if !peer.AddedWithSSOLogin() {
+ return nil
+ }
+
+ settings, err := am.Store.GetAccountSettings(ctx, accountID)
+ if err != nil {
+ return err
+ }
+
+ if peerLoginExpired(ctx, peer, settings) {
+ return status.NewPeerLoginExpiredError()
+ }
+
+ return nil
+}
+
+func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, account *Account, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
var postureChecks []*posture.Checks
if isRequiresApproval {
emptyMap := &NetworkMap{
Network: account.Network.Copy(),
}
- return peer, emptyMap, postureChecks, nil
+ return peer, emptyMap, nil, nil
}
approvedPeersMap, err := am.GetValidatedPeers(account)
@@ -742,6 +735,30 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
return peer, account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, approvedPeersMap), postureChecks, nil
}
+func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, login PeerLogin, account *Account, peer *nbpeer.Peer) error {
+ err := checkAuth(ctx, login.UserID, peer)
+ if err != nil {
+ return err
+ }
+ // If peer was expired before and if it reached this point, it is re-authenticated.
+ // UserID is present, meaning that JWT validation passed successfully in the API layer.
+ updatePeerLastLogin(peer, account)
+
+ // sync user last login with peer last login
+ user, err := account.FindUser(login.UserID)
+ if err != nil {
+ return status.Errorf(status.Internal, "couldn't find user")
+ }
+
+ err = am.Store.SaveUserLastLogin(account.Id, user.Id, peer.LastLogin)
+ if err != nil {
+ return err
+ }
+
+ am.StoreEvent(ctx, login.UserID, peer.ID, account.Id, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain()))
+ return nil
+}
+
func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, account *Account) error {
if peer.AddedWithSSOLogin() {
user, err := account.FindUser(peer.UserID)
@@ -758,11 +775,11 @@ func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, account *Account) error {
func checkAuth(ctx context.Context, loginUserID string, peer *nbpeer.Peer) error {
if loginUserID == "" {
// absence of a user ID indicates that JWT wasn't provided.
- return status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more")
+ return status.NewPeerLoginExpiredError()
}
if peer.UserID != loginUserID {
log.WithContext(ctx).Warnf("user mismatch when logging in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, loginUserID)
- return status.Errorf(status.Unauthenticated, "can't login")
+ return status.Errorf(status.Unauthenticated, "can't login with this credentials")
}
return nil
}
@@ -782,31 +799,6 @@ func updatePeerLastLogin(peer *nbpeer.Peer, account *Account) {
account.UpdatePeer(peer)
}
-func (am *DefaultAccountManager) checkAndUpdatePeerSSHKey(ctx context.Context, peer *nbpeer.Peer, account *Account, newSSHKey string) (*nbpeer.Peer, error) {
- if len(newSSHKey) == 0 {
- log.WithContext(ctx).Debugf("no new SSH key provided for peer %s, skipping update", peer.ID)
- return peer, nil
- }
-
- if peer.SSHKey == newSSHKey {
- log.WithContext(ctx).Debugf("same SSH key provided for peer %s, skipping update", peer.ID)
- return peer, nil
- }
-
- peer.SSHKey = newSSHKey
- account.UpdatePeer(peer)
-
- err := am.Store.SaveAccount(ctx, account)
- if err != nil {
- return nil, err
- }
-
- // trigger network map update
- am.updateAccountPeers(ctx, account)
-
- return peer, nil
-}
-
// UpdatePeerSSHKey updates peer's public SSH key
func (am *DefaultAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error {
if sshKey == "" {
@@ -819,7 +811,7 @@ func (am *DefaultAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID st
return err
}
- unlock := am.Store.AcquireAccountWriteLock(ctx, account.Id)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id)
defer unlock()
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
@@ -854,7 +846,7 @@ func (am *DefaultAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID st
// GetPeer for a given accountID, peerID and userID error if not found.
func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) {
- unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
diff --git a/management/server/policy.go b/management/server/policy.go
index a70d7f0ed..30614ed2d 100644
--- a/management/server/policy.go
+++ b/management/server/policy.go
@@ -315,7 +315,7 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule,
// GetPolicy from the store
func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) {
- unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -343,7 +343,7 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic
// SavePolicy in the store
func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) error {
- unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -371,7 +371,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
// DeletePolicy from the store
func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error {
- unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -398,7 +398,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
// ListPolicies from the store
func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) {
- unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go
index 851d4d31f..4a7c9755d 100644
--- a/management/server/posture_checks.go
+++ b/management/server/posture_checks.go
@@ -15,7 +15,7 @@ const (
)
func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
- unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -42,7 +42,7 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID
}
func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error {
- unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -89,7 +89,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
}
func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error {
- unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -121,7 +121,7 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
}
func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) {
- unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
diff --git a/management/server/route.go b/management/server/route.go
index 6db00a255..064f3c105 100644
--- a/management/server/route.go
+++ b/management/server/route.go
@@ -17,7 +17,7 @@ import (
// GetRoute gets a route object from account and route IDs
func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) {
- unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -126,7 +126,7 @@ func getRouteDescriptor(prefix netip.Prefix, domains domain.List) string {
// CreateRoute creates and saves a new route
func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) {
- unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -214,7 +214,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
// SaveRoute saves route
func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userID string, routeToSave *route.Route) error {
- unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
if routeToSave == nil {
@@ -283,7 +283,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
// DeleteRoute deletes route with routeID
func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error {
- unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -311,7 +311,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
// ListRoutes returns a list of routes from account
func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) {
- unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
diff --git a/management/server/setupkey.go b/management/server/setupkey.go
index dcaee357c..8ef91755c 100644
--- a/management/server/setupkey.go
+++ b/management/server/setupkey.go
@@ -210,7 +210,7 @@ func Hash(s string) uint32 {
// and adds it to the specified account. A list of autoGroups IDs can be empty.
func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType SetupKeyType,
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error) {
- unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
keyDuration := DefaultSetupKeyDuration
@@ -256,7 +256,7 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
// (e.g. the key itself, creation date, ID, etc).
// These properties are overwritten: Name, AutoGroups, Revoked. The rest is copied from the existing key.
func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID string, keyToSave *SetupKey, userID string) (*SetupKey, error) {
- unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
if keyToSave == nil {
@@ -328,7 +328,7 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
// ListSetupKeys returns a list of all setup keys of the account
func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error) {
- unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
@@ -360,7 +360,7 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u
// GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found.
func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) {
- unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
diff --git a/management/server/sql_store.go b/management/server/sql_store.go
index 37cc10d8b..c44ab7f09 100644
--- a/management/server/sql_store.go
+++ b/management/server/sql_store.go
@@ -31,14 +31,16 @@ import (
)
const (
- storeSqliteFileName = "store.db"
- idQueryCondition = "id = ?"
+ storeSqliteFileName = "store.db"
+ idQueryCondition = "id = ?"
+ accountAndIDQueryCondition = "account_id = ? and id = ?"
+ peerNotFoundFMT = "peer %s not found"
)
// SqlStore represents an account storage backed by a Sql DB persisted to disk
type SqlStore struct {
db *gorm.DB
- accountLocks sync.Map
+ resourceLocks sync.Map
globalAccountLock sync.Mutex
metrics telemetry.AppMetrics
installationPK int
@@ -96,33 +98,35 @@ func (s *SqlStore) AcquireGlobalLock(ctx context.Context) (unlock func()) {
return unlock
}
-func (s *SqlStore) AcquireAccountWriteLock(ctx context.Context, accountID string) (unlock func()) {
- log.WithContext(ctx).Tracef("acquiring write lock for account %s", accountID)
+// AcquireWriteLockByUID acquires an ID lock for writing to a resource and returns a function that releases the lock
+func (s *SqlStore) AcquireWriteLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
+ log.WithContext(ctx).Tracef("acquiring write lock for ID %s", uniqueID)
start := time.Now()
- value, _ := s.accountLocks.LoadOrStore(accountID, &sync.RWMutex{})
+ value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.RWMutex{})
mtx := value.(*sync.RWMutex)
mtx.Lock()
unlock = func() {
mtx.Unlock()
- log.WithContext(ctx).Tracef("released write lock for account %s in %v", accountID, time.Since(start))
+ log.WithContext(ctx).Tracef("released write lock for ID %s in %v", uniqueID, time.Since(start))
}
return unlock
}
-func (s *SqlStore) AcquireAccountReadLock(ctx context.Context, accountID string) (unlock func()) {
- log.WithContext(ctx).Tracef("acquiring read lock for account %s", accountID)
+// AcquireReadLockByUID acquires an ID lock for writing to a resource and returns a function that releases the lock
+func (s *SqlStore) AcquireReadLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
+ log.WithContext(ctx).Tracef("acquiring read lock for ID %s", uniqueID)
start := time.Now()
- value, _ := s.accountLocks.LoadOrStore(accountID, &sync.RWMutex{})
+ value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.RWMutex{})
mtx := value.(*sync.RWMutex)
mtx.RLock()
unlock = func() {
mtx.RUnlock()
- log.WithContext(ctx).Tracef("released read lock for account %s in %v", accountID, time.Since(start))
+ log.WithContext(ctx).Tracef("released read lock for ID %s in %v", uniqueID, time.Since(start))
}
return unlock
@@ -271,6 +275,38 @@ func (s *SqlStore) GetInstallationID() string {
return installation.InstallationIDValue
}
+func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error {
+ // To maintain data integrity, we create a copy of the peer's to prevent unintended updates to other fields.
+ peerCopy := peer.Copy()
+ peerCopy.AccountID = accountID
+
+ err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
+ // check if peer exists before saving
+ var peerID string
+ result := tx.Model(&nbpeer.Peer{}).Select("id").Find(&peerID, accountAndIDQueryCondition, accountID, peer.ID)
+ if result.Error != nil {
+ return result.Error
+ }
+
+ if peerID == "" {
+ return status.Errorf(status.NotFound, peerNotFoundFMT, peer.ID)
+ }
+
+ result = tx.Model(&nbpeer.Peer{}).Where(accountAndIDQueryCondition, accountID, peer.ID).Save(peerCopy)
+ if result.Error != nil {
+ return result.Error
+ }
+
+ return nil
+ })
+
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
+
func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
var peerCopy nbpeer.Peer
peerCopy.Status = &peerStatus
@@ -281,14 +317,14 @@ func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.Pe
}
result := s.db.Model(&nbpeer.Peer{}).
Select(fieldsToUpdate).
- Where("account_id = ? AND id = ?", accountID, peerID).
+ Where(accountAndIDQueryCondition, accountID, peerID).
Updates(&peerCopy)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
- return status.Errorf(status.NotFound, "peer %s not found", peerID)
+ return status.Errorf(status.NotFound, peerNotFoundFMT, peerID)
}
return nil
@@ -302,7 +338,7 @@ func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.P
peerCopy.Location = peerWithLocation.Location
result := s.db.Model(&nbpeer.Peer{}).
- Where("account_id = ? and id = ?", accountID, peerWithLocation.ID).
+ Where(accountAndIDQueryCondition, accountID, peerWithLocation.ID).
Updates(peerCopy)
if result.Error != nil {
@@ -310,7 +346,7 @@ func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.P
}
if result.RowsAffected == 0 {
- return status.Errorf(status.NotFound, "peer %s not found", peerWithLocation.ID)
+ return status.Errorf(status.NotFound, peerNotFoundFMT, peerWithLocation.ID)
}
return nil
@@ -644,7 +680,7 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, accountID string) (*S
func (s *SqlStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error {
var user User
- result := s.db.First(&user, "account_id = ? and id = ?", accountID, userID)
+ result := s.db.First(&user, accountAndIDQueryCondition, accountID, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "user %s not found", userID)
diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go
index f46ca7e5d..ce4ee531a 100644
--- a/management/server/sql_store_test.go
+++ b/management/server/sql_store_test.go
@@ -362,6 +362,54 @@ func TestSqlite_GetAccount(t *testing.T) {
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
}
+func TestSqlite_SavePeer(t *testing.T) {
+ if runtime.GOOS == "windows" {
+ t.Skip("The SQLite store is not properly supported by Windows yet")
+ }
+
+ store := newSqliteStoreFromFile(t, "testdata/store.json")
+
+ account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b")
+ require.NoError(t, err)
+
+ // save status of non-existing peer
+ peer := &nbpeer.Peer{
+ Key: "peerkey",
+ ID: "testpeer",
+ SetupKey: "peerkeysetupkey",
+ IP: net.IP{127, 0, 0, 1},
+ Meta: nbpeer.PeerSystemMeta{Hostname: "testingpeer"},
+ Name: "peer name",
+ Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
+ }
+ ctx := context.Background()
+ err = store.SavePeer(ctx, account.Id, peer)
+ assert.Error(t, err)
+ parsedErr, ok := status.FromError(err)
+ require.True(t, ok)
+ require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
+
+ // save new status of existing peer
+ account.Peers[peer.ID] = peer
+
+ err = store.SaveAccount(context.Background(), account)
+ require.NoError(t, err)
+
+ updatedPeer := peer.Copy()
+ updatedPeer.Status.Connected = false
+ updatedPeer.Meta.Hostname = "updatedpeer"
+
+ err = store.SavePeer(ctx, account.Id, updatedPeer)
+ require.NoError(t, err)
+
+ account, err = store.GetAccount(context.Background(), account.Id)
+ require.NoError(t, err)
+
+ actual := account.Peers[peer.ID]
+ assert.Equal(t, updatedPeer.Status, actual.Status)
+ assert.Equal(t, updatedPeer.Meta, actual.Meta)
+}
+
func TestSqlite_SavePeerStatus(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
@@ -402,7 +450,19 @@ func TestSqlite_SavePeerStatus(t *testing.T) {
actual := account.Peers["testpeer"].Status
assert.Equal(t, newStatus, *actual)
+
+ newStatus.Connected = true
+
+ err = store.SavePeerStatus(account.Id, "testpeer", newStatus)
+ require.NoError(t, err)
+
+ account, err = store.GetAccount(context.Background(), account.Id)
+ require.NoError(t, err)
+
+ actual = account.Peers["testpeer"].Status
+ assert.Equal(t, newStatus, *actual)
}
+
func TestSqlite_SavePeerLocation(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
diff --git a/management/server/status/error.go b/management/server/status/error.go
index 39cd6c613..58b9a84a0 100644
--- a/management/server/status/error.go
+++ b/management/server/status/error.go
@@ -95,3 +95,8 @@ func NewUserNotFoundError(userKey string) error {
func NewPeerNotRegisteredError() error {
return Errorf(Unauthenticated, "peer is not registered")
}
+
+// NewPeerLoginExpiredError creates a new Error with PermissionDenied type for an expired peer
+func NewPeerLoginExpiredError() error {
+ return Errorf(PermissionDenied, "peer login has expired, please log in once more")
+}
diff --git a/management/server/store.go b/management/server/store.go
index 3ba73e8c7..864871c8e 100644
--- a/management/server/store.go
+++ b/management/server/store.go
@@ -12,10 +12,11 @@ import (
"strings"
"time"
- nbgroup "github.com/netbirdio/netbird/management/server/group"
log "github.com/sirupsen/logrus"
"gorm.io/gorm"
+ nbgroup "github.com/netbirdio/netbird/management/server/group"
+
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/util"
@@ -48,12 +49,13 @@ type Store interface {
DeleteTokenID2UserIDIndex(tokenID string) error
GetInstallationID() string
SaveInstallationID(ctx context.Context, ID string) error
- // AcquireAccountWriteLock should attempt to acquire account lock for write purposes and return a function that releases the lock
- AcquireAccountWriteLock(ctx context.Context, accountID string) func()
- // AcquireAccountReadLock should attempt to acquire account lock for read purposes and return a function that releases the lock
- AcquireAccountReadLock(ctx context.Context, accountID string) func()
+ // AcquireWriteLockByUID should attempt to acquire a lock for write purposes and return a function that releases the lock
+ AcquireWriteLockByUID(ctx context.Context, uniqueID string) func()
+ // AcquireReadLockByUID should attempt to acquire lock for read purposes and return a function that releases the lock
+ AcquireReadLockByUID(ctx context.Context, uniqueID string) func()
// AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock
AcquireGlobalLock(ctx context.Context) func()
+ SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
SavePeerLocation(accountID string, peer *nbpeer.Peer) error
SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error
diff --git a/management/server/user.go b/management/server/user.go
index 65b5c7878..b8afcda3a 100644
--- a/management/server/user.go
+++ b/management/server/user.go
@@ -211,7 +211,7 @@ func NewOwnerUser(id string) *User {
// createServiceUser creates a new service user under the given account.
func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountID string, initiatorUserID string, role UserRole, serviceUserName string, nonDeletable bool, autoGroups []string) (*UserInfo, error) {
- unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -267,7 +267,7 @@ func (am *DefaultAccountManager) CreateUser(ctx context.Context, accountID, user
// inviteNewUser Invites a USer to a given account and creates reference in datastore
func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, userID string, invite *UserInfo) (*UserInfo, error) {
- unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
if am.idpManager == nil {
@@ -368,7 +368,7 @@ func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.A
return nil, fmt.Errorf("failed to get account with token claims %v", err)
}
- unlock := am.Store.AcquireAccountWriteLock(ctx, account.Id)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id)
defer unlock()
account, err = am.Store.GetAccount(ctx, account.Id)
@@ -401,7 +401,7 @@ func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.A
// ListUsers returns lists of all users under the account.
// It doesn't populate user information such as email or name.
func (am *DefaultAccountManager) ListUsers(ctx context.Context, accountID string) ([]*User, error) {
- unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -428,7 +428,7 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init
if initiatorUserID == targetUserID {
return status.Errorf(status.InvalidArgument, "self deletion is not allowed")
}
- unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -538,7 +538,7 @@ func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorU
// InviteUser resend invitations to users who haven't activated their accounts prior to the expiration period.
func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error {
- unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
if am.idpManager == nil {
@@ -578,7 +578,7 @@ func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID strin
// CreatePAT creates a new PAT for the given user
func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) {
- unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
if tokenName == "" {
@@ -628,7 +628,7 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string
// DeletePAT deletes a specific PAT from a user
func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
- unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -678,7 +678,7 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string
// GetPAT returns a specific PAT from a user
func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) {
- unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -710,7 +710,7 @@ func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, i
// GetAllPATs returns all PATs for a user
func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) {
- unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -752,7 +752,7 @@ func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, i
return nil, status.Errorf(status.InvalidArgument, "provided user update is nil")
}
- unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
updatedUsers, err := am.SaveOrAddUsers(ctx, accountID, initiatorUserID, []*User{update}, addIfNotExists)
@@ -859,7 +859,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
}
account.Network.IncSerial()
- if err = am.Store.SaveUsers(account.Id, account.Users); err != nil {
+ if err = am.Store.SaveAccount(ctx, account); err != nil {
return nil, err
}
diff --git a/util/log.go b/util/log.go
index 74b99311e..4bce75e4a 100644
--- a/util/log.go
+++ b/util/log.go
@@ -35,8 +35,11 @@ func InitLog(logLevel string, logPath string) error {
AddSyslogHook()
}
+ //nolint:gocritic
if os.Getenv("NB_LOG_FORMAT") == "json" {
formatter.SetJSONFormatter(log.StandardLogger())
+ } else if logPath == "syslog" {
+ formatter.SetSyslogFormatter(log.StandardLogger())
} else {
formatter.SetTextFormatter(log.StandardLogger())
}