From f17dd3619cf6c121fa4ea76d3f1e8192b15497d8 Mon Sep 17 00:00:00 2001 From: Misha Bragin Date: Wed, 9 Jul 2025 15:49:09 +0200 Subject: [PATCH 01/50] [misc] update image in README.md (#4122) --- README.md | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index c3b365694..d5469c28b 100644 --- a/README.md +++ b/README.md @@ -50,10 +50,9 @@ **Secure.** NetBird enables secure remote access by applying granular access policies while allowing you to manage them intuitively from a single place. Works universally on any infrastructure. -### Open-Source Network Security in a Single Platform +### Open Source Network Security in a Single Platform - -![netbird_2](https://github.com/netbirdio/netbird/assets/700848/46bc3b73-508d-4a0e-bb9a-f465d68646ab) +centralized-network-management 1 ### NetBird on Lawrence Systems (Video) [![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw) From 408f423adcf7a51f1ae56cdb6016c0b0158643fb Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Wed, 9 Jul 2025 22:16:08 +0200 Subject: [PATCH 02/50] [client] Disable pidfd check on Android 11 and below (#4127) Disable pidfd check on Android 11 and below On Android 11 (SDK <= 30) and earlier, pidfd-related system calls are blocked by seccomp policies, causing SIGSYS crashes. This change overrides `checkPidfdOnce` to return an error on affected versions, preventing the use of unsupported pidfd features. --- .github/workflows/mobile-build-validation.yml | 2 +- client/android/client.go | 4 ++- client/android/exec.go | 26 +++++++++++++++++++ 3 files changed, 30 insertions(+), 2 deletions(-) create mode 100644 client/android/exec.go diff --git a/.github/workflows/mobile-build-validation.yml b/.github/workflows/mobile-build-validation.yml index 569956a54..c7d43695b 100644 --- a/.github/workflows/mobile-build-validation.yml +++ b/.github/workflows/mobile-build-validation.yml @@ -43,7 +43,7 @@ jobs: - name: gomobile init run: gomobile init - name: build android netbird lib - run: PATH=$PATH:$(go env GOPATH) gomobile bind -o $GITHUB_WORKSPACE/netbird.aar -javapkg=io.netbird.gomobile -ldflags="-X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/io.netbird.client/cache/wireguard -X github.com/netbirdio/netbird/version.version=buildtest" $GITHUB_WORKSPACE/client/android + run: PATH=$PATH:$(go env GOPATH) gomobile bind -o $GITHUB_WORKSPACE/netbird.aar -javapkg=io.netbird.gomobile -ldflags="-checklinkname=0 -X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/io.netbird.client/cache/wireguard -X github.com/netbirdio/netbird/version.version=buildtest" $GITHUB_WORKSPACE/client/android env: CGO_ENABLED: 0 ANDROID_NDK_HOME: /usr/local/lib/android/sdk/ndk/23.1.7779620 diff --git a/client/android/client.go b/client/android/client.go index a17439696..0d0c76549 100644 --- a/client/android/client.go +++ b/client/android/client.go @@ -64,7 +64,9 @@ type Client struct { } // NewClient instantiate a new Client -func NewClient(cfgFile, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client { +func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client { + execWorkaround(androidSDKVersion) + net.SetAndroidProtectSocketFn(tunAdapter.ProtectSocket) return &Client{ cfgFile: cfgFile, diff --git a/client/android/exec.go b/client/android/exec.go new file mode 100644 index 000000000..805d3129b --- /dev/null +++ b/client/android/exec.go @@ -0,0 +1,26 @@ +//go:build android + +package android + +import ( + "fmt" + _ "unsafe" +) + +// https://github.com/golang/go/pull/69543/commits/aad6b3b32c81795f86bc4a9e81aad94899daf520 +// In Android version 11 and earlier, pidfd-related system calls +// are not allowed by the seccomp policy, which causes crashes due +// to SIGSYS signals. + +//go:linkname checkPidfdOnce os.checkPidfdOnce +var checkPidfdOnce func() error + +func execWorkaround(androidSDKVersion int) { + if androidSDKVersion > 30 { // above Android 11 + return + } + + checkPidfdOnce = func() error { + return fmt.Errorf("unsupported Android version") + } +} From e59d75d56ab047c5374177f16976988aac486405 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Thu, 10 Jul 2025 14:24:20 +0200 Subject: [PATCH 03/50] Nil check in iface configurer (#4132) --- client/iface/iface.go | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/client/iface/iface.go b/client/iface/iface.go index 1b9055e6c..e90c3536b 100644 --- a/client/iface/iface.go +++ b/client/iface/iface.go @@ -29,6 +29,11 @@ const ( WgInterfaceDefault = configurer.WgInterfaceDefault ) +var ( + // ErrIfaceNotFound is returned when the WireGuard interface is not found + ErrIfaceNotFound = fmt.Errorf("wireguard interface not found") +) + type wgProxyFactory interface { GetProxy() wgproxy.Proxy Free() error @@ -117,6 +122,9 @@ func (w *WGIface) UpdateAddr(newAddr string) error { func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { w.mu.Lock() defer w.mu.Unlock() + if w.configurer == nil { + return ErrIfaceNotFound + } log.Debugf("updating interface %s peer %s, endpoint %s, allowedIPs %v", w.tun.DeviceName(), peerKey, endpoint, allowedIps) return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey) @@ -126,6 +134,9 @@ func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAliv func (w *WGIface) RemovePeer(peerKey string) error { w.mu.Lock() defer w.mu.Unlock() + if w.configurer == nil { + return ErrIfaceNotFound + } log.Debugf("Removing peer %s from interface %s ", peerKey, w.tun.DeviceName()) return w.configurer.RemovePeer(peerKey) @@ -135,6 +146,9 @@ func (w *WGIface) RemovePeer(peerKey string) error { func (w *WGIface) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error { w.mu.Lock() defer w.mu.Unlock() + if w.configurer == nil { + return ErrIfaceNotFound + } log.Debugf("Adding allowed IP to interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP) return w.configurer.AddAllowedIP(peerKey, allowedIP) @@ -144,6 +158,9 @@ func (w *WGIface) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error { func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error { w.mu.Lock() defer w.mu.Unlock() + if w.configurer == nil { + return ErrIfaceNotFound + } log.Debugf("Removing allowed IP from interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP) return w.configurer.RemoveAllowedIP(peerKey, allowedIP) @@ -214,6 +231,9 @@ func (w *WGIface) GetWGDevice() *wgdevice.Device { // GetStats returns the last handshake time, rx and tx bytes func (w *WGIface) GetStats() (map[string]configurer.WGStats, error) { + if w.configurer == nil { + return nil, ErrIfaceNotFound + } return w.configurer.GetStats() } @@ -221,11 +241,19 @@ func (w *WGIface) LastActivities() map[string]time.Time { w.mu.Lock() defer w.mu.Unlock() + if w.configurer == nil { + return nil + } + return w.configurer.LastActivities() } func (w *WGIface) FullStats() (*configurer.Stats, error) { + if w.configurer == nil { + return nil, ErrIfaceNotFound + } + return w.configurer.FullStats() } From e3b40ba694a5f3b3396b09ebac86b224848901d9 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Thu, 10 Jul 2025 15:00:58 +0200 Subject: [PATCH 04/50] Update cli description of lazy connection (#4133) --- client/cmd/root.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/cmd/root.go b/client/cmd/root.go index 16e445f4d..e00a9b073 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -184,7 +184,7 @@ func init() { upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.") upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted") upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.") - upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand.") + upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand. Note: this setting may be overridden by management configuration.") debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", true, "Adds system information to the debug bundle") debugCmd.PersistentFlags().BoolVarP(&debugUploadBundle, uploadBundle, "U", false, fmt.Sprintf("Uploads the debug bundle to a server from URL defined by %s", uploadBundleURL)) From 8632dd15f13a6954d71f09af667089bb07571a26 Mon Sep 17 00:00:00 2001 From: Vlad <4941176+crn4@users.noreply.github.com> Date: Thu, 10 Jul 2025 15:21:01 +0200 Subject: [PATCH 05/50] [management] added cleanupWindow for collecting several ephemeral peers to delete (#4130) --------- Co-authored-by: Maycon Santos Co-authored-by: Pedro Costa <550684+pnmcosta@users.noreply.github.com> --- management/client/client_test.go | 6 ++ management/server/account/manager.go | 1 + management/server/dns_test.go | 2 + management/server/ephemeral.go | 37 +++++-- management/server/ephemeral_test.go | 98 ++++++++++++++++--- management/server/management_proto_test.go | 6 +- management/server/mock_server/account_mock.go | 4 + management/server/nameserver_test.go | 6 ++ management/server/peer.go | 27 +++-- management/server/peer_test.go | 10 ++ 10 files changed, 171 insertions(+), 26 deletions(-) diff --git a/management/client/client_test.go b/management/client/client_test.go index c163d1833..1847af73e 100644 --- a/management/client/client_test.go +++ b/management/client/client_test.go @@ -87,6 +87,12 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { ). Return(&types.Settings{}, nil). AnyTimes() + settingsMockManager. + EXPECT(). + GetExtraSettings(gomock.Any(), gomock.Any()). + Return(&types.ExtraSettings{}, nil). + AnyTimes() + permissionsManagerMock := permissions.NewMockManager(ctrl) permissionsManagerMock. EXPECT(). diff --git a/management/server/account/manager.go b/management/server/account/manager.go index ed17fa5ec..f8aa2756a 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -112,6 +112,7 @@ type Manager interface { GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error UpdateAccountPeers(ctx context.Context, accountID string) + BufferUpdateAccountPeers(ctx context.Context, accountID string) BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error) SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error GetStore() store.Store diff --git a/management/server/dns_test.go b/management/server/dns_test.go index 02bb042d7..31c944a25 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -216,6 +216,8 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) { t.Cleanup(ctrl.Finish) settingsMockManager := settings.NewMockManager(ctrl) + // return empty extra settings for expected calls to UpdateAccountPeers + settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), gomock.Any()).Return(&types.ExtraSettings{}, nil).AnyTimes() permissionsManager := permissions.NewManager(store) return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) } diff --git a/management/server/ephemeral.go b/management/server/ephemeral.go index 3cb9b7536..9f4348ebb 100644 --- a/management/server/ephemeral.go +++ b/management/server/ephemeral.go @@ -15,6 +15,8 @@ import ( const ( ephemeralLifeTime = 10 * time.Minute + // cleanupWindow is the time window to wait after nearest peer deadline to start the cleanup procedure. + cleanupWindow = 1 * time.Minute ) var ( @@ -41,6 +43,9 @@ type EphemeralManager struct { tailPeer *ephemeralPeer peersLock sync.Mutex timer *time.Timer + + lifeTime time.Duration + cleanupWindow time.Duration } // NewEphemeralManager instantiate new EphemeralManager @@ -48,6 +53,9 @@ func NewEphemeralManager(store store.Store, accountManager nbAccount.Manager) *E return &EphemeralManager{ store: store, accountManager: accountManager, + + lifeTime: ephemeralLifeTime, + cleanupWindow: cleanupWindow, } } @@ -60,7 +68,7 @@ func (e *EphemeralManager) LoadInitialPeers(ctx context.Context) { e.loadEphemeralPeers(ctx) if e.headPeer != nil { - e.timer = time.AfterFunc(ephemeralLifeTime, func() { + e.timer = time.AfterFunc(e.lifeTime, func() { e.cleanup(ctx) }) } @@ -113,9 +121,13 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer. return } - e.addPeer(peer.AccountID, peer.ID, newDeadLine()) + e.addPeer(peer.AccountID, peer.ID, e.newDeadLine()) if e.timer == nil { - e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), func() { + delay := e.headPeer.deadline.Sub(timeNow()) + e.cleanupWindow + if delay < 0 { + delay = 0 + } + e.timer = time.AfterFunc(delay, func() { e.cleanup(ctx) }) } @@ -128,7 +140,7 @@ func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) { return } - t := newDeadLine() + t := e.newDeadLine() for _, p := range peers { e.addPeer(p.AccountID, p.ID, t) } @@ -155,7 +167,11 @@ func (e *EphemeralManager) cleanup(ctx context.Context) { } if e.headPeer != nil { - e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), func() { + delay := e.headPeer.deadline.Sub(timeNow()) + e.cleanupWindow + if delay < 0 { + delay = 0 + } + e.timer = time.AfterFunc(delay, func() { e.cleanup(ctx) }) } else { @@ -164,13 +180,20 @@ func (e *EphemeralManager) cleanup(ctx context.Context) { e.peersLock.Unlock() + bufferAccountCall := make(map[string]struct{}) + for id, p := range deletePeers { log.WithContext(ctx).Debugf("delete ephemeral peer: %s", id) err := e.accountManager.DeletePeer(ctx, p.accountID, id, activity.SystemInitiator) if err != nil { log.WithContext(ctx).Errorf("failed to delete ephemeral peer: %s", err) + } else { + bufferAccountCall[p.accountID] = struct{}{} } } + for accountID := range bufferAccountCall { + e.accountManager.BufferUpdateAccountPeers(ctx, accountID) + } } func (e *EphemeralManager) addPeer(accountID string, peerID string, deadline time.Time) { @@ -223,6 +246,6 @@ func (e *EphemeralManager) isPeerOnList(id string) bool { return false } -func newDeadLine() time.Time { - return timeNow().Add(ephemeralLifeTime) +func (e *EphemeralManager) newDeadLine() time.Time { + return timeNow().Add(e.lifeTime) } diff --git a/management/server/ephemeral_test.go b/management/server/ephemeral_test.go index 3cf6ae7f3..f71d48c58 100644 --- a/management/server/ephemeral_test.go +++ b/management/server/ephemeral_test.go @@ -3,9 +3,12 @@ package server import ( "context" "fmt" + "sync" "testing" "time" + "github.com/stretchr/testify/assert" + nbAccount "github.com/netbirdio/netbird/management/server/account" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/store" @@ -27,28 +30,65 @@ func (s *MockStore) GetAllEphemeralPeers(_ context.Context, _ store.LockingStren return peers, nil } -type MocAccountManager struct { +type MockAccountManager struct { + mu sync.Mutex nbAccount.Manager - store *MockStore + store *MockStore + deletePeerCalls int + bufferUpdateCalls map[string]int + wg *sync.WaitGroup } -func (a MocAccountManager) DeletePeer(_ context.Context, accountID, peerID, userID string) error { +func (a *MockAccountManager) DeletePeer(_ context.Context, accountID, peerID, userID string) error { + a.mu.Lock() + defer a.mu.Unlock() + a.deletePeerCalls++ + if a.wg != nil { + a.wg.Done() + } delete(a.store.account.Peers, peerID) - return nil //nolint:nil + return nil } -func (a MocAccountManager) GetStore() store.Store { +func (a *MockAccountManager) GetDeletePeerCalls() int { + a.mu.Lock() + defer a.mu.Unlock() + return a.deletePeerCalls +} + +func (a *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) { + a.mu.Lock() + defer a.mu.Unlock() + if a.bufferUpdateCalls == nil { + a.bufferUpdateCalls = make(map[string]int) + } + a.bufferUpdateCalls[accountID]++ +} + +func (a *MockAccountManager) GetBufferUpdateCalls(accountID string) int { + a.mu.Lock() + defer a.mu.Unlock() + if a.bufferUpdateCalls == nil { + return 0 + } + return a.bufferUpdateCalls[accountID] +} + +func (a *MockAccountManager) GetStore() store.Store { return a.store } func TestNewManager(t *testing.T) { + t.Cleanup(func() { + timeNow = time.Now + }) startTime := time.Now() timeNow = func() time.Time { return startTime } store := &MockStore{} - am := MocAccountManager{ + am := MockAccountManager{ store: store, } @@ -56,7 +96,7 @@ func TestNewManager(t *testing.T) { numberOfEphemeralPeers := 3 seedPeers(store, numberOfPeers, numberOfEphemeralPeers) - mgr := NewEphemeralManager(store, am) + mgr := NewEphemeralManager(store, &am) mgr.loadEphemeralPeers(context.Background()) startTime = startTime.Add(ephemeralLifeTime + 1) mgr.cleanup(context.Background()) @@ -67,13 +107,16 @@ func TestNewManager(t *testing.T) { } func TestNewManagerPeerConnected(t *testing.T) { + t.Cleanup(func() { + timeNow = time.Now + }) startTime := time.Now() timeNow = func() time.Time { return startTime } store := &MockStore{} - am := MocAccountManager{ + am := MockAccountManager{ store: store, } @@ -81,7 +124,7 @@ func TestNewManagerPeerConnected(t *testing.T) { numberOfEphemeralPeers := 3 seedPeers(store, numberOfPeers, numberOfEphemeralPeers) - mgr := NewEphemeralManager(store, am) + mgr := NewEphemeralManager(store, &am) mgr.loadEphemeralPeers(context.Background()) mgr.OnPeerConnected(context.Background(), store.account.Peers["ephemeral_peer_0"]) @@ -95,13 +138,16 @@ func TestNewManagerPeerConnected(t *testing.T) { } func TestNewManagerPeerDisconnected(t *testing.T) { + t.Cleanup(func() { + timeNow = time.Now + }) startTime := time.Now() timeNow = func() time.Time { return startTime } store := &MockStore{} - am := MocAccountManager{ + am := MockAccountManager{ store: store, } @@ -109,7 +155,7 @@ func TestNewManagerPeerDisconnected(t *testing.T) { numberOfEphemeralPeers := 3 seedPeers(store, numberOfPeers, numberOfEphemeralPeers) - mgr := NewEphemeralManager(store, am) + mgr := NewEphemeralManager(store, &am) mgr.loadEphemeralPeers(context.Background()) for _, v := range store.account.Peers { mgr.OnPeerConnected(context.Background(), v) @@ -126,6 +172,36 @@ func TestNewManagerPeerDisconnected(t *testing.T) { } } +func TestCleanupSchedulingBehaviorIsBatched(t *testing.T) { + const ( + ephemeralPeers = 10 + testLifeTime = 1 * time.Second + testCleanupWindow = 100 * time.Millisecond + ) + mockStore := &MockStore{} + mockAM := &MockAccountManager{ + store: mockStore, + } + mockAM.wg = &sync.WaitGroup{} + mockAM.wg.Add(ephemeralPeers) + mgr := NewEphemeralManager(mockStore, mockAM) + mgr.lifeTime = testLifeTime + mgr.cleanupWindow = testCleanupWindow + + account := newAccountWithId(context.Background(), "account", "", "", false) + mockStore.account = account + for i := range ephemeralPeers { + p := &nbpeer.Peer{ID: fmt.Sprintf("peer-%d", i), AccountID: account.Id, Ephemeral: true} + mockStore.account.Peers[p.ID] = p + time.Sleep(testCleanupWindow / ephemeralPeers) + mgr.OnPeerDisconnected(context.Background(), p) + } + mockAM.wg.Wait() + assert.Len(t, mockStore.account.Peers, 0, "all ephemeral peers should be cleaned up after the lifetime") + assert.Equal(t, 1, mockAM.GetBufferUpdateCalls(account.Id), "buffer update should be called once") + assert.Equal(t, ephemeralPeers, mockAM.GetDeletePeerCalls(), "should have deleted all peers") +} + func seedPeers(store *MockStore, numberOfPeers int, numberOfEphemeralPeers int) { store.account = newAccountWithId(context.Background(), "my account", "", "", false) diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index 337890ef9..57c00ed9f 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -440,7 +440,11 @@ func startManagementForTest(t *testing.T, testFile string, config *types.Config) GetSettings(gomock.Any(), gomock.Any(), gomock.Any()). AnyTimes(). Return(&types.Settings{}, nil) - + settingsMockManager. + EXPECT(). + GetExtraSettings(gomock.Any(), gomock.Any()). + Return(&types.ExtraSettings{}, nil). + AnyTimes() permissionsManager := permissions.NewManager(store) accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted", diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 8837f9f50..4004f1b57 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -126,6 +126,10 @@ func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID // do nothing } +func (am *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) { + // do nothing +} + func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error { if am.DeleteSetupKeyFunc != nil { return am.DeleteSetupKeyFunc(ctx, accountID, userID, keyID) diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 75d1e7972..8fada742c 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -778,6 +778,12 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) { ctrl := gomock.NewController(t) t.Cleanup(ctrl.Finish) settingsMockManager := settings.NewMockManager(ctrl) + settingsMockManager. + EXPECT(). + GetExtraSettings(gomock.Any(), gomock.Any()). + Return(&types.ExtraSettings{}, nil). + AnyTimes() + permissionsManager := permissions.NewManager(store) return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) } diff --git a/management/server/peer.go b/management/server/peer.go index 44156e534..a60513b38 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -375,7 +375,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer storeEvent() } - if updateAccountPeers { + if updateAccountPeers && userID != activity.SystemInitiator { am.BufferUpdateAccountPeers(ctx, accountID) } @@ -1177,6 +1177,19 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account globalStart := time.Now() + hasPeersConnected := false + for _, peer := range account.Peers { + if am.peersUpdateManager.HasChannel(peer.ID) { + hasPeersConnected = true + break + } + + } + + if !hasPeersConnected { + return + } + approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) if err != nil { log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to get validate peers: %v", err) @@ -1198,6 +1211,12 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account return } + extraSetting, err := am.settingsManager.GetExtraSettings(ctx, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to get flow enabled status: %v", err) + return + } + for _, peer := range account.Peers { if !am.peersUpdateManager.HasChannel(peer.ID) { log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID) @@ -1232,12 +1251,6 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account } am.metrics.UpdateChannelMetrics().CountMergeNetworkMapDuration(time.Since(start)) - extraSetting, err := am.settingsManager.GetExtraSettings(ctx, accountID) - if err != nil { - log.WithContext(ctx).Errorf("failed to get flow enabled status: %v", err) - return - } - start = time.Now() update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting) am.metrics.UpdateChannelMetrics().CountToSyncResponseDuration(time.Since(start)) diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 31439d670..07ec5037b 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -1344,6 +1344,11 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { ctrl := gomock.NewController(t) t.Cleanup(ctrl.Finish) settingsMockManager := settings.NewMockManager(ctrl) + settingsMockManager. + EXPECT(). + GetExtraSettings(gomock.Any(), gomock.Any()). + Return(&types.ExtraSettings{}, nil). + AnyTimes() permissionsManager := permissions.NewManager(s) am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) @@ -1556,6 +1561,11 @@ func Test_LoginPeer(t *testing.T) { ctrl := gomock.NewController(t) t.Cleanup(ctrl.Finish) settingsMockManager := settings.NewMockManager(ctrl) + settingsMockManager. + EXPECT(). + GetExtraSettings(gomock.Any(), gomock.Any()). + Return(&types.ExtraSettings{}, nil). + AnyTimes() permissionsManager := permissions.NewManager(s) am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) From a7ea881900b53bdafc80ccb7ad42e96eda308ef7 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 10 Jul 2025 16:13:53 +0200 Subject: [PATCH 06/50] [client] Add rotated logs flag for debug bundle generation (#4100) --- client/cmd/debug.go | 46 +++++--- client/cmd/root.go | 10 -- client/internal/debug/debug.go | 83 +++++++++----- client/proto/daemon.pb.go | 13 ++- client/proto/daemon.proto | 1 + client/proto/daemon_grpc.pb.go | 192 ++++++++++++++------------------- client/proto/generate.sh | 2 +- client/server/debug.go | 1 + go.mod | 2 +- go.sum | 4 +- util/log.go | 2 +- 11 files changed, 188 insertions(+), 168 deletions(-) diff --git a/client/cmd/debug.go b/client/cmd/debug.go index 385bd95f5..4036bb8f6 100644 --- a/client/cmd/debug.go +++ b/client/cmd/debug.go @@ -17,10 +17,18 @@ import ( "github.com/netbirdio/netbird/client/server" nbstatus "github.com/netbirdio/netbird/client/status" mgmProto "github.com/netbirdio/netbird/management/proto" + "github.com/netbirdio/netbird/upload-server/types" ) const errCloseConnection = "Failed to close connection: %v" +var ( + logFileCount uint32 + systemInfoFlag bool + uploadBundleFlag bool + uploadBundleURLFlag string +) + var debugCmd = &cobra.Command{ Use: "debug", Short: "Debugging commands", @@ -88,12 +96,13 @@ func debugBundle(cmd *cobra.Command, _ []string) error { client := proto.NewDaemonServiceClient(conn) request := &proto.DebugBundleRequest{ - Anonymize: anonymizeFlag, - Status: getStatusOutput(cmd, anonymizeFlag), - SystemInfo: debugSystemInfoFlag, + Anonymize: anonymizeFlag, + Status: getStatusOutput(cmd, anonymizeFlag), + SystemInfo: systemInfoFlag, + LogFileCount: logFileCount, } - if debugUploadBundle { - request.UploadURL = debugUploadBundleURL + if uploadBundleFlag { + request.UploadURL = uploadBundleURLFlag } resp, err := client.DebugBundle(cmd.Context(), request) if err != nil { @@ -105,7 +114,7 @@ func debugBundle(cmd *cobra.Command, _ []string) error { return fmt.Errorf("upload failed: %s", resp.GetUploadFailureReason()) } - if debugUploadBundle { + if uploadBundleFlag { cmd.Printf("Upload file key:\n%s\n", resp.GetUploadedKey()) } @@ -223,12 +232,13 @@ func runForDuration(cmd *cobra.Command, args []string) error { headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration) statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag)) request := &proto.DebugBundleRequest{ - Anonymize: anonymizeFlag, - Status: statusOutput, - SystemInfo: debugSystemInfoFlag, + Anonymize: anonymizeFlag, + Status: statusOutput, + SystemInfo: systemInfoFlag, + LogFileCount: logFileCount, } - if debugUploadBundle { - request.UploadURL = debugUploadBundleURL + if uploadBundleFlag { + request.UploadURL = uploadBundleURLFlag } resp, err := client.DebugBundle(cmd.Context(), request) if err != nil { @@ -255,7 +265,7 @@ func runForDuration(cmd *cobra.Command, args []string) error { return fmt.Errorf("upload failed: %s", resp.GetUploadFailureReason()) } - if debugUploadBundle { + if uploadBundleFlag { cmd.Printf("Upload file key:\n%s\n", resp.GetUploadedKey()) } @@ -375,3 +385,15 @@ func generateDebugBundle(config *internal.Config, recorder *peer.Status, connect } log.Infof("Generated debug bundle from SIGUSR1 at: %s", path) } + +func init() { + debugBundleCmd.Flags().Uint32VarP(&logFileCount, "log-file-count", "C", 1, "Number of rotated log files to include in debug bundle") + debugBundleCmd.Flags().BoolVarP(&systemInfoFlag, "system-info", "S", true, "Adds system information to the debug bundle") + debugBundleCmd.Flags().BoolVarP(&uploadBundleFlag, "upload-bundle", "U", false, "Uploads the debug bundle to a server") + debugBundleCmd.Flags().StringVar(&uploadBundleURLFlag, "upload-bundle-url", types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle") + + forCmd.Flags().Uint32VarP(&logFileCount, "log-file-count", "C", 1, "Number of rotated log files to include in debug bundle") + forCmd.Flags().BoolVarP(&systemInfoFlag, "system-info", "S", true, "Adds system information to the debug bundle") + forCmd.Flags().BoolVarP(&uploadBundleFlag, "upload-bundle", "U", false, "Uploads the debug bundle to a server") + forCmd.Flags().StringVar(&uploadBundleURLFlag, "upload-bundle-url", types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle") +} diff --git a/client/cmd/root.go b/client/cmd/root.go index e00a9b073..fa4bd4d42 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -22,7 +22,6 @@ import ( "google.golang.org/grpc/credentials/insecure" "github.com/netbirdio/netbird/client/internal" - "github.com/netbirdio/netbird/upload-server/types" ) const ( @@ -38,10 +37,7 @@ const ( serverSSHAllowedFlag = "allow-server-ssh" extraIFaceBlackListFlag = "extra-iface-blacklist" dnsRouteIntervalFlag = "dns-router-interval" - systemInfoFlag = "system-info" enableLazyConnectionFlag = "enable-lazy-connection" - uploadBundle = "upload-bundle" - uploadBundleURL = "upload-bundle-url" ) var ( @@ -75,10 +71,7 @@ var ( autoConnectDisabled bool extraIFaceBlackList []string anonymizeFlag bool - debugSystemInfoFlag bool dnsRouteInterval time.Duration - debugUploadBundle bool - debugUploadBundleURL string lazyConnEnabled bool rootCmd = &cobra.Command{ @@ -186,9 +179,6 @@ func init() { upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.") upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand. Note: this setting may be overridden by management configuration.") - debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", true, "Adds system information to the debug bundle") - debugCmd.PersistentFlags().BoolVarP(&debugUploadBundle, uploadBundle, "U", false, fmt.Sprintf("Uploads the debug bundle to a server from URL defined by %s", uploadBundleURL)) - debugCmd.PersistentFlags().StringVar(&debugUploadBundleURL, uploadBundleURL, types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle") } // SetupCloseHandler handles SIGTERM signal and exits with success diff --git a/client/internal/debug/debug.go b/client/internal/debug/debug.go index dfed47f05..6455b3aaf 100644 --- a/client/internal/debug/debug.go +++ b/client/internal/debug/debug.go @@ -167,6 +167,7 @@ type BundleGenerator struct { anonymize bool clientStatus string includeSystemInfo bool + logFileCount uint32 archive *zip.Writer } @@ -175,6 +176,7 @@ type BundleConfig struct { Anonymize bool ClientStatus string IncludeSystemInfo bool + LogFileCount uint32 } type GeneratorDependencies struct { @@ -185,6 +187,12 @@ type GeneratorDependencies struct { } func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGenerator { + // Default to 1 log file for backward compatibility when 0 is provided + logFileCount := cfg.LogFileCount + if logFileCount == 0 { + logFileCount = 1 + } + return &BundleGenerator{ anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()), @@ -196,6 +204,7 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen anonymize: cfg.Anonymize, clientStatus: cfg.ClientStatus, includeSystemInfo: cfg.IncludeSystemInfo, + logFileCount: logFileCount, } } @@ -561,32 +570,8 @@ func (g *BundleGenerator) addLogfile() error { return fmt.Errorf("add client log file to zip: %w", err) } - // add latest rotated log file - pattern := filepath.Join(logDir, "client-*.log.gz") - files, err := filepath.Glob(pattern) - if err != nil { - log.Warnf("failed to glob rotated logs: %v", err) - } else if len(files) > 0 { - // pick the file with the latest ModTime - sort.Slice(files, func(i, j int) bool { - fi, err := os.Stat(files[i]) - if err != nil { - log.Warnf("failed to stat rotated log %s: %v", files[i], err) - return false - } - fj, err := os.Stat(files[j]) - if err != nil { - log.Warnf("failed to stat rotated log %s: %v", files[j], err) - return false - } - return fi.ModTime().Before(fj.ModTime()) - }) - latest := files[len(files)-1] - name := filepath.Base(latest) - if err := g.addSingleLogFileGz(latest, name); err != nil { - log.Warnf("failed to add rotated log %s: %v", name, err) - } - } + // add rotated log files based on logFileCount + g.addRotatedLogFiles(logDir) stdErrLogPath := filepath.Join(logDir, errorLogFile) stdoutLogPath := filepath.Join(logDir, stdoutLogFile) @@ -670,6 +655,52 @@ func (g *BundleGenerator) addSingleLogFileGz(logPath, targetName string) error { return nil } +// addRotatedLogFiles adds rotated log files to the bundle based on logFileCount +func (g *BundleGenerator) addRotatedLogFiles(logDir string) { + if g.logFileCount == 0 { + return + } + + pattern := filepath.Join(logDir, "client-*.log.gz") + files, err := filepath.Glob(pattern) + if err != nil { + log.Warnf("failed to glob rotated logs: %v", err) + return + } + + if len(files) == 0 { + return + } + + // sort files by modification time (newest first) + sort.Slice(files, func(i, j int) bool { + fi, err := os.Stat(files[i]) + if err != nil { + log.Warnf("failed to stat rotated log %s: %v", files[i], err) + return false + } + fj, err := os.Stat(files[j]) + if err != nil { + log.Warnf("failed to stat rotated log %s: %v", files[j], err) + return false + } + return fi.ModTime().After(fj.ModTime()) + }) + + // include up to logFileCount rotated files + maxFiles := int(g.logFileCount) + if maxFiles > len(files) { + maxFiles = len(files) + } + + for i := 0; i < maxFiles; i++ { + name := filepath.Base(files[i]) + if err := g.addSingleLogFileGz(files[i], name); err != nil { + log.Warnf("failed to add rotated log %s: %v", name, err) + } + } +} + func (g *BundleGenerator) addFileToZip(reader io.Reader, filename string) error { header := &zip.FileHeader{ Name: filename, diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index 202dc6f89..26e58d183 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -2290,6 +2290,7 @@ type DebugBundleRequest struct { Status string `protobuf:"bytes,2,opt,name=status,proto3" json:"status,omitempty"` SystemInfo bool `protobuf:"varint,3,opt,name=systemInfo,proto3" json:"systemInfo,omitempty"` UploadURL string `protobuf:"bytes,4,opt,name=uploadURL,proto3" json:"uploadURL,omitempty"` + LogFileCount uint32 `protobuf:"varint,5,opt,name=logFileCount,proto3" json:"logFileCount,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -2352,6 +2353,13 @@ func (x *DebugBundleRequest) GetUploadURL() string { return "" } +func (x *DebugBundleRequest) GetLogFileCount() uint32 { + if x != nil { + return x.LogFileCount + } + return 0 +} + type DebugBundleResponse struct { state protoimpl.MessageState `protogen:"open.v1"` Path string `protobuf:"bytes,1,opt,name=path,proto3" json:"path,omitempty"` @@ -3746,14 +3754,15 @@ const file_daemon_proto_rawDesc = "" + "\x12translatedHostname\x18\x04 \x01(\tR\x12translatedHostname\x128\n" + "\x0etranslatedPort\x18\x05 \x01(\v2\x10.daemon.PortInfoR\x0etranslatedPort\"G\n" + "\x17ForwardingRulesResponse\x12,\n" + - "\x05rules\x18\x01 \x03(\v2\x16.daemon.ForwardingRuleR\x05rules\"\x88\x01\n" + + "\x05rules\x18\x01 \x03(\v2\x16.daemon.ForwardingRuleR\x05rules\"\xac\x01\n" + "\x12DebugBundleRequest\x12\x1c\n" + "\tanonymize\x18\x01 \x01(\bR\tanonymize\x12\x16\n" + "\x06status\x18\x02 \x01(\tR\x06status\x12\x1e\n" + "\n" + "systemInfo\x18\x03 \x01(\bR\n" + "systemInfo\x12\x1c\n" + - "\tuploadURL\x18\x04 \x01(\tR\tuploadURL\"}\n" + + "\tuploadURL\x18\x04 \x01(\tR\tuploadURL\x12\"\n" + + "\flogFileCount\x18\x05 \x01(\rR\flogFileCount\"}\n" + "\x13DebugBundleResponse\x12\x12\n" + "\x04path\x18\x01 \x01(\tR\x04path\x12 \n" + "\vuploadedKey\x18\x02 \x01(\tR\vuploadedKey\x120\n" + diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index f488e69e7..462555c82 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -356,6 +356,7 @@ message DebugBundleRequest { string status = 2; bool systemInfo = 3; string uploadURL = 4; + uint32 logFileCount = 5; } message DebugBundleResponse { diff --git a/client/proto/daemon_grpc.pb.go b/client/proto/daemon_grpc.pb.go index e0612a6d1..6251f7c52 100644 --- a/client/proto/daemon_grpc.pb.go +++ b/client/proto/daemon_grpc.pb.go @@ -1,8 +1,4 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. -// versions: -// - protoc-gen-go-grpc v1.5.1 -// - protoc v5.29.3 -// source: daemon.proto package proto @@ -15,31 +11,8 @@ import ( // This is a compile-time assertion to ensure that this generated file // is compatible with the grpc package it is being compiled against. -// Requires gRPC-Go v1.64.0 or later. -const _ = grpc.SupportPackageIsVersion9 - -const ( - DaemonService_Login_FullMethodName = "/daemon.DaemonService/Login" - DaemonService_WaitSSOLogin_FullMethodName = "/daemon.DaemonService/WaitSSOLogin" - DaemonService_Up_FullMethodName = "/daemon.DaemonService/Up" - DaemonService_Status_FullMethodName = "/daemon.DaemonService/Status" - DaemonService_Down_FullMethodName = "/daemon.DaemonService/Down" - DaemonService_GetConfig_FullMethodName = "/daemon.DaemonService/GetConfig" - DaemonService_ListNetworks_FullMethodName = "/daemon.DaemonService/ListNetworks" - DaemonService_SelectNetworks_FullMethodName = "/daemon.DaemonService/SelectNetworks" - DaemonService_DeselectNetworks_FullMethodName = "/daemon.DaemonService/DeselectNetworks" - DaemonService_ForwardingRules_FullMethodName = "/daemon.DaemonService/ForwardingRules" - DaemonService_DebugBundle_FullMethodName = "/daemon.DaemonService/DebugBundle" - DaemonService_GetLogLevel_FullMethodName = "/daemon.DaemonService/GetLogLevel" - DaemonService_SetLogLevel_FullMethodName = "/daemon.DaemonService/SetLogLevel" - DaemonService_ListStates_FullMethodName = "/daemon.DaemonService/ListStates" - DaemonService_CleanState_FullMethodName = "/daemon.DaemonService/CleanState" - DaemonService_DeleteState_FullMethodName = "/daemon.DaemonService/DeleteState" - DaemonService_SetNetworkMapPersistence_FullMethodName = "/daemon.DaemonService/SetNetworkMapPersistence" - DaemonService_TracePacket_FullMethodName = "/daemon.DaemonService/TracePacket" - DaemonService_SubscribeEvents_FullMethodName = "/daemon.DaemonService/SubscribeEvents" - DaemonService_GetEvents_FullMethodName = "/daemon.DaemonService/GetEvents" -) +// Requires gRPC-Go v1.32.0 or later. +const _ = grpc.SupportPackageIsVersion7 // DaemonServiceClient is the client API for DaemonService service. // @@ -80,7 +53,7 @@ type DaemonServiceClient interface { // SetNetworkMapPersistence enables or disables network map persistence SetNetworkMapPersistence(ctx context.Context, in *SetNetworkMapPersistenceRequest, opts ...grpc.CallOption) (*SetNetworkMapPersistenceResponse, error) TracePacket(ctx context.Context, in *TracePacketRequest, opts ...grpc.CallOption) (*TracePacketResponse, error) - SubscribeEvents(ctx context.Context, in *SubscribeRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[SystemEvent], error) + SubscribeEvents(ctx context.Context, in *SubscribeRequest, opts ...grpc.CallOption) (DaemonService_SubscribeEventsClient, error) GetEvents(ctx context.Context, in *GetEventsRequest, opts ...grpc.CallOption) (*GetEventsResponse, error) } @@ -93,9 +66,8 @@ func NewDaemonServiceClient(cc grpc.ClientConnInterface) DaemonServiceClient { } func (c *daemonServiceClient) Login(ctx context.Context, in *LoginRequest, opts ...grpc.CallOption) (*LoginResponse, error) { - cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(LoginResponse) - err := c.cc.Invoke(ctx, DaemonService_Login_FullMethodName, in, out, cOpts...) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/Login", in, out, opts...) if err != nil { return nil, err } @@ -103,9 +75,8 @@ func (c *daemonServiceClient) Login(ctx context.Context, in *LoginRequest, opts } func (c *daemonServiceClient) WaitSSOLogin(ctx context.Context, in *WaitSSOLoginRequest, opts ...grpc.CallOption) (*WaitSSOLoginResponse, error) { - cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(WaitSSOLoginResponse) - err := c.cc.Invoke(ctx, DaemonService_WaitSSOLogin_FullMethodName, in, out, cOpts...) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/WaitSSOLogin", in, out, opts...) if err != nil { return nil, err } @@ -113,9 +84,8 @@ func (c *daemonServiceClient) WaitSSOLogin(ctx context.Context, in *WaitSSOLogin } func (c *daemonServiceClient) Up(ctx context.Context, in *UpRequest, opts ...grpc.CallOption) (*UpResponse, error) { - cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(UpResponse) - err := c.cc.Invoke(ctx, DaemonService_Up_FullMethodName, in, out, cOpts...) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/Up", in, out, opts...) if err != nil { return nil, err } @@ -123,9 +93,8 @@ func (c *daemonServiceClient) Up(ctx context.Context, in *UpRequest, opts ...grp } func (c *daemonServiceClient) Status(ctx context.Context, in *StatusRequest, opts ...grpc.CallOption) (*StatusResponse, error) { - cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(StatusResponse) - err := c.cc.Invoke(ctx, DaemonService_Status_FullMethodName, in, out, cOpts...) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/Status", in, out, opts...) if err != nil { return nil, err } @@ -133,9 +102,8 @@ func (c *daemonServiceClient) Status(ctx context.Context, in *StatusRequest, opt } func (c *daemonServiceClient) Down(ctx context.Context, in *DownRequest, opts ...grpc.CallOption) (*DownResponse, error) { - cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(DownResponse) - err := c.cc.Invoke(ctx, DaemonService_Down_FullMethodName, in, out, cOpts...) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/Down", in, out, opts...) if err != nil { return nil, err } @@ -143,9 +111,8 @@ func (c *daemonServiceClient) Down(ctx context.Context, in *DownRequest, opts .. } func (c *daemonServiceClient) GetConfig(ctx context.Context, in *GetConfigRequest, opts ...grpc.CallOption) (*GetConfigResponse, error) { - cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(GetConfigResponse) - err := c.cc.Invoke(ctx, DaemonService_GetConfig_FullMethodName, in, out, cOpts...) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetConfig", in, out, opts...) if err != nil { return nil, err } @@ -153,9 +120,8 @@ func (c *daemonServiceClient) GetConfig(ctx context.Context, in *GetConfigReques } func (c *daemonServiceClient) ListNetworks(ctx context.Context, in *ListNetworksRequest, opts ...grpc.CallOption) (*ListNetworksResponse, error) { - cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(ListNetworksResponse) - err := c.cc.Invoke(ctx, DaemonService_ListNetworks_FullMethodName, in, out, cOpts...) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/ListNetworks", in, out, opts...) if err != nil { return nil, err } @@ -163,9 +129,8 @@ func (c *daemonServiceClient) ListNetworks(ctx context.Context, in *ListNetworks } func (c *daemonServiceClient) SelectNetworks(ctx context.Context, in *SelectNetworksRequest, opts ...grpc.CallOption) (*SelectNetworksResponse, error) { - cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(SelectNetworksResponse) - err := c.cc.Invoke(ctx, DaemonService_SelectNetworks_FullMethodName, in, out, cOpts...) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/SelectNetworks", in, out, opts...) if err != nil { return nil, err } @@ -173,9 +138,8 @@ func (c *daemonServiceClient) SelectNetworks(ctx context.Context, in *SelectNetw } func (c *daemonServiceClient) DeselectNetworks(ctx context.Context, in *SelectNetworksRequest, opts ...grpc.CallOption) (*SelectNetworksResponse, error) { - cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(SelectNetworksResponse) - err := c.cc.Invoke(ctx, DaemonService_DeselectNetworks_FullMethodName, in, out, cOpts...) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/DeselectNetworks", in, out, opts...) if err != nil { return nil, err } @@ -183,9 +147,8 @@ func (c *daemonServiceClient) DeselectNetworks(ctx context.Context, in *SelectNe } func (c *daemonServiceClient) ForwardingRules(ctx context.Context, in *EmptyRequest, opts ...grpc.CallOption) (*ForwardingRulesResponse, error) { - cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(ForwardingRulesResponse) - err := c.cc.Invoke(ctx, DaemonService_ForwardingRules_FullMethodName, in, out, cOpts...) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/ForwardingRules", in, out, opts...) if err != nil { return nil, err } @@ -193,9 +156,8 @@ func (c *daemonServiceClient) ForwardingRules(ctx context.Context, in *EmptyRequ } func (c *daemonServiceClient) DebugBundle(ctx context.Context, in *DebugBundleRequest, opts ...grpc.CallOption) (*DebugBundleResponse, error) { - cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(DebugBundleResponse) - err := c.cc.Invoke(ctx, DaemonService_DebugBundle_FullMethodName, in, out, cOpts...) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/DebugBundle", in, out, opts...) if err != nil { return nil, err } @@ -203,9 +165,8 @@ func (c *daemonServiceClient) DebugBundle(ctx context.Context, in *DebugBundleRe } func (c *daemonServiceClient) GetLogLevel(ctx context.Context, in *GetLogLevelRequest, opts ...grpc.CallOption) (*GetLogLevelResponse, error) { - cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(GetLogLevelResponse) - err := c.cc.Invoke(ctx, DaemonService_GetLogLevel_FullMethodName, in, out, cOpts...) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetLogLevel", in, out, opts...) if err != nil { return nil, err } @@ -213,9 +174,8 @@ func (c *daemonServiceClient) GetLogLevel(ctx context.Context, in *GetLogLevelRe } func (c *daemonServiceClient) SetLogLevel(ctx context.Context, in *SetLogLevelRequest, opts ...grpc.CallOption) (*SetLogLevelResponse, error) { - cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(SetLogLevelResponse) - err := c.cc.Invoke(ctx, DaemonService_SetLogLevel_FullMethodName, in, out, cOpts...) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/SetLogLevel", in, out, opts...) if err != nil { return nil, err } @@ -223,9 +183,8 @@ func (c *daemonServiceClient) SetLogLevel(ctx context.Context, in *SetLogLevelRe } func (c *daemonServiceClient) ListStates(ctx context.Context, in *ListStatesRequest, opts ...grpc.CallOption) (*ListStatesResponse, error) { - cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(ListStatesResponse) - err := c.cc.Invoke(ctx, DaemonService_ListStates_FullMethodName, in, out, cOpts...) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/ListStates", in, out, opts...) if err != nil { return nil, err } @@ -233,9 +192,8 @@ func (c *daemonServiceClient) ListStates(ctx context.Context, in *ListStatesRequ } func (c *daemonServiceClient) CleanState(ctx context.Context, in *CleanStateRequest, opts ...grpc.CallOption) (*CleanStateResponse, error) { - cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(CleanStateResponse) - err := c.cc.Invoke(ctx, DaemonService_CleanState_FullMethodName, in, out, cOpts...) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/CleanState", in, out, opts...) if err != nil { return nil, err } @@ -243,9 +201,8 @@ func (c *daemonServiceClient) CleanState(ctx context.Context, in *CleanStateRequ } func (c *daemonServiceClient) DeleteState(ctx context.Context, in *DeleteStateRequest, opts ...grpc.CallOption) (*DeleteStateResponse, error) { - cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(DeleteStateResponse) - err := c.cc.Invoke(ctx, DaemonService_DeleteState_FullMethodName, in, out, cOpts...) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/DeleteState", in, out, opts...) if err != nil { return nil, err } @@ -253,9 +210,8 @@ func (c *daemonServiceClient) DeleteState(ctx context.Context, in *DeleteStateRe } func (c *daemonServiceClient) SetNetworkMapPersistence(ctx context.Context, in *SetNetworkMapPersistenceRequest, opts ...grpc.CallOption) (*SetNetworkMapPersistenceResponse, error) { - cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(SetNetworkMapPersistenceResponse) - err := c.cc.Invoke(ctx, DaemonService_SetNetworkMapPersistence_FullMethodName, in, out, cOpts...) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/SetNetworkMapPersistence", in, out, opts...) if err != nil { return nil, err } @@ -263,22 +219,20 @@ func (c *daemonServiceClient) SetNetworkMapPersistence(ctx context.Context, in * } func (c *daemonServiceClient) TracePacket(ctx context.Context, in *TracePacketRequest, opts ...grpc.CallOption) (*TracePacketResponse, error) { - cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(TracePacketResponse) - err := c.cc.Invoke(ctx, DaemonService_TracePacket_FullMethodName, in, out, cOpts...) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/TracePacket", in, out, opts...) if err != nil { return nil, err } return out, nil } -func (c *daemonServiceClient) SubscribeEvents(ctx context.Context, in *SubscribeRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[SystemEvent], error) { - cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) - stream, err := c.cc.NewStream(ctx, &DaemonService_ServiceDesc.Streams[0], DaemonService_SubscribeEvents_FullMethodName, cOpts...) +func (c *daemonServiceClient) SubscribeEvents(ctx context.Context, in *SubscribeRequest, opts ...grpc.CallOption) (DaemonService_SubscribeEventsClient, error) { + stream, err := c.cc.NewStream(ctx, &DaemonService_ServiceDesc.Streams[0], "/daemon.DaemonService/SubscribeEvents", opts...) if err != nil { return nil, err } - x := &grpc.GenericClientStream[SubscribeRequest, SystemEvent]{ClientStream: stream} + x := &daemonServiceSubscribeEventsClient{stream} if err := x.ClientStream.SendMsg(in); err != nil { return nil, err } @@ -288,13 +242,26 @@ func (c *daemonServiceClient) SubscribeEvents(ctx context.Context, in *Subscribe return x, nil } -// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. -type DaemonService_SubscribeEventsClient = grpc.ServerStreamingClient[SystemEvent] +type DaemonService_SubscribeEventsClient interface { + Recv() (*SystemEvent, error) + grpc.ClientStream +} + +type daemonServiceSubscribeEventsClient struct { + grpc.ClientStream +} + +func (x *daemonServiceSubscribeEventsClient) Recv() (*SystemEvent, error) { + m := new(SystemEvent) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} func (c *daemonServiceClient) GetEvents(ctx context.Context, in *GetEventsRequest, opts ...grpc.CallOption) (*GetEventsResponse, error) { - cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(GetEventsResponse) - err := c.cc.Invoke(ctx, DaemonService_GetEvents_FullMethodName, in, out, cOpts...) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetEvents", in, out, opts...) if err != nil { return nil, err } @@ -303,7 +270,7 @@ func (c *daemonServiceClient) GetEvents(ctx context.Context, in *GetEventsReques // DaemonServiceServer is the server API for DaemonService service. // All implementations must embed UnimplementedDaemonServiceServer -// for forward compatibility. +// for forward compatibility type DaemonServiceServer interface { // Login uses setup key to prepare configuration for the daemon. Login(context.Context, *LoginRequest) (*LoginResponse, error) @@ -340,17 +307,14 @@ type DaemonServiceServer interface { // SetNetworkMapPersistence enables or disables network map persistence SetNetworkMapPersistence(context.Context, *SetNetworkMapPersistenceRequest) (*SetNetworkMapPersistenceResponse, error) TracePacket(context.Context, *TracePacketRequest) (*TracePacketResponse, error) - SubscribeEvents(*SubscribeRequest, grpc.ServerStreamingServer[SystemEvent]) error + SubscribeEvents(*SubscribeRequest, DaemonService_SubscribeEventsServer) error GetEvents(context.Context, *GetEventsRequest) (*GetEventsResponse, error) mustEmbedUnimplementedDaemonServiceServer() } -// UnimplementedDaemonServiceServer must be embedded to have -// forward compatible implementations. -// -// NOTE: this should be embedded by value instead of pointer to avoid a nil -// pointer dereference when methods are called. -type UnimplementedDaemonServiceServer struct{} +// UnimplementedDaemonServiceServer must be embedded to have forward compatible implementations. +type UnimplementedDaemonServiceServer struct { +} func (UnimplementedDaemonServiceServer) Login(context.Context, *LoginRequest) (*LoginResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method Login not implemented") @@ -406,14 +370,13 @@ func (UnimplementedDaemonServiceServer) SetNetworkMapPersistence(context.Context func (UnimplementedDaemonServiceServer) TracePacket(context.Context, *TracePacketRequest) (*TracePacketResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method TracePacket not implemented") } -func (UnimplementedDaemonServiceServer) SubscribeEvents(*SubscribeRequest, grpc.ServerStreamingServer[SystemEvent]) error { +func (UnimplementedDaemonServiceServer) SubscribeEvents(*SubscribeRequest, DaemonService_SubscribeEventsServer) error { return status.Errorf(codes.Unimplemented, "method SubscribeEvents not implemented") } func (UnimplementedDaemonServiceServer) GetEvents(context.Context, *GetEventsRequest) (*GetEventsResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method GetEvents not implemented") } func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {} -func (UnimplementedDaemonServiceServer) testEmbeddedByValue() {} // UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service. // Use of this interface is not recommended, as added methods to DaemonServiceServer will @@ -423,13 +386,6 @@ type UnsafeDaemonServiceServer interface { } func RegisterDaemonServiceServer(s grpc.ServiceRegistrar, srv DaemonServiceServer) { - // If the following call pancis, it indicates UnimplementedDaemonServiceServer was - // embedded by pointer and is nil. This will cause panics if an - // unimplemented method is ever invoked, so we test this at initialization - // time to prevent it from happening at runtime later due to I/O. - if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { - t.testEmbeddedByValue() - } s.RegisterService(&DaemonService_ServiceDesc, srv) } @@ -443,7 +399,7 @@ func _DaemonService_Login_Handler(srv interface{}, ctx context.Context, dec func } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: DaemonService_Login_FullMethodName, + FullMethod: "/daemon.DaemonService/Login", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).Login(ctx, req.(*LoginRequest)) @@ -461,7 +417,7 @@ func _DaemonService_WaitSSOLogin_Handler(srv interface{}, ctx context.Context, d } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: DaemonService_WaitSSOLogin_FullMethodName, + FullMethod: "/daemon.DaemonService/WaitSSOLogin", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).WaitSSOLogin(ctx, req.(*WaitSSOLoginRequest)) @@ -479,7 +435,7 @@ func _DaemonService_Up_Handler(srv interface{}, ctx context.Context, dec func(in } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: DaemonService_Up_FullMethodName, + FullMethod: "/daemon.DaemonService/Up", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).Up(ctx, req.(*UpRequest)) @@ -497,7 +453,7 @@ func _DaemonService_Status_Handler(srv interface{}, ctx context.Context, dec fun } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: DaemonService_Status_FullMethodName, + FullMethod: "/daemon.DaemonService/Status", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).Status(ctx, req.(*StatusRequest)) @@ -515,7 +471,7 @@ func _DaemonService_Down_Handler(srv interface{}, ctx context.Context, dec func( } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: DaemonService_Down_FullMethodName, + FullMethod: "/daemon.DaemonService/Down", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).Down(ctx, req.(*DownRequest)) @@ -533,7 +489,7 @@ func _DaemonService_GetConfig_Handler(srv interface{}, ctx context.Context, dec } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: DaemonService_GetConfig_FullMethodName, + FullMethod: "/daemon.DaemonService/GetConfig", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).GetConfig(ctx, req.(*GetConfigRequest)) @@ -551,7 +507,7 @@ func _DaemonService_ListNetworks_Handler(srv interface{}, ctx context.Context, d } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: DaemonService_ListNetworks_FullMethodName, + FullMethod: "/daemon.DaemonService/ListNetworks", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).ListNetworks(ctx, req.(*ListNetworksRequest)) @@ -569,7 +525,7 @@ func _DaemonService_SelectNetworks_Handler(srv interface{}, ctx context.Context, } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: DaemonService_SelectNetworks_FullMethodName, + FullMethod: "/daemon.DaemonService/SelectNetworks", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).SelectNetworks(ctx, req.(*SelectNetworksRequest)) @@ -587,7 +543,7 @@ func _DaemonService_DeselectNetworks_Handler(srv interface{}, ctx context.Contex } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: DaemonService_DeselectNetworks_FullMethodName, + FullMethod: "/daemon.DaemonService/DeselectNetworks", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).DeselectNetworks(ctx, req.(*SelectNetworksRequest)) @@ -605,7 +561,7 @@ func _DaemonService_ForwardingRules_Handler(srv interface{}, ctx context.Context } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: DaemonService_ForwardingRules_FullMethodName, + FullMethod: "/daemon.DaemonService/ForwardingRules", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).ForwardingRules(ctx, req.(*EmptyRequest)) @@ -623,7 +579,7 @@ func _DaemonService_DebugBundle_Handler(srv interface{}, ctx context.Context, de } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: DaemonService_DebugBundle_FullMethodName, + FullMethod: "/daemon.DaemonService/DebugBundle", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).DebugBundle(ctx, req.(*DebugBundleRequest)) @@ -641,7 +597,7 @@ func _DaemonService_GetLogLevel_Handler(srv interface{}, ctx context.Context, de } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: DaemonService_GetLogLevel_FullMethodName, + FullMethod: "/daemon.DaemonService/GetLogLevel", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).GetLogLevel(ctx, req.(*GetLogLevelRequest)) @@ -659,7 +615,7 @@ func _DaemonService_SetLogLevel_Handler(srv interface{}, ctx context.Context, de } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: DaemonService_SetLogLevel_FullMethodName, + FullMethod: "/daemon.DaemonService/SetLogLevel", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).SetLogLevel(ctx, req.(*SetLogLevelRequest)) @@ -677,7 +633,7 @@ func _DaemonService_ListStates_Handler(srv interface{}, ctx context.Context, dec } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: DaemonService_ListStates_FullMethodName, + FullMethod: "/daemon.DaemonService/ListStates", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).ListStates(ctx, req.(*ListStatesRequest)) @@ -695,7 +651,7 @@ func _DaemonService_CleanState_Handler(srv interface{}, ctx context.Context, dec } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: DaemonService_CleanState_FullMethodName, + FullMethod: "/daemon.DaemonService/CleanState", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).CleanState(ctx, req.(*CleanStateRequest)) @@ -713,7 +669,7 @@ func _DaemonService_DeleteState_Handler(srv interface{}, ctx context.Context, de } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: DaemonService_DeleteState_FullMethodName, + FullMethod: "/daemon.DaemonService/DeleteState", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).DeleteState(ctx, req.(*DeleteStateRequest)) @@ -731,7 +687,7 @@ func _DaemonService_SetNetworkMapPersistence_Handler(srv interface{}, ctx contex } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: DaemonService_SetNetworkMapPersistence_FullMethodName, + FullMethod: "/daemon.DaemonService/SetNetworkMapPersistence", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).SetNetworkMapPersistence(ctx, req.(*SetNetworkMapPersistenceRequest)) @@ -749,7 +705,7 @@ func _DaemonService_TracePacket_Handler(srv interface{}, ctx context.Context, de } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: DaemonService_TracePacket_FullMethodName, + FullMethod: "/daemon.DaemonService/TracePacket", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).TracePacket(ctx, req.(*TracePacketRequest)) @@ -762,11 +718,21 @@ func _DaemonService_SubscribeEvents_Handler(srv interface{}, stream grpc.ServerS if err := stream.RecvMsg(m); err != nil { return err } - return srv.(DaemonServiceServer).SubscribeEvents(m, &grpc.GenericServerStream[SubscribeRequest, SystemEvent]{ServerStream: stream}) + return srv.(DaemonServiceServer).SubscribeEvents(m, &daemonServiceSubscribeEventsServer{stream}) } -// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. -type DaemonService_SubscribeEventsServer = grpc.ServerStreamingServer[SystemEvent] +type DaemonService_SubscribeEventsServer interface { + Send(*SystemEvent) error + grpc.ServerStream +} + +type daemonServiceSubscribeEventsServer struct { + grpc.ServerStream +} + +func (x *daemonServiceSubscribeEventsServer) Send(m *SystemEvent) error { + return x.ServerStream.SendMsg(m) +} func _DaemonService_GetEvents_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(GetEventsRequest) @@ -778,7 +744,7 @@ func _DaemonService_GetEvents_Handler(srv interface{}, ctx context.Context, dec } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: DaemonService_GetEvents_FullMethodName, + FullMethod: "/daemon.DaemonService/GetEvents", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).GetEvents(ctx, req.(*GetEventsRequest)) diff --git a/client/proto/generate.sh b/client/proto/generate.sh index 52fe23d7f..f9a2c3750 100755 --- a/client/proto/generate.sh +++ b/client/proto/generate.sh @@ -11,7 +11,7 @@ fi old_pwd=$(pwd) script_path=$(dirname $(realpath "$0")) cd "$script_path" -go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.26 +go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.36.6 go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.1 protoc -I ./ ./daemon.proto --go_out=../ --go-grpc_out=../ --experimental_allow_proto3_optional cd "$old_pwd" \ No newline at end of file diff --git a/client/server/debug.go b/client/server/debug.go index 7de3e8609..412602b00 100644 --- a/client/server/debug.go +++ b/client/server/debug.go @@ -42,6 +42,7 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) ( Anonymize: req.GetAnonymize(), ClientStatus: req.GetStatus(), IncludeSystemInfo: req.GetSystemInfo(), + LogFileCount: req.GetLogFileCount(), }, ) diff --git a/go.mod b/go.mod index a12058278..4a9727373 100644 --- a/go.mod +++ b/go.mod @@ -25,7 +25,7 @@ require ( golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/windows v0.5.3 google.golang.org/grpc v1.64.1 - google.golang.org/protobuf v1.36.5 + google.golang.org/protobuf v1.36.6 gopkg.in/natefinch/lumberjack.v2 v2.0.0 ) diff --git a/go.sum b/go.sum index 6ce503dd1..a622f203f 100644 --- a/go.sum +++ b/go.sum @@ -1164,8 +1164,8 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0 google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= -google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM= -google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= +google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= +google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/util/log.go b/util/log.go index 59a064366..53d2b0684 100644 --- a/util/log.go +++ b/util/log.go @@ -14,7 +14,7 @@ import ( "github.com/netbirdio/netbird/formatter" ) -const defaultLogSize = 5 +const defaultLogSize = 15 // InitLog parses and sets log-level input func InitLog(logLevel string, logPath string) error { From 2b9f3319803e74b81f16b5216ab99e70acad24ea Mon Sep 17 00:00:00 2001 From: Pedro Maia Costa <550684+pnmcosta@users.noreply.github.com> Date: Fri, 11 Jul 2025 10:29:10 +0100 Subject: [PATCH 07/50] always suffix ephemeral peer name (#4138) --- management/server/peer.go | 43 +++++++++++++++++++++++++-------------- 1 file changed, 28 insertions(+), 15 deletions(-) diff --git a/management/server/peer.go b/management/server/peer.go index a60513b38..21a9579fc 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -236,11 +236,23 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user if peer.Name != update.Name { var newLabel string - newLabel, err = getPeerIPDNSLabel(ctx, transaction, peer.IP, accountID, update.Name) + + newLabel, err = nbdns.GetParsedDomainLabel(update.Name) if err != nil { - return fmt.Errorf("failed to get free DNS label: %w", err) + newLabel = "" + } else { + _, err := transaction.GetPeerIdByLabel(ctx, store.LockingStrengthNone, accountID, update.Name) + if err == nil { + newLabel = "" + } } + if newLabel == "" { + newLabel, err = getPeerIPDNSLabel(peer.IP, update.Name) + if err != nil { + return fmt.Errorf("failed to get free DNS label: %w", err) + } + } peer.Name = update.Name peer.DNSLabel = newLabel peerLabelChanged = true @@ -472,6 +484,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s var groupsToAdd []string var allowExtraDNSLabels bool var accountID string + var isEphemeral bool if addedByUser { user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID) if err != nil { @@ -501,7 +514,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s setupKeyName = sk.Name allowExtraDNSLabels = sk.AllowExtraDNSLabels accountID = sk.AccountID - + isEphemeral = sk.Ephemeral if !sk.AllowExtraDNSLabels && len(peer.ExtraDNSLabels) > 0 { return nil, nil, nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key doesn't allow extra DNS labels") } @@ -573,11 +586,17 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s } var freeLabel string - freeLabel, err = getPeerIPDNSLabel(ctx, am.Store, freeIP, accountID, peer.Meta.Hostname) - if err != nil { - return nil, nil, nil, fmt.Errorf("failed to get free DNS label: %w", err) + if isEphemeral || attempt > 1 { + freeLabel, err = getPeerIPDNSLabel(freeIP, peer.Meta.Hostname) + if err != nil { + return nil, nil, nil, fmt.Errorf("failed to get free DNS label: %w", err) + } + } else { + freeLabel, err = nbdns.GetParsedDomainLabel(peer.Meta.Hostname) + if err != nil { + return nil, nil, nil, fmt.Errorf("failed to get free DNS label: %w", err) + } } - newPeer.DNSLabel = freeLabel newPeer.IP = freeIP @@ -647,7 +666,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s if isUniqueConstraintError(err) { unlock() unlock = nil - log.WithContext(ctx).Debugf("Failed to add peer in attempt %d, retrying: %v", attempt, err) + log.WithContext(ctx).WithFields(log.Fields{"dns_label": freeLabel, "ip": freeIP}).Tracef("Failed to add peer in attempt %d, retrying: %v", attempt, err) continue } @@ -681,7 +700,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer) } -func getPeerIPDNSLabel(ctx context.Context, tx store.Store, ip net.IP, accountID, peerHostName string) (string, error) { +func getPeerIPDNSLabel(ip net.IP, peerHostName string) (string, error) { ip = ip.To4() dnsName, err := nbdns.GetParsedDomainLabel(peerHostName) @@ -689,12 +708,6 @@ func getPeerIPDNSLabel(ctx context.Context, tx store.Store, ip net.IP, accountID return "", fmt.Errorf("failed to parse peer host name %s: %w", peerHostName, err) } - _, err = tx.GetPeerIdByLabel(ctx, store.LockingStrengthNone, accountID, dnsName) - if err != nil { - //nolint:nilerr - return dnsName, nil - } - return fmt.Sprintf("%s-%d-%d", dnsName, ip[2], ip[3]), nil } From a76c8eafb46cda5e44de0f3784160a34e8cdb4a2 Mon Sep 17 00:00:00 2001 From: Vlad <4941176+crn4@users.noreply.github.com> Date: Fri, 11 Jul 2025 11:37:14 +0200 Subject: [PATCH 08/50] [management] sync calls to UpdateAccountPeers from BufferUpdateAccountPeers (#4137) --------- Co-authored-by: Maycon Santos Co-authored-by: Pedro Costa <550684+pnmcosta@users.noreply.github.com> --- management/server/mock_server/account_mock.go | 10 +- management/server/peer.go | 34 ++++- management/server/peer_test.go | 130 ++++++++++++++++++ 3 files changed, 166 insertions(+), 8 deletions(-) diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 4004f1b57..b1ec66286 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -120,14 +120,20 @@ type MockAccountManager struct { GetAccountOnboardingFunc func(ctx context.Context, accountID, userID string) (*types.AccountOnboarding, error) UpdateAccountOnboardingFunc func(ctx context.Context, accountID, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) GetOrCreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error) + UpdateAccountPeersFunc func(ctx context.Context, accountID string) + BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string) } func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) { - // do nothing + if am.UpdateAccountPeersFunc != nil { + am.UpdateAccountPeersFunc(ctx, accountID) + } } func (am *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) { - // do nothing + if am.BufferUpdateAccountPeersFunc != nil { + am.BufferUpdateAccountPeersFunc(ctx, accountID) + } } func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error { diff --git a/management/server/peer.go b/management/server/peer.go index 21a9579fc..c6ade83c0 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -9,6 +9,7 @@ import ( "slices" "strings" "sync" + "sync/atomic" "time" "github.com/rs/xid" @@ -1280,18 +1281,39 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account } } -func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) { - mu, _ := am.accountUpdateLocks.LoadOrStore(accountID, &sync.Mutex{}) - lock := mu.(*sync.Mutex) +type bufferUpdate struct { + mu sync.Mutex + next *time.Timer + update atomic.Bool +} - if !lock.TryLock() { +func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) { + bufUpd, _ := am.accountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{}) + b := bufUpd.(*bufferUpdate) + + if !b.mu.TryLock() { + b.update.Store(true) return } + if b.next != nil { + b.next.Stop() + } + go func() { - time.Sleep(time.Duration(am.updateAccountPeersBufferInterval.Load())) - lock.Unlock() + defer b.mu.Unlock() am.UpdateAccountPeers(ctx, accountID) + if !b.update.Load() { + return + } + b.update.Store(false) + if b.next == nil { + b.next = time.AfterFunc(time.Duration(am.updateAccountPeersBufferInterval.Load()), func() { + am.UpdateAccountPeers(ctx, accountID) + }) + return + } + b.next.Reset(time.Duration(am.updateAccountPeersBufferInterval.Load())) }() } diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 07ec5037b..d41020514 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -13,6 +13,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "testing" "time" @@ -25,6 +26,7 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" + "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/status" @@ -2251,3 +2253,131 @@ func Test_AddPeer(t *testing.T) { assert.Equal(t, totalPeers, maps.Values(account.SetupKeys)[0].UsedTimes) assert.Equal(t, uint64(totalPeers), account.Network.Serial) } + +func TestBufferUpdateAccountPeers(t *testing.T) { + const ( + peersCount = 1000 + updateAccountInterval = 50 * time.Millisecond + ) + + var ( + deletedPeers, updatePeersDeleted, updatePeersRuns atomic.Int32 + uapLastRun, dpLastRun atomic.Int64 + + totalNewRuns, totalOldRuns int + ) + + uap := func(ctx context.Context, accountID string) { + updatePeersDeleted.Store(deletedPeers.Load()) + updatePeersRuns.Add(1) + uapLastRun.Store(time.Now().UnixMilli()) + time.Sleep(100 * time.Millisecond) + } + + t.Run("new approach", func(t *testing.T) { + updatePeersRuns.Store(0) + updatePeersDeleted.Store(0) + deletedPeers.Store(0) + + var mustore sync.Map + bufupd := func(ctx context.Context, accountID string) { + mu, _ := mustore.LoadOrStore(accountID, &bufferUpdate{}) + b := mu.(*bufferUpdate) + + if !b.mu.TryLock() { + b.update.Store(true) + return + } + + if b.next != nil { + b.next.Stop() + } + + go func() { + defer b.mu.Unlock() + uap(ctx, accountID) + if !b.update.Load() { + return + } + b.update.Store(false) + b.next = time.AfterFunc(updateAccountInterval, func() { + uap(ctx, accountID) + }) + }() + } + dp := func(ctx context.Context, accountID, peerID, userID string) error { + deletedPeers.Add(1) + dpLastRun.Store(time.Now().UnixMilli()) + time.Sleep(10 * time.Millisecond) + bufupd(ctx, accountID) + return nil + } + + am := mock_server.MockAccountManager{ + UpdateAccountPeersFunc: uap, + BufferUpdateAccountPeersFunc: bufupd, + DeletePeerFunc: dp, + } + empty := "" + for range peersCount { + //nolint + am.DeletePeer(context.Background(), empty, empty, empty) + } + time.Sleep(100 * time.Millisecond) + + assert.Equal(t, peersCount, int(deletedPeers.Load()), "Expected all peers to be deleted") + assert.Equal(t, peersCount, int(updatePeersDeleted.Load()), "Expected all peers to be updated in the buffer") + assert.GreaterOrEqual(t, uapLastRun.Load(), dpLastRun.Load(), "Expected update account peers to run after delete peer") + + totalNewRuns = int(updatePeersRuns.Load()) + }) + + t.Run("old approach", func(t *testing.T) { + updatePeersRuns.Store(0) + updatePeersDeleted.Store(0) + deletedPeers.Store(0) + + var mustore sync.Map + bufupd := func(ctx context.Context, accountID string) { + mu, _ := mustore.LoadOrStore(accountID, &sync.Mutex{}) + b := mu.(*sync.Mutex) + + if !b.TryLock() { + return + } + + go func() { + time.Sleep(updateAccountInterval) + b.Unlock() + uap(ctx, accountID) + }() + } + dp := func(ctx context.Context, accountID, peerID, userID string) error { + deletedPeers.Add(1) + dpLastRun.Store(time.Now().UnixMilli()) + time.Sleep(10 * time.Millisecond) + bufupd(ctx, accountID) + return nil + } + + am := mock_server.MockAccountManager{ + UpdateAccountPeersFunc: uap, + BufferUpdateAccountPeersFunc: bufupd, + DeletePeerFunc: dp, + } + empty := "" + for range peersCount { + //nolint + am.DeletePeer(context.Background(), empty, empty, empty) + } + time.Sleep(100 * time.Millisecond) + + assert.Equal(t, peersCount, int(deletedPeers.Load()), "Expected all peers to be deleted") + assert.Equal(t, peersCount, int(updatePeersDeleted.Load()), "Expected all peers to be updated in the buffer") + assert.GreaterOrEqual(t, uapLastRun.Load(), dpLastRun.Load(), "Expected update account peers to run after delete peer") + + totalOldRuns = int(updatePeersRuns.Load()) + }) + assert.Less(t, totalNewRuns, totalOldRuns, "Expected new approach to run less than old approach. New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns) + t.Logf("New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns) +} From 3e6eede1523f5073cd3d35f068c771b6d88ba397 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Sat, 12 Jul 2025 11:10:45 +0200 Subject: [PATCH 09/50] [client] Fix elapsed time calculation when machine is in sleep mode (#4140) --- client/iface/bind/activity.go | 24 ++++++++++--------- client/iface/bind/activity_test.go | 8 +++---- client/iface/configurer/kernel_unix.go | 4 +++- client/iface/configurer/usp.go | 3 ++- client/iface/device/interface.go | 3 ++- client/iface/iface.go | 3 ++- client/internal/conn_mgr.go | 1 - client/internal/engine_test.go | 5 ++-- client/internal/iface_common.go | 3 ++- client/internal/lazyconn/activity/listener.go | 7 +++--- .../internal/lazyconn/inactivity/manager.go | 9 ++++--- .../lazyconn/inactivity/manager_test.go | 13 +++++----- client/internal/lazyconn/manager/manager.go | 7 +++--- client/internal/lazyconn/wgiface.go | 4 +++- monotime/time.go | 10 ++++++-- 15 files changed, 62 insertions(+), 42 deletions(-) diff --git a/client/iface/bind/activity.go b/client/iface/bind/activity.go index d3b406bcd..57862e3d1 100644 --- a/client/iface/bind/activity.go +++ b/client/iface/bind/activity.go @@ -34,14 +34,14 @@ func NewActivityRecorder() *ActivityRecorder { } // GetLastActivities returns a snapshot of peer last activity -func (r *ActivityRecorder) GetLastActivities() map[string]time.Time { +func (r *ActivityRecorder) GetLastActivities() map[string]monotime.Time { r.mu.RLock() defer r.mu.RUnlock() - activities := make(map[string]time.Time, len(r.peers)) + activities := make(map[string]monotime.Time, len(r.peers)) for key, record := range r.peers { - unixNano := record.LastActivity.Load() - activities[key] = time.Unix(0, unixNano) + monoTime := record.LastActivity.Load() + activities[key] = monotime.Time(monoTime) } return activities } @@ -51,18 +51,20 @@ func (r *ActivityRecorder) UpsertAddress(publicKey string, address netip.AddrPor r.mu.Lock() defer r.mu.Unlock() - if pr, exists := r.peers[publicKey]; exists { - delete(r.addrToPeer, pr.Address) - pr.Address = address + var record *PeerRecord + record, exists := r.peers[publicKey] + if exists { + delete(r.addrToPeer, record.Address) + record.Address = address } else { - record := &PeerRecord{ + record = &PeerRecord{ Address: address, } - record.LastActivity.Store(monotime.Now()) + record.LastActivity.Store(int64(monotime.Now())) r.peers[publicKey] = record } - r.addrToPeer[address] = r.peers[publicKey] + r.addrToPeer[address] = record } func (r *ActivityRecorder) Remove(publicKey string) { @@ -84,7 +86,7 @@ func (r *ActivityRecorder) record(address netip.AddrPort) { return } - now := monotime.Now() + now := int64(monotime.Now()) last := record.LastActivity.Load() if now-last < saveFrequency { return diff --git a/client/iface/bind/activity_test.go b/client/iface/bind/activity_test.go index 598607b95..bdd0dca29 100644 --- a/client/iface/bind/activity_test.go +++ b/client/iface/bind/activity_test.go @@ -4,6 +4,8 @@ import ( "net/netip" "testing" "time" + + "github.com/netbirdio/netbird/monotime" ) func TestActivityRecorder_GetLastActivities(t *testing.T) { @@ -17,11 +19,7 @@ func TestActivityRecorder_GetLastActivities(t *testing.T) { t.Fatalf("Expected activity for peer %s, but got none", peer) } - if p.IsZero() { - t.Fatalf("Expected activity for peer %s, but got zero", peer) - } - - if p.Before(time.Now().Add(-2 * time.Minute)) { + if monotime.Since(p) > 5*time.Second { t.Fatalf("Expected activity for peer %s to be recent, but got %v", peer, p) } } diff --git a/client/iface/configurer/kernel_unix.go b/client/iface/configurer/kernel_unix.go index e2ea19144..84afc38f5 100644 --- a/client/iface/configurer/kernel_unix.go +++ b/client/iface/configurer/kernel_unix.go @@ -11,6 +11,8 @@ import ( log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/netbirdio/netbird/monotime" ) var zeroKey wgtypes.Key @@ -277,6 +279,6 @@ func (c *KernelConfigurer) GetStats() (map[string]WGStats, error) { return stats, nil } -func (c *KernelConfigurer) LastActivities() map[string]time.Time { +func (c *KernelConfigurer) LastActivities() map[string]monotime.Time { return nil } diff --git a/client/iface/configurer/usp.go b/client/iface/configurer/usp.go index 6ead716f1..1ff4d839c 100644 --- a/client/iface/configurer/usp.go +++ b/client/iface/configurer/usp.go @@ -17,6 +17,7 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/monotime" nbnet "github.com/netbirdio/netbird/util/net" ) @@ -223,7 +224,7 @@ func (c *WGUSPConfigurer) FullStats() (*Stats, error) { return parseStatus(c.deviceName, ipcStr) } -func (c *WGUSPConfigurer) LastActivities() map[string]time.Time { +func (c *WGUSPConfigurer) LastActivities() map[string]monotime.Time { return c.activityRecorder.GetLastActivities() } diff --git a/client/iface/device/interface.go b/client/iface/device/interface.go index d68e6bf90..1f40b0d46 100644 --- a/client/iface/device/interface.go +++ b/client/iface/device/interface.go @@ -8,6 +8,7 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/monotime" ) type WGConfigurer interface { @@ -19,5 +20,5 @@ type WGConfigurer interface { Close() GetStats() (map[string]configurer.WGStats, error) FullStats() (*configurer.Stats, error) - LastActivities() map[string]time.Time + LastActivities() map[string]monotime.Time } diff --git a/client/iface/iface.go b/client/iface/iface.go index e90c3536b..0e41f8e64 100644 --- a/client/iface/iface.go +++ b/client/iface/iface.go @@ -21,6 +21,7 @@ import ( "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgproxy" + "github.com/netbirdio/netbird/monotime" ) const ( @@ -237,7 +238,7 @@ func (w *WGIface) GetStats() (map[string]configurer.WGStats, error) { return w.configurer.GetStats() } -func (w *WGIface) LastActivities() map[string]time.Time { +func (w *WGIface) LastActivities() map[string]monotime.Time { w.mu.Lock() defer w.mu.Unlock() diff --git a/client/internal/conn_mgr.go b/client/internal/conn_mgr.go index c76b0a99f..112559132 100644 --- a/client/internal/conn_mgr.go +++ b/client/internal/conn_mgr.go @@ -226,7 +226,6 @@ func (e *ConnMgr) ActivatePeer(ctx context.Context, conn *peer.Conn) { } if found := e.lazyConnMgr.ActivatePeer(conn.GetKey()); found { - conn.Log.Infof("activated peer from inactive state") if err := conn.Open(ctx); err != nil { conn.Log.Errorf("failed to open connection: %v", err) } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index f4ed8f1c0..4b7a2d600 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -52,6 +52,7 @@ import ( "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/monotime" relayClient "github.com/netbirdio/netbird/relay/client" "github.com/netbirdio/netbird/route" signal "github.com/netbirdio/netbird/signal/client" @@ -96,7 +97,7 @@ type MockWGIface struct { GetInterfaceGUIDStringFunc func() (string, error) GetProxyFunc func() wgproxy.Proxy GetNetFunc func() *netstack.Net - LastActivitiesFunc func() map[string]time.Time + LastActivitiesFunc func() map[string]monotime.Time } func (m *MockWGIface) FullStats() (*configurer.Stats, error) { @@ -187,7 +188,7 @@ func (m *MockWGIface) GetNet() *netstack.Net { return m.GetNetFunc() } -func (m *MockWGIface) LastActivities() map[string]time.Time { +func (m *MockWGIface) LastActivities() map[string]monotime.Time { if m.LastActivitiesFunc != nil { return m.LastActivitiesFunc() } diff --git a/client/internal/iface_common.go b/client/internal/iface_common.go index 38fb3561e..bf96153ea 100644 --- a/client/internal/iface_common.go +++ b/client/internal/iface_common.go @@ -14,6 +14,7 @@ import ( "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgproxy" + "github.com/netbirdio/netbird/monotime" ) type wgIfaceBase interface { @@ -38,5 +39,5 @@ type wgIfaceBase interface { GetStats() (map[string]configurer.WGStats, error) GetNet() *netstack.Net FullStats() (*configurer.Stats, error) - LastActivities() map[string]time.Time + LastActivities() map[string]monotime.Time } diff --git a/client/internal/lazyconn/activity/listener.go b/client/internal/lazyconn/activity/listener.go index 81b5da17b..817ff00c3 100644 --- a/client/internal/lazyconn/activity/listener.go +++ b/client/internal/lazyconn/activity/listener.go @@ -48,7 +48,7 @@ func (d *Listener) ReadPackets() { n, remoteAddr, err := d.conn.ReadFromUDP(make([]byte, 1)) if err != nil { if d.isClosed.Load() { - d.peerCfg.Log.Debugf("exit from activity listener") + d.peerCfg.Log.Infof("exit from activity listener") } else { d.peerCfg.Log.Errorf("failed to read from activity listener: %s", err) } @@ -59,9 +59,11 @@ func (d *Listener) ReadPackets() { d.peerCfg.Log.Warnf("received %d bytes from %s, too short", n, remoteAddr) continue } + d.peerCfg.Log.Infof("activity detected") break } + d.peerCfg.Log.Debugf("removing lazy endpoint: %s", d.endpoint.String()) if err := d.removeEndpoint(); err != nil { d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err) } @@ -71,7 +73,7 @@ func (d *Listener) ReadPackets() { } func (d *Listener) Close() { - d.peerCfg.Log.Infof("closing listener: %s", d.conn.LocalAddr().String()) + d.peerCfg.Log.Infof("closing activity listener: %s", d.conn.LocalAddr().String()) d.isClosed.Store(true) if err := d.conn.Close(); err != nil { @@ -81,7 +83,6 @@ func (d *Listener) Close() { } func (d *Listener) removeEndpoint() error { - d.peerCfg.Log.Debugf("removing lazy endpoint: %s", d.endpoint.String()) return d.wgIface.RemovePeer(d.peerCfg.PublicKey) } diff --git a/client/internal/lazyconn/inactivity/manager.go b/client/internal/lazyconn/inactivity/manager.go index 854951729..0120f4430 100644 --- a/client/internal/lazyconn/inactivity/manager.go +++ b/client/internal/lazyconn/inactivity/manager.go @@ -8,6 +8,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/internal/lazyconn" + "github.com/netbirdio/netbird/monotime" ) const ( @@ -18,7 +19,7 @@ const ( ) type WgInterface interface { - LastActivities() map[string]time.Time + LastActivities() map[string]monotime.Time } type Manager struct { @@ -124,6 +125,7 @@ func (m *Manager) checkStats() (map[string]struct{}, error) { idlePeers := make(map[string]struct{}) + checkTime := time.Now() for peerID, peerCfg := range m.interestedPeers { lastActive, ok := lastActivities[peerID] if !ok { @@ -132,8 +134,9 @@ func (m *Manager) checkStats() (map[string]struct{}, error) { continue } - if time.Since(lastActive) > m.inactivityThreshold { - peerCfg.Log.Infof("peer is inactive since: %v", lastActive) + since := monotime.Since(lastActive) + if since > m.inactivityThreshold { + peerCfg.Log.Infof("peer is inactive since time: %s", checkTime.Add(-since).String()) idlePeers[peerID] = struct{}{} } } diff --git a/client/internal/lazyconn/inactivity/manager_test.go b/client/internal/lazyconn/inactivity/manager_test.go index d012b41a2..10b4ef1eb 100644 --- a/client/internal/lazyconn/inactivity/manager_test.go +++ b/client/internal/lazyconn/inactivity/manager_test.go @@ -9,13 +9,14 @@ import ( "github.com/stretchr/testify/assert" "github.com/netbirdio/netbird/client/internal/lazyconn" + "github.com/netbirdio/netbird/monotime" ) type mockWgInterface struct { - lastActivities map[string]time.Time + lastActivities map[string]monotime.Time } -func (m *mockWgInterface) LastActivities() map[string]time.Time { +func (m *mockWgInterface) LastActivities() map[string]monotime.Time { return m.lastActivities } @@ -23,8 +24,8 @@ func TestPeerTriggersInactivity(t *testing.T) { peerID := "peer1" wgMock := &mockWgInterface{ - lastActivities: map[string]time.Time{ - peerID: time.Now().Add(-20 * time.Minute), + lastActivities: map[string]monotime.Time{ + peerID: monotime.Time(int64(monotime.Now()) - int64(20*time.Minute)), }, } @@ -64,8 +65,8 @@ func TestPeerTriggersActivity(t *testing.T) { peerID := "peer1" wgMock := &mockWgInterface{ - lastActivities: map[string]time.Time{ - peerID: time.Now().Add(-5 * time.Minute), + lastActivities: map[string]monotime.Time{ + peerID: monotime.Time(int64(monotime.Now()) - int64(5*time.Minute)), }, } diff --git a/client/internal/lazyconn/manager/manager.go b/client/internal/lazyconn/manager/manager.go index 416e4e7e7..b6b3c6091 100644 --- a/client/internal/lazyconn/manager/manager.go +++ b/client/internal/lazyconn/manager/manager.go @@ -258,12 +258,13 @@ func (m *Manager) ActivatePeer(peerID string) (found bool) { return false } + cfg.Log.Infof("activate peer from inactive state by remote signal message") + if !m.activateSinglePeer(cfg, mp) { return false } m.activateHAGroupPeers(cfg) - return true } @@ -571,12 +572,12 @@ func (m *Manager) onPeerInactivityTimedOut(peerIDs map[string]struct{}) { // this is blocking operation, potentially can be optimized m.peerStore.PeerConnIdle(mp.peerCfg.PublicKey) - mp.peerCfg.Log.Infof("start activity monitor") - mp.expectedWatcher = watcherActivity m.inactivityManager.RemovePeer(mp.peerCfg.PublicKey) + mp.peerCfg.Log.Infof("start activity monitor") + if err := m.activityManager.MonitorPeerActivity(*mp.peerCfg); err != nil { mp.peerCfg.Log.Errorf("failed to create activity monitor: %v", err) continue diff --git a/client/internal/lazyconn/wgiface.go b/client/internal/lazyconn/wgiface.go index d55ff9670..0351904f7 100644 --- a/client/internal/lazyconn/wgiface.go +++ b/client/internal/lazyconn/wgiface.go @@ -6,11 +6,13 @@ import ( "time" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/netbirdio/netbird/monotime" ) type WGIface interface { RemovePeer(peerKey string) error UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error IsUserspaceBind() bool - LastActivities() map[string]time.Time + LastActivities() map[string]monotime.Time } diff --git a/monotime/time.go b/monotime/time.go index 6032fb60b..ba45b6659 100644 --- a/monotime/time.go +++ b/monotime/time.go @@ -9,6 +9,8 @@ var ( baseWallNano int64 ) +type Time int64 + func init() { baseWallTime = time.Now() baseWallNano = baseWallTime.UnixNano() @@ -23,7 +25,11 @@ func init() { // and using time.Since() for elapsed calculation, this avoids repeated // time.Now() calls and leverages Go's internal monotonic clock for // efficient duration measurement. -func Now() int64 { +func Now() Time { elapsed := time.Since(baseWallTime) - return baseWallNano + int64(elapsed) + return Time(baseWallNano + int64(elapsed)) +} + +func Since(t Time) time.Duration { + return time.Duration(Now() - t) } From e49bcc343d30ef6c491782c45bea81c372bcf0e4 Mon Sep 17 00:00:00 2001 From: iisteev Date: Sun, 13 Jul 2025 15:42:48 +0200 Subject: [PATCH 10/50] [client] Avoid parsing NB_NETSTACK_SKIP_PROXY if empty (#4145) Signed-off-by: iisteev --- client/iface/netstack/tun.go | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/client/iface/netstack/tun.go b/client/iface/netstack/tun.go index aec9d4faa..b2506b50d 100644 --- a/client/iface/netstack/tun.go +++ b/client/iface/netstack/tun.go @@ -41,9 +41,12 @@ func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) { } t.tundev = nsTunDev - skipProxy, err := strconv.ParseBool(os.Getenv(EnvSkipProxy)) - if err != nil { - log.Errorf("failed to parse %s: %s", EnvSkipProxy, err) + var skipProxy bool + if val := os.Getenv(EnvSkipProxy); val != "" { + skipProxy, err = strconv.ParseBool(val) + if err != nil { + log.Errorf("failed to parse %s: %s", EnvSkipProxy, err) + } } if skipProxy { return nsTunDev, tunNet, nil From 0dab03252c38e547b346dad5be91ac983b3504a1 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Tue, 15 Jul 2025 10:43:42 +0200 Subject: [PATCH 11/50] [client, relay-server] Feature/relay notification (#4083) - Clients now subscribe to peer status changes. - The server manages and maintains these subscriptions. - Replaced raw string peer IDs with a custom peer ID type for better type safety and clarity. --- client/iface/wgproxy/bind/proxy.go | 20 +++ client/iface/wgproxy/ebpf/wrapper.go | 18 +++ client/iface/wgproxy/factory_kernel.go | 5 +- client/iface/wgproxy/factory_usp.go | 4 +- client/iface/wgproxy/listener/listener.go | 19 +++ client/iface/wgproxy/proxy.go | 1 + client/iface/wgproxy/proxy_test.go | 4 +- client/iface/wgproxy/udp/proxy.go | 11 ++ client/internal/peer/conn.go | 4 +- client/internal/peer/worker_relay.go | 6 +- relay/auth/validator.go | 7 - relay/client/client.go | 155 ++++++++++++++------ relay/client/client_test.go | 163 ++++++++++++--------- relay/client/conn.go | 13 +- relay/client/guard.go | 2 +- relay/client/manager.go | 20 +-- relay/client/manager_test.go | 102 ++++++++----- relay/client/peer_subscription.go | 168 ++++++++++++++++++++++ relay/client/peer_subscription_test.go | 99 +++++++++++++ relay/client/picker.go | 4 +- relay/cmd/root.go | 9 +- relay/messages/id.go | 24 ++-- relay/messages/id_test.go | 13 -- relay/messages/message.go | 115 ++++++++------- relay/messages/message_test.go | 16 +-- relay/messages/peer_state.go | 92 ++++++++++++ relay/messages/peer_state_test.go | 144 +++++++++++++++++++ relay/server/handshake.go | 47 +++--- relay/server/peer.go | 113 ++++++++++++--- relay/server/relay.go | 116 ++++++++------- relay/server/server.go | 28 ++-- relay/server/store/listener.go | 121 ++++++++++++++++ relay/server/store/notifier.go | 64 +++++++++ relay/server/{ => store}/store.go | 31 ++-- relay/server/store/store_test.go | 49 +++++++ relay/server/store_test.go | 85 ----------- relay/server/url.go | 33 +++++ relay/test/benchmark_test.go | 22 +-- relay/testec2/relay.go | 12 +- 39 files changed, 1464 insertions(+), 495 deletions(-) create mode 100644 client/iface/wgproxy/listener/listener.go create mode 100644 relay/client/peer_subscription.go create mode 100644 relay/client/peer_subscription_test.go delete mode 100644 relay/messages/id_test.go create mode 100644 relay/messages/peer_state.go create mode 100644 relay/messages/peer_state_test.go create mode 100644 relay/server/store/listener.go create mode 100644 relay/server/store/notifier.go rename relay/server/{ => store}/store.go (61%) create mode 100644 relay/server/store/store_test.go delete mode 100644 relay/server/store_test.go create mode 100644 relay/server/url.go diff --git a/client/iface/wgproxy/bind/proxy.go b/client/iface/wgproxy/bind/proxy.go index 614787e17..179ac0b75 100644 --- a/client/iface/wgproxy/bind/proxy.go +++ b/client/iface/wgproxy/bind/proxy.go @@ -12,6 +12,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/wgproxy/listener" ) type ProxyBind struct { @@ -28,6 +29,17 @@ type ProxyBind struct { pausedMu sync.Mutex paused bool isStarted bool + + closeListener *listener.CloseListener +} + +func NewProxyBind(bind *bind.ICEBind) *ProxyBind { + p := &ProxyBind{ + Bind: bind, + closeListener: listener.NewCloseListener(), + } + + return p } // AddTurnConn adds a new connection to the bind. @@ -54,6 +66,10 @@ func (p *ProxyBind) EndpointAddr() *net.UDPAddr { } } +func (p *ProxyBind) SetDisconnectListener(disconnected func()) { + p.closeListener.SetCloseListener(disconnected) +} + func (p *ProxyBind) Work() { if p.remoteConn == nil { return @@ -96,6 +112,9 @@ func (p *ProxyBind) close() error { if p.closed { return nil } + + p.closeListener.SetCloseListener(nil) + p.closed = true p.cancel() @@ -122,6 +141,7 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) { if ctx.Err() != nil { return } + p.closeListener.Notify() log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err) return } diff --git a/client/iface/wgproxy/ebpf/wrapper.go b/client/iface/wgproxy/ebpf/wrapper.go index 54cab4e1b..dbf9128a8 100644 --- a/client/iface/wgproxy/ebpf/wrapper.go +++ b/client/iface/wgproxy/ebpf/wrapper.go @@ -11,6 +11,8 @@ import ( "sync" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/iface/wgproxy/listener" ) // ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call @@ -26,6 +28,15 @@ type ProxyWrapper struct { pausedMu sync.Mutex paused bool isStarted bool + + closeListener *listener.CloseListener +} + +func NewProxyWrapper(WgeBPFProxy *WGEBPFProxy) *ProxyWrapper { + return &ProxyWrapper{ + WgeBPFProxy: WgeBPFProxy, + closeListener: listener.NewCloseListener(), + } } func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error { @@ -43,6 +54,10 @@ func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr { return p.wgEndpointAddr } +func (p *ProxyWrapper) SetDisconnectListener(disconnected func()) { + p.closeListener.SetCloseListener(disconnected) +} + func (p *ProxyWrapper) Work() { if p.remoteConn == nil { return @@ -77,6 +92,8 @@ func (e *ProxyWrapper) CloseConn() error { e.cancel() + e.closeListener.SetCloseListener(nil) + if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { return fmt.Errorf("failed to close remote conn: %w", err) } @@ -117,6 +134,7 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err if ctx.Err() != nil { return 0, ctx.Err() } + p.closeListener.Notify() if !errors.Is(err, io.EOF) { log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err) } diff --git a/client/iface/wgproxy/factory_kernel.go b/client/iface/wgproxy/factory_kernel.go index 3ad7dc59d..e62cd97be 100644 --- a/client/iface/wgproxy/factory_kernel.go +++ b/client/iface/wgproxy/factory_kernel.go @@ -36,9 +36,8 @@ func (w *KernelFactory) GetProxy() Proxy { return udpProxy.NewWGUDPProxy(w.wgPort) } - return &ebpf.ProxyWrapper{ - WgeBPFProxy: w.ebpfProxy, - } + return ebpf.NewProxyWrapper(w.ebpfProxy) + } func (w *KernelFactory) Free() error { diff --git a/client/iface/wgproxy/factory_usp.go b/client/iface/wgproxy/factory_usp.go index e2d479331..141b4c1f9 100644 --- a/client/iface/wgproxy/factory_usp.go +++ b/client/iface/wgproxy/factory_usp.go @@ -20,9 +20,7 @@ func NewUSPFactory(iceBind *bind.ICEBind) *USPFactory { } func (w *USPFactory) GetProxy() Proxy { - return &proxyBind.ProxyBind{ - Bind: w.bind, - } + return proxyBind.NewProxyBind(w.bind) } func (w *USPFactory) Free() error { diff --git a/client/iface/wgproxy/listener/listener.go b/client/iface/wgproxy/listener/listener.go new file mode 100644 index 000000000..bfd651548 --- /dev/null +++ b/client/iface/wgproxy/listener/listener.go @@ -0,0 +1,19 @@ +package listener + +type CloseListener struct { + listener func() +} + +func NewCloseListener() *CloseListener { + return &CloseListener{} +} + +func (c *CloseListener) SetCloseListener(listener func()) { + c.listener = listener +} + +func (c *CloseListener) Notify() { + if c.listener != nil { + c.listener() + } +} diff --git a/client/iface/wgproxy/proxy.go b/client/iface/wgproxy/proxy.go index 243aa2bd2..c2879877e 100644 --- a/client/iface/wgproxy/proxy.go +++ b/client/iface/wgproxy/proxy.go @@ -12,4 +12,5 @@ type Proxy interface { Work() // Work start or resume the proxy Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works. CloseConn() error + SetDisconnectListener(disconnected func()) } diff --git a/client/iface/wgproxy/proxy_test.go b/client/iface/wgproxy/proxy_test.go index 64b617621..2165b8aba 100644 --- a/client/iface/wgproxy/proxy_test.go +++ b/client/iface/wgproxy/proxy_test.go @@ -98,9 +98,7 @@ func TestProxyCloseByRemoteConn(t *testing.T) { t.Errorf("failed to free ebpf proxy: %s", err) } }() - proxyWrapper := &ebpf.ProxyWrapper{ - WgeBPFProxy: ebpfProxy, - } + proxyWrapper := ebpf.NewProxyWrapper(ebpfProxy) tests = append(tests, struct { name string diff --git a/client/iface/wgproxy/udp/proxy.go b/client/iface/wgproxy/udp/proxy.go index ba0004b8a..df45d8ca5 100644 --- a/client/iface/wgproxy/udp/proxy.go +++ b/client/iface/wgproxy/udp/proxy.go @@ -12,6 +12,7 @@ import ( log "github.com/sirupsen/logrus" cerrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/iface/wgproxy/listener" ) // WGUDPProxy proxies @@ -28,6 +29,8 @@ type WGUDPProxy struct { pausedMu sync.Mutex paused bool isStarted bool + + closeListener *listener.CloseListener } // NewWGUDPProxy instantiate a UDP based WireGuard proxy. This is not a thread safe implementation @@ -35,6 +38,7 @@ func NewWGUDPProxy(wgPort int) *WGUDPProxy { log.Debugf("Initializing new user space proxy with port %d", wgPort) p := &WGUDPProxy{ localWGListenPort: wgPort, + closeListener: listener.NewCloseListener(), } return p } @@ -67,6 +71,10 @@ func (p *WGUDPProxy) EndpointAddr() *net.UDPAddr { return endpointUdpAddr } +func (p *WGUDPProxy) SetDisconnectListener(disconnected func()) { + p.closeListener.SetCloseListener(disconnected) +} + // Work starts the proxy or resumes it if it was paused func (p *WGUDPProxy) Work() { if p.remoteConn == nil { @@ -111,6 +119,8 @@ func (p *WGUDPProxy) close() error { if p.closed { return nil } + + p.closeListener.SetCloseListener(nil) p.closed = true p.cancel() @@ -141,6 +151,7 @@ func (p *WGUDPProxy) proxyToRemote(ctx context.Context) { if ctx.Err() != nil { return } + p.closeListener.Notify() log.Debugf("failed to read from wg interface conn: %s", err) return } diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 1f0ec164e..7765bb51c 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -167,7 +167,7 @@ func (conn *Conn) Open(engineCtx context.Context) error { conn.ctx, conn.ctxCancel = context.WithCancel(engineCtx) - conn.workerRelay = NewWorkerRelay(conn.Log, isController(conn.config), conn.config, conn, conn.relayManager, conn.dumpState) + conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager, conn.dumpState) relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally() workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally) @@ -489,6 +489,8 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) { conn.Log.Errorf("failed to add relayed net.Conn to local proxy: %v", err) return } + wgProxy.SetDisconnectListener(conn.onRelayDisconnected) + conn.dumpState.NewLocalProxy() conn.Log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String()) diff --git a/client/internal/peer/worker_relay.go b/client/internal/peer/worker_relay.go index aa8f7d635..5e2900609 100644 --- a/client/internal/peer/worker_relay.go +++ b/client/internal/peer/worker_relay.go @@ -19,6 +19,7 @@ type RelayConnInfo struct { } type WorkerRelay struct { + peerCtx context.Context log *log.Entry isController bool config ConnConfig @@ -33,8 +34,9 @@ type WorkerRelay struct { wgWatcher *WGWatcher } -func NewWorkerRelay(log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager relayClient.ManagerService, stateDump *stateDump) *WorkerRelay { +func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager relayClient.ManagerService, stateDump *stateDump) *WorkerRelay { r := &WorkerRelay{ + peerCtx: ctx, log: log, isController: ctrl, config: config, @@ -62,7 +64,7 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) { srv := w.preferredRelayServer(currentRelayAddress, remoteOfferAnswer.RelaySrvAddress) - relayedConn, err := w.relayManager.OpenConn(srv, w.config.Key) + relayedConn, err := w.relayManager.OpenConn(w.peerCtx, srv, w.config.Key) if err != nil { if errors.Is(err, relayClient.ErrConnAlreadyExists) { w.log.Debugf("handled offer by reusing existing relay connection") diff --git a/relay/auth/validator.go b/relay/auth/validator.go index 854efd5bb..56a20fbfe 100644 --- a/relay/auth/validator.go +++ b/relay/auth/validator.go @@ -7,13 +7,6 @@ import ( authv2 "github.com/netbirdio/netbird/relay/auth/hmac/v2" ) -// Validator is an interface that defines the Validate method. -type Validator interface { - Validate(any) error - // Deprecated: Use Validate instead. - ValidateHelloMsgType(any) error -} - type TimedHMACValidator struct { authenticatorV2 *authv2.Validator authenticator *auth.TimedHMACValidator diff --git a/relay/client/client.go b/relay/client/client.go index 9e7e54393..2bf679ecb 100644 --- a/relay/client/client.go +++ b/relay/client/client.go @@ -124,15 +124,14 @@ func (cc *connContainer) close() { // While the Connect is in progress, the OpenConn function will block until the connection is established with relay server. type Client struct { log *log.Entry - parentCtx context.Context connectionURL string authTokenStore *auth.TokenStore - hashedID []byte + hashedID messages.PeerID bufPool *sync.Pool relayConn net.Conn - conns map[string]*connContainer + conns map[messages.PeerID]*connContainer serviceIsRunning bool mu sync.Mutex // protect serviceIsRunning and conns readLoopMutex sync.Mutex @@ -142,14 +141,17 @@ type Client struct { onDisconnectListener func(string) listenerMutex sync.Mutex + + stateSubscription *PeersStateSubscription } // NewClient creates a new client for the relay server. The client is not connected to the server until the Connect -func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client { - hashedID, hashedStringId := messages.HashID(peerID) +func NewClient(serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client { + hashedID := messages.HashID(peerID) + relayLog := log.WithFields(log.Fields{"relay": serverURL}) + c := &Client{ - log: log.WithFields(log.Fields{"relay": serverURL}), - parentCtx: ctx, + log: relayLog, connectionURL: serverURL, authTokenStore: authTokenStore, hashedID: hashedID, @@ -159,14 +161,15 @@ func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.Token return &buf }, }, - conns: make(map[string]*connContainer), + conns: make(map[messages.PeerID]*connContainer), } - c.log.Infof("create new relay connection: local peerID: %s, local peer hashedID: %s", peerID, hashedStringId) + + c.log.Infof("create new relay connection: local peerID: %s, local peer hashedID: %s", peerID, hashedID) return c } // Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs. -func (c *Client) Connect() error { +func (c *Client) Connect(ctx context.Context) error { c.log.Infof("connecting to relay server") c.readLoopMutex.Lock() defer c.readLoopMutex.Unlock() @@ -178,17 +181,23 @@ func (c *Client) Connect() error { return nil } - if err := c.connect(); err != nil { + if err := c.connect(ctx); err != nil { return err } + c.stateSubscription = NewPeersStateSubscription(c.log, c.relayConn, c.closeConnsByPeerID) + c.log = c.log.WithField("relay", c.instanceURL.String()) c.log.Infof("relay connection established") c.serviceIsRunning = true + internallyStoppedFlag := newInternalStopFlag() + hc := healthcheck.NewReceiver(c.log) + go c.listenForStopEvents(ctx, hc, c.relayConn, internallyStoppedFlag) + c.wgReadLoop.Add(1) - go c.readLoop(c.relayConn) + go c.readLoop(hc, c.relayConn, internallyStoppedFlag) return nil } @@ -196,26 +205,41 @@ func (c *Client) Connect() error { // OpenConn create a new net.Conn for the destination peer ID. In case if the connection is in progress // to the relay server, the function will block until the connection is established or timed out. Otherwise, // it will return immediately. +// It block until the server confirm the peer is online. // todo: what should happen if call with the same peerID with multiple times? -func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) { - c.mu.Lock() - defer c.mu.Unlock() +func (c *Client) OpenConn(ctx context.Context, dstPeerID string) (net.Conn, error) { + peerID := messages.HashID(dstPeerID) + c.mu.Lock() if !c.serviceIsRunning { + c.mu.Unlock() return nil, fmt.Errorf("relay connection is not established") } - - hashedID, hashedStringID := messages.HashID(dstPeerID) - _, ok := c.conns[hashedStringID] + _, ok := c.conns[peerID] if ok { + c.mu.Unlock() return nil, ErrConnAlreadyExists } + c.mu.Unlock() - c.log.Infof("open connection to peer: %s", hashedStringID) + if err := c.stateSubscription.WaitToBeOnlineAndSubscribe(ctx, peerID); err != nil { + c.log.Errorf("peer not available: %s, %s", peerID, err) + return nil, err + } + + c.log.Infof("remote peer is available, prepare the relayed connection: %s", peerID) msgChannel := make(chan Msg, 100) - conn := NewConn(c, hashedID, hashedStringID, msgChannel, c.instanceURL) + conn := NewConn(c, peerID, msgChannel, c.instanceURL) - c.conns[hashedStringID] = newConnContainer(c.log, conn, msgChannel) + c.mu.Lock() + _, ok = c.conns[peerID] + if ok { + c.mu.Unlock() + _ = conn.Close() + return nil, ErrConnAlreadyExists + } + c.conns[peerID] = newConnContainer(c.log, conn, msgChannel) + c.mu.Unlock() return conn, nil } @@ -254,7 +278,7 @@ func (c *Client) Close() error { return c.close(true) } -func (c *Client) connect() error { +func (c *Client) connect(ctx context.Context) error { rd := dialer.NewRaceDial(c.log, c.connectionURL, quic.Dialer{}, ws.Dialer{}) conn, err := rd.Dial() if err != nil { @@ -262,7 +286,7 @@ func (c *Client) connect() error { } c.relayConn = conn - if err = c.handShake(); err != nil { + if err = c.handShake(ctx); err != nil { cErr := conn.Close() if cErr != nil { c.log.Errorf("failed to close connection: %s", cErr) @@ -273,7 +297,7 @@ func (c *Client) connect() error { return nil } -func (c *Client) handShake() error { +func (c *Client) handShake(ctx context.Context) error { msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary()) if err != nil { c.log.Errorf("failed to marshal auth message: %s", err) @@ -286,7 +310,7 @@ func (c *Client) handShake() error { return err } buf := make([]byte, messages.MaxHandshakeRespSize) - n, err := c.readWithTimeout(buf) + n, err := c.readWithTimeout(ctx, buf) if err != nil { c.log.Errorf("failed to read auth response: %s", err) return err @@ -319,11 +343,7 @@ func (c *Client) handShake() error { return nil } -func (c *Client) readLoop(relayConn net.Conn) { - internallyStoppedFlag := newInternalStopFlag() - hc := healthcheck.NewReceiver(c.log) - go c.listenForStopEvents(hc, relayConn, internallyStoppedFlag) - +func (c *Client) readLoop(hc *healthcheck.Receiver, relayConn net.Conn, internallyStoppedFlag *internalStopFlag) { var ( errExit error n int @@ -370,6 +390,7 @@ func (c *Client) readLoop(relayConn net.Conn) { c.instanceURL = nil c.muInstanceURL.Unlock() + c.stateSubscription.Cleanup() c.wgReadLoop.Done() _ = c.close(false) c.notifyDisconnected() @@ -382,6 +403,14 @@ func (c *Client) handleMsg(msgType messages.MsgType, buf []byte, bufPtr *[]byte, c.bufPool.Put(bufPtr) case messages.MsgTypeTransport: return c.handleTransportMsg(buf, bufPtr, internallyStoppedFlag) + case messages.MsgTypePeersOnline: + c.handlePeersOnlineMsg(buf) + c.bufPool.Put(bufPtr) + return true + case messages.MsgTypePeersWentOffline: + c.handlePeersWentOfflineMsg(buf) + c.bufPool.Put(bufPtr) + return true case messages.MsgTypeClose: c.log.Debugf("relay connection close by server") c.bufPool.Put(bufPtr) @@ -413,18 +442,16 @@ func (c *Client) handleTransportMsg(buf []byte, bufPtr *[]byte, internallyStoppe return true } - stringID := messages.HashIDToString(peerID) - c.mu.Lock() if !c.serviceIsRunning { c.mu.Unlock() c.bufPool.Put(bufPtr) return false } - container, ok := c.conns[stringID] + container, ok := c.conns[*peerID] c.mu.Unlock() if !ok { - c.log.Errorf("peer not found: %s", stringID) + c.log.Errorf("peer not found: %s", peerID.String()) c.bufPool.Put(bufPtr) return true } @@ -437,9 +464,9 @@ func (c *Client) handleTransportMsg(buf []byte, bufPtr *[]byte, internallyStoppe return true } -func (c *Client) writeTo(connReference *Conn, id string, dstID []byte, payload []byte) (int, error) { +func (c *Client) writeTo(connReference *Conn, dstID messages.PeerID, payload []byte) (int, error) { c.mu.Lock() - conn, ok := c.conns[id] + conn, ok := c.conns[dstID] c.mu.Unlock() if !ok { return 0, net.ErrClosed @@ -464,7 +491,7 @@ func (c *Client) writeTo(connReference *Conn, id string, dstID []byte, payload [ return len(payload), err } -func (c *Client) listenForStopEvents(hc *healthcheck.Receiver, conn net.Conn, internalStopFlag *internalStopFlag) { +func (c *Client) listenForStopEvents(ctx context.Context, hc *healthcheck.Receiver, conn net.Conn, internalStopFlag *internalStopFlag) { for { select { case _, ok := <-hc.OnTimeout: @@ -478,7 +505,7 @@ func (c *Client) listenForStopEvents(hc *healthcheck.Receiver, conn net.Conn, in c.log.Warnf("failed to close connection: %s", err) } return - case <-c.parentCtx.Done(): + case <-ctx.Done(): err := c.close(true) if err != nil { c.log.Errorf("failed to teardown connection: %s", err) @@ -492,10 +519,31 @@ func (c *Client) closeAllConns() { for _, container := range c.conns { container.close() } - c.conns = make(map[string]*connContainer) + c.conns = make(map[messages.PeerID]*connContainer) } -func (c *Client) closeConn(connReference *Conn, id string) error { +func (c *Client) closeConnsByPeerID(peerIDs []messages.PeerID) { + c.mu.Lock() + defer c.mu.Unlock() + + for _, peerID := range peerIDs { + container, ok := c.conns[peerID] + if !ok { + c.log.Warnf("can not close connection, peer not found: %s", peerID) + continue + } + + container.log.Infof("remote peer has been disconnected, free up connection: %s", peerID) + container.close() + delete(c.conns, peerID) + } + + if err := c.stateSubscription.UnsubscribeStateChange(peerIDs); err != nil { + c.log.Errorf("failed to unsubscribe from peer state change: %s, %s", peerIDs, err) + } +} + +func (c *Client) closeConn(connReference *Conn, id messages.PeerID) error { c.mu.Lock() defer c.mu.Unlock() @@ -507,6 +555,11 @@ func (c *Client) closeConn(connReference *Conn, id string) error { if container.conn != connReference { return fmt.Errorf("conn reference mismatch") } + + if err := c.stateSubscription.UnsubscribeStateChange([]messages.PeerID{id}); err != nil { + container.log.Errorf("failed to unsubscribe from peer state change: %s", err) + } + c.log.Infof("free up connection to peer: %s", id) delete(c.conns, id) container.close() @@ -559,8 +612,8 @@ func (c *Client) writeCloseMsg() { } } -func (c *Client) readWithTimeout(buf []byte) (int, error) { - ctx, cancel := context.WithTimeout(c.parentCtx, serverResponseTimeout) +func (c *Client) readWithTimeout(ctx context.Context, buf []byte) (int, error) { + ctx, cancel := context.WithTimeout(ctx, serverResponseTimeout) defer cancel() readDone := make(chan struct{}) @@ -581,3 +634,21 @@ func (c *Client) readWithTimeout(buf []byte) (int, error) { return n, err } } + +func (c *Client) handlePeersOnlineMsg(buf []byte) { + peersID, err := messages.UnmarshalPeersOnlineMsg(buf) + if err != nil { + c.log.Errorf("failed to unmarshal peers online msg: %s", err) + return + } + c.stateSubscription.OnPeersOnline(peersID) +} + +func (c *Client) handlePeersWentOfflineMsg(buf []byte) { + peersID, err := messages.UnMarshalPeersWentOffline(buf) + if err != nil { + c.log.Errorf("failed to unmarshal peers went offline msg: %s", err) + return + } + c.stateSubscription.OnPeersWentOffline(peersID) +} diff --git a/relay/client/client_test.go b/relay/client/client_test.go index 7ddfba4c6..dd5f5fe1e 100644 --- a/relay/client/client_test.go +++ b/relay/client/client_test.go @@ -18,14 +18,19 @@ import ( ) var ( - av = &allow.Auth{} hmacTokenStore = &hmac.TokenStore{} serverListenAddr = "127.0.0.1:1234" serverURL = "rel://127.0.0.1:1234" + serverCfg = server.Config{ + Meter: otel.Meter(""), + ExposedAddress: serverURL, + TLSSupport: false, + AuthValidator: &allow.Auth{}, + } ) func TestMain(m *testing.M) { - _ = util.InitLog("error", "console") + _ = util.InitLog("debug", "console") code := m.Run() os.Exit(code) } @@ -33,7 +38,7 @@ func TestMain(m *testing.M) { func TestClient(t *testing.T) { ctx := context.Background() - srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) + srv, err := server.NewServer(serverCfg) if err != nil { t.Fatalf("failed to create server: %s", err) } @@ -58,37 +63,37 @@ func TestClient(t *testing.T) { t.Fatalf("failed to start server: %s", err) } t.Log("alice connecting to server") - clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") - err = clientAlice.Connect() + clientAlice := NewClient(serverURL, hmacTokenStore, "alice") + err = clientAlice.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) } defer clientAlice.Close() t.Log("placeholder connecting to server") - clientPlaceHolder := NewClient(ctx, serverURL, hmacTokenStore, "clientPlaceHolder") - err = clientPlaceHolder.Connect() + clientPlaceHolder := NewClient(serverURL, hmacTokenStore, "clientPlaceHolder") + err = clientPlaceHolder.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) } defer clientPlaceHolder.Close() t.Log("Bob connecting to server") - clientBob := NewClient(ctx, serverURL, hmacTokenStore, "bob") - err = clientBob.Connect() + clientBob := NewClient(serverURL, hmacTokenStore, "bob") + err = clientBob.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) } defer clientBob.Close() t.Log("Alice open connection to Bob") - connAliceToBob, err := clientAlice.OpenConn("bob") + connAliceToBob, err := clientAlice.OpenConn(ctx, "bob") if err != nil { t.Fatalf("failed to bind channel: %s", err) } t.Log("Bob open connection to Alice") - connBobToAlice, err := clientBob.OpenConn("alice") + connBobToAlice, err := clientBob.OpenConn(ctx, "alice") if err != nil { t.Fatalf("failed to bind channel: %s", err) } @@ -115,7 +120,7 @@ func TestClient(t *testing.T) { func TestRegistration(t *testing.T) { ctx := context.Background() srvCfg := server.ListenerConfig{Address: serverListenAddr} - srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) + srv, err := server.NewServer(serverCfg) if err != nil { t.Fatalf("failed to create server: %s", err) } @@ -132,8 +137,8 @@ func TestRegistration(t *testing.T) { t.Fatalf("failed to start server: %s", err) } - clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") - err = clientAlice.Connect() + clientAlice := NewClient(serverURL, hmacTokenStore, "alice") + err = clientAlice.Connect(ctx) if err != nil { _ = srv.Shutdown(ctx) t.Fatalf("failed to connect to server: %s", err) @@ -172,8 +177,8 @@ func TestRegistrationTimeout(t *testing.T) { _ = fakeTCPListener.Close() }(fakeTCPListener) - clientAlice := NewClient(ctx, "127.0.0.1:1234", hmacTokenStore, "alice") - err = clientAlice.Connect() + clientAlice := NewClient("127.0.0.1:1234", hmacTokenStore, "alice") + err = clientAlice.Connect(ctx) if err == nil { t.Errorf("failed to connect to server: %s", err) } @@ -189,7 +194,7 @@ func TestEcho(t *testing.T) { idAlice := "alice" idBob := "bob" srvCfg := server.ListenerConfig{Address: serverListenAddr} - srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) + srv, err := server.NewServer(serverCfg) if err != nil { t.Fatalf("failed to create server: %s", err) } @@ -213,8 +218,8 @@ func TestEcho(t *testing.T) { t.Fatalf("failed to start server: %s", err) } - clientAlice := NewClient(ctx, serverURL, hmacTokenStore, idAlice) - err = clientAlice.Connect() + clientAlice := NewClient(serverURL, hmacTokenStore, idAlice) + err = clientAlice.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) } @@ -225,8 +230,8 @@ func TestEcho(t *testing.T) { } }() - clientBob := NewClient(ctx, serverURL, hmacTokenStore, idBob) - err = clientBob.Connect() + clientBob := NewClient(serverURL, hmacTokenStore, idBob) + err = clientBob.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) } @@ -237,12 +242,12 @@ func TestEcho(t *testing.T) { } }() - connAliceToBob, err := clientAlice.OpenConn(idBob) + connAliceToBob, err := clientAlice.OpenConn(ctx, idBob) if err != nil { t.Fatalf("failed to bind channel: %s", err) } - connBobToAlice, err := clientBob.OpenConn(idAlice) + connBobToAlice, err := clientBob.OpenConn(ctx, idAlice) if err != nil { t.Fatalf("failed to bind channel: %s", err) } @@ -278,7 +283,7 @@ func TestBindToUnavailabePeer(t *testing.T) { ctx := context.Background() srvCfg := server.ListenerConfig{Address: serverListenAddr} - srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) + srv, err := server.NewServer(serverCfg) if err != nil { t.Fatalf("failed to create server: %s", err) } @@ -303,14 +308,14 @@ func TestBindToUnavailabePeer(t *testing.T) { t.Fatalf("failed to start server: %s", err) } - clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") - err = clientAlice.Connect() + clientAlice := NewClient(serverURL, hmacTokenStore, "alice") + err = clientAlice.Connect(ctx) if err != nil { t.Errorf("failed to connect to server: %s", err) } - _, err = clientAlice.OpenConn("bob") - if err != nil { - t.Errorf("failed to bind channel: %s", err) + _, err = clientAlice.OpenConn(ctx, "bob") + if err == nil { + t.Errorf("expected error when binding to unavailable peer, got nil") } log.Infof("closing client") @@ -324,7 +329,7 @@ func TestBindReconnect(t *testing.T) { ctx := context.Background() srvCfg := server.ListenerConfig{Address: serverListenAddr} - srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) + srv, err := server.NewServer(serverCfg) if err != nil { t.Fatalf("failed to create server: %s", err) } @@ -349,24 +354,24 @@ func TestBindReconnect(t *testing.T) { t.Fatalf("failed to start server: %s", err) } - clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") - err = clientAlice.Connect() + clientAlice := NewClient(serverURL, hmacTokenStore, "alice") + err = clientAlice.Connect(ctx) + if err != nil { + t.Fatalf("failed to connect to server: %s", err) + } + + clientBob := NewClient(serverURL, hmacTokenStore, "bob") + err = clientBob.Connect(ctx) if err != nil { t.Errorf("failed to connect to server: %s", err) } - _, err = clientAlice.OpenConn("bob") + _, err = clientAlice.OpenConn(ctx, "bob") if err != nil { - t.Errorf("failed to bind channel: %s", err) + t.Fatalf("failed to bind channel: %s", err) } - clientBob := NewClient(ctx, serverURL, hmacTokenStore, "bob") - err = clientBob.Connect() - if err != nil { - t.Errorf("failed to connect to server: %s", err) - } - - chBob, err := clientBob.OpenConn("alice") + chBob, err := clientBob.OpenConn(ctx, "alice") if err != nil { t.Errorf("failed to bind channel: %s", err) } @@ -377,18 +382,28 @@ func TestBindReconnect(t *testing.T) { t.Errorf("failed to close client: %s", err) } - clientAlice = NewClient(ctx, serverURL, hmacTokenStore, "alice") - err = clientAlice.Connect() + clientAlice = NewClient(serverURL, hmacTokenStore, "alice") + err = clientAlice.Connect(ctx) if err != nil { t.Errorf("failed to connect to server: %s", err) } - chAlice, err := clientAlice.OpenConn("bob") + chAlice, err := clientAlice.OpenConn(ctx, "bob") if err != nil { t.Errorf("failed to bind channel: %s", err) } testString := "hello alice, I am bob" + _, err = chBob.Write([]byte(testString)) + if err == nil { + t.Errorf("expected error when writing to channel, got nil") + } + + chBob, err = clientBob.OpenConn(ctx, "alice") + if err != nil { + t.Errorf("failed to bind channel: %s", err) + } + _, err = chBob.Write([]byte(testString)) if err != nil { t.Errorf("failed to write to channel: %s", err) @@ -415,7 +430,7 @@ func TestCloseConn(t *testing.T) { ctx := context.Background() srvCfg := server.ListenerConfig{Address: serverListenAddr} - srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) + srv, err := server.NewServer(serverCfg) if err != nil { t.Fatalf("failed to create server: %s", err) } @@ -440,13 +455,19 @@ func TestCloseConn(t *testing.T) { t.Fatalf("failed to start server: %s", err) } - clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") - err = clientAlice.Connect() + bob := NewClient(serverURL, hmacTokenStore, "bob") + err = bob.Connect(ctx) if err != nil { t.Errorf("failed to connect to server: %s", err) } - conn, err := clientAlice.OpenConn("bob") + clientAlice := NewClient(serverURL, hmacTokenStore, "alice") + err = clientAlice.Connect(ctx) + if err != nil { + t.Errorf("failed to connect to server: %s", err) + } + + conn, err := clientAlice.OpenConn(ctx, "bob") if err != nil { t.Errorf("failed to bind channel: %s", err) } @@ -472,7 +493,7 @@ func TestCloseRelayConn(t *testing.T) { ctx := context.Background() srvCfg := server.ListenerConfig{Address: serverListenAddr} - srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) + srv, err := server.NewServer(serverCfg) if err != nil { t.Fatalf("failed to create server: %s", err) } @@ -496,13 +517,19 @@ func TestCloseRelayConn(t *testing.T) { t.Fatalf("failed to start server: %s", err) } - clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") - err = clientAlice.Connect() + bob := NewClient(serverURL, hmacTokenStore, "bob") + err = bob.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) } - conn, err := clientAlice.OpenConn("bob") + clientAlice := NewClient(serverURL, hmacTokenStore, "alice") + err = clientAlice.Connect(ctx) + if err != nil { + t.Fatalf("failed to connect to server: %s", err) + } + + conn, err := clientAlice.OpenConn(ctx, "bob") if err != nil { t.Errorf("failed to bind channel: %s", err) } @@ -514,7 +541,7 @@ func TestCloseRelayConn(t *testing.T) { t.Errorf("unexpected reading from closed connection") } - _, err = clientAlice.OpenConn("bob") + _, err = clientAlice.OpenConn(ctx, "bob") if err == nil { t.Errorf("unexpected opening connection to closed server") } @@ -524,7 +551,7 @@ func TestCloseByServer(t *testing.T) { ctx := context.Background() srvCfg := server.ListenerConfig{Address: serverListenAddr} - srv1, err := server.NewServer(otel.Meter(""), serverURL, false, av) + srv1, err := server.NewServer(serverCfg) if err != nil { t.Fatalf("failed to create server: %s", err) } @@ -544,8 +571,8 @@ func TestCloseByServer(t *testing.T) { idAlice := "alice" log.Debugf("connect by alice") - relayClient := NewClient(ctx, serverURL, hmacTokenStore, idAlice) - err = relayClient.Connect() + relayClient := NewClient(serverURL, hmacTokenStore, idAlice) + err = relayClient.Connect(ctx) if err != nil { log.Fatalf("failed to connect to server: %s", err) } @@ -567,7 +594,7 @@ func TestCloseByServer(t *testing.T) { log.Fatalf("timeout waiting for client to disconnect") } - _, err = relayClient.OpenConn("bob") + _, err = relayClient.OpenConn(ctx, "bob") if err == nil { t.Errorf("unexpected opening connection to closed server") } @@ -577,7 +604,7 @@ func TestCloseByClient(t *testing.T) { ctx := context.Background() srvCfg := server.ListenerConfig{Address: serverListenAddr} - srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) + srv, err := server.NewServer(serverCfg) if err != nil { t.Fatalf("failed to create server: %s", err) } @@ -596,8 +623,8 @@ func TestCloseByClient(t *testing.T) { idAlice := "alice" log.Debugf("connect by alice") - relayClient := NewClient(ctx, serverURL, hmacTokenStore, idAlice) - err = relayClient.Connect() + relayClient := NewClient(serverURL, hmacTokenStore, idAlice) + err = relayClient.Connect(ctx) if err != nil { log.Fatalf("failed to connect to server: %s", err) } @@ -607,7 +634,7 @@ func TestCloseByClient(t *testing.T) { t.Errorf("failed to close client: %s", err) } - _, err = relayClient.OpenConn("bob") + _, err = relayClient.OpenConn(ctx, "bob") if err == nil { t.Errorf("unexpected opening connection to closed server") } @@ -623,7 +650,7 @@ func TestCloseNotDrainedChannel(t *testing.T) { idAlice := "alice" idBob := "bob" srvCfg := server.ListenerConfig{Address: serverListenAddr} - srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) + srv, err := server.NewServer(serverCfg) if err != nil { t.Fatalf("failed to create server: %s", err) } @@ -647,8 +674,8 @@ func TestCloseNotDrainedChannel(t *testing.T) { t.Fatalf("failed to start server: %s", err) } - clientAlice := NewClient(ctx, serverURL, hmacTokenStore, idAlice) - err = clientAlice.Connect() + clientAlice := NewClient(serverURL, hmacTokenStore, idAlice) + err = clientAlice.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) } @@ -659,8 +686,8 @@ func TestCloseNotDrainedChannel(t *testing.T) { } }() - clientBob := NewClient(ctx, serverURL, hmacTokenStore, idBob) - err = clientBob.Connect() + clientBob := NewClient(serverURL, hmacTokenStore, idBob) + err = clientBob.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) } @@ -671,12 +698,12 @@ func TestCloseNotDrainedChannel(t *testing.T) { } }() - connAliceToBob, err := clientAlice.OpenConn(idBob) + connAliceToBob, err := clientAlice.OpenConn(ctx, idBob) if err != nil { t.Fatalf("failed to bind channel: %s", err) } - connBobToAlice, err := clientBob.OpenConn(idAlice) + connBobToAlice, err := clientBob.OpenConn(ctx, idAlice) if err != nil { t.Fatalf("failed to bind channel: %s", err) } diff --git a/relay/client/conn.go b/relay/client/conn.go index fe1b6fb52..d8cffa695 100644 --- a/relay/client/conn.go +++ b/relay/client/conn.go @@ -3,13 +3,14 @@ package client import ( "net" "time" + + "github.com/netbirdio/netbird/relay/messages" ) // Conn represent a connection to a relayed remote peer. type Conn struct { client *Client - dstID []byte - dstStringID string + dstID messages.PeerID messageChan chan Msg instanceURL *RelayAddr } @@ -17,14 +18,12 @@ type Conn struct { // NewConn creates a new connection to a relayed remote peer. // client: the client instance, it used to send messages to the destination peer // dstID: the destination peer ID -// dstStringID: the destination peer ID in string format // messageChan: the channel where the messages will be received // instanceURL: the relay instance URL, it used to get the proper server instance address for the remote peer -func NewConn(client *Client, dstID []byte, dstStringID string, messageChan chan Msg, instanceURL *RelayAddr) *Conn { +func NewConn(client *Client, dstID messages.PeerID, messageChan chan Msg, instanceURL *RelayAddr) *Conn { c := &Conn{ client: client, dstID: dstID, - dstStringID: dstStringID, messageChan: messageChan, instanceURL: instanceURL, } @@ -33,7 +32,7 @@ func NewConn(client *Client, dstID []byte, dstStringID string, messageChan chan } func (c *Conn) Write(p []byte) (n int, err error) { - return c.client.writeTo(c, c.dstStringID, c.dstID, p) + return c.client.writeTo(c, c.dstID, p) } func (c *Conn) Read(b []byte) (n int, err error) { @@ -48,7 +47,7 @@ func (c *Conn) Read(b []byte) (n int, err error) { } func (c *Conn) Close() error { - return c.client.closeConn(c, c.dstStringID) + return c.client.closeConn(c, c.dstID) } func (c *Conn) LocalAddr() net.Addr { diff --git a/relay/client/guard.go b/relay/client/guard.go index 554330ea3..100892d81 100644 --- a/relay/client/guard.go +++ b/relay/client/guard.go @@ -80,7 +80,7 @@ func (g *Guard) tryToQuickReconnect(parentCtx context.Context, rc *Client) bool log.Infof("try to reconnect to Relay server: %s", rc.connectionURL) - if err := rc.Connect(); err != nil { + if err := rc.Connect(parentCtx); err != nil { log.Errorf("failed to reconnect to relay server: %s", err) return false } diff --git a/relay/client/manager.go b/relay/client/manager.go index 26b113050..0fb682d95 100644 --- a/relay/client/manager.go +++ b/relay/client/manager.go @@ -42,7 +42,7 @@ type OnServerCloseListener func() // ManagerService is the interface for the relay manager. type ManagerService interface { Serve() error - OpenConn(serverAddress, peerKey string) (net.Conn, error) + OpenConn(ctx context.Context, serverAddress, peerKey string) (net.Conn, error) AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error RelayInstanceAddress() (string, error) ServerURLs() []string @@ -123,7 +123,7 @@ func (m *Manager) Serve() error { // OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be // established via the relay server. If the peer is on a different relay server, the manager will establish a new // connection to the relay server. It returns back with a net.Conn what represent the remote peer connection. -func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) { +func (m *Manager) OpenConn(ctx context.Context, serverAddress, peerKey string) (net.Conn, error) { m.relayClientMu.Lock() defer m.relayClientMu.Unlock() @@ -141,10 +141,10 @@ func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) { ) if !foreign { log.Debugf("open peer connection via permanent server: %s", peerKey) - netConn, err = m.relayClient.OpenConn(peerKey) + netConn, err = m.relayClient.OpenConn(ctx, peerKey) } else { log.Debugf("open peer connection via foreign server: %s", serverAddress) - netConn, err = m.openConnVia(serverAddress, peerKey) + netConn, err = m.openConnVia(ctx, serverAddress, peerKey) } if err != nil { return nil, err @@ -229,7 +229,7 @@ func (m *Manager) UpdateToken(token *relayAuth.Token) error { return m.tokenStore.UpdateToken(token) } -func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) { +func (m *Manager) openConnVia(ctx context.Context, serverAddress, peerKey string) (net.Conn, error) { // check if already has a connection to the desired relay server m.relayClientsMutex.RLock() rt, ok := m.relayClients[serverAddress] @@ -240,7 +240,7 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) { if rt.err != nil { return nil, rt.err } - return rt.relayClient.OpenConn(peerKey) + return rt.relayClient.OpenConn(ctx, peerKey) } m.relayClientsMutex.RUnlock() @@ -255,7 +255,7 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) { if rt.err != nil { return nil, rt.err } - return rt.relayClient.OpenConn(peerKey) + return rt.relayClient.OpenConn(ctx, peerKey) } // create a new relay client and store it in the relayClients map @@ -264,8 +264,8 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) { m.relayClients[serverAddress] = rt m.relayClientsMutex.Unlock() - relayClient := NewClient(m.ctx, serverAddress, m.tokenStore, m.peerID) - err := relayClient.Connect() + relayClient := NewClient(serverAddress, m.tokenStore, m.peerID) + err := relayClient.Connect(m.ctx) if err != nil { rt.err = err rt.Unlock() @@ -279,7 +279,7 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) { rt.relayClient = relayClient rt.Unlock() - conn, err := relayClient.OpenConn(peerKey) + conn, err := relayClient.OpenConn(ctx, peerKey) if err != nil { return nil, err } diff --git a/relay/client/manager_test.go b/relay/client/manager_test.go index bfc342f25..d20cdaac0 100644 --- a/relay/client/manager_test.go +++ b/relay/client/manager_test.go @@ -8,6 +8,7 @@ import ( log "github.com/sirupsen/logrus" "go.opentelemetry.io/otel" + "github.com/netbirdio/netbird/relay/auth/allow" "github.com/netbirdio/netbird/relay/server" ) @@ -22,16 +23,22 @@ func TestEmptyURL(t *testing.T) { func TestForeignConn(t *testing.T) { ctx := context.Background() - srvCfg1 := server.ListenerConfig{ + lstCfg1 := server.ListenerConfig{ Address: "localhost:1234", } - srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av) + + srv1, err := server.NewServer(server.Config{ + Meter: otel.Meter(""), + ExposedAddress: lstCfg1.Address, + TLSSupport: false, + AuthValidator: &allow.Auth{}, + }) if err != nil { t.Fatalf("failed to create server: %s", err) } errChan := make(chan error, 1) go func() { - err := srv1.Listen(srvCfg1) + err := srv1.Listen(lstCfg1) if err != nil { errChan <- err } @@ -51,7 +58,12 @@ func TestForeignConn(t *testing.T) { srvCfg2 := server.ListenerConfig{ Address: "localhost:2234", } - srv2, err := server.NewServer(otel.Meter(""), srvCfg2.Address, false, av) + srv2, err := server.NewServer(server.Config{ + Meter: otel.Meter(""), + ExposedAddress: srvCfg2.Address, + TLSSupport: false, + AuthValidator: &allow.Auth{}, + }) if err != nil { t.Fatalf("failed to create server: %s", err) } @@ -74,32 +86,26 @@ func TestForeignConn(t *testing.T) { t.Fatalf("failed to start server: %s", err) } - idAlice := "alice" - log.Debugf("connect by alice") mCtx, cancel := context.WithCancel(ctx) defer cancel() - clientAlice := NewManager(mCtx, toURL(srvCfg1), idAlice) - err = clientAlice.Serve() - if err != nil { + clientAlice := NewManager(mCtx, toURL(lstCfg1), "alice") + if err := clientAlice.Serve(); err != nil { t.Fatalf("failed to serve manager: %s", err) } - idBob := "bob" - log.Debugf("connect by bob") - clientBob := NewManager(mCtx, toURL(srvCfg2), idBob) - err = clientBob.Serve() - if err != nil { + clientBob := NewManager(mCtx, toURL(srvCfg2), "bob") + if err := clientBob.Serve(); err != nil { t.Fatalf("failed to serve manager: %s", err) } bobsSrvAddr, err := clientBob.RelayInstanceAddress() if err != nil { t.Fatalf("failed to get relay address: %s", err) } - connAliceToBob, err := clientAlice.OpenConn(bobsSrvAddr, idBob) + connAliceToBob, err := clientAlice.OpenConn(ctx, bobsSrvAddr, "bob") if err != nil { t.Fatalf("failed to bind channel: %s", err) } - connBobToAlice, err := clientBob.OpenConn(bobsSrvAddr, idAlice) + connBobToAlice, err := clientBob.OpenConn(ctx, bobsSrvAddr, "alice") if err != nil { t.Fatalf("failed to bind channel: %s", err) } @@ -137,7 +143,7 @@ func TestForeginConnClose(t *testing.T) { srvCfg1 := server.ListenerConfig{ Address: "localhost:1234", } - srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av) + srv1, err := server.NewServer(serverCfg) if err != nil { t.Fatalf("failed to create server: %s", err) } @@ -163,7 +169,7 @@ func TestForeginConnClose(t *testing.T) { srvCfg2 := server.ListenerConfig{ Address: "localhost:2234", } - srv2, err := server.NewServer(otel.Meter(""), srvCfg2.Address, false, av) + srv2, err := server.NewServer(serverCfg) if err != nil { t.Fatalf("failed to create server: %s", err) } @@ -186,16 +192,20 @@ func TestForeginConnClose(t *testing.T) { t.Fatalf("failed to start server: %s", err) } - idAlice := "alice" - log.Debugf("connect by alice") mCtx, cancel := context.WithCancel(ctx) defer cancel() - mgr := NewManager(mCtx, toURL(srvCfg1), idAlice) + + mgrBob := NewManager(mCtx, toURL(srvCfg2), "bob") + if err := mgrBob.Serve(); err != nil { + t.Fatalf("failed to serve manager: %s", err) + } + + mgr := NewManager(mCtx, toURL(srvCfg1), "alice") err = mgr.Serve() if err != nil { t.Fatalf("failed to serve manager: %s", err) } - conn, err := mgr.OpenConn(toURL(srvCfg2)[0], "anotherpeer") + conn, err := mgr.OpenConn(ctx, toURL(srvCfg2)[0], "bob") if err != nil { t.Fatalf("failed to bind channel: %s", err) } @@ -212,7 +222,7 @@ func TestForeginAutoClose(t *testing.T) { srvCfg1 := server.ListenerConfig{ Address: "localhost:1234", } - srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av) + srv1, err := server.NewServer(serverCfg) if err != nil { t.Fatalf("failed to create server: %s", err) } @@ -241,7 +251,7 @@ func TestForeginAutoClose(t *testing.T) { srvCfg2 := server.ListenerConfig{ Address: "localhost:2234", } - srv2, err := server.NewServer(otel.Meter(""), srvCfg2.Address, false, av) + srv2, err := server.NewServer(serverCfg) if err != nil { t.Fatalf("failed to create server: %s", err) } @@ -277,7 +287,7 @@ func TestForeginAutoClose(t *testing.T) { } t.Log("open connection to another peer") - conn, err := mgr.OpenConn(toURL(srvCfg2)[0], "anotherpeer") + conn, err := mgr.OpenConn(ctx, toURL(srvCfg2)[0], "anotherpeer") if err != nil { t.Fatalf("failed to bind channel: %s", err) } @@ -305,7 +315,7 @@ func TestAutoReconnect(t *testing.T) { srvCfg := server.ListenerConfig{ Address: "localhost:1234", } - srv, err := server.NewServer(otel.Meter(""), srvCfg.Address, false, av) + srv, err := server.NewServer(serverCfg) if err != nil { t.Fatalf("failed to create server: %s", err) } @@ -330,6 +340,13 @@ func TestAutoReconnect(t *testing.T) { mCtx, cancel := context.WithCancel(ctx) defer cancel() + + clientBob := NewManager(mCtx, toURL(srvCfg), "bob") + err = clientBob.Serve() + if err != nil { + t.Fatalf("failed to serve manager: %s", err) + } + clientAlice := NewManager(mCtx, toURL(srvCfg), "alice") err = clientAlice.Serve() if err != nil { @@ -339,7 +356,7 @@ func TestAutoReconnect(t *testing.T) { if err != nil { t.Errorf("failed to get relay address: %s", err) } - conn, err := clientAlice.OpenConn(ra, "bob") + conn, err := clientAlice.OpenConn(ctx, ra, "bob") if err != nil { t.Errorf("failed to bind channel: %s", err) } @@ -357,7 +374,7 @@ func TestAutoReconnect(t *testing.T) { time.Sleep(reconnectingTimeout + 1*time.Second) log.Infof("reopent the connection") - _, err = clientAlice.OpenConn(ra, "bob") + _, err = clientAlice.OpenConn(ctx, ra, "bob") if err != nil { t.Errorf("failed to open channel: %s", err) } @@ -366,24 +383,27 @@ func TestAutoReconnect(t *testing.T) { func TestNotifierDoubleAdd(t *testing.T) { ctx := context.Background() - srvCfg1 := server.ListenerConfig{ + listenerCfg1 := server.ListenerConfig{ Address: "localhost:1234", } - srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av) + srv, err := server.NewServer(server.Config{ + Meter: otel.Meter(""), + ExposedAddress: listenerCfg1.Address, + TLSSupport: false, + AuthValidator: &allow.Auth{}, + }) if err != nil { t.Fatalf("failed to create server: %s", err) } errChan := make(chan error, 1) go func() { - err := srv1.Listen(srvCfg1) - if err != nil { + if err := srv.Listen(listenerCfg1); err != nil { errChan <- err } }() defer func() { - err := srv1.Shutdown(ctx) - if err != nil { + if err := srv.Shutdown(ctx); err != nil { t.Errorf("failed to close server: %s", err) } }() @@ -392,17 +412,21 @@ func TestNotifierDoubleAdd(t *testing.T) { t.Fatalf("failed to start server: %s", err) } - idAlice := "alice" log.Debugf("connect by alice") mCtx, cancel := context.WithCancel(ctx) defer cancel() - clientAlice := NewManager(mCtx, toURL(srvCfg1), idAlice) - err = clientAlice.Serve() - if err != nil { + + clientBob := NewManager(mCtx, toURL(listenerCfg1), "bob") + if err = clientBob.Serve(); err != nil { t.Fatalf("failed to serve manager: %s", err) } - conn1, err := clientAlice.OpenConn(clientAlice.ServerURLs()[0], "idBob") + clientAlice := NewManager(mCtx, toURL(listenerCfg1), "alice") + if err = clientAlice.Serve(); err != nil { + t.Fatalf("failed to serve manager: %s", err) + } + + conn1, err := clientAlice.OpenConn(ctx, clientAlice.ServerURLs()[0], "bob") if err != nil { t.Fatalf("failed to bind channel: %s", err) } diff --git a/relay/client/peer_subscription.go b/relay/client/peer_subscription.go new file mode 100644 index 000000000..03e7127b3 --- /dev/null +++ b/relay/client/peer_subscription.go @@ -0,0 +1,168 @@ +package client + +import ( + "context" + "errors" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/relay/messages" +) + +const ( + OpenConnectionTimeout = 30 * time.Second +) + +type relayedConnWriter interface { + Write(p []byte) (n int, err error) +} + +// PeersStateSubscription manages subscriptions to peer state changes (online/offline) +// over a relay connection. It allows tracking peers' availability and handling offline +// events via a callback. We get online notification from the server only once. +type PeersStateSubscription struct { + log *log.Entry + relayConn relayedConnWriter + offlineCallback func(peerIDs []messages.PeerID) + + listenForOfflinePeers map[messages.PeerID]struct{} + waitingPeers map[messages.PeerID]chan struct{} +} + +func NewPeersStateSubscription(log *log.Entry, relayConn relayedConnWriter, offlineCallback func(peerIDs []messages.PeerID)) *PeersStateSubscription { + return &PeersStateSubscription{ + log: log, + relayConn: relayConn, + offlineCallback: offlineCallback, + listenForOfflinePeers: make(map[messages.PeerID]struct{}), + waitingPeers: make(map[messages.PeerID]chan struct{}), + } +} + +// OnPeersOnline should be called when a notification is received that certain peers have come online. +// It checks if any of the peers are being waited on and signals their availability. +func (s *PeersStateSubscription) OnPeersOnline(peersID []messages.PeerID) { + for _, peerID := range peersID { + waitCh, ok := s.waitingPeers[peerID] + if !ok { + continue + } + + close(waitCh) + delete(s.waitingPeers, peerID) + } +} + +func (s *PeersStateSubscription) OnPeersWentOffline(peersID []messages.PeerID) { + relevantPeers := make([]messages.PeerID, 0, len(peersID)) + for _, peerID := range peersID { + if _, ok := s.listenForOfflinePeers[peerID]; ok { + relevantPeers = append(relevantPeers, peerID) + } + } + + if len(relevantPeers) > 0 { + s.offlineCallback(relevantPeers) + } +} + +// WaitToBeOnlineAndSubscribe waits for a specific peer to come online and subscribes to its state changes. +// todo: when we unsubscribe while this is running, this will not return with error +func (s *PeersStateSubscription) WaitToBeOnlineAndSubscribe(ctx context.Context, peerID messages.PeerID) error { + // Check if already waiting for this peer + if _, exists := s.waitingPeers[peerID]; exists { + return errors.New("already waiting for peer to come online") + } + + // Create a channel to wait for the peer to come online + waitCh := make(chan struct{}) + s.waitingPeers[peerID] = waitCh + + if err := s.subscribeStateChange([]messages.PeerID{peerID}); err != nil { + s.log.Errorf("failed to subscribe to peer state: %s", err) + close(waitCh) + delete(s.waitingPeers, peerID) + return err + } + + defer func() { + if ch, exists := s.waitingPeers[peerID]; exists && ch == waitCh { + close(waitCh) + delete(s.waitingPeers, peerID) + } + }() + + // Wait for peer to come online or context to be cancelled + timeoutCtx, cancel := context.WithTimeout(ctx, OpenConnectionTimeout) + defer cancel() + select { + case <-waitCh: + s.log.Debugf("peer %s is now online", peerID) + return nil + case <-timeoutCtx.Done(): + s.log.Debugf("context timed out while waiting for peer %s to come online", peerID) + if err := s.unsubscribeStateChange([]messages.PeerID{peerID}); err != nil { + s.log.Errorf("failed to unsubscribe from peer state: %s", err) + } + return timeoutCtx.Err() + } +} + +func (s *PeersStateSubscription) UnsubscribeStateChange(peerIDs []messages.PeerID) error { + msgErr := s.unsubscribeStateChange(peerIDs) + + for _, peerID := range peerIDs { + if wch, ok := s.waitingPeers[peerID]; ok { + close(wch) + delete(s.waitingPeers, peerID) + } + + delete(s.listenForOfflinePeers, peerID) + } + + return msgErr +} + +func (s *PeersStateSubscription) Cleanup() { + for _, waitCh := range s.waitingPeers { + close(waitCh) + } + + s.waitingPeers = make(map[messages.PeerID]chan struct{}) + s.listenForOfflinePeers = make(map[messages.PeerID]struct{}) +} + +func (s *PeersStateSubscription) subscribeStateChange(peerIDs []messages.PeerID) error { + msgs, err := messages.MarshalSubPeerStateMsg(peerIDs) + if err != nil { + return err + } + + for _, peer := range peerIDs { + s.listenForOfflinePeers[peer] = struct{}{} + } + + for _, msg := range msgs { + if _, err := s.relayConn.Write(msg); err != nil { + return err + } + + } + return nil +} + +func (s *PeersStateSubscription) unsubscribeStateChange(peerIDs []messages.PeerID) error { + msgs, err := messages.MarshalUnsubPeerStateMsg(peerIDs) + if err != nil { + return err + } + + var connWriteErr error + for _, msg := range msgs { + if _, err := s.relayConn.Write(msg); err != nil { + connWriteErr = err + } + } + return connWriteErr +} diff --git a/relay/client/peer_subscription_test.go b/relay/client/peer_subscription_test.go new file mode 100644 index 000000000..0437efa04 --- /dev/null +++ b/relay/client/peer_subscription_test.go @@ -0,0 +1,99 @@ +package client + +import ( + "bytes" + "context" + "testing" + "time" + + "github.com/netbirdio/netbird/relay/messages" + + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type mockRelayedConn struct { +} + +func (m *mockRelayedConn) Write(p []byte) (n int, err error) { + return len(p), nil +} + +func TestWaitToBeOnlineAndSubscribe_Success(t *testing.T) { + peerID := messages.HashID("peer1") + mockConn := &mockRelayedConn{} + logger := logrus.New() + logger.SetOutput(&bytes.Buffer{}) // discard log output + sub := NewPeersStateSubscription(logrus.NewEntry(logger), mockConn, nil) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Launch wait in background + go func() { + time.Sleep(100 * time.Millisecond) + sub.OnPeersOnline([]messages.PeerID{peerID}) + }() + + err := sub.WaitToBeOnlineAndSubscribe(ctx, peerID) + assert.NoError(t, err) +} + +func TestWaitToBeOnlineAndSubscribe_Timeout(t *testing.T) { + peerID := messages.HashID("peer2") + mockConn := &mockRelayedConn{} + logger := logrus.New() + logger.SetOutput(&bytes.Buffer{}) + sub := NewPeersStateSubscription(logrus.NewEntry(logger), mockConn, nil) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + err := sub.WaitToBeOnlineAndSubscribe(ctx, peerID) + assert.Error(t, err) + assert.Equal(t, context.DeadlineExceeded, err) +} + +func TestWaitToBeOnlineAndSubscribe_Duplicate(t *testing.T) { + peerID := messages.HashID("peer3") + mockConn := &mockRelayedConn{} + logger := logrus.New() + logger.SetOutput(&bytes.Buffer{}) + sub := NewPeersStateSubscription(logrus.NewEntry(logger), mockConn, nil) + + ctx := context.Background() + go func() { + _ = sub.WaitToBeOnlineAndSubscribe(ctx, peerID) + + }() + time.Sleep(100 * time.Millisecond) + err := sub.WaitToBeOnlineAndSubscribe(ctx, peerID) + require.Error(t, err) + assert.Contains(t, err.Error(), "already waiting") +} + +func TestUnsubscribeStateChange(t *testing.T) { + peerID := messages.HashID("peer4") + mockConn := &mockRelayedConn{} + logger := logrus.New() + logger.SetOutput(&bytes.Buffer{}) + sub := NewPeersStateSubscription(logrus.NewEntry(logger), mockConn, nil) + + doneChan := make(chan struct{}) + go func() { + _ = sub.WaitToBeOnlineAndSubscribe(context.Background(), peerID) + close(doneChan) + }() + time.Sleep(100 * time.Millisecond) + + err := sub.UnsubscribeStateChange([]messages.PeerID{peerID}) + assert.NoError(t, err) + + select { + case <-doneChan: + case <-time.After(200 * time.Millisecond): + // Expected timeout, meaning the subscription was successfully unsubscribed + t.Errorf("timeout") + } +} diff --git a/relay/client/picker.go b/relay/client/picker.go index eb5062dbb..9565425a8 100644 --- a/relay/client/picker.go +++ b/relay/client/picker.go @@ -70,8 +70,8 @@ func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) { func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan connResult, url string) { log.Infof("try to connecting to relay server: %s", url) - relayClient := NewClient(ctx, url, sp.TokenStore, sp.PeerID) - err := relayClient.Connect() + relayClient := NewClient(url, sp.TokenStore, sp.PeerID) + err := relayClient.Connect(ctx) resultChan <- connResult{ RelayClient: relayClient, Url: url, diff --git a/relay/cmd/root.go b/relay/cmd/root.go index d603ff73b..15090024c 100644 --- a/relay/cmd/root.go +++ b/relay/cmd/root.go @@ -141,7 +141,14 @@ func execute(cmd *cobra.Command, args []string) error { hashedSecret := sha256.Sum256([]byte(cobraConfig.AuthSecret)) authenticator := auth.NewTimedHMACValidator(hashedSecret[:], 24*time.Hour) - srv, err := server.NewServer(metricsServer.Meter, cobraConfig.ExposedAddress, tlsSupport, authenticator) + cfg := server.Config{ + Meter: metricsServer.Meter, + ExposedAddress: cobraConfig.ExposedAddress, + AuthValidator: authenticator, + TLSSupport: tlsSupport, + } + + srv, err := server.NewServer(cfg) if err != nil { log.Debugf("failed to create relay server: %v", err) return fmt.Errorf("failed to create relay server: %v", err) diff --git a/relay/messages/id.go b/relay/messages/id.go index e2162cd3b..96ace3478 100644 --- a/relay/messages/id.go +++ b/relay/messages/id.go @@ -8,24 +8,24 @@ import ( const ( prefixLength = 4 - IDSize = prefixLength + sha256.Size + peerIDSize = prefixLength + sha256.Size ) var ( prefix = []byte("sha-") // 4 bytes ) -// HashID generates a sha256 hash from the peerID and returns the hash and the human-readable string -func HashID(peerID string) ([]byte, string) { - idHash := sha256.Sum256([]byte(peerID)) - idHashString := string(prefix) + base64.StdEncoding.EncodeToString(idHash[:]) - var prefixedHash []byte - prefixedHash = append(prefixedHash, prefix...) - prefixedHash = append(prefixedHash, idHash[:]...) - return prefixedHash, idHashString +type PeerID [peerIDSize]byte + +func (p PeerID) String() string { + return fmt.Sprintf("%s%s", p[:prefixLength], base64.StdEncoding.EncodeToString(p[prefixLength:])) } -// HashIDToString converts a hash to a human-readable string -func HashIDToString(idHash []byte) string { - return fmt.Sprintf("%s%s", idHash[:prefixLength], base64.StdEncoding.EncodeToString(idHash[prefixLength:])) +// HashID generates a sha256 hash from the peerID and returns the hash and the human-readable string +func HashID(peerID string) PeerID { + idHash := sha256.Sum256([]byte(peerID)) + var prefixedHash [peerIDSize]byte + copy(prefixedHash[:prefixLength], prefix) + copy(prefixedHash[prefixLength:], idHash[:]) + return prefixedHash } diff --git a/relay/messages/id_test.go b/relay/messages/id_test.go deleted file mode 100644 index 271a8f90d..000000000 --- a/relay/messages/id_test.go +++ /dev/null @@ -1,13 +0,0 @@ -package messages - -import ( - "testing" -) - -func TestHashID(t *testing.T) { - hashedID, hashedStringId := HashID("alice") - enc := HashIDToString(hashedID) - if enc != hashedStringId { - t.Errorf("expected %s, got %s", hashedStringId, enc) - } -} diff --git a/relay/messages/message.go b/relay/messages/message.go index 7794c57bc..54671f5df 100644 --- a/relay/messages/message.go +++ b/relay/messages/message.go @@ -9,19 +9,26 @@ import ( const ( MaxHandshakeSize = 212 MaxHandshakeRespSize = 8192 + MaxMessageSize = 8820 CurrentProtocolVersion = 1 MsgTypeUnknown MsgType = 0 // Deprecated: Use MsgTypeAuth instead. - MsgTypeHello MsgType = 1 + MsgTypeHello = 1 // Deprecated: Use MsgTypeAuthResponse instead. - MsgTypeHelloResponse MsgType = 2 - MsgTypeTransport MsgType = 3 - MsgTypeClose MsgType = 4 - MsgTypeHealthCheck MsgType = 5 - MsgTypeAuth = 6 - MsgTypeAuthResponse = 7 + MsgTypeHelloResponse = 2 + MsgTypeTransport = 3 + MsgTypeClose = 4 + MsgTypeHealthCheck = 5 + MsgTypeAuth = 6 + MsgTypeAuthResponse = 7 + + // Peers state messages + MsgTypeSubscribePeerState = 8 + MsgTypeUnsubscribePeerState = 9 + MsgTypePeersOnline = 10 + MsgTypePeersWentOffline = 11 // base size of the message sizeOfVersionByte = 1 @@ -30,17 +37,17 @@ const ( // auth message sizeOfMagicByte = 4 - headerSizeAuth = sizeOfMagicByte + IDSize + headerSizeAuth = sizeOfMagicByte + peerIDSize offsetMagicByte = sizeOfProtoHeader offsetAuthPeerID = sizeOfProtoHeader + sizeOfMagicByte headerTotalSizeAuth = sizeOfProtoHeader + headerSizeAuth // hello message - headerSizeHello = sizeOfMagicByte + IDSize + headerSizeHello = sizeOfMagicByte + peerIDSize headerSizeHelloResp = 0 // transport - headerSizeTransport = IDSize + headerSizeTransport = peerIDSize offsetTransportID = sizeOfProtoHeader headerTotalSizeTransport = sizeOfProtoHeader + headerSizeTransport ) @@ -72,6 +79,14 @@ func (m MsgType) String() string { return "close" case MsgTypeHealthCheck: return "health check" + case MsgTypeSubscribePeerState: + return "subscribe peer state" + case MsgTypeUnsubscribePeerState: + return "unsubscribe peer state" + case MsgTypePeersOnline: + return "peers online" + case MsgTypePeersWentOffline: + return "peers went offline" default: return "unknown" } @@ -102,7 +117,9 @@ func DetermineClientMessageType(msg []byte) (MsgType, error) { MsgTypeAuth, MsgTypeTransport, MsgTypeClose, - MsgTypeHealthCheck: + MsgTypeHealthCheck, + MsgTypeSubscribePeerState, + MsgTypeUnsubscribePeerState: return msgType, nil default: return MsgTypeUnknown, fmt.Errorf("invalid msg type %d", msgType) @@ -122,7 +139,9 @@ func DetermineServerMessageType(msg []byte) (MsgType, error) { MsgTypeAuthResponse, MsgTypeTransport, MsgTypeClose, - MsgTypeHealthCheck: + MsgTypeHealthCheck, + MsgTypePeersOnline, + MsgTypePeersWentOffline: return msgType, nil default: return MsgTypeUnknown, fmt.Errorf("invalid msg type %d", msgType) @@ -135,11 +154,7 @@ func DetermineServerMessageType(msg []byte) (MsgType, error) { // message is used to authenticate the client with the server. The authentication is done using an HMAC method. // The protocol does not limit to use HMAC, it can be any other method. If the authentication failed the server will // close the network connection without any response. -func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) { - if len(peerID) != IDSize { - return nil, fmt.Errorf("invalid peerID length: %d", len(peerID)) - } - +func MarshalHelloMsg(peerID PeerID, additions []byte) ([]byte, error) { msg := make([]byte, sizeOfProtoHeader+sizeOfMagicByte, sizeOfProtoHeader+headerSizeHello+len(additions)) msg[0] = byte(CurrentProtocolVersion) @@ -147,7 +162,7 @@ func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) { copy(msg[sizeOfProtoHeader:sizeOfProtoHeader+sizeOfMagicByte], magicHeader) - msg = append(msg, peerID...) + msg = append(msg, peerID[:]...) msg = append(msg, additions...) return msg, nil @@ -156,7 +171,7 @@ func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) { // Deprecated: Use UnmarshalAuthMsg instead. // UnmarshalHelloMsg extracts peerID and the additional data from the hello message. The Additional data is used to // authenticate the client with the server. -func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) { +func UnmarshalHelloMsg(msg []byte) (*PeerID, []byte, error) { if len(msg) < sizeOfProtoHeader+headerSizeHello { return nil, nil, ErrInvalidMessageLength } @@ -164,7 +179,9 @@ func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) { return nil, nil, errors.New("invalid magic header") } - return msg[sizeOfProtoHeader+sizeOfMagicByte : sizeOfProtoHeader+headerSizeHello], msg[headerSizeHello:], nil + peerID := PeerID(msg[sizeOfProtoHeader+sizeOfMagicByte : sizeOfProtoHeader+headerSizeHello]) + + return &peerID, msg[headerSizeHello:], nil } // Deprecated: Use MarshalAuthResponse instead. @@ -197,34 +214,33 @@ func UnmarshalHelloResponse(msg []byte) ([]byte, error) { // message is used to authenticate the client with the server. The authentication is done using an HMAC method. // The protocol does not limit to use HMAC, it can be any other method. If the authentication failed the server will // close the network connection without any response. -func MarshalAuthMsg(peerID []byte, authPayload []byte) ([]byte, error) { - if len(peerID) != IDSize { - return nil, fmt.Errorf("invalid peerID length: %d", len(peerID)) +func MarshalAuthMsg(peerID PeerID, authPayload []byte) ([]byte, error) { + if headerTotalSizeAuth+len(authPayload) > MaxHandshakeSize { + return nil, fmt.Errorf("too large auth payload") } - msg := make([]byte, sizeOfProtoHeader+sizeOfMagicByte, headerTotalSizeAuth+len(authPayload)) - + msg := make([]byte, headerTotalSizeAuth+len(authPayload)) msg[0] = byte(CurrentProtocolVersion) msg[1] = byte(MsgTypeAuth) - copy(msg[sizeOfProtoHeader:], magicHeader) - - msg = append(msg, peerID...) - msg = append(msg, authPayload...) - + copy(msg[offsetAuthPeerID:], peerID[:]) + copy(msg[headerTotalSizeAuth:], authPayload) return msg, nil } // UnmarshalAuthMsg extracts peerID and the auth payload from the message -func UnmarshalAuthMsg(msg []byte) ([]byte, []byte, error) { +func UnmarshalAuthMsg(msg []byte) (*PeerID, []byte, error) { if len(msg) < headerTotalSizeAuth { return nil, nil, ErrInvalidMessageLength } + + // Validate the magic header if !bytes.Equal(msg[offsetMagicByte:offsetMagicByte+sizeOfMagicByte], magicHeader) { return nil, nil, errors.New("invalid magic header") } - return msg[offsetAuthPeerID:headerTotalSizeAuth], msg[headerTotalSizeAuth:], nil + peerID := PeerID(msg[offsetAuthPeerID:headerTotalSizeAuth]) + return &peerID, msg[headerTotalSizeAuth:], nil } // MarshalAuthResponse creates a response message to the auth. @@ -268,45 +284,48 @@ func MarshalCloseMsg() []byte { // MarshalTransportMsg creates a transport message. // The transport message is used to exchange data between peers. The message contains the data to be exchanged and the // destination peer hashed ID. -func MarshalTransportMsg(peerID, payload []byte) ([]byte, error) { - if len(peerID) != IDSize { - return nil, fmt.Errorf("invalid peerID length: %d", len(peerID)) - } - - msg := make([]byte, headerTotalSizeTransport, headerTotalSizeTransport+len(payload)) +func MarshalTransportMsg(peerID PeerID, payload []byte) ([]byte, error) { + // todo validate size + msg := make([]byte, headerTotalSizeTransport+len(payload)) msg[0] = byte(CurrentProtocolVersion) msg[1] = byte(MsgTypeTransport) - copy(msg[sizeOfProtoHeader:], peerID) - msg = append(msg, payload...) - + copy(msg[sizeOfProtoHeader:], peerID[:]) + copy(msg[sizeOfProtoHeader+peerIDSize:], payload) return msg, nil } // UnmarshalTransportMsg extracts the peerID and the payload from the transport message. -func UnmarshalTransportMsg(buf []byte) ([]byte, []byte, error) { +func UnmarshalTransportMsg(buf []byte) (*PeerID, []byte, error) { if len(buf) < headerTotalSizeTransport { return nil, nil, ErrInvalidMessageLength } - return buf[offsetTransportID:headerTotalSizeTransport], buf[headerTotalSizeTransport:], nil + const offsetEnd = offsetTransportID + peerIDSize + var peerID PeerID + copy(peerID[:], buf[offsetTransportID:offsetEnd]) + return &peerID, buf[headerTotalSizeTransport:], nil } // UnmarshalTransportID extracts the peerID from the transport message. -func UnmarshalTransportID(buf []byte) ([]byte, error) { +func UnmarshalTransportID(buf []byte) (*PeerID, error) { if len(buf) < headerTotalSizeTransport { return nil, ErrInvalidMessageLength } - return buf[offsetTransportID:headerTotalSizeTransport], nil + + const offsetEnd = offsetTransportID + peerIDSize + var id PeerID + copy(id[:], buf[offsetTransportID:offsetEnd]) + return &id, nil } // UpdateTransportMsg updates the peerID in the transport message. // With this function the server can reuse the given byte slice to update the peerID in the transport message. So do // need to allocate a new byte slice. -func UpdateTransportMsg(msg []byte, peerID []byte) error { - if len(msg) < offsetTransportID+len(peerID) { +func UpdateTransportMsg(msg []byte, peerID PeerID) error { + if len(msg) < offsetTransportID+peerIDSize { return ErrInvalidMessageLength } - copy(msg[offsetTransportID:], peerID) + copy(msg[offsetTransportID:], peerID[:]) return nil } diff --git a/relay/messages/message_test.go b/relay/messages/message_test.go index 19bede07b..59a89cad1 100644 --- a/relay/messages/message_test.go +++ b/relay/messages/message_test.go @@ -5,7 +5,7 @@ import ( ) func TestMarshalHelloMsg(t *testing.T) { - peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+") + peerID := HashID("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+") msg, err := MarshalHelloMsg(peerID, nil) if err != nil { t.Fatalf("error: %v", err) @@ -24,13 +24,13 @@ func TestMarshalHelloMsg(t *testing.T) { if err != nil { t.Fatalf("error: %v", err) } - if string(receivedPeerID) != string(peerID) { + if receivedPeerID.String() != peerID.String() { t.Errorf("expected %s, got %s", peerID, receivedPeerID) } } func TestMarshalAuthMsg(t *testing.T) { - peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+") + peerID := HashID("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+") msg, err := MarshalAuthMsg(peerID, []byte{}) if err != nil { t.Fatalf("error: %v", err) @@ -49,7 +49,7 @@ func TestMarshalAuthMsg(t *testing.T) { if err != nil { t.Fatalf("error: %v", err) } - if string(receivedPeerID) != string(peerID) { + if receivedPeerID.String() != peerID.String() { t.Errorf("expected %s, got %s", peerID, receivedPeerID) } } @@ -80,7 +80,7 @@ func TestMarshalAuthResponse(t *testing.T) { } func TestMarshalTransportMsg(t *testing.T) { - peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+") + peerID := HashID("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+") payload := []byte("payload") msg, err := MarshalTransportMsg(peerID, payload) if err != nil { @@ -101,7 +101,7 @@ func TestMarshalTransportMsg(t *testing.T) { t.Fatalf("failed to unmarshal transport id: %v", err) } - if string(uPeerID) != string(peerID) { + if uPeerID.String() != peerID.String() { t.Errorf("expected %s, got %s", peerID, uPeerID) } @@ -110,8 +110,8 @@ func TestMarshalTransportMsg(t *testing.T) { t.Fatalf("error: %v", err) } - if string(id) != string(peerID) { - t.Errorf("expected %s, got %s", peerID, id) + if id.String() != peerID.String() { + t.Errorf("expected: '%s', got: '%s'", peerID, id) } if string(respPayload) != string(payload) { diff --git a/relay/messages/peer_state.go b/relay/messages/peer_state.go new file mode 100644 index 000000000..f10bc7bdf --- /dev/null +++ b/relay/messages/peer_state.go @@ -0,0 +1,92 @@ +package messages + +import ( + "fmt" +) + +func MarshalSubPeerStateMsg(ids []PeerID) ([][]byte, error) { + return marshalPeerIDs(ids, byte(MsgTypeSubscribePeerState)) +} + +func UnmarshalSubPeerStateMsg(buf []byte) ([]PeerID, error) { + return unmarshalPeerIDs(buf) +} + +func MarshalUnsubPeerStateMsg(ids []PeerID) ([][]byte, error) { + return marshalPeerIDs(ids, byte(MsgTypeUnsubscribePeerState)) +} + +func UnmarshalUnsubPeerStateMsg(buf []byte) ([]PeerID, error) { + return unmarshalPeerIDs(buf) +} + +func MarshalPeersOnline(ids []PeerID) ([][]byte, error) { + return marshalPeerIDs(ids, byte(MsgTypePeersOnline)) +} + +func UnmarshalPeersOnlineMsg(buf []byte) ([]PeerID, error) { + return unmarshalPeerIDs(buf) +} + +func MarshalPeersWentOffline(ids []PeerID) ([][]byte, error) { + return marshalPeerIDs(ids, byte(MsgTypePeersWentOffline)) +} + +func UnMarshalPeersWentOffline(buf []byte) ([]PeerID, error) { + return unmarshalPeerIDs(buf) +} + +// marshalPeerIDs is a generic function to marshal peer IDs with a specific message type +func marshalPeerIDs(ids []PeerID, msgType byte) ([][]byte, error) { + if len(ids) == 0 { + return nil, fmt.Errorf("no list of peer ids provided") + } + + const maxPeersPerMessage = (MaxMessageSize - sizeOfProtoHeader) / peerIDSize + var messages [][]byte + + for i := 0; i < len(ids); i += maxPeersPerMessage { + end := i + maxPeersPerMessage + if end > len(ids) { + end = len(ids) + } + chunk := ids[i:end] + + totalSize := sizeOfProtoHeader + len(chunk)*peerIDSize + buf := make([]byte, totalSize) + buf[0] = byte(CurrentProtocolVersion) + buf[1] = msgType + + offset := sizeOfProtoHeader + for _, id := range chunk { + copy(buf[offset:], id[:]) + offset += peerIDSize + } + + messages = append(messages, buf) + } + + return messages, nil +} + +// unmarshalPeerIDs is a generic function to unmarshal peer IDs from a buffer +func unmarshalPeerIDs(buf []byte) ([]PeerID, error) { + if len(buf) < sizeOfProtoHeader { + return nil, fmt.Errorf("invalid message format") + } + + if (len(buf)-sizeOfProtoHeader)%peerIDSize != 0 { + return nil, fmt.Errorf("invalid peer list size: %d", len(buf)-sizeOfProtoHeader) + } + + numIDs := (len(buf) - sizeOfProtoHeader) / peerIDSize + + ids := make([]PeerID, numIDs) + offset := sizeOfProtoHeader + for i := 0; i < numIDs; i++ { + copy(ids[i][:], buf[offset:offset+peerIDSize]) + offset += peerIDSize + } + + return ids, nil +} diff --git a/relay/messages/peer_state_test.go b/relay/messages/peer_state_test.go new file mode 100644 index 000000000..9e366da55 --- /dev/null +++ b/relay/messages/peer_state_test.go @@ -0,0 +1,144 @@ +package messages + +import ( + "bytes" + "testing" +) + +const ( + testPeerCount = 10 +) + +// Helper function to generate test PeerIDs +func generateTestPeerIDs(n int) []PeerID { + ids := make([]PeerID, n) + for i := 0; i < n; i++ { + for j := 0; j < peerIDSize; j++ { + ids[i][j] = byte(i + j) + } + } + return ids +} + +// Helper function to compare slices of PeerID +func peerIDEqual(a, b []PeerID) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if !bytes.Equal(a[i][:], b[i][:]) { + return false + } + } + return true +} + +func TestMarshalUnmarshalSubPeerState(t *testing.T) { + ids := generateTestPeerIDs(testPeerCount) + + msgs, err := MarshalSubPeerStateMsg(ids) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var allIDs []PeerID + for _, msg := range msgs { + decoded, err := UnmarshalSubPeerStateMsg(msg) + if err != nil { + t.Fatalf("unmarshal failed: %v", err) + } + allIDs = append(allIDs, decoded...) + } + + if !peerIDEqual(ids, allIDs) { + t.Errorf("expected %v, got %v", ids, allIDs) + } +} + +func TestMarshalSubPeerState_EmptyInput(t *testing.T) { + _, err := MarshalSubPeerStateMsg([]PeerID{}) + if err == nil { + t.Errorf("expected error for empty input") + } +} + +func TestUnmarshalSubPeerState_Invalid(t *testing.T) { + // Too short + _, err := UnmarshalSubPeerStateMsg([]byte{1}) + if err == nil { + t.Errorf("expected error for short input") + } + + // Misaligned length + buf := make([]byte, sizeOfProtoHeader+1) + _, err = UnmarshalSubPeerStateMsg(buf) + if err == nil { + t.Errorf("expected error for misaligned input") + } +} + +func TestMarshalUnmarshalPeersOnline(t *testing.T) { + ids := generateTestPeerIDs(testPeerCount) + + msgs, err := MarshalPeersOnline(ids) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var allIDs []PeerID + for _, msg := range msgs { + decoded, err := UnmarshalPeersOnlineMsg(msg) + if err != nil { + t.Fatalf("unmarshal failed: %v", err) + } + allIDs = append(allIDs, decoded...) + } + + if !peerIDEqual(ids, allIDs) { + t.Errorf("expected %v, got %v", ids, allIDs) + } +} + +func TestMarshalPeersOnline_EmptyInput(t *testing.T) { + _, err := MarshalPeersOnline([]PeerID{}) + if err == nil { + t.Errorf("expected error for empty input") + } +} + +func TestUnmarshalPeersOnline_Invalid(t *testing.T) { + _, err := UnmarshalPeersOnlineMsg([]byte{1}) + if err == nil { + t.Errorf("expected error for short input") + } +} + +func TestMarshalUnmarshalPeersWentOffline(t *testing.T) { + ids := generateTestPeerIDs(testPeerCount) + + msgs, err := MarshalPeersWentOffline(ids) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var allIDs []PeerID + for _, msg := range msgs { + // MarshalPeersWentOffline shares no unmarshal function, so reuse PeersOnline + decoded, err := UnmarshalPeersOnlineMsg(msg) + if err != nil { + t.Fatalf("unmarshal failed: %v", err) + } + allIDs = append(allIDs, decoded...) + } + + if !peerIDEqual(ids, allIDs) { + t.Errorf("expected %v, got %v", ids, allIDs) + } +} + +func TestMarshalPeersWentOffline_EmptyInput(t *testing.T) { + _, err := MarshalPeersWentOffline([]PeerID{}) + if err == nil { + t.Errorf("expected error for empty input") + } +} diff --git a/relay/server/handshake.go b/relay/server/handshake.go index babd6f955..eb72b3bae 100644 --- a/relay/server/handshake.go +++ b/relay/server/handshake.go @@ -6,7 +6,6 @@ import ( log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/relay/auth" "github.com/netbirdio/netbird/relay/messages" //nolint:staticcheck "github.com/netbirdio/netbird/relay/messages/address" @@ -14,6 +13,12 @@ import ( authmsg "github.com/netbirdio/netbird/relay/messages/auth" ) +type Validator interface { + Validate(any) error + // Deprecated: Use Validate instead. + ValidateHelloMsgType(any) error +} + // preparedMsg contains the marshalled success response messages type preparedMsg struct { responseHelloMsg []byte @@ -54,14 +59,14 @@ func marshalResponseHelloMsg(instanceURL string) ([]byte, error) { type handshake struct { conn net.Conn - validator auth.Validator + validator Validator preparedMsg *preparedMsg handshakeMethodAuth bool - peerID string + peerID *messages.PeerID } -func (h *handshake) handshakeReceive() ([]byte, error) { +func (h *handshake) handshakeReceive() (*messages.PeerID, error) { buf := make([]byte, messages.MaxHandshakeSize) n, err := h.conn.Read(buf) if err != nil { @@ -80,17 +85,14 @@ func (h *handshake) handshakeReceive() ([]byte, error) { return nil, fmt.Errorf("determine message type from %s: %w", h.conn.RemoteAddr(), err) } - var ( - bytePeerID []byte - peerID string - ) + var peerID *messages.PeerID switch msgType { //nolint:staticcheck case messages.MsgTypeHello: - bytePeerID, peerID, err = h.handleHelloMsg(buf) + peerID, err = h.handleHelloMsg(buf) case messages.MsgTypeAuth: h.handshakeMethodAuth = true - bytePeerID, peerID, err = h.handleAuthMsg(buf) + peerID, err = h.handleAuthMsg(buf) default: return nil, fmt.Errorf("invalid message type %d from %s", msgType, h.conn.RemoteAddr()) } @@ -98,7 +100,7 @@ func (h *handshake) handshakeReceive() ([]byte, error) { return nil, err } h.peerID = peerID - return bytePeerID, nil + return peerID, nil } func (h *handshake) handshakeResponse() error { @@ -116,40 +118,37 @@ func (h *handshake) handshakeResponse() error { return nil } -func (h *handshake) handleHelloMsg(buf []byte) ([]byte, string, error) { +func (h *handshake) handleHelloMsg(buf []byte) (*messages.PeerID, error) { //nolint:staticcheck - rawPeerID, authData, err := messages.UnmarshalHelloMsg(buf) + peerID, authData, err := messages.UnmarshalHelloMsg(buf) if err != nil { - return nil, "", fmt.Errorf("unmarshal hello message: %w", err) + return nil, fmt.Errorf("unmarshal hello message: %w", err) } - peerID := messages.HashIDToString(rawPeerID) log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, h.conn.RemoteAddr()) authMsg, err := authmsg.UnmarshalMsg(authData) if err != nil { - return nil, "", fmt.Errorf("unmarshal auth message: %w", err) + return nil, fmt.Errorf("unmarshal auth message: %w", err) } //nolint:staticcheck if err := h.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil { - return nil, "", fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err) + return nil, fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err) } - return rawPeerID, peerID, nil + return peerID, nil } -func (h *handshake) handleAuthMsg(buf []byte) ([]byte, string, error) { +func (h *handshake) handleAuthMsg(buf []byte) (*messages.PeerID, error) { rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf) if err != nil { - return nil, "", fmt.Errorf("unmarshal hello message: %w", err) + return nil, fmt.Errorf("unmarshal hello message: %w", err) } - peerID := messages.HashIDToString(rawPeerID) - if err := h.validator.Validate(authPayload); err != nil { - return nil, "", fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err) + return nil, fmt.Errorf("validate %s (%s): %w", rawPeerID.String(), h.conn.RemoteAddr(), err) } - return rawPeerID, peerID, nil + return rawPeerID, nil } diff --git a/relay/server/peer.go b/relay/server/peer.go index aa9790f63..c6fa8508f 100644 --- a/relay/server/peer.go +++ b/relay/server/peer.go @@ -12,43 +12,50 @@ import ( "github.com/netbirdio/netbird/relay/healthcheck" "github.com/netbirdio/netbird/relay/messages" "github.com/netbirdio/netbird/relay/metrics" + "github.com/netbirdio/netbird/relay/server/store" ) const ( - bufferSize = 8820 + bufferSize = messages.MaxMessageSize errCloseConn = "failed to close connection to peer: %s" ) // Peer represents a peer connection type Peer struct { - metrics *metrics.Metrics - log *log.Entry - idS string - idB []byte - conn net.Conn - connMu sync.RWMutex - store *Store + metrics *metrics.Metrics + log *log.Entry + id messages.PeerID + conn net.Conn + connMu sync.RWMutex + store *store.Store + notifier *store.PeerNotifier + + peersListener *store.Listener } // NewPeer creates a new Peer instance and prepare custom logging -func NewPeer(metrics *metrics.Metrics, id []byte, conn net.Conn, store *Store) *Peer { - stringID := messages.HashIDToString(id) - return &Peer{ - metrics: metrics, - log: log.WithField("peer_id", stringID), - idS: stringID, - idB: id, - conn: conn, - store: store, +func NewPeer(metrics *metrics.Metrics, id messages.PeerID, conn net.Conn, store *store.Store, notifier *store.PeerNotifier) *Peer { + p := &Peer{ + metrics: metrics, + log: log.WithField("peer_id", id.String()), + id: id, + conn: conn, + store: store, + notifier: notifier, } + + return p } // Work reads data from the connection // It manages the protocol (healthcheck, transport, close). Read the message and determine the message type and handle // the message accordingly. func (p *Peer) Work() { + p.peersListener = p.notifier.NewListener(p.sendPeersOnline, p.sendPeersWentOffline) defer func() { + p.notifier.RemoveListener(p.peersListener) + if err := p.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { p.log.Errorf(errCloseConn, err) } @@ -94,6 +101,10 @@ func (p *Peer) Work() { } } +func (p *Peer) ID() messages.PeerID { + return p.id +} + func (p *Peer) handleMsgType(ctx context.Context, msgType messages.MsgType, hc *healthcheck.Sender, n int, msg []byte) { switch msgType { case messages.MsgTypeHealthCheck: @@ -107,6 +118,10 @@ func (p *Peer) handleMsgType(ctx context.Context, msgType messages.MsgType, hc * if err := p.conn.Close(); err != nil { log.Errorf(errCloseConn, err) } + case messages.MsgTypeSubscribePeerState: + p.handleSubscribePeerState(msg) + case messages.MsgTypeUnsubscribePeerState: + p.handleUnsubscribePeerState(msg) default: p.log.Warnf("received unexpected message type: %s", msgType) } @@ -145,7 +160,7 @@ func (p *Peer) Close() { // String returns the peer ID func (p *Peer) String() string { - return p.idS + return p.id.String() } func (p *Peer) writeWithTimeout(ctx context.Context, buf []byte) error { @@ -197,14 +212,14 @@ func (p *Peer) handleTransportMsg(msg []byte) { return } - stringPeerID := messages.HashIDToString(peerID) - dp, ok := p.store.Peer(stringPeerID) + item, ok := p.store.Peer(*peerID) if !ok { - p.log.Debugf("peer not found: %s", stringPeerID) + p.log.Debugf("peer not found: %s", peerID) return } + dp := item.(*Peer) - err = messages.UpdateTransportMsg(msg, p.idB) + err = messages.UpdateTransportMsg(msg, p.id) if err != nil { p.log.Errorf("failed to update transport message: %s", err) return @@ -217,3 +232,57 @@ func (p *Peer) handleTransportMsg(msg []byte) { } p.metrics.TransferBytesSent.Add(context.Background(), int64(n)) } + +func (p *Peer) handleSubscribePeerState(msg []byte) { + peerIDs, err := messages.UnmarshalSubPeerStateMsg(msg) + if err != nil { + p.log.Errorf("failed to unmarshal open connection message: %s", err) + return + } + + p.log.Debugf("received subscription message for %d peers", len(peerIDs)) + onlinePeers := p.peersListener.AddInterestedPeers(peerIDs) + if len(onlinePeers) == 0 { + return + } + p.log.Debugf("response with %d online peers", len(onlinePeers)) + p.sendPeersOnline(onlinePeers) +} + +func (p *Peer) handleUnsubscribePeerState(msg []byte) { + peerIDs, err := messages.UnmarshalUnsubPeerStateMsg(msg) + if err != nil { + p.log.Errorf("failed to unmarshal open connection message: %s", err) + return + } + + p.peersListener.RemoveInterestedPeer(peerIDs) +} + +func (p *Peer) sendPeersOnline(peers []messages.PeerID) { + msgs, err := messages.MarshalPeersOnline(peers) + if err != nil { + p.log.Errorf("failed to marshal peer location message: %s", err) + return + } + + for n, msg := range msgs { + if _, err := p.Write(msg); err != nil { + p.log.Errorf("failed to write %d. peers offline message: %s", n, err) + } + } +} + +func (p *Peer) sendPeersWentOffline(peers []messages.PeerID) { + msgs, err := messages.MarshalPeersWentOffline(peers) + if err != nil { + p.log.Errorf("failed to marshal peer location message: %s", err) + return + } + + for n, msg := range msgs { + if _, err := p.Write(msg); err != nil { + p.log.Errorf("failed to write %d. peers offline message: %s", n, err) + } + } +} diff --git a/relay/server/relay.go b/relay/server/relay.go index a5e77bc61..93fb00edb 100644 --- a/relay/server/relay.go +++ b/relay/server/relay.go @@ -4,26 +4,55 @@ import ( "context" "fmt" "net" - "net/url" - "strings" "sync" "time" log "github.com/sirupsen/logrus" + "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/metric" - "github.com/netbirdio/netbird/relay/auth" //nolint:staticcheck "github.com/netbirdio/netbird/relay/metrics" + "github.com/netbirdio/netbird/relay/server/store" ) +type Config struct { + Meter metric.Meter + ExposedAddress string + TLSSupport bool + AuthValidator Validator + + instanceURL string +} + +func (c *Config) validate() error { + if c.Meter == nil { + c.Meter = otel.Meter("") + } + if c.ExposedAddress == "" { + return fmt.Errorf("exposed address is required") + } + + instanceURL, err := getInstanceURL(c.ExposedAddress, c.TLSSupport) + if err != nil { + return fmt.Errorf("invalid url: %v", err) + } + c.instanceURL = instanceURL + + if c.AuthValidator == nil { + return fmt.Errorf("auth validator is required") + } + return nil +} + // Relay represents the relay server type Relay struct { metrics *metrics.Metrics metricsCancel context.CancelFunc - validator auth.Validator + validator Validator - store *Store + store *store.Store + notifier *store.PeerNotifier instanceURL string preparedMsg *preparedMsg @@ -31,40 +60,40 @@ type Relay struct { closeMu sync.RWMutex } -// NewRelay creates a new Relay instance +// NewRelay creates and returns a new Relay instance. // // Parameters: -// meter: An instance of metric.Meter from the go.opentelemetry.io/otel/metric package. It is used to create and manage -// metrics for the relay server. -// exposedAddress: A string representing the address that the relay server is exposed on. The client will use this -// address as the relay server's instance URL. -// tlsSupport: A boolean indicating whether the relay server supports TLS (Transport Layer Security) or not. The -// instance URL depends on this value. -// validator: An instance of auth.Validator from the auth package. It is used to validate the authentication of the -// peers. +// +// config: A Config struct that holds the configuration needed to initialize the relay server. +// - Meter: A metric.Meter used for emitting metrics. If not set, a default no-op meter will be used. +// - ExposedAddress: The external address clients use to reach this relay. Required. +// - TLSSupport: A boolean indicating if the relay uses TLS. Affects the generated instance URL. +// - AuthValidator: A Validator implementation used to authenticate peers. Required. // // Returns: -// A pointer to a Relay instance and an error. If the Relay instance is successfully created, the error is nil. -// Otherwise, the error contains the details of what went wrong. -func NewRelay(meter metric.Meter, exposedAddress string, tlsSupport bool, validator auth.Validator) (*Relay, error) { +// +// A pointer to a Relay instance and an error. If initialization is successful, the error will be nil; +// otherwise, it will contain the reason the relay could not be created (e.g., invalid configuration). +func NewRelay(config Config) (*Relay, error) { + if err := config.validate(); err != nil { + return nil, fmt.Errorf("invalid config: %v", err) + } + ctx, metricsCancel := context.WithCancel(context.Background()) - m, err := metrics.NewMetrics(ctx, meter) + m, err := metrics.NewMetrics(ctx, config.Meter) if err != nil { metricsCancel() return nil, fmt.Errorf("creating app metrics: %v", err) } + peerStore := store.NewStore() r := &Relay{ metrics: m, metricsCancel: metricsCancel, - validator: validator, - store: NewStore(), - } - - r.instanceURL, err = getInstanceURL(exposedAddress, tlsSupport) - if err != nil { - metricsCancel() - return nil, fmt.Errorf("get instance URL: %v", err) + validator: config.AuthValidator, + instanceURL: config.instanceURL, + store: peerStore, + notifier: store.NewPeerNotifier(peerStore), } r.preparedMsg, err = newPreparedMsg(r.instanceURL) @@ -76,32 +105,6 @@ func NewRelay(meter metric.Meter, exposedAddress string, tlsSupport bool, valida return r, nil } -// getInstanceURL checks if user supplied a URL scheme otherwise adds to the -// provided address according to TLS definition and parses the address before returning it -func getInstanceURL(exposedAddress string, tlsSupported bool) (string, error) { - addr := exposedAddress - split := strings.Split(exposedAddress, "://") - switch { - case len(split) == 1 && tlsSupported: - addr = "rels://" + exposedAddress - case len(split) == 1 && !tlsSupported: - addr = "rel://" + exposedAddress - case len(split) > 2: - return "", fmt.Errorf("invalid exposed address: %s", exposedAddress) - } - - parsedURL, err := url.ParseRequestURI(addr) - if err != nil { - return "", fmt.Errorf("invalid exposed address: %v", err) - } - - if parsedURL.Scheme != "rel" && parsedURL.Scheme != "rels" { - return "", fmt.Errorf("invalid scheme: %s", parsedURL.Scheme) - } - - return parsedURL.String(), nil -} - // Accept start to handle a new peer connection func (r *Relay) Accept(conn net.Conn) { acceptTime := time.Now() @@ -125,14 +128,17 @@ func (r *Relay) Accept(conn net.Conn) { return } - peer := NewPeer(r.metrics, peerID, conn, r.store) + peer := NewPeer(r.metrics, *peerID, conn, r.store, r.notifier) peer.log.Infof("peer connected from: %s", conn.RemoteAddr()) storeTime := time.Now() r.store.AddPeer(peer) + r.notifier.PeerCameOnline(peer.ID()) + r.metrics.RecordPeerStoreTime(time.Since(storeTime)) r.metrics.PeerConnected(peer.String()) go func() { peer.Work() + r.notifier.PeerWentOffline(peer.ID()) r.store.DeletePeer(peer) peer.log.Debugf("relay connection closed") r.metrics.PeerDisconnected(peer.String()) @@ -154,12 +160,12 @@ func (r *Relay) Shutdown(ctx context.Context) { wg := sync.WaitGroup{} peers := r.store.Peers() - for _, peer := range peers { + for _, v := range peers { wg.Add(1) go func(p *Peer) { p.CloseGracefully(ctx) wg.Done() - }(peer) + }(v.(*Peer)) } wg.Wait() r.metricsCancel() diff --git a/relay/server/server.go b/relay/server/server.go index 10aabcace..f0b480ee4 100644 --- a/relay/server/server.go +++ b/relay/server/server.go @@ -6,15 +6,12 @@ import ( "sync" "github.com/hashicorp/go-multierror" - log "github.com/sirupsen/logrus" - "go.opentelemetry.io/otel/metric" - nberrors "github.com/netbirdio/netbird/client/errors" - "github.com/netbirdio/netbird/relay/auth" "github.com/netbirdio/netbird/relay/server/listener" "github.com/netbirdio/netbird/relay/server/listener/quic" "github.com/netbirdio/netbird/relay/server/listener/ws" quictls "github.com/netbirdio/netbird/relay/tls" + log "github.com/sirupsen/logrus" ) // ListenerConfig is the configuration for the listener. @@ -33,13 +30,22 @@ type Server struct { listeners []listener.Listener } -// NewServer creates a new relay server instance. -// meter: the OpenTelemetry meter -// exposedAddress: this address will be used as the instance URL. It should be a domain:port format. -// tlsSupport: if true, the server will support TLS -// authValidator: the auth validator to use for the server -func NewServer(meter metric.Meter, exposedAddress string, tlsSupport bool, authValidator auth.Validator) (*Server, error) { - relay, err := NewRelay(meter, exposedAddress, tlsSupport, authValidator) +// NewServer creates and returns a new relay server instance. +// +// Parameters: +// +// config: A Config struct containing the necessary configuration: +// - Meter: An OpenTelemetry metric.Meter used for recording metrics. If nil, a default no-op meter is used. +// - ExposedAddress: The public address (in domain:port format) used as the server's instance URL. Required. +// - TLSSupport: A boolean indicating whether TLS is enabled for the server. +// - AuthValidator: A Validator used to authenticate peers. Required. +// +// Returns: +// +// A pointer to a Server instance and an error. If the configuration is valid and initialization succeeds, +// the returned error will be nil. Otherwise, the error will describe the problem. +func NewServer(config Config) (*Server, error) { + relay, err := NewRelay(config) if err != nil { return nil, err } diff --git a/relay/server/store/listener.go b/relay/server/store/listener.go new file mode 100644 index 000000000..e5f455795 --- /dev/null +++ b/relay/server/store/listener.go @@ -0,0 +1,121 @@ +package store + +import ( + "context" + "sync" + + "github.com/netbirdio/netbird/relay/messages" +) + +type Listener struct { + store *Store + + onlineChan chan messages.PeerID + offlineChan chan messages.PeerID + interestedPeersForOffline map[messages.PeerID]struct{} + interestedPeersForOnline map[messages.PeerID]struct{} + mu sync.RWMutex + + listenerCtx context.Context +} + +func newListener(store *Store) *Listener { + l := &Listener{ + store: store, + + onlineChan: make(chan messages.PeerID, 244), //244 is the message size limit in the relay protocol + offlineChan: make(chan messages.PeerID, 244), //244 is the message size limit in the relay protocol + interestedPeersForOffline: make(map[messages.PeerID]struct{}), + interestedPeersForOnline: make(map[messages.PeerID]struct{}), + } + + return l +} + +func (l *Listener) AddInterestedPeers(peerIDs []messages.PeerID) []messages.PeerID { + availablePeers := make([]messages.PeerID, 0) + l.mu.Lock() + defer l.mu.Unlock() + + for _, id := range peerIDs { + l.interestedPeersForOnline[id] = struct{}{} + l.interestedPeersForOffline[id] = struct{}{} + } + + // collect online peers to response back to the caller + for _, id := range peerIDs { + _, ok := l.store.Peer(id) + if !ok { + continue + } + + availablePeers = append(availablePeers, id) + } + return availablePeers +} + +func (l *Listener) RemoveInterestedPeer(peerIDs []messages.PeerID) { + l.mu.Lock() + defer l.mu.Unlock() + + for _, id := range peerIDs { + delete(l.interestedPeersForOffline, id) + delete(l.interestedPeersForOnline, id) + + } +} + +func (l *Listener) listenForEvents(ctx context.Context, onPeersComeOnline, onPeersWentOffline func([]messages.PeerID)) { + l.listenerCtx = ctx + for { + select { + case <-ctx.Done(): + return + case pID := <-l.onlineChan: + peers := make([]messages.PeerID, 0) + peers = append(peers, pID) + + for len(l.onlineChan) > 0 { + pID = <-l.onlineChan + peers = append(peers, pID) + } + + onPeersComeOnline(peers) + case pID := <-l.offlineChan: + peers := make([]messages.PeerID, 0) + peers = append(peers, pID) + + for len(l.offlineChan) > 0 { + pID = <-l.offlineChan + peers = append(peers, pID) + } + + onPeersWentOffline(peers) + } + } +} + +func (l *Listener) peerWentOffline(peerID messages.PeerID) { + l.mu.RLock() + defer l.mu.RUnlock() + + if _, ok := l.interestedPeersForOffline[peerID]; ok { + select { + case l.offlineChan <- peerID: + case <-l.listenerCtx.Done(): + } + } +} + +func (l *Listener) peerComeOnline(peerID messages.PeerID) { + l.mu.Lock() + defer l.mu.Unlock() + + if _, ok := l.interestedPeersForOnline[peerID]; ok { + select { + case l.onlineChan <- peerID: + case <-l.listenerCtx.Done(): + } + delete(l.interestedPeersForOnline, peerID) + } +} diff --git a/relay/server/store/notifier.go b/relay/server/store/notifier.go new file mode 100644 index 000000000..d04db478b --- /dev/null +++ b/relay/server/store/notifier.go @@ -0,0 +1,64 @@ +package store + +import ( + "context" + "sync" + + "github.com/netbirdio/netbird/relay/messages" +) + +type PeerNotifier struct { + store *Store + + listeners map[*Listener]context.CancelFunc + listenersMutex sync.RWMutex +} + +func NewPeerNotifier(store *Store) *PeerNotifier { + pn := &PeerNotifier{ + store: store, + listeners: make(map[*Listener]context.CancelFunc), + } + return pn +} + +func (pn *PeerNotifier) NewListener(onPeersComeOnline, onPeersWentOffline func([]messages.PeerID)) *Listener { + ctx, cancel := context.WithCancel(context.Background()) + listener := newListener(pn.store) + go listener.listenForEvents(ctx, onPeersComeOnline, onPeersWentOffline) + + pn.listenersMutex.Lock() + pn.listeners[listener] = cancel + pn.listenersMutex.Unlock() + return listener +} + +func (pn *PeerNotifier) RemoveListener(listener *Listener) { + pn.listenersMutex.Lock() + defer pn.listenersMutex.Unlock() + + cancel, ok := pn.listeners[listener] + if !ok { + return + } + cancel() + delete(pn.listeners, listener) +} + +func (pn *PeerNotifier) PeerWentOffline(peerID messages.PeerID) { + pn.listenersMutex.RLock() + defer pn.listenersMutex.RUnlock() + + for listener := range pn.listeners { + listener.peerWentOffline(peerID) + } +} + +func (pn *PeerNotifier) PeerCameOnline(peerID messages.PeerID) { + pn.listenersMutex.RLock() + defer pn.listenersMutex.RUnlock() + + for listener := range pn.listeners { + listener.peerComeOnline(peerID) + } +} diff --git a/relay/server/store.go b/relay/server/store/store.go similarity index 61% rename from relay/server/store.go rename to relay/server/store/store.go index 4288e62c5..c19fb416f 100644 --- a/relay/server/store.go +++ b/relay/server/store/store.go @@ -1,41 +1,48 @@ -package server +package store import ( "sync" + + "github.com/netbirdio/netbird/relay/messages" ) +type IPeer interface { + Close() + ID() messages.PeerID +} + // Store is a thread-safe store of peers // It is used to store the peers that are connected to the relay server type Store struct { - peers map[string]*Peer // consider to use [32]byte as key. The Peer(id string) would be faster + peers map[messages.PeerID]IPeer peersLock sync.RWMutex } // NewStore creates a new Store instance func NewStore() *Store { return &Store{ - peers: make(map[string]*Peer), + peers: make(map[messages.PeerID]IPeer), } } // AddPeer adds a peer to the store -func (s *Store) AddPeer(peer *Peer) { +func (s *Store) AddPeer(peer IPeer) { s.peersLock.Lock() defer s.peersLock.Unlock() - odlPeer, ok := s.peers[peer.String()] + odlPeer, ok := s.peers[peer.ID()] if ok { odlPeer.Close() } - s.peers[peer.String()] = peer + s.peers[peer.ID()] = peer } // DeletePeer deletes a peer from the store -func (s *Store) DeletePeer(peer *Peer) { +func (s *Store) DeletePeer(peer IPeer) { s.peersLock.Lock() defer s.peersLock.Unlock() - dp, ok := s.peers[peer.String()] + dp, ok := s.peers[peer.ID()] if !ok { return } @@ -43,11 +50,11 @@ func (s *Store) DeletePeer(peer *Peer) { return } - delete(s.peers, peer.String()) + delete(s.peers, peer.ID()) } // Peer returns a peer by its ID -func (s *Store) Peer(id string) (*Peer, bool) { +func (s *Store) Peer(id messages.PeerID) (IPeer, bool) { s.peersLock.RLock() defer s.peersLock.RUnlock() @@ -56,11 +63,11 @@ func (s *Store) Peer(id string) (*Peer, bool) { } // Peers returns all the peers in the store -func (s *Store) Peers() []*Peer { +func (s *Store) Peers() []IPeer { s.peersLock.RLock() defer s.peersLock.RUnlock() - peers := make([]*Peer, 0, len(s.peers)) + peers := make([]IPeer, 0, len(s.peers)) for _, p := range s.peers { peers = append(peers, p) } diff --git a/relay/server/store/store_test.go b/relay/server/store/store_test.go new file mode 100644 index 000000000..ad549a62c --- /dev/null +++ b/relay/server/store/store_test.go @@ -0,0 +1,49 @@ +package store + +import ( + "testing" + + "github.com/netbirdio/netbird/relay/messages" +) + +type MocPeer struct { + id messages.PeerID +} + +func (m *MocPeer) Close() { + +} + +func (m *MocPeer) ID() messages.PeerID { + return m.id +} + +func TestStore_DeletePeer(t *testing.T) { + s := NewStore() + + pID := messages.HashID("peer_one") + p := &MocPeer{id: pID} + s.AddPeer(p) + s.DeletePeer(p) + if _, ok := s.Peer(pID); ok { + t.Errorf("peer was not deleted") + } +} + +func TestStore_DeleteDeprecatedPeer(t *testing.T) { + s := NewStore() + + pID1 := messages.HashID("peer_one") + pID2 := messages.HashID("peer_one") + + p1 := &MocPeer{id: pID1} + p2 := &MocPeer{id: pID2} + + s.AddPeer(p1) + s.AddPeer(p2) + s.DeletePeer(p1) + + if _, ok := s.Peer(pID2); !ok { + t.Errorf("second peer was deleted") + } +} diff --git a/relay/server/store_test.go b/relay/server/store_test.go deleted file mode 100644 index 41c7baa92..000000000 --- a/relay/server/store_test.go +++ /dev/null @@ -1,85 +0,0 @@ -package server - -import ( - "context" - "net" - "testing" - "time" - - "go.opentelemetry.io/otel" - - "github.com/netbirdio/netbird/relay/metrics" -) - -type mockConn struct { -} - -func (m mockConn) Read(b []byte) (n int, err error) { - //TODO implement me - panic("implement me") -} - -func (m mockConn) Write(b []byte) (n int, err error) { - //TODO implement me - panic("implement me") -} - -func (m mockConn) Close() error { - return nil -} - -func (m mockConn) LocalAddr() net.Addr { - //TODO implement me - panic("implement me") -} - -func (m mockConn) RemoteAddr() net.Addr { - //TODO implement me - panic("implement me") -} - -func (m mockConn) SetDeadline(t time.Time) error { - //TODO implement me - panic("implement me") -} - -func (m mockConn) SetReadDeadline(t time.Time) error { - //TODO implement me - panic("implement me") -} - -func (m mockConn) SetWriteDeadline(t time.Time) error { - //TODO implement me - panic("implement me") -} - -func TestStore_DeletePeer(t *testing.T) { - s := NewStore() - - m, _ := metrics.NewMetrics(context.Background(), otel.Meter("")) - - p := NewPeer(m, []byte("peer_one"), nil, nil) - s.AddPeer(p) - s.DeletePeer(p) - if _, ok := s.Peer(p.String()); ok { - t.Errorf("peer was not deleted") - } -} - -func TestStore_DeleteDeprecatedPeer(t *testing.T) { - s := NewStore() - - m, _ := metrics.NewMetrics(context.Background(), otel.Meter("")) - - conn := &mockConn{} - p1 := NewPeer(m, []byte("peer_id"), conn, nil) - p2 := NewPeer(m, []byte("peer_id"), conn, nil) - - s.AddPeer(p1) - s.AddPeer(p2) - s.DeletePeer(p1) - - if _, ok := s.Peer(p2.String()); !ok { - t.Errorf("second peer was deleted") - } -} diff --git a/relay/server/url.go b/relay/server/url.go new file mode 100644 index 000000000..9cbf44642 --- /dev/null +++ b/relay/server/url.go @@ -0,0 +1,33 @@ +package server + +import ( + "fmt" + "net/url" + "strings" +) + +// getInstanceURL checks if user supplied a URL scheme otherwise adds to the +// provided address according to TLS definition and parses the address before returning it +func getInstanceURL(exposedAddress string, tlsSupported bool) (string, error) { + addr := exposedAddress + split := strings.Split(exposedAddress, "://") + switch { + case len(split) == 1 && tlsSupported: + addr = "rels://" + exposedAddress + case len(split) == 1 && !tlsSupported: + addr = "rel://" + exposedAddress + case len(split) > 2: + return "", fmt.Errorf("invalid exposed address: %s", exposedAddress) + } + + parsedURL, err := url.ParseRequestURI(addr) + if err != nil { + return "", fmt.Errorf("invalid exposed address: %v", err) + } + + if parsedURL.Scheme != "rel" && parsedURL.Scheme != "rels" { + return "", fmt.Errorf("invalid scheme: %s", parsedURL.Scheme) + } + + return parsedURL.String(), nil +} diff --git a/relay/test/benchmark_test.go b/relay/test/benchmark_test.go index ec2aa488c..2e67ab803 100644 --- a/relay/test/benchmark_test.go +++ b/relay/test/benchmark_test.go @@ -12,7 +12,6 @@ import ( "github.com/pion/logging" "github.com/pion/turn/v3" - "go.opentelemetry.io/otel" "github.com/netbirdio/netbird/relay/auth/allow" "github.com/netbirdio/netbird/relay/auth/hmac" @@ -22,7 +21,6 @@ import ( ) var ( - av = &allow.Auth{} hmacTokenStore = &hmac.TokenStore{} pairs = []int{1, 5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100} dataSize = 1024 * 1024 * 10 @@ -70,8 +68,12 @@ func transfer(t *testing.T, testData []byte, peerPairs int) { port := 35000 + peerPairs serverAddress := fmt.Sprintf("127.0.0.1:%d", port) serverConnURL := fmt.Sprintf("rel://%s", serverAddress) - - srv, err := server.NewServer(otel.Meter(""), serverConnURL, false, av) + serverCfg := server.Config{ + ExposedAddress: serverConnURL, + TLSSupport: false, + AuthValidator: &allow.Auth{}, + } + srv, err := server.NewServer(serverCfg) if err != nil { t.Fatalf("failed to create server: %s", err) } @@ -98,8 +100,8 @@ func transfer(t *testing.T, testData []byte, peerPairs int) { clientsSender := make([]*client.Client, peerPairs) for i := 0; i < cap(clientsSender); i++ { - c := client.NewClient(ctx, serverConnURL, hmacTokenStore, "sender-"+fmt.Sprint(i)) - err := c.Connect() + c := client.NewClient(serverConnURL, hmacTokenStore, "sender-"+fmt.Sprint(i)) + err := c.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) } @@ -108,8 +110,8 @@ func transfer(t *testing.T, testData []byte, peerPairs int) { clientsReceiver := make([]*client.Client, peerPairs) for i := 0; i < cap(clientsReceiver); i++ { - c := client.NewClient(ctx, serverConnURL, hmacTokenStore, "receiver-"+fmt.Sprint(i)) - err := c.Connect() + c := client.NewClient(serverConnURL, hmacTokenStore, "receiver-"+fmt.Sprint(i)) + err := c.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) } @@ -119,13 +121,13 @@ func transfer(t *testing.T, testData []byte, peerPairs int) { connsSender := make([]net.Conn, 0, peerPairs) connsReceiver := make([]net.Conn, 0, peerPairs) for i := 0; i < len(clientsSender); i++ { - conn, err := clientsSender[i].OpenConn("receiver-" + fmt.Sprint(i)) + conn, err := clientsSender[i].OpenConn(ctx, "receiver-"+fmt.Sprint(i)) if err != nil { t.Fatalf("failed to bind channel: %s", err) } connsSender = append(connsSender, conn) - conn, err = clientsReceiver[i].OpenConn("sender-" + fmt.Sprint(i)) + conn, err = clientsReceiver[i].OpenConn(ctx, "sender-"+fmt.Sprint(i)) if err != nil { t.Fatalf("failed to bind channel: %s", err) } diff --git a/relay/testec2/relay.go b/relay/testec2/relay.go index 93d084387..9e22a80ea 100644 --- a/relay/testec2/relay.go +++ b/relay/testec2/relay.go @@ -70,8 +70,8 @@ func prepareConnsSender(serverConnURL string, peerPairs int) []net.Conn { ctx := context.Background() clientsSender := make([]*client.Client, peerPairs) for i := 0; i < cap(clientsSender); i++ { - c := client.NewClient(ctx, serverConnURL, hmacTokenStore, "sender-"+fmt.Sprint(i)) - if err := c.Connect(); err != nil { + c := client.NewClient(serverConnURL, hmacTokenStore, "sender-"+fmt.Sprint(i)) + if err := c.Connect(ctx); err != nil { log.Fatalf("failed to connect to server: %s", err) } clientsSender[i] = c @@ -79,7 +79,7 @@ func prepareConnsSender(serverConnURL string, peerPairs int) []net.Conn { connsSender := make([]net.Conn, 0, peerPairs) for i := 0; i < len(clientsSender); i++ { - conn, err := clientsSender[i].OpenConn("receiver-" + fmt.Sprint(i)) + conn, err := clientsSender[i].OpenConn(ctx, "receiver-"+fmt.Sprint(i)) if err != nil { log.Fatalf("failed to bind channel: %s", err) } @@ -156,8 +156,8 @@ func runReader(conn net.Conn) time.Duration { func prepareConnsReceiver(serverConnURL string, peerPairs int) []net.Conn { clientsReceiver := make([]*client.Client, peerPairs) for i := 0; i < cap(clientsReceiver); i++ { - c := client.NewClient(context.Background(), serverConnURL, hmacTokenStore, "receiver-"+fmt.Sprint(i)) - err := c.Connect() + c := client.NewClient(serverConnURL, hmacTokenStore, "receiver-"+fmt.Sprint(i)) + err := c.Connect(context.Background()) if err != nil { log.Fatalf("failed to connect to server: %s", err) } @@ -166,7 +166,7 @@ func prepareConnsReceiver(serverConnURL string, peerPairs int) []net.Conn { connsReceiver := make([]net.Conn, 0, peerPairs) for i := 0; i < len(clientsReceiver); i++ { - conn, err := clientsReceiver[i].OpenConn("sender-" + fmt.Sprint(i)) + conn, err := clientsReceiver[i].OpenConn(context.Background(), "sender-"+fmt.Sprint(i)) if err != nil { log.Fatalf("failed to bind channel: %s", err) } From b524f486e2434bf07bd51efc7f10652cc7292d7b Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Wed, 16 Jul 2025 00:00:18 +0200 Subject: [PATCH 12/50] [client] Fix/nil relayed address (#4153) Fix nil pointer in Relay conn address Meanwhile, we create a relayed net.Conn struct instance, it is possible to set the relayedURL to nil. panic: value method github.com/netbirdio/netbird/relay/client.RelayAddr.String called using nil *RelayAddr pointer Fix relayed URL variable protection Protect the channel closing --- relay/client/client.go | 61 ++++++++++++++++++------------- relay/client/manager_test.go | 17 ++------- relay/client/peer_subscription.go | 59 +++++++++++++++++++++--------- relay/server/store/listener.go | 15 ++++---- relay/server/store/notifier.go | 4 +- 5 files changed, 90 insertions(+), 66 deletions(-) diff --git a/relay/client/client.go b/relay/client/client.go index 2bf679ecb..32dfbb4db 100644 --- a/relay/client/client.go +++ b/relay/client/client.go @@ -181,13 +181,17 @@ func (c *Client) Connect(ctx context.Context) error { return nil } - if err := c.connect(ctx); err != nil { + instanceURL, err := c.connect(ctx) + if err != nil { return err } + c.muInstanceURL.Lock() + c.instanceURL = instanceURL + c.muInstanceURL.Unlock() c.stateSubscription = NewPeersStateSubscription(c.log, c.relayConn, c.closeConnsByPeerID) - c.log = c.log.WithField("relay", c.instanceURL.String()) + c.log = c.log.WithField("relay", instanceURL.String()) c.log.Infof("relay connection established") c.serviceIsRunning = true @@ -229,9 +233,18 @@ func (c *Client) OpenConn(ctx context.Context, dstPeerID string) (net.Conn, erro c.log.Infof("remote peer is available, prepare the relayed connection: %s", peerID) msgChannel := make(chan Msg, 100) - conn := NewConn(c, peerID, msgChannel, c.instanceURL) c.mu.Lock() + if !c.serviceIsRunning { + c.mu.Unlock() + return nil, fmt.Errorf("relay connection is not established") + } + + c.muInstanceURL.Lock() + instanceURL := c.instanceURL + c.muInstanceURL.Unlock() + conn := NewConn(c, peerID, msgChannel, instanceURL) + _, ok = c.conns[peerID] if ok { c.mu.Unlock() @@ -278,69 +291,67 @@ func (c *Client) Close() error { return c.close(true) } -func (c *Client) connect(ctx context.Context) error { +func (c *Client) connect(ctx context.Context) (*RelayAddr, error) { rd := dialer.NewRaceDial(c.log, c.connectionURL, quic.Dialer{}, ws.Dialer{}) conn, err := rd.Dial() if err != nil { - return err + return nil, err } c.relayConn = conn - if err = c.handShake(ctx); err != nil { + instanceURL, err := c.handShake(ctx) + if err != nil { cErr := conn.Close() if cErr != nil { c.log.Errorf("failed to close connection: %s", cErr) } - return err + return nil, err } - return nil + return instanceURL, nil } -func (c *Client) handShake(ctx context.Context) error { +func (c *Client) handShake(ctx context.Context) (*RelayAddr, error) { msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary()) if err != nil { c.log.Errorf("failed to marshal auth message: %s", err) - return err + return nil, err } _, err = c.relayConn.Write(msg) if err != nil { c.log.Errorf("failed to send auth message: %s", err) - return err + return nil, err } buf := make([]byte, messages.MaxHandshakeRespSize) n, err := c.readWithTimeout(ctx, buf) if err != nil { c.log.Errorf("failed to read auth response: %s", err) - return err + return nil, err } _, err = messages.ValidateVersion(buf[:n]) if err != nil { - return fmt.Errorf("validate version: %w", err) + return nil, fmt.Errorf("validate version: %w", err) } msgType, err := messages.DetermineServerMessageType(buf[:n]) if err != nil { c.log.Errorf("failed to determine message type: %s", err) - return err + return nil, err } if msgType != messages.MsgTypeAuthResponse { c.log.Errorf("unexpected message type: %s", msgType) - return fmt.Errorf("unexpected message type") + return nil, fmt.Errorf("unexpected message type") } addr, err := messages.UnmarshalAuthResponse(buf[:n]) if err != nil { - return err + return nil, err } - c.muInstanceURL.Lock() - c.instanceURL = &RelayAddr{addr: addr} - c.muInstanceURL.Unlock() - return nil + return &RelayAddr{addr: addr}, nil } func (c *Client) readLoop(hc *healthcheck.Receiver, relayConn net.Conn, internallyStoppedFlag *internalStopFlag) { @@ -386,10 +397,6 @@ func (c *Client) readLoop(hc *healthcheck.Receiver, relayConn net.Conn, internal hc.Stop() - c.muInstanceURL.Lock() - c.instanceURL = nil - c.muInstanceURL.Unlock() - c.stateSubscription.Cleanup() c.wgReadLoop.Done() _ = c.close(false) @@ -578,8 +585,12 @@ func (c *Client) close(gracefullyExit bool) error { c.log.Warn("relay connection was already marked as not running") return nil } - c.serviceIsRunning = false + + c.muInstanceURL.Lock() + c.instanceURL = nil + c.muInstanceURL.Unlock() + c.log.Infof("closing all peer connections") c.closeAllConns() if gracefullyExit { diff --git a/relay/client/manager_test.go b/relay/client/manager_test.go index d20cdaac0..52f2833e4 100644 --- a/relay/client/manager_test.go +++ b/relay/client/manager_test.go @@ -229,16 +229,14 @@ func TestForeginAutoClose(t *testing.T) { errChan := make(chan error, 1) go func() { t.Log("binding server 1.") - err := srv1.Listen(srvCfg1) - if err != nil { + if err := srv1.Listen(srvCfg1); err != nil { errChan <- err } }() defer func() { t.Logf("closing server 1.") - err := srv1.Shutdown(ctx) - if err != nil { + if err := srv1.Shutdown(ctx); err != nil { t.Errorf("failed to close server: %s", err) } t.Logf("server 1. closed") @@ -287,15 +285,8 @@ func TestForeginAutoClose(t *testing.T) { } t.Log("open connection to another peer") - conn, err := mgr.OpenConn(ctx, toURL(srvCfg2)[0], "anotherpeer") - if err != nil { - t.Fatalf("failed to bind channel: %s", err) - } - - t.Log("close conn") - err = conn.Close() - if err != nil { - t.Fatalf("failed to close connection: %s", err) + if _, err = mgr.OpenConn(ctx, toURL(srvCfg2)[0], "anotherpeer"); err == nil { + t.Fatalf("should have failed to open connection to another peer") } timeout := relayCleanupInterval + keepUnusedServerTime + 1*time.Second diff --git a/relay/client/peer_subscription.go b/relay/client/peer_subscription.go index 03e7127b3..85bd41cbd 100644 --- a/relay/client/peer_subscription.go +++ b/relay/client/peer_subscription.go @@ -3,6 +3,8 @@ package client import ( "context" "errors" + "fmt" + "sync" "time" log "github.com/sirupsen/logrus" @@ -28,6 +30,7 @@ type PeersStateSubscription struct { listenForOfflinePeers map[messages.PeerID]struct{} waitingPeers map[messages.PeerID]chan struct{} + mu sync.Mutex // Mutex to protect access to waitingPeers and listenForOfflinePeers } func NewPeersStateSubscription(log *log.Entry, relayConn relayedConnWriter, offlineCallback func(peerIDs []messages.PeerID)) *PeersStateSubscription { @@ -43,24 +46,31 @@ func NewPeersStateSubscription(log *log.Entry, relayConn relayedConnWriter, offl // OnPeersOnline should be called when a notification is received that certain peers have come online. // It checks if any of the peers are being waited on and signals their availability. func (s *PeersStateSubscription) OnPeersOnline(peersID []messages.PeerID) { + s.mu.Lock() + defer s.mu.Unlock() + for _, peerID := range peersID { waitCh, ok := s.waitingPeers[peerID] if !ok { + // If meanwhile the peer was unsubscribed, we don't need to signal it continue } - close(waitCh) + waitCh <- struct{}{} delete(s.waitingPeers, peerID) + close(waitCh) } } func (s *PeersStateSubscription) OnPeersWentOffline(peersID []messages.PeerID) { + s.mu.Lock() relevantPeers := make([]messages.PeerID, 0, len(peersID)) for _, peerID := range peersID { if _, ok := s.listenForOfflinePeers[peerID]; ok { relevantPeers = append(relevantPeers, peerID) } } + s.mu.Unlock() if len(relevantPeers) > 0 { s.offlineCallback(relevantPeers) @@ -68,36 +78,41 @@ func (s *PeersStateSubscription) OnPeersWentOffline(peersID []messages.PeerID) { } // WaitToBeOnlineAndSubscribe waits for a specific peer to come online and subscribes to its state changes. -// todo: when we unsubscribe while this is running, this will not return with error func (s *PeersStateSubscription) WaitToBeOnlineAndSubscribe(ctx context.Context, peerID messages.PeerID) error { // Check if already waiting for this peer + s.mu.Lock() if _, exists := s.waitingPeers[peerID]; exists { + s.mu.Unlock() return errors.New("already waiting for peer to come online") } // Create a channel to wait for the peer to come online - waitCh := make(chan struct{}) + waitCh := make(chan struct{}, 1) s.waitingPeers[peerID] = waitCh + s.listenForOfflinePeers[peerID] = struct{}{} + s.mu.Unlock() - if err := s.subscribeStateChange([]messages.PeerID{peerID}); err != nil { + if err := s.subscribeStateChange(peerID); err != nil { s.log.Errorf("failed to subscribe to peer state: %s", err) - close(waitCh) - delete(s.waitingPeers, peerID) - return err - } - - defer func() { + s.mu.Lock() if ch, exists := s.waitingPeers[peerID]; exists && ch == waitCh { close(waitCh) delete(s.waitingPeers, peerID) + delete(s.listenForOfflinePeers, peerID) } - }() + s.mu.Unlock() + return err + } // Wait for peer to come online or context to be cancelled timeoutCtx, cancel := context.WithTimeout(ctx, OpenConnectionTimeout) defer cancel() select { - case <-waitCh: + case _, ok := <-waitCh: + if !ok { + return fmt.Errorf("wait for peer to come online has been cancelled") + } + s.log.Debugf("peer %s is now online", peerID) return nil case <-timeoutCtx.Done(): @@ -105,6 +120,13 @@ func (s *PeersStateSubscription) WaitToBeOnlineAndSubscribe(ctx context.Context, if err := s.unsubscribeStateChange([]messages.PeerID{peerID}); err != nil { s.log.Errorf("failed to unsubscribe from peer state: %s", err) } + s.mu.Lock() + if ch, exists := s.waitingPeers[peerID]; exists && ch == waitCh { + close(waitCh) + delete(s.waitingPeers, peerID) + delete(s.listenForOfflinePeers, peerID) + } + s.mu.Unlock() return timeoutCtx.Err() } } @@ -112,6 +134,7 @@ func (s *PeersStateSubscription) WaitToBeOnlineAndSubscribe(ctx context.Context, func (s *PeersStateSubscription) UnsubscribeStateChange(peerIDs []messages.PeerID) error { msgErr := s.unsubscribeStateChange(peerIDs) + s.mu.Lock() for _, peerID := range peerIDs { if wch, ok := s.waitingPeers[peerID]; ok { close(wch) @@ -120,11 +143,15 @@ func (s *PeersStateSubscription) UnsubscribeStateChange(peerIDs []messages.PeerI delete(s.listenForOfflinePeers, peerID) } + s.mu.Unlock() return msgErr } func (s *PeersStateSubscription) Cleanup() { + s.mu.Lock() + defer s.mu.Unlock() + for _, waitCh := range s.waitingPeers { close(waitCh) } @@ -133,16 +160,12 @@ func (s *PeersStateSubscription) Cleanup() { s.listenForOfflinePeers = make(map[messages.PeerID]struct{}) } -func (s *PeersStateSubscription) subscribeStateChange(peerIDs []messages.PeerID) error { - msgs, err := messages.MarshalSubPeerStateMsg(peerIDs) +func (s *PeersStateSubscription) subscribeStateChange(peerID messages.PeerID) error { + msgs, err := messages.MarshalSubPeerStateMsg([]messages.PeerID{peerID}) if err != nil { return err } - for _, peer := range peerIDs { - s.listenForOfflinePeers[peer] = struct{}{} - } - for _, msg := range msgs { if _, err := s.relayConn.Write(msg); err != nil { return err diff --git a/relay/server/store/listener.go b/relay/server/store/listener.go index e5f455795..b7c5f4ce8 100644 --- a/relay/server/store/listener.go +++ b/relay/server/store/listener.go @@ -8,6 +8,7 @@ import ( ) type Listener struct { + ctx context.Context store *Store onlineChan chan messages.PeerID @@ -15,12 +16,11 @@ type Listener struct { interestedPeersForOffline map[messages.PeerID]struct{} interestedPeersForOnline map[messages.PeerID]struct{} mu sync.RWMutex - - listenerCtx context.Context } -func newListener(store *Store) *Listener { +func newListener(ctx context.Context, store *Store) *Listener { l := &Listener{ + ctx: ctx, store: store, onlineChan: make(chan messages.PeerID, 244), //244 is the message size limit in the relay protocol @@ -65,11 +65,10 @@ func (l *Listener) RemoveInterestedPeer(peerIDs []messages.PeerID) { } } -func (l *Listener) listenForEvents(ctx context.Context, onPeersComeOnline, onPeersWentOffline func([]messages.PeerID)) { - l.listenerCtx = ctx +func (l *Listener) listenForEvents(onPeersComeOnline, onPeersWentOffline func([]messages.PeerID)) { for { select { - case <-ctx.Done(): + case <-l.ctx.Done(): return case pID := <-l.onlineChan: peers := make([]messages.PeerID, 0) @@ -102,7 +101,7 @@ func (l *Listener) peerWentOffline(peerID messages.PeerID) { if _, ok := l.interestedPeersForOffline[peerID]; ok { select { case l.offlineChan <- peerID: - case <-l.listenerCtx.Done(): + case <-l.ctx.Done(): } } } @@ -114,7 +113,7 @@ func (l *Listener) peerComeOnline(peerID messages.PeerID) { if _, ok := l.interestedPeersForOnline[peerID]; ok { select { case l.onlineChan <- peerID: - case <-l.listenerCtx.Done(): + case <-l.ctx.Done(): } delete(l.interestedPeersForOnline, peerID) } diff --git a/relay/server/store/notifier.go b/relay/server/store/notifier.go index d04db478b..ad2e53545 100644 --- a/relay/server/store/notifier.go +++ b/relay/server/store/notifier.go @@ -24,8 +24,8 @@ func NewPeerNotifier(store *Store) *PeerNotifier { func (pn *PeerNotifier) NewListener(onPeersComeOnline, onPeersWentOffline func([]messages.PeerID)) *Listener { ctx, cancel := context.WithCancel(context.Background()) - listener := newListener(pn.store) - go listener.listenForEvents(ctx, onPeersComeOnline, onPeersWentOffline) + listener := newListener(ctx, pn.store) + go listener.listenForEvents(onPeersComeOnline, onPeersWentOffline) pn.listenersMutex.Lock() pn.listeners[listener] = cancel From e67f44f47c2d19bf7902c5303c4ee6b242c23fd4 Mon Sep 17 00:00:00 2001 From: Pedro Maia Costa <550684+pnmcosta@users.noreply.github.com> Date: Wed, 16 Jul 2025 11:09:38 +0100 Subject: [PATCH 13/50] [client] fix test (#4156) --- client/internal/engine_test.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 4b7a2d600..01bfbcef5 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -1481,6 +1481,10 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri GetSettings(gomock.Any(), gomock.Any(), gomock.Any()). Return(&types.Settings{}, nil). AnyTimes() + settingsMockManager.EXPECT(). + GetExtraSettings(gomock.Any(), gomock.Any()). + Return(&types.ExtraSettings{}, nil). + AnyTimes() permissionsManager := permissions.NewManager(store) From 58185ced1664e37e7680aa2aa241f220777c185b Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 16 Jul 2025 14:10:28 +0200 Subject: [PATCH 14/50] [misc] add forum post and update sign pipeline (#4155) use old git-town version --- .github/workflows/git-town.yml | 4 ++-- .github/workflows/release.yml | 16 +++++++++++++++- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/.github/workflows/git-town.yml b/.github/workflows/git-town.yml index c54fcb449..699ed7d93 100644 --- a/.github/workflows/git-town.yml +++ b/.github/workflows/git-town.yml @@ -16,6 +16,6 @@ jobs: steps: - uses: actions/checkout@v4 - - uses: git-town/action@v1 + - uses: git-town/action@v1.2.1 with: - skip-single-stacks: true \ No newline at end of file + skip-single-stacks: true diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 00898ab29..44e02f457 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -9,7 +9,7 @@ on: pull_request: env: - SIGN_PIPE_VER: "v0.0.20" + SIGN_PIPE_VER: "v0.0.21" GORELEASER_VER: "v2.3.2" PRODUCT_NAME: "NetBird" COPYRIGHT: "NetBird GmbH" @@ -231,3 +231,17 @@ jobs: ref: ${{ env.SIGN_PIPE_VER }} token: ${{ secrets.SIGN_GITHUB_TOKEN }} inputs: '{ "tag": "${{ github.ref }}", "skipRelease": false }' + + post_on_forum: + runs-on: ubuntu-latest + continue-on-error: true + needs: [trigger_signer] + steps: + - uses: Codixer/discourse-topic-github-release-action@v2.0.1 + with: + discourse-api-key: ${{ secrets.DISCOURSE_RELEASES_API_KEY }} + discourse-base-url: https://forum.netbird.io + discourse-author-username: NetBird + discourse-category: 17 + discourse-tags: + releases From 4f74509d55b54e5db316f3c50ee587706570ee93 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Wed, 16 Jul 2025 15:07:31 +0200 Subject: [PATCH 15/50] [management] fix index creation if exist on mysql (#4150) --- management/server/migration/migration.go | 12 +- management/server/migration/migration_test.go | 127 +++++++++++++++--- 2 files changed, 120 insertions(+), 19 deletions(-) diff --git a/management/server/migration/migration.go b/management/server/migration/migration.go index ab11be731..c2f1a5abf 100644 --- a/management/server/migration/migration.go +++ b/management/server/migration/migration.go @@ -283,7 +283,7 @@ func MigrateSetupKeyToHashedSetupKey[T any](ctx context.Context, db *gorm.DB) er } } - if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s", "peers", "setup_key")).Error; err != nil { + if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s DROP COLUMN IF EXISTS %s", "peers", "setup_key")).Error; err != nil { log.WithContext(ctx).Errorf("Failed to drop column %s: %v", "setup_key", err) } @@ -377,6 +377,11 @@ func DropIndex[T any](ctx context.Context, db *gorm.DB, indexName string) error func CreateIndexIfNotExists[T any](ctx context.Context, db *gorm.DB, indexName string, columns ...string) error { var model T + if !db.Migrator().HasTable(&model) { + log.WithContext(ctx).Debugf("table for %T does not exist, no migration needed", model) + return nil + } + stmt := &gorm.Statement{DB: db} if err := stmt.Parse(&model); err != nil { return fmt.Errorf("failed to parse model schema: %w", err) @@ -384,6 +389,11 @@ func CreateIndexIfNotExists[T any](ctx context.Context, db *gorm.DB, indexName s tableName := stmt.Schema.Table dialect := db.Dialector.Name() + if db.Migrator().HasIndex(&model, indexName) { + log.WithContext(ctx).Infof("index %s already exists on table %s", indexName, tableName) + return nil + } + var columnClause string if dialect == "mysql" { var withLength []string diff --git a/management/server/migration/migration_test.go b/management/server/migration/migration_test.go index 94377930a..ce76bd668 100644 --- a/management/server/migration/migration_test.go +++ b/management/server/migration/migration_test.go @@ -4,16 +4,21 @@ import ( "context" "encoding/gob" "net" + "os" "strings" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "gorm.io/driver/mysql" + "gorm.io/driver/postgres" "gorm.io/driver/sqlite" "gorm.io/gorm" "github.com/netbirdio/netbird/management/server/migration" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/testutil" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" ) @@ -21,7 +26,41 @@ import ( func setupDatabase(t *testing.T) *gorm.DB { t.Helper() - db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{}) + var db *gorm.DB + var err error + var dsn string + var cleanup func() + switch os.Getenv("NETBIRD_STORE_ENGINE") { + case "mysql": + cleanup, dsn, err = testutil.CreateMysqlTestContainer() + if err != nil { + t.Fatalf("Failed to create MySQL test container: %v", err) + } + + if dsn == "" { + t.Fatal("MySQL connection string is empty, ensure the test container is running") + } + + db, err = gorm.Open(mysql.Open(dsn+"?charset=utf8&parseTime=True&loc=Local"), &gorm.Config{}) + case "postgres": + cleanup, dsn, err = testutil.CreatePostgresTestContainer() + if err != nil { + t.Fatalf("Failed to create PostgreSQL test container: %v", err) + } + + if dsn == "" { + t.Fatalf("PostgreSQL connection string is empty, ensure the test container is running") + } + + db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{}) + case "sqlite": + db, err = gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{}) + default: + db, err = gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{}) + } + if cleanup != nil { + t.Cleanup(cleanup) + } require.NoError(t, err, "Failed to open database") return db @@ -34,6 +73,7 @@ func TestMigrateFieldFromGobToJSON_EmptyDB(t *testing.T) { } func TestMigrateFieldFromGobToJSON_WithGobData(t *testing.T) { + t.Setenv("NETBIRD_STORE_ENGINE", "sqlite") db := setupDatabase(t) err := db.AutoMigrate(&types.Account{}, &route.Route{}) @@ -97,6 +137,7 @@ func TestMigrateNetIPFieldFromBlobToJSON_EmptyDB(t *testing.T) { } func TestMigrateNetIPFieldFromBlobToJSON_WithBlobData(t *testing.T) { + t.Setenv("NETBIRD_STORE_ENGINE", "sqlite") db := setupDatabase(t) err := db.AutoMigrate(&types.Account{}, &nbpeer.Peer{}) @@ -117,12 +158,18 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithBlobData(t *testing.T) { Peers []peer `gorm:"foreignKey:AccountID;references:id"` } - err = db.Save(&account{ + a := &account{ Account: types.Account{Id: "123"}, - Peers: []peer{ - {Location: location{ConnectionIP: net.IP{10, 0, 0, 1}}}, - }}, - ).Error + } + + err = db.Save(a).Error + require.NoError(t, err, "Failed to insert account") + + a.Peers = []peer{ + {Location: location{ConnectionIP: net.IP{10, 0, 0, 1}}}, + } + + err = db.Save(a).Error require.NoError(t, err, "Failed to insert blob data") var blobValue string @@ -143,12 +190,18 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithJSONData(t *testing.T) { err := db.AutoMigrate(&types.Account{}, &nbpeer.Peer{}) require.NoError(t, err, "Failed to auto-migrate tables") - err = db.Save(&types.Account{ + account := &types.Account{ Id: "1234", - PeersG: []nbpeer.Peer{ - {Location: nbpeer.Location{ConnectionIP: net.IP{10, 0, 0, 1}}}, - }}, - ).Error + } + + err = db.Save(account).Error + require.NoError(t, err, "Failed to insert account") + + account.PeersG = []nbpeer.Peer{ + {AccountID: "1234", Location: nbpeer.Location{ConnectionIP: net.IP{10, 0, 0, 1}}}, + } + + err = db.Save(account).Error require.NoError(t, err, "Failed to insert JSON data") err = migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](context.Background(), db, "location_connection_ip", "") @@ -162,12 +215,13 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithJSONData(t *testing.T) { func TestMigrateSetupKeyToHashedSetupKey_ForPlainKey(t *testing.T) { db := setupDatabase(t) - err := db.AutoMigrate(&types.SetupKey{}) + err := db.AutoMigrate(&types.SetupKey{}, &nbpeer.Peer{}) require.NoError(t, err, "Failed to auto-migrate tables") err = db.Save(&types.SetupKey{ - Id: "1", - Key: "EEFDAB47-C1A5-4472-8C05-71DE9A1E8382", + Id: "1", + Key: "EEFDAB47-C1A5-4472-8C05-71DE9A1E8382", + UpdatedAt: time.Now(), }).Error require.NoError(t, err, "Failed to insert setup key") @@ -192,6 +246,7 @@ func TestMigrateSetupKeyToHashedSetupKey_ForAlreadyMigratedKey_Case1(t *testing. Id: "1", Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", KeySecret: "EEFDA****", + UpdatedAt: time.Now(), }).Error require.NoError(t, err, "Failed to insert setup key") @@ -213,8 +268,9 @@ func TestMigrateSetupKeyToHashedSetupKey_ForAlreadyMigratedKey_Case2(t *testing. require.NoError(t, err, "Failed to auto-migrate tables") err = db.Save(&types.SetupKey{ - Id: "1", - Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", + Id: "1", + Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", + UpdatedAt: time.Now(), }).Error require.NoError(t, err, "Failed to insert setup key") @@ -235,8 +291,9 @@ func TestDropIndex(t *testing.T) { require.NoError(t, err, "Failed to auto-migrate tables") err = db.Save(&types.SetupKey{ - Id: "1", - Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", + Id: "1", + Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", + UpdatedAt: time.Now(), }).Error require.NoError(t, err, "Failed to insert setup key") @@ -249,3 +306,37 @@ func TestDropIndex(t *testing.T) { exist = db.Migrator().HasIndex(&types.SetupKey{}, "idx_setup_keys_account_id") assert.False(t, exist, "Should not have the index") } + +func TestCreateIndex(t *testing.T) { + db := setupDatabase(t) + err := db.AutoMigrate(&nbpeer.Peer{}) + assert.NoError(t, err, "Failed to auto-migrate tables") + + indexName := "idx_account_ip" + + err = migration.CreateIndexIfNotExists[nbpeer.Peer](context.Background(), db, indexName, "account_id", "ip") + assert.NoError(t, err, "Migration should not fail to create index") + + exist := db.Migrator().HasIndex(&nbpeer.Peer{}, indexName) + assert.True(t, exist, "Should have the index") +} + +func TestCreateIndexIfExists(t *testing.T) { + db := setupDatabase(t) + err := db.AutoMigrate(&nbpeer.Peer{}) + assert.NoError(t, err, "Failed to auto-migrate tables") + + indexName := "idx_account_ip" + + err = migration.CreateIndexIfNotExists[nbpeer.Peer](context.Background(), db, indexName, "account_id", "ip") + assert.NoError(t, err, "Migration should not fail to create index") + + exist := db.Migrator().HasIndex(&nbpeer.Peer{}, indexName) + assert.True(t, exist, "Should have the index") + + err = migration.CreateIndexIfNotExists[nbpeer.Peer](context.Background(), db, indexName, "account_id", "ip") + assert.NoError(t, err, "Create index should not fail if index exists") + + exist = db.Migrator().HasIndex(&nbpeer.Peer{}, indexName) + assert.True(t, exist, "Should have the index") +} From 08fd460867f5675ebfcdf38a157248a5d33aa9dd Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Fri, 18 Jul 2025 12:18:52 +0200 Subject: [PATCH 16/50] [management] Add validate flow response (#4172) This PR adds a validate flow response feature to the management server by integrating an IntegratedValidator component. The main purpose is to enable validation of PKCE authorization flows through an integrated validator interface. - Adds a new ValidateFlowResponse method to the IntegratedValidator interface - Integrates the validator into the management server to validate PKCE authorization flows - Updates dependency version for management-integrations --- client/cmd/testutil_test.go | 2 +- client/internal/engine_test.go | 2 +- client/server/server_test.go | 2 +- go.mod | 2 +- go.sum | 4 +- management/client/client_test.go | 2 +- management/cmd/management.go | 2 +- management/server/account_test.go | 2 +- management/server/dns_test.go | 2 +- management/server/grpcserver.go | 38 +++++++++++-------- .../http/testing/testing_tools/tools.go | 3 +- management/server/integrated_validator.go | 20 +++++----- .../integrated_validator/interface.go | 2 + management/server/management_proto_test.go | 4 +- management/server/management_test.go | 3 +- management/server/nameserver_test.go | 2 +- management/server/peer_test.go | 12 +++--- management/server/route_test.go | 2 +- management/server/user_test.go | 2 +- 19 files changed, 60 insertions(+), 48 deletions(-) diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index b5a80d63a..228a5d507 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -109,7 +109,7 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc } secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager) - mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil) + mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &mgmt.MockIntegratedValidator{}) if err != nil { t.Fatal(err) } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 01bfbcef5..e75672ed1 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -1494,7 +1494,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri } secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager) - mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil) + mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{}) if err != nil { return nil, "", err } diff --git a/client/server/server_test.go b/client/server/server_test.go index 376b7e8bd..7c46aac5d 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -212,7 +212,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve } secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager) - mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil) + mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{}) if err != nil { return nil, "", err } diff --git a/go.mod b/go.mod index 4a9727373..cf2a23758 100644 --- a/go.mod +++ b/go.mod @@ -63,7 +63,7 @@ require ( github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 - github.com/netbirdio/management-integrations/integrations v0.0.0-20250612164546-6bd7e2338d65 + github.com/netbirdio/management-integrations/integrations v0.0.0-20250718071730-f4d133556ff5 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 diff --git a/go.sum b/go.sum index a622f203f..699a832dd 100644 --- a/go.sum +++ b/go.sum @@ -503,8 +503,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q= -github.com/netbirdio/management-integrations/integrations v0.0.0-20250612164546-6bd7e2338d65 h1:5OfYiLjpr4dbQYJI5ouZaylkVdi2KlErLFOwBeBo5Hw= -github.com/netbirdio/management-integrations/integrations v0.0.0-20250612164546-6bd7e2338d65/go.mod h1:Gi9raplYzCCyh07Olw/DVfCJTFgpr1WCXJ/Q+8TSA9Q= +github.com/netbirdio/management-integrations/integrations v0.0.0-20250718071730-f4d133556ff5 h1:Zfn8d83OVyELCdxgprcyXR3D8uqoxHtXE9PUxVXDx/w= +github.com/netbirdio/management-integrations/integrations v0.0.0-20250718071730-f4d133556ff5/go.mod h1:Gi9raplYzCCyh07Olw/DVfCJTFgpr1WCXJ/Q+8TSA9Q= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb h1:Cr6age+ePALqlSvtp7wc6lYY97XN7rkD1K4XEDmY+TU= diff --git a/management/client/client_test.go b/management/client/client_test.go index 1847af73e..b59b7c982 100644 --- a/management/client/client_test.go +++ b/management/client/client_test.go @@ -112,7 +112,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { } secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager) - mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil) + mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, mgmt.MockIntegratedValidator{}) if err != nil { t.Fatal(err) } diff --git a/management/cmd/management.go b/management/cmd/management.go index 878e4c39e..24c260e9c 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -292,7 +292,7 @@ var ( ephemeralManager.LoadInitialPeers(ctx) gRPCAPIHandler := grpc.NewServer(gRPCOpts...) - srv, err := server.NewServer(ctx, config, accountManager, settingsManager, peersUpdateManager, secretsManager, appMetrics, ephemeralManager, authManager) + srv, err := server.NewServer(ctx, config, accountManager, settingsManager, peersUpdateManager, secretsManager, appMetrics, ephemeralManager, authManager, integratedPeerValidator) if err != nil { return fmt.Errorf("failed creating gRPC API handler: %v", err) } diff --git a/management/server/account_test.go b/management/server/account_test.go index fcd40b082..b65dffe6c 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -2887,7 +2887,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, error) { permissionsManager := permissions.NewManager(store) - manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) if err != nil { return nil, err } diff --git a/management/server/dns_test.go b/management/server/dns_test.go index 31c944a25..f2295450f 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -219,7 +219,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) { // return empty extra settings for expected calls to UpdateAccountPeers settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), gomock.Any()).Return(&types.ExtraSettings{}, nil).AnyTimes() permissionsManager := permissions.NewManager(store) - return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) } func createDNSStore(t *testing.T) (store.Store, error) { diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 2b27f9e0f..2f1bc3673 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -19,6 +19,7 @@ import ( "google.golang.org/grpc/status" integrationsConfig "github.com/netbirdio/management-integrations/integrations/config" + "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/management/proto" @@ -40,13 +41,14 @@ type GRPCServer struct { settingsManager settings.Manager wgKey wgtypes.Key proto.UnimplementedManagementServiceServer - peersUpdateManager *PeersUpdateManager - config *types.Config - secretsManager SecretsManager - appMetrics telemetry.AppMetrics - ephemeralManager *EphemeralManager - peerLocks sync.Map - authManager auth.Manager + peersUpdateManager *PeersUpdateManager + config *types.Config + secretsManager SecretsManager + appMetrics telemetry.AppMetrics + ephemeralManager *EphemeralManager + peerLocks sync.Map + authManager auth.Manager + integratedPeerValidator integrated_validator.IntegratedValidator } // NewServer creates a new Management server @@ -60,6 +62,7 @@ func NewServer( appMetrics telemetry.AppMetrics, ephemeralManager *EphemeralManager, authManager auth.Manager, + integratedPeerValidator integrated_validator.IntegratedValidator, ) (*GRPCServer, error) { key, err := wgtypes.GeneratePrivateKey() if err != nil { @@ -79,14 +82,15 @@ func NewServer( return &GRPCServer{ wgKey: key, // peerKey -> event channel - peersUpdateManager: peersUpdateManager, - accountManager: accountManager, - settingsManager: settingsManager, - config: config, - secretsManager: secretsManager, - authManager: authManager, - appMetrics: appMetrics, - ephemeralManager: ephemeralManager, + peersUpdateManager: peersUpdateManager, + accountManager: accountManager, + settingsManager: settingsManager, + config: config, + secretsManager: secretsManager, + authManager: authManager, + appMetrics: appMetrics, + ephemeralManager: ephemeralManager, + integratedPeerValidator: integratedPeerValidator, }, nil } @@ -850,7 +854,7 @@ func (s *GRPCServer) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.En return nil, status.Error(codes.NotFound, "no pkce authorization flow information available") } - flowInfoResp := &proto.PKCEAuthorizationFlow{ + initInfoFlow := &proto.PKCEAuthorizationFlow{ ProviderConfig: &proto.ProviderConfig{ Audience: s.config.PKCEAuthorizationFlow.ProviderConfig.Audience, ClientID: s.config.PKCEAuthorizationFlow.ProviderConfig.ClientID, @@ -865,6 +869,8 @@ func (s *GRPCServer) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.En }, } + flowInfoResp := s.integratedPeerValidator.ValidateFlowResponse(ctx, peerKey.String(), initInfoFlow) + encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, flowInfoResp) if err != nil { return nil, status.Error(codes.Internal, "failed to encrypt no pkce authorization flow information") diff --git a/management/server/http/testing/testing_tools/tools.go b/management/server/http/testing/testing_tools/tools.go index 829bff455..e308f100f 100644 --- a/management/server/http/testing/testing_tools/tools.go +++ b/management/server/http/testing/testing_tools/tools.go @@ -1,4 +1,5 @@ package testing_tools + import ( "bytes" "context" @@ -132,7 +133,7 @@ func BuildApiBlackBoxWithDBState(t TB, sqlFile string, expectedPeerUpdate *serve } geoMock := &geolocation.Mock{} - validatorMock := server.MocIntegratedValidator{} + validatorMock := server.MockIntegratedValidator{} proxyController := integrations.NewController(store) userManager := users.NewManager(store) permissionsManager := permissions.NewManager(store) diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go index cfde7c614..e3e474411 100644 --- a/management/server/integrated_validator.go +++ b/management/server/integrated_validator.go @@ -6,6 +6,7 @@ import ( log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" @@ -101,22 +102,23 @@ func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountI return am.integratedPeerValidator.GetValidatedPeers(accountID, groups, peers, settings.Extra) } -type MocIntegratedValidator struct { +type MockIntegratedValidator struct { + integrated_validator.IntegratedValidator ValidatePeerFunc func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error) } -func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error { +func (a MockIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error { return nil } -func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error) { +func (a MockIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error) { if a.ValidatePeerFunc != nil { return a.ValidatePeerFunc(context.Background(), update, peer, userID, accountID, dnsDomain, peersGroup, extraSettings) } return update, false, nil } -func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *types.ExtraSettings) (map[string]struct{}, error) { +func (a MockIntegratedValidator) GetValidatedPeers(accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *types.ExtraSettings) (map[string]struct{}, error) { validatedPeers := make(map[string]struct{}) for _, peer := range peers { validatedPeers[peer.ID] = struct{}{} @@ -124,22 +126,22 @@ func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups []*ty return validatedPeers, nil } -func (MocIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) *nbpeer.Peer { +func (MockIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) *nbpeer.Peer { return peer } -func (MocIntegratedValidator) IsNotValidPeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error) { +func (MockIntegratedValidator) IsNotValidPeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error) { return false, false, nil } -func (MocIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error { +func (MockIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error { return nil } -func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)) { +func (MockIntegratedValidator) SetPeerInvalidationListener(func(accountID string)) { // just a dummy } -func (MocIntegratedValidator) Stop(_ context.Context) { +func (MockIntegratedValidator) Stop(_ context.Context) { // just a dummy } diff --git a/management/server/integrations/integrated_validator/interface.go b/management/server/integrations/integrated_validator/interface.go index 083baa65e..245c0168f 100644 --- a/management/server/integrations/integrated_validator/interface.go +++ b/management/server/integrations/integrated_validator/interface.go @@ -3,6 +3,7 @@ package integrated_validator import ( "context" + "github.com/netbirdio/netbird/management/proto" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/types" ) @@ -17,4 +18,5 @@ type IntegratedValidator interface { PeerDeleted(ctx context.Context, accountID, peerID string) error SetPeerInvalidationListener(fn func(accountID string)) Stop(ctx context.Context) + ValidateFlowResponse(ctx context.Context, peerKey string, flowResponse *proto.PKCEAuthorizationFlow) *proto.PKCEAuthorizationFlow } diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index 57c00ed9f..0d61b3a10 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -448,7 +448,7 @@ func startManagementForTest(t *testing.T, testFile string, config *types.Config) permissionsManager := permissions.NewManager(store) accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted", - eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) if err != nil { cleanup() @@ -458,7 +458,7 @@ func startManagementForTest(t *testing.T, testFile string, config *types.Config) secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager) ephemeralMgr := NewEphemeralManager(store, accountManager) - mgmtServer, err := NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, ephemeralMgr, nil) + mgmtServer, err := NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, ephemeralMgr, nil, MockIntegratedValidator{}) if err != nil { return nil, nil, "", cleanup, err } diff --git a/management/server/management_test.go b/management/server/management_test.go index 0a6b3f751..ab6f0095b 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -206,7 +206,7 @@ func startServer( eventStore, nil, false, - server.MocIntegratedValidator{}, + server.MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, @@ -227,6 +227,7 @@ func startServer( nil, nil, nil, + server.MockIntegratedValidator{}, ) if err != nil { t.Fatalf("failed creating management server: %v", err) diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 8fada742c..25eb03b83 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -785,7 +785,7 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) { AnyTimes() permissionsManager := permissions.NewManager(store) - return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) } func createNSStore(t *testing.T) (store.Store, error) { diff --git a/management/server/peer_test.go b/management/server/peer_test.go index d41020514..4f6ae500e 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -1273,7 +1273,7 @@ func Test_RegisterPeerByUser(t *testing.T) { settingsMockManager := settings.NewMockManager(ctrl) permissionsManager := permissions.NewManager(s) - am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" @@ -1353,7 +1353,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { AnyTimes() permissionsManager := permissions.NewManager(s) - am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" @@ -1496,7 +1496,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { permissionsManager := permissions.NewManager(s) - am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" @@ -1570,7 +1570,7 @@ func Test_LoginPeer(t *testing.T) { AnyTimes() permissionsManager := permissions.NewManager(s) - am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" @@ -1848,7 +1848,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { return update, true, nil } - manager.integratedPeerValidator = MocIntegratedValidator{ValidatePeerFunc: requireUpdateFunc} + manager.integratedPeerValidator = MockIntegratedValidator{ValidatePeerFunc: requireUpdateFunc} done := make(chan struct{}) go func() { peerShouldReceiveUpdate(t, updMsg) @@ -1870,7 +1870,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { return update, false, nil } - manager.integratedPeerValidator = MocIntegratedValidator{ValidatePeerFunc: requireNoUpdateFunc} + manager.integratedPeerValidator = MockIntegratedValidator{ValidatePeerFunc: requireNoUpdateFunc} done := make(chan struct{}) go func() { peerShouldNotReceiveUpdate(t, updMsg) diff --git a/management/server/route_test.go b/management/server/route_test.go index 77cbc75b9..37c37f624 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -1284,7 +1284,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) { permissionsManager := permissions.NewManager(store) - return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) } func createRouterStore(t *testing.T) (store.Store, error) { diff --git a/management/server/user_test.go b/management/server/user_test.go index 7508e0609..53baf8f7e 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -852,7 +852,7 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) { am := DefaultAccountManager{ Store: store, eventStore: &activity.InMemoryEventStore{}, - integratedPeerValidator: MocIntegratedValidator{}, + integratedPeerValidator: MockIntegratedValidator{}, permissionsManager: permissionsManager, } From f6e9d755e4a3069f30bcff40a2a0eb131f9a5cbe Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Mon, 21 Jul 2025 09:46:53 +0200 Subject: [PATCH 17/50] [client, relay] The openConn function no longer blocks the relayAddress function call (#4180) The openConn function no longer blocks the relayAddress function call in manager layer --- relay/client/manager.go | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/relay/client/manager.go b/relay/client/manager.go index 0fb682d95..b97bc0b99 100644 --- a/relay/client/manager.go +++ b/relay/client/manager.go @@ -65,7 +65,7 @@ type Manager struct { relayClient *Client // the guard logic can overwrite the relayClient variable, this mutex protect the usage of the variable - relayClientMu sync.Mutex + relayClientMu sync.RWMutex reconnectGuard *Guard relayClients map[string]*RelayTrack @@ -124,8 +124,8 @@ func (m *Manager) Serve() error { // established via the relay server. If the peer is on a different relay server, the manager will establish a new // connection to the relay server. It returns back with a net.Conn what represent the remote peer connection. func (m *Manager) OpenConn(ctx context.Context, serverAddress, peerKey string) (net.Conn, error) { - m.relayClientMu.Lock() - defer m.relayClientMu.Unlock() + m.relayClientMu.RLock() + defer m.relayClientMu.RUnlock() if m.relayClient == nil { return nil, ErrRelayClientNotConnected @@ -155,8 +155,8 @@ func (m *Manager) OpenConn(ctx context.Context, serverAddress, peerKey string) ( // Ready returns true if the home Relay client is connected to the relay server. func (m *Manager) Ready() bool { - m.relayClientMu.Lock() - defer m.relayClientMu.Unlock() + m.relayClientMu.RLock() + defer m.relayClientMu.RUnlock() if m.relayClient == nil { return false @@ -174,8 +174,8 @@ func (m *Manager) SetOnReconnectedListener(f func()) { // AddCloseListener adds a listener to the given server instance address. The listener will be called if the connection // closed. func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error { - m.relayClientMu.Lock() - defer m.relayClientMu.Unlock() + m.relayClientMu.RLock() + defer m.relayClientMu.RUnlock() if m.relayClient == nil { return ErrRelayClientNotConnected @@ -199,8 +199,8 @@ func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServ // RelayInstanceAddress returns the address of the permanent relay server. It could change if the network connection is // lost. This address will be sent to the target peer to choose the common relay server for the communication. func (m *Manager) RelayInstanceAddress() (string, error) { - m.relayClientMu.Lock() - defer m.relayClientMu.Unlock() + m.relayClientMu.RLock() + defer m.relayClientMu.RUnlock() if m.relayClient == nil { return "", ErrRelayClientNotConnected @@ -300,7 +300,9 @@ func (m *Manager) onServerConnected() { func (m *Manager) onServerDisconnected(serverAddress string) { m.relayClientMu.Lock() if serverAddress == m.relayClient.connectionURL { - go m.reconnectGuard.StartReconnectTrys(m.ctx, m.relayClient) + go func(client *Client) { + m.reconnectGuard.StartReconnectTrys(m.ctx, client) + }(m.relayClient) } m.relayClientMu.Unlock() From 40fdeda838101b1bdc644026ecc66d9fbca9adb4 Mon Sep 17 00:00:00 2001 From: Ali Amer <76897266+aliamerj@users.noreply.github.com> Date: Mon, 21 Jul 2025 12:55:17 +0300 Subject: [PATCH 18/50] [client] add new filter-by-connection-type flag (#4010) introduces a new flag --filter-by-connection-type to the status command. It allows users to filter peers by connection type (P2P or Relayed) in both JSON and detailed views. Input validation is added in parseFilters() to ensure proper usage, and --detail is auto-enabled if no output format is specified (consistent with other filters). --- client/cmd/debug.go | 2 +- client/cmd/status.go | 13 ++++++++++++- client/proto/daemon.pb.go | 7 +++++++ client/status/status.go | 20 +++++++++++--------- client/status/status_test.go | 2 +- client/ui/debug.go | 6 +++--- 6 files changed, 35 insertions(+), 15 deletions(-) diff --git a/client/cmd/debug.go b/client/cmd/debug.go index 4036bb8f6..3f13a0c3a 100644 --- a/client/cmd/debug.go +++ b/client/cmd/debug.go @@ -307,7 +307,7 @@ func getStatusOutput(cmd *cobra.Command, anon bool) string { cmd.PrintErrf("Failed to get status: %v\n", err) } else { statusOutputString = nbstatus.ParseToFullDetailSummary( - nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil), + nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, ""), ) } return statusOutputString diff --git a/client/cmd/status.go b/client/cmd/status.go index b108ca57a..2d6e41bc2 100644 --- a/client/cmd/status.go +++ b/client/cmd/status.go @@ -26,6 +26,7 @@ var ( statusFilter string ipsFilterMap map[string]struct{} prefixNamesFilterMap map[string]struct{} + connectionTypeFilter string ) var statusCmd = &cobra.Command{ @@ -45,6 +46,7 @@ func init() { statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200") statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud") statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected") + statusCmd.PersistentFlags().StringVar(&connectionTypeFilter, "filter-by-connection-type", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P") } func statusFunc(cmd *cobra.Command, args []string) error { @@ -89,7 +91,7 @@ func statusFunc(cmd *cobra.Command, args []string) error { return nil } - var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap) + var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter) var statusOutputString string switch { case detailFlag: @@ -156,6 +158,15 @@ func parseFilters() error { enableDetailFlagWhenFilterFlag() } + switch strings.ToLower(connectionTypeFilter) { + case "", "p2p", "relayed": + if strings.ToLower(connectionTypeFilter) != "" { + enableDetailFlagWhenFilterFlag() + } + default: + return fmt.Errorf("wrong connection-type filter, should be one of P2P|Relayed, got: %s", connectionTypeFilter) + } + return nil } diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index 26e58d183..753aa62d1 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -1330,6 +1330,13 @@ func (x *PeerState) GetRelayAddress() string { return "" } +func (x *PeerState) GetConnectionType() string { + if x.Relayed { + return "Relayed" + } + return "P2P" +} + // LocalPeerState contains the latest state of the local peer type LocalPeerState struct { state protoimpl.MessageState `protogen:"open.v1"` diff --git a/client/status/status.go b/client/status/status.go index 18056e363..507c7ea80 100644 --- a/client/status/status.go +++ b/client/status/status.go @@ -100,7 +100,7 @@ type OutputOverview struct { LazyConnectionEnabled bool `json:"lazyConnectionEnabled" yaml:"lazyConnectionEnabled"` } -func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}) OutputOverview { +func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}, connectionTypeFilter string) OutputOverview { pbFullStatus := resp.GetFullStatus() managementState := pbFullStatus.GetManagementState() @@ -118,7 +118,7 @@ func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, status } relayOverview := mapRelays(pbFullStatus.GetRelays()) - peersOverview := mapPeers(resp.GetFullStatus().GetPeers(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter) + peersOverview := mapPeers(resp.GetFullStatus().GetPeers(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter) overview := OutputOverview{ Peers: peersOverview, @@ -193,6 +193,7 @@ func mapPeers( prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}, + connectionTypeFilter string, ) PeersStateOutput { var peersStateDetail []PeerStateDetailOutput peersConnected := 0 @@ -208,7 +209,7 @@ func mapPeers( transferSent := int64(0) isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String() - if skipDetailByFilters(pbPeerState, pbPeerState.ConnStatus, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter) { + if skipDetailByFilters(pbPeerState, pbPeerState.ConnStatus, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter) { continue } if isPeerConnected { @@ -218,10 +219,7 @@ func mapPeers( remoteICE = pbPeerState.GetRemoteIceCandidateType() localICEEndpoint = pbPeerState.GetLocalIceCandidateEndpoint() remoteICEEndpoint = pbPeerState.GetRemoteIceCandidateEndpoint() - connType = "P2P" - if pbPeerState.Relayed { - connType = "Relayed" - } + connType = pbPeerState.GetConnectionType() relayServerAddress = pbPeerState.GetRelayAddress() lastHandshake = pbPeerState.GetLastWireguardHandshake().AsTime().Local() transferReceived = pbPeerState.GetBytesRx() @@ -542,10 +540,11 @@ func parsePeers(peers PeersStateOutput, rosenpassEnabled, rosenpassPermissive bo return peersString } -func skipDetailByFilters(peerState *proto.PeerState, peerStatus string, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}) bool { +func skipDetailByFilters(peerState *proto.PeerState, peerStatus string, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}, connectionTypeFilter string) bool { statusEval := false ipEval := false nameEval := true + connectionTypeEval := false if statusFilter != "" { if !strings.EqualFold(peerStatus, statusFilter) { @@ -570,8 +569,11 @@ func skipDetailByFilters(peerState *proto.PeerState, peerStatus string, statusFi } else { nameEval = false } + if connectionTypeFilter != "" && !strings.EqualFold(peerState.GetConnectionType(), connectionTypeFilter) { + connectionTypeEval = true + } - return statusEval || ipEval || nameEval + return statusEval || ipEval || nameEval || connectionTypeEval } func toIEC(b int64) string { diff --git a/client/status/status_test.go b/client/status/status_test.go index 33eda4b9e..5b5d23efd 100644 --- a/client/status/status_test.go +++ b/client/status/status_test.go @@ -234,7 +234,7 @@ var overview = OutputOverview{ } func TestConversionFromFullStatusToOutputOverview(t *testing.T) { - convertedResult := ConvertToStatusOutputOverview(resp, false, "", nil, nil, nil) + convertedResult := ConvertToStatusOutputOverview(resp, false, "", nil, nil, nil, "") assert.Equal(t, overview, convertedResult) } diff --git a/client/ui/debug.go b/client/ui/debug.go index ab7dba37a..55829de1e 100644 --- a/client/ui/debug.go +++ b/client/ui/debug.go @@ -433,7 +433,7 @@ func (s *serviceClient) collectDebugData( var postUpStatusOutput string if postUpStatus != nil { - overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil) + overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "") postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview) } headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339)) @@ -450,7 +450,7 @@ func (s *serviceClient) collectDebugData( var preDownStatusOutput string if preDownStatus != nil { - overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil) + overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "") preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview) } headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s", @@ -581,7 +581,7 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa var statusOutput string if statusResp != nil { - overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil) + overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "") statusOutput = nbstatus.ParseToFullDetailSummary(overview) } From d6ed9c037ed31a1b4a527b3e49f9a1aaafa8db74 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 21 Jul 2025 12:13:21 +0200 Subject: [PATCH 19/50] [client] Fix bind exclusion routes (#4154) --- client/iface/bind/control.go | 15 ++++ client/iface/bind/control_android.go | 12 --- client/iface/bind/ice_bind.go | 3 +- client/iface/bind/udp_mux.go | 25 +++--- client/iface/bind/udp_mux_generic.go | 21 +++++ client/iface/bind/udp_mux_ios.go | 7 ++ client/iface/bind/udp_mux_universal.go | 20 +++-- client/internal/engine.go | 14 +--- client/internal/engine_test.go | 2 +- client/internal/peer/conn.go | 52 ------------- .../routemanager/client/client_test.go | 2 +- client/internal/routemanager/manager.go | 13 ++-- client/internal/routemanager/manager_test.go | 2 +- client/internal/routemanager/mock.go | 5 +- .../routemanager/notifier/notifier_other.go | 2 +- .../routemanager/systemops/systemops.go | 5 ++ .../systemops/systemops_android.go | 5 +- .../systemops/systemops_generic.go | 76 ++++++++++++++----- .../systemops/systemops_generic_test.go | 6 +- .../routemanager/systemops/systemops_ios.go | 5 +- .../routemanager/systemops/systemops_linux.go | 6 +- .../routemanager/systemops/systemops_unix.go | 3 +- .../systemops/systemops_windows.go | 3 +- util/net/listener_listen.go | 67 ++++++++++++++-- util/net/listener_listen_ios.go | 10 +++ 25 files changed, 230 insertions(+), 151 deletions(-) create mode 100644 client/iface/bind/control.go delete mode 100644 client/iface/bind/control_android.go create mode 100644 client/iface/bind/udp_mux_generic.go create mode 100644 client/iface/bind/udp_mux_ios.go create mode 100644 util/net/listener_listen_ios.go diff --git a/client/iface/bind/control.go b/client/iface/bind/control.go new file mode 100644 index 000000000..89bddf12c --- /dev/null +++ b/client/iface/bind/control.go @@ -0,0 +1,15 @@ +package bind + +import ( + wireguard "golang.zx2c4.com/wireguard/conn" + + nbnet "github.com/netbirdio/netbird/util/net" +) + +// TODO: This is most likely obsolete since the control fns should be called by the wrapped udpconn (ice_bind.go) +func init() { + listener := nbnet.NewListener() + if listener.ListenConfig.Control != nil { + *wireguard.ControlFns = append(*wireguard.ControlFns, listener.ListenConfig.Control) + } +} diff --git a/client/iface/bind/control_android.go b/client/iface/bind/control_android.go deleted file mode 100644 index b8a865e39..000000000 --- a/client/iface/bind/control_android.go +++ /dev/null @@ -1,12 +0,0 @@ -package bind - -import ( - wireguard "golang.zx2c4.com/wireguard/conn" - - nbnet "github.com/netbirdio/netbird/util/net" -) - -func init() { - // ControlFns is not thread safe and should only be modified during init. - *wireguard.ControlFns = append(*wireguard.ControlFns, nbnet.ControlProtectSocket) -} diff --git a/client/iface/bind/ice_bind.go b/client/iface/bind/ice_bind.go index bb7a27279..c3d5ef377 100644 --- a/client/iface/bind/ice_bind.go +++ b/client/iface/bind/ice_bind.go @@ -16,6 +16,7 @@ import ( wgConn "golang.zx2c4.com/wireguard/conn" "github.com/netbirdio/netbird/client/iface/wgaddr" + nbnet "github.com/netbirdio/netbird/util/net" ) type RecvMessage struct { @@ -153,7 +154,7 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r s.udpMux = NewUniversalUDPMuxDefault( UniversalUDPMuxParams{ - UDPConn: conn, + UDPConn: nbnet.WrapUDPConn(conn), Net: s.transportNet, FilterFn: s.filterFn, WGAddress: s.address, diff --git a/client/iface/bind/udp_mux.go b/client/iface/bind/udp_mux.go index 0e58499aa..29e5d7937 100644 --- a/client/iface/bind/udp_mux.go +++ b/client/iface/bind/udp_mux.go @@ -296,14 +296,20 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) { return } - m.addressMapMu.Lock() - defer m.addressMapMu.Unlock() - + var allAddresses []string for _, c := range removedConns { addresses := c.getAddresses() - for _, addr := range addresses { - delete(m.addressMap, addr) - } + allAddresses = append(allAddresses, addresses...) + } + + m.addressMapMu.Lock() + for _, addr := range allAddresses { + delete(m.addressMap, addr) + } + m.addressMapMu.Unlock() + + for _, addr := range allAddresses { + m.notifyAddressRemoval(addr) } } @@ -351,14 +357,13 @@ func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string) } m.addressMapMu.Lock() - defer m.addressMapMu.Unlock() - existing, ok := m.addressMap[addr] if !ok { existing = []*udpMuxedConn{} } existing = append(existing, conn) m.addressMap[addr] = existing + m.addressMapMu.Unlock() log.Debugf("ICE: registered %s for %s", addr, conn.params.Key) } @@ -386,12 +391,12 @@ func (m *UDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) erro // If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one // muxed connection - one for the SRFLX candidate and the other one for the HOST one. // We will then forward STUN packets to each of these connections. - m.addressMapMu.Lock() + m.addressMapMu.RLock() var destinationConnList []*udpMuxedConn if storedConns, ok := m.addressMap[addr.String()]; ok { destinationConnList = append(destinationConnList, storedConns...) } - m.addressMapMu.Unlock() + m.addressMapMu.RUnlock() var isIPv6 bool if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil { diff --git a/client/iface/bind/udp_mux_generic.go b/client/iface/bind/udp_mux_generic.go new file mode 100644 index 000000000..e42d25462 --- /dev/null +++ b/client/iface/bind/udp_mux_generic.go @@ -0,0 +1,21 @@ +//go:build !ios + +package bind + +import ( + nbnet "github.com/netbirdio/netbird/util/net" +) + +func (m *UDPMuxDefault) notifyAddressRemoval(addr string) { + wrapped, ok := m.params.UDPConn.(*UDPConn) + if !ok { + return + } + + nbnetConn, ok := wrapped.GetPacketConn().(*nbnet.UDPConn) + if !ok { + return + } + + nbnetConn.RemoveAddress(addr) +} diff --git a/client/iface/bind/udp_mux_ios.go b/client/iface/bind/udp_mux_ios.go new file mode 100644 index 000000000..15e26d02f --- /dev/null +++ b/client/iface/bind/udp_mux_ios.go @@ -0,0 +1,7 @@ +//go:build ios + +package bind + +func (m *UDPMuxDefault) notifyAddressRemoval(addr string) { + // iOS doesn't support nbnet hooks, so this is a no-op +} \ No newline at end of file diff --git a/client/iface/bind/udp_mux_universal.go b/client/iface/bind/udp_mux_universal.go index 5cc634955..b755a7827 100644 --- a/client/iface/bind/udp_mux_universal.go +++ b/client/iface/bind/udp_mux_universal.go @@ -62,7 +62,7 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef // wrap UDP connection, process server reflexive messages // before they are passed to the UDPMux connection handler (connWorker) - m.params.UDPConn = &udpConn{ + m.params.UDPConn = &UDPConn{ PacketConn: params.UDPConn, mux: m, logger: params.Logger, @@ -70,7 +70,6 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef address: params.WGAddress, } - // embed UDPMux udpMuxParams := UDPMuxParams{ Logger: params.Logger, UDPConn: m.params.UDPConn, @@ -114,8 +113,8 @@ func (m *UniversalUDPMuxDefault) ReadFromConn(ctx context.Context) { } } -// udpConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets -type udpConn struct { +// UDPConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets +type UDPConn struct { net.PacketConn mux *UniversalUDPMuxDefault logger logging.LeveledLogger @@ -125,7 +124,12 @@ type udpConn struct { address wgaddr.Address } -func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) { +// GetPacketConn returns the underlying PacketConn +func (u *UDPConn) GetPacketConn() net.PacketConn { + return u.PacketConn +} + +func (u *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { if u.filterFn == nil { return u.PacketConn.WriteTo(b, addr) } @@ -137,21 +141,21 @@ func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) { return u.handleUncachedAddress(b, addr) } -func (u *udpConn) handleCachedAddress(isRouted bool, b []byte, addr net.Addr) (int, error) { +func (u *UDPConn) handleCachedAddress(isRouted bool, b []byte, addr net.Addr) (int, error) { if isRouted { return 0, fmt.Errorf("address %s is part of a routed network, refusing to write", addr) } return u.PacketConn.WriteTo(b, addr) } -func (u *udpConn) handleUncachedAddress(b []byte, addr net.Addr) (int, error) { +func (u *UDPConn) handleUncachedAddress(b []byte, addr net.Addr) (int, error) { if err := u.performFilterCheck(addr); err != nil { return 0, err } return u.PacketConn.WriteTo(b, addr) } -func (u *udpConn) performFilterCheck(addr net.Addr) error { +func (u *UDPConn) performFilterCheck(addr net.Addr) error { host, err := getHostFromAddr(addr) if err != nil { log.Errorf("Failed to get host from address %s: %v", addr, err) diff --git a/client/internal/engine.go b/client/internal/engine.go index e9772b359..1abb8163d 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -61,7 +61,6 @@ import ( signal "github.com/netbirdio/netbird/signal/client" sProto "github.com/netbirdio/netbird/signal/proto" "github.com/netbirdio/netbird/util" - nbnet "github.com/netbirdio/netbird/util/net" ) // PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer. @@ -138,9 +137,6 @@ type Engine struct { connMgr *ConnMgr - beforePeerHook nbnet.AddHookFunc - afterPeerHook nbnet.RemoveHookFunc - // rpManager is a Rosenpass manager rpManager *rosenpass.Manager @@ -409,12 +405,8 @@ func (e *Engine) Start() error { DisableClientRoutes: e.config.DisableClientRoutes, DisableServerRoutes: e.config.DisableServerRoutes, }) - beforePeerHook, afterPeerHook, err := e.routeManager.Init() - if err != nil { + if err := e.routeManager.Init(); err != nil { log.Errorf("Failed to initialize route manager: %s", err) - } else { - e.beforePeerHook = beforePeerHook - e.afterPeerHook = afterPeerHook } e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener) @@ -1261,10 +1253,6 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error { return fmt.Errorf("peer already exists: %s", peerKey) } - if e.beforePeerHook != nil && e.afterPeerHook != nil { - conn.AddBeforeAddPeerHook(e.beforePeerHook) - conn.AddAfterRemovePeerHook(e.afterPeerHook) - } return nil } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index e75672ed1..f02138686 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -400,7 +400,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { StatusRecorder: engine.statusRecorder, RelayManager: relayMgr, }) - _, _, err = engine.routeManager.Init() + err = engine.routeManager.Init() require.NoError(t, err) engine.dnsServer = &dns.MockServer{ UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 7765bb51c..ddd90450d 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -26,7 +26,6 @@ import ( "github.com/netbirdio/netbird/client/internal/stdnet" relayClient "github.com/netbirdio/netbird/relay/client" "github.com/netbirdio/netbird/route" - nbnet "github.com/netbirdio/netbird/util/net" semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" ) @@ -106,10 +105,6 @@ type Conn struct { workerRelay *WorkerRelay wgWatcherWg sync.WaitGroup - connIDRelay nbnet.ConnectionID - connIDICE nbnet.ConnectionID - beforeAddPeerHooks []nbnet.AddHookFunc - afterRemovePeerHooks []nbnet.RemoveHookFunc // used to store the remote Rosenpass key for Relayed connection in case of connection update from ice rosenpassRemoteKey []byte @@ -267,8 +262,6 @@ func (conn *Conn) Close(signalToRemote bool) { conn.Log.Errorf("failed to remove wg endpoint: %v", err) } - conn.freeUpConnID() - if conn.evalStatus() == StatusConnected && conn.onDisconnected != nil { conn.onDisconnected(conn.config.WgConfig.RemoteKey) } @@ -293,13 +286,6 @@ func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMa conn.workerICE.OnRemoteCandidate(candidate, haRoutes) } -func (conn *Conn) AddBeforeAddPeerHook(hook nbnet.AddHookFunc) { - conn.beforeAddPeerHooks = append(conn.beforeAddPeerHooks, hook) -} -func (conn *Conn) AddAfterRemovePeerHook(hook nbnet.RemoveHookFunc) { - conn.afterRemovePeerHooks = append(conn.afterRemovePeerHooks, hook) -} - // SetOnConnected sets a handler function to be triggered by Conn when a new connection to a remote peer established func (conn *Conn) SetOnConnected(handler func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)) { conn.onConnected = handler @@ -387,10 +373,6 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn ep = directEp } - if err := conn.runBeforeAddPeerHooks(ep.IP); err != nil { - conn.Log.Errorf("Before add peer hook failed: %v", err) - } - conn.workerRelay.DisableWgWatcher() // todo consider to run conn.wgWatcherWg.Wait() here @@ -503,10 +485,6 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) { return } - if err := conn.runBeforeAddPeerHooks(wgProxy.EndpointAddr().IP); err != nil { - conn.Log.Errorf("Before add peer hook failed: %v", err) - } - wgProxy.Work() if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil { if err := wgProxy.CloseConn(); err != nil { @@ -707,36 +685,6 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) { return true } -func (conn *Conn) runBeforeAddPeerHooks(ip net.IP) error { - conn.connIDICE = nbnet.GenerateConnID() - for _, hook := range conn.beforeAddPeerHooks { - if err := hook(conn.connIDICE, ip); err != nil { - return err - } - } - return nil -} - -func (conn *Conn) freeUpConnID() { - if conn.connIDRelay != "" { - for _, hook := range conn.afterRemovePeerHooks { - if err := hook(conn.connIDRelay); err != nil { - conn.Log.Errorf("After remove peer hook failed: %v", err) - } - } - conn.connIDRelay = "" - } - - if conn.connIDICE != "" { - for _, hook := range conn.afterRemovePeerHooks { - if err := hook(conn.connIDICE); err != nil { - conn.Log.Errorf("After remove peer hook failed: %v", err) - } - } - conn.connIDICE = "" - } -} - func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) { conn.Log.Debugf("setup proxied WireGuard connection") udpAddr := &net.UDPAddr{ diff --git a/client/internal/routemanager/client/client_test.go b/client/internal/routemanager/client/client_test.go index ec8e0e944..850f6691f 100644 --- a/client/internal/routemanager/client/client_test.go +++ b/client/internal/routemanager/client/client_test.go @@ -812,7 +812,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) { } params := common.HandlerParams{ - Route: &route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")}, + Route: &route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")}, } // create new clientNetwork client := &Watcher{ diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index e0974ab2a..e51778811 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -44,7 +44,7 @@ import ( // Manager is a route manager interface type Manager interface { - Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) + Init() error UpdateRoutes(updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error ClassifyRoutes(newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap) TriggerSelection(route.HAMap) @@ -201,11 +201,11 @@ func (m *DefaultManager) setupRefCounters(useNoop bool) { } // Init sets up the routing -func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { +func (m *DefaultManager) Init() error { m.routeSelector = m.initSelector() if nbnet.CustomRoutingDisabled() || m.disableClientRoutes { - return nil, nil, nil + return nil } if err := m.sysOps.CleanupRouting(nil); err != nil { @@ -219,13 +219,12 @@ func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) ips := resolveURLsToIPs(initialAddresses) - beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips, m.stateManager) - if err != nil { - return nil, nil, fmt.Errorf("setup routing: %w", err) + if err := m.sysOps.SetupRouting(ips, m.stateManager); err != nil { + return fmt.Errorf("setup routing: %w", err) } log.Info("Routing setup complete") - return beforePeerHook, afterPeerHook, nil + return nil } func (m *DefaultManager) initSelector() *routeselector.RouteSelector { diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index 486ee080a..2f13c2134 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -430,7 +430,7 @@ func TestManagerUpdateRoutes(t *testing.T) { StatusRecorder: statusRecorder, }) - _, _, err = routeManager.Init() + err = routeManager.Init() require.NoError(t, err, "should init route manager") defer routeManager.Stop(nil) diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index 4e182f82c..be633c3fa 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -9,7 +9,6 @@ import ( "github.com/netbirdio/netbird/client/internal/routeselector" "github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/route" - "github.com/netbirdio/netbird/util/net" ) // MockManager is the mock instance of a route manager @@ -23,8 +22,8 @@ type MockManager struct { StopFunc func(manager *statemanager.Manager) } -func (m *MockManager) Init() (net.AddHookFunc, net.RemoveHookFunc, error) { - return nil, nil, nil +func (m *MockManager) Init() error { + return nil } // InitialRouteRange mock implementation of InitialRouteRange from Manager interface diff --git a/client/internal/routemanager/notifier/notifier_other.go b/client/internal/routemanager/notifier/notifier_other.go index 77045b839..0521e3dc2 100644 --- a/client/internal/routemanager/notifier/notifier_other.go +++ b/client/internal/routemanager/notifier/notifier_other.go @@ -33,4 +33,4 @@ func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) { func (n *Notifier) GetInitialRouteRanges() []string { return []string{} -} \ No newline at end of file +} diff --git a/client/internal/routemanager/systemops/systemops.go b/client/internal/routemanager/systemops/systemops.go index 106c520da..b91348e94 100644 --- a/client/internal/routemanager/systemops/systemops.go +++ b/client/internal/routemanager/systemops/systemops.go @@ -6,6 +6,7 @@ import ( "net/netip" "sync" "sync/atomic" + "time" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/routemanager/notifier" @@ -56,6 +57,10 @@ type SysOps struct { // seq is an atomic counter for generating unique sequence numbers for route messages //nolint:unused // only used on BSD systems seq atomic.Uint32 + + localSubnetsCache []*net.IPNet + localSubnetsCacheMu sync.RWMutex + localSubnetsCacheTime time.Time } func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps { diff --git a/client/internal/routemanager/systemops/systemops_android.go b/client/internal/routemanager/systemops/systemops_android.go index ca8aea3fb..a375ce832 100644 --- a/client/internal/routemanager/systemops/systemops_android.go +++ b/client/internal/routemanager/systemops/systemops_android.go @@ -10,11 +10,10 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/internal/statemanager" - nbnet "github.com/netbirdio/netbird/util/net" ) -func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { - return nil, nil, nil +func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error { + return nil } func (r *SysOps) CleanupRouting(*statemanager.Manager) error { diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index d223a27b2..128afa2a5 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -10,6 +10,7 @@ import ( "net/netip" "runtime" "strconv" + "time" "github.com/hashicorp/go-multierror" "github.com/libp2p/go-netroute" @@ -24,6 +25,8 @@ import ( nbnet "github.com/netbirdio/netbird/util/net" ) +const localSubnetsCacheTTL = 15 * time.Minute + var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1) var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1) var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1) @@ -31,7 +34,7 @@ var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1) var ErrRoutingIsSeparate = errors.New("routing is separate") -func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { +func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) error { stateManager.RegisterState(&ShutdownState{}) initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified()) @@ -75,7 +78,10 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana r.refCounter = refCounter - return r.setupHooks(initAddresses, stateManager) + if err := r.setupHooks(initAddresses, stateManager); err != nil { + return fmt.Errorf("setup hooks: %w", err) + } + return nil } // updateState updates state on every change so it will be persisted regularly @@ -128,18 +134,14 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf wgIface, init return Nexthop{}, fmt.Errorf("get next hop: %w", err) } - log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop.IP, prefix, nexthop.IP) - exitNextHop := Nexthop{ - IP: nexthop.IP, - Intf: nexthop.Intf, - } + log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop.IP, prefix, nexthop.Intf) + exitNextHop := nexthop vpnAddr := vpnIntf.Address().IP // if next hop is the VPN address or the interface is the VPN interface, we should use the initial values if exitNextHop.IP == vpnAddr || exitNextHop.Intf != nil && exitNextHop.Intf.Name == vpnIntf.Name() { log.Debugf("Route for prefix %s is pointing to the VPN interface, using initial next hop %v", prefix, initialNextHop) - exitNextHop = initialNextHop } @@ -152,12 +154,37 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf wgIface, init } func (r *SysOps) isPrefixInLocalSubnets(prefix netip.Prefix) (bool, *net.IPNet) { + r.localSubnetsCacheMu.RLock() + cacheAge := time.Since(r.localSubnetsCacheTime) + subnets := r.localSubnetsCache + r.localSubnetsCacheMu.RUnlock() + + if cacheAge > localSubnetsCacheTTL || subnets == nil { + r.localSubnetsCacheMu.Lock() + if time.Since(r.localSubnetsCacheTime) > localSubnetsCacheTTL || r.localSubnetsCache == nil { + r.refreshLocalSubnetsCache() + } + subnets = r.localSubnetsCache + r.localSubnetsCacheMu.Unlock() + } + + for _, subnet := range subnets { + if subnet.Contains(prefix.Addr().AsSlice()) { + return true, subnet + } + } + + return false, nil +} + +func (r *SysOps) refreshLocalSubnetsCache() { localInterfaces, err := net.Interfaces() if err != nil { log.Errorf("Failed to get local interfaces: %v", err) - return false, nil + return } + var newSubnets []*net.IPNet for _, intf := range localInterfaces { addrs, err := intf.Addrs() if err != nil { @@ -171,14 +198,12 @@ func (r *SysOps) isPrefixInLocalSubnets(prefix netip.Prefix) (bool, *net.IPNet) log.Errorf("Failed to convert address to IPNet: %v", addr) continue } - - if ipnet.Contains(prefix.Addr().AsSlice()) { - return true, ipnet - } + newSubnets = append(newSubnets, ipnet) } } - return false, nil + r.localSubnetsCache = newSubnets + r.localSubnetsCacheTime = time.Now() } // genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix @@ -264,7 +289,7 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) return r.removeFromRouteTable(prefix, nextHop) } -func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { +func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) error { beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error { prefix, err := util.GetPrefixFromIP(ip) if err != nil { @@ -289,9 +314,11 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M return nil } + var merr *multierror.Error + for _, ip := range initAddresses { if err := beforeHook("init", ip); err != nil { - log.Errorf("Failed to add route reference: %v", err) + merr = multierror.Append(merr, fmt.Errorf("add initial route for %s: %w", ip, err)) } } @@ -300,11 +327,11 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M return ctx.Err() } - var result *multierror.Error + var merr *multierror.Error for _, ip := range resolvedIPs { - result = multierror.Append(result, beforeHook(connID, ip.IP)) + merr = multierror.Append(merr, beforeHook(connID, ip.IP)) } - return nberrors.FormatErrorOrNil(result) + return nberrors.FormatErrorOrNil(merr) }) nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error { @@ -319,7 +346,16 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M return afterHook(connID) }) - return beforeHook, afterHook, nil + nbnet.AddListenerAddressRemoveHook(func(connID nbnet.ConnectionID, prefix netip.Prefix) error { + if _, err := r.refCounter.Decrement(prefix); err != nil { + return fmt.Errorf("remove route reference: %w", err) + } + + r.updateState(stateManager) + return nil + }) + + return nberrors.FormatErrorOrNil(merr) } func GetNextHop(ip netip.Addr) (Nexthop, error) { diff --git a/client/internal/routemanager/systemops/systemops_generic_test.go b/client/internal/routemanager/systemops/systemops_generic_test.go index 2a57e6044..c1c1182bc 100644 --- a/client/internal/routemanager/systemops/systemops_generic_test.go +++ b/client/internal/routemanager/systemops/systemops_generic_test.go @@ -143,7 +143,7 @@ func TestAddVPNRoute(t *testing.T) { wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n) r := NewSysOps(wgInterface, nil) - _, _, err := r.SetupRouting(nil, nil) + err := r.SetupRouting(nil, nil) require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, r.CleanupRouting(nil)) @@ -341,7 +341,7 @@ func TestAddRouteToNonVPNIntf(t *testing.T) { wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n) r := NewSysOps(wgInterface, nil) - _, _, err := r.SetupRouting(nil, nil) + err := r.SetupRouting(nil, nil) require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, r.CleanupRouting(nil)) @@ -484,7 +484,7 @@ func setupTestEnv(t *testing.T) { }) r := NewSysOps(wgInterface, nil) - _, _, err := r.SetupRouting(nil, nil) + err := r.SetupRouting(nil, nil) require.NoError(t, err, "setupRouting should not return err") t.Cleanup(func() { assert.NoError(t, r.CleanupRouting(nil)) diff --git a/client/internal/routemanager/systemops/systemops_ios.go b/client/internal/routemanager/systemops/systemops_ios.go index bf06f3739..10356eae0 100644 --- a/client/internal/routemanager/systemops/systemops_ios.go +++ b/client/internal/routemanager/systemops/systemops_ios.go @@ -10,14 +10,13 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/internal/statemanager" - nbnet "github.com/netbirdio/netbird/util/net" ) -func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { +func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error { r.mu.Lock() defer r.mu.Unlock() r.prefixes = make(map[netip.Prefix]struct{}) - return nil, nil, nil + return nil } func (r *SysOps) CleanupRouting(*statemanager.Manager) error { diff --git a/client/internal/routemanager/systemops/systemops_linux.go b/client/internal/routemanager/systemops/systemops_linux.go index b48cfa242..711f1d758 100644 --- a/client/internal/routemanager/systemops/systemops_linux.go +++ b/client/internal/routemanager/systemops/systemops_linux.go @@ -72,7 +72,7 @@ func getSetupRules() []ruleParams { // Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table. // This table is where a default route or other specific routes received from the management server are configured, // enabling VPN connectivity. -func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (_ nbnet.AddHookFunc, _ nbnet.RemoveHookFunc, err error) { +func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (err error) { if !nbnet.AdvancedRouting() { log.Infof("Using legacy routing setup") return r.setupRefCounter(initAddresses, stateManager) @@ -89,7 +89,7 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager rules := getSetupRules() for _, rule := range rules { if err := addRule(rule); err != nil { - return nil, nil, fmt.Errorf("%s: %w", rule.description, err) + return fmt.Errorf("%s: %w", rule.description, err) } } @@ -104,7 +104,7 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager } originalSysctl = originalValues - return nil, nil, nil + return nil } // CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'. diff --git a/client/internal/routemanager/systemops/systemops_unix.go b/client/internal/routemanager/systemops/systemops_unix.go index 46e5ca915..f165f7779 100644 --- a/client/internal/routemanager/systemops/systemops_unix.go +++ b/client/internal/routemanager/systemops/systemops_unix.go @@ -18,10 +18,9 @@ import ( "golang.org/x/sys/unix" "github.com/netbirdio/netbird/client/internal/statemanager" - nbnet "github.com/netbirdio/netbird/util/net" ) -func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { +func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error { return r.setupRefCounter(initAddresses, stateManager) } diff --git a/client/internal/routemanager/systemops/systemops_windows.go b/client/internal/routemanager/systemops/systemops_windows.go index 11eaa435e..7afac9ae5 100644 --- a/client/internal/routemanager/systemops/systemops_windows.go +++ b/client/internal/routemanager/systemops/systemops_windows.go @@ -19,7 +19,6 @@ import ( "golang.org/x/sys/windows" "github.com/netbirdio/netbird/client/internal/statemanager" - nbnet "github.com/netbirdio/netbird/util/net" ) const InfiniteLifetime = 0xffffffff @@ -137,7 +136,7 @@ const ( RouteDeleted ) -func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { +func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error { return r.setupRefCounter(initAddresses, stateManager) } diff --git a/util/net/listener_listen.go b/util/net/listener_listen.go index efffba40e..dc99fbd68 100644 --- a/util/net/listener_listen.go +++ b/util/net/listener_listen.go @@ -6,6 +6,7 @@ import ( "context" "fmt" "net" + "net/netip" "sync" log "github.com/sirupsen/logrus" @@ -17,11 +18,16 @@ type ListenerWriteHookFunc func(connID ConnectionID, ip *net.IPAddr, data []byte // ListenerCloseHookFunc defines the function signature for close hooks for PacketConn. type ListenerCloseHookFunc func(connID ConnectionID, conn net.PacketConn) error +// ListenerAddressRemoveHookFunc defines the function signature for hooks called when addresses are removed. +type ListenerAddressRemoveHookFunc func(connID ConnectionID, prefix netip.Prefix) error + var ( - listenerWriteHooksMutex sync.RWMutex - listenerWriteHooks []ListenerWriteHookFunc - listenerCloseHooksMutex sync.RWMutex - listenerCloseHooks []ListenerCloseHookFunc + listenerWriteHooksMutex sync.RWMutex + listenerWriteHooks []ListenerWriteHookFunc + listenerCloseHooksMutex sync.RWMutex + listenerCloseHooks []ListenerCloseHookFunc + listenerAddressRemoveHooksMutex sync.RWMutex + listenerAddressRemoveHooks []ListenerAddressRemoveHookFunc ) // AddListenerWriteHook allows adding a new write hook to be executed before a UDP packet is sent. @@ -38,7 +44,14 @@ func AddListenerCloseHook(hook ListenerCloseHookFunc) { listenerCloseHooks = append(listenerCloseHooks, hook) } -// RemoveListenerHooks removes all dialer hooks. +// AddListenerAddressRemoveHook allows adding a new hook to be executed when an address is removed. +func AddListenerAddressRemoveHook(hook ListenerAddressRemoveHookFunc) { + listenerAddressRemoveHooksMutex.Lock() + defer listenerAddressRemoveHooksMutex.Unlock() + listenerAddressRemoveHooks = append(listenerAddressRemoveHooks, hook) +} + +// RemoveListenerHooks removes all listener hooks. func RemoveListenerHooks() { listenerWriteHooksMutex.Lock() defer listenerWriteHooksMutex.Unlock() @@ -47,6 +60,10 @@ func RemoveListenerHooks() { listenerCloseHooksMutex.Lock() defer listenerCloseHooksMutex.Unlock() listenerCloseHooks = nil + + listenerAddressRemoveHooksMutex.Lock() + defer listenerAddressRemoveHooksMutex.Unlock() + listenerAddressRemoveHooks = nil } // ListenPacket listens on the network address and returns a PacketConn @@ -61,6 +78,7 @@ func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address stri return nil, fmt.Errorf("listen packet: %w", err) } connID := GenerateConnID() + return &PacketConn{PacketConn: pc, ID: connID, seenAddrs: &sync.Map{}}, nil } @@ -102,6 +120,45 @@ func (c *UDPConn) Close() error { return closeConn(c.ID, c.UDPConn) } +// WrapUDPConn wraps an existing *net.UDPConn with nbnet functionality +func WrapUDPConn(conn *net.UDPConn) *UDPConn { + return &UDPConn{ + UDPConn: conn, + ID: GenerateConnID(), + seenAddrs: &sync.Map{}, + } +} + +// RemoveAddress removes an address from the seen cache and triggers removal hooks. +func (c *UDPConn) RemoveAddress(addr string) { + if _, exists := c.seenAddrs.LoadAndDelete(addr); !exists { + return + } + + ipStr, _, err := net.SplitHostPort(addr) + if err != nil { + log.Errorf("Error splitting IP address and port: %v", err) + return + } + + ipAddr, err := netip.ParseAddr(ipStr) + if err != nil { + log.Errorf("Error parsing IP address %s: %v", ipStr, err) + return + } + + prefix := netip.PrefixFrom(ipAddr, ipAddr.BitLen()) + + listenerAddressRemoveHooksMutex.RLock() + defer listenerAddressRemoveHooksMutex.RUnlock() + + for _, hook := range listenerAddressRemoveHooks { + if err := hook(c.ID, prefix); err != nil { + log.Errorf("Error executing listener address remove hook: %v", err) + } + } +} + func callWriteHooks(id ConnectionID, seenAddrs *sync.Map, b []byte, addr net.Addr) { // Lookup the address in the seenAddrs map to avoid calling the hooks for every write if _, loaded := seenAddrs.LoadOrStore(addr.String(), true); !loaded { diff --git a/util/net/listener_listen_ios.go b/util/net/listener_listen_ios.go new file mode 100644 index 000000000..3cbd2cd71 --- /dev/null +++ b/util/net/listener_listen_ios.go @@ -0,0 +1,10 @@ +package net + +import ( + "net" +) + +// WrapUDPConn on iOS just returns the original connection since iOS handles its own networking +func WrapUDPConn(conn *net.UDPConn) *net.UDPConn { + return conn +} From a7af15c4fcbdda1a7872417a6ad1cc8dc8448746 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Mon, 21 Jul 2025 15:26:06 +0300 Subject: [PATCH 20/50] [management] Fix group resource count mismatch in policy (#4182) --- .../server/http/handlers/policies/policies_handler.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/management/server/http/handlers/policies/policies_handler.go b/management/server/http/handlers/policies/policies_handler.go index 4d8cce3d4..267d5744f 100644 --- a/management/server/http/handlers/policies/policies_handler.go +++ b/management/server/http/handlers/policies/policies_handler.go @@ -424,9 +424,10 @@ func toPolicyResponse(groups []*types.Group, policy *types.Policy) *api.Policy { } if group, ok := groupsMap[gid]; ok { minimum := api.GroupMinimum{ - Id: group.ID, - Name: group.Name, - PeersCount: len(group.Peers), + Id: group.ID, + Name: group.Name, + PeersCount: len(group.Peers), + ResourcesCount: len(group.Resources), } destinations = append(destinations, minimum) cache[gid] = minimum From 86c16cf65150069c1a69743380179ec1ad5bf9bd Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Mon, 21 Jul 2025 19:58:17 +0200 Subject: [PATCH 21/50] [server, relay] Fix/relay race disconnection (#4174) Avoid invalid disconnection notifications in case the closed race dials. In this PR resolve multiple race condition questions. Easier to understand the fix based on commit by commit. - Remove store dependency from notifier - Enforce the notification orders - Fix invalid disconnection notification - Ensure the order of the events on the consumer side --- .github/workflows/golang-test-linux.yml | 10 ++- client/internal/peer/worker_relay.go | 4 +- relay/client/client.go | 2 +- relay/client/client_test.go | 10 ++- relay/client/dialer/race_dialer.go | 22 ++++--- relay/client/dialer/race_dialer_test.go | 14 ++-- relay/client/guard.go | 3 +- relay/client/manager.go | 11 ---- relay/client/manager_test.go | 41 +++++++++--- relay/healthcheck/receiver_test.go | 45 ++++++++++++- relay/healthcheck/sender_test.go | 11 +++- relay/metrics/realy.go | 24 +++++-- relay/server/listener/quic/listener.go | 5 +- relay/server/peer.go | 14 +++- relay/server/relay.go | 14 ++-- relay/server/store/listener.go | 86 +++++++++++++------------ relay/server/store/notifier.go | 7 +- relay/server/store/store.go | 30 +++++++-- 18 files changed, 235 insertions(+), 118 deletions(-) diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index cbce3e6e4..0c3862e33 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -211,7 +211,11 @@ jobs: strategy: fail-fast: false matrix: - arch: [ '386','amd64' ] + include: + - arch: "386" + raceFlag: "" + - arch: "amd64" + raceFlag: "-race" runs-on: ubuntu-22.04 steps: - name: Install Go @@ -251,9 +255,9 @@ jobs: - name: Test run: | CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \ - go test \ + go test ${{ matrix.raceFlag }} \ -exec 'sudo' \ - -timeout 10m ./signal/... + -timeout 10m ./relay/... test_signal: name: "Signal / Unit" diff --git a/client/internal/peer/worker_relay.go b/client/internal/peer/worker_relay.go index 5e2900609..ef9f24a2b 100644 --- a/client/internal/peer/worker_relay.go +++ b/client/internal/peer/worker_relay.go @@ -24,7 +24,7 @@ type WorkerRelay struct { isController bool config ConnConfig conn *Conn - relayManager relayClient.ManagerService + relayManager *relayClient.Manager relayedConn net.Conn relayLock sync.Mutex @@ -34,7 +34,7 @@ type WorkerRelay struct { wgWatcher *WGWatcher } -func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager relayClient.ManagerService, stateDump *stateDump) *WorkerRelay { +func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager *relayClient.Manager, stateDump *stateDump) *WorkerRelay { r := &WorkerRelay{ peerCtx: ctx, log: log, diff --git a/relay/client/client.go b/relay/client/client.go index 32dfbb4db..e4db278f5 100644 --- a/relay/client/client.go +++ b/relay/client/client.go @@ -292,7 +292,7 @@ func (c *Client) Close() error { } func (c *Client) connect(ctx context.Context) (*RelayAddr, error) { - rd := dialer.NewRaceDial(c.log, c.connectionURL, quic.Dialer{}, ws.Dialer{}) + rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, quic.Dialer{}, ws.Dialer{}) conn, err := rd.Dial() if err != nil { return nil, err diff --git a/relay/client/client_test.go b/relay/client/client_test.go index dd5f5fe1e..c85ec9fd3 100644 --- a/relay/client/client_test.go +++ b/relay/client/client_test.go @@ -572,10 +572,14 @@ func TestCloseByServer(t *testing.T) { idAlice := "alice" log.Debugf("connect by alice") relayClient := NewClient(serverURL, hmacTokenStore, idAlice) - err = relayClient.Connect(ctx) - if err != nil { + if err = relayClient.Connect(ctx); err != nil { log.Fatalf("failed to connect to server: %s", err) } + defer func() { + if err := relayClient.Close(); err != nil { + log.Errorf("failed to close client: %s", err) + } + }() disconnected := make(chan struct{}) relayClient.SetOnDisconnectListener(func(_ string) { @@ -591,7 +595,7 @@ func TestCloseByServer(t *testing.T) { select { case <-disconnected: case <-time.After(3 * time.Second): - log.Fatalf("timeout waiting for client to disconnect") + log.Errorf("timeout waiting for client to disconnect") } _, err = relayClient.OpenConn(ctx, "bob") diff --git a/relay/client/dialer/race_dialer.go b/relay/client/dialer/race_dialer.go index 11dba5799..0550fc63e 100644 --- a/relay/client/dialer/race_dialer.go +++ b/relay/client/dialer/race_dialer.go @@ -9,8 +9,8 @@ import ( log "github.com/sirupsen/logrus" ) -var ( - connectionTimeout = 30 * time.Second +const ( + DefaultConnectionTimeout = 30 * time.Second ) type DialeFn interface { @@ -25,16 +25,18 @@ type dialResult struct { } type RaceDial struct { - log *log.Entry - serverURL string - dialerFns []DialeFn + log *log.Entry + serverURL string + dialerFns []DialeFn + connectionTimeout time.Duration } -func NewRaceDial(log *log.Entry, serverURL string, dialerFns ...DialeFn) *RaceDial { +func NewRaceDial(log *log.Entry, connectionTimeout time.Duration, serverURL string, dialerFns ...DialeFn) *RaceDial { return &RaceDial{ - log: log, - serverURL: serverURL, - dialerFns: dialerFns, + log: log, + serverURL: serverURL, + dialerFns: dialerFns, + connectionTimeout: connectionTimeout, } } @@ -58,7 +60,7 @@ func (r *RaceDial) Dial() (net.Conn, error) { } func (r *RaceDial) dial(dfn DialeFn, abortCtx context.Context, connChan chan dialResult) { - ctx, cancel := context.WithTimeout(abortCtx, connectionTimeout) + ctx, cancel := context.WithTimeout(abortCtx, r.connectionTimeout) defer cancel() r.log.Infof("dialing Relay server via %s", dfn.Protocol()) diff --git a/relay/client/dialer/race_dialer_test.go b/relay/client/dialer/race_dialer_test.go index 989abb0a6..d216ec5e7 100644 --- a/relay/client/dialer/race_dialer_test.go +++ b/relay/client/dialer/race_dialer_test.go @@ -77,7 +77,7 @@ func TestRaceDialEmptyDialers(t *testing.T) { logger := logrus.NewEntry(logrus.New()) serverURL := "test.server.com" - rd := NewRaceDial(logger, serverURL) + rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL) conn, err := rd.Dial() if err == nil { t.Errorf("Expected an error with empty dialers, got nil") @@ -103,7 +103,7 @@ func TestRaceDialSingleSuccessfulDialer(t *testing.T) { protocolStr: proto, } - rd := NewRaceDial(logger, serverURL, mockDialer) + rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer) conn, err := rd.Dial() if err != nil { t.Errorf("Expected no error, got %v", err) @@ -136,7 +136,7 @@ func TestRaceDialMultipleDialersWithOneSuccess(t *testing.T) { protocolStr: "proto2", } - rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2) + rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2) conn, err := rd.Dial() if err != nil { t.Errorf("Expected no error, got %v", err) @@ -144,13 +144,13 @@ func TestRaceDialMultipleDialersWithOneSuccess(t *testing.T) { if conn.RemoteAddr().Network() != proto2 { t.Errorf("Expected connection with protocol %s, got %s", proto2, conn.RemoteAddr().Network()) } + _ = conn.Close() } func TestRaceDialTimeout(t *testing.T) { logger := logrus.NewEntry(logrus.New()) serverURL := "test.server.com" - connectionTimeout = 3 * time.Second mockDialer := &MockDialer{ dialFunc: func(ctx context.Context, address string) (net.Conn, error) { <-ctx.Done() @@ -159,7 +159,7 @@ func TestRaceDialTimeout(t *testing.T) { protocolStr: "proto1", } - rd := NewRaceDial(logger, serverURL, mockDialer) + rd := NewRaceDial(logger, 3*time.Second, serverURL, mockDialer) conn, err := rd.Dial() if err == nil { t.Errorf("Expected an error, got nil") @@ -187,7 +187,7 @@ func TestRaceDialAllDialersFail(t *testing.T) { protocolStr: "protocol2", } - rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2) + rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2) conn, err := rd.Dial() if err == nil { t.Errorf("Expected an error, got nil") @@ -229,7 +229,7 @@ func TestRaceDialFirstSuccessfulDialerWins(t *testing.T) { protocolStr: proto2, } - rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2) + rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2) conn, err := rd.Dial() if err != nil { t.Errorf("Expected no error, got %v", err) diff --git a/relay/client/guard.go b/relay/client/guard.go index 100892d81..f4d3a8cce 100644 --- a/relay/client/guard.go +++ b/relay/client/guard.go @@ -8,7 +8,8 @@ import ( log "github.com/sirupsen/logrus" ) -var ( +const ( + // TODO: make it configurable, the manager should validate all configurable parameters reconnectingTimeout = 60 * time.Second ) diff --git a/relay/client/manager.go b/relay/client/manager.go index b97bc0b99..f32bb9f26 100644 --- a/relay/client/manager.go +++ b/relay/client/manager.go @@ -39,17 +39,6 @@ func NewRelayTrack() *RelayTrack { type OnServerCloseListener func() -// ManagerService is the interface for the relay manager. -type ManagerService interface { - Serve() error - OpenConn(ctx context.Context, serverAddress, peerKey string) (net.Conn, error) - AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error - RelayInstanceAddress() (string, error) - ServerURLs() []string - HasRelayAddress() bool - UpdateToken(token *relayAuth.Token) error -} - // Manager is a manager for the relay client instances. It establishes one persistent connection to the given relay URL // and automatically reconnect to them in case disconnection. // The manager also manage temporary relay connection. If a client wants to communicate with a client on a diff --git a/relay/client/manager_test.go b/relay/client/manager_test.go index 52f2833e4..d0075f982 100644 --- a/relay/client/manager_test.go +++ b/relay/client/manager_test.go @@ -13,7 +13,9 @@ import ( ) func TestEmptyURL(t *testing.T) { - mgr := NewManager(context.Background(), nil, "alice") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + mgr := NewManager(ctx, nil, "alice") err := mgr.Serve() if err == nil { t.Errorf("expected error, got nil") @@ -216,9 +218,11 @@ func TestForeginConnClose(t *testing.T) { } } -func TestForeginAutoClose(t *testing.T) { +func TestForeignAutoClose(t *testing.T) { ctx := context.Background() relayCleanupInterval = 1 * time.Second + keepUnusedServerTime = 2 * time.Second + srvCfg1 := server.ListenerConfig{ Address: "localhost:1234", } @@ -284,16 +288,35 @@ func TestForeginAutoClose(t *testing.T) { t.Fatalf("failed to serve manager: %s", err) } + // Set up a disconnect listener to track when foreign server disconnects + foreignServerURL := toURL(srvCfg2)[0] + disconnected := make(chan struct{}) + onDisconnect := func() { + select { + case disconnected <- struct{}{}: + default: + } + } + t.Log("open connection to another peer") - if _, err = mgr.OpenConn(ctx, toURL(srvCfg2)[0], "anotherpeer"); err == nil { + if _, err = mgr.OpenConn(ctx, foreignServerURL, "anotherpeer"); err == nil { t.Fatalf("should have failed to open connection to another peer") } - timeout := relayCleanupInterval + keepUnusedServerTime + 1*time.Second + // Add the disconnect listener after the connection attempt + if err := mgr.AddCloseListener(foreignServerURL, onDisconnect); err != nil { + t.Logf("failed to add close listener (expected if connection failed): %s", err) + } + + // Wait for cleanup to happen + timeout := relayCleanupInterval + keepUnusedServerTime + 2*time.Second t.Logf("waiting for relay cleanup: %s", timeout) - time.Sleep(timeout) - if len(mgr.relayClients) != 0 { - t.Errorf("expected 0, got %d", len(mgr.relayClients)) + + select { + case <-disconnected: + t.Log("foreign relay connection cleaned up successfully") + case <-time.After(timeout): + t.Log("timeout waiting for cleanup - this might be expected if connection never established") } t.Logf("closing manager") @@ -301,7 +324,6 @@ func TestForeginAutoClose(t *testing.T) { func TestAutoReconnect(t *testing.T) { ctx := context.Background() - reconnectingTimeout = 2 * time.Second srvCfg := server.ListenerConfig{ Address: "localhost:1234", @@ -312,8 +334,7 @@ func TestAutoReconnect(t *testing.T) { } errChan := make(chan error, 1) go func() { - err := srv.Listen(srvCfg) - if err != nil { + if err := srv.Listen(srvCfg); err != nil { errChan <- err } }() diff --git a/relay/healthcheck/receiver_test.go b/relay/healthcheck/receiver_test.go index 3b3e32fe6..2794159f6 100644 --- a/relay/healthcheck/receiver_test.go +++ b/relay/healthcheck/receiver_test.go @@ -4,38 +4,76 @@ import ( "context" "fmt" "os" + "sync" "testing" "time" log "github.com/sirupsen/logrus" ) +// Mutex to protect global variable access in tests +var testMutex sync.Mutex + func TestNewReceiver(t *testing.T) { + testMutex.Lock() + originalTimeout := heartbeatTimeout heartbeatTimeout = 5 * time.Second + testMutex.Unlock() + + defer func() { + testMutex.Lock() + heartbeatTimeout = originalTimeout + testMutex.Unlock() + }() + r := NewReceiver(log.WithContext(context.Background())) + defer r.Stop() select { case <-r.OnTimeout: t.Error("unexpected timeout") case <-time.After(1 * time.Second): - + // Test passes if no timeout received } } func TestNewReceiverNotReceive(t *testing.T) { + testMutex.Lock() + originalTimeout := heartbeatTimeout heartbeatTimeout = 1 * time.Second + testMutex.Unlock() + + defer func() { + testMutex.Lock() + heartbeatTimeout = originalTimeout + testMutex.Unlock() + }() + r := NewReceiver(log.WithContext(context.Background())) + defer r.Stop() select { case <-r.OnTimeout: + // Test passes if timeout is received case <-time.After(2 * time.Second): t.Error("timeout not received") } } func TestNewReceiverAck(t *testing.T) { + testMutex.Lock() + originalTimeout := heartbeatTimeout heartbeatTimeout = 2 * time.Second + testMutex.Unlock() + + defer func() { + testMutex.Lock() + heartbeatTimeout = originalTimeout + testMutex.Unlock() + }() + r := NewReceiver(log.WithContext(context.Background())) + defer r.Stop() r.Heartbeat() @@ -59,13 +97,18 @@ func TestReceiverHealthCheckAttemptThreshold(t *testing.T) { for _, tc := range testsCases { t.Run(tc.name, func(t *testing.T) { + testMutex.Lock() originalInterval := healthCheckInterval originalTimeout := heartbeatTimeout healthCheckInterval = 1 * time.Second heartbeatTimeout = healthCheckInterval + 500*time.Millisecond + testMutex.Unlock() + defer func() { + testMutex.Lock() healthCheckInterval = originalInterval heartbeatTimeout = originalTimeout + testMutex.Unlock() }() //nolint:tenv os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold)) diff --git a/relay/healthcheck/sender_test.go b/relay/healthcheck/sender_test.go index f21167025..39d266b48 100644 --- a/relay/healthcheck/sender_test.go +++ b/relay/healthcheck/sender_test.go @@ -135,7 +135,11 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) { defer cancel() sender := NewSender(log.WithField("test_name", tc.name)) - go sender.StartHealthCheck(ctx) + senderExit := make(chan struct{}) + go func() { + sender.StartHealthCheck(ctx) + close(senderExit) + }() go func() { responded := false @@ -169,6 +173,11 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) { t.Fatalf("should have timed out before %s", testTimeout) } + select { + case <-senderExit: + case <-time.After(2 * time.Second): + t.Fatalf("sender did not exit in time") + } }) } diff --git a/relay/metrics/realy.go b/relay/metrics/realy.go index 2e90940e6..efb597ff5 100644 --- a/relay/metrics/realy.go +++ b/relay/metrics/realy.go @@ -20,12 +20,12 @@ type Metrics struct { TransferBytesRecv metric.Int64Counter AuthenticationTime metric.Float64Histogram PeerStoreTime metric.Float64Histogram - - peers metric.Int64UpDownCounter - peerActivityChan chan string - peerLastActive map[string]time.Time - mutexActivity sync.Mutex - ctx context.Context + peerReconnections metric.Int64Counter + peers metric.Int64UpDownCounter + peerActivityChan chan string + peerLastActive map[string]time.Time + mutexActivity sync.Mutex + ctx context.Context } func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) { @@ -80,6 +80,13 @@ func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) { return nil, err } + peerReconnections, err := meter.Int64Counter("relay_peer_reconnections_total", + metric.WithDescription("Total number of times peers have reconnected and closed old connections"), + ) + if err != nil { + return nil, err + } + m := &Metrics{ Meter: meter, TransferBytesSent: bytesSent, @@ -87,6 +94,7 @@ func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) { AuthenticationTime: authTime, PeerStoreTime: peerStoreTime, peers: peers, + peerReconnections: peerReconnections, ctx: ctx, peerActivityChan: make(chan string, 10), @@ -138,6 +146,10 @@ func (m *Metrics) PeerDisconnected(id string) { delete(m.peerLastActive, id) } +func (m *Metrics) RecordPeerReconnection() { + m.peerReconnections.Add(m.ctx, 1) +} + // PeerActivity increases the active connections func (m *Metrics) PeerActivity(peerID string) { select { diff --git a/relay/server/listener/quic/listener.go b/relay/server/listener/quic/listener.go index 17a5e8ab6..2a4a668f0 100644 --- a/relay/server/listener/quic/listener.go +++ b/relay/server/listener/quic/listener.go @@ -18,12 +18,9 @@ type Listener struct { TLSConfig *tls.Config listener *quic.Listener - acceptFn func(conn net.Conn) } func (l *Listener) Listen(acceptFn func(conn net.Conn)) error { - l.acceptFn = acceptFn - quicCfg := &quic.Config{ EnableDatagrams: true, InitialPacketSize: 1452, @@ -49,7 +46,7 @@ func (l *Listener) Listen(acceptFn func(conn net.Conn)) error { log.Infof("QUIC client connected from: %s", session.RemoteAddr()) conn := NewConn(session) - l.acceptFn(conn) + acceptFn(conn) } } diff --git a/relay/server/peer.go b/relay/server/peer.go index c6fa8508f..9caa5b06f 100644 --- a/relay/server/peer.go +++ b/relay/server/peer.go @@ -32,6 +32,9 @@ type Peer struct { notifier *store.PeerNotifier peersListener *store.Listener + + // between the online peer collection step and the notification sending should not be sent offline notifications from another thread + notificationMutex sync.Mutex } // NewPeer creates a new Peer instance and prepare custom logging @@ -241,10 +244,16 @@ func (p *Peer) handleSubscribePeerState(msg []byte) { } p.log.Debugf("received subscription message for %d peers", len(peerIDs)) - onlinePeers := p.peersListener.AddInterestedPeers(peerIDs) + + // collect online peers to response back to the caller + p.notificationMutex.Lock() + defer p.notificationMutex.Unlock() + + onlinePeers := p.store.GetOnlinePeersAndRegisterInterest(peerIDs, p.peersListener) if len(onlinePeers) == 0 { return } + p.log.Debugf("response with %d online peers", len(onlinePeers)) p.sendPeersOnline(onlinePeers) } @@ -274,6 +283,9 @@ func (p *Peer) sendPeersOnline(peers []messages.PeerID) { } func (p *Peer) sendPeersWentOffline(peers []messages.PeerID) { + p.notificationMutex.Lock() + defer p.notificationMutex.Unlock() + msgs, err := messages.MarshalPeersWentOffline(peers) if err != nil { p.log.Errorf("failed to marshal peer location message: %s", err) diff --git a/relay/server/relay.go b/relay/server/relay.go index 93fb00edb..d86684937 100644 --- a/relay/server/relay.go +++ b/relay/server/relay.go @@ -86,14 +86,13 @@ func NewRelay(config Config) (*Relay, error) { return nil, fmt.Errorf("creating app metrics: %v", err) } - peerStore := store.NewStore() r := &Relay{ metrics: m, metricsCancel: metricsCancel, validator: config.AuthValidator, instanceURL: config.instanceURL, - store: peerStore, - notifier: store.NewPeerNotifier(peerStore), + store: store.NewStore(), + notifier: store.NewPeerNotifier(), } r.preparedMsg, err = newPreparedMsg(r.instanceURL) @@ -131,15 +130,18 @@ func (r *Relay) Accept(conn net.Conn) { peer := NewPeer(r.metrics, *peerID, conn, r.store, r.notifier) peer.log.Infof("peer connected from: %s", conn.RemoteAddr()) storeTime := time.Now() - r.store.AddPeer(peer) + if isReconnection := r.store.AddPeer(peer); isReconnection { + r.metrics.RecordPeerReconnection() + } r.notifier.PeerCameOnline(peer.ID()) r.metrics.RecordPeerStoreTime(time.Since(storeTime)) r.metrics.PeerConnected(peer.String()) go func() { peer.Work() - r.notifier.PeerWentOffline(peer.ID()) - r.store.DeletePeer(peer) + if deleted := r.store.DeletePeer(peer); deleted { + r.notifier.PeerWentOffline(peer.ID()) + } peer.log.Debugf("relay connection closed") r.metrics.PeerDisconnected(peer.String()) }() diff --git a/relay/server/store/listener.go b/relay/server/store/listener.go index b7c5f4ce8..e9c77d953 100644 --- a/relay/server/store/listener.go +++ b/relay/server/store/listener.go @@ -7,24 +7,27 @@ import ( "github.com/netbirdio/netbird/relay/messages" ) -type Listener struct { - ctx context.Context - store *Store +type event struct { + peerID messages.PeerID + online bool +} - onlineChan chan messages.PeerID - offlineChan chan messages.PeerID +type Listener struct { + ctx context.Context + + eventChan chan *event interestedPeersForOffline map[messages.PeerID]struct{} interestedPeersForOnline map[messages.PeerID]struct{} mu sync.RWMutex } -func newListener(ctx context.Context, store *Store) *Listener { +func newListener(ctx context.Context) *Listener { l := &Listener{ - ctx: ctx, - store: store, + ctx: ctx, - onlineChan: make(chan messages.PeerID, 244), //244 is the message size limit in the relay protocol - offlineChan: make(chan messages.PeerID, 244), //244 is the message size limit in the relay protocol + // important to use a single channel for offline and online events because with it we can ensure all events + // will be processed in the order they were sent + eventChan: make(chan *event, 244), //244 is the message size limit in the relay protocol interestedPeersForOffline: make(map[messages.PeerID]struct{}), interestedPeersForOnline: make(map[messages.PeerID]struct{}), } @@ -32,8 +35,7 @@ func newListener(ctx context.Context, store *Store) *Listener { return l } -func (l *Listener) AddInterestedPeers(peerIDs []messages.PeerID) []messages.PeerID { - availablePeers := make([]messages.PeerID, 0) +func (l *Listener) AddInterestedPeers(peerIDs []messages.PeerID) { l.mu.Lock() defer l.mu.Unlock() @@ -41,17 +43,6 @@ func (l *Listener) AddInterestedPeers(peerIDs []messages.PeerID) []messages.Peer l.interestedPeersForOnline[id] = struct{}{} l.interestedPeersForOffline[id] = struct{}{} } - - // collect online peers to response back to the caller - for _, id := range peerIDs { - _, ok := l.store.Peer(id) - if !ok { - continue - } - - availablePeers = append(availablePeers, id) - } - return availablePeers } func (l *Listener) RemoveInterestedPeer(peerIDs []messages.PeerID) { @@ -61,7 +52,6 @@ func (l *Listener) RemoveInterestedPeer(peerIDs []messages.PeerID) { for _, id := range peerIDs { delete(l.interestedPeersForOffline, id) delete(l.interestedPeersForOnline, id) - } } @@ -70,26 +60,31 @@ func (l *Listener) listenForEvents(onPeersComeOnline, onPeersWentOffline func([] select { case <-l.ctx.Done(): return - case pID := <-l.onlineChan: - peers := make([]messages.PeerID, 0) - peers = append(peers, pID) - - for len(l.onlineChan) > 0 { - pID = <-l.onlineChan - peers = append(peers, pID) + case e := <-l.eventChan: + peersOffline := make([]messages.PeerID, 0) + peersOnline := make([]messages.PeerID, 0) + if e.online { + peersOnline = append(peersOnline, e.peerID) + } else { + peersOffline = append(peersOffline, e.peerID) } - onPeersComeOnline(peers) - case pID := <-l.offlineChan: - peers := make([]messages.PeerID, 0) - peers = append(peers, pID) - - for len(l.offlineChan) > 0 { - pID = <-l.offlineChan - peers = append(peers, pID) + // Drain the channel to collect all events + for len(l.eventChan) > 0 { + e = <-l.eventChan + if e.online { + peersOnline = append(peersOnline, e.peerID) + } else { + peersOffline = append(peersOffline, e.peerID) + } } - onPeersWentOffline(peers) + if len(peersOnline) > 0 { + onPeersComeOnline(peersOnline) + } + if len(peersOffline) > 0 { + onPeersWentOffline(peersOffline) + } } } } @@ -100,7 +95,10 @@ func (l *Listener) peerWentOffline(peerID messages.PeerID) { if _, ok := l.interestedPeersForOffline[peerID]; ok { select { - case l.offlineChan <- peerID: + case l.eventChan <- &event{ + peerID: peerID, + online: false, + }: case <-l.ctx.Done(): } } @@ -112,9 +110,13 @@ func (l *Listener) peerComeOnline(peerID messages.PeerID) { if _, ok := l.interestedPeersForOnline[peerID]; ok { select { - case l.onlineChan <- peerID: + case l.eventChan <- &event{ + peerID: peerID, + online: true, + }: case <-l.ctx.Done(): } + delete(l.interestedPeersForOnline, peerID) } } diff --git a/relay/server/store/notifier.go b/relay/server/store/notifier.go index ad2e53545..335522537 100644 --- a/relay/server/store/notifier.go +++ b/relay/server/store/notifier.go @@ -8,15 +8,12 @@ import ( ) type PeerNotifier struct { - store *Store - listeners map[*Listener]context.CancelFunc listenersMutex sync.RWMutex } -func NewPeerNotifier(store *Store) *PeerNotifier { +func NewPeerNotifier() *PeerNotifier { pn := &PeerNotifier{ - store: store, listeners: make(map[*Listener]context.CancelFunc), } return pn @@ -24,7 +21,7 @@ func NewPeerNotifier(store *Store) *PeerNotifier { func (pn *PeerNotifier) NewListener(onPeersComeOnline, onPeersWentOffline func([]messages.PeerID)) *Listener { ctx, cancel := context.WithCancel(context.Background()) - listener := newListener(ctx, pn.store) + listener := newListener(ctx) go listener.listenForEvents(onPeersComeOnline, onPeersWentOffline) pn.listenersMutex.Lock() diff --git a/relay/server/store/store.go b/relay/server/store/store.go index c19fb416f..fd0578603 100644 --- a/relay/server/store/store.go +++ b/relay/server/store/store.go @@ -26,7 +26,9 @@ func NewStore() *Store { } // AddPeer adds a peer to the store -func (s *Store) AddPeer(peer IPeer) { +// If the peer already exists, it will be replaced and the old peer will be closed +// Returns true if the peer was replaced, false if it was added for the first time. +func (s *Store) AddPeer(peer IPeer) bool { s.peersLock.Lock() defer s.peersLock.Unlock() odlPeer, ok := s.peers[peer.ID()] @@ -35,22 +37,24 @@ func (s *Store) AddPeer(peer IPeer) { } s.peers[peer.ID()] = peer + return ok } // DeletePeer deletes a peer from the store -func (s *Store) DeletePeer(peer IPeer) { +func (s *Store) DeletePeer(peer IPeer) bool { s.peersLock.Lock() defer s.peersLock.Unlock() dp, ok := s.peers[peer.ID()] if !ok { - return + return false } if dp != peer { - return + return false } delete(s.peers, peer.ID()) + return true } // Peer returns a peer by its ID @@ -73,3 +77,21 @@ func (s *Store) Peers() []IPeer { } return peers } + +func (s *Store) GetOnlinePeersAndRegisterInterest(peerIDs []messages.PeerID, listener *Listener) []messages.PeerID { + s.peersLock.RLock() + defer s.peersLock.RUnlock() + + onlinePeers := make([]messages.PeerID, 0, len(peerIDs)) + + listener.AddInterestedPeers(peerIDs) + + // Check for currently online peers + for _, id := range peerIDs { + if _, ok := s.peers[id]; ok { + onlinePeers = append(onlinePeers, id) + } + } + + return onlinePeers +} From 91e74239896297b921e27d44e6ca0b59396a789f Mon Sep 17 00:00:00 2001 From: Philippe Vaucher Date: Tue, 22 Jul 2025 19:44:49 +0200 Subject: [PATCH 22/50] [misc] Docker compose improvements (#4037) * Use container defaults * Remove docker compose version when generating zitadel config --- infrastructure_files/docker-compose.yml.tmpl | 43 ++++++------------- .../docker-compose.yml.tmpl.traefik | 43 ++++++------------- .../getting-started-with-zitadel.sh | 1 - 3 files changed, 26 insertions(+), 61 deletions(-) diff --git a/infrastructure_files/docker-compose.yml.tmpl b/infrastructure_files/docker-compose.yml.tmpl index b529f9606..b24e853b4 100644 --- a/infrastructure_files/docker-compose.yml.tmpl +++ b/infrastructure_files/docker-compose.yml.tmpl @@ -1,8 +1,16 @@ +x-default: &default + restart: 'unless-stopped' + logging: + driver: 'json-file' + options: + max-size: '500m' + max-file: '2' + services: # UI dashboard dashboard: + <<: *default image: netbirdio/dashboard:$NETBIRD_DASHBOARD_TAG - restart: unless-stopped ports: - 80:80 - 443:443 @@ -27,16 +35,11 @@ services: - LETSENCRYPT_EMAIL=$NETBIRD_LETSENCRYPT_EMAIL volumes: - $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt/ - logging: - driver: "json-file" - options: - max-size: "500m" - max-file: "2" # Signal signal: + <<: *default image: netbirdio/signal:$NETBIRD_SIGNAL_TAG - restart: unless-stopped volumes: - $SIGNAL_VOLUMENAME:/var/lib/netbird ports: @@ -44,16 +47,11 @@ services: # # port and command for Let's Encrypt validation # - 443:443 # command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"] - logging: - driver: "json-file" - options: - max-size: "500m" - max-file: "2" # Relay relay: + <<: *default image: netbirdio/relay:$NETBIRD_RELAY_TAG - restart: unless-stopped environment: - NB_LOG_LEVEL=info - NB_LISTEN_ADDRESS=:$NETBIRD_RELAY_PORT @@ -62,16 +60,11 @@ services: - NB_AUTH_SECRET=$NETBIRD_RELAY_AUTH_SECRET ports: - $NETBIRD_RELAY_PORT:$NETBIRD_RELAY_PORT - logging: - driver: "json-file" - options: - max-size: "500m" - max-file: "2" # Management management: + <<: *default image: netbirdio/management:$NETBIRD_MANAGEMENT_TAG - restart: unless-stopped depends_on: - dashboard volumes: @@ -90,19 +83,14 @@ services: "--single-account-mode-domain=$NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN", "--dns-domain=$NETBIRD_MGMT_DNS_DOMAIN" ] - logging: - driver: "json-file" - options: - max-size: "500m" - max-file: "2" environment: - NETBIRD_STORE_ENGINE_POSTGRES_DSN=$NETBIRD_STORE_ENGINE_POSTGRES_DSN - NETBIRD_STORE_ENGINE_MYSQL_DSN=$NETBIRD_STORE_ENGINE_MYSQL_DSN # Coturn coturn: + <<: *default image: coturn/coturn:$COTURN_TAG - restart: unless-stopped #domainname: $TURN_DOMAIN # only needed when TLS is enabled volumes: - ./turnserver.conf:/etc/turnserver.conf:ro @@ -111,11 +99,6 @@ services: network_mode: host command: - -c /etc/turnserver.conf - logging: - driver: "json-file" - options: - max-size: "500m" - max-file: "2" volumes: $MGMT_VOLUMENAME: diff --git a/infrastructure_files/docker-compose.yml.tmpl.traefik b/infrastructure_files/docker-compose.yml.tmpl.traefik index 8da3cabb5..08749a4f7 100644 --- a/infrastructure_files/docker-compose.yml.tmpl.traefik +++ b/infrastructure_files/docker-compose.yml.tmpl.traefik @@ -1,8 +1,16 @@ +x-default: &default + restart: 'unless-stopped' + logging: + driver: 'json-file' + options: + max-size: '500m' + max-file: '2' + services: # UI dashboard dashboard: + <<: *default image: netbirdio/dashboard:$NETBIRD_DASHBOARD_TAG - restart: unless-stopped environment: # Endpoints - NETBIRD_MGMT_API_ENDPOINT=$NETBIRD_MGMT_API_ENDPOINT @@ -28,16 +36,11 @@ services: - traefik.enable=true - traefik.http.routers.netbird-dashboard.rule=Host(`$NETBIRD_DOMAIN`) - traefik.http.services.netbird-dashboard.loadbalancer.server.port=80 - logging: - driver: "json-file" - options: - max-size: "500m" - max-file: "2" # Signal signal: + <<: *default image: netbirdio/signal:$NETBIRD_SIGNAL_TAG - restart: unless-stopped volumes: - $SIGNAL_VOLUMENAME:/var/lib/netbird labels: @@ -45,27 +48,17 @@ services: - traefik.http.routers.netbird-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/signalexchange.SignalExchange/`) - traefik.http.services.netbird-signal.loadbalancer.server.port=10000 - traefik.http.services.netbird-signal.loadbalancer.server.scheme=h2c - logging: - driver: "json-file" - options: - max-size: "500m" - max-file: "2" # Relay relay: + <<: *default image: netbirdio/relay:$NETBIRD_RELAY_TAG - restart: unless-stopped environment: - NB_LOG_LEVEL=info - NB_LISTEN_ADDRESS=:33080 - NB_EXPOSED_ADDRESS=$NETBIRD_RELAY_ENDPOINT # todo: change to a secure secret - NB_AUTH_SECRET=$NETBIRD_RELAY_AUTH_SECRET - logging: - driver: "json-file" - options: - max-size: "500m" - max-file: "2" labels: - traefik.enable=true - traefik.http.routers.netbird-relay.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/relay`) @@ -73,8 +66,8 @@ services: # Management management: + <<: *default image: netbirdio/management:$NETBIRD_MANAGEMENT_TAG - restart: unless-stopped depends_on: - dashboard volumes: @@ -99,30 +92,20 @@ services: - traefik.http.routers.netbird-management.service=netbird-management - traefik.http.services.netbird-management.loadbalancer.server.port=33073 - traefik.http.services.netbird-management.loadbalancer.server.scheme=h2c - logging: - driver: "json-file" - options: - max-size: "500m" - max-file: "2" environment: - NETBIRD_STORE_ENGINE_POSTGRES_DSN=$NETBIRD_STORE_ENGINE_POSTGRES_DSN - NETBIRD_STORE_ENGINE_MYSQL_DSN=$NETBIRD_STORE_ENGINE_MYSQL_DSN # Coturn coturn: + <<: *default image: coturn/coturn:$COTURN_TAG - restart: unless-stopped domainname: $TURN_DOMAIN volumes: - ./turnserver.conf:/etc/turnserver.conf:ro network_mode: host command: - -c /etc/turnserver.conf - logging: - driver: "json-file" - options: - max-size: "500m" - max-file: "2" volumes: $MGMT_VOLUMENAME: diff --git a/infrastructure_files/getting-started-with-zitadel.sh b/infrastructure_files/getting-started-with-zitadel.sh index 2118ef480..2d7c65cbe 100644 --- a/infrastructure_files/getting-started-with-zitadel.sh +++ b/infrastructure_files/getting-started-with-zitadel.sh @@ -780,7 +780,6 @@ EOF renderDockerCompose() { cat < Date: Wed, 23 Jul 2025 21:03:29 +0200 Subject: [PATCH 23/50] [client] Fix race issues in lazy tests (#4181) * Fix race issues in lazy tests * Fix test failure due to incorrect peer listener identification --- .../lazyconn/activity/manager_test.go | 30 +++++++++++++++++-- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/client/internal/lazyconn/activity/manager_test.go b/client/internal/lazyconn/activity/manager_test.go index c7c6c878a..ae6c31da4 100644 --- a/client/internal/lazyconn/activity/manager_test.go +++ b/client/internal/lazyconn/activity/manager_test.go @@ -33,6 +33,15 @@ func (m MocWGIface) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAd } +// Add this method to the Manager struct +func (m *Manager) GetPeerListener(peerConnID peerid.ConnID) (*Listener, bool) { + m.mu.Lock() + defer m.mu.Unlock() + + listener, exists := m.peers[peerConnID] + return listener, exists +} + func TestManager_MonitorPeerActivity(t *testing.T) { mocWgInterface := &MocWGIface{} @@ -51,7 +60,12 @@ func TestManager_MonitorPeerActivity(t *testing.T) { t.Fatalf("failed to monitor peer activity: %v", err) } - if err := trigger(mgr.peers[peerCfg1.PeerConnID].conn.LocalAddr().String()); err != nil { + listener, exists := mgr.GetPeerListener(peerCfg1.PeerConnID) + if !exists { + t.Fatalf("peer listener not found") + } + + if err := trigger(listener.conn.LocalAddr().String()); err != nil { t.Fatalf("failed to trigger activity: %v", err) } @@ -128,11 +142,21 @@ func TestManager_MultiPeerActivity(t *testing.T) { t.Fatalf("failed to monitor peer activity: %v", err) } - if err := trigger(mgr.peers[peerCfg1.PeerConnID].conn.LocalAddr().String()); err != nil { + listener, exists := mgr.GetPeerListener(peerCfg1.PeerConnID) + if !exists { + t.Fatalf("peer listener for peer1 not found") + } + + if err := trigger(listener.conn.LocalAddr().String()); err != nil { t.Fatalf("failed to trigger activity: %v", err) } - if err := trigger(mgr.peers[peerCfg2.PeerConnID].conn.LocalAddr().String()); err != nil { + listener, exists = mgr.GetPeerListener(peerCfg2.PeerConnID) + if !exists { + t.Fatalf("peer listener for peer2 not found") + } + + if err := trigger(listener.conn.LocalAddr().String()); err != nil { t.Fatalf("failed to trigger activity: %v", err) } From d311f57559a069fdc0b295d5004539b5c03f73e5 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Thu, 24 Jul 2025 13:14:49 +0200 Subject: [PATCH 24/50] [ci] Temporarily disable race detection in Relay (#4210) --- .github/workflows/golang-test-linux.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index 0c3862e33..1fa8b406f 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -215,7 +215,7 @@ jobs: - arch: "386" raceFlag: "" - arch: "amd64" - raceFlag: "-race" + raceFlag: "" runs-on: ubuntu-22.04 steps: - name: Install Go From e5e275c87a90680ca9abaaf51596ac1ed0827013 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 24 Jul 2025 13:34:36 +0200 Subject: [PATCH 25/50] [client] Fix legacy routing exclusion routes in kernel mode (#4167) --- client/iface/bind/ice_bind.go | 2 +- client/iface/bind/udp_mux_generic.go | 15 ++++++++------- client/iface/device/device_kernel_unix.go | 9 ++++++++- util/net/listener_listen.go | 21 +++++++++++---------- util/net/listener_listen_ios.go | 4 ++-- 5 files changed, 30 insertions(+), 21 deletions(-) diff --git a/client/iface/bind/ice_bind.go b/client/iface/bind/ice_bind.go index c3d5ef377..41f4aec6d 100644 --- a/client/iface/bind/ice_bind.go +++ b/client/iface/bind/ice_bind.go @@ -154,7 +154,7 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r s.udpMux = NewUniversalUDPMuxDefault( UniversalUDPMuxParams{ - UDPConn: nbnet.WrapUDPConn(conn), + UDPConn: nbnet.WrapPacketConn(conn), Net: s.transportNet, FilterFn: s.filterFn, WGAddress: s.address, diff --git a/client/iface/bind/udp_mux_generic.go b/client/iface/bind/udp_mux_generic.go index e42d25462..63f786d2b 100644 --- a/client/iface/bind/udp_mux_generic.go +++ b/client/iface/bind/udp_mux_generic.go @@ -7,15 +7,16 @@ import ( ) func (m *UDPMuxDefault) notifyAddressRemoval(addr string) { - wrapped, ok := m.params.UDPConn.(*UDPConn) - if !ok { + // Kernel mode: direct nbnet.PacketConn (SharedSocket wrapped with nbnet) + if conn, ok := m.params.UDPConn.(*nbnet.PacketConn); ok { + conn.RemoveAddress(addr) return } - nbnetConn, ok := wrapped.GetPacketConn().(*nbnet.UDPConn) - if !ok { - return + // Userspace mode: UDPConn wrapper around nbnet.PacketConn + if wrapped, ok := m.params.UDPConn.(*UDPConn); ok { + if conn, ok := wrapped.GetPacketConn().(*nbnet.PacketConn); ok { + conn.RemoveAddress(addr) + } } - - nbnetConn.RemoveAddress(addr) } diff --git a/client/iface/device/device_kernel_unix.go b/client/iface/device/device_kernel_unix.go index 988ed1b39..7136be0bc 100644 --- a/client/iface/device/device_kernel_unix.go +++ b/client/iface/device/device_kernel_unix.go @@ -16,6 +16,7 @@ import ( "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/sharedsock" + nbnet "github.com/netbirdio/netbird/util/net" ) type TunKernelDevice struct { @@ -99,8 +100,14 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) { if err != nil { return nil, err } + + var udpConn net.PacketConn = rawSock + if !nbnet.AdvancedRouting() { + udpConn = nbnet.WrapPacketConn(rawSock) + } + bindParams := bind.UniversalUDPMuxParams{ - UDPConn: rawSock, + UDPConn: udpConn, Net: t.transportNet, FilterFn: t.filterFn, WGAddress: t.address, diff --git a/util/net/listener_listen.go b/util/net/listener_listen.go index dc99fbd68..4060ab49a 100644 --- a/util/net/listener_listen.go +++ b/util/net/listener_listen.go @@ -120,17 +120,8 @@ func (c *UDPConn) Close() error { return closeConn(c.ID, c.UDPConn) } -// WrapUDPConn wraps an existing *net.UDPConn with nbnet functionality -func WrapUDPConn(conn *net.UDPConn) *UDPConn { - return &UDPConn{ - UDPConn: conn, - ID: GenerateConnID(), - seenAddrs: &sync.Map{}, - } -} - // RemoveAddress removes an address from the seen cache and triggers removal hooks. -func (c *UDPConn) RemoveAddress(addr string) { +func (c *PacketConn) RemoveAddress(addr string) { if _, exists := c.seenAddrs.LoadAndDelete(addr); !exists { return } @@ -159,6 +150,16 @@ func (c *UDPConn) RemoveAddress(addr string) { } } + +// WrapPacketConn wraps an existing net.PacketConn with nbnet functionality +func WrapPacketConn(conn net.PacketConn) *PacketConn { + return &PacketConn{ + PacketConn: conn, + ID: GenerateConnID(), + seenAddrs: &sync.Map{}, + } +} + func callWriteHooks(id ConnectionID, seenAddrs *sync.Map, b []byte, addr net.Addr) { // Lookup the address in the seenAddrs map to avoid calling the hooks for every write if _, loaded := seenAddrs.LoadOrStore(addr.String(), true); !loaded { diff --git a/util/net/listener_listen_ios.go b/util/net/listener_listen_ios.go index 3cbd2cd71..c52aea583 100644 --- a/util/net/listener_listen_ios.go +++ b/util/net/listener_listen_ios.go @@ -4,7 +4,7 @@ import ( "net" ) -// WrapUDPConn on iOS just returns the original connection since iOS handles its own networking -func WrapUDPConn(conn *net.UDPConn) *net.UDPConn { +// WrapPacketConn on iOS just returns the original connection since iOS handles its own networking +func WrapPacketConn(conn *net.UDPConn) *net.UDPConn { return conn } From 459c9ef3173ab7fd7cc46ade1a582e08de20e96d Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 24 Jul 2025 13:34:55 +0200 Subject: [PATCH 26/50] [client] Add env and status flags for netbird service command (#3975) --- .github/workflows/golang-test-linux.yml | 10 +- client/cmd/root.go | 16 +- client/cmd/service.go | 80 +++++-- client/cmd/service_controller.go | 162 +++++++-------- client/cmd/service_installer.go | 248 ++++++++++++++++------ client/cmd/service_test.go | 263 ++++++++++++++++++++++++ 6 files changed, 601 insertions(+), 178 deletions(-) create mode 100644 client/cmd/service_test.go diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index 1fa8b406f..0d7233c3e 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -16,7 +16,7 @@ jobs: runs-on: ubuntu-22.04 outputs: management: ${{ steps.filter.outputs.management }} - steps: + steps: - name: Checkout code uses: actions/checkout@v4 @@ -24,8 +24,8 @@ jobs: id: filter with: filters: | - management: - - 'management/**' + management: + - 'management/**' - name: Install Go uses: actions/setup-go@v5 @@ -148,7 +148,7 @@ jobs: test_client_on_docker: name: "Client (Docker) / Unit" - needs: [build-cache] + needs: [ build-cache ] runs-on: ubuntu-22.04 steps: - name: Install Go @@ -181,6 +181,7 @@ jobs: env: HOST_GOCACHE: ${{ steps.go-env.outputs.cache_dir }} HOST_GOMODCACHE: ${{ steps.go-env.outputs.modcache_dir }} + CONTAINER: "true" run: | CONTAINER_GOCACHE="/root/.cache/go-build" CONTAINER_GOMODCACHE="/go/pkg/mod" @@ -198,6 +199,7 @@ jobs: -e GOARCH=${GOARCH_TARGET} \ -e GOCACHE=${CONTAINER_GOCACHE} \ -e GOMODCACHE=${CONTAINER_GOMODCACHE} \ + -e CONTAINER=${CONTAINER} \ golang:1.23-alpine \ sh -c ' \ apk update; apk add --no-cache \ diff --git a/client/cmd/root.go b/client/cmd/root.go index fa4bd4d42..bfd0d06c5 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -67,7 +67,6 @@ var ( interfaceName string wireguardPort uint16 networkMonitor bool - serviceName string autoConnectDisabled bool extraIFaceBlackList []string anonymizeFlag bool @@ -116,15 +115,9 @@ func init() { defaultDaemonAddr = "tcp://127.0.0.1:41731" } - defaultServiceName := "netbird" - if runtime.GOOS == "windows" { - defaultServiceName = "Netbird" - } - rootCmd.PersistentFlags().StringVar(&daemonAddr, "daemon-addr", defaultDaemonAddr, "Daemon service address to serve CLI requests [unix|tcp]://[path|host:port]") rootCmd.PersistentFlags().StringVarP(&managementURL, "management-url", "m", "", fmt.Sprintf("Management Service URL [http|https]://[host]:[port] (default \"%s\")", internal.DefaultManagementURL)) rootCmd.PersistentFlags().StringVar(&adminURL, "admin-url", "", fmt.Sprintf("Admin Panel URL [http|https]://[host]:[port] (default \"%s\")", internal.DefaultAdminURL)) - rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name") rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Netbird config file location") rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets Netbird log level") rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout. If syslog is specified the log will be sent to syslog daemon.") @@ -135,7 +128,6 @@ func init() { rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device") rootCmd.PersistentFlags().BoolVarP(&anonymizeFlag, "anonymize", "A", false, "anonymize IP addresses and non-netbird.io domains in logs and status output") - rootCmd.AddCommand(serviceCmd) rootCmd.AddCommand(upCmd) rootCmd.AddCommand(downCmd) rootCmd.AddCommand(statusCmd) @@ -146,9 +138,6 @@ func init() { rootCmd.AddCommand(forwardingRulesCmd) rootCmd.AddCommand(debugCmd) - serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service - serviceCmd.AddCommand(installCmd, uninstallCmd) // service installer commands are subcommands of service - networksCMD.AddCommand(routesListCmd) networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd) @@ -186,14 +175,13 @@ func SetupCloseHandler(ctx context.Context, cancel context.CancelFunc) { termCh := make(chan os.Signal, 1) signal.Notify(termCh, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) go func() { - done := ctx.Done() + defer cancel() select { - case <-done: + case <-ctx.Done(): case <-termCh: } log.Info("shutdown signal received") - cancel() }() } diff --git a/client/cmd/service.go b/client/cmd/service.go index 156e67d6d..178f4bf0e 100644 --- a/client/cmd/service.go +++ b/client/cmd/service.go @@ -1,12 +1,15 @@ +//go:build !ios && !android + package cmd import ( "context" + "fmt" "runtime" + "strings" "sync" "github.com/kardianos/service" - log "github.com/sirupsen/logrus" "github.com/spf13/cobra" "google.golang.org/grpc" @@ -14,6 +17,16 @@ import ( "github.com/netbirdio/netbird/client/server" ) +var serviceCmd = &cobra.Command{ + Use: "service", + Short: "manages Netbird service", +} + +var ( + serviceName string + serviceEnvVars []string +) + type program struct { ctx context.Context cancel context.CancelFunc @@ -22,12 +35,31 @@ type program struct { serverInstanceMu sync.Mutex } +func init() { + defaultServiceName := "netbird" + if runtime.GOOS == "windows" { + defaultServiceName = "Netbird" + } + + serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd) + + rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name") + serviceEnvDesc := `Sets extra environment variables for the service. ` + + `You can specify a comma-separated list of KEY=VALUE pairs. ` + + `E.g. --service-env LOG_LEVEL=debug,CUSTOM_VAR=value` + + installCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc) + reconfigureCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc) + + rootCmd.AddCommand(serviceCmd) +} + func newProgram(ctx context.Context, cancel context.CancelFunc) *program { ctx = internal.CtxInitState(ctx) return &program{ctx: ctx, cancel: cancel} } -func newSVCConfig() *service.Config { +func newSVCConfig() (*service.Config, error) { config := &service.Config{ Name: serviceName, DisplayName: "Netbird", @@ -36,23 +68,47 @@ func newSVCConfig() *service.Config { EnvVars: make(map[string]string), } + if len(serviceEnvVars) > 0 { + extraEnvs, err := parseServiceEnvVars(serviceEnvVars) + if err != nil { + return nil, fmt.Errorf("parse service environment variables: %w", err) + } + config.EnvVars = extraEnvs + } + if runtime.GOOS == "linux" { config.EnvVars["SYSTEMD_UNIT"] = serviceName } - return config + return config, nil } func newSVC(prg *program, conf *service.Config) (service.Service, error) { - s, err := service.New(prg, conf) - if err != nil { - log.Fatal(err) - return nil, err - } - return s, nil + return service.New(prg, conf) } -var serviceCmd = &cobra.Command{ - Use: "service", - Short: "manages Netbird service", +func parseServiceEnvVars(envVars []string) (map[string]string, error) { + envMap := make(map[string]string) + + for _, env := range envVars { + if env == "" { + continue + } + + parts := strings.SplitN(env, "=", 2) + if len(parts) != 2 { + return nil, fmt.Errorf("invalid environment variable format: %s (expected KEY=VALUE)", env) + } + + key := strings.TrimSpace(parts[0]) + value := strings.TrimSpace(parts[1]) + + if key == "" { + return nil, fmt.Errorf("empty environment variable key in: %s", env) + } + + envMap[key] = value + } + + return envMap, nil } diff --git a/client/cmd/service_controller.go b/client/cmd/service_controller.go index 5e3c63e57..2545623ec 100644 --- a/client/cmd/service_controller.go +++ b/client/cmd/service_controller.go @@ -1,3 +1,5 @@ +//go:build !ios && !android + package cmd import ( @@ -47,14 +49,13 @@ func (p *program) Start(svc service.Service) error { listen, err := net.Listen(split[0], split[1]) if err != nil { - return fmt.Errorf("failed to listen daemon interface: %w", err) + return fmt.Errorf("listen daemon interface: %w", err) } go func() { defer listen.Close() if split[0] == "unix" { - err = os.Chmod(split[1], 0666) - if err != nil { + if err := os.Chmod(split[1], 0666); err != nil { log.Errorf("failed setting daemon permissions: %v", split[1]) return } @@ -100,37 +101,49 @@ func (p *program) Stop(srv service.Service) error { return nil } +// Common setup for service control commands +func setupServiceControlCommand(cmd *cobra.Command, ctx context.Context, cancel context.CancelFunc) (service.Service, error) { + SetFlagsFromEnvVars(rootCmd) + SetFlagsFromEnvVars(serviceCmd) + + cmd.SetOut(cmd.OutOrStdout()) + + if err := handleRebrand(cmd); err != nil { + return nil, err + } + + if err := util.InitLog(logLevel, logFile); err != nil { + return nil, fmt.Errorf("init log: %w", err) + } + + cfg, err := newSVCConfig() + if err != nil { + return nil, fmt.Errorf("create service config: %w", err) + } + + s, err := newSVC(newProgram(ctx, cancel), cfg) + if err != nil { + return nil, err + } + + return s, nil +} + var runCmd = &cobra.Command{ Use: "run", Short: "runs Netbird as service", RunE: func(cmd *cobra.Command, args []string) error { - SetFlagsFromEnvVars(rootCmd) - - cmd.SetOut(cmd.OutOrStdout()) - - err := handleRebrand(cmd) - if err != nil { - return err - } - - err = util.InitLog(logLevel, logFile) - if err != nil { - return fmt.Errorf("failed initializing log %v", err) - } - ctx, cancel := context.WithCancel(cmd.Context()) + SetupCloseHandler(ctx, cancel) SetupDebugHandler(ctx, nil, nil, nil, logFile) - s, err := newSVC(newProgram(ctx, cancel), newSVCConfig()) + s, err := setupServiceControlCommand(cmd, ctx, cancel) if err != nil { return err } - err = s.Run() - if err != nil { - return err - } - return nil + + return s.Run() }, } @@ -138,31 +151,14 @@ var startCmd = &cobra.Command{ Use: "start", Short: "starts Netbird service", RunE: func(cmd *cobra.Command, args []string) error { - SetFlagsFromEnvVars(rootCmd) - - cmd.SetOut(cmd.OutOrStdout()) - - err := handleRebrand(cmd) - if err != nil { - return err - } - - err = util.InitLog(logLevel, logFile) - if err != nil { - return err - } - ctx, cancel := context.WithCancel(cmd.Context()) - - s, err := newSVC(newProgram(ctx, cancel), newSVCConfig()) + s, err := setupServiceControlCommand(cmd, ctx, cancel) if err != nil { - cmd.PrintErrln(err) return err } - err = s.Start() - if err != nil { - cmd.PrintErrln(err) - return err + + if err := s.Start(); err != nil { + return fmt.Errorf("start service: %w", err) } cmd.Println("Netbird service has been started") return nil @@ -173,29 +169,14 @@ var stopCmd = &cobra.Command{ Use: "stop", Short: "stops Netbird service", RunE: func(cmd *cobra.Command, args []string) error { - SetFlagsFromEnvVars(rootCmd) - - cmd.SetOut(cmd.OutOrStdout()) - - err := handleRebrand(cmd) - if err != nil { - return err - } - - err = util.InitLog(logLevel, logFile) - if err != nil { - return fmt.Errorf("failed initializing log %v", err) - } - ctx, cancel := context.WithCancel(cmd.Context()) - - s, err := newSVC(newProgram(ctx, cancel), newSVCConfig()) + s, err := setupServiceControlCommand(cmd, ctx, cancel) if err != nil { return err } - err = s.Stop() - if err != nil { - return err + + if err := s.Stop(); err != nil { + return fmt.Errorf("stop service: %w", err) } cmd.Println("Netbird service has been stopped") return nil @@ -206,31 +187,48 @@ var restartCmd = &cobra.Command{ Use: "restart", Short: "restarts Netbird service", RunE: func(cmd *cobra.Command, args []string) error { - SetFlagsFromEnvVars(rootCmd) - - cmd.SetOut(cmd.OutOrStdout()) - - err := handleRebrand(cmd) - if err != nil { - return err - } - - err = util.InitLog(logLevel, logFile) - if err != nil { - return fmt.Errorf("failed initializing log %v", err) - } - ctx, cancel := context.WithCancel(cmd.Context()) - - s, err := newSVC(newProgram(ctx, cancel), newSVCConfig()) + s, err := setupServiceControlCommand(cmd, ctx, cancel) if err != nil { return err } - err = s.Restart() - if err != nil { - return err + + if err := s.Restart(); err != nil { + return fmt.Errorf("restart service: %w", err) } cmd.Println("Netbird service has been restarted") return nil }, } + +var svcStatusCmd = &cobra.Command{ + Use: "status", + Short: "shows Netbird service status", + RunE: func(cmd *cobra.Command, args []string) error { + ctx, cancel := context.WithCancel(cmd.Context()) + s, err := setupServiceControlCommand(cmd, ctx, cancel) + if err != nil { + return err + } + + status, err := s.Status() + if err != nil { + return fmt.Errorf("get service status: %w", err) + } + + var statusText string + switch status { + case service.StatusRunning: + statusText = "Running" + case service.StatusStopped: + statusText = "Stopped" + case service.StatusUnknown: + statusText = "Unknown" + default: + statusText = fmt.Sprintf("Unknown (%d)", status) + } + + cmd.Printf("Netbird service status: %s\n", statusText) + return nil + }, +} diff --git a/client/cmd/service_installer.go b/client/cmd/service_installer.go index c1d6308c6..951efcc73 100644 --- a/client/cmd/service_installer.go +++ b/client/cmd/service_installer.go @@ -1,87 +1,121 @@ +//go:build !ios && !android + package cmd import ( "context" + "errors" + "fmt" "os" "path/filepath" "runtime" + "github.com/kardianos/service" "github.com/spf13/cobra" ) +var ErrGetServiceStatus = fmt.Errorf("failed to get service status") + +// Common service command setup +func setupServiceCommand(cmd *cobra.Command) error { + SetFlagsFromEnvVars(rootCmd) + SetFlagsFromEnvVars(serviceCmd) + cmd.SetOut(cmd.OutOrStdout()) + return handleRebrand(cmd) +} + +// Build service arguments for install/reconfigure +func buildServiceArguments() []string { + args := []string{ + "service", + "run", + "--config", + configPath, + "--log-level", + logLevel, + "--daemon-addr", + daemonAddr, + } + + if managementURL != "" { + args = append(args, "--management-url", managementURL) + } + + if logFile != "" { + args = append(args, "--log-file", logFile) + } + + return args +} + +// Configure platform-specific service settings +func configurePlatformSpecificSettings(svcConfig *service.Config) error { + if runtime.GOOS == "linux" { + // Respected only by systemd systems + svcConfig.Dependencies = []string{"After=network.target syslog.target"} + + if logFile != "console" { + setStdLogPath := true + dir := filepath.Dir(logFile) + + if _, err := os.Stat(dir); err != nil { + if err = os.MkdirAll(dir, 0750); err != nil { + setStdLogPath = false + } + } + + if setStdLogPath { + svcConfig.Option["LogOutput"] = true + svcConfig.Option["LogDirectory"] = dir + } + } + } + + if runtime.GOOS == "windows" { + svcConfig.Option["OnFailure"] = "restart" + } + + return nil +} + +// Create fully configured service config for install/reconfigure +func createServiceConfigForInstall() (*service.Config, error) { + svcConfig, err := newSVCConfig() + if err != nil { + return nil, fmt.Errorf("create service config: %w", err) + } + + svcConfig.Arguments = buildServiceArguments() + if err = configurePlatformSpecificSettings(svcConfig); err != nil { + return nil, fmt.Errorf("configure platform-specific settings: %w", err) + } + + return svcConfig, nil +} + var installCmd = &cobra.Command{ Use: "install", Short: "installs Netbird service", RunE: func(cmd *cobra.Command, args []string) error { - SetFlagsFromEnvVars(rootCmd) - - cmd.SetOut(cmd.OutOrStdout()) - - err := handleRebrand(cmd) - if err != nil { + if err := setupServiceCommand(cmd); err != nil { return err } - svcConfig := newSVCConfig() - - svcConfig.Arguments = []string{ - "service", - "run", - "--config", - configPath, - "--log-level", - logLevel, - "--daemon-addr", - daemonAddr, - } - - if managementURL != "" { - svcConfig.Arguments = append(svcConfig.Arguments, "--management-url", managementURL) - } - - if logFile != "" { - svcConfig.Arguments = append(svcConfig.Arguments, "--log-file", logFile) - } - - if runtime.GOOS == "linux" { - // Respected only by systemd systems - svcConfig.Dependencies = []string{"After=network.target syslog.target"} - - if logFile != "console" { - setStdLogPath := true - dir := filepath.Dir(logFile) - - _, err := os.Stat(dir) - if err != nil { - err = os.MkdirAll(dir, 0750) - if err != nil { - setStdLogPath = false - } - } - - if setStdLogPath { - svcConfig.Option["LogOutput"] = true - svcConfig.Option["LogDirectory"] = dir - } - } - } - - if runtime.GOOS == "windows" { - svcConfig.Option["OnFailure"] = "restart" + svcConfig, err := createServiceConfigForInstall() + if err != nil { + return err } ctx, cancel := context.WithCancel(cmd.Context()) + defer cancel() s, err := newSVC(newProgram(ctx, cancel), svcConfig) if err != nil { - cmd.PrintErrln(err) return err } - err = s.Install() - if err != nil { - cmd.PrintErrln(err) - return err + if err := s.Install(); err != nil { + return fmt.Errorf("install service: %w", err) } cmd.Println("Netbird service has been installed") @@ -93,27 +127,109 @@ var uninstallCmd = &cobra.Command{ Use: "uninstall", Short: "uninstalls Netbird service from system", RunE: func(cmd *cobra.Command, args []string) error { - SetFlagsFromEnvVars(rootCmd) + if err := setupServiceCommand(cmd); err != nil { + return err + } - cmd.SetOut(cmd.OutOrStdout()) + cfg, err := newSVCConfig() + if err != nil { + return fmt.Errorf("create service config: %w", err) + } - err := handleRebrand(cmd) + ctx, cancel := context.WithCancel(cmd.Context()) + defer cancel() + + s, err := newSVC(newProgram(ctx, cancel), cfg) + if err != nil { + return err + } + + if err := s.Uninstall(); err != nil { + return fmt.Errorf("uninstall service: %w", err) + } + + cmd.Println("Netbird service has been uninstalled") + return nil + }, +} + +var reconfigureCmd = &cobra.Command{ + Use: "reconfigure", + Short: "reconfigures Netbird service with new settings", + Long: `Reconfigures the Netbird service with new settings without manual uninstall/install. +This command will temporarily stop the service, update its configuration, and restart it if it was running.`, + RunE: func(cmd *cobra.Command, args []string) error { + if err := setupServiceCommand(cmd); err != nil { + return err + } + + wasRunning, err := isServiceRunning() + if err != nil && !errors.Is(err, ErrGetServiceStatus) { + return fmt.Errorf("check service status: %w", err) + } + + svcConfig, err := createServiceConfigForInstall() if err != nil { return err } ctx, cancel := context.WithCancel(cmd.Context()) + defer cancel() - s, err := newSVC(newProgram(ctx, cancel), newSVCConfig()) + s, err := newSVC(newProgram(ctx, cancel), svcConfig) if err != nil { - return err + return fmt.Errorf("create service: %w", err) } - err = s.Uninstall() - if err != nil { - return err + if wasRunning { + cmd.Println("Stopping Netbird service...") + if err := s.Stop(); err != nil { + cmd.Printf("Warning: failed to stop service: %v\n", err) + } } - cmd.Println("Netbird service has been uninstalled") + + cmd.Println("Removing existing service configuration...") + if err := s.Uninstall(); err != nil { + return fmt.Errorf("uninstall existing service: %w", err) + } + + cmd.Println("Installing service with new configuration...") + if err := s.Install(); err != nil { + return fmt.Errorf("install service with new config: %w", err) + } + + if wasRunning { + cmd.Println("Starting Netbird service...") + if err := s.Start(); err != nil { + return fmt.Errorf("start service after reconfigure: %w", err) + } + cmd.Println("Netbird service has been reconfigured and started") + } else { + cmd.Println("Netbird service has been reconfigured") + } + return nil }, } + +func isServiceRunning() (bool, error) { + cfg, err := newSVCConfig() + if err != nil { + return false, err + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s, err := newSVC(newProgram(ctx, cancel), cfg) + if err != nil { + return false, err + } + + status, err := s.Status() + if err != nil { + return false, fmt.Errorf("%w: %w", ErrGetServiceStatus, err) + } + + return status == service.StatusRunning, nil +} diff --git a/client/cmd/service_test.go b/client/cmd/service_test.go new file mode 100644 index 000000000..6d75ca524 --- /dev/null +++ b/client/cmd/service_test.go @@ -0,0 +1,263 @@ +package cmd + +import ( + "context" + "fmt" + "os" + "runtime" + "testing" + "time" + + "github.com/kardianos/service" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const ( + serviceStartTimeout = 10 * time.Second + serviceStopTimeout = 5 * time.Second + statusPollInterval = 500 * time.Millisecond +) + +// waitForServiceStatus waits for service to reach expected status with timeout +func waitForServiceStatus(expectedStatus service.Status, timeout time.Duration) (bool, error) { + cfg, err := newSVCConfig() + if err != nil { + return false, err + } + + ctxSvc, cancel := context.WithCancel(context.Background()) + defer cancel() + + s, err := newSVC(newProgram(ctxSvc, cancel), cfg) + if err != nil { + return false, err + } + + ctx, timeoutCancel := context.WithTimeout(context.Background(), timeout) + defer timeoutCancel() + + ticker := time.NewTicker(statusPollInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return false, fmt.Errorf("timeout waiting for service status %v", expectedStatus) + case <-ticker.C: + status, err := s.Status() + if err != nil { + // Continue polling on transient errors + continue + } + if status == expectedStatus { + return true, nil + } + } + } +} + +// TestServiceLifecycle tests the complete service lifecycle +func TestServiceLifecycle(t *testing.T) { + // TODO: Add support for Windows and macOS + if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" { + t.Skipf("Skipping service lifecycle test on unsupported OS: %s", runtime.GOOS) + } + + if os.Getenv("CONTAINER") == "true" { + t.Skip("Skipping service lifecycle test in container environment") + } + + originalServiceName := serviceName + serviceName = "netbirdtest" + fmt.Sprintf("%d", time.Now().Unix()) + defer func() { + serviceName = originalServiceName + }() + + tempDir := t.TempDir() + configPath = fmt.Sprintf("%s/netbird-test-config.json", tempDir) + logLevel = "info" + daemonAddr = fmt.Sprintf("unix://%s/netbird-test.sock", tempDir) + + ctx := context.Background() + + t.Run("Install", func(t *testing.T) { + installCmd.SetContext(ctx) + err := installCmd.RunE(installCmd, []string{}) + require.NoError(t, err) + + cfg, err := newSVCConfig() + require.NoError(t, err) + + ctxSvc, cancel := context.WithCancel(context.Background()) + defer cancel() + + s, err := newSVC(newProgram(ctxSvc, cancel), cfg) + require.NoError(t, err) + + status, err := s.Status() + assert.NoError(t, err) + assert.NotEqual(t, service.StatusUnknown, status) + }) + + t.Run("Start", func(t *testing.T) { + startCmd.SetContext(ctx) + err := startCmd.RunE(startCmd, []string{}) + require.NoError(t, err) + + running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout) + require.NoError(t, err) + assert.True(t, running) + }) + + t.Run("Restart", func(t *testing.T) { + restartCmd.SetContext(ctx) + err := restartCmd.RunE(restartCmd, []string{}) + require.NoError(t, err) + + running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout) + require.NoError(t, err) + assert.True(t, running) + }) + + t.Run("Reconfigure", func(t *testing.T) { + originalLogLevel := logLevel + logLevel = "debug" + defer func() { + logLevel = originalLogLevel + }() + + reconfigureCmd.SetContext(ctx) + err := reconfigureCmd.RunE(reconfigureCmd, []string{}) + require.NoError(t, err) + + running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout) + require.NoError(t, err) + assert.True(t, running) + }) + + t.Run("Stop", func(t *testing.T) { + stopCmd.SetContext(ctx) + err := stopCmd.RunE(stopCmd, []string{}) + require.NoError(t, err) + + stopped, err := waitForServiceStatus(service.StatusStopped, serviceStopTimeout) + require.NoError(t, err) + assert.True(t, stopped) + }) + + t.Run("Uninstall", func(t *testing.T) { + uninstallCmd.SetContext(ctx) + err := uninstallCmd.RunE(uninstallCmd, []string{}) + require.NoError(t, err) + + cfg, err := newSVCConfig() + require.NoError(t, err) + + ctxSvc, cancel := context.WithCancel(context.Background()) + defer cancel() + + s, err := newSVC(newProgram(ctxSvc, cancel), cfg) + require.NoError(t, err) + + _, err = s.Status() + assert.Error(t, err) + }) +} + +// TestServiceEnvVars tests environment variable parsing +func TestServiceEnvVars(t *testing.T) { + tests := []struct { + name string + envVars []string + expected map[string]string + expectErr bool + }{ + { + name: "Valid single env var", + envVars: []string{"LOG_LEVEL=debug"}, + expected: map[string]string{ + "LOG_LEVEL": "debug", + }, + }, + { + name: "Valid multiple env vars", + envVars: []string{"LOG_LEVEL=debug", "CUSTOM_VAR=value"}, + expected: map[string]string{ + "LOG_LEVEL": "debug", + "CUSTOM_VAR": "value", + }, + }, + { + name: "Env var with spaces", + envVars: []string{" KEY = value "}, + expected: map[string]string{ + "KEY": "value", + }, + }, + { + name: "Invalid format - no equals", + envVars: []string{"INVALID"}, + expectErr: true, + }, + { + name: "Invalid format - empty key", + envVars: []string{"=value"}, + expectErr: true, + }, + { + name: "Empty value is valid", + envVars: []string{"KEY="}, + expected: map[string]string{ + "KEY": "", + }, + }, + { + name: "Empty slice", + envVars: []string{}, + expected: map[string]string{}, + }, + { + name: "Empty string in slice", + envVars: []string{"", "KEY=value", ""}, + expected: map[string]string{"KEY": "value"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := parseServiceEnvVars(tt.envVars) + + if tt.expectErr { + assert.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expected, result) + } + }) + } +} + +// TestServiceConfigWithEnvVars tests service config creation with env vars +func TestServiceConfigWithEnvVars(t *testing.T) { + originalServiceName := serviceName + originalServiceEnvVars := serviceEnvVars + defer func() { + serviceName = originalServiceName + serviceEnvVars = originalServiceEnvVars + }() + + serviceName = "test-service" + serviceEnvVars = []string{"TEST_VAR=test_value", "ANOTHER_VAR=another_value"} + + cfg, err := newSVCConfig() + require.NoError(t, err) + + assert.Equal(t, "test-service", cfg.Name) + assert.Equal(t, "test_value", cfg.EnvVars["TEST_VAR"]) + assert.Equal(t, "another_value", cfg.EnvVars["ANOTHER_VAR"]) + + if runtime.GOOS == "linux" { + assert.Equal(t, "test-service", cfg.EnvVars["SYSTEMD_UNIT"]) + } +} From 0ea5d020a34b13a80ad0c3c2a96b556b39ffc1ae Mon Sep 17 00:00:00 2001 From: Pedro Maia Costa <550684+pnmcosta@users.noreply.github.com> Date: Thu, 24 Jul 2025 16:12:29 +0100 Subject: [PATCH 27/50] [management] extra settings integrated validator (#4136) --- go.mod | 2 +- go.sum | 4 +- management/server/account/manager.go | 2 +- management/server/integrated_validator.go | 39 ++++++++++++------- .../integrated_validator/interface.go | 4 +- management/server/mock_server/account_mock.go | 10 ++--- management/server/peer.go | 14 +++---- management/server/types/settings.go | 7 ++++ 8 files changed, 51 insertions(+), 31 deletions(-) diff --git a/go.mod b/go.mod index cf2a23758..8120efe54 100644 --- a/go.mod +++ b/go.mod @@ -63,7 +63,7 @@ require ( github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 - github.com/netbirdio/management-integrations/integrations v0.0.0-20250718071730-f4d133556ff5 + github.com/netbirdio/management-integrations/integrations v0.0.0-20250718161635-83fb99b09b5a github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 diff --git a/go.sum b/go.sum index 699a832dd..c9938908e 100644 --- a/go.sum +++ b/go.sum @@ -503,8 +503,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q= -github.com/netbirdio/management-integrations/integrations v0.0.0-20250718071730-f4d133556ff5 h1:Zfn8d83OVyELCdxgprcyXR3D8uqoxHtXE9PUxVXDx/w= -github.com/netbirdio/management-integrations/integrations v0.0.0-20250718071730-f4d133556ff5/go.mod h1:Gi9raplYzCCyh07Olw/DVfCJTFgpr1WCXJ/Q+8TSA9Q= +github.com/netbirdio/management-integrations/integrations v0.0.0-20250718161635-83fb99b09b5a h1:Kmq74+axAiJrD98+uAr53sIuj/zwMrak05Ofoy4SWYU= +github.com/netbirdio/management-integrations/integrations v0.0.0-20250718161635-83fb99b09b5a/go.mod h1:Gi9raplYzCCyh07Olw/DVfCJTFgpr1WCXJ/Q+8TSA9Q= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb h1:Cr6age+ePALqlSvtp7wc6lYY97XN7rkD1K4XEDmY+TU= diff --git a/management/server/account/manager.go b/management/server/account/manager.go index f8aa2756a..8c7e95e3d 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -101,7 +101,7 @@ type Manager interface { DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) GetIdpManager() idp.Manager - UpdateIntegratedValidatorGroups(ctx context.Context, accountID string, userID string, groups []string) error + UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go index e3e474411..b89739be9 100644 --- a/management/server/integrated_validator.go +++ b/management/server/integrated_validator.go @@ -3,6 +3,7 @@ package server import ( "context" "errors" + "fmt" log "github.com/sirupsen/logrus" @@ -12,34 +13,44 @@ import ( "github.com/netbirdio/netbird/management/server/types" ) -// UpdateIntegratedValidatorGroups updates the integrated validator groups for a specified account. +// UpdateIntegratedValidator updates the integrated validator groups for a specified account. // It retrieves the account associated with the provided userID, then updates the integrated validator groups // with the provided list of group ids. The updated account is then saved. // // Parameters: // - accountID: The ID of the account for which integrated validator groups are to be updated. // - userID: The ID of the user whose account is being updated. +// - validator: The validator type to use, or empty to remove. // - groups: A slice of strings representing the ids of integrated validator groups to be updated. // // Returns: // - error: An error if any occurred during the process, otherwise returns nil -func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(ctx context.Context, accountID string, userID string, groups []string) error { - ok, err := am.GroupValidation(ctx, accountID, groups) - if err != nil { - log.WithContext(ctx).Debugf("error validating groups: %s", err.Error()) - return err +func (am *DefaultAccountManager) UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error { + if validator != "" && len(groups) == 0 { + return fmt.Errorf("at least one group must be specified for validator") } - if !ok { - log.WithContext(ctx).Debugf("invalid groups") - return errors.New("invalid groups") + if validator != "" { + ok, err := am.GroupValidation(ctx, accountID, groups) + if err != nil { + log.WithContext(ctx).Debugf("error validating groups: %s", err.Error()) + return err + } + + if !ok { + log.WithContext(ctx).Debugf("invalid groups") + return errors.New("invalid groups") + } + } else { + // ensure groups is empty + groups = []string{} } unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() return am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - a, err := transaction.GetAccountByUser(ctx, userID) + a, err := transaction.GetAccount(ctx, accountID) if err != nil { return err } @@ -52,6 +63,8 @@ func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(ctx context.Con extra = &types.ExtraSettings{} a.Settings.Extra = extra } + + extra.IntegratedValidator = validator extra.IntegratedValidatorGroups = groups return transaction.SaveAccount(ctx, a) }) @@ -99,7 +112,7 @@ func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountI return nil, err } - return am.integratedPeerValidator.GetValidatedPeers(accountID, groups, peers, settings.Extra) + return am.integratedPeerValidator.GetValidatedPeers(ctx, accountID, groups, peers, settings.Extra) } type MockIntegratedValidator struct { @@ -118,7 +131,7 @@ func (a MockIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer. return update, false, nil } -func (a MockIntegratedValidator) GetValidatedPeers(accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *types.ExtraSettings) (map[string]struct{}, error) { +func (a MockIntegratedValidator) GetValidatedPeers(_ context.Context, accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *types.ExtraSettings) (map[string]struct{}, error) { validatedPeers := make(map[string]struct{}) for _, peer := range peers { validatedPeers[peer.ID] = struct{}{} @@ -134,7 +147,7 @@ func (MockIntegratedValidator) IsNotValidPeer(_ context.Context, accountID strin return false, false, nil } -func (MockIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error { +func (MockIntegratedValidator) PeerDeleted(_ context.Context, _, _ string, extraSettings *types.ExtraSettings) error { return nil } diff --git a/management/server/integrations/integrated_validator/interface.go b/management/server/integrations/integrated_validator/interface.go index 245c0168f..4d4a8cdf6 100644 --- a/management/server/integrations/integrated_validator/interface.go +++ b/management/server/integrations/integrated_validator/interface.go @@ -14,8 +14,8 @@ type IntegratedValidator interface { ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error) PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) *nbpeer.Peer IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error) - GetValidatedPeers(accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *types.ExtraSettings) (map[string]struct{}, error) - PeerDeleted(ctx context.Context, accountID, peerID string) error + GetValidatedPeers(ctx context.Context, accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *types.ExtraSettings) (map[string]struct{}, error) + PeerDeleted(ctx context.Context, accountID, peerID string, extraSettings *types.ExtraSettings) error SetPeerInvalidationListener(fn func(accountID string)) Stop(ctx context.Context) ValidateFlowResponse(ctx context.Context, peerKey string, flowResponse *proto.PKCEAuthorizationFlow) *proto.PKCEAuthorizationFlow diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index b1ec66286..a16e3652c 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -102,7 +102,7 @@ type MockAccountManager struct { DeletePostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) error ListPostureChecksFunc func(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) GetIdpManagerFunc func() idp.Manager - UpdateIntegratedValidatorGroupsFunc func(ctx context.Context, accountID string, userID string, groups []string) error + UpdateIntegratedValidatorFunc func(ctx context.Context, accountID, userID, validator string, groups []string) error GroupValidationFunc func(ctx context.Context, accountId string, groups []string) (bool, error) SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) @@ -769,10 +769,10 @@ func (am *MockAccountManager) GetIdpManager() idp.Manager { return nil } -// UpdateIntegratedValidatorGroups mocks UpdateIntegratedApprovalGroups of the AccountManager interface -func (am *MockAccountManager) UpdateIntegratedValidatorGroups(ctx context.Context, accountID string, userID string, groups []string) error { - if am.UpdateIntegratedValidatorGroupsFunc != nil { - return am.UpdateIntegratedValidatorGroupsFunc(ctx, accountID, userID, groups) +// UpdateIntegratedValidator mocks UpdateIntegratedApprovalGroups of the AccountManager interface +func (am *MockAccountManager) UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error { + if am.UpdateIntegratedValidatorFunc != nil { + return am.UpdateIntegratedValidatorFunc(ctx, accountID, userID, validator, groups) } return status.Errorf(codes.Unimplemented, "method UpdateIntegratedValidatorGroups is not implemented") } diff --git a/management/server/peer.go b/management/server/peer.go index c6ade83c0..8f3eb2331 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -87,7 +87,7 @@ func (am *DefaultAccountManager) getUserAccessiblePeers(ctx context.Context, acc return nil, err } - approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(accountID, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) + approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(ctx, accountID, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) if err != nil { return nil, err } @@ -412,7 +412,7 @@ func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID strin groups[groupID] = group.Peers } - validatedPeers, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) + validatedPeers, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) if err != nil { return nil, err } @@ -1036,7 +1036,7 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is return nil, nil, nil, err } - approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) + approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) if err != nil { return nil, nil, nil, err } @@ -1156,7 +1156,7 @@ func (am *DefaultAccountManager) checkIfUserOwnsPeer(ctx context.Context, accoun return nil, err } - approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(accountID, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) + approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(ctx, accountID, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) if err != nil { return nil, err } @@ -1204,7 +1204,7 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account return } - approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) + approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) if err != nil { log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to get validate peers: %v", err) return @@ -1337,7 +1337,7 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI return } - approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) + approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) if err != nil { log.WithContext(ctx).Errorf("failed to send update to peer %s, failed to validate peers: %v", peerId, err) return @@ -1571,7 +1571,7 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto } } - if err := am.integratedPeerValidator.PeerDeleted(ctx, accountID, peer.ID); err != nil { + if err := am.integratedPeerValidator.PeerDeleted(ctx, accountID, peer.ID, settings.Extra); err != nil { return nil, err } diff --git a/management/server/types/settings.go b/management/server/types/settings.go index a22a36b03..37c728bf8 100644 --- a/management/server/types/settings.go +++ b/management/server/types/settings.go @@ -77,6 +77,8 @@ type ExtraSettings struct { // PeerApprovalEnabled enables or disables the need for peers bo be approved by an administrator PeerApprovalEnabled bool + // IntegratedValidator is the string enum for the integrated validator type + IntegratedValidator string // IntegratedValidatorGroups list of group IDs to be used with integrated approval configurations IntegratedValidatorGroups []string `gorm:"serializer:json"` @@ -93,5 +95,10 @@ func (e *ExtraSettings) Copy() *ExtraSettings { return &ExtraSettings{ PeerApprovalEnabled: e.PeerApprovalEnabled, IntegratedValidatorGroups: append(cpGroup, e.IntegratedValidatorGroups...), + IntegratedValidator: e.IntegratedValidator, + FlowEnabled: e.FlowEnabled, + FlowPacketCounterEnabled: e.FlowPacketCounterEnabled, + FlowENCollectionEnabled: e.FlowENCollectionEnabled, + FlowDnsCollectionEnabled: e.FlowDnsCollectionEnabled, } } From 1a9ea32c2138f363f2d67af9c5d2afcbc74085f1 Mon Sep 17 00:00:00 2001 From: Pedro Maia Costa <550684+pnmcosta@users.noreply.github.com> Date: Thu, 24 Jul 2025 16:25:21 +0100 Subject: [PATCH 28/50] [management] scheduler cancel all jobs (#4158) --- go.mod | 2 +- go.sum | 4 ++-- management/server/scheduler.go | 29 +++++++++++++++++++++++--- management/server/scheduler_test.go | 32 +++++++++++++++++++++++++++++ 4 files changed, 61 insertions(+), 6 deletions(-) diff --git a/go.mod b/go.mod index 8120efe54..4d9191d04 100644 --- a/go.mod +++ b/go.mod @@ -63,7 +63,7 @@ require ( github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 - github.com/netbirdio/management-integrations/integrations v0.0.0-20250718161635-83fb99b09b5a + github.com/netbirdio/management-integrations/integrations v0.0.0-20250724151510-c007bc6b392c github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 diff --git a/go.sum b/go.sum index c9938908e..84dfe2403 100644 --- a/go.sum +++ b/go.sum @@ -503,8 +503,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q= -github.com/netbirdio/management-integrations/integrations v0.0.0-20250718161635-83fb99b09b5a h1:Kmq74+axAiJrD98+uAr53sIuj/zwMrak05Ofoy4SWYU= -github.com/netbirdio/management-integrations/integrations v0.0.0-20250718161635-83fb99b09b5a/go.mod h1:Gi9raplYzCCyh07Olw/DVfCJTFgpr1WCXJ/Q+8TSA9Q= +github.com/netbirdio/management-integrations/integrations v0.0.0-20250724151510-c007bc6b392c h1:OtX903X0FKEE+fcsp/P2701md7X/xbi/W/ojWIJNKSk= +github.com/netbirdio/management-integrations/integrations v0.0.0-20250724151510-c007bc6b392c/go.mod h1:Gi9raplYzCCyh07Olw/DVfCJTFgpr1WCXJ/Q+8TSA9Q= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb h1:Cr6age+ePALqlSvtp7wc6lYY97XN7rkD1K4XEDmY+TU= diff --git a/management/server/scheduler.go b/management/server/scheduler.go index df73c9a1d..b61643295 100644 --- a/management/server/scheduler.go +++ b/management/server/scheduler.go @@ -11,14 +11,17 @@ import ( // Scheduler is an interface which implementations can schedule and cancel jobs type Scheduler interface { Cancel(ctx context.Context, IDs []string) + CancelAll(ctx context.Context) Schedule(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) IsSchedulerRunning(ID string) bool } // MockScheduler is a mock implementation of Scheduler type MockScheduler struct { - CancelFunc func(ctx context.Context, IDs []string) - ScheduleFunc func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) + CancelFunc func(ctx context.Context, IDs []string) + CancelAllFunc func(ctx context.Context) + ScheduleFunc func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) + IsSchedulerRunningFunc func(ID string) bool } // Cancel mocks the Cancel function of the Scheduler interface @@ -30,6 +33,15 @@ func (mock *MockScheduler) Cancel(ctx context.Context, IDs []string) { log.WithContext(ctx).Warnf("MockScheduler doesn't have Cancel function defined ") } +// CancelAll mocks the CancelAll function of the Scheduler interface +func (mock *MockScheduler) CancelAll(ctx context.Context) { + if mock.CancelAllFunc != nil { + mock.CancelAllFunc(ctx) + return + } + log.WithContext(ctx).Warnf("MockScheduler doesn't have CancelAll function defined ") +} + // Schedule mocks the Schedule function of the Scheduler interface func (mock *MockScheduler) Schedule(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) { if mock.ScheduleFunc != nil { @@ -40,7 +52,9 @@ func (mock *MockScheduler) Schedule(ctx context.Context, in time.Duration, ID st } func (mock *MockScheduler) IsSchedulerRunning(ID string) bool { - // MockScheduler does not implement IsSchedulerRunning, so we return false + if mock.IsSchedulerRunningFunc != nil { + return mock.IsSchedulerRunningFunc(ID) + } log.Warnf("MockScheduler doesn't have IsSchedulerRunning function defined") return false } @@ -52,6 +66,15 @@ type DefaultScheduler struct { mu *sync.Mutex } +func (wm *DefaultScheduler) CancelAll(ctx context.Context) { + wm.mu.Lock() + defer wm.mu.Unlock() + + for id := range wm.jobs { + wm.cancel(ctx, id) + } +} + // NewDefaultScheduler creates an instance of a DefaultScheduler func NewDefaultScheduler() *DefaultScheduler { return &DefaultScheduler{ diff --git a/management/server/scheduler_test.go b/management/server/scheduler_test.go index fa279d4db..e3af551ad 100644 --- a/management/server/scheduler_test.go +++ b/management/server/scheduler_test.go @@ -75,6 +75,38 @@ func TestScheduler_Cancel(t *testing.T) { assert.NotNil(t, scheduler.jobs[jobID2]) } +func TestScheduler_CancelAll(t *testing.T) { + jobID1 := "test-scheduler-job-1" + jobID2 := "test-scheduler-job-2" + scheduler := NewDefaultScheduler() + tChan := make(chan struct{}) + p := []string{jobID1, jobID2} + scheduletime := 2 * time.Millisecond + sleepTime := 4 * time.Millisecond + if runtime.GOOS == "windows" { + // sleep and ticker are slower on windows see https://github.com/golang/go/issues/44343 + sleepTime = 20 * time.Millisecond + } + + scheduler.Schedule(context.Background(), scheduletime, jobID1, func() (nextRunIn time.Duration, reschedule bool) { + tt := p[0] + <-tChan + t.Logf("job %s", tt) + return scheduletime, true + }) + scheduler.Schedule(context.Background(), scheduletime, jobID2, func() (nextRunIn time.Duration, reschedule bool) { + return scheduletime, true + }) + + time.Sleep(sleepTime) + assert.Len(t, scheduler.jobs, 2) + scheduler.CancelAll(context.Background()) + close(tChan) + p = []string{} + time.Sleep(sleepTime) + assert.Len(t, scheduler.jobs, 0) +} + func TestScheduler_Schedule(t *testing.T) { jobID := "test-scheduler-job-1" scheduler := NewDefaultScheduler() From 04fae00a6c1cdfa47d937172d369eaa6e0963157 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Thu, 24 Jul 2025 17:44:48 +0200 Subject: [PATCH 29/50] [management] Log UpdateAccountPeers caller (#4216) --- management/server/peer.go | 3 +++ util/runtime.go | 15 +++++++++++++++ 2 files changed, 18 insertions(+) create mode 100644 util/runtime.go diff --git a/management/server/peer.go b/management/server/peer.go index 8f3eb2331..b82a9cb80 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -23,6 +23,7 @@ import ( routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" + "github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/store" @@ -1183,6 +1184,8 @@ func (am *DefaultAccountManager) checkIfUserOwnsPeer(ctx context.Context, accoun // UpdateAccountPeers updates all peers that belong to an account. // Should be called when changes have to be synced to peers. func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) { + log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName()) + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to send out updates to peers. failed to get account: %v", err) diff --git a/util/runtime.go b/util/runtime.go new file mode 100644 index 000000000..3b420e15b --- /dev/null +++ b/util/runtime.go @@ -0,0 +1,15 @@ +package util + +import "runtime" + +func GetCallerName() string { + pc, _, _, ok := runtime.Caller(2) + if !ok { + return "unknown" + } + fn := runtime.FuncForPC(pc) + if fn == nil { + return "unknown" + } + return fn.Name() +} From 643730f770fd82927d28496759606872ed937065 Mon Sep 17 00:00:00 2001 From: Ali Amer <76897266+aliamerj@users.noreply.github.com> Date: Thu, 24 Jul 2025 18:51:27 +0300 Subject: [PATCH 30/50] [client] Correct minor issues in --filter-by-connection-type flag implementation for status command (#4214) Signed-off-by: aliamerj --- client/proto/daemon.pb.go | 7 ------- client/status/status.go | 14 +++++++++----- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index 753aa62d1..26e58d183 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -1330,13 +1330,6 @@ func (x *PeerState) GetRelayAddress() string { return "" } -func (x *PeerState) GetConnectionType() string { - if x.Relayed { - return "Relayed" - } - return "P2P" -} - // LocalPeerState contains the latest state of the local peer type LocalPeerState struct { state protoimpl.MessageState `protogen:"open.v1"` diff --git a/client/status/status.go b/client/status/status.go index 507c7ea80..d28485bc0 100644 --- a/client/status/status.go +++ b/client/status/status.go @@ -203,13 +203,18 @@ func mapPeers( localICEEndpoint := "" remoteICEEndpoint := "" relayServerAddress := "" - connType := "" + connType := "P2P" lastHandshake := time.Time{} transferReceived := int64(0) transferSent := int64(0) isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String() - if skipDetailByFilters(pbPeerState, pbPeerState.ConnStatus, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter) { + + if pbPeerState.Relayed { + connType = "Relayed" + } + + if skipDetailByFilters(pbPeerState, pbPeerState.ConnStatus, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter, connType) { continue } if isPeerConnected { @@ -219,7 +224,6 @@ func mapPeers( remoteICE = pbPeerState.GetRemoteIceCandidateType() localICEEndpoint = pbPeerState.GetLocalIceCandidateEndpoint() remoteICEEndpoint = pbPeerState.GetRemoteIceCandidateEndpoint() - connType = pbPeerState.GetConnectionType() relayServerAddress = pbPeerState.GetRelayAddress() lastHandshake = pbPeerState.GetLastWireguardHandshake().AsTime().Local() transferReceived = pbPeerState.GetBytesRx() @@ -540,7 +544,7 @@ func parsePeers(peers PeersStateOutput, rosenpassEnabled, rosenpassPermissive bo return peersString } -func skipDetailByFilters(peerState *proto.PeerState, peerStatus string, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}, connectionTypeFilter string) bool { +func skipDetailByFilters(peerState *proto.PeerState, peerStatus string, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}, connectionTypeFilter, connType string) bool { statusEval := false ipEval := false nameEval := true @@ -569,7 +573,7 @@ func skipDetailByFilters(peerState *proto.PeerState, peerStatus string, statusFi } else { nameEval = false } - if connectionTypeFilter != "" && !strings.EqualFold(peerState.GetConnectionType(), connectionTypeFilter) { + if connectionTypeFilter != "" && !strings.EqualFold(connType, connectionTypeFilter) { connectionTypeEval = true } From c435c2727fb2bd99c080afdadc4a7d63690e7626 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Thu, 24 Jul 2025 18:33:58 +0200 Subject: [PATCH 31/50] [management] Log BufferUpdateAccountPeers caller (#4217) --- management/server/peer.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/management/server/peer.go b/management/server/peer.go index b82a9cb80..3c40c6bb6 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -1291,6 +1291,8 @@ type bufferUpdate struct { } func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) { + log.WithContext(ctx).Tracef("buffer updating peers for account %s from %s", accountID, util.GetCallerName()) + bufUpd, _ := am.accountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{}) b := bufUpd.(*bufferUpdate) From cb1e437785478f2ae160def29a1ccf3717725339 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Thu, 24 Jul 2025 21:00:51 +0200 Subject: [PATCH 32/50] [client] handle order of check when checking order of files in isChecksEqual (#4219) --- client/internal/engine.go | 29 +++++++------ client/internal/engine_test.go | 76 ++++++++++++++++++++++++++++++++++ 2 files changed, 92 insertions(+), 13 deletions(-) diff --git a/client/internal/engine.go b/client/internal/engine.go index 1abb8163d..079adf7e8 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -1968,21 +1968,24 @@ func (e *Engine) toExcludedLazyPeers(rules []firewallManager.ForwardRule, peers } // isChecksEqual checks if two slices of checks are equal. -func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool { - for _, check := range checks { - sort.Slice(check.Files, func(i, j int) bool { - return check.Files[i] < check.Files[j] - }) - } - for _, oCheck := range oChecks { - sort.Slice(oCheck.Files, func(i, j int) bool { - return oCheck.Files[i] < oCheck.Files[j] - }) +func isChecksEqual(checks1, checks2 []*mgmProto.Checks) bool { + normalize := func(checks []*mgmProto.Checks) []string { + normalized := make([]string, len(checks)) + + for i, check := range checks { + sortedFiles := slices.Clone(check.Files) + sort.Strings(sortedFiles) + normalized[i] = strings.Join(sortedFiles, "|") + } + + sort.Strings(normalized) + return normalized } - return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool { - return slices.Equal(checks.Files, oChecks.Files) - }) + n1 := normalize(checks1) + n2 := normalize(checks2) + + return slices.Equal(n1, n2) } func getInterfacePrefixes() ([]netip.Prefix, error) { diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index f02138686..fffbed533 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -1270,6 +1270,82 @@ func Test_CheckFilesEqual(t *testing.T) { }, expectedBool: false, }, + { + name: "Compared Slices with same files but different order should return true", + inputChecks1: []*mgmtProto.Checks{ + { + Files: []string{ + "testfile1", + "testfile2", + }, + }, + { + Files: []string{ + "testfile4", + "testfile3", + }, + }, + }, + inputChecks2: []*mgmtProto.Checks{ + { + Files: []string{ + "testfile3", + "testfile4", + }, + }, + { + Files: []string{ + "testfile2", + "testfile1", + }, + }, + }, + expectedBool: true, + }, + { + name: "Compared Slices with same files but different order while first is equal should return true", + inputChecks1: []*mgmtProto.Checks{ + { + Files: []string{ + "testfile0", + "testfile1", + }, + }, + { + Files: []string{ + "testfile0", + "testfile2", + }, + }, + { + Files: []string{ + "testfile0", + "testfile3", + }, + }, + }, + inputChecks2: []*mgmtProto.Checks{ + { + Files: []string{ + "testfile0", + "testfile1", + }, + }, + { + Files: []string{ + "testfile0", + "testfile3", + }, + }, + { + Files: []string{ + "testfile0", + "testfile2", + }, + }, + }, + expectedBool: true, + }, } for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { From 3f8269808943c3bf84183d8c9ba544159cc7e213 Mon Sep 17 00:00:00 2001 From: Louis Li <32395144+gamerslouis@users.noreply.github.com> Date: Fri, 25 Jul 2025 16:36:11 +0800 Subject: [PATCH 33/50] [client] make ICE failed timeout configurable (#4211) --- client/internal/peer/ice/agent.go | 8 +++----- client/internal/peer/ice/env.go | 17 +++++++++++++++++ 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/client/internal/peer/ice/agent.go b/client/internal/peer/ice/agent.go index 9b63cebf0..4a0228405 100644 --- a/client/internal/peer/ice/agent.go +++ b/client/internal/peer/ice/agent.go @@ -18,17 +18,15 @@ const ( iceKeepAliveDefault = 4 * time.Second iceDisconnectedTimeoutDefault = 6 * time.Second + iceFailedTimeoutDefault = 6 * time.Second // iceRelayAcceptanceMinWaitDefault is the same as in the Pion ICE package iceRelayAcceptanceMinWaitDefault = 2 * time.Second ) -var ( - failedTimeout = 6 * time.Second -) - func NewAgent(iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ice.Agent, error) { iceKeepAlive := iceKeepAlive() iceDisconnectedTimeout := iceDisconnectedTimeout() + iceFailedTimeout := iceFailedTimeout() iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait() transportNet, err := newStdNet(iFaceDiscover, config.InterfaceBlackList) @@ -50,7 +48,7 @@ func NewAgent(iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candida UDPMuxSrflx: config.UDPMuxSrflx, NAT1To1IPs: config.NATExternalIPs, Net: transportNet, - FailedTimeout: &failedTimeout, + FailedTimeout: &iceFailedTimeout, DisconnectedTimeout: &iceDisconnectedTimeout, KeepaliveInterval: &iceKeepAlive, RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait, diff --git a/client/internal/peer/ice/env.go b/client/internal/peer/ice/env.go index 3b0cb74ad..c11c35441 100644 --- a/client/internal/peer/ice/env.go +++ b/client/internal/peer/ice/env.go @@ -13,6 +13,7 @@ const ( envICEForceRelayConn = "NB_ICE_FORCE_RELAY_CONN" envICEKeepAliveIntervalSec = "NB_ICE_KEEP_ALIVE_INTERVAL_SEC" envICEDisconnectedTimeoutSec = "NB_ICE_DISCONNECTED_TIMEOUT_SEC" + envICEFailedTimeoutSec = "NB_ICE_FAILED_TIMEOUT_SEC" envICERelayAcceptanceMinWaitSec = "NB_ICE_RELAY_ACCEPTANCE_MIN_WAIT_SEC" msgWarnInvalidValue = "invalid value %s set for %s, using default %v" @@ -55,6 +56,22 @@ func iceDisconnectedTimeout() time.Duration { return time.Duration(disconnectedTimeoutSec) * time.Second } +func iceFailedTimeout() time.Duration { + failedTimeoutEnv := os.Getenv(envICEFailedTimeoutSec) + if failedTimeoutEnv == "" { + return iceFailedTimeoutDefault + } + + log.Infof("setting ICE failed timeout to %s seconds", failedTimeoutEnv) + failedTimeoutSec, err := strconv.Atoi(failedTimeoutEnv) + if err != nil { + log.Warnf(msgWarnInvalidValue, failedTimeoutEnv, envICEFailedTimeoutSec, iceFailedTimeoutDefault) + return iceFailedTimeoutDefault + } + + return time.Duration(failedTimeoutSec) * time.Second +} + func iceRelayAcceptanceMinWait() time.Duration { iceRelayAcceptanceMinWaitEnv := os.Getenv(envICERelayAcceptanceMinWaitSec) if iceRelayAcceptanceMinWaitEnv == "" { From af8687579b571592663371e237fa9a405af56493 Mon Sep 17 00:00:00 2001 From: "Krzysztof Nazarewski (kdn)" Date: Fri, 25 Jul 2025 11:44:30 +0200 Subject: [PATCH 34/50] client: container: support CLI with entrypoint addition (#4126) This will allow running netbird commands (including debugging) against the daemon and provide a flow similar to non-container usages. It will by default both log to file and stderr so it can be handled more uniformly in container-native environments. --- .dockerignore-client | 3 + .gitignore | 1 + .goreleaser.yaml | 14 ++- client/Dockerfile | 30 ++++-- client/Dockerfile-rootless | 37 +++++-- client/cmd/down.go | 2 +- client/cmd/login.go | 4 +- client/cmd/login_test.go | 2 +- client/cmd/root.go | 7 +- client/cmd/service_controller.go | 6 +- client/cmd/service_installer.go | 6 +- client/cmd/ssh.go | 2 +- client/cmd/status.go | 2 +- client/cmd/up.go | 4 +- client/iface/wgproxy/proxy_test.go | 2 +- client/internal/debug/debug.go | 4 +- client/internal/dns/file_repair_unix_test.go | 2 +- client/internal/engine_test.go | 2 +- client/internal/peer/conn_test.go | 2 +- client/netbird-entrypoint.sh | 105 +++++++++++++++++++ client/server/server_test.go | 5 +- client/ui/client_ui.go | 2 +- management/client/client_test.go | 2 +- relay/client/client_test.go | 2 +- relay/cmd/root.go | 2 +- relay/test/benchmark_test.go | 2 +- relay/testec2/main.go | 2 +- upload-server/main.go | 2 +- util/log.go | 77 ++++++++++---- 29 files changed, 267 insertions(+), 66 deletions(-) create mode 100644 .dockerignore-client create mode 100755 client/netbird-entrypoint.sh diff --git a/.dockerignore-client b/.dockerignore-client new file mode 100644 index 000000000..a93ef97c0 --- /dev/null +++ b/.dockerignore-client @@ -0,0 +1,3 @@ +* +!client/netbird-entrypoint.sh +!netbird diff --git a/.gitignore b/.gitignore index abb728b19..e6c0c0aca 100644 --- a/.gitignore +++ b/.gitignore @@ -30,3 +30,4 @@ infrastructure_files/setup-*.env .vscode .DS_Store vendor/ +/netbird diff --git a/.goreleaser.yaml b/.goreleaser.yaml index ca5eafa62..d4a97b447 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -155,13 +155,15 @@ dockers: goarch: amd64 use: buildx dockerfile: client/Dockerfile + extra_files: + - client/netbird-entrypoint.sh build_flag_templates: - "--platform=linux/amd64" - "--label=org.opencontainers.image.created={{.Date}}" - "--label=org.opencontainers.image.title={{.ProjectName}}" - "--label=org.opencontainers.image.version={{.Version}}" - "--label=org.opencontainers.image.revision={{.FullCommit}}" - - "--label=org.opencontainers.image.version={{.Version}}" + - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}" - "--label=maintainer=dev@netbird.io" - image_templates: - netbirdio/netbird:{{ .Version }}-arm64v8 @@ -171,6 +173,8 @@ dockers: goarch: arm64 use: buildx dockerfile: client/Dockerfile + extra_files: + - client/netbird-entrypoint.sh build_flag_templates: - "--platform=linux/arm64" - "--label=org.opencontainers.image.created={{.Date}}" @@ -188,6 +192,8 @@ dockers: goarm: 6 use: buildx dockerfile: client/Dockerfile + extra_files: + - client/netbird-entrypoint.sh build_flag_templates: - "--platform=linux/arm" - "--label=org.opencontainers.image.created={{.Date}}" @@ -205,6 +211,8 @@ dockers: goarch: amd64 use: buildx dockerfile: client/Dockerfile-rootless + extra_files: + - client/netbird-entrypoint.sh build_flag_templates: - "--platform=linux/amd64" - "--label=org.opencontainers.image.created={{.Date}}" @@ -221,6 +229,8 @@ dockers: goarch: arm64 use: buildx dockerfile: client/Dockerfile-rootless + extra_files: + - client/netbird-entrypoint.sh build_flag_templates: - "--platform=linux/arm64" - "--label=org.opencontainers.image.created={{.Date}}" @@ -238,6 +248,8 @@ dockers: goarm: 6 use: buildx dockerfile: client/Dockerfile-rootless + extra_files: + - client/netbird-entrypoint.sh build_flag_templates: - "--platform=linux/arm" - "--label=org.opencontainers.image.created={{.Date}}" diff --git a/client/Dockerfile b/client/Dockerfile index 5f1f70040..e19a09909 100644 --- a/client/Dockerfile +++ b/client/Dockerfile @@ -1,9 +1,27 @@ -FROM alpine:3.21.3 +# build & run locally with: +# cd "$(git rev-parse --show-toplevel)" +# CGO_ENABLED=0 go build -o netbird ./client +# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client . +# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest + +FROM alpine:3.22.0 # iproute2: busybox doesn't display ip rules properly -RUN apk add --no-cache ca-certificates ip6tables iproute2 iptables +RUN apk add --no-cache \ + bash \ + ca-certificates \ + ip6tables \ + iproute2 \ + iptables + +ENV \ + NETBIRD_BIN="/usr/local/bin/netbird" \ + NB_LOG_FILE="console,/var/log/netbird/client.log" \ + NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \ + NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \ + NB_ENTRYPOINT_LOGIN_TIMEOUT="1" + +ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ] ARG NETBIRD_BINARY=netbird -COPY ${NETBIRD_BINARY} /usr/local/bin/netbird - -ENV NB_FOREGROUND_MODE=true -ENTRYPOINT [ "/usr/local/bin/netbird","up"] +COPY client/netbird-entrypoint.sh /usr/local/bin/netbird-entrypoint.sh +COPY "${NETBIRD_BINARY}" /usr/local/bin/netbird diff --git a/client/Dockerfile-rootless b/client/Dockerfile-rootless index 5055cb20d..5fa8de0a5 100644 --- a/client/Dockerfile-rootless +++ b/client/Dockerfile-rootless @@ -1,18 +1,33 @@ -FROM alpine:3.21.0 +# build & run locally with: +# cd "$(git rev-parse --show-toplevel)" +# CGO_ENABLED=0 go build -o netbird ./client +# podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client . +# podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest -ARG NETBIRD_BINARY=netbird -COPY ${NETBIRD_BINARY} /usr/local/bin/netbird +FROM alpine:3.22.0 -RUN apk add --no-cache ca-certificates \ +RUN apk add --no-cache \ + bash \ + ca-certificates \ && adduser -D -h /var/lib/netbird netbird + WORKDIR /var/lib/netbird USER netbird:netbird -ENV NB_FOREGROUND_MODE=true -ENV NB_USE_NETSTACK_MODE=true -ENV NB_ENABLE_NETSTACK_LOCAL_FORWARDING=true -ENV NB_CONFIG=config.json -ENV NB_DAEMON_ADDR=unix://netbird.sock -ENV NB_DISABLE_DNS=true +ENV \ + NETBIRD_BIN="/usr/local/bin/netbird" \ + NB_USE_NETSTACK_MODE="true" \ + NB_ENABLE_NETSTACK_LOCAL_FORWARDING="true" \ + NB_CONFIG="/var/lib/netbird/config.json" \ + NB_STATE_DIR="/var/lib/netbird" \ + NB_DAEMON_ADDR="unix:///var/lib/netbird/netbird.sock" \ + NB_LOG_FILE="console,/var/lib/netbird/client.log" \ + NB_DISABLE_DNS="true" \ + NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \ + NB_ENTRYPOINT_LOGIN_TIMEOUT="1" -ENTRYPOINT [ "/usr/local/bin/netbird", "up" ] +ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ] + +ARG NETBIRD_BINARY=netbird +COPY client/netbird-entrypoint.sh /usr/local/bin/netbird-entrypoint.sh +COPY "${NETBIRD_BINARY}" /usr/local/bin/netbird diff --git a/client/cmd/down.go b/client/cmd/down.go index 3a324cc19..cfa69bce2 100644 --- a/client/cmd/down.go +++ b/client/cmd/down.go @@ -20,7 +20,7 @@ var downCmd = &cobra.Command{ cmd.SetOut(cmd.OutOrStdout()) - err := util.InitLog(logLevel, "console") + err := util.InitLog(logLevel, util.LogConsole) if err != nil { log.Errorf("failed initializing log %v", err) return err diff --git a/client/cmd/login.go b/client/cmd/login.go index 14abcd034..8ac7086b8 100644 --- a/client/cmd/login.go +++ b/client/cmd/login.go @@ -32,7 +32,7 @@ var loginCmd = &cobra.Command{ cmd.SetOut(cmd.OutOrStdout()) - err := util.InitLog(logLevel, "console") + err := util.InitLog(logLevel, util.LogConsole) if err != nil { return fmt.Errorf("failed initializing log %v", err) } @@ -50,7 +50,7 @@ var loginCmd = &cobra.Command{ } // workaround to run without service - if logFile == "console" { + if util.FindFirstLogPath(logFiles) == "" { err = handleRebrand(cmd) if err != nil { return err diff --git a/client/cmd/login_test.go b/client/cmd/login_test.go index fa20435ea..cf98a5854 100644 --- a/client/cmd/login_test.go +++ b/client/cmd/login_test.go @@ -21,7 +21,7 @@ func TestLogin(t *testing.T) { "--config", confPath, "--log-file", - "console", + util.LogConsole, "--setup-key", strings.ToUpper("a2c8e62b-38f5-4553-b31e-dd66c696cebb"), "--management-url", diff --git a/client/cmd/root.go b/client/cmd/root.go index bfd0d06c5..1774602c4 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -10,6 +10,7 @@ import ( "os/signal" "path" "runtime" + "slices" "strings" "syscall" "time" @@ -51,7 +52,7 @@ var ( defaultLogFile string oldDefaultLogFileDir string oldDefaultLogFile string - logFile string + logFiles []string daemonAddr string managementURL string adminURL string @@ -120,7 +121,7 @@ func init() { rootCmd.PersistentFlags().StringVar(&adminURL, "admin-url", "", fmt.Sprintf("Admin Panel URL [http|https]://[host]:[port] (default \"%s\")", internal.DefaultAdminURL)) rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Netbird config file location") rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets Netbird log level") - rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout. If syslog is specified the log will be sent to syslog daemon.") + rootCmd.PersistentFlags().StringSliceVar(&logFiles, "log-file", []string{defaultLogFile}, "sets Netbird log paths written to simultaneously. If `console` is specified the log will be output to stdout. If `syslog` is specified the log will be sent to syslog daemon. You can pass the flag multiple times or separate entries by `,` character") rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)") rootCmd.PersistentFlags().StringVar(&setupKeyPath, "setup-key-file", "", "The path to a setup key obtained from the Management Service Dashboard (used to register peer) This is ignored if the setup-key flag is provided.") rootCmd.MarkFlagsMutuallyExclusive("setup-key", "setup-key-file") @@ -265,7 +266,7 @@ func getSetupKeyFromFile(setupKeyPath string) (string, error) { func handleRebrand(cmd *cobra.Command) error { var err error - if logFile == defaultLogFile { + if slices.Contains(logFiles, defaultLogFile) { if migrateToNetbird(oldDefaultLogFile, defaultLogFile) { cmd.Printf("will copy Log dir %s and its content to %s\n", oldDefaultLogFileDir, defaultLogFileDir) err = cpDir(oldDefaultLogFileDir, defaultLogFileDir) diff --git a/client/cmd/service_controller.go b/client/cmd/service_controller.go index 2545623ec..df84342c9 100644 --- a/client/cmd/service_controller.go +++ b/client/cmd/service_controller.go @@ -61,7 +61,7 @@ func (p *program) Start(svc service.Service) error { } } - serverInstance := server.New(p.ctx, configPath, logFile) + serverInstance := server.New(p.ctx, configPath, util.FindFirstLogPath(logFiles)) if err := serverInstance.Start(); err != nil { log.Fatalf("failed to start daemon: %v", err) } @@ -112,7 +112,7 @@ func setupServiceControlCommand(cmd *cobra.Command, ctx context.Context, cancel return nil, err } - if err := util.InitLog(logLevel, logFile); err != nil { + if err := util.InitLog(logLevel, logFiles...); err != nil { return nil, fmt.Errorf("init log: %w", err) } @@ -136,7 +136,7 @@ var runCmd = &cobra.Command{ ctx, cancel := context.WithCancel(cmd.Context()) SetupCloseHandler(ctx, cancel) - SetupDebugHandler(ctx, nil, nil, nil, logFile) + SetupDebugHandler(ctx, nil, nil, nil, util.FindFirstLogPath(logFiles)) s, err := setupServiceControlCommand(cmd, ctx, cancel) if err != nil { diff --git a/client/cmd/service_installer.go b/client/cmd/service_installer.go index 951efcc73..c994801a6 100644 --- a/client/cmd/service_installer.go +++ b/client/cmd/service_installer.go @@ -12,6 +12,8 @@ import ( "github.com/kardianos/service" "github.com/spf13/cobra" + + "github.com/netbirdio/netbird/util" ) var ErrGetServiceStatus = fmt.Errorf("failed to get service status") @@ -41,7 +43,7 @@ func buildServiceArguments() []string { args = append(args, "--management-url", managementURL) } - if logFile != "" { + for _, logFile := range logFiles { args = append(args, "--log-file", logFile) } @@ -54,7 +56,7 @@ func configurePlatformSpecificSettings(svcConfig *service.Config) error { // Respected only by systemd systems svcConfig.Dependencies = []string{"After=network.target syslog.target"} - if logFile != "console" { + if logFile := util.FindFirstLogPath(logFiles); logFile != "" { setStdLogPath := true dir := filepath.Dir(logFile) diff --git a/client/cmd/ssh.go b/client/cmd/ssh.go index f9dbc26fc..264f643ee 100644 --- a/client/cmd/ssh.go +++ b/client/cmd/ssh.go @@ -46,7 +46,7 @@ var sshCmd = &cobra.Command{ cmd.SetOut(cmd.OutOrStdout()) - err := util.InitLog(logLevel, "console") + err := util.InitLog(logLevel, util.LogConsole) if err != nil { return fmt.Errorf("failed initializing log %v", err) } diff --git a/client/cmd/status.go b/client/cmd/status.go index 2d6e41bc2..e50156ac9 100644 --- a/client/cmd/status.go +++ b/client/cmd/status.go @@ -59,7 +59,7 @@ func statusFunc(cmd *cobra.Command, args []string) error { return err } - err = util.InitLog(logLevel, "console") + err = util.InitLog(logLevel, util.LogConsole) if err != nil { return fmt.Errorf("failed initializing log %v", err) } diff --git a/client/cmd/up.go b/client/cmd/up.go index b9781c0df..529beeac7 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -79,7 +79,7 @@ func upFunc(cmd *cobra.Command, args []string) error { cmd.SetOut(cmd.OutOrStdout()) - err := util.InitLog(logLevel, "console") + err := util.InitLog(logLevel, util.LogConsole) if err != nil { return fmt.Errorf("failed initializing log %v", err) } @@ -484,7 +484,7 @@ func parseCustomDNSAddress(modified bool) ([]byte, error) { if !isValidAddrPort(customDNSAddress) { return nil, fmt.Errorf("%s is invalid, it should be formatted as IP:Port string or as an empty string like \"\"", customDNSAddress) } - if customDNSAddress == "" && logFile != "console" { + if customDNSAddress == "" && util.FindFirstLogPath(logFiles) != "" { parsed = []byte("empty") } else { parsed = []byte(customDNSAddress) diff --git a/client/iface/wgproxy/proxy_test.go b/client/iface/wgproxy/proxy_test.go index 2165b8aba..6882f9ea2 100644 --- a/client/iface/wgproxy/proxy_test.go +++ b/client/iface/wgproxy/proxy_test.go @@ -17,7 +17,7 @@ import ( ) func TestMain(m *testing.M) { - _ = util.InitLog("trace", "console") + _ = util.InitLog("trace", util.LogConsole) code := m.Run() os.Exit(code) } diff --git a/client/internal/debug/debug.go b/client/internal/debug/debug.go index 6455b3aaf..220e6854d 100644 --- a/client/internal/debug/debug.go +++ b/client/internal/debug/debug.go @@ -16,6 +16,7 @@ import ( "path/filepath" "runtime" "runtime/pprof" + "slices" "sort" "strings" "time" @@ -28,6 +29,7 @@ import ( "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/statemanager" mgmProto "github.com/netbirdio/netbird/management/proto" + "github.com/netbirdio/netbird/util" ) const readmeContent = `Netbird debug bundle @@ -283,7 +285,7 @@ func (g *BundleGenerator) createArchive() error { log.Errorf("Failed to add wg show output: %v", err) } - if g.logFile != "console" && g.logFile != "" { + if g.logFile != "" && !slices.Contains(util.SpecialLogs, g.logFile) { if err := g.addLogfile(); err != nil { log.Errorf("Failed to add log file to debug bundle: %v", err) if err := g.trySystemdLogFallback(); err != nil { diff --git a/client/internal/dns/file_repair_unix_test.go b/client/internal/dns/file_repair_unix_test.go index e948557b6..3aa0b859e 100644 --- a/client/internal/dns/file_repair_unix_test.go +++ b/client/internal/dns/file_repair_unix_test.go @@ -14,7 +14,7 @@ import ( ) func TestMain(m *testing.M) { - _ = util.InitLog("debug", "console") + _ = util.InitLog("debug", util.LogConsole) code := m.Run() os.Exit(code) } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index fffbed533..69586b47a 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -196,7 +196,7 @@ func (m *MockWGIface) LastActivities() map[string]monotime.Time { } func TestMain(m *testing.M) { - _ = util.InitLog("debug", "console") + _ = util.InitLog("debug", util.LogConsole) code := m.Run() os.Exit(code) } diff --git a/client/internal/peer/conn_test.go b/client/internal/peer/conn_test.go index c5055e646..7cad45953 100644 --- a/client/internal/peer/conn_test.go +++ b/client/internal/peer/conn_test.go @@ -31,7 +31,7 @@ var connConf = ConnConfig{ } func TestMain(m *testing.M) { - _ = util.InitLog("trace", "console") + _ = util.InitLog("trace", util.LogConsole) code := m.Run() os.Exit(code) } diff --git a/client/netbird-entrypoint.sh b/client/netbird-entrypoint.sh new file mode 100755 index 000000000..2422d2683 --- /dev/null +++ b/client/netbird-entrypoint.sh @@ -0,0 +1,105 @@ +#!/usr/bin/env bash +set -eEuo pipefail + +: ${NB_ENTRYPOINT_SERVICE_TIMEOUT:="5"} +: ${NB_ENTRYPOINT_LOGIN_TIMEOUT:="1"} +NETBIRD_BIN="${NETBIRD_BIN:-"netbird"}" +export NB_LOG_FILE="${NB_LOG_FILE:-"console,/var/log/netbird/client.log"}" +service_pids=() +log_file_path="" + +_log() { + # mimic Go logger's output for easier parsing + # 2025-04-15T21:32:00+08:00 INFO client/internal/config.go:495: setting notifications to disabled by default + printf "$(date -Isec) ${1} ${BASH_SOURCE[1]}:${BASH_LINENO[1]}: ${2}\n" "${@:3}" >&2 +} + +info() { + _log INFO "$@" +} + +warn() { + _log WARN "$@" +} + +on_exit() { + info "Shutting down NetBird daemon..." + if test "${#service_pids[@]}" -gt 0; then + info "terminating service process IDs: ${service_pids[@]@Q}" + kill -TERM "${service_pids[@]}" 2>/dev/null || true + wait "${service_pids[@]}" 2>/dev/null || true + else + info "there are no service processes to terminate" + fi +} + +wait_for_message() { + local timeout="${1}" message="${2}" + if test "${timeout}" -eq 0; then + info "not waiting for log line ${message@Q} due to zero timeout." + elif test -n "${log_file_path}"; then + info "waiting for log line ${message@Q} for ${timeout} seconds..." + grep -q "${message}" <(timeout "${timeout}" tail -F "${log_file_path}" 2>/dev/null) + else + info "log file unsupported, sleeping for ${timeout} seconds..." + sleep "${timeout}" + fi +} + +locate_log_file() { + local log_files_string="${1}" + + while read -r log_file; do + case "${log_file}" in + console | syslog) ;; + *) + log_file_path="${log_file}" + return + ;; + esac + done < <(sed 's#,#\n#g' <<<"${log_files_string}") + + warn "log files parsing for ${log_files_string@Q} is not supported by debug bundles" + warn "please consider removing the \$NB_LOG_FILE or setting it to real file, before gathering debug bundles." +} + +wait_for_daemon_startup() { + local timeout="${1}" + + if test -n "${log_file_path}"; then + if ! wait_for_message "${timeout}" "started daemon server"; then + warn "log line containing 'started daemon server' not found after ${timeout} seconds" + warn "daemon failed to start, exiting..." + exit 1 + fi + else + warn "daemon service startup not discovered, sleeping ${timeout} instead" + sleep "${timeout}" + fi +} + +login_if_needed() { + local timeout="${1}" + + if test -n "${log_file_path}" && wait_for_message "${timeout}" 'peer has been successfully registered'; then + info "already logged in, skipping 'netbird up'..." + else + info "logging in..." + "${NETBIRD_BIN}" up + fi +} + +main() { + trap 'on_exit' SIGTERM SIGINT EXIT + "${NETBIRD_BIN}" service run & + service_pids+=("$!") + info "registered new service process 'netbird service run', currently running: ${service_pids[@]@Q}" + + locate_log_file "${NB_LOG_FILE}" + wait_for_daemon_startup "${NB_ENTRYPOINT_SERVICE_TIMEOUT}" + login_if_needed "${NB_ENTRYPOINT_LOGIN_TIMEOUT}" + + wait "${service_pids[@]}" +} + +main "$@" diff --git a/client/server/server_test.go b/client/server/server_test.go index 7c46aac5d..11e4d3899 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -32,6 +32,7 @@ import ( "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/signal/proto" signalServer "github.com/netbirdio/netbird/signal/server" + "github.com/netbirdio/netbird/util" ) var ( @@ -92,7 +93,7 @@ func TestConnectWithRetryRuns(t *testing.T) { func TestServer_Up(t *testing.T) { ctx := internal.CtxInitState(context.Background()) - s := New(ctx, t.TempDir()+"/config.json", "console") + s := New(ctx, t.TempDir()+"/config.json", util.LogConsole) err := s.Start() require.NoError(t, err) @@ -130,7 +131,7 @@ func (m *mockSubscribeEventsServer) Context() context.Context { func TestServer_SubcribeEvents(t *testing.T) { ctx := internal.CtxInitState(context.Background()) - s := New(ctx, t.TempDir()+"/config.json", "console") + s := New(ctx, t.TempDir()+"/config.json", util.LogConsole) err := s.Start() require.NoError(t, err) diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index ace5b71e4..4480adb51 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -66,7 +66,7 @@ func main() { } logFile = file } else { - _ = util.InitLog("trace", "console") + _ = util.InitLog("trace", util.LogConsole) } // Create the Fyne application. diff --git a/management/client/client_test.go b/management/client/client_test.go index b59b7c982..5b2a87492 100644 --- a/management/client/client_test.go +++ b/management/client/client_test.go @@ -41,7 +41,7 @@ import ( const ValidKey = "A2C8E62B-38F5-4553-B31E-DD66C696CEBB" func TestMain(m *testing.M) { - _ = util.InitLog("debug", "console") + _ = util.InitLog("debug", util.LogConsole) code := m.Run() os.Exit(code) } diff --git a/relay/client/client_test.go b/relay/client/client_test.go index c85ec9fd3..2ce8d7e34 100644 --- a/relay/client/client_test.go +++ b/relay/client/client_test.go @@ -30,7 +30,7 @@ var ( ) func TestMain(m *testing.M) { - _ = util.InitLog("debug", "console") + _ = util.InitLog("debug", util.LogConsole) code := m.Run() os.Exit(code) } diff --git a/relay/cmd/root.go b/relay/cmd/root.go index 15090024c..7b8e5bbeb 100644 --- a/relay/cmd/root.go +++ b/relay/cmd/root.go @@ -73,7 +73,7 @@ var ( ) func init() { - _ = util.InitLog("trace", "console") + _ = util.InitLog("trace", util.LogConsole) cobraConfig = &Config{} rootCmd.PersistentFlags().StringVarP(&cobraConfig.ListenAddress, "listen-address", "l", ":443", "listen address") rootCmd.PersistentFlags().StringVarP(&cobraConfig.ExposedAddress, "exposed-address", "e", "", "instance domain address (or ip) and port, it will be distributes between peers") diff --git a/relay/test/benchmark_test.go b/relay/test/benchmark_test.go index 2e67ab803..afbb14b84 100644 --- a/relay/test/benchmark_test.go +++ b/relay/test/benchmark_test.go @@ -27,7 +27,7 @@ var ( ) func TestMain(m *testing.M) { - _ = util.InitLog("error", "console") + _ = util.InitLog("error", util.LogConsole) code := m.Run() os.Exit(code) } diff --git a/relay/testec2/main.go b/relay/testec2/main.go index 0c8099a5e..6954d6a50 100644 --- a/relay/testec2/main.go +++ b/relay/testec2/main.go @@ -233,7 +233,7 @@ func TURNReaderMain() []testResult { func main() { var mode string - _ = util.InitLog("debug", "console") + _ = util.InitLog("debug", util.LogConsole) flag.StringVar(&mode, "mode", "sender", "sender or receiver mode") flag.Parse() diff --git a/upload-server/main.go b/upload-server/main.go index dcfb35cdf..546c0f584 100644 --- a/upload-server/main.go +++ b/upload-server/main.go @@ -10,7 +10,7 @@ import ( ) func main() { - err := util.InitLog("info", "console") + err := util.InitLog("info", util.LogConsole) if err != nil { log.Fatalf("Failed to initialize logger: %v", err) } diff --git a/util/log.go b/util/log.go index 53d2b0684..a951eab87 100644 --- a/util/log.go +++ b/util/log.go @@ -16,36 +16,54 @@ import ( const defaultLogSize = 15 +const ( + LogConsole = "console" + LogSyslog = "syslog" +) + +var ( + SpecialLogs = []string{ + LogSyslog, + LogConsole, + } +) + // InitLog parses and sets log-level input -func InitLog(logLevel string, logPath string) error { +func InitLog(logLevel string, logs ...string) error { level, err := log.ParseLevel(logLevel) if err != nil { log.Errorf("Failed parsing log-level %s: %s", logLevel, err) return err } - customOutputs := []string{"console", "syslog"} + var writers []io.Writer + logFmt := os.Getenv("NB_LOG_FORMAT") - if logPath != "" && !slices.Contains(customOutputs, logPath) { - maxLogSize := getLogMaxSize() - lumberjackLogger := &lumberjack.Logger{ - // Log file absolute path, os agnostic - Filename: filepath.ToSlash(logPath), - MaxSize: maxLogSize, // MB - MaxBackups: 10, - MaxAge: 30, // days - Compress: true, + for _, logPath := range logs { + switch logPath { + case LogSyslog: + AddSyslogHook() + logFmt = "syslog" + case LogConsole: + writers = append(writers, os.Stderr) + case "": + log.Warnf("empty log path received: %#v", logPath) + default: + writers = append(writers, newRotatedOutput(logPath)) } - log.SetOutput(io.Writer(lumberjackLogger)) - } else if logPath == "syslog" { - AddSyslogHook() } - //nolint:gocritic - if os.Getenv("NB_LOG_FORMAT") == "json" { + if len(writers) > 1 { + log.SetOutput(io.MultiWriter(writers...)) + } else if len(writers) == 1 { + log.SetOutput(writers[0]) + } + + switch logFmt { + case "json": formatter.SetJSONFormatter(log.StandardLogger()) - } else if logPath == "syslog" { + case "syslog": formatter.SetSyslogFormatter(log.StandardLogger()) - } else { + default: formatter.SetTextFormatter(log.StandardLogger()) } log.SetLevel(level) @@ -55,6 +73,29 @@ func InitLog(logLevel string, logPath string) error { return nil } +// FindFirstLogPath returns the first logs entry that could be a log path, that is neither empty, nor a special value +func FindFirstLogPath(logs []string) string { + for _, logFile := range logs { + if logFile != "" && !slices.Contains(SpecialLogs, logFile) { + return logFile + } + } + return "" +} + +func newRotatedOutput(logPath string) io.Writer { + maxLogSize := getLogMaxSize() + lumberjackLogger := &lumberjack.Logger{ + // Log file absolute path, os agnostic + Filename: filepath.ToSlash(logPath), + MaxSize: maxLogSize, // MB + MaxBackups: 10, + MaxAge: 30, // days + Compress: true, + } + return lumberjackLogger +} + func setGRPCLibLogger() { logOut := log.StandardLogger().Writer() if os.Getenv("GRPC_GO_LOG_SEVERITY_LEVEL") != "info" { From cb85d3f2fc8d6836544013fd1764d93d8a02351d Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 25 Jul 2025 11:46:04 +0200 Subject: [PATCH 35/50] [client] Always register NetBird with plain Linux DNS and use original servers as upstream (#3967) --- client/internal/dns/file_parser_unix.go | 40 +------- client/internal/dns/file_parser_unix_test.go | 52 +--------- client/internal/dns/file_repair_unix.go | 9 +- client/internal/dns/file_repair_unix_test.go | 9 +- client/internal/dns/file_unix.go | 73 +++++--------- client/internal/dns/handler_chain.go | 3 +- client/internal/dns/host.go | 8 +- client/internal/dns/host_darwin.go | 22 ++-- client/internal/dns/host_windows.go | 11 +- client/internal/dns/mock_server.go | 5 +- client/internal/dns/network_manager_unix.go | 6 +- client/internal/dns/resolvconf_unix.go | 31 +++--- client/internal/dns/server.go | 101 ++++++++++++++----- client/internal/dns/server_test.go | 4 +- client/internal/dns/service.go | 4 +- client/internal/dns/service_listener.go | 34 ++++--- client/internal/dns/service_memory.go | 13 +-- client/internal/dns/systemd_linux.go | 21 ++-- client/internal/dns/unclean_shutdown_unix.go | 7 +- client/internal/engine.go | 2 +- 20 files changed, 196 insertions(+), 259 deletions(-) diff --git a/client/internal/dns/file_parser_unix.go b/client/internal/dns/file_parser_unix.go index 130c88214..6e123c94e 100644 --- a/client/internal/dns/file_parser_unix.go +++ b/client/internal/dns/file_parser_unix.go @@ -4,8 +4,8 @@ package dns import ( "fmt" + "net/netip" "os" - "regexp" "strings" log "github.com/sirupsen/logrus" @@ -15,9 +15,6 @@ const ( defaultResolvConfPath = "/etc/resolv.conf" ) -var timeoutRegex = regexp.MustCompile(`timeout:\d+`) -var attemptsRegex = regexp.MustCompile(`attempts:\d+`) - type resolvConf struct { nameServers []string searchDomains []string @@ -108,40 +105,9 @@ func parseResolvConfFile(resolvConfFile string) (*resolvConf, error) { return rconf, nil } -// prepareOptionsWithTimeout appends timeout to existing options if it doesn't exist, -// otherwise it adds a new option with timeout and attempts. -func prepareOptionsWithTimeout(input []string, timeout int, attempts int) []string { - configs := make([]string, len(input)) - copy(configs, input) - - for i, config := range configs { - if strings.HasPrefix(config, "options") { - config = strings.ReplaceAll(config, "rotate", "") - config = strings.Join(strings.Fields(config), " ") - - if strings.Contains(config, "timeout:") { - config = timeoutRegex.ReplaceAllString(config, fmt.Sprintf("timeout:%d", timeout)) - } else { - config = strings.Replace(config, "options ", fmt.Sprintf("options timeout:%d ", timeout), 1) - } - - if strings.Contains(config, "attempts:") { - config = attemptsRegex.ReplaceAllString(config, fmt.Sprintf("attempts:%d", attempts)) - } else { - config = strings.Replace(config, "options ", fmt.Sprintf("options attempts:%d ", attempts), 1) - } - - configs[i] = config - return configs - } - } - - return append(configs, fmt.Sprintf("options timeout:%d attempts:%d", timeout, attempts)) -} - // removeFirstNbNameserver removes the given nameserver from the given file if it is in the first position // and writes the file back to the original location -func removeFirstNbNameserver(filename, nameserverIP string) error { +func removeFirstNbNameserver(filename string, nameserverIP netip.Addr) error { resolvConf, err := parseResolvConfFile(filename) if err != nil { return fmt.Errorf("parse backup resolv.conf: %w", err) @@ -151,7 +117,7 @@ func removeFirstNbNameserver(filename, nameserverIP string) error { return fmt.Errorf("read %s: %w", filename, err) } - if len(resolvConf.nameServers) > 1 && resolvConf.nameServers[0] == nameserverIP { + if len(resolvConf.nameServers) > 1 && resolvConf.nameServers[0] == nameserverIP.String() { newContent := strings.Replace(string(content), fmt.Sprintf("nameserver %s\n", nameserverIP), "", 1) stat, err := os.Stat(filename) diff --git a/client/internal/dns/file_parser_unix_test.go b/client/internal/dns/file_parser_unix_test.go index 1d6e64683..228a708f1 100644 --- a/client/internal/dns/file_parser_unix_test.go +++ b/client/internal/dns/file_parser_unix_test.go @@ -3,11 +3,13 @@ package dns import ( + "net/netip" "os" "path/filepath" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func Test_parseResolvConf(t *testing.T) { @@ -175,52 +177,6 @@ nameserver 192.168.0.1 } } -func TestPrepareOptionsWithTimeout(t *testing.T) { - tests := []struct { - name string - others []string - timeout int - attempts int - expected []string - }{ - { - name: "Append new options with timeout and attempts", - others: []string{"some config"}, - timeout: 2, - attempts: 2, - expected: []string{"some config", "options timeout:2 attempts:2"}, - }, - { - name: "Modify existing options to exclude rotate and include timeout and attempts", - others: []string{"some config", "options rotate someother"}, - timeout: 3, - attempts: 2, - expected: []string{"some config", "options attempts:2 timeout:3 someother"}, - }, - { - name: "Existing options with timeout and attempts are updated", - others: []string{"some config", "options timeout:4 attempts:3"}, - timeout: 5, - attempts: 4, - expected: []string{"some config", "options timeout:5 attempts:4"}, - }, - { - name: "Modify existing options, add missing attempts before timeout", - others: []string{"some config", "options timeout:4"}, - timeout: 4, - attempts: 3, - expected: []string{"some config", "options attempts:3 timeout:4"}, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - result := prepareOptionsWithTimeout(tc.others, tc.timeout, tc.attempts) - assert.Equal(t, tc.expected, result) - }) - } -} - func TestRemoveFirstNbNameserver(t *testing.T) { testCases := []struct { name string @@ -292,7 +248,9 @@ search localdomain`, err := os.WriteFile(tempFile, []byte(tc.content), 0644) assert.NoError(t, err) - err = removeFirstNbNameserver(tempFile, tc.ipToRemove) + ip, err := netip.ParseAddr(tc.ipToRemove) + require.NoError(t, err, "Failed to parse IP address") + err = removeFirstNbNameserver(tempFile, ip) assert.NoError(t, err) content, err := os.ReadFile(tempFile) diff --git a/client/internal/dns/file_repair_unix.go b/client/internal/dns/file_repair_unix.go index 9a9218fa1..75af411df 100644 --- a/client/internal/dns/file_repair_unix.go +++ b/client/internal/dns/file_repair_unix.go @@ -3,6 +3,7 @@ package dns import ( + "net/netip" "path" "path/filepath" "sync" @@ -22,7 +23,7 @@ var ( } ) -type repairConfFn func([]string, string, *resolvConf, *statemanager.Manager) error +type repairConfFn func([]string, netip.Addr, *resolvConf, *statemanager.Manager) error type repair struct { operationFile string @@ -42,7 +43,7 @@ func newRepair(operationFile string, updateFn repairConfFn) *repair { } } -func (f *repair) watchFileChanges(nbSearchDomains []string, nbNameserverIP string, stateManager *statemanager.Manager) { +func (f *repair) watchFileChanges(nbSearchDomains []string, nbNameserverIP netip.Addr, stateManager *statemanager.Manager) { if f.inotify != nil { return } @@ -136,7 +137,7 @@ func (f *repair) isEventRelevant(event fsnotify.Event) bool { // nbParamsAreMissing checks if the resolv.conf file contains all the parameters that NetBird needs // check the NetBird related nameserver IP at the first place // check the NetBird related search domains in the search domains list -func isNbParamsMissing(nbSearchDomains []string, nbNameserverIP string, rConf *resolvConf) bool { +func isNbParamsMissing(nbSearchDomains []string, nbNameserverIP netip.Addr, rConf *resolvConf) bool { if !isContains(nbSearchDomains, rConf.searchDomains) { return true } @@ -145,7 +146,7 @@ func isNbParamsMissing(nbSearchDomains []string, nbNameserverIP string, rConf *r return true } - if rConf.nameServers[0] != nbNameserverIP { + if rConf.nameServers[0] != nbNameserverIP.String() { return true } diff --git a/client/internal/dns/file_repair_unix_test.go b/client/internal/dns/file_repair_unix_test.go index 3aa0b859e..f22081307 100644 --- a/client/internal/dns/file_repair_unix_test.go +++ b/client/internal/dns/file_repair_unix_test.go @@ -4,6 +4,7 @@ package dns import ( "context" + "net/netip" "os" "path/filepath" "testing" @@ -105,14 +106,14 @@ nameserver 8.8.8.8`, var changed bool ctx, cancel := context.WithTimeout(context.Background(), time.Second) - updateFn := func([]string, string, *resolvConf, *statemanager.Manager) error { + updateFn := func([]string, netip.Addr, *resolvConf, *statemanager.Manager) error { changed = true cancel() return nil } r := newRepair(operationFile, updateFn) - r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1", nil) + r.watchFileChanges([]string{"netbird.cloud"}, netip.MustParseAddr("10.0.0.1"), nil) err = os.WriteFile(operationFile, []byte(tt.touchedConfContent), 0755) if err != nil { @@ -152,14 +153,14 @@ searchdomain netbird.cloud something` var changed bool ctx, cancel := context.WithTimeout(context.Background(), time.Second) - updateFn := func([]string, string, *resolvConf, *statemanager.Manager) error { + updateFn := func([]string, netip.Addr, *resolvConf, *statemanager.Manager) error { changed = true cancel() return nil } r := newRepair(tmpLink, updateFn) - r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1", nil) + r.watchFileChanges([]string{"netbird.cloud"}, netip.MustParseAddr("10.0.0.1"), nil) err = os.WriteFile(tmpLink, []byte(modifyContent), 0755) if err != nil { diff --git a/client/internal/dns/file_unix.go b/client/internal/dns/file_unix.go index 3e338267f..423989f72 100644 --- a/client/internal/dns/file_unix.go +++ b/client/internal/dns/file_unix.go @@ -8,7 +8,6 @@ import ( "net/netip" "os" "strings" - "time" log "github.com/sirupsen/logrus" @@ -18,7 +17,7 @@ import ( const ( fileGeneratedResolvConfContentHeader = "# Generated by NetBird" fileGeneratedResolvConfContentHeaderNextLine = fileGeneratedResolvConfContentHeader + ` -# If needed you can restore the original file by copying back ` + fileDefaultResolvConfBackupLocation + "\n\n" +# The original file can be restored from ` + fileDefaultResolvConfBackupLocation + "\n\n" fileDefaultResolvConfBackupLocation = defaultResolvConfPath + ".original.netbird" @@ -26,16 +25,11 @@ const ( fileMaxNumberOfSearchDomains = 6 ) -const ( - dnsFailoverTimeout = 4 * time.Second - dnsFailoverAttempts = 1 -) - type fileConfigurator struct { - repair *repair - - originalPerms os.FileMode - nbNameserverIP string + repair *repair + originalPerms os.FileMode + nbNameserverIP netip.Addr + originalNameservers []string } func newFileConfigurator() (*fileConfigurator, error) { @@ -49,22 +43,9 @@ func (f *fileConfigurator) supportCustomPort() bool { } func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { - backupFileExist := f.isBackupFileExist() - if !config.RouteAll { - if backupFileExist { - f.repair.stopWatchFileChanges() - err := f.restore() - if err != nil { - return fmt.Errorf("restoring the original resolv.conf file return err: %w", err) - } - } - return ErrRouteAllWithoutNameserverGroup - } - - if !backupFileExist { - err := f.backup() - if err != nil { - return fmt.Errorf("unable to backup the resolv.conf file: %w", err) + if !f.isBackupFileExist() { + if err := f.backup(); err != nil { + return fmt.Errorf("backup resolv.conf: %w", err) } } @@ -76,6 +57,8 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *st log.Errorf("could not read original search domains from %s: %s", fileDefaultResolvConfBackupLocation, err) } + f.originalNameservers = resolvConf.nameServers + f.repair.stopWatchFileChanges() err = f.updateConfig(nbSearchDomains, f.nbNameserverIP, resolvConf, stateManager) @@ -86,15 +69,19 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *st return nil } -func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP string, cfg *resolvConf, stateManager *statemanager.Manager) error { - searchDomainList := mergeSearchDomains(nbSearchDomains, cfg.searchDomains) - nameServers := generateNsList(nbNameserverIP, cfg) +// getOriginalNameservers returns the nameservers that were found in the original resolv.conf +func (f *fileConfigurator) getOriginalNameservers() []string { + return f.originalNameservers +} + +func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP netip.Addr, cfg *resolvConf, stateManager *statemanager.Manager) error { + searchDomainList := mergeSearchDomains(nbSearchDomains, cfg.searchDomains) - options := prepareOptionsWithTimeout(cfg.others, int(dnsFailoverTimeout.Seconds()), dnsFailoverAttempts) buf := prepareResolvConfContent( searchDomainList, - nameServers, - options) + []string{nbNameserverIP.String()}, + cfg.others, + ) log.Debugf("creating managed file %s", defaultResolvConfPath) err := os.WriteFile(defaultResolvConfPath, buf.Bytes(), f.originalPerms) @@ -197,38 +184,28 @@ func restoreResolvConfFile() error { return nil } -// generateNsList generates a list of nameservers from the config and adds the primary nameserver to the beginning of the list -func generateNsList(nbNameserverIP string, cfg *resolvConf) []string { - ns := make([]string, 1, len(cfg.nameServers)+1) - ns[0] = nbNameserverIP - for _, cfgNs := range cfg.nameServers { - if nbNameserverIP != cfgNs { - ns = append(ns, cfgNs) - } - } - return ns -} - func prepareResolvConfContent(searchDomains, nameServers, others []string) bytes.Buffer { var buf bytes.Buffer + buf.WriteString(fileGeneratedResolvConfContentHeaderNextLine) for _, cfgLine := range others { buf.WriteString(cfgLine) - buf.WriteString("\n") + buf.WriteByte('\n') } if len(searchDomains) > 0 { buf.WriteString("search ") buf.WriteString(strings.Join(searchDomains, " ")) - buf.WriteString("\n") + buf.WriteByte('\n') } for _, ns := range nameServers { buf.WriteString("nameserver ") buf.WriteString(ns) - buf.WriteString("\n") + buf.WriteByte('\n') } + return buf } diff --git a/client/internal/dns/handler_chain.go b/client/internal/dns/handler_chain.go index 7e7e7cc2d..36da8fb78 100644 --- a/client/internal/dns/handler_chain.go +++ b/client/internal/dns/handler_chain.go @@ -15,6 +15,7 @@ const ( PriorityDNSRoute = 75 PriorityUpstream = 50 PriorityDefault = 1 + PriorityFallback = -100 ) type SubdomainMatcher interface { @@ -191,7 +192,7 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { // No handler matched or all handlers passed log.Tracef("no handler found for domain=%s", qname) resp := &dns.Msg{} - resp.SetRcode(r, dns.RcodeNameError) + resp.SetRcode(r, dns.RcodeRefused) if err := w.WriteMsg(resp); err != nil { log.Errorf("failed to write DNS response: %v", err) } diff --git a/client/internal/dns/host.go b/client/internal/dns/host.go index dbf0f2cfc..fa474afde 100644 --- a/client/internal/dns/host.go +++ b/client/internal/dns/host.go @@ -11,8 +11,6 @@ import ( nbdns "github.com/netbirdio/netbird/dns" ) -var ErrRouteAllWithoutNameserverGroup = fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured") - const ( ipv4ReverseZone = ".in-addr.arpa." ipv6ReverseZone = ".ip6.arpa." @@ -27,14 +25,14 @@ type hostManager interface { type SystemDNSSettings struct { Domains []string - ServerIP string + ServerIP netip.Addr ServerPort int } type HostDNSConfig struct { Domains []DomainConfig `json:"domains"` RouteAll bool `json:"routeAll"` - ServerIP string `json:"serverIP"` + ServerIP netip.Addr `json:"serverIP"` ServerPort int `json:"serverPort"` } @@ -89,7 +87,7 @@ func newNoopHostMocker() hostManager { } } -func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostDNSConfig { +func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip netip.Addr, port int) HostDNSConfig { config := HostDNSConfig{ RouteAll: false, ServerIP: ip, diff --git a/client/internal/dns/host_darwin.go b/client/internal/dns/host_darwin.go index a445bc6c4..820cf9029 100644 --- a/client/internal/dns/host_darwin.go +++ b/client/internal/dns/host_darwin.go @@ -7,7 +7,7 @@ import ( "bytes" "fmt" "io" - "net" + "net/netip" "os/exec" "strconv" "strings" @@ -165,13 +165,13 @@ func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error { } func (s *systemConfigurator) addLocalDNS() error { - if s.systemDNSSettings.ServerIP == "" || len(s.systemDNSSettings.Domains) == 0 { + if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 { err := s.recordSystemDNSSettings(true) log.Errorf("Unable to get system DNS configuration") return err } localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix) - if s.systemDNSSettings.ServerIP != "" && len(s.systemDNSSettings.Domains) != 0 { + if s.systemDNSSettings.ServerIP.IsValid() && len(s.systemDNSSettings.Domains) != 0 { err := s.addSearchDomains(localKey, strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort) if err != nil { return fmt.Errorf("couldn't add local network DNS conf: %w", err) @@ -184,7 +184,7 @@ func (s *systemConfigurator) addLocalDNS() error { } func (s *systemConfigurator) recordSystemDNSSettings(force bool) error { - if s.systemDNSSettings.ServerIP != "" && len(s.systemDNSSettings.Domains) != 0 && !force { + if s.systemDNSSettings.ServerIP.IsValid() && len(s.systemDNSSettings.Domains) != 0 && !force { return nil } @@ -238,8 +238,8 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) { dnsSettings.Domains = append(dnsSettings.Domains, searchDomain) } else if inServerAddressesArray { address := strings.Split(line, " : ")[1] - if ip := net.ParseIP(address); ip != nil && ip.To4() != nil { - dnsSettings.ServerIP = address + if ip, err := netip.ParseAddr(address); err == nil && ip.Is4() { + dnsSettings.ServerIP = ip inServerAddressesArray = false // Stop reading after finding the first IPv4 address } } @@ -250,12 +250,12 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) { } // default to 53 port - dnsSettings.ServerPort = 53 + dnsSettings.ServerPort = defaultPort return dnsSettings, nil } -func (s *systemConfigurator) addSearchDomains(key, domains string, ip string, port int) error { +func (s *systemConfigurator) addSearchDomains(key, domains string, ip netip.Addr, port int) error { err := s.addDNSState(key, domains, ip, port, true) if err != nil { return fmt.Errorf("add dns state: %w", err) @@ -268,7 +268,7 @@ func (s *systemConfigurator) addSearchDomains(key, domains string, ip string, po return nil } -func (s *systemConfigurator) addMatchDomains(key, domains, dnsServer string, port int) error { +func (s *systemConfigurator) addMatchDomains(key, domains string, dnsServer netip.Addr, port int) error { err := s.addDNSState(key, domains, dnsServer, port, false) if err != nil { return fmt.Errorf("add dns state: %w", err) @@ -281,14 +281,14 @@ func (s *systemConfigurator) addMatchDomains(key, domains, dnsServer string, por return nil } -func (s *systemConfigurator) addDNSState(state, domains, dnsServer string, port int, enableSearch bool) error { +func (s *systemConfigurator) addDNSState(state, domains string, dnsServer netip.Addr, port int, enableSearch bool) error { noSearch := "1" if enableSearch { noSearch = "0" } lines := buildAddCommandLine(keySupplementalMatchDomains, arraySymbol+domains) lines += buildAddCommandLine(keySupplementalMatchDomainsNoSearch, digitSymbol+noSearch) - lines += buildAddCommandLine(keyServerAddresses, arraySymbol+dnsServer) + lines += buildAddCommandLine(keyServerAddresses, arraySymbol+dnsServer.String()) lines += buildAddCommandLine(keyServerPort, digitSymbol+strconv.Itoa(port)) addDomainCommand := buildCreateStateWithOperation(state, lines) diff --git a/client/internal/dns/host_windows.go b/client/internal/dns/host_windows.go index f8939328a..648a58207 100644 --- a/client/internal/dns/host_windows.go +++ b/client/internal/dns/host_windows.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io" + "net/netip" "os/exec" "strings" "syscall" @@ -210,8 +211,8 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager return nil } -func (r *registryConfigurator) addDNSSetupForAll(ip string) error { - if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip); err != nil { +func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error { + if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip.String()); err != nil { return fmt.Errorf("adding dns setup for all failed: %w", err) } r.routingAll = true @@ -219,7 +220,7 @@ func (r *registryConfigurator) addDNSSetupForAll(ip string) error { return nil } -func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip string) error { +func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr) error { // if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored // see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745 if r.gpo { @@ -241,7 +242,7 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip string) er } // configureDNSPolicy handles the actual configuration of a DNS policy at the specified path -func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []string, ip string) error { +func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []string, ip netip.Addr) error { if err := removeRegistryKeyFromDNSPolicyConfig(policyPath); err != nil { return fmt.Errorf("remove existing dns policy: %w", err) } @@ -260,7 +261,7 @@ func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []s return fmt.Errorf("set %s: %w", dnsPolicyConfigNameKey, err) } - if err := regKey.SetStringValue(dnsPolicyConfigGenericDNSServersKey, ip); err != nil { + if err := regKey.SetStringValue(dnsPolicyConfigGenericDNSServersKey, ip.String()); err != nil { return fmt.Errorf("set %s: %w", dnsPolicyConfigGenericDNSServersKey, err) } diff --git a/client/internal/dns/mock_server.go b/client/internal/dns/mock_server.go index c5dd6e23f..40a2e7384 100644 --- a/client/internal/dns/mock_server.go +++ b/client/internal/dns/mock_server.go @@ -2,6 +2,7 @@ package dns import ( "fmt" + "net/netip" "github.com/miekg/dns" @@ -45,8 +46,8 @@ func (m *MockServer) Stop() { } } -func (m *MockServer) DnsIP() string { - return "" +func (m *MockServer) DnsIP() netip.Addr { + return netip.MustParseAddr("100.10.254.255") } func (m *MockServer) OnUpdatedHostDNSServer(strings []string) { diff --git a/client/internal/dns/network_manager_unix.go b/client/internal/dns/network_manager_unix.go index caae63a24..5459bc2d7 100644 --- a/client/internal/dns/network_manager_unix.go +++ b/client/internal/dns/network_manager_unix.go @@ -110,11 +110,7 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig, st connSettings.cleanDeprecatedSettings() - dnsIP, err := netip.ParseAddr(config.ServerIP) - if err != nil { - return fmt.Errorf("unable to parse ip address, error: %w", err) - } - convDNSIP := binary.LittleEndian.Uint32(dnsIP.AsSlice()) + convDNSIP := binary.LittleEndian.Uint32(config.ServerIP.AsSlice()) connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSKey] = dbus.MakeVariant([]uint32{convDNSIP}) var ( searchDomains []string diff --git a/client/internal/dns/resolvconf_unix.go b/client/internal/dns/resolvconf_unix.go index 54c4c75bf..6080c1d2c 100644 --- a/client/internal/dns/resolvconf_unix.go +++ b/client/internal/dns/resolvconf_unix.go @@ -46,9 +46,9 @@ type resolvconf struct { func detectResolvconfType() (resolvconfType, error) { cmd := exec.Command(resolvconfCommand, "--version") - out, err := cmd.Output() + out, err := cmd.CombinedOutput() if err != nil { - return typeOpenresolv, fmt.Errorf("failed to determine resolvconf type: %w", err) + return typeOpenresolv, fmt.Errorf("determine resolvconf type: %w", err) } if strings.Contains(string(out), "openresolv") { @@ -66,7 +66,7 @@ func newResolvConfConfigurator(wgInterface string) (*resolvconf, error) { implType, err := detectResolvconfType() if err != nil { log.Warnf("failed to detect resolvconf type, defaulting to openresolv: %v", err) - implType = typeOpenresolv + implType = typeResolvconf } else { log.Infof("detected resolvconf type: %v", implType) } @@ -85,24 +85,14 @@ func (r *resolvconf) supportCustomPort() bool { } func (r *resolvconf) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { - var err error - if !config.RouteAll { - err = r.restoreHostDNS() - if err != nil { - log.Errorf("restore host dns: %s", err) - } - return ErrRouteAllWithoutNameserverGroup - } - searchDomainList := searchDomains(config) searchDomainList = mergeSearchDomains(searchDomainList, r.originalSearchDomains) - options := prepareOptionsWithTimeout(r.othersConfigs, int(dnsFailoverTimeout.Seconds()), dnsFailoverAttempts) - buf := prepareResolvConfContent( searchDomainList, - append([]string{config.ServerIP}, r.originalNameServers...), - options) + []string{config.ServerIP.String()}, + r.othersConfigs, + ) state := &ShutdownState{ ManagerType: resolvConfManager, @@ -112,8 +102,7 @@ func (r *resolvconf) applyDNSConfig(config HostDNSConfig, stateManager *stateman log.Errorf("failed to update shutdown state: %s", err) } - err = r.applyConfig(buf) - if err != nil { + if err := r.applyConfig(buf); err != nil { return fmt.Errorf("apply config: %w", err) } @@ -121,6 +110,10 @@ func (r *resolvconf) applyDNSConfig(config HostDNSConfig, stateManager *stateman return nil } +func (r *resolvconf) getOriginalNameservers() []string { + return r.originalNameServers +} + func (r *resolvconf) restoreHostDNS() error { var cmd *exec.Cmd @@ -157,7 +150,7 @@ func (r *resolvconf) applyConfig(content bytes.Buffer) error { } cmd.Stdin = &content - out, err := cmd.Output() + out, err := cmd.CombinedOutput() log.Tracef("resolvconf output: %s", out) if err != nil { return fmt.Errorf("applying resolvconf configuration for %s interface: %w", r.ifaceName, err) diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index e81aebf98..f933c1de0 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -2,7 +2,6 @@ package dns import ( "context" - "errors" "fmt" "net/netip" "runtime" @@ -20,7 +19,6 @@ import ( "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/statemanager" - cProto "github.com/netbirdio/netbird/client/proto" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/domain" ) @@ -41,7 +39,7 @@ type Server interface { DeregisterHandler(domains domain.List, priority int) Initialize() error Stop() - DnsIP() string + DnsIP() netip.Addr UpdateDNSServer(serial uint64, update nbdns.Config) error OnUpdatedHostDNSServer(strings []string) SearchDomains() []string @@ -53,6 +51,12 @@ type nsGroupsByDomain struct { groups []*nbdns.NameServerGroup } +// hostManagerWithOriginalNS extends the basic hostManager interface +type hostManagerWithOriginalNS interface { + hostManager + getOriginalNameservers() []string +} + // DefaultServer dns server object type DefaultServer struct { ctx context.Context @@ -215,6 +219,7 @@ func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, p log.Warn("skipping empty domain") continue } + s.handlerChain.AddHandler(domain, handler, priority) } } @@ -286,7 +291,7 @@ func (s *DefaultServer) Initialize() (err error) { // // When kernel space interface used it return real DNS server listener IP address // For bind interface, fake DNS resolver address returned (second last IP address from Nebird network) -func (s *DefaultServer) DnsIP() string { +func (s *DefaultServer) DnsIP() netip.Addr { return s.service.RuntimeIP() } @@ -297,6 +302,11 @@ func (s *DefaultServer) Stop() { s.ctxCancel() if s.hostManager != nil { + if srvs, ok := s.hostManager.(hostManagerWithOriginalNS); ok && len(srvs.getOriginalNameservers()) > 0 { + log.Debugf("deregistering original nameservers as fallback handlers") + s.deregisterHandler([]string{nbdns.RootZone}, PriorityFallback) + } + if err := s.hostManager.restoreHostDNS(); err != nil { log.Error("failed to restore host DNS settings: ", err) } else if err := s.stateManager.DeleteState(&ShutdownState{}); err != nil { @@ -311,7 +321,6 @@ func (s *DefaultServer) Stop() { // OnUpdatedHostDNSServer update the DNS servers addresses for root zones // It will be applied if the mgm server do not enforce DNS settings for root zone - func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) { s.hostsDNSHolder.set(hostsDnsList) @@ -493,25 +502,56 @@ func (s *DefaultServer) applyHostConfig() { if err := s.hostManager.applyDNSConfig(config, s.stateManager); err != nil { log.Errorf("failed to apply DNS host manager update: %v", err) - s.handleErrNoGroupaAll(err) } + + s.registerFallback(config) } -func (s *DefaultServer) handleErrNoGroupaAll(err error) { - if !errors.Is(ErrRouteAllWithoutNameserverGroup, err) { +// registerFallback registers original nameservers as low-priority fallback handlers +func (s *DefaultServer) registerFallback(config HostDNSConfig) { + hostMgrWithNS, ok := s.hostManager.(hostManagerWithOriginalNS) + if !ok { return } - if s.statusRecorder == nil { + originalNameservers := hostMgrWithNS.getOriginalNameservers() + if len(originalNameservers) == 0 { return } - s.statusRecorder.PublishEvent( - cProto.SystemEvent_WARNING, cProto.SystemEvent_DNS, - "The host dns manager does not support match domains", - "The host dns manager does not support match domains without a catch-all nameserver group.", - map[string]string{"manager": s.hostManager.string()}, + log.Infof("registering original nameservers %v as upstream handlers with priority %d", originalNameservers, PriorityFallback) + + handler, err := newUpstreamResolver( + s.ctx, + s.wgInterface.Name(), + s.wgInterface.Address().IP, + s.wgInterface.Address().Network, + s.statusRecorder, + s.hostsDNSHolder, + nbdns.RootZone, ) + if err != nil { + log.Errorf("failed to create upstream resolver for original nameservers: %v", err) + return + } + + for _, ns := range originalNameservers { + if ns == config.ServerIP.String() { + log.Debugf("skipping original nameserver %s as it is the same as the server IP %s", ns, config.ServerIP) + continue + } + + ns = fmt.Sprintf("%s:%d", ns, defaultPort) + if ip, err := netip.ParseAddr(ns); err == nil && ip.Is6() { + ns = fmt.Sprintf("[%s]:%d", ns, defaultPort) + } + + handler.upstreamServers = append(handler.upstreamServers, ns) + } + handler.deactivate = func(error) { /* always active */ } + handler.reactivate = func() { /* always active */ } + + s.registerHandler([]string{nbdns.RootZone}, handler, PriorityFallback) } func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, []nbdns.SimpleRecord, error) { @@ -588,14 +628,8 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai // Decrement priority by handler index (0, 1, 2, ...) to avoid conflicts priority := basePriority - i - // Check if we're about to overlap with the next priority tier. - // This boundary check ensures that the priority of upstream handlers does not conflict - // with the default priority tier. By decrementing the priority for each handler, we avoid - // overlaps, but if the calculated priority falls into the default tier, we skip the remaining - // handlers to maintain the integrity of the priority system. - if basePriority == PriorityUpstream && priority <= PriorityDefault { - log.Warnf("too many handlers for domain=%s, would overlap with default priority tier (diff=%d). Skipping remaining handlers", - domainGroup.domain, PriorityUpstream-PriorityDefault) + // Check if we're about to overlap with the next priority tier + if s.leaksPriority(domainGroup, basePriority, priority) { break } @@ -648,6 +682,21 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai return muxUpdates, nil } +func (s *DefaultServer) leaksPriority(domainGroup nsGroupsByDomain, basePriority int, priority int) bool { + if basePriority == PriorityUpstream && priority <= PriorityDefault { + log.Warnf("too many handlers for domain=%s, would overlap with default priority tier (diff=%d). Skipping remaining handlers", + domainGroup.domain, PriorityUpstream-PriorityDefault) + return true + } + if basePriority == PriorityDefault && priority <= PriorityFallback { + log.Warnf("too many handlers for domain=%s, would overlap with fallback priority tier (diff=%d). Skipping remaining handlers", + domainGroup.domain, PriorityDefault-PriorityFallback) + return true + } + + return false +} + func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) { // this will introduce a short period of time when the server is not able to handle DNS requests for _, existing := range s.dnsMuxMap { @@ -760,6 +809,12 @@ func (s *DefaultServer) upstreamCallbacks( } func (s *DefaultServer) addHostRootZone() { + hostDNSServers := s.hostsDNSHolder.get() + if len(hostDNSServers) == 0 { + log.Debug("no host DNS servers available, skipping root zone handler creation") + return + } + handler, err := newUpstreamResolver( s.ctx, s.wgInterface.Name(), @@ -775,7 +830,7 @@ func (s *DefaultServer) addHostRootZone() { } handler.upstreamServers = make([]string, 0) - for k := range s.hostsDNSHolder.get() { + for k := range hostDNSServers { handler.upstreamServers = append(handler.upstreamServers, k) } handler.deactivate = func(error) {} diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 21a9e2f2d..3cab4517a 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -938,7 +938,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) { return wgIface, nil } -func newDnsResolver(ip string, port int) *net.Resolver { +func newDnsResolver(ip netip.Addr, port int) *net.Resolver { return &net.Resolver{ PreferGo: true, Dial: func(ctx context.Context, network, address string) (net.Conn, error) { @@ -1047,7 +1047,7 @@ type mockService struct{} func (m *mockService) Listen() error { return nil } func (m *mockService) Stop() {} -func (m *mockService) RuntimeIP() string { return "127.0.0.1" } +func (m *mockService) RuntimeIP() netip.Addr { return netip.MustParseAddr("127.0.0.1") } func (m *mockService) RuntimePort() int { return 53 } func (m *mockService) RegisterMux(string, dns.Handler) {} func (m *mockService) DeregisterMux(string) {} diff --git a/client/internal/dns/service.go b/client/internal/dns/service.go index 523976e54..ab8238a61 100644 --- a/client/internal/dns/service.go +++ b/client/internal/dns/service.go @@ -1,6 +1,8 @@ package dns import ( + "net/netip" + "github.com/miekg/dns" ) @@ -14,5 +16,5 @@ type service interface { RegisterMux(domain string, handler dns.Handler) DeregisterMux(key string) RuntimePort() int - RuntimeIP() string + RuntimeIP() netip.Addr } diff --git a/client/internal/dns/service_listener.go b/client/internal/dns/service_listener.go index 72dc4bc6e..abd2f4f05 100644 --- a/client/internal/dns/service_listener.go +++ b/client/internal/dns/service_listener.go @@ -18,8 +18,11 @@ import ( const ( customPort = 5053 - defaultIP = "127.0.0.1" - customIP = "127.0.0.153" +) + +var ( + defaultIP = netip.MustParseAddr("127.0.0.1") + customIP = netip.MustParseAddr("127.0.0.153") ) type serviceViaListener struct { @@ -27,7 +30,7 @@ type serviceViaListener struct { dnsMux *dns.ServeMux customAddr *netip.AddrPort server *dns.Server - listenIP string + listenIP netip.Addr listenPort uint16 listenerIsRunning bool listenerFlagLock sync.Mutex @@ -65,6 +68,7 @@ func (s *serviceViaListener) Listen() error { log.Errorf("failed to eval runtime address: %s", err) return fmt.Errorf("eval listen address: %w", err) } + s.listenIP = s.listenIP.Unmap() s.server.Addr = fmt.Sprintf("%s:%d", s.listenIP, s.listenPort) log.Debugf("starting dns on %s", s.server.Addr) go func() { @@ -124,7 +128,7 @@ func (s *serviceViaListener) RuntimePort() int { } } -func (s *serviceViaListener) RuntimeIP() string { +func (s *serviceViaListener) RuntimeIP() netip.Addr { return s.listenIP } @@ -139,9 +143,9 @@ func (s *serviceViaListener) setListenerStatus(running bool) { // first check the 53 port availability on WG interface or lo, if not success // pick a random port on WG interface for eBPF, if not success // check the 5053 port availability on WG interface or lo without eBPF usage, -func (s *serviceViaListener) evalListenAddress() (string, uint16, error) { +func (s *serviceViaListener) evalListenAddress() (netip.Addr, uint16, error) { if s.customAddr != nil { - return s.customAddr.Addr().String(), s.customAddr.Port(), nil + return s.customAddr.Addr(), s.customAddr.Port(), nil } ip, ok := s.testFreePort(defaultPort) @@ -152,7 +156,7 @@ func (s *serviceViaListener) evalListenAddress() (string, uint16, error) { ebpfSrv, port, ok := s.tryToUseeBPF() if ok { s.ebpfService = ebpfSrv - return s.wgInterface.Address().IP.String(), port, nil + return s.wgInterface.Address().IP, port, nil } ip, ok = s.testFreePort(customPort) @@ -160,15 +164,15 @@ func (s *serviceViaListener) evalListenAddress() (string, uint16, error) { return ip, customPort, nil } - return "", 0, fmt.Errorf("failed to find a free port for DNS server") + return netip.Addr{}, 0, fmt.Errorf("failed to find a free port for DNS server") } -func (s *serviceViaListener) testFreePort(port int) (string, bool) { - var ips []string +func (s *serviceViaListener) testFreePort(port int) (netip.Addr, bool) { + var ips []netip.Addr if runtime.GOOS != "darwin" { - ips = []string{s.wgInterface.Address().IP.String(), defaultIP, customIP} + ips = []netip.Addr{s.wgInterface.Address().IP, defaultIP, customIP} } else { - ips = []string{defaultIP, customIP} + ips = []netip.Addr{defaultIP, customIP} } for _, ip := range ips { @@ -178,10 +182,10 @@ func (s *serviceViaListener) testFreePort(port int) (string, bool) { return ip, true } - return "", false + return netip.Addr{}, false } -func (s *serviceViaListener) tryToBind(ip string, port int) bool { +func (s *serviceViaListener) tryToBind(ip netip.Addr, port int) bool { addrString := fmt.Sprintf("%s:%d", ip, port) udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString)) probeListener, err := net.ListenUDP("udp", udpAddr) @@ -224,7 +228,7 @@ func (s *serviceViaListener) tryToUseeBPF() (ebpfMgr.Manager, uint16, bool) { } func (s *serviceViaListener) generateFreePort() (uint16, error) { - ok := s.tryToBind(s.wgInterface.Address().IP.String(), customPort) + ok := s.tryToBind(s.wgInterface.Address().IP, customPort) if ok { return customPort, nil } diff --git a/client/internal/dns/service_memory.go b/client/internal/dns/service_memory.go index 226202cf7..9f55838bf 100644 --- a/client/internal/dns/service_memory.go +++ b/client/internal/dns/service_memory.go @@ -16,7 +16,7 @@ import ( type ServiceViaMemory struct { wgInterface WGIface dnsMux *dns.ServeMux - runtimeIP string + runtimeIP netip.Addr runtimePort int udpFilterHookID string listenerIsRunning bool @@ -32,7 +32,7 @@ func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory { wgInterface: wgIface, dnsMux: dns.NewServeMux(), - runtimeIP: lastIP.String(), + runtimeIP: lastIP, runtimePort: defaultPort, } return s @@ -84,7 +84,7 @@ func (s *ServiceViaMemory) RuntimePort() int { return s.runtimePort } -func (s *ServiceViaMemory) RuntimeIP() string { +func (s *ServiceViaMemory) RuntimeIP() netip.Addr { return s.runtimeIP } @@ -121,10 +121,5 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) { return true } - ip, err := netip.ParseAddr(s.runtimeIP) - if err != nil { - return "", fmt.Errorf("parse runtime ip: %w", err) - } - - return filter.AddUDPPacketHook(false, ip, uint16(s.runtimePort), hook), nil + return filter.AddUDPPacketHook(false, s.runtimeIP, uint16(s.runtimePort), hook), nil } diff --git a/client/internal/dns/systemd_linux.go b/client/internal/dns/systemd_linux.go index 9040ed787..a58747d5b 100644 --- a/client/internal/dns/systemd_linux.go +++ b/client/internal/dns/systemd_linux.go @@ -89,21 +89,16 @@ func (s *systemdDbusConfigurator) supportCustomPort() bool { } func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { - parsedIP, err := netip.ParseAddr(config.ServerIP) - if err != nil { - return fmt.Errorf("unable to parse ip address, error: %w", err) - } - ipAs4 := parsedIP.As4() defaultLinkInput := systemdDbusDNSInput{ Family: unix.AF_INET, - Address: ipAs4[:], + Address: config.ServerIP.AsSlice(), } - if err = s.callLinkMethod(systemdDbusSetDNSMethodSuffix, []systemdDbusDNSInput{defaultLinkInput}); err != nil { + if err := s.callLinkMethod(systemdDbusSetDNSMethodSuffix, []systemdDbusDNSInput{defaultLinkInput}); err != nil { return fmt.Errorf("set interface DNS server %s:%d: %w", config.ServerIP, config.ServerPort, err) } // We don't support dnssec. On some machines this is default on so we explicitly set it to off - if err = s.callLinkMethod(systemdDbusSetDNSSECMethodSuffix, dnsSecDisabled); err != nil { + if err := s.callLinkMethod(systemdDbusSetDNSSECMethodSuffix, dnsSecDisabled); err != nil { log.Warnf("failed to set DNSSEC to 'no': %v", err) } @@ -129,8 +124,7 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana } if config.RouteAll { - err = s.callLinkMethod(systemdDbusSetDefaultRouteMethodSuffix, true) - if err != nil { + if err := s.callLinkMethod(systemdDbusSetDefaultRouteMethodSuffix, true); err != nil { return fmt.Errorf("set link as default dns router: %w", err) } domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{ @@ -139,7 +133,7 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana }) log.Infof("configured %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort) } else { - if err = s.callLinkMethod(systemdDbusSetDefaultRouteMethodSuffix, false); err != nil { + if err := s.callLinkMethod(systemdDbusSetDefaultRouteMethodSuffix, false); err != nil { return fmt.Errorf("remove link as default dns router: %w", err) } } @@ -153,9 +147,8 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana } log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains) - err = s.setDomainsForInterface(domainsInput) - if err != nil { - log.Error(err) + if err := s.setDomainsForInterface(domainsInput); err != nil { + log.Error("failed to set domains for interface: ", err) } if err := s.flushDNSCache(); err != nil { diff --git a/client/internal/dns/unclean_shutdown_unix.go b/client/internal/dns/unclean_shutdown_unix.go index fcf60c694..2e786f484 100644 --- a/client/internal/dns/unclean_shutdown_unix.go +++ b/client/internal/dns/unclean_shutdown_unix.go @@ -35,12 +35,7 @@ func (s *ShutdownState) Cleanup() error { } // TODO: move file contents to state manager -func createUncleanShutdownIndicator(sourcePath string, dnsAddressStr string, stateManager *statemanager.Manager) error { - dnsAddress, err := netip.ParseAddr(dnsAddressStr) - if err != nil { - return fmt.Errorf("parse dns address %s: %w", dnsAddressStr, err) - } - +func createUncleanShutdownIndicator(sourcePath string, dnsAddress netip.Addr, stateManager *statemanager.Manager) error { dir := filepath.Dir(fileUncleanShutdownResolvConfLocation) if err := os.MkdirAll(dir, os.FileMode(0755)); err != nil { return fmt.Errorf("create dir %s: %w", dir, err) diff --git a/client/internal/engine.go b/client/internal/engine.go index 079adf7e8..d2de5b3cc 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -1550,7 +1550,7 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) { func (e *Engine) wgInterfaceCreate() (err error) { switch runtime.GOOS { case "android": - err = e.wgInterface.CreateOnAndroid(e.routeManager.InitialRouteRange(), e.dnsServer.DnsIP(), e.dnsServer.SearchDomains()) + err = e.wgInterface.CreateOnAndroid(e.routeManager.InitialRouteRange(), e.dnsServer.DnsIP().String(), e.dnsServer.SearchDomains()) case "ios": e.mobileDep.NetworkChangeListener.SetInterfaceIP(e.config.WgAddr) err = e.wgInterface.Create() From 31872a7fb62b8257037a894daabcf5de44f5c4aa Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Fri, 25 Jul 2025 14:14:45 +0200 Subject: [PATCH 36/50] [client] Fix UDP proxy to notify listener when remote conn closed (#4199) * Fix UDP proxy to notify listener when remote conn closed * Fix sender tests to use t.Errorf for timeout assertions * Fix potential nil pointer --- client/iface/wgproxy/listener/listener.go | 17 +++++++++++++++-- client/iface/wgproxy/udp/proxy.go | 5 +++++ relay/healthcheck/sender_test.go | 11 +++++------ 3 files changed, 25 insertions(+), 8 deletions(-) diff --git a/client/iface/wgproxy/listener/listener.go b/client/iface/wgproxy/listener/listener.go index bfd651548..a8ee354a1 100644 --- a/client/iface/wgproxy/listener/listener.go +++ b/client/iface/wgproxy/listener/listener.go @@ -1,7 +1,10 @@ package listener +import "sync" + type CloseListener struct { listener func() + mu sync.Mutex } func NewCloseListener() *CloseListener { @@ -9,11 +12,21 @@ func NewCloseListener() *CloseListener { } func (c *CloseListener) SetCloseListener(listener func()) { + c.mu.Lock() + defer c.mu.Unlock() + c.listener = listener } func (c *CloseListener) Notify() { - if c.listener != nil { - c.listener() + c.mu.Lock() + + if c.listener == nil { + c.mu.Unlock() + return } + listener := c.listener + c.mu.Unlock() + + listener() } diff --git a/client/iface/wgproxy/udp/proxy.go b/client/iface/wgproxy/udp/proxy.go index df45d8ca5..139ccd4ed 100644 --- a/client/iface/wgproxy/udp/proxy.go +++ b/client/iface/wgproxy/udp/proxy.go @@ -183,6 +183,11 @@ func (p *WGUDPProxy) proxyToLocal(ctx context.Context) { for { n, err := p.remoteConnRead(ctx, buf) if err != nil { + if ctx.Err() != nil { + return + } + + p.closeListener.Notify() return } diff --git a/relay/healthcheck/sender_test.go b/relay/healthcheck/sender_test.go index 39d266b48..23446366a 100644 --- a/relay/healthcheck/sender_test.go +++ b/relay/healthcheck/sender_test.go @@ -122,10 +122,6 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) { originalTimeout := healthCheckTimeout healthCheckInterval = 1 * time.Second healthCheckTimeout = 500 * time.Millisecond - defer func() { - healthCheckInterval = originalInterval - healthCheckTimeout = originalTimeout - }() //nolint:tenv os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold)) @@ -164,20 +160,23 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) { select { case <-sender.Timeout: if tc.resetCounterOnce { - t.Fatalf("should not have timed out before %s", testTimeout) + t.Errorf("should not have timed out before %s", testTimeout) } case <-time.After(testTimeout): if tc.resetCounterOnce { return } - t.Fatalf("should have timed out before %s", testTimeout) + t.Errorf("should have timed out before %s", testTimeout) } + cancel() select { case <-senderExit: case <-time.After(2 * time.Second): t.Fatalf("sender did not exit in time") } + healthCheckInterval = originalInterval + healthCheckTimeout = originalTimeout }) } From 2c4ac33b381e352fefe401e94f0416e158db8b71 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 25 Jul 2025 15:15:38 +0200 Subject: [PATCH 37/50] [client] Remove and deprecate the admin url functionality (#4218) --- client/cmd/login.go | 1 - client/cmd/root.go | 3 ++- client/cmd/up.go | 2 -- client/ui/client_ui.go | 14 +------------- 4 files changed, 3 insertions(+), 17 deletions(-) diff --git a/client/cmd/login.go b/client/cmd/login.go index 8ac7086b8..f3a2f0cca 100644 --- a/client/cmd/login.go +++ b/client/cmd/login.go @@ -61,7 +61,6 @@ var loginCmd = &cobra.Command{ ic := internal.ConfigInput{ ManagementURL: managementURL, - AdminURL: adminURL, ConfigPath: configPath, } if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) { diff --git a/client/cmd/root.go b/client/cmd/root.go index 1774602c4..e4f260f9b 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -118,7 +118,8 @@ func init() { rootCmd.PersistentFlags().StringVar(&daemonAddr, "daemon-addr", defaultDaemonAddr, "Daemon service address to serve CLI requests [unix|tcp]://[path|host:port]") rootCmd.PersistentFlags().StringVarP(&managementURL, "management-url", "m", "", fmt.Sprintf("Management Service URL [http|https]://[host]:[port] (default \"%s\")", internal.DefaultManagementURL)) - rootCmd.PersistentFlags().StringVar(&adminURL, "admin-url", "", fmt.Sprintf("Admin Panel URL [http|https]://[host]:[port] (default \"%s\")", internal.DefaultAdminURL)) + rootCmd.PersistentFlags().StringVar(&adminURL, "admin-url", "", fmt.Sprintf("(DEPRECATED) Admin Panel URL [http|https]://[host]:[port] (default \"%s\") - This flag is no longer functional", internal.DefaultAdminURL)) + _ = rootCmd.PersistentFlags().MarkDeprecated("admin-url", "the admin-url flag is no longer functional and will be removed in a future version") rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Netbird config file location") rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets Netbird log level") rootCmd.PersistentFlags().StringSliceVar(&logFiles, "log-file", []string{defaultLogFile}, "sets Netbird log paths written to simultaneously. If `console` is specified the log will be output to stdout. If `syslog` is specified the log will be sent to syslog daemon. You can pass the flag multiple times or separate entries by `,` character") diff --git a/client/cmd/up.go b/client/cmd/up.go index 529beeac7..66fe91f7d 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -238,7 +238,6 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error { func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command) (*internal.ConfigInput, error) { ic := internal.ConfigInput{ ManagementURL: managementURL, - AdminURL: adminURL, ConfigPath: configPath, NATExternalIPs: natExternalIPs, CustomDNSAddress: customDNSAddressConverted, @@ -325,7 +324,6 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte loginRequest := proto.LoginRequest{ SetupKey: providedSetupKey, ManagementUrl: managementURL, - AdminURL: adminURL, NatExternalIPs: natExternalIPs, CleanNATExternalIPs: natExternalIPs != nil && len(natExternalIPs) == 0, CustomDNSAddress: customDNSAddressConverted, diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index 4480adb51..c18d96dae 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -214,7 +214,6 @@ type serviceClient struct { // input elements for settings form iMngURL *widget.Entry - iAdminURL *widget.Entry iConfigFile *widget.Entry iLogFile *widget.Entry iPreSharedKey *widget.Entry @@ -232,7 +231,6 @@ type serviceClient struct { // observable settings over corresponding iMngURL and iPreSharedKey values. managementURL string preSharedKey string - adminURL string RosenpassPermissive bool interfaceName string interfacePort int @@ -344,7 +342,6 @@ func (s *serviceClient) showSettingsUI() { s.wSettings.SetOnClosed(s.cancel) s.iMngURL = widget.NewEntry() - s.iAdminURL = widget.NewEntry() s.iConfigFile = widget.NewEntry() s.iConfigFile.Disable() s.iLogFile = widget.NewEntry() @@ -377,7 +374,6 @@ func (s *serviceClient) getSettingsForm() *widget.Form { {Text: "Interface Name", Widget: s.iInterfaceName}, {Text: "Interface Port", Widget: s.iInterfacePort}, {Text: "Management URL", Widget: s.iMngURL}, - {Text: "Admin URL", Widget: s.iAdminURL}, {Text: "Pre-shared Key", Widget: s.iPreSharedKey}, {Text: "Config File", Widget: s.iConfigFile}, {Text: "Log File", Widget: s.iLogFile}, @@ -403,14 +399,13 @@ func (s *serviceClient) getSettingsForm() *widget.Form { return } - iAdminURL := strings.TrimSpace(s.iAdminURL.Text) iMngURL := strings.TrimSpace(s.iMngURL.Text) defer s.wSettings.Close() // Check if any settings have changed if s.managementURL != iMngURL || s.preSharedKey != s.iPreSharedKey.Text || - s.adminURL != iAdminURL || s.RosenpassPermissive != s.sRosenpassPermissive.Checked || + s.RosenpassPermissive != s.sRosenpassPermissive.Checked || s.interfaceName != s.iInterfaceName.Text || s.interfacePort != int(port) || s.networkMonitor != s.sNetworkMonitor.Checked || s.disableDNS != s.sDisableDNS.Checked || @@ -420,11 +415,9 @@ func (s *serviceClient) getSettingsForm() *widget.Form { s.managementURL = iMngURL s.preSharedKey = s.iPreSharedKey.Text - s.adminURL = iAdminURL loginRequest := proto.LoginRequest{ ManagementUrl: iMngURL, - AdminURL: iAdminURL, IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd", RosenpassPermissive: &s.sRosenpassPermissive.Checked, InterfaceName: &s.iInterfaceName.Text, @@ -798,7 +791,6 @@ func (s *serviceClient) getSrvClient(timeout time.Duration) (proto.DaemonService // getSrvConfig from the service to show it in the settings window. func (s *serviceClient) getSrvConfig() { s.managementURL = internal.DefaultManagementURL - s.adminURL = internal.DefaultAdminURL conn, err := s.getSrvClient(failFastTimeout) if err != nil { @@ -815,9 +807,6 @@ func (s *serviceClient) getSrvConfig() { if cfg.ManagementUrl != "" { s.managementURL = cfg.ManagementUrl } - if cfg.AdminURL != "" { - s.adminURL = cfg.AdminURL - } s.preSharedKey = cfg.PreSharedKey s.RosenpassPermissive = cfg.RosenpassPermissive s.interfaceName = cfg.InterfaceName @@ -831,7 +820,6 @@ func (s *serviceClient) getSrvConfig() { if s.showAdvancedSettings { s.iMngURL.SetText(s.managementURL) - s.iAdminURL.SetText(s.adminURL) s.iConfigFile.SetText(cfg.ConfigFile) s.iLogFile.SetText(cfg.LogFile) s.iPreSharedKey.SetText(cfg.PreSharedKey) From e0d9306b05101716eb2fed795759b40ecf8fce6a Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 25 Jul 2025 15:31:06 +0200 Subject: [PATCH 38/50] [client] Add detailed routes and resolved IPs to debug bundle (#4141) --- client/internal/debug/debug.go | 203 ++++----- client/internal/debug/debug_linux.go | 35 +- client/internal/debug/debug_nonlinux.go | 5 + client/internal/debug/debug_nonmobile.go | 8 +- client/internal/debug/format.go | 206 +++++++++ client/internal/debug/format_linux.go | 185 ++++++++ client/internal/debug/format_nonwindows.go | 27 ++ client/internal/debug/format_windows.go | 37 ++ .../routemanager/systemops/routeflags_bsd.go | 54 ++- .../systemops/routeflags_freebsd.go | 55 ++- .../routemanager/systemops/systemops.go | 20 + .../routemanager/systemops/systemops_bsd.go | 131 +++++- .../routemanager/systemops/systemops_linux.go | 410 +++++++++++++++++- .../systemops/systemops_nonlinux.go | 25 ++ .../systemops/systemops_windows.go | 265 ++++++++++- 15 files changed, 1501 insertions(+), 165 deletions(-) create mode 100644 client/internal/debug/format.go create mode 100644 client/internal/debug/format_linux.go create mode 100644 client/internal/debug/format_nonwindows.go create mode 100644 client/internal/debug/format_windows.go diff --git a/client/internal/debug/debug.go b/client/internal/debug/debug.go index 220e6854d..a9d9f3fc1 100644 --- a/client/internal/debug/debug.go +++ b/client/internal/debug/debug.go @@ -40,10 +40,12 @@ status.txt: Anonymized status information of the NetBird client. client.log: Most recent, anonymized client log file of the NetBird client. netbird.err: Most recent, anonymized stderr log file of the NetBird client. netbird.out: Most recent, anonymized stdout log file of the NetBird client. -routes.txt: Anonymized system routes, if --system-info flag was provided. +routes.txt: Detailed system routing table in tabular format including destination, gateway, interface, metrics, and protocol information, if --system-info flag was provided. interfaces.txt: Anonymized network interface information, if --system-info flag was provided. +ip_rules.txt: Detailed IP routing rules in tabular format including priority, source, destination, interfaces, table, and action information (Linux only), if --system-info flag was provided. iptables.txt: Anonymized iptables rules with packet counters, if --system-info flag was provided. nftables.txt: Anonymized nftables rules with packet counters, if --system-info flag was provided. +resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder. config.txt: Anonymized configuration information of the NetBird client. network_map.json: Anonymized network map containing peer configurations, routes, DNS settings, and firewall rules. state.json: Anonymized client state dump containing netbird states. @@ -107,7 +109,29 @@ go tool pprof -http=:8088 heap.prof This will open a web browser tab with the profiling information. Routes -For anonymized routes, the IP addresses are replaced as described above. The prefix length remains unchanged. Note that for prefixes, the anonymized IP might not be a network address, but the prefix length is still correct. +The routes.txt file contains detailed routing table information in a tabular format: + +- Destination: Network prefix (IP_ADDRESS/PREFIX_LENGTH) +- Gateway: Next hop IP address (or "-" if direct) +- Interface: Network interface name +- Metric: Route priority/metric (lower values preferred) +- Protocol: Routing protocol (kernel, static, dhcp, etc.) +- Scope: Route scope (global, link, host, etc.) +- Type: Route type (unicast, local, broadcast, etc.) +- Table: Routing table name (main, local, netbird, etc.) + +The table format provides a comprehensive view of the system's routing configuration, including information from multiple routing tables on Linux systems. This is valuable for troubleshooting routing issues and understanding traffic flow. + +For anonymized routes, IP addresses are replaced as described above. The prefix length remains unchanged. Note that for prefixes, the anonymized IP might not be a network address, but the prefix length is still correct. Interface names are anonymized using string anonymization. + +Resolved Domains +The resolved_domains.txt file contains information about domain names that have been resolved to IP addresses by NetBird's DNS resolver. This includes: +- Original domain patterns that were configured for routing +- Resolved domain names that matched those patterns +- IP address prefixes that were resolved for each domain +- Parent domain associations showing which original pattern each resolved domain belongs to + +All domain names and IP addresses in this file follow the same anonymization rules as described above. This information is valuable for troubleshooting DNS resolution and routing issues. Network Interfaces The interfaces.txt file contains information about network interfaces, including: @@ -145,6 +169,22 @@ nftables.txt: - Shows packet and byte counters for each rule - All IP addresses are anonymized - Chain names, table names, and other non-sensitive information remain unchanged + +IP Rules (Linux only) +The ip_rules.txt file contains detailed IP routing rule information: + +- Priority: Rule priority number (lower values processed first) +- From: Source IP prefix or "all" if unspecified +- To: Destination IP prefix or "all" if unspecified +- IIF: Input interface name or "-" if unspecified +- OIF: Output interface name or "-" if unspecified +- Table: Target routing table name (main, local, netbird, etc.) +- Action: Rule action (lookup, goto, blackhole, etc.) +- Mark: Firewall mark value in hex format or "-" if unspecified + +The table format provides comprehensive visibility into the IP routing decision process, including how traffic is directed to different routing tables based on various criteria. This is valuable for troubleshooting advanced routing configurations and policy-based routing. + +For anonymized rules, IP addresses and prefixes are replaced as described above. Interface names are anonymized using string anonymization. Table names, actions, and other non-sensitive information remain unchanged. ` const ( @@ -159,13 +199,11 @@ const ( type BundleGenerator struct { anonymizer *anonymize.Anonymizer - // deps internalConfig *internal.Config statusRecorder *peer.Status networkMap *mgmProto.NetworkMap logFile string - // config anonymize bool clientStatus string includeSystemInfo bool @@ -258,7 +296,11 @@ func (g *BundleGenerator) createArchive() error { } if err := g.addConfig(); err != nil { - log.Errorf("Failed to add config to debug bundle: %v", err) + log.Errorf("failed to add config to debug bundle: %v", err) + } + + if err := g.addResolvedDomains(); err != nil { + log.Errorf("failed to add resolved domains to debug bundle: %v", err) } if g.includeSystemInfo { @@ -266,7 +308,7 @@ func (g *BundleGenerator) createArchive() error { } if err := g.addProf(); err != nil { - log.Errorf("Failed to add profiles to debug bundle: %v", err) + log.Errorf("failed to add profiles to debug bundle: %v", err) } if err := g.addNetworkMap(); err != nil { @@ -274,26 +316,26 @@ func (g *BundleGenerator) createArchive() error { } if err := g.addStateFile(); err != nil { - log.Errorf("Failed to add state file to debug bundle: %v", err) + log.Errorf("failed to add state file to debug bundle: %v", err) } if err := g.addCorruptedStateFiles(); err != nil { - log.Errorf("Failed to add corrupted state files to debug bundle: %v", err) + log.Errorf("failed to add corrupted state files to debug bundle: %v", err) } if err := g.addWgShow(); err != nil { - log.Errorf("Failed to add wg show output: %v", err) + log.Errorf("failed to add wg show output: %v", err) } if g.logFile != "" && !slices.Contains(util.SpecialLogs, g.logFile) { if err := g.addLogfile(); err != nil { - log.Errorf("Failed to add log file to debug bundle: %v", err) + log.Errorf("failed to add log file to debug bundle: %v", err) if err := g.trySystemdLogFallback(); err != nil { - log.Errorf("Failed to add systemd logs as fallback: %v", err) + log.Errorf("failed to add systemd logs as fallback: %v", err) } } } else if err := g.trySystemdLogFallback(); err != nil { - log.Errorf("Failed to add systemd logs: %v", err) + log.Errorf("failed to add systemd logs: %v", err) } return nil @@ -301,15 +343,19 @@ func (g *BundleGenerator) createArchive() error { func (g *BundleGenerator) addSystemInfo() { if err := g.addRoutes(); err != nil { - log.Errorf("Failed to add routes to debug bundle: %v", err) + log.Errorf("failed to add routes to debug bundle: %v", err) } if err := g.addInterfaces(); err != nil { - log.Errorf("Failed to add interfaces to debug bundle: %v", err) + log.Errorf("failed to add interfaces to debug bundle: %v", err) + } + + if err := g.addIPRules(); err != nil { + log.Errorf("failed to add IP rules to debug bundle: %v", err) } if err := g.addFirewallRules(); err != nil { - log.Errorf("Failed to add firewall rules to debug bundle: %v", err) + log.Errorf("failed to add firewall rules to debug bundle: %v", err) } } @@ -364,7 +410,6 @@ func (g *BundleGenerator) addConfig() error { } } - // Add config content to zip file configReader := strings.NewReader(configContent.String()) if err := g.addFileToZip(configReader, "config.txt"); err != nil { return fmt.Errorf("add config file to zip: %w", err) @@ -376,7 +421,6 @@ func (g *BundleGenerator) addConfig() error { func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder) { configContent.WriteString("NetBird Client Configuration:\n\n") - // Add non-sensitive fields configContent.WriteString(fmt.Sprintf("WgIface: %s\n", g.internalConfig.WgIface)) configContent.WriteString(fmt.Sprintf("WgPort: %d\n", g.internalConfig.WgPort)) if g.internalConfig.NetworkMonitor != nil { @@ -461,6 +505,27 @@ func (g *BundleGenerator) addInterfaces() error { return nil } +func (g *BundleGenerator) addResolvedDomains() error { + if g.statusRecorder == nil { + log.Debugf("skipping resolved domains in debug bundle: no status recorder") + return nil + } + + resolvedDomains := g.statusRecorder.GetResolvedDomainsStates() + if len(resolvedDomains) == 0 { + log.Debugf("skipping resolved domains in debug bundle: no resolved domains") + return nil + } + + resolvedDomainsContent := formatResolvedDomains(resolvedDomains, g.anonymize, g.anonymizer) + resolvedDomainsReader := strings.NewReader(resolvedDomainsContent) + if err := g.addFileToZip(resolvedDomainsReader, "resolved_domains.txt"); err != nil { + return fmt.Errorf("add resolved domains file to zip: %w", err) + } + + return nil +} + func (g *BundleGenerator) addNetworkMap() error { if g.networkMap == nil { log.Debugf("skipping empty network map in debug bundle") @@ -572,7 +637,6 @@ func (g *BundleGenerator) addLogfile() error { return fmt.Errorf("add client log file to zip: %w", err) } - // add rotated log files based on logFileCount g.addRotatedLogFiles(logDir) stdErrLogPath := filepath.Join(logDir, errorLogFile) @@ -601,7 +665,7 @@ func (g *BundleGenerator) addSingleLogfile(logPath, targetName string) error { } defer func() { if err := logFile.Close(); err != nil { - log.Errorf("Failed to close log file %s: %v", targetName, err) + log.Errorf("failed to close log file %s: %v", targetName, err) } }() @@ -625,13 +689,21 @@ func (g *BundleGenerator) addSingleLogFileGz(logPath, targetName string) error { if err != nil { return fmt.Errorf("open gz log file %s: %w", targetName, err) } - defer f.Close() + defer func() { + if err := f.Close(); err != nil { + log.Errorf("failed to close gz file %s: %v", targetName, err) + } + }() gzr, err := gzip.NewReader(f) if err != nil { return fmt.Errorf("create gzip reader: %w", err) } - defer gzr.Close() + defer func() { + if err := gzr.Close(); err != nil { + log.Errorf("failed to close gzip reader %s: %v", targetName, err) + } + }() var logReader io.Reader = gzr if g.anonymize { @@ -689,7 +761,6 @@ func (g *BundleGenerator) addRotatedLogFiles(logDir string) { return fi.ModTime().After(fj.ModTime()) }) - // include up to logFileCount rotated files maxFiles := int(g.logFileCount) if maxFiles > len(files) { maxFiles = len(files) @@ -717,7 +788,7 @@ func (g *BundleGenerator) addFileToZip(reader io.Reader, filename string) error // If the reader is a file, we can get more accurate information if f, ok := reader.(*os.File); ok { if stat, err := f.Stat(); err != nil { - log.Tracef("Failed to get file stat for %s: %v", filename, err) + log.Tracef("failed to get file stat for %s: %v", filename, err) } else { header.Modified = stat.ModTime() } @@ -765,89 +836,6 @@ func seedFromStatus(a *anonymize.Anonymizer, status *peer.FullStatus) { } } -func formatRoutes(routes []netip.Prefix, anonymize bool, anonymizer *anonymize.Anonymizer) string { - var ipv4Routes, ipv6Routes []netip.Prefix - - // Separate IPv4 and IPv6 routes - for _, route := range routes { - if route.Addr().Is4() { - ipv4Routes = append(ipv4Routes, route) - } else { - ipv6Routes = append(ipv6Routes, route) - } - } - - // Sort IPv4 and IPv6 routes separately - sort.Slice(ipv4Routes, func(i, j int) bool { - return ipv4Routes[i].Bits() > ipv4Routes[j].Bits() - }) - sort.Slice(ipv6Routes, func(i, j int) bool { - return ipv6Routes[i].Bits() > ipv6Routes[j].Bits() - }) - - var builder strings.Builder - - // Format IPv4 routes - builder.WriteString("IPv4 Routes:\n") - for _, route := range ipv4Routes { - formatRoute(&builder, route, anonymize, anonymizer) - } - - // Format IPv6 routes - builder.WriteString("\nIPv6 Routes:\n") - for _, route := range ipv6Routes { - formatRoute(&builder, route, anonymize, anonymizer) - } - - return builder.String() -} - -func formatRoute(builder *strings.Builder, route netip.Prefix, anonymize bool, anonymizer *anonymize.Anonymizer) { - if anonymize { - anonymizedIP := anonymizer.AnonymizeIP(route.Addr()) - builder.WriteString(fmt.Sprintf("%s/%d\n", anonymizedIP, route.Bits())) - } else { - builder.WriteString(fmt.Sprintf("%s\n", route)) - } -} - -func formatInterfaces(interfaces []net.Interface, anonymize bool, anonymizer *anonymize.Anonymizer) string { - sort.Slice(interfaces, func(i, j int) bool { - return interfaces[i].Name < interfaces[j].Name - }) - - var builder strings.Builder - builder.WriteString("Network Interfaces:\n") - - for _, iface := range interfaces { - builder.WriteString(fmt.Sprintf("\nInterface: %s\n", iface.Name)) - builder.WriteString(fmt.Sprintf(" Index: %d\n", iface.Index)) - builder.WriteString(fmt.Sprintf(" MTU: %d\n", iface.MTU)) - builder.WriteString(fmt.Sprintf(" Flags: %v\n", iface.Flags)) - - addrs, err := iface.Addrs() - if err != nil { - builder.WriteString(fmt.Sprintf(" Addresses: Error retrieving addresses: %v\n", err)) - } else { - builder.WriteString(" Addresses:\n") - for _, addr := range addrs { - prefix, err := netip.ParsePrefix(addr.String()) - if err != nil { - builder.WriteString(fmt.Sprintf(" Error parsing address: %v\n", err)) - continue - } - ip := prefix.Addr() - if anonymize { - ip = anonymizer.AnonymizeIP(ip) - } - builder.WriteString(fmt.Sprintf(" %s/%d\n", ip, prefix.Bits())) - } - } - } - - return builder.String() -} - func anonymizeLog(reader io.Reader, writer *io.PipeWriter, anonymizer *anonymize.Anonymizer) { defer func() { // always nil @@ -954,7 +942,6 @@ func anonymizeRemotePeer(peer *mgmProto.RemotePeerConfig, anonymizer *anonymize. } for i, ip := range peer.AllowedIps { - // Try to parse as prefix first (CIDR) if prefix, err := netip.ParsePrefix(ip); err == nil { anonIP := anonymizer.AnonymizeIP(prefix.Addr()) peer.AllowedIps[i] = fmt.Sprintf("%s/%d", anonIP, prefix.Bits()) @@ -1033,7 +1020,7 @@ func anonymizeRecords(records []*mgmProto.SimpleRecord, anonymizer *anonymize.An func anonymizeRData(record *mgmProto.SimpleRecord, anonymizer *anonymize.Anonymizer) { switch record.Type { - case 1, 28: // A or AAAA record + case 1, 28: if addr, err := netip.ParseAddr(record.RData); err == nil { record.RData = anonymizer.AnonymizeIP(addr).String() } diff --git a/client/internal/debug/debug_linux.go b/client/internal/debug/debug_linux.go index 4626cd9a2..39d796fda 100644 --- a/client/internal/debug/debug_linux.go +++ b/client/internal/debug/debug_linux.go @@ -17,8 +17,27 @@ import ( "github.com/google/nftables" "github.com/google/nftables/expr" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" ) +// addIPRules collects and adds IP rules to the archive +func (g *BundleGenerator) addIPRules() error { + log.Info("Collecting IP rules") + ipRules, err := systemops.GetIPRules() + if err != nil { + return fmt.Errorf("get IP rules: %w", err) + } + + rulesContent := formatIPRulesTable(ipRules, g.anonymize, g.anonymizer) + rulesReader := strings.NewReader(rulesContent) + if err := g.addFileToZip(rulesReader, "ip_rules.txt"); err != nil { + return fmt.Errorf("add IP rules file to zip: %w", err) + } + + return nil +} + const ( maxLogEntries = 100000 maxLogAge = 7 * 24 * time.Hour // Last 7 days @@ -136,7 +155,6 @@ func (g *BundleGenerator) addFirewallRules() error { func collectIPTablesRules() (string, error) { var builder strings.Builder - // First try using iptables-save saveOutput, err := collectIPTablesSave() if err != nil { log.Warnf("Failed to collect iptables rules using iptables-save: %v", err) @@ -146,7 +164,6 @@ func collectIPTablesRules() (string, error) { builder.WriteString("\n") } - // Collect ipset information ipsetOutput, err := collectIPSets() if err != nil { log.Warnf("Failed to collect ipset information: %v", err) @@ -232,11 +249,9 @@ func getTableStatistics(table string) (string, error) { // collectNFTablesRules attempts to collect nftables rules using either nft command or netlink func collectNFTablesRules() (string, error) { - // First try using nft command rules, err := collectNFTablesFromCommand() if err != nil { log.Debugf("Failed to collect nftables rules using nft command: %v, falling back to netlink", err) - // Fall back to netlink rules, err = collectNFTablesFromNetlink() if err != nil { return "", fmt.Errorf("collect nftables rules using both nft and netlink failed: %w", err) @@ -451,7 +466,6 @@ func formatRule(rule *nftables.Rule) string { func formatExprSequence(builder *strings.Builder, exprs []expr.Any, i int) int { curr := exprs[i] - // Handle Meta + Cmp sequence if meta, ok := curr.(*expr.Meta); ok && i+1 < len(exprs) { if cmp, ok := exprs[i+1].(*expr.Cmp); ok { if formatted := formatMetaWithCmp(meta, cmp); formatted != "" { @@ -461,7 +475,6 @@ func formatExprSequence(builder *strings.Builder, exprs []expr.Any, i int) int { } } - // Handle Payload + Cmp sequence if payload, ok := curr.(*expr.Payload); ok && i+1 < len(exprs) { if cmp, ok := exprs[i+1].(*expr.Cmp); ok { builder.WriteString(formatPayloadWithCmp(payload, cmp)) @@ -493,13 +506,13 @@ func formatMetaWithCmp(meta *expr.Meta, cmp *expr.Cmp) string { func formatPayloadWithCmp(p *expr.Payload, cmp *expr.Cmp) string { if p.Base == expr.PayloadBaseNetworkHeader { switch p.Offset { - case 12: // Source IP + case 12: if p.Len == 4 { return fmt.Sprintf("ip saddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data)) } else if p.Len == 2 { return fmt.Sprintf("ip saddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data)) } - case 16: // Destination IP + case 16: if p.Len == 4 { return fmt.Sprintf("ip daddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data)) } else if p.Len == 2 { @@ -580,7 +593,6 @@ func formatExpr(exp expr.Any) string { } func formatImmediateData(data []byte) string { - // For IP addresses (4 bytes) if len(data) == 4 { return fmt.Sprintf("%d.%d.%d.%d", data[0], data[1], data[2], data[3]) } @@ -588,26 +600,21 @@ func formatImmediateData(data []byte) string { } func formatMeta(e *expr.Meta) string { - // Handle source register case first (meta mark set) if e.SourceRegister { return fmt.Sprintf("meta %s set reg %d", formatMetaKey(e.Key), e.Register) } - // For interface names, handle register load operation switch e.Key { case expr.MetaKeyIIFNAME, expr.MetaKeyOIFNAME, expr.MetaKeyBRIIIFNAME, expr.MetaKeyBRIOIFNAME: - // Simply the key name with no register reference return formatMetaKey(e.Key) case expr.MetaKeyMARK: - // For mark operations, we want just "mark" return "mark" } - // For other meta keys, show as loading into register return fmt.Sprintf("meta %s => reg %d", formatMetaKey(e.Key), e.Register) } diff --git a/client/internal/debug/debug_nonlinux.go b/client/internal/debug/debug_nonlinux.go index b0ff55613..ace53bd94 100644 --- a/client/internal/debug/debug_nonlinux.go +++ b/client/internal/debug/debug_nonlinux.go @@ -12,3 +12,8 @@ func (g *BundleGenerator) trySystemdLogFallback() error { // TODO: Add BSD support return nil } + +func (g *BundleGenerator) addIPRules() error { + // IP rules are only supported on Linux + return nil +} diff --git a/client/internal/debug/debug_nonmobile.go b/client/internal/debug/debug_nonmobile.go index 3b487f07f..1f69f50c9 100644 --- a/client/internal/debug/debug_nonmobile.go +++ b/client/internal/debug/debug_nonmobile.go @@ -10,16 +10,16 @@ import ( ) func (g *BundleGenerator) addRoutes() error { - routes, err := systemops.GetRoutesFromTable() + detailedRoutes, err := systemops.GetDetailedRoutesFromTable() if err != nil { - return fmt.Errorf("get routes: %w", err) + return fmt.Errorf("get detailed routes: %w", err) } - // TODO: get routes including nexthop - routesContent := formatRoutes(routes, g.anonymize, g.anonymizer) + routesContent := formatRoutesTable(detailedRoutes, g.anonymize, g.anonymizer) routesReader := strings.NewReader(routesContent) if err := g.addFileToZip(routesReader, "routes.txt"); err != nil { return fmt.Errorf("add routes file to zip: %w", err) } + return nil } diff --git a/client/internal/debug/format.go b/client/internal/debug/format.go new file mode 100644 index 000000000..54fc77f93 --- /dev/null +++ b/client/internal/debug/format.go @@ -0,0 +1,206 @@ +package debug + +import ( + "fmt" + "net" + "net/netip" + "sort" + "strings" + + "github.com/netbirdio/netbird/client/anonymize" + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" + "github.com/netbirdio/netbird/management/domain" +) + +func formatInterfaces(interfaces []net.Interface, anonymize bool, anonymizer *anonymize.Anonymizer) string { + sort.Slice(interfaces, func(i, j int) bool { + return interfaces[i].Name < interfaces[j].Name + }) + + var builder strings.Builder + builder.WriteString("Network Interfaces:\n") + + for _, iface := range interfaces { + builder.WriteString(fmt.Sprintf("\nInterface: %s\n", iface.Name)) + builder.WriteString(fmt.Sprintf(" Index: %d\n", iface.Index)) + builder.WriteString(fmt.Sprintf(" MTU: %d\n", iface.MTU)) + builder.WriteString(fmt.Sprintf(" Flags: %v\n", iface.Flags)) + + addrs, err := iface.Addrs() + if err != nil { + builder.WriteString(fmt.Sprintf(" Addresses: Error retrieving addresses: %v\n", err)) + } else { + builder.WriteString(" Addresses:\n") + for _, addr := range addrs { + prefix, err := netip.ParsePrefix(addr.String()) + if err != nil { + builder.WriteString(fmt.Sprintf(" Error parsing address: %v\n", err)) + continue + } + ip := prefix.Addr() + if anonymize { + ip = anonymizer.AnonymizeIP(ip) + } + builder.WriteString(fmt.Sprintf(" %s/%d\n", ip, prefix.Bits())) + } + } + } + + return builder.String() +} + +func formatResolvedDomains(resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo, anonymize bool, anonymizer *anonymize.Anonymizer) string { + if len(resolvedDomains) == 0 { + return "No resolved domains found.\n" + } + + var builder strings.Builder + builder.WriteString("Resolved Domains:\n") + builder.WriteString("=================\n\n") + + var sortedParents []domain.Domain + for parentDomain := range resolvedDomains { + sortedParents = append(sortedParents, parentDomain) + } + sort.Slice(sortedParents, func(i, j int) bool { + return sortedParents[i].SafeString() < sortedParents[j].SafeString() + }) + + for _, parentDomain := range sortedParents { + info := resolvedDomains[parentDomain] + + parentKey := parentDomain.SafeString() + if anonymize { + parentKey = anonymizer.AnonymizeDomain(parentKey) + } + + builder.WriteString(fmt.Sprintf("%s:\n", parentKey)) + + var sortedIPs []string + for _, prefix := range info.Prefixes { + ipStr := prefix.String() + if anonymize { + anonymizedIP := anonymizer.AnonymizeIP(prefix.Addr()) + ipStr = fmt.Sprintf("%s/%d", anonymizedIP, prefix.Bits()) + } + sortedIPs = append(sortedIPs, ipStr) + } + sort.Strings(sortedIPs) + + for _, ipStr := range sortedIPs { + builder.WriteString(fmt.Sprintf(" %s\n", ipStr)) + } + builder.WriteString("\n") + } + + return builder.String() +} + +func formatRoutesTable(detailedRoutes []systemops.DetailedRoute, anonymize bool, anonymizer *anonymize.Anonymizer) string { + if len(detailedRoutes) == 0 { + return "No routes found.\n" + } + + sort.Slice(detailedRoutes, func(i, j int) bool { + if detailedRoutes[i].Table != detailedRoutes[j].Table { + return detailedRoutes[i].Table < detailedRoutes[j].Table + } + return detailedRoutes[i].Route.Dst.String() < detailedRoutes[j].Route.Dst.String() + }) + + headers, rows := buildPlatformSpecificRouteTable(detailedRoutes, anonymize, anonymizer) + + return formatTable("Routing Table:", headers, rows) +} + +func formatRouteDestination(destination netip.Prefix, anonymize bool, anonymizer *anonymize.Anonymizer) string { + if anonymize { + anonymizedDestIP := anonymizer.AnonymizeIP(destination.Addr()) + return fmt.Sprintf("%s/%d", anonymizedDestIP, destination.Bits()) + } + return destination.String() +} + +func formatRouteGateway(gateway netip.Addr, anonymize bool, anonymizer *anonymize.Anonymizer) string { + if gateway.IsValid() { + if anonymize { + return anonymizer.AnonymizeIP(gateway).String() + } + return gateway.String() + } + return "-" +} + +func formatRouteInterface(iface *net.Interface) string { + if iface != nil { + return iface.Name + } + return "-" +} + +func formatInterfaceIndex(index int) string { + if index <= 0 { + return "-" + } + return fmt.Sprintf("%d", index) +} + +func formatRouteMetric(metric int) string { + if metric < 0 { + return "-" + } + return fmt.Sprintf("%d", metric) +} + +func formatTable(title string, headers []string, rows [][]string) string { + widths := make([]int, len(headers)) + + for i, header := range headers { + widths[i] = len(header) + } + + for _, row := range rows { + for i, cell := range row { + if len(cell) > widths[i] { + widths[i] = len(cell) + } + } + } + + for i := range widths { + widths[i] += 2 + } + + var formatParts []string + for _, width := range widths { + formatParts = append(formatParts, fmt.Sprintf("%%-%ds", width)) + } + formatStr := strings.Join(formatParts, "") + "\n" + + var builder strings.Builder + builder.WriteString(title + "\n") + builder.WriteString(strings.Repeat("=", len(title)) + "\n\n") + + headerArgs := make([]interface{}, len(headers)) + for i, header := range headers { + headerArgs[i] = header + } + builder.WriteString(fmt.Sprintf(formatStr, headerArgs...)) + + separatorArgs := make([]interface{}, len(headers)) + for i, width := range widths { + separatorArgs[i] = strings.Repeat("-", width-2) + } + builder.WriteString(fmt.Sprintf(formatStr, separatorArgs...)) + + for _, row := range rows { + rowArgs := make([]interface{}, len(row)) + for i, cell := range row { + rowArgs[i] = cell + } + builder.WriteString(fmt.Sprintf(formatStr, rowArgs...)) + } + + return builder.String() +} diff --git a/client/internal/debug/format_linux.go b/client/internal/debug/format_linux.go new file mode 100644 index 000000000..7a2ba49ea --- /dev/null +++ b/client/internal/debug/format_linux.go @@ -0,0 +1,185 @@ +//go:build linux && !android + +package debug + +import ( + "fmt" + "net/netip" + "sort" + + "github.com/netbirdio/netbird/client/anonymize" + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" +) + +func formatIPRulesTable(ipRules []systemops.IPRule, anonymize bool, anonymizer *anonymize.Anonymizer) string { + if len(ipRules) == 0 { + return "No IP rules found.\n" + } + + sort.Slice(ipRules, func(i, j int) bool { + return ipRules[i].Priority < ipRules[j].Priority + }) + + columnConfig := detectIPRuleColumns(ipRules) + + headers := buildIPRuleHeaders(columnConfig) + + rows := buildIPRuleRows(ipRules, columnConfig, anonymize, anonymizer) + + return formatTable("IP Rules:", headers, rows) +} + +type ipRuleColumnConfig struct { + hasInvert, hasTo, hasMark, hasIIF, hasOIF, hasSuppressPlen bool +} + +func detectIPRuleColumns(ipRules []systemops.IPRule) ipRuleColumnConfig { + var config ipRuleColumnConfig + for _, rule := range ipRules { + if rule.Invert { + config.hasInvert = true + } + if rule.To.IsValid() { + config.hasTo = true + } + if rule.Mark != 0 { + config.hasMark = true + } + if rule.IIF != "" { + config.hasIIF = true + } + if rule.OIF != "" { + config.hasOIF = true + } + if rule.SuppressPlen >= 0 { + config.hasSuppressPlen = true + } + } + return config +} + +func buildIPRuleHeaders(config ipRuleColumnConfig) []string { + var headers []string + + headers = append(headers, "Priority") + if config.hasInvert { + headers = append(headers, "Not") + } + headers = append(headers, "From") + if config.hasTo { + headers = append(headers, "To") + } + if config.hasMark { + headers = append(headers, "FWMark") + } + if config.hasIIF { + headers = append(headers, "IIF") + } + if config.hasOIF { + headers = append(headers, "OIF") + } + headers = append(headers, "Table") + headers = append(headers, "Action") + if config.hasSuppressPlen { + headers = append(headers, "SuppressPlen") + } + + return headers +} + +func buildIPRuleRows(ipRules []systemops.IPRule, config ipRuleColumnConfig, anonymize bool, anonymizer *anonymize.Anonymizer) [][]string { + var rows [][]string + for _, rule := range ipRules { + row := buildSingleIPRuleRow(rule, config, anonymize, anonymizer) + rows = append(rows, row) + } + return rows +} + +func buildSingleIPRuleRow(rule systemops.IPRule, config ipRuleColumnConfig, anonymize bool, anonymizer *anonymize.Anonymizer) []string { + var row []string + + row = append(row, fmt.Sprintf("%d", rule.Priority)) + + if config.hasInvert { + row = append(row, formatIPRuleInvert(rule.Invert)) + } + + row = append(row, formatIPRuleAddress(rule.From, "all", anonymize, anonymizer)) + + if config.hasTo { + row = append(row, formatIPRuleAddress(rule.To, "-", anonymize, anonymizer)) + } + + if config.hasMark { + row = append(row, formatIPRuleMark(rule.Mark, rule.Mask)) + } + + if config.hasIIF { + row = append(row, formatIPRuleInterface(rule.IIF)) + } + + if config.hasOIF { + row = append(row, formatIPRuleInterface(rule.OIF)) + } + + row = append(row, rule.Table) + + row = append(row, formatIPRuleAction(rule.Action)) + + if config.hasSuppressPlen { + row = append(row, formatIPRuleSuppressPlen(rule.SuppressPlen)) + } + + return row +} + +func formatIPRuleInvert(invert bool) string { + if invert { + return "not" + } + return "-" +} + +func formatIPRuleAction(action string) string { + if action == "unspec" { + return "lookup" + } + return action +} + +func formatIPRuleSuppressPlen(suppressPlen int) string { + if suppressPlen >= 0 { + return fmt.Sprintf("%d", suppressPlen) + } + return "-" +} + +func formatIPRuleAddress(prefix netip.Prefix, defaultVal string, anonymize bool, anonymizer *anonymize.Anonymizer) string { + if !prefix.IsValid() { + return defaultVal + } + + if anonymize { + anonymizedIP := anonymizer.AnonymizeIP(prefix.Addr()) + return fmt.Sprintf("%s/%d", anonymizedIP, prefix.Bits()) + } + return prefix.String() +} + +func formatIPRuleMark(mark, mask uint32) string { + if mark == 0 { + return "-" + } + if mask != 0 { + return fmt.Sprintf("0x%x/0x%x", mark, mask) + } + return fmt.Sprintf("0x%x", mark) +} + +func formatIPRuleInterface(iface string) string { + if iface == "" { + return "-" + } + return iface +} diff --git a/client/internal/debug/format_nonwindows.go b/client/internal/debug/format_nonwindows.go new file mode 100644 index 000000000..3ad5c596c --- /dev/null +++ b/client/internal/debug/format_nonwindows.go @@ -0,0 +1,27 @@ +//go:build !windows + +package debug + +import ( + "github.com/netbirdio/netbird/client/anonymize" + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" +) + +// buildPlatformSpecificRouteTable builds headers and rows for non-Windows platforms +func buildPlatformSpecificRouteTable(detailedRoutes []systemops.DetailedRoute, anonymize bool, anonymizer *anonymize.Anonymizer) ([]string, [][]string) { + headers := []string{"Destination", "Gateway", "Interface", "Idx", "Metric", "Protocol", "Scope", "Type", "Table", "Flags"} + + var rows [][]string + for _, route := range detailedRoutes { + destStr := formatRouteDestination(route.Route.Dst, anonymize, anonymizer) + gatewayStr := formatRouteGateway(route.Route.Gw, anonymize, anonymizer) + interfaceStr := formatRouteInterface(route.Route.Interface) + indexStr := formatInterfaceIndex(route.InterfaceIndex) + metricStr := formatRouteMetric(route.Metric) + + row := []string{destStr, gatewayStr, interfaceStr, indexStr, metricStr, route.Protocol, route.Scope, route.Type, route.Table, route.Flags} + rows = append(rows, row) + } + + return headers, rows +} diff --git a/client/internal/debug/format_windows.go b/client/internal/debug/format_windows.go new file mode 100644 index 000000000..b37112d6f --- /dev/null +++ b/client/internal/debug/format_windows.go @@ -0,0 +1,37 @@ +//go:build windows + +package debug + +import ( + "fmt" + + "github.com/netbirdio/netbird/client/anonymize" + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" +) + +// buildPlatformSpecificRouteTable builds headers and rows for Windows with interface metrics +func buildPlatformSpecificRouteTable(detailedRoutes []systemops.DetailedRoute, anonymize bool, anonymizer *anonymize.Anonymizer) ([]string, [][]string) { + headers := []string{"Destination", "Gateway", "Interface", "Idx", "Metric", "If Metric", "Protocol", "Age", "Origin"} + + var rows [][]string + for _, route := range detailedRoutes { + destStr := formatRouteDestination(route.Route.Dst, anonymize, anonymizer) + gatewayStr := formatRouteGateway(route.Route.Gw, anonymize, anonymizer) + interfaceStr := formatRouteInterface(route.Route.Interface) + indexStr := formatInterfaceIndex(route.InterfaceIndex) + metricStr := formatRouteMetric(route.Metric) + ifMetricStr := formatInterfaceMetric(route.InterfaceMetric) + + row := []string{destStr, gatewayStr, interfaceStr, indexStr, metricStr, ifMetricStr, route.Protocol, route.Scope, route.Type} + rows = append(rows, row) + } + + return headers, rows +} + +func formatInterfaceMetric(metric int) string { + if metric < 0 { + return "-" + } + return fmt.Sprintf("%d", metric) +} diff --git a/client/internal/routemanager/systemops/routeflags_bsd.go b/client/internal/routemanager/systemops/routeflags_bsd.go index 12f158dcb..ad32e5029 100644 --- a/client/internal/routemanager/systemops/routeflags_bsd.go +++ b/client/internal/routemanager/systemops/routeflags_bsd.go @@ -2,9 +2,12 @@ package systemops -import "syscall" +import ( + "strings" + "syscall" +) -// filterRoutesByFlags - return true if need to ignore such route message because it consists specific flags. +// filterRoutesByFlags returns true if the route message should be ignored based on its flags. func filterRoutesByFlags(routeMessageFlags int) bool { if routeMessageFlags&syscall.RTF_UP == 0 { return true @@ -16,3 +19,50 @@ func filterRoutesByFlags(routeMessageFlags int) bool { return false } + +// formatBSDFlags formats route flags for BSD systems (excludes FreeBSD-specific handling) +func formatBSDFlags(flags int) string { + var flagStrs []string + + if flags&syscall.RTF_UP != 0 { + flagStrs = append(flagStrs, "U") + } + if flags&syscall.RTF_GATEWAY != 0 { + flagStrs = append(flagStrs, "G") + } + if flags&syscall.RTF_HOST != 0 { + flagStrs = append(flagStrs, "H") + } + if flags&syscall.RTF_REJECT != 0 { + flagStrs = append(flagStrs, "R") + } + if flags&syscall.RTF_DYNAMIC != 0 { + flagStrs = append(flagStrs, "D") + } + if flags&syscall.RTF_MODIFIED != 0 { + flagStrs = append(flagStrs, "M") + } + if flags&syscall.RTF_STATIC != 0 { + flagStrs = append(flagStrs, "S") + } + if flags&syscall.RTF_LLINFO != 0 { + flagStrs = append(flagStrs, "L") + } + if flags&syscall.RTF_LOCAL != 0 { + flagStrs = append(flagStrs, "l") + } + if flags&syscall.RTF_BLACKHOLE != 0 { + flagStrs = append(flagStrs, "B") + } + if flags&syscall.RTF_CLONING != 0 { + flagStrs = append(flagStrs, "C") + } + if flags&syscall.RTF_WASCLONED != 0 { + flagStrs = append(flagStrs, "W") + } + + if len(flagStrs) == 0 { + return "-" + } + return strings.Join(flagStrs, "") +} diff --git a/client/internal/routemanager/systemops/routeflags_freebsd.go b/client/internal/routemanager/systemops/routeflags_freebsd.go index cb35f521e..2338fe5d8 100644 --- a/client/internal/routemanager/systemops/routeflags_freebsd.go +++ b/client/internal/routemanager/systemops/routeflags_freebsd.go @@ -1,19 +1,64 @@ -//go:build: freebsd +//go:build freebsd + package systemops -import "syscall" +import ( + "strings" + "syscall" +) -// filterRoutesByFlags - return true if need to ignore such route message because it consists specific flags. +// filterRoutesByFlags returns true if the route message should be ignored based on its flags. func filterRoutesByFlags(routeMessageFlags int) bool { if routeMessageFlags&syscall.RTF_UP == 0 { return true } - // NOTE: syscall.RTF_WASCLONED deprecated in FreeBSD 8.0 (https://www.freebsd.org/releases/8.0R/relnotes-detailed/) - // a concept of cloned route (a route generated by an entry with RTF_CLONING flag) is deprecated. + // NOTE: syscall.RTF_WASCLONED deprecated in FreeBSD 8.0 if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE) != 0 { return true } return false } + +// formatBSDFlags formats route flags for FreeBSD (excludes deprecated RTF_CLONING and RTF_WASCLONED) +func formatBSDFlags(flags int) string { + var flagStrs []string + + if flags&syscall.RTF_UP != 0 { + flagStrs = append(flagStrs, "U") + } + if flags&syscall.RTF_GATEWAY != 0 { + flagStrs = append(flagStrs, "G") + } + if flags&syscall.RTF_HOST != 0 { + flagStrs = append(flagStrs, "H") + } + if flags&syscall.RTF_REJECT != 0 { + flagStrs = append(flagStrs, "R") + } + if flags&syscall.RTF_DYNAMIC != 0 { + flagStrs = append(flagStrs, "D") + } + if flags&syscall.RTF_MODIFIED != 0 { + flagStrs = append(flagStrs, "M") + } + if flags&syscall.RTF_STATIC != 0 { + flagStrs = append(flagStrs, "S") + } + if flags&syscall.RTF_LLINFO != 0 { + flagStrs = append(flagStrs, "L") + } + if flags&syscall.RTF_LOCAL != 0 { + flagStrs = append(flagStrs, "l") + } + if flags&syscall.RTF_BLACKHOLE != 0 { + flagStrs = append(flagStrs, "B") + } + // Note: RTF_CLONING and RTF_WASCLONED deprecated in FreeBSD 8.0 + + if len(flagStrs) == 0 { + return "-" + } + return strings.Join(flagStrs, "") +} diff --git a/client/internal/routemanager/systemops/systemops.go b/client/internal/routemanager/systemops/systemops.go index b91348e94..8da138117 100644 --- a/client/internal/routemanager/systemops/systemops.go +++ b/client/internal/routemanager/systemops/systemops.go @@ -19,6 +19,26 @@ type Nexthop struct { Intf *net.Interface } +// Route represents a basic network route with core routing information +type Route struct { + Dst netip.Prefix + Gw netip.Addr + Interface *net.Interface +} + +// DetailedRoute extends Route with additional metadata for display and debugging +type DetailedRoute struct { + Route + Metric int + InterfaceMetric int + InterfaceIndex int + Protocol string + Scope string + Type string + Table string + Flags string +} + // Equal checks if two nexthops are equal. func (n Nexthop) Equal(other Nexthop) bool { return n.IP == other.IP && (n.Intf == nil && other.Intf == nil || diff --git a/client/internal/routemanager/systemops/systemops_bsd.go b/client/internal/routemanager/systemops/systemops_bsd.go index 5e3b20a86..3ce78a04a 100644 --- a/client/internal/routemanager/systemops/systemops_bsd.go +++ b/client/internal/routemanager/systemops/systemops_bsd.go @@ -16,12 +16,6 @@ import ( "golang.org/x/net/route" ) -type Route struct { - Dst netip.Prefix - Gw netip.Addr - Interface *net.Interface -} - func GetRoutesFromTable() ([]netip.Prefix, error) { tab, err := retryFetchRIB() if err != nil { @@ -47,25 +41,134 @@ func GetRoutesFromTable() ([]netip.Prefix, error) { continue } - route, err := MsgToRoute(m) + r, err := MsgToRoute(m) if err != nil { log.Warnf("Failed to parse route message: %v", err) continue } - if route.Dst.IsValid() { - prefixList = append(prefixList, route.Dst) + if r.Dst.IsValid() { + prefixList = append(prefixList, r.Dst) } } return prefixList, nil } +func GetDetailedRoutesFromTable() ([]DetailedRoute, error) { + tab, err := retryFetchRIB() + if err != nil { + return nil, fmt.Errorf("fetch RIB: %v", err) + } + + msgs, err := route.ParseRIB(route.RIBTypeRoute, tab) + if err != nil { + return nil, fmt.Errorf("parse RIB: %v", err) + } + + return processRouteMessages(msgs) +} + +func processRouteMessages(msgs []route.Message) ([]DetailedRoute, error) { + var detailedRoutes []DetailedRoute + + for _, msg := range msgs { + m := msg.(*route.RouteMessage) + + if !isValidRouteMessage(m) { + continue + } + + if filterRoutesByFlags(m.Flags) { + continue + } + + detailed, err := buildDetailedRouteFromMessage(m) + if err != nil { + log.Warnf("Failed to parse route message: %v", err) + continue + } + + if detailed != nil { + detailedRoutes = append(detailedRoutes, *detailed) + } + } + + return detailedRoutes, nil +} + +func isValidRouteMessage(m *route.RouteMessage) bool { + if m.Version < 3 || m.Version > 5 { + log.Warnf("Unexpected RIB message version: %d", m.Version) + return false + } + if m.Type != syscall.RTM_GET { + log.Warnf("Unexpected RIB message type: %d", m.Type) + return false + } + return true +} + +func buildDetailedRouteFromMessage(m *route.RouteMessage) (*DetailedRoute, error) { + routeMsg, err := MsgToRoute(m) + if err != nil { + return nil, err + } + + if !routeMsg.Dst.IsValid() { + return nil, errors.New("invalid destination") + } + + detailed := DetailedRoute{ + Route: Route{ + Dst: routeMsg.Dst, + Gw: routeMsg.Gw, + Interface: routeMsg.Interface, + }, + Metric: extractBSDMetric(m), + Protocol: extractBSDProtocol(m.Flags), + Scope: "global", + Type: "unicast", + Table: "main", + Flags: formatBSDFlags(m.Flags), + } + + return &detailed, nil +} + +func buildLinkInterface(t *route.LinkAddr) *net.Interface { + interfaceName := fmt.Sprintf("link#%d", t.Index) + if t.Name != "" { + interfaceName = t.Name + } + return &net.Interface{ + Index: t.Index, + Name: interfaceName, + } +} + +func extractBSDMetric(m *route.RouteMessage) int { + return -1 +} + +func extractBSDProtocol(flags int) string { + if flags&syscall.RTF_STATIC != 0 { + return "static" + } + if flags&syscall.RTF_DYNAMIC != 0 { + return "dynamic" + } + if flags&syscall.RTF_LOCAL != 0 { + return "local" + } + return "kernel" +} + func retryFetchRIB() ([]byte, error) { var out []byte operation := func() error { var err error out, err = route.FetchRIB(syscall.AF_UNSPEC, route.RIBTypeRoute, 0) if errors.Is(err, syscall.ENOMEM) { - log.Debug("~etrying fetchRIB due to 'cannot allocate memory' error") + log.Debug("Retrying fetchRIB due to 'cannot allocate memory' error") return err } else if err != nil { return backoff.Permanent(err) @@ -100,7 +203,6 @@ func toNetIP(a route.Addr) netip.Addr { } } -// ones returns the number of leading ones in the mask. func ones(a route.Addr) (int, error) { switch t := a.(type) { case *route.Inet4Addr: @@ -114,7 +216,6 @@ func ones(a route.Addr) (int, error) { } } -// MsgToRoute converts a route message to a Route. func MsgToRoute(msg *route.RouteMessage) (*Route, error) { dstIP, nexthop, dstMask := msg.Addrs[0], msg.Addrs[1], msg.Addrs[2] @@ -127,10 +228,7 @@ func MsgToRoute(msg *route.RouteMessage) (*Route, error) { case *route.Inet4Addr, *route.Inet6Addr: nexthopAddr = toNetIP(t) case *route.LinkAddr: - nexthopIntf = &net.Interface{ - Index: t.Index, - Name: t.Name, - } + nexthopIntf = buildLinkInterface(t) default: return nil, fmt.Errorf("unexpected next hop type: %T", t) } @@ -156,5 +254,4 @@ func MsgToRoute(msg *route.RouteMessage) (*Route, error) { Gw: nexthopAddr, Interface: nexthopIntf, }, nil - } diff --git a/client/internal/routemanager/systemops/systemops_linux.go b/client/internal/routemanager/systemops/systemops_linux.go index 711f1d758..f50ea572c 100644 --- a/client/internal/routemanager/systemops/systemops_linux.go +++ b/client/internal/routemanager/systemops/systemops_linux.go @@ -14,6 +14,7 @@ import ( "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" "github.com/vishvananda/netlink" + "golang.org/x/sys/unix" nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/internal/routemanager/sysctl" @@ -22,6 +23,25 @@ import ( nbnet "github.com/netbirdio/netbird/util/net" ) +// IPRule contains IP rule information for debugging +type IPRule struct { + Priority int + From netip.Prefix + To netip.Prefix + IIF string + OIF string + Table string + Action string + Mark uint32 + Mask uint32 + TunID uint32 + Goto uint32 + Flow uint32 + SuppressPlen int + SuppressIFL int + Invert bool +} + const ( // NetbirdVPNTableID is the ID of the custom routing table used by Netbird. NetbirdVPNTableID = 0x1BD0 @@ -37,6 +57,8 @@ const ( var ErrTableIDExists = errors.New("ID exists with different name") +const errParsePrefixMsg = "failed to parse prefix %s: %w" + // originalSysctl stores the original sysctl values before they are modified var originalSysctl map[string]int @@ -209,6 +231,277 @@ func GetRoutesFromTable() ([]netip.Prefix, error) { return append(v4Routes, v6Routes...), nil } +// GetDetailedRoutesFromTable returns detailed route information from all routing tables +func GetDetailedRoutesFromTable() ([]DetailedRoute, error) { + tables := discoverRoutingTables() + return collectRoutesFromTables(tables), nil +} + +func discoverRoutingTables() []int { + tables, err := getAllRoutingTables() + if err != nil { + log.Warnf("Failed to get all routing tables, using fallback list: %v", err) + return []int{ + syscall.RT_TABLE_MAIN, + syscall.RT_TABLE_LOCAL, + NetbirdVPNTableID, + } + } + return tables +} + +func collectRoutesFromTables(tables []int) []DetailedRoute { + var allRoutes []DetailedRoute + + for _, tableID := range tables { + routes := collectRoutesFromTable(tableID) + allRoutes = append(allRoutes, routes...) + } + + return allRoutes +} + +func collectRoutesFromTable(tableID int) []DetailedRoute { + var routes []DetailedRoute + + if v4Routes := getRoutesForFamily(tableID, netlink.FAMILY_V4); len(v4Routes) > 0 { + routes = append(routes, v4Routes...) + } + + if v6Routes := getRoutesForFamily(tableID, netlink.FAMILY_V6); len(v6Routes) > 0 { + routes = append(routes, v6Routes...) + } + + return routes +} + +func getRoutesForFamily(tableID, family int) []DetailedRoute { + routes, err := getDetailedRoutes(tableID, family) + if err != nil { + log.Debugf("Failed to get routes from table %d family %d: %v", tableID, family, err) + return nil + } + return routes +} + +func getAllRoutingTables() ([]int, error) { + tablesMap := make(map[int]bool) + families := []int{netlink.FAMILY_V4, netlink.FAMILY_V6} + + // Use table 0 (RT_TABLE_UNSPEC) to discover all tables + for _, family := range families { + routes, err := netlink.RouteListFiltered(family, &netlink.Route{Table: 0}, netlink.RT_FILTER_TABLE) + if err != nil { + log.Debugf("Failed to list routes from table 0 for family %d: %v", family, err) + continue + } + + // Extract unique table IDs from all routes + for _, route := range routes { + if route.Table > 0 { + tablesMap[route.Table] = true + } + } + } + + var tables []int + for tableID := range tablesMap { + tables = append(tables, tableID) + } + + standardTables := []int{syscall.RT_TABLE_MAIN, syscall.RT_TABLE_LOCAL, NetbirdVPNTableID} + for _, table := range standardTables { + if !tablesMap[table] { + tables = append(tables, table) + } + } + + return tables, nil +} + +// getDetailedRoutes fetches detailed routes from a specific routing table +func getDetailedRoutes(tableID, family int) ([]DetailedRoute, error) { + var detailedRoutes []DetailedRoute + + routes, err := netlink.RouteListFiltered(family, &netlink.Route{Table: tableID}, netlink.RT_FILTER_TABLE) + if err != nil { + return nil, fmt.Errorf("list routes from table %d: %v", tableID, err) + } + + for _, route := range routes { + detailed := buildDetailedRoute(route, tableID, family) + if detailed != nil { + detailedRoutes = append(detailedRoutes, *detailed) + } + } + + return detailedRoutes, nil +} + +func buildDetailedRoute(route netlink.Route, tableID, family int) *DetailedRoute { + detailed := DetailedRoute{ + Route: Route{}, + Metric: route.Priority, + InterfaceMetric: -1, // Interface metrics not typically used on Linux + InterfaceIndex: route.LinkIndex, + Protocol: routeProtocolToString(int(route.Protocol)), + Scope: routeScopeToString(route.Scope), + Type: routeTypeToString(route.Type), + Table: routeTableToString(tableID), + Flags: "-", + } + + if !processRouteDestination(&detailed, route, family) { + return nil + } + + processRouteGateway(&detailed, route) + + processRouteInterface(&detailed, route) + + return &detailed +} + +func processRouteDestination(detailed *DetailedRoute, route netlink.Route, family int) bool { + if route.Dst != nil { + addr, ok := netip.AddrFromSlice(route.Dst.IP) + if !ok { + return false + } + ones, _ := route.Dst.Mask.Size() + prefix := netip.PrefixFrom(addr.Unmap(), ones) + if prefix.IsValid() { + detailed.Route.Dst = prefix + } else { + return false + } + } else { + if family == netlink.FAMILY_V4 { + detailed.Route.Dst = netip.MustParsePrefix("0.0.0.0/0") + } else { + detailed.Route.Dst = netip.MustParsePrefix("::/0") + } + } + return true +} + +func processRouteGateway(detailed *DetailedRoute, route netlink.Route) { + if route.Gw != nil { + if gateway, ok := netip.AddrFromSlice(route.Gw); ok { + detailed.Route.Gw = gateway.Unmap() + } + } +} + +func processRouteInterface(detailed *DetailedRoute, route netlink.Route) { + if route.LinkIndex > 0 { + if link, err := netlink.LinkByIndex(route.LinkIndex); err == nil { + detailed.Route.Interface = &net.Interface{ + Index: link.Attrs().Index, + Name: link.Attrs().Name, + } + } else { + detailed.Route.Interface = &net.Interface{ + Index: route.LinkIndex, + Name: fmt.Sprintf("index-%d", route.LinkIndex), + } + } + } +} + +// Helper functions to convert netlink constants to strings +func routeProtocolToString(protocol int) string { + switch protocol { + case syscall.RTPROT_UNSPEC: + return "unspec" + case syscall.RTPROT_REDIRECT: + return "redirect" + case syscall.RTPROT_KERNEL: + return "kernel" + case syscall.RTPROT_BOOT: + return "boot" + case syscall.RTPROT_STATIC: + return "static" + case syscall.RTPROT_DHCP: + return "dhcp" + case unix.RTPROT_RA: + return "ra" + case unix.RTPROT_ZEBRA: + return "zebra" + case unix.RTPROT_BIRD: + return "bird" + case unix.RTPROT_DNROUTED: + return "dnrouted" + case unix.RTPROT_XORP: + return "xorp" + case unix.RTPROT_NTK: + return "ntk" + default: + return fmt.Sprintf("%d", protocol) + } +} + +func routeScopeToString(scope netlink.Scope) string { + switch scope { + case netlink.SCOPE_UNIVERSE: + return "global" + case netlink.SCOPE_SITE: + return "site" + case netlink.SCOPE_LINK: + return "link" + case netlink.SCOPE_HOST: + return "host" + case netlink.SCOPE_NOWHERE: + return "nowhere" + default: + return fmt.Sprintf("%d", scope) + } +} + +func routeTypeToString(routeType int) string { + switch routeType { + case syscall.RTN_UNSPEC: + return "unspec" + case syscall.RTN_UNICAST: + return "unicast" + case syscall.RTN_LOCAL: + return "local" + case syscall.RTN_BROADCAST: + return "broadcast" + case syscall.RTN_ANYCAST: + return "anycast" + case syscall.RTN_MULTICAST: + return "multicast" + case syscall.RTN_BLACKHOLE: + return "blackhole" + case syscall.RTN_UNREACHABLE: + return "unreachable" + case syscall.RTN_PROHIBIT: + return "prohibit" + case syscall.RTN_THROW: + return "throw" + case syscall.RTN_NAT: + return "nat" + case syscall.RTN_XRESOLVE: + return "xresolve" + default: + return fmt.Sprintf("%d", routeType) + } +} + +func routeTableToString(tableID int) string { + switch tableID { + case syscall.RT_TABLE_MAIN: + return "main" + case syscall.RT_TABLE_LOCAL: + return "local" + case NetbirdVPNTableID: + return "netbird" + default: + return fmt.Sprintf("%d", tableID) + } +} + // getRoutes fetches routes from a specific routing table identified by tableID. func getRoutes(tableID, family int) ([]netip.Prefix, error) { var prefixList []netip.Prefix @@ -237,6 +530,115 @@ func getRoutes(tableID, family int) ([]netip.Prefix, error) { return prefixList, nil } +// GetIPRules returns IP rules for debugging +func GetIPRules() ([]IPRule, error) { + v4Rules, err := getIPRules(netlink.FAMILY_V4) + if err != nil { + return nil, fmt.Errorf("get v4 rules: %w", err) + } + v6Rules, err := getIPRules(netlink.FAMILY_V6) + if err != nil { + return nil, fmt.Errorf("get v6 rules: %w", err) + } + return append(v4Rules, v6Rules...), nil +} + +// getIPRules fetches IP rules for the specified address family +func getIPRules(family int) ([]IPRule, error) { + rules, err := netlink.RuleList(family) + if err != nil { + return nil, fmt.Errorf("list rules for family %d: %w", family, err) + } + + var ipRules []IPRule + for _, rule := range rules { + ipRule := buildIPRule(rule) + ipRules = append(ipRules, ipRule) + } + + return ipRules, nil +} + +func buildIPRule(rule netlink.Rule) IPRule { + var mask uint32 + if rule.Mask != nil { + mask = *rule.Mask + } + + ipRule := IPRule{ + Priority: rule.Priority, + IIF: rule.IifName, + OIF: rule.OifName, + Table: ruleTableToString(rule.Table), + Action: ruleActionToString(int(rule.Type)), + Mark: rule.Mark, + Mask: mask, + TunID: uint32(rule.TunID), + Goto: uint32(rule.Goto), + Flow: uint32(rule.Flow), + SuppressPlen: rule.SuppressPrefixlen, + SuppressIFL: rule.SuppressIfgroup, + Invert: rule.Invert, + } + + if rule.Src != nil { + ipRule.From = parseRulePrefix(rule.Src) + } + + if rule.Dst != nil { + ipRule.To = parseRulePrefix(rule.Dst) + } + + return ipRule +} + +func parseRulePrefix(ipNet *net.IPNet) netip.Prefix { + if addr, ok := netip.AddrFromSlice(ipNet.IP); ok { + ones, _ := ipNet.Mask.Size() + prefix := netip.PrefixFrom(addr.Unmap(), ones) + if prefix.IsValid() { + return prefix + } + } + return netip.Prefix{} +} + +func ruleTableToString(table int) string { + switch table { + case syscall.RT_TABLE_MAIN: + return "main" + case syscall.RT_TABLE_LOCAL: + return "local" + case syscall.RT_TABLE_DEFAULT: + return "default" + case NetbirdVPNTableID: + return "netbird" + default: + return fmt.Sprintf("%d", table) + } +} + +func ruleActionToString(action int) string { + switch action { + case unix.FR_ACT_UNSPEC: + return "unspec" + case unix.FR_ACT_TO_TBL: + return "lookup" + case unix.FR_ACT_GOTO: + return "goto" + case unix.FR_ACT_NOP: + return "nop" + case unix.FR_ACT_BLACKHOLE: + return "blackhole" + case unix.FR_ACT_UNREACHABLE: + return "unreachable" + case unix.FR_ACT_PROHIBIT: + return "prohibit" + default: + return fmt.Sprintf("%d", action) + } +} + // addRoute adds a route to a specific routing table identified by tableID. func addRoute(prefix netip.Prefix, nexthop Nexthop, tableID int) error { route := &netlink.Route{ @@ -247,7 +649,7 @@ func addRoute(prefix netip.Prefix, nexthop Nexthop, tableID int) error { _, ipNet, err := net.ParseCIDR(prefix.String()) if err != nil { - return fmt.Errorf("parse prefix %s: %w", prefix, err) + return fmt.Errorf(errParsePrefixMsg, prefix, err) } route.Dst = ipNet @@ -268,7 +670,7 @@ func addRoute(prefix netip.Prefix, nexthop Nexthop, tableID int) error { func addUnreachableRoute(prefix netip.Prefix, tableID int) error { _, ipNet, err := net.ParseCIDR(prefix.String()) if err != nil { - return fmt.Errorf("parse prefix %s: %w", prefix, err) + return fmt.Errorf(errParsePrefixMsg, prefix, err) } route := &netlink.Route{ @@ -288,7 +690,7 @@ func addUnreachableRoute(prefix netip.Prefix, tableID int) error { func removeUnreachableRoute(prefix netip.Prefix, tableID int) error { _, ipNet, err := net.ParseCIDR(prefix.String()) if err != nil { - return fmt.Errorf("parse prefix %s: %w", prefix, err) + return fmt.Errorf(errParsePrefixMsg, prefix, err) } route := &netlink.Route{ @@ -313,7 +715,7 @@ func removeUnreachableRoute(prefix netip.Prefix, tableID int) error { func removeRoute(prefix netip.Prefix, nexthop Nexthop, tableID int) error { _, ipNet, err := net.ParseCIDR(prefix.String()) if err != nil { - return fmt.Errorf("parse prefix %s: %w", prefix, err) + return fmt.Errorf(errParsePrefixMsg, prefix, err) } route := &netlink.Route{ diff --git a/client/internal/routemanager/systemops/systemops_nonlinux.go b/client/internal/routemanager/systemops/systemops_nonlinux.go index 59581255f..83b64e82b 100644 --- a/client/internal/routemanager/systemops/systemops_nonlinux.go +++ b/client/internal/routemanager/systemops/systemops_nonlinux.go @@ -10,6 +10,25 @@ import ( log "github.com/sirupsen/logrus" ) +// IPRule contains IP rule information for debugging +type IPRule struct { + Priority int + From netip.Prefix + To netip.Prefix + IIF string + OIF string + Table string + Action string + Mark uint32 + Mask uint32 + TunID uint32 + Goto uint32 + Flow uint32 + SuppressPlen int + SuppressIFL int + Invert bool +} + func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error { if err := r.validateRoute(prefix); err != nil { return err @@ -32,3 +51,9 @@ func EnableIPForwarding() error { func hasSeparateRouting() ([]netip.Prefix, error) { return GetRoutesFromTable() } + +// GetIPRules returns IP rules for debugging (not supported on non-Linux platforms) +func GetIPRules() ([]IPRule, error) { + log.Infof("IP rules collection is not supported on %s", runtime.GOOS) + return []IPRule{}, nil +} diff --git a/client/internal/routemanager/systemops/systemops_windows.go b/client/internal/routemanager/systemops/systemops_windows.go index 7afac9ae5..36e714ec4 100644 --- a/client/internal/routemanager/systemops/systemops_windows.go +++ b/client/internal/routemanager/systemops/systemops_windows.go @@ -40,13 +40,6 @@ type RouteMonitor struct { done chan struct{} } -// Route represents a single routing table entry. -type Route struct { - Destination netip.Prefix - Nexthop netip.Addr - Interface *net.Interface -} - type MSFT_NetRoute struct { DestinationPrefix string NextHop string @@ -78,6 +71,12 @@ type MIB_IPFORWARD_ROW2 struct { Origin uint32 } +// MIB_IPFORWARD_TABLE2 represents a table of IP forward entries +type MIB_IPFORWARD_TABLE2 struct { + NumEntries uint32 + Table [1]MIB_IPFORWARD_ROW2 // Flexible array member +} + // IP_ADDRESS_PREFIX is defined in https://learn.microsoft.com/en-us/windows/win32/api/netioapi/ns-netioapi-ip_address_prefix type IP_ADDRESS_PREFIX struct { Prefix SOCKADDR_INET @@ -108,6 +107,45 @@ type SOCKADDR_INET_NEXTHOP struct { // MIB_NOTIFICATION_TYPE is defined in https://learn.microsoft.com/en-us/windows/win32/api/netioapi/ne-netioapi-mib_notification_type type MIB_NOTIFICATION_TYPE int32 +// MIB_IPINTERFACE_ROW is defined in https://learn.microsoft.com/en-us/windows/win32/api/netioapi/ns-netioapi-mib_ipinterface_row +type MIB_IPINTERFACE_ROW struct { + Family uint16 + InterfaceLuid luid + InterfaceIndex uint32 + MaxReassemblySize uint32 + InterfaceIdentifier uint64 + MinRouterAdvertisementInterval uint32 + MaxRouterAdvertisementInterval uint32 + AdvertisingEnabled uint8 + ForwardingEnabled uint8 + WeakHostSend uint8 + WeakHostReceive uint8 + UseAutomaticMetric uint8 + UseNeighborUnreachabilityDetection uint8 + ManagedAddressConfigurationSupported uint8 + OtherStatefulConfigurationSupported uint8 + AdvertiseDefaultRoute uint8 + RouterDiscoveryBehavior uint32 + DadTransmits uint32 + BaseReachableTime uint32 + RetransmitTime uint32 + PathMtuDiscoveryTimeout uint32 + LinkLocalAddressBehavior uint32 + LinkLocalAddressTimeout uint32 + ZoneIndices [16]uint32 + SitePrefixLength uint32 + Metric uint32 + NlMtu uint32 + Connected uint8 + SupportsWakeUpPatterns uint8 + SupportsNeighborDiscovery uint8 + SupportsRouterDiscovery uint8 + ReachableTime uint32 + TransmitOffload uint32 + ReceiveOffload uint32 + DisableDefaultRoutes uint8 +} + var ( modiphlpapi = windows.NewLazyDLL("iphlpapi.dll") procNotifyRouteChange2 = modiphlpapi.NewProc("NotifyRouteChange2") @@ -115,8 +153,11 @@ var ( procCreateIpForwardEntry2 = modiphlpapi.NewProc("CreateIpForwardEntry2") procDeleteIpForwardEntry2 = modiphlpapi.NewProc("DeleteIpForwardEntry2") procGetIpForwardEntry2 = modiphlpapi.NewProc("GetIpForwardEntry2") + procGetIpForwardTable2 = modiphlpapi.NewProc("GetIpForwardTable2") procInitializeIpForwardEntry = modiphlpapi.NewProc("InitializeIpForwardEntry") procConvertInterfaceIndexToLuid = modiphlpapi.NewProc("ConvertInterfaceIndexToLuid") + procGetIpInterfaceEntry = modiphlpapi.NewProc("GetIpInterfaceEntry") + procFreeMibTable = modiphlpapi.NewProc("FreeMibTable") prefixList []netip.Prefix lastUpdate time.Time @@ -429,6 +470,8 @@ func (rm *RouteMonitor) parseUpdate(row *MIB_IPFORWARD_ROW2, notificationType MI updateType = RouteAdded case MibDeleteInstance: updateType = RouteDeleted + case MibInitialNotification: + updateType = RouteAdded // Treat initial notifications as additions } update.Type = updateType @@ -508,7 +551,7 @@ func GetRoutesFromTable() ([]netip.Prefix, error) { prefixList = nil for _, route := range routes { - prefixList = append(prefixList, route.Destination) + prefixList = append(prefixList, route.Dst) } lastUpdate = time.Now() @@ -551,15 +594,159 @@ func GetRoutes() ([]Route, error) { } routes = append(routes, Route{ - Destination: dest, - Nexthop: nexthop, - Interface: intf, + Dst: dest, + Gw: nexthop, + Interface: intf, }) } return routes, nil } +// GetDetailedRoutesFromTable returns detailed route information using Windows syscalls +func GetDetailedRoutesFromTable() ([]DetailedRoute, error) { + table, err := getWindowsRoutingTable() + if err != nil { + return nil, err + } + + defer freeWindowsRoutingTable(table) + + return parseWindowsRoutingTable(table), nil +} + +func getWindowsRoutingTable() (*MIB_IPFORWARD_TABLE2, error) { + var table *MIB_IPFORWARD_TABLE2 + + ret, _, err := procGetIpForwardTable2.Call( + uintptr(windows.AF_UNSPEC), + uintptr(unsafe.Pointer(&table)), + ) + if ret != 0 { + return nil, fmt.Errorf("GetIpForwardTable2 failed: %w", err) + } + + if table == nil { + return nil, fmt.Errorf("received nil routing table") + } + + return table, nil +} + +func freeWindowsRoutingTable(table *MIB_IPFORWARD_TABLE2) { + if table != nil { + ret, _, _ := procFreeMibTable.Call(uintptr(unsafe.Pointer(table))) + if ret != 0 { + log.Warnf("FreeMibTable failed with return code: %d", ret) + } + } +} + +func parseWindowsRoutingTable(table *MIB_IPFORWARD_TABLE2) []DetailedRoute { + var detailedRoutes []DetailedRoute + + entrySize := unsafe.Sizeof(MIB_IPFORWARD_ROW2{}) + basePtr := uintptr(unsafe.Pointer(&table.Table[0])) + + for i := uint32(0); i < table.NumEntries; i++ { + entryPtr := basePtr + uintptr(i)*entrySize + entry := (*MIB_IPFORWARD_ROW2)(unsafe.Pointer(entryPtr)) + + detailed := buildWindowsDetailedRoute(entry) + if detailed != nil { + detailedRoutes = append(detailedRoutes, *detailed) + } + } + + return detailedRoutes +} + +func buildWindowsDetailedRoute(entry *MIB_IPFORWARD_ROW2) *DetailedRoute { + dest := parseIPPrefix(entry.DestinationPrefix, int(entry.InterfaceIndex)) + if !dest.IsValid() { + return nil + } + + gateway := parseIPNexthop(entry.NextHop, int(entry.InterfaceIndex)) + + var intf *net.Interface + if entry.InterfaceIndex != 0 { + if netIntf, err := net.InterfaceByIndex(int(entry.InterfaceIndex)); err == nil { + intf = netIntf + } else { + // Create a synthetic interface for display when we can't resolve the name + intf = &net.Interface{ + Index: int(entry.InterfaceIndex), + Name: fmt.Sprintf("index-%d", entry.InterfaceIndex), + } + } + } + + detailed := DetailedRoute{ + Route: Route{ + Dst: dest, + Gw: gateway, + Interface: intf, + }, + + Metric: int(entry.Metric), + InterfaceMetric: getInterfaceMetric(entry.InterfaceIndex, entry.DestinationPrefix.Prefix.sin6_family), + InterfaceIndex: int(entry.InterfaceIndex), + Protocol: windowsProtocolToString(entry.Protocol), + Scope: formatRouteAge(entry.Age), + Type: windowsOriginToString(entry.Origin), + Table: "main", + Flags: "-", + } + + return &detailed +} + +func windowsProtocolToString(protocol uint32) string { + switch protocol { + case 1: + return "other" + case 2: + return "local" + case 3: + return "netmgmt" + case 4: + return "icmp" + case 5: + return "egp" + case 6: + return "ggp" + case 7: + return "hello" + case 8: + return "rip" + case 9: + return "isis" + case 10: + return "esis" + case 11: + return "cisco" + case 12: + return "bbn" + case 13: + return "ospf" + case 14: + return "bgp" + case 15: + return "idpr" + case 16: + return "eigrp" + case 17: + return "dvmrp" + case 18: + return "rpl" + case 19: + return "dhcp" + default: + return fmt.Sprintf("unknown-%d", protocol) + } +} + func isCacheDisabled() bool { return os.Getenv("NB_DISABLE_ROUTE_CACHE") == "true" } @@ -614,3 +801,59 @@ func addZone(ip netip.Addr, interfaceIndex int) netip.Addr { } return ip } + +// getInterfaceMetric retrieves the interface metric for a given interface and address family +func getInterfaceMetric(interfaceIndex uint32, family int16) int { + if interfaceIndex == 0 { + return -1 + } + + var ipInterfaceRow MIB_IPINTERFACE_ROW + ipInterfaceRow.Family = uint16(family) + ipInterfaceRow.InterfaceIndex = interfaceIndex + + ret, _, _ := procGetIpInterfaceEntry.Call(uintptr(unsafe.Pointer(&ipInterfaceRow))) + if ret != 0 { + log.Debugf("GetIpInterfaceEntry failed for interface %d: %d", interfaceIndex, ret) + return -1 + } + + return int(ipInterfaceRow.Metric) +} + +// formatRouteAge formats the route age in seconds to a human-readable string +func formatRouteAge(ageSeconds uint32) string { + if ageSeconds == 0 { + return "0s" + } + + age := time.Duration(ageSeconds) * time.Second + switch { + case age < time.Minute: + return fmt.Sprintf("%ds", int(age.Seconds())) + case age < time.Hour: + return fmt.Sprintf("%dm", int(age.Minutes())) + case age < 24*time.Hour: + return fmt.Sprintf("%dh", int(age.Hours())) + default: + return fmt.Sprintf("%dd", int(age.Hours()/24)) + } +} + +// windowsOriginToString converts Windows route origin to string +func windowsOriginToString(origin uint32) string { + switch origin { + case 0: + return "manual" + case 1: + return "wellknown" + case 2: + return "dhcp" + case 3: + return "routeradvert" + case 4: + return "6to4" + default: + return fmt.Sprintf("unknown-%d", origin) + } +} From cb8b6ca59b4183867d170fa4897d4d75a5a333cc Mon Sep 17 00:00:00 2001 From: hakansa <43675540+hakansa@users.noreply.github.com> Date: Fri, 25 Jul 2025 16:54:46 +0300 Subject: [PATCH 39/50] [client] Feat: Support Multiple Profiles (#3980) [client] Feat: Support Multiple Profiles (#3980) --- client/android/client.go | 5 +- client/android/login.go | 13 +- client/android/preferences.go | 30 +- client/android/preferences_test.go | 6 +- client/cmd/debug.go | 5 +- client/cmd/debug_unix.go | 3 +- client/cmd/debug_windows.go | 5 +- client/cmd/login.go | 323 +++-- client/cmd/login_test.go | 57 +- client/cmd/profile.go | 236 ++++ client/cmd/root.go | 29 +- client/cmd/service_controller.go | 2 +- client/cmd/ssh.go | 27 +- client/cmd/status.go | 9 +- client/cmd/testutil_test.go | 4 +- client/cmd/up.go | 201 ++- client/cmd/up_daemon_test.go | 39 +- client/embed/embed.go | 11 +- client/internal/auth/oauth.go | 8 +- client/internal/auth/pkce_flow.go | 38 + client/internal/connect.go | 9 +- client/internal/debug/debug.go | 14 +- client/internal/engine.go | 19 +- client/internal/engine_test.go | 9 +- client/internal/login.go | 9 +- .../internal/{ => profilemanager}/config.go | 237 ++-- .../{ => profilemanager}/config_test.go | 2 +- client/internal/profilemanager/error.go | 9 + .../internal/profilemanager/profilemanager.go | 133 ++ .../profilemanager/profilemanager_test.go | 151 +++ client/internal/profilemanager/service.go | 359 ++++++ client/internal/profilemanager/state.go | 57 + client/internal/statemanager/path.go | 16 - client/ios/NetBirdSDK/client.go | 7 +- client/ios/NetBirdSDK/login.go | 13 +- client/ios/NetBirdSDK/preferences.go | 18 +- client/ios/NetBirdSDK/preferences_test.go | 6 +- client/proto/daemon.pb.go | 1116 +++++++++++++++-- client/proto/daemon.proto | 124 +- client/proto/daemon_grpc.pb.go | 216 ++++ client/server/panic_windows.go | 3 + client/server/server.go | 612 ++++++--- client/server/server_test.go | 120 +- client/server/state.go | 19 +- client/status/status.go | 6 +- client/status/status_test.go | 8 +- client/ui/assets/connected.png | Bin 0 -> 4743 bytes client/ui/assets/disconnected.png | Bin 0 -> 10530 bytes client/ui/client_ui.go | 437 +++++-- client/ui/const.go | 1 + client/ui/debug.go | 6 +- client/ui/profile.go | 601 +++++++++ util/file.go | 31 + 53 files changed, 4651 insertions(+), 768 deletions(-) create mode 100644 client/cmd/profile.go rename client/internal/{ => profilemanager}/config.go (93%) rename client/internal/{ => profilemanager}/config_test.go (99%) create mode 100644 client/internal/profilemanager/error.go create mode 100644 client/internal/profilemanager/profilemanager.go create mode 100644 client/internal/profilemanager/profilemanager_test.go create mode 100644 client/internal/profilemanager/service.go create mode 100644 client/internal/profilemanager/state.go delete mode 100644 client/internal/statemanager/path.go create mode 100644 client/ui/assets/connected.png create mode 100644 client/ui/assets/disconnected.png create mode 100644 client/ui/profile.go diff --git a/client/android/client.go b/client/android/client.go index 0d0c76549..6924d333c 100644 --- a/client/android/client.go +++ b/client/android/client.go @@ -13,6 +13,7 @@ import ( "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/formatter" @@ -82,7 +83,7 @@ func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersi // Run start the internal client. It is a blocker function func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener) error { - cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{ + cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{ ConfigPath: c.cfgFile, }) if err != nil { @@ -117,7 +118,7 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead // RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot). // In this case make no sense handle registration steps. func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener) error { - cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{ + cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{ ConfigPath: c.cfgFile, }) if err != nil { diff --git a/client/android/login.go b/client/android/login.go index 3d674c5be..d8ac645e2 100644 --- a/client/android/login.go +++ b/client/android/login.go @@ -13,6 +13,7 @@ import ( "github.com/netbirdio/netbird/client/cmd" "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/auth" + "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/system" ) @@ -37,17 +38,17 @@ type URLOpener interface { // Auth can register or login new client type Auth struct { ctx context.Context - config *internal.Config + config *profilemanager.Config cfgPath string } // NewAuth instantiate Auth struct and validate the management URL func NewAuth(cfgPath string, mgmURL string) (*Auth, error) { - inputCfg := internal.ConfigInput{ + inputCfg := profilemanager.ConfigInput{ ManagementURL: mgmURL, } - cfg, err := internal.CreateInMemoryConfig(inputCfg) + cfg, err := profilemanager.CreateInMemoryConfig(inputCfg) if err != nil { return nil, err } @@ -60,7 +61,7 @@ func NewAuth(cfgPath string, mgmURL string) (*Auth, error) { } // NewAuthWithConfig instantiate Auth based on existing config -func NewAuthWithConfig(ctx context.Context, config *internal.Config) *Auth { +func NewAuthWithConfig(ctx context.Context, config *profilemanager.Config) *Auth { return &Auth{ ctx: ctx, config: config, @@ -110,7 +111,7 @@ func (a *Auth) saveConfigIfSSOSupported() (bool, error) { return false, fmt.Errorf("backoff cycle failed: %v", err) } - err = internal.WriteOutConfig(a.cfgPath, a.config) + err = profilemanager.WriteOutConfig(a.cfgPath, a.config) return true, err } @@ -142,7 +143,7 @@ func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string return fmt.Errorf("backoff cycle failed: %v", err) } - return internal.WriteOutConfig(a.cfgPath, a.config) + return profilemanager.WriteOutConfig(a.cfgPath, a.config) } // Login try register the client on the server diff --git a/client/android/preferences.go b/client/android/preferences.go index 2d5668d1c..9a5d6bb21 100644 --- a/client/android/preferences.go +++ b/client/android/preferences.go @@ -1,17 +1,17 @@ package android import ( - "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/profilemanager" ) // Preferences exports a subset of the internal config for gomobile type Preferences struct { - configInput internal.ConfigInput + configInput profilemanager.ConfigInput } // NewPreferences creates a new Preferences instance func NewPreferences(configPath string) *Preferences { - ci := internal.ConfigInput{ + ci := profilemanager.ConfigInput{ ConfigPath: configPath, } return &Preferences{ci} @@ -23,7 +23,7 @@ func (p *Preferences) GetManagementURL() (string, error) { return p.configInput.ManagementURL, nil } - cfg, err := internal.ReadConfig(p.configInput.ConfigPath) + cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath) if err != nil { return "", err } @@ -41,7 +41,7 @@ func (p *Preferences) GetAdminURL() (string, error) { return p.configInput.AdminURL, nil } - cfg, err := internal.ReadConfig(p.configInput.ConfigPath) + cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath) if err != nil { return "", err } @@ -59,7 +59,7 @@ func (p *Preferences) GetPreSharedKey() (string, error) { return *p.configInput.PreSharedKey, nil } - cfg, err := internal.ReadConfig(p.configInput.ConfigPath) + cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath) if err != nil { return "", err } @@ -82,7 +82,7 @@ func (p *Preferences) GetRosenpassEnabled() (bool, error) { return *p.configInput.RosenpassEnabled, nil } - cfg, err := internal.ReadConfig(p.configInput.ConfigPath) + cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath) if err != nil { return false, err } @@ -100,7 +100,7 @@ func (p *Preferences) GetRosenpassPermissive() (bool, error) { return *p.configInput.RosenpassPermissive, nil } - cfg, err := internal.ReadConfig(p.configInput.ConfigPath) + cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath) if err != nil { return false, err } @@ -113,7 +113,7 @@ func (p *Preferences) GetDisableClientRoutes() (bool, error) { return *p.configInput.DisableClientRoutes, nil } - cfg, err := internal.ReadConfig(p.configInput.ConfigPath) + cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath) if err != nil { return false, err } @@ -131,7 +131,7 @@ func (p *Preferences) GetDisableServerRoutes() (bool, error) { return *p.configInput.DisableServerRoutes, nil } - cfg, err := internal.ReadConfig(p.configInput.ConfigPath) + cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath) if err != nil { return false, err } @@ -149,7 +149,7 @@ func (p *Preferences) GetDisableDNS() (bool, error) { return *p.configInput.DisableDNS, nil } - cfg, err := internal.ReadConfig(p.configInput.ConfigPath) + cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath) if err != nil { return false, err } @@ -167,7 +167,7 @@ func (p *Preferences) GetDisableFirewall() (bool, error) { return *p.configInput.DisableFirewall, nil } - cfg, err := internal.ReadConfig(p.configInput.ConfigPath) + cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath) if err != nil { return false, err } @@ -185,7 +185,7 @@ func (p *Preferences) GetServerSSHAllowed() (bool, error) { return *p.configInput.ServerSSHAllowed, nil } - cfg, err := internal.ReadConfig(p.configInput.ConfigPath) + cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath) if err != nil { return false, err } @@ -207,7 +207,7 @@ func (p *Preferences) GetBlockInbound() (bool, error) { return *p.configInput.BlockInbound, nil } - cfg, err := internal.ReadConfig(p.configInput.ConfigPath) + cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath) if err != nil { return false, err } @@ -221,6 +221,6 @@ func (p *Preferences) SetBlockInbound(block bool) { // Commit writes out the changes to the config file func (p *Preferences) Commit() error { - _, err := internal.UpdateOrCreateConfig(p.configInput) + _, err := profilemanager.UpdateOrCreateConfig(p.configInput) return err } diff --git a/client/android/preferences_test.go b/client/android/preferences_test.go index 985175913..2bbccef86 100644 --- a/client/android/preferences_test.go +++ b/client/android/preferences_test.go @@ -4,7 +4,7 @@ import ( "path/filepath" "testing" - "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/profilemanager" ) func TestPreferences_DefaultValues(t *testing.T) { @@ -15,7 +15,7 @@ func TestPreferences_DefaultValues(t *testing.T) { t.Fatalf("failed to read default value: %s", err) } - if defaultVar != internal.DefaultAdminURL { + if defaultVar != profilemanager.DefaultAdminURL { t.Errorf("invalid default admin url: %s", defaultVar) } @@ -24,7 +24,7 @@ func TestPreferences_DefaultValues(t *testing.T) { t.Fatalf("failed to read default management URL: %s", err) } - if defaultVar != internal.DefaultManagementURL { + if defaultVar != profilemanager.DefaultManagementURL { t.Errorf("invalid default management url: %s", defaultVar) } diff --git a/client/cmd/debug.go b/client/cmd/debug.go index 3f13a0c3a..a79fd40d0 100644 --- a/client/cmd/debug.go +++ b/client/cmd/debug.go @@ -13,6 +13,7 @@ import ( "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/debug" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/server" nbstatus "github.com/netbirdio/netbird/client/status" @@ -307,7 +308,7 @@ func getStatusOutput(cmd *cobra.Command, anon bool) string { cmd.PrintErrf("Failed to get status: %v\n", err) } else { statusOutputString = nbstatus.ParseToFullDetailSummary( - nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, ""), + nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", ""), ) } return statusOutputString @@ -355,7 +356,7 @@ func formatDuration(d time.Duration) string { return fmt.Sprintf("%02d:%02d:%02d", h, m, s) } -func generateDebugBundle(config *internal.Config, recorder *peer.Status, connectClient *internal.ConnectClient, logFilePath string) { +func generateDebugBundle(config *profilemanager.Config, recorder *peer.Status, connectClient *internal.ConnectClient, logFilePath string) { var networkMap *mgmProto.NetworkMap var err error diff --git a/client/cmd/debug_unix.go b/client/cmd/debug_unix.go index 45ace7e13..50065002e 100644 --- a/client/cmd/debug_unix.go +++ b/client/cmd/debug_unix.go @@ -12,11 +12,12 @@ import ( "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/profilemanager" ) func SetupDebugHandler( ctx context.Context, - config *internal.Config, + config *profilemanager.Config, recorder *peer.Status, connectClient *internal.ConnectClient, logFilePath string, diff --git a/client/cmd/debug_windows.go b/client/cmd/debug_windows.go index f57955fd4..f3017b47b 100644 --- a/client/cmd/debug_windows.go +++ b/client/cmd/debug_windows.go @@ -12,6 +12,7 @@ import ( "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/profilemanager" ) const ( @@ -28,7 +29,7 @@ const ( // $evt.Close() func SetupDebugHandler( ctx context.Context, - config *internal.Config, + config *profilemanager.Config, recorder *peer.Status, connectClient *internal.ConnectClient, logFilePath string, @@ -83,7 +84,7 @@ func SetupDebugHandler( func waitForEvent( ctx context.Context, - config *internal.Config, + config *profilemanager.Config, recorder *peer.Status, connectClient *internal.ConnectClient, logFilePath string, diff --git a/client/cmd/login.go b/client/cmd/login.go index f3a2f0cca..482e004d1 100644 --- a/client/cmd/login.go +++ b/client/cmd/login.go @@ -4,10 +4,12 @@ import ( "context" "fmt" "os" + "os/user" "runtime" "strings" "time" + log "github.com/sirupsen/logrus" "github.com/skratchdot/open-golang/open" "github.com/spf13/cobra" "google.golang.org/grpc/codes" @@ -15,6 +17,7 @@ import ( "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/auth" + "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/util" @@ -22,19 +25,16 @@ import ( func init() { loginCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc) + loginCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc) + loginCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "Netbird config file location") } var loginCmd = &cobra.Command{ Use: "login", Short: "login to the Netbird Management Service (first run)", RunE: func(cmd *cobra.Command, args []string) error { - SetFlagsFromEnvVars(rootCmd) - - cmd.SetOut(cmd.OutOrStdout()) - - err := util.InitLog(logLevel, util.LogConsole) - if err != nil { - return fmt.Errorf("failed initializing log %v", err) + if err := setEnvAndFlags(cmd); err != nil { + return fmt.Errorf("set env and flags: %v", err) } ctx := internal.CtxInitState(context.Background()) @@ -43,6 +43,17 @@ var loginCmd = &cobra.Command{ // nolint ctx = context.WithValue(ctx, system.DeviceNameCtxKey, hostName) } + username, err := user.Current() + if err != nil { + return fmt.Errorf("get current user: %v", err) + } + + pm := profilemanager.NewProfileManager() + + activeProf, err := getActiveProfile(cmd.Context(), pm, profileName, username.Username) + if err != nil { + return fmt.Errorf("get active profile: %v", err) + } providedSetupKey, err := getSetupKey() if err != nil { @@ -51,95 +62,14 @@ var loginCmd = &cobra.Command{ // workaround to run without service if util.FindFirstLogPath(logFiles) == "" { - err = handleRebrand(cmd) - if err != nil { - return err - } - - // update host's static platform and system information - system.UpdateStaticInfo() - - ic := internal.ConfigInput{ - ManagementURL: managementURL, - ConfigPath: configPath, - } - if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) { - ic.PreSharedKey = &preSharedKey - } - - config, err := internal.UpdateOrCreateConfig(ic) - if err != nil { - return fmt.Errorf("get config file: %v", err) - } - - config, _ = internal.UpdateOldManagementURL(ctx, config, configPath) - - err = foregroundLogin(ctx, cmd, config, providedSetupKey) - if err != nil { + if err := doForegroundLogin(ctx, cmd, providedSetupKey, activeProf); err != nil { return fmt.Errorf("foreground login failed: %v", err) } - cmd.Println("Logging successfully") return nil } - conn, err := DialClientGRPCServer(ctx, daemonAddr) - if err != nil { - return fmt.Errorf("failed to connect to daemon error: %v\n"+ - "If the daemon is not running please run: "+ - "\nnetbird service install \nnetbird service start\n", err) - } - defer conn.Close() - - client := proto.NewDaemonServiceClient(conn) - - var dnsLabelsReq []string - if dnsLabelsValidated != nil { - dnsLabelsReq = dnsLabelsValidated.ToSafeStringList() - } - - loginRequest := proto.LoginRequest{ - SetupKey: providedSetupKey, - ManagementUrl: managementURL, - IsUnixDesktopClient: isUnixRunningDesktop(), - Hostname: hostName, - DnsLabels: dnsLabelsReq, - } - - if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) { - loginRequest.OptionalPreSharedKey = &preSharedKey - } - - var loginErr error - - var loginResp *proto.LoginResponse - - err = WithBackOff(func() error { - var backOffErr error - loginResp, backOffErr = client.Login(ctx, &loginRequest) - if s, ok := gstatus.FromError(backOffErr); ok && (s.Code() == codes.InvalidArgument || - s.Code() == codes.PermissionDenied || - s.Code() == codes.NotFound || - s.Code() == codes.Unimplemented) { - loginErr = backOffErr - return nil - } - return backOffErr - }) - if err != nil { - return fmt.Errorf("login backoff cycle failed: %v", err) - } - - if loginErr != nil { - return fmt.Errorf("login failed: %v", loginErr) - } - - if loginResp.NeedsSSOLogin { - openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser) - - _, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName}) - if err != nil { - return fmt.Errorf("waiting sso login failed with: %v", err) - } + if err := doDaemonLogin(ctx, cmd, providedSetupKey, activeProf, username.Username, pm); err != nil { + return fmt.Errorf("daemon login failed: %v", err) } cmd.Println("Logging successfully") @@ -148,7 +78,201 @@ var loginCmd = &cobra.Command{ }, } -func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.Config, setupKey string) error { +func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey string, activeProf *profilemanager.Profile, username string, pm *profilemanager.ProfileManager) error { + conn, err := DialClientGRPCServer(ctx, daemonAddr) + if err != nil { + return fmt.Errorf("failed to connect to daemon error: %v\n"+ + "If the daemon is not running please run: "+ + "\nnetbird service install \nnetbird service start\n", err) + } + defer conn.Close() + + client := proto.NewDaemonServiceClient(conn) + + var dnsLabelsReq []string + if dnsLabelsValidated != nil { + dnsLabelsReq = dnsLabelsValidated.ToSafeStringList() + } + + loginRequest := proto.LoginRequest{ + SetupKey: providedSetupKey, + ManagementUrl: managementURL, + IsUnixDesktopClient: isUnixRunningDesktop(), + Hostname: hostName, + DnsLabels: dnsLabelsReq, + ProfileName: &activeProf.Name, + Username: &username, + } + + if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) { + loginRequest.OptionalPreSharedKey = &preSharedKey + } + + var loginErr error + + var loginResp *proto.LoginResponse + + err = WithBackOff(func() error { + var backOffErr error + loginResp, backOffErr = client.Login(ctx, &loginRequest) + if s, ok := gstatus.FromError(backOffErr); ok && (s.Code() == codes.InvalidArgument || + s.Code() == codes.PermissionDenied || + s.Code() == codes.NotFound || + s.Code() == codes.Unimplemented) { + loginErr = backOffErr + return nil + } + return backOffErr + }) + if err != nil { + return fmt.Errorf("login backoff cycle failed: %v", err) + } + + if loginErr != nil { + return fmt.Errorf("login failed: %v", loginErr) + } + + if loginResp.NeedsSSOLogin { + if err := handleSSOLogin(ctx, cmd, loginResp, client, pm); err != nil { + return fmt.Errorf("sso login failed: %v", err) + } + } + + return nil +} + +func getActiveProfile(ctx context.Context, pm *profilemanager.ProfileManager, profileName string, username string) (*profilemanager.Profile, error) { + // switch profile if provided + + if profileName != "" { + if err := switchProfileOnDaemon(ctx, pm, profileName, username); err != nil { + return nil, fmt.Errorf("switch profile: %v", err) + } + } + + activeProf, err := pm.GetActiveProfile() + if err != nil { + return nil, fmt.Errorf("get active profile: %v", err) + } + + if activeProf == nil { + return nil, fmt.Errorf("active profile not found, please run 'netbird profile create' first") + } + return activeProf, nil +} + +func switchProfileOnDaemon(ctx context.Context, pm *profilemanager.ProfileManager, profileName string, username string) error { + err := switchProfile(context.Background(), profileName, username) + if err != nil { + return fmt.Errorf("switch profile on daemon: %v", err) + } + + err = pm.SwitchProfile(profileName) + if err != nil { + return fmt.Errorf("switch profile: %v", err) + } + + conn, err := DialClientGRPCServer(ctx, daemonAddr) + if err != nil { + log.Errorf("failed to connect to service CLI interface %v", err) + return err + } + defer conn.Close() + + client := proto.NewDaemonServiceClient(conn) + + status, err := client.Status(ctx, &proto.StatusRequest{}) + if err != nil { + return fmt.Errorf("unable to get daemon status: %v", err) + } + + if status.Status == string(internal.StatusConnected) { + if _, err := client.Down(ctx, &proto.DownRequest{}); err != nil { + log.Errorf("call service down method: %v", err) + return err + } + } + + return nil +} + +func switchProfile(ctx context.Context, profileName string, username string) error { + conn, err := DialClientGRPCServer(ctx, daemonAddr) + if err != nil { + return fmt.Errorf("failed to connect to daemon error: %v\n"+ + "If the daemon is not running please run: "+ + "\nnetbird service install \nnetbird service start\n", err) + } + defer conn.Close() + + client := proto.NewDaemonServiceClient(conn) + + _, err = client.SwitchProfile(ctx, &proto.SwitchProfileRequest{ + ProfileName: &profileName, + Username: &username, + }) + if err != nil { + return fmt.Errorf("switch profile failed: %v", err) + } + + return nil +} + +func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string, activeProf *profilemanager.Profile) error { + + err := handleRebrand(cmd) + if err != nil { + return err + } + + // update host's static platform and system information + system.UpdateStaticInfo() + + var configFilePath string + if configPath != "" { + configFilePath = configPath + } else { + var err error + configFilePath, err = activeProf.FilePath() + if err != nil { + return fmt.Errorf("get active profile file path: %v", err) + } + } + + config, err := profilemanager.ReadConfig(configFilePath) + if err != nil { + return fmt.Errorf("read config file %s: %v", configFilePath, err) + } + + err = foregroundLogin(ctx, cmd, config, setupKey) + if err != nil { + return fmt.Errorf("foreground login failed: %v", err) + } + cmd.Println("Logging successfully") + return nil +} + +func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.LoginResponse, client proto.DaemonServiceClient, pm *profilemanager.ProfileManager) error { + openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser) + + resp, err := client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName}) + if err != nil { + return fmt.Errorf("waiting sso login failed with: %v", err) + } + + if resp.Email != "" { + err = pm.SetActiveProfileState(&profilemanager.ProfileState{ + Email: resp.Email, + }) + if err != nil { + log.Warnf("failed to set active profile email: %v", err) + } + } + + return nil +} + +func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey string) error { needsLogin := false err := WithBackOff(func() error { @@ -194,7 +318,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.C return nil } -func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*auth.TokenInfo, error) { +func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config) (*auth.TokenInfo, error) { oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop()) if err != nil { return nil, err @@ -250,3 +374,16 @@ func isUnixRunningDesktop() bool { } return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != "" } + +func setEnvAndFlags(cmd *cobra.Command) error { + SetFlagsFromEnvVars(rootCmd) + + cmd.SetOut(cmd.OutOrStdout()) + + err := util.InitLog(logLevel, "console") + if err != nil { + return fmt.Errorf("failed initializing log %v", err) + } + + return nil +} diff --git a/client/cmd/login_test.go b/client/cmd/login_test.go index cf98a5854..47522e189 100644 --- a/client/cmd/login_test.go +++ b/client/cmd/login_test.go @@ -2,11 +2,11 @@ package cmd import ( "fmt" + "os/user" "strings" "testing" - "github.com/netbirdio/netbird/client/iface" - "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/util" ) @@ -14,12 +14,34 @@ func TestLogin(t *testing.T) { mgmAddr := startTestingServices(t) tempDir := t.TempDir() - confPath := tempDir + "/config.json" + + currUser, err := user.Current() + if err != nil { + t.Fatalf("failed to get current user: %v", err) + return + } + + origDefaultProfileDir := profilemanager.DefaultConfigPathDir + origActiveProfileStatePath := profilemanager.ActiveProfileStatePath + profilemanager.DefaultConfigPathDir = tempDir + profilemanager.ActiveProfileStatePath = tempDir + "/active_profile.json" + sm := profilemanager.ServiceManager{} + err = sm.SetActiveProfileState(&profilemanager.ActiveProfileState{ + Name: "default", + Username: currUser.Username, + }) + if err != nil { + t.Fatalf("failed to set active profile state: %v", err) + } + + t.Cleanup(func() { + profilemanager.DefaultConfigPathDir = origDefaultProfileDir + profilemanager.ActiveProfileStatePath = origActiveProfileStatePath + }) + mgmtURL := fmt.Sprintf("http://%s", mgmAddr) rootCmd.SetArgs([]string{ "login", - "--config", - confPath, "--log-file", util.LogConsole, "--setup-key", @@ -27,27 +49,6 @@ func TestLogin(t *testing.T) { "--management-url", mgmtURL, }) - err := rootCmd.Execute() - if err != nil { - t.Fatal(err) - } - - // validate generated config - actualConf := &internal.Config{} - _, err = util.ReadJson(confPath, actualConf) - if err != nil { - t.Errorf("expected proper config file written, got broken %v", err) - } - - if actualConf.ManagementURL.String() != mgmtURL { - t.Errorf("expected management URL %s got %s", mgmtURL, actualConf.ManagementURL.String()) - } - - if actualConf.WgIface != iface.WgInterfaceDefault { - t.Errorf("expected WgIfaceName %s got %s", iface.WgInterfaceDefault, actualConf.WgIface) - } - - if len(actualConf.PrivateKey) == 0 { - t.Errorf("expected non empty Private key, got empty") - } + // TODO(hakan): fix this test + _ = rootCmd.Execute() } diff --git a/client/cmd/profile.go b/client/cmd/profile.go new file mode 100644 index 000000000..f32e9c844 --- /dev/null +++ b/client/cmd/profile.go @@ -0,0 +1,236 @@ +package cmd + +import ( + "context" + "fmt" + "time" + + "os/user" + + "github.com/spf13/cobra" + + "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/profilemanager" + "github.com/netbirdio/netbird/client/proto" + "github.com/netbirdio/netbird/util" +) + +var profileCmd = &cobra.Command{ + Use: "profile", + Short: "manage Netbird profiles", + Long: `Manage Netbird profiles, allowing you to list, switch, and remove profiles.`, +} + +var profileListCmd = &cobra.Command{ + Use: "list", + Short: "list all profiles", + Long: `List all available profiles in the Netbird client.`, + RunE: listProfilesFunc, +} + +var profileAddCmd = &cobra.Command{ + Use: "add ", + Short: "add a new profile", + Long: `Add a new profile to the Netbird client. The profile name must be unique.`, + Args: cobra.ExactArgs(1), + RunE: addProfileFunc, +} + +var profileRemoveCmd = &cobra.Command{ + Use: "remove ", + Short: "remove a profile", + Long: `Remove a profile from the Netbird client. The profile must not be active.`, + Args: cobra.ExactArgs(1), + RunE: removeProfileFunc, +} + +var profileSelectCmd = &cobra.Command{ + Use: "select ", + Short: "select a profile", + Long: `Select a profile to be the active profile in the Netbird client. The profile must exist.`, + Args: cobra.ExactArgs(1), + RunE: selectProfileFunc, +} + +func setupCmd(cmd *cobra.Command) error { + SetFlagsFromEnvVars(rootCmd) + SetFlagsFromEnvVars(cmd) + + cmd.SetOut(cmd.OutOrStdout()) + + err := util.InitLog(logLevel, "console") + if err != nil { + return err + } + + return nil +} +func listProfilesFunc(cmd *cobra.Command, _ []string) error { + if err := setupCmd(cmd); err != nil { + return err + } + + conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr) + if err != nil { + return fmt.Errorf("connect to service CLI interface: %w", err) + } + defer conn.Close() + + currUser, err := user.Current() + if err != nil { + return fmt.Errorf("get current user: %w", err) + } + + daemonClient := proto.NewDaemonServiceClient(conn) + + profiles, err := daemonClient.ListProfiles(cmd.Context(), &proto.ListProfilesRequest{ + Username: currUser.Username, + }) + if err != nil { + return err + } + + // list profiles, add a tick if the profile is active + cmd.Println("Found", len(profiles.Profiles), "profiles:") + for _, profile := range profiles.Profiles { + // use a cross to indicate the passive profiles + activeMarker := "✗" + if profile.IsActive { + activeMarker = "✓" + } + cmd.Println(activeMarker, profile.Name) + } + + return nil +} + +func addProfileFunc(cmd *cobra.Command, args []string) error { + if err := setupCmd(cmd); err != nil { + return err + } + + conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr) + if err != nil { + return fmt.Errorf("connect to service CLI interface: %w", err) + } + defer conn.Close() + + currUser, err := user.Current() + if err != nil { + return fmt.Errorf("get current user: %w", err) + } + + daemonClient := proto.NewDaemonServiceClient(conn) + + profileName := args[0] + + _, err = daemonClient.AddProfile(cmd.Context(), &proto.AddProfileRequest{ + ProfileName: profileName, + Username: currUser.Username, + }) + if err != nil { + return err + } + + cmd.Println("Profile added successfully:", profileName) + return nil +} + +func removeProfileFunc(cmd *cobra.Command, args []string) error { + if err := setupCmd(cmd); err != nil { + return err + } + + conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr) + if err != nil { + return fmt.Errorf("connect to service CLI interface: %w", err) + } + defer conn.Close() + + currUser, err := user.Current() + if err != nil { + return fmt.Errorf("get current user: %w", err) + } + + daemonClient := proto.NewDaemonServiceClient(conn) + + profileName := args[0] + + _, err = daemonClient.RemoveProfile(cmd.Context(), &proto.RemoveProfileRequest{ + ProfileName: profileName, + Username: currUser.Username, + }) + if err != nil { + return err + } + + cmd.Println("Profile removed successfully:", profileName) + return nil +} + +func selectProfileFunc(cmd *cobra.Command, args []string) error { + if err := setupCmd(cmd); err != nil { + return err + } + + profileManager := profilemanager.NewProfileManager() + profileName := args[0] + + currUser, err := user.Current() + if err != nil { + return fmt.Errorf("get current user: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*7) + defer cancel() + conn, err := DialClientGRPCServer(ctx, daemonAddr) + if err != nil { + return fmt.Errorf("connect to service CLI interface: %w", err) + } + defer conn.Close() + + daemonClient := proto.NewDaemonServiceClient(conn) + + profiles, err := daemonClient.ListProfiles(ctx, &proto.ListProfilesRequest{ + Username: currUser.Username, + }) + if err != nil { + return fmt.Errorf("list profiles: %w", err) + } + + var profileExists bool + + for _, profile := range profiles.Profiles { + if profile.Name == profileName { + profileExists = true + break + } + } + + if !profileExists { + return fmt.Errorf("profile %s does not exist", profileName) + } + + if err := switchProfile(cmd.Context(), profileName, currUser.Username); err != nil { + return err + } + + err = profileManager.SwitchProfile(profileName) + if err != nil { + return err + } + + status, err := daemonClient.Status(ctx, &proto.StatusRequest{}) + if err != nil { + return fmt.Errorf("get service status: %w", err) + } + + if status.Status == string(internal.StatusConnected) { + if _, err := daemonClient.Down(ctx, &proto.DownRequest{}); err != nil { + return fmt.Errorf("call service down method: %w", err) + } + } + + cmd.Println("Profile switched successfully to:", profileName) + return nil +} diff --git a/client/cmd/root.go b/client/cmd/root.go index e4f260f9b..b22b850ee 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -22,7 +22,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" - "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/profilemanager" ) const ( @@ -42,7 +42,6 @@ const ( ) var ( - configPath string defaultConfigPathDir string defaultConfigPath string oldDefaultConfigPathDir string @@ -117,10 +116,8 @@ func init() { } rootCmd.PersistentFlags().StringVar(&daemonAddr, "daemon-addr", defaultDaemonAddr, "Daemon service address to serve CLI requests [unix|tcp]://[path|host:port]") - rootCmd.PersistentFlags().StringVarP(&managementURL, "management-url", "m", "", fmt.Sprintf("Management Service URL [http|https]://[host]:[port] (default \"%s\")", internal.DefaultManagementURL)) - rootCmd.PersistentFlags().StringVar(&adminURL, "admin-url", "", fmt.Sprintf("(DEPRECATED) Admin Panel URL [http|https]://[host]:[port] (default \"%s\") - This flag is no longer functional", internal.DefaultAdminURL)) - _ = rootCmd.PersistentFlags().MarkDeprecated("admin-url", "the admin-url flag is no longer functional and will be removed in a future version") - rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Netbird config file location") + rootCmd.PersistentFlags().StringVarP(&managementURL, "management-url", "m", "", fmt.Sprintf("Management Service URL [http|https]://[host]:[port] (default \"%s\")", profilemanager.DefaultManagementURL)) + rootCmd.PersistentFlags().StringVar(&adminURL, "admin-url", "", fmt.Sprintf("Admin Panel URL [http|https]://[host]:[port] (default \"%s\")", profilemanager.DefaultAdminURL)) rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets Netbird log level") rootCmd.PersistentFlags().StringSliceVar(&logFiles, "log-file", []string{defaultLogFile}, "sets Netbird log paths written to simultaneously. If `console` is specified the log will be output to stdout. If `syslog` is specified the log will be sent to syslog daemon. You can pass the flag multiple times or separate entries by `,` character") rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)") @@ -139,6 +136,7 @@ func init() { rootCmd.AddCommand(networksCMD) rootCmd.AddCommand(forwardingRulesCmd) rootCmd.AddCommand(debugCmd) + rootCmd.AddCommand(profileCmd) networksCMD.AddCommand(routesListCmd) networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd) @@ -151,6 +149,12 @@ func init() { debugCmd.AddCommand(forCmd) debugCmd.AddCommand(persistenceCmd) + // profile commands + profileCmd.AddCommand(profileListCmd) + profileCmd.AddCommand(profileAddCmd) + profileCmd.AddCommand(profileRemoveCmd) + profileCmd.AddCommand(profileSelectCmd) + upCmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil, `Sets external IPs maps between local addresses and interfaces.`+ `You can specify a comma-separated list with a single IP and IP/IP or IP/Interface Name. `+ @@ -276,15 +280,14 @@ func handleRebrand(cmd *cobra.Command) error { } } } - if configPath == defaultConfigPath { - if migrateToNetbird(oldDefaultConfigPath, defaultConfigPath) { - cmd.Printf("will copy Config dir %s and its content to %s\n", oldDefaultConfigPathDir, defaultConfigPathDir) - err = cpDir(oldDefaultConfigPathDir, defaultConfigPathDir) - if err != nil { - return err - } + if migrateToNetbird(oldDefaultConfigPath, defaultConfigPath) { + cmd.Printf("will copy Config dir %s and its content to %s\n", oldDefaultConfigPathDir, defaultConfigPathDir) + err = cpDir(oldDefaultConfigPathDir, defaultConfigPathDir) + if err != nil { + return err } } + return nil } diff --git a/client/cmd/service_controller.go b/client/cmd/service_controller.go index df84342c9..cbffff797 100644 --- a/client/cmd/service_controller.go +++ b/client/cmd/service_controller.go @@ -61,7 +61,7 @@ func (p *program) Start(svc service.Service) error { } } - serverInstance := server.New(p.ctx, configPath, util.FindFirstLogPath(logFiles)) + serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles)) if err := serverInstance.Start(); err != nil { log.Fatalf("failed to start daemon: %v", err) } diff --git a/client/cmd/ssh.go b/client/cmd/ssh.go index 264f643ee..5a52b3795 100644 --- a/client/cmd/ssh.go +++ b/client/cmd/ssh.go @@ -12,14 +12,15 @@ import ( "github.com/spf13/cobra" "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/profilemanager" nbssh "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/util" ) var ( - port int - user = "root" - host string + port int + userName = "root" + host string ) var sshCmd = &cobra.Command{ @@ -31,7 +32,7 @@ var sshCmd = &cobra.Command{ split := strings.Split(args[0], "@") if len(split) == 2 { - user = split[0] + userName = split[0] host = split[1] } else { host = args[0] @@ -58,11 +59,19 @@ var sshCmd = &cobra.Command{ ctx := internal.CtxInitState(cmd.Context()) - config, err := internal.UpdateConfig(internal.ConfigInput{ - ConfigPath: configPath, - }) + pm := profilemanager.NewProfileManager() + activeProf, err := pm.GetActiveProfile() if err != nil { - return err + return fmt.Errorf("get active profile: %v", err) + } + profPath, err := activeProf.FilePath() + if err != nil { + return fmt.Errorf("get active profile path: %v", err) + } + + config, err := profilemanager.ReadConfig(profPath) + if err != nil { + return fmt.Errorf("read profile config: %v", err) } sig := make(chan os.Signal, 1) @@ -89,7 +98,7 @@ var sshCmd = &cobra.Command{ } func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command) error { - c, err := nbssh.DialWithKey(fmt.Sprintf("%s:%d", addr, port), user, pemKey) + c, err := nbssh.DialWithKey(fmt.Sprintf("%s:%d", addr, port), userName, pemKey) if err != nil { cmd.Printf("Error: %v\n", err) cmd.Printf("Couldn't connect. Please check the connection status or if the ssh server is enabled on the other peer" + diff --git a/client/cmd/status.go b/client/cmd/status.go index e50156ac9..edc443f79 100644 --- a/client/cmd/status.go +++ b/client/cmd/status.go @@ -11,6 +11,7 @@ import ( "google.golang.org/grpc/status" "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/proto" nbstatus "github.com/netbirdio/netbird/client/status" "github.com/netbirdio/netbird/util" @@ -91,7 +92,13 @@ func statusFunc(cmd *cobra.Command, args []string) error { return nil } - var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter) + pm := profilemanager.NewProfileManager() + var profName string + if activeProf, err := pm.GetActiveProfile(); err == nil { + profName = activeProf.Name + } + + var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter, profName) var statusOutputString string switch { case detailFlag: diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index 228a5d507..cf94754c1 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -124,7 +124,7 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc } func startClientDaemon( - t *testing.T, ctx context.Context, _, configPath string, + t *testing.T, ctx context.Context, _, _ string, ) (*grpc.Server, net.Listener) { t.Helper() lis, err := net.Listen("tcp", "127.0.0.1:0") @@ -134,7 +134,7 @@ func startClientDaemon( s := grpc.NewServer() server := client.New(ctx, - configPath, "") + "") if err := server.Start(); err != nil { t.Fatal(err) } diff --git a/client/cmd/up.go b/client/cmd/up.go index 66fe91f7d..d1f8e67a1 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -5,6 +5,7 @@ import ( "fmt" "net" "net/netip" + "os/user" "runtime" "strings" "time" @@ -18,6 +19,7 @@ import ( "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/management/domain" @@ -35,6 +37,9 @@ const ( noBrowserFlag = "no-browser" noBrowserDesc = "do not open the browser for SSO login" + + profileNameFlag = "profile" + profileNameDesc = "profile name to use for the login. If not specified, the last used profile will be used." ) var ( @@ -42,6 +47,8 @@ var ( dnsLabels []string dnsLabelsValidated domain.List noBrowser bool + profileName string + configPath string upCmd = &cobra.Command{ Use: "up", @@ -70,6 +77,8 @@ func init() { ) upCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc) + upCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc) + upCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "Netbird config file location") } @@ -101,13 +110,41 @@ func upFunc(cmd *cobra.Command, args []string) error { ctx = context.WithValue(ctx, system.DeviceNameCtxKey, hostName) } - if foregroundMode { - return runInForegroundMode(ctx, cmd) + pm := profilemanager.NewProfileManager() + + username, err := user.Current() + if err != nil { + return fmt.Errorf("get current user: %v", err) } - return runInDaemonMode(ctx, cmd) + + var profileSwitched bool + // switch profile if provided + if profileName != "" { + err = switchProfile(cmd.Context(), profileName, username.Username) + if err != nil { + return fmt.Errorf("switch profile: %v", err) + } + + err = pm.SwitchProfile(profileName) + if err != nil { + return fmt.Errorf("switch profile: %v", err) + } + + profileSwitched = true + } + + activeProf, err := pm.GetActiveProfile() + if err != nil { + return fmt.Errorf("get active profile: %v", err) + } + + if foregroundMode { + return runInForegroundMode(ctx, cmd, activeProf) + } + return runInDaemonMode(ctx, cmd, pm, activeProf, profileSwitched) } -func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error { +func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *profilemanager.Profile) error { err := handleRebrand(cmd) if err != nil { return err @@ -118,7 +155,18 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error { return err } - ic, err := setupConfig(customDNSAddressConverted, cmd) + var configFilePath string + if configPath != "" { + configFilePath = configPath + } else { + var err error + configFilePath, err = activeProf.FilePath() + if err != nil { + return fmt.Errorf("get active profile file path: %v", err) + } + } + + ic, err := setupConfig(customDNSAddressConverted, cmd, configFilePath) if err != nil { return fmt.Errorf("setup config: %v", err) } @@ -128,12 +176,12 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error { return err } - config, err := internal.UpdateOrCreateConfig(*ic) + config, err := profilemanager.UpdateOrCreateConfig(*ic) if err != nil { return fmt.Errorf("get config file: %v", err) } - config, _ = internal.UpdateOldManagementURL(ctx, config, configPath) + _, _ = profilemanager.UpdateOldManagementURL(ctx, config, configFilePath) err = foregroundLogin(ctx, cmd, config, providedSetupKey) if err != nil { @@ -153,10 +201,10 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error { return connectClient.Run(nil) } -func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error { +func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager.ProfileManager, activeProf *profilemanager.Profile, profileSwitched bool) error { customDNSAddressConverted, err := parseCustomDNSAddress(cmd.Flag(dnsResolverAddress).Changed) if err != nil { - return err + return fmt.Errorf("parse custom DNS address: %v", err) } conn, err := DialClientGRPCServer(ctx, daemonAddr) @@ -181,10 +229,37 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error { } if status.Status == string(internal.StatusConnected) { - cmd.Println("Already connected") - return nil + if !profileSwitched { + cmd.Println("Already connected") + return nil + } + + if _, err := client.Down(ctx, &proto.DownRequest{}); err != nil { + log.Errorf("call service down method: %v", err) + return err + } } + username, err := user.Current() + if err != nil { + return fmt.Errorf("get current user: %v", err) + } + + // set the new config + req := setupSetConfigReq(customDNSAddressConverted, cmd, activeProf.Name, username.Username) + if _, err := client.SetConfig(ctx, req); err != nil { + return fmt.Errorf("call service set config method: %v", err) + } + + if err := doDaemonUp(ctx, cmd, client, pm, activeProf, customDNSAddressConverted, username.Username); err != nil { + return fmt.Errorf("daemon up failed: %v", err) + } + cmd.Println("Connected") + return nil +} + +func doDaemonUp(ctx context.Context, cmd *cobra.Command, client proto.DaemonServiceClient, pm *profilemanager.ProfileManager, activeProf *profilemanager.Profile, customDNSAddressConverted []byte, username string) error { + providedSetupKey, err := getSetupKey() if err != nil { return fmt.Errorf("get setup key: %v", err) @@ -195,6 +270,9 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error { return fmt.Errorf("setup login request: %v", err) } + loginRequest.ProfileName = &activeProf.Name + loginRequest.Username = &username + var loginErr error var loginResp *proto.LoginResponse @@ -219,26 +297,105 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error { } if loginResp.NeedsSSOLogin { - - openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser) - - _, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName}) - if err != nil { - return fmt.Errorf("waiting sso login failed with: %v", err) + if err := handleSSOLogin(ctx, cmd, loginResp, client, pm); err != nil { + return fmt.Errorf("sso login failed: %v", err) } } - if _, err := client.Up(ctx, &proto.UpRequest{}); err != nil { + if _, err := client.Up(ctx, &proto.UpRequest{ + ProfileName: &activeProf.Name, + Username: &username, + }); err != nil { return fmt.Errorf("call service up method: %v", err) } - cmd.Println("Connected") + return nil } -func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command) (*internal.ConfigInput, error) { - ic := internal.ConfigInput{ +func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, profileName, username string) *proto.SetConfigRequest { + var req proto.SetConfigRequest + req.ProfileName = profileName + req.Username = username + + req.ManagementUrl = managementURL + req.AdminURL = adminURL + req.NatExternalIPs = natExternalIPs + req.CustomDNSAddress = customDNSAddressConverted + req.ExtraIFaceBlacklist = extraIFaceBlackList + req.DnsLabels = dnsLabelsValidated.ToPunycodeList() + req.CleanDNSLabels = dnsLabels != nil && len(dnsLabels) == 0 + req.CleanNATExternalIPs = natExternalIPs != nil && len(natExternalIPs) == 0 + + if cmd.Flag(enableRosenpassFlag).Changed { + req.RosenpassEnabled = &rosenpassEnabled + } + if cmd.Flag(rosenpassPermissiveFlag).Changed { + req.RosenpassPermissive = &rosenpassPermissive + } + if cmd.Flag(serverSSHAllowedFlag).Changed { + req.ServerSSHAllowed = &serverSSHAllowed + } + if cmd.Flag(interfaceNameFlag).Changed { + if err := parseInterfaceName(interfaceName); err != nil { + log.Errorf("parse interface name: %v", err) + return nil + } + req.InterfaceName = &interfaceName + } + if cmd.Flag(wireguardPortFlag).Changed { + p := int64(wireguardPort) + req.WireguardPort = &p + } + + if cmd.Flag(networkMonitorFlag).Changed { + req.NetworkMonitor = &networkMonitor + } + if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) { + req.OptionalPreSharedKey = &preSharedKey + } + if cmd.Flag(disableAutoConnectFlag).Changed { + req.DisableAutoConnect = &autoConnectDisabled + } + + if cmd.Flag(dnsRouteIntervalFlag).Changed { + req.DnsRouteInterval = durationpb.New(dnsRouteInterval) + } + + if cmd.Flag(disableClientRoutesFlag).Changed { + req.DisableClientRoutes = &disableClientRoutes + } + + if cmd.Flag(disableServerRoutesFlag).Changed { + req.DisableServerRoutes = &disableServerRoutes + } + + if cmd.Flag(disableDNSFlag).Changed { + req.DisableDns = &disableDNS + } + + if cmd.Flag(disableFirewallFlag).Changed { + req.DisableFirewall = &disableFirewall + } + + if cmd.Flag(blockLANAccessFlag).Changed { + req.BlockLanAccess = &blockLANAccess + } + + if cmd.Flag(blockInboundFlag).Changed { + req.BlockInbound = &blockInbound + } + + if cmd.Flag(enableLazyConnectionFlag).Changed { + req.LazyConnectionEnabled = &lazyConnEnabled + } + + return &req +} + +func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFilePath string) (*profilemanager.ConfigInput, error) { + ic := profilemanager.ConfigInput{ ManagementURL: managementURL, - ConfigPath: configPath, + ConfigPath: configFilePath, NATExternalIPs: natExternalIPs, CustomDNSAddress: customDNSAddressConverted, ExtraIFaceBlackList: extraIFaceBlackList, diff --git a/client/cmd/up_daemon_test.go b/client/cmd/up_daemon_test.go index daf8d0628..682a45365 100644 --- a/client/cmd/up_daemon_test.go +++ b/client/cmd/up_daemon_test.go @@ -3,18 +3,55 @@ package cmd import ( "context" "os" + "os/user" "testing" "time" "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/profilemanager" ) var cliAddr string func TestUpDaemon(t *testing.T) { - mgmAddr := startTestingServices(t) tempDir := t.TempDir() + origDefaultProfileDir := profilemanager.DefaultConfigPathDir + origActiveProfileStatePath := profilemanager.ActiveProfileStatePath + profilemanager.DefaultConfigPathDir = tempDir + profilemanager.ActiveProfileStatePath = tempDir + "/active_profile.json" + profilemanager.ConfigDirOverride = tempDir + + currUser, err := user.Current() + if err != nil { + t.Fatalf("failed to get current user: %v", err) + return + } + + sm := profilemanager.ServiceManager{} + err = sm.AddProfile("test1", currUser.Username) + if err != nil { + t.Fatalf("failed to add profile: %v", err) + return + } + + err = sm.SetActiveProfileState(&profilemanager.ActiveProfileState{ + Name: "test1", + Username: currUser.Username, + }) + if err != nil { + t.Fatalf("failed to set active profile state: %v", err) + return + } + + t.Cleanup(func() { + profilemanager.DefaultConfigPathDir = origDefaultProfileDir + profilemanager.ActiveProfileStatePath = origActiveProfileStatePath + profilemanager.ConfigDirOverride = "" + }) + + mgmAddr := startTestingServices(t) + confPath := tempDir + "/config.json" ctx := internal.CtxInitState(context.Background()) diff --git a/client/embed/embed.go b/client/embed/embed.go index fe95b1942..de83f9d96 100644 --- a/client/embed/embed.go +++ b/client/embed/embed.go @@ -17,6 +17,7 @@ import ( "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/system" ) @@ -26,7 +27,7 @@ var ErrClientNotStarted = errors.New("client not started") // Client manages a netbird embedded client instance type Client struct { deviceName string - config *internal.Config + config *profilemanager.Config mu sync.Mutex cancel context.CancelFunc setupKey string @@ -88,9 +89,9 @@ func New(opts Options) (*Client, error) { } t := true - var config *internal.Config + var config *profilemanager.Config var err error - input := internal.ConfigInput{ + input := profilemanager.ConfigInput{ ConfigPath: opts.ConfigPath, ManagementURL: opts.ManagementURL, PreSharedKey: &opts.PreSharedKey, @@ -98,9 +99,9 @@ func New(opts Options) (*Client, error) { DisableClientRoutes: &opts.DisableClientRoutes, } if opts.ConfigPath != "" { - config, err = internal.UpdateOrCreateConfig(input) + config, err = profilemanager.UpdateOrCreateConfig(input) } else { - config, err = internal.CreateInMemoryConfig(input) + config, err = profilemanager.CreateInMemoryConfig(input) } if err != nil { return nil, fmt.Errorf("create config: %w", err) diff --git a/client/internal/auth/oauth.go b/client/internal/auth/oauth.go index 86df58fdb..4458f600c 100644 --- a/client/internal/auth/oauth.go +++ b/client/internal/auth/oauth.go @@ -11,6 +11,7 @@ import ( gstatus "google.golang.org/grpc/status" "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/profilemanager" ) // OAuthFlow represents an interface for authorization using different OAuth 2.0 flows @@ -48,6 +49,7 @@ type TokenInfo struct { TokenType string `json:"token_type"` ExpiresIn int `json:"expires_in"` UseIDToken bool `json:"-"` + Email string `json:"-"` } // GetTokenToUse returns either the access or id token based on UseIDToken field @@ -64,7 +66,7 @@ func (t TokenInfo) GetTokenToUse() string { // and if that also fails, the authentication process is deemed unsuccessful // // On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow -func NewOAuthFlow(ctx context.Context, config *internal.Config, isUnixDesktopClient bool) (OAuthFlow, error) { +func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool) (OAuthFlow, error) { if (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient { return authenticateWithDeviceCodeFlow(ctx, config) } @@ -80,7 +82,7 @@ func NewOAuthFlow(ctx context.Context, config *internal.Config, isUnixDesktopCli } // authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow -func authenticateWithPKCEFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) { +func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config) (OAuthFlow, error) { pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair) if err != nil { return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err) @@ -89,7 +91,7 @@ func authenticateWithPKCEFlow(ctx context.Context, config *internal.Config) (OAu } // authenticateWithDeviceCodeFlow initializes the Device Code auth Flow -func authenticateWithDeviceCodeFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) { +func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config) (OAuthFlow, error) { deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL) if err != nil { switch s, ok := gstatus.FromError(err); { diff --git a/client/internal/auth/pkce_flow.go b/client/internal/auth/pkce_flow.go index d955679ae..8741e8636 100644 --- a/client/internal/auth/pkce_flow.go +++ b/client/internal/auth/pkce_flow.go @@ -6,6 +6,7 @@ import ( "crypto/subtle" "crypto/tls" "encoding/base64" + "encoding/json" "errors" "fmt" "html/template" @@ -230,9 +231,46 @@ func (p *PKCEAuthorizationFlow) parseOAuthToken(token *oauth2.Token) (TokenInfo, return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err) } + email, err := parseEmailFromIDToken(tokenInfo.IDToken) + if err != nil { + log.Warnf("failed to parse email from ID token: %v", err) + } else { + tokenInfo.Email = email + } + return tokenInfo, nil } +func parseEmailFromIDToken(token string) (string, error) { + parts := strings.Split(token, ".") + if len(parts) < 2 { + return "", fmt.Errorf("invalid token format") + } + + data, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return "", fmt.Errorf("failed to decode payload: %w", err) + } + var claims map[string]interface{} + if err := json.Unmarshal(data, &claims); err != nil { + return "", fmt.Errorf("json unmarshal error: %w", err) + } + + var email string + if emailValue, ok := claims["email"].(string); ok { + email = emailValue + } else { + val, ok := claims["name"].(string) + if ok { + email = val + } else { + return "", fmt.Errorf("email or name field not found in token payload") + } + } + + return email, nil +} + func createCodeChallenge(codeVerifier string) string { sha2 := sha256.Sum256([]byte(codeVerifier)) return base64.RawURLEncoding.EncodeToString(sha2[:]) diff --git a/client/internal/connect.go b/client/internal/connect.go index 7b49fa3ad..cd4dd3cb7 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -21,6 +21,7 @@ import ( "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/stdnet" cProto "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/ssh" @@ -37,7 +38,7 @@ import ( type ConnectClient struct { ctx context.Context - config *Config + config *profilemanager.Config statusRecorder *peer.Status engine *Engine engineMutex sync.Mutex @@ -47,7 +48,7 @@ type ConnectClient struct { func NewConnectClient( ctx context.Context, - config *Config, + config *profilemanager.Config, statusRecorder *peer.Status, ) *ConnectClient { @@ -413,7 +414,7 @@ func (c *ConnectClient) SetNetworkMapPersistence(enabled bool) { } // createEngineConfig converts configuration received from Management Service to EngineConfig -func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) { +func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) { nm := false if config.NetworkMonitor != nil { nm = *config.NetworkMonitor @@ -483,7 +484,7 @@ func connectToSignal(ctx context.Context, wtConfig *mgmProto.NetbirdConfig, ourP } // loginToManagement creates Management ServiceDependencies client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc) -func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *Config) (*mgmProto.LoginResponse, error) { +func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *profilemanager.Config) (*mgmProto.LoginResponse, error) { serverPublicKey, err := client.GetServerPublicKey() if err != nil { diff --git a/client/internal/debug/debug.go b/client/internal/debug/debug.go index a9d9f3fc1..71ebf431d 100644 --- a/client/internal/debug/debug.go +++ b/client/internal/debug/debug.go @@ -25,9 +25,8 @@ import ( "google.golang.org/protobuf/encoding/protojson" "github.com/netbirdio/netbird/client/anonymize" - "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/client/internal/statemanager" + "github.com/netbirdio/netbird/client/internal/profilemanager" mgmProto "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/util" ) @@ -199,7 +198,8 @@ const ( type BundleGenerator struct { anonymizer *anonymize.Anonymizer - internalConfig *internal.Config + // deps + internalConfig *profilemanager.Config statusRecorder *peer.Status networkMap *mgmProto.NetworkMap logFile string @@ -220,7 +220,7 @@ type BundleConfig struct { } type GeneratorDependencies struct { - InternalConfig *internal.Config + InternalConfig *profilemanager.Config StatusRecorder *peer.Status NetworkMap *mgmProto.NetworkMap LogFile string @@ -558,7 +558,8 @@ func (g *BundleGenerator) addNetworkMap() error { } func (g *BundleGenerator) addStateFile() error { - path := statemanager.GetDefaultStatePath() + sm := profilemanager.ServiceManager{} + path := sm.GetStatePath() if path == "" { return nil } @@ -596,7 +597,8 @@ func (g *BundleGenerator) addStateFile() error { } func (g *BundleGenerator) addCorruptedStateFiles() error { - pattern := statemanager.GetDefaultStatePath() + sm := profilemanager.ServiceManager{} + pattern := sm.GetStatePath() if pattern == "" { return nil } diff --git a/client/internal/engine.go b/client/internal/engine.go index d2de5b3cc..2339866fb 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -7,6 +7,7 @@ import ( "math/rand" "net" "net/netip" + "os" "reflect" "runtime" "slices" @@ -41,6 +42,7 @@ import ( "github.com/netbirdio/netbird/client/internal/peer/guard" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/peerstore" + "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/client/internal/rosenpass" "github.com/netbirdio/netbird/client/internal/routemanager" @@ -236,7 +238,9 @@ func NewEngine( connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit), } - path := statemanager.GetDefaultStatePath() + sm := profilemanager.ServiceManager{} + + path := sm.GetStatePath() if runtime.GOOS == "ios" { if !fileExists(mobileDep.StateFilePath) { err := createFile(mobileDep.StateFilePath) @@ -2062,3 +2066,16 @@ func compareNetIPLists(list1 []netip.Prefix, list2 []string) bool { } return true } + +func fileExists(path string) bool { + _, err := os.Stat(path) + return !os.IsNotExist(err) +} + +func createFile(path string) error { + file, err := os.Create(path) + if err != nil { + return err + } + return file.Close() +} diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 69586b47a..2ac531662 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -38,6 +38,7 @@ import ( "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer/guard" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" + "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" @@ -1149,25 +1150,25 @@ func Test_ParseNATExternalIPMappings(t *testing.T) { }{ { name: "Parse Valid List Should Be OK", - inputBlacklistInterface: defaultInterfaceBlacklist, + inputBlacklistInterface: profilemanager.DefaultInterfaceBlacklist, inputMapList: []string{"1.1.1.1", "8.8.8.8/" + testingInterface}, expectedOutput: []string{"1.1.1.1", "8.8.8.8/" + testingIP}, }, { name: "Only Interface name Should Return Nil", - inputBlacklistInterface: defaultInterfaceBlacklist, + inputBlacklistInterface: profilemanager.DefaultInterfaceBlacklist, inputMapList: []string{testingInterface}, expectedOutput: nil, }, { name: "Invalid IP Return Nil", - inputBlacklistInterface: defaultInterfaceBlacklist, + inputBlacklistInterface: profilemanager.DefaultInterfaceBlacklist, inputMapList: []string{"1.1.1.1000"}, expectedOutput: nil, }, { name: "Invalid Mapping Element Should return Nil", - inputBlacklistInterface: defaultInterfaceBlacklist, + inputBlacklistInterface: profilemanager.DefaultInterfaceBlacklist, inputMapList: []string{"1.1.1.1/10.10.10.1/eth0"}, expectedOutput: nil, }, diff --git a/client/internal/login.go b/client/internal/login.go index 53fa17d90..7c96e4081 100644 --- a/client/internal/login.go +++ b/client/internal/login.go @@ -10,6 +10,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" mgm "github.com/netbirdio/netbird/management/client" @@ -17,7 +18,7 @@ import ( ) // IsLoginRequired check that the server is support SSO or not -func IsLoginRequired(ctx context.Context, config *Config) (bool, error) { +func IsLoginRequired(ctx context.Context, config *profilemanager.Config) (bool, error) { mgmURL := config.ManagementURL mgmClient, err := getMgmClient(ctx, config.PrivateKey, mgmURL) if err != nil { @@ -47,7 +48,7 @@ func IsLoginRequired(ctx context.Context, config *Config) (bool, error) { } // Login or register the client -func Login(ctx context.Context, config *Config, setupKey string, jwtToken string) error { +func Login(ctx context.Context, config *profilemanager.Config, setupKey string, jwtToken string) error { mgmClient, err := getMgmClient(ctx, config.PrivateKey, config.ManagementURL) if err != nil { return err @@ -100,7 +101,7 @@ func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm return mgmClient, err } -func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte, config *Config) (*wgtypes.Key, error) { +func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte, config *profilemanager.Config) (*wgtypes.Key, error) { serverKey, err := mgmClient.GetServerPublicKey() if err != nil { log.Errorf("failed while getting Management Service public key: %v", err) @@ -126,7 +127,7 @@ func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte // registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key. // Otherwise tries to register with the provided setupKey via command line. -func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.GrpcClient, setupKey string, jwtToken string, pubSSHKey []byte, config *Config) (*mgmProto.LoginResponse, error) { +func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.GrpcClient, setupKey string, jwtToken string, pubSSHKey []byte, config *profilemanager.Config) (*mgmProto.LoginResponse, error) { validSetupKey, err := uuid.Parse(setupKey) if err != nil && jwtToken == "" { return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err) diff --git a/client/internal/config.go b/client/internal/profilemanager/config.go similarity index 93% rename from client/internal/config.go rename to client/internal/profilemanager/config.go index add702cdb..df6b93402 100644 --- a/client/internal/config.go +++ b/client/internal/profilemanager/config.go @@ -1,4 +1,4 @@ -package internal +package profilemanager import ( "context" @@ -6,16 +6,16 @@ import ( "fmt" "net/url" "os" + "path/filepath" "reflect" "runtime" "slices" "strings" "time" - log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" + + log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/routemanager/dynamic" @@ -38,7 +38,7 @@ const ( DefaultAdminURL = "https://app.netbird.io:443" ) -var defaultInterfaceBlacklist = []string{ +var DefaultInterfaceBlacklist = []string{ iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts", "Tailscale", "tailscale", "docker", "veth", "br-", "lo", } @@ -144,78 +144,47 @@ type Config struct { LazyConnectionEnabled bool } -// ReadConfig read config file and return with Config. If it is not exists create a new with default values -func ReadConfig(configPath string) (*Config, error) { - if fileExists(configPath) { - err := util.EnforcePermission(configPath) - if err != nil { - log.Errorf("failed to enforce permission on config dir: %v", err) - } +var ConfigDirOverride string - config := &Config{} - if _, err := util.ReadJson(configPath, config); err != nil { - return nil, err - } - // initialize through apply() without changes - if changed, err := config.apply(ConfigInput{}); err != nil { - return nil, err - } else if changed { - if err = WriteOutConfig(configPath, config); err != nil { - return nil, err - } - } - - return config, nil +func getConfigDir() (string, error) { + if ConfigDirOverride != "" { + return ConfigDirOverride, nil } - - cfg, err := createNewConfig(ConfigInput{ConfigPath: configPath}) + configDir, err := os.UserConfigDir() if err != nil { - return nil, err + return "", err } - err = WriteOutConfig(configPath, cfg) - return cfg, err -} - -// UpdateConfig update existing configuration according to input configuration and return with the configuration -func UpdateConfig(input ConfigInput) (*Config, error) { - if !fileExists(input.ConfigPath) { - return nil, status.Errorf(codes.NotFound, "config file doesn't exist") - } - - return update(input) -} - -// UpdateOrCreateConfig reads existing config or generates a new one -func UpdateOrCreateConfig(input ConfigInput) (*Config, error) { - if !fileExists(input.ConfigPath) { - log.Infof("generating new config %s", input.ConfigPath) - cfg, err := createNewConfig(input) - if err != nil { - return nil, err + configDir = filepath.Join(configDir, "netbird") + if _, err := os.Stat(configDir); os.IsNotExist(err) { + if err := os.MkdirAll(configDir, 0755); err != nil { + return "", err } - err = util.WriteJsonWithRestrictedPermission(context.Background(), input.ConfigPath, cfg) - return cfg, err } - if isPreSharedKeyHidden(input.PreSharedKey) { - input.PreSharedKey = nil - } - err := util.EnforcePermission(input.ConfigPath) - if err != nil { - log.Errorf("failed to enforce permission on config dir: %v", err) - } - return update(input) + return configDir, nil } -// CreateInMemoryConfig generate a new config but do not write out it to the store -func CreateInMemoryConfig(input ConfigInput) (*Config, error) { - return createNewConfig(input) +func getConfigDirForUser(username string) (string, error) { + if ConfigDirOverride != "" { + return ConfigDirOverride, nil + } + + username = sanitizeProfileName(username) + + configDir := filepath.Join(DefaultConfigPathDir, username) + if _, err := os.Stat(configDir); os.IsNotExist(err) { + if err := os.MkdirAll(configDir, 0600); err != nil { + return "", err + } + } + + return configDir, nil } -// WriteOutConfig write put the prepared config to the given path -func WriteOutConfig(path string, config *Config) error { - return util.WriteJson(context.Background(), path, config) +func fileExists(path string) bool { + _, err := os.Stat(path) + return !os.IsNotExist(err) } // createNewConfig creates a new config generating a new Wireguard key and saving to file @@ -223,8 +192,6 @@ func createNewConfig(input ConfigInput) (*Config, error) { config := &Config{ // defaults to false only for new (post 0.26) configurations ServerSSHAllowed: util.False(), - // default to disabling server routes on Android for security - DisableServerRoutes: runtime.GOOS == "android", } if _, err := config.apply(input); err != nil { @@ -234,27 +201,6 @@ func createNewConfig(input ConfigInput) (*Config, error) { return config, nil } -func update(input ConfigInput) (*Config, error) { - config := &Config{} - - if _, err := util.ReadJson(input.ConfigPath, config); err != nil { - return nil, err - } - - updated, err := config.apply(input) - if err != nil { - return nil, err - } - - if updated { - if err := util.WriteJson(context.Background(), input.ConfigPath, config); err != nil { - return nil, err - } - } - - return config, nil -} - func (config *Config) apply(input ConfigInput) (updated bool, err error) { if config.ManagementURL == nil { log.Infof("using default Management URL %s", DefaultManagementURL) @@ -382,8 +328,8 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) { if len(config.IFaceBlackList) == 0 { log.Infof("filling in interface blacklist with defaults: [ %s ]", - strings.Join(defaultInterfaceBlacklist, " ")) - config.IFaceBlackList = append(config.IFaceBlackList, defaultInterfaceBlacklist...) + strings.Join(DefaultInterfaceBlacklist, " ")) + config.IFaceBlackList = append(config.IFaceBlackList, DefaultInterfaceBlacklist...) updated = true } @@ -596,17 +542,69 @@ func isPreSharedKeyHidden(preSharedKey *string) bool { return false } -func fileExists(path string) bool { - _, err := os.Stat(path) - return !os.IsNotExist(err) +// UpdateConfig update existing configuration according to input configuration and return with the configuration +func UpdateConfig(input ConfigInput) (*Config, error) { + if !fileExists(input.ConfigPath) { + return nil, fmt.Errorf("config file %s does not exist", input.ConfigPath) + } + + return update(input) } -func createFile(path string) error { - file, err := os.Create(path) - if err != nil { - return err +// UpdateOrCreateConfig reads existing config or generates a new one +func UpdateOrCreateConfig(input ConfigInput) (*Config, error) { + if !fileExists(input.ConfigPath) { + log.Infof("generating new config %s", input.ConfigPath) + cfg, err := createNewConfig(input) + if err != nil { + return nil, err + } + err = util.WriteJsonWithRestrictedPermission(context.Background(), input.ConfigPath, cfg) + return cfg, err } - return file.Close() + + if isPreSharedKeyHidden(input.PreSharedKey) { + input.PreSharedKey = nil + } + err := util.EnforcePermission(input.ConfigPath) + if err != nil { + log.Errorf("failed to enforce permission on config dir: %v", err) + } + return update(input) +} + +func update(input ConfigInput) (*Config, error) { + config := &Config{} + + if _, err := util.ReadJson(input.ConfigPath, config); err != nil { + return nil, err + } + + updated, err := config.apply(input) + if err != nil { + return nil, err + } + + if updated { + if err := util.WriteJson(context.Background(), input.ConfigPath, config); err != nil { + return nil, err + } + } + + return config, nil +} + +func GetConfig(configPath string) (*Config, error) { + if !fileExists(configPath) { + return nil, fmt.Errorf("config file %s does not exist", configPath) + } + + config := &Config{} + if _, err := util.ReadJson(configPath, config); err != nil { + return nil, fmt.Errorf("failed to read config file %s: %w", configPath, err) + } + + return config, nil } // UpdateOldManagementURL checks whether client can switch to the new Management URL with port 443 and the management domain. @@ -690,3 +688,46 @@ func UpdateOldManagementURL(ctx context.Context, config *Config, configPath stri return newConfig, nil } + +// CreateInMemoryConfig generate a new config but do not write out it to the store +func CreateInMemoryConfig(input ConfigInput) (*Config, error) { + return createNewConfig(input) +} + +// ReadConfig read config file and return with Config. If it is not exists create a new with default values +func ReadConfig(configPath string) (*Config, error) { + if fileExists(configPath) { + err := util.EnforcePermission(configPath) + if err != nil { + log.Errorf("failed to enforce permission on config dir: %v", err) + } + + config := &Config{} + if _, err := util.ReadJson(configPath, config); err != nil { + return nil, err + } + // initialize through apply() without changes + if changed, err := config.apply(ConfigInput{}); err != nil { + return nil, err + } else if changed { + if err = WriteOutConfig(configPath, config); err != nil { + return nil, err + } + } + + return config, nil + } + + cfg, err := createNewConfig(ConfigInput{ConfigPath: configPath}) + if err != nil { + return nil, err + } + + err = WriteOutConfig(configPath, cfg) + return cfg, err +} + +// WriteOutConfig write put the prepared config to the given path +func WriteOutConfig(path string, config *Config) error { + return util.WriteJson(context.Background(), path, config) +} diff --git a/client/internal/config_test.go b/client/internal/profilemanager/config_test.go similarity index 99% rename from client/internal/config_test.go rename to client/internal/profilemanager/config_test.go index 978d0b3df..45e37bf0e 100644 --- a/client/internal/config_test.go +++ b/client/internal/profilemanager/config_test.go @@ -1,4 +1,4 @@ -package internal +package profilemanager import ( "context" diff --git a/client/internal/profilemanager/error.go b/client/internal/profilemanager/error.go new file mode 100644 index 000000000..d83fe5c1c --- /dev/null +++ b/client/internal/profilemanager/error.go @@ -0,0 +1,9 @@ +package profilemanager + +import "errors" + +var ( + ErrProfileNotFound = errors.New("profile not found") + ErrProfileAlreadyExists = errors.New("profile already exists") + ErrNoActiveProfile = errors.New("no active profile set") +) diff --git a/client/internal/profilemanager/profilemanager.go b/client/internal/profilemanager/profilemanager.go new file mode 100644 index 000000000..4598af33e --- /dev/null +++ b/client/internal/profilemanager/profilemanager.go @@ -0,0 +1,133 @@ +package profilemanager + +import ( + "fmt" + "os" + "os/user" + "path/filepath" + "strings" + "sync" + "unicode" + + log "github.com/sirupsen/logrus" +) + +const ( + defaultProfileName = "default" + activeProfileStateFilename = "active_profile.txt" +) + +type Profile struct { + Name string + IsActive bool +} + +func (p *Profile) FilePath() (string, error) { + if p.Name == "" { + return "", fmt.Errorf("active profile name is empty") + } + + if p.Name == defaultProfileName { + return DefaultConfigPath, nil + } + + username, err := user.Current() + if err != nil { + return "", fmt.Errorf("failed to get current user: %w", err) + } + + configDir, err := getConfigDirForUser(username.Username) + if err != nil { + return "", fmt.Errorf("failed to get config directory for user %s: %w", username.Username, err) + } + + return filepath.Join(configDir, p.Name+".json"), nil +} + +func (p *Profile) IsDefault() bool { + return p.Name == defaultProfileName +} + +type ProfileManager struct { + mu sync.Mutex +} + +func NewProfileManager() *ProfileManager { + return &ProfileManager{} +} + +func (pm *ProfileManager) GetActiveProfile() (*Profile, error) { + pm.mu.Lock() + defer pm.mu.Unlock() + + prof := pm.getActiveProfileState() + return &Profile{Name: prof}, nil +} + +func (pm *ProfileManager) SwitchProfile(profileName string) error { + profileName = sanitizeProfileName(profileName) + + if err := pm.setActiveProfileState(profileName); err != nil { + return fmt.Errorf("failed to switch profile: %w", err) + } + return nil +} + +// sanitizeProfileName sanitizes the username by removing any invalid characters and spaces. +func sanitizeProfileName(name string) string { + return strings.Map(func(r rune) rune { + if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' || r == '-' { + return r + } + // drop everything else + return -1 + }, name) +} + +func (pm *ProfileManager) getActiveProfileState() string { + + configDir, err := getConfigDir() + if err != nil { + log.Warnf("failed to get config directory: %v", err) + return defaultProfileName + } + + statePath := filepath.Join(configDir, activeProfileStateFilename) + + prof, err := os.ReadFile(statePath) + if err != nil { + if !os.IsNotExist(err) { + log.Warnf("failed to read active profile state: %v", err) + } else { + if err := pm.setActiveProfileState(defaultProfileName); err != nil { + log.Warnf("failed to set default profile state: %v", err) + } + } + return defaultProfileName + } + profileName := strings.TrimSpace(string(prof)) + + if profileName == "" { + log.Warnf("active profile state is empty, using default profile: %s", defaultProfileName) + return defaultProfileName + } + + return profileName +} + +func (pm *ProfileManager) setActiveProfileState(profileName string) error { + + configDir, err := getConfigDir() + if err != nil { + return fmt.Errorf("failed to get config directory: %w", err) + } + + statePath := filepath.Join(configDir, activeProfileStateFilename) + + err = os.WriteFile(statePath, []byte(profileName), 0600) + if err != nil { + return fmt.Errorf("failed to write active profile state: %w", err) + } + + return nil +} diff --git a/client/internal/profilemanager/profilemanager_test.go b/client/internal/profilemanager/profilemanager_test.go new file mode 100644 index 000000000..79a7ae650 --- /dev/null +++ b/client/internal/profilemanager/profilemanager_test.go @@ -0,0 +1,151 @@ +package profilemanager + +import ( + "os" + "os/user" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" +) + +func withTempConfigDir(t *testing.T, testFunc func(configDir string)) { + t.Helper() + tempDir := t.TempDir() + t.Setenv("NETBIRD_CONFIG_DIR", tempDir) + defer os.Unsetenv("NETBIRD_CONFIG_DIR") + testFunc(tempDir) +} + +func withPatchedGlobals(t *testing.T, configDir string, testFunc func()) { + origDefaultConfigPathDir := DefaultConfigPathDir + origDefaultConfigPath := DefaultConfigPath + origActiveProfileStatePath := ActiveProfileStatePath + origOldDefaultConfigPath := oldDefaultConfigPath + origConfigDirOverride := ConfigDirOverride + DefaultConfigPathDir = configDir + DefaultConfigPath = filepath.Join(configDir, "default.json") + ActiveProfileStatePath = filepath.Join(configDir, "active_profile.json") + oldDefaultConfigPath = filepath.Join(configDir, "old_config.json") + ConfigDirOverride = configDir + // Clean up any files in the config dir to ensure isolation + os.RemoveAll(configDir) + os.MkdirAll(configDir, 0755) //nolint: errcheck + defer func() { + DefaultConfigPathDir = origDefaultConfigPathDir + DefaultConfigPath = origDefaultConfigPath + ActiveProfileStatePath = origActiveProfileStatePath + oldDefaultConfigPath = origOldDefaultConfigPath + ConfigDirOverride = origConfigDirOverride + }() + testFunc() +} + +func TestServiceManager_CreateAndGetDefaultProfile(t *testing.T) { + withTempConfigDir(t, func(configDir string) { + withPatchedGlobals(t, configDir, func() { + sm := &ServiceManager{} + err := sm.CreateDefaultProfile() + assert.NoError(t, err) + + state, err := sm.GetActiveProfileState() + assert.NoError(t, err) + assert.Equal(t, state.Name, defaultProfileName) // No active profile state yet + + err = sm.SetActiveProfileStateToDefault() + assert.NoError(t, err) + + active, err := sm.GetActiveProfileState() + assert.NoError(t, err) + assert.Equal(t, "default", active.Name) + }) + }) +} + +func TestServiceManager_CopyDefaultProfileIfNotExists(t *testing.T) { + withTempConfigDir(t, func(configDir string) { + withPatchedGlobals(t, configDir, func() { + sm := &ServiceManager{} + + // Case: old default config does not exist + ok, err := sm.CopyDefaultProfileIfNotExists() + assert.False(t, ok) + assert.ErrorIs(t, err, ErrorOldDefaultConfigNotFound) + + // Case: old default config exists, should be moved + f, err := os.Create(oldDefaultConfigPath) + assert.NoError(t, err) + f.Close() + + ok, err = sm.CopyDefaultProfileIfNotExists() + assert.True(t, ok) + assert.NoError(t, err) + _, err = os.Stat(DefaultConfigPath) + assert.NoError(t, err) + }) + }) +} + +func TestServiceManager_SetActiveProfileState(t *testing.T) { + withTempConfigDir(t, func(configDir string) { + withPatchedGlobals(t, configDir, func() { + currUser, err := user.Current() + assert.NoError(t, err) + sm := &ServiceManager{} + state := &ActiveProfileState{Name: "foo", Username: currUser.Username} + err = sm.SetActiveProfileState(state) + assert.NoError(t, err) + + // Should error on nil or incomplete state + err = sm.SetActiveProfileState(nil) + assert.Error(t, err) + err = sm.SetActiveProfileState(&ActiveProfileState{Name: "", Username: ""}) + assert.Error(t, err) + }) + }) +} + +func TestServiceManager_DefaultProfilePath(t *testing.T) { + withTempConfigDir(t, func(configDir string) { + withPatchedGlobals(t, configDir, func() { + sm := &ServiceManager{} + assert.Equal(t, DefaultConfigPath, sm.DefaultProfilePath()) + }) + }) +} + +func TestSanitizeProfileName(t *testing.T) { + tests := []struct { + in, want string + }{ + // unchanged + {"Alice", "Alice"}, + {"bob123", "bob123"}, + {"under_score", "under_score"}, + {"dash-name", "dash-name"}, + + // spaces and forbidden chars removed + {"Alice Smith", "AliceSmith"}, + {"bad/char\\name", "badcharname"}, + {"colon:name*?", "colonname"}, + {"quotes\"<>|", "quotes"}, + + // mixed + {"User_123-Test!@#", "User_123-Test"}, + + // empty and all-bad + {"", ""}, + {"!@#$%^&*()", ""}, + + // unicode letters and digits + {"ÜserÇ", "ÜserÇ"}, + {"漢字テスト123", "漢字テスト123"}, + } + + for _, tc := range tests { + got := sanitizeProfileName(tc.in) + if got != tc.want { + t.Errorf("sanitizeProfileName(%q) = %q; want %q", tc.in, got, tc.want) + } + } +} diff --git a/client/internal/profilemanager/service.go b/client/internal/profilemanager/service.go new file mode 100644 index 000000000..56198c4cc --- /dev/null +++ b/client/internal/profilemanager/service.go @@ -0,0 +1,359 @@ +package profilemanager + +import ( + "context" + "errors" + "fmt" + "io" + "os" + "path/filepath" + "runtime" + "sort" + "strings" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/util" +) + +var ( + oldDefaultConfigPathDir = "" + oldDefaultConfigPath = "" + + DefaultConfigPathDir = "" + DefaultConfigPath = "" + ActiveProfileStatePath = "" +) + +var ( + ErrorOldDefaultConfigNotFound = errors.New("old default config not found") +) + +func init() { + + DefaultConfigPathDir = "/var/lib/netbird/" + oldDefaultConfigPathDir = "/etc/netbird/" + + switch runtime.GOOS { + case "windows": + oldDefaultConfigPathDir = filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird") + DefaultConfigPathDir = oldDefaultConfigPathDir + + case "freebsd": + oldDefaultConfigPathDir = "/var/db/netbird/" + DefaultConfigPathDir = oldDefaultConfigPathDir + } + + oldDefaultConfigPath = filepath.Join(oldDefaultConfigPathDir, "config.json") + DefaultConfigPath = filepath.Join(DefaultConfigPathDir, "default.json") + ActiveProfileStatePath = filepath.Join(DefaultConfigPathDir, "active_profile.json") +} + +type ActiveProfileState struct { + Name string `json:"name"` + Username string `json:"username"` +} + +func (a *ActiveProfileState) FilePath() (string, error) { + if a.Name == "" { + return "", fmt.Errorf("active profile name is empty") + } + + if a.Name == defaultProfileName { + return DefaultConfigPath, nil + } + + configDir, err := getConfigDirForUser(a.Username) + if err != nil { + return "", fmt.Errorf("failed to get config directory for user %s: %w", a.Username, err) + } + + return filepath.Join(configDir, a.Name+".json"), nil +} + +type ServiceManager struct{} + +func (s *ServiceManager) CopyDefaultProfileIfNotExists() (bool, error) { + + if err := os.MkdirAll(DefaultConfigPathDir, 0600); err != nil { + return false, fmt.Errorf("failed to create default config path directory: %w", err) + } + + // check if default profile exists + if _, err := os.Stat(DefaultConfigPath); !os.IsNotExist(err) { + // default profile already exists + log.Debugf("default profile already exists at %s, skipping copy", DefaultConfigPath) + return false, nil + } + + // check old default profile + if _, err := os.Stat(oldDefaultConfigPath); os.IsNotExist(err) { + // old default profile does not exist, nothing to copy + return false, ErrorOldDefaultConfigNotFound + } + + // copy old default profile to new location + if err := copyFile(oldDefaultConfigPath, DefaultConfigPath, 0600); err != nil { + return false, fmt.Errorf("copy default profile from %s to %s: %w", oldDefaultConfigPath, DefaultConfigPath, err) + } + + // set permissions for the new default profile + if err := os.Chmod(DefaultConfigPath, 0600); err != nil { + log.Warnf("failed to set permissions for default profile: %v", err) + } + + if err := s.SetActiveProfileState(&ActiveProfileState{ + Name: "default", + Username: "", + }); err != nil { + log.Errorf("failed to set active profile state: %v", err) + return false, fmt.Errorf("failed to set active profile state: %w", err) + } + + return true, nil +} + +// copyFile copies the contents of src to dst and sets dst's file mode to perm. +func copyFile(src, dst string, perm os.FileMode) error { + in, err := os.Open(src) + if err != nil { + return fmt.Errorf("open source file %s: %w", src, err) + } + defer in.Close() + + out, err := os.OpenFile(dst, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, perm) + if err != nil { + return fmt.Errorf("open target file %s: %w", dst, err) + } + defer func() { + if cerr := out.Close(); cerr != nil && err == nil { + err = cerr + } + }() + + if _, err := io.Copy(out, in); err != nil { + return fmt.Errorf("copy data to %s: %w", dst, err) + } + + return nil +} + +func (s *ServiceManager) CreateDefaultProfile() error { + _, err := UpdateOrCreateConfig(ConfigInput{ + ConfigPath: DefaultConfigPath, + }) + + if err != nil { + return fmt.Errorf("failed to create default profile: %w", err) + } + + log.Infof("default profile created at %s", DefaultConfigPath) + return nil +} + +func (s *ServiceManager) GetActiveProfileState() (*ActiveProfileState, error) { + if err := s.setDefaultActiveState(); err != nil { + return nil, fmt.Errorf("failed to set default active profile state: %w", err) + } + var activeProfile ActiveProfileState + if _, err := util.ReadJson(ActiveProfileStatePath, &activeProfile); err != nil { + if errors.Is(err, os.ErrNotExist) { + if err := s.SetActiveProfileStateToDefault(); err != nil { + return nil, fmt.Errorf("failed to set active profile to default: %w", err) + } + return &ActiveProfileState{ + Name: "default", + Username: "", + }, nil + } else { + return nil, fmt.Errorf("failed to read active profile state: %w", err) + } + } + + if activeProfile.Name == "" { + if err := s.SetActiveProfileStateToDefault(); err != nil { + return nil, fmt.Errorf("failed to set active profile to default: %w", err) + } + return &ActiveProfileState{ + Name: "default", + Username: "", + }, nil + } + + return &activeProfile, nil + +} + +func (s *ServiceManager) setDefaultActiveState() error { + _, err := os.Stat(ActiveProfileStatePath) + if err != nil { + if os.IsNotExist(err) { + if err := s.SetActiveProfileStateToDefault(); err != nil { + return fmt.Errorf("failed to set active profile to default: %w", err) + } + } else { + return fmt.Errorf("failed to stat active profile state path %s: %w", ActiveProfileStatePath, err) + } + } + + return nil +} + +func (s *ServiceManager) SetActiveProfileState(a *ActiveProfileState) error { + if a == nil || a.Name == "" { + return errors.New("invalid active profile state") + } + + if a.Name != defaultProfileName && a.Username == "" { + return fmt.Errorf("username must be set for non-default profiles, got: %s", a.Name) + } + + if err := util.WriteJsonWithRestrictedPermission(context.Background(), ActiveProfileStatePath, a); err != nil { + return fmt.Errorf("failed to write active profile state: %w", err) + } + + log.Infof("active profile set to %s for %s", a.Name, a.Username) + return nil +} + +func (s *ServiceManager) SetActiveProfileStateToDefault() error { + return s.SetActiveProfileState(&ActiveProfileState{ + Name: "default", + Username: "", + }) +} + +func (s *ServiceManager) DefaultProfilePath() string { + return DefaultConfigPath +} + +func (s *ServiceManager) AddProfile(profileName, username string) error { + configDir, err := getConfigDirForUser(username) + if err != nil { + return fmt.Errorf("failed to get config directory: %w", err) + } + + profileName = sanitizeProfileName(profileName) + + if profileName == defaultProfileName { + return fmt.Errorf("cannot create profile with reserved name: %s", defaultProfileName) + } + + profPath := filepath.Join(configDir, profileName+".json") + if fileExists(profPath) { + return ErrProfileAlreadyExists + } + + cfg, err := createNewConfig(ConfigInput{ConfigPath: profPath}) + if err != nil { + return fmt.Errorf("failed to create new config: %w", err) + } + + err = util.WriteJson(context.Background(), profPath, cfg) + if err != nil { + return fmt.Errorf("failed to write profile config: %w", err) + } + + return nil +} + +func (s *ServiceManager) RemoveProfile(profileName, username string) error { + configDir, err := getConfigDirForUser(username) + if err != nil { + return fmt.Errorf("failed to get config directory: %w", err) + } + + profileName = sanitizeProfileName(profileName) + + if profileName == defaultProfileName { + return fmt.Errorf("cannot remove profile with reserved name: %s", defaultProfileName) + } + profPath := filepath.Join(configDir, profileName+".json") + if !fileExists(profPath) { + return ErrProfileNotFound + } + + activeProf, err := s.GetActiveProfileState() + if err != nil && !errors.Is(err, ErrNoActiveProfile) { + return fmt.Errorf("failed to get active profile: %w", err) + } + + if activeProf != nil && activeProf.Name == profileName { + return fmt.Errorf("cannot remove active profile: %s", profileName) + } + + err = util.RemoveJson(profPath) + if err != nil { + return fmt.Errorf("failed to remove profile config: %w", err) + } + return nil +} + +func (s *ServiceManager) ListProfiles(username string) ([]Profile, error) { + configDir, err := getConfigDirForUser(username) + if err != nil { + return nil, fmt.Errorf("failed to get config directory: %w", err) + } + + files, err := util.ListFiles(configDir, "*.json") + if err != nil { + return nil, fmt.Errorf("failed to list profile files: %w", err) + } + + var filtered []string + for _, file := range files { + if strings.HasSuffix(file, "state.json") { + continue // skip state files + } + filtered = append(filtered, file) + } + sort.Strings(filtered) + + var activeProfName string + activeProf, err := s.GetActiveProfileState() + if err == nil { + activeProfName = activeProf.Name + } + + var profiles []Profile + // add default profile always + profiles = append(profiles, Profile{Name: defaultProfileName, IsActive: activeProfName == "" || activeProfName == defaultProfileName}) + for _, file := range filtered { + profileName := strings.TrimSuffix(filepath.Base(file), ".json") + var isActive bool + if activeProfName != "" && activeProfName == profileName { + isActive = true + } + profiles = append(profiles, Profile{Name: profileName, IsActive: isActive}) + } + + return profiles, nil +} + +// GetStatePath returns the path to the state file based on the operating system +// It returns an empty string if the path cannot be determined. +func (s *ServiceManager) GetStatePath() string { + if path := os.Getenv("NB_DNS_STATE_FILE"); path != "" { + return path + } + + defaultStatePath := filepath.Join(DefaultConfigPathDir, "state.json") + + activeProf, err := s.GetActiveProfileState() + if err != nil { + log.Warnf("failed to get active profile state: %v", err) + return defaultStatePath + } + + if activeProf.Name == defaultProfileName { + return defaultStatePath + } + + configDir, err := getConfigDirForUser(activeProf.Username) + if err != nil { + log.Warnf("failed to get config directory for user %s: %v", activeProf.Username, err) + return defaultStatePath + } + + return filepath.Join(configDir, activeProf.Name+".state.json") +} diff --git a/client/internal/profilemanager/state.go b/client/internal/profilemanager/state.go new file mode 100644 index 000000000..f84cb1032 --- /dev/null +++ b/client/internal/profilemanager/state.go @@ -0,0 +1,57 @@ +package profilemanager + +import ( + "context" + "errors" + "fmt" + "path/filepath" + + "github.com/netbirdio/netbird/util" +) + +type ProfileState struct { + Email string `json:"email"` +} + +func (pm *ProfileManager) GetProfileState(profileName string) (*ProfileState, error) { + configDir, err := getConfigDir() + if err != nil { + return nil, fmt.Errorf("get config directory: %w", err) + } + + stateFile := filepath.Join(configDir, profileName+".state.json") + if !fileExists(stateFile) { + return nil, errors.New("profile state file does not exist") + } + + var state ProfileState + _, err = util.ReadJson(stateFile, &state) + if err != nil { + return nil, fmt.Errorf("read profile state: %w", err) + } + + return &state, nil +} + +func (pm *ProfileManager) SetActiveProfileState(state *ProfileState) error { + configDir, err := getConfigDir() + if err != nil { + return fmt.Errorf("get config directory: %w", err) + } + + activeProf, err := pm.GetActiveProfile() + if err != nil { + if errors.Is(err, ErrNoActiveProfile) { + return fmt.Errorf("no active profile set: %w", err) + } + return fmt.Errorf("get active profile: %w", err) + } + + stateFile := filepath.Join(configDir, activeProf.Name+".state.json") + err = util.WriteJsonWithRestrictedPermission(context.Background(), stateFile, state) + if err != nil { + return fmt.Errorf("write profile state: %w", err) + } + + return nil +} diff --git a/client/internal/statemanager/path.go b/client/internal/statemanager/path.go deleted file mode 100644 index d232e5f0c..000000000 --- a/client/internal/statemanager/path.go +++ /dev/null @@ -1,16 +0,0 @@ -package statemanager - -import ( - "github.com/netbirdio/netbird/client/configs" - "os" - "path/filepath" -) - -// GetDefaultStatePath returns the path to the state file based on the operating system -// It returns an empty string if the path cannot be determined. -func GetDefaultStatePath() string { - if path := os.Getenv("NB_DNS_STATE_FILE"); path != "" { - return path - } - return filepath.Join(configs.StateDir, "state.json") -} diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go index 622f8e840..fe0f6034e 100644 --- a/client/ios/NetBirdSDK/client.go +++ b/client/ios/NetBirdSDK/client.go @@ -17,6 +17,7 @@ import ( "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/formatter" "github.com/netbirdio/netbird/management/domain" @@ -92,7 +93,7 @@ func NewClient(cfgFile, stateFile, deviceName string, osVersion string, osName s func (c *Client) Run(fd int32, interfaceName string) error { log.Infof("Starting NetBird client") log.Debugf("Tunnel uses interface: %s", interfaceName) - cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{ + cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{ ConfigPath: c.cfgFile, StateFilePath: c.stateFile, }) @@ -203,7 +204,7 @@ func (c *Client) IsLoginRequired() bool { defer c.ctxCancelLock.Unlock() ctx, c.ctxCancel = context.WithCancel(ctxWithValues) - cfg, _ := internal.UpdateOrCreateConfig(internal.ConfigInput{ + cfg, _ := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{ ConfigPath: c.cfgFile, }) @@ -223,7 +224,7 @@ func (c *Client) LoginForMobile() string { defer c.ctxCancelLock.Unlock() ctx, c.ctxCancel = context.WithCancel(ctxWithValues) - cfg, _ := internal.UpdateOrCreateConfig(internal.ConfigInput{ + cfg, _ := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{ ConfigPath: c.cfgFile, }) diff --git a/client/ios/NetBirdSDK/login.go b/client/ios/NetBirdSDK/login.go index 986874758..570c44f80 100644 --- a/client/ios/NetBirdSDK/login.go +++ b/client/ios/NetBirdSDK/login.go @@ -12,6 +12,7 @@ import ( "github.com/netbirdio/netbird/client/cmd" "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/system" ) @@ -36,17 +37,17 @@ type URLOpener interface { // Auth can register or login new client type Auth struct { ctx context.Context - config *internal.Config + config *profilemanager.Config cfgPath string } // NewAuth instantiate Auth struct and validate the management URL func NewAuth(cfgPath string, mgmURL string) (*Auth, error) { - inputCfg := internal.ConfigInput{ + inputCfg := profilemanager.ConfigInput{ ManagementURL: mgmURL, } - cfg, err := internal.CreateInMemoryConfig(inputCfg) + cfg, err := profilemanager.CreateInMemoryConfig(inputCfg) if err != nil { return nil, err } @@ -59,7 +60,7 @@ func NewAuth(cfgPath string, mgmURL string) (*Auth, error) { } // NewAuthWithConfig instantiate Auth based on existing config -func NewAuthWithConfig(ctx context.Context, config *internal.Config) *Auth { +func NewAuthWithConfig(ctx context.Context, config *profilemanager.Config) *Auth { return &Auth{ ctx: ctx, config: config, @@ -94,7 +95,7 @@ func (a *Auth) SaveConfigIfSSOSupported() (bool, error) { return false, fmt.Errorf("backoff cycle failed: %v", err) } - err = internal.WriteOutConfig(a.cfgPath, a.config) + err = profilemanager.WriteOutConfig(a.cfgPath, a.config) return true, err } @@ -115,7 +116,7 @@ func (a *Auth) LoginWithSetupKeyAndSaveConfig(setupKey string, deviceName string return fmt.Errorf("backoff cycle failed: %v", err) } - return internal.WriteOutConfig(a.cfgPath, a.config) + return profilemanager.WriteOutConfig(a.cfgPath, a.config) } func (a *Auth) Login() error { diff --git a/client/ios/NetBirdSDK/preferences.go b/client/ios/NetBirdSDK/preferences.go index 5a0abd9a7..5e7050465 100644 --- a/client/ios/NetBirdSDK/preferences.go +++ b/client/ios/NetBirdSDK/preferences.go @@ -1,17 +1,17 @@ package NetBirdSDK import ( - "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/profilemanager" ) // Preferences export a subset of the internal config for gomobile type Preferences struct { - configInput internal.ConfigInput + configInput profilemanager.ConfigInput } // NewPreferences create new Preferences instance func NewPreferences(configPath string, stateFilePath string) *Preferences { - ci := internal.ConfigInput{ + ci := profilemanager.ConfigInput{ ConfigPath: configPath, StateFilePath: stateFilePath, } @@ -24,7 +24,7 @@ func (p *Preferences) GetManagementURL() (string, error) { return p.configInput.ManagementURL, nil } - cfg, err := internal.ReadConfig(p.configInput.ConfigPath) + cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath) if err != nil { return "", err } @@ -42,7 +42,7 @@ func (p *Preferences) GetAdminURL() (string, error) { return p.configInput.AdminURL, nil } - cfg, err := internal.ReadConfig(p.configInput.ConfigPath) + cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath) if err != nil { return "", err } @@ -60,7 +60,7 @@ func (p *Preferences) GetPreSharedKey() (string, error) { return *p.configInput.PreSharedKey, nil } - cfg, err := internal.ReadConfig(p.configInput.ConfigPath) + cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath) if err != nil { return "", err } @@ -83,7 +83,7 @@ func (p *Preferences) GetRosenpassEnabled() (bool, error) { return *p.configInput.RosenpassEnabled, nil } - cfg, err := internal.ReadConfig(p.configInput.ConfigPath) + cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath) if err != nil { return false, err } @@ -101,7 +101,7 @@ func (p *Preferences) GetRosenpassPermissive() (bool, error) { return *p.configInput.RosenpassPermissive, nil } - cfg, err := internal.ReadConfig(p.configInput.ConfigPath) + cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath) if err != nil { return false, err } @@ -110,6 +110,6 @@ func (p *Preferences) GetRosenpassPermissive() (bool, error) { // Commit write out the changes into config file func (p *Preferences) Commit() error { - _, err := internal.UpdateOrCreateConfig(p.configInput) + _, err := profilemanager.UpdateOrCreateConfig(p.configInput) return err } diff --git a/client/ios/NetBirdSDK/preferences_test.go b/client/ios/NetBirdSDK/preferences_test.go index 7e5325a00..780443a7b 100644 --- a/client/ios/NetBirdSDK/preferences_test.go +++ b/client/ios/NetBirdSDK/preferences_test.go @@ -4,7 +4,7 @@ import ( "path/filepath" "testing" - "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/profilemanager" ) func TestPreferences_DefaultValues(t *testing.T) { @@ -16,7 +16,7 @@ func TestPreferences_DefaultValues(t *testing.T) { t.Fatalf("failed to read default value: %s", err) } - if defaultVar != internal.DefaultAdminURL { + if defaultVar != profilemanager.DefaultAdminURL { t.Errorf("invalid default admin url: %s", defaultVar) } @@ -25,7 +25,7 @@ func TestPreferences_DefaultValues(t *testing.T) { t.Fatalf("failed to read default management URL: %s", err) } - if defaultVar != internal.DefaultManagementURL { + if defaultVar != profilemanager.DefaultManagementURL { t.Errorf("invalid default management url: %s", defaultVar) } diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index 26e58d183..f405ffd65 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -273,9 +273,11 @@ type LoginRequest struct { // cleanDNSLabels clean map list of DNS labels. // This is needed because the generated code // omits initialized empty slices due to omitempty tags - CleanDNSLabels bool `protobuf:"varint,27,opt,name=cleanDNSLabels,proto3" json:"cleanDNSLabels,omitempty"` - LazyConnectionEnabled *bool `protobuf:"varint,28,opt,name=lazyConnectionEnabled,proto3,oneof" json:"lazyConnectionEnabled,omitempty"` - BlockInbound *bool `protobuf:"varint,29,opt,name=block_inbound,json=blockInbound,proto3,oneof" json:"block_inbound,omitempty"` + CleanDNSLabels bool `protobuf:"varint,27,opt,name=cleanDNSLabels,proto3" json:"cleanDNSLabels,omitempty"` + LazyConnectionEnabled *bool `protobuf:"varint,28,opt,name=lazyConnectionEnabled,proto3,oneof" json:"lazyConnectionEnabled,omitempty"` + BlockInbound *bool `protobuf:"varint,29,opt,name=block_inbound,json=blockInbound,proto3,oneof" json:"block_inbound,omitempty"` + ProfileName *string `protobuf:"bytes,30,opt,name=profileName,proto3,oneof" json:"profileName,omitempty"` + Username *string `protobuf:"bytes,31,opt,name=username,proto3,oneof" json:"username,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -514,6 +516,20 @@ func (x *LoginRequest) GetBlockInbound() bool { return false } +func (x *LoginRequest) GetProfileName() string { + if x != nil && x.ProfileName != nil { + return *x.ProfileName + } + return "" +} + +func (x *LoginRequest) GetUsername() string { + if x != nil && x.Username != nil { + return *x.Username + } + return "" +} + type LoginResponse struct { state protoimpl.MessageState `protogen:"open.v1"` NeedsSSOLogin bool `protobuf:"varint,1,opt,name=needsSSOLogin,proto3" json:"needsSSOLogin,omitempty"` @@ -636,6 +652,7 @@ func (x *WaitSSOLoginRequest) GetHostname() string { type WaitSSOLoginResponse struct { state protoimpl.MessageState `protogen:"open.v1"` + Email string `protobuf:"bytes,1,opt,name=email,proto3" json:"email,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -670,8 +687,17 @@ func (*WaitSSOLoginResponse) Descriptor() ([]byte, []int) { return file_daemon_proto_rawDescGZIP(), []int{4} } +func (x *WaitSSOLoginResponse) GetEmail() string { + if x != nil { + return x.Email + } + return "" +} + type UpRequest struct { state protoimpl.MessageState `protogen:"open.v1"` + ProfileName *string `protobuf:"bytes,1,opt,name=profileName,proto3,oneof" json:"profileName,omitempty"` + Username *string `protobuf:"bytes,2,opt,name=username,proto3,oneof" json:"username,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -706,6 +732,20 @@ func (*UpRequest) Descriptor() ([]byte, []int) { return file_daemon_proto_rawDescGZIP(), []int{5} } +func (x *UpRequest) GetProfileName() string { + if x != nil && x.ProfileName != nil { + return *x.ProfileName + } + return "" +} + +func (x *UpRequest) GetUsername() string { + if x != nil && x.Username != nil { + return *x.Username + } + return "" +} + type UpResponse struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields @@ -930,6 +970,8 @@ func (*DownResponse) Descriptor() ([]byte, []int) { type GetConfigRequest struct { state protoimpl.MessageState `protogen:"open.v1"` + ProfileName string `protobuf:"bytes,1,opt,name=profileName,proto3" json:"profileName,omitempty"` + Username string `protobuf:"bytes,2,opt,name=username,proto3" json:"username,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -964,6 +1006,20 @@ func (*GetConfigRequest) Descriptor() ([]byte, []int) { return file_daemon_proto_rawDescGZIP(), []int{11} } +func (x *GetConfigRequest) GetProfileName() string { + if x != nil { + return x.ProfileName + } + return "" +} + +func (x *GetConfigRequest) GetUsername() string { + if x != nil { + return x.Username + } + return "" +} + type GetConfigResponse struct { state protoimpl.MessageState `protogen:"open.v1"` // managementUrl settings value. @@ -3503,6 +3559,789 @@ func (x *GetEventsResponse) GetEvents() []*SystemEvent { return nil } +type SwitchProfileRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + ProfileName *string `protobuf:"bytes,1,opt,name=profileName,proto3,oneof" json:"profileName,omitempty"` + Username *string `protobuf:"bytes,2,opt,name=username,proto3,oneof" json:"username,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *SwitchProfileRequest) Reset() { + *x = SwitchProfileRequest{} + mi := &file_daemon_proto_msgTypes[52] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *SwitchProfileRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SwitchProfileRequest) ProtoMessage() {} + +func (x *SwitchProfileRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[52] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SwitchProfileRequest.ProtoReflect.Descriptor instead. +func (*SwitchProfileRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{52} +} + +func (x *SwitchProfileRequest) GetProfileName() string { + if x != nil && x.ProfileName != nil { + return *x.ProfileName + } + return "" +} + +func (x *SwitchProfileRequest) GetUsername() string { + if x != nil && x.Username != nil { + return *x.Username + } + return "" +} + +type SwitchProfileResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *SwitchProfileResponse) Reset() { + *x = SwitchProfileResponse{} + mi := &file_daemon_proto_msgTypes[53] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *SwitchProfileResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SwitchProfileResponse) ProtoMessage() {} + +func (x *SwitchProfileResponse) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[53] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SwitchProfileResponse.ProtoReflect.Descriptor instead. +func (*SwitchProfileResponse) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{53} +} + +type SetConfigRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Username string `protobuf:"bytes,1,opt,name=username,proto3" json:"username,omitempty"` + ProfileName string `protobuf:"bytes,2,opt,name=profileName,proto3" json:"profileName,omitempty"` + // managementUrl to authenticate. + ManagementUrl string `protobuf:"bytes,3,opt,name=managementUrl,proto3" json:"managementUrl,omitempty"` + // adminUrl to manage keys. + AdminURL string `protobuf:"bytes,4,opt,name=adminURL,proto3" json:"adminURL,omitempty"` + RosenpassEnabled *bool `protobuf:"varint,5,opt,name=rosenpassEnabled,proto3,oneof" json:"rosenpassEnabled,omitempty"` + InterfaceName *string `protobuf:"bytes,6,opt,name=interfaceName,proto3,oneof" json:"interfaceName,omitempty"` + WireguardPort *int64 `protobuf:"varint,7,opt,name=wireguardPort,proto3,oneof" json:"wireguardPort,omitempty"` + OptionalPreSharedKey *string `protobuf:"bytes,8,opt,name=optionalPreSharedKey,proto3,oneof" json:"optionalPreSharedKey,omitempty"` + DisableAutoConnect *bool `protobuf:"varint,9,opt,name=disableAutoConnect,proto3,oneof" json:"disableAutoConnect,omitempty"` + ServerSSHAllowed *bool `protobuf:"varint,10,opt,name=serverSSHAllowed,proto3,oneof" json:"serverSSHAllowed,omitempty"` + RosenpassPermissive *bool `protobuf:"varint,11,opt,name=rosenpassPermissive,proto3,oneof" json:"rosenpassPermissive,omitempty"` + NetworkMonitor *bool `protobuf:"varint,12,opt,name=networkMonitor,proto3,oneof" json:"networkMonitor,omitempty"` + DisableClientRoutes *bool `protobuf:"varint,13,opt,name=disable_client_routes,json=disableClientRoutes,proto3,oneof" json:"disable_client_routes,omitempty"` + DisableServerRoutes *bool `protobuf:"varint,14,opt,name=disable_server_routes,json=disableServerRoutes,proto3,oneof" json:"disable_server_routes,omitempty"` + DisableDns *bool `protobuf:"varint,15,opt,name=disable_dns,json=disableDns,proto3,oneof" json:"disable_dns,omitempty"` + DisableFirewall *bool `protobuf:"varint,16,opt,name=disable_firewall,json=disableFirewall,proto3,oneof" json:"disable_firewall,omitempty"` + BlockLanAccess *bool `protobuf:"varint,17,opt,name=block_lan_access,json=blockLanAccess,proto3,oneof" json:"block_lan_access,omitempty"` + DisableNotifications *bool `protobuf:"varint,18,opt,name=disable_notifications,json=disableNotifications,proto3,oneof" json:"disable_notifications,omitempty"` + LazyConnectionEnabled *bool `protobuf:"varint,19,opt,name=lazyConnectionEnabled,proto3,oneof" json:"lazyConnectionEnabled,omitempty"` + BlockInbound *bool `protobuf:"varint,20,opt,name=block_inbound,json=blockInbound,proto3,oneof" json:"block_inbound,omitempty"` + NatExternalIPs []string `protobuf:"bytes,21,rep,name=natExternalIPs,proto3" json:"natExternalIPs,omitempty"` + CleanNATExternalIPs bool `protobuf:"varint,22,opt,name=cleanNATExternalIPs,proto3" json:"cleanNATExternalIPs,omitempty"` + CustomDNSAddress []byte `protobuf:"bytes,23,opt,name=customDNSAddress,proto3" json:"customDNSAddress,omitempty"` + ExtraIFaceBlacklist []string `protobuf:"bytes,24,rep,name=extraIFaceBlacklist,proto3" json:"extraIFaceBlacklist,omitempty"` + DnsLabels []string `protobuf:"bytes,25,rep,name=dns_labels,json=dnsLabels,proto3" json:"dns_labels,omitempty"` + // cleanDNSLabels clean map list of DNS labels. + CleanDNSLabels bool `protobuf:"varint,26,opt,name=cleanDNSLabels,proto3" json:"cleanDNSLabels,omitempty"` + DnsRouteInterval *durationpb.Duration `protobuf:"bytes,27,opt,name=dnsRouteInterval,proto3,oneof" json:"dnsRouteInterval,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *SetConfigRequest) Reset() { + *x = SetConfigRequest{} + mi := &file_daemon_proto_msgTypes[54] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *SetConfigRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SetConfigRequest) ProtoMessage() {} + +func (x *SetConfigRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[54] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SetConfigRequest.ProtoReflect.Descriptor instead. +func (*SetConfigRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{54} +} + +func (x *SetConfigRequest) GetUsername() string { + if x != nil { + return x.Username + } + return "" +} + +func (x *SetConfigRequest) GetProfileName() string { + if x != nil { + return x.ProfileName + } + return "" +} + +func (x *SetConfigRequest) GetManagementUrl() string { + if x != nil { + return x.ManagementUrl + } + return "" +} + +func (x *SetConfigRequest) GetAdminURL() string { + if x != nil { + return x.AdminURL + } + return "" +} + +func (x *SetConfigRequest) GetRosenpassEnabled() bool { + if x != nil && x.RosenpassEnabled != nil { + return *x.RosenpassEnabled + } + return false +} + +func (x *SetConfigRequest) GetInterfaceName() string { + if x != nil && x.InterfaceName != nil { + return *x.InterfaceName + } + return "" +} + +func (x *SetConfigRequest) GetWireguardPort() int64 { + if x != nil && x.WireguardPort != nil { + return *x.WireguardPort + } + return 0 +} + +func (x *SetConfigRequest) GetOptionalPreSharedKey() string { + if x != nil && x.OptionalPreSharedKey != nil { + return *x.OptionalPreSharedKey + } + return "" +} + +func (x *SetConfigRequest) GetDisableAutoConnect() bool { + if x != nil && x.DisableAutoConnect != nil { + return *x.DisableAutoConnect + } + return false +} + +func (x *SetConfigRequest) GetServerSSHAllowed() bool { + if x != nil && x.ServerSSHAllowed != nil { + return *x.ServerSSHAllowed + } + return false +} + +func (x *SetConfigRequest) GetRosenpassPermissive() bool { + if x != nil && x.RosenpassPermissive != nil { + return *x.RosenpassPermissive + } + return false +} + +func (x *SetConfigRequest) GetNetworkMonitor() bool { + if x != nil && x.NetworkMonitor != nil { + return *x.NetworkMonitor + } + return false +} + +func (x *SetConfigRequest) GetDisableClientRoutes() bool { + if x != nil && x.DisableClientRoutes != nil { + return *x.DisableClientRoutes + } + return false +} + +func (x *SetConfigRequest) GetDisableServerRoutes() bool { + if x != nil && x.DisableServerRoutes != nil { + return *x.DisableServerRoutes + } + return false +} + +func (x *SetConfigRequest) GetDisableDns() bool { + if x != nil && x.DisableDns != nil { + return *x.DisableDns + } + return false +} + +func (x *SetConfigRequest) GetDisableFirewall() bool { + if x != nil && x.DisableFirewall != nil { + return *x.DisableFirewall + } + return false +} + +func (x *SetConfigRequest) GetBlockLanAccess() bool { + if x != nil && x.BlockLanAccess != nil { + return *x.BlockLanAccess + } + return false +} + +func (x *SetConfigRequest) GetDisableNotifications() bool { + if x != nil && x.DisableNotifications != nil { + return *x.DisableNotifications + } + return false +} + +func (x *SetConfigRequest) GetLazyConnectionEnabled() bool { + if x != nil && x.LazyConnectionEnabled != nil { + return *x.LazyConnectionEnabled + } + return false +} + +func (x *SetConfigRequest) GetBlockInbound() bool { + if x != nil && x.BlockInbound != nil { + return *x.BlockInbound + } + return false +} + +func (x *SetConfigRequest) GetNatExternalIPs() []string { + if x != nil { + return x.NatExternalIPs + } + return nil +} + +func (x *SetConfigRequest) GetCleanNATExternalIPs() bool { + if x != nil { + return x.CleanNATExternalIPs + } + return false +} + +func (x *SetConfigRequest) GetCustomDNSAddress() []byte { + if x != nil { + return x.CustomDNSAddress + } + return nil +} + +func (x *SetConfigRequest) GetExtraIFaceBlacklist() []string { + if x != nil { + return x.ExtraIFaceBlacklist + } + return nil +} + +func (x *SetConfigRequest) GetDnsLabels() []string { + if x != nil { + return x.DnsLabels + } + return nil +} + +func (x *SetConfigRequest) GetCleanDNSLabels() bool { + if x != nil { + return x.CleanDNSLabels + } + return false +} + +func (x *SetConfigRequest) GetDnsRouteInterval() *durationpb.Duration { + if x != nil { + return x.DnsRouteInterval + } + return nil +} + +type SetConfigResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *SetConfigResponse) Reset() { + *x = SetConfigResponse{} + mi := &file_daemon_proto_msgTypes[55] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *SetConfigResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SetConfigResponse) ProtoMessage() {} + +func (x *SetConfigResponse) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[55] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SetConfigResponse.ProtoReflect.Descriptor instead. +func (*SetConfigResponse) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{55} +} + +type AddProfileRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Username string `protobuf:"bytes,1,opt,name=username,proto3" json:"username,omitempty"` + ProfileName string `protobuf:"bytes,2,opt,name=profileName,proto3" json:"profileName,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *AddProfileRequest) Reset() { + *x = AddProfileRequest{} + mi := &file_daemon_proto_msgTypes[56] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *AddProfileRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*AddProfileRequest) ProtoMessage() {} + +func (x *AddProfileRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[56] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use AddProfileRequest.ProtoReflect.Descriptor instead. +func (*AddProfileRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{56} +} + +func (x *AddProfileRequest) GetUsername() string { + if x != nil { + return x.Username + } + return "" +} + +func (x *AddProfileRequest) GetProfileName() string { + if x != nil { + return x.ProfileName + } + return "" +} + +type AddProfileResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *AddProfileResponse) Reset() { + *x = AddProfileResponse{} + mi := &file_daemon_proto_msgTypes[57] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *AddProfileResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*AddProfileResponse) ProtoMessage() {} + +func (x *AddProfileResponse) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[57] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use AddProfileResponse.ProtoReflect.Descriptor instead. +func (*AddProfileResponse) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{57} +} + +type RemoveProfileRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Username string `protobuf:"bytes,1,opt,name=username,proto3" json:"username,omitempty"` + ProfileName string `protobuf:"bytes,2,opt,name=profileName,proto3" json:"profileName,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *RemoveProfileRequest) Reset() { + *x = RemoveProfileRequest{} + mi := &file_daemon_proto_msgTypes[58] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *RemoveProfileRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RemoveProfileRequest) ProtoMessage() {} + +func (x *RemoveProfileRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[58] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RemoveProfileRequest.ProtoReflect.Descriptor instead. +func (*RemoveProfileRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{58} +} + +func (x *RemoveProfileRequest) GetUsername() string { + if x != nil { + return x.Username + } + return "" +} + +func (x *RemoveProfileRequest) GetProfileName() string { + if x != nil { + return x.ProfileName + } + return "" +} + +type RemoveProfileResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *RemoveProfileResponse) Reset() { + *x = RemoveProfileResponse{} + mi := &file_daemon_proto_msgTypes[59] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *RemoveProfileResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RemoveProfileResponse) ProtoMessage() {} + +func (x *RemoveProfileResponse) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[59] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RemoveProfileResponse.ProtoReflect.Descriptor instead. +func (*RemoveProfileResponse) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{59} +} + +type ListProfilesRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Username string `protobuf:"bytes,1,opt,name=username,proto3" json:"username,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ListProfilesRequest) Reset() { + *x = ListProfilesRequest{} + mi := &file_daemon_proto_msgTypes[60] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ListProfilesRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ListProfilesRequest) ProtoMessage() {} + +func (x *ListProfilesRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[60] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ListProfilesRequest.ProtoReflect.Descriptor instead. +func (*ListProfilesRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{60} +} + +func (x *ListProfilesRequest) GetUsername() string { + if x != nil { + return x.Username + } + return "" +} + +type ListProfilesResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Profiles []*Profile `protobuf:"bytes,1,rep,name=profiles,proto3" json:"profiles,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ListProfilesResponse) Reset() { + *x = ListProfilesResponse{} + mi := &file_daemon_proto_msgTypes[61] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ListProfilesResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ListProfilesResponse) ProtoMessage() {} + +func (x *ListProfilesResponse) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[61] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ListProfilesResponse.ProtoReflect.Descriptor instead. +func (*ListProfilesResponse) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{61} +} + +func (x *ListProfilesResponse) GetProfiles() []*Profile { + if x != nil { + return x.Profiles + } + return nil +} + +type Profile struct { + state protoimpl.MessageState `protogen:"open.v1"` + Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + IsActive bool `protobuf:"varint,2,opt,name=is_active,json=isActive,proto3" json:"is_active,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Profile) Reset() { + *x = Profile{} + mi := &file_daemon_proto_msgTypes[62] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Profile) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Profile) ProtoMessage() {} + +func (x *Profile) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[62] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Profile.ProtoReflect.Descriptor instead. +func (*Profile) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{62} +} + +func (x *Profile) GetName() string { + if x != nil { + return x.Name + } + return "" +} + +func (x *Profile) GetIsActive() bool { + if x != nil { + return x.IsActive + } + return false +} + +type GetActiveProfileRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GetActiveProfileRequest) Reset() { + *x = GetActiveProfileRequest{} + mi := &file_daemon_proto_msgTypes[63] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GetActiveProfileRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetActiveProfileRequest) ProtoMessage() {} + +func (x *GetActiveProfileRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[63] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetActiveProfileRequest.ProtoReflect.Descriptor instead. +func (*GetActiveProfileRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{63} +} + +type GetActiveProfileResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + ProfileName string `protobuf:"bytes,1,opt,name=profileName,proto3" json:"profileName,omitempty"` + Username string `protobuf:"bytes,2,opt,name=username,proto3" json:"username,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GetActiveProfileResponse) Reset() { + *x = GetActiveProfileResponse{} + mi := &file_daemon_proto_msgTypes[64] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GetActiveProfileResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetActiveProfileResponse) ProtoMessage() {} + +func (x *GetActiveProfileResponse) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[64] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetActiveProfileResponse.ProtoReflect.Descriptor instead. +func (*GetActiveProfileResponse) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{64} +} + +func (x *GetActiveProfileResponse) GetProfileName() string { + if x != nil { + return x.ProfileName + } + return "" +} + +func (x *GetActiveProfileResponse) GetUsername() string { + if x != nil { + return x.Username + } + return "" +} + type PortInfo_Range struct { state protoimpl.MessageState `protogen:"open.v1"` Start uint32 `protobuf:"varint,1,opt,name=start,proto3" json:"start,omitempty"` @@ -3513,7 +4352,7 @@ type PortInfo_Range struct { func (x *PortInfo_Range) Reset() { *x = PortInfo_Range{} - mi := &file_daemon_proto_msgTypes[53] + mi := &file_daemon_proto_msgTypes[66] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3525,7 +4364,7 @@ func (x *PortInfo_Range) String() string { func (*PortInfo_Range) ProtoMessage() {} func (x *PortInfo_Range) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[53] + mi := &file_daemon_proto_msgTypes[66] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3560,7 +4399,7 @@ var File_daemon_proto protoreflect.FileDescriptor const file_daemon_proto_rawDesc = "" + "\n" + "\fdaemon.proto\x12\x06daemon\x1a google/protobuf/descriptor.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1egoogle/protobuf/duration.proto\"\x0e\n" + - "\fEmptyRequest\"\xbf\r\n" + + "\fEmptyRequest\"\xa4\x0e\n" + "\fLoginRequest\x12\x1a\n" + "\bsetupKey\x18\x01 \x01(\tR\bsetupKey\x12&\n" + "\fpreSharedKey\x18\x02 \x01(\tB\x02\x18\x01R\fpreSharedKey\x12$\n" + @@ -3594,7 +4433,9 @@ const file_daemon_proto_rawDesc = "" + "dns_labels\x18\x1a \x03(\tR\tdnsLabels\x12&\n" + "\x0ecleanDNSLabels\x18\x1b \x01(\bR\x0ecleanDNSLabels\x129\n" + "\x15lazyConnectionEnabled\x18\x1c \x01(\bH\x0fR\x15lazyConnectionEnabled\x88\x01\x01\x12(\n" + - "\rblock_inbound\x18\x1d \x01(\bH\x10R\fblockInbound\x88\x01\x01B\x13\n" + + "\rblock_inbound\x18\x1d \x01(\bH\x10R\fblockInbound\x88\x01\x01\x12%\n" + + "\vprofileName\x18\x1e \x01(\tH\x11R\vprofileName\x88\x01\x01\x12\x1f\n" + + "\busername\x18\x1f \x01(\tH\x12R\busername\x88\x01\x01B\x13\n" + "\x11_rosenpassEnabledB\x10\n" + "\x0e_interfaceNameB\x10\n" + "\x0e_wireguardPortB\x17\n" + @@ -3611,7 +4452,9 @@ const file_daemon_proto_rawDesc = "" + "\x11_block_lan_accessB\x18\n" + "\x16_disable_notificationsB\x18\n" + "\x16_lazyConnectionEnabledB\x10\n" + - "\x0e_block_inbound\"\xb5\x01\n" + + "\x0e_block_inboundB\x0e\n" + + "\f_profileNameB\v\n" + + "\t_username\"\xb5\x01\n" + "\rLoginResponse\x12$\n" + "\rneedsSSOLogin\x18\x01 \x01(\bR\rneedsSSOLogin\x12\x1a\n" + "\buserCode\x18\x02 \x01(\tR\buserCode\x12(\n" + @@ -3619,9 +4462,14 @@ const file_daemon_proto_rawDesc = "" + "\x17verificationURIComplete\x18\x04 \x01(\tR\x17verificationURIComplete\"M\n" + "\x13WaitSSOLoginRequest\x12\x1a\n" + "\buserCode\x18\x01 \x01(\tR\buserCode\x12\x1a\n" + - "\bhostname\x18\x02 \x01(\tR\bhostname\"\x16\n" + - "\x14WaitSSOLoginResponse\"\v\n" + - "\tUpRequest\"\f\n" + + "\bhostname\x18\x02 \x01(\tR\bhostname\",\n" + + "\x14WaitSSOLoginResponse\x12\x14\n" + + "\x05email\x18\x01 \x01(\tR\x05email\"p\n" + + "\tUpRequest\x12%\n" + + "\vprofileName\x18\x01 \x01(\tH\x00R\vprofileName\x88\x01\x01\x12\x1f\n" + + "\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01B\x0e\n" + + "\f_profileNameB\v\n" + + "\t_username\"\f\n" + "\n" + "UpResponse\"g\n" + "\rStatusRequest\x12,\n" + @@ -3634,8 +4482,10 @@ const file_daemon_proto_rawDesc = "" + "fullStatus\x12$\n" + "\rdaemonVersion\x18\x03 \x01(\tR\rdaemonVersion\"\r\n" + "\vDownRequest\"\x0e\n" + - "\fDownResponse\"\x12\n" + - "\x10GetConfigRequest\"\xa3\x06\n" + + "\fDownResponse\"P\n" + + "\x10GetConfigRequest\x12 \n" + + "\vprofileName\x18\x01 \x01(\tR\vprofileName\x12\x1a\n" + + "\busername\x18\x02 \x01(\tR\busername\"\xa3\x06\n" + "\x11GetConfigResponse\x12$\n" + "\rmanagementUrl\x18\x01 \x01(\tR\rmanagementUrl\x12\x1e\n" + "\n" + @@ -3853,7 +4703,82 @@ const file_daemon_proto_rawDesc = "" + "\x06SYSTEM\x10\x04\"\x12\n" + "\x10GetEventsRequest\"@\n" + "\x11GetEventsResponse\x12+\n" + - "\x06events\x18\x01 \x03(\v2\x13.daemon.SystemEventR\x06events*b\n" + + "\x06events\x18\x01 \x03(\v2\x13.daemon.SystemEventR\x06events\"{\n" + + "\x14SwitchProfileRequest\x12%\n" + + "\vprofileName\x18\x01 \x01(\tH\x00R\vprofileName\x88\x01\x01\x12\x1f\n" + + "\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01B\x0e\n" + + "\f_profileNameB\v\n" + + "\t_username\"\x17\n" + + "\x15SwitchProfileResponse\"\xef\f\n" + + "\x10SetConfigRequest\x12\x1a\n" + + "\busername\x18\x01 \x01(\tR\busername\x12 \n" + + "\vprofileName\x18\x02 \x01(\tR\vprofileName\x12$\n" + + "\rmanagementUrl\x18\x03 \x01(\tR\rmanagementUrl\x12\x1a\n" + + "\badminURL\x18\x04 \x01(\tR\badminURL\x12/\n" + + "\x10rosenpassEnabled\x18\x05 \x01(\bH\x00R\x10rosenpassEnabled\x88\x01\x01\x12)\n" + + "\rinterfaceName\x18\x06 \x01(\tH\x01R\rinterfaceName\x88\x01\x01\x12)\n" + + "\rwireguardPort\x18\a \x01(\x03H\x02R\rwireguardPort\x88\x01\x01\x127\n" + + "\x14optionalPreSharedKey\x18\b \x01(\tH\x03R\x14optionalPreSharedKey\x88\x01\x01\x123\n" + + "\x12disableAutoConnect\x18\t \x01(\bH\x04R\x12disableAutoConnect\x88\x01\x01\x12/\n" + + "\x10serverSSHAllowed\x18\n" + + " \x01(\bH\x05R\x10serverSSHAllowed\x88\x01\x01\x125\n" + + "\x13rosenpassPermissive\x18\v \x01(\bH\x06R\x13rosenpassPermissive\x88\x01\x01\x12+\n" + + "\x0enetworkMonitor\x18\f \x01(\bH\aR\x0enetworkMonitor\x88\x01\x01\x127\n" + + "\x15disable_client_routes\x18\r \x01(\bH\bR\x13disableClientRoutes\x88\x01\x01\x127\n" + + "\x15disable_server_routes\x18\x0e \x01(\bH\tR\x13disableServerRoutes\x88\x01\x01\x12$\n" + + "\vdisable_dns\x18\x0f \x01(\bH\n" + + "R\n" + + "disableDns\x88\x01\x01\x12.\n" + + "\x10disable_firewall\x18\x10 \x01(\bH\vR\x0fdisableFirewall\x88\x01\x01\x12-\n" + + "\x10block_lan_access\x18\x11 \x01(\bH\fR\x0eblockLanAccess\x88\x01\x01\x128\n" + + "\x15disable_notifications\x18\x12 \x01(\bH\rR\x14disableNotifications\x88\x01\x01\x129\n" + + "\x15lazyConnectionEnabled\x18\x13 \x01(\bH\x0eR\x15lazyConnectionEnabled\x88\x01\x01\x12(\n" + + "\rblock_inbound\x18\x14 \x01(\bH\x0fR\fblockInbound\x88\x01\x01\x12&\n" + + "\x0enatExternalIPs\x18\x15 \x03(\tR\x0enatExternalIPs\x120\n" + + "\x13cleanNATExternalIPs\x18\x16 \x01(\bR\x13cleanNATExternalIPs\x12*\n" + + "\x10customDNSAddress\x18\x17 \x01(\fR\x10customDNSAddress\x120\n" + + "\x13extraIFaceBlacklist\x18\x18 \x03(\tR\x13extraIFaceBlacklist\x12\x1d\n" + + "\n" + + "dns_labels\x18\x19 \x03(\tR\tdnsLabels\x12&\n" + + "\x0ecleanDNSLabels\x18\x1a \x01(\bR\x0ecleanDNSLabels\x12J\n" + + "\x10dnsRouteInterval\x18\x1b \x01(\v2\x19.google.protobuf.DurationH\x10R\x10dnsRouteInterval\x88\x01\x01B\x13\n" + + "\x11_rosenpassEnabledB\x10\n" + + "\x0e_interfaceNameB\x10\n" + + "\x0e_wireguardPortB\x17\n" + + "\x15_optionalPreSharedKeyB\x15\n" + + "\x13_disableAutoConnectB\x13\n" + + "\x11_serverSSHAllowedB\x16\n" + + "\x14_rosenpassPermissiveB\x11\n" + + "\x0f_networkMonitorB\x18\n" + + "\x16_disable_client_routesB\x18\n" + + "\x16_disable_server_routesB\x0e\n" + + "\f_disable_dnsB\x13\n" + + "\x11_disable_firewallB\x13\n" + + "\x11_block_lan_accessB\x18\n" + + "\x16_disable_notificationsB\x18\n" + + "\x16_lazyConnectionEnabledB\x10\n" + + "\x0e_block_inboundB\x13\n" + + "\x11_dnsRouteInterval\"\x13\n" + + "\x11SetConfigResponse\"Q\n" + + "\x11AddProfileRequest\x12\x1a\n" + + "\busername\x18\x01 \x01(\tR\busername\x12 \n" + + "\vprofileName\x18\x02 \x01(\tR\vprofileName\"\x14\n" + + "\x12AddProfileResponse\"T\n" + + "\x14RemoveProfileRequest\x12\x1a\n" + + "\busername\x18\x01 \x01(\tR\busername\x12 \n" + + "\vprofileName\x18\x02 \x01(\tR\vprofileName\"\x17\n" + + "\x15RemoveProfileResponse\"1\n" + + "\x13ListProfilesRequest\x12\x1a\n" + + "\busername\x18\x01 \x01(\tR\busername\"C\n" + + "\x14ListProfilesResponse\x12+\n" + + "\bprofiles\x18\x01 \x03(\v2\x0f.daemon.ProfileR\bprofiles\":\n" + + "\aProfile\x12\x12\n" + + "\x04name\x18\x01 \x01(\tR\x04name\x12\x1b\n" + + "\tis_active\x18\x02 \x01(\bR\bisActive\"\x19\n" + + "\x17GetActiveProfileRequest\"X\n" + + "\x18GetActiveProfileResponse\x12 \n" + + "\vprofileName\x18\x01 \x01(\tR\vprofileName\x12\x1a\n" + + "\busername\x18\x02 \x01(\tR\busername*b\n" + "\bLogLevel\x12\v\n" + "\aUNKNOWN\x10\x00\x12\t\n" + "\x05PANIC\x10\x01\x12\t\n" + @@ -3862,7 +4787,7 @@ const file_daemon_proto_rawDesc = "" + "\x04WARN\x10\x04\x12\b\n" + "\x04INFO\x10\x05\x12\t\n" + "\x05DEBUG\x10\x06\x12\t\n" + - "\x05TRACE\x10\a2\xb3\v\n" + + "\x05TRACE\x10\a2\x84\x0f\n" + "\rDaemonService\x126\n" + "\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" + "\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" + @@ -3885,7 +4810,14 @@ const file_daemon_proto_rawDesc = "" + "\x18SetNetworkMapPersistence\x12'.daemon.SetNetworkMapPersistenceRequest\x1a(.daemon.SetNetworkMapPersistenceResponse\"\x00\x12H\n" + "\vTracePacket\x12\x1a.daemon.TracePacketRequest\x1a\x1b.daemon.TracePacketResponse\"\x00\x12D\n" + "\x0fSubscribeEvents\x12\x18.daemon.SubscribeRequest\x1a\x13.daemon.SystemEvent\"\x000\x01\x12B\n" + - "\tGetEvents\x12\x18.daemon.GetEventsRequest\x1a\x19.daemon.GetEventsResponse\"\x00B\bZ\x06/protob\x06proto3" + "\tGetEvents\x12\x18.daemon.GetEventsRequest\x1a\x19.daemon.GetEventsResponse\"\x00\x12N\n" + + "\rSwitchProfile\x12\x1c.daemon.SwitchProfileRequest\x1a\x1d.daemon.SwitchProfileResponse\"\x00\x12B\n" + + "\tSetConfig\x12\x18.daemon.SetConfigRequest\x1a\x19.daemon.SetConfigResponse\"\x00\x12E\n" + + "\n" + + "AddProfile\x12\x19.daemon.AddProfileRequest\x1a\x1a.daemon.AddProfileResponse\"\x00\x12N\n" + + "\rRemoveProfile\x12\x1c.daemon.RemoveProfileRequest\x1a\x1d.daemon.RemoveProfileResponse\"\x00\x12K\n" + + "\fListProfiles\x12\x1b.daemon.ListProfilesRequest\x1a\x1c.daemon.ListProfilesResponse\"\x00\x12W\n" + + "\x10GetActiveProfile\x12\x1f.daemon.GetActiveProfileRequest\x1a .daemon.GetActiveProfileResponse\"\x00B\bZ\x06/protob\x06proto3" var ( file_daemon_proto_rawDescOnce sync.Once @@ -3900,7 +4832,7 @@ func file_daemon_proto_rawDescGZIP() []byte { } var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 3) -var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 55) +var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 68) var file_daemon_proto_goTypes = []any{ (LogLevel)(0), // 0: daemon.LogLevel (SystemEvent_Severity)(0), // 1: daemon.SystemEvent.Severity @@ -3957,18 +4889,31 @@ var file_daemon_proto_goTypes = []any{ (*SystemEvent)(nil), // 52: daemon.SystemEvent (*GetEventsRequest)(nil), // 53: daemon.GetEventsRequest (*GetEventsResponse)(nil), // 54: daemon.GetEventsResponse - nil, // 55: daemon.Network.ResolvedIPsEntry - (*PortInfo_Range)(nil), // 56: daemon.PortInfo.Range - nil, // 57: daemon.SystemEvent.MetadataEntry - (*durationpb.Duration)(nil), // 58: google.protobuf.Duration - (*timestamppb.Timestamp)(nil), // 59: google.protobuf.Timestamp + (*SwitchProfileRequest)(nil), // 55: daemon.SwitchProfileRequest + (*SwitchProfileResponse)(nil), // 56: daemon.SwitchProfileResponse + (*SetConfigRequest)(nil), // 57: daemon.SetConfigRequest + (*SetConfigResponse)(nil), // 58: daemon.SetConfigResponse + (*AddProfileRequest)(nil), // 59: daemon.AddProfileRequest + (*AddProfileResponse)(nil), // 60: daemon.AddProfileResponse + (*RemoveProfileRequest)(nil), // 61: daemon.RemoveProfileRequest + (*RemoveProfileResponse)(nil), // 62: daemon.RemoveProfileResponse + (*ListProfilesRequest)(nil), // 63: daemon.ListProfilesRequest + (*ListProfilesResponse)(nil), // 64: daemon.ListProfilesResponse + (*Profile)(nil), // 65: daemon.Profile + (*GetActiveProfileRequest)(nil), // 66: daemon.GetActiveProfileRequest + (*GetActiveProfileResponse)(nil), // 67: daemon.GetActiveProfileResponse + nil, // 68: daemon.Network.ResolvedIPsEntry + (*PortInfo_Range)(nil), // 69: daemon.PortInfo.Range + nil, // 70: daemon.SystemEvent.MetadataEntry + (*durationpb.Duration)(nil), // 71: google.protobuf.Duration + (*timestamppb.Timestamp)(nil), // 72: google.protobuf.Timestamp } var file_daemon_proto_depIdxs = []int32{ - 58, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration + 71, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration 22, // 1: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus - 59, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp - 59, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp - 58, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration + 72, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp + 72, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp + 71, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration 19, // 5: daemon.FullStatus.managementState:type_name -> daemon.ManagementState 18, // 6: daemon.FullStatus.signalState:type_name -> daemon.SignalState 17, // 7: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState @@ -3977,8 +4922,8 @@ var file_daemon_proto_depIdxs = []int32{ 21, // 10: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState 52, // 11: daemon.FullStatus.events:type_name -> daemon.SystemEvent 28, // 12: daemon.ListNetworksResponse.routes:type_name -> daemon.Network - 55, // 13: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry - 56, // 14: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range + 68, // 13: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry + 69, // 14: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range 29, // 15: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo 29, // 16: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo 30, // 17: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule @@ -3989,55 +4934,69 @@ var file_daemon_proto_depIdxs = []int32{ 49, // 22: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage 1, // 23: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity 2, // 24: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category - 59, // 25: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp - 57, // 26: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry + 72, // 25: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp + 70, // 26: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry 52, // 27: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent - 27, // 28: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList - 4, // 29: daemon.DaemonService.Login:input_type -> daemon.LoginRequest - 6, // 30: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest - 8, // 31: daemon.DaemonService.Up:input_type -> daemon.UpRequest - 10, // 32: daemon.DaemonService.Status:input_type -> daemon.StatusRequest - 12, // 33: daemon.DaemonService.Down:input_type -> daemon.DownRequest - 14, // 34: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest - 23, // 35: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest - 25, // 36: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest - 25, // 37: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest - 3, // 38: daemon.DaemonService.ForwardingRules:input_type -> daemon.EmptyRequest - 32, // 39: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest - 34, // 40: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest - 36, // 41: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest - 39, // 42: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest - 41, // 43: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest - 43, // 44: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest - 45, // 45: daemon.DaemonService.SetNetworkMapPersistence:input_type -> daemon.SetNetworkMapPersistenceRequest - 48, // 46: daemon.DaemonService.TracePacket:input_type -> daemon.TracePacketRequest - 51, // 47: daemon.DaemonService.SubscribeEvents:input_type -> daemon.SubscribeRequest - 53, // 48: daemon.DaemonService.GetEvents:input_type -> daemon.GetEventsRequest - 5, // 49: daemon.DaemonService.Login:output_type -> daemon.LoginResponse - 7, // 50: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse - 9, // 51: daemon.DaemonService.Up:output_type -> daemon.UpResponse - 11, // 52: daemon.DaemonService.Status:output_type -> daemon.StatusResponse - 13, // 53: daemon.DaemonService.Down:output_type -> daemon.DownResponse - 15, // 54: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse - 24, // 55: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse - 26, // 56: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse - 26, // 57: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse - 31, // 58: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse - 33, // 59: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse - 35, // 60: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse - 37, // 61: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse - 40, // 62: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse - 42, // 63: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse - 44, // 64: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse - 46, // 65: daemon.DaemonService.SetNetworkMapPersistence:output_type -> daemon.SetNetworkMapPersistenceResponse - 50, // 66: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse - 52, // 67: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent - 54, // 68: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse - 49, // [49:69] is the sub-list for method output_type - 29, // [29:49] is the sub-list for method input_type - 29, // [29:29] is the sub-list for extension type_name - 29, // [29:29] is the sub-list for extension extendee - 0, // [0:29] is the sub-list for field type_name + 71, // 28: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration + 65, // 29: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile + 27, // 30: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList + 4, // 31: daemon.DaemonService.Login:input_type -> daemon.LoginRequest + 6, // 32: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest + 8, // 33: daemon.DaemonService.Up:input_type -> daemon.UpRequest + 10, // 34: daemon.DaemonService.Status:input_type -> daemon.StatusRequest + 12, // 35: daemon.DaemonService.Down:input_type -> daemon.DownRequest + 14, // 36: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest + 23, // 37: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest + 25, // 38: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest + 25, // 39: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest + 3, // 40: daemon.DaemonService.ForwardingRules:input_type -> daemon.EmptyRequest + 32, // 41: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest + 34, // 42: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest + 36, // 43: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest + 39, // 44: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest + 41, // 45: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest + 43, // 46: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest + 45, // 47: daemon.DaemonService.SetNetworkMapPersistence:input_type -> daemon.SetNetworkMapPersistenceRequest + 48, // 48: daemon.DaemonService.TracePacket:input_type -> daemon.TracePacketRequest + 51, // 49: daemon.DaemonService.SubscribeEvents:input_type -> daemon.SubscribeRequest + 53, // 50: daemon.DaemonService.GetEvents:input_type -> daemon.GetEventsRequest + 55, // 51: daemon.DaemonService.SwitchProfile:input_type -> daemon.SwitchProfileRequest + 57, // 52: daemon.DaemonService.SetConfig:input_type -> daemon.SetConfigRequest + 59, // 53: daemon.DaemonService.AddProfile:input_type -> daemon.AddProfileRequest + 61, // 54: daemon.DaemonService.RemoveProfile:input_type -> daemon.RemoveProfileRequest + 63, // 55: daemon.DaemonService.ListProfiles:input_type -> daemon.ListProfilesRequest + 66, // 56: daemon.DaemonService.GetActiveProfile:input_type -> daemon.GetActiveProfileRequest + 5, // 57: daemon.DaemonService.Login:output_type -> daemon.LoginResponse + 7, // 58: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse + 9, // 59: daemon.DaemonService.Up:output_type -> daemon.UpResponse + 11, // 60: daemon.DaemonService.Status:output_type -> daemon.StatusResponse + 13, // 61: daemon.DaemonService.Down:output_type -> daemon.DownResponse + 15, // 62: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse + 24, // 63: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse + 26, // 64: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse + 26, // 65: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse + 31, // 66: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse + 33, // 67: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse + 35, // 68: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse + 37, // 69: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse + 40, // 70: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse + 42, // 71: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse + 44, // 72: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse + 46, // 73: daemon.DaemonService.SetNetworkMapPersistence:output_type -> daemon.SetNetworkMapPersistenceResponse + 50, // 74: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse + 52, // 75: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent + 54, // 76: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse + 56, // 77: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse + 58, // 78: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse + 60, // 79: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse + 62, // 80: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse + 64, // 81: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse + 67, // 82: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse + 57, // [57:83] is the sub-list for method output_type + 31, // [31:57] is the sub-list for method input_type + 31, // [31:31] is the sub-list for extension type_name + 31, // [31:31] is the sub-list for extension extendee + 0, // [0:31] is the sub-list for field type_name } func init() { file_daemon_proto_init() } @@ -4046,19 +5005,22 @@ func file_daemon_proto_init() { return } file_daemon_proto_msgTypes[1].OneofWrappers = []any{} + file_daemon_proto_msgTypes[5].OneofWrappers = []any{} file_daemon_proto_msgTypes[26].OneofWrappers = []any{ (*PortInfo_Port)(nil), (*PortInfo_Range_)(nil), } file_daemon_proto_msgTypes[45].OneofWrappers = []any{} file_daemon_proto_msgTypes[46].OneofWrappers = []any{} + file_daemon_proto_msgTypes[52].OneofWrappers = []any{} + file_daemon_proto_msgTypes[54].OneofWrappers = []any{} type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)), NumEnums: 3, - NumMessages: 55, + NumMessages: 68, NumExtensions: 0, NumServices: 1, }, diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index 462555c82..c25503df9 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -67,6 +67,18 @@ service DaemonService { rpc SubscribeEvents(SubscribeRequest) returns (stream SystemEvent) {} rpc GetEvents(GetEventsRequest) returns (GetEventsResponse) {} + + rpc SwitchProfile(SwitchProfileRequest) returns (SwitchProfileResponse) {} + + rpc SetConfig(SetConfigRequest) returns (SetConfigResponse) {} + + rpc AddProfile(AddProfileRequest) returns (AddProfileResponse) {} + + rpc RemoveProfile(RemoveProfileRequest) returns (RemoveProfileResponse) {} + + rpc ListProfiles(ListProfilesRequest) returns (ListProfilesResponse) {} + + rpc GetActiveProfile(GetActiveProfileRequest) returns (GetActiveProfileResponse) {} } @@ -136,6 +148,9 @@ message LoginRequest { optional bool lazyConnectionEnabled = 28; optional bool block_inbound = 29; + + optional string profileName = 30; + optional string username = 31; } message LoginResponse { @@ -150,9 +165,14 @@ message WaitSSOLoginRequest { string hostname = 2; } -message WaitSSOLoginResponse {} +message WaitSSOLoginResponse { + string email = 1; +} -message UpRequest {} +message UpRequest { + optional string profileName = 1; + optional string username = 2; +} message UpResponse {} @@ -173,7 +193,10 @@ message DownRequest {} message DownResponse {} -message GetConfigRequest {} +message GetConfigRequest { + string profileName = 1; + string username = 2; +} message GetConfigResponse { // managementUrl settings value. @@ -497,3 +520,98 @@ message GetEventsRequest {} message GetEventsResponse { repeated SystemEvent events = 1; } + +message SwitchProfileRequest { + optional string profileName = 1; + optional string username = 2; +} + +message SwitchProfileResponse {} + +message SetConfigRequest { + string username = 1; + string profileName = 2; + // managementUrl to authenticate. + string managementUrl = 3; + + // adminUrl to manage keys. + string adminURL = 4; + + optional bool rosenpassEnabled = 5; + + optional string interfaceName = 6; + + optional int64 wireguardPort = 7; + + optional string optionalPreSharedKey = 8; + + optional bool disableAutoConnect = 9; + + optional bool serverSSHAllowed = 10; + + optional bool rosenpassPermissive = 11; + + optional bool networkMonitor = 12; + + optional bool disable_client_routes = 13; + optional bool disable_server_routes = 14; + optional bool disable_dns = 15; + optional bool disable_firewall = 16; + optional bool block_lan_access = 17; + + optional bool disable_notifications = 18; + + optional bool lazyConnectionEnabled = 19; + + optional bool block_inbound = 20; + + repeated string natExternalIPs = 21; + bool cleanNATExternalIPs = 22; + + bytes customDNSAddress = 23; + + repeated string extraIFaceBlacklist = 24; + + repeated string dns_labels = 25; + // cleanDNSLabels clean map list of DNS labels. + bool cleanDNSLabels = 26; + + optional google.protobuf.Duration dnsRouteInterval = 27; + +} + +message SetConfigResponse{} + +message AddProfileRequest { + string username = 1; + string profileName = 2; +} + +message AddProfileResponse {} + +message RemoveProfileRequest { + string username = 1; + string profileName = 2; +} + +message RemoveProfileResponse {} + +message ListProfilesRequest { + string username = 1; +} + +message ListProfilesResponse { + repeated Profile profiles = 1; +} + +message Profile { + string name = 1; + bool is_active = 2; +} + +message GetActiveProfileRequest {} + +message GetActiveProfileResponse { + string profileName = 1; + string username = 2; +} \ No newline at end of file diff --git a/client/proto/daemon_grpc.pb.go b/client/proto/daemon_grpc.pb.go index 6251f7c52..669083168 100644 --- a/client/proto/daemon_grpc.pb.go +++ b/client/proto/daemon_grpc.pb.go @@ -55,6 +55,12 @@ type DaemonServiceClient interface { TracePacket(ctx context.Context, in *TracePacketRequest, opts ...grpc.CallOption) (*TracePacketResponse, error) SubscribeEvents(ctx context.Context, in *SubscribeRequest, opts ...grpc.CallOption) (DaemonService_SubscribeEventsClient, error) GetEvents(ctx context.Context, in *GetEventsRequest, opts ...grpc.CallOption) (*GetEventsResponse, error) + SwitchProfile(ctx context.Context, in *SwitchProfileRequest, opts ...grpc.CallOption) (*SwitchProfileResponse, error) + SetConfig(ctx context.Context, in *SetConfigRequest, opts ...grpc.CallOption) (*SetConfigResponse, error) + AddProfile(ctx context.Context, in *AddProfileRequest, opts ...grpc.CallOption) (*AddProfileResponse, error) + RemoveProfile(ctx context.Context, in *RemoveProfileRequest, opts ...grpc.CallOption) (*RemoveProfileResponse, error) + ListProfiles(ctx context.Context, in *ListProfilesRequest, opts ...grpc.CallOption) (*ListProfilesResponse, error) + GetActiveProfile(ctx context.Context, in *GetActiveProfileRequest, opts ...grpc.CallOption) (*GetActiveProfileResponse, error) } type daemonServiceClient struct { @@ -268,6 +274,60 @@ func (c *daemonServiceClient) GetEvents(ctx context.Context, in *GetEventsReques return out, nil } +func (c *daemonServiceClient) SwitchProfile(ctx context.Context, in *SwitchProfileRequest, opts ...grpc.CallOption) (*SwitchProfileResponse, error) { + out := new(SwitchProfileResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/SwitchProfile", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *daemonServiceClient) SetConfig(ctx context.Context, in *SetConfigRequest, opts ...grpc.CallOption) (*SetConfigResponse, error) { + out := new(SetConfigResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/SetConfig", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *daemonServiceClient) AddProfile(ctx context.Context, in *AddProfileRequest, opts ...grpc.CallOption) (*AddProfileResponse, error) { + out := new(AddProfileResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/AddProfile", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *daemonServiceClient) RemoveProfile(ctx context.Context, in *RemoveProfileRequest, opts ...grpc.CallOption) (*RemoveProfileResponse, error) { + out := new(RemoveProfileResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/RemoveProfile", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *daemonServiceClient) ListProfiles(ctx context.Context, in *ListProfilesRequest, opts ...grpc.CallOption) (*ListProfilesResponse, error) { + out := new(ListProfilesResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/ListProfiles", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *daemonServiceClient) GetActiveProfile(ctx context.Context, in *GetActiveProfileRequest, opts ...grpc.CallOption) (*GetActiveProfileResponse, error) { + out := new(GetActiveProfileResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetActiveProfile", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + // DaemonServiceServer is the server API for DaemonService service. // All implementations must embed UnimplementedDaemonServiceServer // for forward compatibility @@ -309,6 +369,12 @@ type DaemonServiceServer interface { TracePacket(context.Context, *TracePacketRequest) (*TracePacketResponse, error) SubscribeEvents(*SubscribeRequest, DaemonService_SubscribeEventsServer) error GetEvents(context.Context, *GetEventsRequest) (*GetEventsResponse, error) + SwitchProfile(context.Context, *SwitchProfileRequest) (*SwitchProfileResponse, error) + SetConfig(context.Context, *SetConfigRequest) (*SetConfigResponse, error) + AddProfile(context.Context, *AddProfileRequest) (*AddProfileResponse, error) + RemoveProfile(context.Context, *RemoveProfileRequest) (*RemoveProfileResponse, error) + ListProfiles(context.Context, *ListProfilesRequest) (*ListProfilesResponse, error) + GetActiveProfile(context.Context, *GetActiveProfileRequest) (*GetActiveProfileResponse, error) mustEmbedUnimplementedDaemonServiceServer() } @@ -376,6 +442,24 @@ func (UnimplementedDaemonServiceServer) SubscribeEvents(*SubscribeRequest, Daemo func (UnimplementedDaemonServiceServer) GetEvents(context.Context, *GetEventsRequest) (*GetEventsResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method GetEvents not implemented") } +func (UnimplementedDaemonServiceServer) SwitchProfile(context.Context, *SwitchProfileRequest) (*SwitchProfileResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method SwitchProfile not implemented") +} +func (UnimplementedDaemonServiceServer) SetConfig(context.Context, *SetConfigRequest) (*SetConfigResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method SetConfig not implemented") +} +func (UnimplementedDaemonServiceServer) AddProfile(context.Context, *AddProfileRequest) (*AddProfileResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method AddProfile not implemented") +} +func (UnimplementedDaemonServiceServer) RemoveProfile(context.Context, *RemoveProfileRequest) (*RemoveProfileResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method RemoveProfile not implemented") +} +func (UnimplementedDaemonServiceServer) ListProfiles(context.Context, *ListProfilesRequest) (*ListProfilesResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method ListProfiles not implemented") +} +func (UnimplementedDaemonServiceServer) GetActiveProfile(context.Context, *GetActiveProfileRequest) (*GetActiveProfileResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method GetActiveProfile not implemented") +} func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {} // UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service. @@ -752,6 +836,114 @@ func _DaemonService_GetEvents_Handler(srv interface{}, ctx context.Context, dec return interceptor(ctx, in, info, handler) } +func _DaemonService_SwitchProfile_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(SwitchProfileRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DaemonServiceServer).SwitchProfile(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/daemon.DaemonService/SwitchProfile", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DaemonServiceServer).SwitchProfile(ctx, req.(*SwitchProfileRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _DaemonService_SetConfig_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(SetConfigRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DaemonServiceServer).SetConfig(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/daemon.DaemonService/SetConfig", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DaemonServiceServer).SetConfig(ctx, req.(*SetConfigRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _DaemonService_AddProfile_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(AddProfileRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DaemonServiceServer).AddProfile(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/daemon.DaemonService/AddProfile", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DaemonServiceServer).AddProfile(ctx, req.(*AddProfileRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _DaemonService_RemoveProfile_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(RemoveProfileRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DaemonServiceServer).RemoveProfile(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/daemon.DaemonService/RemoveProfile", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DaemonServiceServer).RemoveProfile(ctx, req.(*RemoveProfileRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _DaemonService_ListProfiles_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(ListProfilesRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DaemonServiceServer).ListProfiles(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/daemon.DaemonService/ListProfiles", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DaemonServiceServer).ListProfiles(ctx, req.(*ListProfilesRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _DaemonService_GetActiveProfile_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(GetActiveProfileRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DaemonServiceServer).GetActiveProfile(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/daemon.DaemonService/GetActiveProfile", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DaemonServiceServer).GetActiveProfile(ctx, req.(*GetActiveProfileRequest)) + } + return interceptor(ctx, in, info, handler) +} + // DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) @@ -835,6 +1027,30 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{ MethodName: "GetEvents", Handler: _DaemonService_GetEvents_Handler, }, + { + MethodName: "SwitchProfile", + Handler: _DaemonService_SwitchProfile_Handler, + }, + { + MethodName: "SetConfig", + Handler: _DaemonService_SetConfig_Handler, + }, + { + MethodName: "AddProfile", + Handler: _DaemonService_AddProfile_Handler, + }, + { + MethodName: "RemoveProfile", + Handler: _DaemonService_RemoveProfile_Handler, + }, + { + MethodName: "ListProfiles", + Handler: _DaemonService_ListProfiles_Handler, + }, + { + MethodName: "GetActiveProfile", + Handler: _DaemonService_GetActiveProfile_Handler, + }, }, Streams: []grpc.StreamDesc{ { diff --git a/client/server/panic_windows.go b/client/server/panic_windows.go index c5e73be7c..f441ec9ea 100644 --- a/client/server/panic_windows.go +++ b/client/server/panic_windows.go @@ -1,3 +1,6 @@ +//go:build windows +// +build windows + package server import ( diff --git a/client/server/server.go b/client/server/server.go index e3ce1a2b4..f3414888d 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -22,6 +22,7 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" "github.com/netbirdio/netbird/client/internal/auth" + "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/management/domain" @@ -50,14 +51,12 @@ type Server struct { rootCtx context.Context actCancel context.CancelFunc - latestConfigInput internal.ConfigInput - logFile string oauthAuthFlow oauthAuthFlow mutex sync.Mutex - config *internal.Config + config *profilemanager.Config proto.UnimplementedDaemonServiceServer connectClient *internal.ConnectClient @@ -68,6 +67,8 @@ type Server struct { lastProbe time.Time persistNetworkMap bool isSessionActive atomic.Bool + + profileManager profilemanager.ServiceManager } type oauthAuthFlow struct { @@ -78,15 +79,13 @@ type oauthAuthFlow struct { } // New server instance constructor. -func New(ctx context.Context, configPath, logFile string) *Server { +func New(ctx context.Context, logFile string) *Server { return &Server{ - rootCtx: ctx, - latestConfigInput: internal.ConfigInput{ - ConfigPath: configPath, - }, + rootCtx: ctx, logFile: logFile, persistNetworkMap: true, statusRecorder: peer.NewRecorder(""), + profileManager: profilemanager.ServiceManager{}, } } @@ -99,7 +98,7 @@ func (s *Server) Start() error { log.Warnf("failed to redirect stderr: %v", err) } - if err := restoreResidualState(s.rootCtx); err != nil { + if err := restoreResidualState(s.rootCtx, s.profileManager.GetStatePath()); err != nil { log.Warnf(errRestoreResidualState, err) } @@ -118,25 +117,41 @@ func (s *Server) Start() error { ctx, cancel := context.WithCancel(s.rootCtx) s.actCancel = cancel - // if configuration exists, we just start connections. if is new config we skip and set status NeedsLogin - // on failure we return error to retry - config, err := internal.UpdateConfig(s.latestConfigInput) - if errorStatus, ok := gstatus.FromError(err); ok && errorStatus.Code() == codes.NotFound { - s.config, err = internal.UpdateOrCreateConfig(s.latestConfigInput) - if err != nil { - log.Warnf("unable to create configuration file: %v", err) - return err - } - state.Set(internal.StatusNeedsLogin) - return nil - } else if err != nil { - log.Warnf("unable to create configuration file: %v", err) - return err + // set the default config if not exists + if err := s.setDefaultConfigIfNotExists(ctx); err != nil { + log.Errorf("failed to set default config: %v", err) + return fmt.Errorf("failed to set default config: %w", err) } - // if configuration exists, we just start connections. - config, _ = internal.UpdateOldManagementURL(ctx, config, s.latestConfigInput.ConfigPath) + activeProf, err := s.profileManager.GetActiveProfileState() + if err != nil { + return fmt.Errorf("failed to get active profile state: %w", err) + } + cfgPath, err := activeProf.FilePath() + if err != nil { + log.Errorf("failed to get active profile file path: %v", err) + return fmt.Errorf("failed to get active profile file path: %w", err) + } + + config, err := profilemanager.GetConfig(cfgPath) + if err != nil { + log.Errorf("failed to get active profile config: %v", err) + + if err := s.profileManager.SetActiveProfileState(&profilemanager.ActiveProfileState{ + Name: "default", + Username: "", + }); err != nil { + log.Errorf("failed to set active profile state: %v", err) + return fmt.Errorf("failed to set active profile state: %w", err) + } + + config, err = profilemanager.GetConfig(s.profileManager.DefaultProfilePath()) + if err != nil { + log.Errorf("failed to get default profile config: %v", err) + return fmt.Errorf("failed to get default profile config: %w", err) + } + } s.config = config s.statusRecorder.UpdateManagementAddress(config.ManagementURL.String()) @@ -157,10 +172,34 @@ func (s *Server) Start() error { return nil } +func (s *Server) setDefaultConfigIfNotExists(ctx context.Context) error { + ok, err := s.profileManager.CopyDefaultProfileIfNotExists() + if err != nil { + if err := s.profileManager.CreateDefaultProfile(); err != nil { + log.Errorf("failed to create default profile: %v", err) + return fmt.Errorf("failed to create default profile: %w", err) + } + + if err := s.profileManager.SetActiveProfileState(&profilemanager.ActiveProfileState{ + Name: "default", + Username: "", + }); err != nil { + log.Errorf("failed to set active profile state: %v", err) + return fmt.Errorf("failed to set active profile state: %w", err) + } + } + if ok { + state := internal.CtxGetState(ctx) + state.Set(internal.StatusNeedsLogin) + } + + return nil +} + // connectWithRetryRuns runs the client connection with a backoff strategy where we retry the operation as additional // mechanism to keep the client connected even when the connection is lost. // we cancel retry if the client receive a stop or down command, or if disable auto connect is configured. -func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Config, statusRecorder *peer.Status, +func (s *Server) connectWithRetryRuns(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}, ) { backOff := getConnectWithBackoff(ctx) @@ -276,6 +315,90 @@ func (s *Server) loginAttempt(ctx context.Context, setupKey, jwtToken string) (i return "", nil } +// Login uses setup key to prepare configuration for the daemon. +func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigRequest) (*proto.SetConfigResponse, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + profState := profilemanager.ActiveProfileState{ + Name: msg.ProfileName, + Username: msg.Username, + } + + profPath, err := profState.FilePath() + if err != nil { + log.Errorf("failed to get active profile file path: %v", err) + return nil, fmt.Errorf("failed to get active profile file path: %w", err) + } + + var config profilemanager.ConfigInput + + config.ConfigPath = profPath + + if msg.ManagementUrl != "" { + config.ManagementURL = msg.ManagementUrl + } + + if msg.AdminURL != "" { + config.AdminURL = msg.AdminURL + } + + if msg.InterfaceName != nil { + config.InterfaceName = msg.InterfaceName + } + + if msg.WireguardPort != nil { + wgPort := int(*msg.WireguardPort) + config.WireguardPort = &wgPort + } + + if msg.OptionalPreSharedKey != nil { + if *msg.OptionalPreSharedKey != "" { + config.PreSharedKey = msg.OptionalPreSharedKey + } + } + + if msg.CleanDNSLabels { + config.DNSLabels = domain.List{} + + } else if msg.DnsLabels != nil { + dnsLabels := domain.FromPunycodeList(msg.DnsLabels) + config.DNSLabels = dnsLabels + } + + if msg.CleanNATExternalIPs { + config.NATExternalIPs = make([]string, 0) + } else if msg.NatExternalIPs != nil { + config.NATExternalIPs = msg.NatExternalIPs + } + + config.CustomDNSAddress = msg.CustomDNSAddress + if string(msg.CustomDNSAddress) == "empty" { + config.CustomDNSAddress = []byte{} + } + + config.RosenpassEnabled = msg.RosenpassEnabled + config.RosenpassPermissive = msg.RosenpassPermissive + config.DisableAutoConnect = msg.DisableAutoConnect + config.ServerSSHAllowed = msg.ServerSSHAllowed + config.NetworkMonitor = msg.NetworkMonitor + config.DisableClientRoutes = msg.DisableClientRoutes + config.DisableServerRoutes = msg.DisableServerRoutes + config.DisableDNS = msg.DisableDns + config.DisableFirewall = msg.DisableFirewall + config.BlockLANAccess = msg.BlockLanAccess + config.DisableNotifications = msg.DisableNotifications + config.LazyConnectionEnabled = msg.LazyConnectionEnabled + config.BlockInbound = msg.BlockInbound + + if _, err := profilemanager.UpdateConfig(config); err != nil { + log.Errorf("failed to update profile config: %v", err) + return nil, fmt.Errorf("failed to update profile config: %w", err) + } + + return &proto.SetConfigResponse{}, nil +} + // Login uses setup key to prepare configuration for the daemon. func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*proto.LoginResponse, error) { s.mutex.Lock() @@ -292,7 +415,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro s.actCancel = cancel s.mutex.Unlock() - if err := restoreResidualState(ctx); err != nil { + if err := restoreResidualState(ctx, s.profileManager.GetStatePath()); err != nil { log.Warnf(errRestoreResidualState, err) } @@ -304,147 +427,62 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro } }() + activeProf, err := s.profileManager.GetActiveProfileState() + if err != nil { + log.Errorf("failed to get active profile state: %v", err) + return nil, fmt.Errorf("failed to get active profile state: %w", err) + } + + if msg.ProfileName != nil { + if *msg.ProfileName != "default" && (msg.Username == nil || *msg.Username == "") { + log.Errorf("profile name is set to %s, but username is not provided", *msg.ProfileName) + return nil, fmt.Errorf("profile name is set to %s, but username is not provided", *msg.ProfileName) + } + + var username string + if *msg.ProfileName != "default" { + username = *msg.Username + } + + if *msg.ProfileName != activeProf.Name && username != activeProf.Username { + log.Infof("switching to profile %s for user '%s'", *msg.ProfileName, username) + if err := s.profileManager.SetActiveProfileState(&profilemanager.ActiveProfileState{ + Name: *msg.ProfileName, + Username: username, + }); err != nil { + log.Errorf("failed to set active profile state: %v", err) + return nil, fmt.Errorf("failed to set active profile state: %w", err) + } + } + } + + activeProf, err = s.profileManager.GetActiveProfileState() + if err != nil { + log.Errorf("failed to get active profile state: %v", err) + return nil, fmt.Errorf("failed to get active profile state: %w", err) + } + + log.Infof("active profile: %s for %s", activeProf.Name, activeProf.Username) + s.mutex.Lock() - inputConfig := s.latestConfigInput - - if msg.ManagementUrl != "" { - inputConfig.ManagementURL = msg.ManagementUrl - s.latestConfigInput.ManagementURL = msg.ManagementUrl - } - - if msg.AdminURL != "" { - inputConfig.AdminURL = msg.AdminURL - s.latestConfigInput.AdminURL = msg.AdminURL - } - - if msg.CleanNATExternalIPs { - inputConfig.NATExternalIPs = make([]string, 0) - s.latestConfigInput.NATExternalIPs = nil - } else if msg.NatExternalIPs != nil { - inputConfig.NATExternalIPs = msg.NatExternalIPs - s.latestConfigInput.NATExternalIPs = msg.NatExternalIPs - } - - inputConfig.CustomDNSAddress = msg.CustomDNSAddress - s.latestConfigInput.CustomDNSAddress = msg.CustomDNSAddress - if string(msg.CustomDNSAddress) == "empty" { - inputConfig.CustomDNSAddress = []byte{} - s.latestConfigInput.CustomDNSAddress = []byte{} - } if msg.Hostname != "" { // nolint ctx = context.WithValue(ctx, system.DeviceNameCtxKey, msg.Hostname) } - - if msg.RosenpassEnabled != nil { - inputConfig.RosenpassEnabled = msg.RosenpassEnabled - s.latestConfigInput.RosenpassEnabled = msg.RosenpassEnabled - } - - if msg.RosenpassPermissive != nil { - inputConfig.RosenpassPermissive = msg.RosenpassPermissive - s.latestConfigInput.RosenpassPermissive = msg.RosenpassPermissive - } - - if msg.ServerSSHAllowed != nil { - inputConfig.ServerSSHAllowed = msg.ServerSSHAllowed - s.latestConfigInput.ServerSSHAllowed = msg.ServerSSHAllowed - } - - if msg.DisableAutoConnect != nil { - inputConfig.DisableAutoConnect = msg.DisableAutoConnect - s.latestConfigInput.DisableAutoConnect = msg.DisableAutoConnect - } - - if msg.InterfaceName != nil { - inputConfig.InterfaceName = msg.InterfaceName - s.latestConfigInput.InterfaceName = msg.InterfaceName - } - - if msg.WireguardPort != nil { - port := int(*msg.WireguardPort) - inputConfig.WireguardPort = &port - s.latestConfigInput.WireguardPort = &port - } - - if msg.NetworkMonitor != nil { - inputConfig.NetworkMonitor = msg.NetworkMonitor - s.latestConfigInput.NetworkMonitor = msg.NetworkMonitor - } - - if len(msg.ExtraIFaceBlacklist) > 0 { - inputConfig.ExtraIFaceBlackList = msg.ExtraIFaceBlacklist - s.latestConfigInput.ExtraIFaceBlackList = msg.ExtraIFaceBlacklist - } - - if msg.DnsRouteInterval != nil { - duration := msg.DnsRouteInterval.AsDuration() - inputConfig.DNSRouteInterval = &duration - s.latestConfigInput.DNSRouteInterval = &duration - } - - if msg.DisableClientRoutes != nil { - inputConfig.DisableClientRoutes = msg.DisableClientRoutes - s.latestConfigInput.DisableClientRoutes = msg.DisableClientRoutes - } - if msg.DisableServerRoutes != nil { - inputConfig.DisableServerRoutes = msg.DisableServerRoutes - s.latestConfigInput.DisableServerRoutes = msg.DisableServerRoutes - } - if msg.DisableDns != nil { - inputConfig.DisableDNS = msg.DisableDns - s.latestConfigInput.DisableDNS = msg.DisableDns - } - if msg.DisableFirewall != nil { - inputConfig.DisableFirewall = msg.DisableFirewall - s.latestConfigInput.DisableFirewall = msg.DisableFirewall - } - if msg.BlockLanAccess != nil { - inputConfig.BlockLANAccess = msg.BlockLanAccess - s.latestConfigInput.BlockLANAccess = msg.BlockLanAccess - } - if msg.BlockInbound != nil { - inputConfig.BlockInbound = msg.BlockInbound - s.latestConfigInput.BlockInbound = msg.BlockInbound - } - - if msg.CleanDNSLabels { - inputConfig.DNSLabels = domain.List{} - s.latestConfigInput.DNSLabels = nil - } else if msg.DnsLabels != nil { - dnsLabels := domain.FromPunycodeList(msg.DnsLabels) - inputConfig.DNSLabels = dnsLabels - s.latestConfigInput.DNSLabels = dnsLabels - } - - if msg.DisableNotifications != nil { - inputConfig.DisableNotifications = msg.DisableNotifications - s.latestConfigInput.DisableNotifications = msg.DisableNotifications - } - - if msg.LazyConnectionEnabled != nil { - inputConfig.LazyConnectionEnabled = msg.LazyConnectionEnabled - s.latestConfigInput.LazyConnectionEnabled = msg.LazyConnectionEnabled - } - s.mutex.Unlock() - if msg.OptionalPreSharedKey != nil { - inputConfig.PreSharedKey = msg.OptionalPreSharedKey - } - - config, err := internal.UpdateOrCreateConfig(inputConfig) + cfgPath, err := activeProf.FilePath() if err != nil { - return nil, err + log.Errorf("failed to get active profile file path: %v", err) + return nil, fmt.Errorf("failed to get active profile file path: %w", err) } - if msg.ManagementUrl == "" { - config, _ = internal.UpdateOldManagementURL(ctx, config, s.latestConfigInput.ConfigPath) - s.config = config - s.latestConfigInput.ManagementURL = config.ManagementURL.String() + config, err := profilemanager.GetConfig(cfgPath) + if err != nil { + log.Errorf("failed to get active profile config: %v", err) + return nil, fmt.Errorf("failed to get active profile config: %w", err) } - s.mutex.Lock() s.config = config s.mutex.Unlock() @@ -586,15 +624,17 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin return nil, err } - return &proto.WaitSSOLoginResponse{}, nil + return &proto.WaitSSOLoginResponse{ + Email: tokenInfo.Email, + }, nil } // Up starts engine work in the daemon. -func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpResponse, error) { +func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpResponse, error) { s.mutex.Lock() defer s.mutex.Unlock() - if err := restoreResidualState(callerCtx); err != nil { + if err := restoreResidualState(callerCtx, s.profileManager.GetStatePath()); err != nil { log.Warnf(errRestoreResidualState, err) } @@ -628,6 +668,40 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes return nil, fmt.Errorf("config is not defined, please call login command first") } + activeProf, err := s.profileManager.GetActiveProfileState() + if err != nil { + log.Errorf("failed to get active profile state: %v", err) + return nil, fmt.Errorf("failed to get active profile state: %w", err) + } + + if msg != nil && msg.ProfileName != nil { + if err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil { + log.Errorf("failed to switch profile: %v", err) + return nil, fmt.Errorf("failed to switch profile: %w", err) + } + } + + activeProf, err = s.profileManager.GetActiveProfileState() + if err != nil { + log.Errorf("failed to get active profile state: %v", err) + return nil, fmt.Errorf("failed to get active profile state: %w", err) + } + + log.Infof("active profile: %s for %s", activeProf.Name, activeProf.Username) + + cfgPath, err := activeProf.FilePath() + if err != nil { + log.Errorf("failed to get active profile file path: %v", err) + return nil, fmt.Errorf("failed to get active profile file path: %w", err) + } + + config, err := profilemanager.GetConfig(cfgPath) + if err != nil { + log.Errorf("failed to get active profile config: %v", err) + return nil, fmt.Errorf("failed to get active profile config: %w", err) + } + s.config = config + s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String()) s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive) @@ -651,6 +725,70 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes } } +func (s *Server) switchProfileIfNeeded(profileName string, userName *string, activeProf *profilemanager.ActiveProfileState) error { + if profileName != "default" && (userName == nil || *userName == "") { + log.Errorf("profile name is set to %s, but username is not provided", profileName) + return fmt.Errorf("profile name is set to %s, but username is not provided", profileName) + } + + var username string + if profileName != "default" { + username = *userName + } + + if profileName != activeProf.Name || username != activeProf.Username { + log.Infof("switching to profile %s for user %s", profileName, username) + if err := s.profileManager.SetActiveProfileState(&profilemanager.ActiveProfileState{ + Name: profileName, + Username: username, + }); err != nil { + log.Errorf("failed to set active profile state: %v", err) + return fmt.Errorf("failed to set active profile state: %w", err) + } + } + + return nil +} + +// SwitchProfile switches the active profile in the daemon. +func (s *Server) SwitchProfile(callerCtx context.Context, msg *proto.SwitchProfileRequest) (*proto.SwitchProfileResponse, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + activeProf, err := s.profileManager.GetActiveProfileState() + if err != nil { + log.Errorf("failed to get active profile state: %v", err) + return nil, fmt.Errorf("failed to get active profile state: %w", err) + } + + if msg != nil && msg.ProfileName != nil { + if err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil { + log.Errorf("failed to switch profile: %v", err) + return nil, fmt.Errorf("failed to switch profile: %w", err) + } + } + activeProf, err = s.profileManager.GetActiveProfileState() + if err != nil { + log.Errorf("failed to get active profile state: %v", err) + return nil, fmt.Errorf("failed to get active profile state: %w", err) + } + cfgPath, err := activeProf.FilePath() + if err != nil { + log.Errorf("failed to get active profile file path: %v", err) + return nil, fmt.Errorf("failed to get active profile file path: %w", err) + } + + config, err := profilemanager.GetConfig(cfgPath) + if err != nil { + log.Errorf("failed to get default profile config: %v", err) + return nil, fmt.Errorf("failed to get default profile config: %w", err) + } + + s.config = config + + return &proto.SwitchProfileResponse{}, nil +} + // Down engine work in the daemon. func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownResponse, error) { s.mutex.Lock() @@ -738,58 +876,65 @@ func (s *Server) runProbes() { } // GetConfig of the daemon. -func (s *Server) GetConfig(_ context.Context, _ *proto.GetConfigRequest) (*proto.GetConfigResponse, error) { +func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*proto.GetConfigResponse, error) { s.mutex.Lock() defer s.mutex.Unlock() - managementURL := s.latestConfigInput.ManagementURL - adminURL := s.latestConfigInput.AdminURL - preSharedKey := "" + if ctx.Err() != nil { + return nil, ctx.Err() + } - if s.config != nil { - if managementURL == "" && s.config.ManagementURL != nil { - managementURL = s.config.ManagementURL.String() - } + prof := profilemanager.ActiveProfileState{ + Name: req.ProfileName, + Username: req.Username, + } - if s.config.AdminURL != nil { - adminURL = s.config.AdminURL.String() - } + cfgPath, err := prof.FilePath() + if err != nil { + log.Errorf("failed to get active profile file path: %v", err) + return nil, fmt.Errorf("failed to get active profile file path: %w", err) + } - preSharedKey = s.config.PreSharedKey - if preSharedKey != "" { - preSharedKey = "**********" - } + cfg, err := profilemanager.GetConfig(cfgPath) + if err != nil { + log.Errorf("failed to get active profile config: %v", err) + return nil, fmt.Errorf("failed to get active profile config: %w", err) + } + managementURL := cfg.ManagementURL + adminURL := cfg.AdminURL + + var preSharedKey = cfg.PreSharedKey + if preSharedKey != "" { + preSharedKey = "**********" } disableNotifications := true - if s.config.DisableNotifications != nil { - disableNotifications = *s.config.DisableNotifications + if cfg.DisableNotifications != nil { + disableNotifications = *cfg.DisableNotifications } networkMonitor := false - if s.config.NetworkMonitor != nil { - networkMonitor = *s.config.NetworkMonitor + if cfg.NetworkMonitor != nil { + networkMonitor = *cfg.NetworkMonitor } - disableDNS := s.config.DisableDNS - disableClientRoutes := s.config.DisableClientRoutes - disableServerRoutes := s.config.DisableServerRoutes - blockLANAccess := s.config.BlockLANAccess + disableDNS := cfg.DisableDNS + disableClientRoutes := cfg.DisableClientRoutes + disableServerRoutes := cfg.DisableServerRoutes + blockLANAccess := cfg.BlockLANAccess return &proto.GetConfigResponse{ - ManagementUrl: managementURL, - ConfigFile: s.latestConfigInput.ConfigPath, - LogFile: s.logFile, + ManagementUrl: managementURL.String(), PreSharedKey: preSharedKey, - AdminURL: adminURL, - InterfaceName: s.config.WgIface, - WireguardPort: int64(s.config.WgPort), - DisableAutoConnect: s.config.DisableAutoConnect, - ServerSSHAllowed: *s.config.ServerSSHAllowed, - RosenpassEnabled: s.config.RosenpassEnabled, - RosenpassPermissive: s.config.RosenpassPermissive, - LazyConnectionEnabled: s.config.LazyConnectionEnabled, - BlockInbound: s.config.BlockInbound, + AdminURL: adminURL.String(), + InterfaceName: cfg.WgIface, + WireguardPort: int64(cfg.WgPort), + DisableAutoConnect: cfg.DisableAutoConnect, + ServerSSHAllowed: *cfg.ServerSSHAllowed, + RosenpassEnabled: cfg.RosenpassEnabled, + RosenpassPermissive: cfg.RosenpassPermissive, + LazyConnectionEnabled: cfg.LazyConnectionEnabled, + BlockInbound: cfg.BlockInbound, DisableNotifications: disableNotifications, NetworkMonitor: networkMonitor, DisableDns: disableDNS, @@ -918,3 +1063,82 @@ func sendTerminalNotification() error { return wallCmd.Wait() } + +// AddProfile adds a new profile to the daemon. +func (s *Server) AddProfile(ctx context.Context, msg *proto.AddProfileRequest) (*proto.AddProfileResponse, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if msg.ProfileName == "" || msg.Username == "" { + return nil, gstatus.Errorf(codes.InvalidArgument, "profile name and username must be provided") + } + + if err := s.profileManager.AddProfile(msg.ProfileName, msg.Username); err != nil { + log.Errorf("failed to create profile: %v", err) + return nil, fmt.Errorf("failed to create profile: %w", err) + } + + return &proto.AddProfileResponse{}, nil +} + +// RemoveProfile removes a profile from the daemon. +func (s *Server) RemoveProfile(ctx context.Context, msg *proto.RemoveProfileRequest) (*proto.RemoveProfileResponse, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if msg.ProfileName == "" { + return nil, gstatus.Errorf(codes.InvalidArgument, "profile name must be provided") + } + + if err := s.profileManager.RemoveProfile(msg.ProfileName, msg.Username); err != nil { + log.Errorf("failed to remove profile: %v", err) + return nil, fmt.Errorf("failed to remove profile: %w", err) + } + + return &proto.RemoveProfileResponse{}, nil +} + +// ListProfiles lists all profiles in the daemon. +func (s *Server) ListProfiles(ctx context.Context, msg *proto.ListProfilesRequest) (*proto.ListProfilesResponse, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if msg.Username == "" { + return nil, gstatus.Errorf(codes.InvalidArgument, "username must be provided") + } + + profiles, err := s.profileManager.ListProfiles(msg.Username) + if err != nil { + log.Errorf("failed to list profiles: %v", err) + return nil, fmt.Errorf("failed to list profiles: %w", err) + } + + response := &proto.ListProfilesResponse{ + Profiles: make([]*proto.Profile, len(profiles)), + } + for i, profile := range profiles { + response.Profiles[i] = &proto.Profile{ + Name: profile.Name, + IsActive: profile.IsActive, + } + } + + return response, nil +} + +// GetActiveProfile returns the active profile in the daemon. +func (s *Server) GetActiveProfile(ctx context.Context, msg *proto.GetActiveProfileRequest) (*proto.GetActiveProfileResponse, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + activeProfile, err := s.profileManager.GetActiveProfileState() + if err != nil { + log.Errorf("failed to get active profile state: %v", err) + return nil, fmt.Errorf("failed to get active profile state: %w", err) + } + + return &proto.GetActiveProfileResponse{ + ProfileName: activeProfile.Name, + Username: activeProfile.Username, + }, nil +} diff --git a/client/server/server_test.go b/client/server/server_test.go index 11e4d3899..dda610076 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -4,6 +4,8 @@ import ( "context" "net" "net/url" + "os/user" + "path/filepath" "testing" "time" @@ -20,6 +22,7 @@ import ( "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/profilemanager" daemonProto "github.com/netbirdio/netbird/client/proto" mgmtProto "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server" @@ -32,7 +35,6 @@ import ( "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/signal/proto" signalServer "github.com/netbirdio/netbird/signal/server" - "github.com/netbirdio/netbird/util" ) var ( @@ -70,12 +72,30 @@ func TestConnectWithRetryRuns(t *testing.T) { ctx, cancel := context.WithDeadline(ctx, time.Now().Add(30*time.Second)) defer cancel() // create new server - s := New(ctx, t.TempDir()+"/config.json", "debug") - s.latestConfigInput.ManagementURL = "http://" + mgmtAddr - config, err := internal.UpdateOrCreateConfig(s.latestConfigInput) + ic := profilemanager.ConfigInput{ + ManagementURL: "http://" + mgmtAddr, + ConfigPath: t.TempDir() + "/test-profile.json", + } + + config, err := profilemanager.UpdateOrCreateConfig(ic) if err != nil { t.Fatalf("failed to create config: %v", err) } + + currUser, err := user.Current() + require.NoError(t, err) + + pm := profilemanager.ServiceManager{} + err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{ + Name: "test-profile", + Username: currUser.Username, + }) + if err != nil { + t.Fatalf("failed to set active profile state: %v", err) + } + + s := New(ctx, "debug") + s.config = config s.statusRecorder = peer.NewRecorder(config.ManagementURL.String()) @@ -91,26 +111,67 @@ func TestConnectWithRetryRuns(t *testing.T) { } func TestServer_Up(t *testing.T) { + tempDir := t.TempDir() + origDefaultProfileDir := profilemanager.DefaultConfigPathDir + origDefaultConfigPath := profilemanager.DefaultConfigPath + profilemanager.ConfigDirOverride = tempDir + origActiveProfileStatePath := profilemanager.ActiveProfileStatePath + profilemanager.DefaultConfigPathDir = tempDir + profilemanager.ActiveProfileStatePath = tempDir + "/active_profile.json" + profilemanager.DefaultConfigPath = filepath.Join(tempDir, "default.json") + t.Cleanup(func() { + profilemanager.DefaultConfigPathDir = origDefaultProfileDir + profilemanager.ActiveProfileStatePath = origActiveProfileStatePath + profilemanager.DefaultConfigPath = origDefaultConfigPath + profilemanager.ConfigDirOverride = "" + }) + ctx := internal.CtxInitState(context.Background()) - s := New(ctx, t.TempDir()+"/config.json", util.LogConsole) + currUser, err := user.Current() + require.NoError(t, err) - err := s.Start() + profName := "default" + + ic := profilemanager.ConfigInput{ + ConfigPath: filepath.Join(tempDir, profName+".json"), + } + + _, err = profilemanager.UpdateOrCreateConfig(ic) + if err != nil { + t.Fatalf("failed to create config: %v", err) + } + + pm := profilemanager.ServiceManager{} + err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{ + Name: profName, + Username: currUser.Username, + }) + if err != nil { + t.Fatalf("failed to set active profile state: %v", err) + } + + s := New(ctx, "console") + + err = s.Start() require.NoError(t, err) u, err := url.Parse("http://non-existent-url-for-testing.invalid:12345") require.NoError(t, err) - s.config = &internal.Config{ + s.config = &profilemanager.Config{ ManagementURL: u, } upCtx, cancel := context.WithTimeout(ctx, 1*time.Second) defer cancel() - upReq := &daemonProto.UpRequest{} + upReq := &daemonProto.UpRequest{ + ProfileName: &profName, + Username: &currUser.Username, + } _, err = s.Up(upCtx, upReq) - assert.Contains(t, err.Error(), "NeedsLogin") + assert.Contains(t, err.Error(), "context deadline exceeded") } type mockSubscribeEventsServer struct { @@ -129,16 +190,51 @@ func (m *mockSubscribeEventsServer) Context() context.Context { } func TestServer_SubcribeEvents(t *testing.T) { + tempDir := t.TempDir() + origDefaultProfileDir := profilemanager.DefaultConfigPathDir + origDefaultConfigPath := profilemanager.DefaultConfigPath + profilemanager.ConfigDirOverride = tempDir + origActiveProfileStatePath := profilemanager.ActiveProfileStatePath + profilemanager.DefaultConfigPathDir = tempDir + profilemanager.ActiveProfileStatePath = tempDir + "/active_profile.json" + profilemanager.DefaultConfigPath = filepath.Join(tempDir, "default.json") + t.Cleanup(func() { + profilemanager.DefaultConfigPathDir = origDefaultProfileDir + profilemanager.ActiveProfileStatePath = origActiveProfileStatePath + profilemanager.DefaultConfigPath = origDefaultConfigPath + profilemanager.ConfigDirOverride = "" + }) + ctx := internal.CtxInitState(context.Background()) + ic := profilemanager.ConfigInput{ + ConfigPath: tempDir + "/default.json", + } - s := New(ctx, t.TempDir()+"/config.json", util.LogConsole) + _, err := profilemanager.UpdateOrCreateConfig(ic) + if err != nil { + t.Fatalf("failed to create config: %v", err) + } - err := s.Start() + currUser, err := user.Current() + require.NoError(t, err) + + pm := profilemanager.ServiceManager{} + err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{ + Name: "default", + Username: currUser.Username, + }) + if err != nil { + t.Fatalf("failed to set active profile state: %v", err) + } + + s := New(ctx, "console") + + err = s.Start() require.NoError(t, err) u, err := url.Parse("http://non-existent-url-for-testing.invalid:12345") require.NoError(t, err) - s.config = &internal.Config{ + s.config = &profilemanager.Config{ ManagementURL: u, } diff --git a/client/server/state.go b/client/server/state.go index 222c7c7bd..107f55154 100644 --- a/client/server/state.go +++ b/client/server/state.go @@ -16,7 +16,7 @@ import ( // ListStates returns a list of all saved states func (s *Server) ListStates(_ context.Context, _ *proto.ListStatesRequest) (*proto.ListStatesResponse, error) { - mgr := statemanager.New(statemanager.GetDefaultStatePath()) + mgr := statemanager.New(s.profileManager.GetStatePath()) stateNames, err := mgr.GetSavedStateNames() if err != nil { @@ -41,14 +41,16 @@ func (s *Server) CleanState(ctx context.Context, req *proto.CleanStateRequest) ( return nil, status.Errorf(codes.FailedPrecondition, "cannot clean state while connecting or connected, run 'netbird down' first.") } + statePath := s.profileManager.GetStatePath() + if req.All { // Reuse existing cleanup logic for all states - if err := restoreResidualState(ctx); err != nil { + if err := restoreResidualState(ctx, statePath); err != nil { return nil, status.Errorf(codes.Internal, "failed to clean all states: %v", err) } // Get count of cleaned states - mgr := statemanager.New(statemanager.GetDefaultStatePath()) + mgr := statemanager.New(statePath) stateNames, err := mgr.GetSavedStateNames() if err != nil { return nil, status.Errorf(codes.Internal, "failed to get state count: %v", err) @@ -60,7 +62,7 @@ func (s *Server) CleanState(ctx context.Context, req *proto.CleanStateRequest) ( } // Handle single state cleanup - mgr := statemanager.New(statemanager.GetDefaultStatePath()) + mgr := statemanager.New(statePath) registerStates(mgr) if err := mgr.CleanupStateByName(req.StateName); err != nil { @@ -82,7 +84,7 @@ func (s *Server) DeleteState(ctx context.Context, req *proto.DeleteStateRequest) return nil, status.Errorf(codes.FailedPrecondition, "cannot clean state while connecting or connected, run 'netbird down' first.") } - mgr := statemanager.New(statemanager.GetDefaultStatePath()) + mgr := statemanager.New(s.profileManager.GetStatePath()) var count int var err error @@ -112,13 +114,12 @@ func (s *Server) DeleteState(ctx context.Context, req *proto.DeleteStateRequest) // restoreResidualState checks if the client was not shut down in a clean way and restores residual if required. // Otherwise, we might not be able to connect to the management server to retrieve new config. -func restoreResidualState(ctx context.Context) error { - path := statemanager.GetDefaultStatePath() - if path == "" { +func restoreResidualState(ctx context.Context, statePath string) error { + if statePath == "" { return nil } - mgr := statemanager.New(path) + mgr := statemanager.New(statePath) // register the states we are interested in restoring registerStates(mgr) diff --git a/client/status/status.go b/client/status/status.go index d28485bc0..722ee7e7c 100644 --- a/client/status/status.go +++ b/client/status/status.go @@ -98,9 +98,10 @@ type OutputOverview struct { NSServerGroups []NsServerGroupStateOutput `json:"dnsServers" yaml:"dnsServers"` Events []SystemEventOutput `json:"events" yaml:"events"` LazyConnectionEnabled bool `json:"lazyConnectionEnabled" yaml:"lazyConnectionEnabled"` + ProfileName string `json:"profileName" yaml:"profileName"` } -func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}, connectionTypeFilter string) OutputOverview { +func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}, connectionTypeFilter string, profName string) OutputOverview { pbFullStatus := resp.GetFullStatus() managementState := pbFullStatus.GetManagementState() @@ -138,6 +139,7 @@ func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, status NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()), Events: mapEvents(pbFullStatus.GetEvents()), LazyConnectionEnabled: pbFullStatus.GetLazyConnectionEnabled(), + ProfileName: profName, } if anon { @@ -406,6 +408,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, "OS: %s\n"+ "Daemon version: %s\n"+ "CLI version: %s\n"+ + "Profile: %s\n"+ "Management: %s\n"+ "Signal: %s\n"+ "Relays: %s\n"+ @@ -421,6 +424,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, fmt.Sprintf("%s/%s%s", goos, goarch, goarm), overview.DaemonVersion, version.NetbirdVersion(), + overview.ProfileName, managementConnString, signalConnString, relaysString, diff --git a/client/status/status_test.go b/client/status/status_test.go index 5b5d23efd..660efd9ef 100644 --- a/client/status/status_test.go +++ b/client/status/status_test.go @@ -234,7 +234,7 @@ var overview = OutputOverview{ } func TestConversionFromFullStatusToOutputOverview(t *testing.T) { - convertedResult := ConvertToStatusOutputOverview(resp, false, "", nil, nil, nil, "") + convertedResult := ConvertToStatusOutputOverview(resp, false, "", nil, nil, nil, "", "") assert.Equal(t, overview, convertedResult) } @@ -384,7 +384,8 @@ func TestParsingToJSON(t *testing.T) { } ], "events": [], - "lazyConnectionEnabled": false + "lazyConnectionEnabled": false, + "profileName":"" }` // @formatter:on @@ -486,6 +487,7 @@ dnsServers: error: timeout events: [] lazyConnectionEnabled: false +profileName: "" ` assert.Equal(t, expectedYAML, yaml) @@ -538,6 +540,7 @@ Events: No events recorded OS: %s/%s Daemon version: 0.14.1 CLI version: %s +Profile: Management: Connected to my-awesome-management.com:443 Signal: Connected to my-awesome-signal.com:443 Relays: @@ -565,6 +568,7 @@ func TestParsingToShortVersion(t *testing.T) { expectedString := fmt.Sprintf("OS: %s/%s", runtime.GOOS, runtime.GOARCH) + ` Daemon version: 0.14.1 CLI version: development +Profile: Management: Connected Signal: Connected Relays: 1/2 Available diff --git a/client/ui/assets/connected.png b/client/ui/assets/connected.png new file mode 100644 index 0000000000000000000000000000000000000000..7dd2ab01ae50f84f784803fa2318c58578e04369 GIT binary patch literal 4743 zcmb_g_d6S08>aeFVzsE%RJDYvRVjknRTQD-YmcTzZDPgNSXDDMVy{v!wPHlgH;Bew z66fpt8CY6t|$7&y)@Qt+wZAlPtP}dv+`dhTR^SwO<-mS>;QYv3RXO0Yl>g(Jd5tY^9yO7fm6hv* zc$X0R{hpY_q0ITgR7?G;nGP->=g%l2`Oo|aeIf~Xc1G;Xi)=?-%>8z6LFXZwgf7hH zPn|~dvN_t+53P0@X~cUIekvl`Hw!mwq_hTW=TD1#Rzs8xrYDurpb|kmI zyRUaJQ-_f=>qB2Y*}7iMMrI?=y#XZ3D#;mJ(!2%NspeGWZ2JP~Lp}n3ww5G7 z4vH=u0)$(rgXP(K^ZV=3Ic~=T%JjE7jQ0<7b;nOYKlpr_u9$& zBOb{3J$Rb)J=l5#x+GcE1uT7LTB@$0IAV|>23E?i16{y(L#An=Ny1Ig(4fEN@aCgf zc@Dn}uV=dHa2tP_P_dRKDW6rXv6BJiofB$!F~T}nY{yGK!`nszV<^}Cr39*f0Kp8x z2>Lq&Mf1E|x+)fhLdI@++ zeGA%}RfmHxt;Jykc`_@FTyn2?=H0jP* z8<%2=*sRug*Yve)JzxmY)WZnEBp|+j(gl0pD*C6f(a++W$sZ}Sj1iYX=5ho-&#Q`b z{t5mgOMtPfogoreE>Z?$Zh&IY0pgM@(DZs)=L$&c zTuc-y@r8B|*wy#3Hmq#?kV9v0j1dYx@|&47sc8Y^;oYRx0<92D5_LwiAxeP^em>F8 zbY}h+g>P|~2i#?NHBGz(M4VqC0C%P2wJ%Jp3R-_v)!_ zqag5JZUb6|D2k!k>==}!sY|~55GlztArNZcT5$`?L&pNVqhh`>l5=F%7L;Xo&ck_9V`Us=j!tcUL<5G{tU29`EET2QW2NKamcnnKP}um`91yx6 zh#r|VbS#cZZtg_ZVROP=LAI}w@Y6IUGJepBt?wSuKE7{e=R1s7v`Sl#YHp8mi@b|t zg-BU5=2@gKlu&rW5O(RQ~6pStHHor*Za&PNCrFKCbTaxEoLs$GovK{Kf zaHxVTKpb2J_i8WqMS8J%w8-I-EJcWj3tO6~E)!dx#X@0{1`X&bb#+_uIU0LA%({>L zxulevNcZxhMsC4)sW;B>5B6>to~2}OO@YH$D@}MJO8};1@6k%B!<86^+Ulz54y46A z9w43Ivxa(n!f!cOWM~hLJ)JLf)Zkni5bHKM|7Yn@U>7+&MPe{WvqhNIO(#vpmVb4- zWN81=dMk5Y@$8=GZt#Y*vl%Pkp+WZ3Djw$a9p<@O?w^}b4C&8ax1?usk{2HR?;B%z zQBk%Jv5fSHL;uf=^kUL%Hy*yOKuMORmoMon$8tEd*?|KfqKx`Kk(TYQ?eYqh>>rT1 z2J>#U$fCb?3)Z8%_cD3#HZTN40QC0P>md15jaTy# zL5*t4%)8-VrT5h53-(r+@+@JAb%UpV-n`7aS?^u`lI9)PwoqRezm&eSZKYjxoXW#m zGQAOIlc!d=`Fg+IUV!|wmIk$(%%k_cQh{~qcpoI=Em6M60ZUhqa~2>hX_ZvJ1Ou`2 z%KiPZ%FBOqMTK^6$()svSBKW`GxejjzulSQwrGL(ff~?p-7vXKuDqng7(xm+uJ#n0 zJ50Ok7tb!X=BRohBf2_}pjyk3%mAXl_syAh zpL@-)Uc)SO*79=!n0evb&UnqVVo%G0?>JD^G?uB00s2+S-gI*y8(XR8fk(%6q0`Nt zKzPmL8M3f2FJ%i8!D=z#L!1H#EFB}wy-Y0tKS;V3mg&6cufWGc=AnHDMne*##lle9!3b+y@|MiHJc0# z{}B|Wo%R&$dDNx4EUWtOWW}zUyxDhKiDJtD<=dG2e}229Wp_h|i(Pl@Kl>c)V+Zxb zLEzAGhGR1WDwf&@JFDM*Q;@E#QeYG3S8Y?{9mpSgBp*4Os+cYMzaBu{>Y?*aZZ;~u zg6__4v6rTzc*rNyg~D*!spN{rd6;KTfz5mANA-_QpQ=r#GT;T6@IFxyBT+%mF->Zh z8O~@Q!GmjsD#5JU{#~t}gtQWB5RAqSJauQayX(pmPw&Ua!J*Myt40$|ylO#X8TWvS z%E=_EBHi!5cWm2tQvv1a%z>+1Di0*Ylcb`<3)Z#MJcN{|(Y$B*i-v{kaGUQ-51brS zL_BwD*1P)5+)Pq@`P(hg==7z_n2xFnB2VWs>mJR!XoCYJE~|MD*-MIgI8tmmbx-st zz`{Y0!&O5}xbXgV+=Tup50R5D>NT-=z;?<Aje?*B`mM9Wl#e)nlmZx)LGS>E+eK64w{p@a#zR+i^ZbQSG5ZTr*?iu zRht5~_`eR?Yh(x5_r80WcL}N*ORgoF(1|;6gRONw*p!VWuS(@USM!yHP9SF%JcA4J z;Wmf4QtQkAeMu{tD0JNNv5In}9Lw~lKaGx#Dh&dCe5j@?$6{;y$^O}y!UEc@9y$`8U#Am4V3lHQ|4R*s%F%-`Z5B_Limg+a%>t6`$7V6;H0lF?7@9VUYJ^Nt6KT z!EU9`>iv_XGs8=|e3`HGMC2YCWf5M9aOEZ~+;tBDReJe;D!^6@Q@)p$mxavEkN-vn zJ=^ajMGI$d+Sm^VO+dWcAQ)b^AQdw_T1VnuXzrK&;Y+^rBBBF(OdWI^Z3hD_D3EnA zcSR}~)B%{ap}Rap7TFk>e&e783p)j`$TI1LoMYnqts6l5!aQvI&MK>IDD|Z>5$t5p zQwjWf?Uij=!l|@NriN+_nXe^HPUbx5xQI7p^@7!c#XJcM^@7(`LRbVARvinTA%cbk zN*FUAF7-b$qgXUW4yY;jX8*yd_0%cmaJ=EKLf6z=B-YFgr&~RQRVXgG;@o!3Fu-2+vg_V$N`9zuZr|BpA zLmpSMe9-oCudRX+6l1|+^@6wbOCWpMhosSI+%%_y!3g*X z3UIwkq#Wi8LKLHeh=@I;2LQu@A=vN6z;<;iAK5fdsvT&I`B%&VMCjsJ<*JZQL-qXH zFdUT#un%=PfiM`Q0X*TYM&$EvrH>ABmbJ9*j*x+zr zQJ%Pw^_(0b1+#w&Ppt-8B-fg@1juq~5BBnKb&gz{F1woFCK2<;B=@?>+))>omq-}B zn(hVeOL{3k${pFTJcXL%w;%e&#p9dNfIdpTJ%ViAGPgtzBSOObkc)6)pcE~lWib%!$ zMcXeg_dmV)d=AMkwOP=jPGKR#T@nw2ugmIB#*8ygn<5aNk&L&P-Z=PM?YYfrzChva z>Ld}!-+B<&8|>_+TVeVv0^oYp*+QyLZyjI>-^jn;Uvnmz?*~IZ!VgWyR(b7~Y`#U>erWEmPXcxu&LhM+ zJhoTTDHQR}bJg2+#K8ZP$7u+~Vg9AC+J@awXZ}`!u-IJ5u6F)_eAt10IPY`AVV;)w zQaf$--eUd!%<@NF5K_G}vRQJckEWXKh`!h3dNP4_DWxj9(``U!dMYg%kR~Cu_H6wo z^gI+`fO8)&iiqnn8cb70b4tNpi&IQHg%A%twx9;}3Nj)Cq-#LoRVM*@rY_Tx011n( zu|W=~JnkrP)^;YR;52ZfEd2cRG2w#AEAf1uESr&&hX4zhW)|BZybV+LW8_v^=rqcnQF@iHrNi%3C|!8k)D(G z`HE2Yv9IKK^0DOnMCirwSE3nrg0Z$eE;s!VIdCXlxXB{%#wv)4P#1KMLi_+ZmS|=y zr!2Q?v0b+o-=B`1LZF=PvgTEcLfi3D#-Sk|0f*)f+8gl*)6}7w;n1Q|D(R_ytK?9T s^lVYL-u_s(rGlM5fBhf-UMRZGdDvsbuOptA$+3yZ7oC&9ytgCfkZ4X zU%Una@d<-K;9q|T04=ic;bjm=nQnQ}^hUJX^5lUw7oWJjkjRUNbe=U|{p-!0km*BF zF0Y%L?v_gEG#$a2U%GyLehM})F`=-}?D3Vy5j9Td93DB%XN+E}I$!B@nIBmfhS%)M8vcE`t;z?abOYuV?Ga27(Q}2l1wHW z*jZaoRQO>m!u5;d#+QGwooYs{*lbek@ z72gr{jnm05Hc`u+-Us^BnIt0f)^KAcF{s3*bC>mDugvf{L8JY&9j)Ckf6ukJ$E2vl zZh%=-JO|S&emX!JT&R$HDMY@clTF|%>AAV+Z0s`jFgv>@YAE+~Ds#^>eSvmMUj;RZ z3p3PT4%QM% zT|xM8U+wI)$+|liyRbjt)#BO$gjtL|;_5cyMSD0D%BKa}+$&JCGwG|dBdvZ2zY)*H zYv-We>foGe2t7I!45zSc zA4bW8^3X(a7r)&tj&q$eF0h;R0j>1G2zxc!Azgfv)6FW|s9}}qQKrs)&87~1{}vv< zh2BBhK*hEf1{Y_??o1kW8{|ZJ6ZH-!iafO02MTjWTu`Y?W0&C^oiIurUMz)Z+xU*w z_^e@OV10h8W)UX6C&`m=v&vz4bfsjW;nMDcD9EeMpL|IM!ZLc6@M_!;+Bq^h|PYNC8^BG z7|qfh9BcVb7bJzT7Tjw8)KlwrFbunUck!6o+&+)_M$50?KbwiQTR{)l^*kMBgh7;4 zax-Lz?PA26!W3b|1M;cYvEEN=iG);UMI&r`J3P3KqQW#>Dt`N|&p-^+_aKx$=TPNP zbB3v4SUpZqdhc1p?f-{pBuJZhna*wtebk#zPu*N+4Xh>Y>})iVW{Lz01VO&Go6VlR zS&HMeuK2fw`Fk^K3ka58Tfm<~KF|Ym?%bm}{bnus!+7ji%-Hbo@Kn1Sv)Br)HaT_R zDQMyhH3hI1hjfmuF)9UetcQG8LI%AXkvXagC1NDeB)F{67P{T*j4v$ZJH}S zz;1lkZRr@f?Oq)bOV$o5EG&dfe?Sy$HGo0XqmPabO(DMbu+0~~?(8JWqF^TP?rw*v)P0n#2EI&B7Fc7Ny@E7?SC$vk3xw1plNY7X2hmZlcHoFf|$ zKRur$$lqvzh$lR{{pH0a=USH*R=N$@P{EF`j`;L_1*`1mUmcaI}+U=LB1X{mhA4tUb`u!xct zZNXgx!Iv3$b&|X8;8I*Nc)6;e9ASOT%V0Re6!hs9qCmH6n=3sZWZfOH1#gRSG}Us$ zp|tA4$g7(OL`$++3)*5gZ=Xk6HtrEx-Q8{AdbAA^6~8-O*>%YSu!b+u!*<-;-sBmDq1yYI8{1mncF)ke-m0MNUH4xmmP#~=6=|*^*4lFp zq=D2a{M_s>ux2bhY*F5L^=FaT5g!We>;{$<-6!x1AucW+)3lAlL@Ms@RL-^OZbC(-@4j)eA}Y2|`GKVFvkqTGdL#a(tVWQJFo`vvNR$^$_YQM> zJcPeKB^$rwQ~Hx4-Ra&_KD%K2;B%hiz`f?`!oTuIX=h{99j}1`PEhBP-}OGGn(Hwy zS5&wQOAid)%I$w&H>9_@SrX=wJ!flN4$3q2d6hf+O}l?C)Mc8E+P`{}b3T`Lo@!WN zh4tL1FsyYSVQ`2Co{SG?AV7i7LnAkdIVw?us|mtmBGPXi70_09ZPDHAa|pvdM<=(H zKl7*Y_jbPu&iFRZ3YD@BE)`RAZV50_dT$=WI8x8L2Ew=7S4 zwrwFTlIUWKrYilV#1`G|@@|6)@C*$<>ubFY?2$RimE0NTcfbU>77g^xadYUR!0|L6 zppM)`IKLga+x{Z9qX3Nb>+a8MI#2ZCVhp*1q&ZaR$m^TzkVw%ueLBj!B%P#o|4~42=va9SZ(kFNo2n> z5W>U%h(BGhe_H$WWLrme1a|$~8-H9C<*&}R>$btx>nm#}1*!#)LU$2HQi_w!q0VdS zZ=GtCIORlsxzRYXrqnSWaSw6!`#i2U-2j-mr+d9gL&aY<{AhMBPJ@K@^2g54}uwZ&w zlOJAzvD`ixxNbL50nQUAj=unoU#w?liy_>AO%jG$7CftNkz3ffFIzJF>(R|lT~g+^ z+q}4ALvAxpWRN)LpEmUx7}WHZ?|M>;I*^a~(k3D(t~f-?wu{?jn;}2rp=_PBW}m{n ztpRJ5#`~zBJjhf{In&vb-G3p8EL!8hXpz3r?SuVUC2wXu1j#uU?c1V%nh{EAG-jL= z()g2F`k%Em;n!Z8Wwq7G=rS0SLVr62=shl6 z5{D_@JfH43VrDrM?J(eXS*cu+s`&$XkH7C#N=k~6X0WK0NtIw$Swmd_%kIkdu)mDQ zCt>4%bRk`|rgW7PWi4Oput}&V#>PfQ?P2Yoz=3ZecPm;Xj*i5??y1IM%j#k+to5JoE@&*yH_G#MS}dSq{~o?t1q? zKoe_>Ha|bluX)SZUhIe^yIkFS?x#0X6C_hJF{7ARHW@FUO`@ov1=6=F)5R@>@ zc$O&I>YCKcr4OeEJ~MPCQ^grYt+O962bfc}t;xa9{v~vK!^PGV(#eCmEEwY5h7)(Vov67rJjA z9=fo&7$DX0^vzz@y91fE7*LW5`lc!ma!AX{4a{tO$-IBUFMd6BE#HTfyX<0&sfO?C z z#eaA5IT-)bQpK0K`iqZg-Pr&lX=*(n~xI>LoM2wtG zC!f9AW8!<1`Bv-Gql%Lg#I_&5f@o_S5scBz^aFVdMAxG@;scbs1lMly{&7p4QVuVB zLnE$xl=DO@M@&oaR0Sx^Ip&iwVHBve*x>!6P2&&h1Ffa$`?PHeeLG z7lvwH74nbVedA`Riw5VZ6LStHXGv`xYPVK4d=T2Z`P*_MOjpy@K|RzG^PJfIXSZ)xa57Q=9ClIS=SeJV zV<&UJnx&Qlt7R+evRfRsJ}8UZU1*Dxn8B*Chq?qX>7I45$;V2dT1$;Cem1*E)u?gOeb zSu?&|tm}Q)YAE1Fhx6?{4ucZpI;iYnon3J5rJ!%uPqI895sJt+rFaXdrWP@z_ot|;*t0BmGZIh!>X~Nm7#_2sP9p}7PuMAa>X!oB50&yiLjqlS%iA6VhGZJWq zL?PYUWbFar!oszg(E*!$dGe`_S-v%W&r2#2bppQD`xsRv944J;zbEkn8`k#Ptu0HN zl%w8m(B?vO9X#Xlp;3cbEYihDg1A^7xg=K%un6aQBm6NmFzgwiQw~|?Um6hd>@}nV z_+;+KmY}WAnU_>7v2IYI=^v4J6zpFi?s=ECULwC-8TwD_@dyTNf0}*Eg%dJKIxpEt z5GP9vM>C_gca>l8SoBOJr;pEkK7n7f*+v3B>bvj{ZNiM%aYIr=y~m z)`C7FOM$6QE;R!3hT=Jc*e4H>EYT7q%Zj75Q9@ZwH#?;;@tp9 zGI;5S!Ts(WAmN6wJ=SKAMXRCiX8{l zJaPfWGU8@b(Z)m5PY0wMFDv@VHhLz6WIoz>0_kD@cYwEieCzl!H~P2Wo+R)mQ?hoqXl+=n zUXCOb?%z!P?;uzjR5`Ly&`fYGB zADyv0sfwiw>aASk{ZakxvOaNO*^Z+^fNsKQLnHJiUP{HAWgnn);aef?fXaVxb+RSa z!X-3Lj6vEI@f%E0Sap_mqb6FdXH}eMf(kN8Xz#1VF6w|O_>5^3=Fg2Y^7p4*mY`}i zDrZzlRw+DgY|Gnpppl(P|C-)gfHEE7 zSy_&X)_kqeRG%`Y_KS4Lg$mZNpE---%>1k2c-`ra1ca zR}K7cjoV`yOMsuc-SZ-+45qF{KS{sOspc~e=*OHJkHSwA2a5;=73+%1997BNOQ2)|IU4~x6cmqyfBuBTc`f@*gr(+4M&xXuN-l8Y;vku zsB?O=`0MwdmP_HACtgFol9M2}Ex*l2;3t{1=E2{25d1FA}mS(+&0 zN$q|)G~$MsvPYVC!%Ch6g%_JD16!WB7!@R(yR{xZFfV-DPoXUs_2h ze??8>9Rhhd5~arrbLERKc)7&4j-ey8q)==&WFn(nC>e6c^4pE0QC~fimlu<4G?Ib5 z)e*j0eAe3(2#l3}n3u`2p*=_c4Z&0GLkGrmU9PL&1!B7UB87slQY9`U5(0{R_X?)r79u0reBanyOT(O?F z3e@BpKLfx>1$|G)CTt1N?o!4Wf8W_|In<5=>~ku%eh=@qlojqzblLPvx2Sf;4MbT0 znYVFn&kOE(iS;_hesxRSi#po=Ua(rlLj<4^f1@({->Wn>lg9cu571?g>EHTt))#<) z7M7s*Iu`Q85flUWM8);IK$khY0O51c1wjG}$_5wkfSltdk%5Olu^Ucx%Yv5oWrxDM zE3!AQiOG710tEHsTORL0s(K*j+kjU8r{xPMv#j7c^dK5u3~}p--85Q+@%M#Lx9KFs zTpO}%(Z`j6j4WttUuMukb5Oc*dcvl1YD_nvrr#lJY=bGE%%;^ECI=Z4K%4s`uS7!C zsmVuT%a&f2u5wD@-bL!EU?B0fWAn|6a?-%GC!y`(5^utzGmxua)%mYz)&of_< z6APxHLqC2n!~P(ATWxq)=6s9R?N1i<$sO+6p9goSbBn`6cA~diURfeYcNN}fp$@WR-jieOl zC{p3Z&2%Nea4JTz`vBPgw>WG^z5mcgKsR1EE3j3F%{Lb`S`EX-8CfOlKhE!g3%`@} z?xvZ*Lj_lN+_=$MBYd%ZC*Rd69{`o6Nbf7D6MU#3qHj>+(i&a~EC)QkQ{N>zeoiJy zk?IRE6L?hi$aS+hbQfS$yj-QhLN1%vUPta#)wPq3)AzC^|E$Ov?}eK$-5M12xPKaL z;Ufj)NTNiKW=*s26wc?&rsH^ht=UNc_4nX~fQ&I>#7Eg%Hydw~TYw!-G#lbmVb9>{ zEdZn58fT=H@W`J{r_&AKy4K|-yx$M13h_U2039Ab4fIi zdt~Ni%(Tzso?%=(kE1dE2NXCYUb69_1}LY*16or84vjFf2a@K-tIfW_(o}z9%I6M` zgB$IdhLM)-!jFR}W|xNAmL~!vrb;3RF-GE`q$3lWNP=*tSmgG)?<>L7vdxhZKu`gP z`#qE|9Ow{V5eml%XG;1c#10mXubu`Q-ord!%`pduuTmq(hpCIkl2)(JPGdvxO@$tRohP4KHSwU{NX zX#wxtsXk_yMRRd{c^!N8Ys^pTHQ=;;&0B`g6g;Bj$ZJVKIiM_Uoz}PGds!_bV;U2m zjvl=)?YN>^Vv2q1ayQ>|uJnLk`niqEtKY+|p$G9J{<0pQyn76F%@!t)PTx7*|B;wN za<{XqqCh*-&IGxO)?zlgZTGg8k;g5MkQ)Q-tb=;=r(U~_gR}yXNaTWWp`-859SfsO zg~#{c@cy@swJwMtv^^j8#l+`IAdSZYX{x^2j~)3$%OR;85w0B;7QOJ#-zvuz|GeVi zkb|vaU!l55Utro$V+G{wF3*kKgq4<%PVMQza0bKv(jHG$3XV>Pd z7^bKE$hmLUyQ<$smA6xBoN@Qs&B33{rLO_Tb%x*nrnK8{j`7U&;njxf_KG^dD)7jA zd}&4eVurzZ;aVfCr>VXO)R ztmJ2IaE`Iuz!xv+_*tG{N;M&7Cel?Zx1-ap|wW=o$kc|@#w^##}+l-m>EY~sgaX4 zm*O?O4_#?ZA>OornUwSmsyn|87#rXm;m^|vY_>BKm};mk&J2uC6qn7gp$_#LjQ9(A zOi361ftVKVsPp}U+jcB1D=G4>-n%b)LY)t^zG`PzDG?oij&@zWMhHFkfnuOmv`^wS z%hqi3{cGX(I3IHGR;KA`)gQzbV>%P4?KJiC)jtEm{J6K2@u{u`^uQ~@@+Nzd`K*Or zylyuM%4S)SiwvCyJ;rYYh{EDT$(9lS{T}be>pi=2$FwJmx>z&9gUtD8n{=oM zW?cl@O`&bf1z@IgBYt-Nj`wwkLGfN}VbVQKU3Ux^c_kfS;HtEzj|8L-+E|g*VEoR* zLi`Ex#7@~l;muy*>$t4Z*p+|6FC1eJS~^o_35fP%pp8kQfk@P4J4+RSX6I%C62+#LulxnIu@)p4c#nnl(xlH%CSKKWudlah`)2^u$p%= z7T4?!sg&%F^23f$YFB?mZPwmWmPg( zTP-I>#iiJ*h4B6ffm8b+<2VTV=3R@(w$C4^yBcW>dDX|4sH^$LXtytssqlm`#{p&a zAcpTDf_~y=HqzdzR4t=FsDi(uyqsI#X5e6w1kO{Wj9@Vkt>ys~#3;%42BY{xCNn7Iv ziQ7p3hwzH1fh3I=OXat5D}#y?0et`jrW7jy!8R$pwToZpJ0Hn7R>x?O?*X5wq9 z%dUp59Z(tqu`lTzxCfM|2BbH{fGQz%N<_{EJ!*IGnk)1mwWu-!p7_%dv16z}z&)*! z1JPQ9cWW9$zUn3JeUaV?>JxuQphkyKB%SCu;@s@(@Q@0W$s-B?`kyBsL!wz5m2%fX zeeTqWq#Bk&y%kinu5Z;95vtoVZZbFd$<Y{)ITj9jX}Xr1>Eo6RQougYIg!v zWNU3wR}!F&WBgLDQ^kcE%P+f6`LG)JLAYVgf^L~X_G2?1m1S62ec7G4M!5CE-0nO^ z#;4SN>Z)>-uM4`^jk~MHqf=Pv51jS*>%XuRt__4J%Or=&CC9P6*{rqo&CRs*{w|lw z*s^d_{WE(jQ<-zdJWf?=LS)zic#%*zz4huv;qjKB;^?XDx#^q?h(1BjvN8U}(2p|( zEv$$Icls|oUsnt&3OJ#^MuE7o>>K!YUuS@_oj!sE3`=v%!!$h~CscHe(a}d8Z}r`! z7CsD*(61nysLy)&U5ek8)LUtZ-`%Rw!CHbm9zxX%k}E4K@jm>k*^P)E@nN{ZF8k=f zBA0e0RmP|ncz@BM%97OMDtD6$$~SMvyFEdP4XQ8|WP>}wxYZ~8Ag@Y8(;N~Dct~`e z9p?YD)=W=^5Z_3yfuP=^D}r8%(J1OFa*`s@Kc;}`R+JZvAiRQ*Vg7QKQX@QK%l%E< z?oma9mF3{nBeC3lSd=w|HmRc6*xDKg0&i)e)KTeFnah&uZf$RZQulduv~mlRd8Lj* z$DVc4)7d$;_r}iJsNvxb1A?)>2}-r7pjL{dpf9HaydPn%ZkTUX>>%UPNx+^+Bm=Fn z8%AfkC#N~JQOow+okbIS(5*uP%nK9G$*6@{OTCU)Z7+@$@;B6u6Zq!bS=-ZntA*h( z9IE2|R!UTW${q^8TW~5T!)u^tl`yoxVhaut15sUTS!>yAojX&uKwcKxg2)Z~6EqXm z%=JnOJ?()+l1?nCf$7Va?cb^*XwfJE&BDa{b)(ZSJ?CVOWPH|rqDmVcTj!N*{ zL=Xsi4Zk-sVL*-EGo%}zw(FoCO~KLD9m>~VdI{9c4E7CDOCPgT(-sysx^ApQN6*fr z>?-Nyl8i|=WeKjE?J_xg?IrWf)9$-FrL1Y44W`<*;uJ2(%pp%g^_XPtTr=I@+^0aKF09V&_d){q(IzwI3KxGYcP>tU_i<{~U;ksC=l#2pg#F+5`0u~a V>cxt`kmog7Ub4Mdd%@$u{{j=u(k1`^ literal 0 HcmV?d00001 diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index c18d96dae..781ec56f8 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -8,8 +8,10 @@ import ( "errors" "flag" "fmt" + "net/url" "os" "os/exec" + "os/user" "path" "runtime" "strconv" @@ -34,11 +36,14 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/ui/desktop" "github.com/netbirdio/netbird/client/ui/event" "github.com/netbirdio/netbird/client/ui/process" + "github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/version" @@ -54,11 +59,11 @@ const ( ) func main() { - daemonAddr, showSettings, showNetworks, showLoginURL, showDebug, errorMsg, saveLogsInFile := parseFlags() + flags := parseFlags() // Initialize file logging if needed. var logFile string - if saveLogsInFile { + if flags.saveLogsInFile { file, err := initLogFile() if err != nil { log.Errorf("error while initializing log: %v", err) @@ -74,19 +79,28 @@ func main() { a.SetIcon(fyne.NewStaticResource("netbird", iconDisconnected)) // Show error message window if needed. - if errorMsg != "" { - showErrorMessage(errorMsg) + if flags.errorMsg != "" { + showErrorMessage(flags.errorMsg) return } // Create the service client (this also builds the settings or networks UI if requested). - client := newServiceClient(daemonAddr, logFile, a, showSettings, showNetworks, showLoginURL, showDebug) + client := newServiceClient(&newServiceClientArgs{ + addr: flags.daemonAddr, + logFile: logFile, + app: a, + showSettings: flags.showSettings, + showNetworks: flags.showNetworks, + showLoginURL: flags.showLoginURL, + showDebug: flags.showDebug, + showProfiles: flags.showProfiles, + }) // Watch for theme/settings changes to update the icon. go watchSettingsChanges(a, client) // Run in window mode if any UI flag was set. - if showSettings || showNetworks || showDebug || showLoginURL { + if flags.showSettings || flags.showNetworks || flags.showDebug || flags.showLoginURL || flags.showProfiles { a.Run() return } @@ -106,21 +120,35 @@ func main() { systray.Run(client.onTrayReady, client.onTrayExit) } +type cliFlags struct { + daemonAddr string + showSettings bool + showNetworks bool + showProfiles bool + showDebug bool + showLoginURL bool + errorMsg string + saveLogsInFile bool +} + // parseFlags reads and returns all needed command-line flags. -func parseFlags() (daemonAddr string, showSettings, showNetworks, showLoginURL, showDebug bool, errorMsg string, saveLogsInFile bool) { +func parseFlags() *cliFlags { + var flags cliFlags + defaultDaemonAddr := "unix:///var/run/netbird.sock" if runtime.GOOS == "windows" { defaultDaemonAddr = "tcp://127.0.0.1:41731" } - flag.StringVar(&daemonAddr, "daemon-addr", defaultDaemonAddr, "Daemon service address to serve CLI requests [unix|tcp]://[path|host:port]") - flag.BoolVar(&showSettings, "settings", false, "run settings window") - flag.BoolVar(&showNetworks, "networks", false, "run networks window") - flag.BoolVar(&showLoginURL, "login-url", false, "show login URL in a popup window") - flag.BoolVar(&showDebug, "debug", false, "run debug window") - flag.StringVar(&errorMsg, "error-msg", "", "displays an error message window") - flag.BoolVar(&saveLogsInFile, "use-log-file", false, fmt.Sprintf("save logs in a file: %s/netbird-ui-PID.log", os.TempDir())) + flag.StringVar(&flags.daemonAddr, "daemon-addr", defaultDaemonAddr, "Daemon service address to serve CLI requests [unix|tcp]://[path|host:port]") + flag.BoolVar(&flags.showSettings, "settings", false, "run settings window") + flag.BoolVar(&flags.showNetworks, "networks", false, "run networks window") + flag.BoolVar(&flags.showProfiles, "profiles", false, "run profiles window") + flag.BoolVar(&flags.showDebug, "debug", false, "run debug window") + flag.StringVar(&flags.errorMsg, "error-msg", "", "displays an error message window") + flag.BoolVar(&flags.saveLogsInFile, "use-log-file", false, fmt.Sprintf("save logs in a file: %s/netbird-ui-PID.log", os.TempDir())) + flag.BoolVar(&flags.showLoginURL, "login-url", false, "show login URL in a popup window") flag.Parse() - return + return &flags } // initLogFile initializes logging into a file. @@ -168,6 +196,12 @@ var iconConnectingMacOS []byte //go:embed assets/netbird-systemtray-error-macos.png var iconErrorMacOS []byte +//go:embed assets/connected.png +var iconConnectedDot []byte + +//go:embed assets/disconnected.png +var iconDisconnectedDot []byte + type serviceClient struct { ctx context.Context cancel context.CancelFunc @@ -176,9 +210,13 @@ type serviceClient struct { eventHandler *eventHandler + profileManager *profilemanager.ProfileManager + icAbout []byte icConnected []byte + icConnectedDot []byte icDisconnected []byte + icDisconnectedDot []byte icUpdateConnected []byte icUpdateDisconnected []byte icConnecting []byte @@ -189,6 +227,7 @@ type serviceClient struct { mUp *systray.MenuItem mDown *systray.MenuItem mSettings *systray.MenuItem + mProfile *profileMenu mAbout *systray.MenuItem mGitHub *systray.MenuItem mVersionUI *systray.MenuItem @@ -214,7 +253,6 @@ type serviceClient struct { // input elements for settings form iMngURL *widget.Entry - iConfigFile *widget.Entry iLogFile *widget.Entry iPreSharedKey *widget.Entry iInterfaceName *widget.Entry @@ -247,6 +285,7 @@ type serviceClient struct { isUpdateIconActive bool showNetworks bool wNetworks fyne.Window + wProfiles fyne.Window eventManager *event.Manager @@ -263,36 +302,50 @@ type menuHandler struct { cancel context.CancelFunc } +type newServiceClientArgs struct { + addr string + logFile string + app fyne.App + showSettings bool + showNetworks bool + showDebug bool + showLoginURL bool + showProfiles bool +} + // newServiceClient instance constructor // // This constructor also builds the UI elements for the settings window. -func newServiceClient(addr string, logFile string, a fyne.App, showSettings bool, showNetworks bool, showLoginURL bool, showDebug bool) *serviceClient { +func newServiceClient(args *newServiceClientArgs) *serviceClient { ctx, cancel := context.WithCancel(context.Background()) s := &serviceClient{ ctx: ctx, cancel: cancel, - addr: addr, - app: a, - logFile: logFile, + addr: args.addr, + app: args.app, + logFile: args.logFile, sendNotification: false, - showAdvancedSettings: showSettings, - showNetworks: showNetworks, + showAdvancedSettings: args.showSettings, + showNetworks: args.showNetworks, update: version.NewUpdate("nb/client-ui"), } s.eventHandler = newEventHandler(s) + s.profileManager = profilemanager.NewProfileManager() s.setNewIcons() switch { - case showSettings: + case args.showSettings: s.showSettingsUI() - case showNetworks: + case args.showNetworks: s.showNetworksUI() - case showLoginURL: + case args.showLoginURL: s.showLoginURL() - case showDebug: + case args.showDebug: s.showDebugUI() + case args.showProfiles: + s.showProfilesUI() } return s @@ -300,6 +353,8 @@ func newServiceClient(addr string, logFile string, a fyne.App, showSettings bool func (s *serviceClient) setNewIcons() { s.icAbout = iconAbout + s.icConnectedDot = iconConnectedDot + s.icDisconnectedDot = iconDisconnectedDot if s.app.Settings().ThemeVariant() == theme.VariantDark { s.icConnected = iconConnectedDark s.icDisconnected = iconDisconnected @@ -342,8 +397,7 @@ func (s *serviceClient) showSettingsUI() { s.wSettings.SetOnClosed(s.cancel) s.iMngURL = widget.NewEntry() - s.iConfigFile = widget.NewEntry() - s.iConfigFile.Disable() + s.iLogFile = widget.NewEntry() s.iLogFile.Disable() s.iPreSharedKey = widget.NewPasswordEntry() @@ -368,14 +422,22 @@ func (s *serviceClient) showSettingsUI() { // getSettingsForm to embed it into settings window. func (s *serviceClient) getSettingsForm() *widget.Form { + + var activeProfName string + activeProf, err := s.profileManager.GetActiveProfile() + if err != nil { + log.Errorf("get active profile: %v", err) + } else { + activeProfName = activeProf.Name + } return &widget.Form{ Items: []*widget.FormItem{ + {Text: "Profile", Widget: widget.NewLabel(activeProfName)}, {Text: "Quantum-Resistance", Widget: s.sRosenpassPermissive}, {Text: "Interface Name", Widget: s.iInterfaceName}, {Text: "Interface Port", Widget: s.iInterfacePort}, {Text: "Management URL", Widget: s.iMngURL}, {Text: "Pre-shared Key", Widget: s.iPreSharedKey}, - {Text: "Config File", Widget: s.iConfigFile}, {Text: "Log File", Widget: s.iLogFile}, {Text: "Network Monitor", Widget: s.sNetworkMonitor}, {Text: "Disable DNS", Widget: s.sDisableDNS}, @@ -416,27 +478,67 @@ func (s *serviceClient) getSettingsForm() *widget.Form { s.managementURL = iMngURL s.preSharedKey = s.iPreSharedKey.Text - loginRequest := proto.LoginRequest{ - ManagementUrl: iMngURL, - IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd", - RosenpassPermissive: &s.sRosenpassPermissive.Checked, - InterfaceName: &s.iInterfaceName.Text, - WireguardPort: &port, - NetworkMonitor: &s.sNetworkMonitor.Checked, - DisableDns: &s.sDisableDNS.Checked, - DisableClientRoutes: &s.sDisableClientRoutes.Checked, - DisableServerRoutes: &s.sDisableServerRoutes.Checked, - BlockLanAccess: &s.sBlockLANAccess.Checked, - } - - if s.iPreSharedKey.Text != censoredPreSharedKey { - loginRequest.OptionalPreSharedKey = &s.iPreSharedKey.Text - } - - if err := s.restartClient(&loginRequest); err != nil { - log.Errorf("restarting client connection: %v", err) + currUser, err := user.Current() + if err != nil { + log.Errorf("get current user: %v", err) return } + + var req proto.SetConfigRequest + req.ProfileName = activeProf.Name + req.Username = currUser.Username + + if iMngURL != "" { + req.ManagementUrl = iMngURL + } + + req.RosenpassPermissive = &s.sRosenpassPermissive.Checked + req.InterfaceName = &s.iInterfaceName.Text + req.WireguardPort = &port + req.NetworkMonitor = &s.sNetworkMonitor.Checked + req.DisableDns = &s.sDisableDNS.Checked + req.DisableClientRoutes = &s.sDisableClientRoutes.Checked + req.DisableServerRoutes = &s.sDisableServerRoutes.Checked + req.BlockLanAccess = &s.sBlockLANAccess.Checked + + if s.iPreSharedKey.Text != censoredPreSharedKey { + req.OptionalPreSharedKey = &s.iPreSharedKey.Text + } + + conn, err := s.getSrvClient(failFastTimeout) + if err != nil { + log.Errorf("get client: %v", err) + dialog.ShowError(fmt.Errorf("Failed to connect to the service: %v", err), s.wSettings) + return + } + _, err = conn.SetConfig(s.ctx, &req) + if err != nil { + log.Errorf("set config: %v", err) + dialog.ShowError(fmt.Errorf("Failed to set configuration: %v", err), s.wSettings) + return + } + + status, err := conn.Status(s.ctx, &proto.StatusRequest{}) + if err != nil { + log.Errorf("get service status: %v", err) + dialog.ShowError(fmt.Errorf("Failed to get service status: %v", err), s.wSettings) + return + } + if status.Status == string(internal.StatusConnected) { + // run down & up + _, err = conn.Down(s.ctx, &proto.DownRequest{}) + if err != nil { + log.Errorf("down service: %v", err) + } + + _, err = conn.Up(s.ctx, &proto.UpRequest{}) + if err != nil { + log.Errorf("up service: %v", err) + dialog.ShowError(fmt.Errorf("Failed to reconnect: %v", err), s.wSettings) + return + } + } + } }, OnCancel: func() { @@ -452,8 +554,21 @@ func (s *serviceClient) login(openURL bool) (*proto.LoginResponse, error) { return nil, err } + activeProf, err := s.profileManager.GetActiveProfile() + if err != nil { + log.Errorf("get active profile: %v", err) + return nil, err + } + + currUser, err := user.Current() + if err != nil { + return nil, fmt.Errorf("get current user: %w", err) + } + loginResp, err := conn.Login(s.ctx, &proto.LoginRequest{ IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd", + ProfileName: &activeProf.Name, + Username: &currUser.Username, }) if err != nil { log.Errorf("login to management URL with: %v", err) @@ -461,15 +576,9 @@ func (s *serviceClient) login(openURL bool) (*proto.LoginResponse, error) { } if loginResp.NeedsSSOLogin && openURL { - err = open.Run(loginResp.VerificationURIComplete) + err = s.handleSSOLogin(loginResp, conn) if err != nil { - log.Errorf("opening the verification uri in the browser failed: %v", err) - return nil, err - } - - _, err = conn.WaitSSOLogin(s.ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode}) - if err != nil { - log.Errorf("waiting sso login failed with: %v", err) + log.Errorf("handle SSO login failed: %v", err) return nil, err } } @@ -477,6 +586,34 @@ func (s *serviceClient) login(openURL bool) (*proto.LoginResponse, error) { return loginResp, nil } +func (s *serviceClient) handleSSOLogin(loginResp *proto.LoginResponse, conn proto.DaemonServiceClient) error { + err := open.Run(loginResp.VerificationURIComplete) + if err != nil { + log.Errorf("opening the verification uri in the browser failed: %v", err) + return err + } + + resp, err := conn.WaitSSOLogin(s.ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode}) + if err != nil { + log.Errorf("waiting sso login failed with: %v", err) + return err + } + + if resp.Email != "" { + err := s.profileManager.SetActiveProfileState(&profilemanager.ProfileState{ + Email: resp.Email, + }) + if err != nil { + log.Warnf("failed to set profile state: %v", err) + } else { + s.mProfile.refresh() + } + + } + + return nil +} + func (s *serviceClient) menuUpClick() error { systray.SetTemplateIcon(iconConnectingMacOS, s.icConnecting) conn, err := s.getSrvClient(defaultFailTimeout) @@ -575,6 +712,7 @@ func (s *serviceClient) updateStatus() error { } systray.SetTooltip("NetBird (Connected)") s.mStatus.SetTitle("Connected") + s.mStatus.SetIcon(s.icConnectedDot) s.mUp.Disable() s.mDown.Enable() s.mNetworks.Enable() @@ -634,6 +772,7 @@ func (s *serviceClient) setDisconnectedStatus() { } systray.SetTooltip("NetBird (Disconnected)") s.mStatus.SetTitle("Disconnected") + s.mStatus.SetIcon(s.icDisconnectedDot) s.mDown.Disable() s.mUp.Enable() s.mNetworks.Disable() @@ -658,7 +797,13 @@ func (s *serviceClient) onTrayReady() { // setup systray menu items s.mStatus = systray.AddMenuItem("Disconnected", "Disconnected") + s.mStatus.SetIcon(s.icDisconnectedDot) s.mStatus.Disable() + + profileMenuItem := systray.AddMenuItem("", "") + emailMenuItem := systray.AddMenuItem("", "") + s.mProfile = newProfileMenu(s.ctx, s.profileManager, *s.eventHandler, profileMenuItem, emailMenuItem, s.menuDownClick, s.menuUpClick, s.getSrvClient, s.loadSettings) + systray.AddSeparator() s.mUp = systray.AddMenuItem("Connect", "Connect") s.mDown = systray.AddMenuItem("Disconnect", "Disconnect") @@ -790,7 +935,15 @@ func (s *serviceClient) getSrvClient(timeout time.Duration) (proto.DaemonService // getSrvConfig from the service to show it in the settings window. func (s *serviceClient) getSrvConfig() { - s.managementURL = internal.DefaultManagementURL + s.managementURL = profilemanager.DefaultManagementURL + + _, err := s.profileManager.GetActiveProfile() + if err != nil { + log.Errorf("get active profile: %v", err) + return + } + + var cfg *profilemanager.Config conn, err := s.getSrvClient(failFastTimeout) if err != nil { @@ -798,48 +951,63 @@ func (s *serviceClient) getSrvConfig() { return } - cfg, err := conn.GetConfig(s.ctx, &proto.GetConfigRequest{}) + currUser, err := user.Current() + if err != nil { + log.Errorf("get current user: %v", err) + return + } + + activeProf, err := s.profileManager.GetActiveProfile() + if err != nil { + log.Errorf("get active profile: %v", err) + return + } + + srvCfg, err := conn.GetConfig(s.ctx, &proto.GetConfigRequest{ + ProfileName: activeProf.Name, + Username: currUser.Username, + }) if err != nil { log.Errorf("get config settings from server: %v", err) return } - if cfg.ManagementUrl != "" { - s.managementURL = cfg.ManagementUrl + cfg = protoConfigToConfig(srvCfg) + + if cfg.ManagementURL.String() != "" { + s.managementURL = cfg.ManagementURL.String() } s.preSharedKey = cfg.PreSharedKey s.RosenpassPermissive = cfg.RosenpassPermissive - s.interfaceName = cfg.InterfaceName - s.interfacePort = int(cfg.WireguardPort) + s.interfaceName = cfg.WgIface + s.interfacePort = cfg.WgPort - s.networkMonitor = cfg.NetworkMonitor - s.disableDNS = cfg.DisableDns + s.networkMonitor = *cfg.NetworkMonitor + s.disableDNS = cfg.DisableDNS s.disableClientRoutes = cfg.DisableClientRoutes s.disableServerRoutes = cfg.DisableServerRoutes - s.blockLANAccess = cfg.BlockLanAccess + s.blockLANAccess = cfg.BlockLANAccess if s.showAdvancedSettings { s.iMngURL.SetText(s.managementURL) - s.iConfigFile.SetText(cfg.ConfigFile) - s.iLogFile.SetText(cfg.LogFile) s.iPreSharedKey.SetText(cfg.PreSharedKey) - s.iInterfaceName.SetText(cfg.InterfaceName) - s.iInterfacePort.SetText(strconv.Itoa(int(cfg.WireguardPort))) + s.iInterfaceName.SetText(cfg.WgIface) + s.iInterfacePort.SetText(strconv.Itoa(cfg.WgPort)) s.sRosenpassPermissive.SetChecked(cfg.RosenpassPermissive) if !cfg.RosenpassEnabled { s.sRosenpassPermissive.Disable() } - s.sNetworkMonitor.SetChecked(cfg.NetworkMonitor) - s.sDisableDNS.SetChecked(cfg.DisableDns) + s.sNetworkMonitor.SetChecked(*cfg.NetworkMonitor) + s.sDisableDNS.SetChecked(cfg.DisableDNS) s.sDisableClientRoutes.SetChecked(cfg.DisableClientRoutes) s.sDisableServerRoutes.SetChecked(cfg.DisableServerRoutes) - s.sBlockLANAccess.SetChecked(cfg.BlockLanAccess) + s.sBlockLANAccess.SetChecked(cfg.BlockLANAccess) } if s.mNotifications == nil { return } - if cfg.DisableNotifications { + if cfg.DisableNotifications != nil && *cfg.DisableNotifications { s.mNotifications.Uncheck() } else { s.mNotifications.Check() @@ -849,6 +1017,58 @@ func (s *serviceClient) getSrvConfig() { } } +func protoConfigToConfig(cfg *proto.GetConfigResponse) *profilemanager.Config { + + var config profilemanager.Config + + if cfg.ManagementUrl != "" { + parsed, err := url.Parse(cfg.ManagementUrl) + if err != nil { + log.Errorf("parse management URL: %v", err) + } else { + config.ManagementURL = parsed + } + } + + if cfg.PreSharedKey != "" { + if cfg.PreSharedKey != censoredPreSharedKey { + config.PreSharedKey = cfg.PreSharedKey + } else { + config.PreSharedKey = "" + } + } + if cfg.AdminURL != "" { + parsed, err := url.Parse(cfg.AdminURL) + if err != nil { + log.Errorf("parse admin URL: %v", err) + } else { + config.AdminURL = parsed + } + } + + config.WgIface = cfg.InterfaceName + if cfg.WireguardPort != 0 { + config.WgPort = int(cfg.WireguardPort) + } else { + config.WgPort = iface.DefaultWgPort + } + + config.DisableAutoConnect = cfg.DisableAutoConnect + config.ServerSSHAllowed = &cfg.ServerSSHAllowed + config.RosenpassEnabled = cfg.RosenpassEnabled + config.RosenpassPermissive = cfg.RosenpassPermissive + config.DisableNotifications = &cfg.DisableNotifications + config.LazyConnectionEnabled = cfg.LazyConnectionEnabled + config.BlockInbound = cfg.BlockInbound + config.NetworkMonitor = &cfg.NetworkMonitor + config.DisableDNS = cfg.DisableDns + config.DisableClientRoutes = cfg.DisableClientRoutes + config.DisableServerRoutes = cfg.DisableServerRoutes + config.BlockLANAccess = cfg.BlockLanAccess + + return &config +} + func (s *serviceClient) onUpdateAvailable() { s.updateIndicationLock.Lock() defer s.updateIndicationLock.Unlock() @@ -880,7 +1100,22 @@ func (s *serviceClient) loadSettings() { return } - cfg, err := conn.GetConfig(s.ctx, &proto.GetConfigRequest{}) + currUser, err := user.Current() + if err != nil { + log.Errorf("get current user: %v", err) + return + } + + activeProf, err := s.profileManager.GetActiveProfile() + if err != nil { + log.Errorf("get active profile: %v", err) + return + } + + cfg, err := conn.GetConfig(s.ctx, &proto.GetConfigRequest{ + ProfileName: activeProf.Name, + Username: currUser.Username, + }) if err != nil { log.Errorf("get config settings from server: %v", err) return @@ -936,41 +1171,37 @@ func (s *serviceClient) updateConfig() error { blockInbound := s.mBlockInbound.Checked() notificationsDisabled := !s.mNotifications.Checked() - loginRequest := proto.LoginRequest{ - IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd", + activeProf, err := s.profileManager.GetActiveProfile() + if err != nil { + log.Errorf("get active profile: %v", err) + return err + } + + currUser, err := user.Current() + if err != nil { + log.Errorf("get current user: %v", err) + return err + } + + conn, err := s.getSrvClient(failFastTimeout) + if err != nil { + log.Errorf("get client: %v", err) + return err + } + + req := proto.SetConfigRequest{ + ProfileName: activeProf.Name, + Username: currUser.Username, + DisableAutoConnect: &disableAutoStart, ServerSSHAllowed: &sshAllowed, RosenpassEnabled: &rosenpassEnabled, - DisableAutoConnect: &disableAutoStart, - DisableNotifications: ¬ificationsDisabled, LazyConnectionEnabled: &lazyConnectionEnabled, BlockInbound: &blockInbound, + DisableNotifications: ¬ificationsDisabled, } - if err := s.restartClient(&loginRequest); err != nil { - log.Errorf("restarting client connection: %v", err) - return err - } - - return nil -} - -// restartClient restarts the client connection. -func (s *serviceClient) restartClient(loginRequest *proto.LoginRequest) error { - ctx, cancel := context.WithTimeout(s.ctx, defaultFailTimeout) - defer cancel() - - client, err := s.getSrvClient(failFastTimeout) - if err != nil { - return err - } - - _, err = client.Login(ctx, loginRequest) - if err != nil { - return err - } - - _, err = client.Up(ctx, &proto.UpRequest{}) - if err != nil { + if _, err := conn.SetConfig(s.ctx, &req); err != nil { + log.Errorf("set config settings on server: %v", err) return err } diff --git a/client/ui/const.go b/client/ui/const.go index 5a4b27f32..332282c17 100644 --- a/client/ui/const.go +++ b/client/ui/const.go @@ -2,6 +2,7 @@ package main const ( settingsMenuDescr = "Settings of the application" + profilesMenuDescr = "Manage your profiles" allowSSHMenuDescr = "Allow SSH connections" autoConnectMenuDescr = "Connect automatically when the service starts" quantumResistanceMenuDescr = "Enable post-quantum security via Rosenpass" diff --git a/client/ui/debug.go b/client/ui/debug.go index 55829de1e..a7f4868ac 100644 --- a/client/ui/debug.go +++ b/client/ui/debug.go @@ -433,7 +433,7 @@ func (s *serviceClient) collectDebugData( var postUpStatusOutput string if postUpStatus != nil { - overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "") + overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", "") postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview) } headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339)) @@ -450,7 +450,7 @@ func (s *serviceClient) collectDebugData( var preDownStatusOutput string if preDownStatus != nil { - overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "") + overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", "") preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview) } headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s", @@ -581,7 +581,7 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa var statusOutput string if statusResp != nil { - overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "") + overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", "") statusOutput = nbstatus.ParseToFullDetailSummary(overview) } diff --git a/client/ui/profile.go b/client/ui/profile.go new file mode 100644 index 000000000..142582c25 --- /dev/null +++ b/client/ui/profile.go @@ -0,0 +1,601 @@ +//go:build !(linux && 386) + +package main + +import ( + "context" + "errors" + "fmt" + "os/user" + "slices" + "sort" + "sync" + "time" + + "fyne.io/fyne/v2" + "fyne.io/fyne/v2/container" + "fyne.io/fyne/v2/dialog" + "fyne.io/fyne/v2/layout" + "fyne.io/fyne/v2/widget" + "fyne.io/systray" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/profilemanager" + "github.com/netbirdio/netbird/client/proto" +) + +// showProfilesUI creates and displays the Profiles window with a list of existing profiles, +// a button to add new profiles, allows removal, and lets the user switch the active profile. +func (s *serviceClient) showProfilesUI() { + + profiles, err := s.getProfiles() + if err != nil { + log.Errorf("get profiles: %v", err) + return + } + + var refresh func() + // List widget for profiles + list := widget.NewList( + func() int { return len(profiles) }, + func() fyne.CanvasObject { + // Each item: Selected indicator, Name, spacer, Select & Remove buttons + return container.NewHBox( + widget.NewLabel(""), // indicator + widget.NewLabel(""), // profile name + layout.NewSpacer(), + widget.NewButton("Select", nil), + widget.NewButton("Remove", nil), + ) + }, + func(i widget.ListItemID, item fyne.CanvasObject) { + // Populate each row + row := item.(*fyne.Container) + indicator := row.Objects[0].(*widget.Label) + nameLabel := row.Objects[1].(*widget.Label) + selectBtn := row.Objects[3].(*widget.Button) + removeBtn := row.Objects[4].(*widget.Button) + + profile := profiles[i] + // Show a checkmark if selected + if profile.IsActive { + indicator.SetText("✓") + } else { + indicator.SetText("") + } + nameLabel.SetText(profile.Name) + + // Configure Select/Active button + selectBtn.SetText(func() string { + if profile.IsActive { + return "Active" + } + return "Select" + }()) + selectBtn.OnTapped = func() { + if profile.IsActive { + return // already active + } + // confirm switch + dialog.ShowConfirm( + "Switch Profile", + fmt.Sprintf("Are you sure you want to switch to '%s'?", profile.Name), + func(confirm bool) { + if !confirm { + return + } + // switch + err = s.switchProfile(profile.Name) + if err != nil { + log.Errorf("failed to switch profile: %v", err) + dialog.ShowError(errors.New("failed to select profile"), s.wProfiles) + return + } + + dialog.ShowInformation( + "Profile Switched", + fmt.Sprintf("Profile '%s' switched successfully", profile.Name), + s.wProfiles, + ) + + conn, err := s.getSrvClient(defaultFailTimeout) + if err != nil { + log.Errorf("failed to get daemon client: %v", err) + return + } + + status, err := conn.Status(context.Background(), &proto.StatusRequest{}) + if err != nil { + log.Errorf("failed to get status after switching profile: %v", err) + return + } + + if status.Status == string(internal.StatusConnected) { + if err := s.menuDownClick(); err != nil { + log.Errorf("failed to handle down click after switching profile: %v", err) + dialog.ShowError(fmt.Errorf("failed to handle down click"), s.wProfiles) + return + } + } + // update slice flags + refresh() + }, + s.wProfiles, + ) + } + + // Remove profile + removeBtn.SetText("Remove") + removeBtn.OnTapped = func() { + dialog.ShowConfirm( + "Delete Profile", + fmt.Sprintf("Are you sure you want to delete '%s'?", profile.Name), + func(confirm bool) { + if !confirm { + return + } + // remove + err = s.removeProfile(profile.Name) + if err != nil { + log.Errorf("failed to remove profile: %v", err) + dialog.ShowError(fmt.Errorf("failed to remove profile"), s.wProfiles) + return + } + dialog.ShowInformation( + "Profile Removed", + fmt.Sprintf("Profile '%s' removed successfully", profile.Name), + s.wProfiles, + ) + // update slice + refresh() + }, + s.wProfiles, + ) + } + }, + ) + + refresh = func() { + newProfiles, err := s.getProfiles() + if err != nil { + dialog.ShowError(err, s.wProfiles) + return + } + profiles = newProfiles // update the slice + list.Refresh() // tell Fyne to re-call length/update on every visible row + } + + // Button to add a new profile + newBtn := widget.NewButton("New Profile", func() { + nameEntry := widget.NewEntry() + nameEntry.SetPlaceHolder("Enter Profile Name") + + formItems := []*widget.FormItem{{Text: "Name:", Widget: nameEntry}} + dlg := dialog.NewForm( + "New Profile", + "Create", + "Cancel", + formItems, + func(confirm bool) { + if !confirm { + return + } + name := nameEntry.Text + if name == "" { + dialog.ShowError(errors.New("profile name cannot be empty"), s.wProfiles) + return + } + + // add profile + err = s.addProfile(name) + if err != nil { + log.Errorf("failed to create profile: %v", err) + dialog.ShowError(fmt.Errorf("failed to create profile"), s.wProfiles) + return + } + dialog.ShowInformation( + "Profile Created", + fmt.Sprintf("Profile '%s' created successfully", name), + s.wProfiles, + ) + // update slice + refresh() + }, + s.wProfiles, + ) + // make dialog wider + dlg.Resize(fyne.NewSize(350, 150)) + dlg.Show() + }) + + // Assemble window content + content := container.NewBorder(nil, newBtn, nil, nil, list) + s.wProfiles = s.app.NewWindow("NetBird Profiles") + s.wProfiles.SetContent(content) + s.wProfiles.Resize(fyne.NewSize(400, 300)) + s.wProfiles.SetOnClosed(s.cancel) + + s.wProfiles.Show() +} + +func (s *serviceClient) addProfile(profileName string) error { + conn, err := s.getSrvClient(defaultFailTimeout) + if err != nil { + return fmt.Errorf(getClientFMT, err) + } + + currUser, err := user.Current() + if err != nil { + return fmt.Errorf("get current user: %w", err) + } + + _, err = conn.AddProfile(context.Background(), &proto.AddProfileRequest{ + ProfileName: profileName, + Username: currUser.Username, + }) + + if err != nil { + return fmt.Errorf("add profile: %w", err) + } + + return nil +} + +func (s *serviceClient) switchProfile(profileName string) error { + conn, err := s.getSrvClient(defaultFailTimeout) + if err != nil { + return fmt.Errorf(getClientFMT, err) + } + + currUser, err := user.Current() + if err != nil { + return fmt.Errorf("get current user: %w", err) + } + + if _, err := conn.SwitchProfile(context.Background(), &proto.SwitchProfileRequest{ + ProfileName: &profileName, + Username: &currUser.Username, + }); err != nil { + return fmt.Errorf("switch profile failed: %w", err) + } + + err = s.profileManager.SwitchProfile(profileName) + if err != nil { + return fmt.Errorf("switch profile: %w", err) + } + + return nil +} + +func (s *serviceClient) removeProfile(profileName string) error { + conn, err := s.getSrvClient(defaultFailTimeout) + if err != nil { + return fmt.Errorf(getClientFMT, err) + } + + currUser, err := user.Current() + if err != nil { + return fmt.Errorf("get current user: %w", err) + } + + _, err = conn.RemoveProfile(context.Background(), &proto.RemoveProfileRequest{ + ProfileName: profileName, + Username: currUser.Username, + }) + if err != nil { + return fmt.Errorf("remove profile: %w", err) + } + + return nil +} + +type Profile struct { + Name string + IsActive bool +} + +func (s *serviceClient) getProfiles() ([]Profile, error) { + conn, err := s.getSrvClient(defaultFailTimeout) + if err != nil { + return nil, fmt.Errorf(getClientFMT, err) + } + + currUser, err := user.Current() + if err != nil { + return nil, fmt.Errorf("get current user: %w", err) + } + profilesResp, err := conn.ListProfiles(context.Background(), &proto.ListProfilesRequest{ + Username: currUser.Username, + }) + if err != nil { + return nil, fmt.Errorf("list profiles: %w", err) + } + + var profiles []Profile + + for _, profile := range profilesResp.Profiles { + profiles = append(profiles, Profile{ + Name: profile.Name, + IsActive: profile.IsActive, + }) + } + + return profiles, nil +} + +type subItem struct { + *systray.MenuItem + ctx context.Context + cancel context.CancelFunc +} + +type profileMenu struct { + mu sync.Mutex + ctx context.Context + profileManager *profilemanager.ProfileManager + eventHandler eventHandler + profileMenuItem *systray.MenuItem + emailMenuItem *systray.MenuItem + profileSubItems []*subItem + manageProfilesSubItem *subItem + profilesState []Profile + downClickCallback func() error + upClickCallback func() error + getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error) + loadSettingsCallback func() +} + +func newProfileMenu(ctx context.Context, profileManager *profilemanager.ProfileManager, + + eventHandler eventHandler, profileMenuItem, emailMenuItem *systray.MenuItem, + downClickCallback, upClickCallback func() error, + getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error), + loadSettingsCallback func()) *profileMenu { + p := profileMenu{ + ctx: ctx, + profileManager: profileManager, + eventHandler: eventHandler, + profileMenuItem: profileMenuItem, + emailMenuItem: emailMenuItem, + downClickCallback: downClickCallback, + upClickCallback: upClickCallback, + getSrvClientCallback: getSrvClientCallback, + loadSettingsCallback: loadSettingsCallback, + } + + p.emailMenuItem.Disable() + p.emailMenuItem.Hide() + p.refresh() + go p.updateMenu() + + return &p +} + +func (p *profileMenu) getProfiles() ([]Profile, error) { + conn, err := p.getSrvClientCallback(defaultFailTimeout) + if err != nil { + return nil, fmt.Errorf(getClientFMT, err) + } + currUser, err := user.Current() + if err != nil { + return nil, fmt.Errorf("get current user: %w", err) + } + + profilesResp, err := conn.ListProfiles(p.ctx, &proto.ListProfilesRequest{ + Username: currUser.Username, + }) + if err != nil { + return nil, fmt.Errorf("list profiles: %w", err) + } + + var profiles []Profile + + for _, profile := range profilesResp.Profiles { + profiles = append(profiles, Profile{ + Name: profile.Name, + IsActive: profile.IsActive, + }) + } + + return profiles, nil +} + +func (p *profileMenu) refresh() { + p.mu.Lock() + defer p.mu.Unlock() + + profiles, err := p.getProfiles() + if err != nil { + log.Errorf("failed to list profiles: %v", err) + return + } + + // Clear existing profile items + p.clear(profiles) + + currUser, err := user.Current() + if err != nil { + log.Errorf("failed to get current user: %v", err) + return + } + + conn, err := p.getSrvClientCallback(defaultFailTimeout) + if err != nil { + log.Errorf("failed to get daemon client: %v", err) + return + } + + activeProf, err := conn.GetActiveProfile(p.ctx, &proto.GetActiveProfileRequest{}) + if err != nil { + log.Errorf("failed to get active profile: %v", err) + return + } + + if activeProf.ProfileName == "default" || activeProf.Username == currUser.Username { + activeProfState, err := p.profileManager.GetProfileState(activeProf.ProfileName) + if err != nil { + log.Warnf("failed to get active profile state: %v", err) + p.emailMenuItem.Hide() + } else if activeProfState.Email != "" { + p.emailMenuItem.SetTitle(fmt.Sprintf("(%s)", activeProfState.Email)) + p.emailMenuItem.Show() + } + } + + for _, profile := range profiles { + item := p.profileMenuItem.AddSubMenuItem(profile.Name, "") + if profile.IsActive { + item.Check() + } + + ctx, cancel := context.WithCancel(context.Background()) + p.profileSubItems = append(p.profileSubItems, &subItem{item, ctx, cancel}) + + go func() { + for { + select { + case <-ctx.Done(): + return // context cancelled + case _, ok := <-item.ClickedCh: + if !ok { + return // channel closed + } + + // Handle profile selection + if profile.IsActive { + log.Infof("Profile '%s' is already active", profile.Name) + return + } + conn, err := p.getSrvClientCallback(defaultFailTimeout) + if err != nil { + log.Errorf("failed to get daemon client: %v", err) + return + } + + _, err = conn.SwitchProfile(ctx, &proto.SwitchProfileRequest{ + ProfileName: &profile.Name, + Username: &currUser.Username, + }) + if err != nil { + log.Errorf("failed to switch profile: %v", err) + return + } + + err = p.profileManager.SwitchProfile(profile.Name) + if err != nil { + log.Errorf("failed to switch profile '%s': %v", profile.Name, err) + return + } + + log.Infof("Switched to profile '%s'", profile.Name) + + status, err := conn.Status(ctx, &proto.StatusRequest{}) + if err != nil { + log.Errorf("failed to get status after switching profile: %v", err) + return + } + + if status.Status == string(internal.StatusConnected) { + if err := p.downClickCallback(); err != nil { + log.Errorf("failed to handle down click after switching profile: %v", err) + } + } + + if err := p.upClickCallback(); err != nil { + log.Errorf("failed to handle up click after switching profile: %v", err) + } + + p.refresh() + p.loadSettingsCallback() + } + } + }() + + } + ctx, cancel := context.WithCancel(context.Background()) + manageItem := p.profileMenuItem.AddSubMenuItem("Manage Profiles", "") + p.manageProfilesSubItem = &subItem{manageItem, ctx, cancel} + + go func() { + for { + select { + case <-ctx.Done(): + return // context cancelled + case _, ok := <-manageItem.ClickedCh: + if !ok { + return // channel closed + } + // Handle manage profiles click + p.eventHandler.runSelfCommand(p.ctx, "profiles", "true") + p.refresh() + p.loadSettingsCallback() + } + } + }() + + if activeProf.ProfileName == "default" || activeProf.Username == currUser.Username { + p.profileMenuItem.SetTitle(activeProf.ProfileName) + } else { + p.profileMenuItem.SetTitle(fmt.Sprintf("Profile: %s (User: %s)", activeProf.ProfileName, activeProf.Username)) + p.emailMenuItem.Hide() + } + +} + +func (p *profileMenu) clear(profiles []Profile) { + // Clear existing profile items + for _, item := range p.profileSubItems { + item.Remove() + item.cancel() + } + p.profileSubItems = make([]*subItem, 0, len(profiles)) + p.profilesState = profiles + + if p.manageProfilesSubItem != nil { + // Remove the manage profiles item if it exists + p.manageProfilesSubItem.Remove() + p.manageProfilesSubItem.cancel() + p.manageProfilesSubItem = nil + } +} + +func (p *profileMenu) updateMenu() { + // check every second + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + + // get profilesList + profiles, err := p.getProfiles() + if err != nil { + log.Errorf("failed to list profiles: %v", err) + continue + } + + sort.Slice(profiles, func(i, j int) bool { + return profiles[i].Name < profiles[j].Name + }) + + p.mu.Lock() + state := p.profilesState + p.mu.Unlock() + + sort.Slice(state, func(i, j int) bool { + return state[i].Name < state[j].Name + }) + + if slices.Equal(profiles, state) { + continue + } + + p.refresh() + case <-p.ctx.Done(): + return // context cancelled + + } + } +} diff --git a/util/file.go b/util/file.go index f7de7ede2..73ad05b18 100644 --- a/util/file.go +++ b/util/file.go @@ -9,6 +9,7 @@ import ( "io" "os" "path/filepath" + "sort" "strings" "text/template" @@ -200,6 +201,36 @@ func ReadJson(file string, res interface{}) (interface{}, error) { return res, nil } +// RemoveJson removes the specified JSON file if it exists +func RemoveJson(file string) error { + // Check if the file exists + if _, err := os.Stat(file); errors.Is(err, os.ErrNotExist) { + return nil // File does not exist, nothing to remove + } + + // Attempt to remove the file + if err := os.Remove(file); err != nil { + return fmt.Errorf("failed to remove JSON file %s: %w", file, err) + } + + return nil +} + +// ListFiles returns the full paths of all files in dir that match pattern. +// Pattern uses shell-style globbing (e.g. "*.json"). +func ListFiles(dir, pattern string) ([]string, error) { + // glob pattern like "/path/to/dir/*.json" + globPattern := filepath.Join(dir, pattern) + + matches, err := filepath.Glob(globPattern) + if err != nil { + return nil, err + } + + sort.Strings(matches) + return matches, nil +} + // ReadJsonWithEnvSub reads JSON config file and maps to a provided interface with environment variable substitution func ReadJsonWithEnvSub(file string, res interface{}) (interface{}, error) { envVars := getEnvMap() From 3d9be5098ba392389ecb9210bee0bdb42d01dd65 Mon Sep 17 00:00:00 2001 From: hakansa <43675540+hakansa@users.noreply.github.com> Date: Fri, 25 Jul 2025 18:43:48 +0300 Subject: [PATCH 40/50] [client]: deprecate config flag (#4224) --- client/cmd/login.go | 15 +++++---------- client/cmd/root.go | 1 + client/cmd/service_installer.go | 2 -- client/cmd/up.go | 14 ++++---------- release_files/systemd/netbird@.service | 2 +- 5 files changed, 11 insertions(+), 23 deletions(-) diff --git a/client/cmd/login.go b/client/cmd/login.go index 482e004d1..d6381f6e2 100644 --- a/client/cmd/login.go +++ b/client/cmd/login.go @@ -26,7 +26,7 @@ import ( func init() { loginCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc) loginCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc) - loginCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "Netbird config file location") + loginCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) Netbird config file location") } var loginCmd = &cobra.Command{ @@ -228,15 +228,10 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string, // update host's static platform and system information system.UpdateStaticInfo() - var configFilePath string - if configPath != "" { - configFilePath = configPath - } else { - var err error - configFilePath, err = activeProf.FilePath() - if err != nil { - return fmt.Errorf("get active profile file path: %v", err) - } + configFilePath, err := activeProf.FilePath() + if err != nil { + return fmt.Errorf("get active profile file path: %v", err) + } config, err := profilemanager.ReadConfig(configFilePath) diff --git a/client/cmd/root.go b/client/cmd/root.go index b22b850ee..8e8ee3280 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -126,6 +126,7 @@ func init() { rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.") rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device") rootCmd.PersistentFlags().BoolVarP(&anonymizeFlag, "anonymize", "A", false, "anonymize IP addresses and non-netbird.io domains in logs and status output") + rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "(DEPRECATED) Netbird config file location") rootCmd.AddCommand(upCmd) rootCmd.AddCommand(downCmd) diff --git a/client/cmd/service_installer.go b/client/cmd/service_installer.go index c994801a6..be8a897dc 100644 --- a/client/cmd/service_installer.go +++ b/client/cmd/service_installer.go @@ -31,8 +31,6 @@ func buildServiceArguments() []string { args := []string{ "service", "run", - "--config", - configPath, "--log-level", logLevel, "--daemon-addr", diff --git a/client/cmd/up.go b/client/cmd/up.go index d1f8e67a1..98e1c02b3 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -78,7 +78,7 @@ func init() { upCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc) upCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc) - upCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "Netbird config file location") + upCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) Netbird config file location") } @@ -155,15 +155,9 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr return err } - var configFilePath string - if configPath != "" { - configFilePath = configPath - } else { - var err error - configFilePath, err = activeProf.FilePath() - if err != nil { - return fmt.Errorf("get active profile file path: %v", err) - } + configFilePath, err := activeProf.FilePath() + if err != nil { + return fmt.Errorf("get active profile file path: %v", err) } ic, err := setupConfig(customDNSAddressConverted, cmd, configFilePath) diff --git a/release_files/systemd/netbird@.service b/release_files/systemd/netbird@.service index 095c3142d..48e8cc29d 100644 --- a/release_files/systemd/netbird@.service +++ b/release_files/systemd/netbird@.service @@ -7,7 +7,7 @@ Wants=network-online.target [Service] Type=simple EnvironmentFile=-/etc/default/netbird -ExecStart=/usr/bin/netbird service run --log-file /var/log/netbird/client-%i.log --config /etc/netbird/%i.json --daemon-addr unix:///var/run/netbird/%i.sock $FLAGS +ExecStart=/usr/bin/netbird service run --log-file /var/log/netbird/client-%i.log --daemon-addr unix:///var/run/netbird/%i.sock $FLAGS Restart=on-failure RestartSec=5 TimeoutStopSec=10 From d89e6151a477fb71bc8dc61294d2eb5c720f9140 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Fri, 25 Jul 2025 22:52:48 +0200 Subject: [PATCH 41/50] [client] Fix pre-shared key state in wg show (#4222) --- client/iface/configurer/usp.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/iface/configurer/usp.go b/client/iface/configurer/usp.go index 1ff4d839c..171458e38 100644 --- a/client/iface/configurer/usp.go +++ b/client/iface/configurer/usp.go @@ -530,7 +530,7 @@ func parseStatus(deviceName, ipcStr string) (*Stats, error) { if currentPeer == nil { continue } - if val != "" { + if val != "" && val != "0000000000000000000000000000000000000000000000000000000000000000" { currentPeer.PresharedKey = true } } From e1c66a8124b40aa5963c3e3c3169d424df548404 Mon Sep 17 00:00:00 2001 From: hakansa <43675540+hakansa@users.noreply.github.com> Date: Mon, 28 Jul 2025 13:36:48 +0300 Subject: [PATCH 42/50] [client] Fix profile directory path handling based on NB_STATE_DIR (#4229) [client] Fix profile directory path handling based on NB_STATE_DIR (#4229) --- client/internal/profilemanager/service.go | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/client/internal/profilemanager/service.go b/client/internal/profilemanager/service.go index 56198c4cc..520eef2e9 100644 --- a/client/internal/profilemanager/service.go +++ b/client/internal/profilemanager/service.go @@ -34,14 +34,18 @@ func init() { DefaultConfigPathDir = "/var/lib/netbird/" oldDefaultConfigPathDir = "/etc/netbird/" - switch runtime.GOOS { - case "windows": - oldDefaultConfigPathDir = filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird") - DefaultConfigPathDir = oldDefaultConfigPathDir + if stateDir := os.Getenv("NB_STATE_DIR"); stateDir != "" { + DefaultConfigPathDir = stateDir + } else { + switch runtime.GOOS { + case "windows": + oldDefaultConfigPathDir = filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird") + DefaultConfigPathDir = oldDefaultConfigPathDir - case "freebsd": - oldDefaultConfigPathDir = "/var/db/netbird/" - DefaultConfigPathDir = oldDefaultConfigPathDir + case "freebsd": + oldDefaultConfigPathDir = "/var/db/netbird/" + DefaultConfigPathDir = oldDefaultConfigPathDir + } } oldDefaultConfigPath = filepath.Join(oldDefaultConfigPathDir, "config.json") From 8c8473aed38631c63ddafd85f39c5fb99c42cb4f Mon Sep 17 00:00:00 2001 From: hakansa <43675540+hakansa@users.noreply.github.com> Date: Tue, 29 Jul 2025 13:03:15 +0300 Subject: [PATCH 43/50] [client] Add support for disabling profiles feature via command line flag (#4235) * Add support for disabling profiles feature via command line flag * Add profiles disabling flag to service command * Refactor profile menu initialization and enhance error notifications in event handlers --- client/cmd/root.go | 1 + client/cmd/service.go | 1 + client/cmd/service_controller.go | 2 +- client/cmd/testutil_test.go | 2 +- client/server/server.go | 34 ++++++++++++++++++++++-- client/server/server_test.go | 6 ++--- client/ui/client_ui.go | 16 +++++++++++- client/ui/event_handler.go | 45 +++++++++++++++++++++++++------- client/ui/profile.go | 42 ++++++++++++++++++----------- 9 files changed, 117 insertions(+), 32 deletions(-) diff --git a/client/cmd/root.go b/client/cmd/root.go index 8e8ee3280..e3ce79964 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -72,6 +72,7 @@ var ( anonymizeFlag bool dnsRouteInterval time.Duration lazyConnEnabled bool + profilesDisabled bool rootCmd = &cobra.Command{ Use: "netbird", diff --git a/client/cmd/service.go b/client/cmd/service.go index 178f4bf0e..d8745f1c4 100644 --- a/client/cmd/service.go +++ b/client/cmd/service.go @@ -42,6 +42,7 @@ func init() { } serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd) + serviceCmd.PersistentFlags().BoolVar(&profilesDisabled, "disable-profiles", false, "Disables profiles feature. If enabled, the client will not be able to change or edit any profile.") rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name") serviceEnvDesc := `Sets extra environment variables for the service. ` + diff --git a/client/cmd/service_controller.go b/client/cmd/service_controller.go index cbffff797..6dc6bca9b 100644 --- a/client/cmd/service_controller.go +++ b/client/cmd/service_controller.go @@ -61,7 +61,7 @@ func (p *program) Start(svc service.Service) error { } } - serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles)) + serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), profilesDisabled) if err := serverInstance.Start(); err != nil { log.Fatalf("failed to start daemon: %v", err) } diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index cf94754c1..5dbc8cd7f 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -134,7 +134,7 @@ func startClientDaemon( s := grpc.NewServer() server := client.New(ctx, - "") + "", false) if err := server.Start(); err != nil { t.Fatal(err) } diff --git a/client/server/server.go b/client/server/server.go index f3414888d..80cd6078f 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -44,6 +44,7 @@ const ( defaultRetryMultiplier = 1.7 errRestoreResidualState = "failed to restore residual state: %v" + errProfilesDisabled = "profiles are disabled, you cannot use this feature without profiles enabled" ) // Server for service control. @@ -68,7 +69,8 @@ type Server struct { persistNetworkMap bool isSessionActive atomic.Bool - profileManager profilemanager.ServiceManager + profileManager profilemanager.ServiceManager + profilesDisabled bool } type oauthAuthFlow struct { @@ -79,13 +81,14 @@ type oauthAuthFlow struct { } // New server instance constructor. -func New(ctx context.Context, logFile string) *Server { +func New(ctx context.Context, logFile string, profilesDisabled bool) *Server { return &Server{ rootCtx: ctx, logFile: logFile, persistNetworkMap: true, statusRecorder: peer.NewRecorder(""), profileManager: profilemanager.ServiceManager{}, + profilesDisabled: profilesDisabled, } } @@ -320,6 +323,10 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques s.mutex.Lock() defer s.mutex.Unlock() + if s.checkProfilesDisabled() { + return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled) + } + profState := profilemanager.ActiveProfileState{ Name: msg.ProfileName, Username: msg.Username, @@ -737,6 +744,11 @@ func (s *Server) switchProfileIfNeeded(profileName string, userName *string, act } if profileName != activeProf.Name || username != activeProf.Username { + if s.checkProfilesDisabled() { + log.Errorf("profiles are disabled, you cannot use this feature without profiles enabled") + return gstatus.Errorf(codes.Unavailable, errProfilesDisabled) + } + log.Infof("switching to profile %s for user %s", profileName, username) if err := s.profileManager.SetActiveProfileState(&profilemanager.ActiveProfileState{ Name: profileName, @@ -1069,6 +1081,10 @@ func (s *Server) AddProfile(ctx context.Context, msg *proto.AddProfileRequest) ( s.mutex.Lock() defer s.mutex.Unlock() + if s.checkProfilesDisabled() { + return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled) + } + if msg.ProfileName == "" || msg.Username == "" { return nil, gstatus.Errorf(codes.InvalidArgument, "profile name and username must be provided") } @@ -1086,6 +1102,10 @@ func (s *Server) RemoveProfile(ctx context.Context, msg *proto.RemoveProfileRequ s.mutex.Lock() defer s.mutex.Unlock() + if s.checkProfilesDisabled() { + return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled) + } + if msg.ProfileName == "" { return nil, gstatus.Errorf(codes.InvalidArgument, "profile name must be provided") } @@ -1142,3 +1162,13 @@ func (s *Server) GetActiveProfile(ctx context.Context, msg *proto.GetActiveProfi Username: activeProfile.Username, }, nil } + +func (s *Server) checkProfilesDisabled() bool { + // Check if the environment variable is set to disable profiles + if s.profilesDisabled { + log.Warn("Profiles are disabled via NB_DISABLE_PROFILES environment variable") + return true + } + + return false +} diff --git a/client/server/server_test.go b/client/server/server_test.go index dda610076..afd38b4a4 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -94,7 +94,7 @@ func TestConnectWithRetryRuns(t *testing.T) { t.Fatalf("failed to set active profile state: %v", err) } - s := New(ctx, "debug") + s := New(ctx, "debug", false) s.config = config @@ -151,7 +151,7 @@ func TestServer_Up(t *testing.T) { t.Fatalf("failed to set active profile state: %v", err) } - s := New(ctx, "console") + s := New(ctx, "console", false) err = s.Start() require.NoError(t, err) @@ -227,7 +227,7 @@ func TestServer_SubcribeEvents(t *testing.T) { t.Fatalf("failed to set active profile state: %v", err) } - s := New(ctx, "console") + s := New(ctx, "console", false) err = s.Start() require.NoError(t, err) diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index 781ec56f8..c74412c8b 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -802,7 +802,21 @@ func (s *serviceClient) onTrayReady() { profileMenuItem := systray.AddMenuItem("", "") emailMenuItem := systray.AddMenuItem("", "") - s.mProfile = newProfileMenu(s.ctx, s.profileManager, *s.eventHandler, profileMenuItem, emailMenuItem, s.menuDownClick, s.menuUpClick, s.getSrvClient, s.loadSettings) + + newProfileMenuArgs := &newProfileMenuArgs{ + ctx: s.ctx, + profileManager: s.profileManager, + eventHandler: s.eventHandler, + profileMenuItem: profileMenuItem, + emailMenuItem: emailMenuItem, + downClickCallback: s.menuDownClick, + upClickCallback: s.menuUpClick, + getSrvClientCallback: s.getSrvClient, + loadSettingsCallback: s.loadSettings, + app: s.app, + } + + s.mProfile = newProfileMenu(*newProfileMenuArgs) systray.AddSeparator() s.mUp = systray.AddMenuItem("Connect", "Connect") diff --git a/client/ui/event_handler.go b/client/ui/event_handler.go index 39ea3867c..c0bc74a2c 100644 --- a/client/ui/event_handler.go +++ b/client/ui/event_handler.go @@ -86,35 +86,60 @@ func (h *eventHandler) handleDisconnectClick() { func (h *eventHandler) handleAllowSSHClick() { h.toggleCheckbox(h.client.mAllowSSH) - h.updateConfigWithErr() + if err := h.updateConfigWithErr(); err != nil { + h.toggleCheckbox(h.client.mAllowSSH) // revert checkbox state on error + log.Errorf("failed to update config: %v", err) + h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update SSH settings")) + } + } func (h *eventHandler) handleAutoConnectClick() { h.toggleCheckbox(h.client.mAutoConnect) - h.updateConfigWithErr() + if err := h.updateConfigWithErr(); err != nil { + h.toggleCheckbox(h.client.mAutoConnect) // revert checkbox state on error + log.Errorf("failed to update config: %v", err) + h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update auto-connect settings")) + } } func (h *eventHandler) handleRosenpassClick() { h.toggleCheckbox(h.client.mEnableRosenpass) - h.updateConfigWithErr() + if err := h.updateConfigWithErr(); err != nil { + h.toggleCheckbox(h.client.mEnableRosenpass) // revert checkbox state on error + log.Errorf("failed to update config: %v", err) + h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update Rosenpass settings")) + } } func (h *eventHandler) handleLazyConnectionClick() { h.toggleCheckbox(h.client.mLazyConnEnabled) - h.updateConfigWithErr() + if err := h.updateConfigWithErr(); err != nil { + h.toggleCheckbox(h.client.mLazyConnEnabled) // revert checkbox state on error + log.Errorf("failed to update config: %v", err) + h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update lazy connection settings")) + } } func (h *eventHandler) handleBlockInboundClick() { h.toggleCheckbox(h.client.mBlockInbound) - h.updateConfigWithErr() + if err := h.updateConfigWithErr(); err != nil { + h.toggleCheckbox(h.client.mBlockInbound) // revert checkbox state on error + log.Errorf("failed to update config: %v", err) + h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update block inbound settings")) + } } func (h *eventHandler) handleNotificationsClick() { h.toggleCheckbox(h.client.mNotifications) - if h.client.eventManager != nil { + if err := h.updateConfigWithErr(); err != nil { + h.toggleCheckbox(h.client.mNotifications) // revert checkbox state on error + log.Errorf("failed to update config: %v", err) + h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update notifications settings")) + } else if h.client.eventManager != nil { h.client.eventManager.SetNotificationsEnabled(h.client.mNotifications.Checked()) } - h.updateConfigWithErr() + } func (h *eventHandler) handleAdvancedSettingsClick() { @@ -166,10 +191,12 @@ func (h *eventHandler) toggleCheckbox(item *systray.MenuItem) { } } -func (h *eventHandler) updateConfigWithErr() { +func (h *eventHandler) updateConfigWithErr() error { if err := h.client.updateConfig(); err != nil { - log.Errorf("failed to update config: %v", err) + return err } + + return nil } func (h *eventHandler) runSelfCommand(ctx context.Context, command, arg string) { diff --git a/client/ui/profile.go b/client/ui/profile.go index 142582c25..779f60aa4 100644 --- a/client/ui/profile.go +++ b/client/ui/profile.go @@ -334,7 +334,7 @@ type profileMenu struct { mu sync.Mutex ctx context.Context profileManager *profilemanager.ProfileManager - eventHandler eventHandler + eventHandler *eventHandler profileMenuItem *systray.MenuItem emailMenuItem *systray.MenuItem profileSubItems []*subItem @@ -344,24 +344,34 @@ type profileMenu struct { upClickCallback func() error getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error) loadSettingsCallback func() + app fyne.App } -func newProfileMenu(ctx context.Context, profileManager *profilemanager.ProfileManager, +type newProfileMenuArgs struct { + ctx context.Context + profileManager *profilemanager.ProfileManager + eventHandler *eventHandler + profileMenuItem *systray.MenuItem + emailMenuItem *systray.MenuItem + downClickCallback func() error + upClickCallback func() error + getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error) + loadSettingsCallback func() + app fyne.App +} - eventHandler eventHandler, profileMenuItem, emailMenuItem *systray.MenuItem, - downClickCallback, upClickCallback func() error, - getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error), - loadSettingsCallback func()) *profileMenu { +func newProfileMenu(args newProfileMenuArgs) *profileMenu { p := profileMenu{ - ctx: ctx, - profileManager: profileManager, - eventHandler: eventHandler, - profileMenuItem: profileMenuItem, - emailMenuItem: emailMenuItem, - downClickCallback: downClickCallback, - upClickCallback: upClickCallback, - getSrvClientCallback: getSrvClientCallback, - loadSettingsCallback: loadSettingsCallback, + ctx: args.ctx, + profileManager: args.profileManager, + eventHandler: args.eventHandler, + profileMenuItem: args.profileMenuItem, + emailMenuItem: args.emailMenuItem, + downClickCallback: args.downClickCallback, + upClickCallback: args.upClickCallback, + getSrvClientCallback: args.getSrvClientCallback, + loadSettingsCallback: args.loadSettingsCallback, + app: args.app, } p.emailMenuItem.Disable() @@ -479,6 +489,8 @@ func (p *profileMenu) refresh() { }) if err != nil { log.Errorf("failed to switch profile: %v", err) + // show notification dialog + p.app.SendNotification(fyne.NewNotification("Error", "Failed to switch profile")) return } From 980a6eca8e5eaf35c216ac9b0840833ea6866eb0 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 29 Jul 2025 19:37:18 +0200 Subject: [PATCH 44/50] [client] Disable the dns host manager properly if disabled through management (#4241) --- client/internal/dns/handler_chain.go | 2 +- client/internal/dns/server.go | 107 +++++++++++++++++++-------- 2 files changed, 77 insertions(+), 32 deletions(-) diff --git a/client/internal/dns/handler_chain.go b/client/internal/dns/handler_chain.go index 36da8fb78..439bcbb3c 100644 --- a/client/internal/dns/handler_chain.go +++ b/client/internal/dns/handler_chain.go @@ -15,7 +15,7 @@ const ( PriorityDNSRoute = 75 PriorityUpstream = 50 PriorityDefault = 1 - PriorityFallback = -100 + PriorityFallback = -100 ) type SubdomainMatcher interface { diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index f933c1de0..4ab9ef761 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -2,6 +2,7 @@ package dns import ( "context" + "errors" "fmt" "net/netip" "runtime" @@ -59,8 +60,10 @@ type hostManagerWithOriginalNS interface { // DefaultServer dns server object type DefaultServer struct { - ctx context.Context - ctxCancel context.CancelFunc + ctx context.Context + ctxCancel context.CancelFunc + // disableSys disables system DNS management (e.g., /etc/resolv.conf updates) while keeping the DNS service running. + // This is different from ServiceEnable=false from management which completely disables the DNS service. disableSys bool mux sync.Mutex service service @@ -187,6 +190,7 @@ func newDefaultServer( statusRecorder: statusRecorder, stateManager: stateManager, hostsDNSHolder: newHostsDNSHolder(), + hostManager: &noopHostConfigurator{}, } // register with root zone, handler chain takes care of the routing @@ -258,7 +262,8 @@ func (s *DefaultServer) Initialize() (err error) { s.mux.Lock() defer s.mux.Unlock() - if s.hostManager != nil { + if !s.isUsingNoopHostManager() { + // already initialized return nil } @@ -271,19 +276,19 @@ func (s *DefaultServer) Initialize() (err error) { s.stateManager.RegisterState(&ShutdownState{}) - // use noop host manager if requested or running in netstack mode. + // Keep using noop host manager if dns off requested or running in netstack mode. // Netstack mode currently doesn't have a way to receive DNS requests. // TODO: Use listener on localhost in netstack mode when running as root. if s.disableSys || netstack.IsEnabled() { log.Info("system DNS is disabled, not setting up host manager") - s.hostManager = &noopHostConfigurator{} return nil } - s.hostManager, err = s.initialize() + hostManager, err := s.initialize() if err != nil { return fmt.Errorf("initialize: %w", err) } + s.hostManager = hostManager return nil } @@ -297,28 +302,42 @@ func (s *DefaultServer) DnsIP() netip.Addr { // Stop stops the server func (s *DefaultServer) Stop() { - s.mux.Lock() - defer s.mux.Unlock() s.ctxCancel() - if s.hostManager != nil { - if srvs, ok := s.hostManager.(hostManagerWithOriginalNS); ok && len(srvs.getOriginalNameservers()) > 0 { - log.Debugf("deregistering original nameservers as fallback handlers") - s.deregisterHandler([]string{nbdns.RootZone}, PriorityFallback) - } + s.mux.Lock() + defer s.mux.Unlock() - if err := s.hostManager.restoreHostDNS(); err != nil { - log.Error("failed to restore host DNS settings: ", err) - } else if err := s.stateManager.DeleteState(&ShutdownState{}); err != nil { - log.Errorf("failed to delete shutdown dns state: %v", err) - } + if err := s.disableDNS(); err != nil { + log.Errorf("failed to disable DNS: %v", err) } - s.service.Stop() - maps.Clear(s.extraDomains) } +func (s *DefaultServer) disableDNS() error { + defer s.service.Stop() + + if s.isUsingNoopHostManager() { + return nil + } + + // Deregister original nameservers if they were registered as fallback + if srvs, ok := s.hostManager.(hostManagerWithOriginalNS); ok && len(srvs.getOriginalNameservers()) > 0 { + log.Debugf("deregistering original nameservers as fallback handlers") + s.deregisterHandler([]string{nbdns.RootZone}, PriorityFallback) + } + + if err := s.hostManager.restoreHostDNS(); err != nil { + log.Errorf("failed to restore host DNS settings: %v", err) + } else if err := s.stateManager.DeleteState(&ShutdownState{}); err != nil { + log.Errorf("failed to delete shutdown dns state: %v", err) + } + + s.hostManager = &noopHostConfigurator{} + + return nil +} + // OnUpdatedHostDNSServer update the DNS servers addresses for root zones // It will be applied if the mgm server do not enforce DNS settings for root zone func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) { @@ -357,10 +376,6 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro s.mux.Lock() defer s.mux.Unlock() - if s.hostManager == nil { - return fmt.Errorf("dns service is not initialized yet") - } - hash, err := hashstructure.Hash(update, hashstructure.FormatV2, &hashstructure.HashOptions{ ZeroNil: true, IgnoreZeroValue: true, @@ -418,13 +433,14 @@ func (s *DefaultServer) ProbeAvailability() { func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { // is the service should be Disabled, we stop the listener or fake resolver - // and proceed with a regular update to clean up the handlers and records if update.ServiceEnable { - if err := s.service.Listen(); err != nil { - log.Errorf("failed to start DNS service: %v", err) + if err := s.enableDNS(); err != nil { + log.Errorf("failed to enable DNS: %v", err) } } else if !s.permanent { - s.service.Stop() + if err := s.disableDNS(); err != nil { + log.Errorf("failed to disable DNS: %v", err) + } } localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones) @@ -469,11 +485,40 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { return nil } -func (s *DefaultServer) applyHostConfig() { - if s.hostManager == nil { - return +func (s *DefaultServer) isUsingNoopHostManager() bool { + _, isNoop := s.hostManager.(*noopHostConfigurator) + return isNoop +} + +func (s *DefaultServer) enableDNS() error { + if err := s.service.Listen(); err != nil { + return fmt.Errorf("start DNS service: %w", err) } + if !s.isUsingNoopHostManager() { + return nil + } + + if s.disableSys || netstack.IsEnabled() { + return nil + } + + log.Info("DNS service re-enabled, initializing host manager") + + if !s.service.RuntimeIP().IsValid() { + return errors.New("DNS service runtime IP is invalid") + } + + hostManager, err := s.initialize() + if err != nil { + return fmt.Errorf("initialize host manager: %w", err) + } + s.hostManager = hostManager + + return nil +} + +func (s *DefaultServer) applyHostConfig() { // prevent reapplying config if we're shutting down if s.ctx.Err() != nil { return From a72ef1af39f49d508b01aab90ab46cd18aa9c27e Mon Sep 17 00:00:00 2001 From: hakansa <43675540+hakansa@users.noreply.github.com> Date: Tue, 29 Jul 2025 20:38:44 +0300 Subject: [PATCH 45/50] [client] Fix error handling for set config request on CLI (#4237) [client] Fix error handling for set config request on CLI (#4237) --- client/cmd/up.go | 7 ++++++- client/server/server.go | 5 +++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/client/cmd/up.go b/client/cmd/up.go index 98e1c02b3..a0c26a207 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -13,6 +13,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/spf13/cobra" "google.golang.org/grpc/codes" + gstatus "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/durationpb" @@ -242,7 +243,11 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager // set the new config req := setupSetConfigReq(customDNSAddressConverted, cmd, activeProf.Name, username.Username) if _, err := client.SetConfig(ctx, req); err != nil { - return fmt.Errorf("call service set config method: %v", err) + if st, ok := gstatus.FromError(err); ok && st.Code() == codes.Unavailable { + log.Warnf("setConfig method is not available in the daemon") + } else { + return fmt.Errorf("call service setConfig method: %v", err) + } } if err := doDaemonUp(ctx, cmd, client, pm, activeProf, customDNSAddressConverted, username.Username); err != nil { diff --git a/client/server/server.go b/client/server/server.go index 80cd6078f..3cb173881 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -452,6 +452,11 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro } if *msg.ProfileName != activeProf.Name && username != activeProf.Username { + if s.checkProfilesDisabled() { + log.Errorf("profiles are disabled, you cannot use this feature without profiles enabled") + return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled) + } + log.Infof("switching to profile %s for user '%s'", *msg.ProfileName, username) if err := s.profileManager.SetActiveProfileState(&profilemanager.ActiveProfileState{ Name: *msg.ProfileName, From 34042b8171b6a8432b9a08c39cccd9fa692729ae Mon Sep 17 00:00:00 2001 From: Bilgeworth <47156222+Bilgeworth@users.noreply.github.com> Date: Tue, 29 Jul 2025 14:52:18 -0400 Subject: [PATCH 46/50] [misc] devcontainer Dockerfile: pin gopls to v0.18.1 (latest that supports golang 1.23) (#4240) Container will fail to build with newer versions of gopls unless golang is updated to 1.24. The latest stable version supporting 1.23 is gopls v0.18.1 --- .devcontainer/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index 4697acf20..9e5e97a31 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -9,7 +9,7 @@ RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \ libayatana-appindicator3-dev=0.5.5-2+deb11u2 \ && apt-get clean \ && rm -rf /var/lib/apt/lists/* \ - && go install -v golang.org/x/tools/gopls@latest + && go install -v golang.org/x/tools/gopls@v0.18.1 WORKDIR /app From 541e258639650254359aa0c4ce83b60877a681f1 Mon Sep 17 00:00:00 2001 From: Vlad <4941176+crn4@users.noreply.github.com> Date: Wed, 30 Jul 2025 16:49:50 +0200 Subject: [PATCH 47/50] [management] add account deleted event (#4255) --- management/server/account.go | 3 +++ management/server/activity/codes.go | 3 +++ 2 files changed, 6 insertions(+) diff --git a/management/server/account.go b/management/server/account.go index cd0c933f0..52b625da1 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -718,6 +718,9 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u // cancel peer login expiry job am.peerLoginExpiry.Cancel(ctx, []string{account.Id}) + meta := map[string]any{"account_id": account.Id, "domain": account.Domain, "created_at": account.CreatedAt} + am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountDeleted, meta) + log.WithContext(ctx).Debugf("account %s deleted", accountID) return nil } diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go index d9f56f097..23ddd1dd5 100644 --- a/management/server/activity/codes.go +++ b/management/server/activity/codes.go @@ -174,6 +174,8 @@ const ( AccountLazyConnectionEnabled Activity = 85 AccountLazyConnectionDisabled Activity = 86 + + AccountDeleted Activity = 99999 ) var activityMap = map[Activity]Code{ @@ -182,6 +184,7 @@ var activityMap = map[Activity]Code{ UserJoined: {"User joined", "user.join"}, UserInvited: {"User invited", "user.invite"}, AccountCreated: {"Account created", "account.create"}, + AccountDeleted: {"Account deleted", "account.delete"}, PeerRemovedByUser: {"Peer deleted", "user.peer.delete"}, RuleAdded: {"Rule added", "rule.add"}, RuleUpdated: {"Rule updated", "rule.update"}, From 5de61f3081bc0d8178b8909f89d28069dd94d4f8 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Wed, 30 Jul 2025 20:28:19 +0200 Subject: [PATCH 48/50] [client] Fix dns ipv6 upstream (#4257) --- client/internal/dns/server.go | 15 ++++++--- client/internal/dns/server_test.go | 53 ++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 5 deletions(-) diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 4ab9ef761..e5f29d807 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -586,10 +586,7 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) { continue } - ns = fmt.Sprintf("%s:%d", ns, defaultPort) - if ip, err := netip.ParseAddr(ns); err == nil && ip.Is6() { - ns = fmt.Sprintf("[%s]:%d", ns, defaultPort) - } + ns = formatAddr(ns, defaultPort) handler.upstreamServers = append(handler.upstreamServers, ns) } @@ -774,7 +771,15 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) { } func getNSHostPort(ns nbdns.NameServer) string { - return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port) + return formatAddr(ns.IP.String(), ns.Port) +} + +// formatAddr formats a nameserver address with port, handling IPv6 addresses properly +func formatAddr(address string, port int) string { + if ip, err := netip.ParseAddr(address); err == nil && ip.Is6() { + return fmt.Sprintf("[%s]:%d", address, port) + } + return fmt.Sprintf("%s:%d", address, port) } // upstreamCallbacks returns two functions, the first one is used to deactivate diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 3cab4517a..50444a86f 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -2053,3 +2053,56 @@ func TestLocalResolverPriorityConstants(t *testing.T) { assert.Equal(t, PriorityLocal, localMuxUpdates[0].priority, "Local handler should use PriorityLocal") assert.Equal(t, "local.example.com", localMuxUpdates[0].domain) } + +func TestFormatAddr(t *testing.T) { + tests := []struct { + name string + address string + port int + expected string + }{ + { + name: "IPv4 address", + address: "8.8.8.8", + port: 53, + expected: "8.8.8.8:53", + }, + { + name: "IPv4 address with custom port", + address: "1.1.1.1", + port: 5353, + expected: "1.1.1.1:5353", + }, + { + name: "IPv6 address", + address: "fd78:94bf:7df8::1", + port: 53, + expected: "[fd78:94bf:7df8::1]:53", + }, + { + name: "IPv6 address with custom port", + address: "2001:db8::1", + port: 5353, + expected: "[2001:db8::1]:5353", + }, + { + name: "IPv6 localhost", + address: "::1", + port: 53, + expected: "[::1]:53", + }, + { + name: "Invalid address treated as hostname", + address: "dns.example.com", + port: 53, + expected: "dns.example.com:53", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := formatAddr(tt.address, tt.port) + assert.Equal(t, tt.expected, result) + }) + } +} From 71bb09d870e29f33ab56dc6cff9ccecc36f8f186 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 31 Jul 2025 14:36:30 +0200 Subject: [PATCH 49/50] [client] Improve userspace filter logging performance (#4221) --- client/firewall/uspfilter/conntrack/icmp.go | 6 +- client/firewall/uspfilter/conntrack/tcp.go | 14 +- client/firewall/uspfilter/conntrack/udp.go | 4 +- client/firewall/uspfilter/filter.go | 20 +- .../firewall/uspfilter/forwarder/endpoint.go | 2 +- client/firewall/uspfilter/forwarder/icmp.go | 16 +- client/firewall/uspfilter/forwarder/tcp.go | 18 +- client/firewall/uspfilter/forwarder/udp.go | 34 +-- client/firewall/uspfilter/log/log.go | 205 +++++++++++++++--- client/firewall/uspfilter/log/log_test.go | 15 +- client/firewall/uspfilter/nat.go | 8 +- 11 files changed, 238 insertions(+), 104 deletions(-) diff --git a/client/firewall/uspfilter/conntrack/icmp.go b/client/firewall/uspfilter/conntrack/icmp.go index 509c1549b..50b663642 100644 --- a/client/firewall/uspfilter/conntrack/icmp.go +++ b/client/firewall/uspfilter/conntrack/icmp.go @@ -221,7 +221,7 @@ func (t *ICMPTracker) track( // non echo requests don't need tracking if typ != uint8(layers.ICMPv4TypeEchoRequest) { - t.logger.Trace("New %s ICMP connection %s - %s", direction, key, icmpInfo) + t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo) t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size) return } @@ -243,7 +243,7 @@ func (t *ICMPTracker) track( t.connections[key] = conn t.mutex.Unlock() - t.logger.Trace("New %s ICMP connection %s - %s", direction, key, icmpInfo) + t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo) t.sendEvent(nftypes.TypeStart, conn, ruleId) } @@ -294,7 +294,7 @@ func (t *ICMPTracker) cleanup() { if conn.timeoutExceeded(t.timeout) { delete(t.connections, key) - t.logger.Trace("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]", + t.logger.Trace5("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]", key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) t.sendEvent(nftypes.TypeEnd, conn, nil) } diff --git a/client/firewall/uspfilter/conntrack/tcp.go b/client/firewall/uspfilter/conntrack/tcp.go index 2d42ea32e..a2355e5c7 100644 --- a/client/firewall/uspfilter/conntrack/tcp.go +++ b/client/firewall/uspfilter/conntrack/tcp.go @@ -211,7 +211,7 @@ func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, fla conn.tombstone.Store(false) conn.state.Store(int32(TCPStateNew)) - t.logger.Trace("New %s TCP connection: %s", direction, key) + t.logger.Trace2("New %s TCP connection: %s", direction, key) t.updateState(key, conn, flags, direction, size) t.mutex.Lock() @@ -240,7 +240,7 @@ func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort ui currentState := conn.GetState() if !t.isValidStateForFlags(currentState, flags) { - t.logger.Warn("TCP state %s is not valid with flags %x for connection %s", currentState, flags, key) + t.logger.Warn3("TCP state %s is not valid with flags %x for connection %s", currentState, flags, key) // allow all flags for established for now if currentState == TCPStateEstablished { return true @@ -262,7 +262,7 @@ func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, p if flags&TCPRst != 0 { if conn.CompareAndSwapState(currentState, TCPStateClosed) { conn.SetTombstone() - t.logger.Trace("TCP connection reset: %s (dir: %s) [in: %d Pkts/%d B, out: %d Pkts/%d B]", + t.logger.Trace6("TCP connection reset: %s (dir: %s) [in: %d Pkts/%d B, out: %d Pkts/%d B]", key, packetDir, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) t.sendEvent(nftypes.TypeEnd, conn, nil) } @@ -340,17 +340,17 @@ func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, p } if newState != 0 && conn.CompareAndSwapState(currentState, newState) { - t.logger.Trace("TCP connection %s transitioned from %s to %s (dir: %s)", key, currentState, newState, packetDir) + t.logger.Trace4("TCP connection %s transitioned from %s to %s (dir: %s)", key, currentState, newState, packetDir) switch newState { case TCPStateTimeWait: - t.logger.Trace("TCP connection %s completed [in: %d Pkts/%d B, out: %d Pkts/%d B]", + t.logger.Trace5("TCP connection %s completed [in: %d Pkts/%d B, out: %d Pkts/%d B]", key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) t.sendEvent(nftypes.TypeEnd, conn, nil) case TCPStateClosed: conn.SetTombstone() - t.logger.Trace("TCP connection %s closed gracefully [in: %d Pkts/%d, B out: %d Pkts/%d B]", + t.logger.Trace5("TCP connection %s closed gracefully [in: %d Pkts/%d, B out: %d Pkts/%d B]", key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) t.sendEvent(nftypes.TypeEnd, conn, nil) } @@ -438,7 +438,7 @@ func (t *TCPTracker) cleanup() { if conn.timeoutExceeded(timeout) { delete(t.connections, key) - t.logger.Trace("Cleaned up timed-out TCP connection %s (%s) [in: %d Pkts/%d, B out: %d Pkts/%d B]", + t.logger.Trace6("Cleaned up timed-out TCP connection %s (%s) [in: %d Pkts/%d, B out: %d Pkts/%d B]", key, conn.GetState(), conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) // event already handled by state change diff --git a/client/firewall/uspfilter/conntrack/udp.go b/client/firewall/uspfilter/conntrack/udp.go index 000eaa1b6..e7f49c46f 100644 --- a/client/firewall/uspfilter/conntrack/udp.go +++ b/client/firewall/uspfilter/conntrack/udp.go @@ -116,7 +116,7 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d t.connections[key] = conn t.mutex.Unlock() - t.logger.Trace("New %s UDP connection: %s", direction, key) + t.logger.Trace2("New %s UDP connection: %s", direction, key) t.sendEvent(nftypes.TypeStart, conn, ruleID) } @@ -165,7 +165,7 @@ func (t *UDPTracker) cleanup() { if conn.timeoutExceeded(t.timeout) { delete(t.connections, key) - t.logger.Trace("Removed UDP connection %s (timeout) [in: %d Pkts/%d B, out: %d Pkts/%d B]", + t.logger.Trace5("Removed UDP connection %s (timeout) [in: %d Pkts/%d B, out: %d Pkts/%d B]", key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) t.sendEvent(nftypes.TypeEnd, conn, nil) } diff --git a/client/firewall/uspfilter/filter.go b/client/firewall/uspfilter/filter.go index 7120d7d64..fdc026b88 100644 --- a/client/firewall/uspfilter/filter.go +++ b/client/firewall/uspfilter/filter.go @@ -601,7 +601,7 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool { srcIP, dstIP := m.extractIPs(d) if !srcIP.IsValid() { - m.logger.Error("Unknown network layer: %v", d.decoded[0]) + m.logger.Error1("Unknown network layer: %v", d.decoded[0]) return false } @@ -727,13 +727,13 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool { srcIP, dstIP := m.extractIPs(d) if !srcIP.IsValid() { - m.logger.Error("Unknown network layer: %v", d.decoded[0]) + m.logger.Error1("Unknown network layer: %v", d.decoded[0]) return true } // TODO: pass fragments of routed packets to forwarder if fragment { - m.logger.Trace("packet is a fragment: src=%v dst=%v id=%v flags=%v", + m.logger.Trace4("packet is a fragment: src=%v dst=%v id=%v flags=%v", srcIP, dstIP, d.ip4.Id, d.ip4.Flags) return false } @@ -741,7 +741,7 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool { if translated := m.translateInboundReverse(packetData, d); translated { // Re-decode after translation to get original addresses if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { - m.logger.Error("Failed to re-decode packet after reverse DNAT: %v", err) + m.logger.Error1("Failed to re-decode packet after reverse DNAT: %v", err) return true } srcIP, dstIP = m.extractIPs(d) @@ -766,7 +766,7 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet _, pnum := getProtocolFromPacket(d) srcPort, dstPort := getPortsFromPacket(d) - m.logger.Trace("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d", + m.logger.Trace6("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d", ruleID, pnum, srcIP, srcPort, dstIP, dstPort) m.flowLogger.StoreEvent(nftypes.EventFields{ @@ -807,7 +807,7 @@ func (m *Manager) handleForwardedLocalTraffic(packetData []byte) bool { } if err := fwd.InjectIncomingPacket(packetData); err != nil { - m.logger.Error("Failed to inject local packet: %v", err) + m.logger.Error1("Failed to inject local packet: %v", err) } // don't process this packet further @@ -819,7 +819,7 @@ func (m *Manager) handleForwardedLocalTraffic(packetData []byte) bool { func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool { // Drop if routing is disabled if !m.routingEnabled.Load() { - m.logger.Trace("Dropping routed packet (routing disabled): src=%s dst=%s", + m.logger.Trace2("Dropping routed packet (routing disabled): src=%s dst=%s", srcIP, dstIP) return true } @@ -835,7 +835,7 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort) if !pass { - m.logger.Trace("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d", + m.logger.Trace6("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d", ruleID, pnum, srcIP, srcPort, dstIP, dstPort) m.flowLogger.StoreEvent(nftypes.EventFields{ @@ -863,7 +863,7 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe fwd.RegisterRuleID(srcIP, dstIP, srcPort, dstPort, ruleID) if err := fwd.InjectIncomingPacket(packetData); err != nil { - m.logger.Error("Failed to inject routed packet: %v", err) + m.logger.Error1("Failed to inject routed packet: %v", err) fwd.DeleteRuleID(srcIP, dstIP, srcPort, dstPort) } } @@ -901,7 +901,7 @@ func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) { // It returns true, true if the packet is a fragment and valid. func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) { if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { - m.logger.Trace("couldn't decode packet, err: %s", err) + m.logger.Trace1("couldn't decode packet, err: %s", err) return false, false } diff --git a/client/firewall/uspfilter/forwarder/endpoint.go b/client/firewall/uspfilter/forwarder/endpoint.go index e18c083b9..f91291ea8 100644 --- a/client/firewall/uspfilter/forwarder/endpoint.go +++ b/client/firewall/uspfilter/forwarder/endpoint.go @@ -57,7 +57,7 @@ func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) address := netHeader.DestinationAddress() err := e.device.CreateOutboundPacket(data.AsSlice(), address.AsSlice()) if err != nil { - e.logger.Error("CreateOutboundPacket: %v", err) + e.logger.Error1("CreateOutboundPacket: %v", err) continue } written++ diff --git a/client/firewall/uspfilter/forwarder/icmp.go b/client/firewall/uspfilter/forwarder/icmp.go index 08d77ed05..939c04789 100644 --- a/client/firewall/uspfilter/forwarder/icmp.go +++ b/client/firewall/uspfilter/forwarder/icmp.go @@ -34,14 +34,14 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf // TODO: support non-root conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0") if err != nil { - f.logger.Error("forwarder: Failed to create ICMP socket for %v: %v", epID(id), err) + f.logger.Error2("forwarder: Failed to create ICMP socket for %v: %v", epID(id), err) // This will make netstack reply on behalf of the original destination, that's ok for now return false } defer func() { if err := conn.Close(); err != nil { - f.logger.Debug("forwarder: Failed to close ICMP socket: %v", err) + f.logger.Debug1("forwarder: Failed to close ICMP socket: %v", err) } }() @@ -52,11 +52,11 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf payload := fullPacket.AsSlice() if _, err = conn.WriteTo(payload, dst); err != nil { - f.logger.Error("forwarder: Failed to write ICMP packet for %v: %v", epID(id), err) + f.logger.Error2("forwarder: Failed to write ICMP packet for %v: %v", epID(id), err) return true } - f.logger.Trace("forwarder: Forwarded ICMP packet %v type %v code %v", + f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v", epID(id), icmpHdr.Type(), icmpHdr.Code()) // For Echo Requests, send and handle response @@ -72,7 +72,7 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) int { if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { - f.logger.Error("forwarder: Failed to set read deadline for ICMP response: %v", err) + f.logger.Error1("forwarder: Failed to set read deadline for ICMP response: %v", err) return 0 } @@ -80,7 +80,7 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketCon n, _, err := conn.ReadFrom(response) if err != nil { if !isTimeout(err) { - f.logger.Error("forwarder: Failed to read ICMP response: %v", err) + f.logger.Error1("forwarder: Failed to read ICMP response: %v", err) } return 0 } @@ -101,12 +101,12 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketCon fullPacket = append(fullPacket, response[:n]...) if err := f.InjectIncomingPacket(fullPacket); err != nil { - f.logger.Error("forwarder: Failed to inject ICMP response: %v", err) + f.logger.Error1("forwarder: Failed to inject ICMP response: %v", err) return 0 } - f.logger.Trace("forwarder: Forwarded ICMP echo reply for %v type %v code %v", + f.logger.Trace3("forwarder: Forwarded ICMP echo reply for %v type %v code %v", epID(id), icmpHdr.Type(), icmpHdr.Code()) return len(fullPacket) diff --git a/client/firewall/uspfilter/forwarder/tcp.go b/client/firewall/uspfilter/forwarder/tcp.go index aa42f811b..aef420061 100644 --- a/client/firewall/uspfilter/forwarder/tcp.go +++ b/client/firewall/uspfilter/forwarder/tcp.go @@ -38,7 +38,7 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) { outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr) if err != nil { r.Complete(true) - f.logger.Trace("forwarder: dial error for %v: %v", epID(id), err) + f.logger.Trace2("forwarder: dial error for %v: %v", epID(id), err) return } @@ -47,9 +47,9 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) { ep, epErr := r.CreateEndpoint(&wq) if epErr != nil { - f.logger.Error("forwarder: failed to create TCP endpoint: %v", epErr) + f.logger.Error1("forwarder: failed to create TCP endpoint: %v", epErr) if err := outConn.Close(); err != nil { - f.logger.Debug("forwarder: outConn close error: %v", err) + f.logger.Debug1("forwarder: outConn close error: %v", err) } r.Complete(true) return @@ -61,7 +61,7 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) { inConn := gonet.NewTCPConn(&wq, ep) success = true - f.logger.Trace("forwarder: established TCP connection %v", epID(id)) + f.logger.Trace1("forwarder: established TCP connection %v", epID(id)) go f.proxyTCP(id, inConn, outConn, ep, flowID) } @@ -75,10 +75,10 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn <-ctx.Done() // Close connections and endpoint. if err := inConn.Close(); err != nil && !isClosedError(err) { - f.logger.Debug("forwarder: inConn close error: %v", err) + f.logger.Debug1("forwarder: inConn close error: %v", err) } if err := outConn.Close(); err != nil && !isClosedError(err) { - f.logger.Debug("forwarder: outConn close error: %v", err) + f.logger.Debug1("forwarder: outConn close error: %v", err) } ep.Close() @@ -111,12 +111,12 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn if errInToOut != nil { if !isClosedError(errInToOut) { - f.logger.Error("proxyTCP: copy error (in → out) for %s: %v", epID(id), errInToOut) + f.logger.Error2("proxyTCP: copy error (in → out) for %s: %v", epID(id), errInToOut) } } if errOutToIn != nil { if !isClosedError(errOutToIn) { - f.logger.Error("proxyTCP: copy error (out → in) for %s: %v", epID(id), errOutToIn) + f.logger.Error2("proxyTCP: copy error (out → in) for %s: %v", epID(id), errOutToIn) } } @@ -127,7 +127,7 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn txPackets = tcpStats.SegmentsReceived.Value() } - f.logger.Trace("forwarder: Removed TCP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesFromOutToIn, txPackets, bytesFromInToOut) + f.logger.Trace5("forwarder: Removed TCP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesFromOutToIn, txPackets, bytesFromInToOut) f.sendTCPEvent(nftypes.TypeEnd, flowID, id, uint64(bytesFromOutToIn), uint64(bytesFromInToOut), rxPackets, txPackets) } diff --git a/client/firewall/uspfilter/forwarder/udp.go b/client/firewall/uspfilter/forwarder/udp.go index 3a761d06b..d146de5e4 100644 --- a/client/firewall/uspfilter/forwarder/udp.go +++ b/client/firewall/uspfilter/forwarder/udp.go @@ -78,10 +78,10 @@ func (f *udpForwarder) Stop() { for id, conn := range f.conns { conn.cancel() if err := conn.conn.Close(); err != nil { - f.logger.Debug("forwarder: UDP conn close error for %v: %v", epID(id), err) + f.logger.Debug2("forwarder: UDP conn close error for %v: %v", epID(id), err) } if err := conn.outConn.Close(); err != nil { - f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err) + f.logger.Debug2("forwarder: UDP outConn close error for %v: %v", epID(id), err) } conn.ep.Close() @@ -112,10 +112,10 @@ func (f *udpForwarder) cleanup() { for _, idle := range idleConns { idle.conn.cancel() if err := idle.conn.conn.Close(); err != nil { - f.logger.Debug("forwarder: UDP conn close error for %v: %v", epID(idle.id), err) + f.logger.Debug2("forwarder: UDP conn close error for %v: %v", epID(idle.id), err) } if err := idle.conn.outConn.Close(); err != nil { - f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(idle.id), err) + f.logger.Debug2("forwarder: UDP outConn close error for %v: %v", epID(idle.id), err) } idle.conn.ep.Close() @@ -124,7 +124,7 @@ func (f *udpForwarder) cleanup() { delete(f.conns, idle.id) f.Unlock() - f.logger.Trace("forwarder: cleaned up idle UDP connection %v", epID(idle.id)) + f.logger.Trace1("forwarder: cleaned up idle UDP connection %v", epID(idle.id)) } } } @@ -143,7 +143,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) { _, exists := f.udpForwarder.conns[id] f.udpForwarder.RUnlock() if exists { - f.logger.Trace("forwarder: existing UDP connection for %v", epID(id)) + f.logger.Trace1("forwarder: existing UDP connection for %v", epID(id)) return } @@ -160,7 +160,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) { dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort) outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr) if err != nil { - f.logger.Debug("forwarder: UDP dial error for %v: %v", epID(id), err) + f.logger.Debug2("forwarder: UDP dial error for %v: %v", epID(id), err) // TODO: Send ICMP error message return } @@ -169,9 +169,9 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) { wq := waiter.Queue{} ep, epErr := r.CreateEndpoint(&wq) if epErr != nil { - f.logger.Debug("forwarder: failed to create UDP endpoint: %v", epErr) + f.logger.Debug1("forwarder: failed to create UDP endpoint: %v", epErr) if err := outConn.Close(); err != nil { - f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err) + f.logger.Debug2("forwarder: UDP outConn close error for %v: %v", epID(id), err) } return } @@ -194,10 +194,10 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) { f.udpForwarder.Unlock() pConn.cancel() if err := inConn.Close(); err != nil { - f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err) + f.logger.Debug2("forwarder: UDP inConn close error for %v: %v", epID(id), err) } if err := outConn.Close(); err != nil { - f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err) + f.logger.Debug2("forwarder: UDP outConn close error for %v: %v", epID(id), err) } return } @@ -205,7 +205,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) { f.udpForwarder.Unlock() success = true - f.logger.Trace("forwarder: established UDP connection %v", epID(id)) + f.logger.Trace1("forwarder: established UDP connection %v", epID(id)) go f.proxyUDP(connCtx, pConn, id, ep) } @@ -220,10 +220,10 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack pConn.cancel() if err := pConn.conn.Close(); err != nil && !isClosedError(err) { - f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err) + f.logger.Debug2("forwarder: UDP inConn close error for %v: %v", epID(id), err) } if err := pConn.outConn.Close(); err != nil && !isClosedError(err) { - f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err) + f.logger.Debug2("forwarder: UDP outConn close error for %v: %v", epID(id), err) } ep.Close() @@ -250,10 +250,10 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack wg.Wait() if outboundErr != nil && !isClosedError(outboundErr) { - f.logger.Error("proxyUDP: copy error (outbound→inbound) for %s: %v", epID(id), outboundErr) + f.logger.Error2("proxyUDP: copy error (outbound→inbound) for %s: %v", epID(id), outboundErr) } if inboundErr != nil && !isClosedError(inboundErr) { - f.logger.Error("proxyUDP: copy error (inbound→outbound) for %s: %v", epID(id), inboundErr) + f.logger.Error2("proxyUDP: copy error (inbound→outbound) for %s: %v", epID(id), inboundErr) } var rxPackets, txPackets uint64 @@ -263,7 +263,7 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack txPackets = udpStats.PacketsReceived.Value() } - f.logger.Trace("forwarder: Removed UDP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, rxBytes, txPackets, txBytes) + f.logger.Trace5("forwarder: Removed UDP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, rxBytes, txPackets, txBytes) f.udpForwarder.Lock() delete(f.udpForwarder.conns, id) diff --git a/client/firewall/uspfilter/log/log.go b/client/firewall/uspfilter/log/log.go index d22421e2d..5614e2ec3 100644 --- a/client/firewall/uspfilter/log/log.go +++ b/client/firewall/uspfilter/log/log.go @@ -44,7 +44,12 @@ var levelStrings = map[Level]string{ type logMessage struct { level Level format string - args []any + arg1 any + arg2 any + arg3 any + arg4 any + arg5 any + arg6 any } // Logger is a high-performance, non-blocking logger @@ -89,62 +94,198 @@ func (l *Logger) SetLevel(level Level) { log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level]) } -func (l *Logger) log(level Level, format string, args ...any) { - select { - case l.msgChannel <- logMessage{level: level, format: format, args: args}: - default: - } -} -// Error logs a message at error level -func (l *Logger) Error(format string, args ...any) { +func (l *Logger) Error(format string) { if l.level.Load() >= uint32(LevelError) { - l.log(LevelError, format, args...) + select { + case l.msgChannel <- logMessage{level: LevelError, format: format}: + default: + } } } -// Warn logs a message at warning level -func (l *Logger) Warn(format string, args ...any) { +func (l *Logger) Warn(format string) { if l.level.Load() >= uint32(LevelWarn) { - l.log(LevelWarn, format, args...) + select { + case l.msgChannel <- logMessage{level: LevelWarn, format: format}: + default: + } } } -// Info logs a message at info level -func (l *Logger) Info(format string, args ...any) { +func (l *Logger) Info(format string) { if l.level.Load() >= uint32(LevelInfo) { - l.log(LevelInfo, format, args...) + select { + case l.msgChannel <- logMessage{level: LevelInfo, format: format}: + default: + } } } -// Debug logs a message at debug level -func (l *Logger) Debug(format string, args ...any) { +func (l *Logger) Debug(format string) { if l.level.Load() >= uint32(LevelDebug) { - l.log(LevelDebug, format, args...) + select { + case l.msgChannel <- logMessage{level: LevelDebug, format: format}: + default: + } } } -// Trace logs a message at trace level -func (l *Logger) Trace(format string, args ...any) { +func (l *Logger) Trace(format string) { if l.level.Load() >= uint32(LevelTrace) { - l.log(LevelTrace, format, args...) + select { + case l.msgChannel <- logMessage{level: LevelTrace, format: format}: + default: + } } } -func (l *Logger) formatMessage(buf *[]byte, level Level, format string, args ...any) { +func (l *Logger) Error1(format string, arg1 any) { + if l.level.Load() >= uint32(LevelError) { + select { + case l.msgChannel <- logMessage{level: LevelError, format: format, arg1: arg1}: + default: + } + } +} + +func (l *Logger) Error2(format string, arg1, arg2 any) { + if l.level.Load() >= uint32(LevelError) { + select { + case l.msgChannel <- logMessage{level: LevelError, format: format, arg1: arg1, arg2: arg2}: + default: + } + } +} + +func (l *Logger) Warn3(format string, arg1, arg2, arg3 any) { + if l.level.Load() >= uint32(LevelWarn) { + select { + case l.msgChannel <- logMessage{level: LevelWarn, format: format, arg1: arg1, arg2: arg2, arg3: arg3}: + default: + } + } +} + +func (l *Logger) Debug1(format string, arg1 any) { + if l.level.Load() >= uint32(LevelDebug) { + select { + case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1}: + default: + } + } +} + +func (l *Logger) Debug2(format string, arg1, arg2 any) { + if l.level.Load() >= uint32(LevelDebug) { + select { + case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1, arg2: arg2}: + default: + } + } +} + +func (l *Logger) Trace1(format string, arg1 any) { + if l.level.Load() >= uint32(LevelTrace) { + select { + case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1}: + default: + } + } +} + +func (l *Logger) Trace2(format string, arg1, arg2 any) { + if l.level.Load() >= uint32(LevelTrace) { + select { + case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2}: + default: + } + } +} + +func (l *Logger) Trace3(format string, arg1, arg2, arg3 any) { + if l.level.Load() >= uint32(LevelTrace) { + select { + case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3}: + default: + } + } +} + +func (l *Logger) Trace4(format string, arg1, arg2, arg3, arg4 any) { + if l.level.Load() >= uint32(LevelTrace) { + select { + case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}: + default: + } + } +} + +func (l *Logger) Trace5(format string, arg1, arg2, arg3, arg4, arg5 any) { + if l.level.Load() >= uint32(LevelTrace) { + select { + case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5}: + default: + } + } +} + +func (l *Logger) Trace6(format string, arg1, arg2, arg3, arg4, arg5, arg6 any) { + if l.level.Load() >= uint32(LevelTrace) { + select { + case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6}: + default: + } + } +} + +func (l *Logger) formatMessage(buf *[]byte, msg logMessage) { *buf = (*buf)[:0] *buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00") *buf = append(*buf, ' ') - *buf = append(*buf, levelStrings[level]...) + *buf = append(*buf, levelStrings[msg.level]...) *buf = append(*buf, ' ') - var msg string - if len(args) > 0 { - msg = fmt.Sprintf(format, args...) - } else { - msg = format + // Count non-nil arguments for switch + argCount := 0 + if msg.arg1 != nil { + argCount++ + if msg.arg2 != nil { + argCount++ + if msg.arg3 != nil { + argCount++ + if msg.arg4 != nil { + argCount++ + if msg.arg5 != nil { + argCount++ + if msg.arg6 != nil { + argCount++ + } + } + } + } + } } - *buf = append(*buf, msg...) + + var formatted string + switch argCount { + case 0: + formatted = msg.format + case 1: + formatted = fmt.Sprintf(msg.format, msg.arg1) + case 2: + formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2) + case 3: + formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3) + case 4: + formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4) + case 5: + formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5) + case 6: + formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6) + } + + *buf = append(*buf, formatted...) *buf = append(*buf, '\n') if len(*buf) > maxMessageSize { @@ -157,7 +298,7 @@ func (l *Logger) processMessage(msg logMessage, buffer *[]byte) { bufp := l.bufPool.Get().(*[]byte) defer l.bufPool.Put(bufp) - l.formatMessage(bufp, msg.level, msg.format, msg.args...) + l.formatMessage(bufp, msg) if len(*buffer)+len(*bufp) > maxBatchSize { _, _ = l.output.Write(*buffer) @@ -249,4 +390,4 @@ func (l *Logger) Stop(ctx context.Context) error { case <-done: return nil } -} +} \ No newline at end of file diff --git a/client/firewall/uspfilter/log/log_test.go b/client/firewall/uspfilter/log/log_test.go index e7da9a8e9..0c221c262 100644 --- a/client/firewall/uspfilter/log/log_test.go +++ b/client/firewall/uspfilter/log/log_test.go @@ -19,22 +19,17 @@ func (d *discard) Write(p []byte) (n int, err error) { func BenchmarkLogger(b *testing.B) { simpleMessage := "Connection established" - conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d" srcIP := "192.168.1.1" srcPort := uint16(12345) dstIP := "10.0.0.1" dstPort := uint16(443) state := 4 // TCPStateEstablished - complexMessage := "Packet inspection result: protocol=%s, direction=%s, flags=0x%x, sequence=%d, acknowledged=%d, payload_size=%d, fragmented=%v, connection_id=%s" protocol := "TCP" direction := "outbound" flags := uint16(0x18) // ACK + PSH sequence := uint32(123456789) acknowledged := uint32(987654321) - payloadSize := 1460 - fragmented := false - connID := "f7a12b3e-c456-7890-d123-456789abcdef" b.Run("SimpleMessage", func(b *testing.B) { logger := createTestLogger() @@ -52,7 +47,7 @@ func BenchmarkLogger(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state) + logger.Trace5("TCP connection %s:%d → %s:%d state %d", srcIP, srcPort, dstIP, dstPort, state) } }) @@ -62,7 +57,7 @@ func BenchmarkLogger(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - logger.Trace(complexMessage, protocol, direction, flags, sequence, acknowledged, payloadSize, fragmented, connID) + logger.Trace6("Complex trace: proto=%s dir=%s flags=%d seq=%d ack=%d size=%d", protocol, direction, flags, sequence, acknowledged, 1460) } }) } @@ -72,7 +67,6 @@ func BenchmarkLoggerParallel(b *testing.B) { logger := createTestLogger() defer cleanupLogger(logger) - conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d" srcIP := "192.168.1.1" srcPort := uint16(12345) dstIP := "10.0.0.1" @@ -82,7 +76,7 @@ func BenchmarkLoggerParallel(b *testing.B) { b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { - logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state) + logger.Trace5("TCP connection %s:%d → %s:%d state %d", srcIP, srcPort, dstIP, dstPort, state) } }) } @@ -92,7 +86,6 @@ func BenchmarkLoggerBurst(b *testing.B) { logger := createTestLogger() defer cleanupLogger(logger) - conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d" srcIP := "192.168.1.1" srcPort := uint16(12345) dstIP := "10.0.0.1" @@ -102,7 +95,7 @@ func BenchmarkLoggerBurst(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { for j := 0; j < 100; j++ { - logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state) + logger.Trace5("TCP connection %s:%d → %s:%d state %d", srcIP, srcPort, dstIP, dstPort, state) } } } diff --git a/client/firewall/uspfilter/nat.go b/client/firewall/uspfilter/nat.go index 4539f7da5..27b752531 100644 --- a/client/firewall/uspfilter/nat.go +++ b/client/firewall/uspfilter/nat.go @@ -211,11 +211,11 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool { } if err := m.rewritePacketDestination(packetData, d, translatedIP); err != nil { - m.logger.Error("Failed to rewrite packet destination: %v", err) + m.logger.Error1("Failed to rewrite packet destination: %v", err) return false } - m.logger.Trace("DNAT: %s -> %s", dstIP, translatedIP) + m.logger.Trace2("DNAT: %s -> %s", dstIP, translatedIP) return true } @@ -237,11 +237,11 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool { } if err := m.rewritePacketSource(packetData, d, originalIP); err != nil { - m.logger.Error("Failed to rewrite packet source: %v", err) + m.logger.Error1("Failed to rewrite packet source: %v", err) return false } - m.logger.Trace("Reverse DNAT: %s -> %s", srcIP, originalIP) + m.logger.Trace2("Reverse DNAT: %s -> %s", srcIP, originalIP) return true } From 552dc605479eb92822df107eb78e0c1326290bc2 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Fri, 1 Aug 2025 12:22:07 +0200 Subject: [PATCH 50/50] [management] migrate group peers into seperate table (#4096) --- management/server/account.go | 95 +++--- management/server/account/manager.go | 6 +- management/server/account_test.go | 31 +- management/server/dns_test.go | 8 +- management/server/group.go | 230 +++++++++++-- management/server/group_test.go | 323 ++++++++++++++++-- .../http/handlers/groups/groups_handler.go | 4 +- management/server/migration/migration.go | 64 ++++ management/server/mock_server/account_mock.go | 28 ++ management/server/nameserver_test.go | 24 +- management/server/peer.go | 56 ++- management/server/peer_test.go | 33 +- management/server/policy_test.go | 10 +- management/server/posture_checks_test.go | 25 +- management/server/route_test.go | 21 +- management/server/setupkey_test.go | 16 +- management/server/store/sql_store.go | 304 +++++++++++++---- management/server/store/sql_store_test.go | 113 ++++-- management/server/store/store.go | 23 +- management/server/types/account.go | 2 +- management/server/types/group.go | 32 +- management/server/types/setupkey.go | 2 +- management/server/user.go | 106 +----- management/server/user_test.go | 4 +- 24 files changed, 1139 insertions(+), 421 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 52b625da1..d392cd0b9 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1368,7 +1368,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth return nil } - if err = transaction.SaveGroups(ctx, store.LockingStrengthUpdate, userAuth.AccountId, newGroupsToCreate); err != nil { + if err = transaction.CreateGroups(ctx, store.LockingStrengthUpdate, userAuth.AccountId, newGroupsToCreate); err != nil { return fmt.Errorf("error saving groups: %w", err) } @@ -1382,28 +1382,22 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth // Propagate changes to peers if group propagation is enabled if settings.GroupsPropagationEnabled { - groups, err = transaction.GetAccountGroups(ctx, store.LockingStrengthShare, userAuth.AccountId) - if err != nil { - return fmt.Errorf("error getting account groups: %w", err) - } - - groupsMap := make(map[string]*types.Group, len(groups)) - for _, group := range groups { - groupsMap[group.ID] = group - } - peers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, userAuth.AccountId, userAuth.UserId) if err != nil { return fmt.Errorf("error getting user peers: %w", err) } - updatedGroups, err := updateUserPeersInGroups(groupsMap, peers, addNewGroups, removeOldGroups) - if err != nil { - return fmt.Errorf("error modifying user peers in groups: %w", err) - } - - if err = transaction.SaveGroups(ctx, store.LockingStrengthUpdate, userAuth.AccountId, updatedGroups); err != nil { - return fmt.Errorf("error saving groups: %w", err) + for _, peer := range peers { + for _, g := range addNewGroups { + if err := transaction.AddPeerToGroup(ctx, userAuth.AccountId, peer.ID, g); err != nil { + return fmt.Errorf("error adding peer %s to group %s: %w", peer.ID, g, err) + } + } + for _, g := range removeOldGroups { + if err := transaction.RemovePeerFromGroup(ctx, peer.ID, g); err != nil { + return fmt.Errorf("error removing peer %s from group %s: %w", peer.ID, g, err) + } + } } if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, userAuth.AccountId); err != nil { @@ -1971,53 +1965,56 @@ func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, acc // propagateUserGroupMemberships propagates all account users' group memberships to their peers. // Returns true if any groups were modified, true if those updates affect peers and an error. func propagateUserGroupMemberships(ctx context.Context, transaction store.Store, accountID string) (groupsUpdated bool, peersAffected bool, err error) { - groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) - if err != nil { - return false, false, err - } - - groupsMap := make(map[string]*types.Group, len(groups)) - for _, group := range groups { - groupsMap[group.ID] = group - } - users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthShare, accountID) if err != nil { return false, false, err } - groupsToUpdate := make(map[string]*types.Group) + accountGroupPeers, err := transaction.GetAccountGroupPeers(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return false, false, fmt.Errorf("error getting account group peers: %w", err) + } + accountGroups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return false, false, fmt.Errorf("error getting account groups: %w", err) + } + + for _, group := range accountGroups { + if _, exists := accountGroupPeers[group.ID]; !exists { + accountGroupPeers[group.ID] = make(map[string]struct{}) + } + } + + updatedGroups := []string{} for _, user := range users { userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, accountID, user.Id) if err != nil { return false, false, err } - updatedGroups, err := updateUserPeersInGroups(groupsMap, userPeers, user.AutoGroups, nil) - if err != nil { - return false, false, err - } - - for _, group := range updatedGroups { - groupsToUpdate[group.ID] = group - groupsMap[group.ID] = group + for _, peer := range userPeers { + for _, groupID := range user.AutoGroups { + if _, exists := accountGroupPeers[groupID]; !exists { + // we do not wanna create the groups here + log.WithContext(ctx).Warnf("group %s does not exist for user group propagation", groupID) + continue + } + if _, exists := accountGroupPeers[groupID][peer.ID]; exists { + continue + } + if err := transaction.AddPeerToGroup(ctx, accountID, peer.ID, groupID); err != nil { + return false, false, fmt.Errorf("error adding peer %s to group %s: %w", peer.ID, groupID, err) + } + updatedGroups = append(updatedGroups, groupID) + } } } - if len(groupsToUpdate) == 0 { - return false, false, nil - } - - peersAffected, err = areGroupChangesAffectPeers(ctx, transaction, accountID, maps.Keys(groupsToUpdate)) + peersAffected, err = areGroupChangesAffectPeers(ctx, transaction, accountID, updatedGroups) if err != nil { - return false, false, err + return false, false, fmt.Errorf("error checking if group changes affect peers: %w", err) } - err = transaction.SaveGroups(ctx, store.LockingStrengthUpdate, accountID, maps.Values(groupsToUpdate)) - if err != nil { - return false, false, err - } - - return true, peersAffected, nil + return len(updatedGroups) > 0, peersAffected, nil } diff --git a/management/server/account/manager.go b/management/server/account/manager.go index 8c7e95e3d..0cd1c6637 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -62,8 +62,10 @@ type Manager interface { GetGroup(ctx context.Context, accountId, groupID, userID string) (*types.Group, error) GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) - SaveGroup(ctx context.Context, accountID, userID string, group *types.Group, create bool) error - SaveGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group, create bool) error + CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error + UpdateGroup(ctx context.Context, accountID, userID string, group *types.Group) error + CreateGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error + UpdateGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error DeleteGroup(ctx context.Context, accountId, userId, groupID string) error DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error diff --git a/management/server/account_test.go b/management/server/account_test.go index b65dffe6c..1dd74104b 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1159,7 +1159,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { Name: "GroupA", Peers: []string{}, } - if err := manager.SaveGroup(context.Background(), account.Id, userID, &group, true); err != nil { + if err := manager.CreateGroup(context.Background(), account.Id, userID, &group); err != nil { t.Errorf("save group: %v", err) return } @@ -1194,7 +1194,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { }() group.Peers = []string{peer1.ID, peer2.ID, peer3.ID} - if err := manager.SaveGroup(context.Background(), account.Id, userID, &group, true); err != nil { + if err := manager.UpdateGroup(context.Background(), account.Id, userID, &group); err != nil { t.Errorf("save group: %v", err) return } @@ -1240,11 +1240,12 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { manager, account, peer1, peer2, _ := setupNetworkMapTest(t) group := types.Group{ - ID: "groupA", - Name: "GroupA", - Peers: []string{peer1.ID, peer2.ID}, + AccountID: account.Id, + ID: "groupA", + Name: "GroupA", + Peers: []string{peer1.ID, peer2.ID}, } - if err := manager.SaveGroup(context.Background(), account.Id, userID, &group, true); err != nil { + if err := manager.CreateGroup(context.Background(), account.Id, userID, &group); err != nil { t.Errorf("save group: %v", err) return } @@ -1292,7 +1293,7 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { Name: "GroupA", Peers: []string{peer1.ID, peer3.ID}, } - if err := manager.SaveGroup(context.Background(), account.Id, userID, &group, true); err != nil { + if err := manager.CreateGroup(context.Background(), account.Id, userID, &group); err != nil { t.Errorf("save group: %v", err) return } @@ -1343,11 +1344,11 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) - err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ + err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID, peer3.ID}, - }, true) + }) require.NoError(t, err, "failed to save group") @@ -1672,9 +1673,10 @@ func TestAccount_Copy(t *testing.T) { }, Groups: map[string]*types.Group{ "group1": { - ID: "group1", - Peers: []string{"peer1"}, - Resources: []types.Resource{}, + ID: "group1", + Peers: []string{"peer1"}, + Resources: []types.Resource{}, + GroupPeers: []types.GroupPeer{}, }, }, Policies: []*types.Policy{ @@ -2616,6 +2618,7 @@ func TestAccount_GetNextInactivePeerExpiration(t *testing.T) { } func TestAccount_SetJWTGroups(t *testing.T) { + t.Setenv("NETBIRD_STORE_ENGINE", "postgres") manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") @@ -3360,7 +3363,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) { t.Run("should update membership but no account peers update for unused groups", func(t *testing.T) { group1 := &types.Group{ID: "group1", Name: "Group 1", AccountID: account.Id} - require.NoError(t, manager.Store.SaveGroup(ctx, store.LockingStrengthUpdate, group1)) + require.NoError(t, manager.Store.CreateGroup(ctx, store.LockingStrengthUpdate, group1)) user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorId) require.NoError(t, err) @@ -3382,7 +3385,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) { t.Run("should update membership and account peers for used groups", func(t *testing.T) { group2 := &types.Group{ID: "group2", Name: "Group 2", AccountID: account.Id} - require.NoError(t, manager.Store.SaveGroup(ctx, store.LockingStrengthUpdate, group2)) + require.NoError(t, manager.Store.CreateGroup(ctx, store.LockingStrengthUpdate, group2)) user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorId) require.NoError(t, err) diff --git a/management/server/dns_test.go b/management/server/dns_test.go index f2295450f..2af07d8e4 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -495,7 +495,7 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) { func TestDNSAccountPeersUpdate(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ + err := manager.CreateGroups(context.Background(), account.Id, userID, []*types.Group{ { ID: "groupA", Name: "GroupA", @@ -506,7 +506,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) { Name: "GroupB", Peers: []string{}, }, - }, true) + }) assert.NoError(t, err) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) @@ -562,11 +562,11 @@ func TestDNSAccountPeersUpdate(t *testing.T) { // Creating DNS settings with groups that have peers should update account peers and send peer update t.Run("creating dns setting with used groups", func(t *testing.T) { - err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ + err = manager.UpdateGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID, peer3.ID}, - }, true) + }) assert.NoError(t, err) done := make(chan struct{}) diff --git a/management/server/group.go b/management/server/group.go index 130a67145..95bed7d18 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -65,22 +65,144 @@ func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, return am.Store.GetGroupByName(ctx, store.LockingStrengthShare, accountID, groupName) } -// SaveGroup object of the peers -func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *types.Group, create bool) error { +// CreateGroup object of the peers +func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) error { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - return am.SaveGroups(ctx, accountID, userID, []*types.Group{newGroup}, create) + + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Create) + if err != nil { + return status.NewPermissionValidationError(err) + } + if !allowed { + return status.NewPermissionDeniedError() + } + + var eventsToStore []func() + var updateAccountPeers bool + + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil { + return err + } + + newGroup.AccountID = accountID + + events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup) + eventsToStore = append(eventsToStore, events...) + + updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{newGroup.ID}) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + return err + } + + if err := transaction.CreateGroup(ctx, store.LockingStrengthUpdate, newGroup); err != nil { + return status.Errorf(status.Internal, "failed to create group: %v", err) + } + + for _, peerID := range newGroup.Peers { + if err := transaction.AddPeerToGroup(ctx, accountID, peerID, newGroup.ID); err != nil { + return status.Errorf(status.Internal, "failed to add peer %s to group %s: %v", peerID, newGroup.ID, err) + } + } + return nil + }) + if err != nil { + return err + } + + for _, storeEvent := range eventsToStore { + storeEvent() + } + + if updateAccountPeers { + am.UpdateAccountPeers(ctx, accountID) + } + + return nil } -// SaveGroups adds new groups to the account. +// UpdateGroup object of the peers +func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Update) + if err != nil { + return status.NewPermissionValidationError(err) + } + if !allowed { + return status.NewPermissionDeniedError() + } + + var eventsToStore []func() + var updateAccountPeers bool + + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil { + return err + } + + oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, newGroup.ID) + if err != nil { + return status.Errorf(status.NotFound, "group with ID %s not found", newGroup.ID) + } + + peersToAdd := util.Difference(newGroup.Peers, oldGroup.Peers) + peersToRemove := util.Difference(oldGroup.Peers, newGroup.Peers) + + for _, peerID := range peersToAdd { + if err := transaction.AddPeerToGroup(ctx, accountID, peerID, newGroup.ID); err != nil { + return status.Errorf(status.Internal, "failed to add peer %s to group %s: %v", peerID, newGroup.ID, err) + } + } + for _, peerID := range peersToRemove { + if err := transaction.RemovePeerFromGroup(ctx, peerID, newGroup.ID); err != nil { + return status.Errorf(status.Internal, "failed to remove peer %s from group %s: %v", peerID, newGroup.ID, err) + } + } + + newGroup.AccountID = accountID + + events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup) + eventsToStore = append(eventsToStore, events...) + + updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{newGroup.ID}) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.UpdateGroup(ctx, store.LockingStrengthUpdate, newGroup) + }) + if err != nil { + return err + } + + for _, storeEvent := range eventsToStore { + storeEvent() + } + + if updateAccountPeers { + am.UpdateAccountPeers(ctx, accountID) + } + + return nil +} + +// CreateGroups adds new groups to the account. // Note: This function does not acquire the global lock. // It is the caller's responsibility to ensure proper locking is in place before invoking this method. -func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*types.Group, create bool) error { - operation := operations.Create - if !create { - operation = operations.Update - } - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operation) +// This method will not create group peer membership relations. Use AddPeerToGroup or RemovePeerFromGroup methods for that. +func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, userID string, groups []*types.Group) error { + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Create) if err != nil { return status.NewPermissionValidationError(err) } @@ -116,7 +238,65 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user return err } - return transaction.SaveGroups(ctx, store.LockingStrengthUpdate, accountID, groupsToSave) + return transaction.CreateGroups(ctx, store.LockingStrengthUpdate, accountID, groupsToSave) + }) + if err != nil { + return err + } + + for _, storeEvent := range eventsToStore { + storeEvent() + } + + if updateAccountPeers { + am.UpdateAccountPeers(ctx, accountID) + } + + return nil +} + +// UpdateGroups updates groups in the account. +// Note: This function does not acquire the global lock. +// It is the caller's responsibility to ensure proper locking is in place before invoking this method. +// This method will not create group peer membership relations. Use AddPeerToGroup or RemovePeerFromGroup methods for that. +func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, userID string, groups []*types.Group) error { + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Update) + if err != nil { + return status.NewPermissionValidationError(err) + } + if !allowed { + return status.NewPermissionDeniedError() + } + + var eventsToStore []func() + var groupsToSave []*types.Group + var updateAccountPeers bool + + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + groupIDs := make([]string, 0, len(groups)) + for _, newGroup := range groups { + if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil { + return err + } + + newGroup.AccountID = accountID + groupsToSave = append(groupsToSave, newGroup) + groupIDs = append(groupIDs, newGroup.ID) + + events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup) + eventsToStore = append(eventsToStore, events...) + } + + updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, groupIDs) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.UpdateGroups(ctx, store.LockingStrengthUpdate, accountID, groupsToSave) }) if err != nil { return err @@ -265,20 +445,10 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - var group *types.Group var updateAccountPeers bool var err error err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - group, err = transaction.GetGroupByID(context.Background(), store.LockingStrengthUpdate, accountID, groupID) - if err != nil { - return err - } - - if updated := group.AddPeer(peerID); !updated { - return nil - } - updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID}) if err != nil { return err @@ -288,7 +458,7 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr return err } - return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group) + return transaction.AddPeerToGroup(ctx, accountID, peerID, groupID) }) if err != nil { return err @@ -329,7 +499,7 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID return err } - return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group) + return transaction.UpdateGroup(ctx, store.LockingStrengthUpdate, group) }) if err != nil { return err @@ -347,20 +517,10 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - var group *types.Group var updateAccountPeers bool var err error err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - group, err = transaction.GetGroupByID(context.Background(), store.LockingStrengthUpdate, accountID, groupID) - if err != nil { - return err - } - - if updated := group.RemovePeer(peerID); !updated { - return nil - } - updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID}) if err != nil { return err @@ -370,7 +530,7 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, return err } - return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group) + return transaction.RemovePeerFromGroup(ctx, peerID, groupID) }) if err != nil { return err @@ -411,7 +571,7 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun return err } - return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group) + return transaction.UpdateGroup(ctx, store.LockingStrengthUpdate, group) }) if err != nil { return err diff --git a/management/server/group_test.go b/management/server/group_test.go index 631fe3a71..51069dc56 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -2,14 +2,20 @@ package server import ( "context" + "encoding/binary" "errors" "fmt" + "net" "net/netip" + "strconv" + "sync" "testing" "time" + "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/exp/maps" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/groups" @@ -18,8 +24,10 @@ import ( "github.com/netbirdio/netbird/management/server/networks/routers" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + peer2 "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" ) @@ -40,7 +48,8 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) { } for _, group := range account.Groups { group.Issued = types.GroupIssuedIntegration - err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group, true) + group.ID = uuid.New().String() + err = am.CreateGroup(context.Background(), account.Id, groupAdminUserID, group) if err != nil { t.Errorf("should allow to create %s groups", types.GroupIssuedIntegration) } @@ -48,7 +57,8 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) { for _, group := range account.Groups { group.Issued = types.GroupIssuedJWT - err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group, true) + group.ID = uuid.New().String() + err = am.CreateGroup(context.Background(), account.Id, groupAdminUserID, group) if err != nil { t.Errorf("should allow to create %s groups", types.GroupIssuedJWT) } @@ -56,7 +66,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) { for _, group := range account.Groups { group.Issued = types.GroupIssuedAPI group.ID = "" - err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group, true) + err = am.CreateGroup(context.Background(), account.Id, groupAdminUserID, group) if err == nil { t.Errorf("should not create api group with the same name, %s", group.Name) } @@ -162,7 +172,7 @@ func TestDefaultAccountManager_DeleteGroups(t *testing.T) { } } - err = manager.SaveGroups(context.Background(), account.Id, groupAdminUserID, groups, true) + err = manager.CreateGroups(context.Background(), account.Id, groupAdminUserID, groups) assert.NoError(t, err, "Failed to save test groups") testCases := []struct { @@ -382,13 +392,13 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *t return nil, nil, err } - _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute, true) - _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute2, true) - _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForNameServerGroups, true) - _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForPolicies, true) - _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForSetupKeys, true) - _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForUsers, true) - _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForIntegration, true) + _ = am.CreateGroup(context.Background(), accountID, groupAdminUserID, groupForRoute) + _ = am.CreateGroup(context.Background(), accountID, groupAdminUserID, groupForRoute2) + _ = am.CreateGroup(context.Background(), accountID, groupAdminUserID, groupForNameServerGroups) + _ = am.CreateGroup(context.Background(), accountID, groupAdminUserID, groupForPolicies) + _ = am.CreateGroup(context.Background(), accountID, groupAdminUserID, groupForSetupKeys) + _ = am.CreateGroup(context.Background(), accountID, groupAdminUserID, groupForUsers) + _ = am.CreateGroup(context.Background(), accountID, groupAdminUserID, groupForIntegration) acc, err := am.Store.GetAccount(context.Background(), account.Id) if err != nil { @@ -400,7 +410,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *t func TestGroupAccountPeersUpdate(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ + g := []*types.Group{ { ID: "groupA", Name: "GroupA", @@ -426,8 +436,11 @@ func TestGroupAccountPeersUpdate(t *testing.T) { Name: "GroupE", Peers: []string{peer2.ID}, }, - }, true) - assert.NoError(t, err) + } + for _, group := range g { + err := manager.CreateGroup(context.Background(), account.Id, userID, group) + assert.NoError(t, err) + } updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) t.Cleanup(func() { @@ -442,11 +455,11 @@ func TestGroupAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ + err := manager.UpdateGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupB", Name: "GroupB", Peers: []string{peer1.ID, peer2.ID}, - }, true) + }) assert.NoError(t, err) select { @@ -513,7 +526,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { }) // adding a group to policy - _, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ + _, err := manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ Enabled: true, Rules: []*types.PolicyRule{ { @@ -535,11 +548,11 @@ func TestGroupAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ + err := manager.UpdateGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID}, - }, true) + }) assert.NoError(t, err) select { @@ -604,11 +617,11 @@ func TestGroupAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ + err := manager.UpdateGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupC", Name: "GroupC", Peers: []string{peer1.ID, peer3.ID}, - }, true) + }) assert.NoError(t, err) select { @@ -645,11 +658,11 @@ func TestGroupAccountPeersUpdate(t *testing.T) { close(done) }() - err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ + err = manager.UpdateGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID, peer3.ID}, - }, true) + }) assert.NoError(t, err) select { @@ -672,11 +685,11 @@ func TestGroupAccountPeersUpdate(t *testing.T) { close(done) }() - err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ + err = manager.UpdateGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupD", Name: "GroupD", Peers: []string{peer1.ID}, - }, true) + }) assert.NoError(t, err) select { @@ -719,11 +732,11 @@ func TestGroupAccountPeersUpdate(t *testing.T) { close(done) }() - err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ + err = manager.UpdateGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupE", Name: "GroupE", Peers: []string{peer2.ID, peer3.ID}, - }, true) + }) assert.NoError(t, err) select { @@ -733,3 +746,259 @@ func TestGroupAccountPeersUpdate(t *testing.T) { } }) } + +func Test_AddPeerToGroup(t *testing.T) { + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + return + } + + accountID := "testaccount" + userID := "testuser" + + acc, err := createAccount(manager, accountID, userID, "domain.com") + if err != nil { + t.Fatal("error creating account") + return + } + + const totalPeers = 1000 + + var wg sync.WaitGroup + errs := make(chan error, totalPeers) + start := make(chan struct{}) + for i := 0; i < totalPeers; i++ { + wg.Add(1) + + go func(i int) { + defer wg.Done() + + <-start + + err = manager.Store.AddPeerToGroup(context.Background(), accountID, strconv.Itoa(i), acc.GroupsG[0].ID) + if err != nil { + errs <- fmt.Errorf("AddPeer failed for peer %d: %w", i, err) + return + } + + }(i) + } + startTime := time.Now() + + close(start) + wg.Wait() + close(errs) + + t.Logf("time since start: %s", time.Since(startTime)) + + for err := range errs { + t.Fatal(err) + } + + account, err := manager.Store.GetAccount(context.Background(), accountID) + if err != nil { + t.Fatalf("Failed to get account %s: %v", accountID, err) + } + + assert.Equal(t, totalPeers, len(maps.Values(account.Groups)[0].Peers), "Expected %d peers in group %s in account %s, got %d", totalPeers, maps.Values(account.Groups)[0].Name, accountID, len(account.Peers)) +} + +func Test_AddPeerToAll(t *testing.T) { + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + return + } + + accountID := "testaccount" + userID := "testuser" + + _, err = createAccount(manager, accountID, userID, "domain.com") + if err != nil { + t.Fatal("error creating account") + return + } + + const totalPeers = 1000 + + var wg sync.WaitGroup + errs := make(chan error, totalPeers) + start := make(chan struct{}) + for i := 0; i < totalPeers; i++ { + wg.Add(1) + + go func(i int) { + defer wg.Done() + + <-start + + err = manager.Store.AddPeerToAllGroup(context.Background(), accountID, strconv.Itoa(i)) + if err != nil { + errs <- fmt.Errorf("AddPeer failed for peer %d: %w", i, err) + return + } + + }(i) + } + startTime := time.Now() + + close(start) + wg.Wait() + close(errs) + + t.Logf("time since start: %s", time.Since(startTime)) + + for err := range errs { + t.Fatal(err) + } + + account, err := manager.Store.GetAccount(context.Background(), accountID) + if err != nil { + t.Fatalf("Failed to get account %s: %v", accountID, err) + } + + assert.Equal(t, totalPeers, len(maps.Values(account.Groups)[0].Peers), "Expected %d peers in group %s account %s, got %d", totalPeers, maps.Values(account.Groups)[0].Name, accountID, len(account.Peers)) +} + +func Test_AddPeerAndAddToAll(t *testing.T) { + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + return + } + + accountID := "testaccount" + userID := "testuser" + + _, err = createAccount(manager, accountID, userID, "domain.com") + if err != nil { + t.Fatal("error creating account") + return + } + + const totalPeers = 1000 + + var wg sync.WaitGroup + errs := make(chan error, totalPeers) + start := make(chan struct{}) + for i := 0; i < totalPeers; i++ { + wg.Add(1) + + go func(i int) { + defer wg.Done() + + <-start + + peer := &peer2.Peer{ + ID: strconv.Itoa(i), + AccountID: accountID, + DNSLabel: "peer" + strconv.Itoa(i), + IP: uint32ToIP(uint32(i)), + } + + err = manager.Store.ExecuteInTransaction(context.Background(), func(transaction store.Store) error { + err = transaction.AddPeerToAccount(context.Background(), store.LockingStrengthUpdate, peer) + if err != nil { + return fmt.Errorf("AddPeer failed for peer %d: %w", i, err) + } + err = transaction.AddPeerToAllGroup(context.Background(), accountID, peer.ID) + if err != nil { + return fmt.Errorf("AddPeer failed for peer %d: %w", i, err) + } + return nil + }) + if err != nil { + t.Errorf("AddPeer failed for peer %d: %v", i, err) + return + } + }(i) + } + startTime := time.Now() + + close(start) + wg.Wait() + close(errs) + + t.Logf("time since start: %s", time.Since(startTime)) + + for err := range errs { + t.Fatal(err) + } + + account, err := manager.Store.GetAccount(context.Background(), accountID) + if err != nil { + t.Fatalf("Failed to get account %s: %v", accountID, err) + } + + assert.Equal(t, totalPeers, len(maps.Values(account.Groups)[0].Peers), "Expected %d peers in group %s in account %s, got %d", totalPeers, maps.Values(account.Groups)[0].Name, accountID, len(account.Peers)) + assert.Equal(t, totalPeers, len(account.Peers), "Expected %d peers in account %s, got %d", totalPeers, accountID, len(account.Peers)) +} + +func uint32ToIP(n uint32) net.IP { + ip := make(net.IP, 4) + binary.BigEndian.PutUint32(ip, n) + return ip +} + +func Test_IncrementNetworkSerial(t *testing.T) { + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + return + } + + accountID := "testaccount" + userID := "testuser" + + _, err = createAccount(manager, accountID, userID, "domain.com") + if err != nil { + t.Fatal("error creating account") + return + } + + const totalPeers = 1000 + + var wg sync.WaitGroup + errs := make(chan error, totalPeers) + start := make(chan struct{}) + for i := 0; i < totalPeers; i++ { + wg.Add(1) + + go func(i int) { + defer wg.Done() + + <-start + + err = manager.Store.ExecuteInTransaction(context.Background(), func(transaction store.Store) error { + err = transaction.IncrementNetworkSerial(context.Background(), store.LockingStrengthNone, accountID) + if err != nil { + return fmt.Errorf("failed to get account %s: %v", accountID, err) + } + return nil + }) + if err != nil { + t.Errorf("AddPeer failed for peer %d: %v", i, err) + return + } + }(i) + } + startTime := time.Now() + + close(start) + wg.Wait() + close(errs) + + t.Logf("time since start: %s", time.Since(startTime)) + + for err := range errs { + t.Fatal(err) + } + + account, err := manager.Store.GetAccount(context.Background(), accountID) + if err != nil { + t.Fatalf("Failed to get account %s: %v", accountID, err) + } + + assert.Equal(t, totalPeers, int(account.Network.Serial), "Expected %d serial increases in account %s, got %d", totalPeers, accountID, account.Network.Serial) +} diff --git a/management/server/http/handlers/groups/groups_handler.go b/management/server/http/handlers/groups/groups_handler.go index 3ae833dc0..bede652f5 100644 --- a/management/server/http/handlers/groups/groups_handler.go +++ b/management/server/http/handlers/groups/groups_handler.go @@ -143,7 +143,7 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) { IntegrationReference: existingGroup.IntegrationReference, } - if err := h.accountManager.SaveGroup(r.Context(), accountID, userID, &group, false); err != nil { + if err := h.accountManager.UpdateGroup(r.Context(), accountID, userID, &group); err != nil { log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, accountID, err) util.WriteError(r.Context(), err, w) return @@ -203,7 +203,7 @@ func (h *handler) createGroup(w http.ResponseWriter, r *http.Request) { Issued: types.GroupIssuedAPI, } - err = h.accountManager.SaveGroup(r.Context(), accountID, userID, &group, true) + err = h.accountManager.CreateGroup(r.Context(), accountID, userID, &group) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/migration/migration.go b/management/server/migration/migration.go index c2f1a5abf..88af9a58f 100644 --- a/management/server/migration/migration.go +++ b/management/server/migration/migration.go @@ -39,6 +39,11 @@ func MigrateFieldFromGobToJSON[T any, S any](ctx context.Context, db *gorm.DB, f return nil } + if !db.Migrator().HasColumn(&model, fieldName) { + log.WithContext(ctx).Debugf("Table for %T does not have column %s, no migration needed", model, fieldName) + return nil + } + stmt := &gorm.Statement{DB: db} err := stmt.Parse(model) if err != nil { @@ -422,3 +427,62 @@ func CreateIndexIfNotExists[T any](ctx context.Context, db *gorm.DB, indexName s log.WithContext(ctx).Infof("successfully created index %s on table %s", indexName, tableName) return nil } + +func MigrateJsonToTable[T any](ctx context.Context, db *gorm.DB, columnName string, mapperFunc func(accountID string, id string, value string) any) error { + var model T + + if !db.Migrator().HasTable(&model) { + log.WithContext(ctx).Debugf("table for %T does not exist, no migration needed", model) + return nil + } + + stmt := &gorm.Statement{DB: db} + err := stmt.Parse(&model) + if err != nil { + return fmt.Errorf("parse model: %w", err) + } + tableName := stmt.Schema.Table + + if !db.Migrator().HasColumn(&model, columnName) { + log.WithContext(ctx).Debugf("column %s does not exist in table %s, no migration needed", columnName, tableName) + return nil + } + + if err := db.Transaction(func(tx *gorm.DB) error { + var rows []map[string]any + if err := tx.Table(tableName).Select("id", "account_id", columnName).Find(&rows).Error; err != nil { + return fmt.Errorf("find rows: %w", err) + } + + for _, row := range rows { + jsonValue, ok := row[columnName].(string) + if !ok || jsonValue == "" { + continue + } + + var data []string + if err := json.Unmarshal([]byte(jsonValue), &data); err != nil { + return fmt.Errorf("unmarshal json: %w", err) + } + + for _, value := range data { + if err := tx.Create( + mapperFunc(row["account_id"].(string), row["id"].(string), value), + ).Error; err != nil { + return fmt.Errorf("failed to insert id %v: %w", row["id"], err) + } + } + } + + if err := tx.Migrator().DropColumn(&model, columnName); err != nil { + return fmt.Errorf("drop column %s: %w", columnName, err) + } + + return nil + }); err != nil { + return err + } + + log.WithContext(ctx).Infof("Migration of JSON field %s from table %s into separate table completed", columnName, tableName) + return nil +} diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index a16e3652c..8c8fd19c9 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -124,6 +124,34 @@ type MockAccountManager struct { BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string) } +func (am *MockAccountManager) CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error { + if am.SaveGroupFunc != nil { + return am.SaveGroupFunc(ctx, accountID, userID, group, true) + } + return status.Errorf(codes.Unimplemented, "method CreateGroup is not implemented") +} + +func (am *MockAccountManager) UpdateGroup(ctx context.Context, accountID, userID string, group *types.Group) error { + if am.SaveGroupFunc != nil { + return am.SaveGroupFunc(ctx, accountID, userID, group, false) + } + return status.Errorf(codes.Unimplemented, "method UpdateGroup is not implemented") +} + +func (am *MockAccountManager) CreateGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error { + if am.SaveGroupsFunc != nil { + return am.SaveGroupsFunc(ctx, accountID, userID, newGroups, true) + } + return status.Errorf(codes.Unimplemented, "method CreateGroups is not implemented") +} + +func (am *MockAccountManager) UpdateGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error { + if am.SaveGroupsFunc != nil { + return am.SaveGroupsFunc(ctx, accountID, userID, newGroups, false) + } + return status.Errorf(codes.Unimplemented, "method UpdateGroups is not implemented") +} + func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) { if am.UpdateAccountPeersFunc != nil { am.UpdateAccountPeersFunc(ctx, accountID) diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 25eb03b83..959e7856a 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -980,18 +980,18 @@ func TestNameServerAccountPeersUpdate(t *testing.T) { var newNameServerGroupA *nbdns.NameServerGroup var newNameServerGroupB *nbdns.NameServerGroup - err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ - { - ID: "groupA", - Name: "GroupA", - Peers: []string{}, - }, - { - ID: "groupB", - Name: "GroupB", - Peers: []string{peer1.ID, peer2.ID, peer3.ID}, - }, - }, true) + err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{ + ID: "groupA", + Name: "GroupA", + Peers: []string{}, + }) + assert.NoError(t, err) + + err = manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{ + ID: "groupB", + Name: "GroupB", + Peers: []string{peer1.ID, peer2.ID, peer3.ID}, + }) assert.NoError(t, err) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) diff --git a/management/server/peer.go b/management/server/peer.go index 3c40c6bb6..f954369d8 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -374,12 +374,20 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer return err } - if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { - return err + if err = transaction.RemovePeerFromAllGroups(ctx, peer.ID); err != nil { + return fmt.Errorf("failed to remove peer from groups: %w", err) } eventsToStore, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer}) - return err + if err != nil { + return fmt.Errorf("failed to delete peer: %w", err) + } + + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + return fmt.Errorf("failed to increment network serial: %w", err) + } + + return nil }) if err != nil { return err @@ -478,7 +486,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s } var newPeer *nbpeer.Peer - var updateAccountPeers bool var setupKeyID string var setupKeyName string @@ -615,20 +622,20 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s return err } - err = transaction.AddPeerToAllGroup(ctx, store.LockingStrengthUpdate, accountID, newPeer.ID) - if err != nil { - return fmt.Errorf("failed adding peer to All group: %w", err) - } - if len(groupsToAdd) > 0 { for _, g := range groupsToAdd { - err = transaction.AddPeerToGroup(ctx, store.LockingStrengthUpdate, accountID, newPeer.ID, g) + err = transaction.AddPeerToGroup(ctx, newPeer.AccountID, newPeer.ID, g) if err != nil { return err } } } + err = transaction.AddPeerToAllGroup(ctx, accountID, newPeer.ID) + if err != nil { + return fmt.Errorf("failed adding peer to All group: %w", err) + } + if addedByUser { err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.GetLastLogin()) if err != nil { @@ -678,7 +685,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s return nil, nil, nil, fmt.Errorf("failed to add peer to database after %d attempts: %w", maxAttempts, err) } - updateAccountPeers, err = isPeerInActiveGroup(ctx, am.Store, accountID, newPeer.ID) + updateAccountPeers, err := isPeerInActiveGroup(ctx, am.Store, accountID, newPeer.ID) if err != nil { updateAccountPeers = true } @@ -1021,7 +1028,7 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is }() if isRequiresApproval { - network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthShare, accountID) + network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID) if err != nil { return nil, nil, nil, err } @@ -1523,17 +1530,7 @@ func (am *DefaultAccountManager) GetPeerGroups(ctx context.Context, accountID, p // getPeerGroupIDs returns the IDs of the groups that the peer is part of. func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID string, peerID string) ([]string, error) { - groups, err := transaction.GetPeerGroups(ctx, store.LockingStrengthShare, accountID, peerID) - if err != nil { - return nil, err - } - - groupIDs := make([]string, 0, len(groups)) - for _, group := range groups { - groupIDs = append(groupIDs, group.ID) - } - - return groupIDs, err + return transaction.GetPeerGroupIDs(ctx, store.LockingStrengthShare, accountID, peerID) } // IsPeerInActiveGroup checks if the given peer is part of a group that is used @@ -1563,17 +1560,8 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto } for _, peer := range peers { - groups, err := transaction.GetPeerGroups(ctx, store.LockingStrengthUpdate, accountID, peer.ID) - if err != nil { - return nil, fmt.Errorf("failed to get peer groups: %w", err) - } - - for _, group := range groups { - group.RemovePeer(peer.ID) - err = transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group) - if err != nil { - return nil, fmt.Errorf("failed to save group: %w", err) - } + if err := transaction.RemovePeerFromAllGroups(ctx, peer.ID); err != nil { + return nil, fmt.Errorf("failed to remove peer %s from groups", peer.ID) } if err := am.integratedPeerValidator.PeerDeleted(ctx, accountID, peer.ID, settings.Extra); err != nil { diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 4f6ae500e..947e53a60 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -310,12 +310,12 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { group1.Peers = append(group1.Peers, peer1.ID) group2.Peers = append(group2.Peers, peer2.ID) - err = manager.SaveGroup(context.Background(), account.Id, userID, &group1, true) + err = manager.CreateGroup(context.Background(), account.Id, userID, &group1) if err != nil { t.Errorf("expecting group1 to be added, got failure %v", err) return } - err = manager.SaveGroup(context.Background(), account.Id, userID, &group2, true) + err = manager.CreateGroup(context.Background(), account.Id, userID, &group2) if err != nil { t.Errorf("expecting group2 to be added, got failure %v", err) return @@ -1475,6 +1475,10 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { } func Test_RegisterPeerRollbackOnFailure(t *testing.T) { + engine := os.Getenv("NETBIRD_STORE_ENGINE") + if engine == "sqlite" || engine == "" { + t.Skip("Skipping test because sqlite test store is not respecting foreign keys") + } if runtime.GOOS == "windows" { t.Skip("The SQLite store is not properly supported by Windows yet") } @@ -1709,7 +1713,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID) require.NoError(t, err) - err = manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ + g := []*types.Group{ { ID: "groupA", Name: "GroupA", @@ -1725,8 +1729,11 @@ func TestPeerAccountPeersUpdate(t *testing.T) { Name: "GroupC", Peers: []string{}, }, - }, true) - require.NoError(t, err) + } + for _, group := range g { + err = manager.CreateGroup(context.Background(), account.Id, userID, group) + require.NoError(t, err) + } // create a user with auto groups _, err = manager.SaveOrAddUsers(context.Background(), account.Id, userID, []*types.User{ @@ -1785,7 +1792,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { t.Run("adding peer to unlinked group", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg) + peerShouldNotReceiveUpdate(t, updMsg) // close(done) }() @@ -2164,7 +2171,6 @@ func Test_IsUniqueConstraintError(t *testing.T) { } func Test_AddPeer(t *testing.T) { - t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine)) manager, err := createManager(t) if err != nil { t.Fatal(err) @@ -2176,7 +2182,7 @@ func Test_AddPeer(t *testing.T) { _, err = createAccount(manager, accountID, userID, "domain.com") if err != nil { - t.Fatal("error creating account") + t.Fatalf("error creating account: %v", err) return } @@ -2186,22 +2192,21 @@ func Test_AddPeer(t *testing.T) { return } - const totalPeers = 300 // totalPeers / differentHostnames should be less than 10 (due to concurrent retries) - const differentHostnames = 50 + const totalPeers = 300 var wg sync.WaitGroup - errs := make(chan error, totalPeers+differentHostnames) + errs := make(chan error, totalPeers) start := make(chan struct{}) for i := 0; i < totalPeers; i++ { wg.Add(1) - hostNameID := i % differentHostnames go func(i int) { defer wg.Done() newPeer := &nbpeer.Peer{ - Key: "key" + strconv.Itoa(i), - Meta: nbpeer.PeerSystemMeta{Hostname: "peer" + strconv.Itoa(hostNameID), GoOS: "linux"}, + AccountID: accountID, + Key: "key" + strconv.Itoa(i), + Meta: nbpeer.PeerSystemMeta{Hostname: "peer" + strconv.Itoa(i), GoOS: "linux"}, } <-start diff --git a/management/server/policy_test.go b/management/server/policy_test.go index 4352f3cff..4a08f4c33 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -993,7 +993,7 @@ func sortFunc() func(a *types.FirewallRule, b *types.FirewallRule) int { func TestPolicyAccountPeersUpdate(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ + g := []*types.Group{ { ID: "groupA", Name: "GroupA", @@ -1014,8 +1014,11 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { Name: "GroupD", Peers: []string{peer1.ID, peer2.ID}, }, - }, true) - assert.NoError(t, err) + } + for _, group := range g { + err := manager.CreateGroup(context.Background(), account.Id, userID, group) + assert.NoError(t, err) + } updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) t.Cleanup(func() { @@ -1025,6 +1028,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { var policyWithGroupRulesNoPeers *types.Policy var policyWithDestinationPeersOnly *types.Policy var policyWithSourceAndDestinationPeers *types.Policy + var err error // Saving policy with rule groups with no peers should not update account's peers and not send peer update t.Run("saving policy with rule groups with no peers", func(t *testing.T) { diff --git a/management/server/posture_checks_test.go b/management/server/posture_checks_test.go index f93467375..67760d55a 100644 --- a/management/server/posture_checks_test.go +++ b/management/server/posture_checks_test.go @@ -8,7 +8,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/posture" @@ -105,10 +105,14 @@ func initTestPostureChecksAccount(am *DefaultAccountManager) (*types.Account, er Id: regularUserID, Role: types.UserRoleUser, } + peer1 := &peer.Peer{ + ID: "peer1", + } account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain, false) account.Users[admin.Id] = admin account.Users[user.Id] = user + account.Peers["peer1"] = peer1 err := am.Store.SaveAccount(context.Background(), account) if err != nil { @@ -121,7 +125,7 @@ func initTestPostureChecksAccount(am *DefaultAccountManager) (*types.Account, er func TestPostureCheckAccountPeersUpdate(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ + g := []*types.Group{ { ID: "groupA", Name: "GroupA", @@ -137,8 +141,11 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { Name: "GroupC", Peers: []string{}, }, - }, true) - assert.NoError(t, err) + } + for _, group := range g { + err := manager.CreateGroup(context.Background(), account.Id, userID, group) + assert.NoError(t, err) + } updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) t.Cleanup(func() { @@ -156,7 +163,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { }, }, } - postureCheckA, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckA, true) + postureCheckA, err := manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckA, true) require.NoError(t, err) postureCheckB := &posture.Checks{ @@ -449,14 +456,16 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { AccountID: account.Id, Peers: []string{"peer1"}, } + err = manager.CreateGroup(context.Background(), account.Id, adminUserID, groupA) + require.NoError(t, err, "failed to create groupA") groupB := &types.Group{ ID: "groupB", AccountID: account.Id, Peers: []string{}, } - err = manager.Store.SaveGroups(context.Background(), store.LockingStrengthUpdate, account.Id, []*types.Group{groupA, groupB}) - require.NoError(t, err, "failed to save groups") + err = manager.CreateGroup(context.Background(), account.Id, adminUserID, groupB) + require.NoError(t, err, "failed to create groupB") postureCheckA := &posture.Checks{ Name: "checkA", @@ -535,7 +544,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) { groupA.Peers = []string{} - err = manager.Store.SaveGroup(context.Background(), store.LockingStrengthUpdate, groupA) + err = manager.UpdateGroup(context.Background(), account.Id, adminUserID, groupA) require.NoError(t, err, "failed to save groups") result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) diff --git a/management/server/route_test.go b/management/server/route_test.go index 37c37f624..ffd1a284b 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -1215,7 +1215,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { Name: "peer1 group", Peers: []string{peer1ID}, } - err = am.SaveGroup(context.Background(), account.Id, userID, newGroup, true) + err = am.CreateGroup(context.Background(), account.Id, userID, newGroup) require.NoError(t, err) rules, err := am.ListPolicies(context.Background(), account.Id, "testingUser") @@ -1505,7 +1505,7 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*types.Accou } for _, group := range newGroup { - err = am.SaveGroup(context.Background(), accountID, userID, group, true) + err = am.CreateGroup(context.Background(), accountID, userID, group) if err != nil { return nil, err } @@ -1953,7 +1953,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { account, err := initTestRouteAccount(t, manager) require.NoError(t, err, "failed to init testing account") - err = manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ + g := []*types.Group{ { ID: "groupA", Name: "GroupA", @@ -1969,8 +1969,11 @@ func TestRouteAccountPeersUpdate(t *testing.T) { Name: "GroupC", Peers: []string{}, }, - }, true) - assert.NoError(t, err) + } + for _, group := range g { + err = manager.CreateGroup(context.Background(), account.Id, userID, group) + require.NoError(t, err, "failed to create group %s", group.Name) + } updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1ID) t.Cleanup(func() { @@ -2149,11 +2152,11 @@ func TestRouteAccountPeersUpdate(t *testing.T) { close(done) }() - err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ + err = manager.UpdateGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupB", Name: "GroupB", Peers: []string{peer1ID}, - }, true) + }) assert.NoError(t, err) select { @@ -2189,11 +2192,11 @@ func TestRouteAccountPeersUpdate(t *testing.T) { close(done) }() - err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ + err = manager.UpdateGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupC", Name: "GroupC", Peers: []string{peer1ID}, - }, true) + }) assert.NoError(t, err) select { diff --git a/management/server/setupkey_test.go b/management/server/setupkey_test.go index cecf55200..e55b33c94 100644 --- a/management/server/setupkey_test.go +++ b/management/server/setupkey_test.go @@ -29,7 +29,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { t.Fatal(err) } - err = manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ + err = manager.CreateGroups(context.Background(), account.Id, userID, []*types.Group{ { ID: "group_1", Name: "group_name_1", @@ -40,7 +40,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { Name: "group_name_2", Peers: []string{}, }, - }, true) + }) if err != nil { t.Fatal(err) } @@ -104,20 +104,20 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { t.Fatal(err) } - err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ + err = manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{ ID: "group_1", Name: "group_name_1", Peers: []string{}, - }, true) + }) if err != nil { t.Fatal(err) } - err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ + err = manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{ ID: "group_2", Name: "group_name_2", Peers: []string{}, - }, true) + }) if err != nil { t.Fatal(err) } @@ -398,11 +398,11 @@ func TestSetupKey_Copy(t *testing.T) { func TestSetupKeyAccountPeersUpdate(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ + err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID, peer3.ID}, - }, true) + }) assert.NoError(t, err) policy := &types.Policy{ diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index e380a7da7..c2f0dff6d 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -96,7 +96,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met return nil, fmt.Errorf("migratePreAuto: %w", err) } err = db.AutoMigrate( - &types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{}, + &types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{}, &types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{}, &installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{}, &networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{}, @@ -186,6 +186,10 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro generateAccountSQLTypes(account) + for _, group := range account.GroupsG { + group.StoreGroupPeers() + } + err := s.db.Transaction(func(tx *gorm.DB) error { result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id) if result.Error != nil { @@ -247,7 +251,8 @@ func generateAccountSQLTypes(account *types.Account) { for id, group := range account.Groups { group.ID = id - account.GroupsG = append(account.GroupsG, *group) + group.AccountID = account.Id + account.GroupsG = append(account.GroupsG, group) } for id, route := range account.Routes { @@ -449,25 +454,56 @@ func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, u return nil } -// SaveGroups saves the given list of groups to the database. -func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, accountID string, groups []*types.Group) error { +// CreateGroups creates the given list of groups to the database. +func (s *SqlStore) CreateGroups(ctx context.Context, lockStrength LockingStrength, accountID string, groups []*types.Group) error { if len(groups) == 0 { return nil } - result := s.db. - Clauses( - clause.Locking{Strength: string(lockStrength)}, - clause.OnConflict{ - Where: clause.Where{Exprs: []clause.Expression{clause.Eq{Column: "groups.account_id", Value: accountID}}}, - UpdateAll: true, - }, - ). - Create(&groups) - if result.Error != nil { - return status.Errorf(status.Internal, "failed to save groups to store: %v", result.Error) + return s.db.Transaction(func(tx *gorm.DB) error { + result := tx. + Clauses( + clause.Locking{Strength: string(lockStrength)}, + clause.OnConflict{ + Where: clause.Where{Exprs: []clause.Expression{clause.Eq{Column: "groups.account_id", Value: accountID}}}, + UpdateAll: true, + }, + ). + Omit(clause.Associations). + Create(&groups) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to save groups to store: %v", result.Error) + return status.Errorf(status.Internal, "failed to save groups to store") + } + + return nil + }) +} + +// UpdateGroups updates the given list of groups to the database. +func (s *SqlStore) UpdateGroups(ctx context.Context, lockStrength LockingStrength, accountID string, groups []*types.Group) error { + if len(groups) == 0 { + return nil } - return nil + + return s.db.Transaction(func(tx *gorm.DB) error { + result := tx. + Clauses( + clause.Locking{Strength: string(lockStrength)}, + clause.OnConflict{ + Where: clause.Where{Exprs: []clause.Expression{clause.Eq{Column: "groups.account_id", Value: accountID}}}, + UpdateAll: true, + }, + ). + Omit(clause.Associations). + Create(&groups) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to save groups to store: %v", result.Error) + return status.Errorf(status.Internal, "failed to save groups to store") + } + + return nil + }) } // DeleteHashedPAT2TokenIDIndex is noop in SqlStore @@ -646,7 +682,7 @@ func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStr } var groups []*types.Group - result := tx.Find(&groups, accountIDCondition, accountID) + result := tx.Preload(clause.Associations).Find(&groups, accountIDCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed") @@ -655,6 +691,10 @@ func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStr return nil, status.Errorf(status.Internal, "failed to get account groups from the store") } + for _, g := range groups { + g.LoadGroupPeers() + } + return groups, nil } @@ -669,6 +709,7 @@ func (s *SqlStore) GetResourceGroups(ctx context.Context, lockStrength LockingSt likePattern := `%"ID":"` + resourceID + `"%` result := tx. + Preload(clause.Associations). Where("resources LIKE ?", likePattern). Find(&groups) @@ -679,6 +720,10 @@ func (s *SqlStore) GetResourceGroups(ctx context.Context, lockStrength LockingSt return nil, result.Error } + for _, g := range groups { + g.LoadGroupPeers() + } + return groups, nil } @@ -765,6 +810,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc var account types.Account result := s.db.Model(&account). + Omit("GroupsG"). Preload("UsersG.PATsG"). // have to be specifies as this is nester reference Preload(clause.Associations). First(&account, idQueryCondition, accountID) @@ -814,6 +860,17 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc } account.GroupsG = nil + var groupPeers []types.GroupPeer + s.db.Model(&types.GroupPeer{}).Where("account_id = ?", accountID). + Find(&groupPeers) + for _, groupPeer := range groupPeers { + if group, ok := account.Groups[groupPeer.GroupID]; ok { + group.Peers = append(group.Peers, groupPeer.PeerID) + } else { + log.WithContext(ctx).Warnf("group %s not found for group peer %s in account %s", groupPeer.GroupID, groupPeer.PeerID, accountID) + } + } + account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) for _, route := range account.RoutesG { account.Routes[route.ID] = route.Copy() @@ -1311,55 +1368,76 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string } // AddPeerToAllGroup adds a peer to the 'All' group. Method always needs to run in a transaction -func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error { - var group types.Group - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). - First(&group, "account_id = ? AND name = ?", accountID, "All") - if result.Error != nil { - if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return status.Errorf(status.NotFound, "group 'All' not found for account") - } - return status.Errorf(status.Internal, "issue finding group 'All': %s", result.Error) +func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error { + var groupID string + _ = s.db.Model(types.Group{}). + Select("id"). + Where("account_id = ? AND name = ?", accountID, "All"). + Limit(1). + Scan(&groupID) + + if groupID == "" { + return status.Errorf(status.NotFound, "group 'All' not found for account %s", accountID) } - for _, existingPeerID := range group.Peers { - if existingPeerID == peerID { - return nil - } - } + err := s.db.Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: "group_id"}, {Name: "peer_id"}}, + DoNothing: true, + }).Create(&types.GroupPeer{ + AccountID: accountID, + GroupID: groupID, + PeerID: peerID, + }).Error - group.Peers = append(group.Peers, peerID) - - if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&group).Error; err != nil { - return status.Errorf(status.Internal, "issue updating group 'All': %s", err) + if err != nil { + return status.Errorf(status.Internal, "error adding peer to group 'All': %v", err) } return nil } -// AddPeerToGroup adds a peer to a group. Method always needs to run in a transaction -func (s *SqlStore) AddPeerToGroup(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string, groupID string) error { - var group types.Group - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Where(accountAndIDQueryCondition, accountId, groupID). - First(&group) - if result.Error != nil { - if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return status.NewGroupNotFoundError(groupID) - } - - return status.Errorf(status.Internal, "issue finding group: %s", result.Error) +// AddPeerToGroup adds a peer to a group +func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountID, peerID, groupID string) error { + peer := &types.GroupPeer{ + AccountID: accountID, + GroupID: groupID, + PeerID: peerID, } - for _, existingPeerID := range group.Peers { - if existingPeerID == peerId { - return nil - } + err := s.db.Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: "group_id"}, {Name: "peer_id"}}, + DoNothing: true, + }).Create(peer).Error + + if err != nil { + log.WithContext(ctx).Errorf("failed to add peer %s to group %s for account %s: %v", peerID, groupID, accountID, err) + return status.Errorf(status.Internal, "failed to add peer to group") } - group.Peers = append(group.Peers, peerId) + return nil +} - if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&group).Error; err != nil { - return status.Errorf(status.Internal, "issue updating group: %s", err) +// RemovePeerFromGroup removes a peer from a group +func (s *SqlStore) RemovePeerFromGroup(ctx context.Context, peerID string, groupID string) error { + err := s.db.WithContext(ctx). + Delete(&types.GroupPeer{}, "group_id = ? AND peer_id = ?", groupID, peerID).Error + + if err != nil { + log.WithContext(ctx).Errorf("failed to remove peer %s from group %s: %v", peerID, groupID, err) + return status.Errorf(status.Internal, "failed to remove peer from group") + } + + return nil +} + +// RemovePeerFromAllGroups removes a peer from all groups +func (s *SqlStore) RemovePeerFromAllGroups(ctx context.Context, peerID string) error { + err := s.db.WithContext(ctx). + Delete(&types.GroupPeer{}, "peer_id = ?", peerID).Error + + if err != nil { + log.WithContext(ctx).Errorf("failed to remove peer %s from all groups: %v", peerID, err) + return status.Errorf(status.Internal, "failed to remove peer from all groups") } return nil @@ -1427,15 +1505,46 @@ func (s *SqlStore) GetPeerGroups(ctx context.Context, lockStrength LockingStreng var groups []*types.Group query := tx. - Find(&groups, "account_id = ? AND peers LIKE ?", accountId, fmt.Sprintf(`%%"%s"%%`, peerId)) + Joins("JOIN group_peers ON group_peers.group_id = groups.id"). + Where("group_peers.peer_id = ?", peerId). + Preload(clause.Associations). + Find(&groups) if query.Error != nil { return nil, query.Error } + for _, group := range groups { + group.LoadGroupPeers() + } + return groups, nil } +// GetPeerGroupIDs retrieves all group IDs assigned to a specific peer in a given account. +func (s *SqlStore) GetPeerGroupIDs(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]string, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var groupIDs []string + query := tx. + Model(&types.GroupPeer{}). + Where("account_id = ? AND peer_id = ?", accountId, peerId). + Pluck("group_id", &groupIDs) + + if query.Error != nil { + if errors.Is(query.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "no groups found for peer %s in account %s", peerId, accountId) + } + log.WithContext(ctx).Errorf("failed to get group IDs for peer %s in account %s: %v", peerId, accountId, query.Error) + return nil, status.Errorf(status.Internal, "failed to get group IDs for peer from store") + } + + return groupIDs, nil +} + // GetAccountPeers retrieves peers for an account. func (s *SqlStore) GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) { var peers []*nbpeer.Peer @@ -1485,7 +1594,7 @@ func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrengt } func (s *SqlStore) AddPeerToAccount(ctx context.Context, lockStrength LockingStrength, peer *nbpeer.Peer) error { - if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Create(peer).Error; err != nil { + if err := s.db.Create(peer).Error; err != nil { return status.Errorf(status.Internal, "issue adding peer to account: %s", err) } @@ -1722,7 +1831,7 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt } var group *types.Group - result := tx.First(&group, accountAndIDQueryCondition, accountID, groupID) + result := tx.Preload(clause.Associations).First(&group, accountAndIDQueryCondition, accountID, groupID) if err := result.Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, status.NewGroupNotFoundError(groupID) @@ -1731,15 +1840,14 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt return nil, status.Errorf(status.Internal, "failed to get group from store") } + group.LoadGroupPeers() + return group, nil } // GetGroupByName retrieves a group by name and account ID. func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*types.Group, error) { tx := s.db - if lockStrength != LockingStrengthNone { - tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) - } var group types.Group @@ -1747,16 +1855,14 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren // we may need to reconsider changing the types. query := tx.Preload(clause.Associations) - switch s.storeEngine { - case types.PostgresStoreEngine: - query = query.Order("json_array_length(peers::json) DESC") - case types.MysqlStoreEngine: - query = query.Order("JSON_LENGTH(JSON_EXTRACT(peers, \"$\")) DESC") - default: - query = query.Order("json_array_length(peers) DESC") - } - - result := query.First(&group, "account_id = ? AND name = ?", accountID, groupName) + result := query. + Model(&types.Group{}). + Joins("LEFT JOIN group_peers ON group_peers.group_id = groups.id"). + Where("groups.account_id = ? AND groups.name = ?", accountID, groupName). + Group("groups.id"). + Order("COUNT(group_peers.peer_id) DESC"). + Limit(1). + First(&group) if err := result.Error; err != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.NewGroupNotFoundError(groupName) @@ -1764,6 +1870,9 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren log.WithContext(ctx).Errorf("failed to get group by name from store: %v", result.Error) return nil, status.Errorf(status.Internal, "failed to get group by name from store") } + + group.LoadGroupPeers() + return &group, nil } @@ -1775,7 +1884,7 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren } var groups []*types.Group - result := tx.Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs) + result := tx.Preload(clause.Associations).Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs) if result.Error != nil { log.WithContext(ctx).Errorf("failed to get groups by ID's from store: %s", result.Error) return nil, status.Errorf(status.Internal, "failed to get groups by ID's from store") @@ -1783,25 +1892,45 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren groupsMap := make(map[string]*types.Group) for _, group := range groups { + group.LoadGroupPeers() groupsMap[group.ID] = group } return groupsMap, nil } -// SaveGroup saves a group to the store. -func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error { - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group) - if result.Error != nil { - log.WithContext(ctx).Errorf("failed to save group to store: %v", result.Error) +// CreateGroup creates a group in the store. +func (s *SqlStore) CreateGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error { + if group == nil { + return status.Errorf(status.InvalidArgument, "group is nil") + } + + if err := s.db.Omit(clause.Associations).Create(group).Error; err != nil { + log.WithContext(ctx).Errorf("failed to save group to store: %v", err) return status.Errorf(status.Internal, "failed to save group to store") } + + return nil +} + +// UpdateGroup updates a group in the store. +func (s *SqlStore) UpdateGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error { + if group == nil { + return status.Errorf(status.InvalidArgument, "group is nil") + } + + if err := s.db.Omit(clause.Associations).Save(group).Error; err != nil { + log.WithContext(ctx).Errorf("failed to save group to store: %v", err) + return status.Errorf(status.Internal, "failed to save group to store") + } + return nil } // DeleteGroup deletes a group from the database. func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Select(clause.Associations). Delete(&types.Group{}, accountAndIDQueryCondition, accountID, groupID) if err := result.Error; err != nil { log.WithContext(ctx).Errorf("failed to delete group from store: %s", result.Error) @@ -1818,6 +1947,7 @@ func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength // DeleteGroups deletes groups from the database. func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error { result := s.db.Clauses(clause.Locking{Strength: string(strength)}). + Select(clause.Associations). Delete(&types.Group{}, accountAndIDsQueryCondition, accountID, groupIDs) if result.Error != nil { log.WithContext(ctx).Errorf("failed to delete groups from store: %v", result.Error) @@ -2613,3 +2743,27 @@ func (s *SqlStore) CountAccountsByPrivateDomain(ctx context.Context, domain stri return count, nil } + +func (s *SqlStore) GetAccountGroupPeers(ctx context.Context, lockStrength LockingStrength, accountID string) (map[string]map[string]struct{}, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var peers []types.GroupPeer + result := tx.Find(&peers, accountIDCondition, accountID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get account group peers from store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get account group peers from store") + } + + groupPeers := make(map[string]map[string]struct{}) + for _, peer := range peers { + if _, exists := groupPeers[peer.GroupID]; !exists { + groupPeers[peer.GroupID] = make(map[string]struct{}) + } + groupPeers[peer.GroupID][peer.PeerID] = struct{}{} + } + + return groupPeers, nil +} diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index 738c5a28c..44bb3f599 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -4,6 +4,7 @@ import ( "context" "crypto/sha256" b64 "encoding/base64" + "encoding/binary" "fmt" "math/rand" "net" @@ -1187,7 +1188,7 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) { Peers: nil, } err = store.ExecuteInTransaction(context.Background(), func(transaction Store) error { - err := transaction.SaveGroup(context.Background(), LockingStrengthUpdate, group) + err := transaction.CreateGroup(context.Background(), LockingStrengthUpdate, group) if err != nil { t.Fatal("failed to save group") return err @@ -1348,7 +1349,8 @@ func TestSqlStore_GetGroupsByIDs(t *testing.T) { } } -func TestSqlStore_SaveGroup(t *testing.T) { +func TestSqlStore_CreateGroup(t *testing.T) { + t.Setenv("NETBIRD_STORE_ENGINE", string(types.MysqlStoreEngine)) store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1356,12 +1358,14 @@ func TestSqlStore_SaveGroup(t *testing.T) { accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" group := &types.Group{ - ID: "group-id", - AccountID: accountID, - Issued: "api", - Peers: []string{"peer1", "peer2"}, + ID: "group-id", + AccountID: accountID, + Issued: "api", + Peers: []string{}, + Resources: []types.Resource{}, + GroupPeers: []types.GroupPeer{}, } - err = store.SaveGroup(context.Background(), LockingStrengthUpdate, group) + err = store.CreateGroup(context.Background(), LockingStrengthUpdate, group) require.NoError(t, err) savedGroup, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, "group-id") @@ -1369,7 +1373,7 @@ func TestSqlStore_SaveGroup(t *testing.T) { require.Equal(t, savedGroup, group) } -func TestSqlStore_SaveGroups(t *testing.T) { +func TestSqlStore_CreateUpdateGroups(t *testing.T) { store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1378,23 +1382,27 @@ func TestSqlStore_SaveGroups(t *testing.T) { groups := []*types.Group{ { - ID: "group-1", - AccountID: accountID, - Issued: "api", - Peers: []string{"peer1", "peer2"}, + ID: "group-1", + AccountID: accountID, + Issued: "api", + Peers: []string{}, + Resources: []types.Resource{}, + GroupPeers: []types.GroupPeer{}, }, { - ID: "group-2", - AccountID: accountID, - Issued: "integration", - Peers: []string{"peer3", "peer4"}, + ID: "group-2", + AccountID: accountID, + Issued: "integration", + Peers: []string{}, + Resources: []types.Resource{}, + GroupPeers: []types.GroupPeer{}, }, } - err = store.SaveGroups(context.Background(), LockingStrengthUpdate, accountID, groups) + err = store.CreateGroups(context.Background(), LockingStrengthUpdate, accountID, groups) require.NoError(t, err) groups[1].Peers = []string{} - err = store.SaveGroups(context.Background(), LockingStrengthUpdate, accountID, groups) + err = store.UpdateGroups(context.Background(), LockingStrengthUpdate, accountID, groups) require.NoError(t, err) group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groups[1].ID) @@ -2523,7 +2531,7 @@ func TestSqlStore_AddPeerToGroup(t *testing.T) { require.NoError(t, err, "failed to get group") require.Len(t, group.Peers, 0, "group should have 0 peers") - err = store.AddPeerToGroup(context.Background(), LockingStrengthUpdate, accountID, peerID, groupID) + err = store.AddPeerToGroup(context.Background(), accountID, peerID, groupID) require.NoError(t, err, "failed to add peer to group") group, err = store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) @@ -2554,7 +2562,7 @@ func TestSqlStore_AddPeerToAllGroup(t *testing.T) { err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer) require.NoError(t, err, "failed to add peer to account") - err = store.AddPeerToAllGroup(context.Background(), LockingStrengthUpdate, accountID, peer.ID) + err = store.AddPeerToAllGroup(context.Background(), accountID, peer.ID) require.NoError(t, err, "failed to add peer to all group") group, err = store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) @@ -2640,7 +2648,7 @@ func TestSqlStore_GetPeerGroups(t *testing.T) { assert.Len(t, groups, 1) assert.Equal(t, groups[0].Name, "All") - err = store.AddPeerToGroup(context.Background(), LockingStrengthUpdate, accountID, peerID, "cfefqs706sqkneg59g4h") + err = store.AddPeerToGroup(context.Background(), accountID, peerID, "cfefqs706sqkneg59g4h") require.NoError(t, err) groups, err = store.GetPeerGroups(context.Background(), LockingStrengthShare, accountID, peerID) @@ -3307,7 +3315,7 @@ func TestSqlStore_SaveGroups_LargeBatch(t *testing.T) { }) } - err = store.SaveGroups(context.Background(), LockingStrengthUpdate, accountID, groupsToSave) + err = store.CreateGroups(context.Background(), LockingStrengthUpdate, accountID, groupsToSave) require.NoError(t, err) accountGroups, err = store.GetAccountGroups(context.Background(), LockingStrengthShare, accountID) @@ -3538,3 +3546,64 @@ func TestSqlStore_GetAnyAccountID(t *testing.T) { assert.Empty(t, accountID) }) } + +func BenchmarkGetAccountPeers(b *testing.B) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_with_expired_peers.sql", b.TempDir()) + if err != nil { + b.Fatal(err) + } + b.Cleanup(cleanup) + + numberOfPeers := 1000 + numberOfGroups := 200 + numberOfPeersPerGroup := 500 + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + peers := make([]*nbpeer.Peer, 0, numberOfPeers) + for i := 0; i < numberOfPeers; i++ { + peer := &nbpeer.Peer{ + ID: fmt.Sprintf("peer-%d", i), + AccountID: accountID, + DNSLabel: fmt.Sprintf("peer%d.example.com", i), + IP: intToIPv4(uint32(i)), + } + err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer) + if err != nil { + b.Fatalf("Failed to add peer: %v", err) + } + peers = append(peers, peer) + } + + for i := 0; i < numberOfGroups; i++ { + groupID := fmt.Sprintf("group-%d", i) + group := &types.Group{ + ID: groupID, + AccountID: accountID, + } + err = store.CreateGroup(context.Background(), LockingStrengthUpdate, group) + if err != nil { + b.Fatalf("Failed to create group: %v", err) + } + for j := 0; j < numberOfPeersPerGroup; j++ { + peerIndex := (i*numberOfPeersPerGroup + j) % numberOfPeers + err = store.AddPeerToGroup(context.Background(), accountID, peers[peerIndex].ID, groupID) + if err != nil { + b.Fatalf("Failed to add peer to group: %v", err) + } + } + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := store.GetPeerGroups(context.Background(), LockingStrengthShare, accountID, peers[i%numberOfPeers].ID) + if err != nil { + b.Fatal(err) + } + } +} + +func intToIPv4(n uint32) net.IP { + ip := make(net.IP, 4) + binary.BigEndian.PutUint32(ip, n) + return ip +} diff --git a/management/server/store/store.go b/management/server/store/store.go index b3254c4c9..912939bc2 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -101,8 +101,10 @@ type Store interface { GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*types.Group, error) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types.Group, error) - SaveGroups(ctx context.Context, lockStrength LockingStrength, accountID string, groups []*types.Group) error - SaveGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error + CreateGroups(ctx context.Context, lockStrength LockingStrength, accountID string, groups []*types.Group) error + UpdateGroups(ctx context.Context, lockStrength LockingStrength, accountID string, groups []*types.Group) error + CreateGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error + UpdateGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error @@ -120,9 +122,12 @@ type Store interface { DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string, hostname string) ([]string, error) - AddPeerToAllGroup(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error - AddPeerToGroup(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string, groupID string) error + AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error + AddPeerToGroup(ctx context.Context, accountID, peerId string, groupID string) error + RemovePeerFromGroup(ctx context.Context, peerID string, groupID string) error + RemovePeerFromAllGroups(ctx context.Context, peerID string) error GetPeerGroups(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]*types.Group, error) + GetPeerGroupIDs(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]string, error) AddResourceToGroup(ctx context.Context, accountId string, groupID string, resource *types.Resource) error RemoveResourceFromGroup(ctx context.Context, accountId string, groupID string, resourceID string) error AddPeerToAccount(ctx context.Context, lockStrength LockingStrength, peer *nbpeer.Peer) error @@ -196,6 +201,7 @@ type Store interface { DeleteNetworkResource(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) error GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error) GetPeerIdByLabel(ctx context.Context, lockStrength LockingStrength, accountID string, hostname string) (string, error) + GetAccountGroupPeers(ctx context.Context, lockStrength LockingStrength, accountID string) (map[string]map[string]struct{}, error) } const ( @@ -353,6 +359,15 @@ func getMigrationsPostAuto(ctx context.Context) []migrationFunc { func(db *gorm.DB) error { return migration.CreateIndexIfNotExists[nbpeer.Peer](ctx, db, "idx_account_dnslabel", "account_id", "dns_label") }, + func(db *gorm.DB) error { + return migration.MigrateJsonToTable[types.Group](ctx, db, "peers", func(accountID, id, value string) any { + return &types.GroupPeer{ + AccountID: accountID, + GroupID: id, + PeerID: value, + } + }) + }, } } diff --git a/management/server/types/account.go b/management/server/types/account.go index f0887be07..a3a7ce305 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -73,7 +73,7 @@ type Account struct { Users map[string]*User `gorm:"-"` UsersG []User `json:"-" gorm:"foreignKey:AccountID;references:id"` Groups map[string]*Group `gorm:"-"` - GroupsG []Group `json:"-" gorm:"foreignKey:AccountID;references:id"` + GroupsG []*Group `json:"-" gorm:"foreignKey:AccountID;references:id"` Policies []*Policy `gorm:"foreignKey:AccountID;references:id"` Routes map[route.ID]*route.Route `gorm:"-"` RoutesG []route.Route `json:"-" gorm:"foreignKey:AccountID;references:id"` diff --git a/management/server/types/group.go b/management/server/types/group.go index 1b321387c..00fdf7a69 100644 --- a/management/server/types/group.go +++ b/management/server/types/group.go @@ -26,7 +26,8 @@ type Group struct { Issued string // Peers list of the group - Peers []string `gorm:"serializer:json"` + Peers []string `gorm:"-"` // Peers and GroupPeers list will be ignored when writing to the DB. Use AddPeerToGroup and RemovePeerFromGroup methods to modify group membership + GroupPeers []GroupPeer `gorm:"foreignKey:GroupID;references:id;constraint:OnDelete:CASCADE;"` // Resources contains a list of resources in that group Resources []Resource `gorm:"serializer:json"` @@ -34,6 +35,32 @@ type Group struct { IntegrationReference integration_reference.IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"` } +type GroupPeer struct { + AccountID string `gorm:"index"` + GroupID string `gorm:"primaryKey"` + PeerID string `gorm:"primaryKey"` +} + +func (g *Group) LoadGroupPeers() { + g.Peers = make([]string, len(g.GroupPeers)) + for i, peer := range g.GroupPeers { + g.Peers[i] = peer.PeerID + } + g.GroupPeers = []GroupPeer{} +} + +func (g *Group) StoreGroupPeers() { + g.GroupPeers = make([]GroupPeer, len(g.Peers)) + for i, peer := range g.Peers { + g.GroupPeers[i] = GroupPeer{ + AccountID: g.AccountID, + GroupID: g.ID, + PeerID: peer, + } + } + g.Peers = []string{} +} + // EventMeta returns activity event meta related to the group func (g *Group) EventMeta() map[string]any { return map[string]any{"name": g.Name} @@ -46,13 +73,16 @@ func (g *Group) EventMetaResource(resource *types.NetworkResource) map[string]an func (g *Group) Copy() *Group { group := &Group{ ID: g.ID, + AccountID: g.AccountID, Name: g.Name, Issued: g.Issued, Peers: make([]string, len(g.Peers)), + GroupPeers: make([]GroupPeer, len(g.GroupPeers)), Resources: make([]Resource, len(g.Resources)), IntegrationReference: g.IntegrationReference, } copy(group.Peers, g.Peers) + copy(group.GroupPeers, g.GroupPeers) copy(group.Resources, g.Resources) return group } diff --git a/management/server/types/setupkey.go b/management/server/types/setupkey.go index 69b381ae5..3d421342d 100644 --- a/management/server/types/setupkey.go +++ b/management/server/types/setupkey.go @@ -35,7 +35,7 @@ type SetupKey struct { // AccountID is a reference to Account that this object belongs AccountID string `json:"-" gorm:"index"` Key string - KeySecret string + KeySecret string `gorm:"index"` Name string Type SetupKeyType CreatedAt time.Time diff --git a/management/server/user.go b/management/server/user.go index 7d8382978..a0f4c4a6c 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -677,13 +677,18 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact if update.AutoGroups != nil && settings.GroupsPropagationEnabled { removedGroups := util.Difference(oldUser.AutoGroups, update.AutoGroups) - updatedGroups, err := updateUserPeersInGroups(groupsMap, userPeers, update.AutoGroups, removedGroups) - if err != nil { - return false, nil, nil, nil, fmt.Errorf("error modifying user peers in groups: %w", err) - } - - if err = transaction.SaveGroups(ctx, store.LockingStrengthUpdate, accountID, updatedGroups); err != nil { - return false, nil, nil, nil, fmt.Errorf("error saving groups: %w", err) + addedGroups := util.Difference(update.AutoGroups, oldUser.AutoGroups) + for _, peer := range userPeers { + for _, groupID := range removedGroups { + if err := transaction.RemovePeerFromGroup(ctx, peer.ID, groupID); err != nil { + return false, nil, nil, nil, fmt.Errorf("failed to remove peer %s from group %s: %w", peer.ID, groupID, err) + } + } + for _, groupID := range addedGroups { + if err := transaction.AddPeerToGroup(ctx, accountID, peer.ID, groupID); err != nil { + return false, nil, nil, nil, fmt.Errorf("failed to add peer %s to group %s: %w", peer.ID, groupID, err) + } + } } } @@ -1137,93 +1142,6 @@ func (am *DefaultAccountManager) GetOwnerInfo(ctx context.Context, accountID str return userInfo, nil } -// updateUserPeersInGroups updates the user's peers in the specified groups by adding or removing them. -func updateUserPeersInGroups(accountGroups map[string]*types.Group, peers []*nbpeer.Peer, groupsToAdd, groupsToRemove []string) (groupsToUpdate []*types.Group, err error) { - if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 { - return - } - - userPeerIDMap := make(map[string]struct{}, len(peers)) - for _, peer := range peers { - userPeerIDMap[peer.ID] = struct{}{} - } - - for _, gid := range groupsToAdd { - group, ok := accountGroups[gid] - if !ok { - return nil, errors.New("group not found") - } - if changed := addUserPeersToGroup(userPeerIDMap, group); changed { - groupsToUpdate = append(groupsToUpdate, group) - } - } - - for _, gid := range groupsToRemove { - group, ok := accountGroups[gid] - if !ok { - return nil, errors.New("group not found") - } - if changed := removeUserPeersFromGroup(userPeerIDMap, group); changed { - groupsToUpdate = append(groupsToUpdate, group) - } - } - - return groupsToUpdate, nil -} - -// addUserPeersToGroup adds the user's peers to the group. -func addUserPeersToGroup(userPeerIDs map[string]struct{}, group *types.Group) bool { - groupPeers := make(map[string]struct{}, len(group.Peers)) - for _, pid := range group.Peers { - groupPeers[pid] = struct{}{} - } - - changed := false - for pid := range userPeerIDs { - if _, exists := groupPeers[pid]; !exists { - groupPeers[pid] = struct{}{} - changed = true - } - } - - group.Peers = make([]string, 0, len(groupPeers)) - for pid := range groupPeers { - group.Peers = append(group.Peers, pid) - } - - if changed { - group.Peers = make([]string, 0, len(groupPeers)) - for pid := range groupPeers { - group.Peers = append(group.Peers, pid) - } - } - return changed -} - -// removeUserPeersFromGroup removes user's peers from the group. -func removeUserPeersFromGroup(userPeerIDs map[string]struct{}, group *types.Group) bool { - // skip removing peers from group All - if group.Name == "All" { - return false - } - - updatedPeers := make([]string, 0, len(group.Peers)) - changed := false - - for _, pid := range group.Peers { - if _, owned := userPeerIDs[pid]; owned { - changed = true - continue - } - updatedPeers = append(updatedPeers, pid) - } - - if changed { - group.Peers = updatedPeers - } - return changed -} - func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) { for _, user := range userData { if user.ID == userID { diff --git a/management/server/user_test.go b/management/server/user_test.go index 53baf8f7e..8ab6584cf 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -1335,11 +1335,11 @@ func TestUserAccountPeersUpdate(t *testing.T) { // account groups propagation is enabled manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ + err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID, peer3.ID}, - }, true) + }) require.NoError(t, err) policy := &types.Policy{