From a6c59601f92c09a4b40a0c016c2fed7dcb8f2465 Mon Sep 17 00:00:00 2001
From: Maycon Santos
Date: Sun, 18 Aug 2024 14:19:31 +0200
Subject: [PATCH 01/89] Update Slack invite link (#2445)
---
README.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/README.md b/README.md
index 370445412..1c5e76627 100644
--- a/README.md
+++ b/README.md
@@ -17,7 +17,7 @@
-
+
From 049b5fb7ede553da0d812590d083b6c77e5ca4a2 Mon Sep 17 00:00:00 2001
From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com>
Date: Mon, 19 Aug 2024 12:50:11 +0200
Subject: [PATCH 02/89] Split DB calls in peer login (#2439)
---
management/server/account.go | 22 ++++++
management/server/file_store.go | 29 ++++++++
management/server/peer.go | 116 +++++++++++++++++---------------
management/server/sql_store.go | 28 ++++++++
management/server/store.go | 2 +
5 files changed, 144 insertions(+), 53 deletions(-)
diff --git a/management/server/account.go b/management/server/account.go
index 972272746..4c150fd7e 100644
--- a/management/server/account.go
+++ b/management/server/account.go
@@ -2072,6 +2072,28 @@ func (am *DefaultAccountManager) GetAccountIDForPeerKey(ctx context.Context, pee
return am.Store.GetAccountIDByPeerPubKey(ctx, peerKey)
}
+func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpeer.Peer, settings *Settings) (bool, error) {
+ user, err := am.Store.GetUserByUserID(ctx, peer.UserID)
+ if err != nil {
+ return false, err
+ }
+
+ err = checkIfPeerOwnerIsBlocked(peer, user)
+ if err != nil {
+ return false, err
+ }
+
+ if peerLoginExpired(ctx, peer, settings) {
+ err = am.handleExpiredPeer(ctx, user, peer)
+ if err != nil {
+ return false, err
+ }
+ return true, nil
+ }
+
+ return false, nil
+}
+
// addAllGroup to account object if it doesn't exist
func addAllGroup(account *Account) error {
if len(account.Groups) == 0 {
diff --git a/management/server/file_store.go b/management/server/file_store.go
index 6e3536bcd..1927568ef 100644
--- a/management/server/file_store.go
+++ b/management/server/file_store.go
@@ -469,6 +469,35 @@ func (s *FileStore) GetUserByTokenID(_ context.Context, tokenID string) (*User,
return account.Users[userID].Copy(), nil
}
+func (s *FileStore) GetUserByUserID(_ context.Context, userID string) (*User, error) {
+ accountID, ok := s.UserID2AccountID[userID]
+ if !ok {
+ return nil, status.Errorf(status.NotFound, "accountID not found: provided userID doesn't exists")
+ }
+
+ account, err := s.getAccount(accountID)
+ if err != nil {
+ return nil, err
+ }
+
+ return account.Users[userID].Copy(), nil
+}
+
+func (s *FileStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
+ account, err := s.getAccount(accountID)
+ if err != nil {
+ return nil, err
+ }
+
+ groupsSlice := make([]*nbgroup.Group, 0, len(account.Groups))
+
+ for _, group := range account.Groups {
+ groupsSlice = append(groupsSlice, group)
+ }
+
+ return groupsSlice, nil
+}
+
// GetAllAccounts returns all accounts
func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) {
s.mux.Lock()
diff --git a/management/server/peer.go b/management/server/peer.go
index 7afe6ee0d..93234d9de 100644
--- a/management/server/peer.go
+++ b/management/server/peer.go
@@ -549,16 +549,25 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
return nil, nil, nil, status.NewPeerNotRegisteredError()
}
- err = checkIfPeerOwnerIsBlocked(peer, account)
- if err != nil {
- return nil, nil, nil, err
+ if peer.UserID != "" {
+ log.Infof("Peer has no userID")
+
+ user, err := account.FindUser(peer.UserID)
+ if err != nil {
+ return nil, nil, nil, err
+ }
+
+ err = checkIfPeerOwnerIsBlocked(peer, user)
+ if err != nil {
+ return nil, nil, nil, err
+ }
}
if peerLoginExpired(ctx, peer, account.Settings) {
return nil, nil, nil, status.NewPeerLoginExpiredError()
}
- peer, updated := updatePeerMeta(peer, sync.Meta, account)
+ updated := peer.UpdateMetaIfNew(sync.Meta)
if updated {
err = am.Store.SavePeer(ctx, account.Id, peer)
if err != nil {
@@ -624,31 +633,28 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
// it means that the client has already checked if it needs login and had been through the SSO flow
// so, we can skip this check and directly proceed with the login
if login.UserID == "" {
+ log.Info("Peer needs login")
err = am.checkIFPeerNeedsLoginWithoutLock(ctx, accountID, login)
if err != nil {
return nil, nil, nil, err
}
}
- unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
+ unlockAccount := am.Store.AcquireReadLockByUID(ctx, accountID)
+ defer unlockAccount()
+ unlockPeer := am.Store.AcquireWriteLockByUID(ctx, login.WireGuardPubKey)
defer func() {
- if unlock != nil {
- unlock()
+ if unlockPeer != nil {
+ unlockPeer()
}
}()
- // fetch the account from the store once more after acquiring lock to avoid concurrent updates inconsistencies
- account, err := am.Store.GetAccount(ctx, accountID)
+ peer, err := am.Store.GetPeerByPeerPubKey(ctx, login.WireGuardPubKey)
if err != nil {
return nil, nil, nil, err
}
- peer, err := account.FindPeerByPubKey(login.WireGuardPubKey)
- if err != nil {
- return nil, nil, nil, status.NewPeerNotRegisteredError()
- }
-
- err = checkIfPeerOwnerIsBlocked(peer, account)
+ settings, err := am.Store.GetAccountSettings(ctx, accountID)
if err != nil {
return nil, nil, nil, err
}
@@ -656,21 +662,39 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
// this flag prevents unnecessary calls to the persistent store.
shouldStorePeer := false
updateRemotePeers := false
- if peerLoginExpired(ctx, peer, account.Settings) {
- err = am.handleExpiredPeer(ctx, login, account, peer)
+
+ if login.UserID != "" {
+ changed, err := am.handleUserPeer(ctx, peer, settings)
if err != nil {
return nil, nil, nil, err
}
- updateRemotePeers = true
- shouldStorePeer = true
+ if changed {
+ shouldStorePeer = true
+ updateRemotePeers = true
+ }
}
- isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
+ groups, err := am.Store.GetAccountGroups(ctx, accountID)
if err != nil {
return nil, nil, nil, err
}
- peer, updated := updatePeerMeta(peer, login.Meta, account)
+ var grps []string
+ for _, group := range groups {
+ for _, id := range group.Peers {
+ if id == peer.ID {
+ grps = append(grps, group.ID)
+ break
+ }
+ }
+ }
+
+ isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, grps, settings.Extra)
+ if err != nil {
+ return nil, nil, nil, err
+ }
+
+ updated := peer.UpdateMetaIfNew(login.Meta)
if updated {
shouldStorePeer = true
}
@@ -687,8 +711,13 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
}
}
- unlock()
- unlock = nil
+ unlockPeer()
+ unlockPeer = nil
+
+ account, err := am.Store.GetAccount(ctx, accountID)
+ if err != nil {
+ return nil, nil, nil, err
+ }
if updateRemotePeers || isStatusChanged {
am.updateAccountPeers(ctx, account)
@@ -746,36 +775,30 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
}
-func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, login PeerLogin, account *Account, peer *nbpeer.Peer) error {
- err := checkAuth(ctx, login.UserID, peer)
+func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *User, peer *nbpeer.Peer) error {
+ err := checkAuth(ctx, user.Id, peer)
if err != nil {
return err
}
// If peer was expired before and if it reached this point, it is re-authenticated.
// UserID is present, meaning that JWT validation passed successfully in the API layer.
- updatePeerLastLogin(peer, account)
-
- // sync user last login with peer last login
- user, err := account.FindUser(login.UserID)
- if err != nil {
- return status.Errorf(status.Internal, "couldn't find user")
- }
-
- err = am.Store.SaveUserLastLogin(account.Id, user.Id, peer.LastLogin)
+ peer = peer.UpdateLastLogin()
+ err = am.Store.SavePeer(ctx, peer.AccountID, peer)
if err != nil {
return err
}
- am.StoreEvent(ctx, login.UserID, peer.ID, account.Id, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain()))
+ err = am.Store.SaveUserLastLogin(user.AccountID, user.Id, peer.LastLogin)
+ if err != nil {
+ return err
+ }
+
+ am.StoreEvent(ctx, user.Id, peer.ID, user.AccountID, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain()))
return nil
}
-func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, account *Account) error {
+func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, user *User) error {
if peer.AddedWithSSOLogin() {
- user, err := account.FindUser(peer.UserID)
- if err != nil {
- return status.Errorf(status.PermissionDenied, "user doesn't exist")
- }
if user.IsBlocked() {
return status.Errorf(status.PermissionDenied, "user is blocked")
}
@@ -805,11 +828,6 @@ func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *Settings
return false
}
-func updatePeerLastLogin(peer *nbpeer.Peer, account *Account) {
- peer.UpdateLastLogin()
- account.UpdatePeer(peer)
-}
-
// UpdatePeerSSHKey updates peer's public SSH key
func (am *DefaultAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error {
if sshKey == "" {
@@ -908,14 +926,6 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID,
return nil, status.Errorf(status.Internal, "user %s has no access to peer %s under account %s", userID, peerID, accountID)
}
-func updatePeerMeta(peer *nbpeer.Peer, meta nbpeer.PeerSystemMeta, account *Account) (*nbpeer.Peer, bool) {
- if peer.UpdateMetaIfNew(meta) {
- account.UpdatePeer(peer)
- return peer, true
- }
- return peer, false
-}
-
// updateAccountPeers updates all peers that belong to an account.
// Should be called when changes have to be synced to peers.
func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account *Account) {
diff --git a/management/server/sql_store.go b/management/server/sql_store.go
index c44ab7f09..912e31410 100644
--- a/management/server/sql_store.go
+++ b/management/server/sql_store.go
@@ -468,6 +468,34 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User,
return &user, nil
}
+func (s *SqlStore) GetUserByUserID(ctx context.Context, userID string) (*User, error) {
+ var user User
+ result := s.db.First(&user, idQueryCondition, userID)
+ if result.Error != nil {
+ if errors.Is(result.Error, gorm.ErrRecordNotFound) {
+ return nil, status.Errorf(status.NotFound, "user not found: index lookup failed")
+ }
+ log.WithContext(ctx).Errorf("error when getting user from the store: %s", result.Error)
+ return nil, status.Errorf(status.Internal, "issue getting user from store")
+ }
+
+ return &user, nil
+}
+
+func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
+ var groups []*nbgroup.Group
+ result := s.db.Find(&groups, idQueryCondition, accountID)
+ if result.Error != nil {
+ if errors.Is(result.Error, gorm.ErrRecordNotFound) {
+ return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
+ }
+ log.WithContext(ctx).Errorf("error when getting groups from the store: %s", result.Error)
+ return nil, status.Errorf(status.Internal, "issue getting groups from store")
+ }
+
+ return groups, nil
+}
+
func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*Account) {
var accounts []Account
result := s.db.Find(&accounts)
diff --git a/management/server/store.go b/management/server/store.go
index 864871c8e..a2b489391 100644
--- a/management/server/store.go
+++ b/management/server/store.go
@@ -41,6 +41,8 @@ type Store interface {
GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error)
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
+ GetUserByUserID(ctx context.Context, userID string) (*User, error)
+ GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
SaveAccount(ctx context.Context, account *Account) error
SaveUsers(accountID string, users map[string]*User) error
From d2b04922e9a46e41581f9df5fbd194f1db265c4a Mon Sep 17 00:00:00 2001
From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com>
Date: Tue, 20 Aug 2024 11:46:58 +0200
Subject: [PATCH 03/89] Add script for loading tun module for synology (#2423)
---
release_files/install.sh | 20 ++++++++++++++++++++
1 file changed, 20 insertions(+)
diff --git a/release_files/install.sh b/release_files/install.sh
index 198d74428..d9d436ba5 100755
--- a/release_files/install.sh
+++ b/release_files/install.sh
@@ -151,6 +151,22 @@ add_aur_repo() {
${SUDO} pacman -Rs "$REMOVE_PKGS" --noconfirm
}
+prepare_tun_module() {
+ # Create the necessary file structure for /dev/net/tun
+ if [ ! -c /dev/net/tun ]; then
+ if [ ! -d /dev/net ]; then
+ mkdir -m 755 /dev/net
+ fi
+ mknod /dev/net/tun c 10 200
+ chmod 0755 /dev/net/tun
+ fi
+
+ # Load the tun module if not already loaded
+ if ! lsmod | grep -q "^tun\s"; then
+ insmod /lib/modules/tun.ko
+ fi
+}
+
install_native_binaries() {
# Checks for supported architecture
case "$ARCH" in
@@ -268,6 +284,10 @@ install_netbird() {
;;
esac
+ if [ "$OS_NAME" = "synology" ]; then
+ prepare_tun_module
+ fi
+
# Add package manager to config
${SUDO} mkdir -p "$CONFIG_FOLDER"
echo "package_manager=$PACKAGE_MANAGER" | ${SUDO} tee "$CONFIG_FILE" > /dev/null
From 2a30db02bb04e88cc323e38b64197e8386ff95b4 Mon Sep 17 00:00:00 2001
From: Viktor Liu <17948409+lixmal@users.noreply.github.com>
Date: Tue, 20 Aug 2024 18:47:41 +0200
Subject: [PATCH 04/89] [misc] Use clearer wording on issue template (#2443)
---
.github/ISSUE_TEMPLATE/bug-issue-report.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/.github/ISSUE_TEMPLATE/bug-issue-report.md b/.github/ISSUE_TEMPLATE/bug-issue-report.md
index 789c61974..87f757f42 100644
--- a/.github/ISSUE_TEMPLATE/bug-issue-report.md
+++ b/.github/ISSUE_TEMPLATE/bug-issue-report.md
@@ -35,7 +35,7 @@ Please specify whether you use NetBird Cloud or self-host NetBird's control plan
If applicable, add the `netbird status -dA' command output.
-**Do you face any client issues on desktop?**
+**Do you face any (non-mobile) client issues?**
Please provide the file created by `netbird debug for 1m -AS`.
We advise reviewing the anonymized files for any remaining PII.
From 80b0db80bc1d1d706162b11fecc6d43d42fcc78d Mon Sep 17 00:00:00 2001
From: Viktor Liu <17948409+lixmal@users.noreply.github.com>
Date: Tue, 20 Aug 2024 19:13:16 +0200
Subject: [PATCH 05/89] [client] Replace windows network monitor implementation
(#2450)
This new one uses functions from netioapi.h to monitor route changes.
This change ensures that we include routes that point to virtual
interfaces, such as vEthernet created by the Hyper-V Virtual Switch.
---
client/internal/networkmonitor/monitor_bsd.go | 2 +-
.../networkmonitor/monitor_generic.go | 2 +-
.../networkmonitor/monitor_windows.go | 255 ++-----------
.../systemops/systemops_windows.go | 348 +++++++++++++++---
4 files changed, 338 insertions(+), 269 deletions(-)
diff --git a/client/internal/networkmonitor/monitor_bsd.go b/client/internal/networkmonitor/monitor_bsd.go
index 29df7ea7f..51135a729 100644
--- a/client/internal/networkmonitor/monitor_bsd.go
+++ b/client/internal/networkmonitor/monitor_bsd.go
@@ -65,7 +65,7 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
continue
}
- if !route.Dst.Addr().IsUnspecified() {
+ if route.Dst.Bits() != 0 {
continue
}
diff --git a/client/internal/networkmonitor/monitor_generic.go b/client/internal/networkmonitor/monitor_generic.go
index f5cc19473..19648edba 100644
--- a/client/internal/networkmonitor/monitor_generic.go
+++ b/client/internal/networkmonitor/monitor_generic.go
@@ -59,7 +59,7 @@ func (nw *NetworkMonitor) Start(ctx context.Context, callback func()) (err error
// recover in case sys ops panic
defer func() {
if r := recover(); r != nil {
- err = fmt.Errorf("panic occurred: %v, stack trace: %s", r, string(debug.Stack()))
+ err = fmt.Errorf("panic occurred: %v, stack trace: %s", r, debug.Stack())
}
}()
diff --git a/client/internal/networkmonitor/monitor_windows.go b/client/internal/networkmonitor/monitor_windows.go
index 308b2aa45..cd48c269d 100644
--- a/client/internal/networkmonitor/monitor_windows.go
+++ b/client/internal/networkmonitor/monitor_windows.go
@@ -3,252 +3,73 @@ package networkmonitor
import (
"context"
"fmt"
- "net"
- "net/netip"
"strings"
- "time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
-const (
- unreachable = 0
- incomplete = 1
- probe = 2
- delay = 3
- stale = 4
- reachable = 5
- permanent = 6
- tbd = 7
-)
-
-const interval = 10 * time.Second
-
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error {
- var neighborv4, neighborv6 *systemops.Neighbor
- {
- initialNeighbors, err := getNeighbors()
- if err != nil {
- return fmt.Errorf("get neighbors: %w", err)
- }
-
- neighborv4 = assignNeighbor(nexthopv4, initialNeighbors)
- neighborv6 = assignNeighbor(nexthopv6, initialNeighbors)
+ routeMonitor, err := systemops.NewRouteMonitor(ctx)
+ if err != nil {
+ return fmt.Errorf("failed to create route monitor: %w", err)
}
- log.Debugf("Network monitor: initial IPv4 neighbor: %v, IPv6 neighbor: %v", neighborv4, neighborv6)
-
- ticker := time.NewTicker(interval)
- defer ticker.Stop()
+ defer func() {
+ if err := routeMonitor.Stop(); err != nil {
+ log.Errorf("Network monitor: failed to stop route monitor: %v", err)
+ }
+ }()
for {
select {
case <-ctx.Done():
return ErrStopped
- case <-ticker.C:
- if changed(nexthopv4, neighborv4, nexthopv6, neighborv6) {
- go callback()
- return nil
+ case route := <-routeMonitor.RouteUpdates():
+ if route.Destination.Bits() != 0 {
+ continue
+ }
+
+ if routeChanged(route, nexthopv4, nexthopv6, callback) {
+ break
}
}
}
}
-func assignNeighbor(nexthop systemops.Nexthop, initialNeighbors map[netip.Addr]systemops.Neighbor) *systemops.Neighbor {
- if n, ok := initialNeighbors[nexthop.IP]; ok &&
- n.State != unreachable &&
- n.State != incomplete &&
- n.State != tbd {
- return &n
- }
- return nil
-}
-
-func changed(
- nexthopv4 systemops.Nexthop,
- neighborv4 *systemops.Neighbor,
- nexthopv6 systemops.Nexthop,
- neighborv6 *systemops.Neighbor,
-) bool {
- neighbors, err := getNeighbors()
- if err != nil {
- log.Errorf("network monitor: error fetching current neighbors: %v", err)
- return false
- }
- if neighborChanged(nexthopv4, neighborv4, neighbors) || neighborChanged(nexthopv6, neighborv6, neighbors) {
- return true
- }
-
- routes, err := getRoutes()
- if err != nil {
- log.Errorf("network monitor: error fetching current routes: %v", err)
- return false
- }
-
- if routeChanged(nexthopv4, nexthopv4.Intf, routes) || routeChanged(nexthopv6, nexthopv6.Intf, routes) {
- return true
- }
-
- return false
-}
-
-// routeChanged checks if the default routes still point to our nexthop/interface
-func routeChanged(nexthop systemops.Nexthop, intf *net.Interface, routes []systemops.Route) bool {
- if !nexthop.IP.IsValid() {
- return false
- }
-
- if isSoftInterface(nexthop.Intf.Name) {
- log.Tracef("network monitor: ignoring default route change for soft interface %s", nexthop.Intf.Name)
- return false
- }
-
- unspec := getUnspecifiedPrefix(nexthop.IP)
- defaultRoutes, foundMatchingRoute := processRoutes(nexthop, intf, routes, unspec)
-
- log.Tracef("network monitor: all default routes:\n%s", strings.Join(defaultRoutes, "\n"))
-
- if !foundMatchingRoute {
- logRouteChange(nexthop.IP, intf)
- return true
- }
-
- return false
-}
-
-func getUnspecifiedPrefix(ip netip.Addr) netip.Prefix {
- if ip.Is6() {
- return netip.PrefixFrom(netip.IPv6Unspecified(), 0)
- }
- return netip.PrefixFrom(netip.IPv4Unspecified(), 0)
-}
-
-func processRoutes(nexthop systemops.Nexthop, nexthopIntf *net.Interface, routes []systemops.Route, unspec netip.Prefix) ([]string, bool) {
- var defaultRoutes []string
- foundMatchingRoute := false
-
- for _, r := range routes {
- if r.Destination == unspec {
- routeInfo := formatRouteInfo(r)
- defaultRoutes = append(defaultRoutes, routeInfo)
-
- if r.Nexthop == nexthop.IP && compareIntf(r.Interface, nexthopIntf) == 0 {
- foundMatchingRoute = true
- log.Debugf("network monitor: found matching default route: %s", routeInfo)
- }
+func routeChanged(route systemops.RouteUpdate, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) bool {
+ intf := ""
+ if route.Interface != nil {
+ intf = route.Interface.Name
+ if isSoftInterface(intf) {
+ log.Debugf("Network monitor: ignoring default route change for soft interface %s", intf)
+ return false
}
}
- return defaultRoutes, foundMatchingRoute
-}
-
-func formatRouteInfo(r systemops.Route) string {
- newIntf := ""
- if r.Interface != nil {
- newIntf = r.Interface.Name
- }
- return fmt.Sprintf("Nexthop: %s, Interface: %s", r.Nexthop, newIntf)
-}
-
-func logRouteChange(ip netip.Addr, intf *net.Interface) {
- oldIntf := ""
- if intf != nil {
- oldIntf = intf.Name
- }
- log.Infof("network monitor: default route for %s (%s) is gone or changed", ip, oldIntf)
-}
-
-func neighborChanged(nexthop systemops.Nexthop, neighbor *systemops.Neighbor, neighbors map[netip.Addr]systemops.Neighbor) bool {
- if neighbor == nil {
- return false
- }
-
- // TODO: consider non-local nexthops, e.g. on point-to-point interfaces
- if n, ok := neighbors[nexthop.IP]; ok {
- if n.State == unreachable || n.State == incomplete {
- log.Infof("network monitor: neighbor %s (%s) is not reachable: %s", neighbor.IPAddress, neighbor.LinkLayerAddress, stateFromInt(n.State))
- return true
- } else if n.InterfaceIndex != neighbor.InterfaceIndex {
- log.Infof(
- "network monitor: neighbor %s (%s) changed interface from '%s' (%d) to '%s' (%d): %s",
- neighbor.IPAddress,
- neighbor.LinkLayerAddress,
- neighbor.InterfaceAlias,
- neighbor.InterfaceIndex,
- n.InterfaceAlias,
- n.InterfaceIndex,
- stateFromInt(n.State),
- )
+ switch route.Type {
+ case systemops.RouteModified:
+ // TODO: get routing table to figure out if our route is affected for modified routes
+ log.Infof("Network monitor: default route changed: via %s, interface %s", route.NextHop, intf)
+ go callback()
+ return true
+ case systemops.RouteAdded:
+ if route.NextHop.Is4() && route.NextHop != nexthopv4.IP || route.NextHop.Is6() && route.NextHop != nexthopv6.IP {
+ log.Infof("Network monitor: default route added: via %s, interface %s", route.NextHop, intf)
+ go callback()
+ return true
+ }
+ case systemops.RouteDeleted:
+ if nexthopv4.Intf != nil && route.NextHop == nexthopv4.IP || nexthopv6.Intf != nil && route.NextHop == nexthopv6.IP {
+ log.Infof("Network monitor: default route removed: via %s, interface %s", route.NextHop, intf)
+ go callback()
return true
}
- } else {
- log.Infof("network monitor: neighbor %s (%s) is gone", neighbor.IPAddress, neighbor.LinkLayerAddress)
- return true
}
return false
}
-func getNeighbors() (map[netip.Addr]systemops.Neighbor, error) {
- entries, err := systemops.GetNeighbors()
- if err != nil {
- return nil, fmt.Errorf("get neighbors: %w", err)
- }
-
- neighbours := make(map[netip.Addr]systemops.Neighbor, len(entries))
- for _, entry := range entries {
- neighbours[entry.IPAddress] = entry
- }
-
- return neighbours, nil
-}
-
-func getRoutes() ([]systemops.Route, error) {
- entries, err := systemops.GetRoutes()
- if err != nil {
- return nil, fmt.Errorf("get routes: %w", err)
- }
-
- return entries, nil
-}
-
-func stateFromInt(state uint8) string {
- switch state {
- case unreachable:
- return "unreachable"
- case incomplete:
- return "incomplete"
- case probe:
- return "probe"
- case delay:
- return "delay"
- case stale:
- return "stale"
- case reachable:
- return "reachable"
- case permanent:
- return "permanent"
- case tbd:
- return "tbd"
- default:
- return "unknown"
- }
-}
-
-func compareIntf(a, b *net.Interface) int {
- switch {
- case a == nil && b == nil:
- return 0
- case a == nil:
- return -1
- case b == nil:
- return 1
- default:
- return a.Index - b.Index
- }
-}
-
func isSoftInterface(name string) bool {
return strings.Contains(strings.ToLower(name), "isatap") || strings.Contains(strings.ToLower(name), "teredo")
}
diff --git a/client/internal/routemanager/systemops/systemops_windows.go b/client/internal/routemanager/systemops/systemops_windows.go
index 0d3630cb8..3f756788e 100644
--- a/client/internal/routemanager/systemops/systemops_windows.go
+++ b/client/internal/routemanager/systemops/systemops_windows.go
@@ -3,6 +3,8 @@
package systemops
import (
+ "context"
+ "encoding/binary"
"fmt"
"net"
"net/netip"
@@ -11,15 +13,43 @@ import (
"strconv"
"strings"
"sync"
+ "syscall"
"time"
+ "unsafe"
log "github.com/sirupsen/logrus"
"github.com/yusufpapurcu/wmi"
+ "golang.org/x/sys/windows"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
nbnet "github.com/netbirdio/netbird/util/net"
)
+type RouteUpdateType int
+
+// RouteUpdate represents a change in the routing table.
+// The interface field contains the index only.
+type RouteUpdate struct {
+ Type RouteUpdateType
+ Destination netip.Prefix
+ NextHop netip.Addr
+ Interface *net.Interface
+}
+
+// RouteMonitor provides a way to monitor changes in the routing table.
+type RouteMonitor struct {
+ updates chan RouteUpdate
+ handle windows.Handle
+ done chan struct{}
+}
+
+// Route represents a single routing table entry.
+type Route struct {
+ Destination netip.Prefix
+ Nexthop netip.Addr
+ Interface *net.Interface
+}
+
type MSFT_NetRoute struct {
DestinationPrefix string
NextHop string
@@ -28,33 +58,77 @@ type MSFT_NetRoute struct {
AddressFamily uint16
}
-type Route struct {
- Destination netip.Prefix
- Nexthop netip.Addr
- Interface *net.Interface
+// MIB_IPFORWARD_ROW2 is defined in https://learn.microsoft.com/en-us/windows/win32/api/netioapi/ns-netioapi-mib_ipforward_row2
+type MIB_IPFORWARD_ROW2 struct {
+ InterfaceLuid uint64
+ InterfaceIndex uint32
+ DestinationPrefix IP_ADDRESS_PREFIX
+ NextHop SOCKADDR_INET_NEXTHOP
+ SitePrefixLength uint8
+ ValidLifetime uint32
+ PreferredLifetime uint32
+ Metric uint32
+ Protocol uint32
+ Loopback uint8
+ AutoconfigureAddress uint8
+ Publish uint8
+ Immortal uint8
+ Age uint32
+ Origin uint32
}
-type MSFT_NetNeighbor struct {
- IPAddress string
- LinkLayerAddress string
- State uint8
- AddressFamily uint16
- InterfaceIndex uint32
- InterfaceAlias string
+// IP_ADDRESS_PREFIX is defined in https://learn.microsoft.com/en-us/windows/win32/api/netioapi/ns-netioapi-ip_address_prefix
+type IP_ADDRESS_PREFIX struct {
+ Prefix SOCKADDR_INET
+ PrefixLength uint8
}
-type Neighbor struct {
- IPAddress netip.Addr
- LinkLayerAddress string
- State uint8
- AddressFamily uint16
- InterfaceIndex uint32
- InterfaceAlias string
+// SOCKADDR_INET is defined in https://learn.microsoft.com/en-us/windows/win32/api/ws2ipdef/ns-ws2ipdef-sockaddr_inet
+// It represents the union of IPv4 and IPv6 socket addresses
+type SOCKADDR_INET struct {
+ sin6_family int16
+ // nolint:unused
+ sin6_port uint16
+ // 4 bytes ipv4 or 4 bytes flowinfo + 16 bytes ipv6 + 4 bytes scope_id
+ data [24]byte
}
-var prefixList []netip.Prefix
-var lastUpdate time.Time
-var mux = sync.Mutex{}
+// SOCKADDR_INET_NEXTHOP is the same as SOCKADDR_INET but offset by 2 bytes
+type SOCKADDR_INET_NEXTHOP struct {
+ // nolint:unused
+ pad [2]byte
+ sin6_family int16
+ // nolint:unused
+ sin6_port uint16
+ // 4 bytes ipv4 or 4 bytes flowinfo + 16 bytes ipv6 + 4 bytes scope_id
+ data [24]byte
+}
+
+// MIB_NOTIFICATION_TYPE is defined in https://learn.microsoft.com/en-us/windows/win32/api/netioapi/ne-netioapi-mib_notification_type
+type MIB_NOTIFICATION_TYPE int32
+
+var (
+ modiphlpapi = windows.NewLazyDLL("iphlpapi.dll")
+ procNotifyRouteChange2 = modiphlpapi.NewProc("NotifyRouteChange2")
+ procCancelMibChangeNotify2 = modiphlpapi.NewProc("CancelMibChangeNotify2")
+
+ prefixList []netip.Prefix
+ lastUpdate time.Time
+ mux sync.Mutex
+)
+
+const (
+ MibParemeterModification MIB_NOTIFICATION_TYPE = iota
+ MibAddInstance
+ MibDeleteInstance
+ MibInitialNotification
+)
+
+const (
+ RouteModified RouteUpdateType = iota
+ RouteAdded
+ RouteDeleted
+)
func (r *SysOps) SetupRouting(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
return r.setupRefCounter(initAddresses)
@@ -94,6 +168,155 @@ func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) erro
return nil
}
+// NewRouteMonitor creates and starts a new RouteMonitor.
+// It returns a pointer to the RouteMonitor and an error if the monitor couldn't be started.
+func NewRouteMonitor(ctx context.Context) (*RouteMonitor, error) {
+ rm := &RouteMonitor{
+ updates: make(chan RouteUpdate, 5),
+ done: make(chan struct{}),
+ }
+
+ if err := rm.start(ctx); err != nil {
+ return nil, err
+ }
+
+ return rm, nil
+}
+
+func (rm *RouteMonitor) start(ctx context.Context) error {
+ if ctx.Err() != nil {
+ return ctx.Err()
+ }
+
+ callbackPtr := windows.NewCallback(func(callerContext uintptr, row *MIB_IPFORWARD_ROW2, notificationType MIB_NOTIFICATION_TYPE) uintptr {
+ if ctx.Err() != nil {
+ return 0
+ }
+
+ update, err := rm.parseUpdate(row, notificationType)
+ if err != nil {
+ log.Errorf("Failed to parse route update: %v", err)
+ return 0
+ }
+
+ select {
+ case <-rm.done:
+ return 0
+ case rm.updates <- update:
+ default:
+ log.Warn("Route update channel is full, dropping update")
+ }
+ return 0
+ })
+
+ var handle windows.Handle
+ if err := notifyRouteChange2(windows.AF_UNSPEC, callbackPtr, 0, false, &handle); err != nil {
+ return fmt.Errorf("NotifyRouteChange2 failed: %w", err)
+ }
+
+ rm.handle = handle
+
+ return nil
+}
+
+func (rm *RouteMonitor) parseUpdate(row *MIB_IPFORWARD_ROW2, notificationType MIB_NOTIFICATION_TYPE) (RouteUpdate, error) {
+ // destination prefix, next hop, interface index, interface luid are guaranteed to be there
+ // GetIpForwardEntry2 is not needed
+
+ var update RouteUpdate
+
+ idx := int(row.InterfaceIndex)
+ if idx != 0 {
+ intf, err := net.InterfaceByIndex(idx)
+ if err != nil {
+ return update, fmt.Errorf("get interface name: %w", err)
+ }
+
+ update.Interface = intf
+ }
+
+ log.Tracef("Received route update with destination %v, next hop %v, interface %v", row.DestinationPrefix, row.NextHop, update.Interface)
+ dest := parseIPPrefix(row.DestinationPrefix, idx)
+ if !dest.Addr().IsValid() {
+ return RouteUpdate{}, fmt.Errorf("invalid destination: %v", row)
+ }
+
+ nexthop := parseIPNexthop(row.NextHop, idx)
+ if !nexthop.IsValid() {
+ return RouteUpdate{}, fmt.Errorf("invalid next hop %v", row)
+ }
+
+ updateType := RouteModified
+ switch notificationType {
+ case MibParemeterModification:
+ updateType = RouteModified
+ case MibAddInstance:
+ updateType = RouteAdded
+ case MibDeleteInstance:
+ updateType = RouteDeleted
+ }
+
+ update.Type = updateType
+ update.Destination = dest
+ update.NextHop = nexthop
+
+ return update, nil
+}
+
+// Stop stops the RouteMonitor.
+func (rm *RouteMonitor) Stop() error {
+ if rm.handle != 0 {
+ if err := cancelMibChangeNotify2(rm.handle); err != nil {
+ return fmt.Errorf("CancelMibChangeNotify2 failed: %w", err)
+ }
+ rm.handle = 0
+ }
+ close(rm.done)
+ close(rm.updates)
+ return nil
+}
+
+// RouteUpdates returns a channel that receives RouteUpdate messages.
+func (rm *RouteMonitor) RouteUpdates() <-chan RouteUpdate {
+ return rm.updates
+}
+
+func notifyRouteChange2(family uint32, callback uintptr, callerContext uintptr, initialNotification bool, handle *windows.Handle) error {
+ var initNotif uint32
+ if initialNotification {
+ initNotif = 1
+ }
+
+ r1, _, e1 := syscall.SyscallN(
+ procNotifyRouteChange2.Addr(),
+ uintptr(family),
+ callback,
+ callerContext,
+ uintptr(initNotif),
+ uintptr(unsafe.Pointer(handle)),
+ )
+ if r1 != 0 {
+ if e1 != 0 {
+ return e1
+ }
+ return syscall.EINVAL
+ }
+ return nil
+}
+
+func cancelMibChangeNotify2(handle windows.Handle) error {
+ r1, _, e1 := syscall.SyscallN(procCancelMibChangeNotify2.Addr(), uintptr(handle))
+ if r1 != 0 {
+ if e1 != 0 {
+ return e1
+ }
+ return syscall.EINVAL
+ }
+ return nil
+}
+
+// GetRoutesFromTable returns the current routing table from with prefixes only.
+// It ccaches the result for 2 seconds to avoid blocking the caller.
func GetRoutesFromTable() ([]netip.Prefix, error) {
mux.Lock()
defer mux.Unlock()
@@ -117,6 +340,7 @@ func GetRoutesFromTable() ([]netip.Prefix, error) {
return prefixList, nil
}
+// GetRoutes retrieves the current routing table using WMI.
func GetRoutes() ([]Route, error) {
var entries []MSFT_NetRoute
@@ -146,8 +370,8 @@ func GetRoutes() ([]Route, error) {
Name: entry.InterfaceAlias,
}
- if nexthop.Is6() && (nexthop.IsLinkLocalUnicast() || nexthop.IsLinkLocalMulticast()) {
- nexthop = nexthop.WithZone(strconv.Itoa(int(entry.InterfaceIndex)))
+ if nexthop.Is6() {
+ nexthop = addZone(nexthop, int(entry.InterfaceIndex))
}
}
@@ -161,33 +385,6 @@ func GetRoutes() ([]Route, error) {
return routes, nil
}
-func GetNeighbors() ([]Neighbor, error) {
- var entries []MSFT_NetNeighbor
- query := `SELECT IPAddress, LinkLayerAddress, State, AddressFamily, InterfaceIndex, InterfaceAlias FROM MSFT_NetNeighbor`
- if err := wmi.QueryNamespace(query, &entries, `ROOT\StandardCimv2`); err != nil {
- return nil, fmt.Errorf("failed to query MSFT_NetNeighbor: %w", err)
- }
-
- var neighbors []Neighbor
- for _, entry := range entries {
- addr, err := netip.ParseAddr(entry.IPAddress)
- if err != nil {
- log.Warnf("Unable to parse neighbor IP address %s: %v", entry.IPAddress, err)
- continue
- }
- neighbors = append(neighbors, Neighbor{
- IPAddress: addr,
- LinkLayerAddress: entry.LinkLayerAddress,
- State: entry.State,
- AddressFamily: entry.AddressFamily,
- InterfaceIndex: entry.InterfaceIndex,
- InterfaceAlias: entry.InterfaceAlias,
- })
- }
-
- return neighbors, nil
-}
-
func addRouteCmd(prefix netip.Prefix, nexthop Nexthop) error {
args := []string{"add", prefix.String()}
@@ -220,3 +417,54 @@ func addRouteCmd(prefix netip.Prefix, nexthop Nexthop) error {
func isCacheDisabled() bool {
return os.Getenv("NB_DISABLE_ROUTE_CACHE") == "true"
}
+
+func parseIPPrefix(prefix IP_ADDRESS_PREFIX, idx int) netip.Prefix {
+ ip := parseIP(prefix.Prefix, idx)
+ return netip.PrefixFrom(ip, int(prefix.PrefixLength))
+}
+
+func parseIP(addr SOCKADDR_INET, idx int) netip.Addr {
+ return parseIPGeneric(addr.sin6_family, addr.data, idx)
+}
+
+func parseIPNexthop(addr SOCKADDR_INET_NEXTHOP, idx int) netip.Addr {
+ return parseIPGeneric(addr.sin6_family, addr.data, idx)
+}
+
+func parseIPGeneric(family int16, data [24]byte, interfaceIndex int) netip.Addr {
+ switch family {
+ case windows.AF_INET:
+ ipv4 := binary.BigEndian.Uint32(data[:4])
+ return netip.AddrFrom4([4]byte{
+ byte(ipv4 >> 24),
+ byte(ipv4 >> 16),
+ byte(ipv4 >> 8),
+ byte(ipv4),
+ })
+
+ case windows.AF_INET6:
+ // The IPv6 address is stored after the 4-byte flowinfo field
+ var ipv6 [16]byte
+ copy(ipv6[:], data[4:20])
+ ip := netip.AddrFrom16(ipv6)
+
+ // Check if there's a non-zero scope_id
+ scopeID := binary.BigEndian.Uint32(data[20:24])
+ if scopeID != 0 {
+ ip = ip.WithZone(strconv.FormatUint(uint64(scopeID), 10))
+ } else if interfaceIndex != 0 {
+ ip = addZone(ip, interfaceIndex)
+ }
+
+ return ip
+ }
+
+ return netip.IPv4Unspecified()
+}
+
+func addZone(ip netip.Addr, interfaceIndex int) netip.Addr {
+ if ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
+ ip = ip.WithZone(strconv.Itoa(interfaceIndex))
+ }
+ return ip
+}
From 8c2d37d3fc0a97dd16986358e7389a12c92b156b Mon Sep 17 00:00:00 2001
From: Viktor Liu <17948409+lixmal@users.noreply.github.com>
Date: Tue, 20 Aug 2024 19:13:40 +0200
Subject: [PATCH 06/89] [management] Fix logging out peers on deletion (#2453)
---
management/server/grpcserver.go | 18 ++++++++++++------
1 file changed, 12 insertions(+), 6 deletions(-)
diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go
index ff7a71cfd..ead4a29d6 100644
--- a/management/server/grpcserver.go
+++ b/management/server/grpcserver.go
@@ -132,24 +132,30 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
ctx := srv.Context()
- realIP := getRealIP(ctx)
-
syncReq := &proto.SyncRequest{}
peerKey, err := s.parseRequest(ctx, req, syncReq)
if err != nil {
return err
}
- //nolint
+ // nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
+
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
if err != nil {
- // this case should not happen and already indicates an issue but we don't want the system to fail due to being unable to log in detail
- accountID = "UNKNOWN"
+ // nolint:staticcheck
+ ctx = context.WithValue(ctx, nbContext.AccountIDKey, "UNKNOWN")
+ log.WithContext(ctx).Tracef("peer %s is not registered", peerKey.String())
+ if errStatus, ok := internalStatus.FromError(err); ok && errStatus.Type() == internalStatus.NotFound {
+ return status.Errorf(codes.PermissionDenied, "peer is not registered")
+ }
+ return err
}
- //nolint
+
+ // nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
+ realIP := getRealIP(ctx)
log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, realIP.String())
if syncReq.GetMeta() == nil {
From 3ed90728e64187191e78144554c7ce060bc2f52f Mon Sep 17 00:00:00 2001
From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com>
Date: Tue, 20 Aug 2024 20:06:01 +0200
Subject: [PATCH 07/89] [management] Add buffering for getAccount requests
during login (#2449)
---
management/server/account.go | 3 +
management/server/account_cache.go | 106 ++++++++++
management/server/management_proto_test.go | 224 ++++++++++++++++++++-
management/server/peer.go | 2 +-
4 files changed, 329 insertions(+), 6 deletions(-)
create mode 100644 management/server/account_cache.go
diff --git a/management/server/account.go b/management/server/account.go
index 4c150fd7e..23781c915 100644
--- a/management/server/account.go
+++ b/management/server/account.go
@@ -161,6 +161,8 @@ type DefaultAccountManager struct {
eventStore activity.Store
geo *geolocation.Geolocation
+ cache *AccountCache
+
// singleAccountMode indicates whether the instance has a single account.
// If true, then every new user will end up under the same account.
// This value will be set to false if management service has more than one account.
@@ -967,6 +969,7 @@ func BuildManager(
userDeleteFromIDPEnabled: userDeleteFromIDPEnabled,
integratedPeerValidator: integratedPeerValidator,
metrics: metrics,
+ cache: NewAccountCache(ctx, store),
}
allAccounts := store.GetAllAccounts(ctx)
// enable single account mode only if configured by user and number of existing accounts is not grater than 1
diff --git a/management/server/account_cache.go b/management/server/account_cache.go
new file mode 100644
index 000000000..13ce45819
--- /dev/null
+++ b/management/server/account_cache.go
@@ -0,0 +1,106 @@
+package server
+
+import (
+ "context"
+ "os"
+ "sync"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+)
+
+// AccountRequest holds the result channel to return the requested account.
+type AccountRequest struct {
+ AccountID string
+ ResultChan chan *AccountResult
+}
+
+// AccountResult holds the account data or an error.
+type AccountResult struct {
+ Account *Account
+ Err error
+}
+
+type AccountCache struct {
+ store Store
+ getAccountRequests map[string][]*AccountRequest
+ mu sync.Mutex
+ getAccountRequestCh chan *AccountRequest
+ bufferInterval time.Duration
+}
+
+func NewAccountCache(ctx context.Context, store Store) *AccountCache {
+ bufferIntervalStr := os.Getenv("NB_GET_ACCOUNT_BUFFER_INTERVAL")
+ bufferInterval, err := time.ParseDuration(bufferIntervalStr)
+ if err != nil && bufferIntervalStr != "" {
+ log.WithContext(ctx).Warnf("failed to parse account cache buffer interval: %s", err)
+ bufferInterval = 300 * time.Millisecond
+ }
+
+ log.WithContext(ctx).Infof("set account cache buffer interval to %s", bufferInterval)
+
+ ac := AccountCache{
+ store: store,
+ getAccountRequests: make(map[string][]*AccountRequest),
+ getAccountRequestCh: make(chan *AccountRequest),
+ bufferInterval: bufferInterval,
+ }
+
+ go ac.processGetAccountRequests(ctx)
+
+ return &ac
+}
+func (ac *AccountCache) GetAccountWithBackpressure(ctx context.Context, accountID string) (*Account, error) {
+ req := &AccountRequest{
+ AccountID: accountID,
+ ResultChan: make(chan *AccountResult, 1),
+ }
+
+ log.WithContext(ctx).Tracef("requesting account %s with backpressure", accountID)
+ startTime := time.Now()
+ ac.getAccountRequestCh <- req
+
+ result := <-req.ResultChan
+ log.WithContext(ctx).Tracef("got account with backpressure after %s", time.Since(startTime))
+ return result.Account, result.Err
+}
+
+func (ac *AccountCache) processGetAccountBatch(ctx context.Context, accountID string) {
+ ac.mu.Lock()
+ requests := ac.getAccountRequests[accountID]
+ delete(ac.getAccountRequests, accountID)
+ ac.mu.Unlock()
+
+ if len(requests) == 0 {
+ return
+ }
+
+ startTime := time.Now()
+ account, err := ac.store.GetAccount(ctx, accountID)
+ log.WithContext(ctx).Tracef("getting account %s in batch took %s", accountID, time.Since(startTime))
+ result := &AccountResult{Account: account, Err: err}
+
+ for _, req := range requests {
+ req.ResultChan <- result
+ close(req.ResultChan)
+ }
+}
+
+func (ac *AccountCache) processGetAccountRequests(ctx context.Context) {
+ for {
+ select {
+ case req := <-ac.getAccountRequestCh:
+ ac.mu.Lock()
+ ac.getAccountRequests[req.AccountID] = append(ac.getAccountRequests[req.AccountID], req)
+ if len(ac.getAccountRequests[req.AccountID]) == 1 {
+ go func(ctx context.Context, accountID string) {
+ time.Sleep(ac.bufferInterval)
+ ac.processGetAccountBatch(ctx, accountID)
+ }(ctx, req.AccountID)
+ }
+ ac.mu.Unlock()
+ case <-ctx.Done():
+ return
+ }
+ }
+}
diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go
index fe1e36d47..aa9c0d81e 100644
--- a/management/server/management_proto_test.go
+++ b/management/server/management_proto_test.go
@@ -3,13 +3,17 @@ package server
import (
"context"
"fmt"
+ "io"
"net"
"os"
"path/filepath"
"runtime"
+ "sync"
+ "sync/atomic"
"testing"
"time"
+ log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
@@ -24,6 +28,12 @@ import (
"github.com/netbirdio/netbird/util"
)
+type TestingT interface {
+ require.TestingT
+ Helper()
+ Cleanup(func())
+}
+
var (
kaep = keepalive.EnforcementPolicy{
MinTime: 15 * time.Second,
@@ -86,7 +96,7 @@ func Test_SyncProtocol(t *testing.T) {
defer func() {
os.Remove(filepath.Join(dir, "store.json")) //nolint
}()
- mgmtServer, _, mgmtAddr, err := startManagement(t, &Config{
+ mgmtServer, _, mgmtAddr, err := startManagementForTest(t, &Config{
Stuns: []*Host{{
Proto: "udp",
URI: "stun:stun.wiretrustee.com:3468",
@@ -402,7 +412,7 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
}
}
-func startManagement(t *testing.T, config *Config) (*grpc.Server, *DefaultAccountManager, string, error) {
+func startManagementForTest(t TestingT, config *Config) (*grpc.Server, *DefaultAccountManager, string, error) {
t.Helper()
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
@@ -485,7 +495,7 @@ func testSyncStatusRace(t *testing.T) {
os.Remove(filepath.Join(dir, "store.json")) //nolint
}()
- mgmtServer, am, mgmtAddr, err := startManagement(t, &Config{
+ mgmtServer, am, mgmtAddr, err := startManagementForTest(t, &Config{
Stuns: []*Host{{
Proto: "udp",
URI: "stun:stun.wiretrustee.com:3468",
@@ -545,7 +555,6 @@ func testSyncStatusRace(t *testing.T) {
ctx2, cancelFunc2 := context.WithCancel(context.Background())
- //client.
sync2, err := client.Sync(ctx2, &mgmtProto.EncryptedMessage{
WgPubKey: concurrentPeerKey2.PublicKey().String(),
Body: message2,
@@ -574,7 +583,7 @@ func testSyncStatusRace(t *testing.T) {
ctx, cancelFunc := context.WithCancel(context.Background())
- //client.
+ // client.
sync, err := client.Sync(ctx, &mgmtProto.EncryptedMessage{
WgPubKey: peerWithInvalidStatus.PublicKey().String(),
Body: message,
@@ -626,3 +635,208 @@ func testSyncStatusRace(t *testing.T) {
t.Fatal("Peer should be connected")
}
}
+
+func Test_LoginPerformance(t *testing.T) {
+ if os.Getenv("CI") == "true" {
+ t.Skip("Skipping on CI")
+ }
+
+ t.Setenv("NETBIRD_STORE_ENGINE", "sqlite")
+
+ benchCases := []struct {
+ name string
+ peers int
+ accounts int
+ }{
+ // {"XXS", 5, 1},
+ // {"XS", 10, 1},
+ // {"S", 100, 1},
+ // {"M", 250, 1},
+ // {"L", 500, 1},
+ // {"XL", 750, 1},
+ {"XXL", 1000, 5},
+ }
+
+ log.SetOutput(io.Discard)
+ defer log.SetOutput(os.Stderr)
+
+ for _, bc := range benchCases {
+ t.Run(bc.name, func(t *testing.T) {
+ t.Helper()
+ dir := t.TempDir()
+ err := util.CopyFileContents("testdata/store_with_expired_peers.json", filepath.Join(dir, "store.json"))
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer func() {
+ os.Remove(filepath.Join(dir, "store.json")) //nolint
+ }()
+
+ mgmtServer, am, _, err := startManagementForTest(t, &Config{
+ Stuns: []*Host{{
+ Proto: "udp",
+ URI: "stun:stun.wiretrustee.com:3468",
+ }},
+ TURNConfig: &TURNConfig{
+ TimeBasedCredentials: false,
+ CredentialsTTL: util.Duration{},
+ Secret: "whatever",
+ Turns: []*Host{{
+ Proto: "udp",
+ URI: "turn:stun.wiretrustee.com:3468",
+ }},
+ },
+ Signal: &Host{
+ Proto: "http",
+ URI: "signal.wiretrustee.com:10000",
+ },
+ Datadir: dir,
+ HttpConfig: nil,
+ })
+ if err != nil {
+ t.Fatal(err)
+ return
+ }
+ defer mgmtServer.GracefulStop()
+
+ var counter int32
+ var counterStart int32
+ var wg sync.WaitGroup
+ var mu sync.Mutex
+ messageCalls := []func() error{}
+ for j := 0; j < bc.accounts; j++ {
+ wg.Add(1)
+ go func(j int, counter *int32, counterStart *int32) {
+ defer wg.Done()
+
+ account, err := createAccount(am, fmt.Sprintf("account-%d", j), fmt.Sprintf("user-%d", j), fmt.Sprintf("domain-%d", j))
+ if err != nil {
+ t.Logf("account creation failed: %v", err)
+ return
+ }
+
+ setupKey, err := am.CreateSetupKey(context.Background(), account.Id, fmt.Sprintf("key-%d", j), SetupKeyReusable, time.Hour, nil, 0, fmt.Sprintf("user-%d", j), false)
+ if err != nil {
+ t.Logf("error creating setup key: %v", err)
+ return
+ }
+
+ for i := 0; i < bc.peers; i++ {
+ key, err := wgtypes.GeneratePrivateKey()
+ if err != nil {
+ t.Logf("failed to generate key: %v", err)
+ return
+ }
+
+ meta := &mgmtProto.PeerSystemMeta{
+ Hostname: key.PublicKey().String(),
+ GoOS: runtime.GOOS,
+ OS: runtime.GOOS,
+ Core: "core",
+ Platform: "platform",
+ Kernel: "kernel",
+ WiretrusteeVersion: "",
+ }
+
+ peerLogin := PeerLogin{
+ WireGuardPubKey: key.String(),
+ SSHKey: "random",
+ Meta: extractPeerMeta(context.Background(), meta),
+ SetupKey: setupKey.Key,
+ ConnectionIP: net.IP{1, 1, 1, 1},
+ }
+
+ login := func() error {
+ _, _, _, err = am.LoginPeer(context.Background(), peerLogin)
+ if err != nil {
+ t.Logf("failed to login peer: %v", err)
+ return err
+ }
+ atomic.AddInt32(counter, 1)
+ if *counter%100 == 0 {
+ t.Logf("finished %d login calls", *counter)
+ }
+ return nil
+ }
+
+ mu.Lock()
+ messageCalls = append(messageCalls, login)
+ mu.Unlock()
+ _, _, _, err = am.LoginPeer(context.Background(), peerLogin)
+ if err != nil {
+ t.Logf("failed to login peer: %v", err)
+ return
+ }
+
+ atomic.AddInt32(counterStart, 1)
+ if *counterStart%100 == 0 {
+ t.Logf("registered %d peers", *counterStart)
+ }
+ }
+ }(j, &counter, &counterStart)
+ }
+
+ wg.Wait()
+
+ t.Logf("prepared %d login calls", len(messageCalls))
+ testLoginPerformance(t, messageCalls)
+
+ })
+ }
+}
+
+func testLoginPerformance(t *testing.T, loginCalls []func() error) {
+ t.Helper()
+ wgSetup := sync.WaitGroup{}
+ startChan := make(chan struct{})
+
+ wgDone := sync.WaitGroup{}
+ durations := []time.Duration{}
+ l := sync.Mutex{}
+
+ for i, function := range loginCalls {
+ wgSetup.Add(1)
+ wgDone.Add(1)
+ go func(function func() error, i int) {
+ defer wgDone.Done()
+ wgSetup.Done()
+
+ <-startChan
+ start := time.Now()
+
+ err := function()
+ if err != nil {
+ t.Logf("Error: %v", err)
+ return
+ }
+
+ duration := time.Since(start)
+ l.Lock()
+ durations = append(durations, duration)
+ l.Unlock()
+ }(function, i)
+ }
+
+ wgSetup.Wait()
+ t.Logf("starting login calls")
+ close(startChan)
+ wgDone.Wait()
+ var tMin, tMax, tSum time.Duration
+ for i, d := range durations {
+ if i == 0 {
+ tMin = d
+ tMax = d
+ tSum = d
+ continue
+ }
+ if d < tMin {
+ tMin = d
+ }
+ if d > tMax {
+ tMax = d
+ }
+ tSum += d
+ }
+ tAvg := tSum / time.Duration(len(durations))
+ t.Logf("Min: %v, Max: %v, Avg: %v", tMin, tMax, tAvg)
+}
diff --git a/management/server/peer.go b/management/server/peer.go
index 93234d9de..c7d757bb4 100644
--- a/management/server/peer.go
+++ b/management/server/peer.go
@@ -714,7 +714,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
unlockPeer()
unlockPeer = nil
- account, err := am.Store.GetAccount(ctx, accountID)
+ account, err := am.cache.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, nil, err
}
From 0f0415b92a62970dc4bb07f44566ba8a31ac19d4 Mon Sep 17 00:00:00 2001
From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com>
Date: Wed, 21 Aug 2024 11:44:52 +0200
Subject: [PATCH 08/89] rename request buffer and update default interval
(#2459)
---
management/server/account.go | 4 ++--
...unt_cache.go => account_request_buffer.go} | 22 ++++++++++---------
management/server/management_proto_test.go | 2 +-
management/server/peer.go | 2 +-
4 files changed, 16 insertions(+), 14 deletions(-)
rename management/server/{account_cache.go => account_request_buffer.go} (75%)
diff --git a/management/server/account.go b/management/server/account.go
index 23781c915..49341a67b 100644
--- a/management/server/account.go
+++ b/management/server/account.go
@@ -161,7 +161,7 @@ type DefaultAccountManager struct {
eventStore activity.Store
geo *geolocation.Geolocation
- cache *AccountCache
+ requestBuffer *AccountRequestBuffer
// singleAccountMode indicates whether the instance has a single account.
// If true, then every new user will end up under the same account.
@@ -969,7 +969,7 @@ func BuildManager(
userDeleteFromIDPEnabled: userDeleteFromIDPEnabled,
integratedPeerValidator: integratedPeerValidator,
metrics: metrics,
- cache: NewAccountCache(ctx, store),
+ requestBuffer: NewAccountRequestBuffer(ctx, store),
}
allAccounts := store.GetAllAccounts(ctx)
// enable single account mode only if configured by user and number of existing accounts is not grater than 1
diff --git a/management/server/account_cache.go b/management/server/account_request_buffer.go
similarity index 75%
rename from management/server/account_cache.go
rename to management/server/account_request_buffer.go
index 13ce45819..5f4897e6a 100644
--- a/management/server/account_cache.go
+++ b/management/server/account_request_buffer.go
@@ -21,7 +21,7 @@ type AccountResult struct {
Err error
}
-type AccountCache struct {
+type AccountRequestBuffer struct {
store Store
getAccountRequests map[string][]*AccountRequest
mu sync.Mutex
@@ -29,17 +29,19 @@ type AccountCache struct {
bufferInterval time.Duration
}
-func NewAccountCache(ctx context.Context, store Store) *AccountCache {
+func NewAccountRequestBuffer(ctx context.Context, store Store) *AccountRequestBuffer {
bufferIntervalStr := os.Getenv("NB_GET_ACCOUNT_BUFFER_INTERVAL")
bufferInterval, err := time.ParseDuration(bufferIntervalStr)
- if err != nil && bufferIntervalStr != "" {
- log.WithContext(ctx).Warnf("failed to parse account cache buffer interval: %s", err)
- bufferInterval = 300 * time.Millisecond
+ if err != nil {
+ if bufferIntervalStr != "" {
+ log.WithContext(ctx).Warnf("failed to parse account request buffer interval: %s", err)
+ }
+ bufferInterval = 100 * time.Millisecond
}
- log.WithContext(ctx).Infof("set account cache buffer interval to %s", bufferInterval)
+ log.WithContext(ctx).Infof("set account request buffer interval to %s", bufferInterval)
- ac := AccountCache{
+ ac := AccountRequestBuffer{
store: store,
getAccountRequests: make(map[string][]*AccountRequest),
getAccountRequestCh: make(chan *AccountRequest),
@@ -50,7 +52,7 @@ func NewAccountCache(ctx context.Context, store Store) *AccountCache {
return &ac
}
-func (ac *AccountCache) GetAccountWithBackpressure(ctx context.Context, accountID string) (*Account, error) {
+func (ac *AccountRequestBuffer) GetAccountWithBackpressure(ctx context.Context, accountID string) (*Account, error) {
req := &AccountRequest{
AccountID: accountID,
ResultChan: make(chan *AccountResult, 1),
@@ -65,7 +67,7 @@ func (ac *AccountCache) GetAccountWithBackpressure(ctx context.Context, accountI
return result.Account, result.Err
}
-func (ac *AccountCache) processGetAccountBatch(ctx context.Context, accountID string) {
+func (ac *AccountRequestBuffer) processGetAccountBatch(ctx context.Context, accountID string) {
ac.mu.Lock()
requests := ac.getAccountRequests[accountID]
delete(ac.getAccountRequests, accountID)
@@ -86,7 +88,7 @@ func (ac *AccountCache) processGetAccountBatch(ctx context.Context, accountID st
}
}
-func (ac *AccountCache) processGetAccountRequests(ctx context.Context) {
+func (ac *AccountRequestBuffer) processGetAccountRequests(ctx context.Context) {
for {
select {
case req := <-ac.getAccountRequestCh:
diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go
index aa9c0d81e..d48e1f513 100644
--- a/management/server/management_proto_test.go
+++ b/management/server/management_proto_test.go
@@ -654,7 +654,7 @@ func Test_LoginPerformance(t *testing.T) {
// {"M", 250, 1},
// {"L", 500, 1},
// {"XL", 750, 1},
- {"XXL", 1000, 5},
+ {"XXL", 2000, 1},
}
log.SetOutput(io.Discard)
diff --git a/management/server/peer.go b/management/server/peer.go
index c7d757bb4..6926ef6bc 100644
--- a/management/server/peer.go
+++ b/management/server/peer.go
@@ -714,7 +714,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
unlockPeer()
unlockPeer = nil
- account, err := am.cache.GetAccountWithBackpressure(ctx, accountID)
+ account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, nil, err
}
From 5d6dfe59388d43467599603ad4d803f1e0fcd8eb Mon Sep 17 00:00:00 2001
From: Maycon Santos
Date: Wed, 21 Aug 2024 12:11:45 +0200
Subject: [PATCH 09/89] Add test for SetFlagsFromEnvVars (#2460)
---
client/cmd/root_test.go | 45 +++++++++++++++++++++++++++++++++++++++++
1 file changed, 45 insertions(+)
diff --git a/client/cmd/root_test.go b/client/cmd/root_test.go
index abb7d41b2..f2805cf35 100644
--- a/client/cmd/root_test.go
+++ b/client/cmd/root_test.go
@@ -4,6 +4,10 @@ import (
"fmt"
"io"
"testing"
+
+ "github.com/spf13/cobra"
+
+ "github.com/netbirdio/netbird/iface"
)
func TestInitCommands(t *testing.T) {
@@ -34,3 +38,44 @@ func TestInitCommands(t *testing.T) {
})
}
}
+
+func TestSetFlagsFromEnvVars(t *testing.T) {
+ var cmd = &cobra.Command{
+ Use: "netbird",
+ Long: "test",
+ SilenceUsage: true,
+ Run: func(cmd *cobra.Command, args []string) {
+ SetFlagsFromEnvVars(cmd)
+ },
+ }
+
+ cmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil,
+ `comma separated list of external IPs to map to the Wireguard interface`)
+ cmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
+ cmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "Enable Rosenpass feature Rosenpass.")
+ cmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
+
+ t.Setenv("NB_EXTERNAL_IP_MAP", "abc,dec")
+ t.Setenv("NB_INTERFACE_NAME", "test-name")
+ t.Setenv("NB_ENABLE_ROSENPASS", "true")
+ t.Setenv("NB_WIREGUARD_PORT", "10000")
+ err := cmd.Execute()
+ if err != nil {
+ t.Fatalf("expected no error while running netbird command, got %v", err)
+ }
+ if len(natExternalIPs) != 2 {
+ t.Errorf("expected 2 external ips, got %d", len(natExternalIPs))
+ }
+ if natExternalIPs[0] != "abc" || natExternalIPs[1] != "dec" {
+ t.Errorf("expected abc,dec, got %s,%s", natExternalIPs[0], natExternalIPs[1])
+ }
+ if interfaceName != "test-name" {
+ t.Errorf("expected test-name, got %s", interfaceName)
+ }
+ if !rosenpassEnabled {
+ t.Errorf("expected rosenpassEnabled to be true, got false")
+ }
+ if wireguardPort != 10000 {
+ t.Errorf("expected wireguardPort to be 10000, got %d", wireguardPort)
+ }
+}
From ddea0011709f091acd902d871822187f75fb091d Mon Sep 17 00:00:00 2001
From: Maycon Santos
Date: Wed, 21 Aug 2024 19:24:40 +0200
Subject: [PATCH 10/89] [client] Refactor free port function (#2455)
Rely on net.ListenUDP to get an available port for wireguard in case the configured one is in use
---------
Co-authored-by: Viktor Liu <17948409+lixmal@users.noreply.github.com>
---
client/internal/connect.go | 48 ++++++++++++++++++------
client/internal/connect_test.go | 66 +++++++++++++++++----------------
2 files changed, 71 insertions(+), 43 deletions(-)
diff --git a/client/internal/connect.go b/client/internal/connect.go
index 1cfabe910..3937b7846 100644
--- a/client/internal/connect.go
+++ b/client/internal/connect.go
@@ -397,19 +397,43 @@ func statusRecorderToSignalConnStateNotifier(statusRecorder *peer.Status) signal
return notifier
}
-func freePort(start int) (int, error) {
+// freePort attempts to determine if the provided port is available, if not it will ask the system for a free port.
+func freePort(initPort int) (int, error) {
addr := net.UDPAddr{}
- if start == 0 {
- start = iface.DefaultWgPort
+ if initPort == 0 {
+ initPort = iface.DefaultWgPort
}
- for x := start; x <= 65535; x++ {
- addr.Port = x
- conn, err := net.ListenUDP("udp", &addr)
- if err != nil {
- continue
- }
- conn.Close()
- return x, nil
+
+ addr.Port = initPort
+
+ conn, err := net.ListenUDP("udp", &addr)
+ if err == nil {
+ closeConnWithLog(conn)
+ return initPort, nil
+ }
+
+ // if the port is already in use, ask the system for a free port
+ addr.Port = 0
+ conn, err = net.ListenUDP("udp", &addr)
+ if err != nil {
+ return 0, fmt.Errorf("unable to get a free port: %v", err)
+ }
+
+ udpAddr, ok := conn.LocalAddr().(*net.UDPAddr)
+ if !ok {
+ return 0, errors.New("wrong address type when getting a free port")
+ }
+ closeConnWithLog(conn)
+ return udpAddr.Port, nil
+}
+
+func closeConnWithLog(conn *net.UDPConn) {
+ startClosing := time.Now()
+ err := conn.Close()
+ if err != nil {
+ log.Warnf("closing probe port %d failed: %v. NetBird will still attempt to use this port for connection.", conn.LocalAddr().(*net.UDPAddr).Port, err)
+ }
+ if time.Since(startClosing) > time.Second {
+ log.Warnf("closing the testing port %d took %s. Usually it is safe to ignore, but continuous warnings may indicate a problem.", conn.LocalAddr().(*net.UDPAddr).Port, time.Since(startClosing))
}
- return 0, errors.New("no free ports")
}
diff --git a/client/internal/connect_test.go b/client/internal/connect_test.go
index 6f4a6bbb7..78b4b06e8 100644
--- a/client/internal/connect_test.go
+++ b/client/internal/connect_test.go
@@ -7,51 +7,55 @@ import (
func Test_freePort(t *testing.T) {
tests := []struct {
- name string
- port int
- want int
- wantErr bool
+ name string
+ port int
+ want int
+ shouldMatch bool
}{
{
- name: "available",
- port: 51820,
- want: 51820,
- wantErr: false,
+ name: "not provided, fallback to default",
+ port: 0,
+ want: 51820,
+ shouldMatch: true,
},
{
- name: "notavailable",
- port: 51830,
- want: 51831,
- wantErr: false,
+ name: "provided and available",
+ port: 51821,
+ want: 51821,
+ shouldMatch: true,
},
{
- name: "noports",
- port: 65535,
- want: 0,
- wantErr: true,
+ name: "provided and not available",
+ port: 51830,
+ want: 51830,
+ shouldMatch: false,
},
}
+ c1, err := net.ListenUDP("udp", &net.UDPAddr{Port: 51830})
+ if err != nil {
+ t.Errorf("freePort error = %v", err)
+ }
+ defer func(c1 *net.UDPConn) {
+ _ = c1.Close()
+ }(c1)
+
for _, tt := range tests {
- c1, err := net.ListenUDP("udp", &net.UDPAddr{Port: 51830})
- if err != nil {
- t.Errorf("freePort error = %v", err)
- }
- c2, err := net.ListenUDP("udp", &net.UDPAddr{Port: 65535})
- if err != nil {
- t.Errorf("freePort error = %v", err)
- }
t.Run(tt.name, func(t *testing.T) {
got, err := freePort(tt.port)
- if (err != nil) != tt.wantErr {
- t.Errorf("freePort() error = %v, wantErr %v", err, tt.wantErr)
- return
+
+ if err != nil {
+ t.Errorf("got an error while getting free port: %v", err)
}
- if got != tt.want {
- t.Errorf("freePort() = %v, want %v", got, tt.want)
+
+ if tt.shouldMatch && got != tt.want {
+ t.Errorf("got a different port %v, want %v", got, tt.want)
+ }
+
+ if !tt.shouldMatch && got == tt.want {
+ t.Errorf("got the same port %v, want a different port", tt.want)
}
})
- c1.Close()
- c2.Close()
+
}
}
From d92f2b633f3842152422daf0573c48d3e1692985 Mon Sep 17 00:00:00 2001
From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com>
Date: Thu, 22 Aug 2024 18:49:07 +0200
Subject: [PATCH 11/89] Bump github.com/docker/docker (#2426)
Bumps [github.com/docker/docker](https://github.com/docker/docker) from 26.1.4+incompatible to 26.1.5+incompatible.
- [Release notes](https://github.com/docker/docker/releases)
- [Commits](https://github.com/docker/docker/compare/v26.1.4...v26.1.5)
---
updated-dependencies:
- dependency-name: github.com/docker/docker
dependency-type: indirect
...
Signed-off-by: dependabot[bot]
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
---
go.mod | 2 +-
go.sum | 4 ++--
2 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/go.mod b/go.mod
index f47d8cf79..e10394bc6 100644
--- a/go.mod
+++ b/go.mod
@@ -115,7 +115,7 @@ require (
github.com/dgraph-io/ristretto v0.1.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/distribution/reference v0.6.0 // indirect
- github.com/docker/docker v26.1.4+incompatible // indirect
+ github.com/docker/docker v26.1.5+incompatible // indirect
github.com/docker/go-connections v0.5.0 // indirect
github.com/docker/go-units v0.5.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
diff --git a/go.sum b/go.sum
index 06df95a33..8407fdec7 100644
--- a/go.sum
+++ b/go.sum
@@ -132,8 +132,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
-github.com/docker/docker v26.1.4+incompatible h1:vuTpXDuoga+Z38m1OZHzl7NKisKWaWlhjQk7IDPSLsU=
-github.com/docker/docker v26.1.4+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
+github.com/docker/docker v26.1.5+incompatible h1:NEAxTwEjxV6VbBMBoGG3zPqbiJosIApZjxlbrG9q3/g=
+github.com/docker/docker v26.1.5+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c=
github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc=
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
From 33b264e59893c56ba1dc35a07535af62a6cbfd60 Mon Sep 17 00:00:00 2001
From: Aidan <52164617+arosberg@users.noreply.github.com>
Date: Fri, 23 Aug 2024 10:38:57 -0400
Subject: [PATCH 12/89] [misc] Add support for
NETBIRD_STORE_ENGINE_POSTGRES_DSN environment variable in setup.env (#2462)
* Added Postgres DSN env variable
* Added postgres check to script
---
infrastructure_files/configure.sh | 12 ++++++++++++
infrastructure_files/docker-compose.yml.tmpl | 3 +++
infrastructure_files/docker-compose.yml.tmpl.traefik | 4 +++-
3 files changed, 18 insertions(+), 1 deletion(-)
diff --git a/infrastructure_files/configure.sh b/infrastructure_files/configure.sh
index f04735de6..bf021c9ac 100755
--- a/infrastructure_files/configure.sh
+++ b/infrastructure_files/configure.sh
@@ -41,6 +41,18 @@ if [[ "x-$NETBIRD_DOMAIN" == "x-" ]]; then
exit 1
fi
+# Check if PostgreSQL is set as the store engine
+if [[ "$NETBIRD_STORE_CONFIG_ENGINE" == "postgres" ]]; then
+ # Exit if 'NETBIRD_STORE_ENGINE_POSTGRES_DSN' is not set
+ if [[ -z "$NETBIRD_STORE_ENGINE_POSTGRES_DSN" ]]; then
+ echo "Warning: NETBIRD_STORE_CONFIG_ENGINE=postgres but NETBIRD_STORE_ENGINE_POSTGRES_DSN is not set."
+ echo "Please add the following line to your setup.env file:"
+ echo 'NETBIRD_STORE_ENGINE_POSTGRES_DSN="host= user= password= dbname= port="'
+ exit 1
+ fi
+ export NETBIRD_STORE_ENGINE_POSTGRES_DSN
+fi
+
# local development or tests
if [[ $NETBIRD_DOMAIN == "localhost" || $NETBIRD_DOMAIN == "127.0.0.1" ]]; then
export NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN="netbird.selfhosted"
diff --git a/infrastructure_files/docker-compose.yml.tmpl b/infrastructure_files/docker-compose.yml.tmpl
index 6b6831493..43c8b470c 100644
--- a/infrastructure_files/docker-compose.yml.tmpl
+++ b/infrastructure_files/docker-compose.yml.tmpl
@@ -77,6 +77,9 @@ services:
options:
max-size: "500m"
max-file: "2"
+ environment:
+ - NETBIRD_STORE_ENGINE_POSTGRES_DSN=$NETBIRD_STORE_ENGINE_POSTGRES_DSN
+
# Coturn
coturn:
image: coturn/coturn:$COTURN_TAG
diff --git a/infrastructure_files/docker-compose.yml.tmpl.traefik b/infrastructure_files/docker-compose.yml.tmpl.traefik
index d3ae6529a..c4415d848 100644
--- a/infrastructure_files/docker-compose.yml.tmpl.traefik
+++ b/infrastructure_files/docker-compose.yml.tmpl.traefik
@@ -81,7 +81,9 @@ services:
- traefik.http.routers.netbird-management.service=netbird-management
- traefik.http.services.netbird-management.loadbalancer.server.port=443
- traefik.http.services.netbird-management.loadbalancer.server.scheme=h2c
-
+ environment:
+ - NETBIRD_STORE_ENGINE_POSTGRES_DSN=$NETBIRD_STORE_ENGINE_POSTGRES_DSN
+
# Coturn
coturn:
image: coturn/coturn:$COTURN_TAG
From d97b03656f3804c0ba4be017c25c799ad9c681b3 Mon Sep 17 00:00:00 2001
From: Bethuel Mmbaga
Date: Fri, 23 Aug 2024 19:42:55 +0300
Subject: [PATCH 13/89] [management] Refactor HTTP metrics (#2476)
* Add logging for slow SQL queries in SaveAccount and GetAccount
* Add resource count log for large accounts
* Refactor metrics middleware to simplify counters and histograms
* Update log levels and remove redundant resource count check
---
management/server/account.go | 6 +
management/server/http/handler.go | 21 ---
management/server/sql_store.go | 13 ++
.../server/telemetry/http_api_metrics.go | 138 +++++++-----------
4 files changed, 73 insertions(+), 105 deletions(-)
diff --git a/management/server/account.go b/management/server/account.go
index 49341a67b..7159aa9ac 100644
--- a/management/server/account.go
+++ b/management/server/account.go
@@ -476,6 +476,12 @@ func (a *Account) GetPeerNetworkMap(
objectCount := int64(len(peersToConnect) + len(expiredPeers) + len(routesUpdate) + len(firewallRules))
metrics.CountNetworkMapObjects(objectCount)
metrics.CountGetPeerNetworkMapDuration(time.Since(start))
+
+ if objectCount > 5000 {
+ log.WithContext(ctx).Tracef("account: %s has a total resource count of %d objects, "+
+ "peers to connect: %d, expired peers: %d, routes: %d, firewall rules: %d",
+ a.Id, objectCount, len(peersToConnect), len(expiredPeers), len(routesUpdate), len(firewallRules))
+ }
}
return nm
diff --git a/management/server/http/handler.go b/management/server/http/handler.go
index 3fe26d0ce..366efa9b7 100644
--- a/management/server/http/handler.go
+++ b/management/server/http/handler.go
@@ -100,27 +100,6 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa
api.addPostureCheckEndpoint()
api.addLocationsEndpoint()
- err := api.Router.Walk(func(route *mux.Route, _ *mux.Router, _ []*mux.Route) error {
- methods, err := route.GetMethods()
- if err != nil { // we may have wildcard routes from integrations without methods, skip them for now
- methods = []string{}
- }
- for _, method := range methods {
- template, err := route.GetPathTemplate()
- if err != nil {
- return err
- }
- err = metricsMiddleware.AddHTTPRequestResponseCounter(template, method)
- if err != nil {
- return err
- }
- }
- return nil
- })
- if err != nil {
- return nil, err
- }
-
return rootRouter, nil
}
diff --git a/management/server/sql_store.go b/management/server/sql_store.go
index 912e31410..0fb3d391f 100644
--- a/management/server/sql_store.go
+++ b/management/server/sql_store.go
@@ -134,6 +134,12 @@ func (s *SqlStore) AcquireReadLockByUID(ctx context.Context, uniqueID string) (u
func (s *SqlStore) SaveAccount(ctx context.Context, account *Account) error {
start := time.Now()
+ defer func() {
+ elapsed := time.Since(start)
+ if elapsed > 1*time.Second {
+ log.WithContext(ctx).Tracef("SaveAccount for account %s exceeded 1s, took: %v", account.Id, elapsed)
+ }
+ }()
// todo: remove this check after the issue is resolved
s.checkAccountDomainBeforeSave(ctx, account.Id, account.Domain)
@@ -513,6 +519,13 @@ func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*Account) {
}
func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account, error) {
+ start := time.Now()
+ defer func() {
+ elapsed := time.Since(start)
+ if elapsed > 1*time.Second {
+ log.WithContext(ctx).Tracef("GetAccount for account %s exceeded 1s, took: %v", accountID, elapsed)
+ }
+ }()
var account Account
result := s.db.Model(&account).
diff --git a/management/server/telemetry/http_api_metrics.go b/management/server/telemetry/http_api_metrics.go
index a80453dca..357f019c7 100644
--- a/management/server/telemetry/http_api_metrics.go
+++ b/management/server/telemetry/http_api_metrics.go
@@ -8,6 +8,7 @@ import (
"time"
"github.com/google/uuid"
+ "github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
@@ -54,112 +55,89 @@ func (rw *WrappedResponseWriter) WriteHeader(code int) {
// HTTPMiddleware handler used to collect metrics of every request/response coming to the API.
// Also adds request tracing (logging).
type HTTPMiddleware struct {
- meter metric.Meter
- ctx context.Context
+ ctx context.Context
// all HTTP requests by endpoint & method
- httpRequestCounters map[string]metric.Int64Counter
+ httpRequestCounter metric.Int64Counter
// all HTTP responses by endpoint & method & status code
- httpResponseCounters map[string]metric.Int64Counter
+ httpResponseCounter metric.Int64Counter
// all HTTP requests
totalHTTPRequestsCounter metric.Int64Counter
// all HTTP responses
totalHTTPResponseCounter metric.Int64Counter
// all HTTP responses by status code
- totalHTTPResponseCodeCounters map[int]metric.Int64Counter
+ totalHTTPResponseCodeCounter metric.Int64Counter
// all HTTP requests durations by endpoint and method
- httpRequestDurations map[string]metric.Int64Histogram
+ httpRequestDuration metric.Int64Histogram
// all HTTP requests durations
totalHTTPRequestDuration metric.Int64Histogram
}
// NewMetricsMiddleware creates a new HTTPMiddleware
func NewMetricsMiddleware(ctx context.Context, meter metric.Meter) (*HTTPMiddleware, error) {
- totalHTTPRequestsCounter, err := meter.Int64Counter(fmt.Sprintf("%s_total", httpRequestCounterPrefix), metric.WithUnit("1"))
+ httpRequestCounter, err := meter.Int64Counter(httpRequestCounterPrefix, metric.WithUnit("1"))
if err != nil {
return nil, err
}
- totalHTTPResponseCounter, err := meter.Int64Counter(fmt.Sprintf("%s_total", httpResponseCounterPrefix), metric.WithUnit("1"))
+ httpResponseCounter, err := meter.Int64Counter(httpResponseCounterPrefix, metric.WithUnit("1"))
if err != nil {
return nil, err
}
- totalHTTPRequestDuration, err := meter.Int64Histogram(fmt.Sprintf("%s_total", httpRequestDurationPrefix), metric.WithUnit("milliseconds"))
+ totalHTTPRequestsCounter, err := meter.Int64Counter(fmt.Sprintf("%s.total", httpRequestCounterPrefix), metric.WithUnit("1"))
+ if err != nil {
+ return nil, err
+ }
+
+ totalHTTPResponseCounter, err := meter.Int64Counter(fmt.Sprintf("%s.total", httpResponseCounterPrefix), metric.WithUnit("1"))
+ if err != nil {
+ return nil, err
+ }
+
+ totalHTTPResponseCodeCounter, err := meter.Int64Counter(fmt.Sprintf("%s.code.total", httpResponseCounterPrefix), metric.WithUnit("1"))
+ if err != nil {
+ return nil, err
+ }
+
+ httpRequestDuration, err := meter.Int64Histogram(httpRequestDurationPrefix, metric.WithUnit("milliseconds"))
+ if err != nil {
+ return nil, err
+ }
+
+ totalHTTPRequestDuration, err := meter.Int64Histogram(fmt.Sprintf("%s.total", httpRequestDurationPrefix), metric.WithUnit("milliseconds"))
if err != nil {
return nil, err
}
return &HTTPMiddleware{
- ctx: ctx,
- httpRequestCounters: map[string]metric.Int64Counter{},
- httpResponseCounters: map[string]metric.Int64Counter{},
- httpRequestDurations: map[string]metric.Int64Histogram{},
- totalHTTPResponseCodeCounters: map[int]metric.Int64Counter{},
- meter: meter,
- totalHTTPRequestsCounter: totalHTTPRequestsCounter,
- totalHTTPResponseCounter: totalHTTPResponseCounter,
- totalHTTPRequestDuration: totalHTTPRequestDuration,
+ ctx: ctx,
+ httpRequestCounter: httpRequestCounter,
+ httpResponseCounter: httpResponseCounter,
+ httpRequestDuration: httpRequestDuration,
+ totalHTTPResponseCodeCounter: totalHTTPResponseCodeCounter,
+ totalHTTPRequestsCounter: totalHTTPRequestsCounter,
+ totalHTTPResponseCounter: totalHTTPResponseCounter,
+ totalHTTPRequestDuration: totalHTTPRequestDuration,
},
nil
}
-// AddHTTPRequestResponseCounter adds a new meter for an HTTP defaultEndpoint and Method (GET, POST, etc)
-// Creates one request counter and multiple response counters (one per http response status code).
-func (m *HTTPMiddleware) AddHTTPRequestResponseCounter(endpoint string, method string) error {
- meterKey := getRequestCounterKey(endpoint, method)
- httpReqCounter, err := m.meter.Int64Counter(meterKey, metric.WithUnit("1"))
- if err != nil {
- return err
- }
- m.httpRequestCounters[meterKey] = httpReqCounter
-
- durationKey := getRequestDurationKey(endpoint, method)
- requestDuration, err := m.meter.Int64Histogram(durationKey, metric.WithUnit("milliseconds"))
- if err != nil {
- return err
- }
- m.httpRequestDurations[durationKey] = requestDuration
-
- respCodes := []int{200, 204, 400, 401, 403, 404, 500, 502, 503}
- for _, code := range respCodes {
- meterKey = getResponseCounterKey(endpoint, method, code)
- httpRespCounter, err := m.meter.Int64Counter(meterKey, metric.WithUnit("1"))
- if err != nil {
- return err
- }
- m.httpResponseCounters[meterKey] = httpRespCounter
-
- meterKey = fmt.Sprintf("%s_%d_total", httpResponseCounterPrefix, code)
- totalHTTPResponseCodeCounter, err := m.meter.Int64Counter(meterKey, metric.WithUnit("1"))
- if err != nil {
- return err
- }
- m.totalHTTPResponseCodeCounters[code] = totalHTTPResponseCodeCounter
- }
-
- return nil
-}
-
func replaceEndpointChars(endpoint string) string {
- endpoint = strings.ReplaceAll(endpoint, "/", "_")
endpoint = strings.ReplaceAll(endpoint, "{", "")
endpoint = strings.ReplaceAll(endpoint, "}", "")
return endpoint
}
-func getRequestCounterKey(endpoint, method string) string {
- endpoint = replaceEndpointChars(endpoint)
- return fmt.Sprintf("%s%s_%s", httpRequestCounterPrefix, endpoint, method)
-}
-
-func getRequestDurationKey(endpoint, method string) string {
- endpoint = replaceEndpointChars(endpoint)
- return fmt.Sprintf("%s%s_%s", httpRequestDurationPrefix, endpoint, method)
-}
-
-func getResponseCounterKey(endpoint, method string, status int) string {
- endpoint = replaceEndpointChars(endpoint)
- return fmt.Sprintf("%s%s_%s_%d", httpResponseCounterPrefix, endpoint, method, status)
+func getEndpointMetricAttr(r *http.Request) string {
+ var endpoint string
+ route := mux.CurrentRoute(r)
+ if route != nil {
+ pathTmpl, err := route.GetPathTemplate()
+ if err == nil {
+ endpoint = replaceEndpointChars(pathTmpl)
+ }
+ }
+ return endpoint
}
// Handler logs every request and response and adds the, to metrics.
@@ -176,11 +154,10 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler {
log.WithContext(ctx).Tracef("HTTP request %v: %v %v", reqID, r.Method, r.URL)
- metricKey := getRequestCounterKey(r.URL.Path, r.Method)
+ endpointAttr := attribute.String("endpoint", getEndpointMetricAttr(r))
+ methodAttr := attribute.String("method", r.Method)
- if c, ok := m.httpRequestCounters[metricKey]; ok {
- c.Add(m.ctx, 1)
- }
+ m.httpRequestCounter.Add(m.ctx, 1, metric.WithAttributes(endpointAttr, methodAttr))
m.totalHTTPRequestsCounter.Add(m.ctx, 1)
w := WrapResponseWriter(rw)
@@ -193,21 +170,14 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler {
log.WithContext(ctx).Tracef("HTTP response %v: %v %v status %v", reqID, r.Method, r.URL, w.Status())
}
- metricKey = getResponseCounterKey(r.URL.Path, r.Method, w.Status())
- if c, ok := m.httpResponseCounters[metricKey]; ok {
- c.Add(m.ctx, 1)
- }
+ statusCodeAttr := attribute.Int("code", w.Status())
+ m.httpResponseCounter.Add(m.ctx, 1, metric.WithAttributes(endpointAttr, methodAttr, statusCodeAttr))
m.totalHTTPResponseCounter.Add(m.ctx, 1)
- if c, ok := m.totalHTTPResponseCodeCounters[w.Status()]; ok {
- c.Add(m.ctx, 1)
- }
+ m.totalHTTPResponseCodeCounter.Add(m.ctx, 1, metric.WithAttributes(statusCodeAttr))
- durationKey := getRequestDurationKey(r.URL.Path, r.Method)
reqTook := time.Since(reqStart)
- if c, ok := m.httpRequestDurations[durationKey]; ok {
- c.Record(m.ctx, reqTook.Milliseconds())
- }
+ m.httpRequestDuration.Record(m.ctx, reqTook.Milliseconds(), metric.WithAttributes(endpointAttr, methodAttr))
log.WithContext(ctx).Debugf("request %s %s took %d ms and finished with status %d", r.Method, r.URL.Path, reqTook.Milliseconds(), w.Status())
if w.Status() == 200 && (r.Method == http.MethodPut || r.Method == http.MethodPost || r.Method == http.MethodDelete) {
From be6bc46bcdaa2b7703a040afa0cbdbb758d11385 Mon Sep 17 00:00:00 2001
From: Maycon Santos
Date: Fri, 23 Aug 2024 19:37:20 +0200
Subject: [PATCH 14/89] Update sign pipeline version to 0.0.13 (#2477)
---
.github/workflows/release.yml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index 30f24e92e..5098cd549 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -10,7 +10,7 @@ on:
env:
- SIGN_PIPE_VER: "v0.0.12"
+ SIGN_PIPE_VER: "v0.0.13"
GORELEASER_VER: "v1.14.1"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"
From 00944bcdbf7a31d41f945c4a49f47c32571271b2 Mon Sep 17 00:00:00 2001
From: Harry Kodden
Date: Tue, 27 Aug 2024 16:37:55 +0200
Subject: [PATCH 15/89] [management] Add support to ECDSA public Keys (#2461)
Update the JWT validation logic to handle ECDSA keys in addition to the existing RSA keys
---------
Co-authored-by: Harry Kodden
Co-authored-by: Bethuel Mmbaga
---
management/server/jwtclaims/jwtValidator.go | 145 ++++++++++----------
1 file changed, 75 insertions(+), 70 deletions(-)
diff --git a/management/server/jwtclaims/jwtValidator.go b/management/server/jwtclaims/jwtValidator.go
index 39676982e..d5c1e7c9e 100644
--- a/management/server/jwtclaims/jwtValidator.go
+++ b/management/server/jwtclaims/jwtValidator.go
@@ -1,14 +1,12 @@
package jwtclaims
import (
- "bytes"
"context"
+ "crypto/ecdsa"
+ "crypto/elliptic"
"crypto/rsa"
- "crypto/x509"
"encoding/base64"
- "encoding/binary"
"encoding/json"
- "encoding/pem"
"errors"
"fmt"
"math/big"
@@ -41,11 +39,6 @@ type Options struct {
// When set, all requests with the OPTIONS method will use authentication
// Default: false
EnableAuthOnOptions bool
- // When set, the middelware verifies that tokens are signed with the specific signing algorithm
- // If the signing method is not constant the ValidationKeyGetter callback can be used to implement additional checks
- // Important to avoid security issues described here: https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/
- // Default: nil
- SigningMethod jwt.SigningMethod
}
// Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation
@@ -54,6 +47,18 @@ type Jwks struct {
expiresInTime time.Time
}
+// The supported elliptic curves types
+const (
+ // p256 represents a cryptographic elliptical curve type.
+ p256 = "P-256"
+
+ // p384 represents a cryptographic elliptical curve type.
+ p384 = "P-384"
+
+ // p521 represents a cryptographic elliptical curve type.
+ p521 = "P-521"
+)
+
// JSONWebKey is a representation of a Jason Web Key
type JSONWebKey struct {
Kty string `json:"kty"`
@@ -61,6 +66,9 @@ type JSONWebKey struct {
Use string `json:"use"`
N string `json:"n"`
E string `json:"e"`
+ Crv string `json:"crv"`
+ X string `json:"x"`
+ Y string `json:"y"`
X5c []string `json:"x5c"`
}
@@ -115,15 +123,14 @@ func NewJWTValidator(ctx context.Context, issuer string, audienceList []string,
}
}
- cert, err := getPemCert(ctx, token, keys)
+ publicKey, err := getPublicKey(ctx, token, keys)
if err != nil {
+ log.WithContext(ctx).Errorf("getPublicKey error: %s", err)
return nil, err
}
- result, _ := jwt.ParseRSAPublicKeyFromPEM([]byte(cert))
- return result, nil
+ return publicKey, nil
},
- SigningMethod: jwt.SigningMethodRS256,
EnableAuthOnOptions: false,
}
@@ -159,15 +166,7 @@ func (m *JWTValidator) ValidateAndParse(ctx context.Context, token string) (*jwt
// Check if there was an error in parsing...
if err != nil {
log.WithContext(ctx).Errorf("error parsing token: %v", err)
- return nil, fmt.Errorf("Error parsing token: %w", err)
- }
-
- if m.options.SigningMethod != nil && m.options.SigningMethod.Alg() != parsedToken.Header["alg"] {
- errorMsg := fmt.Sprintf("Expected %s signing method but token specified %s",
- m.options.SigningMethod.Alg(),
- parsedToken.Header["alg"])
- log.WithContext(ctx).Debugf("error validating token algorithm: %s", errorMsg)
- return nil, fmt.Errorf("error validating token algorithm: %s", errorMsg)
+ return nil, fmt.Errorf("error parsing token: %w", err)
}
// Check if the parsed token is valid...
@@ -205,9 +204,8 @@ func getPemKeys(ctx context.Context, keysLocation string) (*Jwks, error) {
return jwks, err
}
-func getPemCert(ctx context.Context, token *jwt.Token, jwks *Jwks) (string, error) {
+func getPublicKey(ctx context.Context, token *jwt.Token, jwks *Jwks) (interface{}, error) {
// todo as we load the jkws when the server is starting, we should build a JKS map with the pem cert at the boot time
- cert := ""
for k := range jwks.Keys {
if token.Header["kid"] != jwks.Keys[k].Kid {
@@ -215,73 +213,79 @@ func getPemCert(ctx context.Context, token *jwt.Token, jwks *Jwks) (string, erro
}
if len(jwks.Keys[k].X5c) != 0 {
- cert = "-----BEGIN CERTIFICATE-----\n" + jwks.Keys[k].X5c[0] + "\n-----END CERTIFICATE-----"
- return cert, nil
+ cert := "-----BEGIN CERTIFICATE-----\n" + jwks.Keys[k].X5c[0] + "\n-----END CERTIFICATE-----"
+ return jwt.ParseRSAPublicKeyFromPEM([]byte(cert))
}
- log.WithContext(ctx).Debugf("generating validation pem from JWK")
- return generatePemFromJWK(jwks.Keys[k])
+
+ if jwks.Keys[k].Kty == "RSA" {
+ log.WithContext(ctx).Debugf("generating PublicKey from RSA JWK")
+ return getPublicKeyFromRSA(jwks.Keys[k])
+ }
+ if jwks.Keys[k].Kty == "EC" {
+ log.WithContext(ctx).Debugf("generating PublicKey from ECDSA JWK")
+ return getPublicKeyFromECDSA(jwks.Keys[k])
+ }
+
+ log.WithContext(ctx).Debugf("Key Type: %s not yet supported, please raise ticket!", jwks.Keys[k].Kty)
}
- return cert, errors.New("unable to find appropriate key")
+ return nil, errors.New("unable to find appropriate key")
}
-func generatePemFromJWK(jwk JSONWebKey) (string, error) {
- decodedModulus, err := base64.RawURLEncoding.DecodeString(jwk.N)
- if err != nil {
- return "", fmt.Errorf("unable to decode JWK modulus, error: %s", err)
+func getPublicKeyFromECDSA(jwk JSONWebKey) (publicKey *ecdsa.PublicKey, err error) {
+
+ if jwk.X == "" || jwk.Y == "" || jwk.Crv == "" {
+ return nil, fmt.Errorf("ecdsa key incomplete")
}
- intModules := big.NewInt(0)
- intModules.SetBytes(decodedModulus)
-
- exponent, err := convertExponentStringToInt(jwk.E)
- if err != nil {
- return "", fmt.Errorf("unable to decode JWK exponent, error: %s", err)
+ var xCoordinate []byte
+ if xCoordinate, err = base64.RawURLEncoding.DecodeString(jwk.X); err != nil {
+ return nil, err
}
- publicKey := &rsa.PublicKey{
- N: intModules,
- E: exponent,
+ var yCoordinate []byte
+ if yCoordinate, err = base64.RawURLEncoding.DecodeString(jwk.Y); err != nil {
+ return nil, err
}
- derKey, err := x509.MarshalPKIXPublicKey(publicKey)
- if err != nil {
- return "", fmt.Errorf("unable to convert public key to DER, error: %s", err)
+ publicKey = &ecdsa.PublicKey{}
+
+ var curve elliptic.Curve
+ switch jwk.Crv {
+ case p256:
+ curve = elliptic.P256()
+ case p384:
+ curve = elliptic.P384()
+ case p521:
+ curve = elliptic.P521()
}
- block := &pem.Block{
- Type: "RSA PUBLIC KEY",
- Bytes: derKey,
- }
+ publicKey.Curve = curve
+ publicKey.X = big.NewInt(0).SetBytes(xCoordinate)
+ publicKey.Y = big.NewInt(0).SetBytes(yCoordinate)
- var out bytes.Buffer
- err = pem.Encode(&out, block)
- if err != nil {
- return "", fmt.Errorf("unable to encode Pem block , error: %s", err)
- }
-
- return out.String(), nil
+ return publicKey, nil
}
-func convertExponentStringToInt(stringExponent string) (int, error) {
- decodedString, err := base64.StdEncoding.DecodeString(stringExponent)
+func getPublicKeyFromRSA(jwk JSONWebKey) (*rsa.PublicKey, error) {
+
+ decodedE, err := base64.RawURLEncoding.DecodeString(jwk.E)
if err != nil {
- return 0, err
+ return nil, err
}
- exponentBytes := decodedString
- if len(decodedString) < 8 {
- exponentBytes = make([]byte, 8-len(decodedString), 8)
- exponentBytes = append(exponentBytes, decodedString...)
+ decodedN, err := base64.RawURLEncoding.DecodeString(jwk.N)
+ if err != nil {
+ return nil, err
}
- bytesReader := bytes.NewReader(exponentBytes)
- var exponent uint64
- err = binary.Read(bytesReader, binary.BigEndian, &exponent)
- if err != nil {
- return 0, err
- }
+ var n, e big.Int
+ e.SetBytes(decodedE)
+ n.SetBytes(decodedN)
- return int(exponent), nil
+ return &rsa.PublicKey{
+ E: int(e.Int64()),
+ N: &n,
+ }, nil
}
// getMaxAgeFromCacheHeader extracts max-age directive from the Cache-Control header
@@ -306,3 +310,4 @@ func getMaxAgeFromCacheHeader(ctx context.Context, cacheControl string) int {
return 0
}
+
From 63a75d72fcfff9bc503147bbbbc57ba9041cbd08 Mon Sep 17 00:00:00 2001
From: Maycon Santos
Date: Tue, 27 Aug 2024 16:38:42 +0200
Subject: [PATCH 16/89] [misc] Test infrastructure files generation with
postgres store (#2478)
---
.../workflows/test-infrastructure-files.yml | 33 +++++++++++++++++--
1 file changed, 30 insertions(+), 3 deletions(-)
diff --git a/.github/workflows/test-infrastructure-files.yml b/.github/workflows/test-infrastructure-files.yml
index 52b8ee3e2..f758e74bd 100644
--- a/.github/workflows/test-infrastructure-files.yml
+++ b/.github/workflows/test-infrastructure-files.yml
@@ -18,7 +18,31 @@ concurrency:
jobs:
test-docker-compose:
runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ store: [ 'sqlite', 'postgres' ]
+ services:
+ postgres:
+ image: ${{ (matrix.store == 'postgres') && 'postgres' || '' }}
+ env:
+ POSTGRES_USER: netbird
+ POSTGRES_PASSWORD: postgres
+ POSTGRES_DB: netbird
+ options: >-
+ --health-cmd pg_isready
+ --health-interval 10s
+ --health-timeout 5s
+ ports:
+ - 5432:5432
steps:
+ - name: Set Database Connection String
+ run: |
+ if [ "${{ matrix.store }}" == "postgres" ]; then
+ echo "NETBIRD_STORE_ENGINE_POSTGRES_DSN=host=$(hostname -I | awk '{print $1}') user=netbird password=postgres dbname=netbird port=5432" >> $GITHUB_ENV
+ else
+ echo "NETBIRD_STORE_ENGINE_POSTGRES_DSN==" >> $GITHUB_ENV
+ fi
+
- name: Install jq
run: sudo apt-get install -y jq
@@ -58,7 +82,8 @@ jobs:
CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
CI_NETBIRD_AUTH_SUPPORTED_SCOPES: "openid profile email offline_access api email_verified"
- CI_NETBIRD_STORE_CONFIG_ENGINE: "sqlite"
+ CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }}
+ NETBIRD_STORE_ENGINE_POSTGRES_DSN: ${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
- name: check values
@@ -85,7 +110,8 @@ jobs:
CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
CI_NETBIRD_SIGNAL_PORT: 12345
- CI_NETBIRD_STORE_CONFIG_ENGINE: "sqlite"
+ CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }}
+ NETBIRD_STORE_ENGINE_POSTGRES_DSN: '${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}$'
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
@@ -123,6 +149,7 @@ jobs:
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES"
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep -A 3 RedirectURLs | grep "http://localhost:53000"
grep "external-ip" turnserver.conf | grep $CI_NETBIRD_TURN_EXTERNAL_IP
+ grep NETBIRD_STORE_ENGINE_POSTGRES_DSN docker-compose.yml | egrep "$NETBIRD_STORE_ENGINE_POSTGRES_DSN"
- name: Install modules
run: go mod tidy
@@ -159,7 +186,7 @@ jobs:
- name: test running containers
run: |
count=$(docker compose ps --format json | jq '. | select(.Name | contains("artifacts")) | .State' | grep -c running)
- test $count -eq 4
+ test $count -eq 4 || docker compose logs
working-directory: infrastructure_files/artifacts
- name: test geolocation databases
From 7efaf7eadbd7e746ffec3edb0083f5d03fe3848a Mon Sep 17 00:00:00 2001
From: Maycon Santos
Date: Tue, 27 Aug 2024 19:21:14 +0200
Subject: [PATCH 17/89] [client] Use static requested GUID when creating
Windows interface (#2479)
RequestedGUID is the GUID of the created network adapter, which then influences NLA generation deterministically.
With this change, NetBird should not generate multiple interfaces in every restart on Windows.
---
client/internal/engine_test.go | 3 +++
iface/iface_test.go | 8 ++++++++
iface/tun.go | 3 +++
iface/tun_windows.go | 17 ++++++++++++++++-
4 files changed, 30 insertions(+), 1 deletion(-)
diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go
index e0f85d211..80b79a364 100644
--- a/client/internal/engine_test.go
+++ b/client/internal/engine_test.go
@@ -13,6 +13,7 @@ import (
"testing"
"time"
+ "github.com/google/uuid"
"github.com/pion/transport/v3/stdnet"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
@@ -845,6 +846,8 @@ func TestEngine_MultiplePeers(t *testing.T) {
engine.dnsServer = &dns.MockServer{}
mu.Lock()
defer mu.Unlock()
+ guid := fmt.Sprintf("{%s}", uuid.New().String())
+ iface.CustomWindowsGUIDString = strings.ToLower(guid)
err = engine.Start()
if err != nil {
t.Errorf("unable to start engine for peer %d with error %v", j, err)
diff --git a/iface/iface_test.go b/iface/iface_test.go
index 43c44b770..6609c06f4 100644
--- a/iface/iface_test.go
+++ b/iface/iface_test.go
@@ -4,9 +4,11 @@ import (
"fmt"
"net"
"net/netip"
+ "strings"
"testing"
"time"
+ "github.com/google/uuid"
"github.com/pion/transport/v3/stdnet"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
@@ -345,6 +347,9 @@ func Test_ConnectPeers(t *testing.T) {
t.Fatal(err)
}
+ guid := fmt.Sprintf("{%s}", uuid.New().String())
+ CustomWindowsGUIDString = strings.ToLower(guid)
+
iface1, err := NewWGIFace(peer1ifaceName, peer1wgIP, peer1wgPort, peer1Key.String(), DefaultMTU, newNet, nil, nil)
if err != nil {
t.Fatal(err)
@@ -364,6 +369,9 @@ func Test_ConnectPeers(t *testing.T) {
t.Fatal(err)
}
+ guid = fmt.Sprintf("{%s}", uuid.New().String())
+ CustomWindowsGUIDString = strings.ToLower(guid)
+
newNet, err = stdnet.NewNet()
if err != nil {
t.Fatal(err)
diff --git a/iface/tun.go b/iface/tun.go
index b3c0f9d80..7d0a57ed6 100644
--- a/iface/tun.go
+++ b/iface/tun.go
@@ -7,6 +7,9 @@ import (
"github.com/netbirdio/netbird/iface/bind"
)
+// CustomWindowsGUIDString is a custom GUID string for the interface
+var CustomWindowsGUIDString string
+
type wgTunDevice interface {
Create() (wgConfigurer, error)
Up() (*bind.UniversalUDPMuxDefault, error)
diff --git a/iface/tun_windows.go b/iface/tun_windows.go
index 0d658059f..8c0a3c3b5 100644
--- a/iface/tun_windows.go
+++ b/iface/tun_windows.go
@@ -14,6 +14,8 @@ import (
"github.com/netbirdio/netbird/iface/bind"
)
+const defaultWindowsGUIDSTring = "{f2f29e61-d91f-4d76-8151-119b20c4bdeb}"
+
type tunDevice struct {
name string
address WGAddress
@@ -40,9 +42,22 @@ func newTunDevice(name string, address WGAddress, port int, key string, mtu int,
}
}
+func getGUID() (windows.GUID, error) {
+ guidString := defaultWindowsGUIDSTring
+ if CustomWindowsGUIDString != "" {
+ guidString = CustomWindowsGUIDString
+ }
+ return windows.GUIDFromString(guidString)
+}
+
func (t *tunDevice) Create() (wgConfigurer, error) {
+ guid, err := getGUID()
+ if err != nil {
+ log.Errorf("failed to get GUID: %s", err)
+ return nil, err
+ }
log.Info("create tun interface")
- tunDevice, err := tun.CreateTUN(t.name, t.mtu)
+ tunDevice, err := tun.CreateTUNWithRequestedGUID(t.name, &guid, t.mtu)
if err != nil {
return nil, err
}
From 880b81154f0a03bf040663ba88478f93d96d6f22 Mon Sep 17 00:00:00 2001
From: Maycon Santos
Date: Wed, 28 Aug 2024 14:46:35 +0200
Subject: [PATCH 18/89] Use new sign pipeline (#2490)
---
.github/workflows/release.yml | 24 +++++-------------------
1 file changed, 5 insertions(+), 19 deletions(-)
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index 5098cd549..a8f7868d5 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -10,7 +10,7 @@ on:
env:
- SIGN_PIPE_VER: "v0.0.13"
+ SIGN_PIPE_VER: "v0.0.14"
GORELEASER_VER: "v1.14.1"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"
@@ -231,29 +231,15 @@ jobs:
path: dist/
retention-days: 3
- trigger_windows_signer:
+ trigger_signer:
runs-on: ubuntu-latest
- needs: [release,release_ui]
+ needs: [release,release_ui,release_ui_darwin]
if: startsWith(github.ref, 'refs/tags/')
steps:
- - name: Trigger Windows binaries sign pipeline
+ - name: Trigger binaries sign pipelines
uses: benc-uk/workflow-dispatch@v1
with:
- workflow: Sign windows bin and installer
- repo: netbirdio/sign-pipelines
- ref: ${{ env.SIGN_PIPE_VER }}
- token: ${{ secrets.SIGN_GITHUB_TOKEN }}
- inputs: '{ "tag": "${{ github.ref }}" }'
-
- trigger_darwin_signer:
- runs-on: ubuntu-latest
- needs: [release,release_ui_darwin]
- if: startsWith(github.ref, 'refs/tags/')
- steps:
- - name: Trigger Darwin App binaries sign pipeline
- uses: benc-uk/workflow-dispatch@v1
- with:
- workflow: Sign darwin ui app with dispatch
+ workflow: Sign bin and installer
repo: netbirdio/sign-pipelines
ref: ${{ env.SIGN_PIPE_VER }}
token: ${{ secrets.SIGN_GITHUB_TOKEN }}
From 5ac6f565941338f5c8a9ab8307670f83b4978ea0 Mon Sep 17 00:00:00 2001
From: Zoltan Papp
Date: Thu, 29 Aug 2024 21:31:19 +0200
Subject: [PATCH 19/89] [relay] Replace the iface to interface (#2473)
Replace the iface to interface
---
client/internal/engine.go | 2 +-
client/internal/engine_test.go | 13 +--
client/internal/peer/conn.go | 2 +-
client/internal/routemanager/client.go | 6 +-
client/internal/routemanager/dynamic/route.go | 4 +-
client/internal/routemanager/manager.go | 4 +-
.../internal/routemanager/server_android.go | 2 +-
.../routemanager/server_nonandroid.go | 4 +-
.../routemanager/sysctl/sysctl_linux.go | 2 +-
.../routemanager/systemops/systemops.go | 4 +-
.../systemops/systemops_generic.go | 2 +-
iface/iface_moc.go | 103 ++++++++++++++++++
iface/iwginterface.go | 32 ++++++
iface/iwginterface_windows.go | 31 ++++++
14 files changed, 188 insertions(+), 23 deletions(-)
create mode 100644 iface/iface_moc.go
create mode 100644 iface/iwginterface.go
create mode 100644 iface/iwginterface_windows.go
diff --git a/client/internal/engine.go b/client/internal/engine.go
index d65322d6a..b3fc2b628 100644
--- a/client/internal/engine.go
+++ b/client/internal/engine.go
@@ -134,7 +134,7 @@ type Engine struct {
ctx context.Context
cancel context.CancelFunc
- wgInterface *iface.WGIface
+ wgInterface iface.IWGIface
wgProxyFactory *wgproxy.Factory
udpMux *bind.UniversalUDPMuxDefault
diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go
index 80b79a364..e024dd323 100644
--- a/client/internal/engine_test.go
+++ b/client/internal/engine_test.go
@@ -215,14 +215,13 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
WgPrivateKey: key,
WgPort: 33100,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
- newNet, err := stdnet.NewNet()
- if err != nil {
- t.Fatal(err)
- }
- engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil, nil)
- if err != nil {
- t.Fatal(err)
+
+ wgIface := &iface.MockWGIface{
+ RemovePeerFunc: func(peerKey string) error {
+ return nil
+ },
}
+ engine.wgInterface = wgIface
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, nil)
engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go
index 0d8fd932c..d1fe0d419 100644
--- a/client/internal/peer/conn.go
+++ b/client/internal/peer/conn.go
@@ -36,7 +36,7 @@ const (
type WgConfig struct {
WgListenPort int
RemoteKey string
- WgInterface *iface.WGIface
+ WgInterface iface.IWGIface
AllowedIps string
PreSharedKey *wgtypes.Key
}
diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go
index 1566d10dd..cebdd2b0f 100644
--- a/client/internal/routemanager/client.go
+++ b/client/internal/routemanager/client.go
@@ -44,7 +44,7 @@ type clientNetwork struct {
ctx context.Context
cancel context.CancelFunc
statusRecorder *peer.Status
- wgInterface *iface.WGIface
+ wgInterface iface.IWGIface
routes map[route.ID]*route.Route
routeUpdate chan routesUpdate
peerStateUpdate chan struct{}
@@ -54,7 +54,7 @@ type clientNetwork struct {
updateSerial uint64
}
-func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration, wgInterface *iface.WGIface, statusRecorder *peer.Status, rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *clientNetwork {
+func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration, wgInterface iface.IWGIface, statusRecorder *peer.Status, rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *clientNetwork {
ctx, cancel := context.WithCancel(ctx)
client := &clientNetwork{
@@ -384,7 +384,7 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
}
}
-func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status, wgInterface *iface.WGIface) RouteHandler {
+func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status, wgInterface iface.IWGIface) RouteHandler {
if rt.IsDynamic() {
dns := nbdns.NewServiceViaMemory(wgInterface)
return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder, wgInterface, fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()))
diff --git a/client/internal/routemanager/dynamic/route.go b/client/internal/routemanager/dynamic/route.go
index 3296f3ddf..5897031e7 100644
--- a/client/internal/routemanager/dynamic/route.go
+++ b/client/internal/routemanager/dynamic/route.go
@@ -48,7 +48,7 @@ type Route struct {
currentPeerKey string
cancel context.CancelFunc
statusRecorder *peer.Status
- wgInterface *iface.WGIface
+ wgInterface iface.IWGIface
resolverAddr string
}
@@ -58,7 +58,7 @@ func NewRoute(
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
interval time.Duration,
statusRecorder *peer.Status,
- wgInterface *iface.WGIface,
+ wgInterface iface.IWGIface,
resolverAddr string,
) *Route {
return &Route{
diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go
index 0b10dbe33..597eddd51 100644
--- a/client/internal/routemanager/manager.go
+++ b/client/internal/routemanager/manager.go
@@ -49,7 +49,7 @@ type DefaultManager struct {
serverRouter serverRouter
sysOps *systemops.SysOps
statusRecorder *peer.Status
- wgInterface *iface.WGIface
+ wgInterface iface.IWGIface
pubKey string
notifier *notifier.Notifier
routeRefCounter *refcounter.RouteRefCounter
@@ -61,7 +61,7 @@ func NewManager(
ctx context.Context,
pubKey string,
dnsRouteInterval time.Duration,
- wgInterface *iface.WGIface,
+ wgInterface iface.IWGIface,
statusRecorder *peer.Status,
initialRoutes []*route.Route,
) *DefaultManager {
diff --git a/client/internal/routemanager/server_android.go b/client/internal/routemanager/server_android.go
index b4065bca6..2057b9cc8 100644
--- a/client/internal/routemanager/server_android.go
+++ b/client/internal/routemanager/server_android.go
@@ -11,6 +11,6 @@ import (
"github.com/netbirdio/netbird/iface"
)
-func newServerRouter(context.Context, *iface.WGIface, firewall.Manager, *peer.Status) (serverRouter, error) {
+func newServerRouter(context.Context, iface.IWGIface, firewall.Manager, *peer.Status) (serverRouter, error) {
return nil, fmt.Errorf("server route not supported on this os")
}
diff --git a/client/internal/routemanager/server_nonandroid.go b/client/internal/routemanager/server_nonandroid.go
index 8470934c2..43a266cd2 100644
--- a/client/internal/routemanager/server_nonandroid.go
+++ b/client/internal/routemanager/server_nonandroid.go
@@ -22,11 +22,11 @@ type defaultServerRouter struct {
ctx context.Context
routes map[route.ID]*route.Route
firewall firewall.Manager
- wgInterface *iface.WGIface
+ wgInterface iface.IWGIface
statusRecorder *peer.Status
}
-func newServerRouter(ctx context.Context, wgInterface *iface.WGIface, firewall firewall.Manager, statusRecorder *peer.Status) (serverRouter, error) {
+func newServerRouter(ctx context.Context, wgInterface iface.IWGIface, firewall firewall.Manager, statusRecorder *peer.Status) (serverRouter, error) {
return &defaultServerRouter{
ctx: ctx,
routes: make(map[route.ID]*route.Route),
diff --git a/client/internal/routemanager/sysctl/sysctl_linux.go b/client/internal/routemanager/sysctl/sysctl_linux.go
index 43394a823..13e1229f8 100644
--- a/client/internal/routemanager/sysctl/sysctl_linux.go
+++ b/client/internal/routemanager/sysctl/sysctl_linux.go
@@ -23,7 +23,7 @@ const (
)
// Setup configures sysctl settings for RP filtering and source validation.
-func Setup(wgIface *iface.WGIface) (map[string]int, error) {
+func Setup(wgIface iface.IWGIface) (map[string]int, error) {
keys := map[string]int{}
var result *multierror.Error
diff --git a/client/internal/routemanager/systemops/systemops.go b/client/internal/routemanager/systemops/systemops.go
index cddd7e7e2..ae27b0123 100644
--- a/client/internal/routemanager/systemops/systemops.go
+++ b/client/internal/routemanager/systemops/systemops.go
@@ -19,7 +19,7 @@ type ExclusionCounter = refcounter.Counter[any, Nexthop]
type SysOps struct {
refCounter *ExclusionCounter
- wgInterface *iface.WGIface
+ wgInterface iface.IWGIface
// prefixes is tracking all the current added prefixes im memory
// (this is used in iOS as all route updates require a full table update)
//nolint
@@ -30,7 +30,7 @@ type SysOps struct {
notifier *notifier.Notifier
}
-func NewSysOps(wgInterface *iface.WGIface, notifier *notifier.Notifier) *SysOps {
+func NewSysOps(wgInterface iface.IWGIface, notifier *notifier.Notifier) *SysOps {
return &SysOps{
wgInterface: wgInterface,
notifier: notifier,
diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go
index 671545b86..d76824c10 100644
--- a/client/internal/routemanager/systemops/systemops_generic.go
+++ b/client/internal/routemanager/systemops/systemops_generic.go
@@ -122,7 +122,7 @@ func (r *SysOps) addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface.
// If the next hop or interface is pointing to the VPN interface, it will return the initial values.
-func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIface, initialNextHop Nexthop) (Nexthop, error) {
+func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf iface.IWGIface, initialNextHop Nexthop) (Nexthop, error) {
addr := prefix.Addr()
switch {
case addr.IsLoopback(),
diff --git a/iface/iface_moc.go b/iface/iface_moc.go
new file mode 100644
index 000000000..fab3054a0
--- /dev/null
+++ b/iface/iface_moc.go
@@ -0,0 +1,103 @@
+package iface
+
+import (
+ "net"
+ "time"
+
+ "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
+
+ "github.com/netbirdio/netbird/iface/bind"
+)
+
+type MockWGIface struct {
+ CreateFunc func() error
+ CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error
+ IsUserspaceBindFunc func() bool
+ NameFunc func() string
+ AddressFunc func() WGAddress
+ ToInterfaceFunc func() *net.Interface
+ UpFunc func() (*bind.UniversalUDPMuxDefault, error)
+ UpdateAddrFunc func(newAddr string) error
+ UpdatePeerFunc func(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
+ RemovePeerFunc func(peerKey string) error
+ AddAllowedIPFunc func(peerKey string, allowedIP string) error
+ RemoveAllowedIPFunc func(peerKey string, allowedIP string) error
+ CloseFunc func() error
+ SetFilterFunc func(filter PacketFilter) error
+ GetFilterFunc func() PacketFilter
+ GetDeviceFunc func() *DeviceWrapper
+ GetStatsFunc func(peerKey string) (WGStats, error)
+ GetInterfaceGUIDStringFunc func() (string, error)
+}
+
+func (m *MockWGIface) GetInterfaceGUIDString() (string, error) {
+ return m.GetInterfaceGUIDStringFunc()
+}
+
+func (m *MockWGIface) Create() error {
+ return m.CreateFunc()
+}
+
+func (m *MockWGIface) CreateOnAndroid(routeRange []string, ip string, domains []string) error {
+ return m.CreateOnAndroidFunc(routeRange, ip, domains)
+}
+
+func (m *MockWGIface) IsUserspaceBind() bool {
+ return m.IsUserspaceBindFunc()
+}
+
+func (m *MockWGIface) Name() string {
+ return m.NameFunc()
+}
+
+func (m *MockWGIface) Address() WGAddress {
+ return m.AddressFunc()
+}
+
+func (m *MockWGIface) ToInterface() *net.Interface {
+ return m.ToInterfaceFunc()
+}
+
+func (m *MockWGIface) Up() (*bind.UniversalUDPMuxDefault, error) {
+ return m.UpFunc()
+}
+
+func (m *MockWGIface) UpdateAddr(newAddr string) error {
+ return m.UpdateAddrFunc(newAddr)
+}
+
+func (m *MockWGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
+ return m.UpdatePeerFunc(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
+}
+
+func (m *MockWGIface) RemovePeer(peerKey string) error {
+ return m.RemovePeerFunc(peerKey)
+}
+
+func (m *MockWGIface) AddAllowedIP(peerKey string, allowedIP string) error {
+ return m.AddAllowedIPFunc(peerKey, allowedIP)
+}
+
+func (m *MockWGIface) RemoveAllowedIP(peerKey string, allowedIP string) error {
+ return m.RemoveAllowedIPFunc(peerKey, allowedIP)
+}
+
+func (m *MockWGIface) Close() error {
+ return m.CloseFunc()
+}
+
+func (m *MockWGIface) SetFilter(filter PacketFilter) error {
+ return m.SetFilterFunc(filter)
+}
+
+func (m *MockWGIface) GetFilter() PacketFilter {
+ return m.GetFilterFunc()
+}
+
+func (m *MockWGIface) GetDevice() *DeviceWrapper {
+ return m.GetDeviceFunc()
+}
+
+func (m *MockWGIface) GetStats(peerKey string) (WGStats, error) {
+ return m.GetStatsFunc(peerKey)
+}
diff --git a/iface/iwginterface.go b/iface/iwginterface.go
new file mode 100644
index 000000000..501f51d2b
--- /dev/null
+++ b/iface/iwginterface.go
@@ -0,0 +1,32 @@
+//go:build !windows
+
+package iface
+
+import (
+ "net"
+ "time"
+
+ "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
+
+ "github.com/netbirdio/netbird/iface/bind"
+)
+
+type IWGIface interface {
+ Create() error
+ CreateOnAndroid(routeRange []string, ip string, domains []string) error
+ IsUserspaceBind() bool
+ Name() string
+ Address() WGAddress
+ ToInterface() *net.Interface
+ Up() (*bind.UniversalUDPMuxDefault, error)
+ UpdateAddr(newAddr string) error
+ UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
+ RemovePeer(peerKey string) error
+ AddAllowedIP(peerKey string, allowedIP string) error
+ RemoveAllowedIP(peerKey string, allowedIP string) error
+ Close() error
+ SetFilter(filter PacketFilter) error
+ GetFilter() PacketFilter
+ GetDevice() *DeviceWrapper
+ GetStats(peerKey string) (WGStats, error)
+}
diff --git a/iface/iwginterface_windows.go b/iface/iwginterface_windows.go
new file mode 100644
index 000000000..b5053474e
--- /dev/null
+++ b/iface/iwginterface_windows.go
@@ -0,0 +1,31 @@
+package iface
+
+import (
+ "net"
+ "time"
+
+ "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
+
+ "github.com/netbirdio/netbird/iface/bind"
+)
+
+type IWGIface interface {
+ Create() error
+ CreateOnAndroid(routeRange []string, ip string, domains []string) error
+ IsUserspaceBind() bool
+ Name() string
+ Address() WGAddress
+ ToInterface() *net.Interface
+ Up() (*bind.UniversalUDPMuxDefault, error)
+ UpdateAddr(newAddr string) error
+ UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
+ RemovePeer(peerKey string) error
+ AddAllowedIP(peerKey string, allowedIP string) error
+ RemoveAllowedIP(peerKey string, allowedIP string) error
+ Close() error
+ SetFilter(filter PacketFilter) error
+ GetFilter() PacketFilter
+ GetDevice() *DeviceWrapper
+ GetStats(peerKey string) (WGStats, error)
+ GetInterfaceGUIDString() (string, error)
+}
From 92a0092ad5262a3687d64a42edbf0938e3872039 Mon Sep 17 00:00:00 2001
From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com>
Date: Fri, 30 Aug 2024 15:44:07 +0200
Subject: [PATCH 20/89] [signal] Use signal dispatcher (#2373)
---
go.mod | 1 +
go.sum | 2 +
signal/peer/peer.go | 12 ++--
signal/server/signal.go | 142 ++++++++++++++++++++--------------------
4 files changed, 82 insertions(+), 75 deletions(-)
diff --git a/go.mod b/go.mod
index e10394bc6..cbe32427b 100644
--- a/go.mod
+++ b/go.mod
@@ -58,6 +58,7 @@ require (
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e
+ github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240820130728-bc0683599080
github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0
github.com/patrickmn/go-cache v2.1.0+incompatible
diff --git a/go.sum b/go.sum
index 8407fdec7..6625dbb71 100644
--- a/go.sum
+++ b/go.sum
@@ -477,6 +477,8 @@ github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-
github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y=
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g=
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
+github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240820130728-bc0683599080 h1:mXJkoWLdqJTlkQ7DgQ536kcXHXIdUPeagkN8i4eFDdg=
+github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240820130728-bc0683599080/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs=
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM=
diff --git a/signal/peer/peer.go b/signal/peer/peer.go
index 3149526b2..85de91581 100644
--- a/signal/peer/peer.go
+++ b/signal/peer/peer.go
@@ -18,16 +18,20 @@ type Peer struct {
StreamID int64
- //a gRpc connection stream to the Peer
+ // a gRpc connection stream to the Peer
Stream proto.SignalExchange_ConnectStreamServer
+
+ // registration time
+ RegisteredAt time.Time
}
// NewPeer creates a new instance of a connected Peer
func NewPeer(id string, stream proto.SignalExchange_ConnectStreamServer) *Peer {
return &Peer{
- Id: id,
- Stream: stream,
- StreamID: time.Now().UnixNano(),
+ Id: id,
+ Stream: stream,
+ StreamID: time.Now().UnixNano(),
+ RegisteredAt: time.Now(),
}
}
diff --git a/signal/server/signal.go b/signal/server/signal.go
index 219bdcc41..69387cc69 100644
--- a/signal/server/signal.go
+++ b/signal/server/signal.go
@@ -13,6 +13,8 @@ import (
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
+ "github.com/netbirdio/signal-dispatcher/dispatcher"
+
"github.com/netbirdio/netbird/signal/metrics"
"github.com/netbirdio/netbird/signal/peer"
"github.com/netbirdio/netbird/signal/proto"
@@ -40,8 +42,8 @@ const (
type Server struct {
registry *peer.Registry
proto.UnimplementedSignalExchangeServer
-
- metrics *metrics.AppMetrics
+ dispatcher *dispatcher.Dispatcher
+ metrics *metrics.AppMetrics
}
// NewServer creates a new Signal server
@@ -51,9 +53,15 @@ func NewServer(meter metric.Meter) (*Server, error) {
return nil, fmt.Errorf("creating app metrics: %v", err)
}
+ dispatcher, err := dispatcher.NewDispatcher()
+ if err != nil {
+ return nil, fmt.Errorf("creating dispatcher: %v", err)
+ }
+
s := &Server{
- registry: peer.NewRegistry(appMetrics),
- metrics: appMetrics,
+ dispatcher: dispatcher,
+ registry: peer.NewRegistry(appMetrics),
+ metrics: appMetrics,
}
return s, nil
@@ -61,57 +69,31 @@ func NewServer(meter metric.Meter) (*Server, error) {
// Send forwards a message to the signal peer
func (s *Server) Send(ctx context.Context, msg *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
- if !s.registry.IsPeerRegistered(msg.Key) {
- s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeNotRegistered)))
+ log.Debugf("received a new message to send from peer [%s] to peer [%s]", msg.Key, msg.RemoteKey)
- return nil, fmt.Errorf("peer %s is not registered", msg.Key)
+ if msg.RemoteKey == "dummy" {
+ // Test message send during netbird status
+ return &proto.EncryptedMessage{}, nil
}
- getRegistrationStart := time.Now()
-
- if dstPeer, found := s.registry.Get(msg.RemoteKey); found {
- s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeMessage), attribute.String(labelRegistrationStatus, labelRegistrationFound)))
- start := time.Now()
- //forward the message to the target peer
- if err := dstPeer.Stream.Send(msg); err != nil {
- log.Errorf("error while forwarding message from peer [%s] to peer [%s] %v", msg.Key, msg.RemoteKey, err)
- //todo respond to the sender?
-
- s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeError)))
- } else {
- s.metrics.MessageForwardLatency.Record(ctx, float64(time.Since(start).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeMessage)))
- s.metrics.MessagesForwarded.Add(context.Background(), 1)
- }
- } else {
- s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeMessage), attribute.String(labelRegistrationStatus, labelRegistrationNotFound)))
- log.Debugf("message from peer [%s] can't be forwarded to peer [%s] because destination peer is not connected", msg.Key, msg.RemoteKey)
- //todo respond to the sender?
-
- s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeNotConnected)))
+ if _, found := s.registry.Get(msg.RemoteKey); found {
+ s.forwardMessageToPeer(ctx, msg)
+ return &proto.EncryptedMessage{}, nil
}
- return &proto.EncryptedMessage{}, nil
+
+ return s.dispatcher.SendMessage(context.Background(), msg)
}
// ConnectStream connects to the exchange stream
func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer) error {
- p, err := s.connectPeer(stream)
+ p, err := s.RegisterPeer(stream)
if err != nil {
return err
}
- startRegister := time.Now()
+ defer s.DeregisterPeer(p)
- s.metrics.ActivePeers.Add(stream.Context(), 1)
-
- defer func() {
- log.Infof("peer disconnected [%s] [streamID %d] ", p.Id, p.StreamID)
- s.registry.Deregister(p)
-
- s.metrics.PeerConnectionDuration.Record(stream.Context(), int64(time.Since(startRegister).Seconds()))
- s.metrics.ActivePeers.Add(context.Background(), -1)
- }()
-
- //needed to confirm that the peer has been registered so that the client can proceed
+ // needed to confirm that the peer has been registered so that the client can proceed
header := metadata.Pairs(proto.HeaderRegistered, "1")
err = stream.SendHeader(header)
if err != nil {
@@ -119,11 +101,10 @@ func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer)
return err
}
- log.Infof("peer connected [%s] [streamID %d] ", p.Id, p.StreamID)
+ log.Debugf("peer connected [%s] [streamID %d] ", p.Id, p.StreamID)
for {
-
- //read incoming messages
+ // read incoming messages
msg, err := stream.Recv()
if err == io.EOF {
break
@@ -131,44 +112,28 @@ func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer)
return err
}
- log.Debugf("received a new message from peer [%s] to peer [%s]", p.Id, msg.RemoteKey)
+ log.Debugf("Received a response from peer [%s] to peer [%s]", msg.Key, msg.RemoteKey)
- getRegistrationStart := time.Now()
-
- // lookup the target peer where the message is going to
- if dstPeer, found := s.registry.Get(msg.RemoteKey); found {
- s.metrics.GetRegistrationDelay.Record(stream.Context(), float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationFound)))
- start := time.Now()
- //forward the message to the target peer
- if err := dstPeer.Stream.Send(msg); err != nil {
- log.Errorf("error while forwarding message from peer [%s] to peer [%s] %v", p.Id, msg.RemoteKey, err)
- //todo respond to the sender?
- s.metrics.MessageForwardFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelType, labelTypeError)))
- } else {
- // in milliseconds
- s.metrics.MessageForwardLatency.Record(stream.Context(), float64(time.Since(start).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream)))
- s.metrics.MessagesForwarded.Add(stream.Context(), 1)
- }
- } else {
- s.metrics.GetRegistrationDelay.Record(stream.Context(), float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationNotFound)))
- s.metrics.MessageForwardFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelType, labelTypeNotConnected)))
- log.Debugf("message from peer [%s] can't be forwarded to peer [%s] because destination peer is not connected", p.Id, msg.RemoteKey)
- //todo respond to the sender?
+ _, err = s.dispatcher.SendMessage(stream.Context(), msg)
+ if err != nil {
+ log.Debugf("error while sending message from peer [%s] to peer [%s] %v", msg.Key, msg.RemoteKey, err)
}
}
+
<-stream.Context().Done()
return stream.Context().Err()
}
-// Handles initial Peer connection.
-// Each connection must provide an Id header.
-// At this moment the connecting Peer will be registered in the peer.Registry
-func (s Server) connectPeer(stream proto.SignalExchange_ConnectStreamServer) (*peer.Peer, error) {
+func (s *Server) RegisterPeer(stream proto.SignalExchange_ConnectStreamServer) (*peer.Peer, error) {
+ log.Debugf("registering new peer")
if meta, hasMeta := metadata.FromIncomingContext(stream.Context()); hasMeta {
if id, found := meta[proto.HeaderId]; found {
p := peer.NewPeer(id[0], stream)
s.registry.Register(p)
+ s.dispatcher.ListenForMessages(stream.Context(), p.Id, s.forwardMessageToPeer)
+
+ s.metrics.ActivePeers.Add(stream.Context(), 1)
return p, nil
} else {
@@ -180,3 +145,38 @@ func (s Server) connectPeer(stream proto.SignalExchange_ConnectStreamServer) (*p
return nil, status.Errorf(codes.FailedPrecondition, "missing connection stream meta")
}
}
+
+func (s *Server) DeregisterPeer(p *peer.Peer) {
+ log.Debugf("peer disconnected [%s] [streamID %d] ", p.Id, p.StreamID)
+ s.registry.Deregister(p)
+
+ s.metrics.PeerConnectionDuration.Record(p.Stream.Context(), int64(time.Since(p.RegisteredAt).Seconds()))
+ s.metrics.ActivePeers.Add(context.Background(), -1)
+}
+
+func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedMessage) {
+ log.Debugf("forwarding a new message from peer [%s] to peer [%s]", msg.Key, msg.RemoteKey)
+
+ getRegistrationStart := time.Now()
+
+ // lookup the target peer where the message is going to
+ if dstPeer, found := s.registry.Get(msg.RemoteKey); found {
+ s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationFound)))
+ start := time.Now()
+ // forward the message to the target peer
+ if err := dstPeer.Stream.Send(msg); err != nil {
+ log.Warnf("error while forwarding message from peer [%s] to peer [%s] %v", msg.Key, msg.RemoteKey, err)
+ // todo respond to the sender?
+ s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeError)))
+ } else {
+ // in milliseconds
+ s.metrics.MessageForwardLatency.Record(ctx, float64(time.Since(start).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream)))
+ s.metrics.MessagesForwarded.Add(ctx, 1)
+ }
+ } else {
+ s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationNotFound)))
+ s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeNotConnected)))
+ log.Debugf("message from peer [%s] can't be forwarded to peer [%s] because destination peer is not connected", msg.Key, msg.RemoteKey)
+ // todo respond to the sender?
+ }
+}
From 95174d4619cf8d86fc0c29cc33bfe0ba6d981e0d Mon Sep 17 00:00:00 2001
From: Maycon Santos
Date: Mon, 2 Sep 2024 17:40:34 +0200
Subject: [PATCH 21/89] Update route API doc with max domain number (#2516)
---
management/server/http/api/openapi.yml | 6 +++---
management/server/http/api/types.gen.go | 4 ++--
2 files changed, 5 insertions(+), 5 deletions(-)
diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml
index 45887dc2e..d32ec6167 100644
--- a/management/server/http/api/openapi.yml
+++ b/management/server/http/api/openapi.yml
@@ -1064,12 +1064,12 @@ components:
type: string
example: 10.64.0.0/24
domains:
- description: Domain list to be dynamically resolved. Conflicts with network
+ description: Domain list to be dynamically resolved. Max of 32 domains can be added per route configuration. Conflicts with network
type: array
items:
type: string
minLength: 1
- maxLength: 255
+ maxLength: 32
example: "example.com"
metric:
description: Route metric number. Lowest number has higher priority
@@ -2759,4 +2759,4 @@ paths:
'403':
"$ref": "#/components/responses/forbidden"
'500':
- "$ref": "#/components/responses/internal_error"
\ No newline at end of file
+ "$ref": "#/components/responses/internal_error"
diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go
index 77a6c643d..a575ff54b 100644
--- a/management/server/http/api/types.gen.go
+++ b/management/server/http/api/types.gen.go
@@ -995,7 +995,7 @@ type Route struct {
// Description Route description
Description string `json:"description"`
- // Domains Domain list to be dynamically resolved. Conflicts with network
+ // Domains Domain list to be dynamically resolved. Max of 32 domains can be added per route configuration. Conflicts with network
Domains *[]string `json:"domains,omitempty"`
// Enabled Route status
@@ -1037,7 +1037,7 @@ type RouteRequest struct {
// Description Route description
Description string `json:"description"`
- // Domains Domain list to be dynamically resolved. Conflicts with network
+ // Domains Domain list to be dynamically resolved. Max of 32 domains can be added per route configuration. Conflicts with network
Domains *[]string `json:"domains,omitempty"`
// Enabled Route status
From 13e7198046a0d73a9cd91bf8e063fafb3d41885c Mon Sep 17 00:00:00 2001
From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com>
Date: Mon, 2 Sep 2024 19:19:14 +0200
Subject: [PATCH 22/89] [client] Destory WG interface on down timeout (#2435)
wait on engine down to not only wait for the interface to be down but completely removed. If the waiting loop reaches the timeout we will trigger an interface destroy. On the up command, it now waits until the engine is fully running before sending the response to the CLI. Includes a small refactor of probes to comply with sonar rules about parameter count in the function call
---
client/cmd/down.go | 2 +
client/internal/connect.go | 38 +++++++--------
client/internal/engine.go | 85 +++++++++++-----------------------
client/internal/probe.go | 7 +++
client/server/server.go | 52 ++++++++++++---------
client/server/server_test.go | 2 +-
iface/iface.go | 45 +++++++++++++++++-
iface/iface_destroy_bsd.go | 17 +++++++
iface/iface_destroy_linux.go | 22 +++++++++
iface/iface_destroy_mobile.go | 9 ++++
iface/iface_destroy_windows.go | 32 +++++++++++++
iface/tun_darwin.go | 7 +--
iface/tun_kernel_unix.go | 2 +-
iface/tun_netstack.go | 4 +-
iface/tun_usp_unix.go | 8 ++--
iface/tun_windows.go | 6 +--
16 files changed, 222 insertions(+), 116 deletions(-)
create mode 100644 iface/iface_destroy_bsd.go
create mode 100644 iface/iface_destroy_linux.go
create mode 100644 iface/iface_destroy_mobile.go
create mode 100644 iface/iface_destroy_windows.go
diff --git a/client/cmd/down.go b/client/cmd/down.go
index 4d9f1eba4..3a324cc19 100644
--- a/client/cmd/down.go
+++ b/client/cmd/down.go
@@ -42,6 +42,8 @@ var downCmd = &cobra.Command{
log.Errorf("call service down method: %v", err)
return err
}
+
+ cmd.Println("Disconnected")
return nil
},
}
diff --git a/client/internal/connect.go b/client/internal/connect.go
index 3937b7846..62fd3c61d 100644
--- a/client/internal/connect.go
+++ b/client/internal/connect.go
@@ -55,17 +55,15 @@ func NewConnectClient(
// Run with main logic.
func (c *ConnectClient) Run() error {
- return c.run(MobileDependency{}, nil, nil, nil, nil)
+ return c.run(MobileDependency{}, nil, nil)
}
// RunWithProbes runs the client's main logic with probes attached
func (c *ConnectClient) RunWithProbes(
- mgmProbe *Probe,
- signalProbe *Probe,
- relayProbe *Probe,
- wgProbe *Probe,
+ probes *ProbeHolder,
+ runningWg *sync.WaitGroup,
) error {
- return c.run(MobileDependency{}, mgmProbe, signalProbe, relayProbe, wgProbe)
+ return c.run(MobileDependency{}, probes, runningWg)
}
// RunOnAndroid with main logic on mobile system
@@ -84,7 +82,7 @@ func (c *ConnectClient) RunOnAndroid(
HostDNSAddresses: dnsAddresses,
DnsReadyListener: dnsReadyListener,
}
- return c.run(mobileDependency, nil, nil, nil, nil)
+ return c.run(mobileDependency, nil, nil)
}
func (c *ConnectClient) RunOniOS(
@@ -100,15 +98,13 @@ func (c *ConnectClient) RunOniOS(
NetworkChangeListener: networkChangeListener,
DnsManager: dnsManager,
}
- return c.run(mobileDependency, nil, nil, nil, nil)
+ return c.run(mobileDependency, nil, nil)
}
func (c *ConnectClient) run(
mobileDependency MobileDependency,
- mgmProbe *Probe,
- signalProbe *Probe,
- relayProbe *Probe,
- wgProbe *Probe,
+ probes *ProbeHolder,
+ runningWg *sync.WaitGroup,
) error {
defer func() {
if r := recover(); r != nil {
@@ -255,7 +251,7 @@ func (c *ConnectClient) run(
checks := loginResp.GetChecks()
c.engineMutex.Lock()
- c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, c.statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe, checks)
+ c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, c.statusRecorder, probes, checks)
c.engineMutex.Unlock()
err = c.engine.Start()
@@ -267,17 +263,15 @@ func (c *ConnectClient) run(
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
state.Set(StatusConnected)
+ if runningWg != nil {
+ runningWg.Done()
+ }
+
<-engineCtx.Done()
c.statusRecorder.ClientTeardown()
backOff.Reset()
- err = c.engine.Stop()
- if err != nil {
- log.Errorf("failed stopping engine %v", err)
- return wrapErr(err)
- }
-
log.Info("stopped NetBird client")
if _, err := state.Status(); errors.Is(err, ErrResetConnection) {
@@ -307,6 +301,12 @@ func (c *ConnectClient) Engine() *Engine {
return e
}
+func (c *ConnectClient) Stop() error {
+ c.engineMutex.Lock()
+ defer c.engineMutex.Unlock()
+ return c.engine.Stop()
+}
+
// createEngineConfig converts configuration received from Management Service to EngineConfig
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
nm := false
diff --git a/client/internal/engine.go b/client/internal/engine.go
index b3fc2b628..ca93fa482 100644
--- a/client/internal/engine.go
+++ b/client/internal/engine.go
@@ -155,10 +155,7 @@ type Engine struct {
dnsServer dns.Server
- mgmProbe *Probe
- signalProbe *Probe
- relayProbe *Probe
- wgProbe *Probe
+ probes *ProbeHolder
wgConnWorker sync.WaitGroup
@@ -192,9 +189,6 @@ func NewEngine(
mobileDep,
statusRecorder,
nil,
- nil,
- nil,
- nil,
checks,
)
}
@@ -208,10 +202,7 @@ func NewEngineWithProbes(
config *EngineConfig,
mobileDep MobileDependency,
statusRecorder *peer.Status,
- mgmProbe *Probe,
- signalProbe *Probe,
- relayProbe *Probe,
- wgProbe *Probe,
+ probes *ProbeHolder,
checks []*mgmProto.Checks,
) *Engine {
@@ -229,22 +220,20 @@ func NewEngineWithProbes(
networkSerial: 0,
sshServerFunc: nbssh.DefaultSSHServer,
statusRecorder: statusRecorder,
- mgmProbe: mgmProbe,
- signalProbe: signalProbe,
- relayProbe: relayProbe,
- wgProbe: wgProbe,
+ probes: probes,
checks: checks,
}
}
func (e *Engine) Stop() error {
+ if e == nil {
+ // this seems to be a very odd case but there was the possibility if the netbird down command comes before the engine is fully started
+ log.Debugf("tried stopping engine that is nil")
+ return nil
+ }
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
- if e.cancel != nil {
- e.cancel()
- }
-
// stopping network monitor first to avoid starting the engine again
if e.networkMonitor != nil {
e.networkMonitor.Stop()
@@ -260,29 +249,21 @@ func (e *Engine) Stop() error {
e.clientRoutes = nil
e.clientRoutesMu.Unlock()
+ if e.cancel != nil {
+ e.cancel()
+ }
+
// very ugly but we want to remove peers from the WireGuard interface first before removing interface.
// Removing peers happens in the conn.Close() asynchronously
time.Sleep(500 * time.Millisecond)
e.close()
+
e.wgConnWorker.Wait()
- maxWaitTime := 5 * time.Second
- timeout := time.After(maxWaitTime)
+ log.Infof("Engine stopped")
- for {
- if !e.IsWGIfaceUp() {
- log.Infof("stopped Netbird Engine")
- return nil
- }
-
- select {
- case <-timeout:
- return fmt.Errorf("timeout when waiting for interface shutdown")
- default:
- time.Sleep(100 * time.Millisecond)
- }
- }
+ return nil
}
// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
@@ -1415,24 +1396,27 @@ func (e *Engine) getRosenpassAddr() string {
}
func (e *Engine) receiveProbeEvents() {
- if e.signalProbe != nil {
- go e.signalProbe.Receive(e.ctx, func() bool {
+ if e.probes == nil {
+ return
+ }
+ if e.probes.SignalProbe != nil {
+ go e.probes.SignalProbe.Receive(e.ctx, func() bool {
healthy := e.signal.IsHealthy()
log.Debugf("received signal probe request, healthy: %t", healthy)
return healthy
})
}
- if e.mgmProbe != nil {
- go e.mgmProbe.Receive(e.ctx, func() bool {
+ if e.probes.MgmProbe != nil {
+ go e.probes.MgmProbe.Receive(e.ctx, func() bool {
healthy := e.mgmClient.IsHealthy()
log.Debugf("received management probe request, healthy: %t", healthy)
return healthy
})
}
- if e.relayProbe != nil {
- go e.relayProbe.Receive(e.ctx, func() bool {
+ if e.probes.RelayProbe != nil {
+ go e.probes.RelayProbe.Receive(e.ctx, func() bool {
healthy := true
results := append(e.probeSTUNs(), e.probeTURNs()...)
@@ -1451,8 +1435,8 @@ func (e *Engine) receiveProbeEvents() {
})
}
- if e.wgProbe != nil {
- go e.wgProbe.Receive(e.ctx, func() bool {
+ if e.probes.WgProbe != nil {
+ go e.probes.WgProbe.Receive(e.ctx, func() bool {
log.Debug("received wg probe request")
for _, peer := range e.peerConns {
@@ -1548,20 +1532,3 @@ func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
return slices.Equal(checks.Files, oChecks.Files)
})
}
-
-func (e *Engine) IsWGIfaceUp() bool {
- if e == nil || e.wgInterface == nil {
- return false
- }
- iface, err := net.InterfaceByName(e.wgInterface.Name())
- if err != nil {
- log.Debugf("failed to get interface by name %s: %v", e.wgInterface.Name(), err)
- return false
- }
-
- if iface.Flags&net.FlagUp != 0 {
- return true
- }
-
- return false
-}
diff --git a/client/internal/probe.go b/client/internal/probe.go
index 743b6b190..23290cf74 100644
--- a/client/internal/probe.go
+++ b/client/internal/probe.go
@@ -2,6 +2,13 @@ package internal
import "context"
+type ProbeHolder struct {
+ MgmProbe *Probe
+ SignalProbe *Probe
+ RelayProbe *Probe
+ WgProbe *Probe
+}
+
// Probe allows to run on-demand callbacks from different code locations.
// Pass the probe to a receiving and a sending end. The receiving end starts listening
// to requests with Receive and executes a callback when the sending end requests it
diff --git a/client/server/server.go b/client/server/server.go
index 8173d0741..ce6e90864 100644
--- a/client/server/server.go
+++ b/client/server/server.go
@@ -12,7 +12,6 @@ import (
"github.com/cenkalti/backoff/v4"
"golang.org/x/exp/maps"
-
"google.golang.org/protobuf/types/known/durationpb"
log "github.com/sirupsen/logrus"
@@ -143,10 +142,14 @@ func (s *Server) Start() error {
s.sessionWatcher.SetOnExpireListener(s.onSessionExpire)
}
+ runningWg := sync.WaitGroup{}
+ runningWg.Add(1)
if !config.DisableAutoConnect {
- go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe)
+ go s.connectWithRetryRuns(ctx, config, s.statusRecorder, &runningWg)
}
+ runningWg.Wait()
+
return nil
}
@@ -154,7 +157,7 @@ func (s *Server) Start() error {
// mechanism to keep the client connected even when the connection is lost.
// we cancel retry if the client receive a stop or down command, or if disable auto connect is configured.
func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Config, statusRecorder *peer.Status,
- mgmProbe *internal.Probe, signalProbe *internal.Probe, relayProbe *internal.Probe, wgProbe *internal.Probe,
+ runningWg *sync.WaitGroup,
) {
backOff := getConnectWithBackoff(ctx)
retryStarted := false
@@ -185,7 +188,15 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Conf
runOperation := func() error {
log.Tracef("running client connection")
s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder)
- err := s.connectClient.RunWithProbes(mgmProbe, signalProbe, relayProbe, wgProbe)
+
+ probes := internal.ProbeHolder{
+ MgmProbe: s.mgmProbe,
+ SignalProbe: s.signalProbe,
+ RelayProbe: s.relayProbe,
+ WgProbe: s.wgProbe,
+ }
+
+ err := s.connectClient.RunWithProbes(&probes, runningWg)
if err != nil {
log.Debugf("run client connection exited with error: %v. Will retry in the background", err)
}
@@ -576,7 +587,11 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes
s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String())
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)
- go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe)
+ runningWg := sync.WaitGroup{}
+ runningWg.Add(1)
+ go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, &runningWg)
+
+ runningWg.Wait()
return &proto.UpResponse{}, nil
}
@@ -590,28 +605,19 @@ func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownRes
return nil, fmt.Errorf("service is not up")
}
s.actCancel()
+
+ err := s.connectClient.Stop()
+ if err != nil {
+ log.Errorf("failed to shut down properly: %v", err)
+ return nil, err
+ }
+
state := internal.CtxGetState(s.rootCtx)
state.Set(internal.StatusIdle)
- maxWaitTime := 5 * time.Second
- timeout := time.After(maxWaitTime)
+ log.Infof("service is down")
- engine := s.connectClient.Engine()
-
- for {
- if !engine.IsWGIfaceUp() {
- return &proto.DownResponse{}, nil
- }
-
- select {
- case <-ctx.Done():
- return &proto.DownResponse{}, nil
- case <-timeout:
- return nil, fmt.Errorf("failed to shut down properly")
- default:
- time.Sleep(100 * time.Millisecond)
- }
- }
+ return &proto.DownResponse{}, nil
}
// Status returns the daemon status
diff --git a/client/server/server_test.go b/client/server/server_test.go
index 6a3de774c..242d399ec 100644
--- a/client/server/server_test.go
+++ b/client/server/server_test.go
@@ -73,7 +73,7 @@ func TestConnectWithRetryRuns(t *testing.T) {
t.Setenv(maxRetryTimeVar, "5s")
t.Setenv(retryMultiplierVar, "1")
- s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe)
+ s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil)
if counter < 3 {
t.Fatalf("expected counter > 2, got %d", counter)
}
diff --git a/iface/iface.go b/iface/iface.go
index 928077a3d..545feffcf 100644
--- a/iface/iface.go
+++ b/iface/iface.go
@@ -124,7 +124,23 @@ func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP string) error {
func (w *WGIface) Close() error {
w.mu.Lock()
defer w.mu.Unlock()
- return w.tun.Close()
+
+ err := w.tun.Close()
+ if err != nil {
+ return fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err)
+ }
+
+ err = w.waitUntilRemoved()
+ if err != nil {
+ log.Warnf("failed to remove WireGuard interface %s: %v", w.Name(), err)
+ err = w.Destroy()
+ if err != nil {
+ return fmt.Errorf("failed to remove WireGuard interface %s: %w", w.Name(), err)
+ }
+ log.Infof("interface %s successfully removed", w.Name())
+ }
+
+ return nil
}
// SetFilter sets packet filters for the userspace implementation
@@ -163,3 +179,30 @@ func (w *WGIface) GetDevice() *DeviceWrapper {
func (w *WGIface) GetStats(peerKey string) (WGStats, error) {
return w.configurer.getStats(peerKey)
}
+
+func (w *WGIface) waitUntilRemoved() error {
+ maxWaitTime := 5 * time.Second
+ timeout := time.NewTimer(maxWaitTime)
+ defer timeout.Stop()
+
+ for {
+ iface, err := net.InterfaceByName(w.Name())
+ if err != nil {
+ if _, ok := err.(*net.OpError); ok {
+ log.Infof("interface %s has been removed", w.Name())
+ return nil
+ }
+ log.Debugf("failed to get interface by name %s: %v", w.Name(), err)
+ } else if iface == nil {
+ log.Infof("interface %s has been removed", w.Name())
+ return nil
+ }
+
+ select {
+ case <-timeout.C:
+ return fmt.Errorf("timeout when waiting for interface %s to be removed", w.Name())
+ default:
+ time.Sleep(100 * time.Millisecond)
+ }
+ }
+}
diff --git a/iface/iface_destroy_bsd.go b/iface/iface_destroy_bsd.go
new file mode 100644
index 000000000..c16010a1c
--- /dev/null
+++ b/iface/iface_destroy_bsd.go
@@ -0,0 +1,17 @@
+//go:build darwin || dragonfly || freebsd || netbsd || openbsd
+
+package iface
+
+import (
+ "fmt"
+ "os/exec"
+)
+
+func (w *WGIface) Destroy() error {
+ out, err := exec.Command("ifconfig", w.Name(), "destroy").CombinedOutput()
+ if err != nil {
+ return fmt.Errorf("failed to remove interface %s: %w - %s", w.Name(), err, out)
+ }
+
+ return nil
+}
diff --git a/iface/iface_destroy_linux.go b/iface/iface_destroy_linux.go
new file mode 100644
index 000000000..e9d54bed1
--- /dev/null
+++ b/iface/iface_destroy_linux.go
@@ -0,0 +1,22 @@
+//go:build linux && !android
+
+package iface
+
+import (
+ "fmt"
+
+ "github.com/vishvananda/netlink"
+)
+
+func (w *WGIface) Destroy() error {
+ link, err := netlink.LinkByName(w.Name())
+ if err != nil {
+ return fmt.Errorf("failed to get link by name %s: %w", w.Name(), err)
+ }
+
+ if err := netlink.LinkDel(link); err != nil {
+ return fmt.Errorf("failed to delete link %s: %w", w.Name(), err)
+ }
+
+ return nil
+}
diff --git a/iface/iface_destroy_mobile.go b/iface/iface_destroy_mobile.go
new file mode 100644
index 000000000..89f87a598
--- /dev/null
+++ b/iface/iface_destroy_mobile.go
@@ -0,0 +1,9 @@
+//go:build android || (ios && !darwin)
+
+package iface
+
+import "errors"
+
+func (w *WGIface) Destroy() error {
+ return errors.New("not supported on mobile")
+}
diff --git a/iface/iface_destroy_windows.go b/iface/iface_destroy_windows.go
new file mode 100644
index 000000000..0bfa4e211
--- /dev/null
+++ b/iface/iface_destroy_windows.go
@@ -0,0 +1,32 @@
+//go:build windows
+
+package iface
+
+import (
+ "fmt"
+ "os/exec"
+
+ log "github.com/sirupsen/logrus"
+)
+
+func (w *WGIface) Destroy() error {
+ netshCmd := GetSystem32Command("netsh")
+ out, err := exec.Command(netshCmd, "interface", "set", "interface", w.Name(), "admin=disable").CombinedOutput()
+ if err != nil {
+ return fmt.Errorf("failed to remove interface %s: %w - %s", w.Name(), err, out)
+ }
+ return nil
+}
+
+// GetSystem32Command checks if a command can be found in the system path and returns it. In case it can't find it
+// in the path it will return the full path of a command assuming C:\windows\system32 as the base path.
+func GetSystem32Command(command string) string {
+ _, err := exec.LookPath(command)
+ if err == nil {
+ return command
+ }
+
+ log.Tracef("Command %s not found in PATH, using C:\\windows\\system32\\%s.exe path", command, command)
+
+ return "C:\\windows\\system32\\" + command + ".exe"
+}
diff --git a/iface/tun_darwin.go b/iface/tun_darwin.go
index 364e5dfad..fcf9f8ba0 100644
--- a/iface/tun_darwin.go
+++ b/iface/tun_darwin.go
@@ -3,6 +3,7 @@
package iface
import (
+ "fmt"
"os/exec"
"github.com/pion/transport/v3"
@@ -41,7 +42,7 @@ func newTunDevice(name string, address WGAddress, port int, key string, mtu int,
func (t *tunDevice) Create() (wgConfigurer, error) {
tunDevice, err := tun.CreateTUN(t.name, t.mtu)
if err != nil {
- return nil, err
+ return nil, fmt.Errorf("error creating tun device: %s", err)
}
t.wrapper = newDeviceWrapper(tunDevice)
@@ -55,7 +56,7 @@ func (t *tunDevice) Create() (wgConfigurer, error) {
err = t.assignAddr()
if err != nil {
t.device.Close()
- return nil, err
+ return nil, fmt.Errorf("error assigning ip: %s", err)
}
t.configurer = newWGUSPConfigurer(t.device, t.name)
@@ -63,7 +64,7 @@ func (t *tunDevice) Create() (wgConfigurer, error) {
if err != nil {
t.device.Close()
t.configurer.close()
- return nil, err
+ return nil, fmt.Errorf("error configuring interface: %s", err)
}
return t.configurer, nil
}
diff --git a/iface/tun_kernel_unix.go b/iface/tun_kernel_unix.go
index 019dd786b..220c07888 100644
--- a/iface/tun_kernel_unix.go
+++ b/iface/tun_kernel_unix.go
@@ -70,7 +70,7 @@ func (t *tunKernelDevice) Create() (wgConfigurer, error) {
configurer := newWGConfigurer(t.name)
if err := configurer.configureInterface(t.key, t.wgPort); err != nil {
- return nil, err
+ return nil, fmt.Errorf("error configuring interface: %s", err)
}
return configurer, nil
diff --git a/iface/tun_netstack.go b/iface/tun_netstack.go
index df2f75c45..de1ff6654 100644
--- a/iface/tun_netstack.go
+++ b/iface/tun_netstack.go
@@ -47,7 +47,7 @@ func (t *tunNetstackDevice) Create() (wgConfigurer, error) {
t.nsTun = netstack.NewNetStackTun(t.listenAddress, t.address.IP.String(), t.mtu)
tunIface, err := t.nsTun.Create()
if err != nil {
- return nil, err
+ return nil, fmt.Errorf("error creating tun device: %s", err)
}
t.wrapper = newDeviceWrapper(tunIface)
@@ -61,7 +61,7 @@ func (t *tunNetstackDevice) Create() (wgConfigurer, error) {
err = t.configurer.configureInterface(t.key, t.port)
if err != nil {
_ = tunIface.Close()
- return nil, err
+ return nil, fmt.Errorf("error configuring interface: %s", err)
}
log.Debugf("device has been created: %s", t.name)
diff --git a/iface/tun_usp_unix.go b/iface/tun_usp_unix.go
index 814c9ca89..1c1d3ac89 100644
--- a/iface/tun_usp_unix.go
+++ b/iface/tun_usp_unix.go
@@ -48,8 +48,8 @@ func (t *tunUSPDevice) Create() (wgConfigurer, error) {
log.Info("create tun interface")
tunIface, err := tun.CreateTUN(t.name, t.mtu)
if err != nil {
- log.Debugf("failed to create tun unterface (%s, %d): %s", t.name, t.mtu, err)
- return nil, err
+ log.Debugf("failed to create tun interface (%s, %d): %s", t.name, t.mtu, err)
+ return nil, fmt.Errorf("error creating tun device: %s", err)
}
t.wrapper = newDeviceWrapper(tunIface)
@@ -63,7 +63,7 @@ func (t *tunUSPDevice) Create() (wgConfigurer, error) {
err = t.assignAddr()
if err != nil {
t.device.Close()
- return nil, err
+ return nil, fmt.Errorf("error assigning ip: %s", err)
}
t.configurer = newWGUSPConfigurer(t.device, t.name)
@@ -71,7 +71,7 @@ func (t *tunUSPDevice) Create() (wgConfigurer, error) {
if err != nil {
t.device.Close()
t.configurer.close()
- return nil, err
+ return nil, fmt.Errorf("error configuring interface: %s", err)
}
return t.configurer, nil
}
diff --git a/iface/tun_windows.go b/iface/tun_windows.go
index 8c0a3c3b5..afb67bcc0 100644
--- a/iface/tun_windows.go
+++ b/iface/tun_windows.go
@@ -59,7 +59,7 @@ func (t *tunDevice) Create() (wgConfigurer, error) {
log.Info("create tun interface")
tunDevice, err := tun.CreateTUNWithRequestedGUID(t.name, &guid, t.mtu)
if err != nil {
- return nil, err
+ return nil, fmt.Errorf("error creating tun device: %s", err)
}
t.nativeTunDevice = tunDevice.(*tun.NativeTun)
t.wrapper = newDeviceWrapper(tunDevice)
@@ -89,7 +89,7 @@ func (t *tunDevice) Create() (wgConfigurer, error) {
err = t.assignAddr()
if err != nil {
t.device.Close()
- return nil, err
+ return nil, fmt.Errorf("error assigning ip: %s", err)
}
t.configurer = newWGUSPConfigurer(t.device, t.name)
@@ -97,7 +97,7 @@ func (t *tunDevice) Create() (wgConfigurer, error) {
if err != nil {
t.device.Close()
t.configurer.close()
- return nil, err
+ return nil, fmt.Errorf("error configuring interface: %s", err)
}
return t.configurer, nil
}
From 13e923b7c6a68f5d4ce7e90f8d102b848358244e Mon Sep 17 00:00:00 2001
From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com>
Date: Mon, 2 Sep 2024 23:46:36 +0200
Subject: [PATCH 23/89] Fix service down (#2519)
---
client/cmd/service.go | 9 ++++++---
client/cmd/service_controller.go | 10 ++++++++++
2 files changed, 16 insertions(+), 3 deletions(-)
diff --git a/client/cmd/service.go b/client/cmd/service.go
index 5c60744f9..855eb30fa 100644
--- a/client/cmd/service.go
+++ b/client/cmd/service.go
@@ -2,18 +2,21 @@ package cmd
import (
"context"
+
"github.com/kardianos/service"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"google.golang.org/grpc"
"github.com/netbirdio/netbird/client/internal"
+ "github.com/netbirdio/netbird/client/server"
)
type program struct {
- ctx context.Context
- cancel context.CancelFunc
- serv *grpc.Server
+ ctx context.Context
+ cancel context.CancelFunc
+ serv *grpc.Server
+ serverInstance *server.Server
}
func newProgram(ctx context.Context, cancel context.CancelFunc) *program {
diff --git a/client/cmd/service_controller.go b/client/cmd/service_controller.go
index d416afaac..86546e31c 100644
--- a/client/cmd/service_controller.go
+++ b/client/cmd/service_controller.go
@@ -61,6 +61,8 @@ func (p *program) Start(svc service.Service) error {
}
proto.RegisterDaemonServiceServer(p.serv, serverInstance)
+ p.serverInstance = serverInstance
+
log.Printf("started daemon server: %v", split[1])
if err := p.serv.Serve(listen); err != nil {
log.Errorf("failed to serve daemon requests: %v", err)
@@ -70,6 +72,14 @@ func (p *program) Start(svc service.Service) error {
}
func (p *program) Stop(srv service.Service) error {
+ if p.serverInstance != nil {
+ in := new(proto.DownRequest)
+ _, err := p.serverInstance.Down(p.ctx, in)
+ if err != nil {
+ log.Errorf("failed to stop daemon: %v", err)
+ }
+ }
+
p.cancel()
if p.serv != nil {
From 1ff7a953a02842d18cb557562658431bbf25387e Mon Sep 17 00:00:00 2001
From: Zoltan Papp
Date: Wed, 4 Sep 2024 11:14:58 +0200
Subject: [PATCH 24/89] [relay] Store the StunTurn address in thread safe store
(#2470)
Store the StunTurn address in atomic store
---
client/internal/engine.go | 45 +++++++++++-----------
client/internal/peer/conn.go | 52 +++++++++++++-------------
client/internal/peer/conn_test.go | 14 +++----
client/internal/peer/stdnet.go | 2 +-
client/internal/peer/stdnet_android.go | 2 +-
5 files changed, 58 insertions(+), 57 deletions(-)
diff --git a/client/internal/engine.go b/client/internal/engine.go
index ca93fa482..52912f11e 100644
--- a/client/internal/engine.go
+++ b/client/internal/engine.go
@@ -13,6 +13,7 @@ import (
"slices"
"strings"
"sync"
+ "sync/atomic"
"time"
"github.com/pion/ice/v3"
@@ -122,7 +123,8 @@ type Engine struct {
// STUNs is a list of STUN servers used by ICE
STUNs []*stun.URI
// TURNs is a list of STUN servers used by ICE
- TURNs []*stun.URI
+ TURNs []*stun.URI
+ stunTurn atomic.Value
// clientRoutes is the most recent list of clientRoutes received from the Management Service
clientRoutes route.HAMap
@@ -535,6 +537,11 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return err
}
+ var stunTurn []*stun.URI
+ stunTurn = append(stunTurn, e.STUNs...)
+ stunTurn = append(stunTurn, e.TURNs...)
+ e.stunTurn.Store(stunTurn)
+
// todo update signal
}
@@ -961,11 +968,6 @@ func (e *Engine) connWorker(conn *peer.Conn, peerKey string) {
continue
}
- // we might have received new STUN and TURN servers meanwhile, so update them
- e.syncMsgMux.Lock()
- conn.UpdateStunTurn(append(e.STUNs, e.TURNs...))
- e.syncMsgMux.Unlock()
-
err := conn.Open(e.ctx)
if err != nil {
log.Debugf("connection to peer %s failed: %v", peerKey, err)
@@ -989,9 +991,6 @@ func (e *Engine) peerExists(peerKey string) bool {
func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, error) {
log.Debugf("creating peer connection %s", pubKey)
- var stunTurn []*stun.URI
- stunTurn = append(stunTurn, e.STUNs...)
- stunTurn = append(stunTurn, e.TURNs...)
wgConfig := peer.WgConfig{
RemoteKey: pubKey,
@@ -1024,19 +1023,21 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
// randomize connection timeout
timeout := time.Duration(rand.Intn(PeerConnectionTimeoutMax-PeerConnectionTimeoutMin)+PeerConnectionTimeoutMin) * time.Millisecond
config := peer.ConnConfig{
- Key: pubKey,
- LocalKey: e.config.WgPrivateKey.PublicKey().String(),
- StunTurn: stunTurn,
- InterfaceBlackList: e.config.IFaceBlackList,
- DisableIPv6Discovery: e.config.DisableIPv6Discovery,
- Timeout: timeout,
- UDPMux: e.udpMux.UDPMuxDefault,
- UDPMuxSrflx: e.udpMux,
- WgConfig: wgConfig,
- LocalWgPort: e.config.WgPort,
- NATExternalIPs: e.parseNATExternalIPMappings(),
- RosenpassPubKey: e.getRosenpassPubKey(),
- RosenpassAddr: e.getRosenpassAddr(),
+ Key: pubKey,
+ LocalKey: e.config.WgPrivateKey.PublicKey().String(),
+ Timeout: timeout,
+ WgConfig: wgConfig,
+ LocalWgPort: e.config.WgPort,
+ RosenpassPubKey: e.getRosenpassPubKey(),
+ RosenpassAddr: e.getRosenpassAddr(),
+ ICEConfig: peer.ICEConfig{
+ StunTurn: &e.stunTurn,
+ InterfaceBlackList: e.config.IFaceBlackList,
+ DisableIPv6Discovery: e.config.DisableIPv6Discovery,
+ UDPMux: e.udpMux.UDPMuxDefault,
+ UDPMuxSrflx: e.udpMux,
+ NATExternalIPs: e.parseNATExternalIPMappings(),
+ },
}
peerConn, err := peer.NewConn(config, e.statusRecorder, e.wgProxyFactory, e.mobileDep.TunAdapter, e.mobileDep.IFaceDiscover)
diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go
index d1fe0d419..7c3b60011 100644
--- a/client/internal/peer/conn.go
+++ b/client/internal/peer/conn.go
@@ -7,6 +7,7 @@ import (
"runtime"
"strings"
"sync"
+ "sync/atomic"
"time"
"github.com/pion/ice/v3"
@@ -41,37 +42,41 @@ type WgConfig struct {
PreSharedKey *wgtypes.Key
}
-// ConnConfig is a peer Connection configuration
-type ConnConfig struct {
-
- // Key is a public key of a remote peer
- Key string
- // LocalKey is a public key of a local peer
- LocalKey string
-
+type ICEConfig struct {
// StunTurn is a list of STUN and TURN URLs
- StunTurn []*stun.URI
+ StunTurn *atomic.Value // []*stun.URI
// InterfaceBlackList is a list of machine interfaces that should be filtered out by ICE Candidate gathering
// (e.g. if eth0 is in the list, host candidate of this interface won't be used)
InterfaceBlackList []string
DisableIPv6Discovery bool
+ UDPMux ice.UDPMux
+ UDPMuxSrflx ice.UniversalUDPMux
+
+ NATExternalIPs []string
+}
+
+// ConnConfig is a peer Connection configuration
+type ConnConfig struct {
+ // Key is a public key of a remote peer
+ Key string
+ // LocalKey is a public key of a local peer
+ LocalKey string
+
Timeout time.Duration
WgConfig WgConfig
- UDPMux ice.UDPMux
- UDPMuxSrflx ice.UniversalUDPMux
-
LocalWgPort int
- NATExternalIPs []string
-
// RosenpassPubKey is this peer's Rosenpass public key
RosenpassPubKey []byte
// RosenpassPubKey is this peer's RosenpassAddr server address (IP:port)
RosenpassAddr string
+
+ // ICEConfig ICE protocol configuration
+ ICEConfig ICEConfig
}
// OfferAnswer represents a session establishment offer or answer
@@ -146,11 +151,6 @@ func (conn *Conn) WgConfig() WgConfig {
return conn.config.WgConfig
}
-// UpdateStunTurn update the turn and stun addresses
-func (conn *Conn) UpdateStunTurn(turnStun []*stun.URI) {
- conn.config.StunTurn = turnStun
-}
-
// NewConn creates a new not opened Conn to the remote peer.
// To establish a connection run Conn.Open
func NewConn(config ConnConfig, statusRecorder *Status, wgProxyFactory *wgproxy.Factory, adapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover) (*Conn, error) {
@@ -187,20 +187,20 @@ func (conn *Conn) reCreateAgent() error {
agentConfig := &ice.AgentConfig{
MulticastDNSMode: ice.MulticastDNSModeDisabled,
NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6},
- Urls: conn.config.StunTurn,
+ Urls: conn.config.ICEConfig.StunTurn.Load().([]*stun.URI),
CandidateTypes: conn.candidateTypes(),
FailedTimeout: &failedTimeout,
- InterfaceFilter: stdnet.InterfaceFilter(conn.config.InterfaceBlackList),
- UDPMux: conn.config.UDPMux,
- UDPMuxSrflx: conn.config.UDPMuxSrflx,
- NAT1To1IPs: conn.config.NATExternalIPs,
+ InterfaceFilter: stdnet.InterfaceFilter(conn.config.ICEConfig.InterfaceBlackList),
+ UDPMux: conn.config.ICEConfig.UDPMux,
+ UDPMuxSrflx: conn.config.ICEConfig.UDPMuxSrflx,
+ NAT1To1IPs: conn.config.ICEConfig.NATExternalIPs,
Net: transportNet,
DisconnectedTimeout: &iceDisconnectedTimeout,
KeepaliveInterval: &iceKeepAlive,
RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait,
}
- if conn.config.DisableIPv6Discovery {
+ if conn.config.ICEConfig.DisableIPv6Discovery {
agentConfig.NetworkTypes = []ice.NetworkType{ice.NetworkTypeUDP4}
}
@@ -480,7 +480,7 @@ func (conn *Conn) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
return
}
- mux, ok := conn.config.UDPMuxSrflx.(*bind.UniversalUDPMuxDefault)
+ mux, ok := conn.config.ICEConfig.UDPMuxSrflx.(*bind.UniversalUDPMuxDefault)
if !ok {
log.Warn("invalid udp mux conversion")
return
diff --git a/client/internal/peer/conn_test.go b/client/internal/peer/conn_test.go
index b608a5929..c124208d1 100644
--- a/client/internal/peer/conn_test.go
+++ b/client/internal/peer/conn_test.go
@@ -7,7 +7,6 @@ import (
"time"
"github.com/magiconair/properties/assert"
- "github.com/pion/stun/v2"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/internal/wgproxy"
@@ -15,12 +14,13 @@ import (
)
var connConf = ConnConfig{
- Key: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
- LocalKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
- StunTurn: []*stun.URI{},
- InterfaceBlackList: nil,
- Timeout: time.Second,
- LocalWgPort: 51820,
+ Key: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
+ LocalKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
+ Timeout: time.Second,
+ LocalWgPort: 51820,
+ ICEConfig: ICEConfig{
+ InterfaceBlackList: nil,
+ },
}
func TestNewConn_interfaceFilter(t *testing.T) {
diff --git a/client/internal/peer/stdnet.go b/client/internal/peer/stdnet.go
index 13f5886f5..1faa30ce3 100644
--- a/client/internal/peer/stdnet.go
+++ b/client/internal/peer/stdnet.go
@@ -7,5 +7,5 @@ import (
)
func (conn *Conn) newStdNet() (*stdnet.Net, error) {
- return stdnet.NewNet(conn.config.InterfaceBlackList)
+ return stdnet.NewNet(conn.config.ICEConfig.InterfaceBlackList)
}
diff --git a/client/internal/peer/stdnet_android.go b/client/internal/peer/stdnet_android.go
index 8a2454371..90865242b 100644
--- a/client/internal/peer/stdnet_android.go
+++ b/client/internal/peer/stdnet_android.go
@@ -3,5 +3,5 @@ package peer
import "github.com/netbirdio/netbird/client/internal/stdnet"
func (conn *Conn) newStdNet() (*stdnet.Net, error) {
- return stdnet.NewNetWithDiscover(conn.iFaceDiscover, conn.config.InterfaceBlackList)
+ return stdnet.NewNetWithDiscover(conn.iFaceDiscover, conn.config.ICEConfig.InterfaceBlackList)
}
From c52b406afa192daf3e680cf60dcd620efe942903 Mon Sep 17 00:00:00 2001
From: Maycon Santos
Date: Wed, 4 Sep 2024 19:22:33 +0200
Subject: [PATCH 25/89] [client] Avoid deadlock when auto connect and early
exit (#2528)
---
client/internal/connect.go | 13 ++++++++-----
client/server/server.go | 33 ++++++++++++++++++++-------------
2 files changed, 28 insertions(+), 18 deletions(-)
diff --git a/client/internal/connect.go b/client/internal/connect.go
index 62fd3c61d..6e1994f96 100644
--- a/client/internal/connect.go
+++ b/client/internal/connect.go
@@ -61,9 +61,9 @@ func (c *ConnectClient) Run() error {
// RunWithProbes runs the client's main logic with probes attached
func (c *ConnectClient) RunWithProbes(
probes *ProbeHolder,
- runningWg *sync.WaitGroup,
+ runningChan chan error,
) error {
- return c.run(MobileDependency{}, probes, runningWg)
+ return c.run(MobileDependency{}, probes, runningChan)
}
// RunOnAndroid with main logic on mobile system
@@ -104,7 +104,7 @@ func (c *ConnectClient) RunOniOS(
func (c *ConnectClient) run(
mobileDependency MobileDependency,
probes *ProbeHolder,
- runningWg *sync.WaitGroup,
+ runningChan chan error,
) error {
defer func() {
if r := recover(); r != nil {
@@ -195,6 +195,7 @@ func (c *ConnectClient) run(
log.Debug(err)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
state.Set(StatusNeedsLogin)
+ _ = c.Stop()
return backoff.Permanent(wrapErr(err)) // unrecoverable error
}
return wrapErr(err)
@@ -263,8 +264,9 @@ func (c *ConnectClient) run(
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
state.Set(StatusConnected)
- if runningWg != nil {
- runningWg.Done()
+ if runningChan != nil {
+ runningChan <- nil
+ close(runningChan)
}
<-engineCtx.Done()
@@ -287,6 +289,7 @@ func (c *ConnectClient) run(
log.Debugf("exiting client retry loop due to unrecoverable error: %s", err)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
state.Set(StatusNeedsLogin)
+ _ = c.Stop()
}
return err
}
diff --git a/client/server/server.go b/client/server/server.go
index ce6e90864..d8d32e1ce 100644
--- a/client/server/server.go
+++ b/client/server/server.go
@@ -142,13 +142,11 @@ func (s *Server) Start() error {
s.sessionWatcher.SetOnExpireListener(s.onSessionExpire)
}
- runningWg := sync.WaitGroup{}
- runningWg.Add(1)
- if !config.DisableAutoConnect {
- go s.connectWithRetryRuns(ctx, config, s.statusRecorder, &runningWg)
+ if config.DisableAutoConnect {
+ return nil
}
- runningWg.Wait()
+ go s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil)
return nil
}
@@ -157,7 +155,7 @@ func (s *Server) Start() error {
// mechanism to keep the client connected even when the connection is lost.
// we cancel retry if the client receive a stop or down command, or if disable auto connect is configured.
func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Config, statusRecorder *peer.Status,
- runningWg *sync.WaitGroup,
+ runningChan chan error,
) {
backOff := getConnectWithBackoff(ctx)
retryStarted := false
@@ -196,7 +194,7 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Conf
WgProbe: s.wgProbe,
}
- err := s.connectClient.RunWithProbes(&probes, runningWg)
+ err := s.connectClient.RunWithProbes(&probes, runningChan)
if err != nil {
log.Debugf("run client connection exited with error: %v. Will retry in the background", err)
}
@@ -587,13 +585,22 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes
s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String())
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)
- runningWg := sync.WaitGroup{}
- runningWg.Add(1)
- go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, &runningWg)
+ runningChan := make(chan error)
+ go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, runningChan)
- runningWg.Wait()
-
- return &proto.UpResponse{}, nil
+ for {
+ select {
+ case err := <-runningChan:
+ if err != nil {
+ log.Debugf("waiting for engine to become ready failed: %s", err)
+ } else {
+ return &proto.UpResponse{}, nil
+ }
+ case <-callerCtx.Done():
+ log.Debug("context done, stopping the wait for engine to become ready")
+ return nil, callerCtx.Err()
+ }
+ }
}
// Down engine work in the daemon.
From f2b5b2e9b55d093852d04e563e503ecf5e56a1fd Mon Sep 17 00:00:00 2001
From: Gianluca Boiano <491117+M0Rf30@users.noreply.github.com>
Date: Wed, 4 Sep 2024 19:22:52 +0200
Subject: [PATCH 26/89] [misc] Support rpm-ostree based distros in installation
script (#2508)
* Detect rpm-ostree-based distro and use proper package manager
* Update kardianos/service module to fix folders detection
---
go.mod | 4 ++--
go.sum | 4 ++--
release_files/install.sh | 10 ++++++++++
3 files changed, 14 insertions(+), 4 deletions(-)
diff --git a/go.mod b/go.mod
index cbe32427b..9e440e342 100644
--- a/go.mod
+++ b/go.mod
@@ -10,7 +10,7 @@ require (
github.com/golang/protobuf v1.5.4
github.com/google/uuid v1.6.0
github.com/gorilla/mux v1.8.0
- github.com/kardianos/service v1.2.1-0.20210728001519-a323c3813bc7
+ github.com/kardianos/service v1.2.3-0.20240613133416-becf2eb62b83
github.com/onsi/ginkgo v1.16.5
github.com/onsi/gomega v1.23.0
github.com/pion/ice/v3 v3.0.2
@@ -205,7 +205,7 @@ require (
k8s.io/apimachinery v0.26.2 // indirect
)
-replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0
+replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20240904111318-17777758453a
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
diff --git a/go.sum b/go.sum
index 6625dbb71..916f1f0c8 100644
--- a/go.sum
+++ b/go.sum
@@ -475,8 +475,8 @@ github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6R
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e h1:LYxhAmiEzSldLELHSMVoUnRPq3ztTNQImrD27frrGsI=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y=
-github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g=
-github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
+github.com/netbirdio/service v0.0.0-20240904111318-17777758453a h1:2EcDFDT39Odz5EC38pOSyjCd3bLUjPi7pMQpH6k+zzk=
+github.com/netbirdio/service v0.0.0-20240904111318-17777758453a/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240820130728-bc0683599080 h1:mXJkoWLdqJTlkQ7DgQ536kcXHXIdUPeagkN8i4eFDdg=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240820130728-bc0683599080/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs=
diff --git a/release_files/install.sh b/release_files/install.sh
index d9d436ba5..7b6774d84 100755
--- a/release_files/install.sh
+++ b/release_files/install.sh
@@ -242,6 +242,13 @@ install_netbird() {
${SUDO} dnf -y install netbird-ui
fi
;;
+ rpm-ostree)
+ add_rpm_repo
+ ${SUDO} rpm-ostree -y install netbird
+ if ! $SKIP_UI_APP; then
+ ${SUDO} rpm-ostree -y install netbird-ui
+ fi
+ ;;
pacman)
${SUDO} pacman -Syy
add_aur_repo
@@ -403,6 +410,9 @@ if type uname >/dev/null 2>&1; then
elif [ -x "$(command -v dnf)" ]; then
PACKAGE_MANAGER="dnf"
echo "The installation will be performed using dnf package manager"
+ elif [ -x "$(command -v rpm-ostree)" ]; then
+ PACKAGE_MANAGER="rpm-ostree"
+ echo "The installation will be performed using rpm-ostree package manager"
elif [ -x "$(command -v yum)" ]; then
PACKAGE_MANAGER="yum"
echo "The installation will be performed using yum package manager"
From bdbd1db843397082447342f86b9ee10223bef8b8 Mon Sep 17 00:00:00 2001
From: Maycon Santos
Date: Thu, 5 Sep 2024 15:09:46 +0200
Subject: [PATCH 27/89] [client] Avoid panic when there is no conn client
(#2541)
---
client/internal/connect.go | 10 ++++++++++
1 file changed, 10 insertions(+)
diff --git a/client/internal/connect.go b/client/internal/connect.go
index 6e1994f96..5dacde746 100644
--- a/client/internal/connect.go
+++ b/client/internal/connect.go
@@ -297,6 +297,9 @@ func (c *ConnectClient) run(
}
func (c *ConnectClient) Engine() *Engine {
+ if c == nil {
+ return nil
+ }
var e *Engine
c.engineMutex.Lock()
e = c.engine
@@ -305,8 +308,15 @@ func (c *ConnectClient) Engine() *Engine {
}
func (c *ConnectClient) Stop() error {
+ if c == nil {
+ return nil
+ }
c.engineMutex.Lock()
defer c.engineMutex.Unlock()
+
+ if c.engine == nil {
+ return nil
+ }
return c.engine.Stop()
}
From a33b11946df27965c2cc1f9949cb6e845dedbebe Mon Sep 17 00:00:00 2001
From: Maycon Santos
Date: Thu, 5 Sep 2024 22:28:31 +0200
Subject: [PATCH 28/89] [misc] Update slack url (#2544)
* Update slack url
* correct url
---
README.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/README.md b/README.md
index 1c5e76627..aa3ec41e5 100644
--- a/README.md
+++ b/README.md
@@ -30,7 +30,7 @@
See Documentation
- Join our Slack channel
+ Join our Slack channel
From fcf150f70421f3f6d26f19cf06c5334e2ee57d61 Mon Sep 17 00:00:00 2001
From: Eduard Gert
Date: Fri, 6 Sep 2024 15:39:08 +0200
Subject: [PATCH 29/89] Use X-Frame-Options sameorigin header (#2547)
---
infrastructure_files/getting-started-with-zitadel.sh | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/infrastructure_files/getting-started-with-zitadel.sh b/infrastructure_files/getting-started-with-zitadel.sh
index 5c33e2db6..1aae212ee 100644
--- a/infrastructure_files/getting-started-with-zitadel.sh
+++ b/infrastructure_files/getting-started-with-zitadel.sh
@@ -541,7 +541,7 @@ renderCaddyfile() {
# clickjacking protection
# https://cheatsheetseries.owasp.org/cheatsheets/HTTP_Headers_Cheat_Sheet.html#x-frame-options
- X-Frame-Options "DENY"
+ X-Frame-Options "SAMEORIGIN"
# xss protection
# https://cheatsheetseries.owasp.org/cheatsheets/HTTP_Headers_Cheat_Sheet.html#x-xss-protection
From a7e46bf7b18df9fbfe3e46ae615b9bd38377bb18 Mon Sep 17 00:00:00 2001
From: Maycon Santos
Date: Fri, 6 Sep 2024 16:28:19 +0200
Subject: [PATCH 30/89] Reduce test logs (#2550)
---
management/server/route_test.go | 57 ++++++++++++++++-------------
management/server/scheduler_test.go | 1 +
management/server/updatechannel.go | 2 +-
3 files changed, 33 insertions(+), 27 deletions(-)
diff --git a/management/server/route_test.go b/management/server/route_test.go
index 47dc4d078..506bfb0a8 100644
--- a/management/server/route_test.go
+++ b/management/server/route_test.go
@@ -1272,11 +1272,12 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
}
peer1 := &nbpeer.Peer{
- IP: peer1IP,
- ID: peer1ID,
- Key: peer1Key,
- Name: "test-host1@netbird.io",
- UserID: userID,
+ IP: peer1IP,
+ ID: peer1ID,
+ Key: peer1Key,
+ Name: "test-host1@netbird.io",
+ DNSLabel: "test-host1",
+ UserID: userID,
Meta: nbpeer.PeerSystemMeta{
Hostname: "test-host1@netbird.io",
GoOS: "linux",
@@ -1298,11 +1299,12 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
}
peer2 := &nbpeer.Peer{
- IP: peer2IP,
- ID: peer2ID,
- Key: peer2Key,
- Name: "test-host2@netbird.io",
- UserID: userID,
+ IP: peer2IP,
+ ID: peer2ID,
+ Key: peer2Key,
+ Name: "test-host2@netbird.io",
+ DNSLabel: "test-host2",
+ UserID: userID,
Meta: nbpeer.PeerSystemMeta{
Hostname: "test-host2@netbird.io",
GoOS: "linux",
@@ -1324,11 +1326,12 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
}
peer3 := &nbpeer.Peer{
- IP: peer3IP,
- ID: peer3ID,
- Key: peer3Key,
- Name: "test-host3@netbird.io",
- UserID: userID,
+ IP: peer3IP,
+ ID: peer3ID,
+ Key: peer3Key,
+ Name: "test-host3@netbird.io",
+ DNSLabel: "test-host3",
+ UserID: userID,
Meta: nbpeer.PeerSystemMeta{
Hostname: "test-host3@netbird.io",
GoOS: "darwin",
@@ -1350,11 +1353,12 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
}
peer4 := &nbpeer.Peer{
- IP: peer4IP,
- ID: peer4ID,
- Key: peer4Key,
- Name: "test-host4@netbird.io",
- UserID: userID,
+ IP: peer4IP,
+ ID: peer4ID,
+ Key: peer4Key,
+ Name: "test-host4@netbird.io",
+ DNSLabel: "test-host4",
+ UserID: userID,
Meta: nbpeer.PeerSystemMeta{
Hostname: "test-host4@netbird.io",
GoOS: "linux",
@@ -1376,13 +1380,14 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
}
peer5 := &nbpeer.Peer{
- IP: peer5IP,
- ID: peer5ID,
- Key: peer5Key,
- Name: "test-host4@netbird.io",
- UserID: userID,
+ IP: peer5IP,
+ ID: peer5ID,
+ Key: peer5Key,
+ Name: "test-host5@netbird.io",
+ DNSLabel: "test-host5",
+ UserID: userID,
Meta: nbpeer.PeerSystemMeta{
- Hostname: "test-host4@netbird.io",
+ Hostname: "test-host5@netbird.io",
GoOS: "linux",
Kernel: "Linux",
Core: "21.04",
diff --git a/management/server/scheduler_test.go b/management/server/scheduler_test.go
index 7c287a554..fa279d4db 100644
--- a/management/server/scheduler_test.go
+++ b/management/server/scheduler_test.go
@@ -63,6 +63,7 @@ func TestScheduler_Cancel(t *testing.T) {
scheduler.Schedule(context.Background(), scheduletime, jobID2, func() (nextRunIn time.Duration, reschedule bool) {
return scheduletime, true
})
+ defer scheduler.Cancel(context.Background(), []string{jobID2})
time.Sleep(sleepTime)
assert.Len(t, scheduler.jobs, 2)
diff --git a/management/server/updatechannel.go b/management/server/updatechannel.go
index c11225dbc..0188cef52 100644
--- a/management/server/updatechannel.go
+++ b/management/server/updatechannel.go
@@ -55,7 +55,7 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda
log.WithContext(ctx).Debugf("update was sent to channel for peer %s", peerID)
default:
dropped = true
- log.WithContext(ctx).Warnf("channel for peer %s is %d full", peerID, len(channel))
+ log.WithContext(ctx).Warnf("channel for peer %s is %d full or closed", peerID, len(channel))
}
} else {
log.WithContext(ctx).Debugf("peer %s has no channel", peerID)
From fcac02a92f42014907face0a95cb655581165efd Mon Sep 17 00:00:00 2001
From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com>
Date: Fri, 6 Sep 2024 19:04:34 +0200
Subject: [PATCH 31/89] add log (#2546)
---
client/internal/engine.go | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/client/internal/engine.go b/client/internal/engine.go
index 52912f11e..0d80806a4 100644
--- a/client/internal/engine.go
+++ b/client/internal/engine.go
@@ -244,7 +244,7 @@ func (e *Engine) Stop() error {
err := e.removeAllPeers()
if err != nil {
- return err
+ return fmt.Errorf("failed to remove all peers: %s", err)
}
e.clientRoutesMu.Lock()
From 0c039274a41ea0d2284545f49527ce0c2a38386b Mon Sep 17 00:00:00 2001
From: Zoltan Papp
Date: Sun, 8 Sep 2024 12:06:14 +0200
Subject: [PATCH 32/89] [relay] Feature/relay integration (#2244)
This update adds new relay integration for NetBird clients. The new relay is based on web sockets and listens on a single port.
- Adds new relay implementation with websocket with single port relaying mechanism
- refactor peer connection logic, allowing upgrade and downgrade from/to P2P connection
- peer connections are faster since it connects first to relay and then upgrades to P2P
- maintains compatibility with old clients by not using the new relay
- updates infrastructure scripts with new relay service
---
.../workflows/test-infrastructure-files.yml | 30 +-
.goreleaser.yaml | 74 +-
client/cmd/status.go | 27 +-
client/cmd/status_test.go | 16 +-
client/cmd/testutil_test.go | 5 +-
client/cmd/up.go | 5 +-
client/internal/connect.go | 56 +-
client/internal/engine.go | 184 +--
client/internal/engine_test.go | 77 +-
client/internal/peer/conn.go | 1326 +++++++++--------
client/internal/peer/conn_test.go | 71 +-
client/internal/peer/handshaker.go | 192 +++
client/internal/peer/signaler.go | 70 +
client/internal/peer/status.go | 194 ++-
client/internal/peer/status_test.go | 10 +-
client/internal/peer/stdnet.go | 4 +-
client/internal/peer/stdnet_android.go | 4 +-
client/internal/peer/worker_ice.go | 470 ++++++
client/internal/peer/worker_relay.go | 223 +++
client/internal/relay/relay.go | 4 +-
client/internal/routemanager/client.go | 8 +-
client/internal/routemanager/client_test.go | 43 -
client/internal/routemanager/manager.go | 13 +-
client/internal/routemanager/manager_test.go | 2 +-
client/internal/wgproxy/proxy_ebpf.go | 4 +-
client/internal/wgproxy/proxy_userspace.go | 14 +-
client/ios/NetBirdSDK/client.go | 1 -
client/proto/daemon.pb.go | 423 +++---
client/proto/daemon.proto | 2 +-
client/server/debug.go | 4 +-
client/server/server.go | 4 +-
client/server/server_test.go | 8 +-
client/testdata/management.json | 9 +-
encryption/cert.go | 19 +
encryption/letsencrypt.go | 4 +-
encryption/route53.go | 87 ++
encryption/route53_test.go | 84 ++
go.mod | 35 +-
go.sum | 79 +-
infrastructure_files/base.setup.env | 12 +-
infrastructure_files/configure.sh | 5 +
infrastructure_files/docker-compose.yml.tmpl | 17 +
.../getting-started-with-zitadel.sh | 49 +-
infrastructure_files/management.json.tmpl | 5 +
infrastructure_files/setup.env.example | 12 +
infrastructure_files/tests/setup.env | 3 +-
management/client/client_test.go | 5 +-
management/cmd/management.go | 8 +-
management/cmd/management_test.go | 59 +
management/proto/management.pb.go | 859 ++++++-----
management/proto/management.proto | 9 +
management/server/config.go | 7 +
management/server/grpcserver.go | 144 +-
management/server/management_proto_test.go | 5 +-
management/server/management_test.go | 5 +-
management/server/peer.go | 2 +-
management/server/peer_test.go | 8 +-
management/server/token_mgr.go | 222 +++
management/server/token_mgr_test.go | 218 +++
management/server/turncredentials.go | 126 --
management/server/turncredentials_test.go | 136 --
relay/Dockerfile | 4 +
relay/auth/allow/allow_all.go | 12 +
relay/auth/doc.go | 26 +
relay/auth/hmac/doc.go | 8 +
relay/auth/hmac/store.go | 36 +
relay/auth/hmac/token.go | 105 ++
relay/auth/hmac/token_test.go | 105 ++
relay/auth/hmac/validator.go | 33 +
relay/auth/validator.go | 8 +
relay/client/addr.go | 13 +
relay/client/client.go | 553 +++++++
relay/client/client_test.go | 631 ++++++++
relay/client/conn.go | 76 +
relay/client/dialer/ws/addr.go | 13 +
relay/client/dialer/ws/conn.go | 66 +
relay/client/dialer/ws/ws.go | 67 +
relay/client/doc.go | 12 +
relay/client/guard.go | 48 +
relay/client/manager.go | 365 +++++
relay/client/manager_test.go | 432 ++++++
relay/cmd/env.go | 35 +
relay/cmd/root.go | 214 +++
relay/doc.go | 14 +
relay/healthcheck/doc.go | 17 +
relay/healthcheck/receiver.go | 82 +
relay/healthcheck/receiver_test.go | 42 +
relay/healthcheck/sender.go | 68 +
relay/healthcheck/sender_test.go | 103 ++
relay/main.go | 13 +
relay/messages/address/address.go | 30 +
relay/messages/auth/auth.go | 51 +
relay/messages/doc.go | 5 +
relay/messages/id.go | 31 +
relay/messages/id_test.go | 13 +
relay/messages/message.go | 239 +++
relay/messages/message_test.go | 43 +
relay/metrics/realy.go | 136 ++
relay/server/listener/listener.go | 11 +
relay/server/listener/ws/conn.go | 114 ++
relay/server/listener/ws/listener.go | 92 ++
relay/server/peer.go | 203 +++
relay/server/relay.go | 206 +++
relay/server/relay_test.go | 36 +
relay/server/server.go | 76 +
relay/server/store.go | 64 +
relay/server/store_test.go | 40 +
relay/test/benchmark_test.go | 386 +++++
relay/testec2/main.go | 258 ++++
relay/testec2/relay.go | 176 +++
relay/testec2/signal.go | 91 ++
relay/testec2/start_msg.go | 39 +
relay/testec2/tun/proxy.go | 72 +
relay/testec2/tun/tun.go | 110 ++
relay/testec2/turn.go | 181 +++
relay/testec2/turn_allocator.go | 83 ++
signal/client/client.go | 6 +-
signal/proto/signalexchange.pb.go | 20 +-
signal/proto/signalexchange.proto | 3 +
util/net/dialer_nonios.go | 2 +
120 files changed, 9879 insertions(+), 1940 deletions(-)
create mode 100644 client/internal/peer/handshaker.go
create mode 100644 client/internal/peer/signaler.go
create mode 100644 client/internal/peer/worker_ice.go
create mode 100644 client/internal/peer/worker_relay.go
create mode 100644 encryption/cert.go
create mode 100644 encryption/route53.go
create mode 100644 encryption/route53_test.go
create mode 100644 management/cmd/management_test.go
create mode 100644 management/server/token_mgr.go
create mode 100644 management/server/token_mgr_test.go
delete mode 100644 management/server/turncredentials.go
delete mode 100644 management/server/turncredentials_test.go
create mode 100644 relay/Dockerfile
create mode 100644 relay/auth/allow/allow_all.go
create mode 100644 relay/auth/doc.go
create mode 100644 relay/auth/hmac/doc.go
create mode 100644 relay/auth/hmac/store.go
create mode 100644 relay/auth/hmac/token.go
create mode 100644 relay/auth/hmac/token_test.go
create mode 100644 relay/auth/hmac/validator.go
create mode 100644 relay/auth/validator.go
create mode 100644 relay/client/addr.go
create mode 100644 relay/client/client.go
create mode 100644 relay/client/client_test.go
create mode 100644 relay/client/conn.go
create mode 100644 relay/client/dialer/ws/addr.go
create mode 100644 relay/client/dialer/ws/conn.go
create mode 100644 relay/client/dialer/ws/ws.go
create mode 100644 relay/client/doc.go
create mode 100644 relay/client/guard.go
create mode 100644 relay/client/manager.go
create mode 100644 relay/client/manager_test.go
create mode 100644 relay/cmd/env.go
create mode 100644 relay/cmd/root.go
create mode 100644 relay/doc.go
create mode 100644 relay/healthcheck/doc.go
create mode 100644 relay/healthcheck/receiver.go
create mode 100644 relay/healthcheck/receiver_test.go
create mode 100644 relay/healthcheck/sender.go
create mode 100644 relay/healthcheck/sender_test.go
create mode 100644 relay/main.go
create mode 100644 relay/messages/address/address.go
create mode 100644 relay/messages/auth/auth.go
create mode 100644 relay/messages/doc.go
create mode 100644 relay/messages/id.go
create mode 100644 relay/messages/id_test.go
create mode 100644 relay/messages/message.go
create mode 100644 relay/messages/message_test.go
create mode 100644 relay/metrics/realy.go
create mode 100644 relay/server/listener/listener.go
create mode 100644 relay/server/listener/ws/conn.go
create mode 100644 relay/server/listener/ws/listener.go
create mode 100644 relay/server/peer.go
create mode 100644 relay/server/relay.go
create mode 100644 relay/server/relay_test.go
create mode 100644 relay/server/server.go
create mode 100644 relay/server/store.go
create mode 100644 relay/server/store_test.go
create mode 100644 relay/test/benchmark_test.go
create mode 100644 relay/testec2/main.go
create mode 100644 relay/testec2/relay.go
create mode 100644 relay/testec2/signal.go
create mode 100644 relay/testec2/start_msg.go
create mode 100644 relay/testec2/tun/proxy.go
create mode 100644 relay/testec2/tun/tun.go
create mode 100644 relay/testec2/turn.go
create mode 100644 relay/testec2/turn_allocator.go
diff --git a/.github/workflows/test-infrastructure-files.yml b/.github/workflows/test-infrastructure-files.yml
index f758e74bd..03ecbd445 100644
--- a/.github/workflows/test-infrastructure-files.yml
+++ b/.github/workflows/test-infrastructure-files.yml
@@ -150,6 +150,13 @@ jobs:
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep -A 3 RedirectURLs | grep "http://localhost:53000"
grep "external-ip" turnserver.conf | grep $CI_NETBIRD_TURN_EXTERNAL_IP
grep NETBIRD_STORE_ENGINE_POSTGRES_DSN docker-compose.yml | egrep "$NETBIRD_STORE_ENGINE_POSTGRES_DSN"
+ # check relay values
+ grep "NB_EXPOSED_ADDRESS=$CI_NETBIRD_DOMAIN:33445" docker-compose.yml
+ grep "NB_LISTEN_ADDRESS=:33445" docker-compose.yml
+ grep '33445:33445' docker-compose.yml
+ grep -A 10 'relay:' docker-compose.yml | egrep 'NB_AUTH_SECRET=.+$'
+ grep -A 7 Relay management.json | grep "rel://$CI_NETBIRD_DOMAIN:33445"
+ grep -A 7 Relay management.json | egrep '"Secret": ".+"'
- name: Install modules
run: go mod tidy
@@ -175,6 +182,15 @@ jobs:
run: |
docker build -t netbirdio/signal:latest .
+ - name: Build relay binary
+ working-directory: relay
+ run: CGO_ENABLED=0 go build -o netbird-relay main.go
+
+ - name: Build relay docker image
+ working-directory: relay
+ run: |
+ docker build -t netbirdio/relay:latest .
+
- name: run docker compose up
working-directory: infrastructure_files/artifacts
run: |
@@ -186,7 +202,7 @@ jobs:
- name: test running containers
run: |
count=$(docker compose ps --format json | jq '. | select(.Name | contains("artifacts")) | .State' | grep -c running)
- test $count -eq 4 || docker compose logs
+ test $count -eq 5 || docker compose logs
working-directory: infrastructure_files/artifacts
- name: test geolocation databases
@@ -205,6 +221,9 @@ jobs:
- name: Checkout code
uses: actions/checkout@v3
+ - name: handle insisting image # remove after release
+ run: docker pull netbirdio/relay:latest || docker pull netbirdio/signal:latest && docker tag netbirdio/signal:latest netbirdio/relay:latest
+
- name: run script with Zitadel PostgreSQL
run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh
@@ -229,6 +248,9 @@ jobs:
- name: test dashboard.env file gen postgres
run: test -f dashboard.env
+ - name: test relay.env file gen postgres
+ run: test -f relay.env
+
- name: test zdb.env file gen postgres
run: test -f zdb.env
@@ -237,6 +259,9 @@ jobs:
docker compose down --volumes --rmi all
rm -rf docker-compose.yml Caddyfile zitadel.env dashboard.env machinekey/zitadel-admin-sa.token turnserver.conf management.json zdb.env
+ - name: handle insisting image gen CockroachDB # remove after release
+ run: docker pull netbirdio/relay:latest || docker pull netbirdio/signal:latest && docker tag netbirdio/signal:latest netbirdio/relay:latest
+
- name: run script with Zitadel CockroachDB
run: bash -x infrastructure_files/getting-started-with-zitadel.sh
env:
@@ -264,6 +289,9 @@ jobs:
- name: test dashboard.env file gen CockroachDB
run: test -f dashboard.env
+ - name: test relay.env file gen CockroachDB
+ run: test -f relay.env
+
test-download-geolite2-script:
runs-on: ubuntu-latest
steps:
diff --git a/.goreleaser.yaml b/.goreleaser.yaml
index 7a219110a..068864d6e 100644
--- a/.goreleaser.yaml
+++ b/.goreleaser.yaml
@@ -80,6 +80,20 @@ builds:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: '{{ .CommitTimestamp }}'
+ - id: netbird-relay
+ dir: relay
+ env: [CGO_ENABLED=0]
+ binary: netbird-relay
+ goos:
+ - linux
+ goarch:
+ - amd64
+ - arm64
+ - arm
+ ldflags:
+ - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
+ mod_timestamp: '{{ .CommitTimestamp }}'
+
archives:
- builds:
- netbird
@@ -161,6 +175,52 @@ dockers:
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
+ - image_templates:
+ - netbirdio/relay:{{ .Version }}-amd64
+ ids:
+ - netbird-relay
+ goarch: amd64
+ use: buildx
+ dockerfile: relay/Dockerfile
+ build_flag_templates:
+ - "--platform=linux/amd64"
+ - "--label=org.opencontainers.image.created={{.Date}}"
+ - "--label=org.opencontainers.image.title={{.ProjectName}}"
+ - "--label=org.opencontainers.image.version={{.Version}}"
+ - "--label=org.opencontainers.image.revision={{.FullCommit}}"
+ - "--label=org.opencontainers.image.version={{.Version}}"
+ - "--label=maintainer=dev@netbird.io"
+ - image_templates:
+ - netbirdio/relay:{{ .Version }}-arm64v8
+ ids:
+ - netbird-relay
+ goarch: arm64
+ use: buildx
+ dockerfile: relay/Dockerfile
+ build_flag_templates:
+ - "--platform=linux/arm64"
+ - "--label=org.opencontainers.image.created={{.Date}}"
+ - "--label=org.opencontainers.image.title={{.ProjectName}}"
+ - "--label=org.opencontainers.image.version={{.Version}}"
+ - "--label=org.opencontainers.image.revision={{.FullCommit}}"
+ - "--label=org.opencontainers.image.version={{.Version}}"
+ - "--label=maintainer=dev@netbird.io"
+ - image_templates:
+ - netbirdio/relay:{{ .Version }}-arm
+ ids:
+ - netbird-relay
+ goarch: arm
+ goarm: 6
+ use: buildx
+ dockerfile: relay/Dockerfile
+ build_flag_templates:
+ - "--platform=linux/arm"
+ - "--label=org.opencontainers.image.created={{.Date}}"
+ - "--label=org.opencontainers.image.title={{.ProjectName}}"
+ - "--label=org.opencontainers.image.version={{.Version}}"
+ - "--label=org.opencontainers.image.revision={{.FullCommit}}"
+ - "--label=org.opencontainers.image.version={{.Version}}"
+ - "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/signal:{{ .Version }}-amd64
ids:
@@ -313,6 +373,18 @@ docker_manifests:
- netbirdio/netbird:{{ .Version }}-arm
- netbirdio/netbird:{{ .Version }}-amd64
+ - name_template: netbirdio/relay:{{ .Version }}
+ image_templates:
+ - netbirdio/relay:{{ .Version }}-arm64v8
+ - netbirdio/relay:{{ .Version }}-arm
+ - netbirdio/relay:{{ .Version }}-amd64
+
+ - name_template: netbirdio/relay:latest
+ image_templates:
+ - netbirdio/relay:{{ .Version }}-arm64v8
+ - netbirdio/relay:{{ .Version }}-arm
+ - netbirdio/relay:{{ .Version }}-amd64
+
- name_template: netbirdio/signal:{{ .Version }}
image_templates:
- netbirdio/signal:{{ .Version }}-arm64v8
@@ -386,4 +458,4 @@ checksum:
release:
extra_files:
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
- - glob: ./release_files/install.sh
\ No newline at end of file
+ - glob: ./release_files/install.sh
diff --git a/client/cmd/status.go b/client/cmd/status.go
index d9b7a9c91..1ef8b4913 100644
--- a/client/cmd/status.go
+++ b/client/cmd/status.go
@@ -31,9 +31,9 @@ type peerStateDetailOutput struct {
Status string `json:"status" yaml:"status"`
LastStatusUpdate time.Time `json:"lastStatusUpdate" yaml:"lastStatusUpdate"`
ConnType string `json:"connectionType" yaml:"connectionType"`
- Direct bool `json:"direct" yaml:"direct"`
IceCandidateType iceCandidateType `json:"iceCandidateType" yaml:"iceCandidateType"`
IceCandidateEndpoint iceCandidateType `json:"iceCandidateEndpoint" yaml:"iceCandidateEndpoint"`
+ RelayAddress string `json:"relayAddress" yaml:"relayAddress"`
LastWireguardHandshake time.Time `json:"lastWireguardHandshake" yaml:"lastWireguardHandshake"`
TransferReceived int64 `json:"transferReceived" yaml:"transferReceived"`
TransferSent int64 `json:"transferSent" yaml:"transferSent"`
@@ -335,16 +335,18 @@ func mapNSGroups(servers []*proto.NSGroupState) []nsServerGroupStateOutput {
func mapPeers(peers []*proto.PeerState) peersStateOutput {
var peersStateDetail []peerStateDetailOutput
- localICE := ""
- remoteICE := ""
- localICEEndpoint := ""
- remoteICEEndpoint := ""
- connType := ""
peersConnected := 0
- lastHandshake := time.Time{}
- transferReceived := int64(0)
- transferSent := int64(0)
for _, pbPeerState := range peers {
+ localICE := ""
+ remoteICE := ""
+ localICEEndpoint := ""
+ remoteICEEndpoint := ""
+ relayServerAddress := ""
+ connType := ""
+ lastHandshake := time.Time{}
+ transferReceived := int64(0)
+ transferSent := int64(0)
+
isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String()
if skipDetailByFilters(pbPeerState, isPeerConnected) {
continue
@@ -360,6 +362,7 @@ func mapPeers(peers []*proto.PeerState) peersStateOutput {
if pbPeerState.Relayed {
connType = "Relayed"
}
+ relayServerAddress = pbPeerState.GetRelayAddress()
lastHandshake = pbPeerState.GetLastWireguardHandshake().AsTime().Local()
transferReceived = pbPeerState.GetBytesRx()
transferSent = pbPeerState.GetBytesTx()
@@ -372,7 +375,6 @@ func mapPeers(peers []*proto.PeerState) peersStateOutput {
Status: pbPeerState.GetConnStatus(),
LastStatusUpdate: timeLocal,
ConnType: connType,
- Direct: pbPeerState.GetDirect(),
IceCandidateType: iceCandidateType{
Local: localICE,
Remote: remoteICE,
@@ -381,6 +383,7 @@ func mapPeers(peers []*proto.PeerState) peersStateOutput {
Local: localICEEndpoint,
Remote: remoteICEEndpoint,
},
+ RelayAddress: relayServerAddress,
FQDN: pbPeerState.GetFqdn(),
LastWireguardHandshake: lastHandshake,
TransferReceived: transferReceived,
@@ -641,9 +644,9 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
" Status: %s\n"+
" -- detail --\n"+
" Connection type: %s\n"+
- " Direct: %t\n"+
" ICE candidate (Local/Remote): %s/%s\n"+
" ICE candidate endpoints (Local/Remote): %s/%s\n"+
+ " Relay server address: %s\n"+
" Last connection update: %s\n"+
" Last WireGuard handshake: %s\n"+
" Transfer status (received/sent) %s/%s\n"+
@@ -655,11 +658,11 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
peerState.PubKey,
peerState.Status,
peerState.ConnType,
- peerState.Direct,
localICE,
remoteICE,
localICEEndpoint,
remoteICEEndpoint,
+ peerState.RelayAddress,
timeAgo(peerState.LastStatusUpdate),
timeAgo(peerState.LastWireguardHandshake),
toIEC(peerState.TransferReceived),
diff --git a/client/cmd/status_test.go b/client/cmd/status_test.go
index 46620a956..ca43df8a5 100644
--- a/client/cmd/status_test.go
+++ b/client/cmd/status_test.go
@@ -37,7 +37,6 @@ var resp = &proto.StatusResponse{
ConnStatus: "Connected",
ConnStatusUpdate: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 1, 0, time.UTC)),
Relayed: false,
- Direct: true,
LocalIceCandidateType: "",
RemoteIceCandidateType: "",
LocalIceCandidateEndpoint: "",
@@ -57,7 +56,6 @@ var resp = &proto.StatusResponse{
ConnStatus: "Connected",
ConnStatusUpdate: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 2, 0, time.UTC)),
Relayed: true,
- Direct: false,
LocalIceCandidateType: "relay",
RemoteIceCandidateType: "prflx",
LocalIceCandidateEndpoint: "10.0.0.1:10001",
@@ -137,7 +135,6 @@ var overview = statusOutputOverview{
Status: "Connected",
LastStatusUpdate: time.Date(2001, 1, 1, 1, 1, 1, 0, time.UTC),
ConnType: "P2P",
- Direct: true,
IceCandidateType: iceCandidateType{
Local: "",
Remote: "",
@@ -161,7 +158,6 @@ var overview = statusOutputOverview{
Status: "Connected",
LastStatusUpdate: time.Date(2002, 2, 2, 2, 2, 2, 0, time.UTC),
ConnType: "Relayed",
- Direct: false,
IceCandidateType: iceCandidateType{
Local: "relay",
Remote: "prflx",
@@ -283,7 +279,6 @@ func TestParsingToJSON(t *testing.T) {
"status": "Connected",
"lastStatusUpdate": "2001-01-01T01:01:01Z",
"connectionType": "P2P",
- "direct": true,
"iceCandidateType": {
"local": "",
"remote": ""
@@ -292,6 +287,7 @@ func TestParsingToJSON(t *testing.T) {
"local": "",
"remote": ""
},
+ "relayAddress": "",
"lastWireguardHandshake": "2001-01-01T01:01:02Z",
"transferReceived": 200,
"transferSent": 100,
@@ -308,7 +304,6 @@ func TestParsingToJSON(t *testing.T) {
"status": "Connected",
"lastStatusUpdate": "2002-02-02T02:02:02Z",
"connectionType": "Relayed",
- "direct": false,
"iceCandidateType": {
"local": "relay",
"remote": "prflx"
@@ -317,6 +312,7 @@ func TestParsingToJSON(t *testing.T) {
"local": "10.0.0.1:10001",
"remote": "10.0.10.1:10002"
},
+ "relayAddress": "",
"lastWireguardHandshake": "2002-02-02T02:02:03Z",
"transferReceived": 2000,
"transferSent": 1000,
@@ -408,13 +404,13 @@ func TestParsingToYAML(t *testing.T) {
status: Connected
lastStatusUpdate: 2001-01-01T01:01:01Z
connectionType: P2P
- direct: true
iceCandidateType:
local: ""
remote: ""
iceCandidateEndpoint:
local: ""
remote: ""
+ relayAddress: ""
lastWireguardHandshake: 2001-01-01T01:01:02Z
transferReceived: 200
transferSent: 100
@@ -428,13 +424,13 @@ func TestParsingToYAML(t *testing.T) {
status: Connected
lastStatusUpdate: 2002-02-02T02:02:02Z
connectionType: Relayed
- direct: false
iceCandidateType:
local: relay
remote: prflx
iceCandidateEndpoint:
local: 10.0.0.1:10001
remote: 10.0.10.1:10002
+ relayAddress: ""
lastWireguardHandshake: 2002-02-02T02:02:03Z
transferReceived: 2000
transferSent: 1000
@@ -505,9 +501,9 @@ func TestParsingToDetail(t *testing.T) {
Status: Connected
-- detail --
Connection type: P2P
- Direct: true
ICE candidate (Local/Remote): -/-
ICE candidate endpoints (Local/Remote): -/-
+ Relay server address:
Last connection update: %s
Last WireGuard handshake: %s
Transfer status (received/sent) 200 B/100 B
@@ -521,9 +517,9 @@ func TestParsingToDetail(t *testing.T) {
Status: Connected
-- detail --
Connection type: Relayed
- Direct: false
ICE candidate (Local/Remote): relay/prflx
ICE candidate endpoints (Local/Remote): 10.0.0.1:10001/10.0.10.1:10002
+ Relay server address:
Last connection update: %s
Last WireGuard handshake: %s
Transfer status (received/sent) 2.0 KiB/1000 B
diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go
index 984aa6df7..780cc8b04 100644
--- a/client/cmd/testutil_test.go
+++ b/client/cmd/testutil_test.go
@@ -98,8 +98,9 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
if err != nil {
t.Fatal(err)
}
- turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
- mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
+
+ secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
+ mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, nil)
if err != nil {
t.Fatal(err)
}
diff --git a/client/cmd/up.go b/client/cmd/up.go
index 2ed6e41d2..b447f7141 100644
--- a/client/cmd/up.go
+++ b/client/cmd/up.go
@@ -168,7 +168,10 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
ctx, cancel = context.WithCancel(ctx)
SetupCloseHandler(ctx, cancel)
- connectClient := internal.NewConnectClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()))
+ r := peer.NewRecorder(config.ManagementURL.String())
+ r.GetFullStatus()
+
+ connectClient := internal.NewConnectClient(ctx, config, r)
return connectClient.Run()
}
diff --git a/client/internal/connect.go b/client/internal/connect.go
index 5dacde746..515321f7f 100644
--- a/client/internal/connect.go
+++ b/client/internal/connect.go
@@ -26,6 +26,8 @@ import (
"github.com/netbirdio/netbird/iface"
mgm "github.com/netbirdio/netbird/management/client"
mgmProto "github.com/netbirdio/netbird/management/proto"
+ "github.com/netbirdio/netbird/relay/auth/hmac"
+ relayClient "github.com/netbirdio/netbird/relay/client"
signal "github.com/netbirdio/netbird/signal/client"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/version"
@@ -158,10 +160,8 @@ func (c *ConnectClient) run(
defer c.statusRecorder.ClientStop()
operation := func() error {
// if context cancelled we not start new backoff cycle
- select {
- case <-c.ctx.Done():
+ if c.isContextCancelled() {
return nil
- default:
}
state.Set(StatusConnecting)
@@ -183,8 +183,7 @@ func (c *ConnectClient) run(
log.Debugf("connected to the Management service %s", c.config.ManagementURL.Host)
defer func() {
- err = mgmClient.Close()
- if err != nil {
+ if err = mgmClient.Close(); err != nil {
log.Warnf("failed to close the Management service client %v", err)
}
}()
@@ -208,7 +207,6 @@ func (c *ConnectClient) run(
KernelInterface: iface.WireGuardModuleIsLoaded(),
FQDN: loginResp.GetPeerConfig().GetFqdn(),
}
-
c.statusRecorder.UpdateLocalPeerState(localPeerState)
signalURL := fmt.Sprintf("%s://%s",
@@ -241,6 +239,23 @@ func (c *ConnectClient) run(
c.statusRecorder.MarkSignalConnected()
+ relayURLs, token := parseRelayInfo(loginResp)
+ relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String())
+ if len(relayURLs) > 0 {
+ if token != nil {
+ if err := relayManager.UpdateToken(token); err != nil {
+ log.Errorf("failed to update token: %s", err)
+ return wrapErr(err)
+ }
+ }
+ log.Infof("connecting to the Relay service(s): %s", strings.Join(relayURLs, ", "))
+ if err = relayManager.Serve(); err != nil {
+ log.Error(err)
+ return wrapErr(err)
+ }
+ c.statusRecorder.SetRelayMgr(relayManager)
+ }
+
peerConfig := loginResp.GetPeerConfig()
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig)
@@ -252,11 +267,11 @@ func (c *ConnectClient) run(
checks := loginResp.GetChecks()
c.engineMutex.Lock()
- c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, c.statusRecorder, probes, checks)
+ c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, probes, checks)
+
c.engineMutex.Unlock()
- err = c.engine.Start()
- if err != nil {
+ if err := c.engine.Start(); err != nil {
log.Errorf("error while starting Netbird Connection Engine: %s", err)
return wrapErr(err)
}
@@ -296,6 +311,20 @@ func (c *ConnectClient) run(
return nil
}
+func parseRelayInfo(loginResp *mgmProto.LoginResponse) ([]string, *hmac.Token) {
+ relayCfg := loginResp.GetWiretrusteeConfig().GetRelay()
+ if relayCfg == nil {
+ return nil, nil
+ }
+
+ token := &hmac.Token{
+ Payload: relayCfg.GetTokenPayload(),
+ Signature: relayCfg.GetTokenSignature(),
+ }
+
+ return relayCfg.GetUrls(), token
+}
+
func (c *ConnectClient) Engine() *Engine {
if c == nil {
return nil
@@ -320,6 +349,15 @@ func (c *ConnectClient) Stop() error {
return c.engine.Stop()
}
+func (c *ConnectClient) isContextCancelled() bool {
+ select {
+ case <-c.ctx.Done():
+ return true
+ default:
+ return false
+ }
+}
+
// createEngineConfig converts configuration received from Management Service to EngineConfig
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
nm := false
diff --git a/client/internal/engine.go b/client/internal/engine.go
index 0d80806a4..47a36c4bf 100644
--- a/client/internal/engine.go
+++ b/client/internal/engine.go
@@ -25,6 +25,7 @@ import (
"github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/dns"
+
"github.com/netbirdio/netbird/client/internal/networkmonitor"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/relay"
@@ -40,6 +41,8 @@ import (
mgm "github.com/netbirdio/netbird/management/client"
"github.com/netbirdio/netbird/management/domain"
mgmProto "github.com/netbirdio/netbird/management/proto"
+ auth "github.com/netbirdio/netbird/relay/auth/hmac"
+ relayClient "github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/route"
signal "github.com/netbirdio/netbird/signal/client"
sProto "github.com/netbirdio/netbird/signal/proto"
@@ -102,7 +105,8 @@ type EngineConfig struct {
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
type Engine struct {
// signal is a Signal Service client
- signal signal.Client
+ signal signal.Client
+ signaler *peer.Signaler
// mgmClient is a Management Service client
mgmClient mgm.Client
// peerConns is a map that holds all the peers that are known to this peer
@@ -159,10 +163,10 @@ type Engine struct {
probes *ProbeHolder
- wgConnWorker sync.WaitGroup
-
// checks are the client-applied posture checks that need to be evaluated on the client
checks []*mgmProto.Checks
+
+ relayManager *relayClient.Manager
}
// Peer is an instance of the Connection Peer
@@ -177,6 +181,7 @@ func NewEngine(
clientCancel context.CancelFunc,
signalClient signal.Client,
mgmClient mgm.Client,
+ relayManager *relayClient.Manager,
config *EngineConfig,
mobileDep MobileDependency,
statusRecorder *peer.Status,
@@ -187,6 +192,7 @@ func NewEngine(
clientCancel,
signalClient,
mgmClient,
+ relayManager,
config,
mobileDep,
statusRecorder,
@@ -201,18 +207,20 @@ func NewEngineWithProbes(
clientCancel context.CancelFunc,
signalClient signal.Client,
mgmClient mgm.Client,
+ relayManager *relayClient.Manager,
config *EngineConfig,
mobileDep MobileDependency,
statusRecorder *peer.Status,
probes *ProbeHolder,
checks []*mgmProto.Checks,
) *Engine {
-
return &Engine{
clientCtx: clientCtx,
clientCancel: clientCancel,
signal: signalClient,
+ signaler: peer.NewSignaler(signalClient, config.WgPrivateKey),
mgmClient: mgmClient,
+ relayManager: relayManager,
peerConns: make(map[string]*peer.Conn),
syncMsgMux: &sync.Mutex{},
config: config,
@@ -260,11 +268,7 @@ func (e *Engine) Stop() error {
time.Sleep(500 * time.Millisecond)
e.close()
-
- e.wgConnWorker.Wait()
-
- log.Infof("Engine stopped")
-
+ log.Infof("stopped Netbird Engine")
return nil
}
@@ -314,7 +318,7 @@ func (e *Engine) Start() error {
}
e.dnsServer = dnsServer
- e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.config.DNSRouteInterval, e.wgInterface, e.statusRecorder, initialRoutes)
+ e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.config.DNSRouteInterval, e.wgInterface, e.statusRecorder, e.relayManager, initialRoutes)
beforePeerHook, afterPeerHook, err := e.routeManager.Init()
if err != nil {
log.Errorf("Failed to initialize route manager: %s", err)
@@ -463,78 +467,25 @@ func (e *Engine) removePeer(peerKey string) error {
conn, exists := e.peerConns[peerKey]
if exists {
delete(e.peerConns, peerKey)
- err := conn.Close()
- if err != nil {
- switch err.(type) {
- case *peer.ConnectionAlreadyClosedError:
- return nil
- default:
- return err
- }
- }
+ conn.Close()
}
return nil
}
-func signalCandidate(candidate ice.Candidate, myKey wgtypes.Key, remoteKey wgtypes.Key, s signal.Client) error {
- err := s.Send(&sProto.Message{
- Key: myKey.PublicKey().String(),
- RemoteKey: remoteKey.String(),
- Body: &sProto.Body{
- Type: sProto.Body_CANDIDATE,
- Payload: candidate.Marshal(),
- },
- })
- if err != nil {
- return err
- }
-
- return nil
-}
-
-func sendSignal(message *sProto.Message, s signal.Client) error {
- return s.Send(message)
-}
-
-// SignalOfferAnswer signals either an offer or an answer to remote peer
-func SignalOfferAnswer(offerAnswer peer.OfferAnswer, myKey wgtypes.Key, remoteKey wgtypes.Key, s signal.Client,
- isAnswer bool) error {
- var t sProto.Body_Type
- if isAnswer {
- t = sProto.Body_ANSWER
- } else {
- t = sProto.Body_OFFER
- }
-
- msg, err := signal.MarshalCredential(myKey, offerAnswer.WgListenPort, remoteKey, &signal.Credential{
- UFrag: offerAnswer.IceCredentials.UFrag,
- Pwd: offerAnswer.IceCredentials.Pwd,
- }, t, offerAnswer.RosenpassPubKey, offerAnswer.RosenpassAddr)
- if err != nil {
- return err
- }
-
- err = s.Send(msg)
- if err != nil {
- return err
- }
-
- return nil
-}
-
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
if update.GetWiretrusteeConfig() != nil {
- err := e.updateTURNs(update.GetWiretrusteeConfig().GetTurns())
+ wCfg := update.GetWiretrusteeConfig()
+ err := e.updateTURNs(wCfg.GetTurns())
if err != nil {
- return err
+ return fmt.Errorf("update TURNs: %w", err)
}
- err = e.updateSTUNs(update.GetWiretrusteeConfig().GetStuns())
+ err = e.updateSTUNs(wCfg.GetStuns())
if err != nil {
- return err
+ return fmt.Errorf("update STUNs: %w", err)
}
var stunTurn []*stun.URI
@@ -542,6 +493,19 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
stunTurn = append(stunTurn, e.TURNs...)
e.stunTurn.Store(stunTurn)
+ relayMsg := wCfg.GetRelay()
+ if relayMsg != nil {
+ c := &auth.Token{
+ Payload: relayMsg.GetTokenPayload(),
+ Signature: relayMsg.GetTokenSignature(),
+ }
+ if err := e.relayManager.UpdateToken(c); err != nil {
+ log.Errorf("failed to update relay token: %v", err)
+ return fmt.Errorf("update relay token: %w", err)
+ }
+ }
+
+ // todo update relay address in the relay manager
// todo update signal
}
@@ -937,58 +901,11 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err)
}
- e.wgConnWorker.Add(1)
- go e.connWorker(conn, peerKey)
+ conn.Open()
}
return nil
}
-func (e *Engine) connWorker(conn *peer.Conn, peerKey string) {
- defer e.wgConnWorker.Done()
- for {
-
- // randomize starting time a bit
- minValue := 500
- maxValue := 2000
- duration := time.Duration(rand.Intn(maxValue-minValue)+minValue) * time.Millisecond
- select {
- case <-e.ctx.Done():
- return
- case <-time.After(duration):
- }
-
- // if peer has been removed -> give up
- if !e.peerExists(peerKey) {
- log.Debugf("peer %s doesn't exist anymore, won't retry connection", peerKey)
- return
- }
-
- if !e.signal.Ready() {
- log.Infof("signal client isn't ready, skipping connection attempt %s", peerKey)
- continue
- }
-
- err := conn.Open(e.ctx)
- if err != nil {
- log.Debugf("connection to peer %s failed: %v", peerKey, err)
- var connectionClosedError *peer.ConnectionClosedError
- switch {
- case errors.As(err, &connectionClosedError):
- // conn has been forced to close, so we exit the loop
- return
- default:
- }
- }
- }
-}
-
-func (e *Engine) peerExists(peerKey string) bool {
- e.syncMsgMux.Lock()
- defer e.syncMsgMux.Unlock()
- _, ok := e.peerConns[peerKey]
- return ok
-}
-
func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, error) {
log.Debugf("creating peer connection %s", pubKey)
@@ -1040,37 +957,12 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
},
}
- peerConn, err := peer.NewConn(config, e.statusRecorder, e.wgProxyFactory, e.mobileDep.TunAdapter, e.mobileDep.IFaceDiscover)
+ peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.wgProxyFactory, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager)
if err != nil {
return nil, err
}
- wgPubKey, err := wgtypes.ParseKey(pubKey)
- if err != nil {
- return nil, err
- }
-
- signalOffer := func(offerAnswer peer.OfferAnswer) error {
- return SignalOfferAnswer(offerAnswer, e.config.WgPrivateKey, wgPubKey, e.signal, false)
- }
-
- signalCandidate := func(candidate ice.Candidate) error {
- return signalCandidate(candidate, e.config.WgPrivateKey, wgPubKey, e.signal)
- }
-
- signalAnswer := func(offerAnswer peer.OfferAnswer) error {
- return SignalOfferAnswer(offerAnswer, e.config.WgPrivateKey, wgPubKey, e.signal, true)
- }
-
- peerConn.SetSignalCandidate(signalCandidate)
- peerConn.SetSignalOffer(signalOffer)
- peerConn.SetSignalAnswer(signalAnswer)
- peerConn.SetSendSignalMessage(func(message *sProto.Message) error {
- return sendSignal(message, e.signal)
- })
-
if e.rpManager != nil {
-
peerConn.SetOnConnected(e.rpManager.OnConnected)
peerConn.SetOnDisconnected(e.rpManager.OnDisconnected)
}
@@ -1113,6 +1005,7 @@ func (e *Engine) receiveSignalEvents() {
Version: msg.GetBody().GetNetBirdVersion(),
RosenpassPubKey: rosenpassPubKey,
RosenpassAddr: rosenpassAddr,
+ RelaySrvAddress: msg.GetBody().GetRelayServerAddress(),
})
case sProto.Body_ANSWER:
remoteCred, err := signal.UnMarshalCredential(msg)
@@ -1135,6 +1028,7 @@ func (e *Engine) receiveSignalEvents() {
Version: msg.GetBody().GetNetBirdVersion(),
RosenpassPubKey: rosenpassPubKey,
RosenpassAddr: rosenpassAddr,
+ RelaySrvAddress: msg.GetBody().GetRelayServerAddress(),
})
case sProto.Body_CANDIDATE:
candidate, err := ice.UnmarshalCandidate(msg.GetBody().Payload)
@@ -1143,7 +1037,7 @@ func (e *Engine) receiveSignalEvents() {
return err
}
- conn.OnRemoteCandidate(candidate, e.GetClientRoutes())
+ go conn.OnRemoteCandidate(candidate, e.GetClientRoutes())
case sProto.Body_MODE:
}
@@ -1442,7 +1336,7 @@ func (e *Engine) receiveProbeEvents() {
for _, peer := range e.peerConns {
key := peer.GetKey()
- wgStats, err := peer.GetConf().WgConfig.WgInterface.GetStats(key)
+ wgStats, err := peer.WgConfig().WgInterface.GetStats(key)
if err != nil {
log.Debugf("failed to get wg stats for peer %s: %s", key, err)
}
diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go
index e024dd323..f30566380 100644
--- a/client/internal/engine_test.go
+++ b/client/internal/engine_test.go
@@ -38,6 +38,7 @@ import (
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/telemetry"
+ relayClient "github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/route"
signal "github.com/netbirdio/netbird/signal/client"
"github.com/netbirdio/netbird/signal/proto"
@@ -59,6 +60,12 @@ var (
}
)
+func TestMain(m *testing.M) {
+ _ = util.InitLog("debug", "console")
+ code := m.Run()
+ os.Exit(code)
+}
+
func TestEngine_SSH(t *testing.T) {
// todo resolve test execution on freebsd
if runtime.GOOS == "windows" || runtime.GOOS == "freebsd" {
@@ -74,13 +81,23 @@ func TestEngine_SSH(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
- engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, &EngineConfig{
- WgIfaceName: "utun101",
- WgAddr: "100.64.0.1/24",
- WgPrivateKey: key,
- WgPort: 33100,
- ServerSSHAllowed: true,
- }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
+ relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String())
+ engine := NewEngine(
+ ctx, cancel,
+ &signal.MockClient{},
+ &mgmt.MockClient{},
+ relayMgr,
+ &EngineConfig{
+ WgIfaceName: "utun101",
+ WgAddr: "100.64.0.1/24",
+ WgPrivateKey: key,
+ WgPort: 33100,
+ ServerSSHAllowed: true,
+ },
+ MobileDependency{},
+ peer.NewRecorder("https://mgm"),
+ nil,
+ )
engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
@@ -209,12 +226,21 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
- engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, &EngineConfig{
- WgIfaceName: "utun102",
- WgAddr: "100.64.0.1/24",
- WgPrivateKey: key,
- WgPort: 33100,
- }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
+ relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String())
+ engine := NewEngine(
+ ctx, cancel,
+ &signal.MockClient{},
+ &mgmt.MockClient{},
+ relayMgr,
+ &EngineConfig{
+ WgIfaceName: "utun102",
+ WgAddr: "100.64.0.1/24",
+ WgPrivateKey: key,
+ WgPort: 33100,
+ },
+ MobileDependency{},
+ peer.NewRecorder("https://mgm"),
+ nil)
wgIface := &iface.MockWGIface{
RemovePeerFunc: func(peerKey string) error {
@@ -222,7 +248,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
},
}
engine.wgInterface = wgIface
- engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, nil)
+ engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil)
engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
}
@@ -404,8 +430,8 @@ func TestEngine_Sync(t *testing.T) {
}
return nil
}
-
- engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{SyncFunc: syncFunc}, &EngineConfig{
+ relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String())
+ engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{SyncFunc: syncFunc}, relayMgr, &EngineConfig{
WgIfaceName: "utun103",
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
@@ -564,7 +590,8 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
wgIfaceName := fmt.Sprintf("utun%d", 104+n)
wgAddr := fmt.Sprintf("100.66.%d.1/24", n)
- engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, &EngineConfig{
+ relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String())
+ engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
WgIfaceName: wgIfaceName,
WgAddr: wgAddr,
WgPrivateKey: key,
@@ -734,7 +761,8 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
wgIfaceName := fmt.Sprintf("utun%d", 104+n)
wgAddr := fmt.Sprintf("100.66.%d.1/24", n)
- engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, &EngineConfig{
+ relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String())
+ engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
WgIfaceName: wgIfaceName,
WgAddr: wgAddr,
WgPrivateKey: key,
@@ -1012,7 +1040,8 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
WgPort: wgPort,
}
- e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil
+ relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String())
+ e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil
e.ctx = ctx
return e, err
}
@@ -1046,6 +1075,11 @@ func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error)
config := &server.Config{
Stuns: []*server.Host{},
TURNConfig: &server.TURNConfig{},
+ Relay: &server.Relay{
+ Addresses: []string{"127.0.0.1:1234"},
+ CredentialsTTL: util.Duration{Duration: time.Hour},
+ Secret: "222222222222222222",
+ },
Signal: &server.Host{
Proto: "http",
URI: "localhost:10000",
@@ -1080,8 +1114,9 @@ func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error)
if err != nil {
return nil, "", err
}
- turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
- mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
+
+ secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
+ mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, nil)
if err != nil {
return nil, "", err
}
diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go
index 7c3b60011..8b8b3c5c0 100644
--- a/client/internal/peer/conn.go
+++ b/client/internal/peer/conn.go
@@ -2,36 +2,35 @@ package peer
import (
"context"
- "fmt"
+ "math/rand"
"net"
+ "os"
"runtime"
"strings"
"sync"
- "sync/atomic"
"time"
+ "github.com/cenkalti/backoff/v4"
"github.com/pion/ice/v3"
- "github.com/pion/stun/v2"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/internal/wgproxy"
"github.com/netbirdio/netbird/iface"
- "github.com/netbirdio/netbird/iface/bind"
+ relayClient "github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/route"
- sProto "github.com/netbirdio/netbird/signal/proto"
nbnet "github.com/netbirdio/netbird/util/net"
- "github.com/netbirdio/netbird/version"
)
-const (
- iceKeepAliveDefault = 4 * time.Second
- iceDisconnectedTimeoutDefault = 6 * time.Second
- // iceRelayAcceptanceMinWaitDefault is the same as in the Pion ICE package
- iceRelayAcceptanceMinWaitDefault = 2 * time.Second
+type ConnPriority int
+const (
defaultWgKeepAlive = 25 * time.Second
+
+ connPriorityRelay ConnPriority = 1
+ connPriorityICETurn ConnPriority = 1
+ connPriorityICEP2P ConnPriority = 2
)
type WgConfig struct {
@@ -42,21 +41,6 @@ type WgConfig struct {
PreSharedKey *wgtypes.Key
}
-type ICEConfig struct {
- // StunTurn is a list of STUN and TURN URLs
- StunTurn *atomic.Value // []*stun.URI
-
- // InterfaceBlackList is a list of machine interfaces that should be filtered out by ICE Candidate gathering
- // (e.g. if eth0 is in the list, host candidate of this interface won't be used)
- InterfaceBlackList []string
- DisableIPv6Discovery bool
-
- UDPMux ice.UDPMux
- UDPMuxSrflx ice.UniversalUDPMux
-
- NATExternalIPs []string
-}
-
// ConnConfig is a peer Connection configuration
type ConnConfig struct {
// Key is a public key of a remote peer
@@ -79,493 +63,215 @@ type ConnConfig struct {
ICEConfig ICEConfig
}
-// OfferAnswer represents a session establishment offer or answer
-type OfferAnswer struct {
- IceCredentials IceCredentials
- // WgListenPort is a remote WireGuard listen port.
- // This field is used when establishing a direct WireGuard connection without any proxy.
- // We can set the remote peer's endpoint with this port.
- WgListenPort int
+type WorkerCallbacks struct {
+ OnRelayReadyCallback func(info RelayConnInfo)
+ OnRelayStatusChanged func(ConnStatus)
- // Version of NetBird Agent
- Version string
- // RosenpassPubKey is the Rosenpass public key of the remote peer when receiving this message
- // This value is the local Rosenpass server public key when sending the message
- RosenpassPubKey []byte
- // RosenpassAddr is the Rosenpass server address (IP:port) of the remote peer when receiving this message
- // This value is the local Rosenpass server address when sending the message
- RosenpassAddr string
-}
-
-// IceCredentials ICE protocol credentials struct
-type IceCredentials struct {
- UFrag string
- Pwd string
+ OnICEConnReadyCallback func(ConnPriority, ICEConnInfo)
+ OnICEStatusChanged func(ConnStatus)
}
type Conn struct {
- config ConnConfig
- mu sync.Mutex
-
- // signalCandidate is a handler function to signal remote peer about local connection candidate
- signalCandidate func(candidate ice.Candidate) error
- // signalOffer is a handler function to signal remote peer our connection offer (credentials)
- signalOffer func(OfferAnswer) error
- signalAnswer func(OfferAnswer) error
- sendSignalMessage func(message *sProto.Message) error
- onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
- onDisconnected func(remotePeer string, wgIP string)
-
- // remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
- remoteOffersCh chan OfferAnswer
- // remoteAnswerCh is a channel used to wait for remote credentials answer (confirmation of our offer) to proceed with the connection
- remoteAnswerCh chan OfferAnswer
- closeCh chan struct{}
- ctx context.Context
- notifyDisconnected context.CancelFunc
-
- agent *ice.Agent
- status ConnStatus
-
+ log *log.Entry
+ mu sync.Mutex
+ ctx context.Context
+ ctxCancel context.CancelFunc
+ config ConnConfig
statusRecorder *Status
-
wgProxyFactory *wgproxy.Factory
- wgProxy wgproxy.Proxy
+ wgProxyICE wgproxy.Proxy
+ wgProxyRelay wgproxy.Proxy
+ signaler *Signaler
+ relayManager *relayClient.Manager
+ allowedIPsIP string
+ handshaker *Handshaker
- adapter iface.TunAdapter
- iFaceDiscover stdnet.ExternalIFaceDiscover
- sentExtraSrflx bool
+ onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
+ onDisconnected func(remotePeer string, wgIP string)
- connID nbnet.ConnectionID
+ statusRelay ConnStatus
+ statusICE ConnStatus
+ currentConnPriority ConnPriority
+ opened bool // this flag is used to prevent close in case of not opened connection
+
+ workerICE *WorkerICE
+ workerRelay *WorkerRelay
+
+ connIDRelay nbnet.ConnectionID
+ connIDICE nbnet.ConnectionID
beforeAddPeerHooks []nbnet.AddHookFunc
afterRemovePeerHooks []nbnet.RemoveHookFunc
-}
-// GetConf returns the connection config
-func (conn *Conn) GetConf() ConnConfig {
- return conn.config
-}
+ endpointRelay *net.UDPAddr
-// WgConfig returns the WireGuard config
-func (conn *Conn) WgConfig() WgConfig {
- return conn.config.WgConfig
+ // for reconnection operations
+ iCEDisconnected chan bool
+ relayDisconnected chan bool
}
// NewConn creates a new not opened Conn to the remote peer.
// To establish a connection run Conn.Open
-func NewConn(config ConnConfig, statusRecorder *Status, wgProxyFactory *wgproxy.Factory, adapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover) (*Conn, error) {
- return &Conn{
- config: config,
- mu: sync.Mutex{},
- status: StatusDisconnected,
- closeCh: make(chan struct{}),
- remoteOffersCh: make(chan OfferAnswer),
- remoteAnswerCh: make(chan OfferAnswer),
- statusRecorder: statusRecorder,
- wgProxyFactory: wgProxyFactory,
- adapter: adapter,
- iFaceDiscover: iFaceDiscover,
- }, nil
+func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, wgProxyFactory *wgproxy.Factory, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager) (*Conn, error) {
+ _, allowedIPsIP, err := net.ParseCIDR(config.WgConfig.AllowedIps)
+ if err != nil {
+ log.Errorf("failed to parse allowedIPS: %v", err)
+ return nil, err
+ }
+
+ ctx, ctxCancel := context.WithCancel(engineCtx)
+ connLog := log.WithField("peer", config.Key)
+
+ var conn = &Conn{
+ log: connLog,
+ ctx: ctx,
+ ctxCancel: ctxCancel,
+ config: config,
+ statusRecorder: statusRecorder,
+ wgProxyFactory: wgProxyFactory,
+ signaler: signaler,
+ relayManager: relayManager,
+ allowedIPsIP: allowedIPsIP.String(),
+ statusRelay: StatusDisconnected,
+ statusICE: StatusDisconnected,
+ iCEDisconnected: make(chan bool, 1),
+ relayDisconnected: make(chan bool, 1),
+ }
+
+ rFns := WorkerRelayCallbacks{
+ OnConnReady: conn.relayConnectionIsReady,
+ OnDisconnected: conn.onWorkerRelayStateDisconnected,
+ }
+
+ wFns := WorkerICECallbacks{
+ OnConnReady: conn.iCEConnectionIsReady,
+ OnStatusChanged: conn.onWorkerICEStateDisconnected,
+ }
+
+ conn.workerRelay = NewWorkerRelay(connLog, config, relayManager, rFns)
+
+ relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
+ conn.workerICE, err = NewWorkerICE(ctx, connLog, config, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally, wFns)
+ if err != nil {
+ return nil, err
+ }
+
+ conn.handshaker = NewHandshaker(ctx, connLog, config, signaler, conn.workerICE, conn.workerRelay)
+
+ conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer)
+ if os.Getenv("NB_FORCE_RELAY") != "true" {
+ conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer)
+ }
+
+ go conn.handshaker.Listen()
+
+ return conn, nil
}
-func (conn *Conn) reCreateAgent() error {
+// Open opens connection to the remote peer
+// It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will
+// be used.
+func (conn *Conn) Open() {
+ conn.log.Debugf("open connection to peer")
conn.mu.Lock()
defer conn.mu.Unlock()
-
- failedTimeout := 6 * time.Second
-
- var err error
- transportNet, err := conn.newStdNet()
- if err != nil {
- log.Errorf("failed to create pion's stdnet: %s", err)
- }
-
- iceKeepAlive := iceKeepAlive()
- iceDisconnectedTimeout := iceDisconnectedTimeout()
- iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait()
-
- agentConfig := &ice.AgentConfig{
- MulticastDNSMode: ice.MulticastDNSModeDisabled,
- NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6},
- Urls: conn.config.ICEConfig.StunTurn.Load().([]*stun.URI),
- CandidateTypes: conn.candidateTypes(),
- FailedTimeout: &failedTimeout,
- InterfaceFilter: stdnet.InterfaceFilter(conn.config.ICEConfig.InterfaceBlackList),
- UDPMux: conn.config.ICEConfig.UDPMux,
- UDPMuxSrflx: conn.config.ICEConfig.UDPMuxSrflx,
- NAT1To1IPs: conn.config.ICEConfig.NATExternalIPs,
- Net: transportNet,
- DisconnectedTimeout: &iceDisconnectedTimeout,
- KeepaliveInterval: &iceKeepAlive,
- RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait,
- }
-
- if conn.config.ICEConfig.DisableIPv6Discovery {
- agentConfig.NetworkTypes = []ice.NetworkType{ice.NetworkTypeUDP4}
- }
-
- conn.agent, err = ice.NewAgent(agentConfig)
- if err != nil {
- return err
- }
-
- err = conn.agent.OnCandidate(conn.onICECandidate)
- if err != nil {
- return err
- }
-
- err = conn.agent.OnConnectionStateChange(conn.onICEConnectionStateChange)
- if err != nil {
- return err
- }
-
- err = conn.agent.OnSelectedCandidatePairChange(conn.onICESelectedCandidatePair)
- if err != nil {
- return err
- }
-
- err = conn.agent.OnSuccessfulSelectedPairBindingResponse(func(p *ice.CandidatePair) {
- err := conn.statusRecorder.UpdateLatency(conn.config.Key, p.Latency())
- if err != nil {
- log.Debugf("failed to update latency for peer %s: %s", conn.config.Key, err)
- return
- }
- })
- if err != nil {
- return fmt.Errorf("failed setting binding response callback: %w", err)
- }
-
- return nil
-}
-
-func (conn *Conn) candidateTypes() []ice.CandidateType {
- if hasICEForceRelayConn() {
- return []ice.CandidateType{ice.CandidateTypeRelay}
- }
- // TODO: remove this once we have refactored userspace proxy into the bind package
- if runtime.GOOS == "ios" {
- return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive}
- }
- return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay}
-}
-
-// Open opens connection to the remote peer starting ICE candidate gathering process.
-// Blocks until connection has been closed or connection timeout.
-// ConnStatus will be set accordingly
-func (conn *Conn) Open(ctx context.Context) error {
- log.Debugf("trying to connect to peer %s", conn.config.Key)
+ conn.opened = true
peerState := State{
PubKey: conn.config.Key,
IP: strings.Split(conn.config.WgConfig.AllowedIps, "/")[0],
ConnStatusUpdate: time.Now(),
- ConnStatus: conn.status,
+ ConnStatus: StatusDisconnected,
Mux: new(sync.RWMutex),
}
err := conn.statusRecorder.UpdatePeerState(peerState)
if err != nil {
- log.Warnf("error while updating the state of peer %s,err: %v", conn.config.Key, err)
+ conn.log.Warnf("error while updating the state err: %v", err)
}
- defer func() {
- err := conn.cleanup()
- if err != nil {
- log.Warnf("error while cleaning up peer connection %s: %v", conn.config.Key, err)
- return
- }
- }()
+ go conn.startHandshakeAndReconnect()
+}
- err = conn.reCreateAgent()
+func (conn *Conn) startHandshakeAndReconnect() {
+ conn.waitInitialRandomSleepTime()
+
+ err := conn.handshaker.sendOffer()
if err != nil {
- return err
+ conn.log.Errorf("failed to send initial offer: %v", err)
}
- err = conn.sendOffer()
- if err != nil {
- return err
- }
-
- log.Debugf("connection offer sent to peer %s, waiting for the confirmation", conn.config.Key)
-
- // Only continue once we got a connection confirmation from the remote peer.
- // The connection timeout could have happened before a confirmation received from the remote.
- // The connection could have also been closed externally (e.g. when we received an update from the management that peer shouldn't be connected)
- var remoteOfferAnswer OfferAnswer
- select {
- case remoteOfferAnswer = <-conn.remoteOffersCh:
- // received confirmation from the remote peer -> ready to proceed
- err = conn.sendAnswer()
- if err != nil {
- return err
- }
- case remoteOfferAnswer = <-conn.remoteAnswerCh:
- case <-time.After(conn.config.Timeout):
- return NewConnectionTimeoutError(conn.config.Key, conn.config.Timeout)
- case <-conn.closeCh:
- // closed externally
- return NewConnectionClosedError(conn.config.Key)
- }
-
- log.Debugf("received connection confirmation from peer %s running version %s and with remote WireGuard listen port %d",
- conn.config.Key, remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort)
-
- // at this point we received offer/answer and we are ready to gather candidates
- conn.mu.Lock()
- conn.status = StatusConnecting
- conn.ctx, conn.notifyDisconnected = context.WithCancel(ctx)
- defer conn.notifyDisconnected()
- conn.mu.Unlock()
-
- peerState = State{
- PubKey: conn.config.Key,
- ConnStatus: conn.status,
- ConnStatusUpdate: time.Now(),
- Mux: new(sync.RWMutex),
- }
- err = conn.statusRecorder.UpdatePeerState(peerState)
- if err != nil {
- log.Warnf("error while updating the state of peer %s,err: %v", conn.config.Key, err)
- }
-
- err = conn.agent.GatherCandidates()
- if err != nil {
- return fmt.Errorf("gather candidates: %v", err)
- }
-
- // will block until connection succeeded
- // but it won't release if ICE Agent went into Disconnected or Failed state,
- // so we have to cancel it with the provided context once agent detected a broken connection
- isControlling := conn.config.LocalKey > conn.config.Key
- var remoteConn *ice.Conn
- if isControlling {
- remoteConn, err = conn.agent.Dial(conn.ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
+ if conn.workerRelay.IsController() {
+ conn.reconnectLoopWithRetry()
} else {
- remoteConn, err = conn.agent.Accept(conn.ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
- }
- if err != nil {
- return err
- }
-
- // dynamically set remote WireGuard port if other side specified a different one from the default one
- remoteWgPort := iface.DefaultWgPort
- if remoteOfferAnswer.WgListenPort != 0 {
- remoteWgPort = remoteOfferAnswer.WgListenPort
- }
-
- // the ice connection has been established successfully so we are ready to start the proxy
- remoteAddr, err := conn.configureConnection(remoteConn, remoteWgPort, remoteOfferAnswer.RosenpassPubKey,
- remoteOfferAnswer.RosenpassAddr)
- if err != nil {
- return err
- }
-
- log.Infof("connected to peer %s, endpoint address: %s", conn.config.Key, remoteAddr.String())
-
- // wait until connection disconnected or has been closed externally (upper layer, e.g. engine)
- select {
- case <-conn.closeCh:
- // closed externally
- return NewConnectionClosedError(conn.config.Key)
- case <-conn.ctx.Done():
- // disconnected from the remote peer
- return NewConnectionDisconnectedError(conn.config.Key)
+ conn.reconnectLoopForOnDisconnectedEvent()
}
}
-func isRelayCandidate(candidate ice.Candidate) bool {
- return candidate.Type() == ice.CandidateTypeRelay
+// Close closes this peer Conn issuing a close event to the Conn closeCh
+func (conn *Conn) Close() {
+ conn.mu.Lock()
+ defer conn.mu.Unlock()
+
+ conn.log.Infof("close peer connection")
+ conn.ctxCancel()
+
+ if !conn.opened {
+ conn.log.Debugf("ignore close connection to peer")
+ return
+ }
+
+ conn.workerRelay.DisableWgWatcher()
+ conn.workerRelay.CloseConn()
+ conn.workerICE.Close()
+
+ if conn.wgProxyRelay != nil {
+ err := conn.wgProxyRelay.CloseConn()
+ if err != nil {
+ conn.log.Errorf("failed to close wg proxy for relay: %v", err)
+ }
+ conn.wgProxyRelay = nil
+ }
+
+ if conn.wgProxyICE != nil {
+ err := conn.wgProxyICE.CloseConn()
+ if err != nil {
+ conn.log.Errorf("failed to close wg proxy for ice: %v", err)
+ }
+ conn.wgProxyICE = nil
+ }
+
+ err := conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
+ if err != nil {
+ conn.log.Errorf("failed to remove wg endpoint: %v", err)
+ }
+
+ conn.freeUpConnID()
+
+ if conn.evalStatus() == StatusConnected && conn.onDisconnected != nil {
+ conn.onDisconnected(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps)
+ }
+
+ conn.setStatusToDisconnected()
+}
+
+// OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
+// doesn't block, discards the message if connection wasn't ready
+func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) bool {
+ conn.log.Debugf("OnRemoteAnswer, status ICE: %s, status relay: %s", conn.statusICE, conn.statusRelay)
+ return conn.handshaker.OnRemoteAnswer(answer)
+}
+
+// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
+func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) {
+ conn.workerICE.OnRemoteCandidate(candidate, haRoutes)
}
func (conn *Conn) AddBeforeAddPeerHook(hook nbnet.AddHookFunc) {
conn.beforeAddPeerHooks = append(conn.beforeAddPeerHooks, hook)
}
-
func (conn *Conn) AddAfterRemovePeerHook(hook nbnet.RemoveHookFunc) {
conn.afterRemovePeerHooks = append(conn.afterRemovePeerHooks, hook)
}
-// configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
-func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, remoteRosenpassPubKey []byte, remoteRosenpassAddr string) (net.Addr, error) {
- conn.mu.Lock()
- defer conn.mu.Unlock()
-
- pair, err := conn.agent.GetSelectedCandidatePair()
- if err != nil {
- return nil, err
- }
-
- var endpoint net.Addr
- if isRelayCandidate(pair.Local) {
- log.Debugf("setup relay connection")
- conn.wgProxy = conn.wgProxyFactory.GetProxy(conn.ctx)
- endpoint, err = conn.wgProxy.AddTurnConn(remoteConn)
- if err != nil {
- return nil, err
- }
- } else {
- // To support old version's with direct mode we attempt to punch an additional role with the remote WireGuard port
- go conn.punchRemoteWGPort(pair, remoteWgPort)
- endpoint = remoteConn.RemoteAddr()
- }
-
- endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
- log.Debugf("Conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP)
-
- conn.connID = nbnet.GenerateConnID()
- for _, hook := range conn.beforeAddPeerHooks {
- if err := hook(conn.connID, endpointUdpAddr.IP); err != nil {
- log.Errorf("Before add peer hook failed: %v", err)
- }
- }
-
- err = conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, endpointUdpAddr, conn.config.WgConfig.PreSharedKey)
- if err != nil {
- if conn.wgProxy != nil {
- if err := conn.wgProxy.CloseConn(); err != nil {
- log.Warnf("Failed to close turn connection: %v", err)
- }
- }
- return nil, fmt.Errorf("update peer: %w", err)
- }
-
- conn.status = StatusConnected
- rosenpassEnabled := false
- if remoteRosenpassPubKey != nil {
- rosenpassEnabled = true
- }
-
- peerState := State{
- PubKey: conn.config.Key,
- ConnStatus: conn.status,
- ConnStatusUpdate: time.Now(),
- LocalIceCandidateType: pair.Local.Type().String(),
- RemoteIceCandidateType: pair.Remote.Type().String(),
- LocalIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Local.Address(), pair.Local.Port()),
- RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Remote.Port()),
- Direct: !isRelayCandidate(pair.Local),
- RosenpassEnabled: rosenpassEnabled,
- Mux: new(sync.RWMutex),
- }
- if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
- peerState.Relayed = true
- }
-
- err = conn.statusRecorder.UpdatePeerState(peerState)
- if err != nil {
- log.Warnf("unable to save peer's state, got error: %v", err)
- }
-
- _, ipNet, err := net.ParseCIDR(conn.config.WgConfig.AllowedIps)
- if err != nil {
- return nil, err
- }
-
- if runtime.GOOS == "ios" {
- runtime.GC()
- }
-
- if conn.onConnected != nil {
- conn.onConnected(conn.config.Key, remoteRosenpassPubKey, ipNet.IP.String(), remoteRosenpassAddr)
- }
-
- return endpoint, nil
-}
-
-func (conn *Conn) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
- // wait local endpoint configuration
- time.Sleep(time.Second)
- addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", pair.Remote.Address(), remoteWgPort))
- if err != nil {
- log.Warnf("got an error while resolving the udp address, err: %s", err)
- return
- }
-
- mux, ok := conn.config.ICEConfig.UDPMuxSrflx.(*bind.UniversalUDPMuxDefault)
- if !ok {
- log.Warn("invalid udp mux conversion")
- return
- }
- _, err = mux.GetSharedConn().WriteTo([]byte{0x6e, 0x62}, addr)
- if err != nil {
- log.Warnf("got an error while sending the punch packet, err: %s", err)
- }
-}
-
-// cleanup closes all open resources and sets status to StatusDisconnected
-func (conn *Conn) cleanup() error {
- log.Debugf("trying to cleanup %s", conn.config.Key)
- conn.mu.Lock()
- defer conn.mu.Unlock()
-
- conn.sentExtraSrflx = false
-
- var err1, err2, err3 error
- if conn.agent != nil {
- err1 = conn.agent.Close()
- if err1 == nil {
- conn.agent = nil
- }
- }
-
- if conn.wgProxy != nil {
- err2 = conn.wgProxy.CloseConn()
- conn.wgProxy = nil
- }
-
- // todo: is it problem if we try to remove a peer what is never existed?
- err3 = conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
-
- if conn.connID != "" {
- for _, hook := range conn.afterRemovePeerHooks {
- if err := hook(conn.connID); err != nil {
- log.Errorf("After remove peer hook failed: %v", err)
- }
- }
- }
- conn.connID = ""
-
- if conn.notifyDisconnected != nil {
- conn.notifyDisconnected()
- conn.notifyDisconnected = nil
- }
-
- if conn.status == StatusConnected && conn.onDisconnected != nil {
- conn.onDisconnected(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps)
- }
-
- conn.status = StatusDisconnected
-
- peerState := State{
- PubKey: conn.config.Key,
- ConnStatus: conn.status,
- ConnStatusUpdate: time.Now(),
- Mux: new(sync.RWMutex),
- }
- err := conn.statusRecorder.UpdatePeerState(peerState)
- if err != nil {
- // pretty common error because by that time Engine can already remove the peer and status won't be available.
- // todo rethink status updates
- log.Debugf("error while updating peer's %s state, err: %v", conn.config.Key, err)
- }
- if err := conn.statusRecorder.UpdateWireGuardPeerState(conn.config.Key, iface.WGStats{}); err != nil {
- log.Debugf("failed to reset wireguard stats for peer %s: %s", conn.config.Key, err)
- }
-
- log.Debugf("cleaned up connection to peer %s", conn.config.Key)
- if err1 != nil {
- return err1
- }
- if err2 != nil {
- return err2
- }
- return err3
-}
-
-// SetSignalOffer sets a handler function to be triggered by Conn when a new connection offer has to be signalled to the remote peer
-func (conn *Conn) SetSignalOffer(handler func(offer OfferAnswer) error) {
- conn.signalOffer = handler
-}
-
// SetOnConnected sets a handler function to be triggered by Conn when a new connection to a remote peer established
func (conn *Conn) SetOnConnected(handler func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)) {
conn.onConnected = handler
@@ -576,218 +282,514 @@ func (conn *Conn) SetOnDisconnected(handler func(remotePeer string, wgIP string)
conn.onDisconnected = handler
}
-// SetSignalAnswer sets a handler function to be triggered by Conn when a new connection answer has to be signalled to the remote peer
-func (conn *Conn) SetSignalAnswer(handler func(answer OfferAnswer) error) {
- conn.signalAnswer = handler
+func (conn *Conn) OnRemoteOffer(offer OfferAnswer) bool {
+ conn.log.Debugf("OnRemoteOffer, on status ICE: %s, status Relay: %s", conn.statusICE, conn.statusRelay)
+ return conn.handshaker.OnRemoteOffer(offer)
}
-// SetSignalCandidate sets a handler function to be triggered by Conn when a new ICE local connection candidate has to be signalled to the remote peer
-func (conn *Conn) SetSignalCandidate(handler func(candidate ice.Candidate) error) {
- conn.signalCandidate = handler
-}
-
-// SetSendSignalMessage sets a handler function to be triggered by Conn when there is new message to send via signal
-func (conn *Conn) SetSendSignalMessage(handler func(message *sProto.Message) error) {
- conn.sendSignalMessage = handler
-}
-
-// onICECandidate is a callback attached to an ICE Agent to receive new local connection candidates
-// and then signals them to the remote peer
-func (conn *Conn) onICECandidate(candidate ice.Candidate) {
- // nil means candidate gathering has been ended
- if candidate == nil {
- return
- }
-
- // TODO: reported port is incorrect for CandidateTypeHost, makes understanding ICE use via logs confusing as port is ignored
- log.Debugf("discovered local candidate %s", candidate.String())
- go func() {
- err := conn.signalCandidate(candidate)
- if err != nil {
- log.Errorf("failed signaling candidate to the remote peer %s %s", conn.config.Key, err)
- }
- }()
-
- if !conn.shouldSendExtraSrflxCandidate(candidate) {
- return
- }
-
- // sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port)
- // this is useful when network has an existing port forwarding rule for the wireguard port and this peer
- extraSrflx, err := extraSrflxCandidate(candidate)
- if err != nil {
- log.Errorf("failed creating extra server reflexive candidate %s", err)
- return
- }
- conn.sentExtraSrflx = true
-
- go func() {
- err = conn.signalCandidate(extraSrflx)
- if err != nil {
- log.Errorf("failed signaling the extra server reflexive candidate to the remote peer %s: %s", conn.config.Key, err)
- }
- }()
-}
-
-func (conn *Conn) onICESelectedCandidatePair(c1 ice.Candidate, c2 ice.Candidate) {
- log.Debugf("selected candidate pair [local <-> remote] -> [%s <-> %s], peer %s", c1.String(), c2.String(),
- conn.config.Key)
-}
-
-// onICEConnectionStateChange registers callback of an ICE Agent to track connection state
-func (conn *Conn) onICEConnectionStateChange(state ice.ConnectionState) {
- log.Debugf("peer %s ICE ConnectionState has changed to %s", conn.config.Key, state.String())
- if state == ice.ConnectionStateFailed || state == ice.ConnectionStateDisconnected {
- conn.notifyDisconnected()
- }
-}
-
-func (conn *Conn) sendAnswer() error {
- conn.mu.Lock()
- defer conn.mu.Unlock()
-
- localUFrag, localPwd, err := conn.agent.GetLocalUserCredentials()
- if err != nil {
- return err
- }
-
- log.Debugf("sending answer to %s", conn.config.Key)
- err = conn.signalAnswer(OfferAnswer{
- IceCredentials: IceCredentials{localUFrag, localPwd},
- WgListenPort: conn.config.LocalWgPort,
- Version: version.NetbirdVersion(),
- RosenpassPubKey: conn.config.RosenpassPubKey,
- RosenpassAddr: conn.config.RosenpassAddr,
- })
- if err != nil {
- return err
- }
-
- return nil
-}
-
-// sendOffer prepares local user credentials and signals them to the remote peer
-func (conn *Conn) sendOffer() error {
- conn.mu.Lock()
- defer conn.mu.Unlock()
-
- localUFrag, localPwd, err := conn.agent.GetLocalUserCredentials()
- if err != nil {
- return err
- }
- err = conn.signalOffer(OfferAnswer{
- IceCredentials: IceCredentials{localUFrag, localPwd},
- WgListenPort: conn.config.LocalWgPort,
- Version: version.NetbirdVersion(),
- RosenpassPubKey: conn.config.RosenpassPubKey,
- RosenpassAddr: conn.config.RosenpassAddr,
- })
- if err != nil {
- return err
- }
- return nil
-}
-
-// Close closes this peer Conn issuing a close event to the Conn closeCh
-func (conn *Conn) Close() error {
- conn.mu.Lock()
- defer conn.mu.Unlock()
- select {
- case conn.closeCh <- struct{}{}:
- return nil
- default:
- // probably could happen when peer has been added and removed right after not even starting to connect
- // todo further investigate
- // this really happens due to unordered messages coming from management
- // more importantly it causes inconsistency -> 2 Conn objects for the same peer
- // e.g. this flow:
- // update from management has peers: [1,2,3,4]
- // engine creates a Conn for peers: [1,2,3,4] and schedules Open in ~1sec
- // before conn.Open() another update from management arrives with peers: [1,2,3]
- // engine removes peer 4 and calls conn.Close() which does nothing (this default clause)
- // before conn.Open() another update from management arrives with peers: [1,2,3,4,5]
- // engine adds a new Conn for 4 and 5
- // therefore peer 4 has 2 Conn objects
- log.Warnf("Connection has been already closed or attempted closing not started connection %s", conn.config.Key)
- return NewConnectionAlreadyClosed(conn.config.Key)
- }
+// WgConfig returns the WireGuard config
+func (conn *Conn) WgConfig() WgConfig {
+ return conn.config.WgConfig
}
// Status returns current status of the Conn
func (conn *Conn) Status() ConnStatus {
conn.mu.Lock()
defer conn.mu.Unlock()
- return conn.status
-}
-
-// OnRemoteOffer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
-// doesn't block, discards the message if connection wasn't ready
-func (conn *Conn) OnRemoteOffer(offer OfferAnswer) bool {
- log.Debugf("OnRemoteOffer from peer %s on status %s", conn.config.Key, conn.status.String())
-
- select {
- case conn.remoteOffersCh <- offer:
- return true
- default:
- log.Debugf("OnRemoteOffer skipping message from peer %s on status %s because is not ready", conn.config.Key, conn.status.String())
- // connection might not be ready yet to receive so we ignore the message
- return false
- }
-}
-
-// OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
-// doesn't block, discards the message if connection wasn't ready
-func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) bool {
- log.Debugf("OnRemoteAnswer from peer %s on status %s", conn.config.Key, conn.status.String())
-
- select {
- case conn.remoteAnswerCh <- answer:
- return true
- default:
- // connection might not be ready yet to receive so we ignore the message
- log.Debugf("OnRemoteAnswer skipping message from peer %s on status %s because is not ready", conn.config.Key, conn.status.String())
- return false
- }
-}
-
-// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
-func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) {
- log.Debugf("OnRemoteCandidate from peer %s -> %s", conn.config.Key, candidate.String())
- go func() {
- conn.mu.Lock()
- defer conn.mu.Unlock()
-
- if conn.agent == nil {
- return
- }
-
- err := conn.agent.AddRemoteCandidate(candidate)
- if err != nil {
- log.Errorf("error while handling remote candidate from peer %s", conn.config.Key)
- return
- }
- }()
+ return conn.evalStatus()
}
func (conn *Conn) GetKey() string {
return conn.config.Key
}
-func (conn *Conn) shouldSendExtraSrflxCandidate(candidate ice.Candidate) bool {
- if !conn.sentExtraSrflx && candidate.Type() == ice.CandidateTypeServerReflexive && candidate.Port() != candidate.RelatedAddress().Port {
- return true
+func (conn *Conn) reconnectLoopWithRetry() {
+ // Give chance to the peer to establish the initial connection.
+ // With it, we can decrease to send necessary offer
+ select {
+ case <-conn.ctx.Done():
+ case <-time.After(3 * time.Second):
+ }
+
+ ticker := conn.prepareExponentTicker()
+ defer ticker.Stop()
+ time.Sleep(1 * time.Second)
+ for {
+ select {
+ case t := <-ticker.C:
+ if t.IsZero() {
+ // in case if the ticker has been canceled by context then avoid the temporary loop
+ return
+ }
+
+ if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
+ if conn.statusRelay == StatusDisconnected || conn.statusICE == StatusDisconnected {
+ conn.log.Tracef("connectivity guard timedout, relay state: %s, ice state: %s", conn.statusRelay, conn.statusICE)
+ }
+ } else {
+ if conn.statusICE == StatusDisconnected {
+ conn.log.Tracef("connectivity guard timedout, ice state: %s", conn.statusICE)
+ }
+ }
+
+ // checks if there is peer connection is established via relay or ice
+ if conn.isConnected() {
+ continue
+ }
+
+ err := conn.handshaker.sendOffer()
+ if err != nil {
+ conn.log.Errorf("failed to do handshake: %v", err)
+ }
+ case changed := <-conn.relayDisconnected:
+ if !changed {
+ continue
+ }
+ conn.log.Debugf("Relay state changed, reset reconnect timer")
+ ticker.Stop()
+ ticker = conn.prepareExponentTicker()
+ case changed := <-conn.iCEDisconnected:
+ if !changed {
+ continue
+ }
+ conn.log.Debugf("ICE state changed, reset reconnect timer")
+ ticker.Stop()
+ ticker = conn.prepareExponentTicker()
+ case <-conn.ctx.Done():
+ conn.log.Debugf("context is done, stop reconnect loop")
+ return
+ }
}
- return false
}
-func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) {
- relatedAdd := candidate.RelatedAddress()
- return ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
- Network: candidate.NetworkType().String(),
- Address: candidate.Address(),
- Port: relatedAdd.Port,
- Component: candidate.Component(),
- RelAddr: relatedAdd.Address,
- RelPort: relatedAdd.Port,
- })
+func (conn *Conn) prepareExponentTicker() *backoff.Ticker {
+ bo := backoff.WithContext(&backoff.ExponentialBackOff{
+ InitialInterval: 800 * time.Millisecond,
+ RandomizationFactor: 0.01,
+ Multiplier: 2,
+ MaxInterval: conn.config.Timeout,
+ MaxElapsedTime: 0,
+ Stop: backoff.Stop,
+ Clock: backoff.SystemClock,
+ }, conn.ctx)
+
+ ticker := backoff.NewTicker(bo)
+ <-ticker.C // consume the initial tick what is happening right after the ticker has been created
+
+ return ticker
+}
+
+// reconnectLoopForOnDisconnectedEvent is used when the peer is not a controller and it should reconnect to the peer
+// when the connection is lost. It will try to establish a connection only once time if before the connection was established
+// It track separately the ice and relay connection status. Just because a lover priority connection reestablished it does not
+// mean that to switch to it. We always force to use the higher priority connection.
+func (conn *Conn) reconnectLoopForOnDisconnectedEvent() {
+ for {
+ select {
+ case changed := <-conn.relayDisconnected:
+ if !changed {
+ continue
+ }
+ conn.log.Debugf("Relay state changed, try to send new offer")
+ case changed := <-conn.iCEDisconnected:
+ if !changed {
+ continue
+ }
+ conn.log.Debugf("ICE state changed, try to send new offer")
+ case <-conn.ctx.Done():
+ conn.log.Debugf("context is done, stop reconnect loop")
+ return
+ }
+
+ err := conn.handshaker.SendOffer()
+ if err != nil {
+ conn.log.Errorf("failed to do handshake: %v", err)
+ }
+ }
+}
+
+// configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
+func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICEConnInfo) {
+ conn.mu.Lock()
+ defer conn.mu.Unlock()
+
+ if conn.ctx.Err() != nil {
+ return
+ }
+
+ conn.log.Debugf("ICE connection is ready")
+
+ conn.statusICE = StatusConnected
+
+ defer conn.updateIceState(iceConnInfo)
+
+ if conn.currentConnPriority > priority {
+ return
+ }
+
+ conn.log.Infof("set ICE to active connection")
+
+ endpoint, wgProxy, err := conn.getEndpointForICEConnInfo(iceConnInfo)
+ if err != nil {
+ return
+ }
+
+ endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
+ conn.log.Debugf("Conn resolved IP is %s for endopint %s", endpoint, endpointUdpAddr.IP)
+
+ conn.connIDICE = nbnet.GenerateConnID()
+ for _, hook := range conn.beforeAddPeerHooks {
+ if err := hook(conn.connIDICE, endpointUdpAddr.IP); err != nil {
+ conn.log.Errorf("Before add peer hook failed: %v", err)
+ }
+ }
+
+ conn.workerRelay.DisableWgWatcher()
+
+ err = conn.configureWGEndpoint(endpointUdpAddr)
+ if err != nil {
+ if wgProxy != nil {
+ if err := wgProxy.CloseConn(); err != nil {
+ conn.log.Warnf("Failed to close turn connection: %v", err)
+ }
+ }
+ conn.log.Warnf("Failed to update wg peer configuration: %v", err)
+ return
+ }
+ wgConfigWorkaround()
+
+ if conn.wgProxyICE != nil {
+ if err := conn.wgProxyICE.CloseConn(); err != nil {
+ conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
+ }
+ }
+ conn.wgProxyICE = wgProxy
+
+ conn.currentConnPriority = priority
+
+ conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
+}
+
+// todo review to make sense to handle connecting and disconnected status also?
+func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
+ conn.mu.Lock()
+ defer conn.mu.Unlock()
+
+ if conn.ctx.Err() != nil {
+ return
+ }
+
+ conn.log.Tracef("ICE connection state changed to %s", newState)
+
+ // switch back to relay connection
+ if conn.endpointRelay != nil && conn.currentConnPriority != connPriorityRelay {
+ conn.log.Debugf("ICE disconnected, set Relay to active connection")
+ err := conn.configureWGEndpoint(conn.endpointRelay)
+ if err != nil {
+ conn.log.Errorf("failed to switch to relay conn: %v", err)
+ }
+ conn.workerRelay.EnableWgWatcher(conn.ctx)
+ conn.currentConnPriority = connPriorityRelay
+ }
+
+ changed := conn.statusICE != newState && newState != StatusConnecting
+ conn.statusICE = newState
+
+ select {
+ case conn.iCEDisconnected <- changed:
+ default:
+ }
+
+ peerState := State{
+ PubKey: conn.config.Key,
+ ConnStatus: conn.evalStatus(),
+ Relayed: conn.isRelayed(),
+ ConnStatusUpdate: time.Now(),
+ }
+
+ err := conn.statusRecorder.UpdatePeerICEStateToDisconnected(peerState)
+ if err != nil {
+ conn.log.Warnf("unable to set peer's state to disconnected ice, got error: %v", err)
+ }
+}
+
+func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
+ conn.mu.Lock()
+ defer conn.mu.Unlock()
+
+ if conn.ctx.Err() != nil {
+ return
+ }
+
+ conn.log.Debugf("Relay connection is ready to use")
+ conn.statusRelay = StatusConnected
+
+ wgProxy := conn.wgProxyFactory.GetProxy(conn.ctx)
+ endpoint, err := wgProxy.AddTurnConn(rci.relayedConn)
+ if err != nil {
+ conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
+ return
+ }
+
+ endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
+ conn.endpointRelay = endpointUdpAddr
+ conn.log.Debugf("conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP)
+
+ defer conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
+
+ if conn.currentConnPriority > connPriorityRelay {
+ if conn.statusICE == StatusConnected {
+ log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority)
+ return
+ }
+ }
+
+ conn.connIDRelay = nbnet.GenerateConnID()
+ for _, hook := range conn.beforeAddPeerHooks {
+ if err := hook(conn.connIDRelay, endpointUdpAddr.IP); err != nil {
+ conn.log.Errorf("Before add peer hook failed: %v", err)
+ }
+ }
+
+ err = conn.configureWGEndpoint(endpointUdpAddr)
+ if err != nil {
+ if err := wgProxy.CloseConn(); err != nil {
+ conn.log.Warnf("Failed to close relay connection: %v", err)
+ }
+ conn.log.Errorf("Failed to update wg peer configuration: %v", err)
+ return
+ }
+ wgConfigWorkaround()
+ conn.workerRelay.EnableWgWatcher(conn.ctx)
+
+ if conn.wgProxyRelay != nil {
+ if err := conn.wgProxyRelay.CloseConn(); err != nil {
+ conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
+ }
+ }
+ conn.wgProxyRelay = wgProxy
+ conn.currentConnPriority = connPriorityRelay
+
+ conn.log.Infof("start to communicate with peer via relay")
+ conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
+}
+
+func (conn *Conn) onWorkerRelayStateDisconnected() {
+ conn.mu.Lock()
+ defer conn.mu.Unlock()
+
+ if conn.ctx.Err() != nil {
+ return
+ }
+
+ if conn.wgProxyRelay != nil {
+ log.Debugf("relayed connection is closed, clean up WireGuard config")
+ err := conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
+ if err != nil {
+ conn.log.Errorf("failed to remove wg endpoint: %v", err)
+ }
+
+ conn.endpointRelay = nil
+ _ = conn.wgProxyRelay.CloseConn()
+ conn.wgProxyRelay = nil
+ }
+
+ changed := conn.statusRelay != StatusDisconnected
+ conn.statusRelay = StatusDisconnected
+
+ select {
+ case conn.relayDisconnected <- changed:
+ default:
+ }
+
+ peerState := State{
+ PubKey: conn.config.Key,
+ ConnStatus: conn.evalStatus(),
+ Relayed: conn.isRelayed(),
+ ConnStatusUpdate: time.Now(),
+ }
+
+ err := conn.statusRecorder.UpdatePeerRelayedStateToDisconnected(peerState)
+ if err != nil {
+ conn.log.Warnf("unable to save peer's state to Relay disconnected, got error: %v", err)
+ }
+}
+
+func (conn *Conn) configureWGEndpoint(addr *net.UDPAddr) error {
+ return conn.config.WgConfig.WgInterface.UpdatePeer(
+ conn.config.WgConfig.RemoteKey,
+ conn.config.WgConfig.AllowedIps,
+ defaultWgKeepAlive,
+ addr,
+ conn.config.WgConfig.PreSharedKey,
+ )
+}
+
+func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []byte) {
+ peerState := State{
+ PubKey: conn.config.Key,
+ ConnStatusUpdate: time.Now(),
+ ConnStatus: conn.evalStatus(),
+ Relayed: conn.isRelayed(),
+ RelayServerAddress: relayServerAddr,
+ RosenpassEnabled: isRosenpassEnabled(rosenpassPubKey),
+ }
+
+ err := conn.statusRecorder.UpdatePeerRelayedState(peerState)
+ if err != nil {
+ conn.log.Warnf("unable to save peer's Relay state, got error: %v", err)
+ }
+}
+
+func (conn *Conn) updateIceState(iceConnInfo ICEConnInfo) {
+ peerState := State{
+ PubKey: conn.config.Key,
+ ConnStatusUpdate: time.Now(),
+ ConnStatus: conn.evalStatus(),
+ Relayed: iceConnInfo.Relayed,
+ LocalIceCandidateType: iceConnInfo.LocalIceCandidateType,
+ RemoteIceCandidateType: iceConnInfo.RemoteIceCandidateType,
+ LocalIceCandidateEndpoint: iceConnInfo.LocalIceCandidateEndpoint,
+ RemoteIceCandidateEndpoint: iceConnInfo.RemoteIceCandidateEndpoint,
+ RosenpassEnabled: isRosenpassEnabled(iceConnInfo.RosenpassPubKey),
+ }
+
+ err := conn.statusRecorder.UpdatePeerICEState(peerState)
+ if err != nil {
+ conn.log.Warnf("unable to save peer's ICE state, got error: %v", err)
+ }
+}
+
+func (conn *Conn) setStatusToDisconnected() {
+ conn.statusRelay = StatusDisconnected
+ conn.statusICE = StatusDisconnected
+
+ peerState := State{
+ PubKey: conn.config.Key,
+ ConnStatus: StatusDisconnected,
+ ConnStatusUpdate: time.Now(),
+ Mux: new(sync.RWMutex),
+ }
+ err := conn.statusRecorder.UpdatePeerState(peerState)
+ if err != nil {
+ // pretty common error because by that time Engine can already remove the peer and status won't be available.
+ // todo rethink status updates
+ conn.log.Debugf("error while updating peer's state, err: %v", err)
+ }
+ if err := conn.statusRecorder.UpdateWireGuardPeerState(conn.config.Key, iface.WGStats{}); err != nil {
+ conn.log.Debugf("failed to reset wireguard stats for peer: %s", err)
+ }
+}
+
+func (conn *Conn) doOnConnected(remoteRosenpassPubKey []byte, remoteRosenpassAddr string) {
+ if runtime.GOOS == "ios" {
+ runtime.GC()
+ }
+
+ if conn.onConnected != nil {
+ conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.allowedIPsIP, remoteRosenpassAddr)
+ }
+}
+
+func (conn *Conn) waitInitialRandomSleepTime() {
+ minWait := 100
+ maxWait := 800
+ duration := time.Duration(rand.Intn(maxWait-minWait)+minWait) * time.Millisecond
+
+ timeout := time.NewTimer(duration)
+ defer timeout.Stop()
+
+ select {
+ case <-conn.ctx.Done():
+ case <-timeout.C:
+ }
+}
+
+func (conn *Conn) isRelayed() bool {
+ if conn.statusRelay == StatusDisconnected && (conn.statusICE == StatusDisconnected || conn.statusICE == StatusConnecting) {
+ return false
+ }
+
+ if conn.currentConnPriority == connPriorityICEP2P {
+ return false
+ }
+
+ return true
+}
+
+func (conn *Conn) evalStatus() ConnStatus {
+ if conn.statusRelay == StatusConnected || conn.statusICE == StatusConnected {
+ return StatusConnected
+ }
+
+ if conn.statusRelay == StatusConnecting || conn.statusICE == StatusConnecting {
+ return StatusConnecting
+ }
+
+ return StatusDisconnected
+}
+
+func (conn *Conn) isConnected() bool {
+ conn.mu.Lock()
+ defer conn.mu.Unlock()
+
+ if conn.statusICE != StatusConnected && conn.statusICE != StatusConnecting {
+ return false
+ }
+
+ if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
+ if conn.statusRelay != StatusConnected {
+ return false
+ }
+ }
+
+ return true
+}
+
+func (conn *Conn) freeUpConnID() {
+ if conn.connIDRelay != "" {
+ for _, hook := range conn.afterRemovePeerHooks {
+ if err := hook(conn.connIDRelay); err != nil {
+ conn.log.Errorf("After remove peer hook failed: %v", err)
+ }
+ }
+ conn.connIDRelay = ""
+ }
+
+ if conn.connIDICE != "" {
+ for _, hook := range conn.afterRemovePeerHooks {
+ if err := hook(conn.connIDICE); err != nil {
+ conn.log.Errorf("After remove peer hook failed: %v", err)
+ }
+ }
+ conn.connIDICE = ""
+ }
+}
+
+func (conn *Conn) getEndpointForICEConnInfo(iceConnInfo ICEConnInfo) (net.Addr, wgproxy.Proxy, error) {
+ if !iceConnInfo.RelayedOnLocal {
+ return iceConnInfo.RemoteConn.RemoteAddr(), nil, nil
+ }
+ conn.log.Debugf("setup ice turn connection")
+ wgProxy := conn.wgProxyFactory.GetProxy(conn.ctx)
+ ep, err := wgProxy.AddTurnConn(iceConnInfo.RemoteConn)
+ if err != nil {
+ conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
+ err = wgProxy.CloseConn()
+ if err != nil {
+ conn.log.Warnf("failed to close turn proxy connection: %v", err)
+ }
+ return nil, nil, err
+ }
+ return ep, wgProxy, nil
+}
+
+func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool {
+ return remoteRosenpassPubKey != nil
+}
+
+// wgConfigWorkaround is a workaround for the issue with WireGuard configuration update
+// When update a peer configuration in near to each other time, the second update can be ignored by WireGuard
+func wgConfigWorkaround() {
+ time.Sleep(100 * time.Millisecond)
}
diff --git a/client/internal/peer/conn_test.go b/client/internal/peer/conn_test.go
index c124208d1..59f249b82 100644
--- a/client/internal/peer/conn_test.go
+++ b/client/internal/peer/conn_test.go
@@ -2,6 +2,7 @@ package peer
import (
"context"
+ "os"
"sync"
"testing"
"time"
@@ -11,6 +12,7 @@ import (
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/internal/wgproxy"
"github.com/netbirdio/netbird/iface"
+ "github.com/netbirdio/netbird/util"
)
var connConf = ConnConfig{
@@ -23,6 +25,12 @@ var connConf = ConnConfig{
},
}
+func TestMain(m *testing.M) {
+ _ = util.InitLog("trace", "console")
+ code := m.Run()
+ os.Exit(code)
+}
+
func TestNewConn_interfaceFilter(t *testing.T) {
ignore := []string{iface.WgInterfaceDefault, "tun0", "zt", "ZeroTier", "utun", "wg", "ts",
"Tailscale", "tailscale"}
@@ -40,7 +48,7 @@ func TestConn_GetKey(t *testing.T) {
defer func() {
_ = wgProxyFactory.Free()
}()
- conn, err := NewConn(connConf, nil, wgProxyFactory, nil, nil)
+ conn, err := NewConn(context.Background(), connConf, nil, wgProxyFactory, nil, nil, nil)
if err != nil {
return
}
@@ -55,7 +63,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
defer func() {
_ = wgProxyFactory.Free()
}()
- conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil)
+ conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil)
if err != nil {
return
}
@@ -63,7 +71,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
wg := sync.WaitGroup{}
wg.Add(2)
go func() {
- <-conn.remoteOffersCh
+ <-conn.handshaker.remoteOffersCh
wg.Done()
}()
@@ -92,7 +100,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
defer func() {
_ = wgProxyFactory.Free()
}()
- conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil)
+ conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil)
if err != nil {
return
}
@@ -100,7 +108,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
wg := sync.WaitGroup{}
wg.Add(2)
go func() {
- <-conn.remoteAnswerCh
+ <-conn.handshaker.remoteAnswerCh
wg.Done()
}()
@@ -128,58 +136,33 @@ func TestConn_Status(t *testing.T) {
defer func() {
_ = wgProxyFactory.Free()
}()
- conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil)
+ conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil)
if err != nil {
return
}
tables := []struct {
- name string
- status ConnStatus
- want ConnStatus
+ name string
+ statusIce ConnStatus
+ statusRelay ConnStatus
+ want ConnStatus
}{
- {"StatusConnected", StatusConnected, StatusConnected},
- {"StatusDisconnected", StatusDisconnected, StatusDisconnected},
- {"StatusConnecting", StatusConnecting, StatusConnecting},
+ {"StatusConnected", StatusConnected, StatusConnected, StatusConnected},
+ {"StatusDisconnected", StatusDisconnected, StatusDisconnected, StatusDisconnected},
+ {"StatusConnecting", StatusConnecting, StatusConnecting, StatusConnecting},
+ {"StatusConnectingIce", StatusConnecting, StatusDisconnected, StatusConnecting},
+ {"StatusConnectingIceAlternative", StatusConnecting, StatusConnected, StatusConnected},
+ {"StatusConnectingRelay", StatusDisconnected, StatusConnecting, StatusConnecting},
+ {"StatusConnectingRelayAlternative", StatusConnected, StatusConnecting, StatusConnected},
}
for _, table := range tables {
t.Run(table.name, func(t *testing.T) {
- conn.status = table.status
+ conn.statusICE = table.statusIce
+ conn.statusRelay = table.statusRelay
got := conn.Status()
assert.Equal(t, got, table.want, "they should be equal")
})
}
}
-
-func TestConn_Close(t *testing.T) {
- wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
- defer func() {
- _ = wgProxyFactory.Free()
- }()
- conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil)
- if err != nil {
- return
- }
-
- wg := sync.WaitGroup{}
- wg.Add(1)
- go func() {
- <-conn.closeCh
- wg.Done()
- }()
-
- go func() {
- for {
- err := conn.Close()
- if err != nil {
- continue
- } else {
- return
- }
- }
- }()
-
- wg.Wait()
-}
diff --git a/client/internal/peer/handshaker.go b/client/internal/peer/handshaker.go
new file mode 100644
index 000000000..545f81966
--- /dev/null
+++ b/client/internal/peer/handshaker.go
@@ -0,0 +1,192 @@
+package peer
+
+import (
+ "context"
+ "errors"
+ "sync"
+
+ log "github.com/sirupsen/logrus"
+
+ "github.com/netbirdio/netbird/version"
+)
+
+var (
+ ErrSignalIsNotReady = errors.New("signal is not ready")
+)
+
+// IceCredentials ICE protocol credentials struct
+type IceCredentials struct {
+ UFrag string
+ Pwd string
+}
+
+// OfferAnswer represents a session establishment offer or answer
+type OfferAnswer struct {
+ IceCredentials IceCredentials
+ // WgListenPort is a remote WireGuard listen port.
+ // This field is used when establishing a direct WireGuard connection without any proxy.
+ // We can set the remote peer's endpoint with this port.
+ WgListenPort int
+
+ // Version of NetBird Agent
+ Version string
+ // RosenpassPubKey is the Rosenpass public key of the remote peer when receiving this message
+ // This value is the local Rosenpass server public key when sending the message
+ RosenpassPubKey []byte
+ // RosenpassAddr is the Rosenpass server address (IP:port) of the remote peer when receiving this message
+ // This value is the local Rosenpass server address when sending the message
+ RosenpassAddr string
+
+ // relay server address
+ RelaySrvAddress string
+}
+
+type Handshaker struct {
+ mu sync.Mutex
+ ctx context.Context
+ log *log.Entry
+ config ConnConfig
+ signaler *Signaler
+ ice *WorkerICE
+ relay *WorkerRelay
+ onNewOfferListeners []func(*OfferAnswer)
+
+ // remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
+ remoteOffersCh chan OfferAnswer
+ // remoteAnswerCh is a channel used to wait for remote credentials answer (confirmation of our offer) to proceed with the connection
+ remoteAnswerCh chan OfferAnswer
+}
+
+func NewHandshaker(ctx context.Context, log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay) *Handshaker {
+ return &Handshaker{
+ ctx: ctx,
+ log: log,
+ config: config,
+ signaler: signaler,
+ ice: ice,
+ relay: relay,
+ remoteOffersCh: make(chan OfferAnswer),
+ remoteAnswerCh: make(chan OfferAnswer),
+ }
+}
+
+func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAnswer)) {
+ h.onNewOfferListeners = append(h.onNewOfferListeners, offer)
+}
+
+func (h *Handshaker) Listen() {
+ for {
+ h.log.Debugf("wait for remote offer confirmation")
+ remoteOfferAnswer, err := h.waitForRemoteOfferConfirmation()
+ if err != nil {
+ var connectionClosedError *ConnectionClosedError
+ if errors.As(err, &connectionClosedError) {
+ h.log.Tracef("stop handshaker")
+ return
+ }
+ h.log.Errorf("failed to received remote offer confirmation: %s", err)
+ continue
+ }
+
+ h.log.Debugf("received connection confirmation, running version %s and with remote WireGuard listen port %d", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort)
+ for _, listener := range h.onNewOfferListeners {
+ go listener(remoteOfferAnswer)
+ }
+ }
+}
+
+func (h *Handshaker) SendOffer() error {
+ h.mu.Lock()
+ defer h.mu.Unlock()
+ return h.sendOffer()
+}
+
+// OnRemoteOffer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
+// doesn't block, discards the message if connection wasn't ready
+func (h *Handshaker) OnRemoteOffer(offer OfferAnswer) bool {
+ select {
+ case h.remoteOffersCh <- offer:
+ return true
+ default:
+ h.log.Debugf("OnRemoteOffer skipping message because is not ready")
+ // connection might not be ready yet to receive so we ignore the message
+ return false
+ }
+}
+
+// OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
+// doesn't block, discards the message if connection wasn't ready
+func (h *Handshaker) OnRemoteAnswer(answer OfferAnswer) bool {
+ select {
+ case h.remoteAnswerCh <- answer:
+ return true
+ default:
+ // connection might not be ready yet to receive so we ignore the message
+ h.log.Debugf("OnRemoteAnswer skipping message because is not ready")
+ return false
+ }
+}
+
+func (h *Handshaker) waitForRemoteOfferConfirmation() (*OfferAnswer, error) {
+ select {
+ case remoteOfferAnswer := <-h.remoteOffersCh:
+ // received confirmation from the remote peer -> ready to proceed
+ err := h.sendAnswer()
+ if err != nil {
+ return nil, err
+ }
+ return &remoteOfferAnswer, nil
+ case remoteOfferAnswer := <-h.remoteAnswerCh:
+ return &remoteOfferAnswer, nil
+ case <-h.ctx.Done():
+ // closed externally
+ return nil, NewConnectionClosedError(h.config.Key)
+ }
+}
+
+// sendOffer prepares local user credentials and signals them to the remote peer
+func (h *Handshaker) sendOffer() error {
+ if !h.signaler.Ready() {
+ return ErrSignalIsNotReady
+ }
+
+ iceUFrag, icePwd := h.ice.GetLocalUserCredentials()
+ offer := OfferAnswer{
+ IceCredentials: IceCredentials{iceUFrag, icePwd},
+ WgListenPort: h.config.LocalWgPort,
+ Version: version.NetbirdVersion(),
+ RosenpassPubKey: h.config.RosenpassPubKey,
+ RosenpassAddr: h.config.RosenpassAddr,
+ }
+
+ addr, err := h.relay.RelayInstanceAddress()
+ if err == nil {
+ offer.RelaySrvAddress = addr
+ }
+
+ return h.signaler.SignalOffer(offer, h.config.Key)
+}
+
+func (h *Handshaker) sendAnswer() error {
+ h.log.Debugf("sending answer")
+ uFrag, pwd := h.ice.GetLocalUserCredentials()
+
+ answer := OfferAnswer{
+ IceCredentials: IceCredentials{uFrag, pwd},
+ WgListenPort: h.config.LocalWgPort,
+ Version: version.NetbirdVersion(),
+ RosenpassPubKey: h.config.RosenpassPubKey,
+ RosenpassAddr: h.config.RosenpassAddr,
+ }
+ addr, err := h.relay.RelayInstanceAddress()
+ if err == nil {
+ answer.RelaySrvAddress = addr
+ }
+
+ err = h.signaler.SignalAnswer(answer, h.config.Key)
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
diff --git a/client/internal/peer/signaler.go b/client/internal/peer/signaler.go
new file mode 100644
index 000000000..713123e5d
--- /dev/null
+++ b/client/internal/peer/signaler.go
@@ -0,0 +1,70 @@
+package peer
+
+import (
+ "github.com/pion/ice/v3"
+ "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
+
+ signal "github.com/netbirdio/netbird/signal/client"
+ sProto "github.com/netbirdio/netbird/signal/proto"
+)
+
+type Signaler struct {
+ signal signal.Client
+ wgPrivateKey wgtypes.Key
+}
+
+func NewSignaler(signal signal.Client, wgPrivateKey wgtypes.Key) *Signaler {
+ return &Signaler{
+ signal: signal,
+ wgPrivateKey: wgPrivateKey,
+ }
+}
+
+func (s *Signaler) SignalOffer(offer OfferAnswer, remoteKey string) error {
+ return s.signalOfferAnswer(offer, remoteKey, sProto.Body_OFFER)
+}
+
+func (s *Signaler) SignalAnswer(offer OfferAnswer, remoteKey string) error {
+ return s.signalOfferAnswer(offer, remoteKey, sProto.Body_ANSWER)
+}
+
+func (s *Signaler) SignalICECandidate(candidate ice.Candidate, remoteKey string) error {
+ return s.signal.Send(&sProto.Message{
+ Key: s.wgPrivateKey.PublicKey().String(),
+ RemoteKey: remoteKey,
+ Body: &sProto.Body{
+ Type: sProto.Body_CANDIDATE,
+ Payload: candidate.Marshal(),
+ },
+ })
+}
+
+func (s *Signaler) Ready() bool {
+ return s.signal.Ready()
+}
+
+// SignalOfferAnswer signals either an offer or an answer to remote peer
+func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string, bodyType sProto.Body_Type) error {
+ msg, err := signal.MarshalCredential(
+ s.wgPrivateKey,
+ offerAnswer.WgListenPort,
+ remoteKey,
+ &signal.Credential{
+ UFrag: offerAnswer.IceCredentials.UFrag,
+ Pwd: offerAnswer.IceCredentials.Pwd,
+ },
+ bodyType,
+ offerAnswer.RosenpassPubKey,
+ offerAnswer.RosenpassAddr,
+ offerAnswer.RelaySrvAddress)
+ if err != nil {
+ return err
+ }
+
+ err = s.signal.Send(msg)
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go
index a7cfb95c4..f116f3fef 100644
--- a/client/internal/peer/status.go
+++ b/client/internal/peer/status.go
@@ -3,6 +3,7 @@ package peer
import (
"errors"
"net/netip"
+ "slices"
"sync"
"time"
@@ -13,6 +14,7 @@ import (
"github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/management/domain"
+ relayClient "github.com/netbirdio/netbird/relay/client"
)
// State contains the latest state of a peer
@@ -24,11 +26,11 @@ type State struct {
ConnStatus ConnStatus
ConnStatusUpdate time.Time
Relayed bool
- Direct bool
LocalIceCandidateType string
RemoteIceCandidateType string
LocalIceCandidateEndpoint string
RemoteIceCandidateEndpoint string
+ RelayServerAddress string
LastWireguardHandshake time.Time
BytesTx int64
BytesRx int64
@@ -142,6 +144,8 @@ type Status struct {
// Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events
// set to true this variable and at the end of the processing we will reset it by the FinishPeerListModifications()
peerListChangedForNotification bool
+
+ relayMgr *relayClient.Manager
}
// NewRecorder returns a new Status instance
@@ -156,6 +160,12 @@ func NewRecorder(mgmAddress string) *Status {
}
}
+func (d *Status) SetRelayMgr(manager *relayClient.Manager) {
+ d.mux.Lock()
+ defer d.mux.Unlock()
+ d.relayMgr = manager
+}
+
// ReplaceOfflinePeers replaces
func (d *Status) ReplaceOfflinePeers(replacement []State) {
d.mux.Lock()
@@ -231,17 +241,17 @@ func (d *Status) UpdatePeerState(receivedState State) error {
peerState.SetRoutes(receivedState.GetRoutes())
}
- skipNotification := shouldSkipNotify(receivedState, peerState)
+ skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState)
if receivedState.ConnStatus != peerState.ConnStatus {
peerState.ConnStatus = receivedState.ConnStatus
peerState.ConnStatusUpdate = receivedState.ConnStatusUpdate
- peerState.Direct = receivedState.Direct
peerState.Relayed = receivedState.Relayed
peerState.LocalIceCandidateType = receivedState.LocalIceCandidateType
peerState.RemoteIceCandidateType = receivedState.RemoteIceCandidateType
peerState.LocalIceCandidateEndpoint = receivedState.LocalIceCandidateEndpoint
peerState.RemoteIceCandidateEndpoint = receivedState.RemoteIceCandidateEndpoint
+ peerState.RelayServerAddress = receivedState.RelayServerAddress
peerState.RosenpassEnabled = receivedState.RosenpassEnabled
}
@@ -261,6 +271,146 @@ func (d *Status) UpdatePeerState(receivedState State) error {
return nil
}
+func (d *Status) UpdatePeerICEState(receivedState State) error {
+ d.mux.Lock()
+ defer d.mux.Unlock()
+
+ peerState, ok := d.peers[receivedState.PubKey]
+ if !ok {
+ return errors.New("peer doesn't exist")
+ }
+
+ if receivedState.IP != "" {
+ peerState.IP = receivedState.IP
+ }
+
+ skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState)
+
+ peerState.ConnStatus = receivedState.ConnStatus
+ peerState.ConnStatusUpdate = receivedState.ConnStatusUpdate
+ peerState.Relayed = receivedState.Relayed
+ peerState.LocalIceCandidateType = receivedState.LocalIceCandidateType
+ peerState.RemoteIceCandidateType = receivedState.RemoteIceCandidateType
+ peerState.LocalIceCandidateEndpoint = receivedState.LocalIceCandidateEndpoint
+ peerState.RemoteIceCandidateEndpoint = receivedState.RemoteIceCandidateEndpoint
+ peerState.RosenpassEnabled = receivedState.RosenpassEnabled
+
+ d.peers[receivedState.PubKey] = peerState
+
+ if skipNotification {
+ return nil
+ }
+
+ ch, found := d.changeNotify[receivedState.PubKey]
+ if found && ch != nil {
+ close(ch)
+ d.changeNotify[receivedState.PubKey] = nil
+ }
+
+ d.notifyPeerListChanged()
+ return nil
+}
+
+func (d *Status) UpdatePeerRelayedState(receivedState State) error {
+ d.mux.Lock()
+ defer d.mux.Unlock()
+
+ peerState, ok := d.peers[receivedState.PubKey]
+ if !ok {
+ return errors.New("peer doesn't exist")
+ }
+
+ skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState)
+
+ peerState.ConnStatus = receivedState.ConnStatus
+ peerState.ConnStatusUpdate = receivedState.ConnStatusUpdate
+ peerState.Relayed = receivedState.Relayed
+ peerState.RelayServerAddress = receivedState.RelayServerAddress
+ peerState.RosenpassEnabled = receivedState.RosenpassEnabled
+
+ d.peers[receivedState.PubKey] = peerState
+
+ if skipNotification {
+ return nil
+ }
+
+ ch, found := d.changeNotify[receivedState.PubKey]
+ if found && ch != nil {
+ close(ch)
+ d.changeNotify[receivedState.PubKey] = nil
+ }
+
+ d.notifyPeerListChanged()
+ return nil
+}
+
+func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error {
+ d.mux.Lock()
+ defer d.mux.Unlock()
+
+ peerState, ok := d.peers[receivedState.PubKey]
+ if !ok {
+ return errors.New("peer doesn't exist")
+ }
+
+ skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState)
+
+ peerState.ConnStatus = receivedState.ConnStatus
+ peerState.Relayed = receivedState.Relayed
+ peerState.ConnStatusUpdate = receivedState.ConnStatusUpdate
+ peerState.RelayServerAddress = ""
+
+ d.peers[receivedState.PubKey] = peerState
+
+ if skipNotification {
+ return nil
+ }
+
+ ch, found := d.changeNotify[receivedState.PubKey]
+ if found && ch != nil {
+ close(ch)
+ d.changeNotify[receivedState.PubKey] = nil
+ }
+
+ d.notifyPeerListChanged()
+ return nil
+}
+
+func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
+ d.mux.Lock()
+ defer d.mux.Unlock()
+
+ peerState, ok := d.peers[receivedState.PubKey]
+ if !ok {
+ return errors.New("peer doesn't exist")
+ }
+
+ skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState)
+
+ peerState.ConnStatus = receivedState.ConnStatus
+ peerState.Relayed = receivedState.Relayed
+ peerState.ConnStatusUpdate = receivedState.ConnStatusUpdate
+ peerState.LocalIceCandidateType = receivedState.LocalIceCandidateType
+ peerState.RemoteIceCandidateType = receivedState.RemoteIceCandidateType
+ peerState.LocalIceCandidateEndpoint = receivedState.LocalIceCandidateEndpoint
+ peerState.RemoteIceCandidateEndpoint = receivedState.RemoteIceCandidateEndpoint
+
+ d.peers[receivedState.PubKey] = peerState
+
+ if skipNotification {
+ return nil
+ }
+
+ ch, found := d.changeNotify[receivedState.PubKey]
+ if found && ch != nil {
+ close(ch)
+ d.changeNotify[receivedState.PubKey] = nil
+ }
+
+ d.notifyPeerListChanged()
+ return nil
+}
+
// UpdateWireGuardPeerState updates the WireGuard bits of the peer state
func (d *Status) UpdateWireGuardPeerState(pubKey string, wgStats iface.WGStats) error {
d.mux.Lock()
@@ -280,13 +430,13 @@ func (d *Status) UpdateWireGuardPeerState(pubKey string, wgStats iface.WGStats)
return nil
}
-func shouldSkipNotify(received, curr State) bool {
+func shouldSkipNotify(receivedConnStatus ConnStatus, curr State) bool {
switch {
- case received.ConnStatus == StatusConnecting:
+ case receivedConnStatus == StatusConnecting:
return true
- case received.ConnStatus == StatusDisconnected && curr.ConnStatus == StatusConnecting:
+ case receivedConnStatus == StatusDisconnected && curr.ConnStatus == StatusConnecting:
return true
- case received.ConnStatus == StatusDisconnected && curr.ConnStatus == StatusDisconnected:
+ case receivedConnStatus == StatusDisconnected && curr.ConnStatus == StatusDisconnected:
return curr.IP != ""
default:
return false
@@ -502,8 +652,35 @@ func (d *Status) GetSignalState() SignalState {
}
}
+// GetRelayStates returns the stun/turn/permanent relay states
func (d *Status) GetRelayStates() []relay.ProbeResult {
- return d.relayStates
+ if d.relayMgr == nil {
+ return d.relayStates
+ }
+
+ // extend the list of stun, turn servers with relay address
+ relayStates := slices.Clone(d.relayStates)
+
+ var relayState relay.ProbeResult
+
+ // if the server connection is not established then we will use the general address
+ // in case of connection we will use the instance specific address
+ instanceAddr, err := d.relayMgr.RelayInstanceAddress()
+ if err != nil {
+ // TODO add their status
+ if errors.Is(err, relayClient.ErrRelayClientNotConnected) {
+ for _, r := range d.relayMgr.ServerURLs() {
+ relayStates = append(relayStates, relay.ProbeResult{
+ URI: r,
+ })
+ }
+ return relayStates
+ }
+ relayState.Err = err
+ }
+
+ relayState.URI = instanceAddr
+ return append(relayStates, relayState)
}
func (d *Status) GetDNSStates() []NSGroupState {
@@ -535,7 +712,6 @@ func (d *Status) GetFullStatus() FullStatus {
}
fullStatus.Peers = append(fullStatus.Peers, d.offlinePeers...)
-
return fullStatus
}
diff --git a/client/internal/peer/status_test.go b/client/internal/peer/status_test.go
index a4a6e6081..1d283433b 100644
--- a/client/internal/peer/status_test.go
+++ b/client/internal/peer/status_test.go
@@ -2,8 +2,8 @@ package peer
import (
"errors"
- "testing"
"sync"
+ "testing"
"github.com/stretchr/testify/assert"
)
@@ -43,7 +43,7 @@ func TestUpdatePeerState(t *testing.T) {
status := NewRecorder("https://mgm")
peerState := State{
PubKey: key,
- Mux: new(sync.RWMutex),
+ Mux: new(sync.RWMutex),
}
status.peers[key] = peerState
@@ -64,7 +64,7 @@ func TestStatus_UpdatePeerFQDN(t *testing.T) {
status := NewRecorder("https://mgm")
peerState := State{
PubKey: key,
- Mux: new(sync.RWMutex),
+ Mux: new(sync.RWMutex),
}
status.peers[key] = peerState
@@ -83,7 +83,7 @@ func TestGetPeerStateChangeNotifierLogic(t *testing.T) {
status := NewRecorder("https://mgm")
peerState := State{
PubKey: key,
- Mux: new(sync.RWMutex),
+ Mux: new(sync.RWMutex),
}
status.peers[key] = peerState
@@ -108,7 +108,7 @@ func TestRemovePeer(t *testing.T) {
status := NewRecorder("https://mgm")
peerState := State{
PubKey: key,
- Mux: new(sync.RWMutex),
+ Mux: new(sync.RWMutex),
}
status.peers[key] = peerState
diff --git a/client/internal/peer/stdnet.go b/client/internal/peer/stdnet.go
index 1faa30ce3..ae31ebbf0 100644
--- a/client/internal/peer/stdnet.go
+++ b/client/internal/peer/stdnet.go
@@ -6,6 +6,6 @@ import (
"github.com/netbirdio/netbird/client/internal/stdnet"
)
-func (conn *Conn) newStdNet() (*stdnet.Net, error) {
- return stdnet.NewNet(conn.config.ICEConfig.InterfaceBlackList)
+func (w *WorkerICE) newStdNet() (*stdnet.Net, error) {
+ return stdnet.NewNet(w.config.ICEConfig.InterfaceBlackList)
}
diff --git a/client/internal/peer/stdnet_android.go b/client/internal/peer/stdnet_android.go
index 90865242b..b411405bb 100644
--- a/client/internal/peer/stdnet_android.go
+++ b/client/internal/peer/stdnet_android.go
@@ -2,6 +2,6 @@ package peer
import "github.com/netbirdio/netbird/client/internal/stdnet"
-func (conn *Conn) newStdNet() (*stdnet.Net, error) {
- return stdnet.NewNetWithDiscover(conn.iFaceDiscover, conn.config.ICEConfig.InterfaceBlackList)
+func (w *WorkerICE) newStdNet() (*stdnet.Net, error) {
+ return stdnet.NewNetWithDiscover(w.iFaceDiscover, w.config.ICEConfig.InterfaceBlackList)
}
diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go
new file mode 100644
index 000000000..8bf1b7568
--- /dev/null
+++ b/client/internal/peer/worker_ice.go
@@ -0,0 +1,470 @@
+package peer
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "net/netip"
+ "runtime"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/pion/ice/v3"
+ "github.com/pion/randutil"
+ "github.com/pion/stun/v2"
+ log "github.com/sirupsen/logrus"
+
+ "github.com/netbirdio/netbird/client/internal/stdnet"
+ "github.com/netbirdio/netbird/iface"
+ "github.com/netbirdio/netbird/iface/bind"
+ "github.com/netbirdio/netbird/route"
+)
+
+const (
+ iceKeepAliveDefault = 4 * time.Second
+ iceDisconnectedTimeoutDefault = 6 * time.Second
+ // iceRelayAcceptanceMinWaitDefault is the same as in the Pion ICE package
+ iceRelayAcceptanceMinWaitDefault = 2 * time.Second
+
+ lenUFrag = 16
+ lenPwd = 32
+ runesAlpha = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
+)
+
+var (
+ failedTimeout = 6 * time.Second
+)
+
+type ICEConfig struct {
+ // StunTurn is a list of STUN and TURN URLs
+ StunTurn *atomic.Value // []*stun.URI
+
+ // InterfaceBlackList is a list of machine interfaces that should be filtered out by ICE Candidate gathering
+ // (e.g. if eth0 is in the list, host candidate of this interface won't be used)
+ InterfaceBlackList []string
+ DisableIPv6Discovery bool
+
+ UDPMux ice.UDPMux
+ UDPMuxSrflx ice.UniversalUDPMux
+
+ NATExternalIPs []string
+}
+
+type ICEConnInfo struct {
+ RemoteConn net.Conn
+ RosenpassPubKey []byte
+ RosenpassAddr string
+ LocalIceCandidateType string
+ RemoteIceCandidateType string
+ RemoteIceCandidateEndpoint string
+ LocalIceCandidateEndpoint string
+ Relayed bool
+ RelayedOnLocal bool
+}
+
+type WorkerICECallbacks struct {
+ OnConnReady func(ConnPriority, ICEConnInfo)
+ OnStatusChanged func(ConnStatus)
+}
+
+type WorkerICE struct {
+ ctx context.Context
+ log *log.Entry
+ config ConnConfig
+ signaler *Signaler
+ iFaceDiscover stdnet.ExternalIFaceDiscover
+ statusRecorder *Status
+ hasRelayOnLocally bool
+ conn WorkerICECallbacks
+
+ selectedPriority ConnPriority
+
+ agent *ice.Agent
+ muxAgent sync.Mutex
+
+ StunTurn []*stun.URI
+
+ sentExtraSrflx bool
+
+ localUfrag string
+ localPwd string
+}
+
+func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool, callBacks WorkerICECallbacks) (*WorkerICE, error) {
+ w := &WorkerICE{
+ ctx: ctx,
+ log: log,
+ config: config,
+ signaler: signaler,
+ iFaceDiscover: ifaceDiscover,
+ statusRecorder: statusRecorder,
+ hasRelayOnLocally: hasRelayOnLocally,
+ conn: callBacks,
+ }
+
+ localUfrag, localPwd, err := generateICECredentials()
+ if err != nil {
+ return nil, err
+ }
+ w.localUfrag = localUfrag
+ w.localPwd = localPwd
+ return w, nil
+}
+
+func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
+ w.log.Debugf("OnNewOffer for ICE")
+ w.muxAgent.Lock()
+
+ if w.agent != nil {
+ w.log.Debugf("agent already exists, skipping the offer")
+ w.muxAgent.Unlock()
+ return
+ }
+
+ var preferredCandidateTypes []ice.CandidateType
+ if w.hasRelayOnLocally && remoteOfferAnswer.RelaySrvAddress != "" {
+ w.selectedPriority = connPriorityICEP2P
+ preferredCandidateTypes = candidateTypesP2P()
+ } else {
+ w.selectedPriority = connPriorityICETurn
+ preferredCandidateTypes = candidateTypes()
+ }
+
+ w.log.Debugf("recreate ICE agent")
+ agentCtx, agentCancel := context.WithCancel(w.ctx)
+ agent, err := w.reCreateAgent(agentCancel, preferredCandidateTypes)
+ if err != nil {
+ w.log.Errorf("failed to recreate ICE Agent: %s", err)
+ w.muxAgent.Unlock()
+ return
+ }
+ w.agent = agent
+ w.muxAgent.Unlock()
+
+ w.log.Debugf("gather candidates")
+ err = w.agent.GatherCandidates()
+ if err != nil {
+ w.log.Debugf("failed to gather candidates: %s", err)
+ return
+ }
+
+ // will block until connection succeeded
+ // but it won't release if ICE Agent went into Disconnected or Failed state,
+ // so we have to cancel it with the provided context once agent detected a broken connection
+ w.log.Debugf("turn agent dial")
+ remoteConn, err := w.turnAgentDial(agentCtx, remoteOfferAnswer)
+ if err != nil {
+ w.log.Debugf("failed to dial the remote peer: %s", err)
+ return
+ }
+ w.log.Debugf("agent dial succeeded")
+
+ pair, err := w.agent.GetSelectedCandidatePair()
+ if err != nil {
+ return
+ }
+
+ if !isRelayCandidate(pair.Local) {
+ // dynamically set remote WireGuard port if other side specified a different one from the default one
+ remoteWgPort := iface.DefaultWgPort
+ if remoteOfferAnswer.WgListenPort != 0 {
+ remoteWgPort = remoteOfferAnswer.WgListenPort
+ }
+
+ // To support old version's with direct mode we attempt to punch an additional role with the remote WireGuard port
+ go w.punchRemoteWGPort(pair, remoteWgPort)
+ }
+
+ ci := ICEConnInfo{
+ RemoteConn: remoteConn,
+ RosenpassPubKey: remoteOfferAnswer.RosenpassPubKey,
+ RosenpassAddr: remoteOfferAnswer.RosenpassAddr,
+ LocalIceCandidateType: pair.Local.Type().String(),
+ RemoteIceCandidateType: pair.Remote.Type().String(),
+ LocalIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Local.Address(), pair.Local.Port()),
+ RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Remote.Port()),
+ Relayed: isRelayed(pair),
+ RelayedOnLocal: isRelayCandidate(pair.Local),
+ }
+ w.log.Debugf("on ICE conn read to use ready")
+ go w.conn.OnConnReady(w.selectedPriority, ci)
+}
+
+// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
+func (w *WorkerICE) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) {
+ w.muxAgent.Lock()
+ defer w.muxAgent.Unlock()
+ w.log.Debugf("OnRemoteCandidate from peer %s -> %s", w.config.Key, candidate.String())
+ if w.agent == nil {
+ w.log.Warnf("ICE Agent is not initialized yet")
+ return
+ }
+
+ if candidateViaRoutes(candidate, haRoutes) {
+ return
+ }
+
+ err := w.agent.AddRemoteCandidate(candidate)
+ if err != nil {
+ w.log.Errorf("error while handling remote candidate")
+ return
+ }
+}
+
+func (w *WorkerICE) GetLocalUserCredentials() (frag string, pwd string) {
+ w.muxAgent.Lock()
+ defer w.muxAgent.Unlock()
+ return w.localUfrag, w.localPwd
+}
+
+func (w *WorkerICE) Close() {
+ w.muxAgent.Lock()
+ defer w.muxAgent.Unlock()
+
+ if w.agent == nil {
+ return
+ }
+
+ err := w.agent.Close()
+ if err != nil {
+ w.log.Warnf("failed to close ICE agent: %s", err)
+ }
+}
+
+func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, relaySupport []ice.CandidateType) (*ice.Agent, error) {
+ transportNet, err := w.newStdNet()
+ if err != nil {
+ w.log.Errorf("failed to create pion's stdnet: %s", err)
+ }
+
+ iceKeepAlive := iceKeepAlive()
+ iceDisconnectedTimeout := iceDisconnectedTimeout()
+ iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait()
+
+ agentConfig := &ice.AgentConfig{
+ MulticastDNSMode: ice.MulticastDNSModeDisabled,
+ NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6},
+ Urls: w.config.ICEConfig.StunTurn.Load().([]*stun.URI),
+ CandidateTypes: relaySupport,
+ InterfaceFilter: stdnet.InterfaceFilter(w.config.ICEConfig.InterfaceBlackList),
+ UDPMux: w.config.ICEConfig.UDPMux,
+ UDPMuxSrflx: w.config.ICEConfig.UDPMuxSrflx,
+ NAT1To1IPs: w.config.ICEConfig.NATExternalIPs,
+ Net: transportNet,
+ FailedTimeout: &failedTimeout,
+ DisconnectedTimeout: &iceDisconnectedTimeout,
+ KeepaliveInterval: &iceKeepAlive,
+ RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait,
+ LocalUfrag: w.localUfrag,
+ LocalPwd: w.localPwd,
+ }
+
+ if w.config.ICEConfig.DisableIPv6Discovery {
+ agentConfig.NetworkTypes = []ice.NetworkType{ice.NetworkTypeUDP4}
+ }
+
+ w.sentExtraSrflx = false
+ agent, err := ice.NewAgent(agentConfig)
+ if err != nil {
+ return nil, err
+ }
+
+ err = agent.OnCandidate(w.onICECandidate)
+ if err != nil {
+ return nil, err
+ }
+
+ err = agent.OnConnectionStateChange(func(state ice.ConnectionState) {
+ w.log.Debugf("ICE ConnectionState has changed to %s", state.String())
+ if state == ice.ConnectionStateFailed || state == ice.ConnectionStateDisconnected {
+ w.conn.OnStatusChanged(StatusDisconnected)
+
+ w.muxAgent.Lock()
+ agentCancel()
+ _ = agent.Close()
+ w.agent = nil
+
+ w.muxAgent.Unlock()
+ }
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ err = agent.OnSelectedCandidatePairChange(w.onICESelectedCandidatePair)
+ if err != nil {
+ return nil, err
+ }
+
+ err = agent.OnSuccessfulSelectedPairBindingResponse(func(p *ice.CandidatePair) {
+ err := w.statusRecorder.UpdateLatency(w.config.Key, p.Latency())
+ if err != nil {
+ w.log.Debugf("failed to update latency for peer: %s", err)
+ return
+ }
+ })
+ if err != nil {
+ return nil, fmt.Errorf("failed setting binding response callback: %w", err)
+ }
+
+ return agent, nil
+}
+
+func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
+ // wait local endpoint configuration
+ time.Sleep(time.Second)
+ addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", pair.Remote.Address(), remoteWgPort))
+ if err != nil {
+ w.log.Warnf("got an error while resolving the udp address, err: %s", err)
+ return
+ }
+
+ mux, ok := w.config.ICEConfig.UDPMuxSrflx.(*bind.UniversalUDPMuxDefault)
+ if !ok {
+ w.log.Warn("invalid udp mux conversion")
+ return
+ }
+ _, err = mux.GetSharedConn().WriteTo([]byte{0x6e, 0x62}, addr)
+ if err != nil {
+ w.log.Warnf("got an error while sending the punch packet, err: %s", err)
+ }
+}
+
+// onICECandidate is a callback attached to an ICE Agent to receive new local connection candidates
+// and then signals them to the remote peer
+func (w *WorkerICE) onICECandidate(candidate ice.Candidate) {
+ // nil means candidate gathering has been ended
+ if candidate == nil {
+ return
+ }
+
+ // TODO: reported port is incorrect for CandidateTypeHost, makes understanding ICE use via logs confusing as port is ignored
+ w.log.Debugf("discovered local candidate %s", candidate.String())
+ go func() {
+ err := w.signaler.SignalICECandidate(candidate, w.config.Key)
+ if err != nil {
+ w.log.Errorf("failed signaling candidate to the remote peer %s %s", w.config.Key, err)
+ }
+ }()
+
+ if !w.shouldSendExtraSrflxCandidate(candidate) {
+ return
+ }
+
+ // sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port)
+ // this is useful when network has an existing port forwarding rule for the wireguard port and this peer
+ extraSrflx, err := extraSrflxCandidate(candidate)
+ if err != nil {
+ w.log.Errorf("failed creating extra server reflexive candidate %s", err)
+ return
+ }
+ w.sentExtraSrflx = true
+
+ go func() {
+ err = w.signaler.SignalICECandidate(extraSrflx, w.config.Key)
+ if err != nil {
+ w.log.Errorf("failed signaling the extra server reflexive candidate: %s", err)
+ }
+ }()
+}
+
+func (w *WorkerICE) onICESelectedCandidatePair(c1 ice.Candidate, c2 ice.Candidate) {
+ w.log.Debugf("selected candidate pair [local <-> remote] -> [%s <-> %s], peer %s", c1.String(), c2.String(),
+ w.config.Key)
+}
+
+func (w *WorkerICE) shouldSendExtraSrflxCandidate(candidate ice.Candidate) bool {
+ if !w.sentExtraSrflx && candidate.Type() == ice.CandidateTypeServerReflexive && candidate.Port() != candidate.RelatedAddress().Port {
+ return true
+ }
+ return false
+}
+
+func (w *WorkerICE) turnAgentDial(ctx context.Context, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) {
+ isControlling := w.config.LocalKey > w.config.Key
+ if isControlling {
+ return w.agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
+ } else {
+ return w.agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
+ }
+}
+
+func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) {
+ relatedAdd := candidate.RelatedAddress()
+ return ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
+ Network: candidate.NetworkType().String(),
+ Address: candidate.Address(),
+ Port: relatedAdd.Port,
+ Component: candidate.Component(),
+ RelAddr: relatedAdd.Address,
+ RelPort: relatedAdd.Port,
+ })
+}
+
+func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool {
+ var routePrefixes []netip.Prefix
+ for _, routes := range clientRoutes {
+ if len(routes) > 0 && routes[0] != nil {
+ routePrefixes = append(routePrefixes, routes[0].Network)
+ }
+ }
+
+ addr, err := netip.ParseAddr(candidate.Address())
+ if err != nil {
+ log.Errorf("Failed to parse IP address %s: %v", candidate.Address(), err)
+ return false
+ }
+
+ for _, prefix := range routePrefixes {
+ // default route is
+ if prefix.Bits() == 0 {
+ continue
+ }
+
+ if prefix.Contains(addr) {
+ log.Debugf("Ignoring candidate [%s], its address is part of routed network %s", candidate.String(), prefix)
+ return true
+ }
+ }
+ return false
+}
+
+func candidateTypes() []ice.CandidateType {
+ if hasICEForceRelayConn() {
+ return []ice.CandidateType{ice.CandidateTypeRelay}
+ }
+ // TODO: remove this once we have refactored userspace proxy into the bind package
+ if runtime.GOOS == "ios" {
+ return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive}
+ }
+ return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay}
+}
+
+func candidateTypesP2P() []ice.CandidateType {
+ return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive}
+}
+
+func isRelayCandidate(candidate ice.Candidate) bool {
+ return candidate.Type() == ice.CandidateTypeRelay
+}
+
+func isRelayed(pair *ice.CandidatePair) bool {
+ if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
+ return true
+ }
+ return false
+}
+
+func generateICECredentials() (string, string, error) {
+ ufrag, err := randutil.GenerateCryptoRandomString(lenUFrag, runesAlpha)
+ if err != nil {
+ return "", "", err
+ }
+
+ pwd, err := randutil.GenerateCryptoRandomString(lenPwd, runesAlpha)
+ if err != nil {
+ return "", "", err
+ }
+ return ufrag, pwd, nil
+}
diff --git a/client/internal/peer/worker_relay.go b/client/internal/peer/worker_relay.go
new file mode 100644
index 000000000..930a8f5b6
--- /dev/null
+++ b/client/internal/peer/worker_relay.go
@@ -0,0 +1,223 @@
+package peer
+
+import (
+ "context"
+ "errors"
+ "net"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+
+ relayClient "github.com/netbirdio/netbird/relay/client"
+)
+
+var (
+ wgHandshakePeriod = 2 * time.Minute
+ wgHandshakeOvertime = 30 * time.Second
+)
+
+type RelayConnInfo struct {
+ relayedConn net.Conn
+ rosenpassPubKey []byte
+ rosenpassAddr string
+}
+
+type WorkerRelayCallbacks struct {
+ OnConnReady func(RelayConnInfo)
+ OnDisconnected func()
+}
+
+type WorkerRelay struct {
+ log *log.Entry
+ config ConnConfig
+ relayManager relayClient.ManagerService
+ callBacks WorkerRelayCallbacks
+
+ relayedConn net.Conn
+ relayLock sync.Mutex
+ ctxWgWatch context.Context
+ ctxCancelWgWatch context.CancelFunc
+ ctxLock sync.Mutex
+
+ relaySupportedOnRemotePeer atomic.Bool
+}
+
+func NewWorkerRelay(log *log.Entry, config ConnConfig, relayManager relayClient.ManagerService, callbacks WorkerRelayCallbacks) *WorkerRelay {
+ r := &WorkerRelay{
+ log: log,
+ config: config,
+ relayManager: relayManager,
+ callBacks: callbacks,
+ }
+ return r
+}
+
+func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
+ if !w.isRelaySupported(remoteOfferAnswer) {
+ w.log.Infof("Relay is not supported by remote peer")
+ w.relaySupportedOnRemotePeer.Store(false)
+ return
+ }
+ w.relaySupportedOnRemotePeer.Store(true)
+
+ // the relayManager will return with error in case if the connection has lost with relay server
+ currentRelayAddress, err := w.relayManager.RelayInstanceAddress()
+ if err != nil {
+ w.log.Errorf("failed to handle new offer: %s", err)
+ return
+ }
+
+ srv := w.preferredRelayServer(currentRelayAddress, remoteOfferAnswer.RelaySrvAddress)
+
+ relayedConn, err := w.relayManager.OpenConn(srv, w.config.Key)
+ if err != nil {
+ if errors.Is(err, relayClient.ErrConnAlreadyExists) {
+ w.log.Infof("do not need to reopen relay connection")
+ return
+ }
+ w.log.Errorf("failed to open connection via Relay: %s", err)
+ return
+ }
+ w.relayLock.Lock()
+ w.relayedConn = relayedConn
+ w.relayLock.Unlock()
+
+ err = w.relayManager.AddCloseListener(srv, w.onRelayMGDisconnected)
+ if err != nil {
+ log.Errorf("failed to add close listener: %s", err)
+ _ = relayedConn.Close()
+ return
+ }
+
+ w.log.Debugf("peer conn opened via Relay: %s", srv)
+ go w.callBacks.OnConnReady(RelayConnInfo{
+ relayedConn: relayedConn,
+ rosenpassPubKey: remoteOfferAnswer.RosenpassPubKey,
+ rosenpassAddr: remoteOfferAnswer.RosenpassAddr,
+ })
+}
+
+func (w *WorkerRelay) EnableWgWatcher(ctx context.Context) {
+ w.log.Debugf("enable WireGuard watcher")
+ w.ctxLock.Lock()
+ defer w.ctxLock.Unlock()
+
+ if w.ctxWgWatch != nil && w.ctxWgWatch.Err() == nil {
+ return
+ }
+
+ ctx, ctxCancel := context.WithCancel(ctx)
+ go w.wgStateCheck(ctx)
+ w.ctxWgWatch = ctx
+ w.ctxCancelWgWatch = ctxCancel
+
+}
+
+func (w *WorkerRelay) DisableWgWatcher() {
+ w.ctxLock.Lock()
+ defer w.ctxLock.Unlock()
+
+ if w.ctxCancelWgWatch == nil {
+ return
+ }
+
+ w.log.Debugf("disable WireGuard watcher")
+
+ w.ctxCancelWgWatch()
+}
+
+func (w *WorkerRelay) RelayInstanceAddress() (string, error) {
+ return w.relayManager.RelayInstanceAddress()
+}
+
+func (w *WorkerRelay) IsRelayConnectionSupportedWithPeer() bool {
+ return w.relaySupportedOnRemotePeer.Load() && w.RelayIsSupportedLocally()
+}
+
+func (w *WorkerRelay) IsController() bool {
+ return w.config.LocalKey > w.config.Key
+}
+
+func (w *WorkerRelay) RelayIsSupportedLocally() bool {
+ return w.relayManager.HasRelayAddress()
+}
+
+func (w *WorkerRelay) CloseConn() {
+ w.relayLock.Lock()
+ defer w.relayLock.Unlock()
+ if w.relayedConn == nil {
+ return
+ }
+
+ err := w.relayedConn.Close()
+ if err != nil {
+ w.log.Warnf("failed to close relay connection: %v", err)
+ }
+}
+
+// wgStateCheck help to check the state of the wireguard handshake and relay connection
+func (w *WorkerRelay) wgStateCheck(ctx context.Context) {
+ timer := time.NewTimer(wgHandshakeOvertime)
+ defer timer.Stop()
+ expected := wgHandshakeOvertime
+ for {
+ select {
+ case <-timer.C:
+ lastHandshake, err := w.wgState()
+ if err != nil {
+ w.log.Errorf("failed to read wg stats: %v", err)
+ continue
+ }
+ w.log.Tracef("last handshake: %v", lastHandshake)
+
+ if time.Since(lastHandshake) > expected {
+ w.log.Infof("Wireguard handshake timed out, closing relay connection")
+ w.relayLock.Lock()
+ _ = w.relayedConn.Close()
+ w.relayLock.Unlock()
+ w.callBacks.OnDisconnected()
+ return
+ }
+ resetTime := time.Until(lastHandshake.Add(wgHandshakePeriod + wgHandshakeOvertime))
+ timer.Reset(resetTime)
+ expected = wgHandshakePeriod
+ case <-ctx.Done():
+ w.log.Debugf("WireGuard watcher stopped")
+ return
+ }
+ }
+}
+
+func (w *WorkerRelay) isRelaySupported(answer *OfferAnswer) bool {
+ if !w.relayManager.HasRelayAddress() {
+ return false
+ }
+ return answer.RelaySrvAddress != ""
+}
+
+func (w *WorkerRelay) preferredRelayServer(myRelayAddress, remoteRelayAddress string) string {
+ if w.IsController() {
+ return myRelayAddress
+ }
+ return remoteRelayAddress
+}
+
+func (w *WorkerRelay) wgState() (time.Time, error) {
+ wgState, err := w.config.WgConfig.WgInterface.GetStats(w.config.Key)
+ if err != nil {
+ return time.Time{}, err
+ }
+ return wgState.LastHandshake, nil
+}
+
+func (w *WorkerRelay) onRelayMGDisconnected() {
+ w.ctxLock.Lock()
+ defer w.ctxLock.Unlock()
+
+ if w.ctxCancelWgWatch != nil {
+ w.ctxCancelWgWatch()
+ }
+ go w.callBacks.OnDisconnected()
+}
diff --git a/client/internal/relay/relay.go b/client/internal/relay/relay.go
index 4542a37fe..7d98a6060 100644
--- a/client/internal/relay/relay.go
+++ b/client/internal/relay/relay.go
@@ -17,7 +17,7 @@ import (
// ProbeResult holds the info about the result of a relay probe request
type ProbeResult struct {
- URI *stun.URI
+ URI string
Err error
Addr string
}
@@ -176,7 +176,7 @@ func ProbeAll(
wg.Add(1)
go func(res *ProbeResult, stunURI *stun.URI) {
defer wg.Done()
- res.URI = stunURI
+ res.URI = stunURI.String()
res.Addr, res.Err = fn(ctx, stunURI)
}(&results[i], uri)
}
diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go
index cebdd2b0f..db2caea7f 100644
--- a/client/internal/routemanager/client.go
+++ b/client/internal/routemanager/client.go
@@ -22,7 +22,6 @@ import (
type routerPeerStatus struct {
connected bool
relayed bool
- direct bool
latency time.Duration
}
@@ -82,7 +81,6 @@ func (c *clientNetwork) getRouterPeerStatuses() map[route.ID]routerPeerStatus {
routePeerStatuses[r.ID] = routerPeerStatus{
connected: peerStatus.ConnStatus == peer.StatusConnected,
relayed: peerStatus.Relayed,
- direct: peerStatus.Direct,
latency: peerStatus.Latency,
}
}
@@ -97,8 +95,8 @@ func (c *clientNetwork) getRouterPeerStatuses() map[route.ID]routerPeerStatus {
// * Connected peers: Only routes with connected peers are considered.
// * Metric: Routes with lower metrics (better) are prioritized.
// * Non-relayed: Routes without relays are preferred.
-// * Direct connections: Routes with direct peer connections are favored.
// * Latency: Routes with lower latency are prioritized.
+// * we compare the current score + 10ms to the chosen score to avoid flapping between routes
// * Stability: In case of equal scores, the currently active route (if any) is maintained.
//
// It returns the ID of the selected optimal route.
@@ -137,10 +135,6 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
tempScore++
}
- if peerStatus.direct {
- tempScore++
- }
-
if tempScore > chosenScore || (tempScore == chosenScore && chosen == "") {
chosen = r.ID
chosenScore = tempScore
diff --git a/client/internal/routemanager/client_test.go b/client/internal/routemanager/client_test.go
index 0ae10e568..583156e4d 100644
--- a/client/internal/routemanager/client_test.go
+++ b/client/internal/routemanager/client_test.go
@@ -24,7 +24,6 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
"route1": {
connected: true,
relayed: false,
- direct: true,
},
},
existingRoutes: map[route.ID]*route.Route{
@@ -43,7 +42,6 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
"route1": {
connected: true,
relayed: true,
- direct: true,
},
},
existingRoutes: map[route.ID]*route.Route{
@@ -62,7 +60,6 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
"route1": {
connected: true,
relayed: true,
- direct: false,
},
},
existingRoutes: map[route.ID]*route.Route{
@@ -81,7 +78,6 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
"route1": {
connected: false,
relayed: false,
- direct: false,
},
},
existingRoutes: map[route.ID]*route.Route{
@@ -100,12 +96,10 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
"route1": {
connected: true,
relayed: false,
- direct: true,
},
"route2": {
connected: true,
relayed: false,
- direct: true,
},
},
existingRoutes: map[route.ID]*route.Route{
@@ -129,41 +123,10 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
"route1": {
connected: true,
relayed: false,
- direct: true,
},
"route2": {
connected: true,
relayed: true,
- direct: true,
- },
- },
- existingRoutes: map[route.ID]*route.Route{
- "route1": {
- ID: "route1",
- Metric: route.MaxMetric,
- Peer: "peer1",
- },
- "route2": {
- ID: "route2",
- Metric: route.MaxMetric,
- Peer: "peer2",
- },
- },
- currentRoute: "",
- expectedRouteID: "route1",
- },
- {
- name: "multiple connected peers with one direct",
- statuses: map[route.ID]routerPeerStatus{
- "route1": {
- connected: true,
- relayed: false,
- direct: true,
- },
- "route2": {
- connected: true,
- relayed: false,
- direct: false,
},
},
existingRoutes: map[route.ID]*route.Route{
@@ -241,13 +204,11 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
"route1": {
connected: true,
relayed: false,
- direct: true,
latency: 15 * time.Millisecond,
},
"route2": {
connected: true,
relayed: false,
- direct: true,
latency: 10 * time.Millisecond,
},
},
@@ -272,13 +233,11 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
"route1": {
connected: true,
relayed: false,
- direct: true,
latency: 200 * time.Millisecond,
},
"route2": {
connected: true,
relayed: false,
- direct: true,
latency: 10 * time.Millisecond,
},
},
@@ -303,13 +262,11 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
"route1": {
connected: true,
relayed: false,
- direct: true,
latency: 20 * time.Millisecond,
},
"route2": {
connected: true,
relayed: false,
- direct: true,
latency: 10 * time.Millisecond,
},
},
diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go
index 597eddd51..cdfd322bd 100644
--- a/client/internal/routemanager/manager.go
+++ b/client/internal/routemanager/manager.go
@@ -22,6 +22,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
"github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/iface"
+ relayClient "github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/route"
nbnet "github.com/netbirdio/netbird/util/net"
"github.com/netbirdio/netbird/version"
@@ -49,6 +50,7 @@ type DefaultManager struct {
serverRouter serverRouter
sysOps *systemops.SysOps
statusRecorder *peer.Status
+ relayMgr *relayClient.Manager
wgInterface iface.IWGIface
pubKey string
notifier *notifier.Notifier
@@ -63,6 +65,7 @@ func NewManager(
dnsRouteInterval time.Duration,
wgInterface iface.IWGIface,
statusRecorder *peer.Status,
+ relayMgr *relayClient.Manager,
initialRoutes []*route.Route,
) *DefaultManager {
mCTX, cancel := context.WithCancel(ctx)
@@ -74,6 +77,7 @@ func NewManager(
stop: cancel,
dnsRouteInterval: dnsRouteInterval,
clientNetworks: make(map[route.HAUniqueID]*clientNetwork),
+ relayMgr: relayMgr,
routeSelector: routeselector.NewRouteSelector(),
sysOps: sysOps,
statusRecorder: statusRecorder,
@@ -124,9 +128,12 @@ func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
log.Warnf("Failed cleaning up routing: %v", err)
}
- mgmtAddress := m.statusRecorder.GetManagementState().URL
- signalAddress := m.statusRecorder.GetSignalState().URL
- ips := resolveURLsToIPs([]string{mgmtAddress, signalAddress})
+ initialAddresses := []string{m.statusRecorder.GetManagementState().URL, m.statusRecorder.GetSignalState().URL}
+ if m.relayMgr != nil {
+ initialAddresses = append(initialAddresses, m.relayMgr.ServerURLs()...)
+ }
+
+ ips := resolveURLsToIPs(initialAddresses)
beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips)
if err != nil {
diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go
index 455c7ac0b..2995e2740 100644
--- a/client/internal/routemanager/manager_test.go
+++ b/client/internal/routemanager/manager_test.go
@@ -416,7 +416,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
statusRecorder := peer.NewRecorder("https://mgm")
ctx := context.TODO()
- routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil)
+ routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil)
_, _, err = routeManager.Init()
diff --git a/client/internal/wgproxy/proxy_ebpf.go b/client/internal/wgproxy/proxy_ebpf.go
index bbd00d6e2..d385cc4ca 100644
--- a/client/internal/wgproxy/proxy_ebpf.go
+++ b/client/internal/wgproxy/proxy_ebpf.go
@@ -181,7 +181,7 @@ func (p *WGEBPFProxy) proxyToRemote() {
conn, ok := p.turnConnStore[uint16(addr.Port)]
p.turnConnMutex.Unlock()
if !ok {
- log.Infof("turn conn not found by port: %d", addr.Port)
+ log.Debugf("turn conn not found by port because conn already has been closed: %d", addr.Port)
continue
}
@@ -206,7 +206,7 @@ func (p *WGEBPFProxy) storeTurnConn(turnConn net.Conn) (uint16, error) {
}
func (p *WGEBPFProxy) removeTurnConn(turnConnID uint16) {
- log.Tracef("remove turn conn from store by port: %d", turnConnID)
+ log.Debugf("remove turn conn from store by port: %d", turnConnID)
p.turnConnMutex.Lock()
defer p.turnConnMutex.Unlock()
delete(p.turnConnStore, turnConnID)
diff --git a/client/internal/wgproxy/proxy_userspace.go b/client/internal/wgproxy/proxy_userspace.go
index 234ea2a42..c2c8a9b51 100644
--- a/client/internal/wgproxy/proxy_userspace.go
+++ b/client/internal/wgproxy/proxy_userspace.go
@@ -3,6 +3,7 @@ package wgproxy
import (
"context"
"fmt"
+ "io"
"net"
log "github.com/sirupsen/logrus"
@@ -64,7 +65,6 @@ func (p *WGUserSpaceProxy) Free() error {
// proxyToRemote proxies everything from Wireguard to the RemoteKey peer
// blocks
func (p *WGUserSpaceProxy) proxyToRemote() {
-
buf := make([]byte, 1500)
for {
select {
@@ -73,11 +73,17 @@ func (p *WGUserSpaceProxy) proxyToRemote() {
default:
n, err := p.localConn.Read(buf)
if err != nil {
+ log.Debugf("failed to read from wg interface conn: %s", err)
continue
}
_, err = p.remoteConn.Write(buf[:n])
if err != nil {
+ if err == io.EOF {
+ p.cancel()
+ } else {
+ log.Debugf("failed to write to remote conn: %s", err)
+ }
continue
}
}
@@ -96,11 +102,17 @@ func (p *WGUserSpaceProxy) proxyToLocal() {
default:
n, err := p.remoteConn.Read(buf)
if err != nil {
+ if err == io.EOF {
+ p.cancel()
+ return
+ }
+ log.Errorf("failed to read from remote conn: %s", err)
continue
}
_, err = p.localConn.Write(buf[:n])
if err != nil {
+ log.Debugf("failed to write to wg interface conn: %s", err)
continue
}
}
diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go
index 779c27a4d..dc13706bf 100644
--- a/client/ios/NetBirdSDK/client.go
+++ b/client/ios/NetBirdSDK/client.go
@@ -168,7 +168,6 @@ func (c *Client) GetStatusDetails() *StatusDetails {
BytesTx: p.BytesTx,
ConnStatus: p.ConnStatus.String(),
ConnStatusUpdate: p.ConnStatusUpdate.Format("2006-01-02 15:04:05"),
- Direct: p.Direct,
LastWireguardHandshake: p.LastWireguardHandshake.String(),
Relayed: p.Relayed,
RosenpassEnabled: p.RosenpassEnabled,
diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go
index fb10a38d3..b942d8b6e 100644
--- a/client/proto/daemon.pb.go
+++ b/client/proto/daemon.pb.go
@@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.26.0
-// protoc v4.23.4
+// protoc v3.21.12
// source: daemon.proto
package proto
@@ -899,7 +899,6 @@ type PeerState struct {
ConnStatus string `protobuf:"bytes,3,opt,name=connStatus,proto3" json:"connStatus,omitempty"`
ConnStatusUpdate *timestamppb.Timestamp `protobuf:"bytes,4,opt,name=connStatusUpdate,proto3" json:"connStatusUpdate,omitempty"`
Relayed bool `protobuf:"varint,5,opt,name=relayed,proto3" json:"relayed,omitempty"`
- Direct bool `protobuf:"varint,6,opt,name=direct,proto3" json:"direct,omitempty"`
LocalIceCandidateType string `protobuf:"bytes,7,opt,name=localIceCandidateType,proto3" json:"localIceCandidateType,omitempty"`
RemoteIceCandidateType string `protobuf:"bytes,8,opt,name=remoteIceCandidateType,proto3" json:"remoteIceCandidateType,omitempty"`
Fqdn string `protobuf:"bytes,9,opt,name=fqdn,proto3" json:"fqdn,omitempty"`
@@ -911,6 +910,7 @@ type PeerState struct {
RosenpassEnabled bool `protobuf:"varint,15,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"`
Routes []string `protobuf:"bytes,16,rep,name=routes,proto3" json:"routes,omitempty"`
Latency *durationpb.Duration `protobuf:"bytes,17,opt,name=latency,proto3" json:"latency,omitempty"`
+ RelayAddress string `protobuf:"bytes,18,opt,name=relayAddress,proto3" json:"relayAddress,omitempty"`
}
func (x *PeerState) Reset() {
@@ -980,13 +980,6 @@ func (x *PeerState) GetRelayed() bool {
return false
}
-func (x *PeerState) GetDirect() bool {
- if x != nil {
- return x.Direct
- }
- return false
-}
-
func (x *PeerState) GetLocalIceCandidateType() string {
if x != nil {
return x.LocalIceCandidateType
@@ -1064,6 +1057,13 @@ func (x *PeerState) GetLatency() *durationpb.Duration {
return nil
}
+func (x *PeerState) GetRelayAddress() string {
+ if x != nil {
+ return x.RelayAddress
+ }
+ return ""
+}
+
// LocalPeerState contains the latest state of the local peer
type LocalPeerState struct {
state protoimpl.MessageState
@@ -2243,7 +2243,7 @@ var file_daemon_proto_rawDesc = []byte{
0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e,
0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x18, 0x0c,
0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50,
- 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x22, 0xce, 0x05, 0x0a, 0x09, 0x50, 0x65,
+ 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x22, 0xda, 0x05, 0x0a, 0x09, 0x50, 0x65,
0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20,
0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65,
0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12,
@@ -2255,209 +2255,210 @@ var file_daemon_proto_rawDesc = []byte{
0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x10, 0x63, 0x6f, 0x6e, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x75,
0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x72, 0x65, 0x6c, 0x61, 0x79,
0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x65,
- 0x64, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28,
- 0x08, 0x52, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x12, 0x34, 0x0a, 0x15, 0x6c, 0x6f, 0x63,
- 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79,
- 0x70, 0x65, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x15, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49,
+ 0x64, 0x12, 0x34, 0x0a, 0x15, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e,
+ 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09,
+ 0x52, 0x15, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64,
+ 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, 0x36, 0x0a, 0x16, 0x72, 0x65, 0x6d, 0x6f, 0x74,
+ 0x65, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70,
+ 0x65, 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, 0x52, 0x16, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49,
0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12,
- 0x36, 0x0a, 0x16, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64,
- 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, 0x52,
- 0x16, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64,
- 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18,
- 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12, 0x3c, 0x0a, 0x19, 0x6c,
- 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65,
- 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x19,
- 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74,
- 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x3e, 0x0a, 0x1a, 0x72, 0x65, 0x6d,
- 0x6f, 0x74, 0x65, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45,
- 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x09, 0x52, 0x1a, 0x72,
- 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74,
- 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x52, 0x0a, 0x16, 0x6c, 0x61, 0x73,
- 0x74, 0x57, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x48, 0x61, 0x6e, 0x64, 0x73, 0x68,
- 0x61, 0x6b, 0x65, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67,
- 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65,
- 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x16, 0x6c, 0x61, 0x73, 0x74, 0x57, 0x69, 0x72, 0x65, 0x67,
- 0x75, 0x61, 0x72, 0x64, 0x48, 0x61, 0x6e, 0x64, 0x73, 0x68, 0x61, 0x6b, 0x65, 0x12, 0x18, 0x0a,
- 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x52, 0x78, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x03, 0x52, 0x07,
- 0x62, 0x79, 0x74, 0x65, 0x73, 0x52, 0x78, 0x12, 0x18, 0x0a, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73,
- 0x54, 0x78, 0x18, 0x0e, 0x20, 0x01, 0x28, 0x03, 0x52, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x54,
- 0x78, 0x12, 0x2a, 0x0a, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e,
- 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x6f, 0x73,
- 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x16, 0x0a,
- 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x10, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x72,
- 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x33, 0x0a, 0x07, 0x6c, 0x61, 0x74, 0x65, 0x6e, 0x63, 0x79,
- 0x18, 0x11, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e,
- 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f,
- 0x6e, 0x52, 0x07, 0x6c, 0x61, 0x74, 0x65, 0x6e, 0x63, 0x79, 0x22, 0xec, 0x01, 0x0a, 0x0e, 0x4c,
- 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0e, 0x0a,
- 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a,
- 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x70,
- 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x28, 0x0a, 0x0f, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x49,
- 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0f,
- 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x12,
- 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66,
- 0x71, 0x64, 0x6e, 0x12, 0x2a, 0x0a, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73,
- 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72,
- 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12,
- 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d,
- 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x72, 0x6f,
- 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76,
- 0x65, 0x12, 0x16, 0x0a, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28,
- 0x09, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x22, 0x53, 0x0a, 0x0b, 0x53, 0x69, 0x67,
- 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x4c, 0x18,
- 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x4c, 0x12, 0x1c, 0x0a, 0x09, 0x63, 0x6f,
- 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x63,
- 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f,
- 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0x57,
- 0x0a, 0x0f, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74,
- 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x4c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03,
- 0x55, 0x52, 0x4c, 0x12, 0x1c, 0x0a, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64,
- 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65,
- 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09,
- 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0x52, 0x0a, 0x0a, 0x52, 0x65, 0x6c, 0x61, 0x79,
- 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x49, 0x18, 0x01, 0x20, 0x01,
- 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x49, 0x12, 0x1c, 0x0a, 0x09, 0x61, 0x76, 0x61, 0x69, 0x6c,
- 0x61, 0x62, 0x6c, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x61, 0x76, 0x61, 0x69,
- 0x6c, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03,
- 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0x72, 0x0a, 0x0c, 0x4e,
- 0x53, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x73,
- 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x73, 0x65,
- 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73,
- 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12,
- 0x18, 0x0a, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08,
- 0x52, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72,
- 0x6f, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22,
- 0xd2, 0x02, 0x0a, 0x0a, 0x46, 0x75, 0x6c, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x41,
- 0x0a, 0x0f, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74,
- 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e,
- 0x2e, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65,
- 0x52, 0x0f, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74,
- 0x65, 0x12, 0x35, 0x0a, 0x0b, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65,
- 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
- 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0b, 0x73, 0x69, 0x67,
- 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x3e, 0x0a, 0x0e, 0x6c, 0x6f, 0x63, 0x61,
- 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b,
- 0x32, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x50,
- 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x50,
- 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x27, 0x0a, 0x05, 0x70, 0x65, 0x65, 0x72,
- 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e,
- 0x2e, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x05, 0x70, 0x65, 0x65, 0x72,
- 0x73, 0x12, 0x2a, 0x0a, 0x06, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28,
- 0x0b, 0x32, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x52, 0x65, 0x6c, 0x61, 0x79,
- 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x06, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x73, 0x12, 0x35, 0x0a,
- 0x0b, 0x64, 0x6e, 0x73, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x06, 0x20, 0x03,
- 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4e, 0x53, 0x47, 0x72,
- 0x6f, 0x75, 0x70, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0a, 0x64, 0x6e, 0x73, 0x53, 0x65, 0x72,
- 0x76, 0x65, 0x72, 0x73, 0x22, 0x13, 0x0a, 0x11, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74,
- 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x3b, 0x0a, 0x12, 0x4c, 0x69, 0x73,
- 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12,
- 0x25, 0x0a, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32,
- 0x0d, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x06,
- 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x22, 0x5b, 0x0a, 0x13, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74,
- 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a,
- 0x08, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52,
- 0x08, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x61, 0x70, 0x70,
- 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x06, 0x61, 0x70, 0x70, 0x65, 0x6e,
- 0x64, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x6c, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03,
- 0x61, 0x6c, 0x6c, 0x22, 0x16, 0x0a, 0x14, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75,
- 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x1a, 0x0a, 0x06, 0x49,
- 0x50, 0x4c, 0x69, 0x73, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x69, 0x70, 0x73, 0x18, 0x01, 0x20, 0x03,
- 0x28, 0x09, 0x52, 0x03, 0x69, 0x70, 0x73, 0x22, 0xf9, 0x01, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74,
- 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49,
- 0x44, 0x12, 0x18, 0x0a, 0x07, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01,
- 0x28, 0x09, 0x52, 0x07, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x1a, 0x0a, 0x08, 0x73,
- 0x65, 0x6c, 0x65, 0x63, 0x74, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x73,
- 0x65, 0x6c, 0x65, 0x63, 0x74, 0x65, 0x64, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69,
- 0x6e, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e,
- 0x73, 0x12, 0x40, 0x0a, 0x0b, 0x72, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64, 0x49, 0x50, 0x73,
- 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
- 0x52, 0x6f, 0x75, 0x74, 0x65, 0x2e, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64, 0x49, 0x50,
- 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x0b, 0x72, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64,
- 0x49, 0x50, 0x73, 0x1a, 0x4e, 0x0a, 0x10, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64, 0x49,
- 0x50, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01,
- 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x24, 0x0a, 0x05, 0x76, 0x61, 0x6c,
- 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f,
- 0x6e, 0x2e, 0x49, 0x50, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a,
- 0x02, 0x38, 0x01, 0x22, 0x6a, 0x0a, 0x12, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64,
- 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1c, 0x0a, 0x09, 0x61, 0x6e, 0x6f,
- 0x6e, 0x79, 0x6d, 0x69, 0x7a, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x61, 0x6e,
- 0x6f, 0x6e, 0x79, 0x6d, 0x69, 0x7a, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75,
- 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12,
- 0x1e, 0x0a, 0x0a, 0x73, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x03, 0x20,
- 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, 0x6f, 0x22,
- 0x29, 0x0a, 0x13, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65,
- 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, 0x68, 0x18, 0x01,
- 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x22, 0x14, 0x0a, 0x12, 0x47, 0x65,
- 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
- 0x22, 0x3d, 0x0a, 0x13, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52,
- 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x26, 0x0a, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c,
- 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
- 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x22,
- 0x3c, 0x0a, 0x12, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65,
- 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x26, 0x0a, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x18, 0x01,
- 0x20, 0x01, 0x28, 0x0e, 0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f,
- 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x22, 0x15, 0x0a,
- 0x13, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70,
- 0x6f, 0x6e, 0x73, 0x65, 0x2a, 0x62, 0x0a, 0x08, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c,
- 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x09, 0x0a,
- 0x05, 0x50, 0x41, 0x4e, 0x49, 0x43, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x46, 0x41, 0x54, 0x41,
- 0x4c, 0x10, 0x02, 0x12, 0x09, 0x0a, 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x03, 0x12, 0x08,
- 0x0a, 0x04, 0x57, 0x41, 0x52, 0x4e, 0x10, 0x04, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x4e, 0x46, 0x4f,
- 0x10, 0x05, 0x12, 0x09, 0x0a, 0x05, 0x44, 0x45, 0x42, 0x55, 0x47, 0x10, 0x06, 0x12, 0x09, 0x0a,
- 0x05, 0x54, 0x52, 0x41, 0x43, 0x45, 0x10, 0x07, 0x32, 0xb8, 0x06, 0x0a, 0x0d, 0x44, 0x61, 0x65,
- 0x6d, 0x6f, 0x6e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, 0x0a, 0x05, 0x4c, 0x6f,
- 0x67, 0x69, 0x6e, 0x12, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67,
- 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d,
- 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65,
- 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67,
- 0x69, 0x6e, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74,
- 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a,
- 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f,
- 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12,
- 0x2d, 0x0a, 0x02, 0x55, 0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55,
- 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f,
- 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x39,
- 0x0a, 0x06, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f,
- 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a,
- 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52,
- 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x04, 0x44, 0x6f, 0x77,
- 0x6e, 0x12, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52,
- 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
- 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x42,
- 0x0a, 0x09, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x2e, 0x64, 0x61,
- 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65,
- 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47,
- 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65,
- 0x22, 0x00, 0x12, 0x45, 0x0a, 0x0a, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73,
- 0x12, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f,
- 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61,
- 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52,
- 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x53, 0x65, 0x6c,
- 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d,
- 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52,
- 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
- 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70,
- 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4d, 0x0a, 0x0e, 0x44, 0x65, 0x73, 0x65, 0x6c, 0x65,
- 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f,
- 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65,
- 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53,
- 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f,
- 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75,
- 0x6e, 0x64, 0x6c, 0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65,
- 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
- 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42,
- 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12,
- 0x48, 0x0a, 0x0b, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a,
- 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65,
- 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65,
+ 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66,
+ 0x71, 0x64, 0x6e, 0x12, 0x3c, 0x0a, 0x19, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43,
+ 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74,
+ 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x19, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65,
+ 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e,
+ 0x74, 0x12, 0x3e, 0x0a, 0x1a, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, 0x65, 0x43, 0x61,
+ 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18,
+ 0x0b, 0x20, 0x01, 0x28, 0x09, 0x52, 0x1a, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, 0x65,
+ 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e,
+ 0x74, 0x12, 0x52, 0x0a, 0x16, 0x6c, 0x61, 0x73, 0x74, 0x57, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61,
+ 0x72, 0x64, 0x48, 0x61, 0x6e, 0x64, 0x73, 0x68, 0x61, 0x6b, 0x65, 0x18, 0x0c, 0x20, 0x01, 0x28,
+ 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f,
+ 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x16, 0x6c,
+ 0x61, 0x73, 0x74, 0x57, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x48, 0x61, 0x6e, 0x64,
+ 0x73, 0x68, 0x61, 0x6b, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x52, 0x78,
+ 0x18, 0x0d, 0x20, 0x01, 0x28, 0x03, 0x52, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x52, 0x78, 0x12,
+ 0x18, 0x0a, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x54, 0x78, 0x18, 0x0e, 0x20, 0x01, 0x28, 0x03,
+ 0x52, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x54, 0x78, 0x12, 0x2a, 0x0a, 0x10, 0x72, 0x6f, 0x73,
+ 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x0f, 0x20,
+ 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e,
+ 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18,
+ 0x10, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x33, 0x0a,
+ 0x07, 0x6c, 0x61, 0x74, 0x65, 0x6e, 0x63, 0x79, 0x18, 0x11, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19,
+ 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66,
+ 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x07, 0x6c, 0x61, 0x74, 0x65, 0x6e,
+ 0x63, 0x79, 0x12, 0x22, 0x0a, 0x0c, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x41, 0x64, 0x64, 0x72, 0x65,
+ 0x73, 0x73, 0x18, 0x12, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x41,
+ 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x22, 0xec, 0x01, 0x0a, 0x0e, 0x4c, 0x6f, 0x63, 0x61, 0x6c,
+ 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18,
+ 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x70, 0x75, 0x62,
+ 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65,
+ 0x79, 0x12, 0x28, 0x0a, 0x0f, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x49, 0x6e, 0x74, 0x65, 0x72,
+ 0x66, 0x61, 0x63, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0f, 0x6b, 0x65, 0x72, 0x6e,
+ 0x65, 0x6c, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x66,
+ 0x71, 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12,
+ 0x2a, 0x0a, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62,
+ 0x6c, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e,
+ 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x30, 0x0a, 0x13, 0x72,
+ 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69,
+ 0x76, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70,
+ 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x12, 0x16, 0x0a,
+ 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x72,
+ 0x6f, 0x75, 0x74, 0x65, 0x73, 0x22, 0x53, 0x0a, 0x0b, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53,
+ 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x4c, 0x18, 0x01, 0x20, 0x01, 0x28,
+ 0x09, 0x52, 0x03, 0x55, 0x52, 0x4c, 0x12, 0x1c, 0x0a, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63,
+ 0x74, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65,
+ 0x63, 0x74, 0x65, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20,
+ 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0x57, 0x0a, 0x0f, 0x4d, 0x61,
+ 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a,
+ 0x03, 0x55, 0x52, 0x4c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x4c, 0x12,
+ 0x1c, 0x0a, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01,
+ 0x28, 0x08, 0x52, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x12, 0x14, 0x0a,
+ 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72,
+ 0x72, 0x6f, 0x72, 0x22, 0x52, 0x0a, 0x0a, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x74, 0x61, 0x74,
+ 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x49, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03,
+ 0x55, 0x52, 0x49, 0x12, 0x1c, 0x0a, 0x09, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65,
+ 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c,
+ 0x65, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09,
+ 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0x72, 0x0a, 0x0c, 0x4e, 0x53, 0x47, 0x72, 0x6f,
+ 0x75, 0x70, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x65, 0x72, 0x76, 0x65,
+ 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72,
+ 0x73, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x03,
+ 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x65,
+ 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e,
+ 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x04,
+ 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0xd2, 0x02, 0x0a, 0x0a,
+ 0x46, 0x75, 0x6c, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x41, 0x0a, 0x0f, 0x6d, 0x61,
+ 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x01, 0x20,
+ 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4d, 0x61, 0x6e,
+ 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0f, 0x6d, 0x61,
+ 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x35, 0x0a,
+ 0x0b, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x02, 0x20, 0x01,
+ 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x69, 0x67, 0x6e,
+ 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0b, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53,
+ 0x74, 0x61, 0x74, 0x65, 0x12, 0x3e, 0x0a, 0x0e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65,
+ 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x64,
+ 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53,
+ 0x74, 0x61, 0x74, 0x65, 0x52, 0x0e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53,
+ 0x74, 0x61, 0x74, 0x65, 0x12, 0x27, 0x0a, 0x05, 0x70, 0x65, 0x65, 0x72, 0x73, 0x18, 0x04, 0x20,
+ 0x03, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x50, 0x65, 0x65,
+ 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x05, 0x70, 0x65, 0x65, 0x72, 0x73, 0x12, 0x2a, 0x0a,
+ 0x06, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x12, 0x2e,
+ 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x74, 0x61, 0x74,
+ 0x65, 0x52, 0x06, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x73, 0x12, 0x35, 0x0a, 0x0b, 0x64, 0x6e, 0x73,
+ 0x5f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x06, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x14,
+ 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4e, 0x53, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53,
+ 0x74, 0x61, 0x74, 0x65, 0x52, 0x0a, 0x64, 0x6e, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73,
+ 0x22, 0x13, 0x0a, 0x11, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65,
+ 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x3b, 0x0a, 0x12, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75,
+ 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x25, 0x0a, 0x06, 0x72,
+ 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x0d, 0x2e, 0x64, 0x61,
+ 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74,
+ 0x65, 0x73, 0x22, 0x5b, 0x0a, 0x13, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74,
+ 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x72, 0x6f, 0x75,
+ 0x74, 0x65, 0x49, 0x44, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x08, 0x72, 0x6f, 0x75,
+ 0x74, 0x65, 0x49, 0x44, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x61, 0x70, 0x70, 0x65, 0x6e, 0x64, 0x18,
+ 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x06, 0x61, 0x70, 0x70, 0x65, 0x6e, 0x64, 0x12, 0x10, 0x0a,
+ 0x03, 0x61, 0x6c, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x61, 0x6c, 0x6c, 0x22,
+ 0x16, 0x0a, 0x14, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52,
+ 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x1a, 0x0a, 0x06, 0x49, 0x50, 0x4c, 0x69, 0x73,
+ 0x74, 0x12, 0x10, 0x0a, 0x03, 0x69, 0x70, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x03,
+ 0x69, 0x70, 0x73, 0x22, 0xf9, 0x01, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, 0x0a,
+ 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, 0x0a,
+ 0x07, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07,
+ 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x1a, 0x0a, 0x08, 0x73, 0x65, 0x6c, 0x65, 0x63,
+ 0x74, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x73, 0x65, 0x6c, 0x65, 0x63,
+ 0x74, 0x65, 0x64, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x04,
+ 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x40, 0x0a,
+ 0x0b, 0x72, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64, 0x49, 0x50, 0x73, 0x18, 0x05, 0x20, 0x03,
+ 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x52, 0x6f, 0x75, 0x74,
+ 0x65, 0x2e, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64, 0x49, 0x50, 0x73, 0x45, 0x6e, 0x74,
+ 0x72, 0x79, 0x52, 0x0b, 0x72, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64, 0x49, 0x50, 0x73, 0x1a,
+ 0x4e, 0x0a, 0x10, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64, 0x49, 0x50, 0x73, 0x45, 0x6e,
+ 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09,
+ 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x24, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02,
+ 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x49, 0x50,
+ 0x4c, 0x69, 0x73, 0x74, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22,
+ 0x6a, 0x0a, 0x12, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65,
+ 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1c, 0x0a, 0x09, 0x61, 0x6e, 0x6f, 0x6e, 0x79, 0x6d, 0x69,
+ 0x7a, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x61, 0x6e, 0x6f, 0x6e, 0x79, 0x6d,
+ 0x69, 0x7a, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x02, 0x20,
+ 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x1e, 0x0a, 0x0a, 0x73,
+ 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52,
+ 0x0a, 0x73, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, 0x6f, 0x22, 0x29, 0x0a, 0x13, 0x44,
+ 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e,
+ 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, 0x68, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09,
+ 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x22, 0x14, 0x0a, 0x12, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67,
+ 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x3d, 0x0a, 0x13,
+ 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f,
+ 0x6e, 0x73, 0x65, 0x12, 0x26, 0x0a, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x18, 0x01, 0x20, 0x01,
+ 0x28, 0x0e, 0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x4c,
+ 0x65, 0x76, 0x65, 0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x22, 0x3c, 0x0a, 0x12, 0x53,
+ 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73,
+ 0x74, 0x12, 0x26, 0x0a, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e,
+ 0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76,
+ 0x65, 0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x22, 0x15, 0x0a, 0x13, 0x53, 0x65, 0x74,
+ 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65,
+ 0x2a, 0x62, 0x0a, 0x08, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x0b, 0x0a, 0x07,
+ 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x09, 0x0a, 0x05, 0x50, 0x41, 0x4e,
+ 0x49, 0x43, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x46, 0x41, 0x54, 0x41, 0x4c, 0x10, 0x02, 0x12,
+ 0x09, 0x0a, 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x57, 0x41,
+ 0x52, 0x4e, 0x10, 0x04, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x4e, 0x46, 0x4f, 0x10, 0x05, 0x12, 0x09,
+ 0x0a, 0x05, 0x44, 0x45, 0x42, 0x55, 0x47, 0x10, 0x06, 0x12, 0x09, 0x0a, 0x05, 0x54, 0x52, 0x41,
+ 0x43, 0x45, 0x10, 0x07, 0x32, 0xb8, 0x06, 0x0a, 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53,
+ 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12,
+ 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65,
+ 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c,
+ 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b,
+ 0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b,
+ 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c,
+ 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61,
+ 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69,
+ 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55,
+ 0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71,
+ 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70,
+ 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74,
+ 0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74,
+ 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61,
+ 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f,
+ 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e,
+ 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65,
+ 0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e,
+ 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65,
+ 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e,
+ 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73,
+ 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f,
+ 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45,
+ 0x0a, 0x0a, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x19, 0x2e, 0x64,
+ 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73,
+ 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e,
+ 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f,
+ 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52,
+ 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53,
+ 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65,
+ 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65,
+ 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65,
+ 0x22, 0x00, 0x12, 0x4d, 0x0a, 0x0e, 0x44, 0x65, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f,
+ 0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65,
+ 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73,
+ 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63,
+ 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22,
+ 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65,
+ 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42,
+ 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64,
+ 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c,
+ 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x47,
+ 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65,
0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52,
- 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x53, 0x65, 0x74,
- 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f,
- 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71,
- 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65,
- 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73,
- 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70,
- 0x72, 0x6f, 0x74, 0x6f, 0x33,
+ 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
+ 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f,
+ 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c,
+ 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65,
+ 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
+ 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67,
+ 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42,
+ 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f,
+ 0x33,
}
var (
diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto
index 43c379fb5..384bc0e62 100644
--- a/client/proto/daemon.proto
+++ b/client/proto/daemon.proto
@@ -168,7 +168,6 @@ message PeerState {
string connStatus = 3;
google.protobuf.Timestamp connStatusUpdate = 4;
bool relayed = 5;
- bool direct = 6;
string localIceCandidateType = 7;
string remoteIceCandidateType = 8;
string fqdn = 9;
@@ -180,6 +179,7 @@ message PeerState {
bool rosenpassEnabled = 15;
repeated string routes = 16;
google.protobuf.Duration latency = 17;
+ string relayAddress = 18;
}
// LocalPeerState contains the latest state of the local peer
diff --git a/client/server/debug.go b/client/server/debug.go
index 1187f3187..5ed43293b 100644
--- a/client/server/debug.go
+++ b/client/server/debug.go
@@ -369,8 +369,8 @@ func seedFromStatus(a *anonymize.Anonymizer, status *peer.FullStatus) {
}
for _, relay := range status.Relays {
- if relay.URI != nil {
- a.AnonymizeURI(relay.URI.String())
+ if relay.URI != "" {
+ a.AnonymizeURI(relay.URI)
}
}
}
diff --git a/client/server/server.go b/client/server/server.go
index d8d32e1ce..0a4c18131 100644
--- a/client/server/server.go
+++ b/client/server/server.go
@@ -758,11 +758,11 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
ConnStatus: peerState.ConnStatus.String(),
ConnStatusUpdate: timestamppb.New(peerState.ConnStatusUpdate),
Relayed: peerState.Relayed,
- Direct: peerState.Direct,
LocalIceCandidateType: peerState.LocalIceCandidateType,
RemoteIceCandidateType: peerState.RemoteIceCandidateType,
LocalIceCandidateEndpoint: peerState.LocalIceCandidateEndpoint,
RemoteIceCandidateEndpoint: peerState.RemoteIceCandidateEndpoint,
+ RelayAddress: peerState.RelayServerAddress,
Fqdn: peerState.FQDN,
LastWireguardHandshake: timestamppb.New(peerState.LastWireguardHandshake),
BytesRx: peerState.BytesRx,
@@ -776,7 +776,7 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
for _, relayState := range fullStatus.Relays {
pbRelayState := &proto.RelayState{
- URI: relayState.URI.String(),
+ URI: relayState.URI,
Available: relayState.Err == nil,
}
if err := relayState.Err; err != nil {
diff --git a/client/server/server_test.go b/client/server/server_test.go
index 242d399ec..795060fab 100644
--- a/client/server/server_test.go
+++ b/client/server/server_test.go
@@ -6,10 +6,11 @@ import (
"testing"
"time"
- "github.com/netbirdio/management-integrations/integrations"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
+ "github.com/netbirdio/management-integrations/integrations"
+
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
@@ -129,8 +130,9 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
if err != nil {
return nil, "", err
}
- turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
- mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
+
+ secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
+ mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, nil)
if err != nil {
return nil, "", err
}
diff --git a/client/testdata/management.json b/client/testdata/management.json
index 4745f2e8c..674c66e06 100644
--- a/client/testdata/management.json
+++ b/client/testdata/management.json
@@ -20,6 +20,13 @@
"Secret": "c29tZV9wYXNzd29yZA==",
"TimeBasedCredentials": true
},
+ "Relay": {
+ "Addresses": [
+ "localhost:0"
+ ],
+ "CredentialsTTL": "1h",
+ "Secret": "b29tZV9wYXNzd29yZA=="
+ },
"Signal": {
"Proto": "http",
"URI": "signal.wiretrustee.com:10000",
@@ -34,4 +41,4 @@
"AuthAudience": "",
"AuthKeysLocation": ""
}
-}
\ No newline at end of file
+}
diff --git a/encryption/cert.go b/encryption/cert.go
new file mode 100644
index 000000000..3f6d5c679
--- /dev/null
+++ b/encryption/cert.go
@@ -0,0 +1,19 @@
+package encryption
+
+import "crypto/tls"
+
+func LoadTLSConfig(certFile, keyFile string) (*tls.Config, error) {
+ serverCert, err := tls.LoadX509KeyPair(certFile, keyFile)
+ if err != nil {
+ return nil, err
+ }
+
+ config := &tls.Config{
+ Certificates: []tls.Certificate{serverCert},
+ ClientAuth: tls.NoClientCert,
+ NextProtos: []string{
+ "h2", "http/1.1", // enable HTTP/2
+ },
+ }
+ return config, nil
+}
diff --git a/encryption/letsencrypt.go b/encryption/letsencrypt.go
index cfe54ec5a..27a5e3110 100644
--- a/encryption/letsencrypt.go
+++ b/encryption/letsencrypt.go
@@ -9,7 +9,7 @@ import (
)
// CreateCertManager wraps common logic of generating Let's encrypt certificate.
-func CreateCertManager(datadir string, letsencryptDomain string) (*autocert.Manager, error) {
+func CreateCertManager(datadir string, letsencryptDomain ...string) (*autocert.Manager, error) {
certDir := filepath.Join(datadir, "letsencrypt")
if _, err := os.Stat(certDir); os.IsNotExist(err) {
@@ -24,7 +24,7 @@ func CreateCertManager(datadir string, letsencryptDomain string) (*autocert.Mana
certManager := &autocert.Manager{
Prompt: autocert.AcceptTOS,
Cache: autocert.DirCache(certDir),
- HostPolicy: autocert.HostWhitelist(letsencryptDomain),
+ HostPolicy: autocert.HostWhitelist(letsencryptDomain...),
}
return certManager, nil
diff --git a/encryption/route53.go b/encryption/route53.go
new file mode 100644
index 000000000..3c81ab103
--- /dev/null
+++ b/encryption/route53.go
@@ -0,0 +1,87 @@
+package encryption
+
+import (
+ "context"
+ "crypto/tls"
+ "fmt"
+ "os"
+ "strings"
+
+ "github.com/caddyserver/certmagic"
+ "github.com/libdns/route53"
+ log "github.com/sirupsen/logrus"
+ "go.uber.org/zap"
+ "go.uber.org/zap/zapcore"
+ "golang.org/x/crypto/acme"
+)
+
+// Route53TLS by default, loads the AWS configuration from the environment.
+// env variables: AWS_REGION, AWS_PROFILE, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_SESSION_TOKEN
+type Route53TLS struct {
+ DataDir string
+ Email string
+ Domains []string
+ CA string
+}
+
+func (r *Route53TLS) GetCertificate() (*tls.Config, error) {
+ if len(r.Domains) == 0 {
+ return nil, fmt.Errorf("no domains provided")
+ }
+
+ certmagic.Default.Logger = logger()
+ certmagic.Default.Storage = &certmagic.FileStorage{Path: r.DataDir}
+ certmagic.DefaultACME.Agreed = true
+ if r.Email != "" {
+ certmagic.DefaultACME.Email = r.Email
+ } else {
+ certmagic.DefaultACME.Email = emailFromDomain(r.Domains[0])
+ }
+
+ if r.CA == "" {
+ certmagic.DefaultACME.CA = certmagic.LetsEncryptProductionCA
+ } else {
+ certmagic.DefaultACME.CA = r.CA
+ }
+
+ certmagic.DefaultACME.DNS01Solver = &certmagic.DNS01Solver{
+ DNSManager: certmagic.DNSManager{
+ DNSProvider: &route53.Provider{},
+ },
+ }
+ cm := certmagic.NewDefault()
+ if err := cm.ManageSync(context.Background(), r.Domains); err != nil {
+ log.Errorf("failed to manage certificate: %v", err)
+ return nil, err
+ }
+
+ tlsConfig := &tls.Config{
+ GetCertificate: cm.GetCertificate,
+ NextProtos: []string{"h2", "http/1.1", acme.ALPNProto},
+ }
+
+ return tlsConfig, nil
+}
+
+func emailFromDomain(domain string) string {
+ if domain == "" {
+ return ""
+ }
+
+ parts := strings.Split(domain, ".")
+ if len(parts) < 2 {
+ return ""
+ }
+ if parts[0] == "" {
+ return ""
+ }
+ return fmt.Sprintf("admin@%s.%s", parts[len(parts)-2], parts[len(parts)-1])
+}
+
+func logger() *zap.Logger {
+ return zap.New(zapcore.NewCore(
+ zapcore.NewConsoleEncoder(zap.NewProductionEncoderConfig()),
+ os.Stderr,
+ zap.ErrorLevel,
+ ))
+}
diff --git a/encryption/route53_test.go b/encryption/route53_test.go
new file mode 100644
index 000000000..765b60f84
--- /dev/null
+++ b/encryption/route53_test.go
@@ -0,0 +1,84 @@
+package encryption
+
+import (
+ "context"
+ "io"
+ "net/http"
+ "os"
+ "testing"
+ "time"
+)
+
+func TestRoute53TLSConfig(t *testing.T) {
+ t.SkipNow() // This test requires AWS credentials
+ exampleString := "Hello, world!"
+ rtls := &Route53TLS{
+ DataDir: t.TempDir(),
+ Email: os.Getenv("LE_EMAIL_ROUTE53"),
+ Domains: []string{os.Getenv("DOMAIN")},
+ }
+ tlsConfig, err := rtls.GetCertificate()
+ if err != nil {
+ t.Errorf("Route53TLSConfig failed: %v", err)
+ }
+
+ server := &http.Server{
+ Addr: ":8443",
+ Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ _, _ = w.Write([]byte(exampleString))
+ }),
+ TLSConfig: tlsConfig,
+ }
+
+ go func() {
+ err := server.ListenAndServeTLS("", "")
+ if err != http.ErrServerClosed {
+ t.Errorf("Failed to start server: %v", err)
+ }
+ }()
+ defer func() {
+ if err := server.Shutdown(context.Background()); err != nil {
+ t.Errorf("Failed to shutdown server: %v", err)
+ }
+ }()
+
+ time.Sleep(1 * time.Second)
+ resp, err := http.Get("https://relay.godevltd.com:8443")
+ if err != nil {
+ t.Errorf("Failed to get response: %v", err)
+ return
+ }
+ defer func() {
+ _ = resp.Body.Close()
+ }()
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ t.Errorf("Failed to read response body: %v", err)
+ }
+ if string(body) != exampleString {
+ t.Errorf("Unexpected response: %s", body)
+ }
+}
+
+func Test_emailFromDomain(t *testing.T) {
+ tests := []struct {
+ input string
+ want string
+ }{
+ {"example.com", "admin@example.com"},
+ {"x.example.com", "admin@example.com"},
+ {"x.x.example.com", "admin@example.com"},
+ {"*.example.com", "admin@example.com"},
+ {"example", ""},
+ {"", ""},
+ {".com", ""},
+ }
+ for _, tt := range tests {
+ t.Run("domain test", func(t *testing.T) {
+ if got := emailFromDomain(tt.input); got != tt.want {
+ t.Errorf("emailFromDomain() = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
diff --git a/go.mod b/go.mod
index 9e440e342..7d5817769 100644
--- a/go.mod
+++ b/go.mod
@@ -12,7 +12,7 @@ require (
github.com/gorilla/mux v1.8.0
github.com/kardianos/service v1.2.3-0.20240613133416-becf2eb62b83
github.com/onsi/ginkgo v1.16.5
- github.com/onsi/gomega v1.23.0
+ github.com/onsi/gomega v1.27.6
github.com/pion/ice/v3 v3.0.2
github.com/rs/cors v1.8.0
github.com/sirupsen/logrus v1.9.3
@@ -34,6 +34,7 @@ require (
fyne.io/systray v1.11.0
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible
github.com/c-robinson/iplib v1.0.3
+ github.com/caddyserver/certmagic v0.21.3
github.com/cilium/ebpf v0.15.0
github.com/coreos/go-iptables v0.7.0
github.com/creack/pty v1.1.18
@@ -50,11 +51,12 @@ require (
github.com/hashicorp/go-multierror v1.1.1
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2
github.com/hashicorp/go-version v1.6.0
+ github.com/libdns/route53 v1.5.0
github.com/libp2p/go-netroute v0.2.1
github.com/magiconair/properties v1.8.7
github.com/mattn/go-sqlite3 v1.14.19
github.com/mdlayher/socket v0.4.1
- github.com/miekg/dns v1.1.43
+ github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e
@@ -63,6 +65,7 @@ require (
github.com/oschwald/maxminddb-golang v1.12.0
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/pion/logging v0.2.2
+ github.com/pion/randutil v0.1.0
github.com/pion/stun/v2 v2.0.0
github.com/pion/transport/v3 v3.0.1
github.com/pion/turn/v3 v3.0.1
@@ -70,6 +73,7 @@ require (
github.com/rs/xid v1.3.0
github.com/shirou/gopsutil/v3 v3.24.4
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
+ github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8
github.com/stretchr/testify v1.9.0
github.com/testcontainers/testcontainers-go v0.31.0
github.com/testcontainers/testcontainers-go/modules/postgres v0.31.0
@@ -81,6 +85,7 @@ require (
go.opentelemetry.io/otel/exporters/prometheus v0.48.0
go.opentelemetry.io/otel/metric v1.26.0
go.opentelemetry.io/otel/sdk/metric v1.26.0
+ go.uber.org/zap v1.27.0
goauthentik.io/api/v3 v3.2023051.3
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a
@@ -93,6 +98,7 @@ require (
gorm.io/driver/postgres v1.5.7
gorm.io/driver/sqlite v1.5.3
gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde
+ nhooyr.io/websocket v1.8.11
)
require (
@@ -106,8 +112,23 @@ require (
github.com/Microsoft/hcsshim v0.12.3 // indirect
github.com/XiaoMi/pegasus-go-client v0.0.0-20210427083443-f3b6b08bc4c2 // indirect
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect
+ github.com/aws/aws-sdk-go-v2 v1.30.3 // indirect
+ github.com/aws/aws-sdk-go-v2/config v1.27.27 // indirect
+ github.com/aws/aws-sdk-go-v2/credentials v1.17.27 // indirect
+ github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.11 // indirect
+ github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.15 // indirect
+ github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.15 // indirect
+ github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 // indirect
+ github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.3 // indirect
+ github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.17 // indirect
+ github.com/aws/aws-sdk-go-v2/service/route53 v1.42.3 // indirect
+ github.com/aws/aws-sdk-go-v2/service/sso v1.22.4 // indirect
+ github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.4 // indirect
+ github.com/aws/aws-sdk-go-v2/service/sts v1.30.3 // indirect
+ github.com/aws/smithy-go v1.20.3 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/bradfitz/gomemcache v0.0.0-20220106215444-fb4bf637b56d // indirect
+ github.com/caddyserver/zerossl v0.1.3 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/containerd/containerd v1.7.16 // indirect
github.com/containerd/log v0.1.0 // indirect
@@ -140,7 +161,7 @@ require (
github.com/googleapis/gax-go/v2 v2.12.3 // indirect
github.com/gopherjs/gopherjs v1.17.2 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect
- github.com/hashicorp/go-uuid v1.0.2 // indirect
+ github.com/hashicorp/go-uuid v1.0.3 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
@@ -149,13 +170,17 @@ require (
github.com/jeandeaual/go-locale v0.0.0-20240223122105-ce5225dcaa49 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
+ github.com/jmespath/go-jmespath v0.4.0 // indirect
github.com/josharian/native v1.1.0 // indirect
github.com/jsummers/gobmp v0.0.0-20151104160322-e2ba15ffa76e // indirect
github.com/kelseyhightower/envconfig v1.4.0 // indirect
github.com/klauspost/compress v1.17.8 // indirect
+ github.com/klauspost/cpuid/v2 v2.2.7 // indirect
+ github.com/libdns/libdns v0.2.2 // indirect
github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae // indirect
github.com/mdlayher/genetlink v1.3.2 // indirect
github.com/mdlayher/netlink v1.7.2 // indirect
+ github.com/mholt/acmez/v2 v2.0.1 // indirect
github.com/moby/docker-image-spec v1.3.1 // indirect
github.com/moby/patternmatcher v0.6.0 // indirect
github.com/moby/sys/sequential v0.5.0 // indirect
@@ -164,12 +189,12 @@ require (
github.com/morikuni/aec v1.0.0 // indirect
github.com/nicksnyder/go-i18n/v2 v2.4.0 // indirect
github.com/nxadm/tail v1.4.8 // indirect
+ github.com/onsi/ginkgo/v2 v2.9.5 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.0 // indirect
github.com/pegasus-kv/thrift v0.13.0 // indirect
github.com/pion/dtls/v2 v2.2.10 // indirect
github.com/pion/mdns v0.0.12 // indirect
- github.com/pion/randutil v0.1.0 // indirect
github.com/pion/transport/v2 v2.2.4 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
@@ -186,10 +211,12 @@ require (
github.com/tklauser/numcpus v0.8.0 // indirect
github.com/vishvananda/netns v0.0.4 // indirect
github.com/yuin/goldmark v1.7.1 // indirect
+ github.com/zeebo/blake3 v0.2.3 // indirect
go.opencensus.io v0.24.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect
go.opentelemetry.io/otel/sdk v1.26.0 // indirect
go.opentelemetry.io/otel/trace v1.26.0 // indirect
+ go.uber.org/multierr v1.11.0 // indirect
golang.org/x/image v0.18.0 // indirect
golang.org/x/mod v0.17.0 // indirect
golang.org/x/text v0.16.0 // indirect
diff --git a/go.sum b/go.sum
index 916f1f0c8..7a587c0d1 100644
--- a/go.sum
+++ b/go.sum
@@ -79,6 +79,34 @@ github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kd
github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o=
github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY=
github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8=
+github.com/aws/aws-sdk-go-v2 v1.30.3 h1:jUeBtG0Ih+ZIFH0F4UkmL9w3cSpaMv9tYYDbzILP8dY=
+github.com/aws/aws-sdk-go-v2 v1.30.3/go.mod h1:nIQjQVp5sfpQcTc9mPSr1B0PaWK5ByX9MOoDadSN4lc=
+github.com/aws/aws-sdk-go-v2/config v1.27.27 h1:HdqgGt1OAP0HkEDDShEl0oSYa9ZZBSOmKpdpsDMdO90=
+github.com/aws/aws-sdk-go-v2/config v1.27.27/go.mod h1:MVYamCg76dFNINkZFu4n4RjDixhVr51HLj4ErWzrVwg=
+github.com/aws/aws-sdk-go-v2/credentials v1.17.27 h1:2raNba6gr2IfA0eqqiP2XiQ0UVOpGPgDSi0I9iAP+UI=
+github.com/aws/aws-sdk-go-v2/credentials v1.17.27/go.mod h1:gniiwbGahQByxan6YjQUMcW4Aov6bLC3m+evgcoN4r4=
+github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.11 h1:KreluoV8FZDEtI6Co2xuNk/UqI9iwMrOx/87PBNIKqw=
+github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.11/go.mod h1:SeSUYBLsMYFoRvHE0Tjvn7kbxaUhl75CJi1sbfhMxkU=
+github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.15 h1:SoNJ4RlFEQEbtDcCEt+QG56MY4fm4W8rYirAmq+/DdU=
+github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.15/go.mod h1:U9ke74k1n2bf+RIgoX1SXFed1HLs51OgUSs+Ph0KJP8=
+github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.15 h1:C6WHdGnTDIYETAm5iErQUiVNsclNx9qbJVPIt03B6bI=
+github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.15/go.mod h1:ZQLZqhcu+JhSrA9/NXRm8SkDvsycE+JkV3WGY41e+IM=
+github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 h1:hT8rVHwugYE2lEfdFE0QWVo81lF7jMrYJVDWI+f+VxU=
+github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0/go.mod h1:8tu/lYfQfFe6IGnaOdrpVgEL2IrrDOf6/m9RQum4NkY=
+github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.3 h1:dT3MqvGhSoaIhRseqw2I0yH81l7wiR2vjs57O51EAm8=
+github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.3/go.mod h1:GlAeCkHwugxdHaueRr4nhPuY+WW+gR8UjlcqzPr1SPI=
+github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.17 h1:HGErhhrxZlQ044RiM+WdoZxp0p+EGM62y3L6pwA4olE=
+github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.17/go.mod h1:RkZEx4l0EHYDJpWppMJ3nD9wZJAa8/0lq9aVC+r2UII=
+github.com/aws/aws-sdk-go-v2/service/route53 v1.42.3 h1:MmLCRqP4U4Cw9gJ4bNrCG0mWqEtBlmAVleyelcHARMU=
+github.com/aws/aws-sdk-go-v2/service/route53 v1.42.3/go.mod h1:AMPjK2YnRh0YgOID3PqhJA1BRNfXDfGOnSsKHtAe8yA=
+github.com/aws/aws-sdk-go-v2/service/sso v1.22.4 h1:BXx0ZIxvrJdSgSvKTZ+yRBeSqqgPM89VPlulEcl37tM=
+github.com/aws/aws-sdk-go-v2/service/sso v1.22.4/go.mod h1:ooyCOXjvJEsUw7x+ZDHeISPMhtwI3ZCB7ggFMcFfWLU=
+github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.4 h1:yiwVzJW2ZxZTurVbYWA7QOrAaCYQR72t0wrSBfoesUE=
+github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.4/go.mod h1:0oxfLkpz3rQ/CHlx5hB7H69YUpFiI1tql6Q6Ne+1bCw=
+github.com/aws/aws-sdk-go-v2/service/sts v1.30.3 h1:ZsDKRLXGWHk8WdtyYMoGNO7bTudrvuKpDKgMVRlepGE=
+github.com/aws/aws-sdk-go-v2/service/sts v1.30.3/go.mod h1:zwySh8fpFyXp9yOr/KVzxOl8SRqgf/IDw5aUt9UKFcQ=
+github.com/aws/smithy-go v1.20.3 h1:ryHwveWzPV5BIof6fyDvor6V3iUL7nTfiTKXHiW05nE=
+github.com/aws/smithy-go v1.20.3/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs=
@@ -87,6 +115,10 @@ github.com/bradfitz/gomemcache v0.0.0-20220106215444-fb4bf637b56d h1:pVrfxiGfwel
github.com/bradfitz/gomemcache v0.0.0-20220106215444-fb4bf637b56d/go.mod h1:H0wQNHz2YrLsuXOZozoeDmnHXkNCRmMW0gwFWDfEZDA=
github.com/c-robinson/iplib v1.0.3 h1:NG0UF0GoEsrC1/vyfX1Lx2Ss7CySWl3KqqXh3q4DdPU=
github.com/c-robinson/iplib v1.0.3/go.mod h1:i3LuuFL1hRT5gFpBRnEydzw8R6yhGkF4szNDIbF8pgo=
+github.com/caddyserver/certmagic v0.21.3 h1:pqRRry3yuB4CWBVq9+cUqu+Y6E2z8TswbhNx1AZeYm0=
+github.com/caddyserver/certmagic v0.21.3/go.mod h1:Zq6pklO9nVRl3DIFUw9gVUfXKdpc/0qwTUAQMBlfgtI=
+github.com/caddyserver/zerossl v0.1.3 h1:onS+pxp3M8HnHpN5MMbOMyNjmTheJyWRaZYwn+YTAyA=
+github.com/caddyserver/zerossl v0.1.3/go.mod h1:CxA0acn7oEGO6//4rtrRjYgEoa4MFw/XofZnrYwGqG4=
github.com/cenkalti/backoff/v4 v4.1.0/go.mod h1:scbssz8iZGpm3xbr14ovlUdkxfGXNInqkPWOWmG2CLw=
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
@@ -207,6 +239,8 @@ github.com/go-quicktest/qt v1.101.0/go.mod h1:14Bz/f7NwaXPtdYEgzsx46kqSxVwTbzVZs
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE=
+github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
+github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls=
github.com/go-text/render v0.1.0 h1:osrmVDZNHuP1RSu3pNG7Z77Sd2xSbcb/xWytAj9kyVs=
github.com/go-text/render v0.1.0/go.mod h1:jqEuNMenrmj6QRnkdpeaP0oKGFLDNhDkVKwGjsWWYU4=
github.com/go-text/typesetting v0.1.0 h1:vioSaLPYcHwPEPLT7gsjCGDCoYSbljxoHJzMnKwVvHw=
@@ -350,8 +384,9 @@ github.com/hashicorp/go-sockaddr v1.0.0/go.mod h1:7Xibr9yA9JjQq1JpNB2Vw7kxv8xerX
github.com/hashicorp/go-syslog v1.0.0/go.mod h1:qPfqrKkXGihmCqbJM2mZgkZGvKG1dFdvsLplgctolz4=
github.com/hashicorp/go-uuid v1.0.0/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
github.com/hashicorp/go-uuid v1.0.1/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
-github.com/hashicorp/go-uuid v1.0.2 h1:cfejS+Tpcp13yd5nYHWDI6qVCny6wyX2Mt5SGur2IGE=
github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
+github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8=
+github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
github.com/hashicorp/go-version v1.6.0 h1:feTTfFNnjP967rlCxM/I9g701jU+RN74YKx2mOkIeek=
github.com/hashicorp/go-version v1.6.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA=
github.com/hashicorp/go.net v0.0.1/go.mod h1:hjKkEWcCURg++eb33jQU7oqQcI9XDCnUzHA0oac0k90=
@@ -382,6 +417,10 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
+github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg=
+github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo=
+github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8=
+github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U=
github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA=
github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
github.com/json-iterator/go v0.0.0-20180612202835-f2b4162afba3/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
@@ -401,6 +440,9 @@ github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/klauspost/compress v1.17.8 h1:YcnTYrq7MikUT7k0Yb5eceMmALQPYBW/Xltxn0NAMnU=
github.com/klauspost/compress v1.17.8/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
+github.com/klauspost/cpuid/v2 v2.0.12/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c=
+github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
+github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
@@ -413,6 +455,10 @@ github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdAPozLkw=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
+github.com/libdns/libdns v0.2.2 h1:O6ws7bAfRPaBsgAYt8MDe2HcNBGC29hkZ9MX2eUSX3s=
+github.com/libdns/libdns v0.2.2/go.mod h1:4Bj9+5CQiNMVGf87wjX4CY3HQJypUHRuLvlsfsZqLWQ=
+github.com/libdns/route53 v1.5.0 h1:2SKdpPFl/qgWsXQvsLNJJAoX7rSxlk7zgoL4jnWdXVA=
+github.com/libdns/route53 v1.5.0/go.mod h1:joT4hKmaTNKHEwb7GmZ65eoDz1whTu7KKYPS8ZqIh6Q=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae h1:dIZY4ULFcto4tAFlj1FYZl8ztUZ13bdq+PLY+NOfbyI=
github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae/go.mod h1:ilwx/Dta8jXAgpFYFvSWEMwxmbWXyiUHkd5FwyKhb5k=
@@ -431,9 +477,11 @@ github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U=
github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA=
+github.com/mholt/acmez/v2 v2.0.1 h1:3/3N0u1pLjMK4sNEAFSI+bcvzbPhRpY383sy1kLHJ6k=
+github.com/mholt/acmez/v2 v2.0.1/go.mod h1:fX4c9r5jYwMyMsC+7tkYRxHibkOTgta5DIFGoe67e1U=
github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg=
-github.com/miekg/dns v1.1.43 h1:JKfpVSCB84vrAmHzyrsxB5NAr5kLoMXZArPSw7Qlgyg=
-github.com/miekg/dns v1.1.43/go.mod h1:+evo5L0630/F6ca/Z9+GAqzhjGyn8/c+TBaOyfEl0V4=
+github.com/miekg/dns v1.1.59 h1:C9EXc/UToRwKLhK5wKU/I4QVsBUc8kE6MkHBkeypWZs=
+github.com/miekg/dns v1.1.59/go.mod h1:nZpewl5p6IvctfgrckopVx2OlSEHPRO/U4SYkRklrEk=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc=
github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc=
@@ -494,14 +542,14 @@ github.com/onsi/ginkgo v1.10.1/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+
github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk=
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
-github.com/onsi/ginkgo/v2 v2.4.0 h1:+Ig9nvqgS5OBSACXNk15PLdp0U9XPYROt9CFzVdFGIs=
-github.com/onsi/ginkgo/v2 v2.4.0/go.mod h1:iHkDK1fKGcBoEHT5W7YBq4RFWaQulw+caOMkAt4OrFo=
+github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q=
+github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k=
github.com/onsi/gomega v0.0.0-20170829124025-dcabb60a477c/go.mod h1:C1qb7wdrVGGVU+Z6iS04AVkA3Q65CEZX59MT0QO5uiA=
github.com/onsi/gomega v1.7.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY=
github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo=
-github.com/onsi/gomega v1.23.0 h1:/oxKu9c2HVap+F3PfKort2Hw5DEU+HGlW8n+tguWsys=
-github.com/onsi/gomega v1.23.0/go.mod h1:Z/NWtiqwBrwUt4/2loMmHL63EDLnYHmVbuBpDr2vQAg=
+github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE=
+github.com/onsi/gomega v1.27.6/go.mod h1:PIQNjfQwkP3aQAH7lf7j87O/5FiNr+ZR8+ipb+qQlhg=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQb2IpWsCzug=
@@ -592,6 +640,8 @@ github.com/smartystreets/assertions v1.13.0/go.mod h1:wDmR7qL282YbGsPy6H/yAsesrx
github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA=
github.com/smartystreets/goconvey v1.7.2 h1:9RBaZCeXEQ3UselpuwUQHltGVXvdwm6cv1hgR6gDIPg=
github.com/smartystreets/goconvey v1.7.2/go.mod h1:Vw0tHAZW6lzCRk3xgdin6fKYcG+G3Pg9vgXWeJpQFMM=
+github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D9tmUCz4VNwm9MfrtPr0SU2qSX8=
+github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E=
github.com/spf13/afero v1.6.0/go.mod h1:Ai8FlHk4v/PARR026UzYexafAt9roJ7LcLMAmO6Z93I=
github.com/spf13/cast v1.3.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE=
github.com/spf13/cast v1.5.0 h1:rj3WzYc11XZaIZMPKmwP96zkFEnnAmV8s6XbB2aY32w=
@@ -660,6 +710,12 @@ github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
github.com/zcalusic/sysinfo v1.0.2 h1:nwTTo2a+WQ0NXwo0BGRojOJvJ/5XKvQih+2RrtWqfxc=
github.com/zcalusic/sysinfo v1.0.2/go.mod h1:kluzTYflRWo6/tXVMJPdEjShsbPpsFRyy+p1mBQPC30=
+github.com/zeebo/assert v1.1.0 h1:hU1L1vLTHsnO8x8c9KAR5GmM5QscxHg5RNU5z5qbUWY=
+github.com/zeebo/assert v1.1.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0=
+github.com/zeebo/blake3 v0.2.3 h1:TFoLXsjeXqRNFxSbk35Dk4YtszE/MQQGK10BH4ptoTg=
+github.com/zeebo/blake3 v0.2.3/go.mod h1:mjJjZpnsyIVtVgTOSpJ9vmRE4wgDeyt2HU3qXvvKCaQ=
+github.com/zeebo/pcg v1.0.1 h1:lyqfGeWiv4ahac6ttHs+I5hwtH/+1mrhlCtVNQM2kHo=
+github.com/zeebo/pcg v1.0.1/go.mod h1:09F0S9iiKrwn9rlI5yjLkmrug154/YRW6KnnXVDM/l4=
go.etcd.io/etcd/api/v3 v3.5.0/go.mod h1:cbVKeC6lCfl7j/8jBhAK6aIYO9XOjdptoxU/nLQcPvs=
go.etcd.io/etcd/client/pkg/v3 v3.5.0/go.mod h1:IJHfcCEKxYu1Os13ZdwCwIUTUVGYTSAM3YSwc9/Ac1g=
go.etcd.io/etcd/client/v2 v2.305.0/go.mod h1:h9puh54ZTgAKtEbut2oe9P4L/oqKCVB6xsXlzd7alYQ=
@@ -695,8 +751,14 @@ go.opentelemetry.io/otel/trace v1.26.0/go.mod h1:4iDxvGDQuUkHve82hJJ8UqrwswHYsZu
go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I=
go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM=
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
+go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
+go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU=
+go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
+go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.17.0/go.mod h1:MXVU+bhUf/A7Xi2HNOnopQOrmycQ5Ih87HtOu4q5SSo=
+go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
+go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
goauthentik.io/api/v3 v3.2023051.3 h1:NebAhD/TeTWNo/9X3/Uj+rM5fG1HaiLOlKTNLQv9Qq4=
goauthentik.io/api/v3 v3.2023051.3/go.mod h1:nYECml4jGbp/541hj8GcylKQG1gVBsKppHy4+7G8u4U=
golang.org/x/crypto v0.0.0-20181029021203-45a5f77698d3/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
@@ -890,7 +952,6 @@ golang.org/x/sys v0.0.0-20210104204734-6f8348627aad/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210112080510-489259a85091/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210220050731-9a76102bfb43/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20210303074136-134d130e1a04/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210305230114-8fe3ee5dd75b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210315160823-c6e025ad8005/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -1187,6 +1248,8 @@ k8s.io/gengo v0.0.0-20190128074634-0689ccc1d7d6/go.mod h1:ezvh/TsK7cY6rbqRK0oQQ8
k8s.io/klog v0.0.0-20181102134211-b9b56d5dfc92/go.mod h1:Gq+BEi5rUBO/HRz0bTSXDUcqjScdoY3a9IHpCEIOOfk=
k8s.io/klog v1.0.0/go.mod h1:4Bi6QPql/J/LkTDqv7R/cd3hPo4k2DG6Ptcz060Ez5I=
k8s.io/kube-openapi v0.0.0-20191107075043-30be4d16710a/go.mod h1:1TqjTSzOxsLGIKfj0lK8EeCP7K1iUG65v09OM0/WG5E=
+nhooyr.io/websocket v1.8.11 h1:f/qXNc2/3DpoSZkHt1DQu6rj4zGC8JmkkLkWss0MgN0=
+nhooyr.io/websocket v1.8.11/go.mod h1:rN9OFWIUwuxg4fR5tELlYC04bXYowCP9GX47ivo2l+c=
rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8=
rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0=
rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA=
diff --git a/infrastructure_files/base.setup.env b/infrastructure_files/base.setup.env
index 296e165f0..45dce8d88 100644
--- a/infrastructure_files/base.setup.env
+++ b/infrastructure_files/base.setup.env
@@ -20,6 +20,12 @@ NETBIRD_MGMT_IDP_SIGNKEY_REFRESH=${NETBIRD_MGMT_IDP_SIGNKEY_REFRESH:-false}
NETBIRD_SIGNAL_PROTOCOL="http"
NETBIRD_SIGNAL_PORT=${NETBIRD_SIGNAL_PORT:-10000}
+# Relay
+NETBIRD_RELAY_DOMAIN=${NETBIRD_RELAY_DOMAIN:-$NETBIRD_DOMAIN}
+NETBIRD_RELAY_PORT=${NETBIRD_RELAY_PORT:-33080}
+# Relay auth secret
+NETBIRD_RELAY_AUTH_SECRET=
+
# Turn
TURN_DOMAIN=${NETBIRD_TURN_DOMAIN:-$NETBIRD_DOMAIN}
@@ -69,7 +75,7 @@ NETBIRD_DASHBOARD_TAG=${NETBIRD_DASHBOARD_TAG:-"latest"}
NETBIRD_SIGNAL_TAG=${NETBIRD_SIGNAL_TAG:-"latest"}
NETBIRD_MANAGEMENT_TAG=${NETBIRD_MANAGEMENT_TAG:-"latest"}
COTURN_TAG=${COTURN_TAG:-"latest"}
-
+NETBIRD_RELAY_TAG=${NETBIRD_RELAY_TAG:-"latest"}
# exports
export NETBIRD_DOMAIN
@@ -123,3 +129,7 @@ export NETBIRD_SIGNAL_TAG
export NETBIRD_MANAGEMENT_TAG
export COTURN_TAG
export NETBIRD_TURN_EXTERNAL_IP
+export NETBIRD_RELAY_DOMAIN
+export NETBIRD_RELAY_PORT
+export NETBIRD_RELAY_AUTH_SECRET
+export NETBIRD_RELAY_TAG
diff --git a/infrastructure_files/configure.sh b/infrastructure_files/configure.sh
index bf021c9ac..ff33004b2 100755
--- a/infrastructure_files/configure.sh
+++ b/infrastructure_files/configure.sh
@@ -89,6 +89,11 @@ fi
export TURN_EXTERNAL_IP_CONFIG
+# if not provided, we generate a relay auth secret
+if [[ "x-$NETBIRD_RELAY_AUTH_SECRET" == "x-" ]]; then
+ export NETBIRD_RELAY_AUTH_SECRET=$(openssl rand -base64 32 | sed 's/=//g')
+fi
+
artifacts_path="./artifacts"
mkdir -p $artifacts_path
diff --git a/infrastructure_files/docker-compose.yml.tmpl b/infrastructure_files/docker-compose.yml.tmpl
index 43c8b470c..ba68b3f8d 100644
--- a/infrastructure_files/docker-compose.yml.tmpl
+++ b/infrastructure_files/docker-compose.yml.tmpl
@@ -49,6 +49,23 @@ services:
options:
max-size: "500m"
max-file: "2"
+ # Relay
+ relay:
+ image: netbirdio/relay:$NETBIRD_RELAY_TAG
+ restart: unless-stopped
+ environment:
+ - NB_LOG_LEVEL=info
+ - NB_LISTEN_ADDRESS=:$NETBIRD_RELAY_PORT
+ - NB_EXPOSED_ADDRESS=$NETBIRD_RELAY_DOMAIN:$NETBIRD_RELAY_PORT
+ # todo: change to a secure secret
+ - NB_AUTH_SECRET=$NETBIRD_RELAY_AUTH_SECRET
+ ports:
+ - $NETBIRD_RELAY_PORT:$NETBIRD_RELAY_PORT
+ logging:
+ driver: "json-file"
+ options:
+ max-size: "500m"
+ max-file: "2"
# Management
management:
diff --git a/infrastructure_files/getting-started-with-zitadel.sh b/infrastructure_files/getting-started-with-zitadel.sh
index 1aae212ee..c0275536b 100644
--- a/infrastructure_files/getting-started-with-zitadel.sh
+++ b/infrastructure_files/getting-started-with-zitadel.sh
@@ -103,13 +103,25 @@ wait_api() {
INSTANCE_URL=$1
PAT=$2
set +e
+ counter=1
while true; do
- curl -s --fail -o /dev/null "$INSTANCE_URL/auth/v1/users/me" -H "Authorization: Bearer $PAT"
+ FLAGS="-s"
+ if [[ $counter -eq 45 ]]; then
+ FLAGS="-v"
+ echo ""
+ fi
+
+ curl $FLAGS --fail --connect-timeout 1 -o /dev/null "$INSTANCE_URL/auth/v1/users/me" -H "Authorization: Bearer $PAT"
if [[ $? -eq 0 ]]; then
break
fi
+ if [[ $counter -eq 45 ]]; then
+ echo ""
+ echo "Unable to connect to Zitadel for more than 45s, please check the output above, your firewall rules and the caddy container logs to confirm if there are any issues provisioning TLS certificates"
+ fi
echo -n " ."
sleep 1
+ counter=$((counter + 1))
done
echo " done"
set -e
@@ -424,8 +436,10 @@ initEnvironment() {
ZITADEL_MASTERKEY="$(openssl rand -base64 32 | head -c 32)"
NETBIRD_PORT=80
NETBIRD_HTTP_PROTOCOL="http"
+ NETBIRD_RELAY_PROTO="rel"
TURN_USER="self"
TURN_PASSWORD=$(openssl rand -base64 32 | sed 's/=//g')
+ NETBIRD_RELAY_AUTH_SECRET=$(openssl rand -base64 32 | sed 's/=//g')
TURN_MIN_PORT=49152
TURN_MAX_PORT=65535
TURN_EXTERNAL_IP_CONFIG=$(get_turn_external_ip)
@@ -442,6 +456,7 @@ initEnvironment() {
NETBIRD_PORT=443
CADDY_SECURE_DOMAIN=", $NETBIRD_DOMAIN:$NETBIRD_PORT"
NETBIRD_HTTP_PROTOCOL="https"
+ NETBIRD_RELAY_PROTO="rels"
fi
if [[ "$OSTYPE" == "darwin"* ]]; then
@@ -458,7 +473,7 @@ initEnvironment() {
echo "Generated files already exist, if you want to reinitialize the environment, please remove them first."
echo "You can use the following commands:"
echo " $DOCKER_COMPOSE_COMMAND down --volumes # to remove all containers and volumes"
- echo " rm -f docker-compose.yml Caddyfile zitadel.env dashboard.env machinekey/zitadel-admin-sa.token turnserver.conf management.json"
+ echo " rm -f docker-compose.yml Caddyfile zitadel.env dashboard.env machinekey/zitadel-admin-sa.token turnserver.conf management.json relay.env"
echo "Be aware that this will remove all data from the database, and you will have to reconfigure the dashboard."
exit 1
fi
@@ -484,6 +499,7 @@ initEnvironment() {
echo "" > dashboard.env
echo "" > turnserver.conf
echo "" > management.json
+ echo "" > relay.env
mkdir -p machinekey
chmod 777 machinekey
@@ -498,6 +514,7 @@ initEnvironment() {
renderTurnServerConf > turnserver.conf
renderManagementJson > management.json
renderDashboardEnv > dashboard.env
+ renderRelayEnv > relay.env
echo -e "\nStarting NetBird services\n"
$DOCKER_COMPOSE_COMMAND up -d
@@ -559,6 +576,8 @@ renderCaddyfile() {
:80${CADDY_SECURE_DOMAIN} {
import security_headers
+ # relay
+ reverse_proxy /relay* relay:80
# Signal
reverse_proxy /signalexchange.SignalExchange/* h2c://signal:10000
# Management
@@ -629,6 +648,11 @@ renderManagementJson() {
],
"TimeBasedCredentials": false
},
+ "Relay": {
+ "Addresses": ["$NETBIRD_RELAY_PROTO://$NETBIRD_DOMAIN:$NETBIRD_PORT"],
+ "CredentialsTTL": "24h",
+ "Secret": "$NETBIRD_RELAY_AUTH_SECRET"
+ },
"Signal": {
"Proto": "$NETBIRD_HTTP_PROTOCOL",
"URI": "$NETBIRD_DOMAIN:$NETBIRD_PORT"
@@ -744,6 +768,15 @@ POSTGRES_PASSWORD=$POSTGRES_ROOT_PASSWORD
EOF
}
+renderRelayEnv() {
+ cat < management.PeerSystemMeta
17, // 1: management.SyncResponse.wiretrusteeConfig:type_name -> management.WiretrusteeConfig
- 20, // 2: management.SyncResponse.peerConfig:type_name -> management.PeerConfig
- 22, // 3: management.SyncResponse.remotePeers:type_name -> management.RemotePeerConfig
- 21, // 4: management.SyncResponse.NetworkMap:type_name -> management.NetworkMap
- 37, // 5: management.SyncResponse.Checks:type_name -> management.Checks
+ 21, // 2: management.SyncResponse.peerConfig:type_name -> management.PeerConfig
+ 23, // 3: management.SyncResponse.remotePeers:type_name -> management.RemotePeerConfig
+ 22, // 4: management.SyncResponse.NetworkMap:type_name -> management.NetworkMap
+ 38, // 5: management.SyncResponse.Checks:type_name -> management.Checks
13, // 6: management.SyncMetaRequest.meta:type_name -> management.PeerSystemMeta
13, // 7: management.LoginRequest.meta:type_name -> management.PeerSystemMeta
10, // 8: management.LoginRequest.peerKeys:type_name -> management.PeerKeys
- 36, // 9: management.PeerSystemMeta.networkAddresses:type_name -> management.NetworkAddress
+ 37, // 9: management.PeerSystemMeta.networkAddresses:type_name -> management.NetworkAddress
11, // 10: management.PeerSystemMeta.environment:type_name -> management.Environment
12, // 11: management.PeerSystemMeta.files:type_name -> management.File
17, // 12: management.LoginResponse.wiretrusteeConfig:type_name -> management.WiretrusteeConfig
- 20, // 13: management.LoginResponse.peerConfig:type_name -> management.PeerConfig
- 37, // 14: management.LoginResponse.Checks:type_name -> management.Checks
- 38, // 15: management.ServerKeyResponse.expiresAt:type_name -> google.protobuf.Timestamp
+ 21, // 13: management.LoginResponse.peerConfig:type_name -> management.PeerConfig
+ 38, // 14: management.LoginResponse.Checks:type_name -> management.Checks
+ 39, // 15: management.ServerKeyResponse.expiresAt:type_name -> google.protobuf.Timestamp
18, // 16: management.WiretrusteeConfig.stuns:type_name -> management.HostConfig
- 19, // 17: management.WiretrusteeConfig.turns:type_name -> management.ProtectedHostConfig
+ 20, // 17: management.WiretrusteeConfig.turns:type_name -> management.ProtectedHostConfig
18, // 18: management.WiretrusteeConfig.signal:type_name -> management.HostConfig
- 0, // 19: management.HostConfig.protocol:type_name -> management.HostConfig.Protocol
- 18, // 20: management.ProtectedHostConfig.hostConfig:type_name -> management.HostConfig
- 23, // 21: management.PeerConfig.sshConfig:type_name -> management.SSHConfig
- 20, // 22: management.NetworkMap.peerConfig:type_name -> management.PeerConfig
- 22, // 23: management.NetworkMap.remotePeers:type_name -> management.RemotePeerConfig
- 29, // 24: management.NetworkMap.Routes:type_name -> management.Route
- 30, // 25: management.NetworkMap.DNSConfig:type_name -> management.DNSConfig
- 22, // 26: management.NetworkMap.offlinePeers:type_name -> management.RemotePeerConfig
- 35, // 27: management.NetworkMap.FirewallRules:type_name -> management.FirewallRule
- 23, // 28: management.RemotePeerConfig.sshConfig:type_name -> management.SSHConfig
- 1, // 29: management.DeviceAuthorizationFlow.Provider:type_name -> management.DeviceAuthorizationFlow.provider
- 28, // 30: management.DeviceAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig
- 28, // 31: management.PKCEAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig
- 33, // 32: management.DNSConfig.NameServerGroups:type_name -> management.NameServerGroup
- 31, // 33: management.DNSConfig.CustomZones:type_name -> management.CustomZone
- 32, // 34: management.CustomZone.Records:type_name -> management.SimpleRecord
- 34, // 35: management.NameServerGroup.NameServers:type_name -> management.NameServer
- 2, // 36: management.FirewallRule.Direction:type_name -> management.FirewallRule.direction
- 3, // 37: management.FirewallRule.Action:type_name -> management.FirewallRule.action
- 4, // 38: management.FirewallRule.Protocol:type_name -> management.FirewallRule.protocol
- 5, // 39: management.ManagementService.Login:input_type -> management.EncryptedMessage
- 5, // 40: management.ManagementService.Sync:input_type -> management.EncryptedMessage
- 16, // 41: management.ManagementService.GetServerKey:input_type -> management.Empty
- 16, // 42: management.ManagementService.isHealthy:input_type -> management.Empty
- 5, // 43: management.ManagementService.GetDeviceAuthorizationFlow:input_type -> management.EncryptedMessage
- 5, // 44: management.ManagementService.GetPKCEAuthorizationFlow:input_type -> management.EncryptedMessage
- 5, // 45: management.ManagementService.SyncMeta:input_type -> management.EncryptedMessage
- 5, // 46: management.ManagementService.Login:output_type -> management.EncryptedMessage
- 5, // 47: management.ManagementService.Sync:output_type -> management.EncryptedMessage
- 15, // 48: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse
- 16, // 49: management.ManagementService.isHealthy:output_type -> management.Empty
- 5, // 50: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage
- 5, // 51: management.ManagementService.GetPKCEAuthorizationFlow:output_type -> management.EncryptedMessage
- 16, // 52: management.ManagementService.SyncMeta:output_type -> management.Empty
- 46, // [46:53] is the sub-list for method output_type
- 39, // [39:46] is the sub-list for method input_type
- 39, // [39:39] is the sub-list for extension type_name
- 39, // [39:39] is the sub-list for extension extendee
- 0, // [0:39] is the sub-list for field type_name
+ 19, // 19: management.WiretrusteeConfig.relay:type_name -> management.RelayConfig
+ 0, // 20: management.HostConfig.protocol:type_name -> management.HostConfig.Protocol
+ 18, // 21: management.ProtectedHostConfig.hostConfig:type_name -> management.HostConfig
+ 24, // 22: management.PeerConfig.sshConfig:type_name -> management.SSHConfig
+ 21, // 23: management.NetworkMap.peerConfig:type_name -> management.PeerConfig
+ 23, // 24: management.NetworkMap.remotePeers:type_name -> management.RemotePeerConfig
+ 30, // 25: management.NetworkMap.Routes:type_name -> management.Route
+ 31, // 26: management.NetworkMap.DNSConfig:type_name -> management.DNSConfig
+ 23, // 27: management.NetworkMap.offlinePeers:type_name -> management.RemotePeerConfig
+ 36, // 28: management.NetworkMap.FirewallRules:type_name -> management.FirewallRule
+ 24, // 29: management.RemotePeerConfig.sshConfig:type_name -> management.SSHConfig
+ 1, // 30: management.DeviceAuthorizationFlow.Provider:type_name -> management.DeviceAuthorizationFlow.provider
+ 29, // 31: management.DeviceAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig
+ 29, // 32: management.PKCEAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig
+ 34, // 33: management.DNSConfig.NameServerGroups:type_name -> management.NameServerGroup
+ 32, // 34: management.DNSConfig.CustomZones:type_name -> management.CustomZone
+ 33, // 35: management.CustomZone.Records:type_name -> management.SimpleRecord
+ 35, // 36: management.NameServerGroup.NameServers:type_name -> management.NameServer
+ 2, // 37: management.FirewallRule.Direction:type_name -> management.FirewallRule.direction
+ 3, // 38: management.FirewallRule.Action:type_name -> management.FirewallRule.action
+ 4, // 39: management.FirewallRule.Protocol:type_name -> management.FirewallRule.protocol
+ 5, // 40: management.ManagementService.Login:input_type -> management.EncryptedMessage
+ 5, // 41: management.ManagementService.Sync:input_type -> management.EncryptedMessage
+ 16, // 42: management.ManagementService.GetServerKey:input_type -> management.Empty
+ 16, // 43: management.ManagementService.isHealthy:input_type -> management.Empty
+ 5, // 44: management.ManagementService.GetDeviceAuthorizationFlow:input_type -> management.EncryptedMessage
+ 5, // 45: management.ManagementService.GetPKCEAuthorizationFlow:input_type -> management.EncryptedMessage
+ 5, // 46: management.ManagementService.SyncMeta:input_type -> management.EncryptedMessage
+ 5, // 47: management.ManagementService.Login:output_type -> management.EncryptedMessage
+ 5, // 48: management.ManagementService.Sync:output_type -> management.EncryptedMessage
+ 15, // 49: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse
+ 16, // 50: management.ManagementService.isHealthy:output_type -> management.Empty
+ 5, // 51: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage
+ 5, // 52: management.ManagementService.GetPKCEAuthorizationFlow:output_type -> management.EncryptedMessage
+ 16, // 53: management.ManagementService.SyncMeta:output_type -> management.Empty
+ 47, // [47:54] is the sub-list for method output_type
+ 40, // [40:47] is the sub-list for method input_type
+ 40, // [40:40] is the sub-list for extension type_name
+ 40, // [40:40] is the sub-list for extension extendee
+ 0, // [0:40] is the sub-list for field type_name
}
func init() { file_management_proto_init() }
@@ -3256,7 +3339,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[14].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*ProtectedHostConfig); i {
+ switch v := v.(*RelayConfig); i {
case 0:
return &v.state
case 1:
@@ -3268,7 +3351,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[15].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*PeerConfig); i {
+ switch v := v.(*ProtectedHostConfig); i {
case 0:
return &v.state
case 1:
@@ -3280,7 +3363,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[16].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*NetworkMap); i {
+ switch v := v.(*PeerConfig); i {
case 0:
return &v.state
case 1:
@@ -3292,7 +3375,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[17].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*RemotePeerConfig); i {
+ switch v := v.(*NetworkMap); i {
case 0:
return &v.state
case 1:
@@ -3304,7 +3387,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[18].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*SSHConfig); i {
+ switch v := v.(*RemotePeerConfig); i {
case 0:
return &v.state
case 1:
@@ -3316,7 +3399,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[19].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*DeviceAuthorizationFlowRequest); i {
+ switch v := v.(*SSHConfig); i {
case 0:
return &v.state
case 1:
@@ -3328,7 +3411,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[20].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*DeviceAuthorizationFlow); i {
+ switch v := v.(*DeviceAuthorizationFlowRequest); i {
case 0:
return &v.state
case 1:
@@ -3340,7 +3423,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[21].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*PKCEAuthorizationFlowRequest); i {
+ switch v := v.(*DeviceAuthorizationFlow); i {
case 0:
return &v.state
case 1:
@@ -3352,7 +3435,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[22].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*PKCEAuthorizationFlow); i {
+ switch v := v.(*PKCEAuthorizationFlowRequest); i {
case 0:
return &v.state
case 1:
@@ -3364,7 +3447,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[23].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*ProviderConfig); i {
+ switch v := v.(*PKCEAuthorizationFlow); i {
case 0:
return &v.state
case 1:
@@ -3376,7 +3459,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[24].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*Route); i {
+ switch v := v.(*ProviderConfig); i {
case 0:
return &v.state
case 1:
@@ -3388,7 +3471,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[25].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*DNSConfig); i {
+ switch v := v.(*Route); i {
case 0:
return &v.state
case 1:
@@ -3400,7 +3483,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[26].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*CustomZone); i {
+ switch v := v.(*DNSConfig); i {
case 0:
return &v.state
case 1:
@@ -3412,7 +3495,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[27].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*SimpleRecord); i {
+ switch v := v.(*CustomZone); i {
case 0:
return &v.state
case 1:
@@ -3424,7 +3507,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[28].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*NameServerGroup); i {
+ switch v := v.(*SimpleRecord); i {
case 0:
return &v.state
case 1:
@@ -3436,7 +3519,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[29].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*NameServer); i {
+ switch v := v.(*NameServerGroup); i {
case 0:
return &v.state
case 1:
@@ -3448,7 +3531,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[30].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*FirewallRule); i {
+ switch v := v.(*NameServer); i {
case 0:
return &v.state
case 1:
@@ -3460,7 +3543,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[31].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*NetworkAddress); i {
+ switch v := v.(*FirewallRule); i {
case 0:
return &v.state
case 1:
@@ -3472,6 +3555,18 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[32].Exporter = func(v interface{}, i int) interface{} {
+ switch v := v.(*NetworkAddress); i {
+ case 0:
+ return &v.state
+ case 1:
+ return &v.sizeCache
+ case 2:
+ return &v.unknownFields
+ default:
+ return nil
+ }
+ }
+ file_management_proto_msgTypes[33].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*Checks); i {
case 0:
return &v.state
@@ -3490,7 +3585,7 @@ func file_management_proto_init() {
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_management_proto_rawDesc,
NumEnums: 5,
- NumMessages: 33,
+ NumMessages: 34,
NumExtensions: 0,
NumServices: 1,
},
diff --git a/management/proto/management.proto b/management/proto/management.proto
index 06b243773..c5646820f 100644
--- a/management/proto/management.proto
+++ b/management/proto/management.proto
@@ -177,6 +177,8 @@ message WiretrusteeConfig {
// a Signal server config
HostConfig signal = 3;
+
+ RelayConfig relay = 4;
}
// HostConfig describes connection properties of some server (e.g. STUN, Signal, Management)
@@ -193,6 +195,13 @@ message HostConfig {
DTLS = 4;
}
}
+
+message RelayConfig {
+ repeated string urls = 1;
+ string tokenPayload = 2;
+ string tokenSignature = 3;
+}
+
// ProtectedHostConfig is similar to HostConfig but has additional user and password
// Mostly used for TURN servers
message ProtectedHostConfig {
diff --git a/management/server/config.go b/management/server/config.go
index 4efe4fe74..2f7e49766 100644
--- a/management/server/config.go
+++ b/management/server/config.go
@@ -34,6 +34,7 @@ const (
type Config struct {
Stuns []*Host
TURNConfig *TURNConfig
+ Relay *Relay
Signal *Host
Datadir string
@@ -75,6 +76,12 @@ type TURNConfig struct {
Turns []*Host
}
+type Relay struct {
+ Addresses []string
+ CredentialsTTL util.Duration
+ Secret string
+}
+
// HttpServerConfig is a config of the HTTP Management service server
type HttpServerConfig struct {
LetsEncryptDomain string
diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go
index ead4a29d6..5d7094b6a 100644
--- a/management/server/grpcserver.go
+++ b/management/server/grpcserver.go
@@ -16,13 +16,12 @@ import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
- nbContext "github.com/netbirdio/netbird/management/server/context"
- "github.com/netbirdio/netbird/management/server/posture"
-
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/proto"
+ nbContext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
+ "github.com/netbirdio/netbird/management/server/posture"
internalStatus "github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/telemetry"
)
@@ -32,17 +31,25 @@ type GRPCServer struct {
accountManager AccountManager
wgKey wgtypes.Key
proto.UnimplementedManagementServiceServer
- peersUpdateManager *PeersUpdateManager
- config *Config
- turnCredentialsManager TURNCredentialsManager
- jwtValidator *jwtclaims.JWTValidator
- jwtClaimsExtractor *jwtclaims.ClaimsExtractor
- appMetrics telemetry.AppMetrics
- ephemeralManager *EphemeralManager
+ peersUpdateManager *PeersUpdateManager
+ config *Config
+ secretsManager SecretsManager
+ jwtValidator *jwtclaims.JWTValidator
+ jwtClaimsExtractor *jwtclaims.ClaimsExtractor
+ appMetrics telemetry.AppMetrics
+ ephemeralManager *EphemeralManager
}
// NewServer creates a new Management server
-func NewServer(ctx context.Context, config *Config, accountManager AccountManager, peersUpdateManager *PeersUpdateManager, turnCredentialsManager TURNCredentialsManager, appMetrics telemetry.AppMetrics, ephemeralManager *EphemeralManager) (*GRPCServer, error) {
+func NewServer(
+ ctx context.Context,
+ config *Config,
+ accountManager AccountManager,
+ peersUpdateManager *PeersUpdateManager,
+ secretsManager SecretsManager,
+ appMetrics telemetry.AppMetrics,
+ ephemeralManager *EphemeralManager,
+) (*GRPCServer, error) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
return nil, err
@@ -88,14 +95,14 @@ func NewServer(ctx context.Context, config *Config, accountManager AccountManage
return &GRPCServer{
wgKey: key,
// peerKey -> event channel
- peersUpdateManager: peersUpdateManager,
- accountManager: accountManager,
- config: config,
- turnCredentialsManager: turnCredentialsManager,
- jwtValidator: jwtValidator,
- jwtClaimsExtractor: jwtClaimsExtractor,
- appMetrics: appMetrics,
- ephemeralManager: ephemeralManager,
+ peersUpdateManager: peersUpdateManager,
+ accountManager: accountManager,
+ config: config,
+ secretsManager: secretsManager,
+ jwtValidator: jwtValidator,
+ jwtClaimsExtractor: jwtClaimsExtractor,
+ appMetrics: appMetrics,
+ ephemeralManager: ephemeralManager,
}, nil
}
@@ -177,9 +184,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
s.ephemeralManager.OnPeerConnected(ctx, peer)
- if s.config.TURNConfig.TimeBasedCredentials {
- s.turnCredentialsManager.SetupRefresh(ctx, peer.ID)
- }
+ s.secretsManager.SetupRefresh(ctx, peer.ID)
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart))
@@ -241,7 +246,7 @@ func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey w
func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
s.peersUpdateManager.CloseChannel(ctx, peer.ID)
- s.turnCredentialsManager.CancelRefresh(peer.ID)
+ s.secretsManager.CancelRefresh(peer.ID)
_ = s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
}
@@ -427,9 +432,17 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
}
+ var relayToken *Token
+ if s.config.Relay != nil && len(s.config.Relay.Addresses) > 0 {
+ relayToken, err = s.secretsManager.GenerateRelayToken()
+ if err != nil {
+ log.Errorf("failed generating Relay token: %v", err)
+ }
+ }
+
// if peer has reached this point then it has logged in
loginResp := &proto.LoginResponse{
- WiretrusteeConfig: toWiretrusteeConfig(s.config, nil),
+ WiretrusteeConfig: toWiretrusteeConfig(s.config, nil, relayToken),
PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain()),
Checks: toProtocolChecks(ctx, postureChecks),
}
@@ -487,10 +500,11 @@ func ToResponseProto(configProto Protocol) proto.HostConfig_Protocol {
}
}
-func toWiretrusteeConfig(config *Config, turnCredentials *TURNCredentials) *proto.WiretrusteeConfig {
+func toWiretrusteeConfig(config *Config, turnCredentials *Token, relayToken *Token) *proto.WiretrusteeConfig {
if config == nil {
return nil
}
+
var stuns []*proto.HostConfig
for _, stun := range config.Stuns {
stuns = append(stuns, &proto.HostConfig{
@@ -498,25 +512,40 @@ func toWiretrusteeConfig(config *Config, turnCredentials *TURNCredentials) *prot
Protocol: ToResponseProto(stun.Proto),
})
}
+
var turns []*proto.ProtectedHostConfig
- for _, turn := range config.TURNConfig.Turns {
- var username string
- var password string
- if turnCredentials != nil {
- username = turnCredentials.Username
- password = turnCredentials.Password
- } else {
- username = turn.Username
- password = turn.Password
+ if config.TURNConfig != nil {
+ for _, turn := range config.TURNConfig.Turns {
+ var username string
+ var password string
+ if turnCredentials != nil {
+ username = turnCredentials.Payload
+ password = turnCredentials.Signature
+ } else {
+ username = turn.Username
+ password = turn.Password
+ }
+ turns = append(turns, &proto.ProtectedHostConfig{
+ HostConfig: &proto.HostConfig{
+ Uri: turn.URI,
+ Protocol: ToResponseProto(turn.Proto),
+ },
+ User: username,
+ Password: password,
+ })
+ }
+ }
+
+ var relayCfg *proto.RelayConfig
+ if config.Relay != nil && len(config.Relay.Addresses) > 0 {
+ relayCfg = &proto.RelayConfig{
+ Urls: config.Relay.Addresses,
+ }
+
+ if relayToken != nil {
+ relayCfg.TokenPayload = relayToken.Payload
+ relayCfg.TokenSignature = relayToken.Signature
}
- turns = append(turns, &proto.ProtectedHostConfig{
- HostConfig: &proto.HostConfig{
- Uri: turn.URI,
- Protocol: ToResponseProto(turn.Proto),
- },
- User: username,
- Password: password,
- })
}
return &proto.WiretrusteeConfig{
@@ -526,6 +555,7 @@ func toWiretrusteeConfig(config *Config, turnCredentials *TURNCredentials) *prot
Uri: config.Signal.URI,
Protocol: ToResponseProto(config.Signal.Proto),
},
+ Relay: relayCfg,
}
}
@@ -539,9 +569,9 @@ func toPeerConfig(peer *nbpeer.Peer, network *Network, dnsName string) *proto.Pe
}
}
-func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache) *proto.SyncResponse {
+func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache) *proto.SyncResponse {
response := &proto.SyncResponse{
- WiretrusteeConfig: toWiretrusteeConfig(config, turnCredentials),
+ WiretrusteeConfig: toWiretrusteeConfig(config, turnCredentials, relayCredentials),
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName),
NetworkMap: &proto.NetworkMap{
Serial: networkMap.Network.CurrentSerial(),
@@ -588,15 +618,25 @@ func (s *GRPCServer) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Em
// sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization
func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer) error {
- // make secret time based TURN credentials optional
- var turnCredentials *TURNCredentials
- if s.config.TURNConfig.TimeBasedCredentials {
- creds := s.turnCredentialsManager.GenerateCredentials()
- turnCredentials = &creds
- } else {
- turnCredentials = nil
+ var err error
+
+ var turnToken *Token
+ if s.config.TURNConfig != nil && s.config.TURNConfig.TimeBasedCredentials {
+ turnToken, err = s.secretsManager.GenerateTurnToken()
+ if err != nil {
+ log.Errorf("failed generating TURN token: %v", err)
+ }
}
- plainResp := toSyncResponse(ctx, s.config, peer, turnCredentials, networkMap, s.accountManager.GetDNSDomain(), postureChecks, nil)
+
+ var relayToken *Token
+ if s.config.Relay != nil && len(s.config.Relay.Addresses) > 0 {
+ relayToken, err = s.secretsManager.GenerateRelayToken()
+ if err != nil {
+ log.Errorf("failed generating Relay token: %v", err)
+ }
+ }
+
+ plainResp := toSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.accountManager.GetDNSDomain(), postureChecks, nil)
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
if err != nil {
diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go
index d48e1f513..00ee4bda2 100644
--- a/management/server/management_proto_test.go
+++ b/management/server/management_proto_test.go
@@ -439,10 +439,11 @@ func startManagementForTest(t TestingT, config *Config) (*grpc.Server, *DefaultA
if err != nil {
return nil, nil, "", err
}
- turnManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
+
+ secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
ephemeralMgr := NewEphemeralManager(store, accountManager)
- mgmtServer, err := NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, ephemeralMgr)
+ mgmtServer, err := NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, ephemeralMgr)
if err != nil {
return nil, nil, "", err
}
diff --git a/management/server/management_test.go b/management/server/management_test.go
index 62e7f5a05..3956d96b1 100644
--- a/management/server/management_test.go
+++ b/management/server/management_test.go
@@ -552,8 +552,9 @@ func startServer(config *server.Config) (*grpc.Server, net.Listener) {
if err != nil {
log.Fatalf("failed creating a manager: %v", err)
}
- turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
- mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
+
+ secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
+ mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, nil)
Expect(err).NotTo(HaveOccurred())
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
go func() {
diff --git a/management/server/peer.go b/management/server/peer.go
index 6926ef6bc..5fc6352ee 100644
--- a/management/server/peer.go
+++ b/management/server/peer.go
@@ -964,7 +964,7 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
postureChecks := am.getPeerPostureChecks(account, p)
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
- update := toSyncResponse(ctx, nil, p, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache)
+ update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache)
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update})
}(peer)
}
diff --git a/management/server/peer_test.go b/management/server/peer_test.go
index 918436515..448e83a08 100644
--- a/management/server/peer_test.go
+++ b/management/server/peer_test.go
@@ -848,9 +848,9 @@ func TestToSyncResponse(t *testing.T) {
DNSLabel: "peer1",
SSHKey: "peer1-ssh-key",
}
- turnCredentials := &TURNCredentials{
- Username: "turn-user",
- Password: "turn-pass",
+ turnRelayToken := &Token{
+ Payload: "turn-user",
+ Signature: "turn-pass",
}
networkMap := &NetworkMap{
Network: &Network{Net: *ipnet, Serial: 1000},
@@ -916,7 +916,7 @@ func TestToSyncResponse(t *testing.T) {
}
dnsCache := &DNSConfigCache{}
- response := toSyncResponse(context.Background(), config, peer, turnCredentials, networkMap, dnsName, checks, dnsCache)
+ response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache)
assert.NotNil(t, response)
// assert peer config
diff --git a/management/server/token_mgr.go b/management/server/token_mgr.go
new file mode 100644
index 000000000..8a6648a3a
--- /dev/null
+++ b/management/server/token_mgr.go
@@ -0,0 +1,222 @@
+package server
+
+import (
+ "context"
+ "crypto/sha1"
+ "crypto/sha256"
+ "fmt"
+ "sync"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+
+ "github.com/netbirdio/netbird/management/proto"
+ auth "github.com/netbirdio/netbird/relay/auth/hmac"
+)
+
+const defaultDuration = 12 * time.Hour
+
+// SecretsManager used to manage TURN and relay secrets
+type SecretsManager interface {
+ GenerateTurnToken() (*Token, error)
+ GenerateRelayToken() (*Token, error)
+ SetupRefresh(ctx context.Context, peerKey string)
+ CancelRefresh(peerKey string)
+}
+
+// TimeBasedAuthSecretsManager generates credentials with TTL and using pre-shared secret known to TURN server
+type TimeBasedAuthSecretsManager struct {
+ mux sync.Mutex
+ turnCfg *TURNConfig
+ relayCfg *Relay
+ turnHmacToken *auth.TimedHMAC
+ relayHmacToken *auth.TimedHMAC
+ updateManager *PeersUpdateManager
+ turnCancelMap map[string]chan struct{}
+ relayCancelMap map[string]chan struct{}
+}
+
+type Token auth.Token
+
+func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *TURNConfig, relayCfg *Relay) *TimeBasedAuthSecretsManager {
+ mgr := &TimeBasedAuthSecretsManager{
+ updateManager: updateManager,
+ turnCfg: turnCfg,
+ relayCfg: relayCfg,
+ turnCancelMap: make(map[string]chan struct{}),
+ relayCancelMap: make(map[string]chan struct{}),
+ }
+
+ if turnCfg != nil {
+ duration := turnCfg.CredentialsTTL.Duration
+ if turnCfg.CredentialsTTL.Duration <= 0 {
+ log.Warnf("TURN credentials TTL is not set or invalid, using default value %s", defaultDuration)
+ duration = defaultDuration
+ }
+ mgr.turnHmacToken = auth.NewTimedHMAC(turnCfg.Secret, duration)
+ }
+
+ if relayCfg != nil {
+ duration := relayCfg.CredentialsTTL.Duration
+ if relayCfg.CredentialsTTL.Duration <= 0 {
+ log.Warnf("Relay credentials TTL is not set or invalid, using default value %s", defaultDuration)
+ duration = defaultDuration
+ }
+
+ mgr.relayHmacToken = auth.NewTimedHMAC(relayCfg.Secret, duration)
+ }
+
+ return mgr
+}
+
+// GenerateTurnToken generates new time-based secret credentials for TURN
+func (m *TimeBasedAuthSecretsManager) GenerateTurnToken() (*Token, error) {
+ if m.turnHmacToken == nil {
+ return nil, fmt.Errorf("TURN configuration is not set")
+ }
+ turnToken, err := m.turnHmacToken.GenerateToken(sha1.New)
+ if err != nil {
+ return nil, fmt.Errorf("failed to generate TURN token: %s", err)
+ }
+ return (*Token)(turnToken), nil
+}
+
+// GenerateRelayToken generates new time-based secret credentials for relay
+func (m *TimeBasedAuthSecretsManager) GenerateRelayToken() (*Token, error) {
+ if m.relayHmacToken == nil {
+ return nil, fmt.Errorf("relay configuration is not set")
+ }
+ relayToken, err := m.relayHmacToken.GenerateToken(sha256.New)
+ if err != nil {
+ return nil, fmt.Errorf("failed to generate relay token: %s", err)
+ }
+ return (*Token)(relayToken), nil
+}
+
+func (m *TimeBasedAuthSecretsManager) cancelTURN(peerID string) {
+ if channel, ok := m.turnCancelMap[peerID]; ok {
+ close(channel)
+ delete(m.turnCancelMap, peerID)
+ }
+}
+
+func (m *TimeBasedAuthSecretsManager) cancelRelay(peerID string) {
+ if channel, ok := m.relayCancelMap[peerID]; ok {
+ close(channel)
+ delete(m.relayCancelMap, peerID)
+ }
+}
+
+// CancelRefresh cancels scheduled peer credentials refresh
+func (m *TimeBasedAuthSecretsManager) CancelRefresh(peerID string) {
+ m.mux.Lock()
+ defer m.mux.Unlock()
+ m.cancelTURN(peerID)
+ m.cancelRelay(peerID)
+}
+
+// SetupRefresh starts peer credentials refresh
+func (m *TimeBasedAuthSecretsManager) SetupRefresh(ctx context.Context, peerID string) {
+ m.mux.Lock()
+ defer m.mux.Unlock()
+
+ m.cancelTURN(peerID)
+ m.cancelRelay(peerID)
+
+ if m.turnCfg != nil && m.turnCfg.TimeBasedCredentials {
+ turnCancel := make(chan struct{}, 1)
+ m.turnCancelMap[peerID] = turnCancel
+ go m.refreshTURNTokens(ctx, peerID, turnCancel)
+ log.WithContext(ctx).Debugf("starting TURN refresh for %s", peerID)
+ }
+
+ if m.relayCfg != nil {
+ relayCancel := make(chan struct{}, 1)
+ m.relayCancelMap[peerID] = relayCancel
+ go m.refreshRelayTokens(ctx, peerID, relayCancel)
+ log.WithContext(ctx).Debugf("starting relay refresh for %s", peerID)
+ }
+}
+
+func (m *TimeBasedAuthSecretsManager) refreshTURNTokens(ctx context.Context, peerID string, cancel chan struct{}) {
+ ticker := time.NewTicker(m.turnCfg.CredentialsTTL.Duration / 4 * 3)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-cancel:
+ log.WithContext(ctx).Debugf("stopping TURN refresh for %s", peerID)
+ return
+ case <-ticker.C:
+ m.pushNewTURNTokens(ctx, peerID)
+ }
+ }
+}
+
+func (m *TimeBasedAuthSecretsManager) refreshRelayTokens(ctx context.Context, peerID string, cancel chan struct{}) {
+ ticker := time.NewTicker(m.relayCfg.CredentialsTTL.Duration / 4 * 3)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-cancel:
+ log.WithContext(ctx).Debugf("stopping relay refresh for %s", peerID)
+ return
+ case <-ticker.C:
+ m.pushNewRelayTokens(ctx, peerID)
+ }
+ }
+}
+
+func (m *TimeBasedAuthSecretsManager) pushNewTURNTokens(ctx context.Context, peerID string) {
+ turnToken, err := m.turnHmacToken.GenerateToken(sha1.New)
+ if err != nil {
+ log.Errorf("failed to generate token for peer '%s': %s", peerID, err)
+ return
+ }
+
+ var turns []*proto.ProtectedHostConfig
+ for _, host := range m.turnCfg.Turns {
+ turn := &proto.ProtectedHostConfig{
+ HostConfig: &proto.HostConfig{
+ Uri: host.URI,
+ Protocol: ToResponseProto(host.Proto),
+ },
+ User: turnToken.Payload,
+ Password: turnToken.Signature,
+ }
+ turns = append(turns, turn)
+ }
+
+ update := &proto.SyncResponse{
+ WiretrusteeConfig: &proto.WiretrusteeConfig{
+ Turns: turns,
+ // omit Relay to avoid updates there
+ },
+ }
+
+ log.WithContext(ctx).Debugf("sending new TURN credentials to peer %s", peerID)
+ m.updateManager.SendUpdate(ctx, peerID, &UpdateMessage{Update: update})
+}
+
+func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, peerID string) {
+ relayToken, err := m.relayHmacToken.GenerateToken(sha256.New)
+ if err != nil {
+ log.Errorf("failed to generate relay token for peer '%s': %s", peerID, err)
+ return
+ }
+
+ update := &proto.SyncResponse{
+ WiretrusteeConfig: &proto.WiretrusteeConfig{
+ Relay: &proto.RelayConfig{
+ Urls: m.relayCfg.Addresses,
+ TokenPayload: relayToken.Payload,
+ TokenSignature: relayToken.Signature,
+ },
+ // omit Turns to avoid updates there
+ },
+ }
+
+ log.WithContext(ctx).Debugf("sending new relay credentials to peer %s", peerID)
+ m.updateManager.SendUpdate(ctx, peerID, &UpdateMessage{Update: update})
+}
diff --git a/management/server/token_mgr_test.go b/management/server/token_mgr_test.go
new file mode 100644
index 000000000..d59fd3a3f
--- /dev/null
+++ b/management/server/token_mgr_test.go
@@ -0,0 +1,218 @@
+package server
+
+import (
+ "context"
+ "crypto/hmac"
+ "crypto/sha1"
+ "crypto/sha256"
+ "encoding/base64"
+ "hash"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+
+ "github.com/netbirdio/netbird/management/proto"
+ "github.com/netbirdio/netbird/util"
+)
+
+var TurnTestHost = &Host{
+ Proto: UDP,
+ URI: "turn:turn.wiretrustee.com:77777",
+ Username: "username",
+ Password: "",
+}
+
+func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
+ ttl := util.Duration{Duration: time.Hour}
+ secret := "some_secret"
+ peersManager := NewPeersUpdateManager(nil)
+
+ rc := &Relay{
+ Addresses: []string{"localhost:0"},
+ CredentialsTTL: ttl,
+ Secret: secret,
+ }
+
+ tested := NewTimeBasedAuthSecretsManager(peersManager, &TURNConfig{
+ CredentialsTTL: ttl,
+ Secret: secret,
+ Turns: []*Host{TurnTestHost},
+ TimeBasedCredentials: true,
+ }, rc)
+
+ turnCredentials, err := tested.GenerateTurnToken()
+ require.NoError(t, err)
+
+ if turnCredentials.Payload == "" {
+ t.Errorf("expected generated TURN username not to be empty, got empty")
+ }
+ if turnCredentials.Signature == "" {
+ t.Errorf("expected generated TURN password not to be empty, got empty")
+ }
+
+ validateMAC(t, sha1.New, turnCredentials.Payload, turnCredentials.Signature, []byte(secret))
+
+ relayCredentials, err := tested.GenerateRelayToken()
+ require.NoError(t, err)
+
+ if relayCredentials.Payload == "" {
+ t.Errorf("expected generated relay payload not to be empty, got empty")
+ }
+ if relayCredentials.Signature == "" {
+ t.Errorf("expected generated relay signature not to be empty, got empty")
+ }
+
+ validateMAC(t, sha256.New, relayCredentials.Payload, relayCredentials.Signature, []byte(secret))
+}
+
+func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
+ ttl := util.Duration{Duration: 2 * time.Second}
+ secret := "some_secret"
+ peersManager := NewPeersUpdateManager(nil)
+ peer := "some_peer"
+ updateChannel := peersManager.CreateChannel(context.Background(), peer)
+
+ rc := &Relay{
+ Addresses: []string{"localhost:0"},
+ CredentialsTTL: ttl,
+ Secret: secret,
+ }
+ tested := NewTimeBasedAuthSecretsManager(peersManager, &TURNConfig{
+ CredentialsTTL: ttl,
+ Secret: secret,
+ Turns: []*Host{TurnTestHost},
+ TimeBasedCredentials: true,
+ }, rc)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ tested.SetupRefresh(ctx, peer)
+
+ if _, ok := tested.turnCancelMap[peer]; !ok {
+ t.Errorf("expecting peer to be present in the turn cancel map, got not present")
+ }
+
+ if _, ok := tested.relayCancelMap[peer]; !ok {
+ t.Errorf("expecting peer to be present in the relay cancel map, got not present")
+ }
+
+ var updates []*UpdateMessage
+
+loop:
+ for timeout := time.After(5 * time.Second); ; {
+ select {
+ case update := <-updateChannel:
+ updates = append(updates, update)
+ case <-timeout:
+ break loop
+ }
+
+ if len(updates) >= 2 {
+ break loop
+ }
+ }
+
+ if len(updates) < 2 {
+ t.Errorf("expecting at least 2 peer credentials updates, got %v", len(updates))
+ }
+
+ var turnUpdates, relayUpdates int
+ var firstTurnUpdate, secondTurnUpdate *proto.ProtectedHostConfig
+ var firstRelayUpdate, secondRelayUpdate *proto.RelayConfig
+
+ for _, update := range updates {
+ if turns := update.Update.GetWiretrusteeConfig().GetTurns(); len(turns) > 0 {
+ turnUpdates++
+ if turnUpdates == 1 {
+ firstTurnUpdate = turns[0]
+ } else {
+ secondTurnUpdate = turns[0]
+ }
+ }
+ if relay := update.Update.GetWiretrusteeConfig().GetRelay(); relay != nil {
+ relayUpdates++
+ if relayUpdates == 1 {
+ firstRelayUpdate = relay
+ } else {
+ secondRelayUpdate = relay
+ }
+ }
+ }
+
+ if turnUpdates < 1 {
+ t.Errorf("expecting at least 1 TURN credential update, got %v", turnUpdates)
+ }
+ if relayUpdates < 1 {
+ t.Errorf("expecting at least 1 relay credential update, got %v", relayUpdates)
+ }
+
+ if firstTurnUpdate != nil && secondTurnUpdate != nil {
+ if firstTurnUpdate.Password == secondTurnUpdate.Password {
+ t.Errorf("expecting first TURN credential update password %v to be different from second, got equal", firstTurnUpdate.Password)
+ }
+ }
+
+ if firstRelayUpdate != nil && secondRelayUpdate != nil {
+ if firstRelayUpdate.TokenSignature == secondRelayUpdate.TokenSignature {
+ t.Errorf("expecting first relay credential update signature %v to be different from second, got equal", firstRelayUpdate.TokenSignature)
+ }
+ }
+}
+
+func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) {
+ ttl := util.Duration{Duration: time.Hour}
+ secret := "some_secret"
+ peersManager := NewPeersUpdateManager(nil)
+ peer := "some_peer"
+
+ rc := &Relay{
+ Addresses: []string{"localhost:0"},
+ CredentialsTTL: ttl,
+ Secret: secret,
+ }
+ tested := NewTimeBasedAuthSecretsManager(peersManager, &TURNConfig{
+ CredentialsTTL: ttl,
+ Secret: secret,
+ Turns: []*Host{TurnTestHost},
+ TimeBasedCredentials: true,
+ }, rc)
+
+ tested.SetupRefresh(context.Background(), peer)
+ if _, ok := tested.turnCancelMap[peer]; !ok {
+ t.Errorf("expecting peer to be present in turn cancel map, got not present")
+ }
+ if _, ok := tested.relayCancelMap[peer]; !ok {
+ t.Errorf("expecting peer to be present in relay cancel map, got not present")
+ }
+
+ tested.CancelRefresh(peer)
+ if _, ok := tested.turnCancelMap[peer]; ok {
+ t.Errorf("expecting peer to be not present in turn cancel map, got present")
+ }
+ if _, ok := tested.relayCancelMap[peer]; ok {
+ t.Errorf("expecting peer to be not present in relay cancel map, got present")
+ }
+}
+
+func validateMAC(t *testing.T, algo func() hash.Hash, username string, actualMAC string, key []byte) {
+ t.Helper()
+ mac := hmac.New(algo, key)
+
+ _, err := mac.Write([]byte(username))
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ expectedMAC := mac.Sum(nil)
+ decodedMAC, err := base64.StdEncoding.DecodeString(actualMAC)
+ if err != nil {
+ t.Fatal(err)
+ }
+ equal := hmac.Equal(decodedMAC, expectedMAC)
+
+ if !equal {
+ t.Errorf("expected password MAC to be %s. got %s", expectedMAC, decodedMAC)
+ }
+}
diff --git a/management/server/turncredentials.go b/management/server/turncredentials.go
deleted file mode 100644
index 79f42e882..000000000
--- a/management/server/turncredentials.go
+++ /dev/null
@@ -1,126 +0,0 @@
-package server
-
-import (
- "context"
- "crypto/hmac"
- "crypto/sha1"
- "encoding/base64"
- "fmt"
- "sync"
- "time"
-
- log "github.com/sirupsen/logrus"
-
- "github.com/netbirdio/netbird/management/proto"
-)
-
-// TURNCredentialsManager used to manage TURN credentials
-type TURNCredentialsManager interface {
- GenerateCredentials() TURNCredentials
- SetupRefresh(ctx context.Context, peerKey string)
- CancelRefresh(peerKey string)
-}
-
-// TimeBasedAuthSecretsManager generates credentials with TTL and using pre-shared secret known to TURN server
-type TimeBasedAuthSecretsManager struct {
- mux sync.Mutex
- config *TURNConfig
- updateManager *PeersUpdateManager
- cancelMap map[string]chan struct{}
-}
-
-type TURNCredentials struct {
- Username string
- Password string
-}
-
-func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, config *TURNConfig) *TimeBasedAuthSecretsManager {
- return &TimeBasedAuthSecretsManager{
- mux: sync.Mutex{},
- config: config,
- updateManager: updateManager,
- cancelMap: make(map[string]chan struct{}),
- }
-}
-
-// GenerateCredentials generates new time-based secret credentials - basically username is a unix timestamp and password is a HMAC hash of a timestamp with a preshared TURN secret
-func (m *TimeBasedAuthSecretsManager) GenerateCredentials() TURNCredentials {
- mac := hmac.New(sha1.New, []byte(m.config.Secret))
-
- timeAuth := time.Now().Add(m.config.CredentialsTTL.Duration).Unix()
-
- username := fmt.Sprint(timeAuth)
-
- _, err := mac.Write([]byte(username))
- if err != nil {
- log.Errorln("Generating turn password failed with error: ", err)
- }
-
- bytePassword := mac.Sum(nil)
- password := base64.StdEncoding.EncodeToString(bytePassword)
-
- return TURNCredentials{
- Username: username,
- Password: password,
- }
-
-}
-
-func (m *TimeBasedAuthSecretsManager) cancel(peerID string) {
- if channel, ok := m.cancelMap[peerID]; ok {
- close(channel)
- delete(m.cancelMap, peerID)
- }
-}
-
-// CancelRefresh cancels scheduled peer credentials refresh
-func (m *TimeBasedAuthSecretsManager) CancelRefresh(peerID string) {
- m.mux.Lock()
- defer m.mux.Unlock()
- m.cancel(peerID)
-}
-
-// SetupRefresh starts peer credentials refresh. Since credentials are expiring (TTL) it is necessary to always generate them and send to the peer.
-// A goroutine is created and put into TimeBasedAuthSecretsManager.cancelMap. This routine should be cancelled if peer is gone.
-func (m *TimeBasedAuthSecretsManager) SetupRefresh(ctx context.Context, peerID string) {
- m.mux.Lock()
- defer m.mux.Unlock()
- m.cancel(peerID)
- cancel := make(chan struct{}, 1)
- m.cancelMap[peerID] = cancel
- log.WithContext(ctx).Debugf("starting turn refresh for %s", peerID)
-
- go func() {
- // we don't want to regenerate credentials right on expiration, so we do it slightly before (at 3/4 of TTL)
- ticker := time.NewTicker(m.config.CredentialsTTL.Duration / 4 * 3)
-
- for {
- select {
- case <-cancel:
- log.WithContext(ctx).Debugf("stopping turn refresh for %s", peerID)
- return
- case <-ticker.C:
- c := m.GenerateCredentials()
- var turns []*proto.ProtectedHostConfig
- for _, host := range m.config.Turns {
- turns = append(turns, &proto.ProtectedHostConfig{
- HostConfig: &proto.HostConfig{
- Uri: host.URI,
- Protocol: ToResponseProto(host.Proto),
- },
- User: c.Username,
- Password: c.Password,
- })
- }
-
- update := &proto.SyncResponse{
- WiretrusteeConfig: &proto.WiretrusteeConfig{
- Turns: turns,
- },
- }
- log.WithContext(ctx).Debugf("sending new TURN credentials to peer %s", peerID)
- m.updateManager.SendUpdate(ctx, peerID, &UpdateMessage{Update: update})
- }
- }
- }()
-}
diff --git a/management/server/turncredentials_test.go b/management/server/turncredentials_test.go
deleted file mode 100644
index 667dccbb5..000000000
--- a/management/server/turncredentials_test.go
+++ /dev/null
@@ -1,136 +0,0 @@
-package server
-
-import (
- "context"
- "crypto/hmac"
- "crypto/sha1"
- "encoding/base64"
- "testing"
- "time"
-
- "github.com/netbirdio/netbird/util"
-)
-
-var TurnTestHost = &Host{
- Proto: UDP,
- URI: "turn:turn.wiretrustee.com:77777",
- Username: "username",
- Password: "",
-}
-
-func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
- ttl := util.Duration{Duration: time.Hour}
- secret := "some_secret"
- peersManager := NewPeersUpdateManager(nil)
-
- tested := NewTimeBasedAuthSecretsManager(peersManager, &TURNConfig{
- CredentialsTTL: ttl,
- Secret: secret,
- Turns: []*Host{TurnTestHost},
- })
-
- credentials := tested.GenerateCredentials()
-
- if credentials.Username == "" {
- t.Errorf("expected generated TURN username not to be empty, got empty")
- }
- if credentials.Password == "" {
- t.Errorf("expected generated TURN password not to be empty, got empty")
- }
-
- validateMAC(t, credentials.Username, credentials.Password, []byte(secret))
-
-}
-
-func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
- ttl := util.Duration{Duration: 2 * time.Second}
- secret := "some_secret"
- peersManager := NewPeersUpdateManager(nil)
- peer := "some_peer"
- updateChannel := peersManager.CreateChannel(context.Background(), peer)
-
- tested := NewTimeBasedAuthSecretsManager(peersManager, &TURNConfig{
- CredentialsTTL: ttl,
- Secret: secret,
- Turns: []*Host{TurnTestHost},
- })
-
- tested.SetupRefresh(context.Background(), peer)
-
- if _, ok := tested.cancelMap[peer]; !ok {
- t.Errorf("expecting peer to be present in a cancel map, got not present")
- }
-
- var updates []*UpdateMessage
-
-loop:
- for timeout := time.After(5 * time.Second); ; {
-
- select {
- case update := <-updateChannel:
- updates = append(updates, update)
- case <-timeout:
- break loop
- }
-
- if len(updates) >= 2 {
- break loop
- }
- }
-
- if len(updates) < 2 {
- t.Errorf("expecting 2 peer credentials updates, got %v", len(updates))
- }
-
- firstUpdate := updates[0].Update.GetWiretrusteeConfig().Turns[0]
- secondUpdate := updates[1].Update.GetWiretrusteeConfig().Turns[0]
-
- if firstUpdate.Password == secondUpdate.Password {
- t.Errorf("expecting first credential update password %v to be diffeerent from second, got equal", firstUpdate.Password)
- }
-
-}
-
-func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) {
- ttl := util.Duration{Duration: time.Hour}
- secret := "some_secret"
- peersManager := NewPeersUpdateManager(nil)
- peer := "some_peer"
-
- tested := NewTimeBasedAuthSecretsManager(peersManager, &TURNConfig{
- CredentialsTTL: ttl,
- Secret: secret,
- Turns: []*Host{TurnTestHost},
- })
-
- tested.SetupRefresh(context.Background(), peer)
- if _, ok := tested.cancelMap[peer]; !ok {
- t.Errorf("expecting peer to be present in a cancel map, got not present")
- }
-
- tested.CancelRefresh(peer)
- if _, ok := tested.cancelMap[peer]; ok {
- t.Errorf("expecting peer to be not present in a cancel map, got present")
- }
-}
-
-func validateMAC(t *testing.T, username string, actualMAC string, key []byte) {
- t.Helper()
- mac := hmac.New(sha1.New, key)
-
- _, err := mac.Write([]byte(username))
- if err != nil {
- t.Fatal(err)
- }
-
- expectedMAC := mac.Sum(nil)
- decodedMAC, err := base64.StdEncoding.DecodeString(actualMAC)
- if err != nil {
- t.Fatal(err)
- }
- equal := hmac.Equal(decodedMAC, expectedMAC)
-
- if !equal {
- t.Errorf("expected password MAC to be %s. got %s", expectedMAC, decodedMAC)
- }
-}
diff --git a/relay/Dockerfile b/relay/Dockerfile
new file mode 100644
index 000000000..f750027c3
--- /dev/null
+++ b/relay/Dockerfile
@@ -0,0 +1,4 @@
+FROM gcr.io/distroless/base:debug
+ENTRYPOINT [ "/go/bin/netbird-relay" ]
+ENV NB_LOG_FILE=console
+COPY netbird-relay /go/bin/netbird-relay
diff --git a/relay/auth/allow/allow_all.go b/relay/auth/allow/allow_all.go
new file mode 100644
index 000000000..92845818b
--- /dev/null
+++ b/relay/auth/allow/allow_all.go
@@ -0,0 +1,12 @@
+package allow
+
+import "hash"
+
+// Auth is a Validator that allows all connections.
+// Used this for testing purposes only.
+type Auth struct {
+}
+
+func (a *Auth) Validate(func() hash.Hash, any) error {
+ return nil
+}
diff --git a/relay/auth/doc.go b/relay/auth/doc.go
new file mode 100644
index 000000000..b3e8dbb08
--- /dev/null
+++ b/relay/auth/doc.go
@@ -0,0 +1,26 @@
+/*
+Package auth manages the authentication process with the relay server.
+
+Key Components:
+
+Validator: The Validator interface defines the Validate method. Any type that provides this method can be used as a
+Validator.
+
+Methods:
+
+Validate(func() hash.Hash, any): This method is defined in the Validator interface and is used to validate the authentication.
+
+Usage:
+
+To create a new AllowAllAuth validator, simply instantiate it:
+
+ validator := &allow.Auth{}
+
+To validate the authentication, use the Validate method:
+
+ err := validator.Validate(sha256.New, any)
+
+This package provides a simple and effective way to manage authentication with the relay server, ensuring that the
+peers are authenticated properly.
+*/
+package auth
diff --git a/relay/auth/hmac/doc.go b/relay/auth/hmac/doc.go
new file mode 100644
index 000000000..a1b135aa6
--- /dev/null
+++ b/relay/auth/hmac/doc.go
@@ -0,0 +1,8 @@
+/*
+This package uses a similar HMAC method for authentication with the TURN server. The Management server provides the
+tokens for the peers. The peers manage these tokens in the token store. The token store is a simple thread safe store
+that keeps the tokens in memory. These tokens are used to authenticate the peers with the Relay server in the hello
+message.
+*/
+
+package hmac
diff --git a/relay/auth/hmac/store.go b/relay/auth/hmac/store.go
new file mode 100644
index 000000000..36c195a7b
--- /dev/null
+++ b/relay/auth/hmac/store.go
@@ -0,0 +1,36 @@
+package hmac
+
+import (
+ "sync"
+
+ log "github.com/sirupsen/logrus"
+)
+
+// TokenStore is a simple in-memory store for token
+// With this can update the token in thread safe way
+type TokenStore struct {
+ mu sync.Mutex
+ token []byte
+}
+
+func (a *TokenStore) UpdateToken(token *Token) error {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ if token == nil {
+ return nil
+ }
+
+ t, err := marshalToken(*token)
+ if err != nil {
+ log.Debugf("failed to marshal token: %s", err)
+ return err
+ }
+ a.token = t
+ return nil
+}
+
+func (a *TokenStore) TokenBinary() []byte {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ return a.token
+}
diff --git a/relay/auth/hmac/token.go b/relay/auth/hmac/token.go
new file mode 100644
index 000000000..e2e62b84e
--- /dev/null
+++ b/relay/auth/hmac/token.go
@@ -0,0 +1,105 @@
+package hmac
+
+import (
+ "bytes"
+ "crypto/hmac"
+ "encoding/base64"
+ "encoding/gob"
+ "fmt"
+ "hash"
+ "strconv"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+)
+
+type Token struct {
+ Payload string
+ Signature string
+}
+
+func marshalToken(token Token) ([]byte, error) {
+ var buffer bytes.Buffer
+ encoder := gob.NewEncoder(&buffer)
+ err := encoder.Encode(token)
+ if err != nil {
+ log.Debugf("failed to marshal token: %s", err)
+ return nil, fmt.Errorf("failed to marshal token: %w", err)
+ }
+ return buffer.Bytes(), nil
+}
+
+func unmarshalToken(payload []byte) (Token, error) {
+ var creds Token
+ buffer := bytes.NewBuffer(payload)
+ decoder := gob.NewDecoder(buffer)
+ err := decoder.Decode(&creds)
+ return creds, err
+}
+
+// TimedHMAC generates a token with TTL and uses a pre-shared secret known to the relay server
+type TimedHMAC struct {
+ secret string
+ timeToLive time.Duration
+}
+
+// NewTimedHMAC creates a new TimedHMAC instance
+func NewTimedHMAC(secret string, timeToLive time.Duration) *TimedHMAC {
+ return &TimedHMAC{
+ secret: secret,
+ timeToLive: timeToLive,
+ }
+}
+
+// GenerateToken generates new time-based secret token - basically Payload is a unix timestamp and Signature is a HMAC
+// hash of a timestamp with a preshared TURN secret
+func (m *TimedHMAC) GenerateToken(algo func() hash.Hash) (*Token, error) {
+ timeAuth := time.Now().Add(m.timeToLive).Unix()
+ timeStamp := strconv.FormatInt(timeAuth, 10)
+
+ checksum, err := m.generate(algo, timeStamp)
+ if err != nil {
+ return nil, err
+ }
+
+ return &Token{
+ Payload: timeStamp,
+ Signature: base64.StdEncoding.EncodeToString(checksum),
+ }, nil
+}
+
+// Validate checks if the token is valid
+func (m *TimedHMAC) Validate(algo func() hash.Hash, token Token) error {
+ expectedMAC, err := m.generate(algo, token.Payload)
+ if err != nil {
+ return err
+ }
+
+ expectedSignature := base64.StdEncoding.EncodeToString(expectedMAC)
+
+ if !hmac.Equal([]byte(expectedSignature), []byte(token.Signature)) {
+ return fmt.Errorf("signature mismatch")
+ }
+
+ timeAuthInt, err := strconv.ParseInt(token.Payload, 10, 64)
+ if err != nil {
+ return fmt.Errorf("invalid payload: %w", err)
+ }
+
+ if time.Now().Unix() > timeAuthInt {
+ return fmt.Errorf("expired token")
+ }
+
+ return nil
+}
+
+func (m *TimedHMAC) generate(algo func() hash.Hash, payload string) ([]byte, error) {
+ mac := hmac.New(algo, []byte(m.secret))
+ _, err := mac.Write([]byte(payload))
+ if err != nil {
+ log.Debugf("failed to generate token: %s", err)
+ return nil, fmt.Errorf("failed to generate token: %w", err)
+ }
+
+ return mac.Sum(nil), nil
+}
diff --git a/relay/auth/hmac/token_test.go b/relay/auth/hmac/token_test.go
new file mode 100644
index 000000000..e629eab97
--- /dev/null
+++ b/relay/auth/hmac/token_test.go
@@ -0,0 +1,105 @@
+package hmac
+
+import (
+ "crypto/sha1"
+ "crypto/sha256"
+ "encoding/base64"
+ "strconv"
+ "testing"
+ "time"
+)
+
+func TestGenerateCredentials(t *testing.T) {
+ secret := "secret"
+ timeToLive := 1 * time.Hour
+ v := NewTimedHMAC(secret, timeToLive)
+
+ creds, err := v.GenerateToken(sha1.New)
+ if err != nil {
+ t.Fatalf("expected no error, got %v", err)
+ }
+
+ if creds.Payload == "" {
+ t.Fatalf("expected non-empty payload")
+ }
+
+ _, err = strconv.ParseInt(creds.Payload, 10, 64)
+ if err != nil {
+ t.Fatalf("expected payload to be a valid unix timestamp, got %v", err)
+ }
+
+ _, err = base64.StdEncoding.DecodeString(creds.Signature)
+ if err != nil {
+ t.Fatalf("expected signature to be base64 encoded, got %v", err)
+ }
+}
+
+func TestValidateCredentials(t *testing.T) {
+ secret := "supersecret"
+ timeToLive := 1 * time.Hour
+ manager := NewTimedHMAC(secret, timeToLive)
+
+ // Test valid token
+ creds, err := manager.GenerateToken(sha1.New)
+ if err != nil {
+ t.Fatalf("expected no error, got %v", err)
+ }
+
+ if err := manager.Validate(sha1.New, *creds); err != nil {
+ t.Fatalf("expected valid token: %s", err)
+ }
+}
+
+func TestInvalidSignature(t *testing.T) {
+ secret := "supersecret"
+ timeToLive := 1 * time.Hour
+ manager := NewTimedHMAC(secret, timeToLive)
+
+ creds, err := manager.GenerateToken(sha256.New)
+ if err != nil {
+ t.Fatalf("expected no error, got %v", err)
+ }
+
+ invalidCreds := &Token{
+ Payload: creds.Payload,
+ Signature: "invalidsignature",
+ }
+
+ if err = manager.Validate(sha1.New, *invalidCreds); err == nil {
+ t.Fatalf("expected invalid token due to signature mismatch")
+ }
+}
+
+func TestExpired(t *testing.T) {
+ secret := "supersecret"
+ v := NewTimedHMAC(secret, -1*time.Hour)
+ expiredCreds, err := v.GenerateToken(sha256.New)
+ if err != nil {
+ t.Fatalf("expected no error, got %v", err)
+ }
+
+ if err = v.Validate(sha1.New, *expiredCreds); err == nil {
+ t.Fatalf("expected invalid token due to expiration")
+ }
+}
+
+func TestInvalidPayload(t *testing.T) {
+ secret := "supersecret"
+ timeToLive := 1 * time.Hour
+ v := NewTimedHMAC(secret, timeToLive)
+
+ creds, err := v.GenerateToken(sha256.New)
+ if err != nil {
+ t.Fatalf("expected no error, got %v", err)
+ }
+
+ // Test invalid payload
+ invalidPayloadCreds := &Token{
+ Payload: "invalidtimestamp",
+ Signature: creds.Signature,
+ }
+
+ if err = v.Validate(sha1.New, *invalidPayloadCreds); err == nil {
+ t.Fatalf("expected invalid token due to invalid payload")
+ }
+}
diff --git a/relay/auth/hmac/validator.go b/relay/auth/hmac/validator.go
new file mode 100644
index 000000000..6ddd89c19
--- /dev/null
+++ b/relay/auth/hmac/validator.go
@@ -0,0 +1,33 @@
+package hmac
+
+import (
+ "fmt"
+ "hash"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+)
+
+type TimedHMACValidator struct {
+ *TimedHMAC
+}
+
+func NewTimedHMACValidator(secret string, duration time.Duration) *TimedHMACValidator {
+ ta := NewTimedHMAC(secret, duration)
+ return &TimedHMACValidator{
+ ta,
+ }
+}
+
+func (a *TimedHMACValidator) Validate(algo func() hash.Hash, credentials any) error {
+ b, ok := credentials.([]byte)
+ if !ok {
+ return fmt.Errorf("invalid credentials type")
+ }
+ c, err := unmarshalToken(b)
+ if err != nil {
+ log.Debugf("failed to unmarshal token: %s", err)
+ return err
+ }
+ return a.TimedHMAC.Validate(algo, c)
+}
diff --git a/relay/auth/validator.go b/relay/auth/validator.go
new file mode 100644
index 000000000..078811f3d
--- /dev/null
+++ b/relay/auth/validator.go
@@ -0,0 +1,8 @@
+package auth
+
+import "hash"
+
+// Validator is an interface that defines the Validate method.
+type Validator interface {
+ Validate(func() hash.Hash, any) error
+}
diff --git a/relay/client/addr.go b/relay/client/addr.go
new file mode 100644
index 000000000..af4f459f8
--- /dev/null
+++ b/relay/client/addr.go
@@ -0,0 +1,13 @@
+package client
+
+type RelayAddr struct {
+ addr string
+}
+
+func (a RelayAddr) Network() string {
+ return "relay"
+}
+
+func (a RelayAddr) String() string {
+ return a.addr
+}
diff --git a/relay/client/client.go b/relay/client/client.go
new file mode 100644
index 000000000..1160d1c9e
--- /dev/null
+++ b/relay/client/client.go
@@ -0,0 +1,553 @@
+package client
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "net"
+ "sync"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+
+ auth "github.com/netbirdio/netbird/relay/auth/hmac"
+ "github.com/netbirdio/netbird/relay/client/dialer/ws"
+ "github.com/netbirdio/netbird/relay/healthcheck"
+ "github.com/netbirdio/netbird/relay/messages"
+ "github.com/netbirdio/netbird/relay/messages/address"
+ auth2 "github.com/netbirdio/netbird/relay/messages/auth"
+)
+
+const (
+ bufferSize = 8820
+ serverResponseTimeout = 8 * time.Second
+)
+
+var (
+ ErrConnAlreadyExists = fmt.Errorf("connection already exists")
+)
+
+type internalStopFlag struct {
+ sync.Mutex
+ stop bool
+}
+
+func newInternalStopFlag() *internalStopFlag {
+ return &internalStopFlag{}
+}
+
+func (isf *internalStopFlag) set() {
+ isf.Lock()
+ defer isf.Unlock()
+ isf.stop = true
+}
+
+func (isf *internalStopFlag) isSet() bool {
+ isf.Lock()
+ defer isf.Unlock()
+ return isf.stop
+}
+
+// Msg carry the payload from the server to the client. With this struct, the net.Conn can free the buffer.
+type Msg struct {
+ Payload []byte
+
+ bufPool *sync.Pool
+ bufPtr *[]byte
+}
+
+func (m *Msg) Free() {
+ m.bufPool.Put(m.bufPtr)
+}
+
+type connContainer struct {
+ conn *Conn
+ messages chan Msg
+ msgChanLock sync.Mutex
+ closed bool // flag to check if channel is closed
+}
+
+func newConnContainer(conn *Conn, messages chan Msg) *connContainer {
+ return &connContainer{
+ conn: conn,
+ messages: messages,
+ }
+}
+
+func (cc *connContainer) writeMsg(msg Msg) {
+ cc.msgChanLock.Lock()
+ defer cc.msgChanLock.Unlock()
+ if cc.closed {
+ return
+ }
+ cc.messages <- msg
+}
+
+func (cc *connContainer) close() {
+ cc.msgChanLock.Lock()
+ defer cc.msgChanLock.Unlock()
+ if cc.closed {
+ return
+ }
+ close(cc.messages)
+ cc.closed = true
+}
+
+// Client is a client for the relay server. It is responsible for establishing a connection to the relay server and
+// managing connections to other peers. All exported functions are safe to call concurrently. After close the connection,
+// the client can be reused by calling Connect again. When the client is closed, all connections are closed too.
+// While the Connect is in progress, the OpenConn function will block until the connection is established with relay server.
+type Client struct {
+ log *log.Entry
+ parentCtx context.Context
+ connectionURL string
+ authTokenStore *auth.TokenStore
+ hashedID []byte
+
+ bufPool *sync.Pool
+
+ relayConn net.Conn
+ conns map[string]*connContainer
+ serviceIsRunning bool
+ mu sync.Mutex // protect serviceIsRunning and conns
+ readLoopMutex sync.Mutex
+ wgReadLoop sync.WaitGroup
+ instanceURL *RelayAddr
+ muInstanceURL sync.Mutex
+
+ onDisconnectListener func()
+ listenerMutex sync.Mutex
+}
+
+// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
+func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client {
+ hashedID, hashedStringId := messages.HashID(peerID)
+ return &Client{
+ log: log.WithField("client_id", hashedStringId),
+ parentCtx: ctx,
+ connectionURL: serverURL,
+ authTokenStore: authTokenStore,
+ hashedID: hashedID,
+ bufPool: &sync.Pool{
+ New: func() any {
+ buf := make([]byte, bufferSize)
+ return &buf
+ },
+ },
+ conns: make(map[string]*connContainer),
+ }
+}
+
+// Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs.
+func (c *Client) Connect() error {
+ c.log.Infof("connecting to relay server: %s", c.connectionURL)
+ c.readLoopMutex.Lock()
+ defer c.readLoopMutex.Unlock()
+
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ if c.serviceIsRunning {
+ return nil
+ }
+
+ err := c.connect()
+ if err != nil {
+ return err
+ }
+
+ c.serviceIsRunning = true
+
+ c.wgReadLoop.Add(1)
+ go c.readLoop(c.relayConn)
+
+ c.log.Infof("relay connection established with: %s", c.connectionURL)
+ return nil
+}
+
+// OpenConn create a new net.Conn for the destination peer ID. In case if the connection is in progress
+// to the relay server, the function will block until the connection is established or timed out. Otherwise,
+// it will return immediately.
+// todo: what should happen if call with the same peerID with multiple times?
+func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ if !c.serviceIsRunning {
+ return nil, fmt.Errorf("relay connection is not established")
+ }
+
+ hashedID, hashedStringID := messages.HashID(dstPeerID)
+ _, ok := c.conns[hashedStringID]
+ if ok {
+ return nil, ErrConnAlreadyExists
+ }
+
+ log.Infof("open connection to peer: %s", hashedStringID)
+ msgChannel := make(chan Msg, 2)
+ conn := NewConn(c, hashedID, hashedStringID, msgChannel, c.instanceURL)
+
+ c.conns[hashedStringID] = newConnContainer(conn, msgChannel)
+ return conn, nil
+}
+
+// ServerInstanceURL returns the address of the relay server. It could change after the close and reopen the connection.
+func (c *Client) ServerInstanceURL() (string, error) {
+ c.muInstanceURL.Lock()
+ defer c.muInstanceURL.Unlock()
+ if c.instanceURL == nil {
+ return "", fmt.Errorf("relay connection is not established")
+ }
+ return c.instanceURL.String(), nil
+}
+
+// SetOnDisconnectListener sets a function that will be called when the connection to the relay server is closed.
+func (c *Client) SetOnDisconnectListener(fn func()) {
+ c.listenerMutex.Lock()
+ defer c.listenerMutex.Unlock()
+ c.onDisconnectListener = fn
+}
+
+// HasConns returns true if there are connections.
+func (c *Client) HasConns() bool {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ return len(c.conns) > 0
+}
+
+// Close closes the connection to the relay server and all connections to other peers.
+func (c *Client) Close() error {
+ return c.close(true)
+}
+
+func (c *Client) connect() error {
+ conn, err := ws.Dial(c.connectionURL)
+ if err != nil {
+ return err
+ }
+ c.relayConn = conn
+
+ err = c.handShake()
+ if err != nil {
+ cErr := conn.Close()
+ if cErr != nil {
+ log.Errorf("failed to close connection: %s", cErr)
+ }
+ return err
+ }
+
+ return nil
+}
+
+func (c *Client) handShake() error {
+ authMsg := &auth2.Msg{
+ AuthAlgorithm: auth2.AlgoHMACSHA256,
+ AdditionalData: c.authTokenStore.TokenBinary(),
+ }
+
+ authData, err := authMsg.Marshal()
+ if err != nil {
+ return fmt.Errorf("marshal auth message: %w", err)
+ }
+
+ msg, err := messages.MarshalHelloMsg(c.hashedID, authData)
+ if err != nil {
+ log.Errorf("failed to marshal hello message: %s", err)
+ return err
+ }
+
+ _, err = c.relayConn.Write(msg)
+ if err != nil {
+ log.Errorf("failed to send hello message: %s", err)
+ return err
+ }
+ buf := make([]byte, messages.MaxHandshakeSize)
+ n, err := c.readWithTimeout(buf)
+ if err != nil {
+ log.Errorf("failed to read hello response: %s", err)
+ return err
+ }
+
+ _, err = messages.ValidateVersion(buf[:n])
+ if err != nil {
+ return fmt.Errorf("validate version: %w", err)
+ }
+
+ msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n])
+ if err != nil {
+ log.Errorf("failed to determine message type: %s", err)
+ return err
+ }
+
+ if msgType != messages.MsgTypeHelloResponse {
+ log.Errorf("unexpected message type: %s", msgType)
+ return fmt.Errorf("unexpected message type")
+ }
+
+ additionalData, err := messages.UnmarshalHelloResponse(buf[messages.SizeOfProtoHeader:n])
+ if err != nil {
+ return err
+ }
+
+ addr, err := address.Unmarshal(additionalData)
+ if err != nil {
+ return fmt.Errorf("unmarshal address: %w", err)
+ }
+
+ c.muInstanceURL.Lock()
+ c.instanceURL = &RelayAddr{addr: addr.URL}
+ c.muInstanceURL.Unlock()
+ return nil
+}
+
+func (c *Client) readLoop(relayConn net.Conn) {
+ internallyStoppedFlag := newInternalStopFlag()
+ hc := healthcheck.NewReceiver()
+ go c.listenForStopEvents(hc, relayConn, internallyStoppedFlag)
+
+ var (
+ errExit error
+ n int
+ )
+ for {
+ bufPtr := c.bufPool.Get().(*[]byte)
+ buf := *bufPtr
+ n, errExit = relayConn.Read(buf)
+ if errExit != nil {
+ c.mu.Lock()
+ if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
+ c.log.Debugf("failed to read message from relay server: %s", errExit)
+ }
+ c.mu.Unlock()
+ break
+ }
+
+ _, err := messages.ValidateVersion(buf[:n])
+ if err != nil {
+ c.log.Errorf("failed to validate protocol version: %s", err)
+ c.bufPool.Put(bufPtr)
+ continue
+ }
+
+ msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n])
+ if err != nil {
+ c.log.Errorf("failed to determine message type: %s", err)
+ c.bufPool.Put(bufPtr)
+ continue
+ }
+
+ if !c.handleMsg(msgType, buf[messages.SizeOfProtoHeader:n], bufPtr, hc, internallyStoppedFlag) {
+ break
+ }
+ }
+
+ hc.Stop()
+
+ c.muInstanceURL.Lock()
+ c.instanceURL = nil
+ c.muInstanceURL.Unlock()
+
+ c.notifyDisconnected()
+ c.wgReadLoop.Done()
+ _ = c.close(false)
+}
+
+func (c *Client) handleMsg(msgType messages.MsgType, buf []byte, bufPtr *[]byte, hc *healthcheck.Receiver, internallyStoppedFlag *internalStopFlag) (continueLoop bool) {
+ switch msgType {
+ case messages.MsgTypeHealthCheck:
+ c.handleHealthCheck(hc, internallyStoppedFlag)
+ c.bufPool.Put(bufPtr)
+ case messages.MsgTypeTransport:
+ return c.handleTransportMsg(buf, bufPtr, internallyStoppedFlag)
+ case messages.MsgTypeClose:
+ log.Debugf("relay connection close by server")
+ c.bufPool.Put(bufPtr)
+ return false
+ }
+
+ return true
+}
+
+func (c *Client) handleHealthCheck(hc *healthcheck.Receiver, internallyStoppedFlag *internalStopFlag) {
+ msg := messages.MarshalHealthcheck()
+ _, wErr := c.relayConn.Write(msg)
+ if wErr != nil {
+ if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
+ c.log.Errorf("failed to send heartbeat: %s", wErr)
+ }
+ }
+ hc.Heartbeat()
+}
+
+func (c *Client) handleTransportMsg(buf []byte, bufPtr *[]byte, internallyStoppedFlag *internalStopFlag) bool {
+ peerID, payload, err := messages.UnmarshalTransportMsg(buf)
+ if err != nil {
+ if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
+ c.log.Errorf("failed to parse transport message: %v", err)
+ }
+
+ c.bufPool.Put(bufPtr)
+ return true
+ }
+
+ stringID := messages.HashIDToString(peerID)
+
+ c.mu.Lock()
+ if !c.serviceIsRunning {
+ c.mu.Unlock()
+ c.bufPool.Put(bufPtr)
+ return false
+ }
+ container, ok := c.conns[stringID]
+ c.mu.Unlock()
+ if !ok {
+ c.log.Errorf("peer not found: %s", stringID)
+ c.bufPool.Put(bufPtr)
+ return true
+ }
+ msg := Msg{
+ bufPool: c.bufPool,
+ bufPtr: bufPtr,
+ Payload: payload,
+ }
+ container.writeMsg(msg)
+ return true
+}
+
+func (c *Client) writeTo(connReference *Conn, id string, dstID []byte, payload []byte) (int, error) {
+ c.mu.Lock()
+ conn, ok := c.conns[id]
+ c.mu.Unlock()
+ if !ok {
+ return 0, io.EOF
+ }
+
+ if conn.conn != connReference {
+ return 0, io.EOF
+ }
+
+ // todo: use buffer pool instead of create new transport msg.
+ msg, err := messages.MarshalTransportMsg(dstID, payload)
+ if err != nil {
+ log.Errorf("failed to marshal transport message: %s", err)
+ return 0, err
+ }
+
+ // the write always return with 0 length because the underling does not support the size feedback.
+ _, err = c.relayConn.Write(msg)
+ if err != nil {
+ log.Errorf("failed to write transport message: %s", err)
+ }
+ return len(payload), err
+}
+
+func (c *Client) listenForStopEvents(hc *healthcheck.Receiver, conn net.Conn, internalStopFlag *internalStopFlag) {
+ for {
+ select {
+ case _, ok := <-hc.OnTimeout:
+ if !ok {
+ return
+ }
+ c.log.Errorf("health check timeout")
+ internalStopFlag.set()
+ _ = conn.Close() // ignore the err because the readLoop will handle it
+ return
+ case <-c.parentCtx.Done():
+ err := c.close(true)
+ if err != nil {
+ log.Errorf("failed to teardown connection: %s", err)
+ }
+ return
+ }
+ }
+}
+
+func (c *Client) closeAllConns() {
+ for _, container := range c.conns {
+ container.close()
+ }
+ c.conns = make(map[string]*connContainer)
+}
+
+func (c *Client) closeConn(connReference *Conn, id string) error {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ container, ok := c.conns[id]
+ if !ok {
+ return fmt.Errorf("connection already closed")
+ }
+
+ if container.conn != connReference {
+ return fmt.Errorf("conn reference mismatch")
+ }
+ container.close()
+ delete(c.conns, id)
+
+ return nil
+}
+
+func (c *Client) close(gracefullyExit bool) error {
+ c.readLoopMutex.Lock()
+ defer c.readLoopMutex.Unlock()
+
+ c.mu.Lock()
+ var err error
+ if !c.serviceIsRunning {
+ c.mu.Unlock()
+ return nil
+ }
+
+ c.serviceIsRunning = false
+ c.closeAllConns()
+ if gracefullyExit {
+ c.writeCloseMsg()
+ }
+ err = c.relayConn.Close()
+ c.mu.Unlock()
+
+ c.wgReadLoop.Wait()
+ c.log.Infof("relay connection closed with: %s", c.connectionURL)
+ return err
+}
+
+func (c *Client) notifyDisconnected() {
+ c.listenerMutex.Lock()
+ defer c.listenerMutex.Unlock()
+
+ if c.onDisconnectListener == nil {
+ return
+ }
+ go c.onDisconnectListener()
+}
+
+func (c *Client) writeCloseMsg() {
+ msg := messages.MarshalCloseMsg()
+ _, err := c.relayConn.Write(msg)
+ if err != nil {
+ c.log.Errorf("failed to send close message: %s", err)
+ }
+}
+
+func (c *Client) readWithTimeout(buf []byte) (int, error) {
+ ctx, cancel := context.WithTimeout(c.parentCtx, serverResponseTimeout)
+ defer cancel()
+
+ readDone := make(chan struct{})
+ var (
+ n int
+ err error
+ )
+
+ go func() {
+ n, err = c.relayConn.Read(buf)
+ close(readDone)
+ }()
+
+ select {
+ case <-ctx.Done():
+ return 0, fmt.Errorf("read operation timed out")
+ case <-readDone:
+ return n, err
+ }
+}
diff --git a/relay/client/client_test.go b/relay/client/client_test.go
new file mode 100644
index 000000000..b7f1a63ca
--- /dev/null
+++ b/relay/client/client_test.go
@@ -0,0 +1,631 @@
+package client
+
+import (
+ "context"
+ "net"
+ "os"
+ "testing"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+ "go.opentelemetry.io/otel"
+
+ "github.com/netbirdio/netbird/relay/auth/allow"
+ "github.com/netbirdio/netbird/relay/auth/hmac"
+ "github.com/netbirdio/netbird/util"
+
+ "github.com/netbirdio/netbird/relay/server"
+)
+
+var (
+ av = &allow.Auth{}
+ hmacTokenStore = &hmac.TokenStore{}
+ serverListenAddr = "127.0.0.1:1234"
+ serverURL = "rel://127.0.0.1:1234"
+)
+
+func TestMain(m *testing.M) {
+ _ = util.InitLog("error", "console")
+ code := m.Run()
+ os.Exit(code)
+}
+
+func TestClient(t *testing.T) {
+ ctx := context.Background()
+
+ srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
+ if err != nil {
+ t.Fatalf("failed to create server: %s", err)
+ }
+ errChan := make(chan error, 1)
+ go func() {
+ listenCfg := server.ListenerConfig{Address: serverListenAddr}
+ err := srv.Listen(listenCfg)
+ if err != nil {
+ errChan <- err
+ }
+ }()
+
+ defer func() {
+ err := srv.Shutdown(ctx)
+ if err != nil {
+ t.Errorf("failed to close server: %s", err)
+ }
+ }()
+
+ // wait for server to start
+ if err := waitForServerToStart(errChan); err != nil {
+ t.Fatalf("failed to start server: %s", err)
+ }
+ t.Log("alice connecting to server")
+ clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
+ err = clientAlice.Connect()
+ if err != nil {
+ t.Fatalf("failed to connect to server: %s", err)
+ }
+ defer clientAlice.Close()
+
+ t.Log("placeholder connecting to server")
+ clientPlaceHolder := NewClient(ctx, serverURL, hmacTokenStore, "clientPlaceHolder")
+ err = clientPlaceHolder.Connect()
+ if err != nil {
+ t.Fatalf("failed to connect to server: %s", err)
+ }
+ defer clientPlaceHolder.Close()
+
+ t.Log("Bob connecting to server")
+ clientBob := NewClient(ctx, serverURL, hmacTokenStore, "bob")
+ err = clientBob.Connect()
+ if err != nil {
+ t.Fatalf("failed to connect to server: %s", err)
+ }
+ defer clientBob.Close()
+
+ t.Log("Alice open connection to Bob")
+ connAliceToBob, err := clientAlice.OpenConn("bob")
+ if err != nil {
+ t.Fatalf("failed to bind channel: %s", err)
+ }
+
+ t.Log("Bob open connection to Alice")
+ connBobToAlice, err := clientBob.OpenConn("alice")
+ if err != nil {
+ t.Fatalf("failed to bind channel: %s", err)
+ }
+
+ payload := "hello bob, I am alice"
+ _, err = connAliceToBob.Write([]byte(payload))
+ if err != nil {
+ t.Fatalf("failed to write to channel: %s", err)
+ }
+ log.Debugf("alice sent message to bob")
+
+ buf := make([]byte, 65535)
+ n, err := connBobToAlice.Read(buf)
+ if err != nil {
+ t.Fatalf("failed to read from channel: %s", err)
+ }
+ log.Debugf("on new message from alice to bob")
+
+ if payload != string(buf[:n]) {
+ t.Fatalf("expected %s, got %s", payload, string(buf[:n]))
+ }
+}
+
+func TestRegistration(t *testing.T) {
+ ctx := context.Background()
+ srvCfg := server.ListenerConfig{Address: serverListenAddr}
+ srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
+ if err != nil {
+ t.Fatalf("failed to create server: %s", err)
+ }
+ errChan := make(chan error, 1)
+ go func() {
+ err := srv.Listen(srvCfg)
+ if err != nil {
+ errChan <- err
+ }
+ }()
+
+ // wait for server to start
+ if err := waitForServerToStart(errChan); err != nil {
+ t.Fatalf("failed to start server: %s", err)
+ }
+
+ clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
+ err = clientAlice.Connect()
+ if err != nil {
+ _ = srv.Shutdown(ctx)
+ t.Fatalf("failed to connect to server: %s", err)
+ }
+ err = clientAlice.Close()
+ if err != nil {
+ t.Errorf("failed to close conn: %s", err)
+ }
+ err = srv.Shutdown(ctx)
+ if err != nil {
+ t.Errorf("failed to close server: %s", err)
+ }
+}
+
+func TestRegistrationTimeout(t *testing.T) {
+ ctx := context.Background()
+ fakeUDPListener, err := net.ListenUDP("udp", &net.UDPAddr{
+ Port: 1234,
+ IP: net.ParseIP("0.0.0.0"),
+ })
+ if err != nil {
+ t.Fatalf("failed to bind UDP server: %s", err)
+ }
+ defer func(fakeUDPListener *net.UDPConn) {
+ _ = fakeUDPListener.Close()
+ }(fakeUDPListener)
+
+ fakeTCPListener, err := net.ListenTCP("tcp", &net.TCPAddr{
+ Port: 1234,
+ IP: net.ParseIP("0.0.0.0"),
+ })
+ if err != nil {
+ t.Fatalf("failed to bind TCP server: %s", err)
+ }
+ defer func(fakeTCPListener *net.TCPListener) {
+ _ = fakeTCPListener.Close()
+ }(fakeTCPListener)
+
+ clientAlice := NewClient(ctx, "127.0.0.1:1234", hmacTokenStore, "alice")
+ err = clientAlice.Connect()
+ if err == nil {
+ t.Errorf("failed to connect to server: %s", err)
+ }
+ log.Debugf("%s", err)
+ err = clientAlice.Close()
+ if err != nil {
+ t.Errorf("failed to close conn: %s", err)
+ }
+}
+
+func TestEcho(t *testing.T) {
+ ctx := context.Background()
+ idAlice := "alice"
+ idBob := "bob"
+ srvCfg := server.ListenerConfig{Address: serverListenAddr}
+ srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
+ if err != nil {
+ t.Fatalf("failed to create server: %s", err)
+ }
+ errChan := make(chan error, 1)
+ go func() {
+ err := srv.Listen(srvCfg)
+ if err != nil {
+ errChan <- err
+ }
+ }()
+
+ defer func() {
+ err := srv.Shutdown(ctx)
+ if err != nil {
+ t.Errorf("failed to close server: %s", err)
+ }
+ }()
+
+ // wait for servers to start
+ if err := waitForServerToStart(errChan); err != nil {
+ t.Fatalf("failed to start server: %s", err)
+ }
+
+ clientAlice := NewClient(ctx, serverURL, hmacTokenStore, idAlice)
+ err = clientAlice.Connect()
+ if err != nil {
+ t.Fatalf("failed to connect to server: %s", err)
+ }
+ defer func() {
+ err := clientAlice.Close()
+ if err != nil {
+ t.Errorf("failed to close Alice client: %s", err)
+ }
+ }()
+
+ clientBob := NewClient(ctx, serverURL, hmacTokenStore, idBob)
+ err = clientBob.Connect()
+ if err != nil {
+ t.Fatalf("failed to connect to server: %s", err)
+ }
+ defer func() {
+ err := clientBob.Close()
+ if err != nil {
+ t.Errorf("failed to close Bob client: %s", err)
+ }
+ }()
+
+ connAliceToBob, err := clientAlice.OpenConn(idBob)
+ if err != nil {
+ t.Fatalf("failed to bind channel: %s", err)
+ }
+
+ connBobToAlice, err := clientBob.OpenConn(idAlice)
+ if err != nil {
+ t.Fatalf("failed to bind channel: %s", err)
+ }
+
+ payload := "hello bob, I am alice"
+ _, err = connAliceToBob.Write([]byte(payload))
+ if err != nil {
+ t.Fatalf("failed to write to channel: %s", err)
+ }
+
+ buf := make([]byte, 65535)
+ n, err := connBobToAlice.Read(buf)
+ if err != nil {
+ t.Fatalf("failed to read from channel: %s", err)
+ }
+
+ _, err = connBobToAlice.Write(buf[:n])
+ if err != nil {
+ t.Fatalf("failed to write to channel: %s", err)
+ }
+
+ n, err = connAliceToBob.Read(buf)
+ if err != nil {
+ t.Fatalf("failed to read from channel: %s", err)
+ }
+
+ if payload != string(buf[:n]) {
+ t.Fatalf("expected %s, got %s", payload, string(buf[:n]))
+ }
+}
+
+func TestBindToUnavailabePeer(t *testing.T) {
+ ctx := context.Background()
+
+ srvCfg := server.ListenerConfig{Address: serverListenAddr}
+ srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
+ if err != nil {
+ t.Fatalf("failed to create server: %s", err)
+ }
+ errChan := make(chan error, 1)
+ go func() {
+ err := srv.Listen(srvCfg)
+ if err != nil {
+ errChan <- err
+ }
+ }()
+
+ defer func() {
+ log.Infof("closing server")
+ err := srv.Shutdown(ctx)
+ if err != nil {
+ t.Errorf("failed to close server: %s", err)
+ }
+ }()
+
+ // wait for servers to start
+ if err := waitForServerToStart(errChan); err != nil {
+ t.Fatalf("failed to start server: %s", err)
+ }
+
+ clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
+ err = clientAlice.Connect()
+ if err != nil {
+ t.Errorf("failed to connect to server: %s", err)
+ }
+ _, err = clientAlice.OpenConn("bob")
+ if err != nil {
+ t.Errorf("failed to bind channel: %s", err)
+ }
+
+ log.Infof("closing client")
+ err = clientAlice.Close()
+ if err != nil {
+ t.Errorf("failed to close client: %s", err)
+ }
+}
+
+func TestBindReconnect(t *testing.T) {
+ ctx := context.Background()
+
+ srvCfg := server.ListenerConfig{Address: serverListenAddr}
+ srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
+ if err != nil {
+ t.Fatalf("failed to create server: %s", err)
+ }
+ errChan := make(chan error, 1)
+ go func() {
+ err := srv.Listen(srvCfg)
+ if err != nil {
+ errChan <- err
+ }
+ }()
+
+ defer func() {
+ log.Infof("closing server")
+ err := srv.Shutdown(ctx)
+ if err != nil {
+ t.Errorf("failed to close server: %s", err)
+ }
+ }()
+
+ // wait for servers to start
+ if err := waitForServerToStart(errChan); err != nil {
+ t.Fatalf("failed to start server: %s", err)
+ }
+
+ clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
+ err = clientAlice.Connect()
+ if err != nil {
+ t.Errorf("failed to connect to server: %s", err)
+ }
+
+ _, err = clientAlice.OpenConn("bob")
+ if err != nil {
+ t.Errorf("failed to bind channel: %s", err)
+ }
+
+ clientBob := NewClient(ctx, serverURL, hmacTokenStore, "bob")
+ err = clientBob.Connect()
+ if err != nil {
+ t.Errorf("failed to connect to server: %s", err)
+ }
+
+ chBob, err := clientBob.OpenConn("alice")
+ if err != nil {
+ t.Errorf("failed to bind channel: %s", err)
+ }
+
+ log.Infof("closing client Alice")
+ err = clientAlice.Close()
+ if err != nil {
+ t.Errorf("failed to close client: %s", err)
+ }
+
+ clientAlice = NewClient(ctx, serverURL, hmacTokenStore, "alice")
+ err = clientAlice.Connect()
+ if err != nil {
+ t.Errorf("failed to connect to server: %s", err)
+ }
+
+ chAlice, err := clientAlice.OpenConn("bob")
+ if err != nil {
+ t.Errorf("failed to bind channel: %s", err)
+ }
+
+ testString := "hello alice, I am bob"
+ _, err = chBob.Write([]byte(testString))
+ if err != nil {
+ t.Errorf("failed to write to channel: %s", err)
+ }
+
+ buf := make([]byte, 65535)
+ n, err := chAlice.Read(buf)
+ if err != nil {
+ t.Errorf("failed to read from channel: %s", err)
+ }
+
+ if testString != string(buf[:n]) {
+ t.Errorf("expected %s, got %s", testString, string(buf[:n]))
+ }
+
+ log.Infof("closing client")
+ err = clientAlice.Close()
+ if err != nil {
+ t.Errorf("failed to close client: %s", err)
+ }
+}
+
+func TestCloseConn(t *testing.T) {
+ ctx := context.Background()
+
+ srvCfg := server.ListenerConfig{Address: serverListenAddr}
+ srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
+ if err != nil {
+ t.Fatalf("failed to create server: %s", err)
+ }
+ errChan := make(chan error, 1)
+ go func() {
+ err := srv.Listen(srvCfg)
+ if err != nil {
+ errChan <- err
+ }
+ }()
+
+ defer func() {
+ log.Infof("closing server")
+ err := srv.Shutdown(ctx)
+ if err != nil {
+ t.Errorf("failed to close server: %s", err)
+ }
+ }()
+
+ // wait for servers to start
+ if err := waitForServerToStart(errChan); err != nil {
+ t.Fatalf("failed to start server: %s", err)
+ }
+
+ clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
+ err = clientAlice.Connect()
+ if err != nil {
+ t.Errorf("failed to connect to server: %s", err)
+ }
+
+ conn, err := clientAlice.OpenConn("bob")
+ if err != nil {
+ t.Errorf("failed to bind channel: %s", err)
+ }
+
+ log.Infof("closing connection")
+ err = conn.Close()
+ if err != nil {
+ t.Errorf("failed to close connection: %s", err)
+ }
+
+ _, err = conn.Read(make([]byte, 1))
+ if err == nil {
+ t.Errorf("unexpected reading from closed connection")
+ }
+
+ _, err = conn.Write([]byte("hello"))
+ if err == nil {
+ t.Errorf("unexpected writing from closed connection")
+ }
+}
+
+func TestCloseRelayConn(t *testing.T) {
+ ctx := context.Background()
+
+ srvCfg := server.ListenerConfig{Address: serverListenAddr}
+ srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
+ if err != nil {
+ t.Fatalf("failed to create server: %s", err)
+ }
+ errChan := make(chan error, 1)
+ go func() {
+ err := srv.Listen(srvCfg)
+ if err != nil {
+ errChan <- err
+ }
+ }()
+
+ defer func() {
+ err := srv.Shutdown(ctx)
+ if err != nil {
+ log.Errorf("failed to close server: %s", err)
+ }
+ }()
+
+ // wait for servers to start
+ if err := waitForServerToStart(errChan); err != nil {
+ t.Fatalf("failed to start server: %s", err)
+ }
+
+ clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
+ err = clientAlice.Connect()
+ if err != nil {
+ t.Fatalf("failed to connect to server: %s", err)
+ }
+
+ conn, err := clientAlice.OpenConn("bob")
+ if err != nil {
+ t.Errorf("failed to bind channel: %s", err)
+ }
+
+ _ = clientAlice.relayConn.Close()
+
+ _, err = conn.Read(make([]byte, 1))
+ if err == nil {
+ t.Errorf("unexpected reading from closed connection")
+ }
+
+ _, err = clientAlice.OpenConn("bob")
+ if err == nil {
+ t.Errorf("unexpected opening connection to closed server")
+ }
+}
+
+func TestCloseByServer(t *testing.T) {
+ ctx := context.Background()
+
+ srvCfg := server.ListenerConfig{Address: serverListenAddr}
+ srv1, err := server.NewServer(otel.Meter(""), serverURL, false, av)
+ if err != nil {
+ t.Fatalf("failed to create server: %s", err)
+ }
+ errChan := make(chan error, 1)
+
+ go func() {
+ err := srv1.Listen(srvCfg)
+ if err != nil {
+ errChan <- err
+ }
+ }()
+
+ // wait for servers to start
+ if err := waitForServerToStart(errChan); err != nil {
+ t.Fatalf("failed to start server: %s", err)
+ }
+
+ idAlice := "alice"
+ log.Debugf("connect by alice")
+ relayClient := NewClient(ctx, serverURL, hmacTokenStore, idAlice)
+ err = relayClient.Connect()
+ if err != nil {
+ log.Fatalf("failed to connect to server: %s", err)
+ }
+
+ disconnected := make(chan struct{})
+ relayClient.SetOnDisconnectListener(func() {
+ log.Infof("client disconnected")
+ close(disconnected)
+ })
+
+ err = srv1.Shutdown(ctx)
+ if err != nil {
+ t.Fatalf("failed to close server: %s", err)
+ }
+
+ select {
+ case <-disconnected:
+ case <-time.After(3 * time.Second):
+ log.Fatalf("timeout waiting for client to disconnect")
+ }
+
+ _, err = relayClient.OpenConn("bob")
+ if err == nil {
+ t.Errorf("unexpected opening connection to closed server")
+ }
+}
+
+func TestCloseByClient(t *testing.T) {
+ ctx := context.Background()
+
+ srvCfg := server.ListenerConfig{Address: serverListenAddr}
+ srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
+ if err != nil {
+ t.Fatalf("failed to create server: %s", err)
+ }
+ errChan := make(chan error, 1)
+ go func() {
+ err := srv.Listen(srvCfg)
+ if err != nil {
+ errChan <- err
+ }
+ }()
+
+ // wait for servers to start
+ if err := waitForServerToStart(errChan); err != nil {
+ t.Fatalf("failed to start server: %s", err)
+ }
+
+ idAlice := "alice"
+ log.Debugf("connect by alice")
+ relayClient := NewClient(ctx, serverURL, hmacTokenStore, idAlice)
+ err = relayClient.Connect()
+ if err != nil {
+ log.Fatalf("failed to connect to server: %s", err)
+ }
+
+ err = relayClient.Close()
+ if err != nil {
+ t.Errorf("failed to close client: %s", err)
+ }
+
+ _, err = relayClient.OpenConn("bob")
+ if err == nil {
+ t.Errorf("unexpected opening connection to closed server")
+ }
+
+ err = srv.Shutdown(ctx)
+ if err != nil {
+ t.Fatalf("failed to close server: %s", err)
+ }
+}
+
+func waitForServerToStart(errChan chan error) error {
+ select {
+ case err := <-errChan:
+ if err != nil {
+ return err
+ }
+ case <-time.After(300 * time.Millisecond):
+ return nil
+ }
+ return nil
+}
diff --git a/relay/client/conn.go b/relay/client/conn.go
new file mode 100644
index 000000000..b4ff903e8
--- /dev/null
+++ b/relay/client/conn.go
@@ -0,0 +1,76 @@
+package client
+
+import (
+ "io"
+ "net"
+ "time"
+)
+
+// Conn represent a connection to a relayed remote peer.
+type Conn struct {
+ client *Client
+ dstID []byte
+ dstStringID string
+ messageChan chan Msg
+ instanceURL *RelayAddr
+}
+
+// NewConn creates a new connection to a relayed remote peer.
+// client: the client instance, it used to send messages to the destination peer
+// dstID: the destination peer ID
+// dstStringID: the destination peer ID in string format
+// messageChan: the channel where the messages will be received
+// instanceURL: the relay instance URL, it used to get the proper server instance address for the remote peer
+func NewConn(client *Client, dstID []byte, dstStringID string, messageChan chan Msg, instanceURL *RelayAddr) *Conn {
+ c := &Conn{
+ client: client,
+ dstID: dstID,
+ dstStringID: dstStringID,
+ messageChan: messageChan,
+ instanceURL: instanceURL,
+ }
+
+ return c
+}
+
+func (c *Conn) Write(p []byte) (n int, err error) {
+ return c.client.writeTo(c, c.dstStringID, c.dstID, p)
+}
+
+func (c *Conn) Read(b []byte) (n int, err error) {
+ msg, ok := <-c.messageChan
+ if !ok {
+ return 0, io.EOF
+ }
+
+ n = copy(b, msg.Payload)
+ msg.Free()
+ return n, nil
+}
+
+func (c *Conn) Close() error {
+ return c.client.closeConn(c, c.dstStringID)
+}
+
+func (c *Conn) LocalAddr() net.Addr {
+ return c.client.relayConn.LocalAddr()
+}
+
+func (c *Conn) RemoteAddr() net.Addr {
+ return c.instanceURL
+}
+
+func (c *Conn) SetDeadline(t time.Time) error {
+ //TODO implement me
+ panic("SetDeadline is not implemented")
+}
+
+func (c *Conn) SetReadDeadline(t time.Time) error {
+ //TODO implement me
+ panic("SetReadDeadline is not implemented")
+}
+
+func (c *Conn) SetWriteDeadline(t time.Time) error {
+ //TODO implement me
+ panic("SetReadDeadline is not implemented")
+}
diff --git a/relay/client/dialer/ws/addr.go b/relay/client/dialer/ws/addr.go
new file mode 100644
index 000000000..43f5dd6af
--- /dev/null
+++ b/relay/client/dialer/ws/addr.go
@@ -0,0 +1,13 @@
+package ws
+
+type WebsocketAddr struct {
+ addr string
+}
+
+func (a WebsocketAddr) Network() string {
+ return "websocket"
+}
+
+func (a WebsocketAddr) String() string {
+ return a.addr
+}
diff --git a/relay/client/dialer/ws/conn.go b/relay/client/dialer/ws/conn.go
new file mode 100644
index 000000000..e7f771b8d
--- /dev/null
+++ b/relay/client/dialer/ws/conn.go
@@ -0,0 +1,66 @@
+package ws
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "time"
+
+ "nhooyr.io/websocket"
+)
+
+type Conn struct {
+ ctx context.Context
+ *websocket.Conn
+ remoteAddr WebsocketAddr
+}
+
+func NewConn(wsConn *websocket.Conn, serverAddress string) net.Conn {
+ return &Conn{
+ ctx: context.Background(),
+ Conn: wsConn,
+ remoteAddr: WebsocketAddr{serverAddress},
+ }
+}
+
+func (c *Conn) Read(b []byte) (n int, err error) {
+ t, ioReader, err := c.Conn.Reader(c.ctx)
+ if err != nil {
+ return 0, err
+ }
+
+ if t != websocket.MessageBinary {
+ return 0, fmt.Errorf("unexpected message type")
+ }
+
+ return ioReader.Read(b)
+}
+
+func (c *Conn) Write(b []byte) (n int, err error) {
+ err = c.Conn.Write(c.ctx, websocket.MessageBinary, b)
+ return 0, err
+}
+
+func (c *Conn) RemoteAddr() net.Addr {
+ return c.remoteAddr
+}
+
+func (c *Conn) LocalAddr() net.Addr {
+ return WebsocketAddr{addr: "unknown"}
+}
+
+func (c *Conn) SetReadDeadline(t time.Time) error {
+ return fmt.Errorf("SetReadDeadline is not implemented")
+}
+
+func (c *Conn) SetWriteDeadline(t time.Time) error {
+ return fmt.Errorf("SetWriteDeadline is not implemented")
+}
+
+func (c *Conn) SetDeadline(t time.Time) error {
+ return fmt.Errorf("SetDeadline is not implemented")
+}
+
+func (c *Conn) Close() error {
+ return c.Conn.CloseNow()
+}
diff --git a/relay/client/dialer/ws/ws.go b/relay/client/dialer/ws/ws.go
new file mode 100644
index 000000000..d9388aafd
--- /dev/null
+++ b/relay/client/dialer/ws/ws.go
@@ -0,0 +1,67 @@
+package ws
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "net/http"
+ "net/url"
+ "strings"
+
+ log "github.com/sirupsen/logrus"
+ "nhooyr.io/websocket"
+
+ "github.com/netbirdio/netbird/relay/server/listener/ws"
+ nbnet "github.com/netbirdio/netbird/util/net"
+)
+
+func Dial(address string) (net.Conn, error) {
+ wsURL, err := prepareURL(address)
+ if err != nil {
+ return nil, err
+ }
+
+ opts := &websocket.DialOptions{
+ HTTPClient: httpClientNbDialer(),
+ }
+
+ parsedURL, err := url.Parse(wsURL)
+ if err != nil {
+ return nil, err
+ }
+ parsedURL.Path = ws.URLPath
+
+ wsConn, resp, err := websocket.Dial(context.Background(), parsedURL.String(), opts)
+ if err != nil {
+ log.Errorf("failed to dial to Relay server '%s': %s", wsURL, err)
+ return nil, err
+ }
+ if resp.Body != nil {
+ _ = resp.Body.Close()
+ }
+
+ conn := NewConn(wsConn, address)
+ return conn, nil
+}
+
+func prepareURL(address string) (string, error) {
+ if !strings.HasPrefix(address, "rel:") && !strings.HasPrefix(address, "rels:") {
+ return "", fmt.Errorf("unsupported scheme: %s", address)
+ }
+
+ return strings.Replace(address, "rel", "ws", 1), nil
+}
+
+func httpClientNbDialer() *http.Client {
+ customDialer := nbnet.NewDialer()
+
+ customTransport := &http.Transport{
+ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
+ return customDialer.DialContext(ctx, network, addr)
+ },
+ }
+
+ return &http.Client{
+ Transport: customTransport,
+ }
+}
diff --git a/relay/client/doc.go b/relay/client/doc.go
new file mode 100644
index 000000000..1339251d9
--- /dev/null
+++ b/relay/client/doc.go
@@ -0,0 +1,12 @@
+/*
+Package client contains the implementation of the Relay client.
+
+The Relay client is responsible for establishing a connection with the Relay server and sending and receiving messages,
+Keep persistent connection with the Relay server and handle the connection issues.
+It uses the WebSocket protocol for communication and optionally supports TLS (Transport Layer Security).
+
+If a peer wants to communicate with a peer on a different relay server, the manager will establish a new connection to
+the relay server. The connection with these relay servers will be closed if there is no active connection. The peers
+negotiate the common relay instance via signaling service.
+*/
+package client
diff --git a/relay/client/guard.go b/relay/client/guard.go
new file mode 100644
index 000000000..f826cf1b6
--- /dev/null
+++ b/relay/client/guard.go
@@ -0,0 +1,48 @@
+package client
+
+import (
+ "context"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+)
+
+var (
+ reconnectingTimeout = 5 * time.Second
+)
+
+// Guard manage the reconnection tries to the Relay server in case of disconnection event.
+type Guard struct {
+ ctx context.Context
+ relayClient *Client
+}
+
+// NewGuard creates a new guard for the relay client.
+func NewGuard(context context.Context, relayClient *Client) *Guard {
+ g := &Guard{
+ ctx: context,
+ relayClient: relayClient,
+ }
+ return g
+}
+
+// OnDisconnected is called when the relay client is disconnected from the relay server. It will trigger the reconnection
+// todo prevent multiple reconnection instances. In the current usage it should not happen, but it is better to prevent
+func (g *Guard) OnDisconnected() {
+ ticker := time.NewTicker(reconnectingTimeout)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ticker.C:
+ err := g.relayClient.Connect()
+ if err != nil {
+ log.Errorf("failed to reconnect to relay server: %s", err)
+ continue
+ }
+ return
+ case <-g.ctx.Done():
+ return
+ }
+ }
+}
diff --git a/relay/client/manager.go b/relay/client/manager.go
new file mode 100644
index 000000000..3e152a963
--- /dev/null
+++ b/relay/client/manager.go
@@ -0,0 +1,365 @@
+package client
+
+import (
+ "container/list"
+ "context"
+ "errors"
+ "fmt"
+ "net"
+ "reflect"
+ "sync"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+
+ relayAuth "github.com/netbirdio/netbird/relay/auth/hmac"
+)
+
+var (
+ relayCleanupInterval = 60 * time.Second
+ connectionTimeout = 30 * time.Second
+ maxConcurrentServers = 7
+
+ ErrRelayClientNotConnected = fmt.Errorf("relay client not connected")
+)
+
+// RelayTrack hold the relay clients for the foreign relay servers.
+// With the mutex can ensure we can open new connection in case the relay connection has been established with
+// the relay server.
+type RelayTrack struct {
+ sync.RWMutex
+ relayClient *Client
+}
+
+func NewRelayTrack() *RelayTrack {
+ return &RelayTrack{}
+}
+
+type OnServerCloseListener func()
+
+// ManagerService is the interface for the relay manager.
+type ManagerService interface {
+ Serve() error
+ OpenConn(serverAddress, peerKey string) (net.Conn, error)
+ AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error
+ RelayInstanceAddress() (string, error)
+ ServerURLs() []string
+ HasRelayAddress() bool
+ UpdateToken(token *relayAuth.Token) error
+}
+
+// Manager is a manager for the relay client instances. It establishes one persistent connection to the given relay URL
+// and automatically reconnect to them in case disconnection.
+// The manager also manage temporary relay connection. If a client wants to communicate with a client on a
+// different relay server, the manager will establish a new connection to the relay server. The connection with these
+// relay servers will be closed if there is no active connection. Periodically the manager will check if there is any
+// unused relay connection and close it.
+type Manager struct {
+ ctx context.Context
+ serverURLs []string
+ peerID string
+ tokenStore *relayAuth.TokenStore
+
+ relayClient *Client
+ reconnectGuard *Guard
+
+ relayClients map[string]*RelayTrack
+ relayClientsMutex sync.RWMutex
+
+ onDisconnectedListeners map[string]*list.List
+ listenerLock sync.Mutex
+}
+
+// NewManager creates a new manager instance.
+// The serverURL address can be empty. In this case, the manager will not serve.
+func NewManager(ctx context.Context, serverURLs []string, peerID string) *Manager {
+ return &Manager{
+ ctx: ctx,
+ serverURLs: serverURLs,
+ peerID: peerID,
+ tokenStore: &relayAuth.TokenStore{},
+ relayClients: make(map[string]*RelayTrack),
+ onDisconnectedListeners: make(map[string]*list.List),
+ }
+}
+
+// Serve starts the manager. It will establish a connection to the relay server and start the relay cleanup loop for
+// the unused relay connections. The manager will automatically reconnect to the relay server in case of disconnection.
+func (m *Manager) Serve() error {
+ if m.relayClient != nil {
+ return fmt.Errorf("manager already serving")
+ }
+ log.Debugf("starting relay client manager with %v relay servers", m.serverURLs)
+
+ totalServers := len(m.serverURLs)
+
+ successChan := make(chan *Client, 1)
+ errChan := make(chan error, len(m.serverURLs))
+
+ ctx, cancel := context.WithTimeout(m.ctx, connectionTimeout)
+ defer cancel()
+
+ sem := make(chan struct{}, maxConcurrentServers)
+
+ for _, url := range m.serverURLs {
+ sem <- struct{}{}
+ go func(url string) {
+ defer func() { <-sem }()
+ m.connect(m.ctx, url, successChan, errChan)
+ }(url)
+ }
+
+ var errCount int
+
+ for {
+ select {
+ case client := <-successChan:
+ log.Infof("Successfully connected to relay server: %s", client.connectionURL)
+
+ m.relayClient = client
+
+ m.reconnectGuard = NewGuard(m.ctx, m.relayClient)
+ m.relayClient.SetOnDisconnectListener(func() {
+ m.onServerDisconnected(client.connectionURL)
+ })
+ m.startCleanupLoop()
+ return nil
+ case err := <-errChan:
+ errCount++
+ log.Warnf("Connection attempt failed: %v", err)
+ if errCount == totalServers {
+ return errors.New("failed to connect to any relay server: all attempts failed")
+ }
+ case <-ctx.Done():
+ return fmt.Errorf("failed to connect to any relay server: %w", ctx.Err())
+ }
+ }
+}
+
+func (m *Manager) connect(ctx context.Context, serverURL string, successChan chan<- *Client, errChan chan<- error) {
+ // TODO: abort the connection if another connection was successful
+ relayClient := NewClient(ctx, serverURL, m.tokenStore, m.peerID)
+ if err := relayClient.Connect(); err != nil {
+ errChan <- fmt.Errorf("failed to connect to %s: %w", serverURL, err)
+ return
+ }
+
+ select {
+ case successChan <- relayClient:
+ // This client was the first to connect successfully
+ default:
+ if err := relayClient.Close(); err != nil {
+ log.Debugf("failed to close relay client: %s", err)
+ }
+ }
+}
+
+// OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be
+// established via the relay server. If the peer is on a different relay server, the manager will establish a new
+// connection to the relay server. It returns back with a net.Conn what represent the remote peer connection.
+func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) {
+ if m.relayClient == nil {
+ return nil, ErrRelayClientNotConnected
+ }
+
+ foreign, err := m.isForeignServer(serverAddress)
+ if err != nil {
+ return nil, err
+ }
+
+ var (
+ netConn net.Conn
+ )
+ if !foreign {
+ log.Debugf("open peer connection via permanent server: %s", peerKey)
+ netConn, err = m.relayClient.OpenConn(peerKey)
+ } else {
+ log.Debugf("open peer connection via foreign server: %s", serverAddress)
+ netConn, err = m.openConnVia(serverAddress, peerKey)
+ }
+ if err != nil {
+ return nil, err
+ }
+
+ return netConn, err
+}
+
+// AddCloseListener adds a listener to the given server instance address. The listener will be called if the connection
+// closed.
+func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error {
+ foreign, err := m.isForeignServer(serverAddress)
+ if err != nil {
+ return err
+ }
+
+ var listenerAddr string
+ if foreign {
+ listenerAddr = serverAddress
+ } else {
+ listenerAddr = m.relayClient.connectionURL
+ }
+ m.addListener(listenerAddr, onClosedListener)
+ return nil
+}
+
+// RelayInstanceAddress returns the address of the permanent relay server. It could change if the network connection is
+// lost. This address will be sent to the target peer to choose the common relay server for the communication.
+func (m *Manager) RelayInstanceAddress() (string, error) {
+ if m.relayClient == nil {
+ return "", ErrRelayClientNotConnected
+ }
+ return m.relayClient.ServerInstanceURL()
+}
+
+// ServerURLs returns the addresses of the relay servers.
+func (m *Manager) ServerURLs() []string {
+ return m.serverURLs
+}
+
+// HasRelayAddress returns true if the manager is serving. With this method can check if the peer can communicate with
+// Relay service.
+func (m *Manager) HasRelayAddress() bool {
+ return len(m.serverURLs) > 0
+}
+
+// UpdateToken updates the token in the token store.
+func (m *Manager) UpdateToken(token *relayAuth.Token) error {
+ return m.tokenStore.UpdateToken(token)
+}
+
+func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
+ // check if already has a connection to the desired relay server
+ m.relayClientsMutex.RLock()
+ rt, ok := m.relayClients[serverAddress]
+ if ok {
+ rt.RLock()
+ m.relayClientsMutex.RUnlock()
+ defer rt.RUnlock()
+ return rt.relayClient.OpenConn(peerKey)
+ }
+ m.relayClientsMutex.RUnlock()
+
+ // if not, establish a new connection but check it again (because changed the lock type) before starting the
+ // connection
+ m.relayClientsMutex.Lock()
+ rt, ok = m.relayClients[serverAddress]
+ if ok {
+ rt.RLock()
+ m.relayClientsMutex.Unlock()
+ defer rt.RUnlock()
+ return rt.relayClient.OpenConn(peerKey)
+ }
+
+ // create a new relay client and store it in the relayClients map
+ rt = NewRelayTrack()
+ rt.Lock()
+ m.relayClients[serverAddress] = rt
+ m.relayClientsMutex.Unlock()
+
+ relayClient := NewClient(m.ctx, serverAddress, m.tokenStore, m.peerID)
+ err := relayClient.Connect()
+ if err != nil {
+ rt.Unlock()
+ m.relayClientsMutex.Lock()
+ delete(m.relayClients, serverAddress)
+ m.relayClientsMutex.Unlock()
+ return nil, err
+ }
+ // if connection closed then delete the relay client from the list
+ relayClient.SetOnDisconnectListener(func() {
+ m.onServerDisconnected(serverAddress)
+ })
+ rt.relayClient = relayClient
+ rt.Unlock()
+
+ conn, err := relayClient.OpenConn(peerKey)
+ if err != nil {
+ return nil, err
+ }
+ return conn, nil
+}
+
+func (m *Manager) onServerDisconnected(serverAddress string) {
+ if serverAddress == m.relayClient.connectionURL {
+ go m.reconnectGuard.OnDisconnected()
+ }
+
+ m.notifyOnDisconnectListeners(serverAddress)
+}
+
+func (m *Manager) isForeignServer(address string) (bool, error) {
+ rAddr, err := m.relayClient.ServerInstanceURL()
+ if err != nil {
+ return false, fmt.Errorf("relay client not connected")
+ }
+ return rAddr != address, nil
+}
+
+func (m *Manager) startCleanupLoop() {
+ if m.ctx.Err() != nil {
+ return
+ }
+
+ ticker := time.NewTicker(relayCleanupInterval)
+ go func() {
+ defer ticker.Stop()
+ for {
+ select {
+ case <-m.ctx.Done():
+ return
+ case <-ticker.C:
+ m.cleanUpUnusedRelays()
+ }
+ }
+ }()
+}
+
+func (m *Manager) cleanUpUnusedRelays() {
+ m.relayClientsMutex.Lock()
+ defer m.relayClientsMutex.Unlock()
+
+ for addr, rt := range m.relayClients {
+ rt.Lock()
+ if rt.relayClient.HasConns() {
+ rt.Unlock()
+ continue
+ }
+ rt.relayClient.SetOnDisconnectListener(nil)
+ go func() {
+ _ = rt.relayClient.Close()
+ }()
+ log.Debugf("clean up unused relay server connection: %s", addr)
+ delete(m.relayClients, addr)
+ rt.Unlock()
+ }
+}
+
+func (m *Manager) addListener(serverAddress string, onClosedListener OnServerCloseListener) {
+ m.listenerLock.Lock()
+ defer m.listenerLock.Unlock()
+ l, ok := m.onDisconnectedListeners[serverAddress]
+ if !ok {
+ l = list.New()
+ }
+ for e := l.Front(); e != nil; e = e.Next() {
+ if reflect.ValueOf(e.Value).Pointer() == reflect.ValueOf(onClosedListener).Pointer() {
+ return
+ }
+ }
+ l.PushBack(onClosedListener)
+ m.onDisconnectedListeners[serverAddress] = l
+}
+
+func (m *Manager) notifyOnDisconnectListeners(serverAddress string) {
+ m.listenerLock.Lock()
+ defer m.listenerLock.Unlock()
+
+ l, ok := m.onDisconnectedListeners[serverAddress]
+ if !ok {
+ return
+ }
+ for e := l.Front(); e != nil; e = e.Next() {
+ go e.Value.(OnServerCloseListener)()
+ }
+ delete(m.onDisconnectedListeners, serverAddress)
+}
diff --git a/relay/client/manager_test.go b/relay/client/manager_test.go
new file mode 100644
index 000000000..e9cc2c581
--- /dev/null
+++ b/relay/client/manager_test.go
@@ -0,0 +1,432 @@
+package client
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+ "go.opentelemetry.io/otel"
+
+ "github.com/netbirdio/netbird/relay/server"
+)
+
+func TestEmptyURL(t *testing.T) {
+ mgr := NewManager(context.Background(), nil, "alice")
+ err := mgr.Serve()
+ if err == nil {
+ t.Errorf("expected error, got nil")
+ }
+}
+
+func TestForeignConn(t *testing.T) {
+ ctx := context.Background()
+
+ srvCfg1 := server.ListenerConfig{
+ Address: "localhost:1234",
+ }
+ srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av)
+ if err != nil {
+ t.Fatalf("failed to create server: %s", err)
+ }
+ errChan := make(chan error, 1)
+ go func() {
+ err := srv1.Listen(srvCfg1)
+ if err != nil {
+ errChan <- err
+ }
+ }()
+
+ defer func() {
+ err := srv1.Shutdown(ctx)
+ if err != nil {
+ t.Errorf("failed to close server: %s", err)
+ }
+ }()
+
+ if err := waitForServerToStart(errChan); err != nil {
+ t.Fatalf("failed to start server: %s", err)
+ }
+
+ srvCfg2 := server.ListenerConfig{
+ Address: "localhost:2234",
+ }
+ srv2, err := server.NewServer(otel.Meter(""), srvCfg2.Address, false, av)
+ if err != nil {
+ t.Fatalf("failed to create server: %s", err)
+ }
+ errChan2 := make(chan error, 1)
+ go func() {
+ err := srv2.Listen(srvCfg2)
+ if err != nil {
+ errChan2 <- err
+ }
+ }()
+
+ defer func() {
+ err := srv2.Shutdown(ctx)
+ if err != nil {
+ t.Errorf("failed to close server: %s", err)
+ }
+ }()
+
+ if err := waitForServerToStart(errChan2); err != nil {
+ t.Fatalf("failed to start server: %s", err)
+ }
+
+ idAlice := "alice"
+ log.Debugf("connect by alice")
+ mCtx, cancel := context.WithCancel(ctx)
+ defer cancel()
+ clientAlice := NewManager(mCtx, toURL(srvCfg1), idAlice)
+ err = clientAlice.Serve()
+ if err != nil {
+ t.Fatalf("failed to serve manager: %s", err)
+ }
+
+ idBob := "bob"
+ log.Debugf("connect by bob")
+ clientBob := NewManager(mCtx, toURL(srvCfg2), idBob)
+ err = clientBob.Serve()
+ if err != nil {
+ t.Fatalf("failed to serve manager: %s", err)
+ }
+ bobsSrvAddr, err := clientBob.RelayInstanceAddress()
+ if err != nil {
+ t.Fatalf("failed to get relay address: %s", err)
+ }
+ connAliceToBob, err := clientAlice.OpenConn(bobsSrvAddr, idBob)
+ if err != nil {
+ t.Fatalf("failed to bind channel: %s", err)
+ }
+ connBobToAlice, err := clientBob.OpenConn(bobsSrvAddr, idAlice)
+ if err != nil {
+ t.Fatalf("failed to bind channel: %s", err)
+ }
+
+ payload := "hello bob, I am alice"
+ _, err = connAliceToBob.Write([]byte(payload))
+ if err != nil {
+ t.Fatalf("failed to write to channel: %s", err)
+ }
+
+ buf := make([]byte, 65535)
+ n, err := connBobToAlice.Read(buf)
+ if err != nil {
+ t.Fatalf("failed to read from channel: %s", err)
+ }
+
+ _, err = connBobToAlice.Write(buf[:n])
+ if err != nil {
+ t.Fatalf("failed to write to channel: %s", err)
+ }
+
+ n, err = connAliceToBob.Read(buf)
+ if err != nil {
+ t.Fatalf("failed to read from channel: %s", err)
+ }
+
+ if payload != string(buf[:n]) {
+ t.Fatalf("expected %s, got %s", payload, string(buf[:n]))
+ }
+}
+
+func TestForeginConnClose(t *testing.T) {
+ ctx := context.Background()
+
+ srvCfg1 := server.ListenerConfig{
+ Address: "localhost:1234",
+ }
+ srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av)
+ if err != nil {
+ t.Fatalf("failed to create server: %s", err)
+ }
+ errChan := make(chan error, 1)
+ go func() {
+ err := srv1.Listen(srvCfg1)
+ if err != nil {
+ errChan <- err
+ }
+ }()
+
+ defer func() {
+ err := srv1.Shutdown(ctx)
+ if err != nil {
+ t.Errorf("failed to close server: %s", err)
+ }
+ }()
+
+ if err := waitForServerToStart(errChan); err != nil {
+ t.Fatalf("failed to start server: %s", err)
+ }
+
+ srvCfg2 := server.ListenerConfig{
+ Address: "localhost:2234",
+ }
+ srv2, err := server.NewServer(otel.Meter(""), srvCfg2.Address, false, av)
+ if err != nil {
+ t.Fatalf("failed to create server: %s", err)
+ }
+ errChan2 := make(chan error, 1)
+ go func() {
+ err := srv2.Listen(srvCfg2)
+ if err != nil {
+ errChan2 <- err
+ }
+ }()
+
+ defer func() {
+ err := srv2.Shutdown(ctx)
+ if err != nil {
+ t.Errorf("failed to close server: %s", err)
+ }
+ }()
+
+ if err := waitForServerToStart(errChan2); err != nil {
+ t.Fatalf("failed to start server: %s", err)
+ }
+
+ idAlice := "alice"
+ log.Debugf("connect by alice")
+ mCtx, cancel := context.WithCancel(ctx)
+ defer cancel()
+ mgr := NewManager(mCtx, toURL(srvCfg1), idAlice)
+ err = mgr.Serve()
+ if err != nil {
+ t.Fatalf("failed to serve manager: %s", err)
+ }
+ conn, err := mgr.OpenConn(toURL(srvCfg2)[0], "anotherpeer")
+ if err != nil {
+ t.Fatalf("failed to bind channel: %s", err)
+ }
+
+ err = conn.Close()
+ if err != nil {
+ t.Fatalf("failed to close connection: %s", err)
+ }
+}
+
+func TestForeginAutoClose(t *testing.T) {
+ ctx := context.Background()
+ relayCleanupInterval = 1 * time.Second
+ srvCfg1 := server.ListenerConfig{
+ Address: "localhost:1234",
+ }
+ srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av)
+ if err != nil {
+ t.Fatalf("failed to create server: %s", err)
+ }
+ errChan := make(chan error, 1)
+ go func() {
+ t.Log("binding server 1.")
+ err := srv1.Listen(srvCfg1)
+ if err != nil {
+ errChan <- err
+ }
+ }()
+
+ defer func() {
+ t.Logf("closing server 1.")
+ err := srv1.Shutdown(ctx)
+ if err != nil {
+ t.Errorf("failed to close server: %s", err)
+ }
+ t.Logf("server 1. closed")
+ }()
+
+ if err := waitForServerToStart(errChan); err != nil {
+ t.Fatalf("failed to start server: %s", err)
+ }
+
+ srvCfg2 := server.ListenerConfig{
+ Address: "localhost:2234",
+ }
+ srv2, err := server.NewServer(otel.Meter(""), srvCfg2.Address, false, av)
+ if err != nil {
+ t.Fatalf("failed to create server: %s", err)
+ }
+ errChan2 := make(chan error, 1)
+ go func() {
+ t.Log("binding server 2.")
+ err := srv2.Listen(srvCfg2)
+ if err != nil {
+ errChan2 <- err
+ }
+ }()
+ defer func() {
+ t.Logf("closing server 2.")
+ err := srv2.Shutdown(ctx)
+ if err != nil {
+ t.Errorf("failed to close server: %s", err)
+ }
+ t.Logf("server 2 closed.")
+ }()
+
+ if err := waitForServerToStart(errChan2); err != nil {
+ t.Fatalf("failed to start server: %s", err)
+ }
+
+ idAlice := "alice"
+ t.Log("connect to server 1.")
+ mCtx, cancel := context.WithCancel(ctx)
+ defer cancel()
+ mgr := NewManager(mCtx, toURL(srvCfg1), idAlice)
+ err = mgr.Serve()
+ if err != nil {
+ t.Fatalf("failed to serve manager: %s", err)
+ }
+
+ t.Log("open connection to another peer")
+ conn, err := mgr.OpenConn(toURL(srvCfg2)[0], "anotherpeer")
+ if err != nil {
+ t.Fatalf("failed to bind channel: %s", err)
+ }
+
+ t.Log("close conn")
+ err = conn.Close()
+ if err != nil {
+ t.Fatalf("failed to close connection: %s", err)
+ }
+
+ t.Logf("waiting for relay cleanup: %s", relayCleanupInterval+1*time.Second)
+ time.Sleep(relayCleanupInterval + 1*time.Second)
+ if len(mgr.relayClients) != 0 {
+ t.Errorf("expected 0, got %d", len(mgr.relayClients))
+ }
+
+ t.Logf("closing manager")
+}
+
+func TestAutoReconnect(t *testing.T) {
+ ctx := context.Background()
+ reconnectingTimeout = 2 * time.Second
+
+ srvCfg := server.ListenerConfig{
+ Address: "localhost:1234",
+ }
+ srv, err := server.NewServer(otel.Meter(""), srvCfg.Address, false, av)
+ if err != nil {
+ t.Fatalf("failed to create server: %s", err)
+ }
+ errChan := make(chan error, 1)
+ go func() {
+ err := srv.Listen(srvCfg)
+ if err != nil {
+ errChan <- err
+ }
+ }()
+
+ defer func() {
+ err := srv.Shutdown(ctx)
+ if err != nil {
+ log.Errorf("failed to close server: %s", err)
+ }
+ }()
+
+ if err := waitForServerToStart(errChan); err != nil {
+ t.Fatalf("failed to start server: %s", err)
+ }
+
+ mCtx, cancel := context.WithCancel(ctx)
+ defer cancel()
+ clientAlice := NewManager(mCtx, toURL(srvCfg), "alice")
+ err = clientAlice.Serve()
+ if err != nil {
+ t.Fatalf("failed to serve manager: %s", err)
+ }
+ ra, err := clientAlice.RelayInstanceAddress()
+ if err != nil {
+ t.Errorf("failed to get relay address: %s", err)
+ }
+ conn, err := clientAlice.OpenConn(ra, "bob")
+ if err != nil {
+ t.Errorf("failed to bind channel: %s", err)
+ }
+
+ t.Log("closing client relay connection")
+ // todo figure out moc server
+ _ = clientAlice.relayClient.relayConn.Close()
+ t.Log("start test reading")
+ _, err = conn.Read(make([]byte, 1))
+ if err == nil {
+ t.Errorf("unexpected reading from closed connection")
+ }
+
+ log.Infof("waiting for reconnection")
+ time.Sleep(reconnectingTimeout + 1*time.Second)
+
+ log.Infof("reopent the connection")
+ _, err = clientAlice.OpenConn(ra, "bob")
+ if err != nil {
+ t.Errorf("failed to open channel: %s", err)
+ }
+}
+
+func TestNotifierDoubleAdd(t *testing.T) {
+ ctx := context.Background()
+
+ srvCfg1 := server.ListenerConfig{
+ Address: "localhost:1234",
+ }
+ srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av)
+ if err != nil {
+ t.Fatalf("failed to create server: %s", err)
+ }
+ errChan := make(chan error, 1)
+ go func() {
+ err := srv1.Listen(srvCfg1)
+ if err != nil {
+ errChan <- err
+ }
+ }()
+
+ defer func() {
+ err := srv1.Shutdown(ctx)
+ if err != nil {
+ t.Errorf("failed to close server: %s", err)
+ }
+ }()
+
+ if err := waitForServerToStart(errChan); err != nil {
+ t.Fatalf("failed to start server: %s", err)
+ }
+
+ idAlice := "alice"
+ log.Debugf("connect by alice")
+ mCtx, cancel := context.WithCancel(ctx)
+ defer cancel()
+ clientAlice := NewManager(mCtx, toURL(srvCfg1), idAlice)
+ err = clientAlice.Serve()
+ if err != nil {
+ t.Fatalf("failed to serve manager: %s", err)
+ }
+
+ conn1, err := clientAlice.OpenConn(clientAlice.ServerURLs()[0], "idBob")
+ if err != nil {
+ t.Fatalf("failed to bind channel: %s", err)
+ }
+
+ fnCloseListener := OnServerCloseListener(func() {
+ log.Infof("close listener")
+ })
+
+ err = clientAlice.AddCloseListener(clientAlice.ServerURLs()[0], fnCloseListener)
+ if err != nil {
+ t.Fatalf("failed to add close listener: %s", err)
+ }
+
+ err = clientAlice.AddCloseListener(clientAlice.ServerURLs()[0], fnCloseListener)
+ if err != nil {
+ t.Fatalf("failed to add close listener: %s", err)
+ }
+
+ err = conn1.Close()
+ if err != nil {
+ t.Errorf("failed to close connection: %s", err)
+ }
+
+}
+
+func toURL(address server.ListenerConfig) []string {
+ return []string{"rel://" + address.Address}
+}
diff --git a/relay/cmd/env.go b/relay/cmd/env.go
new file mode 100644
index 000000000..3c15ebe1f
--- /dev/null
+++ b/relay/cmd/env.go
@@ -0,0 +1,35 @@
+package cmd
+
+import (
+ "os"
+ "strings"
+
+ log "github.com/sirupsen/logrus"
+ "github.com/spf13/cobra"
+ "github.com/spf13/pflag"
+)
+
+// setFlagsFromEnvVars reads and updates flag values from environment variables with prefix NB_
+func setFlagsFromEnvVars(cmd *cobra.Command) {
+ flags := cmd.PersistentFlags()
+ flags.VisitAll(func(f *pflag.Flag) {
+ newEnvVar := flagNameToEnvVar(f.Name, "NB_")
+ value, present := os.LookupEnv(newEnvVar)
+ if !present {
+ return
+ }
+
+ err := flags.Set(f.Name, value)
+ if err != nil {
+ log.Infof("unable to configure flag %s using variable %s, err: %v", f.Name, newEnvVar, err)
+ }
+ })
+}
+
+// flagNameToEnvVar converts flag name to environment var name adding a prefix,
+// replacing dashes and making all uppercase (e.g. setup-keys is converted to NB_SETUP_KEYS according to the input prefix)
+func flagNameToEnvVar(cmdFlag string, prefix string) string {
+ parsed := strings.ReplaceAll(cmdFlag, "-", "_")
+ upper := strings.ToUpper(parsed)
+ return prefix + upper
+}
diff --git a/relay/cmd/root.go b/relay/cmd/root.go
new file mode 100644
index 000000000..784b42c1a
--- /dev/null
+++ b/relay/cmd/root.go
@@ -0,0 +1,214 @@
+package cmd
+
+import (
+ "context"
+ "crypto/tls"
+ "errors"
+ "fmt"
+ "net/http"
+ "os"
+ "os/signal"
+ "syscall"
+ "time"
+
+ "github.com/hashicorp/go-multierror"
+ log "github.com/sirupsen/logrus"
+ "github.com/spf13/cobra"
+
+ "github.com/netbirdio/netbird/encryption"
+ auth "github.com/netbirdio/netbird/relay/auth/hmac"
+ "github.com/netbirdio/netbird/relay/server"
+ "github.com/netbirdio/netbird/signal/metrics"
+ "github.com/netbirdio/netbird/util"
+)
+
+const (
+ metricsPort = 9090
+)
+
+type Config struct {
+ ListenAddress string
+ // in HA every peer connect to a common domain, the instance domain has been distributed during the p2p connection
+ // it is a domain:port or ip:port
+ ExposedAddress string
+ LetsencryptEmail string
+ LetsencryptDataDir string
+ LetsencryptDomains []string
+ // in case of using Route 53 for DNS challenge the credentials should be provided in the environment variables or
+ // in the AWS credentials file
+ LetsencryptAWSRoute53 bool
+ TlsCertFile string
+ TlsKeyFile string
+ AuthSecret string
+ LogLevel string
+ LogFile string
+}
+
+func (c Config) Validate() error {
+ if c.ExposedAddress == "" {
+ return fmt.Errorf("exposed address is required")
+ }
+ if c.AuthSecret == "" {
+ return fmt.Errorf("auth secret is required")
+ }
+ return nil
+}
+
+func (c Config) HasCertConfig() bool {
+ return c.TlsCertFile != "" && c.TlsKeyFile != ""
+}
+
+func (c Config) HasLetsEncrypt() bool {
+ return c.LetsencryptDataDir != "" && c.LetsencryptDomains != nil && len(c.LetsencryptDomains) > 0
+}
+
+var (
+ cobraConfig *Config
+ rootCmd = &cobra.Command{
+ Use: "relay",
+ Short: "Relay service",
+ Long: "Relay service for Netbird agents",
+ SilenceUsage: true,
+ SilenceErrors: true,
+ RunE: execute,
+ }
+)
+
+func init() {
+ _ = util.InitLog("trace", "console")
+ cobraConfig = &Config{}
+ rootCmd.PersistentFlags().StringVarP(&cobraConfig.ListenAddress, "listen-address", "l", ":443", "listen address")
+ rootCmd.PersistentFlags().StringVarP(&cobraConfig.ExposedAddress, "exposed-address", "e", "", "instance domain address (or ip) and port, it will be distributes between peers")
+ rootCmd.PersistentFlags().StringVarP(&cobraConfig.LetsencryptDataDir, "letsencrypt-data-dir", "d", "", "a directory to store Let's Encrypt data. Required if Let's Encrypt is enabled.")
+ rootCmd.PersistentFlags().StringSliceVarP(&cobraConfig.LetsencryptDomains, "letsencrypt-domains", "a", nil, "list of domains to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
+ rootCmd.PersistentFlags().StringVar(&cobraConfig.LetsencryptEmail, "letsencrypt-email", "", "email address to use for Let's Encrypt certificate registration")
+ rootCmd.PersistentFlags().BoolVar(&cobraConfig.LetsencryptAWSRoute53, "letsencrypt-aws-route53", false, "use AWS Route 53 for Let's Encrypt DNS challenge")
+ rootCmd.PersistentFlags().StringVarP(&cobraConfig.TlsCertFile, "tls-cert-file", "c", "", "")
+ rootCmd.PersistentFlags().StringVarP(&cobraConfig.TlsKeyFile, "tls-key-file", "k", "", "")
+ rootCmd.PersistentFlags().StringVarP(&cobraConfig.AuthSecret, "auth-secret", "s", "", "auth secret")
+ rootCmd.PersistentFlags().StringVar(&cobraConfig.LogLevel, "log-level", "info", "log level")
+ rootCmd.PersistentFlags().StringVar(&cobraConfig.LogFile, "log-file", "console", "log file")
+
+ setFlagsFromEnvVars(rootCmd)
+}
+
+func Execute() error {
+ return rootCmd.Execute()
+}
+
+func waitForExitSignal() {
+ osSigs := make(chan os.Signal, 1)
+ signal.Notify(osSigs, syscall.SIGINT, syscall.SIGTERM)
+ <-osSigs
+}
+
+func execute(cmd *cobra.Command, args []string) error {
+ err := cobraConfig.Validate()
+ if err != nil {
+ log.Debugf("invalid config: %s", err)
+ return fmt.Errorf("invalid config: %s", err)
+ }
+
+ err = util.InitLog(cobraConfig.LogLevel, cobraConfig.LogFile)
+ if err != nil {
+ log.Debugf("failed to initialize log: %s", err)
+ return fmt.Errorf("failed to initialize log: %s", err)
+ }
+
+ metricsServer, err := metrics.NewServer(metricsPort, "")
+ if err != nil {
+ log.Debugf("setup metrics: %v", err)
+ return fmt.Errorf("setup metrics: %v", err)
+ }
+
+ go func() {
+ log.Infof("running metrics server: %s%s", metricsServer.Addr, metricsServer.Endpoint)
+ if err := metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
+ log.Fatalf("Failed to start metrics server: %v", err)
+ }
+ }()
+
+ srvListenerCfg := server.ListenerConfig{
+ Address: cobraConfig.ListenAddress,
+ }
+
+ tlsConfig, tlsSupport, err := handleTLSConfig(cobraConfig)
+ if err != nil {
+ log.Debugf("failed to setup TLS config: %s", err)
+ return fmt.Errorf("failed to setup TLS config: %s", err)
+ }
+ srvListenerCfg.TLSConfig = tlsConfig
+
+ authenticator := auth.NewTimedHMACValidator(cobraConfig.AuthSecret, 24*time.Hour)
+ srv, err := server.NewServer(metricsServer.Meter, cobraConfig.ExposedAddress, tlsSupport, authenticator)
+ if err != nil {
+ log.Debugf("failed to create relay server: %v", err)
+ return fmt.Errorf("failed to create relay server: %v", err)
+ }
+ log.Infof("server will be available on: %s", srv.InstanceURL())
+ go func() {
+ if err := srv.Listen(srvListenerCfg); err != nil {
+ log.Fatalf("failed to bind server: %s", err)
+ }
+ }()
+
+ // it will block until exit signal
+ waitForExitSignal()
+
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+
+ var shutDownErrors error
+ if err := srv.Shutdown(ctx); err != nil {
+ shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close server: %s", err))
+ }
+
+ log.Infof("shutting down metrics server")
+ if err := metricsServer.Shutdown(ctx); err != nil {
+ shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close metrics server: %v", err))
+ }
+ return shutDownErrors
+}
+
+func handleTLSConfig(cfg *Config) (*tls.Config, bool, error) {
+ if cfg.LetsencryptAWSRoute53 {
+ log.Debugf("using Let's Encrypt DNS resolver with Route 53 support")
+ r53 := encryption.Route53TLS{
+ DataDir: cfg.LetsencryptDataDir,
+ Email: cfg.LetsencryptEmail,
+ Domains: cfg.LetsencryptDomains,
+ }
+ tlsCfg, err := r53.GetCertificate()
+ if err != nil {
+ return nil, false, fmt.Errorf("%s", err)
+ }
+ return tlsCfg, true, nil
+ }
+
+ if cfg.HasLetsEncrypt() {
+ log.Infof("setting up TLS with Let's Encrypt.")
+ tlsCfg, err := setupTLSCertManager(cfg.LetsencryptDataDir, cfg.LetsencryptDomains...)
+ if err != nil {
+ return nil, false, fmt.Errorf("%s", err)
+ }
+ return tlsCfg, true, nil
+ }
+
+ if cfg.HasCertConfig() {
+ log.Debugf("using file based TLS config")
+ tlsCfg, err := encryption.LoadTLSConfig(cfg.TlsCertFile, cfg.TlsKeyFile)
+ if err != nil {
+ return nil, false, fmt.Errorf("%s", err)
+ }
+ return tlsCfg, true, nil
+ }
+ return nil, false, nil
+}
+
+func setupTLSCertManager(letsencryptDataDir string, letsencryptDomains ...string) (*tls.Config, error) {
+ certManager, err := encryption.CreateCertManager(letsencryptDataDir, letsencryptDomains...)
+ if err != nil {
+ return nil, fmt.Errorf("failed creating LetsEncrypt cert manager: %v", err)
+ }
+ return certManager.TLSConfig(), nil
+}
diff --git a/relay/doc.go b/relay/doc.go
new file mode 100644
index 000000000..56e010e3e
--- /dev/null
+++ b/relay/doc.go
@@ -0,0 +1,14 @@
+//Package main
+/*
+The `relay` package contains the implementation of the Relay server and client. The Relay server can be used to relay
+messages between peers on a single network channel. In this implementation the transport layer is the WebSocket
+protocol.
+
+Between the server and client communication has been design a custom protocol and message format. These messages are
+transported over the WebSocket connection. Optionally the server can use TLS to secure the communication.
+
+The service can support multiple Relay server instances. For this purpose the peers must know the server instance URL.
+This URL will be sent to the target peer to choose the common Relay server for the communication via Signal service.
+
+*/
+package main
diff --git a/relay/healthcheck/doc.go b/relay/healthcheck/doc.go
new file mode 100644
index 000000000..da9689c6b
--- /dev/null
+++ b/relay/healthcheck/doc.go
@@ -0,0 +1,17 @@
+/*
+The `healthcheck` package is responsible for managing the health checks between the client and the relay server. It
+ensures that the connection between the client and the server are alive and functioning properly.
+
+The `Sender` struct is responsible for sending health check signals to the receiver. The receiver listens for these
+signals and sends a new signal back to the sender to acknowledge that the signal has been received. If the sender does
+not receive an acknowledgment signal within a certain time frame, it will send a timeout signal via timeout channel
+and stop working.
+
+The `Receiver` struct is responsible for receiving the health check signals from the sender. If the receiver does not
+receive a signal within a certain time frame, it will send a timeout signal via the OnTimeout channel and stop working.
+
+In the Relay usage the signal is sent to the peer in message type Healthcheck. In case of timeout the connection is
+closed and the peer is removed from the relay.
+*/
+
+package healthcheck
diff --git a/relay/healthcheck/receiver.go b/relay/healthcheck/receiver.go
new file mode 100644
index 000000000..2b9c9e2e0
--- /dev/null
+++ b/relay/healthcheck/receiver.go
@@ -0,0 +1,82 @@
+package healthcheck
+
+import (
+ "context"
+ "time"
+)
+
+var (
+ heartbeatTimeout = healthCheckInterval + 3*time.Second
+)
+
+// Receiver is a healthcheck receiver
+// It will listen for heartbeat and check if the heartbeat is not received in a certain time
+// If the heartbeat is not received in a certain time, it will send a timeout signal and stop to work
+// The heartbeat timeout is a bit longer than the sender's healthcheck interval
+type Receiver struct {
+ OnTimeout chan struct{}
+
+ ctx context.Context
+ ctxCancel context.CancelFunc
+ heartbeat chan struct{}
+ alive bool
+}
+
+// NewReceiver creates a new healthcheck receiver and start the timer in the background
+func NewReceiver() *Receiver {
+ ctx, ctxCancel := context.WithCancel(context.Background())
+
+ r := &Receiver{
+ OnTimeout: make(chan struct{}, 1),
+ ctx: ctx,
+ ctxCancel: ctxCancel,
+ heartbeat: make(chan struct{}, 1),
+ }
+
+ go r.waitForHealthcheck()
+ return r
+}
+
+// Heartbeat acknowledge the heartbeat has been received
+func (r *Receiver) Heartbeat() {
+ select {
+ case r.heartbeat <- struct{}{}:
+ default:
+ }
+}
+
+// Stop check the timeout and do not send new notifications
+func (r *Receiver) Stop() {
+ r.ctxCancel()
+}
+
+func (r *Receiver) waitForHealthcheck() {
+ ticker := time.NewTicker(heartbeatTimeout)
+ defer ticker.Stop()
+ defer r.ctxCancel()
+ defer close(r.OnTimeout)
+
+ for {
+ select {
+ case <-r.heartbeat:
+ r.alive = true
+ case <-ticker.C:
+ if r.alive {
+ r.alive = false
+ continue
+ }
+
+ r.notifyTimeout()
+ return
+ case <-r.ctx.Done():
+ return
+ }
+ }
+}
+
+func (r *Receiver) notifyTimeout() {
+ select {
+ case r.OnTimeout <- struct{}{}:
+ default:
+ }
+}
diff --git a/relay/healthcheck/receiver_test.go b/relay/healthcheck/receiver_test.go
new file mode 100644
index 000000000..4b4123416
--- /dev/null
+++ b/relay/healthcheck/receiver_test.go
@@ -0,0 +1,42 @@
+package healthcheck
+
+import (
+ "testing"
+ "time"
+)
+
+func TestNewReceiver(t *testing.T) {
+ heartbeatTimeout = 5 * time.Second
+ r := NewReceiver()
+
+ select {
+ case <-r.OnTimeout:
+ t.Error("unexpected timeout")
+ case <-time.After(1 * time.Second):
+
+ }
+}
+
+func TestNewReceiverNotReceive(t *testing.T) {
+ heartbeatTimeout = 1 * time.Second
+ r := NewReceiver()
+
+ select {
+ case <-r.OnTimeout:
+ case <-time.After(2 * time.Second):
+ t.Error("timeout not received")
+ }
+}
+
+func TestNewReceiverAck(t *testing.T) {
+ heartbeatTimeout = 2 * time.Second
+ r := NewReceiver()
+
+ r.Heartbeat()
+
+ select {
+ case <-r.OnTimeout:
+ t.Error("unexpected timeout")
+ case <-time.After(3 * time.Second):
+ }
+}
diff --git a/relay/healthcheck/sender.go b/relay/healthcheck/sender.go
new file mode 100644
index 000000000..ec0560ef2
--- /dev/null
+++ b/relay/healthcheck/sender.go
@@ -0,0 +1,68 @@
+package healthcheck
+
+import (
+ "context"
+ "time"
+)
+
+var (
+ healthCheckInterval = 25 * time.Second
+ healthCheckTimeout = 5 * time.Second
+)
+
+// Sender is a healthcheck sender
+// It will send healthcheck signal to the receiver
+// If the receiver does not receive the signal in a certain time, it will send a timeout signal and stop to work
+// It will also stop if the context is canceled
+type Sender struct {
+ // HealthCheck is a channel to send health check signal to the peer
+ HealthCheck chan struct{}
+ // Timeout is a channel to the health check signal is not received in a certain time
+ Timeout chan struct{}
+
+ ack chan struct{}
+}
+
+// NewSender creates a new healthcheck sender
+func NewSender() *Sender {
+ hc := &Sender{
+ HealthCheck: make(chan struct{}, 1),
+ Timeout: make(chan struct{}, 1),
+ ack: make(chan struct{}, 1),
+ }
+
+ return hc
+}
+
+// OnHCResponse sends an acknowledgment signal to the sender
+func (hc *Sender) OnHCResponse() {
+ select {
+ case hc.ack <- struct{}{}:
+ default:
+ }
+}
+
+func (hc *Sender) StartHealthCheck(ctx context.Context) {
+ ticker := time.NewTicker(healthCheckInterval)
+ defer ticker.Stop()
+
+ timeoutTimer := time.NewTimer(healthCheckInterval + healthCheckTimeout)
+ defer timeoutTimer.Stop()
+
+ defer close(hc.HealthCheck)
+ defer close(hc.Timeout)
+
+ for {
+ select {
+ case <-ticker.C:
+ hc.HealthCheck <- struct{}{}
+ case <-timeoutTimer.C:
+ hc.Timeout <- struct{}{}
+ return
+ case <-hc.ack:
+ timeoutTimer.Reset(healthCheckInterval + healthCheckTimeout)
+ case <-ctx.Done():
+ return
+ }
+ }
+}
diff --git a/relay/healthcheck/sender_test.go b/relay/healthcheck/sender_test.go
new file mode 100644
index 000000000..7a105c308
--- /dev/null
+++ b/relay/healthcheck/sender_test.go
@@ -0,0 +1,103 @@
+package healthcheck
+
+import (
+ "context"
+ "os"
+ "testing"
+ "time"
+)
+
+func TestMain(m *testing.M) {
+ // override the health check interval to speed up the test
+ healthCheckInterval = 2 * time.Second
+ healthCheckTimeout = 100 * time.Millisecond
+ code := m.Run()
+ os.Exit(code)
+}
+
+func TestNewHealthPeriod(t *testing.T) {
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ hc := NewSender()
+ go hc.StartHealthCheck(ctx)
+
+ iterations := 0
+ for i := 0; i < 3; i++ {
+ select {
+ case <-hc.HealthCheck:
+ iterations++
+ hc.OnHCResponse()
+ case <-hc.Timeout:
+ t.Fatalf("health check is timed out")
+ case <-time.After(healthCheckInterval + 100*time.Millisecond):
+ t.Fatalf("health check not received")
+ }
+ }
+}
+
+func TestNewHealthFailed(t *testing.T) {
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ hc := NewSender()
+ go hc.StartHealthCheck(ctx)
+
+ select {
+ case <-hc.Timeout:
+ case <-time.After(healthCheckInterval + healthCheckTimeout + 100*time.Millisecond):
+ t.Fatalf("health check is not timed out")
+ }
+}
+
+func TestNewHealthcheckStop(t *testing.T) {
+ ctx, cancel := context.WithCancel(context.Background())
+ hc := NewSender()
+ go hc.StartHealthCheck(ctx)
+
+ time.Sleep(100 * time.Millisecond)
+ cancel()
+
+ select {
+ case _, ok := <-hc.HealthCheck:
+ if ok {
+ t.Fatalf("health check on received")
+ }
+ case _, ok := <-hc.Timeout:
+ if ok {
+ t.Fatalf("health check on received")
+ }
+ case <-ctx.Done():
+ // expected
+ case <-time.After(10 * time.Second):
+ t.Fatalf("is not exited")
+ }
+}
+
+func TestTimeoutReset(t *testing.T) {
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ hc := NewSender()
+ go hc.StartHealthCheck(ctx)
+
+ iterations := 0
+ for i := 0; i < 3; i++ {
+ select {
+ case <-hc.HealthCheck:
+ iterations++
+ hc.OnHCResponse()
+ case <-hc.Timeout:
+ t.Fatalf("health check is timed out")
+ case <-time.After(healthCheckInterval + 100*time.Millisecond):
+ t.Fatalf("health check not received")
+ }
+ }
+
+ select {
+ case <-hc.HealthCheck:
+ case <-hc.Timeout:
+ // expected
+ case <-ctx.Done():
+ t.Fatalf("context is done")
+ case <-time.After(10 * time.Second):
+ t.Fatalf("is not exited")
+ }
+}
diff --git a/relay/main.go b/relay/main.go
new file mode 100644
index 000000000..e28f73603
--- /dev/null
+++ b/relay/main.go
@@ -0,0 +1,13 @@
+package main
+
+import (
+ log "github.com/sirupsen/logrus"
+
+ "github.com/netbirdio/netbird/relay/cmd"
+)
+
+func main() {
+ if err := cmd.Execute(); err != nil {
+ log.Fatalf("failed to execute command: %v", err)
+ }
+}
diff --git a/relay/messages/address/address.go b/relay/messages/address/address.go
new file mode 100644
index 000000000..829206294
--- /dev/null
+++ b/relay/messages/address/address.go
@@ -0,0 +1,30 @@
+package address
+
+import (
+ "bytes"
+ "encoding/gob"
+ "fmt"
+)
+
+type Address struct {
+ URL string
+}
+
+func (addr *Address) Marshal() ([]byte, error) {
+ var buf bytes.Buffer
+ enc := gob.NewEncoder(&buf)
+ if err := enc.Encode(addr); err != nil {
+ return nil, fmt.Errorf("encode Address: %w", err)
+ }
+ return buf.Bytes(), nil
+}
+
+func Unmarshal(data []byte) (*Address, error) {
+ var addr Address
+ buf := bytes.NewBuffer(data)
+ dec := gob.NewDecoder(buf)
+ if err := dec.Decode(&addr); err != nil {
+ return nil, fmt.Errorf("decode Address: %w", err)
+ }
+ return &addr, nil
+}
diff --git a/relay/messages/auth/auth.go b/relay/messages/auth/auth.go
new file mode 100644
index 000000000..8230bccf2
--- /dev/null
+++ b/relay/messages/auth/auth.go
@@ -0,0 +1,51 @@
+package auth
+
+import (
+ "bytes"
+ "encoding/gob"
+ "fmt"
+)
+
+type Algorithm int
+
+const (
+ AlgoUnknown Algorithm = iota
+ AlgoHMACSHA256
+ AlgoHMACSHA512
+)
+
+func (a Algorithm) String() string {
+ switch a {
+ case AlgoHMACSHA256:
+ return "HMAC-SHA256"
+ case AlgoHMACSHA512:
+ return "HMAC-SHA512"
+ default:
+ return "Unknown"
+ }
+}
+
+type Msg struct {
+ AuthAlgorithm Algorithm
+ AdditionalData []byte
+}
+
+func (msg *Msg) Marshal() ([]byte, error) {
+ var buf bytes.Buffer
+ enc := gob.NewEncoder(&buf)
+ if err := enc.Encode(msg); err != nil {
+ return nil, fmt.Errorf("encode Msg: %w", err)
+ }
+ return buf.Bytes(), nil
+}
+
+func UnmarshalMsg(data []byte) (*Msg, error) {
+ var msg *Msg
+
+ buf := bytes.NewBuffer(data)
+ dec := gob.NewDecoder(buf)
+ if err := dec.Decode(&msg); err != nil {
+ return nil, fmt.Errorf("decode Msg: %w", err)
+ }
+ return msg, nil
+}
diff --git a/relay/messages/doc.go b/relay/messages/doc.go
new file mode 100644
index 000000000..4c719df3a
--- /dev/null
+++ b/relay/messages/doc.go
@@ -0,0 +1,5 @@
+/*
+Package messages provides the message types that are used to communicate between the relay and the client.
+This package is used to determine the type of message that is being sent and received between the relay and the client.
+*/
+package messages
diff --git a/relay/messages/id.go b/relay/messages/id.go
new file mode 100644
index 000000000..e2162cd3b
--- /dev/null
+++ b/relay/messages/id.go
@@ -0,0 +1,31 @@
+package messages
+
+import (
+ "crypto/sha256"
+ "encoding/base64"
+ "fmt"
+)
+
+const (
+ prefixLength = 4
+ IDSize = prefixLength + sha256.Size
+)
+
+var (
+ prefix = []byte("sha-") // 4 bytes
+)
+
+// HashID generates a sha256 hash from the peerID and returns the hash and the human-readable string
+func HashID(peerID string) ([]byte, string) {
+ idHash := sha256.Sum256([]byte(peerID))
+ idHashString := string(prefix) + base64.StdEncoding.EncodeToString(idHash[:])
+ var prefixedHash []byte
+ prefixedHash = append(prefixedHash, prefix...)
+ prefixedHash = append(prefixedHash, idHash[:]...)
+ return prefixedHash, idHashString
+}
+
+// HashIDToString converts a hash to a human-readable string
+func HashIDToString(idHash []byte) string {
+ return fmt.Sprintf("%s%s", idHash[:prefixLength], base64.StdEncoding.EncodeToString(idHash[prefixLength:]))
+}
diff --git a/relay/messages/id_test.go b/relay/messages/id_test.go
new file mode 100644
index 000000000..271a8f90d
--- /dev/null
+++ b/relay/messages/id_test.go
@@ -0,0 +1,13 @@
+package messages
+
+import (
+ "testing"
+)
+
+func TestHashID(t *testing.T) {
+ hashedID, hashedStringId := HashID("alice")
+ enc := HashIDToString(hashedID)
+ if enc != hashedStringId {
+ t.Errorf("expected %s, got %s", hashedStringId, enc)
+ }
+}
diff --git a/relay/messages/message.go b/relay/messages/message.go
new file mode 100644
index 000000000..cfcac3f72
--- /dev/null
+++ b/relay/messages/message.go
@@ -0,0 +1,239 @@
+package messages
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+)
+
+const (
+ MsgTypeUnknown MsgType = 0
+ MsgTypeHello MsgType = 1
+ MsgTypeHelloResponse MsgType = 2
+ MsgTypeTransport MsgType = 3
+ MsgTypeClose MsgType = 4
+ MsgTypeHealthCheck MsgType = 5
+
+ SizeOfVersionByte = 1
+ SizeOfMsgType = 1
+
+ SizeOfProtoHeader = SizeOfVersionByte + SizeOfMsgType
+
+ sizeOfMagicByte = 4
+
+ headerSizeTransport = IDSize
+ headerSizeHello = sizeOfMagicByte + IDSize
+ headerSizeHelloResp = 0
+
+ MaxHandshakeSize = 8192
+
+ CurrentProtocolVersion = 1
+)
+
+var (
+ ErrInvalidMessageLength = errors.New("invalid message length")
+ ErrUnsupportedVersion = errors.New("unsupported version")
+
+ magicHeader = []byte{0x21, 0x12, 0xA4, 0x42}
+
+ healthCheckMsg = []byte{byte(CurrentProtocolVersion), byte(MsgTypeHealthCheck)}
+)
+
+type MsgType byte
+
+func (m MsgType) String() string {
+ switch m {
+ case MsgTypeHello:
+ return "hello"
+ case MsgTypeHelloResponse:
+ return "hello response"
+ case MsgTypeTransport:
+ return "transport"
+ case MsgTypeClose:
+ return "close"
+ case MsgTypeHealthCheck:
+ return "health check"
+ default:
+ return "unknown"
+ }
+}
+
+type HelloResponse struct {
+ InstanceAddress string
+}
+
+// ValidateVersion checks if the given version is supported by the protocol
+func ValidateVersion(msg []byte) (int, error) {
+ if len(msg) < SizeOfVersionByte {
+ return 0, ErrInvalidMessageLength
+ }
+ version := int(msg[0])
+ if version != CurrentProtocolVersion {
+ return 0, fmt.Errorf("%d: %w", version, ErrUnsupportedVersion)
+ }
+ return version, nil
+}
+
+// DetermineClientMessageType determines the message type from the first the message
+func DetermineClientMessageType(msg []byte) (MsgType, error) {
+ if len(msg) < SizeOfMsgType {
+ return 0, ErrInvalidMessageLength
+ }
+
+ msgType := MsgType(msg[0])
+ switch msgType {
+ case
+ MsgTypeHello,
+ MsgTypeTransport,
+ MsgTypeClose,
+ MsgTypeHealthCheck:
+ return msgType, nil
+ default:
+ return MsgTypeUnknown, fmt.Errorf("invalid msg type %d", msgType)
+ }
+}
+
+// DetermineServerMessageType determines the message type from the first the message
+func DetermineServerMessageType(msg []byte) (MsgType, error) {
+ if len(msg) < SizeOfMsgType {
+ return 0, ErrInvalidMessageLength
+ }
+
+ msgType := MsgType(msg[0])
+ switch msgType {
+ case
+ MsgTypeHelloResponse,
+ MsgTypeTransport,
+ MsgTypeClose,
+ MsgTypeHealthCheck:
+ return msgType, nil
+ default:
+ return MsgTypeUnknown, fmt.Errorf("invalid msg type %d", msgType)
+ }
+}
+
+// MarshalHelloMsg initial hello message
+// The Hello message is the first message sent by a client after establishing a connection with the Relay server. This
+// message is used to authenticate the client with the server. The authentication is done using an HMAC method.
+// The protocol does not limit to use HMAC, it can be any other method. If the authentication failed the server will
+// close the network connection without any response.
+func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) {
+ if len(peerID) != IDSize {
+ return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
+ }
+
+ msg := make([]byte, SizeOfProtoHeader+sizeOfMagicByte, SizeOfProtoHeader+headerSizeHello+len(additions))
+
+ msg[0] = byte(CurrentProtocolVersion)
+ msg[1] = byte(MsgTypeHello)
+
+ copy(msg[SizeOfProtoHeader:SizeOfProtoHeader+sizeOfMagicByte], magicHeader)
+
+ msg = append(msg, peerID...)
+ msg = append(msg, additions...)
+
+ return msg, nil
+}
+
+// UnmarshalHelloMsg extracts peerID and the additional data from the hello message. The Additional data is used to
+// authenticate the client with the server.
+func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
+ if len(msg) < headerSizeHello {
+ return nil, nil, ErrInvalidMessageLength
+ }
+ if !bytes.Equal(msg[:sizeOfMagicByte], magicHeader) {
+ return nil, nil, errors.New("invalid magic header")
+ }
+
+ return msg[sizeOfMagicByte:headerSizeHello], msg[headerSizeHello:], nil
+}
+
+// MarshalHelloResponse creates a response message to the hello message.
+// In case of success connection the server response with a Hello Response message. This message contains the server's
+// instance URL. This URL will be used by choose the common Relay server in case if the peers are in different Relay
+// servers.
+func MarshalHelloResponse(additionalData []byte) ([]byte, error) {
+ msg := make([]byte, SizeOfProtoHeader, SizeOfProtoHeader+headerSizeHelloResp+len(additionalData))
+
+ msg[0] = byte(CurrentProtocolVersion)
+ msg[1] = byte(MsgTypeHelloResponse)
+
+ msg = append(msg, additionalData...)
+
+ return msg, nil
+}
+
+// UnmarshalHelloResponse extracts the additional data from the hello response message.
+func UnmarshalHelloResponse(msg []byte) ([]byte, error) {
+ if len(msg) < headerSizeHelloResp {
+ return nil, ErrInvalidMessageLength
+ }
+ return msg, nil
+}
+
+// MarshalCloseMsg creates a close message.
+// The close message is used to close the connection gracefully between the client and the server. The server and the
+// client can send this message. After receiving this message, the server or client will close the connection.
+func MarshalCloseMsg() []byte {
+ msg := make([]byte, SizeOfProtoHeader)
+
+ msg[0] = byte(CurrentProtocolVersion)
+ msg[1] = byte(MsgTypeClose)
+
+ return msg
+}
+
+// MarshalTransportMsg creates a transport message.
+// The transport message is used to exchange data between peers. The message contains the data to be exchanged and the
+// destination peer hashed ID.
+func MarshalTransportMsg(peerID []byte, payload []byte) ([]byte, error) {
+ if len(peerID) != IDSize {
+ return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
+ }
+
+ msg := make([]byte, SizeOfProtoHeader+headerSizeTransport, SizeOfProtoHeader+headerSizeTransport+len(payload))
+
+ msg[0] = byte(CurrentProtocolVersion)
+ msg[1] = byte(MsgTypeTransport)
+
+ copy(msg[SizeOfProtoHeader:], peerID)
+
+ msg = append(msg, payload...)
+
+ return msg, nil
+}
+
+// UnmarshalTransportMsg extracts the peerID and the payload from the transport message.
+func UnmarshalTransportMsg(buf []byte) ([]byte, []byte, error) {
+ if len(buf) < headerSizeTransport {
+ return nil, nil, ErrInvalidMessageLength
+ }
+
+ return buf[:headerSizeTransport], buf[headerSizeTransport:], nil
+}
+
+// UnmarshalTransportID extracts the peerID from the transport message.
+func UnmarshalTransportID(buf []byte) ([]byte, error) {
+ if len(buf) < headerSizeTransport {
+ return nil, ErrInvalidMessageLength
+ }
+ return buf[:headerSizeTransport], nil
+}
+
+// UpdateTransportMsg updates the peerID in the transport message.
+// With this function the server can reuse the given byte slice to update the peerID in the transport message. So do
+// need to allocate a new byte slice.
+func UpdateTransportMsg(msg []byte, peerID []byte) error {
+ if len(msg) < len(peerID) {
+ return ErrInvalidMessageLength
+ }
+ copy(msg, peerID)
+ return nil
+}
+
+// MarshalHealthcheck creates a health check message.
+// Health check message is sent by the server periodically. The client will respond with a health check response
+// message. If the client does not respond to the health check message, the server will close the connection.
+func MarshalHealthcheck() []byte {
+ return healthCheckMsg
+}
diff --git a/relay/messages/message_test.go b/relay/messages/message_test.go
new file mode 100644
index 000000000..a4e7d9fae
--- /dev/null
+++ b/relay/messages/message_test.go
@@ -0,0 +1,43 @@
+package messages
+
+import (
+ "testing"
+)
+
+func TestMarshalHelloMsg(t *testing.T) {
+ peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
+ bHello, err := MarshalHelloMsg(peerID, nil)
+ if err != nil {
+ t.Fatalf("error: %v", err)
+ }
+
+ receivedPeerID, _, err := UnmarshalHelloMsg(bHello[SizeOfProtoHeader:])
+ if err != nil {
+ t.Fatalf("error: %v", err)
+ }
+ if string(receivedPeerID) != string(peerID) {
+ t.Errorf("expected %s, got %s", peerID, receivedPeerID)
+ }
+}
+
+func TestMarshalTransportMsg(t *testing.T) {
+ peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
+ payload := []byte("payload")
+ msg, err := MarshalTransportMsg(peerID, payload)
+ if err != nil {
+ t.Fatalf("error: %v", err)
+ }
+
+ id, respPayload, err := UnmarshalTransportMsg(msg[SizeOfProtoHeader:])
+ if err != nil {
+ t.Fatalf("error: %v", err)
+ }
+
+ if string(id) != string(peerID) {
+ t.Errorf("expected %s, got %s", peerID, id)
+ }
+
+ if string(respPayload) != string(payload) {
+ t.Errorf("expected %s, got %s", payload, respPayload)
+ }
+}
diff --git a/relay/metrics/realy.go b/relay/metrics/realy.go
new file mode 100644
index 000000000..80e12ee6b
--- /dev/null
+++ b/relay/metrics/realy.go
@@ -0,0 +1,136 @@
+package metrics
+
+import (
+ "context"
+ "sync"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+ "go.opentelemetry.io/otel/metric"
+)
+
+const (
+ idleTimeout = 30 * time.Second
+)
+
+type Metrics struct {
+ metric.Meter
+
+ TransferBytesSent metric.Int64Counter
+ TransferBytesRecv metric.Int64Counter
+
+ peers metric.Int64UpDownCounter
+ peerActivityChan chan string
+ peerLastActive map[string]time.Time
+ mutexActivity sync.Mutex
+ ctx context.Context
+}
+
+func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) {
+ bytesSent, err := meter.Int64Counter("relay_transfer_sent_bytes_total")
+ if err != nil {
+ return nil, err
+ }
+
+ bytesRecv, err := meter.Int64Counter("relay_transfer_received_bytes_total")
+ if err != nil {
+ return nil, err
+ }
+
+ peers, err := meter.Int64UpDownCounter("relay_peers")
+ if err != nil {
+ return nil, err
+ }
+
+ peersActive, err := meter.Int64ObservableGauge("relay_peers_active")
+ if err != nil {
+ return nil, err
+ }
+
+ peersIdle, err := meter.Int64ObservableGauge("relay_peers_idle")
+ if err != nil {
+ return nil, err
+ }
+
+ m := &Metrics{
+ Meter: meter,
+ TransferBytesSent: bytesSent,
+ TransferBytesRecv: bytesRecv,
+ peers: peers,
+
+ ctx: ctx,
+ peerActivityChan: make(chan string, 10),
+ peerLastActive: make(map[string]time.Time),
+ }
+
+ _, err = meter.RegisterCallback(
+ func(ctx context.Context, o metric.Observer) error {
+ active, idle := m.calculateActiveIdleConnections()
+ o.ObserveInt64(peersActive, active)
+ o.ObserveInt64(peersIdle, idle)
+ return nil
+ },
+ peersActive, peersIdle,
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ go m.readPeerActivity()
+ return m, nil
+}
+
+// PeerConnected increments the number of connected peers and increments number of idle connections
+func (m *Metrics) PeerConnected(id string) {
+ m.peers.Add(m.ctx, 1)
+ m.mutexActivity.Lock()
+ defer m.mutexActivity.Unlock()
+
+ m.peerLastActive[id] = time.Time{}
+}
+
+// PeerDisconnected decrements the number of connected peers and decrements number of idle or active connections
+func (m *Metrics) PeerDisconnected(id string) {
+ m.peers.Add(m.ctx, -1)
+ m.mutexActivity.Lock()
+ defer m.mutexActivity.Unlock()
+
+ delete(m.peerLastActive, id)
+}
+
+// PeerActivity increases the active connections
+func (m *Metrics) PeerActivity(peerID string) {
+ select {
+ case m.peerActivityChan <- peerID:
+ default:
+ log.Errorf("peer activity channel is full, dropping activity metrics for peer %s", peerID)
+ }
+}
+
+func (m *Metrics) calculateActiveIdleConnections() (int64, int64) {
+ active, idle := int64(0), int64(0)
+ m.mutexActivity.Lock()
+ defer m.mutexActivity.Unlock()
+
+ for _, lastActive := range m.peerLastActive {
+ if time.Since(lastActive) > idleTimeout {
+ idle++
+ } else {
+ active++
+ }
+ }
+ return active, idle
+}
+
+func (m *Metrics) readPeerActivity() {
+ for {
+ select {
+ case peerID := <-m.peerActivityChan:
+ m.mutexActivity.Lock()
+ m.peerLastActive[peerID] = time.Now()
+ m.mutexActivity.Unlock()
+ case <-m.ctx.Done():
+ return
+ }
+ }
+}
diff --git a/relay/server/listener/listener.go b/relay/server/listener/listener.go
new file mode 100644
index 000000000..535c8bcd9
--- /dev/null
+++ b/relay/server/listener/listener.go
@@ -0,0 +1,11 @@
+package listener
+
+import (
+ "context"
+ "net"
+)
+
+type Listener interface {
+ Listen(func(conn net.Conn)) error
+ Shutdown(ctx context.Context) error
+}
diff --git a/relay/server/listener/ws/conn.go b/relay/server/listener/ws/conn.go
new file mode 100644
index 000000000..c248963b9
--- /dev/null
+++ b/relay/server/listener/ws/conn.go
@@ -0,0 +1,114 @@
+package ws
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "sync"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+ "nhooyr.io/websocket"
+)
+
+const (
+ writeTimeout = 10 * time.Second
+)
+
+type Conn struct {
+ *websocket.Conn
+ lAddr *net.TCPAddr
+ rAddr *net.TCPAddr
+
+ closed bool
+ closedMu sync.Mutex
+ ctx context.Context
+}
+
+func NewConn(wsConn *websocket.Conn, lAddr, rAddr *net.TCPAddr) *Conn {
+ return &Conn{
+ Conn: wsConn,
+ lAddr: lAddr,
+ rAddr: rAddr,
+ ctx: context.Background(),
+ }
+}
+
+func (c *Conn) Read(b []byte) (n int, err error) {
+ t, r, err := c.Reader(c.ctx)
+ if err != nil {
+ return 0, c.ioErrHandling(err)
+ }
+
+ if t != websocket.MessageBinary {
+ log.Errorf("unexpected message type: %d", t)
+ return 0, fmt.Errorf("unexpected message type")
+ }
+
+ n, err = r.Read(b)
+ if err != nil {
+ return 0, c.ioErrHandling(err)
+ }
+ return n, err
+}
+
+// Write writes a binary message with the given payload.
+// It does not block until fill the internal buffer.
+// If the buffer filled up, wait until the buffer is drained or timeout.
+func (c *Conn) Write(b []byte) (int, error) {
+ ctx, ctxCancel := context.WithTimeout(c.ctx, writeTimeout)
+ defer ctxCancel()
+
+ err := c.Conn.Write(ctx, websocket.MessageBinary, b)
+ return len(b), err
+}
+
+func (c *Conn) LocalAddr() net.Addr {
+ return c.lAddr
+}
+
+func (c *Conn) RemoteAddr() net.Addr {
+ return c.rAddr
+}
+
+func (c *Conn) SetReadDeadline(t time.Time) error {
+ return fmt.Errorf("SetReadDeadline is not implemented")
+}
+
+func (c *Conn) SetWriteDeadline(t time.Time) error {
+ return fmt.Errorf("SetWriteDeadline is not implemented")
+}
+
+func (c *Conn) SetDeadline(t time.Time) error {
+ return fmt.Errorf("SetDeadline is not implemented")
+}
+
+func (c *Conn) Close() error {
+ c.closedMu.Lock()
+ c.closed = true
+ c.closedMu.Unlock()
+ return c.Conn.CloseNow()
+}
+
+func (c *Conn) isClosed() bool {
+ c.closedMu.Lock()
+ defer c.closedMu.Unlock()
+ return c.closed
+}
+
+func (c *Conn) ioErrHandling(err error) error {
+ if c.isClosed() {
+ return io.EOF
+ }
+
+ var wErr *websocket.CloseError
+ if !errors.As(err, &wErr) {
+ return err
+ }
+ if wErr.Code == websocket.StatusNormalClosure {
+ return io.EOF
+ }
+ return err
+}
diff --git a/relay/server/listener/ws/listener.go b/relay/server/listener/ws/listener.go
new file mode 100644
index 000000000..10bfbe44d
--- /dev/null
+++ b/relay/server/listener/ws/listener.go
@@ -0,0 +1,92 @@
+package ws
+
+import (
+ "context"
+ "crypto/tls"
+ "errors"
+ "fmt"
+ "net"
+ "net/http"
+
+ log "github.com/sirupsen/logrus"
+ "nhooyr.io/websocket"
+)
+
+// URLPath is the path for the websocket connection.
+const URLPath = "/relay"
+
+type Listener struct {
+ // Address is the address to listen on.
+ Address string
+ // TLSConfig is the TLS configuration for the server.
+ TLSConfig *tls.Config
+
+ server *http.Server
+ acceptFn func(conn net.Conn)
+}
+
+func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
+ l.acceptFn = acceptFn
+ mux := http.NewServeMux()
+ mux.HandleFunc(URLPath, l.onAccept)
+
+ l.server = &http.Server{
+ Addr: l.Address,
+ Handler: mux,
+ TLSConfig: l.TLSConfig,
+ }
+
+ log.Infof("WS server listening address: %s", l.Address)
+ var err error
+ if l.TLSConfig != nil {
+ err = l.server.ListenAndServeTLS("", "")
+ } else {
+ err = l.server.ListenAndServe()
+ }
+ if errors.Is(err, http.ErrServerClosed) {
+ return nil
+ }
+ return err
+}
+
+func (l *Listener) Shutdown(ctx context.Context) error {
+ if l.server == nil {
+ return nil
+ }
+
+ log.Infof("stop WS listener")
+ if err := l.server.Shutdown(ctx); err != nil {
+ return fmt.Errorf("server shutdown failed: %v", err)
+ }
+ log.Infof("WS listener stopped")
+ return nil
+}
+
+func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) {
+ wsConn, err := websocket.Accept(w, r, nil)
+ if err != nil {
+ log.Errorf("failed to accept ws connection from %s: %s", r.RemoteAddr, err)
+ return
+ }
+
+ rAddr, err := net.ResolveTCPAddr("tcp", r.RemoteAddr)
+ if err != nil {
+ err = wsConn.Close(websocket.StatusInternalError, "internal error")
+ if err != nil {
+ log.Errorf("failed to close ws connection: %s", err)
+ }
+ return
+ }
+
+ lAddr, err := net.ResolveTCPAddr("tcp", l.server.Addr)
+ if err != nil {
+ err = wsConn.Close(websocket.StatusInternalError, "internal error")
+ if err != nil {
+ log.Errorf("failed to close ws connection: %s", err)
+ }
+ return
+ }
+
+ conn := NewConn(wsConn, lAddr, rAddr)
+ l.acceptFn(conn)
+}
diff --git a/relay/server/peer.go b/relay/server/peer.go
new file mode 100644
index 000000000..a9583700a
--- /dev/null
+++ b/relay/server/peer.go
@@ -0,0 +1,203 @@
+package server
+
+import (
+ "context"
+ "io"
+ "net"
+ "sync"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+
+ "github.com/netbirdio/netbird/relay/healthcheck"
+ "github.com/netbirdio/netbird/relay/messages"
+ "github.com/netbirdio/netbird/relay/metrics"
+)
+
+const (
+ bufferSize = 8820
+)
+
+// Peer represents a peer connection
+type Peer struct {
+ metrics *metrics.Metrics
+ log *log.Entry
+ idS string
+ idB []byte
+ conn net.Conn
+ connMu sync.RWMutex
+ store *Store
+}
+
+// NewPeer creates a new Peer instance and prepare custom logging
+func NewPeer(metrics *metrics.Metrics, id []byte, conn net.Conn, store *Store) *Peer {
+ stringID := messages.HashIDToString(id)
+ return &Peer{
+ metrics: metrics,
+ log: log.WithField("peer_id", stringID),
+ idS: stringID,
+ idB: id,
+ conn: conn,
+ store: store,
+ }
+}
+
+// Work reads data from the connection
+// It manages the protocol (healthcheck, transport, close). Read the message and determine the message type and handle
+// the message accordingly.
+func (p *Peer) Work() {
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ hc := healthcheck.NewSender()
+ go hc.StartHealthCheck(ctx)
+ go p.handleHealthcheckEvents(ctx, hc)
+
+ buf := make([]byte, bufferSize)
+ for {
+ n, err := p.conn.Read(buf)
+ if err != nil {
+ if err != io.EOF {
+ p.log.Errorf("failed to read message: %s", err)
+ }
+ return
+ }
+
+ if n == 0 {
+ p.log.Errorf("received empty message")
+ return
+ }
+
+ msg := buf[:n]
+
+ _, err = messages.ValidateVersion(msg)
+ if err != nil {
+ p.log.Warnf("failed to validate protocol version: %s", err)
+ return
+ }
+
+ msgType, err := messages.DetermineClientMessageType(msg[messages.SizeOfVersionByte:])
+ if err != nil {
+ p.log.Errorf("failed to determine message type: %s", err)
+ return
+ }
+
+ p.handleMsgType(ctx, msgType, hc, n, msg)
+ }
+}
+
+func (p *Peer) handleMsgType(ctx context.Context, msgType messages.MsgType, hc *healthcheck.Sender, n int, msg []byte) {
+ switch msgType {
+ case messages.MsgTypeHealthCheck:
+ hc.OnHCResponse()
+ case messages.MsgTypeTransport:
+ p.metrics.TransferBytesRecv.Add(ctx, int64(n))
+ p.metrics.PeerActivity(p.String())
+ p.handleTransportMsg(msg)
+ case messages.MsgTypeClose:
+ p.log.Infof("peer exited gracefully")
+ if err := p.conn.Close(); err != nil {
+ log.Errorf("failed to close connection to peer: %s", err)
+ }
+ default:
+ p.log.Warnf("received unexpected message type: %s", msgType)
+ }
+}
+
+// Write writes data to the connection
+func (p *Peer) Write(b []byte) (int, error) {
+ p.connMu.RLock()
+ defer p.connMu.RUnlock()
+ return p.conn.Write(b)
+}
+
+// CloseGracefully closes the connection with the peer gracefully. Send a close message to the client and close the
+// connection.
+func (p *Peer) CloseGracefully(ctx context.Context) {
+ p.connMu.Lock()
+ err := p.writeWithTimeout(ctx, messages.MarshalCloseMsg())
+ if err != nil {
+ p.log.Errorf("failed to send close message to peer: %s", p.String())
+ }
+
+ err = p.conn.Close()
+ if err != nil {
+ p.log.Errorf("failed to close connection to peer: %s", err)
+ }
+
+ defer p.connMu.Unlock()
+}
+
+// String returns the peer ID
+func (p *Peer) String() string {
+ return p.idS
+}
+
+func (p *Peer) writeWithTimeout(ctx context.Context, buf []byte) error {
+ ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
+ defer cancel()
+
+ writeDone := make(chan struct{})
+ var err error
+ go func() {
+ _, err = p.conn.Write(buf)
+ close(writeDone)
+ }()
+
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-writeDone:
+ return err
+ }
+}
+
+func (p *Peer) handleHealthcheckEvents(ctx context.Context, hc *healthcheck.Sender) {
+ for {
+ select {
+ case <-hc.HealthCheck:
+ _, err := p.Write(messages.MarshalHealthcheck())
+ if err != nil {
+ p.log.Errorf("failed to send healthcheck message: %s", err)
+ return
+ }
+ case <-hc.Timeout:
+ p.log.Errorf("peer healthcheck timeout")
+ err := p.conn.Close()
+ if err != nil {
+ p.log.Errorf("failed to close connection to peer: %s", err)
+ }
+ return
+ case <-ctx.Done():
+ return
+ }
+ }
+}
+
+func (p *Peer) handleTransportMsg(msg []byte) {
+ peerID, err := messages.UnmarshalTransportID(msg[messages.SizeOfProtoHeader:])
+ if err != nil {
+ p.log.Errorf("failed to unmarshal transport message: %s", err)
+ return
+ }
+
+ stringPeerID := messages.HashIDToString(peerID)
+ dp, ok := p.store.Peer(stringPeerID)
+ if !ok {
+ p.log.Errorf("peer not found: %s", stringPeerID)
+ return
+ }
+
+ err = messages.UpdateTransportMsg(msg[messages.SizeOfProtoHeader:], p.idB)
+ if err != nil {
+ p.log.Errorf("failed to update transport message: %s", err)
+ return
+ }
+
+ n, err := dp.Write(msg)
+ if err != nil {
+ p.log.Errorf("failed to write transport message to: %s", dp.String())
+ return
+ }
+ p.metrics.TransferBytesSent.Add(context.Background(), int64(n))
+}
diff --git a/relay/server/relay.go b/relay/server/relay.go
new file mode 100644
index 000000000..6d88cbbb2
--- /dev/null
+++ b/relay/server/relay.go
@@ -0,0 +1,206 @@
+package server
+
+import (
+ "context"
+ "crypto/sha256"
+ "fmt"
+ "net"
+ "net/url"
+ "strings"
+ "sync"
+
+ log "github.com/sirupsen/logrus"
+ "go.opentelemetry.io/otel/metric"
+
+ "github.com/netbirdio/netbird/relay/auth"
+ "github.com/netbirdio/netbird/relay/messages"
+ "github.com/netbirdio/netbird/relay/messages/address"
+ authmsg "github.com/netbirdio/netbird/relay/messages/auth"
+ "github.com/netbirdio/netbird/relay/metrics"
+)
+
+// Relay represents the relay server
+type Relay struct {
+ metrics *metrics.Metrics
+ metricsCancel context.CancelFunc
+ validator auth.Validator
+
+ store *Store
+ instanceURL string
+
+ closed bool
+ closeMu sync.RWMutex
+}
+
+// NewRelay creates a new Relay instance
+//
+// Parameters:
+// meter: An instance of metric.Meter from the go.opentelemetry.io/otel/metric package. It is used to create and manage
+// metrics for the relay server.
+// exposedAddress: A string representing the address that the relay server is exposed on. The client will use this
+// address as the relay server's instance URL.
+// tlsSupport: A boolean indicating whether the relay server supports TLS (Transport Layer Security) or not. The
+// instance URL depends on this value.
+// validator: An instance of auth.Validator from the auth package. It is used to validate the authentication of the
+// peers.
+//
+// Returns:
+// A pointer to a Relay instance and an error. If the Relay instance is successfully created, the error is nil.
+// Otherwise, the error contains the details of what went wrong.
+func NewRelay(meter metric.Meter, exposedAddress string, tlsSupport bool, validator auth.Validator) (*Relay, error) {
+ ctx, metricsCancel := context.WithCancel(context.Background())
+ m, err := metrics.NewMetrics(ctx, meter)
+ if err != nil {
+ metricsCancel()
+ return nil, fmt.Errorf("creating app metrics: %v", err)
+ }
+
+ r := &Relay{
+ metrics: m,
+ metricsCancel: metricsCancel,
+ validator: validator,
+ store: NewStore(),
+ }
+
+ r.instanceURL, err = getInstanceURL(exposedAddress, tlsSupport)
+ if err != nil {
+ metricsCancel()
+ return nil, fmt.Errorf("get instance URL: %v", err)
+ }
+
+ return r, nil
+}
+
+// getInstanceURL checks if user supplied a URL scheme otherwise adds to the
+// provided address according to TLS definition and parses the address before returning it
+func getInstanceURL(exposedAddress string, tlsSupported bool) (string, error) {
+ addr := exposedAddress
+ split := strings.Split(exposedAddress, "://")
+ switch {
+ case len(split) == 1 && tlsSupported:
+ addr = "rels://" + exposedAddress
+ case len(split) == 1 && !tlsSupported:
+ addr = "rel://" + exposedAddress
+ case len(split) > 2:
+ return "", fmt.Errorf("invalid exposed address: %s", exposedAddress)
+ }
+
+ parsedURL, err := url.ParseRequestURI(addr)
+ if err != nil {
+ return "", fmt.Errorf("invalid exposed address: %v", err)
+ }
+
+ if parsedURL.Scheme != "rel" && parsedURL.Scheme != "rels" {
+ return "", fmt.Errorf("invalid scheme: %s", parsedURL.Scheme)
+ }
+
+ return parsedURL.String(), nil
+}
+
+// Accept start to handle a new peer connection
+func (r *Relay) Accept(conn net.Conn) {
+ r.closeMu.RLock()
+ defer r.closeMu.RUnlock()
+ if r.closed {
+ return
+ }
+
+ peerID, err := r.handshake(conn)
+ if err != nil {
+ log.Errorf("failed to handshake: %s", err)
+ cErr := conn.Close()
+ if cErr != nil {
+ log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr)
+ }
+ return
+ }
+
+ peer := NewPeer(r.metrics, peerID, conn, r.store)
+ peer.log.Infof("peer connected from: %s", conn.RemoteAddr())
+ r.store.AddPeer(peer)
+ r.metrics.PeerConnected(peer.String())
+ go func() {
+ peer.Work()
+ r.store.DeletePeer(peer)
+ peer.log.Debugf("relay connection closed")
+ r.metrics.PeerDisconnected(peer.String())
+ }()
+}
+
+// Shutdown closes the relay server
+// It closes the connection with all peers in gracefully and stops accepting new connections.
+func (r *Relay) Shutdown(ctx context.Context) {
+ log.Infof("close connection with all peers")
+ r.closeMu.Lock()
+ wg := sync.WaitGroup{}
+ peers := r.store.Peers()
+ for _, peer := range peers {
+ wg.Add(1)
+ go func(p *Peer) {
+ p.CloseGracefully(ctx)
+ wg.Done()
+ }(peer)
+ }
+ wg.Wait()
+ r.metricsCancel()
+ r.closeMu.Unlock()
+}
+
+// InstanceURL returns the instance URL of the relay server
+func (r *Relay) InstanceURL() string {
+ return r.instanceURL
+}
+
+func (r *Relay) handshake(conn net.Conn) ([]byte, error) {
+ buf := make([]byte, messages.MaxHandshakeSize)
+ n, err := conn.Read(buf)
+ if err != nil {
+ return nil, fmt.Errorf("read from %s: %w", conn.RemoteAddr(), err)
+ }
+
+ _, err = messages.ValidateVersion(buf[:n])
+ if err != nil {
+ return nil, fmt.Errorf("validate version from %s: %w", conn.RemoteAddr(), err)
+ }
+
+ msgType, err := messages.DetermineClientMessageType(buf[messages.SizeOfVersionByte:n])
+ if err != nil {
+ return nil, fmt.Errorf("determine message type from %s: %w", conn.RemoteAddr(), err)
+ }
+
+ if msgType != messages.MsgTypeHello {
+ return nil, fmt.Errorf("invalid message type from %s", conn.RemoteAddr())
+ }
+
+ peerID, authData, err := messages.UnmarshalHelloMsg(buf[messages.SizeOfProtoHeader:n])
+ if err != nil {
+ return nil, fmt.Errorf("unmarshal hello message: %w", err)
+ }
+
+ authMsg, err := authmsg.UnmarshalMsg(authData)
+ if err != nil {
+ return nil, fmt.Errorf("unmarshal auth message: %w", err)
+ }
+
+ if err := r.validator.Validate(sha256.New, authMsg.AdditionalData); err != nil {
+ return nil, fmt.Errorf("validate %s (%s): %w", peerID, conn.RemoteAddr(), err)
+ }
+
+ addr := &address.Address{URL: r.instanceURL}
+ addrData, err := addr.Marshal()
+ if err != nil {
+ return nil, fmt.Errorf("marshal addressc to %s (%s): %w", peerID, conn.RemoteAddr(), err)
+ }
+
+ msg, err := messages.MarshalHelloResponse(addrData)
+ if err != nil {
+ return nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, conn.RemoteAddr(), err)
+ }
+
+ _, err = conn.Write(msg)
+ if err != nil {
+ return nil, fmt.Errorf("write to %s (%s): %w", peerID, conn.RemoteAddr(), err)
+ }
+
+ return peerID, nil
+}
diff --git a/relay/server/relay_test.go b/relay/server/relay_test.go
new file mode 100644
index 000000000..062039ab9
--- /dev/null
+++ b/relay/server/relay_test.go
@@ -0,0 +1,36 @@
+package server
+
+import "testing"
+
+func TestGetInstanceURL(t *testing.T) {
+ tests := []struct {
+ name string
+ exposedAddress string
+ tlsSupported bool
+ expectedURL string
+ expectError bool
+ }{
+ {"Valid address with TLS", "example.com", true, "rels://example.com", false},
+ {"Valid address without TLS", "example.com", false, "rel://example.com", false},
+ {"Valid address with scheme", "rel://example.com", false, "rel://example.com", false},
+ {"Valid address with non TLS scheme and TLS true", "rel://example.com", true, "rel://example.com", false},
+ {"Valid address with TLS scheme", "rels://example.com", true, "rels://example.com", false},
+ {"Valid address with TLS scheme and TLS false", "rels://example.com", false, "rels://example.com", false},
+ {"Valid address with TLS scheme and custom port", "rels://example.com:9300", true, "rels://example.com:9300", false},
+ {"Invalid address with multiple schemes", "rel://rels://example.com", false, "", true},
+ {"Invalid address with unsupported scheme", "http://example.com", false, "", true},
+ {"Invalid address format", "://example.com", false, "", true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ url, err := getInstanceURL(tt.exposedAddress, tt.tlsSupported)
+ if (err != nil) != tt.expectError {
+ t.Errorf("expected error: %v, got: %v", tt.expectError, err)
+ }
+ if url != tt.expectedURL {
+ t.Errorf("expected URL: %s, got: %s", tt.expectedURL, url)
+ }
+ })
+ }
+}
diff --git a/relay/server/server.go b/relay/server/server.go
new file mode 100644
index 000000000..0036e2390
--- /dev/null
+++ b/relay/server/server.go
@@ -0,0 +1,76 @@
+package server
+
+import (
+ "context"
+ "crypto/tls"
+
+ log "github.com/sirupsen/logrus"
+ "go.opentelemetry.io/otel/metric"
+
+ "github.com/netbirdio/netbird/relay/auth"
+ "github.com/netbirdio/netbird/relay/server/listener"
+ "github.com/netbirdio/netbird/relay/server/listener/ws"
+)
+
+// ListenerConfig is the configuration for the listener.
+// Address: the address to bind the listener to. It could be an address behind a reverse proxy.
+// TLSConfig: the TLS configuration for the listener.
+type ListenerConfig struct {
+ Address string
+ TLSConfig *tls.Config
+}
+
+// Server is the main entry point for the relay server.
+// It is the gate between the WebSocket listener and the Relay server logic.
+// In a new HTTP connection, the server will accept the connection and pass it to the Relay server via the Accept method.
+type Server struct {
+ relay *Relay
+ wSListener listener.Listener
+}
+
+// NewServer creates a new relay server instance.
+// meter: the OpenTelemetry meter
+// exposedAddress: this address will be used as the instance URL. It should be a domain:port format.
+// tlsSupport: if true, the server will support TLS
+// authValidator: the auth validator to use for the server
+func NewServer(meter metric.Meter, exposedAddress string, tlsSupport bool, authValidator auth.Validator) (*Server, error) {
+ relay, err := NewRelay(meter, exposedAddress, tlsSupport, authValidator)
+ if err != nil {
+ return nil, err
+ }
+ return &Server{
+ relay: relay,
+ }, nil
+}
+
+// Listen starts the relay server.
+func (r *Server) Listen(cfg ListenerConfig) error {
+ r.wSListener = &ws.Listener{
+ Address: cfg.Address,
+ TLSConfig: cfg.TLSConfig,
+ }
+
+ wslErr := r.wSListener.Listen(r.relay.Accept)
+ if wslErr != nil {
+ log.Errorf("failed to bind ws server: %s", wslErr)
+ }
+
+ return wslErr
+}
+
+// Shutdown stops the relay server. If there are active connections, they will be closed gracefully. In case of a context,
+// the connections will be forcefully closed.
+func (r *Server) Shutdown(ctx context.Context) (err error) {
+ // stop service new connections
+ if r.wSListener != nil {
+ err = r.wSListener.Shutdown(ctx)
+ }
+
+ r.relay.Shutdown(ctx)
+ return
+}
+
+// InstanceURL returns the instance URL of the relay server.
+func (r *Server) InstanceURL() string {
+ return r.relay.instanceURL
+}
diff --git a/relay/server/store.go b/relay/server/store.go
new file mode 100644
index 000000000..96879dae1
--- /dev/null
+++ b/relay/server/store.go
@@ -0,0 +1,64 @@
+package server
+
+import (
+ "sync"
+)
+
+// Store is a thread-safe store of peers
+// It is used to store the peers that are connected to the relay server
+type Store struct {
+ peers map[string]*Peer // consider to use [32]byte as key. The Peer(id string) would be faster
+ peersLock sync.RWMutex
+}
+
+// NewStore creates a new Store instance
+func NewStore() *Store {
+ return &Store{
+ peers: make(map[string]*Peer),
+ }
+}
+
+// AddPeer adds a peer to the store
+// todo: consider to close peer conn if the peer already exists
+func (s *Store) AddPeer(peer *Peer) {
+ s.peersLock.Lock()
+ defer s.peersLock.Unlock()
+ s.peers[peer.String()] = peer
+}
+
+// DeletePeer deletes a peer from the store
+func (s *Store) DeletePeer(peer *Peer) {
+ s.peersLock.Lock()
+ defer s.peersLock.Unlock()
+
+ dp, ok := s.peers[peer.String()]
+ if !ok {
+ return
+ }
+ if dp != peer {
+ return
+ }
+
+ delete(s.peers, peer.String())
+}
+
+// Peer returns a peer by its ID
+func (s *Store) Peer(id string) (*Peer, bool) {
+ s.peersLock.RLock()
+ defer s.peersLock.RUnlock()
+
+ p, ok := s.peers[id]
+ return p, ok
+}
+
+// Peers returns all the peers in the store
+func (s *Store) Peers() []*Peer {
+ s.peersLock.RLock()
+ defer s.peersLock.RUnlock()
+
+ peers := make([]*Peer, 0, len(s.peers))
+ for _, p := range s.peers {
+ peers = append(peers, p)
+ }
+ return peers
+}
diff --git a/relay/server/store_test.go b/relay/server/store_test.go
new file mode 100644
index 000000000..4a30bc131
--- /dev/null
+++ b/relay/server/store_test.go
@@ -0,0 +1,40 @@
+package server
+
+import (
+ "context"
+ "testing"
+
+ "go.opentelemetry.io/otel"
+
+ "github.com/netbirdio/netbird/relay/metrics"
+)
+
+func TestStore_DeletePeer(t *testing.T) {
+ s := NewStore()
+
+ m, _ := metrics.NewMetrics(context.Background(), otel.Meter(""))
+
+ p := NewPeer(m, []byte("peer_one"), nil, nil)
+ s.AddPeer(p)
+ s.DeletePeer(p)
+ if _, ok := s.Peer(p.String()); ok {
+ t.Errorf("peer was not deleted")
+ }
+}
+
+func TestStore_DeleteDeprecatedPeer(t *testing.T) {
+ s := NewStore()
+
+ m, _ := metrics.NewMetrics(context.Background(), otel.Meter(""))
+
+ p1 := NewPeer(m, []byte("peer_id"), nil, nil)
+ p2 := NewPeer(m, []byte("peer_id"), nil, nil)
+
+ s.AddPeer(p1)
+ s.AddPeer(p2)
+ s.DeletePeer(p1)
+
+ if _, ok := s.Peer(p2.String()); !ok {
+ t.Errorf("second peer was deleted")
+ }
+}
diff --git a/relay/test/benchmark_test.go b/relay/test/benchmark_test.go
new file mode 100644
index 000000000..ec2aa488c
--- /dev/null
+++ b/relay/test/benchmark_test.go
@@ -0,0 +1,386 @@
+package test
+
+import (
+ "context"
+ "crypto/rand"
+ "fmt"
+ "net"
+ "os"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/pion/logging"
+ "github.com/pion/turn/v3"
+ "go.opentelemetry.io/otel"
+
+ "github.com/netbirdio/netbird/relay/auth/allow"
+ "github.com/netbirdio/netbird/relay/auth/hmac"
+ "github.com/netbirdio/netbird/relay/client"
+ "github.com/netbirdio/netbird/relay/server"
+ "github.com/netbirdio/netbird/util"
+)
+
+var (
+ av = &allow.Auth{}
+ hmacTokenStore = &hmac.TokenStore{}
+ pairs = []int{1, 5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100}
+ dataSize = 1024 * 1024 * 10
+)
+
+func TestMain(m *testing.M) {
+ _ = util.InitLog("error", "console")
+ code := m.Run()
+ os.Exit(code)
+}
+
+func TestRelayDataTransfer(t *testing.T) {
+ t.SkipNow() // skip this test on CI because it is a benchmark test
+ testData, err := seedRandomData(dataSize)
+ if err != nil {
+ t.Fatalf("failed to seed random data: %s", err)
+ }
+
+ for _, peerPairs := range pairs {
+ t.Run(fmt.Sprintf("peerPairs-%d", peerPairs), func(t *testing.T) {
+ transfer(t, testData, peerPairs)
+ })
+ }
+}
+
+// TestTurnDataTransfer run turn server:
+// docker run --rm --name coturn -d --network=host coturn/coturn --user test:test
+func TestTurnDataTransfer(t *testing.T) {
+ t.SkipNow() // skip this test on CI because it is a benchmark test
+ testData, err := seedRandomData(dataSize)
+ if err != nil {
+ t.Fatalf("failed to seed random data: %s", err)
+ }
+
+ for _, peerPairs := range pairs {
+ t.Run(fmt.Sprintf("peerPairs-%d", peerPairs), func(t *testing.T) {
+ runTurnTest(t, testData, peerPairs)
+ })
+ }
+}
+
+func transfer(t *testing.T, testData []byte, peerPairs int) {
+ t.Helper()
+ ctx := context.Background()
+ port := 35000 + peerPairs
+ serverAddress := fmt.Sprintf("127.0.0.1:%d", port)
+ serverConnURL := fmt.Sprintf("rel://%s", serverAddress)
+
+ srv, err := server.NewServer(otel.Meter(""), serverConnURL, false, av)
+ if err != nil {
+ t.Fatalf("failed to create server: %s", err)
+ }
+ errChan := make(chan error, 1)
+ go func() {
+ listenCfg := server.ListenerConfig{Address: serverAddress}
+ err := srv.Listen(listenCfg)
+ if err != nil {
+ errChan <- err
+ }
+ }()
+
+ defer func() {
+ err := srv.Shutdown(ctx)
+ if err != nil {
+ t.Errorf("failed to close server: %s", err)
+ }
+ }()
+
+ // wait for server to start
+ if err := waitForServerToStart(errChan); err != nil {
+ t.Fatalf("failed to start server: %s", err)
+ }
+
+ clientsSender := make([]*client.Client, peerPairs)
+ for i := 0; i < cap(clientsSender); i++ {
+ c := client.NewClient(ctx, serverConnURL, hmacTokenStore, "sender-"+fmt.Sprint(i))
+ err := c.Connect()
+ if err != nil {
+ t.Fatalf("failed to connect to server: %s", err)
+ }
+ clientsSender[i] = c
+ }
+
+ clientsReceiver := make([]*client.Client, peerPairs)
+ for i := 0; i < cap(clientsReceiver); i++ {
+ c := client.NewClient(ctx, serverConnURL, hmacTokenStore, "receiver-"+fmt.Sprint(i))
+ err := c.Connect()
+ if err != nil {
+ t.Fatalf("failed to connect to server: %s", err)
+ }
+ clientsReceiver[i] = c
+ }
+
+ connsSender := make([]net.Conn, 0, peerPairs)
+ connsReceiver := make([]net.Conn, 0, peerPairs)
+ for i := 0; i < len(clientsSender); i++ {
+ conn, err := clientsSender[i].OpenConn("receiver-" + fmt.Sprint(i))
+ if err != nil {
+ t.Fatalf("failed to bind channel: %s", err)
+ }
+ connsSender = append(connsSender, conn)
+
+ conn, err = clientsReceiver[i].OpenConn("sender-" + fmt.Sprint(i))
+ if err != nil {
+ t.Fatalf("failed to bind channel: %s", err)
+ }
+ connsReceiver = append(connsReceiver, conn)
+ }
+
+ var transferDuration []time.Duration
+ wg := sync.WaitGroup{}
+ var writeErr error
+ var readErr error
+ for i := 0; i < len(connsSender); i++ {
+ wg.Add(2)
+ start := time.Now()
+ go func(i int) {
+ defer wg.Done()
+ pieceSize := 1024
+ testDataLen := len(testData)
+
+ for j := 0; j < testDataLen; j += pieceSize {
+ end := j + pieceSize
+ if end > testDataLen {
+ end = testDataLen
+ }
+ _, writeErr = connsSender[i].Write(testData[j:end])
+ if writeErr != nil {
+ return
+ }
+ }
+
+ }(i)
+
+ go func(i int, start time.Time) {
+ defer wg.Done()
+ buf := make([]byte, 8192)
+ rcv := 0
+ var n int
+ for receivedSize := 0; receivedSize < len(testData); {
+
+ n, readErr = connsReceiver[i].Read(buf)
+ if readErr != nil {
+ return
+ }
+
+ receivedSize += n
+ rcv += n
+ }
+ transferDuration = append(transferDuration, time.Since(start))
+ }(i, start)
+ }
+
+ wg.Wait()
+
+ if writeErr != nil {
+ t.Fatalf("failed to write to channel: %s", err)
+ }
+
+ if readErr != nil {
+ t.Fatalf("failed to read from channel: %s", err)
+ }
+
+ // calculate the megabytes per second from the average transferDuration against the dataSize
+ var totalDuration time.Duration
+ for _, d := range transferDuration {
+ totalDuration += d
+ }
+ avgDuration := totalDuration / time.Duration(len(transferDuration))
+ mbps := float64(len(testData)) / avgDuration.Seconds() / 1024 / 1024
+ t.Logf("average transfer duration: %s", avgDuration)
+ t.Logf("average transfer speed: %.2f MB/s", mbps)
+
+ for i := 0; i < len(connsSender); i++ {
+ err := connsSender[i].Close()
+ if err != nil {
+ t.Errorf("failed to close connection: %s", err)
+ }
+
+ err = connsReceiver[i].Close()
+ if err != nil {
+ t.Errorf("failed to close connection: %s", err)
+ }
+ }
+}
+
+func runTurnTest(t *testing.T, testData []byte, maxPairs int) {
+ t.Helper()
+ var transferDuration []time.Duration
+ var wg sync.WaitGroup
+
+ for i := 0; i < maxPairs; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ d := runTurnDataTransfer(t, testData)
+ transferDuration = append(transferDuration, d)
+ }()
+
+ }
+ wg.Wait()
+
+ var totalDuration time.Duration
+ for _, d := range transferDuration {
+ totalDuration += d
+ }
+ avgDuration := totalDuration / time.Duration(len(transferDuration))
+ mbps := float64(len(testData)) / avgDuration.Seconds() / 1024 / 1024
+ t.Logf("average transfer duration: %s", avgDuration)
+ t.Logf("average transfer speed: %.2f MB/s", mbps)
+}
+
+func runTurnDataTransfer(t *testing.T, testData []byte) time.Duration {
+ t.Helper()
+ testDataLen := len(testData)
+ relayAddress := "192.168.0.10:3478"
+ conn, err := net.Dial("tcp", relayAddress)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer func(conn net.Conn) {
+ _ = conn.Close()
+ }(conn)
+
+ turnClient, err := getTurnClient(t, relayAddress, conn)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer turnClient.Close()
+
+ relayConn, err := turnClient.Allocate()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer func(relayConn net.PacketConn) {
+ _ = relayConn.Close()
+ }(relayConn)
+
+ receiverConn, err := net.Dial("udp", relayConn.LocalAddr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer func(receiverConn net.Conn) {
+ _ = receiverConn.Close()
+ }(receiverConn)
+
+ var (
+ tb int
+ start time.Time
+ timerInit bool
+ readDone = make(chan struct{})
+ ack = make([]byte, 1)
+ )
+ go func() {
+ defer func() {
+ readDone <- struct{}{}
+ }()
+ buff := make([]byte, 8192)
+ for {
+ n, e := receiverConn.Read(buff)
+ if e != nil {
+ return
+ }
+ if !timerInit {
+ start = time.Now()
+ timerInit = true
+ }
+ tb += n
+ _, _ = receiverConn.Write(ack)
+
+ if tb >= testDataLen {
+ return
+ }
+ }
+ }()
+
+ pieceSize := 1024
+ ackBuff := make([]byte, 1)
+ pipelineSize := 10
+ for j := 0; j < testDataLen; j += pieceSize {
+ end := j + pieceSize
+ if end > testDataLen {
+ end = testDataLen
+ }
+ _, err := relayConn.WriteTo(testData[j:end], receiverConn.LocalAddr())
+ if err != nil {
+ t.Fatalf("failed to write to channel: %s", err)
+ }
+ if pipelineSize == 0 {
+ _, _, _ = relayConn.ReadFrom(ackBuff)
+ } else {
+ pipelineSize--
+ }
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
+ defer cancel()
+ select {
+ case <-readDone:
+ if tb != testDataLen {
+ t.Fatalf("failed to read all data: %d/%d", tb, testDataLen)
+ }
+ case <-ctx.Done():
+ t.Fatal("timeout")
+ }
+ return time.Since(start)
+}
+
+func getTurnClient(t *testing.T, address string, conn net.Conn) (*turn.Client, error) {
+ t.Helper()
+ // Dial TURN Server
+ addrStr := fmt.Sprintf("%s:%d", address, 443)
+
+ fac := logging.NewDefaultLoggerFactory()
+ //fac.DefaultLogLevel = logging.LogLevelTrace
+
+ // Start a new TURN Client and wrap our net.Conn in a STUNConn
+ // This allows us to simulate datagram based communication over a net.Conn
+ cfg := &turn.ClientConfig{
+ TURNServerAddr: address,
+ Conn: turn.NewSTUNConn(conn),
+ Username: "test",
+ Password: "test",
+ LoggerFactory: fac,
+ }
+
+ client, err := turn.NewClient(cfg)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create TURN client for server %s: %s", addrStr, err)
+ }
+
+ // Start listening on the conn provided.
+ err = client.Listen()
+ if err != nil {
+ client.Close()
+ return nil, fmt.Errorf("failed to listen on TURN client for server %s: %s", addrStr, err)
+ }
+
+ return client, nil
+}
+
+func seedRandomData(size int) ([]byte, error) {
+ token := make([]byte, size)
+ _, err := rand.Read(token)
+ if err != nil {
+ return nil, err
+ }
+ return token, nil
+}
+
+func waitForServerToStart(errChan chan error) error {
+ select {
+ case err := <-errChan:
+ if err != nil {
+ return err
+ }
+ case <-time.After(300 * time.Millisecond):
+ return nil
+ }
+ return nil
+}
diff --git a/relay/testec2/main.go b/relay/testec2/main.go
new file mode 100644
index 000000000..0c8099a5e
--- /dev/null
+++ b/relay/testec2/main.go
@@ -0,0 +1,258 @@
+//go:build linux || darwin
+
+package main
+
+import (
+ "crypto/rand"
+ "flag"
+ "fmt"
+ "net"
+ "os"
+ "sync"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+
+ "github.com/netbirdio/netbird/util"
+)
+
+const (
+ errMsgFailedReadTCP = "failed to read from tcp: %s"
+)
+
+var (
+ dataSize = 1024 * 1024 * 50 // 50MB
+ pairs = []int{1, 5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100}
+ signalListenAddress = ":8081"
+
+ relaySrvAddress string
+ turnSrvAddress string
+ signalURL string
+ udpListener string // used for TURN test
+)
+
+type testResult struct {
+ numOfPairs int
+ duration time.Duration
+ speed float64
+}
+
+func (tr testResult) Speed() string {
+ speed := tr.speed
+ var unit string
+
+ switch {
+ case speed < 1024:
+ unit = "B/s"
+ case speed < 1048576:
+ speed /= 1024
+ unit = "KB/s"
+ case speed < 1073741824:
+ speed /= 1048576
+ unit = "MB/s"
+ default:
+ speed /= 1073741824
+ unit = "GB/s"
+ }
+
+ return fmt.Sprintf("%.2f %s", speed, unit)
+}
+
+func seedRandomData(size int) ([]byte, error) {
+ token := make([]byte, size)
+ _, err := rand.Read(token)
+ if err != nil {
+ return nil, err
+ }
+ return token, nil
+}
+
+func avg(transferDuration []time.Duration) (time.Duration, float64) {
+ var totalDuration time.Duration
+ for _, d := range transferDuration {
+ totalDuration += d
+ }
+ avgDuration := totalDuration / time.Duration(len(transferDuration))
+ bps := float64(dataSize) / avgDuration.Seconds()
+ return avgDuration, bps
+}
+
+func RelayReceiverMain() []testResult {
+ testResults := make([]testResult, 0, len(pairs))
+ for _, p := range pairs {
+ tr := testResult{numOfPairs: p}
+ td := relayReceive(relaySrvAddress, p)
+ tr.duration, tr.speed = avg(td)
+
+ testResults = append(testResults, tr)
+ }
+
+ return testResults
+}
+
+func RelaySenderMain() {
+ log.Infof("starting sender")
+ log.Infof("starting seed phase")
+
+ testData, err := seedRandomData(dataSize)
+ if err != nil {
+ log.Fatalf("failed to seed random data: %s", err)
+ }
+
+ log.Infof("data size: %d", len(testData))
+
+ for n, p := range pairs {
+ log.Infof("running test with %d pairs", p)
+ relayTransfer(relaySrvAddress, testData, p)
+
+ // grant time to prepare new receivers
+ if n < len(pairs)-1 {
+ time.Sleep(3 * time.Second)
+ }
+ }
+
+}
+
+// TRUNSenderMain is the sender
+// - allocate turn clients
+// - send relayed addresses to signal server in batch
+// - wait for signal server to send back addresses in a map
+// - send test data to each address in parallel
+func TRUNSenderMain() {
+ log.Infof("starting TURN sender test")
+
+ log.Infof("starting seed random data: %d", dataSize)
+ testData, err := seedRandomData(dataSize)
+ if err != nil {
+ log.Fatalf("failed to seed random data: %s", err)
+ }
+
+ ss := SignalClient{signalURL}
+
+ for _, p := range pairs {
+ log.Infof("running test with %d pairs", p)
+ turnSender := &TurnSender{}
+
+ createTurnConns(p, turnSender)
+
+ log.Infof("send addresses via signal server: %d", len(turnSender.addresses))
+ clientAddresses, err := ss.SendAddress(turnSender.addresses)
+ if err != nil {
+ log.Fatalf("failed to send address: %s", err)
+ }
+ log.Infof("received addresses: %v", clientAddresses.Address)
+
+ createSenderDevices(turnSender, clientAddresses)
+
+ log.Infof("waiting for tcpListeners to be ready")
+ time.Sleep(2 * time.Second)
+
+ tcpConns := make([]net.Conn, 0, len(turnSender.devices))
+ for i := range turnSender.devices {
+ addr := fmt.Sprintf("10.0.%d.2:9999", i)
+ log.Infof("dialing: %s", addr)
+ tcpConn, err := net.Dial("tcp", addr)
+ if err != nil {
+ log.Fatalf("failed to dial tcp: %s", err)
+ }
+ tcpConns = append(tcpConns, tcpConn)
+ }
+
+ log.Infof("start test data transfer for %d pairs", p)
+ testDataLen := len(testData)
+ wg := sync.WaitGroup{}
+ wg.Add(len(tcpConns))
+ for i, tcpConn := range tcpConns {
+ log.Infof("sending test data to device: %d", i)
+ go runTurnWriting(tcpConn, testData, testDataLen, &wg)
+ }
+ wg.Wait()
+
+ for _, d := range turnSender.devices {
+ _ = d.Close()
+ }
+
+ log.Infof("test finished with %d pairs", p)
+ }
+}
+
+func TURNReaderMain() []testResult {
+ log.Infof("starting TURN receiver test")
+ si := NewSignalService()
+ go func() {
+ log.Infof("starting signal server")
+ err := si.Listen(signalListenAddress)
+ if err != nil {
+ log.Errorf("failed to listen: %s", err)
+ }
+ }()
+
+ testResults := make([]testResult, 0, len(pairs))
+ for range pairs {
+ addresses := <-si.AddressesChan
+ instanceNumber := len(addresses)
+ log.Infof("received addresses: %d", instanceNumber)
+
+ turnReceiver := &TurnReceiver{}
+ err := createDevices(addresses, turnReceiver)
+ if err != nil {
+ log.Fatalf("%s", err)
+ }
+
+ // send client addresses back via signal server
+ si.ClientAddressChan <- turnReceiver.clientAddresses
+
+ durations := make(chan time.Duration, instanceNumber)
+ for _, device := range turnReceiver.devices {
+ go runTurnReading(device, durations)
+ }
+
+ durationsList := make([]time.Duration, 0, instanceNumber)
+ for d := range durations {
+ durationsList = append(durationsList, d)
+ if len(durationsList) == instanceNumber {
+ close(durations)
+ }
+ }
+
+ avgDuration, avgSpeed := avg(durationsList)
+ ts := testResult{
+ numOfPairs: len(durationsList),
+ duration: avgDuration,
+ speed: avgSpeed,
+ }
+ testResults = append(testResults, ts)
+
+ for _, d := range turnReceiver.devices {
+ _ = d.Close()
+ }
+ }
+ return testResults
+}
+
+func main() {
+ var mode string
+
+ _ = util.InitLog("debug", "console")
+ flag.StringVar(&mode, "mode", "sender", "sender or receiver mode")
+ flag.Parse()
+
+ relaySrvAddress = os.Getenv("TEST_RELAY_SERVER") // rel://ip:port
+ turnSrvAddress = os.Getenv("TEST_TURN_SERVER") // ip:3478
+ signalURL = os.Getenv("TEST_SIGNAL_URL") // http://receiver_ip:8081
+ udpListener = os.Getenv("TEST_UDP_LISTENER") // IP:0
+
+ if mode == "receiver" {
+ relayResult := RelayReceiverMain()
+ turnResults := TURNReaderMain()
+ for i := 0; i < len(turnResults); i++ {
+ log.Infof("pairs: %d,\tRelay speed:\t%s,\trelay duration:\t%s", relayResult[i].numOfPairs, relayResult[i].Speed(), relayResult[i].duration)
+ log.Infof("pairs: %d,\tTURN speed:\t%s,\tturn duration:\t%s", turnResults[i].numOfPairs, turnResults[i].Speed(), turnResults[i].duration)
+ }
+ } else {
+ RelaySenderMain()
+ // grant time for receiver to start
+ time.Sleep(3 * time.Second)
+ TRUNSenderMain()
+ }
+}
diff --git a/relay/testec2/relay.go b/relay/testec2/relay.go
new file mode 100644
index 000000000..93d084387
--- /dev/null
+++ b/relay/testec2/relay.go
@@ -0,0 +1,176 @@
+//go:build linux || darwin
+
+package main
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "sync"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+
+ "github.com/netbirdio/netbird/relay/auth/hmac"
+ "github.com/netbirdio/netbird/relay/client"
+)
+
+var (
+ hmacTokenStore = &hmac.TokenStore{}
+)
+
+func relayTransfer(serverConnURL string, testData []byte, peerPairs int) {
+ connsSender := prepareConnsSender(serverConnURL, peerPairs)
+ defer func() {
+ for i := 0; i < len(connsSender); i++ {
+ err := connsSender[i].Close()
+ if err != nil {
+ log.Errorf("failed to close connection: %s", err)
+ }
+ }
+ }()
+
+ wg := sync.WaitGroup{}
+ wg.Add(len(connsSender))
+ for _, conn := range connsSender {
+ go func(conn net.Conn) {
+ defer wg.Done()
+ runWriter(conn, testData)
+ }(conn)
+ }
+ wg.Wait()
+}
+
+func runWriter(conn net.Conn, testData []byte) {
+ si := NewStartInidication(time.Now(), len(testData))
+ _, err := conn.Write(si)
+ if err != nil {
+ log.Errorf("failed to write to channel: %s", err)
+ return
+ }
+ log.Infof("sent start indication")
+
+ pieceSize := 1024
+ testDataLen := len(testData)
+
+ for j := 0; j < testDataLen; j += pieceSize {
+ end := j + pieceSize
+ if end > testDataLen {
+ end = testDataLen
+ }
+ _, writeErr := conn.Write(testData[j:end])
+ if writeErr != nil {
+ log.Errorf("failed to write to channel: %s", writeErr)
+ return
+ }
+ }
+}
+
+func prepareConnsSender(serverConnURL string, peerPairs int) []net.Conn {
+ ctx := context.Background()
+ clientsSender := make([]*client.Client, peerPairs)
+ for i := 0; i < cap(clientsSender); i++ {
+ c := client.NewClient(ctx, serverConnURL, hmacTokenStore, "sender-"+fmt.Sprint(i))
+ if err := c.Connect(); err != nil {
+ log.Fatalf("failed to connect to server: %s", err)
+ }
+ clientsSender[i] = c
+ }
+
+ connsSender := make([]net.Conn, 0, peerPairs)
+ for i := 0; i < len(clientsSender); i++ {
+ conn, err := clientsSender[i].OpenConn("receiver-" + fmt.Sprint(i))
+ if err != nil {
+ log.Fatalf("failed to bind channel: %s", err)
+ }
+ connsSender = append(connsSender, conn)
+ }
+ return connsSender
+}
+
+func relayReceive(serverConnURL string, peerPairs int) []time.Duration {
+ connsReceiver := prepareConnsReceiver(serverConnURL, peerPairs)
+ defer func() {
+ for i := 0; i < len(connsReceiver); i++ {
+ if err := connsReceiver[i].Close(); err != nil {
+ log.Errorf("failed to close connection: %s", err)
+ }
+ }
+ }()
+
+ durations := make(chan time.Duration, len(connsReceiver))
+ wg := sync.WaitGroup{}
+ for _, conn := range connsReceiver {
+ wg.Add(1)
+ go func(conn net.Conn) {
+ defer wg.Done()
+ duration := runReader(conn)
+ durations <- duration
+ }(conn)
+ }
+ wg.Wait()
+
+ durationsList := make([]time.Duration, 0, len(connsReceiver))
+ for d := range durations {
+ durationsList = append(durationsList, d)
+ if len(durationsList) == len(connsReceiver) {
+ close(durations)
+ }
+ }
+
+ return durationsList
+}
+
+func runReader(conn net.Conn) time.Duration {
+ buf := make([]byte, 8192)
+
+ n, readErr := conn.Read(buf)
+ if readErr != nil {
+ log.Errorf("failed to read from channel: %s", readErr)
+ return 0
+ }
+
+ si := DecodeStartIndication(buf[:n])
+ log.Infof("received start indication: %v", si)
+
+ receivedSize, err := conn.Read(buf)
+ if err != nil {
+ log.Fatalf("failed to read from relay: %s", err)
+ }
+ now := time.Now()
+
+ rcv := 0
+ for receivedSize < si.TransferSize {
+ n, readErr = conn.Read(buf)
+ if readErr != nil {
+ log.Errorf("failed to read from channel: %s", readErr)
+ return 0
+ }
+
+ receivedSize += n
+ rcv += n
+ }
+ return time.Since(now)
+}
+
+func prepareConnsReceiver(serverConnURL string, peerPairs int) []net.Conn {
+ clientsReceiver := make([]*client.Client, peerPairs)
+ for i := 0; i < cap(clientsReceiver); i++ {
+ c := client.NewClient(context.Background(), serverConnURL, hmacTokenStore, "receiver-"+fmt.Sprint(i))
+ err := c.Connect()
+ if err != nil {
+ log.Fatalf("failed to connect to server: %s", err)
+ }
+ clientsReceiver[i] = c
+ }
+
+ connsReceiver := make([]net.Conn, 0, peerPairs)
+ for i := 0; i < len(clientsReceiver); i++ {
+ conn, err := clientsReceiver[i].OpenConn("sender-" + fmt.Sprint(i))
+ if err != nil {
+ log.Fatalf("failed to bind channel: %s", err)
+ }
+ connsReceiver = append(connsReceiver, conn)
+ }
+ return connsReceiver
+}
diff --git a/relay/testec2/signal.go b/relay/testec2/signal.go
new file mode 100644
index 000000000..fe93a2fe2
--- /dev/null
+++ b/relay/testec2/signal.go
@@ -0,0 +1,91 @@
+//go:build linux || darwin
+
+package main
+
+import (
+ "bytes"
+ "encoding/json"
+ "net/http"
+
+ log "github.com/sirupsen/logrus"
+)
+
+type PeerAddr struct {
+ Address []string
+}
+
+type ClientPeerAddr struct {
+ Address map[string]string
+}
+
+type Signal struct {
+ AddressesChan chan []string
+ ClientAddressChan chan map[string]string
+}
+
+func NewSignalService() *Signal {
+ return &Signal{
+ AddressesChan: make(chan []string),
+ ClientAddressChan: make(chan map[string]string),
+ }
+}
+
+func (rs *Signal) Listen(listenAddr string) error {
+ http.HandleFunc("/", rs.onNewAddresses)
+ return http.ListenAndServe(listenAddr, nil)
+}
+
+func (rs *Signal) onNewAddresses(w http.ResponseWriter, r *http.Request) {
+ var msg PeerAddr
+ err := json.NewDecoder(r.Body).Decode(&msg)
+ if err != nil {
+ log.Errorf("Error decoding message: %v", err)
+ }
+
+ log.Infof("received addresses: %d", len(msg.Address))
+ rs.AddressesChan <- msg.Address
+ clientAddresses := <-rs.ClientAddressChan
+
+ respMsg := ClientPeerAddr{
+ Address: clientAddresses,
+ }
+ data, err := json.Marshal(respMsg)
+ if err != nil {
+ log.Errorf("Error marshalling message: %v", err)
+ return
+ }
+
+ _, err = w.Write(data)
+ if err != nil {
+ log.Errorf("Error writing response: %v", err)
+ }
+}
+
+type SignalClient struct {
+ SignalURL string
+}
+
+func (ss SignalClient) SendAddress(addresses []string) (*ClientPeerAddr, error) {
+ msg := PeerAddr{
+ Address: addresses,
+ }
+ data, err := json.Marshal(msg)
+ if err != nil {
+ return nil, err
+ }
+
+ response, err := http.Post(ss.SignalURL, "application/json", bytes.NewBuffer(data))
+ if err != nil {
+ return nil, err
+ }
+
+ defer response.Body.Close()
+
+ log.Debugf("wait for signal response")
+ var respPeerAddress ClientPeerAddr
+ err = json.NewDecoder(response.Body).Decode(&respPeerAddress)
+ if err != nil {
+ return nil, err
+ }
+ return &respPeerAddress, nil
+}
diff --git a/relay/testec2/start_msg.go b/relay/testec2/start_msg.go
new file mode 100644
index 000000000..19b65380b
--- /dev/null
+++ b/relay/testec2/start_msg.go
@@ -0,0 +1,39 @@
+//go:build linux || darwin
+
+package main
+
+import (
+ "bytes"
+ "encoding/gob"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+)
+
+type StartIndication struct {
+ Started time.Time
+ TransferSize int
+}
+
+func NewStartInidication(started time.Time, transferSize int) []byte {
+ si := StartIndication{
+ Started: started,
+ TransferSize: transferSize,
+ }
+
+ var data bytes.Buffer
+ err := gob.NewEncoder(&data).Encode(si)
+ if err != nil {
+ log.Fatal("encode error:", err)
+ }
+ return data.Bytes()
+}
+
+func DecodeStartIndication(data []byte) StartIndication {
+ var si StartIndication
+ err := gob.NewDecoder(bytes.NewReader(data)).Decode(&si)
+ if err != nil {
+ log.Fatal("decode error:", err)
+ }
+ return si
+}
diff --git a/relay/testec2/tun/proxy.go b/relay/testec2/tun/proxy.go
new file mode 100644
index 000000000..7d84bece7
--- /dev/null
+++ b/relay/testec2/tun/proxy.go
@@ -0,0 +1,72 @@
+//go:build linux || darwin
+
+package tun
+
+import (
+ "net"
+ "sync/atomic"
+
+ log "github.com/sirupsen/logrus"
+)
+
+type Proxy struct {
+ Device *Device
+ PConn net.PacketConn
+ DstAddr net.Addr
+ shutdownFlag atomic.Bool
+}
+
+func (p *Proxy) Start() {
+ go p.readFromDevice()
+ go p.readFromConn()
+}
+
+func (p *Proxy) Close() {
+ p.shutdownFlag.Store(true)
+}
+
+func (p *Proxy) readFromDevice() {
+ buf := make([]byte, 1500)
+ for {
+ n, err := p.Device.Read(buf)
+ if err != nil {
+ if p.shutdownFlag.Load() {
+ return
+ }
+ log.Errorf("failed to read from device: %s", err)
+ return
+ }
+
+ _, err = p.PConn.WriteTo(buf[:n], p.DstAddr)
+ if err != nil {
+ if p.shutdownFlag.Load() {
+ return
+ }
+ log.Errorf("failed to write to conn: %s", err)
+ return
+ }
+ }
+}
+
+func (p *Proxy) readFromConn() {
+ buf := make([]byte, 1500)
+ for {
+ n, _, err := p.PConn.ReadFrom(buf)
+ if err != nil {
+ if p.shutdownFlag.Load() {
+ return
+ }
+ log.Errorf("failed to read from conn: %s", err)
+ return
+ }
+
+ _, err = p.Device.Write(buf[:n])
+ if err != nil {
+ if p.shutdownFlag.Load() {
+ return
+ }
+ log.Errorf("failed to write to device: %s", err)
+ return
+ }
+ }
+}
diff --git a/relay/testec2/tun/tun.go b/relay/testec2/tun/tun.go
new file mode 100644
index 000000000..5580785ce
--- /dev/null
+++ b/relay/testec2/tun/tun.go
@@ -0,0 +1,110 @@
+//go:build linux || darwin
+
+package tun
+
+import (
+ "net"
+
+ log "github.com/sirupsen/logrus"
+ "github.com/songgao/water"
+ "github.com/vishvananda/netlink"
+)
+
+type Device struct {
+ Name string
+ IP string
+ PConn net.PacketConn
+ DstAddr net.Addr
+
+ iFace *water.Interface
+ proxy *Proxy
+}
+
+func (d *Device) Up() error {
+ cfg := water.Config{
+ DeviceType: water.TUN,
+ PlatformSpecificParams: water.PlatformSpecificParams{
+ Name: d.Name,
+ },
+ }
+ iFace, err := water.New(cfg)
+ if err != nil {
+ return err
+ }
+ d.iFace = iFace
+
+ err = d.assignIP()
+ if err != nil {
+ return err
+ }
+
+ err = d.bringUp()
+ if err != nil {
+ return err
+ }
+
+ d.proxy = &Proxy{
+ Device: d,
+ PConn: d.PConn,
+ DstAddr: d.DstAddr,
+ }
+ d.proxy.Start()
+ return nil
+}
+
+func (d *Device) Close() error {
+ if d.proxy != nil {
+ d.proxy.Close()
+ }
+ if d.iFace != nil {
+ return d.iFace.Close()
+ }
+ return nil
+}
+
+func (d *Device) Read(b []byte) (int, error) {
+ return d.iFace.Read(b)
+}
+
+func (d *Device) Write(b []byte) (int, error) {
+ return d.iFace.Write(b)
+}
+
+func (d *Device) assignIP() error {
+ iface, err := netlink.LinkByName(d.Name)
+ if err != nil {
+ log.Errorf("failed to get TUN device: %v", err)
+ return err
+ }
+
+ ip := net.IPNet{
+ IP: net.ParseIP(d.IP),
+ Mask: net.CIDRMask(24, 32),
+ }
+
+ addr := &netlink.Addr{
+ IPNet: &ip,
+ }
+ err = netlink.AddrAdd(iface, addr)
+ if err != nil {
+ log.Errorf("failed to add IP address: %v", err)
+ return err
+ }
+ return nil
+}
+
+func (d *Device) bringUp() error {
+ iface, err := netlink.LinkByName(d.Name)
+ if err != nil {
+ log.Errorf("failed to get device: %v", err)
+ return err
+ }
+
+ // Bring the interface up
+ err = netlink.LinkSetUp(iface)
+ if err != nil {
+ log.Errorf("failed to set device up: %v", err)
+ return err
+ }
+ return nil
+}
diff --git a/relay/testec2/turn.go b/relay/testec2/turn.go
new file mode 100644
index 000000000..8beb40423
--- /dev/null
+++ b/relay/testec2/turn.go
@@ -0,0 +1,181 @@
+//go:build linux || darwin
+
+package main
+
+import (
+ "fmt"
+ "net"
+ "sync"
+ "time"
+
+ "github.com/netbirdio/netbird/relay/testec2/tun"
+
+ log "github.com/sirupsen/logrus"
+)
+
+type TurnReceiver struct {
+ conns []*net.UDPConn
+ clientAddresses map[string]string
+ devices []*tun.Device
+}
+
+type TurnSender struct {
+ turnConns map[string]*TurnConn
+ addresses []string
+ devices []*tun.Device
+}
+
+func runTurnWriting(tcpConn net.Conn, testData []byte, testDataLen int, wg *sync.WaitGroup) {
+ defer wg.Done()
+ defer tcpConn.Close()
+
+ log.Infof("start to sending test data: %s", tcpConn.RemoteAddr())
+
+ si := NewStartInidication(time.Now(), testDataLen)
+ _, err := tcpConn.Write(si)
+ if err != nil {
+ log.Errorf("failed to write to tcp: %s", err)
+ return
+ }
+
+ pieceSize := 1024
+ for j := 0; j < testDataLen; j += pieceSize {
+ end := j + pieceSize
+ if end > testDataLen {
+ end = testDataLen
+ }
+ _, writeErr := tcpConn.Write(testData[j:end])
+ if writeErr != nil {
+ log.Errorf("failed to write to tcp conn: %s", writeErr)
+ return
+ }
+ }
+
+ // grant time to flush out packages
+ time.Sleep(3 * time.Second)
+}
+
+func createSenderDevices(sender *TurnSender, clientAddresses *ClientPeerAddr) {
+ var i int
+ devices := make([]*tun.Device, 0, len(clientAddresses.Address))
+ for k, v := range clientAddresses.Address {
+ tc, ok := sender.turnConns[k]
+ if !ok {
+ log.Fatalf("failed to find turn conn: %s", k)
+ }
+
+ addr, err := net.ResolveUDPAddr("udp", v)
+ if err != nil {
+ log.Fatalf("failed to resolve udp address: %s", err)
+ }
+ device := &tun.Device{
+ Name: fmt.Sprintf("mtun-sender-%d", i),
+ IP: fmt.Sprintf("10.0.%d.1", i),
+ PConn: tc.relayConn,
+ DstAddr: addr,
+ }
+
+ err = device.Up()
+ if err != nil {
+ log.Fatalf("failed to bring up device: %s", err)
+ }
+
+ devices = append(devices, device)
+ i++
+ }
+ sender.devices = devices
+}
+
+func createTurnConns(p int, sender *TurnSender) {
+ turnConns := make(map[string]*TurnConn)
+ addresses := make([]string, 0, len(pairs))
+ for i := 0; i < p; i++ {
+ tc := AllocateTurnClient(turnSrvAddress)
+ log.Infof("allocated turn client: %s", tc.Address().String())
+ turnConns[tc.Address().String()] = tc
+ addresses = append(addresses, tc.Address().String())
+ }
+
+ sender.turnConns = turnConns
+ sender.addresses = addresses
+}
+
+func runTurnReading(d *tun.Device, durations chan time.Duration) {
+ tcpListener, err := net.Listen("tcp", d.IP+":9999")
+ if err != nil {
+ log.Fatalf("failed to listen on tcp: %s", err)
+ }
+ log := log.WithField("device", tcpListener.Addr())
+
+ tcpConn, err := tcpListener.Accept()
+ if err != nil {
+ _ = tcpListener.Close()
+ log.Fatalf("failed to accept connection: %s", err)
+ }
+ log.Infof("remote peer connected")
+
+ buf := make([]byte, 103)
+ n, err := tcpConn.Read(buf)
+ if err != nil {
+ _ = tcpListener.Close()
+ log.Fatalf(errMsgFailedReadTCP, err)
+ }
+
+ si := DecodeStartIndication(buf[:n])
+ log.Infof("received start indication: %v, %d", si, n)
+
+ buf = make([]byte, 8192)
+ i, err := tcpConn.Read(buf)
+ if err != nil {
+ _ = tcpListener.Close()
+ log.Fatalf(errMsgFailedReadTCP, err)
+ }
+ now := time.Now()
+ for i < si.TransferSize {
+ n, err := tcpConn.Read(buf)
+ if err != nil {
+ _ = tcpListener.Close()
+ log.Fatalf(errMsgFailedReadTCP, err)
+ }
+ i += n
+ }
+ durations <- time.Since(now)
+}
+
+func createDevices(addresses []string, receiver *TurnReceiver) error {
+ receiver.conns = make([]*net.UDPConn, 0, len(addresses))
+ receiver.clientAddresses = make(map[string]string, len(addresses))
+ receiver.devices = make([]*tun.Device, 0, len(addresses))
+ for i, addr := range addresses {
+ localAddr, err := net.ResolveUDPAddr("udp", udpListener)
+ if err != nil {
+ return fmt.Errorf("failed to resolve UDP address: %s", err)
+ }
+
+ conn, err := net.ListenUDP("udp", localAddr)
+ if err != nil {
+ return fmt.Errorf("failed to create UDP connection: %s", err)
+ }
+
+ receiver.conns = append(receiver.conns, conn)
+ receiver.clientAddresses[addr] = conn.LocalAddr().String()
+
+ dstAddr, err := net.ResolveUDPAddr("udp", addr)
+ if err != nil {
+ return fmt.Errorf("failed to resolve address: %s", err)
+ }
+
+ device := &tun.Device{
+ Name: fmt.Sprintf("mtun-%d", i),
+ IP: fmt.Sprintf("10.0.%d.2", i),
+ PConn: conn,
+ DstAddr: dstAddr,
+ }
+
+ if err = device.Up(); err != nil {
+ return fmt.Errorf("failed to bring up device: %s, %s", device.Name, err)
+ }
+ receiver.devices = append(receiver.devices, device)
+ }
+ return nil
+}
diff --git a/relay/testec2/turn_allocator.go b/relay/testec2/turn_allocator.go
new file mode 100644
index 000000000..fd86208df
--- /dev/null
+++ b/relay/testec2/turn_allocator.go
@@ -0,0 +1,83 @@
+//go:build linux || darwin
+
+package main
+
+import (
+ "fmt"
+ "net"
+
+ "github.com/pion/logging"
+ "github.com/pion/turn/v3"
+ log "github.com/sirupsen/logrus"
+)
+
+type TurnConn struct {
+ conn net.Conn
+ turnClient *turn.Client
+ relayConn net.PacketConn
+}
+
+func (tc *TurnConn) Address() net.Addr {
+ return tc.relayConn.LocalAddr()
+}
+
+func (tc *TurnConn) Close() {
+ _ = tc.relayConn.Close()
+ tc.turnClient.Close()
+ _ = tc.conn.Close()
+}
+
+func AllocateTurnClient(serverAddr string) *TurnConn {
+ conn, err := net.Dial("tcp", serverAddr)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ turnClient, err := getTurnClient(serverAddr, conn)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ relayConn, err := turnClient.Allocate()
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ return &TurnConn{
+ conn: conn,
+ turnClient: turnClient,
+ relayConn: relayConn,
+ }
+}
+
+func getTurnClient(address string, conn net.Conn) (*turn.Client, error) {
+ // Dial TURN Server
+ addrStr := fmt.Sprintf("%s:%d", address, 443)
+
+ fac := logging.NewDefaultLoggerFactory()
+ //fac.DefaultLogLevel = logging.LogLevelTrace
+
+ // Start a new TURN Client and wrap our net.Conn in a STUNConn
+ // This allows us to simulate datagram based communication over a net.Conn
+ cfg := &turn.ClientConfig{
+ TURNServerAddr: address,
+ Conn: turn.NewSTUNConn(conn),
+ Username: "test",
+ Password: "test",
+ LoggerFactory: fac,
+ }
+
+ client, err := turn.NewClient(cfg)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create TURN client for server %s: %s", addrStr, err)
+ }
+
+ // Start listening on the conn provided.
+ err = client.Listen()
+ if err != nil {
+ client.Close()
+ return nil, fmt.Errorf("failed to listen on TURN client for server %s: %s", addrStr, err)
+ }
+
+ return client, nil
+}
diff --git a/signal/client/client.go b/signal/client/client.go
index 9d99b3677..ced3fb7d0 100644
--- a/signal/client/client.go
+++ b/signal/client/client.go
@@ -51,11 +51,10 @@ func UnMarshalCredential(msg *proto.Message) (*Credential, error) {
}
// MarshalCredential marshal a Credential instance and returns a Message object
-func MarshalCredential(myKey wgtypes.Key, myPort int, remoteKey wgtypes.Key, credential *Credential, t proto.Body_Type,
- rosenpassPubKey []byte, rosenpassAddr string) (*proto.Message, error) {
+func MarshalCredential(myKey wgtypes.Key, myPort int, remoteKey string, credential *Credential, t proto.Body_Type, rosenpassPubKey []byte, rosenpassAddr string, relaySrvAddress string) (*proto.Message, error) {
return &proto.Message{
Key: myKey.PublicKey().String(),
- RemoteKey: remoteKey.String(),
+ RemoteKey: remoteKey,
Body: &proto.Body{
Type: t,
Payload: fmt.Sprintf("%s:%s", credential.UFrag, credential.Pwd),
@@ -65,6 +64,7 @@ func MarshalCredential(myKey wgtypes.Key, myPort int, remoteKey wgtypes.Key, cre
RosenpassPubKey: rosenpassPubKey,
RosenpassServerAddr: rosenpassAddr,
},
+ RelayServerAddress: relaySrvAddress,
},
}, nil
}
diff --git a/signal/proto/signalexchange.pb.go b/signal/proto/signalexchange.pb.go
index 782c45da1..30f704c6f 100644
--- a/signal/proto/signalexchange.pb.go
+++ b/signal/proto/signalexchange.pb.go
@@ -1,15 +1,15 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.26.0
-// protoc v3.12.4
+// protoc v3.21.12
// source: signalexchange.proto
package proto
import (
- _ "github.com/golang/protobuf/protoc-gen-go/descriptor"
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
+ _ "google.golang.org/protobuf/types/descriptorpb"
reflect "reflect"
sync "sync"
)
@@ -225,6 +225,8 @@ type Body struct {
FeaturesSupported []uint32 `protobuf:"varint,6,rep,packed,name=featuresSupported,proto3" json:"featuresSupported,omitempty"`
// RosenpassConfig is a Rosenpass config of the remote peer our peer tries to connect to
RosenpassConfig *RosenpassConfig `protobuf:"bytes,7,opt,name=rosenpassConfig,proto3" json:"rosenpassConfig,omitempty"`
+ // relayServerAddress is an IP:port of the relay server
+ RelayServerAddress string `protobuf:"bytes,8,opt,name=relayServerAddress,proto3" json:"relayServerAddress,omitempty"`
}
func (x *Body) Reset() {
@@ -308,6 +310,13 @@ func (x *Body) GetRosenpassConfig() *RosenpassConfig {
return nil
}
+func (x *Body) GetRelayServerAddress() string {
+ if x != nil {
+ return x.RelayServerAddress
+ }
+ return ""
+}
+
// Mode indicates a connection mode
type Mode struct {
state protoimpl.MessageState
@@ -431,7 +440,7 @@ var file_signalexchange_proto_rawDesc = []byte{
0x52, 0x09, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x4b, 0x65, 0x79, 0x12, 0x28, 0x0a, 0x04, 0x62,
0x6f, 0x64, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x73, 0x69, 0x67, 0x6e,
0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f, 0x64, 0x79, 0x52,
- 0x04, 0x62, 0x6f, 0x64, 0x79, 0x22, 0xf6, 0x02, 0x0a, 0x04, 0x42, 0x6f, 0x64, 0x79, 0x12, 0x2d,
+ 0x04, 0x62, 0x6f, 0x64, 0x79, 0x22, 0xa6, 0x03, 0x0a, 0x04, 0x42, 0x6f, 0x64, 0x79, 0x12, 0x2d,
0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x73,
0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f,
0x64, 0x79, 0x2e, 0x54, 0x79, 0x70, 0x65, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x18, 0x0a,
@@ -451,7 +460,10 @@ var file_signalexchange_proto_rawDesc = []byte{
0x20, 0x01, 0x28, 0x0b, 0x32, 0x1f, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63,
0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x52, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x43,
0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73,
- 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x36, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x09,
+ 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x2e, 0x0a, 0x12, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53,
+ 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x08, 0x20, 0x01,
+ 0x28, 0x09, 0x52, 0x12, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41,
+ 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x22, 0x36, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x09,
0x0a, 0x05, 0x4f, 0x46, 0x46, 0x45, 0x52, 0x10, 0x00, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x4e, 0x53,
0x57, 0x45, 0x52, 0x10, 0x01, 0x12, 0x0d, 0x0a, 0x09, 0x43, 0x41, 0x4e, 0x44, 0x49, 0x44, 0x41,
0x54, 0x45, 0x10, 0x02, 0x12, 0x08, 0x0a, 0x04, 0x4d, 0x4f, 0x44, 0x45, 0x10, 0x04, 0x22, 0x2e,
diff --git a/signal/proto/signalexchange.proto b/signal/proto/signalexchange.proto
index a8c4c309c..4431edd7c 100644
--- a/signal/proto/signalexchange.proto
+++ b/signal/proto/signalexchange.proto
@@ -60,6 +60,9 @@ message Body {
// RosenpassConfig is a Rosenpass config of the remote peer our peer tries to connect to
RosenpassConfig rosenpassConfig = 7;
+
+ // relayServerAddress is url of the relay server
+ string relayServerAddress = 8;
}
// Mode indicates a connection mode
diff --git a/util/net/dialer_nonios.go b/util/net/dialer_nonios.go
index 7a5de7587..4032a75c0 100644
--- a/util/net/dialer_nonios.go
+++ b/util/net/dialer_nonios.go
@@ -49,6 +49,8 @@ func RemoveDialerHooks() {
// DialContext wraps the net.Dialer's DialContext method to use the custom connection
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
+ log.Debugf("Dialing %s %s", network, address)
+
if CustomRoutingDisabled() {
return d.Dialer.DialContext(ctx, network, address)
}
From 28248ea9f43dc4eae27a71fe90ebccd0cef78b35 Mon Sep 17 00:00:00 2001
From: Maycon Santos
Date: Mon, 9 Sep 2024 14:44:46 +0200
Subject: [PATCH 33/89] add TestRecreation test (#2558)
---
iface/iface_test.go | 66 +++++++++++++++++++++++++++++++++++++++++++++
1 file changed, 66 insertions(+)
diff --git a/iface/iface_test.go b/iface/iface_test.go
index 6609c06f4..8de9f647e 100644
--- a/iface/iface_test.go
+++ b/iface/iface_test.go
@@ -176,6 +176,72 @@ func Test_Close(t *testing.T) {
}
}
+func TestRecreation(t *testing.T) {
+ for i := 0; i < 100; i++ {
+ t.Run(fmt.Sprintf("down-%d", i), func(t *testing.T) {
+ ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2)
+ wgIP := "10.99.99.2/32"
+ wgPort := 33100
+ newNet, err := stdnet.NewNet()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ iface, err := NewWGIFace(ifaceName, wgIP, wgPort, key, DefaultMTU, newNet, nil, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ for {
+ _, err = net.InterfaceByName(ifaceName)
+ if err != nil {
+ t.Logf("interface %s not found: err: %s", ifaceName, err)
+ break
+ }
+ t.Logf("interface %s found", ifaceName)
+ }
+
+ err = iface.Create()
+ if err != nil {
+ t.Fatal(err)
+ }
+ wg, err := wgctrl.New()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer func() {
+ err = wg.Close()
+ if err != nil {
+ t.Error(err)
+ }
+ }()
+
+ _, err = iface.Up()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ for {
+ _, err = net.InterfaceByName(ifaceName)
+ if err == nil {
+ t.Logf("interface %s found", ifaceName)
+
+ break
+ }
+ t.Logf("interface %s not found: err: %s", ifaceName, err)
+
+ }
+
+ start := time.Now()
+ err = iface.Close()
+ t.Logf("down time: %s", time.Since(start))
+ if err != nil {
+ t.Fatal(err)
+ }
+ })
+ }
+}
+
func Test_ConfigureInterface(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+3)
wgIP := "10.99.99.5/30"
From c720d54de692f141a7d3a2ab6ef9f0fa6b3196ce Mon Sep 17 00:00:00 2001
From: Zoltan Papp
Date: Mon, 9 Sep 2024 18:12:32 +0200
Subject: [PATCH 34/89] Fix error handling in openConnVia function (#2560)
---
relay/client/manager.go | 8 ++++++++
1 file changed, 8 insertions(+)
diff --git a/relay/client/manager.go b/relay/client/manager.go
index 3e152a963..a9d294160 100644
--- a/relay/client/manager.go
+++ b/relay/client/manager.go
@@ -29,6 +29,7 @@ var (
type RelayTrack struct {
sync.RWMutex
relayClient *Client
+ err error
}
func NewRelayTrack() *RelayTrack {
@@ -235,6 +236,9 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
rt.RLock()
m.relayClientsMutex.RUnlock()
defer rt.RUnlock()
+ if rt.err != nil {
+ return nil, rt.err
+ }
return rt.relayClient.OpenConn(peerKey)
}
m.relayClientsMutex.RUnlock()
@@ -247,6 +251,9 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
rt.RLock()
m.relayClientsMutex.Unlock()
defer rt.RUnlock()
+ if rt.err != nil {
+ return nil, rt.err
+ }
return rt.relayClient.OpenConn(peerKey)
}
@@ -259,6 +266,7 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
relayClient := NewClient(m.ctx, serverAddress, m.tokenStore, m.peerID)
err := relayClient.Connect()
if err != nil {
+ rt.err = err
rt.Unlock()
m.relayClientsMutex.Lock()
delete(m.relayClients, serverAddress)
From 12c36312b51a771763c3735f5be2a319762e0277 Mon Sep 17 00:00:00 2001
From: benniekiss <63211101+benniekiss@users.noreply.github.com>
Date: Mon, 9 Sep 2024 12:27:42 -0400
Subject: [PATCH 35/89] [management] Auto update geolite (#2297)
introduces helper functions to fetch and verify database versions, downloads new files if outdated, and deletes old ones. It also refactors filename handling to improve clarity and consistency, adding options to disable auto-updating via a flag. The changes aim to simplify GeoLite database management for admins.
---
.../workflows/test-infrastructure-files.yml | 26 +--
.gitignore | 1 -
client/ui/bundled.go | 12 +
infrastructure_files/download-geolite2.sh | 109 ---------
management/cmd/management.go | 8 +-
management/cmd/root.go | 2 +
management/server/geolocation/database.go | 52 +++--
management/server/geolocation/geolocation.go | 207 ++++++++++--------
.../server/geolocation/geolocation_test.go | 14 +-
management/server/geolocation/store.go | 72 +-----
management/server/geolocation/utils.go | 19 ++
.../server/http/geolocation_handler_test.go | 11 +-
...-Test.mmdb => GeoLite2-City_20240305.mmdb} | Bin
...{geonames-test.db => geonames_20240305.db} | Bin
14 files changed, 199 insertions(+), 334 deletions(-)
create mode 100644 client/ui/bundled.go
delete mode 100755 infrastructure_files/download-geolite2.sh
rename management/server/testdata/{GeoLite2-City-Test.mmdb => GeoLite2-City_20240305.mmdb} (100%)
rename management/server/testdata/{geonames-test.db => geonames_20240305.db} (100%)
diff --git a/.github/workflows/test-infrastructure-files.yml b/.github/workflows/test-infrastructure-files.yml
index 03ecbd445..d1aef3324 100644
--- a/.github/workflows/test-infrastructure-files.yml
+++ b/.github/workflows/test-infrastructure-files.yml
@@ -209,8 +209,8 @@ jobs:
working-directory: infrastructure_files/artifacts
run: |
sleep 30
- docker compose exec management ls -l /var/lib/netbird/ | grep -i GeoLite2-City.mmdb
- docker compose exec management ls -l /var/lib/netbird/ | grep -i geonames.db
+ docker compose exec management ls -l /var/lib/netbird/ | grep -i GeoLite2-City_[0-9]*.mmdb
+ docker compose exec management ls -l /var/lib/netbird/ | grep -i geonames_[0-9]*.db
test-getting-started-script:
runs-on: ubuntu-latest
@@ -237,7 +237,7 @@ jobs:
run: test -f management.json
- name: test turnserver.conf file gen postgres
- run: |
+ run: |
set -x
test -f turnserver.conf
grep external-ip turnserver.conf
@@ -278,7 +278,7 @@ jobs:
run: test -f management.json
- name: test turnserver.conf file gen CockroachDB
- run: |
+ run: |
set -x
test -f turnserver.conf
grep external-ip turnserver.conf
@@ -291,21 +291,3 @@ jobs:
- name: test relay.env file gen CockroachDB
run: test -f relay.env
-
- test-download-geolite2-script:
- runs-on: ubuntu-latest
- steps:
- - name: Install jq
- run: sudo apt-get update && sudo apt-get install -y unzip sqlite3
-
- - name: Checkout code
- uses: actions/checkout@v3
-
- - name: test script
- run: bash -x infrastructure_files/download-geolite2.sh
-
- - name: test mmdb file exists
- run: test -f GeoLite2-City.mmdb
-
- - name: test geonames file exists
- run: test -f geonames.db
diff --git a/.gitignore b/.gitignore
index cdce46975..d0b4f82dd 100644
--- a/.gitignore
+++ b/.gitignore
@@ -29,4 +29,3 @@ infrastructure_files/setup.env
infrastructure_files/setup-*.env
.vscode
.DS_Store
-GeoLite2-City*
\ No newline at end of file
diff --git a/client/ui/bundled.go b/client/ui/bundled.go
new file mode 100644
index 000000000..e2c138b14
--- /dev/null
+++ b/client/ui/bundled.go
@@ -0,0 +1,12 @@
+// auto-generated
+// Code generated by '$ fyne bundle'. DO NOT EDIT.
+
+package main
+
+import "fyne.io/fyne/v2"
+
+var resourceNetbirdSystemtrayConnectedPng = &fyne.StaticResource{
+ StaticName: "netbird-systemtray-connected.png",
+ StaticContent: []byte(
+ "\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x01\x00\x00\x00\x01\x00\b\x06\x00\x00\x00\\r\xa8f\x00\x00\x00\xc3zTXtRaw profile type exif\x00\x00x\xdamP\xdb\r\xc3 \f\xfc\xf7\x14\x1d\xc1\xaf\x80\x19\x874\xa9\xd4\r:~\r8Q\x88r\x92χ\x9d\x1cư\xff\xbe\x1fx50)\xe8\x92-\x95\x94СE\vW\x17\x86\x03\xb53\xa1v>@\xc1S\x1dNɞų\x8c\x86\xa5\xf8\xeb\xa8\xd3d\x83T]-\x17#{Gc\x9d\x1bEGf\xbb\x19\xc5E\xd2&b\x17[\x18\x950\x12\x1e\r\n\x83:\x9e\x85\xa9X\xbe>a\xddq\x86\x8d\x80F\x92\xbb\xf7ir?k\xf6\xedm\x8b\x17\x85y\x17\x12t\x16\xd11\x80\xb4P\x90\xdaE\xf5\xf0\xa1\xfc#u-\x92:[L\xe2\vy\xda\xd3\x01\xf8\x03\xda\xd4Y\x17ݮ\xb7\xee\x00\x00\x01\x84iCCPICC profile\x00\x00x\x9c}\x91=H\xc3@\x1c\xc5_S\xa5\"-\x0e\x16\x14\x11\xccP\x9d\xec\xa2\"\xe2T\xabP\x84\n\xa5Vh\xd5\xc1\xe4\xd2/hҐ\xa4\xb88\n\xae\x05\a?\x16\xab\x0e.κ:\xb8\n\x82\xe0\a\x88\xb3\x83\x93\xa2\x8b\x94\xf8\xbf\xa4\xd0\"ƃ\xe3~\xbc\xbb\xf7\xb8{\a\b\x8d\nSͮ\x18\xa0j\x96\x91N\xc4\xc5lnU\f\xbcB\xc0\x00B\x18\xc1\xac\xc4L}.\x95J\xc2s|\xdd\xc3\xc7\u05fb(\xcf\xf2>\xf7\xe7\b)y\x93\x01>\x918\xc6t\xc3\"\xde \x9e\u07b4t\xce\xfb\xc4aV\x92\x14\xe2s\xe2q\x83.H\xfc\xc8u\xd9\xe57\xceE\x87\x05\x9e\x1962\xe9y\xe20\xb1X\xec`\xb9\x83Y\xc9P\x89\xa7\x88#\x8a\xaaQ\xbe\x90uY\xe1\xbc\xc5Y\xad\xd4X\xeb\x9e\xfc\x85\xc1\xbc\xb6\xb2\xccu\x9a\xc3H`\x11KHA\x84\x8c\x1aʨ\xc0B\x94V\x8d\x14\x13iڏ{\xf8\x87\x1c\x7f\x8a\\2\xb9\xca`\xe4X@\x15*$\xc7\x0f\xfe\a\xbf\xbb5\v\x93\x13nR0\x0et\xbf\xd8\xf6\xc7(\x10\xd8\x05\x9au\xdb\xfe>\xb6\xed\xe6\t\xe0\x7f\x06\xae\xb4\xb6\xbf\xda\x00f>I\xaf\xb7\xb5\xc8\x11з\r\\\\\xb75y\x0f\xb8\xdc\x01\x06\x9ftɐ\x1c\xc9OS(\x14\x80\xf73\xfa\xa6\x1c\xd0\x7f\v\xf4\xae\xb9\xbd\xb5\xf6q\xfa\x00d\xa8\xab\xe4\rpp\b\x8c\x15){\xdd\xe3\xdd=\x9d\xbd\xfd{\xa6\xd5\xdf\x0fںr\xd0VwQ\xba\x00\x00\rxiTXtXML:com.adobe.xmp\x00\x00\x00\x00\x00\n\n \n \n \n \n \n \n \n \n \n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\xf0C\xff\xd9\x00\x00\x00\x06bKGD\x00\xff\x00\xff\x00\xff\xa0\xbd\xa7\x93\x00\x00\x00\tpHYs\x00\x00\v\x13\x00\x00\v\x13\x01\x00\x9a\x9c\x18\x00\x00\x00\atIME\a\xe8\x02\x17\r$'\xdd\xf7ȗ\x00\x00\x13;IDATx\xda\xed\x9d]o\x14W\x9a\xc7\xff\xa7\xaamh\xbf\xc46,I`\x99\xa1\xc3\ni\xb5{1\x95O0\xe4\x1b\xc0'X\xf2\t`.W`hp\xa2\xb9\fH{O\xa3\xcc\xc5\xecJ3q\xa4\x1d\xed\xcdJx>Aj/\"EBJګL \xb1\x00g\xf1\v\xb6\xbb\xeb\xec\x85mb\f\xb6\xfb\xa5^Ω\xfa\xfd\xee\x928v\xf7\xa9z\xfe\xcfs\x9e\xa7ο$\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\u0603a\t\xc0g\xd6\x7f\x1f5\x92\x8e\"k\xd4\b\xa4s\xb2jH\x9afez\n\xfe\xdb\b\x00x\x81mF\xd3/CE]\xa3(\x94~c\xa5\x8b;\xc1\x0e\x83E\x7f{\xecF\xfcA\x8d\x95\x00g\xb3\xfb\\tQ\xd2o\xadtq]\xba(I\x81\x95,K345\xa3˒\x84\x00\x80SY~5Х0\xd0o\x13\xabK\x96R>\x9b\xe4o\xd4\x1a\xbd\x1e\xc7\b\x008\x93\xe9\xadtkM\x8a\x02i\xdaZ\x9aS\x99\x12\xea\xf6\xabJ\x80Հ\x02\xf7\xf4W\x13\xe9\xdan\xa6'\xe8sXw\xe9\xf6ؿ\xc6\xed_Z\x01\x00\x05d{\xed\xec\xe9!\xcf\xda\x7f\xbb\xf1\xf7Z/\x80U\x81<\x03\xdf\x12\xf8\xc5\xc5\x7f\xf2K\xe9O\x05\x00d\xfcje\xffx\xecF\xfc\xe1\xfe\x7fM\x05\x00\xd9\x04~3j$5ݲVWX\r\a\xe2?\xdc\x1e\xfb!\x00\x909ks\xd1\xd5Dj\x1a\xcb\x18ω\xe07j\xd5\xf74\xfe\x10\x00Ȅ\x95O\xa3(Ht_R\xc4\xdeҙҿ\xbdw췟\x80\x15\x82\xb4\xb2~\x90\xe8+I\x11\xab\xe1\x0e\xd6\xea\xc1A\xd9\x7f[\x1f\x00\x86\xdc\xeb\xdbP\xf7E\x93\xcf\xc9\xec\xbf\x7f\xec\xc7\x16\x00\xd2\v\xfe\xb9\xe8b\"}axd\xd7\xcd\xf8O\x0e.\xfd\xd9\x02\xc0\xb0\xc1\x7f\xcbJ\x0f\t~G\t4_\xbf\x19\xb7\x8e\xfa1*\x00\xe8\x9b\xd5O\xa2\xfb\x8c\xf7\x1c\xcf\xfe\x81~\xd7\xcb\xcf!\x00\xd03\xb6\x19M\xaf\x87\xfaB\x96\xfd\xbe\xd3\xc1\x7f\xc8؏-\x00\fV\xf27\xa3\xc6z\xa8\x87\xa2\xd9\xe7x\xf4\x1f>\xf6\xa3\x02\x80\x81\x82\xdf\xd6\xf4\x10\a\x1e\x0f\xe2?\xd1\xed\xfa\x8d\u07b2\xff\xb6^\x00\x10\xfc\xa5\xc9\xfeG\x8d\xfdJW\x01\xd8f4\xfd\xf2ؾNt\xe7\xe8\x9b5\f\xb4\xdc\r\xb4\xbc\xfb\xcf\xc77\xb4l\x9a\xf12wѾ=?\xc1\xef\xcf\xf5\xb2\xbd5\xfe\x9c\xac\x00\xd6\x7f\x1f5\xc2D\xd3[\x89\x1a\xd6jZ\x81\xa6\x8d\xd5t`tn\xe7\xcb5$M\xcb\xec\x04{\x867\xa5\x95\x96\x8dѲ\xacvK\xa9ec\xb4\x9cX-Z\xa3ec\xd5\x0e\xa4e\xd5\xd4\xee\xb5\xd9\xe2#ks\x11O\xf6\xf9\x92\xfc\x8dZ\xf5\x1b\xf1\xc7N\n\x80mF\xd3[#jlv\x15)\xd0\xf4\x1e\xfb憌\xa6\xbd\xcf0F\xed\x1d\xb1X\xb6\xd2\xffX\xabvh\xd4>\xdeU\xeckU\xb1v'\xfaLF\xd7\b-On\xc1\x9a>\x18$\x19\x99,\x82<\b\xf4\x1b\xb3\xed\xed\x16Y\xa9Q\xe5\x87E\xac\xb4l\xa4XFq\"-\x86V\xb1\xeb°\xf3\x90O\x93\xb0\xf2\xe6\x1e\xbb=>\x1b\x0ft\xbd\xcc\xc0\x81nuq'\x93Gv\xfb\xf4\x17O\x84\xf5G,\xa9\x9d\x18\xfd5\xb4\x8a\xeb\xb3\xf1\x82\v\x1fju.\xbad\xa4/\xb8<\xfeT\x9f\xf5\x8e>\x1c4\xa1\x98\xa3\x82}-\xd4Ek\xd4\xe0e\f\xb9\xb0 \xa3\xd8Z\xfdu\xac\xab\x85\xbc\xab\x04:\xfe\x1eƿ\xd5ǽ<\xf2ۓ\x00l~\x1aE\x9bV\x17\tv\x87\xaa\x04\xa3\x05c\xf5e\x1e\x15\xc2\xda'\xd1w\\s\xbf\xb2\x7f\xbfc\xbf7~\xc5\xda\\tU\xd2%\xcax/z\t\v\x89\u0557\xe1\x88\x16Ҟ>\xb0\xef\xf70\xfe\al\xfc\xbd\xf6;V>\x8d\"\x93p\xaa\xcb\xc7\xedBb\xf5 \r1\xd89\xd3\xff\x1dK\xeaQ\xf0\x0f8\xf6{\xeb\x16`ǹ\xf5!\xcbZM1X\x9d\x8b\xbe2\xcc\xfb\xbd\xaa\x06\x83\x9a>L\xa3\n|\xd5\x03X\xbf\x13]\xb1F\xf7Y^\uf677҃\xf1\xd9x\xbe\x97\x1f^\xb9\x13]\t\xb8\xee\xbe\t\xc0\xc0c\xbf\x03\x05\x00\x11([\x8d\xa8\xb6\x12͛\x11\xdd;,S\xd0\xf8\xf3\xef\xba\x0e\xdb\xf8\xdb\xcbkǁ\xeb7\xe3\x96U\xefG\t\xc1\xe94ѐ\xd15\xdb\xd1w\xab\x9fD\xf7w^\xb5\xfd\xfa\xde\x7f.\xbaE\xf0{\x16\xffI\xba\xf1i\x0e\xd8\x136\xcd\xf6\xdb\\\xa0d\xd9#It{\xe2fܢ\xf1\xe7\xe5\xf5[\x18\xbb\x11\x7f\x94\xb9\x00 \x02\x15\x10\x82\xe7\x13\xed`z\xe5\"\x8b\xe1\xd1eKa\xec׳\x00 \x02%\xde\x1dlִ\xf5\xf5\xafeF;\nO?Wp\xe2\x05\x8b\xe2z\xf0\xa74\xf6;\xb4\a\xb0\x9f\xf1ٸ\x99H\x0fX\xfer\xd1yt\xe6\x95\x10t\x16Oi\xeb\xeb_+y6\xc9\xc28\\\xb1\xf5c\xf3\x95\x9a\x00H\xd2\xc4l|\x05\x11(\x0fɳI\xd9\xcd\xda\x1b\x15\xc1\xae\x10ؕ:\x8b\xe4\x1e\xf7\xb2\xf2\x9d\xe8\xf94\xe0\xfa\\\xf4\x90w\xbb{^\xfaw\x03u\xbe9\xfb\x86\x00\xbc\x91\x15N\xac(<\xfd\\ft\x8bEs \xfb\xa79\xf6\xeb\xbb\x02\xd8\xe5eW\x97\xed\xf6\x11V\xf05\xfb\xff4ud\xf0oW\t\x13\xda\xfa\xfaW\xea>\x9eaъ\x8e\xff$۱|_~\x00ϛ\xd1\xf4h\xa8\x87<6\xeaa\xf6\xdfi\xfc\xf5}\x83\x8cvT\xbb\xf0\x98j\xa0\x88\xe0Ϩ\xf17P\x05 I3\xcdx9\xe8게\xda\\\x1e\xbf\x184\x9bo\vǯ\xd4\xfd\xfe\xa4l\x97\xd7H\xe4J\x98\xfdCy}_\xd1z3n\x9b\x8e>B\x04<\xca\xfe+\xf5\xa1\xbb\xfcݥ\xa9\x9d\xfe\xc1\b\v\x9a\xc75\x93n\xe7a8;\x90\xa4#\x02~\xd1Y<\x95\xe26\x82\xde@\xf6\xb5\xbf\xdaAM\xad<\xfe\xd4\xc05\x1d\"\xe0\ao\x1b\xfb\r\xbd\x9dx2\xa3-\xaa\x81\xec\xe2?\xc9'\xfbok͐`(\xe2p\x19\xb9YS\xe7љ\xd4\x05\xe0\xd5\xcd3\xdaQx\xf6\xa9\x82\xa9U\x16;\xc5\xec\x9f\xe5\xd8/\xb5\n`\x97\x89\xebql\x03}d%ު\xe3\x18\xdd\xc73\x99\x05\xff+\x81\xf9\xf6=\xb6\x04)R3\xba\x9c\xe7\xdfK\xa5\xad;q=\x8e%}\xcc\xe5s+\xfb\xe7\xf5xo\xf7Ɍ:\x8b\xef2%\x186\xf9\x1b\xb5F\xb7c\xc9/\x01\x90\xa4\xf1\xd9x\xdeXD\xc0\xa5\xec\x9fo\xafa\x82)\xc1\xb0\x84\xf9{q\xa4*\xd9\xf5\x9bqK\xa6\xff\x17\x14B\xda\xc18Y\xc8\xe1\x9e\xed\x9e\xc3iD`\x90\xb5S~\x8d\xbf\xd7[\x0e\x19\xc01\xe2b\xd9\xfa\xfaי\xee\xfd\x8f\xced\x89F.<\x96\xa9op1z\x8b\xc2\\\x1b\x7f\x99U\x00{\xb6\x03M\xac\xc5\n*\xfd{|\xde?\xdb\x0f\x11h\xeb\xd1i%?\x8fsAz\x89\xff\xa4\xb8X\xc9\xf4\xed\xc0T\x02E\x94\xe0g\x8a\x17\x80\xbd\xc5\xc0\xb9%\x85\x18\x8e\x1c\x16\x81\xf1؍\xf8â\xfe|\xa6m\xdb\x1dC\x91{\\\xe5\x9c\x12o\xc6c\xbf\x81>\xd3\xe2)u1\x1b98\xfe\xc3|\xc7~\xb9\n\x80$M\xcc\xc6\xd70\x14\xc9'\xfb\xbb\xea\xea\x83\b\x1c\x10\xfcF\xad\"\x1a\x7f\xb9\n\xc0\x8e\b\xe0*\x941\x9do\xdfw\xbb:A\x04\xf6\x97\xfe\xed\"\xc6~\x85\b\x80$muu\rC\x91lH\x9eMʮ\x8f\xba\xbfE\xf9\xfe\xa4\xec\xfa1.\x98$k\xf5\xa0\xe8쿭C9\x82\xa1HF\xe2Z\xf4د\x1f\xc2D#\xff\xf8\xb7j\x1b\x8c\x148\xf6+\xac\x02\x90\xb6\rE6\xbb\x9c L5\xab:\xd8\xf8;\xfc\x03\a\x95\x7fX\xa8ȱ_\xa1\x02\xb0+\x02\x1c#N\xa9\x8cܬ\xa9\xfbd\xc6\xcb\xcf\xdd\xf9\xf6\xbdJ\x9e\x1d0F\xad\xfa\u0378UY\x01\x90\xf0\x12H3\xfb{+^\xeb\xa3\xea~\xffwջh\xa1[\x0f\xc8\x15&\xc1\x88\xc0\xb0\x01t\xcc\xfb\x97y$\xcf&*u\x94\u0605\xb1\x9f3\x02\xb0W\x04\xf0\x12\xe8\x9fη\uf563\x8ay2S\x8d\x97\x9182\xf6sJ\x00vE\x00C\x91~3\xe7\xa4_\x8d\xbf#\xd8\xfa\xf6\xbd\xd27\x05\xf3\xb4\xf9\xeaO\x97\x1c\x01k\xb1\x1eK\x7f\a\x9f\xf7O\xe5F\x9cx\xa9\x91\v?\x946\xfb\xbb2\xf6s\xae\x02\xd8e\xe2z\x1c\a\x16/\x81#\xb3\xff\xd3\xc9\xd2\x05\xbf$ٕ\xe3\xea.M\x953\xfe\x1d6\xcaqj\x0eS\xbf\x19\xb7p\x15:<\xfb\xfb8\xf6\xeb\xb9\x1f\xf0\xfd\xc9\xd2\xf5\x03\x8cQ\xab>\x1b/ \x00}\x88\x00\xaeB\a\x04H\x05:\xe6\x9d\xc5S\xe5z> t\xdb\x17\xc3ɕ\x1e\xbb\x11\xdf\xc5Pd_\xf6\xffy\xdc\xfb\xb1_\xafUNR\x12\xa1+\xca\xe6\xcb{\x01\x90p\x15z#3~\x7f\xb2:\x95\xceҔ\xff[\x01\xa3\xf6XWw]\xff\x98N\xd7Z\x88\xc06e\x1b\xfbUA\xf0L\xa2ۦ\x19;?\xda6>,\xe6\xca\\t7\x90\xaeV\xb2\xf4/\xe9د\xa7\xed\xf3٧\nO\xfd\xecg\xf6wt\xec\xe7U\x05\xb0K\x95]\x85\xbc;\xed\x97\xf6w\xf7\xb0!hB}\xe4\xcbg\xf5fu\xab\xe8*\xe4\xb2\xcdW>\n\x10xw`\xc8\xc5\xe7\xfdK!\x00R\xf5\\\x85*yZn\x1fɳ\to\x1a\x82VZv}\xec\xe7\xb5\x00\xec\x1a\x8aTA\x04\x92g\x93J~\x1e\x13H\x1d\x7fƂ\xf7|\xca\xfe\x92'M\xc0\xfd\xac7\xa3\x86\xad顬\x1ae\xbd齲\xf9ʁ\x91\v\x8fe&\xd6]\x8e$o\x1a\x7f\xdeV\x00\xbb\x94\xddK\xa0ʍ?_\xab\x00\x97l\xbeJ/\x00e\x16\x01\xbbYSR\xd2C1C\xad\xcb\xcaqw{\x01\x81\xe6]\xb2\xf9\xaa\x84\x00\x94U\x04|\x1d}U\xb9\n0\x81\xbfgW\xbc\xbf\xd3\xca\xe4*T\xf9\xb1\x9f\x87U\x80oc\xbf\xd2\t\xc0\xae\b\x94\xc1U\xa8\xf3\xe8\fQ\xeeS\x15\xe0\xa8\xcdW\xe5\x04@\xda1\x14Q\xb1/Z\x1c\x86*>\xef\xef{\x15\xe0\xaa\xcdW%\x05@\x92\xea\xb3\U000423c6\"\xb6\x1bT\xca\x1d\xb7\x14U\x80Q\xdb\xd7\xc6_i\x05@\xf2\xd3U(\xf9i\x8a\xec\xdfo\x15P\xb0\x89\xa8-\x89}])\xdb\xcd\xf5\x9bq˗c\xc4e\xb7\xf9\xcan\xcb4Q\\\xf27j\x8d\xcf\xc6\xf3\b\x80\xc3\xf8\xe2%@\xe9?\xe0\xba\xfd4Uܸ4,\x8fGE\xa9\aή\x8b\x80]\xa93\xf6\x1bX\x01\x82B\xd6\xce\a\x9b/\x04\xc0\x13\x11\xe8,\x9e\"\x90\x87\xd9\x06,\x8f\xe7\\\xfb\xab\x1d\xd4\xd4*\xd3\x1aV⑳\xf1ٸ隗\x00c\xbf4*\xa8\xe3\xb2\xeb\xc7\xf2\x8b\xff\xa4\\ٿ2\x02 \xb9e(b7k\xec\xfd\xd3\x12Ҽ\x8eL\x97d\xecWY\x01\xd8\x15\x01#-\x14~\xd32\xf6K\xaf\x15\xf0S>\a\xa7j\xc6߇\xcc\x10\x80=\xbc\xec\xear\x91\x86\"v\xb3V\xdaW`\x15\xa3\x00A\xe6O\x06\x1a\xa3\xd6\xe8\xf58F\x00J@ѮB\x94\xfe\x19\xaci\xd6ۀ\xb0\xbc\xd6\xf4\x95@X\xbd\xec\x8f\x00d \x02\x8c\xfd\x1c\xe8\x03\xf4)\x02I\xc5\x1a\x7f\b\xc0\x80\"Ћ\xa1\bo\xf7q\xa1\x0f\xd0\xc76\xc0\xa8\x1d\xd6t\xb7\xaak\x85\x00\xf4\xc1Q\xaeB\xd8|\xb9\xb2\r\xe8\xdd&\xac\x8c6_\xfd`\xb8]\xfagu.\xfa\xcaH\xd1k7]7P盳\b\x80\v7\xf5hG#\xff\xfc\xbf=e\xff\xb1\x1b\xf1\aU^+*\x80\x01x\x9b\xa1\b6_\x0eU\x00\x9b\xb5\x9e\x9a\xb0>\xbeF\x0e\x01p\x80]W\xa1\xdd\x13\x84\xbc\xdd\xc7A\x8e\xd8\x06\x18\xa3V}6^@\x00``\x11\xd8=F\xcc\x13\x7f\xee\x91\x1c5\t\xa8\xe8\xd8\x0f\x01H\x91z3n'\x1b\xb5{\xae\xbc\xae\x1a\xf6l\x03\x0e\xa9\x00\xcan\xf3\x85\x00乀\xc7:Wk\x17~\x90\u0084\xc5pI\x00\xd6\x0e\xa8\x00\x8c\xdac\xdd\xea\x8e\xfd\x10\x80\x14Y\x9b\x8b\xaeʪaF;\x1a\xb9\xf0\x18\x11pI\x00\x0ehȚD\xb7M3^f\x85\x10\x80\xa1XoF\r\x19]{uc\xd574r\xfeG\x16\xc6\x15\xba\xc1\x9b\x0eA%}\xbb\x0f\x02P\x00IM\xb7d\xd5x\xed\xfe\x9aXWxn\x89\xc5q\xa6\x0f\xf0\xfa6\xc0\x84\xfa\x88UA\x00R\xc9\xfe\xc6\xea\xca\xdb\xfe[x\xe2\x05\"\xe0\xe06\xa0J6_\b@\xd6\xd9?\xd4\x17\x87\xfd\xf7\xf0\xc4\v\x85\xa7\x9f\xb3P\x8e\b\x80\x95\x96\x19\xfb!\x00\xa9\xb0r'\xba\xb2\xff1්\xc0\xfb\xcf\x11\x01w*\x80{d\x7f\x04 \x9d\x05\vt\xabןE\x04\nf\xedX%\xde\xee\x83\x00\xe4\xb5\xf7\x9f\x8b\xdeh\xfc!\x02\x0eW\x00ݠ\x926_\xfd\xc0i\xc0^\x83\xbf\x195l\xa8\xef\x06\xfd\xff;\x8b\xa70\n\xc9\xfb\xe6~g\xbd=\xd5\xfa\xaf\x0fX\t*\x80\xa1Ij\xbd\x97\xfeo\xa3vnI\xc1\x89\x17,d\x8e\x84'V\xc8\xfeT\x00\xc5g\xff\xbdl=:û\x02\xf2\xc8lc\x9b\xf3\xef\xfc\xe1?/\xb3\x12T\x00\xc3\xef%kz\x98\xd6\xef\x1a9\xffD\xa6\xbeɢf\x99\xd5F;\xeav&~\xc7J \x00C\xb3r'\xba\xd2o\xe3\xef\xf0\xba4\xd1ȅ\x1f\x10\x81,K\xff\xa9\xf5\xd6\xcc\x1f\xff\xd8f%\x10\x80\xa1K\xff~\xc6~\xfd\x88@\xed\xfc\x13\x99\xd1\x0e\x8b\x9c>\xed\xb0\xb1\xc4\xde\x1f\x01H#P\xf5/\xa9f\xff}ej\xed\xc2\x0f\x88@\xda7\xf4\xf4\x1ag\xfd\xfb\xb9\x0fY\x82\x83\xb3\x7fZ\x8d\xbfC\xfb\v\x9b5u\x1e\x9d\xc1O0\x8d\x9b\x99\xb1\x1f\x15@Z\f;\xf6\xa3\x12(\xe0f\x9e\xd8\xfa\x98U@\x00\x86fu.\xbat\xd0i\xbf\xccD\xe0\xfc\x8f\x18\x8a\fs#\x8fo\xb4&\xff\xed\xbf\x17X\t\x04`\xf8\x804\xfa,\xf7\xbfY\xdf\xc0Uh\b\x01\x1d9\xb9J\xe3\x0f\x01\x18\x9e\xd4\xc7~}\x8a@\r/\x81\xfe\x19\xedܮ\xdf]h\xb3\x10\b\xc0Pd6\xf6\xeb\xe7\x82L\xadb(\xd2\x1f\xedw\xce\xff\x80\xc9'\x020\xff\v\x8d?*\x80\xe1qe\xec\xd7\x0f\xb5\xb3O+\xed%\x10\x9e\xfa?J\x7f*\x80\x94\xb2\xbfcc\xbf\x9e\xe9\x06\xdb\xd6b\xeb\xa3\xd5\xcaV#\xc9\xfc;\xff>\x8f\xcd\x17\x15@\n\xd9\xff\x88\xb7\xfb\xb8\x9d\x06w\\\x85*t\x82Ќv\xd45DZ\xf9B\x00\x86\xa7\u05f7\xfb\xb8.\x02U:F\x8c\xcd\x17\x02\x90ޗv|\xec\xd7OV\xac\x88\b`\xf3\x85\x00\xa4\xb4\xf7\xf7d\xec\x87\b\xfc\x82\x95\xb0\xf9\xca\xea\xfe\xa9T\xf0\xfb\xdc\xf8;*H6k\xda\xfa\xe6\xac\xd4-\x97\xa6\xf3\xbc?\x15@j\xe4e\xf3UT%PFC\x11l\xbe\xa8\x00Ra\xe5\xd3(\n\x12}U\xf6\xefi\u05cfi\xeb\xd1\xe9RT\x02\xc1\xf8F\xeb\x9d\xcf\xff\x82\x00P\x01\xa4\xf0E\xad\xc7c\xbf~\x14\xbd\xbeQ\n/\x013\xdaQwk\x92\xc6\x1f\x02\x90B\xf6/\xd0\xe6\xab\b\xc2\x13/\xfcw\x15\x1a\xed\xdcf\xec\x87\x00\f\x8d\v6_\x85\x89\x80\xbf^\x02\xd8|!\x00)\xed\x89\x03]\xadR\xf6\x7fM\x04<5\x14\t\xde_\xc6\xe6+\xaf\xadVٳ\x7fY\xc7~\xfd\xe0\x93\xb5\x18c?*\x80\xd4(\xf3د\xac\x95\x006_\b@*\xac܉\xae\xe4\xf9v\x1f/D\xc0qW\xa1`|\x83\xe7\xfd\x11\x80\x94\xbeX@\xf6\x7fC\x04\xdcv\x15\xc2\xe6\v\x01H\x87\xb5\xb9\xa8\xb2\x8d\xbf\xa3\xa8\x9d[\x92\x99x\xe9\xde\xde\x1f\x9b\xafbֽl_\xc8G\x9b\xaf\xdcq\xccP\xc4Ԓ\xf6\xd4\x7f\xcc\xd3\xf8\xa3\x02\x18\x1e\x1fm\xbe\xf2\xdf\v\xec\x18\x8a8b-\x16\x9e}J\xe9O\x05\x90R\xf6g\xec\xd73v\xb3\xa6Σ3\xb2\x9bŽ\x1e\x02\x9b/*\x80\x14S\x89\xeesI\xfbP\xff\x82\xbd\x04\xb0\xf9B\x00Rc\xe5Nt\xc5J\x17\xb9\xa4\xfe\x88\x80\x19\xe92\xf6C\x00R\xfa\"\x8c\xfd|\x13\x01\xc6~\b@J{\xff\x92\xd9|\x15&\x02\xe7\x7f\xcc\xcdP\x04\x9b/G\xae\xbb\xf7\xc1ߌ\x1aI\xa8\xaf\x8c4\xcd\xe5L!0s0\x14\xe1\xed>T\x00\xa9\x91\xd4t\x8b\xe0O18\xeb\x1b\x1a9\xffc\xb67\xdd\xd4\x06.?T\x00\xe9d\x7f\xc6~\xd9\xd0}6\xa9\xee\xe2\xa9\xf4o8cZS\x7f\xfa\x13\x02@\x05\x90B\xb9Z\xd3C.a6d\xe1*dF;JFFh\xfc!\x00\xc3S5\x9b\xaf\xc2D \xcdc\xc4\xd8|!\x00\xa9d\xfef4\xcd\xd8/'\x11H\xcfK\x00\x9b/\x04 \x1d^\x86ⴟg\"\x10\x1c\xdf\xc2\xe6\xcbA\xbck\x02\xd2\xf8+\x8eA\xadŰ\xf9\xa2\x02H\rl\xbe\x8a\xad\x04\x061\x14\xc1\xe6\v\x01H\x85չ\xe8\x126_\xc5R;\xb7ԗ\b`\xf3\x85\x00\xa4\xb7_1\xfa\x8cK\xe6\x86\b\xf4\xe4%\x10&\xcb#'W\x19\xfb!\x00\xc3\xc3\xd8\xcf-z1\x141\xf5\xcd{\xf5\xbb\vd\x7f\x97\x93\xaa\x0f\x1f\x12\x9b/G\xe9\x06\xda\xfa\xe6\xecA\x86\"\xed\xe9?\xff\x99\xc6\x1f\x15\xc0\xf0`\xf3\xe5(ar\xe01\xe2Zc\x89ҟ\n \xa5\xec\xcf\xd8\xcfi\xf6[\x8ba\xf3E\x05\x90\xde\xcd\x15\xd2\xf8s>\x8b\xec1\x141a\x82͗G\xd4\\\xfep\xabs\xd1%\x19E\x92\xda\\*\xc7E\xe0XG#\xff\xf0X\x1b\x8b\xa7\xbe\x9c\xf9\xc3<\xd7\v\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00|\xe4\xff\x01\xf6P(\xf3)+S\x1f\x00\x00\x00\x00IEND\xaeB`\x82"),
+}
diff --git a/infrastructure_files/download-geolite2.sh b/infrastructure_files/download-geolite2.sh
deleted file mode 100755
index 4a9db5e01..000000000
--- a/infrastructure_files/download-geolite2.sh
+++ /dev/null
@@ -1,109 +0,0 @@
-#!/bin/bash
-
-# to install sha256sum on mac: brew install coreutils
-if ! command -v sha256sum &> /dev/null
-then
- echo "sha256sum is not installed or not in PATH, please install with your package manager. e.g. sudo apt install sha256sum" > /dev/stderr
- exit 1
-fi
-
-if ! command -v sqlite3 &> /dev/null
-then
- echo "sqlite3 is not installed or not in PATH, please install with your package manager. e.g. sudo apt install sqlite3" > /dev/stderr
- exit 1
-fi
-
-if ! command -v unzip &> /dev/null
-then
- echo "unzip is not installed or not in PATH, please install with your package manager. e.g. sudo apt install unzip" > /dev/stderr
- exit 1
-fi
-
-download_geolite_mmdb() {
- DATABASE_URL="https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City/download?suffix=tar.gz"
- SIGNATURE_URL="https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City/download?suffix=tar.gz.sha256"
- # Download the database and signature files
- echo "Downloading mmdb signature file..."
- SIGNATURE_FILE=$(curl -s -L -O -J "$SIGNATURE_URL" -w "%{filename_effective}")
- echo "Downloading mmdb database file..."
- DATABASE_FILE=$(curl -s -L -O -J "$DATABASE_URL" -w "%{filename_effective}")
-
- # Verify the signature
- echo "Verifying signature..."
- if sha256sum -c --status "$SIGNATURE_FILE"; then
- echo "Signature is valid."
- else
- echo "Signature is invalid. Aborting."
- exit 1
- fi
-
- # Unpack the database file
- EXTRACTION_DIR=$(basename "$DATABASE_FILE" .tar.gz)
- echo "Unpacking $DATABASE_FILE..."
- mkdir -p "$EXTRACTION_DIR"
- tar -xzvf "$DATABASE_FILE" > /dev/null 2>&1
-
- MMDB_FILE="GeoLite2-City.mmdb"
- cp "$EXTRACTION_DIR"/"$MMDB_FILE" $MMDB_FILE
-
- # Remove downloaded files
- rm -r "$EXTRACTION_DIR"
- rm "$DATABASE_FILE" "$SIGNATURE_FILE"
-
- # Done. Print next steps
- echo ""
- echo "Process completed successfully."
- echo "Now you can place $MMDB_FILE to 'datadir' of management service."
- echo -e "Example:\n\tdocker compose cp $MMDB_FILE management:/var/lib/netbird/"
-}
-
-
-download_geolite_csv_and_create_sqlite_db() {
- DATABASE_URL="https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City-CSV/download?suffix=zip"
- SIGNATURE_URL="https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City-CSV/download?suffix=zip.sha256"
-
-
- # Download the database file
- echo "Downloading csv signature file..."
- SIGNATURE_FILE=$(curl -s -L -O -J "$SIGNATURE_URL" -w "%{filename_effective}")
- echo "Downloading csv database file..."
- DATABASE_FILE=$(curl -s -L -O -J "$DATABASE_URL" -w "%{filename_effective}")
-
- # Verify the signature
- echo "Verifying signature..."
- if sha256sum -c --status "$SIGNATURE_FILE"; then
- echo "Signature is valid."
- else
- echo "Signature is invalid. Aborting."
- exit 1
- fi
-
- # Unpack the database file
- EXTRACTION_DIR=$(basename "$DATABASE_FILE" .zip)
- DB_NAME="geonames.db"
-
- echo "Unpacking $DATABASE_FILE..."
- unzip "$DATABASE_FILE" > /dev/null 2>&1
-
-# Create SQLite database and import data from CSV
-sqlite3 "$DB_NAME" <
Date: Mon, 9 Sep 2024 18:44:37 +0200
Subject: [PATCH 36/89] fix: client/Dockerfile to reduce vulnerabilities
(#2548)
The following vulnerabilities are fixed with an upgrade:
- https://snyk.io/vuln/SNYK-ALPINE319-OPENSSL-7895536
- https://snyk.io/vuln/SNYK-ALPINE319-OPENSSL-7895536
Co-authored-by: snyk-bot
---
client/Dockerfile | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/client/Dockerfile b/client/Dockerfile
index a3220bf33..b9f7c1355 100644
--- a/client/Dockerfile
+++ b/client/Dockerfile
@@ -1,4 +1,4 @@
-FROM alpine:3.19
+FROM alpine:3.20
RUN apk add --no-cache ca-certificates iptables ip6tables
ENV NB_FOREGROUND_MODE=true
ENTRYPOINT [ "/usr/local/bin/netbird","up"]
From f43a0a0177a9d6ee2ea2985825d9053702a8d184 Mon Sep 17 00:00:00 2001
From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com>
Date: Mon, 9 Sep 2024 19:02:10 +0200
Subject: [PATCH 37/89] [client] Retry on tun creation for darwin (#2564)
The interface creation on macOS seems to be asynchronus why the tun.create methode somethimes failes becasue the interface is not ready yet. To work around this issue we introduce a retry on tun.create
---
iface/iface_create.go | 2 +-
iface/iface_darwin.go | 28 ++++++++++++++++++++++++++++
management/server/peer.go | 2 --
3 files changed, 29 insertions(+), 3 deletions(-)
diff --git a/iface/iface_create.go b/iface/iface_create.go
index cfc555f2e..f389019ed 100644
--- a/iface/iface_create.go
+++ b/iface/iface_create.go
@@ -1,4 +1,4 @@
-//go:build !android
+//go:build (!android && !darwin) || ios
package iface
diff --git a/iface/iface_darwin.go b/iface/iface_darwin.go
index 15e4a7817..f48f324c3 100644
--- a/iface/iface_darwin.go
+++ b/iface/iface_darwin.go
@@ -4,7 +4,9 @@ package iface
import (
"fmt"
+ "time"
+ "github.com/cenkalti/backoff/v4"
"github.com/pion/transport/v3"
"github.com/netbirdio/netbird/iface/bind"
@@ -36,3 +38,29 @@ func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string,
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
return fmt.Errorf("this function has not implemented on this platform")
}
+
+// Create creates a new Wireguard interface, sets a given IP and brings it up.
+// Will reuse an existing one.
+// this function is different on Android
+func (w *WGIface) Create() error {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+
+ backOff := &backoff.ExponentialBackOff{
+ InitialInterval: 20 * time.Millisecond,
+ MaxElapsedTime: 500 * time.Millisecond,
+ Stop: backoff.Stop,
+ Clock: backoff.SystemClock,
+ }
+
+ operation := func() error {
+ cfgr, err := w.tun.Create()
+ if err != nil {
+ return err
+ }
+ w.configurer = cfgr
+ return nil
+ }
+
+ return backoff.Retry(operation, backOff)
+}
diff --git a/management/server/peer.go b/management/server/peer.go
index 5fc6352ee..26e27617d 100644
--- a/management/server/peer.go
+++ b/management/server/peer.go
@@ -550,8 +550,6 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
}
if peer.UserID != "" {
- log.Infof("Peer has no userID")
-
user, err := account.FindUser(peer.UserID)
if err != nil {
return nil, nil, nil, err
From 50ebbe482e7b38936b47dfda4b6c167390606932 Mon Sep 17 00:00:00 2001
From: Viktor Liu <17948409+lixmal@users.noreply.github.com>
Date: Wed, 11 Sep 2024 16:05:13 +0200
Subject: [PATCH 38/89] [client] Don't overwrite allowed IPs when updating the
wg peer's endpoint address (#2578)
This will fix broken routes on routing clients when upgrading/downgrading from/to relayed connections.
---
iface/wg_configurer_kernel_unix.go | 5 +++--
iface/wg_configurer_usp.go | 5 +++--
2 files changed, 6 insertions(+), 4 deletions(-)
diff --git a/iface/wg_configurer_kernel_unix.go b/iface/wg_configurer_kernel_unix.go
index 48ea70b7b..8b47082da 100644
--- a/iface/wg_configurer_kernel_unix.go
+++ b/iface/wg_configurer_kernel_unix.go
@@ -56,8 +56,9 @@ func (c *wgKernelConfigurer) updatePeer(peerKey string, allowedIps string, keepA
return err
}
peer := wgtypes.PeerConfig{
- PublicKey: peerKeyParsed,
- ReplaceAllowedIPs: true,
+ PublicKey: peerKeyParsed,
+ ReplaceAllowedIPs: false,
+ // don't replace allowed ips, wg will handle duplicated peer IP
AllowedIPs: []net.IPNet{*ipNet},
PersistentKeepaliveInterval: &keepAlive,
Endpoint: endpoint,
diff --git a/iface/wg_configurer_usp.go b/iface/wg_configurer_usp.go
index 04a29a60b..cd1d9d0b6 100644
--- a/iface/wg_configurer_usp.go
+++ b/iface/wg_configurer_usp.go
@@ -64,8 +64,9 @@ func (c *wgUSPConfigurer) updatePeer(peerKey string, allowedIps string, keepAliv
return err
}
peer := wgtypes.PeerConfig{
- PublicKey: peerKeyParsed,
- ReplaceAllowedIPs: true,
+ PublicKey: peerKeyParsed,
+ ReplaceAllowedIPs: false,
+ // don't replace allowed ips, wg will handle duplicated peer IP
AllowedIPs: []net.IPNet{*ipNet},
PersistentKeepaliveInterval: &keepAlive,
PresharedKey: preSharedKey,
From 2d1bf3982dc38502ed4870bfe08a05da8c6f6d84 Mon Sep 17 00:00:00 2001
From: Viktor Liu <17948409+lixmal@users.noreply.github.com>
Date: Wed, 11 Sep 2024 16:20:30 +0200
Subject: [PATCH 39/89] [relay] Improve relay messages (#2574)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Co-authored-by: Zoltán Papp
---
management/server/token_mgr.go | 28 ++++---
management/server/token_mgr_test.go | 3 +-
relay/auth/allow/allow_all.go | 8 +-
relay/auth/hmac/store.go | 18 +++--
relay/auth/hmac/token.go | 11 ---
relay/auth/hmac/v2/algo.go | 40 ++++++++++
relay/auth/hmac/v2/generator.go | 45 ++++++++++++
relay/auth/hmac/v2/hmac_test.go | 110 ++++++++++++++++++++++++++++
relay/auth/hmac/v2/token.go | 39 ++++++++++
relay/auth/hmac/v2/validator.go | 59 +++++++++++++++
relay/auth/hmac/validator.go | 6 +-
relay/auth/validator.go | 31 +++++++-
relay/client/client.go | 33 ++-------
relay/cmd/root.go | 7 +-
relay/messages/address/address.go | 11 +--
relay/messages/auth/auth.go | 10 +--
relay/messages/message.go | 96 +++++++++++++++++++++---
relay/messages/message_test.go | 16 ++++
relay/server/relay.go | 97 +++++++++++++++++-------
19 files changed, 552 insertions(+), 116 deletions(-)
create mode 100644 relay/auth/hmac/v2/algo.go
create mode 100644 relay/auth/hmac/v2/generator.go
create mode 100644 relay/auth/hmac/v2/hmac_test.go
create mode 100644 relay/auth/hmac/v2/token.go
create mode 100644 relay/auth/hmac/v2/validator.go
diff --git a/management/server/token_mgr.go b/management/server/token_mgr.go
index 8a6648a3a..ef8276b59 100644
--- a/management/server/token_mgr.go
+++ b/management/server/token_mgr.go
@@ -4,6 +4,7 @@ import (
"context"
"crypto/sha1"
"crypto/sha256"
+ "encoding/base64"
"fmt"
"sync"
"time"
@@ -12,6 +13,7 @@ import (
"github.com/netbirdio/netbird/management/proto"
auth "github.com/netbirdio/netbird/relay/auth/hmac"
+ authv2 "github.com/netbirdio/netbird/relay/auth/hmac/v2"
)
const defaultDuration = 12 * time.Hour
@@ -30,7 +32,7 @@ type TimeBasedAuthSecretsManager struct {
turnCfg *TURNConfig
relayCfg *Relay
turnHmacToken *auth.TimedHMAC
- relayHmacToken *auth.TimedHMAC
+ relayHmacToken *authv2.Generator
updateManager *PeersUpdateManager
turnCancelMap map[string]chan struct{}
relayCancelMap map[string]chan struct{}
@@ -63,7 +65,11 @@ func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *
duration = defaultDuration
}
- mgr.relayHmacToken = auth.NewTimedHMAC(relayCfg.Secret, duration)
+ hashedSecret := sha256.Sum256([]byte(relayCfg.Secret))
+ var err error
+ if mgr.relayHmacToken, err = authv2.NewGenerator(authv2.AuthAlgoHMACSHA256, hashedSecret[:], duration); err != nil {
+ log.Errorf("failed to create relay token generator: %s", err)
+ }
}
return mgr
@@ -76,7 +82,7 @@ func (m *TimeBasedAuthSecretsManager) GenerateTurnToken() (*Token, error) {
}
turnToken, err := m.turnHmacToken.GenerateToken(sha1.New)
if err != nil {
- return nil, fmt.Errorf("failed to generate TURN token: %s", err)
+ return nil, fmt.Errorf("generate TURN token: %s", err)
}
return (*Token)(turnToken), nil
}
@@ -86,11 +92,15 @@ func (m *TimeBasedAuthSecretsManager) GenerateRelayToken() (*Token, error) {
if m.relayHmacToken == nil {
return nil, fmt.Errorf("relay configuration is not set")
}
- relayToken, err := m.relayHmacToken.GenerateToken(sha256.New)
+ relayToken, err := m.relayHmacToken.GenerateToken()
if err != nil {
- return nil, fmt.Errorf("failed to generate relay token: %s", err)
+ return nil, fmt.Errorf("generate relay token: %s", err)
}
- return (*Token)(relayToken), nil
+
+ return &Token{
+ Payload: string(relayToken.Payload),
+ Signature: base64.StdEncoding.EncodeToString(relayToken.Signature),
+ }, nil
}
func (m *TimeBasedAuthSecretsManager) cancelTURN(peerID string) {
@@ -200,7 +210,7 @@ func (m *TimeBasedAuthSecretsManager) pushNewTURNTokens(ctx context.Context, pee
}
func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, peerID string) {
- relayToken, err := m.relayHmacToken.GenerateToken(sha256.New)
+ relayToken, err := m.relayHmacToken.GenerateToken()
if err != nil {
log.Errorf("failed to generate relay token for peer '%s': %s", peerID, err)
return
@@ -210,8 +220,8 @@ func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, pe
WiretrusteeConfig: &proto.WiretrusteeConfig{
Relay: &proto.RelayConfig{
Urls: m.relayCfg.Addresses,
- TokenPayload: relayToken.Payload,
- TokenSignature: relayToken.Signature,
+ TokenPayload: string(relayToken.Payload),
+ TokenSignature: base64.StdEncoding.EncodeToString(relayToken.Signature),
},
// omit Turns to avoid updates there
},
diff --git a/management/server/token_mgr_test.go b/management/server/token_mgr_test.go
index d59fd3a3f..3e63346c2 100644
--- a/management/server/token_mgr_test.go
+++ b/management/server/token_mgr_test.go
@@ -63,7 +63,8 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
t.Errorf("expected generated relay signature not to be empty, got empty")
}
- validateMAC(t, sha256.New, relayCredentials.Payload, relayCredentials.Signature, []byte(secret))
+ hashedSecret := sha256.Sum256([]byte(secret))
+ validateMAC(t, sha256.New, relayCredentials.Payload, relayCredentials.Signature, hashedSecret[:])
}
func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
diff --git a/relay/auth/allow/allow_all.go b/relay/auth/allow/allow_all.go
index 92845818b..2d30c59c9 100644
--- a/relay/auth/allow/allow_all.go
+++ b/relay/auth/allow/allow_all.go
@@ -1,12 +1,14 @@
package allow
-import "hash"
-
// Auth is a Validator that allows all connections.
// Used this for testing purposes only.
type Auth struct {
}
-func (a *Auth) Validate(func() hash.Hash, any) error {
+func (a *Auth) Validate(any) error {
+ return nil
+}
+
+func (a *Auth) ValidateHelloMsgType(any) error {
return nil
}
diff --git a/relay/auth/hmac/store.go b/relay/auth/hmac/store.go
index 36c195a7b..169b8d6b0 100644
--- a/relay/auth/hmac/store.go
+++ b/relay/auth/hmac/store.go
@@ -1,9 +1,11 @@
package hmac
import (
+ "encoding/base64"
+ "fmt"
"sync"
- log "github.com/sirupsen/logrus"
+ v2 "github.com/netbirdio/netbird/relay/auth/hmac/v2"
)
// TokenStore is a simple in-memory store for token
@@ -20,12 +22,18 @@ func (a *TokenStore) UpdateToken(token *Token) error {
return nil
}
- t, err := marshalToken(*token)
+ sig, err := base64.StdEncoding.DecodeString(token.Signature)
if err != nil {
- log.Debugf("failed to marshal token: %s", err)
- return err
+ return fmt.Errorf("decode signature: %w", err)
}
- a.token = t
+
+ tok := v2.Token{
+ AuthAlgo: v2.AuthAlgoHMACSHA256,
+ Signature: sig,
+ Payload: []byte(token.Payload),
+ }
+
+ a.token = tok.Marshal()
return nil
}
diff --git a/relay/auth/hmac/token.go b/relay/auth/hmac/token.go
index e2e62b84e..581b1d6fd 100644
--- a/relay/auth/hmac/token.go
+++ b/relay/auth/hmac/token.go
@@ -18,17 +18,6 @@ type Token struct {
Signature string
}
-func marshalToken(token Token) ([]byte, error) {
- var buffer bytes.Buffer
- encoder := gob.NewEncoder(&buffer)
- err := encoder.Encode(token)
- if err != nil {
- log.Debugf("failed to marshal token: %s", err)
- return nil, fmt.Errorf("failed to marshal token: %w", err)
- }
- return buffer.Bytes(), nil
-}
-
func unmarshalToken(payload []byte) (Token, error) {
var creds Token
buffer := bytes.NewBuffer(payload)
diff --git a/relay/auth/hmac/v2/algo.go b/relay/auth/hmac/v2/algo.go
new file mode 100644
index 000000000..c379c2bd7
--- /dev/null
+++ b/relay/auth/hmac/v2/algo.go
@@ -0,0 +1,40 @@
+package v2
+
+import (
+ "crypto/sha256"
+ "hash"
+)
+
+const (
+ AuthAlgoUnknown AuthAlgo = iota
+ AuthAlgoHMACSHA256
+)
+
+type AuthAlgo uint8
+
+func (a AuthAlgo) String() string {
+ switch a {
+ case AuthAlgoHMACSHA256:
+ return "HMAC-SHA256"
+ default:
+ return "Unknown"
+ }
+}
+
+func (a AuthAlgo) New() func() hash.Hash {
+ switch a {
+ case AuthAlgoHMACSHA256:
+ return sha256.New
+ default:
+ return nil
+ }
+}
+
+func (a AuthAlgo) Size() int {
+ switch a {
+ case AuthAlgoHMACSHA256:
+ return sha256.Size
+ default:
+ return 0
+ }
+}
diff --git a/relay/auth/hmac/v2/generator.go b/relay/auth/hmac/v2/generator.go
new file mode 100644
index 000000000..827532730
--- /dev/null
+++ b/relay/auth/hmac/v2/generator.go
@@ -0,0 +1,45 @@
+package v2
+
+import (
+ "crypto/hmac"
+ "fmt"
+ "hash"
+ "strconv"
+ "time"
+)
+
+type Generator struct {
+ algo func() hash.Hash
+ algoType AuthAlgo
+ secret []byte
+ timeToLive time.Duration
+}
+
+func NewGenerator(algo AuthAlgo, secret []byte, timeToLive time.Duration) (*Generator, error) {
+ algoFunc := algo.New()
+ if algoFunc == nil {
+ return nil, fmt.Errorf("unsupported auth algorithm: %s", algo)
+ }
+ return &Generator{
+ algo: algoFunc,
+ algoType: algo,
+ secret: secret,
+ timeToLive: timeToLive,
+ }, nil
+}
+
+func (g *Generator) GenerateToken() (*Token, error) {
+ expirationTime := time.Now().Add(g.timeToLive).Unix()
+
+ payload := []byte(strconv.FormatInt(expirationTime, 10))
+
+ h := hmac.New(g.algo, g.secret)
+ h.Write(payload)
+ signature := h.Sum(nil)
+
+ return &Token{
+ AuthAlgo: g.algoType,
+ Signature: signature,
+ Payload: payload,
+ }, nil
+}
diff --git a/relay/auth/hmac/v2/hmac_test.go b/relay/auth/hmac/v2/hmac_test.go
new file mode 100644
index 000000000..40336363f
--- /dev/null
+++ b/relay/auth/hmac/v2/hmac_test.go
@@ -0,0 +1,110 @@
+package v2
+
+import (
+ "strconv"
+ "testing"
+ "time"
+)
+
+func TestGenerateCredentials(t *testing.T) {
+ secret := "supersecret"
+ timeToLive := 1 * time.Hour
+ g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive)
+ if err != nil {
+ t.Fatalf("failed to create generator: %v", err)
+ }
+
+ token, err := g.GenerateToken()
+ if err != nil {
+ t.Fatalf("expected no error, got %v", err)
+ }
+
+ if len(token.Payload) == 0 {
+ t.Fatalf("expected non-empty payload")
+ }
+
+ _, err = strconv.ParseInt(string(token.Payload), 10, 64)
+ if err != nil {
+ t.Fatalf("expected payload to be a valid unix timestamp, got %v", err)
+ }
+}
+
+func TestValidateCredentials(t *testing.T) {
+ secret := "supersecret"
+ timeToLive := 1 * time.Hour
+ g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive)
+ if err != nil {
+ t.Fatalf("failed to create generator: %v", err)
+ }
+
+ token, err := g.GenerateToken()
+ if err != nil {
+ t.Fatalf("expected no error, got %v", err)
+ }
+
+ v := NewValidator([]byte(secret))
+ if err := v.Validate(token.Marshal()); err != nil {
+ t.Fatalf("expected valid token: %s", err)
+ }
+}
+
+func TestInvalidSignature(t *testing.T) {
+ secret := "supersecret"
+ timeToLive := 1 * time.Hour
+ g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive)
+ if err != nil {
+ t.Fatalf("failed to create generator: %v", err)
+ }
+
+ token, err := g.GenerateToken()
+ if err != nil {
+ t.Fatalf("expected no error, got %v", err)
+ }
+
+ token.Signature = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
+
+ v := NewValidator([]byte(secret))
+ if err := v.Validate(token.Marshal()); err == nil {
+ t.Fatalf("expected valid token: %s", err)
+ }
+}
+
+func TestExpired(t *testing.T) {
+ secret := "supersecret"
+ timeToLive := -1 * time.Hour
+ g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive)
+ if err != nil {
+ t.Fatalf("failed to create generator: %v", err)
+ }
+
+ token, err := g.GenerateToken()
+ if err != nil {
+ t.Fatalf("expected no error, got %v", err)
+ }
+
+ v := NewValidator([]byte(secret))
+ if err := v.Validate(token.Marshal()); err == nil {
+ t.Fatalf("expected valid token: %s", err)
+ }
+}
+
+func TestInvalidPayload(t *testing.T) {
+ secret := "supersecret"
+ timeToLive := 1 * time.Hour
+ g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive)
+ if err != nil {
+ t.Fatalf("failed to create generator: %v", err)
+ }
+
+ token, err := g.GenerateToken()
+ if err != nil {
+ t.Fatalf("expected no error, got %v", err)
+ }
+
+ token.Payload = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
+
+ v := NewValidator([]byte(secret))
+ if err := v.Validate(token.Marshal()); err == nil {
+ t.Fatalf("expected invalid token due to invalid payload")
+ }
+}
diff --git a/relay/auth/hmac/v2/token.go b/relay/auth/hmac/v2/token.go
new file mode 100644
index 000000000..553ac01b9
--- /dev/null
+++ b/relay/auth/hmac/v2/token.go
@@ -0,0 +1,39 @@
+package v2
+
+import "errors"
+
+type Token struct {
+ AuthAlgo AuthAlgo
+ Signature []byte
+ Payload []byte
+}
+
+func (t *Token) Marshal() []byte {
+ size := 1 + len(t.Signature) + len(t.Payload)
+
+ buf := make([]byte, size)
+
+ buf[0] = byte(t.AuthAlgo)
+ copy(buf[1:], t.Signature)
+ copy(buf[1+len(t.Signature):], t.Payload)
+
+ return buf
+}
+
+func UnmarshalToken(data []byte) (*Token, error) {
+ if len(data) == 0 {
+ return nil, errors.New("invalid token data")
+ }
+
+ algo := AuthAlgo(data[0])
+ sigSize := algo.Size()
+ if len(data) < 1+sigSize {
+ return nil, errors.New("invalid token data: insufficient length")
+ }
+
+ return &Token{
+ AuthAlgo: algo,
+ Signature: data[1 : 1+sigSize],
+ Payload: data[1+sigSize:],
+ }, nil
+}
diff --git a/relay/auth/hmac/v2/validator.go b/relay/auth/hmac/v2/validator.go
new file mode 100644
index 000000000..7f448dd5f
--- /dev/null
+++ b/relay/auth/hmac/v2/validator.go
@@ -0,0 +1,59 @@
+package v2
+
+import (
+ "crypto/hmac"
+ "errors"
+ "fmt"
+ "strconv"
+ "time"
+)
+
+const minLengthUnixTimestamp = 10
+
+type Validator struct {
+ secret []byte
+}
+
+func NewValidator(secret []byte) *Validator {
+ return &Validator{secret: secret}
+}
+
+func (v *Validator) Validate(data any) error {
+ d, ok := data.([]byte)
+ if !ok {
+ return fmt.Errorf("invalid data type")
+ }
+
+ token, err := UnmarshalToken(d)
+ if err != nil {
+ return fmt.Errorf("unmarshal token: %w", err)
+ }
+
+ if len(token.Payload) < minLengthUnixTimestamp {
+ return errors.New("invalid payload: insufficient length")
+ }
+
+ hashFunc := token.AuthAlgo.New()
+ if hashFunc == nil {
+ return fmt.Errorf("unsupported auth algorithm: %s", token.AuthAlgo)
+ }
+
+ h := hmac.New(hashFunc, v.secret)
+ h.Write(token.Payload)
+ expectedMAC := h.Sum(nil)
+
+ if !hmac.Equal(token.Signature, expectedMAC) {
+ return errors.New("invalid signature")
+ }
+
+ timestamp, err := strconv.ParseInt(string(token.Payload), 10, 64)
+ if err != nil {
+ return fmt.Errorf("invalid payload: %w", err)
+ }
+
+ if time.Now().Unix() > timestamp {
+ return fmt.Errorf("expired token")
+ }
+
+ return nil
+}
diff --git a/relay/auth/hmac/validator.go b/relay/auth/hmac/validator.go
index 6ddd89c19..b0b7542be 100644
--- a/relay/auth/hmac/validator.go
+++ b/relay/auth/hmac/validator.go
@@ -1,8 +1,8 @@
package hmac
import (
+ "crypto/sha256"
"fmt"
- "hash"
"time"
log "github.com/sirupsen/logrus"
@@ -19,7 +19,7 @@ func NewTimedHMACValidator(secret string, duration time.Duration) *TimedHMACVali
}
}
-func (a *TimedHMACValidator) Validate(algo func() hash.Hash, credentials any) error {
+func (a *TimedHMACValidator) Validate(credentials any) error {
b, ok := credentials.([]byte)
if !ok {
return fmt.Errorf("invalid credentials type")
@@ -29,5 +29,5 @@ func (a *TimedHMACValidator) Validate(algo func() hash.Hash, credentials any) er
log.Debugf("failed to unmarshal token: %s", err)
return err
}
- return a.TimedHMAC.Validate(algo, c)
+ return a.TimedHMAC.Validate(sha256.New, c)
}
diff --git a/relay/auth/validator.go b/relay/auth/validator.go
index 078811f3d..854efd5bb 100644
--- a/relay/auth/validator.go
+++ b/relay/auth/validator.go
@@ -1,8 +1,35 @@
package auth
-import "hash"
+import (
+ "time"
+
+ auth "github.com/netbirdio/netbird/relay/auth/hmac"
+ authv2 "github.com/netbirdio/netbird/relay/auth/hmac/v2"
+)
// Validator is an interface that defines the Validate method.
type Validator interface {
- Validate(func() hash.Hash, any) error
+ Validate(any) error
+ // Deprecated: Use Validate instead.
+ ValidateHelloMsgType(any) error
+}
+
+type TimedHMACValidator struct {
+ authenticatorV2 *authv2.Validator
+ authenticator *auth.TimedHMACValidator
+}
+
+func NewTimedHMACValidator(secret []byte, duration time.Duration) *TimedHMACValidator {
+ return &TimedHMACValidator{
+ authenticatorV2: authv2.NewValidator(secret),
+ authenticator: auth.NewTimedHMACValidator(string(secret), duration),
+ }
+}
+
+func (a *TimedHMACValidator) Validate(credentials any) error {
+ return a.authenticatorV2.Validate(credentials)
+}
+
+func (a *TimedHMACValidator) ValidateHelloMsgType(credentials any) error {
+ return a.authenticator.Validate(credentials)
}
diff --git a/relay/client/client.go b/relay/client/client.go
index 1160d1c9e..6560c81e1 100644
--- a/relay/client/client.go
+++ b/relay/client/client.go
@@ -14,8 +14,6 @@ import (
"github.com/netbirdio/netbird/relay/client/dialer/ws"
"github.com/netbirdio/netbird/relay/healthcheck"
"github.com/netbirdio/netbird/relay/messages"
- "github.com/netbirdio/netbird/relay/messages/address"
- auth2 "github.com/netbirdio/netbird/relay/messages/auth"
)
const (
@@ -240,31 +238,21 @@ func (c *Client) connect() error {
}
func (c *Client) handShake() error {
- authMsg := &auth2.Msg{
- AuthAlgorithm: auth2.AlgoHMACSHA256,
- AdditionalData: c.authTokenStore.TokenBinary(),
- }
-
- authData, err := authMsg.Marshal()
+ msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary())
if err != nil {
- return fmt.Errorf("marshal auth message: %w", err)
- }
-
- msg, err := messages.MarshalHelloMsg(c.hashedID, authData)
- if err != nil {
- log.Errorf("failed to marshal hello message: %s", err)
+ log.Errorf("failed to marshal auth message: %s", err)
return err
}
_, err = c.relayConn.Write(msg)
if err != nil {
- log.Errorf("failed to send hello message: %s", err)
+ log.Errorf("failed to send auth message: %s", err)
return err
}
- buf := make([]byte, messages.MaxHandshakeSize)
+ buf := make([]byte, messages.MaxHandshakeRespSize)
n, err := c.readWithTimeout(buf)
if err != nil {
- log.Errorf("failed to read hello response: %s", err)
+ log.Errorf("failed to read auth response: %s", err)
return err
}
@@ -279,23 +267,18 @@ func (c *Client) handShake() error {
return err
}
- if msgType != messages.MsgTypeHelloResponse {
+ if msgType != messages.MsgTypeAuthResponse {
log.Errorf("unexpected message type: %s", msgType)
return fmt.Errorf("unexpected message type")
}
- additionalData, err := messages.UnmarshalHelloResponse(buf[messages.SizeOfProtoHeader:n])
+ addr, err := messages.UnmarshalAuthResponse(buf[messages.SizeOfProtoHeader:n])
if err != nil {
return err
}
- addr, err := address.Unmarshal(additionalData)
- if err != nil {
- return fmt.Errorf("unmarshal address: %w", err)
- }
-
c.muInstanceURL.Lock()
- c.instanceURL = &RelayAddr{addr: addr.URL}
+ c.instanceURL = &RelayAddr{addr: addr}
c.muInstanceURL.Unlock()
return nil
}
diff --git a/relay/cmd/root.go b/relay/cmd/root.go
index 784b42c1a..dcc1465d0 100644
--- a/relay/cmd/root.go
+++ b/relay/cmd/root.go
@@ -2,6 +2,7 @@ package cmd
import (
"context"
+ "crypto/sha256"
"crypto/tls"
"errors"
"fmt"
@@ -16,7 +17,7 @@ import (
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/encryption"
- auth "github.com/netbirdio/netbird/relay/auth/hmac"
+ "github.com/netbirdio/netbird/relay/auth"
"github.com/netbirdio/netbird/relay/server"
"github.com/netbirdio/netbird/signal/metrics"
"github.com/netbirdio/netbird/util"
@@ -139,7 +140,9 @@ func execute(cmd *cobra.Command, args []string) error {
}
srvListenerCfg.TLSConfig = tlsConfig
- authenticator := auth.NewTimedHMACValidator(cobraConfig.AuthSecret, 24*time.Hour)
+ hashedSecret := sha256.Sum256([]byte(cobraConfig.AuthSecret))
+ authenticator := auth.NewTimedHMACValidator(hashedSecret[:], 24*time.Hour)
+
srv, err := server.NewServer(metricsServer.Meter, cobraConfig.ExposedAddress, tlsSupport, authenticator)
if err != nil {
log.Debugf("failed to create relay server: %v", err)
diff --git a/relay/messages/address/address.go b/relay/messages/address/address.go
index 829206294..707e73e55 100644
--- a/relay/messages/address/address.go
+++ b/relay/messages/address/address.go
@@ -1,3 +1,4 @@
+// Deprecated: This package is deprecated and will be removed in a future release.
package address
import (
@@ -18,13 +19,3 @@ func (addr *Address) Marshal() ([]byte, error) {
}
return buf.Bytes(), nil
}
-
-func Unmarshal(data []byte) (*Address, error) {
- var addr Address
- buf := bytes.NewBuffer(data)
- dec := gob.NewDecoder(buf)
- if err := dec.Decode(&addr); err != nil {
- return nil, fmt.Errorf("decode Address: %w", err)
- }
- return &addr, nil
-}
diff --git a/relay/messages/auth/auth.go b/relay/messages/auth/auth.go
index 8230bccf2..9c2511f2f 100644
--- a/relay/messages/auth/auth.go
+++ b/relay/messages/auth/auth.go
@@ -1,3 +1,4 @@
+// Deprecated: This package is deprecated and will be removed in a future release.
package auth
import (
@@ -30,15 +31,6 @@ type Msg struct {
AdditionalData []byte
}
-func (msg *Msg) Marshal() ([]byte, error) {
- var buf bytes.Buffer
- enc := gob.NewEncoder(&buf)
- if err := enc.Encode(msg); err != nil {
- return nil, fmt.Errorf("encode Msg: %w", err)
- }
- return buf.Bytes(), nil
-}
-
func UnmarshalMsg(data []byte) (*Msg, error) {
var msg *Msg
diff --git a/relay/messages/message.go b/relay/messages/message.go
index cfcac3f72..39ca0aa90 100644
--- a/relay/messages/message.go
+++ b/relay/messages/message.go
@@ -7,12 +7,21 @@ import (
)
const (
- MsgTypeUnknown MsgType = 0
- MsgTypeHello MsgType = 1
+ MaxHandshakeSize = 212
+ MaxHandshakeRespSize = 8192
+
+ CurrentProtocolVersion = 1
+
+ MsgTypeUnknown MsgType = 0
+ // Deprecated: Use MsgTypeAuth instead.
+ MsgTypeHello MsgType = 1
+ // Deprecated: Use MsgTypeAuthResponse instead.
MsgTypeHelloResponse MsgType = 2
MsgTypeTransport MsgType = 3
MsgTypeClose MsgType = 4
MsgTypeHealthCheck MsgType = 5
+ MsgTypeAuth = 6
+ MsgTypeAuthResponse = 7
SizeOfVersionByte = 1
SizeOfMsgType = 1
@@ -22,12 +31,12 @@ const (
sizeOfMagicByte = 4
headerSizeTransport = IDSize
+
headerSizeHello = sizeOfMagicByte + IDSize
headerSizeHelloResp = 0
- MaxHandshakeSize = 8192
-
- CurrentProtocolVersion = 1
+ headerSizeAuth = sizeOfMagicByte + IDSize
+ headerSizeAuthResp = 0
)
var (
@@ -47,6 +56,10 @@ func (m MsgType) String() string {
return "hello"
case MsgTypeHelloResponse:
return "hello response"
+ case MsgTypeAuth:
+ return "auth"
+ case MsgTypeAuthResponse:
+ return "auth response"
case MsgTypeTransport:
return "transport"
case MsgTypeClose:
@@ -58,10 +71,6 @@ func (m MsgType) String() string {
}
}
-type HelloResponse struct {
- InstanceAddress string
-}
-
// ValidateVersion checks if the given version is supported by the protocol
func ValidateVersion(msg []byte) (int, error) {
if len(msg) < SizeOfVersionByte {
@@ -84,6 +93,7 @@ func DetermineClientMessageType(msg []byte) (MsgType, error) {
switch msgType {
case
MsgTypeHello,
+ MsgTypeAuth,
MsgTypeTransport,
MsgTypeClose,
MsgTypeHealthCheck:
@@ -103,6 +113,7 @@ func DetermineServerMessageType(msg []byte) (MsgType, error) {
switch msgType {
case
MsgTypeHelloResponse,
+ MsgTypeAuthResponse,
MsgTypeTransport,
MsgTypeClose,
MsgTypeHealthCheck:
@@ -112,6 +123,7 @@ func DetermineServerMessageType(msg []byte) (MsgType, error) {
}
}
+// Deprecated: Use MarshalAuthMsg instead.
// MarshalHelloMsg initial hello message
// The Hello message is the first message sent by a client after establishing a connection with the Relay server. This
// message is used to authenticate the client with the server. The authentication is done using an HMAC method.
@@ -135,6 +147,7 @@ func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) {
return msg, nil
}
+// Deprecated: Use UnmarshalAuthMsg instead.
// UnmarshalHelloMsg extracts peerID and the additional data from the hello message. The Additional data is used to
// authenticate the client with the server.
func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
@@ -148,6 +161,7 @@ func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
return msg[sizeOfMagicByte:headerSizeHello], msg[headerSizeHello:], nil
}
+// Deprecated: Use MarshalAuthResponse instead.
// MarshalHelloResponse creates a response message to the hello message.
// In case of success connection the server response with a Hello Response message. This message contains the server's
// instance URL. This URL will be used by choose the common Relay server in case if the peers are in different Relay
@@ -163,6 +177,7 @@ func MarshalHelloResponse(additionalData []byte) ([]byte, error) {
return msg, nil
}
+// Deprecated: Use UnmarshalAuthResponse instead.
// UnmarshalHelloResponse extracts the additional data from the hello response message.
func UnmarshalHelloResponse(msg []byte) ([]byte, error) {
if len(msg) < headerSizeHelloResp {
@@ -171,6 +186,69 @@ func UnmarshalHelloResponse(msg []byte) ([]byte, error) {
return msg, nil
}
+// MarshalAuthMsg initial authentication message
+// The Auth message is the first message sent by a client after establishing a connection with the Relay server. This
+// message is used to authenticate the client with the server. The authentication is done using an HMAC method.
+// The protocol does not limit to use HMAC, it can be any other method. If the authentication failed the server will
+// close the network connection without any response.
+func MarshalAuthMsg(peerID []byte, authPayload []byte) ([]byte, error) {
+ if len(peerID) != IDSize {
+ return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
+ }
+
+ msg := make([]byte, SizeOfProtoHeader+sizeOfMagicByte, SizeOfProtoHeader+headerSizeAuth+len(authPayload))
+
+ msg[0] = byte(CurrentProtocolVersion)
+ msg[1] = byte(MsgTypeAuth)
+
+ copy(msg[SizeOfProtoHeader:SizeOfProtoHeader+sizeOfMagicByte], magicHeader)
+
+ msg = append(msg, peerID...)
+ msg = append(msg, authPayload...)
+
+ return msg, nil
+}
+
+// UnmarshalAuthMsg extracts peerID and the auth payload from the message
+func UnmarshalAuthMsg(msg []byte) ([]byte, []byte, error) {
+ if len(msg) < headerSizeAuth {
+ return nil, nil, ErrInvalidMessageLength
+ }
+ if !bytes.Equal(msg[:sizeOfMagicByte], magicHeader) {
+ return nil, nil, errors.New("invalid magic header")
+ }
+
+ return msg[sizeOfMagicByte:headerSizeAuth], msg[headerSizeAuth:], nil
+}
+
+// MarshalAuthResponse creates a response message to the auth.
+// In case of success connection the server response with a AuthResponse message. This message contains the server's
+// instance URL. This URL will be used by choose the common Relay server in case if the peers are in different Relay
+// servers.
+func MarshalAuthResponse(address string) ([]byte, error) {
+ ab := []byte(address)
+ msg := make([]byte, SizeOfProtoHeader, SizeOfProtoHeader+headerSizeAuthResp+len(ab))
+
+ msg[0] = byte(CurrentProtocolVersion)
+ msg[1] = byte(MsgTypeAuthResponse)
+
+ msg = append(msg, ab...)
+
+ if len(msg) > MaxHandshakeRespSize {
+ return nil, fmt.Errorf("invalid message length: %d", len(msg))
+ }
+
+ return msg, nil
+}
+
+// UnmarshalAuthResponse it is a confirmation message to auth success
+func UnmarshalAuthResponse(msg []byte) (string, error) {
+ if len(msg) < headerSizeAuthResp+1 {
+ return "", ErrInvalidMessageLength
+ }
+ return string(msg), nil
+}
+
// MarshalCloseMsg creates a close message.
// The close message is used to close the connection gracefully between the client and the server. The server and the
// client can send this message. After receiving this message, the server or client will close the connection.
diff --git a/relay/messages/message_test.go b/relay/messages/message_test.go
index a4e7d9fae..6e917da71 100644
--- a/relay/messages/message_test.go
+++ b/relay/messages/message_test.go
@@ -20,6 +20,22 @@ func TestMarshalHelloMsg(t *testing.T) {
}
}
+func TestMarshalAuthMsg(t *testing.T) {
+ peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
+ bHello, err := MarshalAuthMsg(peerID, []byte{})
+ if err != nil {
+ t.Fatalf("error: %v", err)
+ }
+
+ receivedPeerID, _, err := UnmarshalAuthMsg(bHello[SizeOfProtoHeader:])
+ if err != nil {
+ t.Fatalf("error: %v", err)
+ }
+ if string(receivedPeerID) != string(peerID) {
+ t.Errorf("expected %s, got %s", peerID, receivedPeerID)
+ }
+}
+
func TestMarshalTransportMsg(t *testing.T) {
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
payload := []byte("payload")
diff --git a/relay/server/relay.go b/relay/server/relay.go
index 6d88cbbb2..76c01a697 100644
--- a/relay/server/relay.go
+++ b/relay/server/relay.go
@@ -2,7 +2,6 @@ package server
import (
"context"
- "crypto/sha256"
"fmt"
"net"
"net/url"
@@ -14,7 +13,9 @@ import (
"github.com/netbirdio/netbird/relay/auth"
"github.com/netbirdio/netbird/relay/messages"
+ //nolint:staticcheck
"github.com/netbirdio/netbird/relay/messages/address"
+ //nolint:staticcheck
authmsg "github.com/netbirdio/netbird/relay/messages/auth"
"github.com/netbirdio/netbird/relay/metrics"
)
@@ -168,39 +169,81 @@ func (r *Relay) handshake(conn net.Conn) ([]byte, error) {
return nil, fmt.Errorf("determine message type from %s: %w", conn.RemoteAddr(), err)
}
- if msgType != messages.MsgTypeHello {
- return nil, fmt.Errorf("invalid message type from %s", conn.RemoteAddr())
+ var (
+ responseMsg []byte
+ peerID []byte
+ )
+ switch msgType {
+ //nolint:staticcheck
+ case messages.MsgTypeHello:
+ peerID, responseMsg, err = r.handleHelloMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr())
+ case messages.MsgTypeAuth:
+ peerID, responseMsg, err = r.handleAuthMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr())
+ default:
+ return nil, fmt.Errorf("invalid message type %d from %s", msgType, conn.RemoteAddr())
}
-
- peerID, authData, err := messages.UnmarshalHelloMsg(buf[messages.SizeOfProtoHeader:n])
if err != nil {
- return nil, fmt.Errorf("unmarshal hello message: %w", err)
+ return nil, err
}
- authMsg, err := authmsg.UnmarshalMsg(authData)
- if err != nil {
- return nil, fmt.Errorf("unmarshal auth message: %w", err)
- }
-
- if err := r.validator.Validate(sha256.New, authMsg.AdditionalData); err != nil {
- return nil, fmt.Errorf("validate %s (%s): %w", peerID, conn.RemoteAddr(), err)
- }
-
- addr := &address.Address{URL: r.instanceURL}
- addrData, err := addr.Marshal()
- if err != nil {
- return nil, fmt.Errorf("marshal addressc to %s (%s): %w", peerID, conn.RemoteAddr(), err)
- }
-
- msg, err := messages.MarshalHelloResponse(addrData)
- if err != nil {
- return nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, conn.RemoteAddr(), err)
- }
-
- _, err = conn.Write(msg)
+ _, err = conn.Write(responseMsg)
if err != nil {
return nil, fmt.Errorf("write to %s (%s): %w", peerID, conn.RemoteAddr(), err)
}
return peerID, nil
}
+
+func (r *Relay) handleHelloMsg(buf []byte, remoteAddr net.Addr) ([]byte, []byte, error) {
+ //nolint:staticcheck
+ rawPeerID, authData, err := messages.UnmarshalHelloMsg(buf)
+ if err != nil {
+ return nil, nil, fmt.Errorf("unmarshal hello message: %w", err)
+ }
+
+ peerID := messages.HashIDToString(rawPeerID)
+ log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, remoteAddr)
+
+ authMsg, err := authmsg.UnmarshalMsg(authData)
+ if err != nil {
+ return nil, nil, fmt.Errorf("unmarshal auth message: %w", err)
+ }
+
+ //nolint:staticcheck
+ if err := r.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil {
+ return nil, nil, fmt.Errorf("validate %s (%s): %w", peerID, remoteAddr, err)
+ }
+
+ addr := &address.Address{URL: r.instanceURL}
+ addrData, err := addr.Marshal()
+ if err != nil {
+ return nil, nil, fmt.Errorf("marshal addressc to %s (%s): %w", peerID, remoteAddr, err)
+ }
+
+ //nolint:staticcheck
+ responseMsg, err := messages.MarshalHelloResponse(addrData)
+ if err != nil {
+ return nil, nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, remoteAddr, err)
+ }
+ return rawPeerID, responseMsg, nil
+}
+
+func (r *Relay) handleAuthMsg(buf []byte, addr net.Addr) ([]byte, []byte, error) {
+ rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf)
+ if err != nil {
+ return nil, nil, fmt.Errorf("unmarshal hello message: %w", err)
+ }
+
+ peerID := messages.HashIDToString(rawPeerID)
+
+ if err := r.validator.Validate(authPayload); err != nil {
+ return nil, nil, fmt.Errorf("validate %s (%s): %w", peerID, addr, err)
+ }
+
+ responseMsg, err := messages.MarshalAuthResponse(r.instanceURL)
+ if err != nil {
+ return nil, nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, addr, err)
+ }
+
+ return rawPeerID, responseMsg, nil
+}
From 9cfc8f8aa48d605467e1bfd80ed4608cc3d9e3ad Mon Sep 17 00:00:00 2001
From: Zoltan Papp
Date: Wed, 11 Sep 2024 18:36:19 +0200
Subject: [PATCH 40/89] [relay] change log levels (#2580)
---
relay/metrics/realy.go | 2 +-
relay/server/peer.go | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/relay/metrics/realy.go b/relay/metrics/realy.go
index 80e12ee6b..13799713a 100644
--- a/relay/metrics/realy.go
+++ b/relay/metrics/realy.go
@@ -103,7 +103,7 @@ func (m *Metrics) PeerActivity(peerID string) {
select {
case m.peerActivityChan <- peerID:
default:
- log.Errorf("peer activity channel is full, dropping activity metrics for peer %s", peerID)
+ log.Tracef("peer activity channel is full, dropping activity metrics for peer %s", peerID)
}
}
diff --git a/relay/server/peer.go b/relay/server/peer.go
index a9583700a..0de601996 100644
--- a/relay/server/peer.go
+++ b/relay/server/peer.go
@@ -184,7 +184,7 @@ func (p *Peer) handleTransportMsg(msg []byte) {
stringPeerID := messages.HashIDToString(peerID)
dp, ok := p.store.Peer(stringPeerID)
if !ok {
- p.log.Errorf("peer not found: %s", stringPeerID)
+ p.log.Debugf("peer not found: %s", stringPeerID)
return
}
From 47adb976f8e5d5eb866d347b6ba1031d219ad230 Mon Sep 17 00:00:00 2001
From: Maycon Santos
Date: Wed, 11 Sep 2024 18:59:19 +0200
Subject: [PATCH 41/89] Remove pre-release step from workflow (#2583)
---
.github/workflows/test-infrastructure-files.yml | 6 ------
1 file changed, 6 deletions(-)
diff --git a/.github/workflows/test-infrastructure-files.yml b/.github/workflows/test-infrastructure-files.yml
index d1aef3324..d1c2b3aef 100644
--- a/.github/workflows/test-infrastructure-files.yml
+++ b/.github/workflows/test-infrastructure-files.yml
@@ -221,9 +221,6 @@ jobs:
- name: Checkout code
uses: actions/checkout@v3
- - name: handle insisting image # remove after release
- run: docker pull netbirdio/relay:latest || docker pull netbirdio/signal:latest && docker tag netbirdio/signal:latest netbirdio/relay:latest
-
- name: run script with Zitadel PostgreSQL
run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh
@@ -259,9 +256,6 @@ jobs:
docker compose down --volumes --rmi all
rm -rf docker-compose.yml Caddyfile zitadel.env dashboard.env machinekey/zitadel-admin-sa.token turnserver.conf management.json zdb.env
- - name: handle insisting image gen CockroachDB # remove after release
- run: docker pull netbirdio/relay:latest || docker pull netbirdio/signal:latest && docker tag netbirdio/signal:latest netbirdio/relay:latest
-
- name: run script with Zitadel CockroachDB
run: bash -x infrastructure_files/getting-started-with-zitadel.sh
env:
From c59a39d27dc40e88d5c48cbed27e496d1ef1d9be Mon Sep 17 00:00:00 2001
From: Maycon Santos
Date: Wed, 11 Sep 2024 19:05:10 +0200
Subject: [PATCH 42/89] Update service package version (#2582)
---
go.mod | 2 +-
go.sum | 4 ++--
2 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/go.mod b/go.mod
index 7d5817769..16b5b55e4 100644
--- a/go.mod
+++ b/go.mod
@@ -232,7 +232,7 @@ require (
k8s.io/apimachinery v0.26.2 // indirect
)
-replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20240904111318-17777758453a
+replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
diff --git a/go.sum b/go.sum
index 7a587c0d1..2355f6f0c 100644
--- a/go.sum
+++ b/go.sum
@@ -523,8 +523,8 @@ github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6R
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e h1:LYxhAmiEzSldLELHSMVoUnRPq3ztTNQImrD27frrGsI=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y=
-github.com/netbirdio/service v0.0.0-20240904111318-17777758453a h1:2EcDFDT39Odz5EC38pOSyjCd3bLUjPi7pMQpH6k+zzk=
-github.com/netbirdio/service v0.0.0-20240904111318-17777758453a/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
+github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
+github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240820130728-bc0683599080 h1:mXJkoWLdqJTlkQ7DgQ536kcXHXIdUPeagkN8i4eFDdg=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240820130728-bc0683599080/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs=
From cf6210a6f42355e88c422c624376f6fcdaea6729 Mon Sep 17 00:00:00 2001
From: Bethuel Mmbaga
Date: Wed, 11 Sep 2024 20:09:57 +0300
Subject: [PATCH 43/89] [management] Add GCM encryption and migrate legacy
encrypted events (#2569)
* Add AES-GCM encryption
Signed-off-by: bcmmbaga
* migrate legacy encrypted data to AES-GCM encryption
Signed-off-by: bcmmbaga
* Refactor and use transaction when migrating data
Signed-off-by: bcmmbaga
* Add events migration tests
Signed-off-by: bcmmbaga
* fix lint
Signed-off-by: bcmmbaga
* skip migrating record on error
Signed-off-by: bcmmbaga
* Preallocate capacity for nonce to avoid allocations in Seal
Signed-off-by: bcmmbaga
---------
Signed-off-by: bcmmbaga
---
management/server/activity/sqlite/crypt.go | 49 +++++-
.../server/activity/sqlite/crypt_test.go | 38 ++++-
.../server/activity/sqlite/migration.go | 157 ++++++++++++++++++
.../server/activity/sqlite/migration_test.go | 84 ++++++++++
management/server/activity/sqlite/sqlite.go | 154 ++++++++---------
5 files changed, 396 insertions(+), 86 deletions(-)
create mode 100644 management/server/activity/sqlite/migration.go
create mode 100644 management/server/activity/sqlite/migration_test.go
diff --git a/management/server/activity/sqlite/crypt.go b/management/server/activity/sqlite/crypt.go
index cf4dda746..852d9bc4a 100644
--- a/management/server/activity/sqlite/crypt.go
+++ b/management/server/activity/sqlite/crypt.go
@@ -6,6 +6,7 @@ import (
"crypto/cipher"
"crypto/rand"
"encoding/base64"
+ "errors"
"fmt"
)
@@ -13,6 +14,7 @@ var iv = []byte{10, 22, 13, 79, 05, 8, 52, 91, 87, 98, 88, 98, 35, 25, 13, 05}
type FieldEncrypt struct {
block cipher.Block
+ gcm cipher.AEAD
}
func GenerateKey() (string, error) {
@@ -35,14 +37,21 @@ func NewFieldEncrypt(key string) (*FieldEncrypt, error) {
if err != nil {
return nil, err
}
+
+ gcm, err := cipher.NewGCM(block)
+ if err != nil {
+ return nil, err
+ }
+
ec := &FieldEncrypt{
block: block,
+ gcm: gcm,
}
return ec, nil
}
-func (ec *FieldEncrypt) Encrypt(payload string) string {
+func (ec *FieldEncrypt) LegacyEncrypt(payload string) string {
plainText := pkcs5Padding([]byte(payload))
cipherText := make([]byte, len(plainText))
cbc := cipher.NewCBCEncrypter(ec.block, iv)
@@ -50,7 +59,22 @@ func (ec *FieldEncrypt) Encrypt(payload string) string {
return base64.StdEncoding.EncodeToString(cipherText)
}
-func (ec *FieldEncrypt) Decrypt(data string) (string, error) {
+// Encrypt encrypts plaintext using AES-GCM
+func (ec *FieldEncrypt) Encrypt(payload string) (string, error) {
+ plaintext := []byte(payload)
+ nonceSize := ec.gcm.NonceSize()
+
+ nonce := make([]byte, nonceSize, len(plaintext)+nonceSize+ec.gcm.Overhead())
+ if _, err := rand.Read(nonce); err != nil {
+ return "", err
+ }
+
+ ciphertext := ec.gcm.Seal(nonce, nonce, plaintext, nil)
+
+ return base64.StdEncoding.EncodeToString(ciphertext), nil
+}
+
+func (ec *FieldEncrypt) LegacyDecrypt(data string) (string, error) {
cipherText, err := base64.StdEncoding.DecodeString(data)
if err != nil {
return "", err
@@ -65,6 +89,27 @@ func (ec *FieldEncrypt) Decrypt(data string) (string, error) {
return string(payload), nil
}
+// Decrypt decrypts ciphertext using AES-GCM
+func (ec *FieldEncrypt) Decrypt(data string) (string, error) {
+ cipherText, err := base64.StdEncoding.DecodeString(data)
+ if err != nil {
+ return "", err
+ }
+
+ nonceSize := ec.gcm.NonceSize()
+ if len(cipherText) < nonceSize {
+ return "", errors.New("cipher text too short")
+ }
+
+ nonce, cipherText := cipherText[:nonceSize], cipherText[nonceSize:]
+ plainText, err := ec.gcm.Open(nil, nonce, cipherText, nil)
+ if err != nil {
+ return "", err
+ }
+
+ return string(plainText), nil
+}
+
func pkcs5Padding(ciphertext []byte) []byte {
padding := aes.BlockSize - len(ciphertext)%aes.BlockSize
padText := bytes.Repeat([]byte{byte(padding)}, padding)
diff --git a/management/server/activity/sqlite/crypt_test.go b/management/server/activity/sqlite/crypt_test.go
index efa740921..1033ab6ed 100644
--- a/management/server/activity/sqlite/crypt_test.go
+++ b/management/server/activity/sqlite/crypt_test.go
@@ -15,7 +15,11 @@ func TestGenerateKey(t *testing.T) {
t.Fatalf("failed to init email encryption: %s", err)
}
- encrypted := ee.Encrypt(testData)
+ encrypted, err := ee.Encrypt(testData)
+ if err != nil {
+ t.Fatalf("failed to encrypt data: %s", err)
+ }
+
if encrypted == "" {
t.Fatalf("invalid encrypted text")
}
@@ -30,6 +34,32 @@ func TestGenerateKey(t *testing.T) {
}
}
+func TestGenerateKeyLegacy(t *testing.T) {
+ testData := "exampl@netbird.io"
+ key, err := GenerateKey()
+ if err != nil {
+ t.Fatalf("failed to generate key: %s", err)
+ }
+ ee, err := NewFieldEncrypt(key)
+ if err != nil {
+ t.Fatalf("failed to init email encryption: %s", err)
+ }
+
+ encrypted := ee.LegacyEncrypt(testData)
+ if encrypted == "" {
+ t.Fatalf("invalid encrypted text")
+ }
+
+ decrypted, err := ee.LegacyDecrypt(encrypted)
+ if err != nil {
+ t.Fatalf("failed to decrypt data: %s", err)
+ }
+
+ if decrypted != testData {
+ t.Fatalf("decrypted data is not match with test data: %s, %s", testData, decrypted)
+ }
+}
+
func TestCorruptKey(t *testing.T) {
testData := "exampl@netbird.io"
key, err := GenerateKey()
@@ -41,7 +71,11 @@ func TestCorruptKey(t *testing.T) {
t.Fatalf("failed to init email encryption: %s", err)
}
- encrypted := ee.Encrypt(testData)
+ encrypted, err := ee.Encrypt(testData)
+ if err != nil {
+ t.Fatalf("failed to encrypt data: %s", err)
+ }
+
if encrypted == "" {
t.Fatalf("invalid encrypted text")
}
diff --git a/management/server/activity/sqlite/migration.go b/management/server/activity/sqlite/migration.go
new file mode 100644
index 000000000..28c5b3020
--- /dev/null
+++ b/management/server/activity/sqlite/migration.go
@@ -0,0 +1,157 @@
+package sqlite
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+
+ log "github.com/sirupsen/logrus"
+)
+
+func migrate(ctx context.Context, crypt *FieldEncrypt, db *sql.DB) error {
+ if _, err := db.Exec(createTableQuery); err != nil {
+ return err
+ }
+
+ if _, err := db.Exec(creatTableDeletedUsersQuery); err != nil {
+ return err
+ }
+
+ if err := updateDeletedUsersTable(ctx, db); err != nil {
+ return fmt.Errorf("failed to update deleted_users table: %v", err)
+ }
+
+ return migrateLegacyEncryptedUsersToGCM(ctx, crypt, db)
+}
+
+// updateDeletedUsersTable checks and updates the deleted_users table schema to ensure required columns exist.
+func updateDeletedUsersTable(ctx context.Context, db *sql.DB) error {
+ exists, err := checkColumnExists(db, "deleted_users", "name")
+ if err != nil {
+ return err
+ }
+
+ if !exists {
+ log.WithContext(ctx).Debug("Adding name column to the deleted_users table")
+
+ _, err = db.Exec(`ALTER TABLE deleted_users ADD COLUMN name TEXT;`)
+ if err != nil {
+ return err
+ }
+
+ log.WithContext(ctx).Debug("Successfully added name column to the deleted_users table")
+ }
+
+ exists, err = checkColumnExists(db, "deleted_users", "enc_algo")
+ if err != nil {
+ return err
+ }
+
+ if !exists {
+ log.WithContext(ctx).Debug("Adding enc_algo column to the deleted_users table")
+
+ _, err = db.Exec(`ALTER TABLE deleted_users ADD COLUMN enc_algo TEXT;`)
+ if err != nil {
+ return err
+ }
+
+ log.WithContext(ctx).Debug("Successfully added enc_algo column to the deleted_users table")
+ }
+
+ return nil
+}
+
+// migrateLegacyEncryptedUsersToGCM migrates previously encrypted data using,
+// legacy CBC encryption with a static IV to the new GCM encryption method.
+func migrateLegacyEncryptedUsersToGCM(ctx context.Context, crypt *FieldEncrypt, db *sql.DB) error {
+ log.WithContext(ctx).Debug("Migrating CBC encrypted deleted users to GCM")
+
+ tx, err := db.Begin()
+ if err != nil {
+ return fmt.Errorf("failed to begin transaction: %v", err)
+ }
+ defer func() {
+ _ = tx.Rollback()
+ }()
+
+ rows, err := tx.Query(fmt.Sprintf(`SELECT id, email, name FROM deleted_users where enc_algo IS NULL OR enc_algo != '%s'`, gcmEncAlgo))
+ if err != nil {
+ return fmt.Errorf("failed to execute select query: %v", err)
+ }
+ defer rows.Close()
+
+ updateStmt, err := tx.Prepare(`UPDATE deleted_users SET email = ?, name = ?, enc_algo = ? WHERE id = ?`)
+ if err != nil {
+ return fmt.Errorf("failed to prepare update statement: %v", err)
+ }
+ defer updateStmt.Close()
+
+ if err = processUserRows(ctx, crypt, rows, updateStmt); err != nil {
+ return err
+ }
+
+ if err = tx.Commit(); err != nil {
+ return fmt.Errorf("failed to commit transaction: %v", err)
+ }
+
+ log.WithContext(ctx).Debug("Successfully migrated CBC encrypted deleted users to GCM")
+ return nil
+}
+
+// processUserRows processes database rows of user data, decrypts legacy encryption fields, and re-encrypts them using GCM.
+func processUserRows(ctx context.Context, crypt *FieldEncrypt, rows *sql.Rows, updateStmt *sql.Stmt) error {
+ for rows.Next() {
+ var (
+ id, decryptedEmail, decryptedName string
+ email, name *string
+ )
+
+ err := rows.Scan(&id, &email, &name)
+ if err != nil {
+ return err
+ }
+
+ if email != nil {
+ decryptedEmail, err = crypt.LegacyDecrypt(*email)
+ if err != nil {
+ log.WithContext(ctx).Warnf("skipping migrating deleted user %s: %v",
+ id,
+ fmt.Errorf("failed to decrypt email: %w", err),
+ )
+ continue
+ }
+ }
+
+ if name != nil {
+ decryptedName, err = crypt.LegacyDecrypt(*name)
+ if err != nil {
+ log.WithContext(ctx).Warnf("skipping migrating deleted user %s: %v",
+ id,
+ fmt.Errorf("failed to decrypt name: %w", err),
+ )
+ continue
+ }
+ }
+
+ encryptedEmail, err := crypt.Encrypt(decryptedEmail)
+ if err != nil {
+ return fmt.Errorf("failed to encrypt email: %w", err)
+ }
+
+ encryptedName, err := crypt.Encrypt(decryptedName)
+ if err != nil {
+ return fmt.Errorf("failed to encrypt name: %w", err)
+ }
+
+ _, err = updateStmt.Exec(encryptedEmail, encryptedName, gcmEncAlgo, id)
+ if err != nil {
+ return err
+ }
+ }
+
+ if err := rows.Err(); err != nil {
+ return err
+ }
+
+ return nil
+}
diff --git a/management/server/activity/sqlite/migration_test.go b/management/server/activity/sqlite/migration_test.go
new file mode 100644
index 000000000..a03774fa8
--- /dev/null
+++ b/management/server/activity/sqlite/migration_test.go
@@ -0,0 +1,84 @@
+package sqlite
+
+import (
+ "context"
+ "database/sql"
+ "path/filepath"
+ "testing"
+ "time"
+
+ _ "github.com/mattn/go-sqlite3"
+ "github.com/netbirdio/netbird/management/server/activity"
+
+ "github.com/stretchr/testify/require"
+)
+
+func setupDatabase(t *testing.T) *sql.DB {
+ t.Helper()
+
+ dbFile := filepath.Join(t.TempDir(), eventSinkDB)
+ db, err := sql.Open("sqlite3", dbFile)
+ require.NoError(t, err, "Failed to open database")
+
+ t.Cleanup(func() {
+ _ = db.Close()
+ })
+
+ _, err = db.Exec(createTableQuery)
+ require.NoError(t, err, "Failed to create events table")
+
+ _, err = db.Exec(`CREATE TABLE deleted_users (id TEXT NOT NULL, email TEXT NOT NULL, name TEXT);`)
+ require.NoError(t, err, "Failed to create deleted_users table")
+
+ return db
+}
+
+func TestMigrate(t *testing.T) {
+ db := setupDatabase(t)
+
+ key, err := GenerateKey()
+ require.NoError(t, err, "Failed to generate key")
+
+ crypt, err := NewFieldEncrypt(key)
+ require.NoError(t, err, "Failed to initialize FieldEncrypt")
+
+ legacyEmail := crypt.LegacyEncrypt("testaccount@test.com")
+ legacyName := crypt.LegacyEncrypt("Test Account")
+
+ _, err = db.Exec(`INSERT INTO events(activity, timestamp, initiator_id, target_id, account_id, meta) VALUES(?, ?, ?, ?, ?, ?)`,
+ activity.UserDeleted, time.Now(), "initiatorID", "targetID", "accountID", "")
+ require.NoError(t, err, "Failed to insert event")
+
+ _, err = db.Exec(`INSERT INTO deleted_users(id, email, name) VALUES(?, ?, ?)`, "targetID", legacyEmail, legacyName)
+ require.NoError(t, err, "Failed to insert legacy encrypted data")
+
+ colExists, err := checkColumnExists(db, "deleted_users", "enc_algo")
+ require.NoError(t, err, "Failed to check if enc_algo column exists")
+ require.False(t, colExists, "enc_algo column should not exist before migration")
+
+ err = migrate(context.Background(), crypt, db)
+ require.NoError(t, err, "Migration failed")
+
+ colExists, err = checkColumnExists(db, "deleted_users", "enc_algo")
+ require.NoError(t, err, "Failed to check if enc_algo column exists after migration")
+ require.True(t, colExists, "enc_algo column should exist after migration")
+
+ var encAlgo string
+ err = db.QueryRow(`SELECT enc_algo FROM deleted_users LIMIT 1`, "").Scan(&encAlgo)
+ require.NoError(t, err, "Failed to select updated data")
+ require.Equal(t, gcmEncAlgo, encAlgo, "enc_algo should be set to 'GCM' after migration")
+
+ store, err := createStore(crypt, db)
+ require.NoError(t, err, "Failed to create store")
+
+ events, err := store.Get(context.Background(), "accountID", 0, 1, false)
+ require.NoError(t, err, "Failed to get events")
+
+ require.Len(t, events, 1, "Should have one event")
+ require.Equal(t, activity.UserDeleted, events[0].Activity, "activity should match")
+ require.Equal(t, "initiatorID", events[0].InitiatorID, "initiator id should match")
+ require.Equal(t, "targetID", events[0].TargetID, "target id should match")
+ require.Equal(t, "accountID", events[0].AccountID, "account id should match")
+ require.Equal(t, "testaccount@test.com", events[0].Meta["email"], "email should match")
+ require.Equal(t, "Test Account", events[0].Meta["username"], "username should match")
+}
diff --git a/management/server/activity/sqlite/sqlite.go b/management/server/activity/sqlite/sqlite.go
index fadf1eb07..823e0b4ac 100644
--- a/management/server/activity/sqlite/sqlite.go
+++ b/management/server/activity/sqlite/sqlite.go
@@ -26,7 +26,7 @@ const (
"meta TEXT," +
" target_id TEXT);"
- creatTableDeletedUsersQuery = `CREATE TABLE IF NOT EXISTS deleted_users (id TEXT NOT NULL, email TEXT NOT NULL, name TEXT);`
+ creatTableDeletedUsersQuery = `CREATE TABLE IF NOT EXISTS deleted_users (id TEXT NOT NULL, email TEXT NOT NULL, name TEXT, enc_algo TEXT NOT NULL);`
selectDescQuery = `SELECT events.id, activity, timestamp, initiator_id, i.name as "initiator_name", i.email as "initiator_email", target_id, t.name as "target_name", t.email as "target_email", account_id, meta
FROM events
@@ -69,10 +69,12 @@ const (
and some selfhosted deployments might have duplicates already so we need to clean the table first.
*/
- insertDeleteUserQuery = `INSERT INTO deleted_users(id, email, name) VALUES(?, ?, ?)`
+ insertDeleteUserQuery = `INSERT INTO deleted_users(id, email, name, enc_algo) VALUES(?, ?, ?, ?)`
fallbackName = "unknown"
fallbackEmail = "unknown@unknown.com"
+
+ gcmEncAlgo = "GCM"
)
// Store is the implementation of the activity.Store interface backed by SQLite
@@ -100,58 +102,12 @@ func NewSQLiteStore(ctx context.Context, dataDir string, encryptionKey string) (
return nil, err
}
- _, err = db.Exec(createTableQuery)
- if err != nil {
+ if err = migrate(ctx, crypt, db); err != nil {
_ = db.Close()
- return nil, err
+ return nil, fmt.Errorf("events database migration: %w", err)
}
- _, err = db.Exec(creatTableDeletedUsersQuery)
- if err != nil {
- _ = db.Close()
- return nil, err
- }
-
- err = updateDeletedUsersTable(ctx, db)
- if err != nil {
- _ = db.Close()
- return nil, err
- }
-
- insertStmt, err := db.Prepare(insertQuery)
- if err != nil {
- _ = db.Close()
- return nil, err
- }
-
- selectDescStmt, err := db.Prepare(selectDescQuery)
- if err != nil {
- _ = db.Close()
- return nil, err
- }
-
- selectAscStmt, err := db.Prepare(selectAscQuery)
- if err != nil {
- _ = db.Close()
- return nil, err
- }
-
- deleteUserStmt, err := db.Prepare(insertDeleteUserQuery)
- if err != nil {
- _ = db.Close()
- return nil, err
- }
-
- s := &Store{
- db: db,
- fieldEncrypt: crypt,
- insertStatement: insertStmt,
- selectDescStatement: selectDescStmt,
- selectAscStatement: selectAscStmt,
- deleteUserStmt: deleteUserStmt,
- }
-
- return s, nil
+ return createStore(crypt, db)
}
func (store *Store) processResult(ctx context.Context, result *sql.Rows) ([]*activity.Event, error) {
@@ -302,9 +258,16 @@ func (store *Store) saveDeletedUserEmailAndNameInEncrypted(event *activity.Event
return event.Meta, nil
}
- encryptedEmail := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", email))
- encryptedName := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", name))
- _, err := store.deleteUserStmt.Exec(event.TargetID, encryptedEmail, encryptedName)
+ encryptedEmail, err := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", email))
+ if err != nil {
+ return nil, err
+ }
+ encryptedName, err := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", name))
+ if err != nil {
+ return nil, err
+ }
+
+ _, err = store.deleteUserStmt.Exec(event.TargetID, encryptedEmail, encryptedName, gcmEncAlgo)
if err != nil {
return nil, err
}
@@ -325,43 +288,70 @@ func (store *Store) Close(_ context.Context) error {
return nil
}
-func updateDeletedUsersTable(ctx context.Context, db *sql.DB) error {
- log.WithContext(ctx).Debugf("check deleted_users table version")
- rows, err := db.Query(`PRAGMA table_info(deleted_users);`)
+// createStore initializes and returns a new Store instance with prepared SQL statements.
+func createStore(crypt *FieldEncrypt, db *sql.DB) (*Store, error) {
+ insertStmt, err := db.Prepare(insertQuery)
if err != nil {
- return err
+ _ = db.Close()
+ return nil, err
+ }
+
+ selectDescStmt, err := db.Prepare(selectDescQuery)
+ if err != nil {
+ _ = db.Close()
+ return nil, err
+ }
+
+ selectAscStmt, err := db.Prepare(selectAscQuery)
+ if err != nil {
+ _ = db.Close()
+ return nil, err
+ }
+
+ deleteUserStmt, err := db.Prepare(insertDeleteUserQuery)
+ if err != nil {
+ _ = db.Close()
+ return nil, err
+ }
+
+ return &Store{
+ db: db,
+ fieldEncrypt: crypt,
+ insertStatement: insertStmt,
+ selectDescStatement: selectDescStmt,
+ selectAscStatement: selectAscStmt,
+ deleteUserStmt: deleteUserStmt,
+ }, nil
+}
+
+// checkColumnExists checks if a column exists in a specified table
+func checkColumnExists(db *sql.DB, tableName, columnName string) (bool, error) {
+ query := fmt.Sprintf("PRAGMA table_info(%s);", tableName)
+ rows, err := db.Query(query)
+ if err != nil {
+ return false, fmt.Errorf("failed to query table info: %w", err)
}
defer rows.Close()
- found := false
+
for rows.Next() {
- var (
- cid int
- name string
- dataType string
- notNull int
- dfltVal sql.NullString
- pk int
- )
- err := rows.Scan(&cid, &name, &dataType, ¬Null, &dfltVal, &pk)
+ var cid int
+ var name, ctype string
+ var notnull, pk int
+ var dfltValue sql.NullString
+
+ err = rows.Scan(&cid, &name, &ctype, ¬null, &dfltValue, &pk)
if err != nil {
- return err
+ return false, fmt.Errorf("failed to scan row: %w", err)
}
- if name == "name" {
- found = true
- break
+
+ if name == columnName {
+ return true, nil
}
}
- err = rows.Err()
- if err != nil {
- return err
+ if err = rows.Err(); err != nil {
+ return false, err
}
- if found {
- return nil
- }
-
- log.WithContext(ctx).Debugf("update delted_users table")
- _, err = db.Exec(`ALTER TABLE deleted_users ADD COLUMN name TEXT;`)
- return err
+ return false, nil
}
From afb9673bc472ca5dcb67757155806a14d435052b Mon Sep 17 00:00:00 2001
From: Maycon Santos
Date: Wed, 11 Sep 2024 21:49:05 +0200
Subject: [PATCH 44/89] [misc] Update core github actions (#2584)
---
.github/workflows/golang-test-darwin.yml | 6 ++--
.github/workflows/golang-test-linux.yml | 14 ++++-----
.github/workflows/golang-test-windows.yml | 4 +--
.github/workflows/golangci-lint.yml | 8 ++---
.github/workflows/install-script-test.yml | 2 +-
.github/workflows/mobile-build-validation.yml | 14 ++++-----
.github/workflows/release.yml | 30 +++++++++----------
.../workflows/test-infrastructure-files.yml | 8 ++---
8 files changed, 43 insertions(+), 43 deletions(-)
diff --git a/.github/workflows/golang-test-darwin.yml b/.github/workflows/golang-test-darwin.yml
index 2b4c43cb4..81c3105b2 100644
--- a/.github/workflows/golang-test-darwin.yml
+++ b/.github/workflows/golang-test-darwin.yml
@@ -18,14 +18,14 @@ jobs:
runs-on: macos-latest
steps:
- name: Install Go
- uses: actions/setup-go@v4
+ uses: actions/setup-go@v5
with:
go-version: "1.21.x"
- name: Checkout code
- uses: actions/checkout@v3
+ uses: actions/checkout@v4
- name: Cache Go modules
- uses: actions/cache@v3
+ uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: macos-go-${{ hashFiles('**/go.sum') }}
diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml
index 120b213e9..076cd061e 100644
--- a/.github/workflows/golang-test-linux.yml
+++ b/.github/workflows/golang-test-linux.yml
@@ -19,13 +19,13 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Install Go
- uses: actions/setup-go@v4
+ uses: actions/setup-go@v5
with:
go-version: "1.21.x"
- name: Cache Go modules
- uses: actions/cache@v3
+ uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -33,7 +33,7 @@ jobs:
${{ runner.os }}-go-
- name: Checkout code
- uses: actions/checkout@v3
+ uses: actions/checkout@v4
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
@@ -55,12 +55,12 @@ jobs:
runs-on: ubuntu-20.04
steps:
- name: Install Go
- uses: actions/setup-go@v4
+ uses: actions/setup-go@v5
with:
go-version: "1.21.x"
- name: Cache Go modules
- uses: actions/cache@v3
+ uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -68,7 +68,7 @@ jobs:
${{ runner.os }}-go-
- name: Checkout code
- uses: actions/checkout@v3
+ uses: actions/checkout@v4
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
@@ -124,4 +124,4 @@ jobs:
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="sqlite" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Peer tests in docker
- run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1
\ No newline at end of file
+ run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1
diff --git a/.github/workflows/golang-test-windows.yml b/.github/workflows/golang-test-windows.yml
index 2d63acbcd..f51ddb09f 100644
--- a/.github/workflows/golang-test-windows.yml
+++ b/.github/workflows/golang-test-windows.yml
@@ -17,10 +17,10 @@ jobs:
runs-on: windows-latest
steps:
- name: Checkout code
- uses: actions/checkout@v3
+ uses: actions/checkout@v4
- name: Install Go
- uses: actions/setup-go@v4
+ uses: actions/setup-go@v5
id: go
with:
go-version: "1.21.x"
diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml
index 78b9f504f..2833ded20 100644
--- a/.github/workflows/golangci-lint.yml
+++ b/.github/workflows/golangci-lint.yml
@@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout code
- uses: actions/checkout@v3
+ uses: actions/checkout@v4
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
@@ -32,13 +32,13 @@ jobs:
timeout-minutes: 15
steps:
- name: Checkout code
- uses: actions/checkout@v3
+ uses: actions/checkout@v4
- name: Check for duplicate constants
if: matrix.os == 'ubuntu-latest'
run: |
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
- name: Install Go
- uses: actions/setup-go@v4
+ uses: actions/setup-go@v5
with:
go-version: "1.21.x"
cache: false
@@ -49,4 +49,4 @@ jobs:
uses: golangci/golangci-lint-action@v3
with:
version: latest
- args: --timeout=12m
\ No newline at end of file
+ args: --timeout=12m
diff --git a/.github/workflows/install-script-test.yml b/.github/workflows/install-script-test.yml
index dfb8a279b..04c222e87 100644
--- a/.github/workflows/install-script-test.yml
+++ b/.github/workflows/install-script-test.yml
@@ -21,7 +21,7 @@ jobs:
runs-on: ${{ matrix.os }}
steps:
- name: Checkout code
- uses: actions/checkout@v3
+ uses: actions/checkout@v4
- name: run install script
env:
diff --git a/.github/workflows/mobile-build-validation.yml b/.github/workflows/mobile-build-validation.yml
index e5a5ff485..5bae3a3ec 100644
--- a/.github/workflows/mobile-build-validation.yml
+++ b/.github/workflows/mobile-build-validation.yml
@@ -15,9 +15,9 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
- uses: actions/checkout@v3
+ uses: actions/checkout@v4
- name: Install Go
- uses: actions/setup-go@v4
+ uses: actions/setup-go@v5
with:
go-version: "1.21.x"
- name: Setup Android SDK
@@ -25,13 +25,13 @@ jobs:
with:
cmdline-tools-version: 8512546
- name: Setup Java
- uses: actions/setup-java@v3
+ uses: actions/setup-java@v4
with:
java-version: "11"
distribution: "adopt"
- name: NDK Cache
id: ndk-cache
- uses: actions/cache@v3
+ uses: actions/cache@v4
with:
path: /usr/local/lib/android/sdk/ndk
key: ndk-cache-23.1.7779620
@@ -50,9 +50,9 @@ jobs:
runs-on: macos-latest
steps:
- name: Checkout repository
- uses: actions/checkout@v3
+ uses: actions/checkout@v4
- name: Install Go
- uses: actions/setup-go@v4
+ uses: actions/setup-go@v5
with:
go-version: "1.21.x"
- name: install gomobile
@@ -62,4 +62,4 @@ jobs:
- name: build iOS netbird lib
run: PATH=$PATH:$(go env GOPATH) gomobile bind -target=ios -bundleid=io.netbird.framework -ldflags="-X github.com/netbirdio/netbird/version.version=buildtest" -o ./NetBirdSDK.xcframework ./client/ios/NetBirdSDK
env:
- CGO_ENABLED: 0
\ No newline at end of file
+ CGO_ENABLED: 0
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index a8f7868d5..bb8887e6d 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -36,18 +36,18 @@ jobs:
run: echo "flags=--snapshot" >> $GITHUB_ENV
-
name: Checkout
- uses: actions/checkout@v3
+ uses: actions/checkout@v4
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
-
name: Set up Go
- uses: actions/setup-go@v4
+ uses: actions/setup-go@v5
with:
go-version: "1.21"
cache: false
-
name: Cache Go modules
- uses: actions/cache@v3
+ uses: actions/cache@v4
with:
path: |
~/go/pkg/mod
@@ -93,28 +93,28 @@ jobs:
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
-
name: upload non tags for debug purposes
- uses: actions/upload-artifact@v3
+ uses: actions/upload-artifact@v4
with:
name: release
path: dist/
retention-days: 3
-
name: upload linux packages
- uses: actions/upload-artifact@v3
+ uses: actions/upload-artifact@v4
with:
name: linux-packages
path: dist/netbird_linux**
retention-days: 3
-
name: upload windows packages
- uses: actions/upload-artifact@v3
+ uses: actions/upload-artifact@v4
with:
name: windows-packages
path: dist/netbird_windows**
retention-days: 3
-
name: upload macos packages
- uses: actions/upload-artifact@v3
+ uses: actions/upload-artifact@v4
with:
name: macos-packages
path: dist/netbird_darwin**
@@ -133,17 +133,17 @@ jobs:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Checkout
- uses: actions/checkout@v3
+ uses: actions/checkout@v4
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
- name: Set up Go
- uses: actions/setup-go@v4
+ uses: actions/setup-go@v5
with:
go-version: "1.21"
cache: false
- name: Cache Go modules
- uses: actions/cache@v3
+ uses: actions/cache@v4
with:
path: |
~/go/pkg/mod
@@ -176,7 +176,7 @@ jobs:
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
- name: upload non tags for debug purposes
- uses: actions/upload-artifact@v3
+ uses: actions/upload-artifact@v4
with:
name: release-ui
path: dist/
@@ -189,18 +189,18 @@ jobs:
run: echo "flags=--snapshot" >> $GITHUB_ENV
-
name: Checkout
- uses: actions/checkout@v3
+ uses: actions/checkout@v4
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
-
name: Set up Go
- uses: actions/setup-go@v4
+ uses: actions/setup-go@v5
with:
go-version: "1.21"
cache: false
-
name: Cache Go modules
- uses: actions/cache@v3
+ uses: actions/cache@v4
with:
path: |
~/go/pkg/mod
@@ -225,7 +225,7 @@ jobs:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
-
name: upload non tags for debug purposes
- uses: actions/upload-artifact@v3
+ uses: actions/upload-artifact@v4
with:
name: release-ui-darwin
path: dist/
diff --git a/.github/workflows/test-infrastructure-files.yml b/.github/workflows/test-infrastructure-files.yml
index d1c2b3aef..d627adccd 100644
--- a/.github/workflows/test-infrastructure-files.yml
+++ b/.github/workflows/test-infrastructure-files.yml
@@ -50,12 +50,12 @@ jobs:
run: sudo apt-get install -y curl
- name: Install Go
- uses: actions/setup-go@v4
+ uses: actions/setup-go@v5
with:
go-version: "1.21.x"
- name: Cache Go modules
- uses: actions/cache@v3
+ uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -63,7 +63,7 @@ jobs:
${{ runner.os }}-go-
- name: Checkout code
- uses: actions/checkout@v3
+ uses: actions/checkout@v4
- name: cp setup.env
run: cp infrastructure_files/tests/setup.env infrastructure_files/
@@ -219,7 +219,7 @@ jobs:
run: sudo apt-get install -y jq
- name: Checkout code
- uses: actions/checkout@v3
+ uses: actions/checkout@v4
- name: run script with Zitadel PostgreSQL
run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh
From 4c130a02917259fc331f32d95107009786a806f7 Mon Sep 17 00:00:00 2001
From: Maycon Santos
Date: Thu, 12 Sep 2024 13:46:28 +0200
Subject: [PATCH 45/89] Update Go version to 1.23 (#2588)
---
.github/workflows/golang-test-darwin.yml | 2 +-
.github/workflows/golang-test-linux.yml | 4 ++--
.github/workflows/golang-test-windows.yml | 2 +-
.github/workflows/golangci-lint.yml | 2 +-
.github/workflows/mobile-build-validation.yml | 4 ++--
.github/workflows/release.yml | 6 +++---
.github/workflows/test-infrastructure-files.yml | 2 +-
go.mod | 2 +-
8 files changed, 12 insertions(+), 12 deletions(-)
diff --git a/.github/workflows/golang-test-darwin.yml b/.github/workflows/golang-test-darwin.yml
index 81c3105b2..2aaef7564 100644
--- a/.github/workflows/golang-test-darwin.yml
+++ b/.github/workflows/golang-test-darwin.yml
@@ -20,7 +20,7 @@ jobs:
- name: Install Go
uses: actions/setup-go@v5
with:
- go-version: "1.21.x"
+ go-version: "1.23.x"
- name: Checkout code
uses: actions/checkout@v4
diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml
index 076cd061e..263623bd1 100644
--- a/.github/workflows/golang-test-linux.yml
+++ b/.github/workflows/golang-test-linux.yml
@@ -21,7 +21,7 @@ jobs:
- name: Install Go
uses: actions/setup-go@v5
with:
- go-version: "1.21.x"
+ go-version: "1.23.x"
- name: Cache Go modules
@@ -57,7 +57,7 @@ jobs:
- name: Install Go
uses: actions/setup-go@v5
with:
- go-version: "1.21.x"
+ go-version: "1.23.x"
- name: Cache Go modules
uses: actions/cache@v4
diff --git a/.github/workflows/golang-test-windows.yml b/.github/workflows/golang-test-windows.yml
index f51ddb09f..d378bec3f 100644
--- a/.github/workflows/golang-test-windows.yml
+++ b/.github/workflows/golang-test-windows.yml
@@ -23,7 +23,7 @@ jobs:
uses: actions/setup-go@v5
id: go
with:
- go-version: "1.21.x"
+ go-version: "1.23.x"
- name: Download wintun
uses: carlosperate/download-file-action@v2
diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml
index 2833ded20..8b7136841 100644
--- a/.github/workflows/golangci-lint.yml
+++ b/.github/workflows/golangci-lint.yml
@@ -40,7 +40,7 @@ jobs:
- name: Install Go
uses: actions/setup-go@v5
with:
- go-version: "1.21.x"
+ go-version: "1.23.x"
cache: false
- name: Install dependencies
if: matrix.os == 'ubuntu-latest'
diff --git a/.github/workflows/mobile-build-validation.yml b/.github/workflows/mobile-build-validation.yml
index 5bae3a3ec..dcf461a34 100644
--- a/.github/workflows/mobile-build-validation.yml
+++ b/.github/workflows/mobile-build-validation.yml
@@ -19,7 +19,7 @@ jobs:
- name: Install Go
uses: actions/setup-go@v5
with:
- go-version: "1.21.x"
+ go-version: "1.23.x"
- name: Setup Android SDK
uses: android-actions/setup-android@v3
with:
@@ -54,7 +54,7 @@ jobs:
- name: Install Go
uses: actions/setup-go@v5
with:
- go-version: "1.21.x"
+ go-version: "1.23.x"
- name: install gomobile
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
- name: gomobile init
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index bb8887e6d..5f423f1c9 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -43,7 +43,7 @@ jobs:
name: Set up Go
uses: actions/setup-go@v5
with:
- go-version: "1.21"
+ go-version: "1.23"
cache: false
-
name: Cache Go modules
@@ -140,7 +140,7 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v5
with:
- go-version: "1.21"
+ go-version: "1.23"
cache: false
- name: Cache Go modules
uses: actions/cache@v4
@@ -196,7 +196,7 @@ jobs:
name: Set up Go
uses: actions/setup-go@v5
with:
- go-version: "1.21"
+ go-version: "1.23"
cache: false
-
name: Cache Go modules
diff --git a/.github/workflows/test-infrastructure-files.yml b/.github/workflows/test-infrastructure-files.yml
index d627adccd..da3ec746a 100644
--- a/.github/workflows/test-infrastructure-files.yml
+++ b/.github/workflows/test-infrastructure-files.yml
@@ -52,7 +52,7 @@ jobs:
- name: Install Go
uses: actions/setup-go@v5
with:
- go-version: "1.21.x"
+ go-version: "1.23.x"
- name: Cache Go modules
uses: actions/cache@v4
diff --git a/go.mod b/go.mod
index 16b5b55e4..12709e50d 100644
--- a/go.mod
+++ b/go.mod
@@ -1,6 +1,6 @@
module github.com/netbirdio/netbird
-go 1.21.0
+go 1.23.0
require (
cunicu.li/go-rosenpass v0.4.0
From 170e842422a53e3f218f21bc59c2942c84e73c0e Mon Sep 17 00:00:00 2001
From: Bethuel Mmbaga
Date: Thu, 12 Sep 2024 16:19:27 +0300
Subject: [PATCH 46/89] [management] Add accessible peers endpoint (#2579)
* move accessible peer to separate endpoint in api doc
Signed-off-by: bcmmbaga
* add endpoint to get accessible peers
Signed-off-by: bcmmbaga
* Update management/server/http/api/openapi.yml
Co-authored-by: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com>
* Update management/server/http/api/openapi.yml
Co-authored-by: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com>
* Update management/server/http/peers_handler.go
Co-authored-by: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com>
---------
Signed-off-by: bcmmbaga
Co-authored-by: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com>
---
management/server/http/api/openapi.yml | 75 ++++++++++++++++----
management/server/http/api/types.gen.go | 93 +++++--------------------
management/server/http/handler.go | 1 +
management/server/http/peers_handler.go | 81 +++++++++++++--------
4 files changed, 133 insertions(+), 117 deletions(-)
diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml
index d32ec6167..156310a9b 100644
--- a/management/server/http/api/openapi.yml
+++ b/management/server/http/api/openapi.yml
@@ -251,7 +251,7 @@ components:
- name
- ssh_enabled
- login_expiration_enabled
- PeerBase:
+ Peer:
allOf:
- $ref: '#/components/schemas/PeerMinimum'
- type: object
@@ -378,25 +378,40 @@ components:
description: User ID of the user that enrolled this peer
type: string
example: google-oauth2|277474792786460067937
+ os:
+ description: Peer's operating system and version
+ type: string
+ example: linux
+ country_code:
+ $ref: '#/components/schemas/CountryCode'
+ city_name:
+ $ref: '#/components/schemas/CityName'
+ geoname_id:
+ description: Unique identifier from the GeoNames database for a specific geographical location.
+ type: integer
+ example: 2643743
+ connected:
+ description: Peer to Management connection status
+ type: boolean
+ example: true
+ last_seen:
+ description: Last time peer connected to Netbird's management service
+ type: string
+ format: date-time
+ example: "2023-05-05T10:05:26.420578Z"
required:
- ip
- dns_label
- user_id
- Peer:
- allOf:
- - $ref: '#/components/schemas/PeerBase'
- - type: object
- properties:
- accessible_peers:
- description: List of accessible peers
- type: array
- items:
- $ref: '#/components/schemas/AccessiblePeer'
- required:
- - accessible_peers
+ - os
+ - country_code
+ - city_name
+ - geoname_id
+ - connected
+ - last_seen
PeerBatch:
allOf:
- - $ref: '#/components/schemas/PeerBase'
+ - $ref: '#/components/schemas/Peer'
- type: object
properties:
accessible_peers_count:
@@ -1806,6 +1821,38 @@ paths:
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
+ /api/peers/{peerId}/accessible-peers:
+ get:
+ summary: List accessible Peers
+ description: Returns a list of peers that the specified peer can connect to within the network.
+ tags: [ Peers ]
+ security:
+ - BearerAuth: [ ]
+ - TokenAuth: [ ]
+ parameters:
+ - in: path
+ name: peerId
+ required: true
+ schema:
+ type: string
+ description: The unique identifier of a peer
+ responses:
+ '200':
+ description: A JSON Array of Accessible Peers
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ $ref: '#/components/schemas/AccessiblePeer'
+ '400':
+ "$ref": "#/components/responses/bad_request"
+ '401':
+ "$ref": "#/components/responses/requires_authentication"
+ '403':
+ "$ref": "#/components/responses/forbidden"
+ '500':
+ "$ref": "#/components/responses/internal_error"
/api/setup-keys:
get:
summary: List all Setup Keys
diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go
index a575ff54b..b219d38fd 100644
--- a/management/server/http/api/types.gen.go
+++ b/management/server/http/api/types.gen.go
@@ -152,18 +152,36 @@ const (
// AccessiblePeer defines model for AccessiblePeer.
type AccessiblePeer struct {
+ // CityName Commonly used English name of the city
+ CityName CityName `json:"city_name"`
+
+ // Connected Peer to Management connection status
+ Connected bool `json:"connected"`
+
+ // CountryCode 2-letter ISO 3166-1 alpha-2 code that represents the country
+ CountryCode CountryCode `json:"country_code"`
+
// DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
DnsLabel string `json:"dns_label"`
+ // GeonameId Unique identifier from the GeoNames database for a specific geographical location.
+ GeonameId int `json:"geoname_id"`
+
// Id Peer ID
Id string `json:"id"`
// Ip Peer's IP address
Ip string `json:"ip"`
+ // LastSeen Last time peer connected to Netbird's management service
+ LastSeen time.Time `json:"last_seen"`
+
// Name Peer's hostname
Name string `json:"name"`
+ // Os Peer's operating system and version
+ Os string `json:"os"`
+
// UserId User ID of the user that enrolled this peer
UserId string `json:"user_id"`
}
@@ -490,81 +508,6 @@ type OSVersionCheck struct {
// Peer defines model for Peer.
type Peer struct {
- // AccessiblePeers List of accessible peers
- AccessiblePeers []AccessiblePeer `json:"accessible_peers"`
-
- // ApprovalRequired (Cloud only) Indicates whether peer needs approval
- ApprovalRequired bool `json:"approval_required"`
-
- // CityName Commonly used English name of the city
- CityName CityName `json:"city_name"`
-
- // Connected Peer to Management connection status
- Connected bool `json:"connected"`
-
- // ConnectionIp Peer's public connection IP address
- ConnectionIp string `json:"connection_ip"`
-
- // CountryCode 2-letter ISO 3166-1 alpha-2 code that represents the country
- CountryCode CountryCode `json:"country_code"`
-
- // DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
- DnsLabel string `json:"dns_label"`
-
- // GeonameId Unique identifier from the GeoNames database for a specific geographical location.
- GeonameId int `json:"geoname_id"`
-
- // Groups Groups that the peer belongs to
- Groups []GroupMinimum `json:"groups"`
-
- // Hostname Hostname of the machine
- Hostname string `json:"hostname"`
-
- // Id Peer ID
- Id string `json:"id"`
-
- // Ip Peer's IP address
- Ip string `json:"ip"`
-
- // KernelVersion Peer's operating system kernel version
- KernelVersion string `json:"kernel_version"`
-
- // LastLogin Last time this peer performed log in (authentication). E.g., user authenticated.
- LastLogin time.Time `json:"last_login"`
-
- // LastSeen Last time peer connected to Netbird's management service
- LastSeen time.Time `json:"last_seen"`
-
- // LoginExpirationEnabled Indicates whether peer login expiration has been enabled or not
- LoginExpirationEnabled bool `json:"login_expiration_enabled"`
-
- // LoginExpired Indicates whether peer's login expired or not
- LoginExpired bool `json:"login_expired"`
-
- // Name Peer's hostname
- Name string `json:"name"`
-
- // Os Peer's operating system and version
- Os string `json:"os"`
-
- // SerialNumber System serial number
- SerialNumber string `json:"serial_number"`
-
- // SshEnabled Indicates whether SSH server is enabled on this peer
- SshEnabled bool `json:"ssh_enabled"`
-
- // UiVersion Peer's desktop UI version
- UiVersion string `json:"ui_version"`
-
- // UserId User ID of the user that enrolled this peer
- UserId string `json:"user_id"`
-
- // Version Peer's daemon or cli version
- Version string `json:"version"`
-}
-
-// PeerBase defines model for PeerBase.
-type PeerBase struct {
// ApprovalRequired (Cloud only) Indicates whether peer needs approval
ApprovalRequired bool `json:"approval_required"`
diff --git a/management/server/http/handler.go b/management/server/http/handler.go
index 366efa9b7..ef94f22b9 100644
--- a/management/server/http/handler.go
+++ b/management/server/http/handler.go
@@ -115,6 +115,7 @@ func (apiHandler *apiHandler) addPeersEndpoint() {
apiHandler.Router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer).
Methods("GET", "PUT", "DELETE", "OPTIONS")
+ apiHandler.Router.HandleFunc("/peers/{peerId}/accessible-peers", peersHandler.GetAccessiblePeers).Methods("GET", "OPTIONS")
}
func (apiHandler *apiHandler) addUsersEndpoint() {
diff --git a/management/server/http/peers_handler.go b/management/server/http/peers_handler.go
index 913d424d1..1487bbc39 100644
--- a/management/server/http/peers_handler.go
+++ b/management/server/http/peers_handler.go
@@ -71,12 +71,8 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee
return
}
- customZone := account.GetPeersCustomZone(ctx, h.accountManager.GetDNSDomain())
- netMap := account.GetPeerNetworkMap(ctx, peerID, customZone, validPeers, nil)
- accessiblePeers := toAccessiblePeers(netMap, dnsDomain)
-
_, valid := validPeers[peer.ID]
- util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, accessiblePeers, valid))
+ util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid))
}
func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) {
@@ -117,13 +113,9 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account,
return
}
- customZone := account.GetPeersCustomZone(ctx, h.accountManager.GetDNSDomain())
- netMap := account.GetPeerNetworkMap(ctx, peerID, customZone, validPeers, nil)
- accessiblePeers := toAccessiblePeers(netMap, dnsDomain)
-
_, valid := validPeers[peer.ID]
- util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, accessiblePeers, valid))
+ util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, valid))
}
func (h *PeersHandler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) {
@@ -220,32 +212,66 @@ func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approv
}
}
+// GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network.
+func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
+ claims := h.claimsExtractor.FromRequestContext(r)
+ account, _, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ if err != nil {
+ util.WriteError(r.Context(), err, w)
+ return
+ }
+
+ vars := mux.Vars(r)
+ peerID := vars["peerId"]
+ if len(peerID) == 0 {
+ util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid peer ID"), w)
+ return
+ }
+
+ dnsDomain := h.accountManager.GetDNSDomain()
+
+ validPeers, err := h.accountManager.GetValidatedPeers(account)
+ if err != nil {
+ log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err)
+ util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
+ return
+ }
+
+ customZone := account.GetPeersCustomZone(r.Context(), h.accountManager.GetDNSDomain())
+ netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, nil)
+
+ util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
+}
+
func toAccessiblePeers(netMap *server.NetworkMap, dnsDomain string) []api.AccessiblePeer {
accessiblePeers := make([]api.AccessiblePeer, 0, len(netMap.Peers)+len(netMap.OfflinePeers))
for _, p := range netMap.Peers {
- ap := api.AccessiblePeer{
- Id: p.ID,
- Name: p.Name,
- Ip: p.IP.String(),
- DnsLabel: fqdn(p, dnsDomain),
- UserId: p.UserID,
- }
- accessiblePeers = append(accessiblePeers, ap)
+ accessiblePeers = append(accessiblePeers, peerToAccessiblePeer(p, dnsDomain))
}
for _, p := range netMap.OfflinePeers {
- ap := api.AccessiblePeer{
- Id: p.ID,
- Name: p.Name,
- Ip: p.IP.String(),
- DnsLabel: fqdn(p, dnsDomain),
- UserId: p.UserID,
- }
- accessiblePeers = append(accessiblePeers, ap)
+ accessiblePeers = append(accessiblePeers, peerToAccessiblePeer(p, dnsDomain))
}
+
return accessiblePeers
}
+func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePeer {
+ return api.AccessiblePeer{
+ CityName: peer.Location.CityName,
+ Connected: peer.Status.Connected,
+ CountryCode: peer.Location.CountryCode,
+ DnsLabel: fqdn(peer, dnsDomain),
+ GeonameId: int(peer.Location.GeoNameID),
+ Id: peer.ID,
+ Ip: peer.IP.String(),
+ LastSeen: peer.Status.LastSeen,
+ Name: peer.Name,
+ Os: peer.Meta.OS,
+ UserId: peer.UserID,
+ }
+}
+
func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum {
var groupsInfo []api.GroupMinimum
groupsChecked := make(map[string]struct{})
@@ -270,7 +296,7 @@ func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMi
return groupsInfo
}
-func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, accessiblePeer []api.AccessiblePeer, approved bool) *api.Peer {
+func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, approved bool) *api.Peer {
osVersion := peer.Meta.OSVersion
if osVersion == "" {
osVersion = peer.Meta.Core
@@ -296,7 +322,6 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD
LoginExpirationEnabled: peer.LoginExpirationEnabled,
LastLogin: peer.LastLogin,
LoginExpired: peer.Status.LoginExpired,
- AccessiblePeers: accessiblePeer,
ApprovalRequired: !approved,
CountryCode: peer.Location.CountryCode,
CityName: peer.Location.CityName,
From 33c9b2d989ac8a11cbf3046f4f9fa13850c62ca5 Mon Sep 17 00:00:00 2001
From: Gianluca Boiano <491117+M0Rf30@users.noreply.github.com>
Date: Thu, 12 Sep 2024 17:32:47 +0200
Subject: [PATCH 47/89] fix: install.sh: avoid call of netbird executable after
rpm installation (#2589)
---
release_files/install.sh | 12 +++++++-----
1 file changed, 7 insertions(+), 5 deletions(-)
diff --git a/release_files/install.sh b/release_files/install.sh
index 7b6774d84..d6aabebd8 100755
--- a/release_files/install.sh
+++ b/release_files/install.sh
@@ -300,11 +300,13 @@ install_netbird() {
echo "package_manager=$PACKAGE_MANAGER" | ${SUDO} tee "$CONFIG_FILE" > /dev/null
# Load and start netbird service
- if ! ${SUDO} netbird service install 2>&1; then
- echo "NetBird service has already been loaded"
- fi
- if ! ${SUDO} netbird service start 2>&1; then
- echo "NetBird service has already been started"
+ if [ "$PACKAGE_MANAGER" != "rpm-ostree" ]; then
+ if ! ${SUDO} netbird service install 2>&1; then
+ echo "NetBird service has already been loaded"
+ fi
+ if ! ${SUDO} netbird service start 2>&1; then
+ echo "NetBird service has already been started"
+ fi
fi
From ab892b8cf9b421b5e61a74a9545a8b394d6d3a09 Mon Sep 17 00:00:00 2001
From: Zoltan Papp
Date: Thu, 12 Sep 2024 19:18:02 +0200
Subject: [PATCH 48/89] Fix wg handshake checking (#2590)
* Fix wg handshake checking
* Ensure in the initial handshake reading
* Change the handshake period
---
client/internal/peer/conn.go | 4 +-
client/internal/peer/worker_relay.go | 69 +++++++++++++++++-----------
2 files changed, 43 insertions(+), 30 deletions(-)
diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go
index 8b8b3c5c0..f4a701f7f 100644
--- a/client/internal/peer/conn.go
+++ b/client/internal/peer/conn.go
@@ -484,11 +484,11 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
// switch back to relay connection
if conn.endpointRelay != nil && conn.currentConnPriority != connPriorityRelay {
conn.log.Debugf("ICE disconnected, set Relay to active connection")
+ conn.workerRelay.EnableWgWatcher(conn.ctx)
err := conn.configureWGEndpoint(conn.endpointRelay)
if err != nil {
conn.log.Errorf("failed to switch to relay conn: %v", err)
}
- conn.workerRelay.EnableWgWatcher(conn.ctx)
conn.currentConnPriority = connPriorityRelay
}
@@ -551,6 +551,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
}
}
+ conn.workerRelay.EnableWgWatcher(conn.ctx)
err = conn.configureWGEndpoint(endpointUdpAddr)
if err != nil {
if err := wgProxy.CloseConn(); err != nil {
@@ -560,7 +561,6 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
return
}
wgConfigWorkaround()
- conn.workerRelay.EnableWgWatcher(conn.ctx)
if conn.wgProxyRelay != nil {
if err := conn.wgProxyRelay.CloseConn(); err != nil {
diff --git a/client/internal/peer/worker_relay.go b/client/internal/peer/worker_relay.go
index 930a8f5b6..3457faa46 100644
--- a/client/internal/peer/worker_relay.go
+++ b/client/internal/peer/worker_relay.go
@@ -14,7 +14,7 @@ import (
)
var (
- wgHandshakePeriod = 2 * time.Minute
+ wgHandshakePeriod = 3 * time.Minute
wgHandshakeOvertime = 30 * time.Second
)
@@ -109,7 +109,7 @@ func (w *WorkerRelay) EnableWgWatcher(ctx context.Context) {
}
ctx, ctxCancel := context.WithCancel(ctx)
- go w.wgStateCheck(ctx)
+ w.wgStateCheck(ctx)
w.ctxWgWatch = ctx
w.ctxCancelWgWatch = ctxCancel
@@ -157,37 +157,50 @@ func (w *WorkerRelay) CloseConn() {
}
}
-// wgStateCheck help to check the state of the wireguard handshake and relay connection
+// wgStateCheck help to check the state of the WireGuard handshake and relay connection
func (w *WorkerRelay) wgStateCheck(ctx context.Context) {
- timer := time.NewTimer(wgHandshakeOvertime)
- defer timer.Stop()
- expected := wgHandshakeOvertime
- for {
- select {
- case <-timer.C:
- lastHandshake, err := w.wgState()
- if err != nil {
- w.log.Errorf("failed to read wg stats: %v", err)
- continue
- }
- w.log.Tracef("last handshake: %v", lastHandshake)
+ lastHandshake, err := w.wgState()
+ if err != nil {
+ w.log.Errorf("failed to read wg stats: %v", err)
+ lastHandshake = time.Time{}
+ }
- if time.Since(lastHandshake) > expected {
- w.log.Infof("Wireguard handshake timed out, closing relay connection")
- w.relayLock.Lock()
- _ = w.relayedConn.Close()
- w.relayLock.Unlock()
- w.callBacks.OnDisconnected()
+ go func(lastHandshake time.Time) {
+ timer := time.NewTimer(wgHandshakeOvertime)
+ defer timer.Stop()
+
+ for {
+ select {
+ case <-timer.C:
+
+ handshake, err := w.wgState()
+ if err != nil {
+ w.log.Errorf("failed to read wg stats: %v", err)
+ timer.Reset(wgHandshakeOvertime)
+ continue
+ }
+
+ w.log.Tracef("previous handshake, handshake: %v, %v", lastHandshake, handshake)
+
+ if handshake.Equal(lastHandshake) {
+ w.log.Infof("WireGuard handshake timed out, closing relay connection: %v", handshake)
+ w.relayLock.Lock()
+ _ = w.relayedConn.Close()
+ w.relayLock.Unlock()
+ w.callBacks.OnDisconnected()
+ return
+ }
+
+ resetTime := time.Until(handshake.Add(wgHandshakePeriod + wgHandshakeOvertime))
+ lastHandshake = handshake
+ timer.Reset(resetTime)
+ case <-ctx.Done():
+ w.log.Debugf("WireGuard watcher stopped")
return
}
- resetTime := time.Until(lastHandshake.Add(wgHandshakePeriod + wgHandshakeOvertime))
- timer.Reset(resetTime)
- expected = wgHandshakePeriod
- case <-ctx.Done():
- w.log.Debugf("WireGuard watcher stopped")
- return
}
- }
+ }(lastHandshake)
+
}
func (w *WorkerRelay) isRelaySupported(answer *OfferAnswer) bool {
From f6d57e7a96fb7f91a4af1f6fb79d84352cf33792 Mon Sep 17 00:00:00 2001
From: Maycon Santos
Date: Thu, 12 Sep 2024 19:56:55 +0200
Subject: [PATCH 49/89] [misc] Support configurable max log size with var
NB_LOG_MAX_SIZE_MB (#2592)
* Support configurable max log size with var NB_LOG_MAX_SIZE_MB
* add better logs
---
util/log.go | 23 +++++++++++++++++++++--
1 file changed, 21 insertions(+), 2 deletions(-)
diff --git a/util/log.go b/util/log.go
index 4bce75e4a..7a9235ee6 100644
--- a/util/log.go
+++ b/util/log.go
@@ -5,6 +5,7 @@ import (
"os"
"path/filepath"
"slices"
+ "strconv"
log "github.com/sirupsen/logrus"
"gopkg.in/natefinch/lumberjack.v2"
@@ -12,6 +13,8 @@ import (
"github.com/netbirdio/netbird/formatter"
)
+const defaultLogSize = 5
+
// InitLog parses and sets log-level input
func InitLog(logLevel string, logPath string) error {
level, err := log.ParseLevel(logLevel)
@@ -19,13 +22,14 @@ func InitLog(logLevel string, logPath string) error {
log.Errorf("Failed parsing log-level %s: %s", logLevel, err)
return err
}
- customOutputs := []string{"console", "syslog"};
+ customOutputs := []string{"console", "syslog"}
if logPath != "" && !slices.Contains(customOutputs, logPath) {
+ maxLogSize := getLogMaxSize()
lumberjackLogger := &lumberjack.Logger{
// Log file absolute path, os agnostic
Filename: filepath.ToSlash(logPath),
- MaxSize: 5, // MB
+ MaxSize: maxLogSize, // MB
MaxBackups: 10,
MaxAge: 30, // days
Compress: true,
@@ -46,3 +50,18 @@ func InitLog(logLevel string, logPath string) error {
log.SetLevel(level)
return nil
}
+
+func getLogMaxSize() int {
+ if sizeVar, ok := os.LookupEnv("NB_LOG_MAX_SIZE_MB"); ok {
+ size, err := strconv.ParseInt(sizeVar, 10, 64)
+ if err != nil {
+ log.Errorf("Failed parsing log-size %s: %s. Should be just an integer", sizeVar, err)
+ return defaultLogSize
+ }
+
+ log.Infof("Setting log file max size to %d MB", size)
+
+ return int(size)
+ }
+ return defaultLogSize
+}
From 1ef51a4ffa6cea6b9c5f6d0c2ee5fe12f12b793e Mon Sep 17 00:00:00 2001
From: Carlos Hernandez
Date: Fri, 13 Sep 2024 08:46:59 -0600
Subject: [PATCH 50/89] [client] Ensure engine is stopped before starting it
back (#2565)
Before starting a new instance of the engine, check if it is nil and stop the current instance
---
client/internal/connect.go | 10 +++++++++-
client/internal/engine.go | 32 ++++++++++++++++++++++++--------
2 files changed, 33 insertions(+), 9 deletions(-)
diff --git a/client/internal/connect.go b/client/internal/connect.go
index 515321f7f..36b340cfb 100644
--- a/client/internal/connect.go
+++ b/client/internal/connect.go
@@ -158,6 +158,7 @@ func (c *ConnectClient) run(
}
defer c.statusRecorder.ClientStop()
+ runningChanOpen := true
operation := func() error {
// if context cancelled we not start new backoff cycle
if c.isContextCancelled() {
@@ -267,6 +268,12 @@ func (c *ConnectClient) run(
checks := loginResp.GetChecks()
c.engineMutex.Lock()
+ if c.engine != nil && c.engine.ctx.Err() != nil {
+ log.Info("Stopping Netbird Engine")
+ if err := c.engine.Stop(); err != nil {
+ log.Errorf("Failed to stop engine: %v", err)
+ }
+ }
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, probes, checks)
c.engineMutex.Unlock()
@@ -279,9 +286,10 @@ func (c *ConnectClient) run(
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
state.Set(StatusConnected)
- if runningChan != nil {
+ if runningChan != nil && runningChanOpen {
runningChan <- nil
close(runningChan)
+ runningChanOpen = false
}
<-engineCtx.Done()
diff --git a/client/internal/engine.go b/client/internal/engine.go
index 47a36c4bf..b0deb5a29 100644
--- a/client/internal/engine.go
+++ b/client/internal/engine.go
@@ -1115,10 +1115,7 @@ func (e *Engine) close() {
}
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
- if e.dnsServer != nil {
- e.dnsServer.Stop()
- e.dnsServer = nil
- }
+ e.stopDNSServer()
if e.routeManager != nil {
e.routeManager.Stop()
@@ -1360,12 +1357,16 @@ func (e *Engine) probeTURNs() []relay.ProbeResult {
}
func (e *Engine) restartEngine() {
+ log.Info("restarting engine")
+ CtxGetState(e.ctx).Set(StatusConnecting)
+
if err := e.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}
- if err := e.Start(); err != nil {
- log.Errorf("Failed to start engine: %v", err)
- }
+
+ _ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
+ log.Infof("cancelling client, engine will be recreated")
+ e.clientCancel()
}
func (e *Engine) startNetworkMonitor() {
@@ -1387,6 +1388,7 @@ func (e *Engine) startNetworkMonitor() {
defer mu.Unlock()
if debounceTimer != nil {
+ log.Infof("Network monitor: detected network change, reset debounceTimer")
debounceTimer.Stop()
}
@@ -1396,7 +1398,7 @@ func (e *Engine) startNetworkMonitor() {
mu.Lock()
defer mu.Unlock()
- log.Infof("Network monitor detected network change, restarting engine")
+ log.Infof("Network monitor: detected network change, restarting engine")
e.restartEngine()
})
})
@@ -1421,6 +1423,20 @@ func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) {
return false, netip.Prefix{}, nil
}
+func (e *Engine) stopDNSServer() {
+ err := fmt.Errorf("DNS server stopped")
+ nsGroupStates := e.statusRecorder.GetDNSStates()
+ for i := range nsGroupStates {
+ nsGroupStates[i].Enabled = false
+ nsGroupStates[i].Error = err
+ }
+ e.statusRecorder.UpdateDNSStates(nsGroupStates)
+ if e.dnsServer != nil {
+ e.dnsServer.Stop()
+ e.dnsServer = nil
+ }
+}
+
// isChecksEqual checks if two slices of checks are equal.
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool {
From b4c8cf0a678bbed161a405b99ca0676f6266f2b3 Mon Sep 17 00:00:00 2001
From: Zoltan Papp
Date: Sat, 14 Sep 2024 10:12:54 +0200
Subject: [PATCH 51/89] Change heartbeat timeout (#2598)
---
relay/healthcheck/receiver.go | 2 +-
relay/healthcheck/sender.go | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/relay/healthcheck/receiver.go b/relay/healthcheck/receiver.go
index 2b9c9e2e0..59f780ed8 100644
--- a/relay/healthcheck/receiver.go
+++ b/relay/healthcheck/receiver.go
@@ -6,7 +6,7 @@ import (
)
var (
- heartbeatTimeout = healthCheckInterval + 3*time.Second
+ heartbeatTimeout = healthCheckInterval + 10*time.Second
)
// Receiver is a healthcheck receiver
diff --git a/relay/healthcheck/sender.go b/relay/healthcheck/sender.go
index ec0560ef2..8d1716b2c 100644
--- a/relay/healthcheck/sender.go
+++ b/relay/healthcheck/sender.go
@@ -7,7 +7,7 @@ import (
var (
healthCheckInterval = 25 * time.Second
- healthCheckTimeout = 5 * time.Second
+ healthCheckTimeout = 20 * time.Second
)
// Sender is a healthcheck sender
From 9e041b7f824d26de2a4dd8207acd83e2db5cc08c Mon Sep 17 00:00:00 2001
From: Zoltan Papp
Date: Sat, 14 Sep 2024 10:27:37 +0200
Subject: [PATCH 52/89] Fix blocked net.Conn Close call (#2600)
---
relay/client/client.go | 27 +++++++++++--
relay/client/client_test.go | 81 +++++++++++++++++++++++++++++++++++++
2 files changed, 105 insertions(+), 3 deletions(-)
diff --git a/relay/client/client.go b/relay/client/client.go
index 6560c81e1..7ff17944f 100644
--- a/relay/client/client.go
+++ b/relay/client/client.go
@@ -63,32 +63,53 @@ type connContainer struct {
messages chan Msg
msgChanLock sync.Mutex
closed bool // flag to check if channel is closed
+ ctx context.Context
+ cancel context.CancelFunc
}
func newConnContainer(conn *Conn, messages chan Msg) *connContainer {
+ ctx, cancel := context.WithCancel(context.Background())
+
return &connContainer{
conn: conn,
messages: messages,
+ ctx: ctx,
+ cancel: cancel,
}
}
func (cc *connContainer) writeMsg(msg Msg) {
cc.msgChanLock.Lock()
defer cc.msgChanLock.Unlock()
+
if cc.closed {
+ msg.Free()
return
}
- cc.messages <- msg
+
+ select {
+ case cc.messages <- msg:
+ case <-cc.ctx.Done():
+ msg.Free()
+ }
}
func (cc *connContainer) close() {
+ cc.cancel()
+
cc.msgChanLock.Lock()
defer cc.msgChanLock.Unlock()
+
if cc.closed {
return
}
- close(cc.messages)
+
cc.closed = true
+ close(cc.messages)
+
+ for msg := range cc.messages {
+ msg.Free()
+ }
}
// Client is a client for the relay server. It is responsible for establishing a connection to the relay server and
@@ -464,8 +485,8 @@ func (c *Client) closeConn(connReference *Conn, id string) error {
if container.conn != connReference {
return fmt.Errorf("conn reference mismatch")
}
- container.close()
delete(c.conns, id)
+ container.close()
return nil
}
diff --git a/relay/client/client_test.go b/relay/client/client_test.go
index b7f1a63ca..ef28203e9 100644
--- a/relay/client/client_test.go
+++ b/relay/client/client_test.go
@@ -618,6 +618,87 @@ func TestCloseByClient(t *testing.T) {
}
}
+func TestCloseNotDrainedChannel(t *testing.T) {
+ ctx := context.Background()
+ idAlice := "alice"
+ idBob := "bob"
+ srvCfg := server.ListenerConfig{Address: serverListenAddr}
+ srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
+ if err != nil {
+ t.Fatalf("failed to create server: %s", err)
+ }
+ errChan := make(chan error, 1)
+ go func() {
+ err := srv.Listen(srvCfg)
+ if err != nil {
+ errChan <- err
+ }
+ }()
+
+ defer func() {
+ err := srv.Shutdown(ctx)
+ if err != nil {
+ t.Errorf("failed to close server: %s", err)
+ }
+ }()
+
+ // wait for servers to start
+ if err := waitForServerToStart(errChan); err != nil {
+ t.Fatalf("failed to start server: %s", err)
+ }
+
+ clientAlice := NewClient(ctx, serverURL, hmacTokenStore, idAlice)
+ err = clientAlice.Connect()
+ if err != nil {
+ t.Fatalf("failed to connect to server: %s", err)
+ }
+ defer func() {
+ err := clientAlice.Close()
+ if err != nil {
+ t.Errorf("failed to close Alice client: %s", err)
+ }
+ }()
+
+ clientBob := NewClient(ctx, serverURL, hmacTokenStore, idBob)
+ err = clientBob.Connect()
+ if err != nil {
+ t.Fatalf("failed to connect to server: %s", err)
+ }
+ defer func() {
+ err := clientBob.Close()
+ if err != nil {
+ t.Errorf("failed to close Bob client: %s", err)
+ }
+ }()
+
+ connAliceToBob, err := clientAlice.OpenConn(idBob)
+ if err != nil {
+ t.Fatalf("failed to bind channel: %s", err)
+ }
+
+ connBobToAlice, err := clientBob.OpenConn(idAlice)
+ if err != nil {
+ t.Fatalf("failed to bind channel: %s", err)
+ }
+
+ payload := "hello bob, I am alice"
+ // the internal channel buffer size is 2. So we should overflow it
+ for i := 0; i < 5; i++ {
+ _, err = connAliceToBob.Write([]byte(payload))
+ if err != nil {
+ t.Fatalf("failed to write to channel: %s", err)
+ }
+
+ }
+
+ // wait for delivery
+ time.Sleep(1 * time.Second)
+ err = connBobToAlice.Close()
+ if err != nil {
+ t.Errorf("failed to close channel: %s", err)
+ }
+}
+
func waitForServerToStart(errChan chan error) error {
select {
case err := <-errChan:
From f1171198de1a5f06e83e6d4b8e785f8e49fc80d9 Mon Sep 17 00:00:00 2001
From: benniekiss <63211101+benniekiss@users.noreply.github.com>
Date: Sat, 14 Sep 2024 04:34:32 -0400
Subject: [PATCH 53/89] [management] Add command flag to set metrics port for
signal and relay service, and update management port (#2599)
* add flags to customize metrics port for relay and signal
* change management default metrics port to match other services
---
management/cmd/root.go | 2 +-
relay/cmd/root.go | 8 +++-----
signal/cmd/run.go | 6 ++----
3 files changed, 6 insertions(+), 10 deletions(-)
diff --git a/management/cmd/root.go b/management/cmd/root.go
index 9b6a96b82..86155a956 100644
--- a/management/cmd/root.go
+++ b/management/cmd/root.go
@@ -54,7 +54,7 @@ func Execute() error {
func init() {
stopCh = make(chan int)
mgmtCmd.Flags().IntVar(&mgmtPort, "port", 80, "server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise")
- mgmtCmd.Flags().IntVar(&mgmtMetricsPort, "metrics-port", 8081, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics")
+ mgmtCmd.Flags().IntVar(&mgmtMetricsPort, "metrics-port", 9090, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics")
mgmtCmd.Flags().StringVar(&mgmtDataDir, "datadir", defaultMgmtDataDir, "server data directory location")
mgmtCmd.Flags().StringVar(&mgmtConfig, "config", defaultMgmtConfig, "Netbird config file location. Config params specified via command line (e.g. datadir) have a precedence over configuration from this file")
mgmtCmd.Flags().StringVar(&mgmtLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
diff --git a/relay/cmd/root.go b/relay/cmd/root.go
index dcc1465d0..d603ff73b 100644
--- a/relay/cmd/root.go
+++ b/relay/cmd/root.go
@@ -23,15 +23,12 @@ import (
"github.com/netbirdio/netbird/util"
)
-const (
- metricsPort = 9090
-)
-
type Config struct {
ListenAddress string
// in HA every peer connect to a common domain, the instance domain has been distributed during the p2p connection
// it is a domain:port or ip:port
ExposedAddress string
+ MetricsPort int
LetsencryptEmail string
LetsencryptDataDir string
LetsencryptDomains []string
@@ -80,6 +77,7 @@ func init() {
cobraConfig = &Config{}
rootCmd.PersistentFlags().StringVarP(&cobraConfig.ListenAddress, "listen-address", "l", ":443", "listen address")
rootCmd.PersistentFlags().StringVarP(&cobraConfig.ExposedAddress, "exposed-address", "e", "", "instance domain address (or ip) and port, it will be distributes between peers")
+ rootCmd.PersistentFlags().IntVar(&cobraConfig.MetricsPort, "metrics-port", 9090, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics")
rootCmd.PersistentFlags().StringVarP(&cobraConfig.LetsencryptDataDir, "letsencrypt-data-dir", "d", "", "a directory to store Let's Encrypt data. Required if Let's Encrypt is enabled.")
rootCmd.PersistentFlags().StringSliceVarP(&cobraConfig.LetsencryptDomains, "letsencrypt-domains", "a", nil, "list of domains to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
rootCmd.PersistentFlags().StringVar(&cobraConfig.LetsencryptEmail, "letsencrypt-email", "", "email address to use for Let's Encrypt certificate registration")
@@ -116,7 +114,7 @@ func execute(cmd *cobra.Command, args []string) error {
return fmt.Errorf("failed to initialize log: %s", err)
}
- metricsServer, err := metrics.NewServer(metricsPort, "")
+ metricsServer, err := metrics.NewServer(cobraConfig.MetricsPort, "")
if err != nil {
log.Debugf("setup metrics: %v", err)
return fmt.Errorf("setup metrics: %v", err)
diff --git a/signal/cmd/run.go b/signal/cmd/run.go
index 61f7a32a7..0bdc62ead 100644
--- a/signal/cmd/run.go
+++ b/signal/cmd/run.go
@@ -29,12 +29,9 @@ import (
"google.golang.org/grpc/keepalive"
)
-const (
- metricsPort = 9090
-)
-
var (
signalPort int
+ metricsPort int
signalLetsencryptDomain string
signalSSLDir string
defaultSignalSSLDir string
@@ -288,6 +285,7 @@ func loadTLSConfig(certFile string, certKey string) (*tls.Config, error) {
func init() {
runCmd.PersistentFlags().IntVar(&signalPort, "port", 80, "Server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise")
+ runCmd.Flags().IntVar(&metricsPort, "metrics-port", 9090, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics")
runCmd.Flags().StringVar(&signalSSLDir, "ssl-dir", defaultSignalSSLDir, "server ssl directory location. *Required only for Let's Encrypt certificates.")
runCmd.Flags().StringVar(&signalLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
runCmd.Flags().StringVar(&signalCertFile, "cert-file", "", "Location of your SSL certificate. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")
From fa7767e612f60ced85ca60ad290084512f88d27a Mon Sep 17 00:00:00 2001
From: Maycon Santos
Date: Sun, 15 Sep 2024 16:07:26 +0200
Subject: [PATCH 54/89] Fix get management and signal state race condition
(#2570)
* Fix get management and signal state race condition
* fix get full status lock
---
client/internal/peer/status.go | 19 +++++++++++++++----
1 file changed, 15 insertions(+), 4 deletions(-)
diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go
index f116f3fef..915fa63f0 100644
--- a/client/internal/peer/status.go
+++ b/client/internal/peer/status.go
@@ -597,6 +597,8 @@ func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
}
func (d *Status) GetRosenpassState() RosenpassState {
+ d.mux.Lock()
+ defer d.mux.Unlock()
return RosenpassState{
d.rosenpassEnabled,
d.rosenpassPermissive,
@@ -604,6 +606,8 @@ func (d *Status) GetRosenpassState() RosenpassState {
}
func (d *Status) GetManagementState() ManagementState {
+ d.mux.Lock()
+ defer d.mux.Unlock()
return ManagementState{
d.mgmAddress,
d.managementState,
@@ -645,6 +649,8 @@ func (d *Status) IsLoginRequired() bool {
}
func (d *Status) GetSignalState() SignalState {
+ d.mux.Lock()
+ defer d.mux.Unlock()
return SignalState{
d.signalAddress,
d.signalState,
@@ -654,6 +660,8 @@ func (d *Status) GetSignalState() SignalState {
// GetRelayStates returns the stun/turn/permanent relay states
func (d *Status) GetRelayStates() []relay.ProbeResult {
+ d.mux.Lock()
+ defer d.mux.Unlock()
if d.relayMgr == nil {
return d.relayStates
}
@@ -684,6 +692,8 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
}
func (d *Status) GetDNSStates() []NSGroupState {
+ d.mux.Lock()
+ defer d.mux.Unlock()
return d.nsGroupStates
}
@@ -695,18 +705,19 @@ func (d *Status) GetResolvedDomainsStates() map[domain.Domain][]netip.Prefix {
// GetFullStatus gets full status
func (d *Status) GetFullStatus() FullStatus {
- d.mux.Lock()
- defer d.mux.Unlock()
-
fullStatus := FullStatus{
ManagementState: d.GetManagementState(),
SignalState: d.GetSignalState(),
- LocalPeerState: d.localPeer,
Relays: d.GetRelayStates(),
RosenpassState: d.GetRosenpassState(),
NSGroupStates: d.GetDNSStates(),
}
+ d.mux.Lock()
+ defer d.mux.Unlock()
+
+ fullStatus.LocalPeerState = d.localPeer
+
for _, status := range d.peers {
fullStatus.Peers = append(fullStatus.Peers, status)
}
From 82739e2832e98577f49951fe98ea664b33046e0c Mon Sep 17 00:00:00 2001
From: Bethuel Mmbaga
Date: Sun, 15 Sep 2024 17:22:46 +0300
Subject: [PATCH 55/89] [management] fix legacy decrypting of empty values
(#2595)
* allow legacy decrypting on empty values
* validate source size and padding limits
* added tests
---------
Signed-off-by: bcmmbaga
Co-authored-by: Maycon Santos
---
management/server/activity/sqlite/crypt.go | 20 +-
.../server/activity/sqlite/crypt_test.go | 213 ++++++++++++++++++
2 files changed, 228 insertions(+), 5 deletions(-)
diff --git a/management/server/activity/sqlite/crypt.go b/management/server/activity/sqlite/crypt.go
index 852d9bc4a..096f49ea3 100644
--- a/management/server/activity/sqlite/crypt.go
+++ b/management/server/activity/sqlite/crypt.go
@@ -7,7 +7,6 @@ import (
"crypto/rand"
"encoding/base64"
"errors"
- "fmt"
)
var iv = []byte{10, 22, 13, 79, 05, 8, 52, 91, 87, 98, 88, 98, 35, 25, 13, 05}
@@ -115,12 +114,23 @@ func pkcs5Padding(ciphertext []byte) []byte {
padText := bytes.Repeat([]byte{byte(padding)}, padding)
return append(ciphertext, padText...)
}
-
func pkcs5UnPadding(src []byte) ([]byte, error) {
srcLen := len(src)
- paddingLen := int(src[srcLen-1])
- if paddingLen >= srcLen || paddingLen > aes.BlockSize {
- return nil, fmt.Errorf("padding size error")
+ if srcLen == 0 {
+ return nil, errors.New("input data is empty")
}
+
+ paddingLen := int(src[srcLen-1])
+ if paddingLen == 0 || paddingLen > aes.BlockSize || paddingLen > srcLen {
+ return nil, errors.New("invalid padding size")
+ }
+
+ // Verify that all padding bytes are the same
+ for i := 0; i < paddingLen; i++ {
+ if src[srcLen-1-i] != byte(paddingLen) {
+ return nil, errors.New("invalid padding")
+ }
+ }
+
return src[:srcLen-paddingLen], nil
}
diff --git a/management/server/activity/sqlite/crypt_test.go b/management/server/activity/sqlite/crypt_test.go
index 1033ab6ed..aff3a08b1 100644
--- a/management/server/activity/sqlite/crypt_test.go
+++ b/management/server/activity/sqlite/crypt_test.go
@@ -1,6 +1,7 @@
package sqlite
import (
+ "bytes"
"testing"
)
@@ -95,3 +96,215 @@ func TestCorruptKey(t *testing.T) {
t.Fatalf("incorrect decryption, the result is: %s", res)
}
}
+
+func TestEncryptDecrypt(t *testing.T) {
+ // Generate a key for encryption/decryption
+ key, err := GenerateKey()
+ if err != nil {
+ t.Fatalf("Failed to generate key: %v", err)
+ }
+
+ // Initialize the FieldEncrypt with the generated key
+ ec, err := NewFieldEncrypt(key)
+ if err != nil {
+ t.Fatalf("Failed to create FieldEncrypt: %v", err)
+ }
+
+ // Test cases
+ testCases := []struct {
+ name string
+ input string
+ }{
+ {
+ name: "Empty String",
+ input: "",
+ },
+ {
+ name: "Short String",
+ input: "Hello",
+ },
+ {
+ name: "String with Spaces",
+ input: "Hello, World!",
+ },
+ {
+ name: "Long String",
+ input: "The quick brown fox jumps over the lazy dog.",
+ },
+ {
+ name: "Unicode Characters",
+ input: "こんにちは世界",
+ },
+ {
+ name: "Special Characters",
+ input: "!@#$%^&*()_+-=[]{}|;':\",./<>?",
+ },
+ {
+ name: "Numeric String",
+ input: "1234567890",
+ },
+ {
+ name: "Repeated Characters",
+ input: "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
+ },
+ {
+ name: "Multi-block String",
+ input: "This is a longer string that will span multiple blocks in the encryption algorithm.",
+ },
+ {
+ name: "Non-ASCII and ASCII Mix",
+ input: "Hello 世界 123",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name+" - Legacy", func(t *testing.T) {
+ // Legacy Encryption
+ encryptedLegacy := ec.LegacyEncrypt(tc.input)
+ if encryptedLegacy == "" {
+ t.Errorf("LegacyEncrypt returned empty string for input '%s'", tc.input)
+ }
+
+ // Legacy Decryption
+ decryptedLegacy, err := ec.LegacyDecrypt(encryptedLegacy)
+ if err != nil {
+ t.Errorf("LegacyDecrypt failed for input '%s': %v", tc.input, err)
+ }
+
+ // Verify that the decrypted value matches the original input
+ if decryptedLegacy != tc.input {
+ t.Errorf("LegacyDecrypt output '%s' does not match original input '%s'", decryptedLegacy, tc.input)
+ }
+ })
+
+ t.Run(tc.name+" - New", func(t *testing.T) {
+ // New Encryption
+ encryptedNew, err := ec.Encrypt(tc.input)
+ if err != nil {
+ t.Errorf("Encrypt failed for input '%s': %v", tc.input, err)
+ }
+ if encryptedNew == "" {
+ t.Errorf("Encrypt returned empty string for input '%s'", tc.input)
+ }
+
+ // New Decryption
+ decryptedNew, err := ec.Decrypt(encryptedNew)
+ if err != nil {
+ t.Errorf("Decrypt failed for input '%s': %v", tc.input, err)
+ }
+
+ // Verify that the decrypted value matches the original input
+ if decryptedNew != tc.input {
+ t.Errorf("Decrypt output '%s' does not match original input '%s'", decryptedNew, tc.input)
+ }
+ })
+ }
+}
+
+func TestPKCS5UnPadding(t *testing.T) {
+ tests := []struct {
+ name string
+ input []byte
+ expected []byte
+ expectError bool
+ }{
+ {
+ name: "Valid Padding",
+ input: append([]byte("Hello, World!"), bytes.Repeat([]byte{4}, 4)...),
+ expected: []byte("Hello, World!"),
+ },
+ {
+ name: "Empty Input",
+ input: []byte{},
+ expectError: true,
+ },
+ {
+ name: "Padding Length Zero",
+ input: append([]byte("Hello, World!"), bytes.Repeat([]byte{0}, 4)...),
+ expectError: true,
+ },
+ {
+ name: "Padding Length Exceeds Block Size",
+ input: append([]byte("Hello, World!"), bytes.Repeat([]byte{17}, 17)...),
+ expectError: true,
+ },
+ {
+ name: "Padding Length Exceeds Input Length",
+ input: []byte{5, 5, 5},
+ expectError: true,
+ },
+ {
+ name: "Invalid Padding Bytes",
+ input: append([]byte("Hello, World!"), []byte{2, 3, 4, 5}...),
+ expectError: true,
+ },
+ {
+ name: "Valid Single Byte Padding",
+ input: append([]byte("Hello, World!"), byte(1)),
+ expected: []byte("Hello, World!"),
+ },
+ {
+ name: "Invalid Mixed Padding Bytes",
+ input: append([]byte("Hello, World!"), []byte{3, 3, 2}...),
+ expectError: true,
+ },
+ {
+ name: "Valid Full Block Padding",
+ input: append([]byte("Hello, World!"), bytes.Repeat([]byte{16}, 16)...),
+ expected: []byte("Hello, World!"),
+ },
+ {
+ name: "Non-Padding Byte at End",
+ input: append([]byte("Hello, World!"), []byte{4, 4, 4, 5}...),
+ expectError: true,
+ },
+ {
+ name: "Valid Padding with Different Text Length",
+ input: append([]byte("Test"), bytes.Repeat([]byte{12}, 12)...),
+ expected: []byte("Test"),
+ },
+ {
+ name: "Padding Length Equal to Input Length",
+ input: bytes.Repeat([]byte{8}, 8),
+ expected: []byte{},
+ },
+ {
+ name: "Invalid Padding Length Zero (Again)",
+ input: append([]byte("Test"), byte(0)),
+ expectError: true,
+ },
+ {
+ name: "Padding Length Greater Than Input",
+ input: []byte{10},
+ expectError: true,
+ },
+ {
+ name: "Input Length Not Multiple of Block Size",
+ input: append([]byte("Invalid Length"), byte(1)),
+ expected: []byte("Invalid Length"),
+ },
+ {
+ name: "Valid Padding with Non-ASCII Characters",
+ input: append([]byte("こんにちは"), bytes.Repeat([]byte{2}, 2)...),
+ expected: []byte("こんにちは"),
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result, err := pkcs5UnPadding(tt.input)
+ if tt.expectError {
+ if err == nil {
+ t.Errorf("Expected error but got nil")
+ }
+ } else {
+ if err != nil {
+ t.Errorf("Did not expect error but got: %v", err)
+ }
+ if !bytes.Equal(result, tt.expected) {
+ t.Errorf("Expected output %v, got %v", tt.expected, result)
+ }
+ }
+ })
+ }
+}
From 730dd1733e3a1651fa8cb0fc537327318cda0024 Mon Sep 17 00:00:00 2001
From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com>
Date: Sun, 15 Sep 2024 16:46:55 +0200
Subject: [PATCH 56/89] [signal] Fix signal active peers metrics (#2591)
---
signal/peer/peer.go | 9 ++++++---
signal/server/signal.go | 3 ---
2 files changed, 6 insertions(+), 6 deletions(-)
diff --git a/signal/peer/peer.go b/signal/peer/peer.go
index 85de91581..ed2360d67 100644
--- a/signal/peer/peer.go
+++ b/signal/peer/peer.go
@@ -82,8 +82,11 @@ func (registry *Registry) Register(peer *Peer) {
log.Warnf("peer [%s] is already registered [new streamID %d, previous StreamID %d]. Will override stream.",
peer.Id, peer.StreamID, pp.StreamID)
registry.Peers.Store(peer.Id, peer)
+ return
}
+
log.Debugf("peer registered [%s]", peer.Id)
+ registry.metrics.ActivePeers.Add(context.Background(), 1)
// record time as milliseconds
registry.metrics.RegistrationDelay.Record(context.Background(), float64(time.Since(start).Nanoseconds())/1e6)
@@ -105,8 +108,8 @@ func (registry *Registry) Deregister(peer *Peer) {
peer.Id, pp.StreamID, peer.StreamID)
return
}
+ registry.metrics.ActivePeers.Add(context.Background(), -1)
+ log.Debugf("peer deregistered [%s]", peer.Id)
+ registry.metrics.Deregistrations.Add(context.Background(), 1)
}
- log.Debugf("peer deregistered [%s]", peer.Id)
-
- registry.metrics.Deregistrations.Add(context.Background(), 1)
}
diff --git a/signal/server/signal.go b/signal/server/signal.go
index 69387cc69..b268aa3fc 100644
--- a/signal/server/signal.go
+++ b/signal/server/signal.go
@@ -133,8 +133,6 @@ func (s *Server) RegisterPeer(stream proto.SignalExchange_ConnectStreamServer) (
s.registry.Register(p)
s.dispatcher.ListenForMessages(stream.Context(), p.Id, s.forwardMessageToPeer)
- s.metrics.ActivePeers.Add(stream.Context(), 1)
-
return p, nil
} else {
s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorMissingId)))
@@ -151,7 +149,6 @@ func (s *Server) DeregisterPeer(p *peer.Peer) {
s.registry.Deregister(p)
s.metrics.PeerConnectionDuration.Record(p.Stream.Context(), int64(time.Since(p.RegisteredAt).Seconds()))
- s.metrics.ActivePeers.Add(context.Background(), -1)
}
func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedMessage) {
From 6c50b0c84b66cfac198e0aed47bfc9688237fb7f Mon Sep 17 00:00:00 2001
From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com>
Date: Mon, 16 Sep 2024 15:47:03 +0200
Subject: [PATCH 57/89] [management] Add transaction to addPeer (#2469)
This PR removes the GetAccount and SaveAccount operations from the AddPeer and instead makes use of gorm.Transaction to add the new peer.
---
.github/workflows/golang-test-linux.yml | 2 +-
management/server/account.go | 42 ++-
management/server/ephemeral_test.go | 3 +-
management/server/file_store.go | 168 +++++++++-
management/server/management_proto_test.go | 47 ++-
management/server/peer.go | 303 ++++++++++--------
management/server/peer_test.go | 185 +++++++++++
management/server/sql_store.go | 238 ++++++++++++--
management/server/sql_store_test.go | 160 +++++++++
management/server/status/error.go | 10 +
management/server/store.go | 27 +-
.../server/testdata/extended-store.json | 120 +++++++
management/server/user.go | 6 +-
13 files changed, 1095 insertions(+), 216 deletions(-)
create mode 100644 management/server/testdata/extended-store.json
diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml
index 263623bd1..2d5cf2856 100644
--- a/.github/workflows/golang-test-linux.yml
+++ b/.github/workflows/golang-test-linux.yml
@@ -49,7 +49,7 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
- run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./...
+ run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 6m -p 1 ./...
test_client_on_docker:
runs-on: ubuntu-20.04
diff --git a/management/server/account.go b/management/server/account.go
index 7159aa9ac..208315643 100644
--- a/management/server/account.go
+++ b/management/server/account.go
@@ -263,6 +263,11 @@ type AccountSettings struct {
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
}
+// Subclass used in gorm to only load network and not whole account
+type AccountNetwork struct {
+ Network *Network `gorm:"embedded;embeddedPrefix:network_"`
+}
+
type UserPermissions struct {
DashboardView string `json:"dashboard_view"`
}
@@ -700,14 +705,6 @@ func (a *Account) GetPeerGroupsList(peerID string) []string {
return grps
}
-func (a *Account) getUserGroups(userID string) ([]string, error) {
- user, err := a.FindUser(userID)
- if err != nil {
- return nil, err
- }
- return user.AutoGroups, nil
-}
-
func (a *Account) getPeerDNSManagementStatus(peerID string) bool {
peerGroups := a.getPeerGroups(peerID)
enabled := true
@@ -734,14 +731,6 @@ func (a *Account) getPeerGroups(peerID string) lookupMap {
return groupList
}
-func (a *Account) getSetupKeyGroups(setupKey string) ([]string, error) {
- key, err := a.FindSetupKey(setupKey)
- if err != nil {
- return nil, err
- }
- return key.AutoGroups, nil
-}
-
func (a *Account) getTakenIPs() []net.IP {
var takenIps []net.IP
for _, existingPeer := range a.Peers {
@@ -2082,7 +2071,7 @@ func (am *DefaultAccountManager) GetAccountIDForPeerKey(ctx context.Context, pee
}
func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpeer.Peer, settings *Settings) (bool, error) {
- user, err := am.Store.GetUserByUserID(ctx, peer.UserID)
+ user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, peer.UserID)
if err != nil {
return false, err
}
@@ -2103,6 +2092,25 @@ func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpee
return false, nil
}
+func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, store Store, accountID string, peerHostName string) (string, error) {
+ existingLabels, err := store.GetPeerLabelsInAccount(ctx, LockingStrengthShare, accountID)
+ if err != nil {
+ return "", fmt.Errorf("failed to get peer dns labels: %w", err)
+ }
+
+ labelMap := ConvertSliceToMap(existingLabels)
+ newLabel, err := getPeerHostLabel(peerHostName, labelMap)
+ if err != nil {
+ return "", fmt.Errorf("failed to get new host label: %w", err)
+ }
+
+ if newLabel == "" {
+ return "", fmt.Errorf("failed to get new host label: %w", err)
+ }
+
+ return newLabel, nil
+}
+
// addAllGroup to account object if it doesn't exist
func addAllGroup(account *Account) error {
if len(account.Groups) == 0 {
diff --git a/management/server/ephemeral_test.go b/management/server/ephemeral_test.go
index 36c88f1d1..1390352a5 100644
--- a/management/server/ephemeral_test.go
+++ b/management/server/ephemeral_test.go
@@ -7,6 +7,7 @@ import (
"time"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
+ "github.com/netbirdio/netbird/management/server/status"
)
type MockStore struct {
@@ -24,7 +25,7 @@ func (s *MockStore) GetAccountByPeerID(_ context.Context, peerId string) (*Accou
return s.account, nil
}
- return nil, fmt.Errorf("account not found")
+ return nil, status.NewPeerNotFoundError(peerId)
}
type MocAccountManager struct {
diff --git a/management/server/file_store.go b/management/server/file_store.go
index 1927568ef..95d5b4e6e 100644
--- a/management/server/file_store.go
+++ b/management/server/file_store.go
@@ -2,6 +2,8 @@ package server
import (
"context"
+ "errors"
+ "net"
"os"
"path/filepath"
"strings"
@@ -46,6 +48,158 @@ type FileStore struct {
metrics telemetry.AppMetrics `json:"-"`
}
+func (s *FileStore) ExecuteInTransaction(ctx context.Context, f func(store Store) error) error {
+ return f(s)
+}
+
+func (s *FileStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
+ s.mux.Lock()
+ defer s.mux.Unlock()
+
+ accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKeyID)]
+ if !ok {
+ return status.NewSetupKeyNotFoundError()
+ }
+
+ account, err := s.getAccount(accountID)
+ if err != nil {
+ return err
+ }
+
+ account.SetupKeys[setupKeyID].UsedTimes++
+
+ return s.SaveAccount(ctx, account)
+}
+
+func (s *FileStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
+ s.mux.Lock()
+ defer s.mux.Unlock()
+
+ account, err := s.getAccount(accountID)
+ if err != nil {
+ return err
+ }
+
+ allGroup, err := account.GetGroupAll()
+ if err != nil || allGroup == nil {
+ return errors.New("all group not found")
+ }
+
+ allGroup.Peers = append(allGroup.Peers, peerID)
+
+ return nil
+}
+
+func (s *FileStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error {
+ s.mux.Lock()
+ defer s.mux.Unlock()
+
+ account, err := s.getAccount(accountId)
+ if err != nil {
+ return err
+ }
+
+ account.Groups[groupID].Peers = append(account.Groups[groupID].Peers, peerId)
+
+ return nil
+}
+
+func (s *FileStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
+ s.mux.Lock()
+ defer s.mux.Unlock()
+
+ account, ok := s.Accounts[peer.AccountID]
+ if !ok {
+ return status.NewAccountNotFoundError(peer.AccountID)
+ }
+
+ account.Peers[peer.ID] = peer
+ return s.SaveAccount(ctx, account)
+}
+
+func (s *FileStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
+ s.mux.Lock()
+ defer s.mux.Unlock()
+
+ account, ok := s.Accounts[accountId]
+ if !ok {
+ return status.NewAccountNotFoundError(accountId)
+ }
+
+ account.Network.Serial++
+
+ return s.SaveAccount(ctx, account)
+}
+
+func (s *FileStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) {
+ s.mux.Lock()
+ defer s.mux.Unlock()
+
+ accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(key)]
+ if !ok {
+ return nil, status.NewSetupKeyNotFoundError()
+ }
+
+ account, err := s.getAccount(accountID)
+ if err != nil {
+ return nil, err
+ }
+
+ setupKey, ok := account.SetupKeys[key]
+ if !ok {
+ return nil, status.Errorf(status.NotFound, "setup key not found")
+ }
+
+ return setupKey, nil
+}
+
+func (s *FileStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) {
+ s.mux.Lock()
+ defer s.mux.Unlock()
+
+ account, err := s.getAccount(accountID)
+ if err != nil {
+ return nil, err
+ }
+
+ var takenIps []net.IP
+ for _, existingPeer := range account.Peers {
+ takenIps = append(takenIps, existingPeer.IP)
+ }
+
+ return takenIps, nil
+}
+
+func (s *FileStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) {
+ s.mux.Lock()
+ defer s.mux.Unlock()
+
+ account, err := s.getAccount(accountID)
+ if err != nil {
+ return nil, err
+ }
+
+ existingLabels := []string{}
+ for _, peer := range account.Peers {
+ if peer.DNSLabel != "" {
+ existingLabels = append(existingLabels, peer.DNSLabel)
+ }
+ }
+ return existingLabels, nil
+}
+
+func (s *FileStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) {
+ s.mux.Lock()
+ defer s.mux.Unlock()
+
+ account, err := s.getAccount(accountID)
+ if err != nil {
+ return nil, err
+ }
+
+ return account.Network, nil
+}
+
type StoredAccount struct{}
// NewFileStore restores a store from the file located in the datadir
@@ -422,7 +576,7 @@ func (s *FileStore) GetAccountBySetupKey(_ context.Context, setupKey string) (*A
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)]
if !ok {
- return nil, status.Errorf(status.NotFound, "account not found: provided setup key doesn't exists")
+ return nil, status.NewSetupKeyNotFoundError()
}
account, err := s.getAccount(accountID)
@@ -469,7 +623,7 @@ func (s *FileStore) GetUserByTokenID(_ context.Context, tokenID string) (*User,
return account.Users[userID].Copy(), nil
}
-func (s *FileStore) GetUserByUserID(_ context.Context, userID string) (*User, error) {
+func (s *FileStore) GetUserByUserID(_ context.Context, _ LockingStrength, userID string) (*User, error) {
accountID, ok := s.UserID2AccountID[userID]
if !ok {
return nil, status.Errorf(status.NotFound, "accountID not found: provided userID doesn't exists")
@@ -513,7 +667,7 @@ func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) {
func (s *FileStore) getAccount(accountID string) (*Account, error) {
account, ok := s.Accounts[accountID]
if !ok {
- return nil, status.Errorf(status.NotFound, "account not found")
+ return nil, status.NewAccountNotFoundError(accountID)
}
return account, nil
@@ -639,13 +793,13 @@ func (s *FileStore) GetAccountIDBySetupKey(_ context.Context, setupKey string) (
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)]
if !ok {
- return "", status.Errorf(status.NotFound, "account not found: provided setup key doesn't exists")
+ return "", status.NewSetupKeyNotFoundError()
}
return accountID, nil
}
-func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, peerKey string) (*nbpeer.Peer, error) {
+func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, _ LockingStrength, peerKey string) (*nbpeer.Peer, error) {
s.mux.Lock()
defer s.mux.Unlock()
@@ -668,7 +822,7 @@ func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, peerKey string) (*nbp
return nil, status.NewPeerNotFoundError(peerKey)
}
-func (s *FileStore) GetAccountSettings(_ context.Context, accountID string) (*Settings, error) {
+func (s *FileStore) GetAccountSettings(_ context.Context, _ LockingStrength, accountID string) (*Settings, error) {
s.mux.Lock()
defer s.mux.Unlock()
@@ -758,7 +912,7 @@ func (s *FileStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.
}
// SaveUserLastLogin stores the last login time for a user in memory. It doesn't attempt to persist data to speed up things.
-func (s *FileStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error {
+func (s *FileStore) SaveUserLastLogin(_ context.Context, accountID, userID string, lastLogin time.Time) error {
s.mux.Lock()
defer s.mux.Unlock()
diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go
index 00ee4bda2..ff09129bd 100644
--- a/management/server/management_proto_test.go
+++ b/management/server/management_proto_test.go
@@ -627,7 +627,7 @@ func testSyncStatusRace(t *testing.T) {
}
time.Sleep(10 * time.Millisecond)
- peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), peerWithInvalidStatus.PublicKey().String())
+ peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, peerWithInvalidStatus.PublicKey().String())
if err != nil {
t.Fatal(err)
return
@@ -638,8 +638,8 @@ func testSyncStatusRace(t *testing.T) {
}
func Test_LoginPerformance(t *testing.T) {
- if os.Getenv("CI") == "true" {
- t.Skip("Skipping on CI")
+ if os.Getenv("CI") == "true" || runtime.GOOS == "windows" {
+ t.Skip("Skipping test on CI or Windows")
}
t.Setenv("NETBIRD_STORE_ENGINE", "sqlite")
@@ -655,7 +655,7 @@ func Test_LoginPerformance(t *testing.T) {
// {"M", 250, 1},
// {"L", 500, 1},
// {"XL", 750, 1},
- {"XXL", 2000, 1},
+ {"XXL", 5000, 1},
}
log.SetOutput(io.Discard)
@@ -700,15 +700,18 @@ func Test_LoginPerformance(t *testing.T) {
}
defer mgmtServer.GracefulStop()
+ t.Logf("management setup complete, start registering peers")
+
var counter int32
var counterStart int32
- var wg sync.WaitGroup
+ var wgAccount sync.WaitGroup
var mu sync.Mutex
messageCalls := []func() error{}
for j := 0; j < bc.accounts; j++ {
- wg.Add(1)
+ wgAccount.Add(1)
+ var wgPeer sync.WaitGroup
go func(j int, counter *int32, counterStart *int32) {
- defer wg.Done()
+ defer wgAccount.Done()
account, err := createAccount(am, fmt.Sprintf("account-%d", j), fmt.Sprintf("user-%d", j), fmt.Sprintf("domain-%d", j))
if err != nil {
@@ -722,7 +725,9 @@ func Test_LoginPerformance(t *testing.T) {
return
}
+ startTime := time.Now()
for i := 0; i < bc.peers; i++ {
+ wgPeer.Add(1)
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Logf("failed to generate key: %v", err)
@@ -763,21 +768,29 @@ func Test_LoginPerformance(t *testing.T) {
mu.Lock()
messageCalls = append(messageCalls, login)
mu.Unlock()
- _, _, _, err = am.LoginPeer(context.Background(), peerLogin)
- if err != nil {
- t.Logf("failed to login peer: %v", err)
- return
- }
- atomic.AddInt32(counterStart, 1)
- if *counterStart%100 == 0 {
- t.Logf("registered %d peers", *counterStart)
- }
+ go func(peerLogin PeerLogin, counterStart *int32) {
+ defer wgPeer.Done()
+ _, _, _, err = am.LoginPeer(context.Background(), peerLogin)
+ if err != nil {
+ t.Logf("failed to login peer: %v", err)
+ return
+ }
+
+ atomic.AddInt32(counterStart, 1)
+ if *counterStart%100 == 0 {
+ t.Logf("registered %d peers", *counterStart)
+ }
+ }(peerLogin, counterStart)
+
}
+ wgPeer.Wait()
+
+ t.Logf("Time for registration: %s", time.Since(startTime))
}(j, &counter, &counterStart)
}
- wg.Wait()
+ wgAccount.Wait()
t.Logf("prepared %d login calls", len(messageCalls))
testLoginPerformance(t, messageCalls)
diff --git a/management/server/peer.go b/management/server/peer.go
index 26e27617d..da9586734 100644
--- a/management/server/peer.go
+++ b/management/server/peer.go
@@ -11,6 +11,7 @@ import (
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
+ "github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/proto"
@@ -371,164 +372,175 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
}
}()
- var account *Account
- // ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
- account, err = am.Store.GetAccount(ctx, accountID)
- if err != nil {
- return nil, nil, nil, err
- }
-
- if strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad" && userID != "" {
- if am.idpManager != nil {
- userdata, err := am.lookupUserInCache(ctx, userID, account)
- if err == nil && userdata != nil {
- peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0])
- }
- }
- }
-
// This is a handling for the case when the same machine (with the same WireGuard pub key) tries to register twice.
// Such case is possible when AddPeer function takes long time to finish after AcquireWriteLockByUID (e.g., database is slow)
// and the peer disconnects with a timeout and tries to register again.
// We just check if this machine has been registered before and reject the second registration.
// The connecting peer should be able to recover with a retry.
- _, err = account.FindPeerByPubKey(peer.Key)
+ _, err = am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, peer.Key)
if err == nil {
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "peer has been already registered")
}
opEvent := &activity.Event{
Timestamp: time.Now().UTC(),
- AccountID: account.Id,
+ AccountID: accountID,
}
- var ephemeral bool
- setupKeyName := ""
- if !addedByUser {
- // validate the setup key if adding with a key
- sk, err := account.FindSetupKey(upperKey)
- if err != nil {
- return nil, nil, nil, err
- }
+ var newPeer *nbpeer.Peer
- if !sk.IsValid() {
- return nil, nil, nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid")
- }
-
- account.SetupKeys[sk.Key] = sk.IncrementUsage()
- opEvent.InitiatorID = sk.Id
- opEvent.Activity = activity.PeerAddedWithSetupKey
- ephemeral = sk.Ephemeral
- setupKeyName = sk.Name
- } else {
- opEvent.InitiatorID = userID
- opEvent.Activity = activity.PeerAddedByUser
- }
-
- takenIps := account.getTakenIPs()
- existingLabels := account.getPeerDNSLabels()
-
- newLabel, err := getPeerHostLabel(peer.Meta.Hostname, existingLabels)
- if err != nil {
- return nil, nil, nil, err
- }
-
- peer.DNSLabel = newLabel
- network := account.Network
- nextIp, err := AllocatePeerIP(network.Net, takenIps)
- if err != nil {
- return nil, nil, nil, err
- }
-
- registrationTime := time.Now().UTC()
-
- newPeer := &nbpeer.Peer{
- ID: xid.New().String(),
- Key: peer.Key,
- SetupKey: upperKey,
- IP: nextIp,
- Meta: peer.Meta,
- Name: peer.Meta.Hostname,
- DNSLabel: newLabel,
- UserID: userID,
- Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime},
- SSHEnabled: false,
- SSHKey: peer.SSHKey,
- LastLogin: registrationTime,
- CreatedAt: registrationTime,
- LoginExpirationEnabled: addedByUser,
- Ephemeral: ephemeral,
- Location: peer.Location,
- }
-
- if am.geo != nil && newPeer.Location.ConnectionIP != nil {
- location, err := am.geo.Lookup(newPeer.Location.ConnectionIP)
- if err != nil {
- log.WithContext(ctx).Warnf("failed to get location for new peer realip: [%s]: %v", newPeer.Location.ConnectionIP.String(), err)
+ err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
+ var groupsToAdd []string
+ var setupKeyID string
+ var setupKeyName string
+ var ephemeral bool
+ if addedByUser {
+ user, err := transaction.GetUserByUserID(ctx, LockingStrengthUpdate, userID)
+ if err != nil {
+ return fmt.Errorf("failed to get user groups: %w", err)
+ }
+ groupsToAdd = user.AutoGroups
+ opEvent.InitiatorID = userID
+ opEvent.Activity = activity.PeerAddedByUser
} else {
- newPeer.Location.CountryCode = location.Country.ISOCode
- newPeer.Location.CityName = location.City.Names.En
- newPeer.Location.GeoNameID = location.City.GeonameID
- }
- }
+ // Validate the setup key
+ sk, err := transaction.GetSetupKeyBySecret(ctx, LockingStrengthUpdate, upperKey)
+ if err != nil {
+ return fmt.Errorf("failed to get setup key: %w", err)
+ }
- // add peer to 'All' group
- group, err := account.GetGroupAll()
- if err != nil {
- return nil, nil, nil, err
- }
- group.Peers = append(group.Peers, newPeer.ID)
+ if !sk.IsValid() {
+ return status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid")
+ }
- var groupsToAdd []string
- if addedByUser {
- groupsToAdd, err = account.getUserGroups(userID)
- if err != nil {
- return nil, nil, nil, err
+ opEvent.InitiatorID = sk.Id
+ opEvent.Activity = activity.PeerAddedWithSetupKey
+ groupsToAdd = sk.AutoGroups
+ ephemeral = sk.Ephemeral
+ setupKeyID = sk.Id
+ setupKeyName = sk.Name
}
- } else {
- groupsToAdd, err = account.getSetupKeyGroups(upperKey)
- if err != nil {
- return nil, nil, nil, err
- }
- }
- if len(groupsToAdd) > 0 {
- for _, s := range groupsToAdd {
- if g, ok := account.Groups[s]; ok && g.Name != "All" {
- g.Peers = append(g.Peers, newPeer.ID)
+ if strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad" && userID != "" {
+ if am.idpManager != nil {
+ userdata, err := am.idpManager.GetUserDataByID(ctx, userID, idp.AppMetadata{WTAccountID: accountID})
+ if err == nil && userdata != nil {
+ peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0])
+ }
}
}
- }
- newPeer = am.integratedPeerValidator.PreparePeer(ctx, account.Id, newPeer, account.GetPeerGroupsList(newPeer.ID), account.Settings.Extra)
-
- if addedByUser {
- user, err := account.FindUser(userID)
+ freeLabel, err := am.getFreeDNSLabel(ctx, transaction, accountID, peer.Meta.Hostname)
if err != nil {
- return nil, nil, nil, status.Errorf(status.Internal, "couldn't find user")
+ return fmt.Errorf("failed to get free DNS label: %w", err)
}
- user.updateLastLogin(newPeer.LastLogin)
- }
- account.Peers[newPeer.ID] = newPeer
- account.Network.IncSerial()
- err = am.Store.SaveAccount(ctx, account)
+ freeIP, err := am.getFreeIP(ctx, transaction, accountID)
+ if err != nil {
+ return fmt.Errorf("failed to get free IP: %w", err)
+ }
+
+ registrationTime := time.Now().UTC()
+ newPeer = &nbpeer.Peer{
+ ID: xid.New().String(),
+ AccountID: accountID,
+ Key: peer.Key,
+ SetupKey: upperKey,
+ IP: freeIP,
+ Meta: peer.Meta,
+ Name: peer.Meta.Hostname,
+ DNSLabel: freeLabel,
+ UserID: userID,
+ Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime},
+ SSHEnabled: false,
+ SSHKey: peer.SSHKey,
+ LastLogin: registrationTime,
+ CreatedAt: registrationTime,
+ LoginExpirationEnabled: addedByUser,
+ Ephemeral: ephemeral,
+ Location: peer.Location,
+ }
+ opEvent.TargetID = newPeer.ID
+ opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain())
+ if !addedByUser {
+ opEvent.Meta["setup_key_name"] = setupKeyName
+ }
+
+ if am.geo != nil && newPeer.Location.ConnectionIP != nil {
+ location, err := am.geo.Lookup(newPeer.Location.ConnectionIP)
+ if err != nil {
+ log.WithContext(ctx).Warnf("failed to get location for new peer realip: [%s]: %v", newPeer.Location.ConnectionIP.String(), err)
+ } else {
+ newPeer.Location.CountryCode = location.Country.ISOCode
+ newPeer.Location.CityName = location.City.Names.En
+ newPeer.Location.GeoNameID = location.City.GeonameID
+ }
+ }
+
+ settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID)
+ if err != nil {
+ return fmt.Errorf("failed to get account settings: %w", err)
+ }
+ newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra)
+
+ err = transaction.AddPeerToAllGroup(ctx, accountID, newPeer.ID)
+ if err != nil {
+ return fmt.Errorf("failed adding peer to All group: %w", err)
+ }
+
+ if len(groupsToAdd) > 0 {
+ for _, g := range groupsToAdd {
+ err = transaction.AddPeerToGroup(ctx, accountID, newPeer.ID, g)
+ if err != nil {
+ return err
+ }
+ }
+ }
+
+ err = transaction.AddPeerToAccount(ctx, newPeer)
+ if err != nil {
+ return fmt.Errorf("failed to add peer to account: %w", err)
+ }
+
+ err = transaction.IncrementNetworkSerial(ctx, accountID)
+ if err != nil {
+ return fmt.Errorf("failed to increment network serial: %w", err)
+ }
+
+ if addedByUser {
+ err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.LastLogin)
+ if err != nil {
+ return fmt.Errorf("failed to update user last login: %w", err)
+ }
+ } else {
+ err = transaction.IncrementSetupKeyUsage(ctx, setupKeyID)
+ if err != nil {
+ return fmt.Errorf("failed to increment setup key usage: %w", err)
+ }
+ }
+
+ log.WithContext(ctx).Debugf("Peer %s added to account %s", newPeer.ID, accountID)
+ return nil
+ })
+
if err != nil {
- return nil, nil, nil, err
+ return nil, nil, nil, fmt.Errorf("failed to add peer to database: %w", err)
}
- // Account is saved, we can release the lock
- unlock()
- unlock = nil
-
- opEvent.TargetID = newPeer.ID
- opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain())
- if !addedByUser {
- opEvent.Meta["setup_key_name"] = setupKeyName
+ if newPeer == nil {
+ return nil, nil, nil, fmt.Errorf("new peer is nil")
}
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
+ unlock()
+ unlock = nil
+
+ account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
+ if err != nil {
+ return nil, nil, nil, fmt.Errorf("error getting account: %w", err)
+ }
+
am.updateAccountPeers(ctx, account)
approvedPeersMap, err := am.GetValidatedPeers(account)
@@ -536,12 +548,31 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
return nil, nil, nil, err
}
- postureChecks := am.getPeerPostureChecks(account, peer)
+ postureChecks := am.getPeerPostureChecks(account, newPeer)
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
return newPeer, networkMap, postureChecks, nil
}
+func (am *DefaultAccountManager) getFreeIP(ctx context.Context, store Store, accountID string) (net.IP, error) {
+ takenIps, err := store.GetTakenIPs(ctx, LockingStrengthShare, accountID)
+ if err != nil {
+ return nil, fmt.Errorf("failed to get taken IPs: %w", err)
+ }
+
+ network, err := store.GetAccountNetwork(ctx, LockingStrengthUpdate, accountID)
+ if err != nil {
+ return nil, fmt.Errorf("failed getting network: %w", err)
+ }
+
+ nextIp, err := AllocatePeerIP(network.Net, takenIps)
+ if err != nil {
+ return nil, fmt.Errorf("failed to allocate new peer ip: %w", err)
+ }
+
+ return nextIp, nil
+}
+
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey)
@@ -647,12 +678,12 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
}
}()
- peer, err := am.Store.GetPeerByPeerPubKey(ctx, login.WireGuardPubKey)
+ peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthUpdate, login.WireGuardPubKey)
if err != nil {
return nil, nil, nil, err
}
- settings, err := am.Store.GetAccountSettings(ctx, accountID)
+ settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, nil, nil, err
}
@@ -730,7 +761,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
// with no JWT token and usually no setup-key. As the client can send up to two login request to check if it is expired
// and before starting the engine, we do the checks without an account lock to avoid piling up requests.
func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Context, accountID string, login PeerLogin) error {
- peer, err := am.Store.GetPeerByPeerPubKey(ctx, login.WireGuardPubKey)
+ peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, login.WireGuardPubKey)
if err != nil {
return err
}
@@ -741,7 +772,7 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co
return nil
}
- settings, err := am.Store.GetAccountSettings(ctx, accountID)
+ settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
@@ -786,7 +817,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *Us
return err
}
- err = am.Store.SaveUserLastLogin(user.AccountID, user.Id, peer.LastLogin)
+ err = am.Store.SaveUserLastLogin(ctx, user.AccountID, user.Id, peer.LastLogin)
if err != nil {
return err
}
@@ -969,3 +1000,11 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
wg.Wait()
}
+
+func ConvertSliceToMap(existingLabels []string) map[string]struct{} {
+ labelMap := make(map[string]struct{}, len(existingLabels))
+ for _, label := range existingLabels {
+ labelMap[label] = struct{}{}
+ }
+ return labelMap
+}
diff --git a/management/server/peer_test.go b/management/server/peer_test.go
index 448e83a08..4b2ec66c6 100644
--- a/management/server/peer_test.go
+++ b/management/server/peer_test.go
@@ -7,20 +7,24 @@ import (
"net"
"net/netip"
"os"
+ "runtime"
"testing"
"time"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/proto"
+ "github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
+ "github.com/netbirdio/netbird/management/server/telemetry"
nbroute "github.com/netbirdio/netbird/route"
)
@@ -995,3 +999,184 @@ func TestToSyncResponse(t *testing.T) {
assert.Equal(t, 1, len(response.Checks))
assert.Equal(t, "/usr/bin/netbird", response.Checks[0].Files[0])
}
+
+func Test_RegisterPeerByUser(t *testing.T) {
+ if runtime.GOOS == "windows" {
+ t.Skip("The SQLite store is not properly supported by Windows yet")
+ }
+
+ store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
+
+ eventStore := &activity.InMemoryEventStore{}
+
+ metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
+ assert.NoError(t, err)
+
+ am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
+ assert.NoError(t, err)
+
+ existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
+ existingUserID := "edafee4e-63fb-11ec-90d6-0242ac120003"
+
+ _, err = store.GetAccount(context.Background(), existingAccountID)
+ require.NoError(t, err)
+
+ newPeer := &nbpeer.Peer{
+ ID: xid.New().String(),
+ AccountID: existingAccountID,
+ Key: "newPeerKey",
+ SetupKey: "",
+ IP: net.IP{123, 123, 123, 123},
+ Meta: nbpeer.PeerSystemMeta{
+ Hostname: "newPeer",
+ GoOS: "linux",
+ },
+ Name: "newPeerName",
+ DNSLabel: "newPeer.test",
+ UserID: existingUserID,
+ Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
+ SSHEnabled: false,
+ LastLogin: time.Now(),
+ }
+
+ addedPeer, _, _, err := am.AddPeer(context.Background(), "", existingUserID, newPeer)
+ require.NoError(t, err)
+
+ peer, err := store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, addedPeer.Key)
+ require.NoError(t, err)
+ assert.Equal(t, peer.AccountID, existingAccountID)
+ assert.Equal(t, peer.UserID, existingUserID)
+
+ account, err := store.GetAccount(context.Background(), existingAccountID)
+ require.NoError(t, err)
+ assert.Contains(t, account.Peers, addedPeer.ID)
+ assert.Equal(t, peer.Meta.Hostname, newPeer.Meta.Hostname)
+ assert.Contains(t, account.Groups["cfefqs706sqkneg59g3g"].Peers, addedPeer.ID)
+ assert.Contains(t, account.Groups["cfefqs706sqkneg59g4g"].Peers, addedPeer.ID)
+
+ assert.Equal(t, uint64(1), account.Network.Serial)
+
+ lastLogin, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z")
+ assert.NoError(t, err)
+ assert.NotEqual(t, lastLogin, account.Users[existingUserID].LastLogin)
+}
+
+func Test_RegisterPeerBySetupKey(t *testing.T) {
+ if runtime.GOOS == "windows" {
+ t.Skip("The SQLite store is not properly supported by Windows yet")
+ }
+
+ store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
+
+ eventStore := &activity.InMemoryEventStore{}
+
+ metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
+ assert.NoError(t, err)
+
+ am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
+ assert.NoError(t, err)
+
+ existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
+ existingSetupKeyID := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
+
+ _, err = store.GetAccount(context.Background(), existingAccountID)
+ require.NoError(t, err)
+
+ newPeer := &nbpeer.Peer{
+ ID: xid.New().String(),
+ AccountID: existingAccountID,
+ Key: "newPeerKey",
+ SetupKey: "existingSetupKey",
+ UserID: "",
+ IP: net.IP{123, 123, 123, 123},
+ Meta: nbpeer.PeerSystemMeta{
+ Hostname: "newPeer",
+ GoOS: "linux",
+ },
+ Name: "newPeerName",
+ DNSLabel: "newPeer.test",
+ Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
+ SSHEnabled: false,
+ }
+
+ addedPeer, _, _, err := am.AddPeer(context.Background(), existingSetupKeyID, "", newPeer)
+
+ require.NoError(t, err)
+
+ peer, err := store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, newPeer.Key)
+ require.NoError(t, err)
+ assert.Equal(t, peer.AccountID, existingAccountID)
+ assert.Equal(t, peer.SetupKey, existingSetupKeyID)
+
+ account, err := store.GetAccount(context.Background(), existingAccountID)
+ require.NoError(t, err)
+ assert.Contains(t, account.Peers, addedPeer.ID)
+ assert.Contains(t, account.Groups["cfefqs706sqkneg59g2g"].Peers, addedPeer.ID)
+ assert.Contains(t, account.Groups["cfefqs706sqkneg59g4g"].Peers, addedPeer.ID)
+
+ assert.Equal(t, uint64(1), account.Network.Serial)
+
+ lastUsed, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z")
+ assert.NoError(t, err)
+ assert.NotEqual(t, lastUsed, account.SetupKeys[existingSetupKeyID].LastUsed)
+ assert.Equal(t, 1, account.SetupKeys[existingSetupKeyID].UsedTimes)
+
+}
+
+func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
+ if runtime.GOOS == "windows" {
+ t.Skip("The SQLite store is not properly supported by Windows yet")
+ }
+
+ store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
+
+ eventStore := &activity.InMemoryEventStore{}
+
+ metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
+ assert.NoError(t, err)
+
+ am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
+ assert.NoError(t, err)
+
+ existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
+ faultyKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBC"
+
+ _, err = store.GetAccount(context.Background(), existingAccountID)
+ require.NoError(t, err)
+
+ newPeer := &nbpeer.Peer{
+ ID: xid.New().String(),
+ AccountID: existingAccountID,
+ Key: "newPeerKey",
+ SetupKey: "existingSetupKey",
+ UserID: "",
+ IP: net.IP{123, 123, 123, 123},
+ Meta: nbpeer.PeerSystemMeta{
+ Hostname: "newPeer",
+ GoOS: "linux",
+ },
+ Name: "newPeerName",
+ DNSLabel: "newPeer.test",
+ Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
+ SSHEnabled: false,
+ }
+
+ _, _, _, err = am.AddPeer(context.Background(), faultyKey, "", newPeer)
+ require.Error(t, err)
+
+ _, err = store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, newPeer.Key)
+ require.Error(t, err)
+
+ account, err := store.GetAccount(context.Background(), existingAccountID)
+ require.NoError(t, err)
+ assert.NotContains(t, account.Peers, newPeer.ID)
+ assert.NotContains(t, account.Groups["cfefqs706sqkneg59g3g"].Peers, newPeer.ID)
+ assert.NotContains(t, account.Groups["cfefqs706sqkneg59g4g"].Peers, newPeer.ID)
+
+ assert.Equal(t, uint64(0), account.Network.Serial)
+
+ lastUsed, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z")
+ assert.NoError(t, err)
+ assert.Equal(t, lastUsed, account.SetupKeys[faultyKey].LastUsed)
+ assert.Equal(t, 0, account.SetupKeys[faultyKey].UsedTimes)
+}
diff --git a/management/server/sql_store.go b/management/server/sql_store.go
index 0fb3d391f..6f1f66ef8 100644
--- a/management/server/sql_store.go
+++ b/management/server/sql_store.go
@@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
+ "net"
"os"
"path/filepath"
"runtime"
@@ -33,6 +34,7 @@ import (
const (
storeSqliteFileName = "store.db"
idQueryCondition = "id = ?"
+ keyQueryCondition = "key = ?"
accountAndIDQueryCondition = "account_id = ? and id = ?"
peerNotFoundFMT = "peer %s not found"
)
@@ -415,13 +417,12 @@ func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string)
func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) {
var key SetupKey
- result := s.db.Select("account_id").First(&key, "key = ?", strings.ToUpper(setupKey))
+ result := s.db.WithContext(ctx).Select("account_id").First(&key, keyQueryCondition, strings.ToUpper(setupKey))
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
- log.WithContext(ctx).Errorf("error when getting setup key from the store: %s", result.Error)
- return nil, status.Errorf(status.Internal, "issue getting setup key from store")
+ return nil, status.NewSetupKeyNotFoundError()
}
if key.AccountID == "" {
@@ -474,15 +475,15 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User,
return &user, nil
}
-func (s *SqlStore) GetUserByUserID(ctx context.Context, userID string) (*User, error) {
+func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) {
var user User
- result := s.db.First(&user, idQueryCondition, userID)
+ result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
+ First(&user, idQueryCondition, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
- return nil, status.Errorf(status.NotFound, "user not found: index lookup failed")
+ return nil, status.NewUserNotFoundError(userID)
}
- log.WithContext(ctx).Errorf("error when getting user from the store: %s", result.Error)
- return nil, status.Errorf(status.Internal, "issue getting user from store")
+ return nil, status.NewGetUserFromStoreError()
}
return &user, nil
@@ -535,7 +536,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account,
if result.Error != nil {
log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
- return nil, status.Errorf(status.NotFound, "account not found")
+ return nil, status.NewAccountNotFoundError(accountID)
}
return nil, status.Errorf(status.Internal, "issue getting account from store")
}
@@ -595,7 +596,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account,
func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Account, error) {
var user User
- result := s.db.Select("account_id").First(&user, idQueryCondition, userID)
+ result := s.db.WithContext(ctx).Select("account_id").First(&user, idQueryCondition, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
@@ -612,12 +613,11 @@ func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Accoun
func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) {
var peer nbpeer.Peer
- result := s.db.Select("account_id").First(&peer, idQueryCondition, peerID)
+ result := s.db.WithContext(ctx).Select("account_id").First(&peer, idQueryCondition, peerID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
- log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting account from store")
}
@@ -631,12 +631,11 @@ func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Acco
func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) {
var peer nbpeer.Peer
- result := s.db.Select("account_id").First(&peer, "key = ?", peerKey)
+ result := s.db.WithContext(ctx).Select("account_id").First(&peer, keyQueryCondition, peerKey)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
- log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting account from store")
}
@@ -650,12 +649,11 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (
func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) {
var peer nbpeer.Peer
var accountID string
- result := s.db.Model(&peer).Select("account_id").Where("key = ?", peerKey).First(&accountID)
+ result := s.db.WithContext(ctx).Model(&peer).Select("account_id").Where(keyQueryCondition, peerKey).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
- log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error)
return "", status.Errorf(status.Internal, "issue getting account from store")
}
@@ -677,61 +675,117 @@ func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
}
func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) {
- var key SetupKey
var accountID string
- result := s.db.Model(&key).Select("account_id").Where("key = ?", strings.ToUpper(setupKey)).First(&accountID)
+ result := s.db.WithContext(ctx).Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, strings.ToUpper(setupKey)).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
- log.WithContext(ctx).Errorf("error when getting setup key from the store: %s", result.Error)
- return "", status.Errorf(status.Internal, "issue getting setup key from store")
+ return "", status.NewSetupKeyNotFoundError()
+ }
+
+ if accountID == "" {
+ return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
return accountID, nil
}
-func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, peerKey string) (*nbpeer.Peer, error) {
+func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) {
+ var ipJSONStrings []string
+
+ // Fetch the IP addresses as JSON strings
+ result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
+ Where("account_id = ?", accountID).
+ Pluck("ip", &ipJSONStrings)
+ if result.Error != nil {
+ if errors.Is(result.Error, gorm.ErrRecordNotFound) {
+ return nil, status.Errorf(status.NotFound, "no peers found for the account")
+ }
+ return nil, status.Errorf(status.Internal, "issue getting IPs from store")
+ }
+
+ // Convert the JSON strings to net.IP objects
+ ips := make([]net.IP, len(ipJSONStrings))
+ for i, ipJSON := range ipJSONStrings {
+ var ip net.IP
+ if err := json.Unmarshal([]byte(ipJSON), &ip); err != nil {
+ return nil, status.Errorf(status.Internal, "issue parsing IP JSON from store")
+ }
+ ips[i] = ip
+ }
+
+ return ips, nil
+}
+
+func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) {
+ var labels []string
+
+ result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
+ Where("account_id = ?", accountID).
+ Pluck("dns_label", &labels)
+
+ if result.Error != nil {
+ if errors.Is(result.Error, gorm.ErrRecordNotFound) {
+ return nil, status.Errorf(status.NotFound, "no peers found for the account")
+ }
+ log.WithContext(ctx).Errorf("error when getting dns labels from the store: %s", result.Error)
+ return nil, status.Errorf(status.Internal, "issue getting dns labels from store")
+ }
+
+ return labels, nil
+}
+
+func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) {
+ var accountNetwork AccountNetwork
+
+ if err := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil {
+ if errors.Is(err, gorm.ErrRecordNotFound) {
+ return nil, status.NewAccountNotFoundError(accountID)
+ }
+ return nil, status.Errorf(status.Internal, "issue getting network from store")
+ }
+ return accountNetwork.Network, nil
+}
+
+func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) {
var peer nbpeer.Peer
- result := s.db.First(&peer, "key = ?", peerKey)
+ result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).First(&peer, keyQueryCondition, peerKey)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "peer not found")
}
- log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting peer from store")
}
return &peer, nil
}
-func (s *SqlStore) GetAccountSettings(ctx context.Context, accountID string) (*Settings, error) {
+func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) {
var accountSettings AccountSettings
- if err := s.db.Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil {
+ if err := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "settings not found")
}
- log.WithContext(ctx).Errorf("error when getting settings from the store: %s", err)
return nil, status.Errorf(status.Internal, "issue getting settings from store")
}
return accountSettings.Settings, nil
}
// SaveUserLastLogin stores the last login time for a user in DB.
-func (s *SqlStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error {
+func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
var user User
- result := s.db.First(&user, accountAndIDQueryCondition, accountID, userID)
+ result := s.db.WithContext(ctx).First(&user, accountAndIDQueryCondition, accountID, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
- return status.Errorf(status.NotFound, "user %s not found", userID)
+ return status.NewUserNotFoundError(userID)
}
- return status.Errorf(status.Internal, "issue getting user from store")
+ return status.NewGetUserFromStoreError()
}
-
user.LastLogin = lastLogin
- return s.db.Save(user).Error
+ return s.db.Save(&user).Error
}
func (s *SqlStore) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
@@ -850,3 +904,123 @@ func NewPostgresqlStoreFromFileStore(ctx context.Context, fileStore *FileStore,
return store, nil
}
+
+func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) {
+ var setupKey SetupKey
+ result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
+ First(&setupKey, keyQueryCondition, strings.ToUpper(key))
+ if result.Error != nil {
+ if errors.Is(result.Error, gorm.ErrRecordNotFound) {
+ return nil, status.Errorf(status.NotFound, "setup key not found")
+ }
+ return nil, status.NewSetupKeyNotFoundError()
+ }
+ return &setupKey, nil
+}
+
+func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
+ result := s.db.WithContext(ctx).Model(&SetupKey{}).
+ Where(idQueryCondition, setupKeyID).
+ Updates(map[string]interface{}{
+ "used_times": gorm.Expr("used_times + 1"),
+ "last_used": time.Now(),
+ })
+
+ if result.Error != nil {
+ return status.Errorf(status.Internal, "issue incrementing setup key usage count: %s", result.Error)
+ }
+
+ if result.RowsAffected == 0 {
+ return status.Errorf(status.NotFound, "setup key not found")
+ }
+
+ return nil
+}
+
+func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
+ var group nbgroup.Group
+
+ result := s.db.WithContext(ctx).Where("account_id = ? AND name = ?", accountID, "All").First(&group)
+ if result.Error != nil {
+ if errors.Is(result.Error, gorm.ErrRecordNotFound) {
+ return status.Errorf(status.NotFound, "group 'All' not found for account")
+ }
+ return status.Errorf(status.Internal, "issue finding group 'All'")
+ }
+
+ for _, existingPeerID := range group.Peers {
+ if existingPeerID == peerID {
+ return nil
+ }
+ }
+
+ group.Peers = append(group.Peers, peerID)
+
+ if err := s.db.Save(&group).Error; err != nil {
+ return status.Errorf(status.Internal, "issue updating group 'All'")
+ }
+
+ return nil
+}
+
+func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error {
+ var group nbgroup.Group
+
+ result := s.db.WithContext(ctx).Where(accountAndIDQueryCondition, accountId, groupID).First(&group)
+ if result.Error != nil {
+ if errors.Is(result.Error, gorm.ErrRecordNotFound) {
+ return status.Errorf(status.NotFound, "group not found for account")
+ }
+ return status.Errorf(status.Internal, "issue finding group")
+ }
+
+ for _, existingPeerID := range group.Peers {
+ if existingPeerID == peerId {
+ return nil
+ }
+ }
+
+ group.Peers = append(group.Peers, peerId)
+
+ if err := s.db.Save(&group).Error; err != nil {
+ return status.Errorf(status.Internal, "issue updating group")
+ }
+
+ return nil
+}
+
+func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
+ if err := s.db.WithContext(ctx).Create(peer).Error; err != nil {
+ return status.Errorf(status.Internal, "issue adding peer to account")
+ }
+
+ return nil
+}
+
+func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
+ result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
+ if result.Error != nil {
+ return status.Errorf(status.Internal, "issue incrementing network serial count")
+ }
+ return nil
+}
+
+func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(store Store) error) error {
+ tx := s.db.WithContext(ctx).Begin()
+ if tx.Error != nil {
+ return tx.Error
+ }
+ repo := s.withTx(tx)
+ err := operation(repo)
+ if err != nil {
+ tx.Rollback()
+ return err
+ }
+ return tx.Commit().Error
+}
+
+func (s *SqlStore) withTx(tx *gorm.DB) Store {
+ return &SqlStore{
+ db: tx,
+ }
+}
diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go
index ce4ee531a..64ef36831 100644
--- a/management/server/sql_store_test.go
+++ b/management/server/sql_store_test.go
@@ -1003,3 +1003,163 @@ func TestPostgresql_GetUserByTokenID(t *testing.T) {
require.NoError(t, err)
require.Equal(t, id, user.PATs[id].ID)
}
+
+func TestSqlite_GetTakenIPs(t *testing.T) {
+ if runtime.GOOS == "windows" {
+ t.Skip("The SQLite store is not properly supported by Windows yet")
+ }
+
+ store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
+ defer store.Close(context.Background())
+
+ existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
+
+ _, err := store.GetAccount(context.Background(), existingAccountID)
+ require.NoError(t, err)
+
+ takenIPs, err := store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID)
+ require.NoError(t, err)
+ assert.Equal(t, []net.IP{}, takenIPs)
+
+ peer1 := &nbpeer.Peer{
+ ID: "peer1",
+ AccountID: existingAccountID,
+ IP: net.IP{1, 1, 1, 1},
+ }
+ err = store.AddPeerToAccount(context.Background(), peer1)
+ require.NoError(t, err)
+
+ takenIPs, err = store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID)
+ require.NoError(t, err)
+ ip1 := net.IP{1, 1, 1, 1}.To16()
+ assert.Equal(t, []net.IP{ip1}, takenIPs)
+
+ peer2 := &nbpeer.Peer{
+ ID: "peer2",
+ AccountID: existingAccountID,
+ IP: net.IP{2, 2, 2, 2},
+ }
+ err = store.AddPeerToAccount(context.Background(), peer2)
+ require.NoError(t, err)
+
+ takenIPs, err = store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID)
+ require.NoError(t, err)
+ ip2 := net.IP{2, 2, 2, 2}.To16()
+ assert.Equal(t, []net.IP{ip1, ip2}, takenIPs)
+
+}
+
+func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
+ if runtime.GOOS == "windows" {
+ t.Skip("The SQLite store is not properly supported by Windows yet")
+ }
+
+ store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
+ defer store.Close(context.Background())
+
+ existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
+
+ _, err := store.GetAccount(context.Background(), existingAccountID)
+ require.NoError(t, err)
+
+ labels, err := store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
+ require.NoError(t, err)
+ assert.Equal(t, []string{}, labels)
+
+ peer1 := &nbpeer.Peer{
+ ID: "peer1",
+ AccountID: existingAccountID,
+ DNSLabel: "peer1.domain.test",
+ }
+ err = store.AddPeerToAccount(context.Background(), peer1)
+ require.NoError(t, err)
+
+ labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
+ require.NoError(t, err)
+ assert.Equal(t, []string{"peer1.domain.test"}, labels)
+
+ peer2 := &nbpeer.Peer{
+ ID: "peer2",
+ AccountID: existingAccountID,
+ DNSLabel: "peer2.domain.test",
+ }
+ err = store.AddPeerToAccount(context.Background(), peer2)
+ require.NoError(t, err)
+
+ labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
+ require.NoError(t, err)
+ assert.Equal(t, []string{"peer1.domain.test", "peer2.domain.test"}, labels)
+}
+
+func TestSqlite_GetAccountNetwork(t *testing.T) {
+ if runtime.GOOS == "windows" {
+ t.Skip("The SQLite store is not properly supported by Windows yet")
+ }
+
+ store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
+ defer store.Close(context.Background())
+
+ existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
+
+ _, err := store.GetAccount(context.Background(), existingAccountID)
+ require.NoError(t, err)
+
+ network, err := store.GetAccountNetwork(context.Background(), LockingStrengthShare, existingAccountID)
+ require.NoError(t, err)
+ ip := net.IP{100, 64, 0, 0}.To16()
+ assert.Equal(t, ip, network.Net.IP)
+ assert.Equal(t, net.IPMask{255, 255, 0, 0}, network.Net.Mask)
+ assert.Equal(t, "", network.Dns)
+ assert.Equal(t, "af1c8024-ha40-4ce2-9418-34653101fc3c", network.Identifier)
+ assert.Equal(t, uint64(0), network.Serial)
+}
+
+func TestSqlite_GetSetupKeyBySecret(t *testing.T) {
+ if runtime.GOOS == "windows" {
+ t.Skip("The SQLite store is not properly supported by Windows yet")
+ }
+ store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
+ defer store.Close(context.Background())
+
+ existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
+
+ _, err := store.GetAccount(context.Background(), existingAccountID)
+ require.NoError(t, err)
+
+ setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
+ require.NoError(t, err)
+ assert.Equal(t, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB", setupKey.Key)
+ assert.Equal(t, "bf1c8084-ba50-4ce7-9439-34653001fc3b", setupKey.AccountID)
+ assert.Equal(t, "Default key", setupKey.Name)
+}
+
+func TestSqlite_incrementSetupKeyUsage(t *testing.T) {
+ if runtime.GOOS == "windows" {
+ t.Skip("The SQLite store is not properly supported by Windows yet")
+ }
+ store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
+ defer store.Close(context.Background())
+
+ existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
+
+ _, err := store.GetAccount(context.Background(), existingAccountID)
+ require.NoError(t, err)
+
+ setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
+ require.NoError(t, err)
+ assert.Equal(t, 0, setupKey.UsedTimes)
+
+ err = store.IncrementSetupKeyUsage(context.Background(), setupKey.Id)
+ require.NoError(t, err)
+
+ setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
+ require.NoError(t, err)
+ assert.Equal(t, 1, setupKey.UsedTimes)
+
+ err = store.IncrementSetupKeyUsage(context.Background(), setupKey.Id)
+ require.NoError(t, err)
+
+ setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
+ require.NoError(t, err)
+ assert.Equal(t, 2, setupKey.UsedTimes)
+}
diff --git a/management/server/status/error.go b/management/server/status/error.go
index 58b9a84a0..d7fde35b9 100644
--- a/management/server/status/error.go
+++ b/management/server/status/error.go
@@ -100,3 +100,13 @@ func NewPeerNotRegisteredError() error {
func NewPeerLoginExpiredError() error {
return Errorf(PermissionDenied, "peer login has expired, please log in once more")
}
+
+// NewSetupKeyNotFoundError creates a new Error with NotFound type for a missing setup key
+func NewSetupKeyNotFoundError() error {
+ return Errorf(NotFound, "setup key not found")
+}
+
+// NewGetUserFromStoreError creates a new Error with Internal type for an issue getting user from store
+func NewGetUserFromStoreError() error {
+ return Errorf(Internal, "issue getting user from store")
+}
diff --git a/management/server/store.go b/management/server/store.go
index a2b489391..84b3b140c 100644
--- a/management/server/store.go
+++ b/management/server/store.go
@@ -27,6 +27,15 @@ import (
"github.com/netbirdio/netbird/route"
)
+type LockingStrength string
+
+const (
+ LockingStrengthUpdate LockingStrength = "UPDATE" // Strongest lock, preventing any changes by other transactions until your transaction completes.
+ LockingStrengthShare LockingStrength = "SHARE" // Allows reading but prevents changes by other transactions.
+ LockingStrengthNoKeyUpdate LockingStrength = "NO KEY UPDATE" // Similar to UPDATE but allows changes to related rows.
+ LockingStrengthKeyShare LockingStrength = "KEY SHARE" // Protects against changes to primary/unique keys but allows other updates.
+)
+
type Store interface {
GetAllAccounts(ctx context.Context) []*Account
GetAccount(ctx context.Context, accountID string) (*Account, error)
@@ -41,7 +50,7 @@ type Store interface {
GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error)
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
- GetUserByUserID(ctx context.Context, userID string) (*User, error)
+ GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
SaveAccount(ctx context.Context, account *Account) error
@@ -60,14 +69,24 @@ type Store interface {
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
SavePeerLocation(accountID string, peer *nbpeer.Peer) error
- SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error
+ SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
// Close should close the store persisting all unsaved data.
Close(ctx context.Context) error
// GetStoreEngine should return StoreEngine of the current store implementation.
// This is also a method of metrics.DataSource interface.
GetStoreEngine() StoreEngine
- GetPeerByPeerPubKey(ctx context.Context, peerKey string) (*nbpeer.Peer, error)
- GetAccountSettings(ctx context.Context, accountID string) (*Settings, error)
+ GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
+ GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error)
+ GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)
+ GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
+ IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
+ AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
+ GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
+ AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
+ AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
+ IncrementNetworkSerial(ctx context.Context, accountId string) error
+ GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)
+ ExecuteInTransaction(ctx context.Context, f func(store Store) error) error
}
type StoreEngine string
diff --git a/management/server/testdata/extended-store.json b/management/server/testdata/extended-store.json
new file mode 100644
index 000000000..7f96e57a8
--- /dev/null
+++ b/management/server/testdata/extended-store.json
@@ -0,0 +1,120 @@
+{
+ "Accounts": {
+ "bf1c8084-ba50-4ce7-9439-34653001fc3b": {
+ "Id": "bf1c8084-ba50-4ce7-9439-34653001fc3b",
+ "CreatedBy": "",
+ "Domain": "test.com",
+ "DomainCategory": "private",
+ "IsDomainPrimaryAccount": true,
+ "SetupKeys": {
+ "A2C8E62B-38F5-4553-B31E-DD66C696CEBB": {
+ "Id": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB",
+ "AccountID": "",
+ "Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB",
+ "Name": "Default key",
+ "Type": "reusable",
+ "CreatedAt": "2021-08-19T20:46:20.005936822+02:00",
+ "ExpiresAt": "2321-09-18T20:46:20.005936822+02:00",
+ "UpdatedAt": "0001-01-01T00:00:00Z",
+ "Revoked": false,
+ "UsedTimes": 0,
+ "LastUsed": "0001-01-01T00:00:00Z",
+ "AutoGroups": ["cfefqs706sqkneg59g2g"],
+ "UsageLimit": 0,
+ "Ephemeral": false
+ },
+ "A2C8E62B-38F5-4553-B31E-DD66C696CEBC": {
+ "Id": "A2C8E62B-38F5-4553-B31E-DD66C696CEBC",
+ "AccountID": "",
+ "Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBC",
+ "Name": "Faulty key with non existing group",
+ "Type": "reusable",
+ "CreatedAt": "2021-08-19T20:46:20.005936822+02:00",
+ "ExpiresAt": "2321-09-18T20:46:20.005936822+02:00",
+ "UpdatedAt": "0001-01-01T00:00:00Z",
+ "Revoked": false,
+ "UsedTimes": 0,
+ "LastUsed": "0001-01-01T00:00:00Z",
+ "AutoGroups": ["abcd"],
+ "UsageLimit": 0,
+ "Ephemeral": false
+ }
+ },
+ "Network": {
+ "id": "af1c8024-ha40-4ce2-9418-34653101fc3c",
+ "Net": {
+ "IP": "100.64.0.0",
+ "Mask": "//8AAA=="
+ },
+ "Dns": "",
+ "Serial": 0
+ },
+ "Peers": {},
+ "Users": {
+ "edafee4e-63fb-11ec-90d6-0242ac120003": {
+ "Id": "edafee4e-63fb-11ec-90d6-0242ac120003",
+ "AccountID": "",
+ "Role": "admin",
+ "IsServiceUser": false,
+ "ServiceUserName": "",
+ "AutoGroups": ["cfefqs706sqkneg59g3g"],
+ "PATs": {},
+ "Blocked": false,
+ "LastLogin": "0001-01-01T00:00:00Z"
+ },
+ "f4f6d672-63fb-11ec-90d6-0242ac120003": {
+ "Id": "f4f6d672-63fb-11ec-90d6-0242ac120003",
+ "AccountID": "",
+ "Role": "user",
+ "IsServiceUser": false,
+ "ServiceUserName": "",
+ "AutoGroups": null,
+ "PATs": {
+ "9dj38s35-63fb-11ec-90d6-0242ac120003": {
+ "ID": "9dj38s35-63fb-11ec-90d6-0242ac120003",
+ "UserID": "",
+ "Name": "",
+ "HashedToken": "SoMeHaShEdToKeN",
+ "ExpirationDate": "2023-02-27T00:00:00Z",
+ "CreatedBy": "user",
+ "CreatedAt": "2023-01-01T00:00:00Z",
+ "LastUsed": "2023-02-01T00:00:00Z"
+ }
+ },
+ "Blocked": false,
+ "LastLogin": "0001-01-01T00:00:00Z"
+ }
+ },
+ "Groups": {
+ "cfefqs706sqkneg59g4g": {
+ "ID": "cfefqs706sqkneg59g4g",
+ "Name": "All",
+ "Peers": []
+ },
+ "cfefqs706sqkneg59g3g": {
+ "ID": "cfefqs706sqkneg59g3g",
+ "Name": "AwesomeGroup1",
+ "Peers": []
+ },
+ "cfefqs706sqkneg59g2g": {
+ "ID": "cfefqs706sqkneg59g2g",
+ "Name": "AwesomeGroup2",
+ "Peers": []
+ }
+ },
+ "Rules": null,
+ "Policies": [],
+ "Routes": null,
+ "NameServerGroups": null,
+ "DNSSettings": null,
+ "Settings": {
+ "PeerLoginExpirationEnabled": false,
+ "PeerLoginExpiration": 86400000000000,
+ "GroupsPropagationEnabled": false,
+ "JWTGroupsEnabled": false,
+ "JWTGroupsClaimName": ""
+ }
+ }
+ },
+ "InstallationID": ""
+}
diff --git a/management/server/user.go b/management/server/user.go
index 727bc5c6b..9e60bb94b 100644
--- a/management/server/user.go
+++ b/management/server/user.go
@@ -89,10 +89,6 @@ func (u *User) LastDashboardLoginChanged(LastLogin time.Time) bool {
return LastLogin.After(u.LastLogin) && !u.LastLogin.IsZero()
}
-func (u *User) updateLastLogin(login time.Time) {
- u.LastLogin = login
-}
-
// HasAdminPower returns true if the user has admin or owner roles, false otherwise
func (u *User) HasAdminPower() bool {
return u.Role == UserRoleAdmin || u.Role == UserRoleOwner
@@ -386,7 +382,7 @@ func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.A
// server when user authenticates a device. And we need to separate the Dashboard login event from the Device login event.
newLogin := user.LastDashboardLoginChanged(claims.LastLogin)
- err = am.Store.SaveUserLastLogin(account.Id, claims.UserId, claims.LastLogin)
+ err = am.Store.SaveUserLastLogin(ctx, account.Id, claims.UserId, claims.LastLogin)
if err != nil {
log.WithContext(ctx).Errorf("failed saving user last login: %v", err)
}
From 97e10e440cf172e94921c983395c56252c0e429e Mon Sep 17 00:00:00 2001
From: Zoltan Papp
Date: Mon, 16 Sep 2024 16:11:10 +0200
Subject: [PATCH 58/89] Fix leaked server connections (#2596)
Fix leaked server connections
close unused connections in the client lib
close deprecated connection in the server lib
The Server Picker is reusable in the guard if we want in the future. So we can support the server address changes.
---------
Co-authored-by: Maycon Santos
* Add logging
---------
Co-authored-by: Maycon Santos
---
relay/client/client.go | 34 ++++++++------
relay/client/manager.go | 73 ++++++-----------------------
relay/client/picker.go | 94 ++++++++++++++++++++++++++++++++++++++
relay/server/peer.go | 9 ++++
relay/server/store.go | 6 ++-
relay/server/store_test.go | 49 +++++++++++++++++++-
6 files changed, 187 insertions(+), 78 deletions(-)
create mode 100644 relay/client/picker.go
diff --git a/relay/client/client.go b/relay/client/client.go
index 7ff17944f..3e5c0ba24 100644
--- a/relay/client/client.go
+++ b/relay/client/client.go
@@ -142,7 +142,7 @@ type Client struct {
func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client {
hashedID, hashedStringId := messages.HashID(peerID)
return &Client{
- log: log.WithField("client_id", hashedStringId),
+ log: log.WithFields(log.Fields{"client_id": hashedStringId, "relay": serverURL}),
parentCtx: ctx,
connectionURL: serverURL,
authTokenStore: authTokenStore,
@@ -159,7 +159,7 @@ func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.Token
// Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs.
func (c *Client) Connect() error {
- c.log.Infof("connecting to relay server: %s", c.connectionURL)
+ c.log.Infof("connecting to relay server")
c.readLoopMutex.Lock()
defer c.readLoopMutex.Unlock()
@@ -180,7 +180,7 @@ func (c *Client) Connect() error {
c.wgReadLoop.Add(1)
go c.readLoop(c.relayConn)
- c.log.Infof("relay connection established with: %s", c.connectionURL)
+ c.log.Infof("relay connection established")
return nil
}
@@ -202,7 +202,7 @@ func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) {
return nil, ErrConnAlreadyExists
}
- log.Infof("open connection to peer: %s", hashedStringID)
+ c.log.Infof("open connection to peer: %s", hashedStringID)
msgChannel := make(chan Msg, 2)
conn := NewConn(c, hashedID, hashedStringID, msgChannel, c.instanceURL)
@@ -250,7 +250,7 @@ func (c *Client) connect() error {
if err != nil {
cErr := conn.Close()
if cErr != nil {
- log.Errorf("failed to close connection: %s", cErr)
+ c.log.Errorf("failed to close connection: %s", cErr)
}
return err
}
@@ -261,19 +261,19 @@ func (c *Client) connect() error {
func (c *Client) handShake() error {
msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary())
if err != nil {
- log.Errorf("failed to marshal auth message: %s", err)
+ c.log.Errorf("failed to marshal auth message: %s", err)
return err
}
_, err = c.relayConn.Write(msg)
if err != nil {
- log.Errorf("failed to send auth message: %s", err)
+ c.log.Errorf("failed to send auth message: %s", err)
return err
}
buf := make([]byte, messages.MaxHandshakeRespSize)
n, err := c.readWithTimeout(buf)
if err != nil {
- log.Errorf("failed to read auth response: %s", err)
+ c.log.Errorf("failed to read auth response: %s", err)
return err
}
@@ -284,12 +284,12 @@ func (c *Client) handShake() error {
msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n])
if err != nil {
- log.Errorf("failed to determine message type: %s", err)
+ c.log.Errorf("failed to determine message type: %s", err)
return err
}
if msgType != messages.MsgTypeAuthResponse {
- log.Errorf("unexpected message type: %s", msgType)
+ c.log.Errorf("unexpected message type: %s", msgType)
return fmt.Errorf("unexpected message type")
}
@@ -318,6 +318,7 @@ func (c *Client) readLoop(relayConn net.Conn) {
buf := *bufPtr
n, errExit = relayConn.Read(buf)
if errExit != nil {
+ c.log.Infof("start to Relay read loop exit")
c.mu.Lock()
if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
c.log.Debugf("failed to read message from relay server: %s", errExit)
@@ -364,7 +365,7 @@ func (c *Client) handleMsg(msgType messages.MsgType, buf []byte, bufPtr *[]byte,
case messages.MsgTypeTransport:
return c.handleTransportMsg(buf, bufPtr, internallyStoppedFlag)
case messages.MsgTypeClose:
- log.Debugf("relay connection close by server")
+ c.log.Debugf("relay connection close by server")
c.bufPool.Put(bufPtr)
return false
}
@@ -433,14 +434,14 @@ func (c *Client) writeTo(connReference *Conn, id string, dstID []byte, payload [
// todo: use buffer pool instead of create new transport msg.
msg, err := messages.MarshalTransportMsg(dstID, payload)
if err != nil {
- log.Errorf("failed to marshal transport message: %s", err)
+ c.log.Errorf("failed to marshal transport message: %s", err)
return 0, err
}
// the write always return with 0 length because the underling does not support the size feedback.
_, err = c.relayConn.Write(msg)
if err != nil {
- log.Errorf("failed to write transport message: %s", err)
+ c.log.Errorf("failed to write transport message: %s", err)
}
return len(payload), err
}
@@ -459,7 +460,7 @@ func (c *Client) listenForStopEvents(hc *healthcheck.Receiver, conn net.Conn, in
case <-c.parentCtx.Done():
err := c.close(true)
if err != nil {
- log.Errorf("failed to teardown connection: %s", err)
+ c.log.Errorf("failed to teardown connection: %s", err)
}
return
}
@@ -499,10 +500,12 @@ func (c *Client) close(gracefullyExit bool) error {
var err error
if !c.serviceIsRunning {
c.mu.Unlock()
+ c.log.Warn("relay connection was already marked as not running")
return nil
}
c.serviceIsRunning = false
+ c.log.Infof("closing all peer connections")
c.closeAllConns()
if gracefullyExit {
c.writeCloseMsg()
@@ -510,8 +513,9 @@ func (c *Client) close(gracefullyExit bool) error {
err = c.relayConn.Close()
c.mu.Unlock()
+ c.log.Infof("waiting for read loop to close")
c.wgReadLoop.Wait()
- c.log.Infof("relay connection closed with: %s", c.connectionURL)
+ c.log.Infof("relay connection closed")
return err
}
diff --git a/relay/client/manager.go b/relay/client/manager.go
index a9d294160..4554c7c0f 100644
--- a/relay/client/manager.go
+++ b/relay/client/manager.go
@@ -3,7 +3,6 @@ package client
import (
"container/list"
"context"
- "errors"
"fmt"
"net"
"reflect"
@@ -17,8 +16,6 @@ import (
var (
relayCleanupInterval = 60 * time.Second
- connectionTimeout = 30 * time.Second
- maxConcurrentServers = 7
ErrRelayClientNotConnected = fmt.Errorf("relay client not connected")
)
@@ -92,67 +89,23 @@ func (m *Manager) Serve() error {
}
log.Debugf("starting relay client manager with %v relay servers", m.serverURLs)
- totalServers := len(m.serverURLs)
-
- successChan := make(chan *Client, 1)
- errChan := make(chan error, len(m.serverURLs))
-
- ctx, cancel := context.WithTimeout(m.ctx, connectionTimeout)
- defer cancel()
-
- sem := make(chan struct{}, maxConcurrentServers)
-
- for _, url := range m.serverURLs {
- sem <- struct{}{}
- go func(url string) {
- defer func() { <-sem }()
- m.connect(m.ctx, url, successChan, errChan)
- }(url)
+ sp := ServerPicker{
+ TokenStore: m.tokenStore,
+ PeerID: m.peerID,
}
- var errCount int
-
- for {
- select {
- case client := <-successChan:
- log.Infof("Successfully connected to relay server: %s", client.connectionURL)
-
- m.relayClient = client
-
- m.reconnectGuard = NewGuard(m.ctx, m.relayClient)
- m.relayClient.SetOnDisconnectListener(func() {
- m.onServerDisconnected(client.connectionURL)
- })
- m.startCleanupLoop()
- return nil
- case err := <-errChan:
- errCount++
- log.Warnf("Connection attempt failed: %v", err)
- if errCount == totalServers {
- return errors.New("failed to connect to any relay server: all attempts failed")
- }
- case <-ctx.Done():
- return fmt.Errorf("failed to connect to any relay server: %w", ctx.Err())
- }
+ client, err := sp.PickServer(m.ctx, m.serverURLs)
+ if err != nil {
+ return err
}
-}
+ m.relayClient = client
-func (m *Manager) connect(ctx context.Context, serverURL string, successChan chan<- *Client, errChan chan<- error) {
- // TODO: abort the connection if another connection was successful
- relayClient := NewClient(ctx, serverURL, m.tokenStore, m.peerID)
- if err := relayClient.Connect(); err != nil {
- errChan <- fmt.Errorf("failed to connect to %s: %w", serverURL, err)
- return
- }
-
- select {
- case successChan <- relayClient:
- // This client was the first to connect successfully
- default:
- if err := relayClient.Close(); err != nil {
- log.Debugf("failed to close relay client: %s", err)
- }
- }
+ m.reconnectGuard = NewGuard(m.ctx, m.relayClient)
+ m.relayClient.SetOnDisconnectListener(func() {
+ m.onServerDisconnected(client.connectionURL)
+ })
+ m.startCleanupLoop()
+ return nil
}
// OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be
diff --git a/relay/client/picker.go b/relay/client/picker.go
new file mode 100644
index 000000000..b0888a4a0
--- /dev/null
+++ b/relay/client/picker.go
@@ -0,0 +1,94 @@
+package client
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+
+ auth "github.com/netbirdio/netbird/relay/auth/hmac"
+)
+
+const (
+ connectionTimeout = 30 * time.Second
+ maxConcurrentServers = 7
+)
+
+type connResult struct {
+ RelayClient *Client
+ Url string
+ Err error
+}
+
+type ServerPicker struct {
+ TokenStore *auth.TokenStore
+ PeerID string
+}
+
+func (sp *ServerPicker) PickServer(parentCtx context.Context, urls []string) (*Client, error) {
+ ctx, cancel := context.WithTimeout(parentCtx, connectionTimeout)
+ defer cancel()
+
+ totalServers := len(urls)
+
+ connResultChan := make(chan connResult, totalServers)
+ successChan := make(chan connResult, 1)
+
+ concurrentLimiter := make(chan struct{}, maxConcurrentServers)
+ for _, url := range urls {
+ concurrentLimiter <- struct{}{}
+ go func(url string) {
+ defer func() { <-concurrentLimiter }()
+ sp.startConnection(parentCtx, connResultChan, url)
+ }(url)
+ }
+
+ go sp.processConnResults(connResultChan, successChan)
+
+ select {
+ case cr, ok := <-successChan:
+ if !ok {
+ return nil, errors.New("failed to connect to any relay server: all attempts failed")
+ }
+ log.Infof("chosen home Relay server: %s", cr.Url)
+ return cr.RelayClient, nil
+ case <-ctx.Done():
+ return nil, fmt.Errorf("failed to connect to any relay server: %w", ctx.Err())
+ }
+}
+
+func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan connResult, url string) {
+ log.Infof("try to connecting to relay server: %s", url)
+ relayClient := NewClient(ctx, url, sp.TokenStore, sp.PeerID)
+ err := relayClient.Connect()
+ resultChan <- connResult{
+ RelayClient: relayClient,
+ Url: url,
+ Err: err,
+ }
+}
+
+func (sp *ServerPicker) processConnResults(resultChan chan connResult, successChan chan connResult) {
+ var hasSuccess bool
+ for cr := range resultChan {
+ if cr.Err != nil {
+ log.Debugf("failed to connect to Relay server: %s: %v", cr.Url, cr.Err)
+ continue
+ }
+ log.Infof("connected to Relay server: %s", cr.Url)
+
+ if hasSuccess {
+ log.Infof("closing unnecessary Relay connection to: %s", cr.Url)
+ if err := cr.RelayClient.Close(); err != nil {
+ log.Errorf("failed to close connection to %s: %v", cr.Url, err)
+ }
+ continue
+ }
+
+ hasSuccess = true
+ successChan <- cr
+ }
+ close(successChan)
+}
diff --git a/relay/server/peer.go b/relay/server/peer.go
index 0de601996..00341e98b 100644
--- a/relay/server/peer.go
+++ b/relay/server/peer.go
@@ -115,6 +115,7 @@ func (p *Peer) Write(b []byte) (int, error) {
// connection.
func (p *Peer) CloseGracefully(ctx context.Context) {
p.connMu.Lock()
+ defer p.connMu.Unlock()
err := p.writeWithTimeout(ctx, messages.MarshalCloseMsg())
if err != nil {
p.log.Errorf("failed to send close message to peer: %s", p.String())
@@ -124,8 +125,15 @@ func (p *Peer) CloseGracefully(ctx context.Context) {
if err != nil {
p.log.Errorf("failed to close connection to peer: %s", err)
}
+}
+func (p *Peer) Close() {
+ p.connMu.Lock()
defer p.connMu.Unlock()
+
+ if err := p.conn.Close(); err != nil {
+ p.log.Errorf("failed to close connection to peer: %s", err)
+ }
}
// String returns the peer ID
@@ -167,6 +175,7 @@ func (p *Peer) handleHealthcheckEvents(ctx context.Context, hc *healthcheck.Send
if err != nil {
p.log.Errorf("failed to close connection to peer: %s", err)
}
+ p.log.Info("peer connection closed due healthcheck timeout")
return
case <-ctx.Done():
return
diff --git a/relay/server/store.go b/relay/server/store.go
index 96879dae1..4288e62c5 100644
--- a/relay/server/store.go
+++ b/relay/server/store.go
@@ -19,10 +19,14 @@ func NewStore() *Store {
}
// AddPeer adds a peer to the store
-// todo: consider to close peer conn if the peer already exists
func (s *Store) AddPeer(peer *Peer) {
s.peersLock.Lock()
defer s.peersLock.Unlock()
+ odlPeer, ok := s.peers[peer.String()]
+ if ok {
+ odlPeer.Close()
+ }
+
s.peers[peer.String()] = peer
}
diff --git a/relay/server/store_test.go b/relay/server/store_test.go
index 4a30bc131..41c7baa92 100644
--- a/relay/server/store_test.go
+++ b/relay/server/store_test.go
@@ -2,13 +2,57 @@ package server
import (
"context"
+ "net"
"testing"
+ "time"
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/relay/metrics"
)
+type mockConn struct {
+}
+
+func (m mockConn) Read(b []byte) (n int, err error) {
+ //TODO implement me
+ panic("implement me")
+}
+
+func (m mockConn) Write(b []byte) (n int, err error) {
+ //TODO implement me
+ panic("implement me")
+}
+
+func (m mockConn) Close() error {
+ return nil
+}
+
+func (m mockConn) LocalAddr() net.Addr {
+ //TODO implement me
+ panic("implement me")
+}
+
+func (m mockConn) RemoteAddr() net.Addr {
+ //TODO implement me
+ panic("implement me")
+}
+
+func (m mockConn) SetDeadline(t time.Time) error {
+ //TODO implement me
+ panic("implement me")
+}
+
+func (m mockConn) SetReadDeadline(t time.Time) error {
+ //TODO implement me
+ panic("implement me")
+}
+
+func (m mockConn) SetWriteDeadline(t time.Time) error {
+ //TODO implement me
+ panic("implement me")
+}
+
func TestStore_DeletePeer(t *testing.T) {
s := NewStore()
@@ -27,8 +71,9 @@ func TestStore_DeleteDeprecatedPeer(t *testing.T) {
m, _ := metrics.NewMetrics(context.Background(), otel.Meter(""))
- p1 := NewPeer(m, []byte("peer_id"), nil, nil)
- p2 := NewPeer(m, []byte("peer_id"), nil, nil)
+ conn := &mockConn{}
+ p1 := NewPeer(m, []byte("peer_id"), conn, nil)
+ p2 := NewPeer(m, []byte("peer_id"), conn, nil)
s.AddPeer(p1)
s.AddPeer(p2)
From b74951f29e5929cd71618638a381b1bc79df84d8 Mon Sep 17 00:00:00 2001
From: Zoltan Papp
Date: Mon, 16 Sep 2024 22:42:37 +0200
Subject: [PATCH 59/89] [client] Enforce permissions on Win (#2568)
Enforce folder permission on Windows, giving only administrators and system access to the NetBird folder.
---
client/internal/config.go | 11 ++++-
util/file.go | 87 ++++++++++++++++++++++++--------------
util/permission.go | 7 +++
util/permission_windows.go | 86 +++++++++++++++++++++++++++++++++++++
4 files changed, 158 insertions(+), 33 deletions(-)
create mode 100644 util/permission.go
create mode 100644 util/permission_windows.go
diff --git a/client/internal/config.go b/client/internal/config.go
index 725703c43..1df1e0547 100644
--- a/client/internal/config.go
+++ b/client/internal/config.go
@@ -117,6 +117,11 @@ type Config struct {
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
func ReadConfig(configPath string) (*Config, error) {
if configFileIsExists(configPath) {
+ err := util.EnforcePermission(configPath)
+ if err != nil {
+ log.Errorf("failed to enforce permission on config dir: %v", err)
+ }
+
config := &Config{}
if _, err := util.ReadJson(configPath, config); err != nil {
return nil, err
@@ -159,13 +164,17 @@ func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
if err != nil {
return nil, err
}
- err = WriteOutConfig(input.ConfigPath, cfg)
+ err = util.WriteJsonWithRestrictedPermission(input.ConfigPath, cfg)
return cfg, err
}
if isPreSharedKeyHidden(input.PreSharedKey) {
input.PreSharedKey = nil
}
+ err := util.EnforcePermission(input.ConfigPath)
+ if err != nil {
+ log.Errorf("failed to enforce permission on config dir: %v", err)
+ }
return update(input)
}
diff --git a/util/file.go b/util/file.go
index 2a6182556..8355488c9 100644
--- a/util/file.go
+++ b/util/file.go
@@ -10,51 +10,30 @@ import (
log "github.com/sirupsen/logrus"
)
-// WriteJson writes JSON config object to a file creating parent directories if required
-// The output JSON is pretty-formatted
-func WriteJson(file string, obj interface{}) error {
-
+// WriteJsonWithRestrictedPermission writes JSON config object to a file. Enforces permission on the parent directory
+func WriteJsonWithRestrictedPermission(file string, obj interface{}) error {
configDir, configFileName, err := prepareConfigFileDir(file)
if err != nil {
return err
}
- // make it pretty
- bs, err := json.MarshalIndent(obj, "", " ")
+ err = EnforcePermission(file)
if err != nil {
return err
}
- tempFile, err := os.CreateTemp(configDir, ".*"+configFileName)
+ return writeJson(file, obj, configDir, configFileName)
+}
+
+// WriteJson writes JSON config object to a file creating parent directories if required
+// The output JSON is pretty-formatted
+func WriteJson(file string, obj interface{}) error {
+ configDir, configFileName, err := prepareConfigFileDir(file)
if err != nil {
return err
}
- tempFileName := tempFile.Name()
- // closing file ops as windows doesn't allow to move it
- err = tempFile.Close()
- if err != nil {
- return err
- }
-
- defer func() {
- _, err = os.Stat(tempFileName)
- if err == nil {
- os.Remove(tempFileName)
- }
- }()
-
- err = os.WriteFile(tempFileName, bs, 0600)
- if err != nil {
- return err
- }
-
- err = os.Rename(tempFileName, file)
- if err != nil {
- return err
- }
-
- return nil
+ return writeJson(file, obj, configDir, configFileName)
}
// DirectWriteJson writes JSON config object to a file creating parent directories if required without creating a temporary file
@@ -96,6 +75,46 @@ func DirectWriteJson(ctx context.Context, file string, obj interface{}) error {
return nil
}
+func writeJson(file string, obj interface{}, configDir string, configFileName string) error {
+
+ // make it pretty
+ bs, err := json.MarshalIndent(obj, "", " ")
+ if err != nil {
+ return err
+ }
+
+ tempFile, err := os.CreateTemp(configDir, ".*"+configFileName)
+ if err != nil {
+ return err
+ }
+
+ tempFileName := tempFile.Name()
+ // closing file ops as windows doesn't allow to move it
+ err = tempFile.Close()
+ if err != nil {
+ return err
+ }
+
+ defer func() {
+ _, err = os.Stat(tempFileName)
+ if err == nil {
+ os.Remove(tempFileName)
+ }
+ }()
+
+ err = os.WriteFile(tempFileName, bs, 0600)
+ if err != nil {
+ return err
+ }
+
+ err = os.Rename(tempFileName, file)
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
+
func openOrCreateFile(file string) (*os.File, error) {
s, err := os.Stat(file)
if err == nil {
@@ -172,5 +191,9 @@ func prepareConfigFileDir(file string) (string, string, error) {
}
err := os.MkdirAll(configDir, 0750)
+ if err != nil {
+ return "", "", err
+ }
+
return configDir, configFileName, err
}
diff --git a/util/permission.go b/util/permission.go
new file mode 100644
index 000000000..666998cff
--- /dev/null
+++ b/util/permission.go
@@ -0,0 +1,7 @@
+//go:build !windows
+
+package util
+
+func EnforcePermission(dirPath string) error {
+ return nil
+}
diff --git a/util/permission_windows.go b/util/permission_windows.go
new file mode 100644
index 000000000..548fef824
--- /dev/null
+++ b/util/permission_windows.go
@@ -0,0 +1,86 @@
+package util
+
+import (
+ "path/filepath"
+
+ log "github.com/sirupsen/logrus"
+ "golang.org/x/sys/windows"
+)
+
+const (
+ securityFlags = windows.OWNER_SECURITY_INFORMATION |
+ windows.GROUP_SECURITY_INFORMATION |
+ windows.DACL_SECURITY_INFORMATION |
+ windows.PROTECTED_DACL_SECURITY_INFORMATION
+)
+
+func EnforcePermission(file string) error {
+ dirPath := filepath.Dir(file)
+
+ user, group, err := sids()
+ if err != nil {
+ return err
+ }
+
+ adminGroupSid, err := windows.CreateWellKnownSid(windows.WinBuiltinAdministratorsSid)
+ if err != nil {
+ return err
+ }
+
+ explicitAccess := []windows.EXPLICIT_ACCESS{
+ {
+ AccessPermissions: windows.GENERIC_ALL,
+ AccessMode: windows.SET_ACCESS,
+ Inheritance: windows.SUB_CONTAINERS_AND_OBJECTS_INHERIT,
+ Trustee: windows.TRUSTEE{
+ MultipleTrusteeOperation: windows.NO_MULTIPLE_TRUSTEE,
+ TrusteeForm: windows.TRUSTEE_IS_SID,
+ TrusteeType: windows.TRUSTEE_IS_USER,
+ TrusteeValue: windows.TrusteeValueFromSID(user),
+ },
+ },
+ {
+ AccessPermissions: windows.GENERIC_ALL,
+ AccessMode: windows.SET_ACCESS,
+ Inheritance: windows.SUB_CONTAINERS_AND_OBJECTS_INHERIT,
+ Trustee: windows.TRUSTEE{
+ MultipleTrusteeOperation: windows.NO_MULTIPLE_TRUSTEE,
+ TrusteeForm: windows.TRUSTEE_IS_SID,
+ TrusteeType: windows.TRUSTEE_IS_WELL_KNOWN_GROUP,
+ TrusteeValue: windows.TrusteeValueFromSID(adminGroupSid),
+ },
+ },
+ }
+
+ dacl, err := windows.ACLFromEntries(explicitAccess, nil)
+ if err != nil {
+ return err
+ }
+
+ return windows.SetNamedSecurityInfo(dirPath, windows.SE_FILE_OBJECT, securityFlags, user, group, dacl, nil)
+}
+
+func sids() (*windows.SID, *windows.SID, error) {
+ var token windows.Token
+ err := windows.OpenProcessToken(windows.CurrentProcess(), windows.TOKEN_QUERY, &token)
+ if err != nil {
+ return nil, nil, err
+ }
+ defer func() {
+ if err := token.Close(); err != nil {
+ log.Errorf("failed to close process token: %v", err)
+ }
+ }()
+
+ tu, err := token.GetTokenUser()
+ if err != nil {
+ return nil, nil, err
+ }
+
+ pg, err := token.GetTokenPrimaryGroup()
+ if err != nil {
+ return nil, nil, err
+ }
+
+ return tu.User.Sid, pg.PrimaryGroup, nil
+}
From 5bc601111da6de26d658f45b1863018f25e06fe7 Mon Sep 17 00:00:00 2001
From: Maycon Santos
Date: Tue, 17 Sep 2024 10:04:17 +0200
Subject: [PATCH 60/89] [relay] Add health check attempt threshold (#2609)
* Add health check attempt threshold for receiver
* Add health check attempt threshold for sender
---
relay/client/client.go | 2 +-
relay/healthcheck/receiver.go | 34 ++++++---
relay/healthcheck/receiver_test.go | 61 +++++++++++++++-
relay/healthcheck/sender.go | 60 +++++++++++++---
relay/healthcheck/sender_test.go | 110 +++++++++++++++++++++++++++--
relay/server/peer.go | 2 +-
6 files changed, 240 insertions(+), 29 deletions(-)
diff --git a/relay/client/client.go b/relay/client/client.go
index 3e5c0ba24..e431c029d 100644
--- a/relay/client/client.go
+++ b/relay/client/client.go
@@ -306,7 +306,7 @@ func (c *Client) handShake() error {
func (c *Client) readLoop(relayConn net.Conn) {
internallyStoppedFlag := newInternalStopFlag()
- hc := healthcheck.NewReceiver()
+ hc := healthcheck.NewReceiver(c.log)
go c.listenForStopEvents(hc, relayConn, internallyStoppedFlag)
var (
diff --git a/relay/healthcheck/receiver.go b/relay/healthcheck/receiver.go
index 59f780ed8..b3503d5db 100644
--- a/relay/healthcheck/receiver.go
+++ b/relay/healthcheck/receiver.go
@@ -3,6 +3,8 @@ package healthcheck
import (
"context"
"time"
+
+ log "github.com/sirupsen/logrus"
)
var (
@@ -14,23 +16,26 @@ var (
// If the heartbeat is not received in a certain time, it will send a timeout signal and stop to work
// The heartbeat timeout is a bit longer than the sender's healthcheck interval
type Receiver struct {
- OnTimeout chan struct{}
-
- ctx context.Context
- ctxCancel context.CancelFunc
- heartbeat chan struct{}
- alive bool
+ OnTimeout chan struct{}
+ log *log.Entry
+ ctx context.Context
+ ctxCancel context.CancelFunc
+ heartbeat chan struct{}
+ alive bool
+ attemptThreshold int
}
// NewReceiver creates a new healthcheck receiver and start the timer in the background
-func NewReceiver() *Receiver {
+func NewReceiver(log *log.Entry) *Receiver {
ctx, ctxCancel := context.WithCancel(context.Background())
r := &Receiver{
- OnTimeout: make(chan struct{}, 1),
- ctx: ctx,
- ctxCancel: ctxCancel,
- heartbeat: make(chan struct{}, 1),
+ OnTimeout: make(chan struct{}, 1),
+ log: log,
+ ctx: ctx,
+ ctxCancel: ctxCancel,
+ heartbeat: make(chan struct{}, 1),
+ attemptThreshold: getAttemptThresholdFromEnv(),
}
go r.waitForHealthcheck()
@@ -56,16 +61,23 @@ func (r *Receiver) waitForHealthcheck() {
defer r.ctxCancel()
defer close(r.OnTimeout)
+ failureCounter := 0
for {
select {
case <-r.heartbeat:
r.alive = true
+ failureCounter = 0
case <-ticker.C:
if r.alive {
r.alive = false
continue
}
+ failureCounter++
+ if failureCounter < r.attemptThreshold {
+ r.log.Warnf("healthcheck failed, attempt %d", failureCounter)
+ continue
+ }
r.notifyTimeout()
return
case <-r.ctx.Done():
diff --git a/relay/healthcheck/receiver_test.go b/relay/healthcheck/receiver_test.go
index 4b4123416..3b3e32fe6 100644
--- a/relay/healthcheck/receiver_test.go
+++ b/relay/healthcheck/receiver_test.go
@@ -1,13 +1,18 @@
package healthcheck
import (
+ "context"
+ "fmt"
+ "os"
"testing"
"time"
+
+ log "github.com/sirupsen/logrus"
)
func TestNewReceiver(t *testing.T) {
heartbeatTimeout = 5 * time.Second
- r := NewReceiver()
+ r := NewReceiver(log.WithContext(context.Background()))
select {
case <-r.OnTimeout:
@@ -19,7 +24,7 @@ func TestNewReceiver(t *testing.T) {
func TestNewReceiverNotReceive(t *testing.T) {
heartbeatTimeout = 1 * time.Second
- r := NewReceiver()
+ r := NewReceiver(log.WithContext(context.Background()))
select {
case <-r.OnTimeout:
@@ -30,7 +35,7 @@ func TestNewReceiverNotReceive(t *testing.T) {
func TestNewReceiverAck(t *testing.T) {
heartbeatTimeout = 2 * time.Second
- r := NewReceiver()
+ r := NewReceiver(log.WithContext(context.Background()))
r.Heartbeat()
@@ -40,3 +45,53 @@ func TestNewReceiverAck(t *testing.T) {
case <-time.After(3 * time.Second):
}
}
+
+func TestReceiverHealthCheckAttemptThreshold(t *testing.T) {
+ testsCases := []struct {
+ name string
+ threshold int
+ resetCounterOnce bool
+ }{
+ {"Default attempt threshold", defaultAttemptThreshold, false},
+ {"Custom attempt threshold", 3, false},
+ {"Should reset threshold once", 2, true},
+ }
+
+ for _, tc := range testsCases {
+ t.Run(tc.name, func(t *testing.T) {
+ originalInterval := healthCheckInterval
+ originalTimeout := heartbeatTimeout
+ healthCheckInterval = 1 * time.Second
+ heartbeatTimeout = healthCheckInterval + 500*time.Millisecond
+ defer func() {
+ healthCheckInterval = originalInterval
+ heartbeatTimeout = originalTimeout
+ }()
+ //nolint:tenv
+ os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold))
+ defer os.Unsetenv(defaultAttemptThresholdEnv)
+
+ receiver := NewReceiver(log.WithField("test_name", tc.name))
+
+ testTimeout := heartbeatTimeout*time.Duration(tc.threshold) + healthCheckInterval
+
+ if tc.resetCounterOnce {
+ receiver.Heartbeat()
+ t.Logf("reset counter once")
+ }
+
+ select {
+ case <-receiver.OnTimeout:
+ if tc.resetCounterOnce {
+ t.Fatalf("should not have timed out before %s", testTimeout)
+ }
+ case <-time.After(testTimeout):
+ if tc.resetCounterOnce {
+ return
+ }
+ t.Fatalf("should have timed out before %s", testTimeout)
+ }
+
+ })
+ }
+}
diff --git a/relay/healthcheck/sender.go b/relay/healthcheck/sender.go
index 8d1716b2c..57b3015ec 100644
--- a/relay/healthcheck/sender.go
+++ b/relay/healthcheck/sender.go
@@ -2,7 +2,16 @@ package healthcheck
import (
"context"
+ "os"
+ "strconv"
"time"
+
+ log "github.com/sirupsen/logrus"
+)
+
+const (
+ defaultAttemptThreshold = 1
+ defaultAttemptThresholdEnv = "NB_RELAY_HC_ATTEMPT_THRESHOLD"
)
var (
@@ -15,20 +24,25 @@ var (
// If the receiver does not receive the signal in a certain time, it will send a timeout signal and stop to work
// It will also stop if the context is canceled
type Sender struct {
+ log *log.Entry
// HealthCheck is a channel to send health check signal to the peer
HealthCheck chan struct{}
// Timeout is a channel to the health check signal is not received in a certain time
Timeout chan struct{}
- ack chan struct{}
+ ack chan struct{}
+ alive bool
+ attemptThreshold int
}
// NewSender creates a new healthcheck sender
-func NewSender() *Sender {
+func NewSender(log *log.Entry) *Sender {
hc := &Sender{
- HealthCheck: make(chan struct{}, 1),
- Timeout: make(chan struct{}, 1),
- ack: make(chan struct{}, 1),
+ log: log,
+ HealthCheck: make(chan struct{}, 1),
+ Timeout: make(chan struct{}, 1),
+ ack: make(chan struct{}, 1),
+ attemptThreshold: getAttemptThresholdFromEnv(),
}
return hc
@@ -46,23 +60,51 @@ func (hc *Sender) StartHealthCheck(ctx context.Context) {
ticker := time.NewTicker(healthCheckInterval)
defer ticker.Stop()
- timeoutTimer := time.NewTimer(healthCheckInterval + healthCheckTimeout)
- defer timeoutTimer.Stop()
+ timeoutTicker := time.NewTicker(hc.getTimeoutTime())
+ defer timeoutTicker.Stop()
defer close(hc.HealthCheck)
defer close(hc.Timeout)
+ failureCounter := 0
for {
select {
case <-ticker.C:
hc.HealthCheck <- struct{}{}
- case <-timeoutTimer.C:
+ case <-timeoutTicker.C:
+ if hc.alive {
+ hc.alive = false
+ continue
+ }
+
+ failureCounter++
+ if failureCounter < hc.attemptThreshold {
+ hc.log.Warnf("Health check failed attempt %d.", failureCounter)
+ continue
+ }
hc.Timeout <- struct{}{}
return
case <-hc.ack:
- timeoutTimer.Reset(healthCheckInterval + healthCheckTimeout)
+ failureCounter = 0
+ hc.alive = true
case <-ctx.Done():
return
}
}
}
+
+func (hc *Sender) getTimeoutTime() time.Duration {
+ return healthCheckInterval + healthCheckTimeout
+}
+
+func getAttemptThresholdFromEnv() int {
+ if attemptThreshold := os.Getenv(defaultAttemptThresholdEnv); attemptThreshold != "" {
+ threshold, err := strconv.ParseInt(attemptThreshold, 10, 64)
+ if err != nil {
+ log.Errorf("Failed to parse attempt threshold from environment variable \"%s\" should be an integer. Using default value", attemptThreshold)
+ return defaultAttemptThreshold
+ }
+ return int(threshold)
+ }
+ return defaultAttemptThreshold
+}
diff --git a/relay/healthcheck/sender_test.go b/relay/healthcheck/sender_test.go
index 7a105c308..f21167025 100644
--- a/relay/healthcheck/sender_test.go
+++ b/relay/healthcheck/sender_test.go
@@ -2,9 +2,12 @@ package healthcheck
import (
"context"
+ "fmt"
"os"
"testing"
"time"
+
+ log "github.com/sirupsen/logrus"
)
func TestMain(m *testing.M) {
@@ -18,7 +21,7 @@ func TestMain(m *testing.M) {
func TestNewHealthPeriod(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
- hc := NewSender()
+ hc := NewSender(log.WithContext(ctx))
go hc.StartHealthCheck(ctx)
iterations := 0
@@ -38,7 +41,7 @@ func TestNewHealthPeriod(t *testing.T) {
func TestNewHealthFailed(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
- hc := NewSender()
+ hc := NewSender(log.WithContext(ctx))
go hc.StartHealthCheck(ctx)
select {
@@ -50,7 +53,7 @@ func TestNewHealthFailed(t *testing.T) {
func TestNewHealthcheckStop(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
- hc := NewSender()
+ hc := NewSender(log.WithContext(ctx))
go hc.StartHealthCheck(ctx)
time.Sleep(100 * time.Millisecond)
@@ -75,7 +78,7 @@ func TestNewHealthcheckStop(t *testing.T) {
func TestTimeoutReset(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
- hc := NewSender()
+ hc := NewSender(log.WithContext(ctx))
go hc.StartHealthCheck(ctx)
iterations := 0
@@ -101,3 +104,102 @@ func TestTimeoutReset(t *testing.T) {
t.Fatalf("is not exited")
}
}
+
+func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
+ testsCases := []struct {
+ name string
+ threshold int
+ resetCounterOnce bool
+ }{
+ {"Default attempt threshold", defaultAttemptThreshold, false},
+ {"Custom attempt threshold", 3, false},
+ {"Should reset threshold once", 2, true},
+ }
+
+ for _, tc := range testsCases {
+ t.Run(tc.name, func(t *testing.T) {
+ originalInterval := healthCheckInterval
+ originalTimeout := healthCheckTimeout
+ healthCheckInterval = 1 * time.Second
+ healthCheckTimeout = 500 * time.Millisecond
+ defer func() {
+ healthCheckInterval = originalInterval
+ healthCheckTimeout = originalTimeout
+ }()
+
+ //nolint:tenv
+ os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold))
+ defer os.Unsetenv(defaultAttemptThresholdEnv)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ sender := NewSender(log.WithField("test_name", tc.name))
+ go sender.StartHealthCheck(ctx)
+
+ go func() {
+ responded := false
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ case _, ok := <-sender.HealthCheck:
+ if !ok {
+ return
+ }
+ if tc.resetCounterOnce && !responded {
+ responded = true
+ sender.OnHCResponse()
+ }
+ }
+ }
+ }()
+
+ testTimeout := sender.getTimeoutTime()*time.Duration(tc.threshold) + healthCheckInterval
+
+ select {
+ case <-sender.Timeout:
+ if tc.resetCounterOnce {
+ t.Fatalf("should not have timed out before %s", testTimeout)
+ }
+ case <-time.After(testTimeout):
+ if tc.resetCounterOnce {
+ return
+ }
+ t.Fatalf("should have timed out before %s", testTimeout)
+ }
+
+ })
+ }
+
+}
+
+//nolint:tenv
+func TestGetAttemptThresholdFromEnv(t *testing.T) {
+ tests := []struct {
+ name string
+ envValue string
+ expected int
+ }{
+ {"Default attempt threshold when env is not set", "", defaultAttemptThreshold},
+ {"Custom attempt threshold when env is set to a valid integer", "3", 3},
+ {"Default attempt threshold when env is set to an invalid value", "invalid", defaultAttemptThreshold},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if tt.envValue == "" {
+ os.Unsetenv(defaultAttemptThresholdEnv)
+ } else {
+ os.Setenv(defaultAttemptThresholdEnv, tt.envValue)
+ }
+
+ result := getAttemptThresholdFromEnv()
+ if result != tt.expected {
+ t.Fatalf("Expected %d, got %d", tt.expected, result)
+ }
+
+ os.Unsetenv(defaultAttemptThresholdEnv)
+ })
+ }
+}
diff --git a/relay/server/peer.go b/relay/server/peer.go
index 00341e98b..a9c542f84 100644
--- a/relay/server/peer.go
+++ b/relay/server/peer.go
@@ -49,7 +49,7 @@ func (p *Peer) Work() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
- hc := healthcheck.NewSender()
+ hc := healthcheck.NewSender(p.log)
go hc.StartHealthCheck(ctx)
go p.handleHealthcheckEvents(ctx, hc)
From 1104c9c0487200fb4ffb1c1981c73551722a20d8 Mon Sep 17 00:00:00 2001
From: Zoltan Papp
Date: Tue, 17 Sep 2024 11:15:14 +0200
Subject: [PATCH 61/89] [client] Fix race condition while read/write conn
status in peer conn (#2607)
---
client/internal/peer/conn.go | 40 ++++++++++++++---------------
client/internal/peer/conn_status.go | 35 +++++++++++++++++++++++--
client/internal/peer/conn_test.go | 9 +++++--
3 files changed, 60 insertions(+), 24 deletions(-)
diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go
index f4a701f7f..9eb881087 100644
--- a/client/internal/peer/conn.go
+++ b/client/internal/peer/conn.go
@@ -89,8 +89,8 @@ type Conn struct {
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
onDisconnected func(remotePeer string, wgIP string)
- statusRelay ConnStatus
- statusICE ConnStatus
+ statusRelay *AtomicConnStatus
+ statusICE *AtomicConnStatus
currentConnPriority ConnPriority
opened bool // this flag is used to prevent close in case of not opened connection
@@ -131,8 +131,8 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
signaler: signaler,
relayManager: relayManager,
allowedIPsIP: allowedIPsIP.String(),
- statusRelay: StatusDisconnected,
- statusICE: StatusDisconnected,
+ statusRelay: NewAtomicConnStatus(),
+ statusICE: NewAtomicConnStatus(),
iCEDisconnected: make(chan bool, 1),
relayDisconnected: make(chan bool, 1),
}
@@ -323,11 +323,11 @@ func (conn *Conn) reconnectLoopWithRetry() {
}
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
- if conn.statusRelay == StatusDisconnected || conn.statusICE == StatusDisconnected {
+ if conn.statusRelay.Get() == StatusDisconnected || conn.statusICE.Get() == StatusDisconnected {
conn.log.Tracef("connectivity guard timedout, relay state: %s, ice state: %s", conn.statusRelay, conn.statusICE)
}
} else {
- if conn.statusICE == StatusDisconnected {
+ if conn.statusICE.Get() == StatusDisconnected {
conn.log.Tracef("connectivity guard timedout, ice state: %s", conn.statusICE)
}
}
@@ -419,7 +419,7 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon
conn.log.Debugf("ICE connection is ready")
- conn.statusICE = StatusConnected
+ conn.statusICE.Set(StatusConnected)
defer conn.updateIceState(iceConnInfo)
@@ -492,8 +492,8 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
conn.currentConnPriority = connPriorityRelay
}
- changed := conn.statusICE != newState && newState != StatusConnecting
- conn.statusICE = newState
+ changed := conn.statusICE.Get() != newState && newState != StatusConnecting
+ conn.statusICE.Set(newState)
select {
case conn.iCEDisconnected <- changed:
@@ -522,7 +522,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
}
conn.log.Debugf("Relay connection is ready to use")
- conn.statusRelay = StatusConnected
+ conn.statusRelay.Set(StatusConnected)
wgProxy := conn.wgProxyFactory.GetProxy(conn.ctx)
endpoint, err := wgProxy.AddTurnConn(rci.relayedConn)
@@ -538,7 +538,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
defer conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
if conn.currentConnPriority > connPriorityRelay {
- if conn.statusICE == StatusConnected {
+ if conn.statusICE.Get() == StatusConnected {
log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority)
return
}
@@ -594,8 +594,8 @@ func (conn *Conn) onWorkerRelayStateDisconnected() {
conn.wgProxyRelay = nil
}
- changed := conn.statusRelay != StatusDisconnected
- conn.statusRelay = StatusDisconnected
+ changed := conn.statusRelay.Get() != StatusDisconnected
+ conn.statusRelay.Set(StatusDisconnected)
select {
case conn.relayDisconnected <- changed:
@@ -661,8 +661,8 @@ func (conn *Conn) updateIceState(iceConnInfo ICEConnInfo) {
}
func (conn *Conn) setStatusToDisconnected() {
- conn.statusRelay = StatusDisconnected
- conn.statusICE = StatusDisconnected
+ conn.statusRelay.Set(StatusDisconnected)
+ conn.statusICE.Set(StatusDisconnected)
peerState := State{
PubKey: conn.config.Key,
@@ -706,7 +706,7 @@ func (conn *Conn) waitInitialRandomSleepTime() {
}
func (conn *Conn) isRelayed() bool {
- if conn.statusRelay == StatusDisconnected && (conn.statusICE == StatusDisconnected || conn.statusICE == StatusConnecting) {
+ if conn.statusRelay.Get() == StatusDisconnected && (conn.statusICE.Get() == StatusDisconnected || conn.statusICE.Get() == StatusConnecting) {
return false
}
@@ -718,11 +718,11 @@ func (conn *Conn) isRelayed() bool {
}
func (conn *Conn) evalStatus() ConnStatus {
- if conn.statusRelay == StatusConnected || conn.statusICE == StatusConnected {
+ if conn.statusRelay.Get() == StatusConnected || conn.statusICE.Get() == StatusConnected {
return StatusConnected
}
- if conn.statusRelay == StatusConnecting || conn.statusICE == StatusConnecting {
+ if conn.statusRelay.Get() == StatusConnecting || conn.statusICE.Get() == StatusConnecting {
return StatusConnecting
}
@@ -733,12 +733,12 @@ func (conn *Conn) isConnected() bool {
conn.mu.Lock()
defer conn.mu.Unlock()
- if conn.statusICE != StatusConnected && conn.statusICE != StatusConnecting {
+ if conn.statusICE.Get() != StatusConnected && conn.statusICE.Get() != StatusConnecting {
return false
}
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
- if conn.statusRelay != StatusConnected {
+ if conn.statusRelay.Get() != StatusConnected {
return false
}
}
diff --git a/client/internal/peer/conn_status.go b/client/internal/peer/conn_status.go
index 639117c89..3c747864f 100644
--- a/client/internal/peer/conn_status.go
+++ b/client/internal/peer/conn_status.go
@@ -1,6 +1,10 @@
package peer
-import log "github.com/sirupsen/logrus"
+import (
+ "sync/atomic"
+
+ log "github.com/sirupsen/logrus"
+)
const (
// StatusConnected indicate the peer is in connected state
@@ -12,7 +16,34 @@ const (
)
// ConnStatus describe the status of a peer's connection
-type ConnStatus int
+type ConnStatus int32
+
+// AtomicConnStatus is a thread-safe wrapper for ConnStatus
+type AtomicConnStatus struct {
+ status atomic.Int32
+}
+
+// NewAtomicConnStatus creates a new AtomicConnStatus with the given initial status
+func NewAtomicConnStatus() *AtomicConnStatus {
+ acs := &AtomicConnStatus{}
+ acs.Set(StatusDisconnected)
+ return acs
+}
+
+// Get returns the current connection status
+func (acs *AtomicConnStatus) Get() ConnStatus {
+ return ConnStatus(acs.status.Load())
+}
+
+// Set updates the connection status
+func (acs *AtomicConnStatus) Set(status ConnStatus) {
+ acs.status.Store(int32(status))
+}
+
+// String returns the string representation of the current status
+func (acs *AtomicConnStatus) String() string {
+ return acs.Get().String()
+}
func (s ConnStatus) String() string {
switch s {
diff --git a/client/internal/peer/conn_test.go b/client/internal/peer/conn_test.go
index 59f249b82..80c25f63c 100644
--- a/client/internal/peer/conn_test.go
+++ b/client/internal/peer/conn_test.go
@@ -158,8 +158,13 @@ func TestConn_Status(t *testing.T) {
for _, table := range tables {
t.Run(table.name, func(t *testing.T) {
- conn.statusICE = table.statusIce
- conn.statusRelay = table.statusRelay
+ si := NewAtomicConnStatus()
+ si.Set(table.statusIce)
+ conn.statusICE = si
+
+ sr := NewAtomicConnStatus()
+ sr.Set(table.statusRelay)
+ conn.statusRelay = sr
got := conn.Status()
assert.Equal(t, got, table.want, "they should be equal")
From 28cbb4b70f370e0d3d23818fc8246ceee0b358c8 Mon Sep 17 00:00:00 2001
From: Zoltan Papp
Date: Tue, 17 Sep 2024 12:10:17 +0200
Subject: [PATCH 62/89] [client] Cancel the context of wg watcher when the go
routine exit (#2612)
---
client/internal/peer/conn.go | 4 ++--
client/internal/peer/worker_relay.go | 9 +++++----
2 files changed, 7 insertions(+), 6 deletions(-)
diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go
index 9eb881087..5327e31d2 100644
--- a/client/internal/peer/conn.go
+++ b/client/internal/peer/conn.go
@@ -484,11 +484,11 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
// switch back to relay connection
if conn.endpointRelay != nil && conn.currentConnPriority != connPriorityRelay {
conn.log.Debugf("ICE disconnected, set Relay to active connection")
- conn.workerRelay.EnableWgWatcher(conn.ctx)
err := conn.configureWGEndpoint(conn.endpointRelay)
if err != nil {
conn.log.Errorf("failed to switch to relay conn: %v", err)
}
+ conn.workerRelay.EnableWgWatcher(conn.ctx)
conn.currentConnPriority = connPriorityRelay
}
@@ -551,7 +551,6 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
}
}
- conn.workerRelay.EnableWgWatcher(conn.ctx)
err = conn.configureWGEndpoint(endpointUdpAddr)
if err != nil {
if err := wgProxy.CloseConn(); err != nil {
@@ -560,6 +559,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
conn.log.Errorf("Failed to update wg peer configuration: %v", err)
return
}
+ conn.workerRelay.EnableWgWatcher(conn.ctx)
wgConfigWorkaround()
if conn.wgProxyRelay != nil {
diff --git a/client/internal/peer/worker_relay.go b/client/internal/peer/worker_relay.go
index 3457faa46..6bb385d3e 100644
--- a/client/internal/peer/worker_relay.go
+++ b/client/internal/peer/worker_relay.go
@@ -109,10 +109,10 @@ func (w *WorkerRelay) EnableWgWatcher(ctx context.Context) {
}
ctx, ctxCancel := context.WithCancel(ctx)
- w.wgStateCheck(ctx)
w.ctxWgWatch = ctx
w.ctxCancelWgWatch = ctxCancel
+ w.wgStateCheck(ctx, ctxCancel)
}
func (w *WorkerRelay) DisableWgWatcher() {
@@ -158,21 +158,22 @@ func (w *WorkerRelay) CloseConn() {
}
// wgStateCheck help to check the state of the WireGuard handshake and relay connection
-func (w *WorkerRelay) wgStateCheck(ctx context.Context) {
+func (w *WorkerRelay) wgStateCheck(ctx context.Context, ctxCancel context.CancelFunc) {
+ w.log.Debugf("WireGuard watcher started")
lastHandshake, err := w.wgState()
if err != nil {
- w.log.Errorf("failed to read wg stats: %v", err)
+ w.log.Warnf("failed to read wg stats: %v", err)
lastHandshake = time.Time{}
}
go func(lastHandshake time.Time) {
timer := time.NewTimer(wgHandshakeOvertime)
defer timer.Stop()
+ defer ctxCancel()
for {
select {
case <-timer.C:
-
handshake, err := w.wgState()
if err != nil {
w.log.Errorf("failed to read wg stats: %v", err)
From 6f0fd1d1b33dddaf4a927cede4f6854765431ea6 Mon Sep 17 00:00:00 2001
From: Zoltan Papp
Date: Thu, 19 Sep 2024 13:49:09 +0200
Subject: [PATCH 63/89] - Increase queue size and drop the overflowed messages
(#2617)
- Explicit close the net.Conn in user space wgProxy when close the wgProxy
- Add extra logs
---
client/internal/peer/conn.go | 9 +++++---
client/internal/wgproxy/proxy_userspace.go | 18 +++++++++++----
relay/client/client.go | 27 ++++++++++++++++------
3 files changed, 40 insertions(+), 14 deletions(-)
diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go
index 5327e31d2..911ddd228 100644
--- a/client/internal/peer/conn.go
+++ b/client/internal/peer/conn.go
@@ -518,6 +518,9 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
defer conn.mu.Unlock()
if conn.ctx.Err() != nil {
+ if err := rci.relayedConn.Close(); err != nil {
+ log.Warnf("failed to close unnecessary relayed connection: %v", err)
+ }
return
}
@@ -530,6 +533,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
return
}
+ conn.log.Infof("created new wgProxy for relay connection: %s", endpoint)
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
conn.endpointRelay = endpointUdpAddr
@@ -775,9 +779,8 @@ func (conn *Conn) getEndpointForICEConnInfo(iceConnInfo ICEConnInfo) (net.Addr,
ep, err := wgProxy.AddTurnConn(iceConnInfo.RemoteConn)
if err != nil {
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
- err = wgProxy.CloseConn()
- if err != nil {
- conn.log.Warnf("failed to close turn proxy connection: %v", err)
+ if errClose := wgProxy.CloseConn(); errClose != nil {
+ conn.log.Warnf("failed to close turn proxy connection: %v", errClose)
}
return nil, nil, err
}
diff --git a/client/internal/wgproxy/proxy_userspace.go b/client/internal/wgproxy/proxy_userspace.go
index c2c8a9b51..701f615b9 100644
--- a/client/internal/wgproxy/proxy_userspace.go
+++ b/client/internal/wgproxy/proxy_userspace.go
@@ -32,8 +32,8 @@ func NewWGUserSpaceProxy(ctx context.Context, wgPort int) *WGUserSpaceProxy {
}
// AddTurnConn start the proxy with the given remote conn
-func (p *WGUserSpaceProxy) AddTurnConn(turnConn net.Conn) (net.Addr, error) {
- p.remoteConn = turnConn
+func (p *WGUserSpaceProxy) AddTurnConn(remoteConn net.Conn) (net.Addr, error) {
+ p.remoteConn = remoteConn
var err error
p.localConn, err = nbnet.NewDialer().DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
@@ -54,6 +54,14 @@ func (p *WGUserSpaceProxy) CloseConn() error {
if p.localConn == nil {
return nil
}
+
+ if p.remoteConn == nil {
+ return nil
+ }
+
+ if err := p.remoteConn.Close(); err != nil {
+ log.Warnf("failed to close remote conn: %s", err)
+ }
return p.localConn.Close()
}
@@ -65,6 +73,8 @@ func (p *WGUserSpaceProxy) Free() error {
// proxyToRemote proxies everything from Wireguard to the RemoteKey peer
// blocks
func (p *WGUserSpaceProxy) proxyToRemote() {
+ defer log.Infof("exit from proxyToRemote: %s", p.localConn.LocalAddr())
+
buf := make([]byte, 1500)
for {
select {
@@ -93,7 +103,8 @@ func (p *WGUserSpaceProxy) proxyToRemote() {
// proxyToLocal proxies everything from the RemoteKey peer to local Wireguard
// blocks
func (p *WGUserSpaceProxy) proxyToLocal() {
-
+ defer p.cancel()
+ defer log.Infof("exit from proxyToLocal: %s", p.localConn.LocalAddr())
buf := make([]byte, 1500)
for {
select {
@@ -103,7 +114,6 @@ func (p *WGUserSpaceProxy) proxyToLocal() {
n, err := p.remoteConn.Read(buf)
if err != nil {
if err == io.EOF {
- p.cancel()
return
}
log.Errorf("failed to read from remote conn: %s", err)
diff --git a/relay/client/client.go b/relay/client/client.go
index e431c029d..90bc3ac41 100644
--- a/relay/client/client.go
+++ b/relay/client/client.go
@@ -58,7 +58,10 @@ func (m *Msg) Free() {
m.bufPool.Put(m.bufPtr)
}
+// connContainer is a container for the connection to the peer. It is responsible for managing the messages from the
+// server and forwarding them to the upper layer content reader.
type connContainer struct {
+ log *log.Entry
conn *Conn
messages chan Msg
msgChanLock sync.Mutex
@@ -67,10 +70,10 @@ type connContainer struct {
cancel context.CancelFunc
}
-func newConnContainer(conn *Conn, messages chan Msg) *connContainer {
+func newConnContainer(log *log.Entry, conn *Conn, messages chan Msg) *connContainer {
ctx, cancel := context.WithCancel(context.Background())
-
return &connContainer{
+ log: log,
conn: conn,
messages: messages,
ctx: ctx,
@@ -91,6 +94,10 @@ func (cc *connContainer) writeMsg(msg Msg) {
case cc.messages <- msg:
case <-cc.ctx.Done():
msg.Free()
+ default:
+ msg.Free()
+ cc.log.Infof("message queue is full")
+ // todo consider to close the connection
}
}
@@ -141,8 +148,8 @@ type Client struct {
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client {
hashedID, hashedStringId := messages.HashID(peerID)
- return &Client{
- log: log.WithFields(log.Fields{"client_id": hashedStringId, "relay": serverURL}),
+ c := &Client{
+ log: log.WithFields(log.Fields{"relay": serverURL}),
parentCtx: ctx,
connectionURL: serverURL,
authTokenStore: authTokenStore,
@@ -155,6 +162,8 @@ func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.Token
},
conns: make(map[string]*connContainer),
}
+ c.log.Infof("create new relay connection: local peerID: %s, local peer hashedID: %s", peerID, hashedStringId)
+ return c
}
// Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs.
@@ -203,10 +212,10 @@ func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) {
}
c.log.Infof("open connection to peer: %s", hashedStringID)
- msgChannel := make(chan Msg, 2)
+ msgChannel := make(chan Msg, 100)
conn := NewConn(c, hashedID, hashedStringID, msgChannel, c.instanceURL)
- c.conns[hashedStringID] = newConnContainer(conn, msgChannel)
+ c.conns[hashedStringID] = newConnContainer(c.log, conn, msgChannel)
return conn, nil
}
@@ -455,7 +464,10 @@ func (c *Client) listenForStopEvents(hc *healthcheck.Receiver, conn net.Conn, in
}
c.log.Errorf("health check timeout")
internalStopFlag.set()
- _ = conn.Close() // ignore the err because the readLoop will handle it
+ if err := conn.Close(); err != nil {
+ // ignore the err handling because the readLoop will handle it
+ c.log.Warnf("failed to close connection: %s", err)
+ }
return
case <-c.parentCtx.Done():
err := c.close(true)
@@ -486,6 +498,7 @@ func (c *Client) closeConn(connReference *Conn, id string) error {
if container.conn != connReference {
return fmt.Errorf("conn reference mismatch")
}
+ c.log.Infof("free up connection to peer: %s", id)
delete(c.conns, id)
container.close()
From fc4b37f7bcdc2de36f279c458ce79da312d8d29e Mon Sep 17 00:00:00 2001
From: Zoltan Papp
Date: Thu, 19 Sep 2024 13:49:28 +0200
Subject: [PATCH 64/89] Exit from processConnResults after all tries (#2621)
* Exit from processConnResults after all tries
If all server is unavailable then the server picker never return
because we never close the result channel.
Count the number of the results and exit when we reached the
expected size
---
relay/client/picker.go | 10 +++++++---
relay/client/picker_test.go | 31 +++++++++++++++++++++++++++++++
2 files changed, 38 insertions(+), 3 deletions(-)
create mode 100644 relay/client/picker_test.go
diff --git a/relay/client/picker.go b/relay/client/picker.go
index b0888a4a0..13b0547aa 100644
--- a/relay/client/picker.go
+++ b/relay/client/picker.go
@@ -35,12 +35,15 @@ func (sp *ServerPicker) PickServer(parentCtx context.Context, urls []string) (*C
connResultChan := make(chan connResult, totalServers)
successChan := make(chan connResult, 1)
-
concurrentLimiter := make(chan struct{}, maxConcurrentServers)
+
for _, url := range urls {
+ // todo check if we have a successful connection so we do not need to connect to other servers
concurrentLimiter <- struct{}{}
go func(url string) {
- defer func() { <-concurrentLimiter }()
+ defer func() {
+ <-concurrentLimiter
+ }()
sp.startConnection(parentCtx, connResultChan, url)
}(url)
}
@@ -72,7 +75,8 @@ func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan con
func (sp *ServerPicker) processConnResults(resultChan chan connResult, successChan chan connResult) {
var hasSuccess bool
- for cr := range resultChan {
+ for numOfResults := 0; numOfResults < cap(resultChan); numOfResults++ {
+ cr := <-resultChan
if cr.Err != nil {
log.Debugf("failed to connect to Relay server: %s: %v", cr.Url, cr.Err)
continue
diff --git a/relay/client/picker_test.go b/relay/client/picker_test.go
new file mode 100644
index 000000000..f5649d700
--- /dev/null
+++ b/relay/client/picker_test.go
@@ -0,0 +1,31 @@
+package client
+
+import (
+ "context"
+ "errors"
+ "testing"
+ "time"
+)
+
+func TestServerPicker_UnavailableServers(t *testing.T) {
+ sp := ServerPicker{
+ TokenStore: nil,
+ PeerID: "test",
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+
+ go func() {
+ _, err := sp.PickServer(ctx, []string{"rel://dummy1", "rel://dummy2"})
+ if err == nil {
+ t.Error(err)
+ }
+ cancel()
+ }()
+
+ <-ctx.Done()
+ if errors.Is(ctx.Err(), context.DeadlineExceeded) {
+ t.Errorf("PickServer() took too long to complete")
+ }
+}
From 35c892aea365fa9e3e710ecc45dc7c0a353b8c71 Mon Sep 17 00:00:00 2001
From: Bethuel Mmbaga
Date: Fri, 20 Sep 2024 12:36:58 +0300
Subject: [PATCH 65/89] [management] Restrict accessible peers to user-owned
peers for non-admins (#2618)
* Restrict accessible peers to user-owned peers for non-admin users
Signed-off-by: bcmmbaga
* add tests
Signed-off-by: bcmmbaga
* add service user test
Signed-off-by: bcmmbaga
* reuse account from token
Signed-off-by: bcmmbaga
* return error when peer not found
Signed-off-by: bcmmbaga
---------
Signed-off-by: bcmmbaga
---
management/server/http/peers_handler.go | 20 +-
management/server/http/peers_handler_test.go | 194 +++++++++++++++++--
2 files changed, 198 insertions(+), 16 deletions(-)
diff --git a/management/server/http/peers_handler.go b/management/server/http/peers_handler.go
index 1487bbc39..5a2190d83 100644
--- a/management/server/http/peers_handler.go
+++ b/management/server/http/peers_handler.go
@@ -7,8 +7,6 @@ import (
"net/http"
"github.com/gorilla/mux"
- log "github.com/sirupsen/logrus"
-
"github.com/netbirdio/netbird/management/server"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/http/api"
@@ -16,6 +14,7 @@ import (
"github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
+ log "github.com/sirupsen/logrus"
)
// PeersHandler is a handler that returns peers of the account
@@ -215,7 +214,7 @@ func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approv
// GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network.
func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, _, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -228,6 +227,21 @@ func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request
return
}
+ // If the user is regular user and does not own the peer
+ // with the given peerID return an empty list
+ if !user.HasAdminPower() && !user.IsServiceUser {
+ peer, ok := account.Peers[peerID]
+ if !ok {
+ util.WriteError(r.Context(), status.Errorf(status.NotFound, "peer not found"), w)
+ return
+ }
+
+ if peer.UserID != user.Id {
+ util.WriteJSONObject(r.Context(), w, []api.AccessiblePeer{})
+ return
+ }
+ }
+
dnsDomain := h.accountManager.GetDNSDomain()
validPeers, err := h.accountManager.GetValidatedPeers(account)
diff --git a/management/server/http/peers_handler_test.go b/management/server/http/peers_handler_test.go
index 153c8f03a..dae264fff 100644
--- a/management/server/http/peers_handler_test.go
+++ b/management/server/http/peers_handler_test.go
@@ -4,6 +4,7 @@ import (
"bytes"
"context"
"encoding/json"
+ "fmt"
"io"
"net"
"net/http"
@@ -12,20 +13,30 @@ import (
"time"
"github.com/gorilla/mux"
-
+ nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/http/api"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
+ "golang.org/x/exp/maps"
"github.com/netbirdio/netbird/management/server/jwtclaims"
- "github.com/magiconair/properties/assert"
+ "github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/mock_server"
)
-const testPeerID = "test_peer"
-const noUpdateChannelTestPeerID = "no-update-channel"
+type ctxKey string
+
+const (
+ testPeerID = "test_peer"
+ noUpdateChannelTestPeerID = "no-update-channel"
+
+ adminUser = "admin_user"
+ regularUser = "regular_user"
+ serviceUser = "service_user"
+ userIDKey ctxKey = "user_id"
+)
func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
return &PeersHandler{
@@ -60,21 +71,57 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
return "netbird.selfhosted"
},
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
- user := server.NewAdminUser("test_user")
- return &server.Account{
+ peersMap := make(map[string]*nbpeer.Peer)
+ for _, peer := range peers {
+ peersMap[peer.ID] = peer.Copy()
+ }
+
+ policy := &server.Policy{
+ ID: "policy",
+ AccountID: claims.AccountId,
+ Name: "policy",
+ Enabled: true,
+ Rules: []*server.PolicyRule{
+ {
+ ID: "rule",
+ Name: "rule",
+ Enabled: true,
+ Action: "accept",
+ Destinations: []string{"group1"},
+ Sources: []string{"group1"},
+ Bidirectional: true,
+ Protocol: "all",
+ Ports: []string{"80"},
+ },
+ },
+ }
+
+ srvUser := server.NewRegularUser(serviceUser)
+ srvUser.IsServiceUser = true
+
+ account := &server.Account{
Id: claims.AccountId,
Domain: "hotmail.com",
- Peers: map[string]*nbpeer.Peer{
- peers[0].ID: peers[0],
- peers[1].ID: peers[1],
- },
+ Peers: peersMap,
Users: map[string]*server.User{
- "test_user": user,
+ adminUser: server.NewAdminUser(adminUser),
+ regularUser: server.NewRegularUser(regularUser),
+ serviceUser: srvUser,
+ },
+ Groups: map[string]*nbgroup.Group{
+ "group1": {
+ ID: "group1",
+ AccountID: claims.AccountId,
+ Name: "group1",
+ Issued: "api",
+ Peers: maps.Keys(peersMap),
+ },
},
Settings: &server.Settings{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: time.Hour,
},
+ Policies: []*server.Policy{policy},
Network: &server.Network{
Identifier: "ciclqisab2ss43jdn8q0",
Net: net.IPNet{
@@ -83,7 +130,9 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
},
Serial: 51,
},
- }, user, nil
+ }
+
+ return account, account.Users[claims.UserId], nil
},
HasConnectedChannelFunc: func(peerID string) bool {
statuses := make(map[string]struct{})
@@ -99,8 +148,9 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
},
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
+ userID := r.Context().Value(userIDKey).(string)
return jwtclaims.AuthorizationClaims{
- UserId: "test_user",
+ UserId: userID,
Domain: "hotmail.com",
AccountId: "test_id",
}
@@ -197,6 +247,8 @@ func TestGetPeers(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
+ ctx := context.WithValue(context.Background(), userIDKey, "admin_user")
+ req = req.WithContext(ctx)
router := mux.NewRouter()
router.HandleFunc("/api/peers/", p.GetAllPeers).Methods("GET")
@@ -251,3 +303,119 @@ func TestGetPeers(t *testing.T) {
})
}
}
+
+func TestGetAccessiblePeers(t *testing.T) {
+ peer1 := &nbpeer.Peer{
+ ID: "peer1",
+ Key: "key1",
+ IP: net.ParseIP("100.64.0.1"),
+ Status: &nbpeer.PeerStatus{Connected: true},
+ Name: "peer1",
+ LoginExpirationEnabled: false,
+ UserID: regularUser,
+ }
+
+ peer2 := &nbpeer.Peer{
+ ID: "peer2",
+ Key: "key2",
+ IP: net.ParseIP("100.64.0.2"),
+ Status: &nbpeer.PeerStatus{Connected: true},
+ Name: "peer2",
+ LoginExpirationEnabled: false,
+ UserID: adminUser,
+ }
+
+ peer3 := &nbpeer.Peer{
+ ID: "peer3",
+ Key: "key3",
+ IP: net.ParseIP("100.64.0.3"),
+ Status: &nbpeer.PeerStatus{Connected: true},
+ Name: "peer3",
+ LoginExpirationEnabled: false,
+ UserID: regularUser,
+ }
+
+ p := initTestMetaData(peer1, peer2, peer3)
+
+ tt := []struct {
+ name string
+ peerID string
+ callerUserID string
+ expectedStatus int
+ expectedPeers []string
+ }{
+ {
+ name: "non admin user can access owned peer",
+ peerID: "peer1",
+ callerUserID: regularUser,
+ expectedStatus: http.StatusOK,
+ expectedPeers: []string{"peer2", "peer3"},
+ },
+ {
+ name: "non admin user can't access unowned peer",
+ peerID: "peer2",
+ callerUserID: regularUser,
+ expectedStatus: http.StatusOK,
+ expectedPeers: []string{},
+ },
+ {
+ name: "admin user can access owned peer",
+ peerID: "peer2",
+ callerUserID: adminUser,
+ expectedStatus: http.StatusOK,
+ expectedPeers: []string{"peer1", "peer3"},
+ },
+ {
+ name: "admin user can access unowned peer",
+ peerID: "peer3",
+ callerUserID: adminUser,
+ expectedStatus: http.StatusOK,
+ expectedPeers: []string{"peer1", "peer2"},
+ },
+ {
+ name: "service user can access unowned peer",
+ peerID: "peer3",
+ callerUserID: serviceUser,
+ expectedStatus: http.StatusOK,
+ expectedPeers: []string{"peer1", "peer2"},
+ },
+ }
+
+ for _, tc := range tt {
+ t.Run(tc.name, func(t *testing.T) {
+
+ recorder := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/api/peers/%s/accessible-peers", tc.peerID), nil)
+ ctx := context.WithValue(context.Background(), userIDKey, tc.callerUserID)
+ req = req.WithContext(ctx)
+
+ router := mux.NewRouter()
+ router.HandleFunc("/api/peers/{peerId}/accessible-peers", p.GetAccessiblePeers).Methods("GET")
+ router.ServeHTTP(recorder, req)
+
+ res := recorder.Result()
+ if res.StatusCode != tc.expectedStatus {
+ t.Fatalf("handler returned wrong status code: got %v want %v", res.StatusCode, tc.expectedStatus)
+ }
+
+ body, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatalf("failed to read response body: %v", err)
+ }
+ defer res.Body.Close()
+
+ var accessiblePeers []api.AccessiblePeer
+ err = json.Unmarshal(body, &accessiblePeers)
+ if err != nil {
+ t.Fatalf("failed to unmarshal response: %v", err)
+ }
+
+ peerIDs := make([]string, len(accessiblePeers))
+ for i, peer := range accessiblePeers {
+ peerIDs[i] = peer.Id
+ }
+
+ assert.ElementsMatch(t, peerIDs, tc.expectedPeers)
+ })
+ }
+}
From d47be154ea2fa5792d8e79fd75f9700338f59dec Mon Sep 17 00:00:00 2001
From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com>
Date: Mon, 23 Sep 2024 10:02:03 +0200
Subject: [PATCH 66/89] [misc] Fix ip range posture check example (#2628)
---
management/server/http/api/openapi.yml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml
index 156310a9b..2463f830e 100644
--- a/management/server/http/api/openapi.yml
+++ b/management/server/http/api/openapi.yml
@@ -950,7 +950,7 @@ components:
type: array
items:
type: string
- example: ["192.168.1.0/24", "10.0.0.0/8", "2001:db8:1234:1a00::/56"]
+ example: ["192.168.1.0/24", "10.0.0.0/8", "2001:db8:1234:1a00::/56"]
action:
description: Action to take upon policy match
type: string
From ab82302c9590783d685c467ad2d1967c2fbbee26 Mon Sep 17 00:00:00 2001
From: Viktor Liu <17948409+lixmal@users.noreply.github.com>
Date: Tue, 24 Sep 2024 12:29:15 +0200
Subject: [PATCH 67/89] [client] Remove usage of custom dialer for localhost
(#2639)
* Downgrade error log level for network monitor warnings
* Do not use custom dialer for localhost
---
client/internal/networkmonitor/monitor_bsd.go | 10 +++++-----
client/internal/wgproxy/proxy_userspace.go | 5 ++---
2 files changed, 7 insertions(+), 8 deletions(-)
diff --git a/client/internal/networkmonitor/monitor_bsd.go b/client/internal/networkmonitor/monitor_bsd.go
index 51135a729..4dc2c1aa3 100644
--- a/client/internal/networkmonitor/monitor_bsd.go
+++ b/client/internal/networkmonitor/monitor_bsd.go
@@ -24,7 +24,7 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
defer func() {
err := unix.Close(fd)
if err != nil && !errors.Is(err, unix.EBADF) {
- log.Errorf("Network monitor: failed to close routing socket: %v", err)
+ log.Warnf("Network monitor: failed to close routing socket: %v", err)
}
}()
@@ -32,7 +32,7 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
<-ctx.Done()
err := unix.Close(fd)
if err != nil && !errors.Is(err, unix.EBADF) {
- log.Debugf("Network monitor: closed routing socket")
+ log.Debugf("Network monitor: closed routing socket: %v", err)
}
}()
@@ -45,12 +45,12 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
n, err := unix.Read(fd, buf)
if err != nil {
if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) {
- log.Errorf("Network monitor: failed to read from routing socket: %v", err)
+ log.Warnf("Network monitor: failed to read from routing socket: %v", err)
}
continue
}
if n < unix.SizeofRtMsghdr {
- log.Errorf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
+ log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
continue
}
@@ -61,7 +61,7 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
case unix.RTM_ADD, syscall.RTM_DELETE:
route, err := parseRouteMessage(buf[:n])
if err != nil {
- log.Errorf("Network monitor: error parsing routing message: %v", err)
+ log.Debugf("Network monitor: error parsing routing message: %v", err)
continue
}
diff --git a/client/internal/wgproxy/proxy_userspace.go b/client/internal/wgproxy/proxy_userspace.go
index 701f615b9..8fc640b6a 100644
--- a/client/internal/wgproxy/proxy_userspace.go
+++ b/client/internal/wgproxy/proxy_userspace.go
@@ -7,8 +7,6 @@ import (
"net"
log "github.com/sirupsen/logrus"
-
- nbnet "github.com/netbirdio/netbird/util/net"
)
// WGUserSpaceProxy proxies
@@ -36,7 +34,8 @@ func (p *WGUserSpaceProxy) AddTurnConn(remoteConn net.Conn) (net.Addr, error) {
p.remoteConn = remoteConn
var err error
- p.localConn, err = nbnet.NewDialer().DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
+ dialer := &net.Dialer{}
+ p.localConn, err = dialer.DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
if err != nil {
log.Errorf("failed dialing to local Wireguard port %s", err)
return nil, err
From e7d52c8c95aa0a0520442bc6c984be2343d70ee8 Mon Sep 17 00:00:00 2001
From: Viktor Liu <17948409+lixmal@users.noreply.github.com>
Date: Tue, 24 Sep 2024 20:57:56 +0200
Subject: [PATCH 68/89] [client] Fix error count formatting (#2641)
---
client/errors/errors.go | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/client/errors/errors.go b/client/errors/errors.go
index cef999ac8..8faadbda5 100644
--- a/client/errors/errors.go
+++ b/client/errors/errors.go
@@ -8,8 +8,8 @@ import (
)
func formatError(es []error) string {
- if len(es) == 0 {
- return fmt.Sprintf("0 error occurred:\n\t* %s", es[0])
+ if len(es) == 1 {
+ return fmt.Sprintf("1 error occurred:\n\t* %s", es[0])
}
points := make([]string, len(es))
From b51d75204b13a191ecb2ab01ad81d1a73b49b5c5 Mon Sep 17 00:00:00 2001
From: Viktor Liu <17948409+lixmal@users.noreply.github.com>
Date: Tue, 24 Sep 2024 20:58:18 +0200
Subject: [PATCH 69/89] [client] Anonymize relay address in status peers view
(#2640)
---
client/cmd/status.go | 3 +++
1 file changed, 3 insertions(+)
diff --git a/client/cmd/status.go b/client/cmd/status.go
index 1ef8b4913..ed3daa2b5 100644
--- a/client/cmd/status.go
+++ b/client/cmd/status.go
@@ -805,6 +805,9 @@ func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) {
if remoteIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Remote); err == nil {
peer.IceCandidateEndpoint.Remote = fmt.Sprintf("%s:%s", a.AnonymizeIPString(remoteIP), port)
}
+
+ peer.RelayAddress = a.AnonymizeURI(peer.RelayAddress)
+
for i, route := range peer.Routes {
peer.Routes[i] = a.AnonymizeIPString(route)
}
From 1e4a0f77e27710e57c66ef775f9ccd1e97e82a84 Mon Sep 17 00:00:00 2001
From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com>
Date: Wed, 25 Sep 2024 18:22:27 +0200
Subject: [PATCH 70/89] Add get DB method to store (#2650)
---
management/server/sql_store.go | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/management/server/sql_store.go b/management/server/sql_store.go
index 6f1f66ef8..8fa5f9d05 100644
--- a/management/server/sql_store.go
+++ b/management/server/sql_store.go
@@ -1024,3 +1024,7 @@ func (s *SqlStore) withTx(tx *gorm.DB) Store {
db: tx,
}
}
+
+func (s *SqlStore) GetDB() *gorm.DB {
+ return s.db
+}
From 4ebf6e1c4c5b549ad6983b2a7a36874fd8a85dc4 Mon Sep 17 00:00:00 2001
From: Zoltan Papp
Date: Wed, 25 Sep 2024 18:50:10 +0200
Subject: [PATCH 71/89] [client] Close the remote conn in proxy (#2626)
Port the conn close call to eBPF proxy
---
client/internal/engine.go | 2 +-
client/internal/peer/conn.go | 8 +-
client/internal/peer/conn_test.go | 8 +-
.../internal/wgproxy/{ => ebpf}/portlookup.go | 2 +-
.../wgproxy/{ => ebpf}/portlookup_test.go | 2 +-
.../wgproxy/{proxy_ebpf.go => ebpf/proxy.go} | 169 ++++++++++--------
.../proxy_test.go} | 9 +-
client/internal/wgproxy/ebpf/wrapper.go | 44 +++++
client/internal/wgproxy/factory.go | 22 ---
client/internal/wgproxy/factory_linux.go | 33 +++-
client/internal/wgproxy/factory_nonlinux.go | 16 +-
client/internal/wgproxy/proxy.go | 6 +-
client/internal/wgproxy/proxy_test.go | 128 +++++++++++++
client/internal/wgproxy/proxy_userspace.go | 129 -------------
client/internal/wgproxy/usp/proxy.go | 146 +++++++++++++++
relay/client/picker_test.go | 2 +-
16 files changed, 469 insertions(+), 257 deletions(-)
rename client/internal/wgproxy/{ => ebpf}/portlookup.go (96%)
rename client/internal/wgproxy/{ => ebpf}/portlookup_test.go (97%)
rename client/internal/wgproxy/{proxy_ebpf.go => ebpf/proxy.go} (65%)
rename client/internal/wgproxy/{proxy_ebpf_test.go => ebpf/proxy_test.go} (86%)
create mode 100644 client/internal/wgproxy/ebpf/wrapper.go
delete mode 100644 client/internal/wgproxy/factory.go
create mode 100644 client/internal/wgproxy/proxy_test.go
delete mode 100644 client/internal/wgproxy/proxy_userspace.go
create mode 100644 client/internal/wgproxy/usp/proxy.go
diff --git a/client/internal/engine.go b/client/internal/engine.go
index b0deb5a29..463507ad8 100644
--- a/client/internal/engine.go
+++ b/client/internal/engine.go
@@ -292,7 +292,7 @@ func (e *Engine) Start() error {
e.wgInterface = wgIface
userspace := e.wgInterface.IsUserspaceBind()
- e.wgProxyFactory = wgproxy.NewFactory(e.ctx, userspace, e.config.WgPort)
+ e.wgProxyFactory = wgproxy.NewFactory(userspace, e.config.WgPort)
if e.config.RosenpassEnabled {
log.Infof("rosenpass is enabled")
diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go
index 911ddd228..ea6d892b9 100644
--- a/client/internal/peer/conn.go
+++ b/client/internal/peer/conn.go
@@ -527,8 +527,8 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
conn.log.Debugf("Relay connection is ready to use")
conn.statusRelay.Set(StatusConnected)
- wgProxy := conn.wgProxyFactory.GetProxy(conn.ctx)
- endpoint, err := wgProxy.AddTurnConn(rci.relayedConn)
+ wgProxy := conn.wgProxyFactory.GetProxy()
+ endpoint, err := wgProxy.AddTurnConn(conn.ctx, rci.relayedConn)
if err != nil {
conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
return
@@ -775,8 +775,8 @@ func (conn *Conn) getEndpointForICEConnInfo(iceConnInfo ICEConnInfo) (net.Addr,
return iceConnInfo.RemoteConn.RemoteAddr(), nil, nil
}
conn.log.Debugf("setup ice turn connection")
- wgProxy := conn.wgProxyFactory.GetProxy(conn.ctx)
- ep, err := wgProxy.AddTurnConn(iceConnInfo.RemoteConn)
+ wgProxy := conn.wgProxyFactory.GetProxy()
+ ep, err := wgProxy.AddTurnConn(conn.ctx, iceConnInfo.RemoteConn)
if err != nil {
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
if errClose := wgProxy.CloseConn(); errClose != nil {
diff --git a/client/internal/peer/conn_test.go b/client/internal/peer/conn_test.go
index 80c25f63c..22e5409f8 100644
--- a/client/internal/peer/conn_test.go
+++ b/client/internal/peer/conn_test.go
@@ -44,7 +44,7 @@ func TestNewConn_interfaceFilter(t *testing.T) {
}
func TestConn_GetKey(t *testing.T) {
- wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
+ wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
@@ -59,7 +59,7 @@ func TestConn_GetKey(t *testing.T) {
}
func TestConn_OnRemoteOffer(t *testing.T) {
- wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
+ wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
@@ -96,7 +96,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
}
func TestConn_OnRemoteAnswer(t *testing.T) {
- wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
+ wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
@@ -132,7 +132,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
wg.Wait()
}
func TestConn_Status(t *testing.T) {
- wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
+ wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
diff --git a/client/internal/wgproxy/portlookup.go b/client/internal/wgproxy/ebpf/portlookup.go
similarity index 96%
rename from client/internal/wgproxy/portlookup.go
rename to client/internal/wgproxy/ebpf/portlookup.go
index 6f3d33487..0e2c20c99 100644
--- a/client/internal/wgproxy/portlookup.go
+++ b/client/internal/wgproxy/ebpf/portlookup.go
@@ -1,4 +1,4 @@
-package wgproxy
+package ebpf
import (
"fmt"
diff --git a/client/internal/wgproxy/portlookup_test.go b/client/internal/wgproxy/ebpf/portlookup_test.go
similarity index 97%
rename from client/internal/wgproxy/portlookup_test.go
rename to client/internal/wgproxy/ebpf/portlookup_test.go
index 6a386f330..92f4b8eee 100644
--- a/client/internal/wgproxy/portlookup_test.go
+++ b/client/internal/wgproxy/ebpf/portlookup_test.go
@@ -1,4 +1,4 @@
-package wgproxy
+package ebpf
import (
"fmt"
diff --git a/client/internal/wgproxy/proxy_ebpf.go b/client/internal/wgproxy/ebpf/proxy.go
similarity index 65%
rename from client/internal/wgproxy/proxy_ebpf.go
rename to client/internal/wgproxy/ebpf/proxy.go
index d385cc4ca..4bd4bfff6 100644
--- a/client/internal/wgproxy/proxy_ebpf.go
+++ b/client/internal/wgproxy/ebpf/proxy.go
@@ -1,6 +1,6 @@
//go:build linux && !android
-package wgproxy
+package ebpf
import (
"context"
@@ -13,47 +13,49 @@ import (
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
+ "github.com/hashicorp/go-multierror"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
+ nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/ebpf"
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
nbnet "github.com/netbirdio/netbird/util/net"
)
+const (
+ loopbackAddr = "127.0.0.1"
+)
+
// WGEBPFProxy definition for proxy with EBPF support
type WGEBPFProxy struct {
- ebpfManager ebpfMgr.Manager
-
- ctx context.Context
- cancel context.CancelFunc
-
- lastUsedPort uint16
localWGListenPort int
+ ebpfManager ebpfMgr.Manager
turnConnStore map[uint16]net.Conn
turnConnMutex sync.Mutex
- rawConn net.PacketConn
- conn transport.UDPConn
+ lastUsedPort uint16
+ rawConn net.PacketConn
+ conn transport.UDPConn
+
+ ctx context.Context
+ ctxCancel context.CancelFunc
}
// NewWGEBPFProxy create new WGEBPFProxy instance
-func NewWGEBPFProxy(ctx context.Context, wgPort int) *WGEBPFProxy {
+func NewWGEBPFProxy(wgPort int) *WGEBPFProxy {
log.Debugf("instantiate ebpf proxy")
wgProxy := &WGEBPFProxy{
localWGListenPort: wgPort,
ebpfManager: ebpf.GetEbpfManagerInstance(),
- lastUsedPort: 0,
turnConnStore: make(map[uint16]net.Conn),
}
- wgProxy.ctx, wgProxy.cancel = context.WithCancel(ctx)
-
return wgProxy
}
-// listen load ebpf program and listen the proxy
-func (p *WGEBPFProxy) listen() error {
+// Listen load ebpf program and listen the proxy
+func (p *WGEBPFProxy) Listen() error {
pl := portLookup{}
wgPorxyPort, err := pl.searchFreePort()
if err != nil {
@@ -72,9 +74,11 @@ func (p *WGEBPFProxy) listen() error {
addr := net.UDPAddr{
Port: wgPorxyPort,
- IP: net.ParseIP("127.0.0.1"),
+ IP: net.ParseIP(loopbackAddr),
}
+ p.ctx, p.ctxCancel = context.WithCancel(context.Background())
+
conn, err := nbnet.ListenUDP("udp", &addr)
if err != nil {
cErr := p.Free()
@@ -91,108 +95,112 @@ func (p *WGEBPFProxy) listen() error {
}
// AddTurnConn add new turn connection for the proxy
-func (p *WGEBPFProxy) AddTurnConn(turnConn net.Conn) (net.Addr, error) {
+func (p *WGEBPFProxy) AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error) {
wgEndpointPort, err := p.storeTurnConn(turnConn)
if err != nil {
return nil, err
}
- go p.proxyToLocal(wgEndpointPort, turnConn)
+ go p.proxyToLocal(ctx, wgEndpointPort, turnConn)
log.Infof("turn conn added to wg proxy store: %s, endpoint port: :%d", turnConn.RemoteAddr(), wgEndpointPort)
wgEndpoint := &net.UDPAddr{
- IP: net.ParseIP("127.0.0.1"),
+ IP: net.ParseIP(loopbackAddr),
Port: int(wgEndpointPort),
}
return wgEndpoint, nil
}
-// CloseConn doing nothing because this type of proxy implementation does not store the connection
-func (p *WGEBPFProxy) CloseConn() error {
- return nil
-}
-
-// Free resources
+// Free resources except the remoteConns will be keep open.
func (p *WGEBPFProxy) Free() error {
log.Debugf("free up ebpf wg proxy")
- var err1, err2, err3 error
- if p.conn != nil {
- err1 = p.conn.Close()
+ if p.ctx != nil && p.ctx.Err() != nil {
+ //nolint
+ return nil
}
- err2 = p.ebpfManager.FreeWGProxy()
- if p.rawConn != nil {
- err3 = p.rawConn.Close()
+ p.ctxCancel()
+
+ var result *multierror.Error
+ if err := p.conn.Close(); err != nil {
+ result = multierror.Append(result, err)
}
- if err1 != nil {
- return err1
+ if err := p.ebpfManager.FreeWGProxy(); err != nil {
+ result = multierror.Append(result, err)
}
- if err2 != nil {
- return err2
+ if err := p.rawConn.Close(); err != nil {
+ result = multierror.Append(result, err)
}
-
- return err3
+ return nberrors.FormatErrorOrNil(result)
}
-func (p *WGEBPFProxy) proxyToLocal(endpointPort uint16, remoteConn net.Conn) {
+func (p *WGEBPFProxy) proxyToLocal(ctx context.Context, endpointPort uint16, remoteConn net.Conn) {
+ defer p.removeTurnConn(endpointPort)
+
+ var (
+ err error
+ n int
+ )
buf := make([]byte, 1500)
- var err error
- defer func() {
- p.removeTurnConn(endpointPort)
- }()
- for {
- select {
- case <-p.ctx.Done():
- return
- default:
- var n int
- n, err = remoteConn.Read(buf)
- if err != nil {
- if err != io.EOF {
- log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err)
- }
+ for ctx.Err() == nil {
+ n, err = remoteConn.Read(buf)
+ if err != nil {
+ if ctx.Err() != nil {
return
}
- err = p.sendPkg(buf[:n], endpointPort)
- if err != nil {
- log.Errorf("failed to write out turn pkg to local conn: %v", err)
+ if err != io.EOF {
+ log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err)
}
+ return
+ }
+
+ if err := p.sendPkg(buf[:n], endpointPort); err != nil {
+ if ctx.Err() != nil || p.ctx.Err() != nil {
+ return
+ }
+ log.Errorf("failed to write out turn pkg to local conn: %v", err)
}
}
}
// proxyToRemote read messages from local WireGuard interface and forward it to remote conn
+// From this go routine has only one instance.
func (p *WGEBPFProxy) proxyToRemote() {
buf := make([]byte, 1500)
- for {
- select {
- case <-p.ctx.Done():
- return
- default:
- n, addr, err := p.conn.ReadFromUDP(buf)
- if err != nil {
- log.Errorf("failed to read UDP pkg from WG: %s", err)
+ for p.ctx.Err() == nil {
+ if err := p.readAndForwardPacket(buf); err != nil {
+ if p.ctx.Err() != nil {
return
}
-
- p.turnConnMutex.Lock()
- conn, ok := p.turnConnStore[uint16(addr.Port)]
- p.turnConnMutex.Unlock()
- if !ok {
- log.Debugf("turn conn not found by port because conn already has been closed: %d", addr.Port)
- continue
- }
-
- _, err = conn.Write(buf[:n])
- if err != nil {
- log.Debugf("failed to forward local wg pkg (%d) to remote turn conn: %s", addr.Port, err)
- }
+ log.Errorf("failed to proxy packet to remote conn: %s", err)
}
}
}
+func (p *WGEBPFProxy) readAndForwardPacket(buf []byte) error {
+ n, addr, err := p.conn.ReadFromUDP(buf)
+ if err != nil {
+ return fmt.Errorf("failed to read UDP packet from WG: %w", err)
+ }
+
+ p.turnConnMutex.Lock()
+ conn, ok := p.turnConnStore[uint16(addr.Port)]
+ p.turnConnMutex.Unlock()
+ if !ok {
+ if p.ctx.Err() == nil {
+ log.Debugf("turn conn not found by port because conn already has been closed: %d", addr.Port)
+ }
+ return nil
+ }
+
+ if _, err := conn.Write(buf[:n]); err != nil {
+ return fmt.Errorf("failed to forward local WG packet (%d) to remote turn conn: %w", addr.Port, err)
+ }
+ return nil
+}
+
func (p *WGEBPFProxy) storeTurnConn(turnConn net.Conn) (uint16, error) {
p.turnConnMutex.Lock()
defer p.turnConnMutex.Unlock()
@@ -206,11 +214,14 @@ func (p *WGEBPFProxy) storeTurnConn(turnConn net.Conn) (uint16, error) {
}
func (p *WGEBPFProxy) removeTurnConn(turnConnID uint16) {
- log.Debugf("remove turn conn from store by port: %d", turnConnID)
p.turnConnMutex.Lock()
defer p.turnConnMutex.Unlock()
- delete(p.turnConnStore, turnConnID)
+ _, ok := p.turnConnStore[turnConnID]
+ if ok {
+ log.Debugf("remove turn conn from store by port: %d", turnConnID)
+ }
+ delete(p.turnConnStore, turnConnID)
}
func (p *WGEBPFProxy) nextFreePort() (uint16, error) {
diff --git a/client/internal/wgproxy/proxy_ebpf_test.go b/client/internal/wgproxy/ebpf/proxy_test.go
similarity index 86%
rename from client/internal/wgproxy/proxy_ebpf_test.go
rename to client/internal/wgproxy/ebpf/proxy_test.go
index 821e64218..b15bc686c 100644
--- a/client/internal/wgproxy/proxy_ebpf_test.go
+++ b/client/internal/wgproxy/ebpf/proxy_test.go
@@ -1,14 +1,13 @@
//go:build linux && !android
-package wgproxy
+package ebpf
import (
- "context"
"testing"
)
func TestWGEBPFProxy_connStore(t *testing.T) {
- wgProxy := NewWGEBPFProxy(context.Background(), 1)
+ wgProxy := NewWGEBPFProxy(1)
p, _ := wgProxy.storeTurnConn(nil)
if p != 1 {
@@ -28,7 +27,7 @@ func TestWGEBPFProxy_connStore(t *testing.T) {
}
func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) {
- wgProxy := NewWGEBPFProxy(context.Background(), 1)
+ wgProxy := NewWGEBPFProxy(1)
_, _ = wgProxy.storeTurnConn(nil)
wgProxy.lastUsedPort = 65535
@@ -44,7 +43,7 @@ func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) {
}
func TestWGEBPFProxy_portCalculation_maxConn(t *testing.T) {
- wgProxy := NewWGEBPFProxy(context.Background(), 1)
+ wgProxy := NewWGEBPFProxy(1)
for i := 0; i < 65535; i++ {
_, _ = wgProxy.storeTurnConn(nil)
diff --git a/client/internal/wgproxy/ebpf/wrapper.go b/client/internal/wgproxy/ebpf/wrapper.go
new file mode 100644
index 000000000..c5639f840
--- /dev/null
+++ b/client/internal/wgproxy/ebpf/wrapper.go
@@ -0,0 +1,44 @@
+//go:build linux && !android
+
+package ebpf
+
+import (
+ "context"
+ "fmt"
+ "net"
+)
+
+// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
+type ProxyWrapper struct {
+ WgeBPFProxy *WGEBPFProxy
+
+ remoteConn net.Conn
+ cancel context.CancelFunc // with thic cancel function, we stop remoteToLocal thread
+}
+
+func (e *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) {
+ ctxConn, cancel := context.WithCancel(ctx)
+ addr, err := e.WgeBPFProxy.AddTurnConn(ctxConn, remoteConn)
+
+ if err != nil {
+ cancel()
+ return nil, fmt.Errorf("add turn conn: %w", err)
+ }
+ e.remoteConn = remoteConn
+ e.cancel = cancel
+ return addr, err
+}
+
+// CloseConn close the remoteConn and automatically remove the conn instance from the map
+func (e *ProxyWrapper) CloseConn() error {
+ if e.cancel == nil {
+ return fmt.Errorf("proxy not started")
+ }
+
+ e.cancel()
+
+ if err := e.remoteConn.Close(); err != nil {
+ return fmt.Errorf("failed to close remote conn: %w", err)
+ }
+ return nil
+}
diff --git a/client/internal/wgproxy/factory.go b/client/internal/wgproxy/factory.go
deleted file mode 100644
index f4eb150b0..000000000
--- a/client/internal/wgproxy/factory.go
+++ /dev/null
@@ -1,22 +0,0 @@
-package wgproxy
-
-import "context"
-
-type Factory struct {
- wgPort int
- ebpfProxy Proxy
-}
-
-func (w *Factory) GetProxy(ctx context.Context) Proxy {
- if w.ebpfProxy != nil {
- return w.ebpfProxy
- }
- return NewWGUserSpaceProxy(ctx, w.wgPort)
-}
-
-func (w *Factory) Free() error {
- if w.ebpfProxy != nil {
- return w.ebpfProxy.Free()
- }
- return nil
-}
diff --git a/client/internal/wgproxy/factory_linux.go b/client/internal/wgproxy/factory_linux.go
index d01ae7e74..369ba99db 100644
--- a/client/internal/wgproxy/factory_linux.go
+++ b/client/internal/wgproxy/factory_linux.go
@@ -3,20 +3,26 @@
package wgproxy
import (
- "context"
-
log "github.com/sirupsen/logrus"
+
+ "github.com/netbirdio/netbird/client/internal/wgproxy/ebpf"
+ "github.com/netbirdio/netbird/client/internal/wgproxy/usp"
)
-func NewFactory(ctx context.Context, userspace bool, wgPort int) *Factory {
+type Factory struct {
+ wgPort int
+ ebpfProxy *ebpf.WGEBPFProxy
+}
+
+func NewFactory(userspace bool, wgPort int) *Factory {
f := &Factory{wgPort: wgPort}
if userspace {
return f
}
- ebpfProxy := NewWGEBPFProxy(ctx, wgPort)
- err := ebpfProxy.listen()
+ ebpfProxy := ebpf.NewWGEBPFProxy(wgPort)
+ err := ebpfProxy.Listen()
if err != nil {
log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err)
return f
@@ -25,3 +31,20 @@ func NewFactory(ctx context.Context, userspace bool, wgPort int) *Factory {
f.ebpfProxy = ebpfProxy
return f
}
+
+func (w *Factory) GetProxy() Proxy {
+ if w.ebpfProxy != nil {
+ p := &ebpf.ProxyWrapper{
+ WgeBPFProxy: w.ebpfProxy,
+ }
+ return p
+ }
+ return usp.NewWGUserSpaceProxy(w.wgPort)
+}
+
+func (w *Factory) Free() error {
+ if w.ebpfProxy == nil {
+ return nil
+ }
+ return w.ebpfProxy.Free()
+}
diff --git a/client/internal/wgproxy/factory_nonlinux.go b/client/internal/wgproxy/factory_nonlinux.go
index d1640c97d..f930b09b3 100644
--- a/client/internal/wgproxy/factory_nonlinux.go
+++ b/client/internal/wgproxy/factory_nonlinux.go
@@ -2,8 +2,20 @@
package wgproxy
-import "context"
+import "github.com/netbirdio/netbird/client/internal/wgproxy/usp"
-func NewFactory(ctx context.Context, _ bool, wgPort int) *Factory {
+type Factory struct {
+ wgPort int
+}
+
+func NewFactory(_ bool, wgPort int) *Factory {
return &Factory{wgPort: wgPort}
}
+
+func (w *Factory) GetProxy() Proxy {
+ return usp.NewWGUserSpaceProxy(w.wgPort)
+}
+
+func (w *Factory) Free() error {
+ return nil
+}
diff --git a/client/internal/wgproxy/proxy.go b/client/internal/wgproxy/proxy.go
index b88df73a0..96fae8dd1 100644
--- a/client/internal/wgproxy/proxy.go
+++ b/client/internal/wgproxy/proxy.go
@@ -1,12 +1,12 @@
package wgproxy
import (
+ "context"
"net"
)
-// Proxy is a transfer layer between the Turn connection and the WireGuard
+// Proxy is a transfer layer between the relayed connection and the WireGuard
type Proxy interface {
- AddTurnConn(turnConn net.Conn) (net.Addr, error)
+ AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error)
CloseConn() error
- Free() error
}
diff --git a/client/internal/wgproxy/proxy_test.go b/client/internal/wgproxy/proxy_test.go
new file mode 100644
index 000000000..b09e6be55
--- /dev/null
+++ b/client/internal/wgproxy/proxy_test.go
@@ -0,0 +1,128 @@
+//go:build linux
+
+package wgproxy
+
+import (
+ "context"
+ "io"
+ "net"
+ "os"
+ "runtime"
+ "testing"
+ "time"
+
+ "github.com/netbirdio/netbird/client/internal/wgproxy/ebpf"
+ "github.com/netbirdio/netbird/client/internal/wgproxy/usp"
+ "github.com/netbirdio/netbird/util"
+)
+
+func TestMain(m *testing.M) {
+ _ = util.InitLog("trace", "console")
+ code := m.Run()
+ os.Exit(code)
+}
+
+type mocConn struct {
+ closeChan chan struct{}
+ closed bool
+}
+
+func newMockConn() *mocConn {
+ return &mocConn{
+ closeChan: make(chan struct{}),
+ }
+}
+
+func (m *mocConn) Read(b []byte) (n int, err error) {
+ <-m.closeChan
+ return 0, io.EOF
+}
+
+func (m *mocConn) Write(b []byte) (n int, err error) {
+ <-m.closeChan
+ return 0, io.EOF
+}
+
+func (m *mocConn) Close() error {
+ if m.closed == true {
+ return nil
+ }
+
+ m.closed = true
+ close(m.closeChan)
+ return nil
+}
+
+func (m *mocConn) LocalAddr() net.Addr {
+ panic("implement me")
+}
+
+func (m *mocConn) RemoteAddr() net.Addr {
+ return &net.UDPAddr{
+ IP: net.ParseIP("172.16.254.1"),
+ }
+}
+
+func (m *mocConn) SetDeadline(t time.Time) error {
+ panic("implement me")
+}
+
+func (m *mocConn) SetReadDeadline(t time.Time) error {
+ panic("implement me")
+}
+
+func (m *mocConn) SetWriteDeadline(t time.Time) error {
+ panic("implement me")
+}
+
+func TestProxyCloseByRemoteConn(t *testing.T) {
+ ctx := context.Background()
+
+ tests := []struct {
+ name string
+ proxy Proxy
+ }{
+ {
+ name: "userspace proxy",
+ proxy: usp.NewWGUserSpaceProxy(51830),
+ },
+ }
+
+ if runtime.GOOS == "linux" && os.Getenv("GITHUB_ACTIONS") != "true" {
+ ebpfProxy := ebpf.NewWGEBPFProxy(51831)
+ if err := ebpfProxy.Listen(); err != nil {
+ t.Fatalf("failed to initialize ebpf proxy: %s", err)
+ }
+ defer func() {
+ if err := ebpfProxy.Free(); err != nil {
+ t.Errorf("failed to free ebpf proxy: %s", err)
+ }
+ }()
+ proxyWrapper := &ebpf.ProxyWrapper{
+ WgeBPFProxy: ebpfProxy,
+ }
+
+ tests = append(tests, struct {
+ name string
+ proxy Proxy
+ }{
+ name: "ebpf proxy",
+ proxy: proxyWrapper,
+ })
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ relayedConn := newMockConn()
+ _, err := tt.proxy.AddTurnConn(ctx, relayedConn)
+ if err != nil {
+ t.Errorf("error: %v", err)
+ }
+
+ _ = relayedConn.Close()
+ if err := tt.proxy.CloseConn(); err != nil {
+ t.Errorf("error: %v", err)
+ }
+ })
+ }
+}
diff --git a/client/internal/wgproxy/proxy_userspace.go b/client/internal/wgproxy/proxy_userspace.go
deleted file mode 100644
index 8fc640b6a..000000000
--- a/client/internal/wgproxy/proxy_userspace.go
+++ /dev/null
@@ -1,129 +0,0 @@
-package wgproxy
-
-import (
- "context"
- "fmt"
- "io"
- "net"
-
- log "github.com/sirupsen/logrus"
-)
-
-// WGUserSpaceProxy proxies
-type WGUserSpaceProxy struct {
- localWGListenPort int
- ctx context.Context
- cancel context.CancelFunc
-
- remoteConn net.Conn
- localConn net.Conn
-}
-
-// NewWGUserSpaceProxy instantiate a user space WireGuard proxy
-func NewWGUserSpaceProxy(ctx context.Context, wgPort int) *WGUserSpaceProxy {
- log.Debugf("Initializing new user space proxy with port %d", wgPort)
- p := &WGUserSpaceProxy{
- localWGListenPort: wgPort,
- }
- p.ctx, p.cancel = context.WithCancel(ctx)
- return p
-}
-
-// AddTurnConn start the proxy with the given remote conn
-func (p *WGUserSpaceProxy) AddTurnConn(remoteConn net.Conn) (net.Addr, error) {
- p.remoteConn = remoteConn
-
- var err error
- dialer := &net.Dialer{}
- p.localConn, err = dialer.DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
- if err != nil {
- log.Errorf("failed dialing to local Wireguard port %s", err)
- return nil, err
- }
-
- go p.proxyToRemote()
- go p.proxyToLocal()
-
- return p.localConn.LocalAddr(), err
-}
-
-// CloseConn close the localConn
-func (p *WGUserSpaceProxy) CloseConn() error {
- p.cancel()
- if p.localConn == nil {
- return nil
- }
-
- if p.remoteConn == nil {
- return nil
- }
-
- if err := p.remoteConn.Close(); err != nil {
- log.Warnf("failed to close remote conn: %s", err)
- }
- return p.localConn.Close()
-}
-
-// Free doing nothing because this implementation of proxy does not have global state
-func (p *WGUserSpaceProxy) Free() error {
- return nil
-}
-
-// proxyToRemote proxies everything from Wireguard to the RemoteKey peer
-// blocks
-func (p *WGUserSpaceProxy) proxyToRemote() {
- defer log.Infof("exit from proxyToRemote: %s", p.localConn.LocalAddr())
-
- buf := make([]byte, 1500)
- for {
- select {
- case <-p.ctx.Done():
- return
- default:
- n, err := p.localConn.Read(buf)
- if err != nil {
- log.Debugf("failed to read from wg interface conn: %s", err)
- continue
- }
-
- _, err = p.remoteConn.Write(buf[:n])
- if err != nil {
- if err == io.EOF {
- p.cancel()
- } else {
- log.Debugf("failed to write to remote conn: %s", err)
- }
- continue
- }
- }
- }
-}
-
-// proxyToLocal proxies everything from the RemoteKey peer to local Wireguard
-// blocks
-func (p *WGUserSpaceProxy) proxyToLocal() {
- defer p.cancel()
- defer log.Infof("exit from proxyToLocal: %s", p.localConn.LocalAddr())
- buf := make([]byte, 1500)
- for {
- select {
- case <-p.ctx.Done():
- return
- default:
- n, err := p.remoteConn.Read(buf)
- if err != nil {
- if err == io.EOF {
- return
- }
- log.Errorf("failed to read from remote conn: %s", err)
- continue
- }
-
- _, err = p.localConn.Write(buf[:n])
- if err != nil {
- log.Debugf("failed to write to wg interface conn: %s", err)
- continue
- }
- }
- }
-}
diff --git a/client/internal/wgproxy/usp/proxy.go b/client/internal/wgproxy/usp/proxy.go
new file mode 100644
index 000000000..83a8725d8
--- /dev/null
+++ b/client/internal/wgproxy/usp/proxy.go
@@ -0,0 +1,146 @@
+package usp
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "sync"
+
+ "github.com/hashicorp/go-multierror"
+ log "github.com/sirupsen/logrus"
+
+ "github.com/netbirdio/netbird/client/errors"
+)
+
+// WGUserSpaceProxy proxies
+type WGUserSpaceProxy struct {
+ localWGListenPort int
+ ctx context.Context
+ cancel context.CancelFunc
+
+ remoteConn net.Conn
+ localConn net.Conn
+ closeMu sync.Mutex
+ closed bool
+}
+
+// NewWGUserSpaceProxy instantiate a user space WireGuard proxy. This is not a thread safe implementation
+func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy {
+ log.Debugf("Initializing new user space proxy with port %d", wgPort)
+ p := &WGUserSpaceProxy{
+ localWGListenPort: wgPort,
+ }
+ return p
+}
+
+// AddTurnConn start the proxy with the given remote conn
+func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) {
+ p.ctx, p.cancel = context.WithCancel(ctx)
+
+ p.remoteConn = remoteConn
+
+ var err error
+ dialer := net.Dialer{}
+ p.localConn, err = dialer.DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
+ if err != nil {
+ log.Errorf("failed dialing to local Wireguard port %s", err)
+ return nil, err
+ }
+
+ go p.proxyToRemote()
+ go p.proxyToLocal()
+
+ return p.localConn.LocalAddr(), err
+}
+
+// CloseConn close the localConn
+func (p *WGUserSpaceProxy) CloseConn() error {
+ if p.cancel == nil {
+ return fmt.Errorf("proxy not started")
+ }
+ return p.close()
+}
+
+func (p *WGUserSpaceProxy) close() error {
+ p.closeMu.Lock()
+ defer p.closeMu.Unlock()
+
+ // prevent double close
+ if p.closed {
+ return nil
+ }
+ p.closed = true
+
+ p.cancel()
+
+ var result *multierror.Error
+ if err := p.remoteConn.Close(); err != nil {
+ result = multierror.Append(result, fmt.Errorf("remote conn: %s", err))
+ }
+
+ if err := p.localConn.Close(); err != nil {
+ result = multierror.Append(result, fmt.Errorf("local conn: %s", err))
+ }
+ return errors.FormatErrorOrNil(result)
+}
+
+// proxyToRemote proxies from Wireguard to the RemoteKey
+func (p *WGUserSpaceProxy) proxyToRemote() {
+ defer func() {
+ if err := p.close(); err != nil {
+ log.Warnf("error in proxy to remote loop: %s", err)
+ }
+ }()
+
+ buf := make([]byte, 1500)
+ for p.ctx.Err() == nil {
+ n, err := p.localConn.Read(buf)
+ if err != nil {
+ if p.ctx.Err() != nil {
+ return
+ }
+ log.Debugf("failed to read from wg interface conn: %s", err)
+ return
+ }
+
+ _, err = p.remoteConn.Write(buf[:n])
+ if err != nil {
+ if p.ctx.Err() != nil {
+ return
+ }
+
+ log.Debugf("failed to write to remote conn: %s", err)
+ return
+ }
+ }
+}
+
+// proxyToLocal proxies from the Remote peer to local WireGuard
+func (p *WGUserSpaceProxy) proxyToLocal() {
+ defer func() {
+ if err := p.close(); err != nil {
+ log.Warnf("error in proxy to local loop: %s", err)
+ }
+ }()
+
+ buf := make([]byte, 1500)
+ for p.ctx.Err() == nil {
+ n, err := p.remoteConn.Read(buf)
+ if err != nil {
+ if p.ctx.Err() != nil {
+ return
+ }
+ log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
+ return
+ }
+
+ _, err = p.localConn.Write(buf[:n])
+ if err != nil {
+ if p.ctx.Err() != nil {
+ return
+ }
+ log.Debugf("failed to write to wg interface conn: %s", err)
+ continue
+ }
+ }
+}
diff --git a/relay/client/picker_test.go b/relay/client/picker_test.go
index f5649d700..eb14581e0 100644
--- a/relay/client/picker_test.go
+++ b/relay/client/picker_test.go
@@ -13,7 +13,7 @@ func TestServerPicker_UnavailableServers(t *testing.T) {
PeerID: "test",
}
- ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
go func() {
From acb73bd64abeca70d1ae39d0c088fbf2ffcb80aa Mon Sep 17 00:00:00 2001
From: Bethuel Mmbaga
Date: Fri, 27 Sep 2024 17:10:50 +0300
Subject: [PATCH 72/89] [management] Remove redundant get account calls in
GetAccountFromToken (#2615)
* refactor access control middleware and user access by JWT groups
Signed-off-by: bcmmbaga
* refactor jwt groups extractor
Signed-off-by: bcmmbaga
* refactor handlers to get account when necessary
Signed-off-by: bcmmbaga
* refactor getAccountFromToken
Signed-off-by: bcmmbaga
* refactor getAccountWithAuthorizationClaims
Signed-off-by: bcmmbaga
* fix merge
Signed-off-by: bcmmbaga
* revert handles change
Signed-off-by: bcmmbaga
* remove GetUserByID from account manager
Signed-off-by: bcmmbaga
* fix tests
Signed-off-by: bcmmbaga
* refactor getAccountWithAuthorizationClaims to return account id
Signed-off-by: bcmmbaga
* refactor handlers to use GetAccountIDFromToken
Signed-off-by: bcmmbaga
* fix tests
Signed-off-by: bcmmbaga
* remove locks
Signed-off-by: bcmmbaga
* refactor
Signed-off-by: bcmmbaga
* add GetGroupByName from store
Signed-off-by: bcmmbaga
* add GetGroupByID from store and refactor
Signed-off-by: bcmmbaga
* Refactor retrieval of policy and posture checks
Signed-off-by: bcmmbaga
* Refactor user permissions and retrieves PAT
Signed-off-by: bcmmbaga
* Refactor route, setupkey, nameserver and dns to get record(s) from store
Signed-off-by: bcmmbaga
* Refactor store
Signed-off-by: bcmmbaga
* fix lint
Signed-off-by: bcmmbaga
* fix tests
Signed-off-by: bcmmbaga
* fix add missing policy source posture checks
Signed-off-by: bcmmbaga
* add store lock
Signed-off-by: bcmmbaga
* fix tests
Signed-off-by: bcmmbaga
* add get account
Signed-off-by: bcmmbaga
---------
Signed-off-by: bcmmbaga
---
management/server/account.go | 395 +++++++++++-------
management/server/account_test.go | 107 +++--
management/server/dns.go | 16 +-
management/server/file_store.go | 101 ++++-
management/server/group.go | 106 ++---
management/server/grpcserver.go | 2 +-
management/server/http/accounts_handler.go | 46 +-
.../server/http/accounts_handler_test.go | 7 +-
.../server/http/dns_settings_handler.go | 8 +-
.../server/http/dns_settings_handler_test.go | 4 +-
management/server/http/events_handler.go | 6 +-
management/server/http/events_handler_test.go | 14 +-
.../server/http/geolocation_handler_test.go | 15 +-
.../server/http/geolocations_handler.go | 7 +-
management/server/http/groups_handler.go | 119 +++---
management/server/http/groups_handler_test.go | 65 ++-
management/server/http/nameservers_handler.go | 20 +-
.../server/http/nameservers_handler_test.go | 13 +-
management/server/http/pat_handler.go | 18 +-
management/server/http/pat_handler_test.go | 6 +-
management/server/http/peers_handler.go | 52 ++-
management/server/http/peers_handler_test.go | 28 +-
management/server/http/policies_handler.go | 174 ++++----
.../server/http/policies_handler_test.go | 16 +-
.../server/http/posture_checks_handler.go | 46 +-
.../http/posture_checks_handler_test.go | 12 +-
management/server/http/routes_handler.go | 40 +-
management/server/http/routes_handler_test.go | 16 +-
management/server/http/setupkeys_handler.go | 18 +-
.../server/http/setupkeys_handler_test.go | 18 +-
management/server/http/users_handler.go | 32 +-
management/server/http/users_handler_test.go | 7 +-
management/server/mock_server/account_mock.go | 70 +++-
management/server/nameserver.go | 42 +-
management/server/peer_test.go | 4 +-
management/server/policy.go | 107 +++--
management/server/posture_checks.go | 34 +-
management/server/route.go | 52 +--
management/server/route_test.go | 2 +-
management/server/setupkey.go | 56 +--
management/server/sql_store.go | 189 ++++++++-
management/server/store.go | 73 +++-
management/server/user.go | 86 ++--
management/server/user_test.go | 11 +-
44 files changed, 1279 insertions(+), 981 deletions(-)
diff --git a/management/server/account.go b/management/server/account.go
index 208315643..710b6f62f 100644
--- a/management/server/account.go
+++ b/management/server/account.go
@@ -20,11 +20,6 @@ import (
cacheStore "github.com/eko/gocache/v3/store"
"github.com/hashicorp/go-multierror"
"github.com/miekg/dns"
- gocache "github.com/patrickmn/go-cache"
- "github.com/rs/xid"
- log "github.com/sirupsen/logrus"
- "golang.org/x/exp/maps"
-
"github.com/netbirdio/netbird/base62"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
@@ -41,6 +36,10 @@ import (
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/route"
+ gocache "github.com/patrickmn/go-cache"
+ "github.com/rs/xid"
+ log "github.com/sirupsen/logrus"
+ "golang.org/x/exp/maps"
)
const (
@@ -63,6 +62,7 @@ func cacheEntryExpiration() time.Duration {
type AccountManager interface {
GetOrCreateAccountByUser(ctx context.Context, userId, domain string) (*Account, error)
+ GetAccount(ctx context.Context, accountID string) (*Account, error)
CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration,
autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error)
SaveSetupKey(ctx context.Context, accountID string, key *SetupKey, userID string) (*SetupKey, error)
@@ -75,12 +75,14 @@ type AccountManager interface {
SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error)
SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error)
GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error)
- GetAccountByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (*Account, error)
- GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, *User, error)
+ GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error)
+ GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error)
+ GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error)
DeleteAccount(ctx context.Context, accountID, userID string) error
MarkPATUsed(ctx context.Context, tokenID string) error
+ GetUserByID(ctx context.Context, id string) (*User, error)
GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error)
ListUsers(ctx context.Context, accountID string) ([]*User, error)
GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
@@ -107,7 +109,7 @@ type AccountManager interface {
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error
GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error)
- SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) error
+ SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error
DeletePolicy(ctx context.Context, accountID, policyID, userID string) error
ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error)
GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
@@ -145,6 +147,7 @@ type AccountManager interface {
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error)
+ GetAccountSettings(ctx context.Context, accountID string, userID string) (*Settings, error)
}
type DefaultAccountManager struct {
@@ -268,6 +271,11 @@ type AccountNetwork struct {
Network *Network `gorm:"embedded;embeddedPrefix:network_"`
}
+// AccountDNSSettings used in gorm to only load dns settings and not whole account
+type AccountDNSSettings struct {
+ DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"`
+}
+
type UserPermissions struct {
DashboardView string `json:"dashboard_view"`
}
@@ -1252,25 +1260,37 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
return nil
}
-// GetAccountByUserOrAccountID looks for an account by user or accountID, if no account is provided and
-// userID doesn't have an account associated with it, one account is created
-// domain is used to create a new account if no account is found
-func (am *DefaultAccountManager) GetAccountByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (*Account, error) {
+// GetAccountIDByUserOrAccountID retrieves the account ID based on either the userID or accountID provided.
+// If an accountID is provided, it checks if the account exists and returns it.
+// If no accountID is provided, but a userID is given, it tries to retrieve the account by userID.
+// If the user doesn't have an account, it creates one using the provided domain.
+// Returns the account ID or an error if none is found or created.
+func (am *DefaultAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) {
if accountID != "" {
- return am.Store.GetAccount(ctx, accountID)
- } else if userID != "" {
- account, err := am.GetOrCreateAccountByUser(ctx, userID, domain)
+ exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, accountID)
if err != nil {
- return nil, status.Errorf(status.NotFound, "account not found using user id: %s", userID)
+ return "", err
}
- err = am.addAccountIDToIDPAppMeta(ctx, userID, account)
- if err != nil {
- return nil, err
+ if !exists {
+ return "", status.Errorf(status.NotFound, "account %s does not exist", accountID)
}
- return account, nil
+ return accountID, nil
}
- return nil, status.Errorf(status.NotFound, "no valid user or account Id provided")
+ if userID != "" {
+ account, err := am.GetOrCreateAccountByUser(ctx, userID, domain)
+ if err != nil {
+ return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID)
+ }
+
+ if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil {
+ return "", err
+ }
+
+ return account.Id, nil
+ }
+
+ return "", status.Errorf(status.NotFound, "no valid userID or accountID provided")
}
func isNil(i idp.Manager) bool {
@@ -1613,13 +1633,18 @@ func (am *DefaultAccountManager) handleNewUserAccount(ctx context.Context, domai
}
// redeemInvite checks whether user has been invited and redeems the invite
-func (am *DefaultAccountManager) redeemInvite(ctx context.Context, account *Account, userID string) error {
+func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID string, userID string) error {
// only possible with the enabled IdP manager
if am.idpManager == nil {
log.WithContext(ctx).Warnf("invites only work with enabled IdP manager")
return nil
}
+ account, err := am.Store.GetAccount(ctx, accountID)
+ if err != nil {
+ return err
+ }
+
user, err := am.lookupUserInCache(ctx, userID, account)
if err != nil {
return err
@@ -1678,6 +1703,11 @@ func (am *DefaultAccountManager) MarkPATUsed(ctx context.Context, tokenID string
return am.Store.SaveAccount(ctx, account)
}
+// GetAccount returns an account associated with this account ID.
+func (am *DefaultAccountManager) GetAccount(ctx context.Context, accountID string) (*Account, error) {
+ return am.Store.GetAccount(ctx, accountID)
+}
+
// GetAccountFromPAT returns Account and User associated with a personal access token
func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token string) (*Account, *User, *PersonalAccessToken, error) {
if len(token) != PATLength {
@@ -1726,10 +1756,24 @@ func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token st
return account, user, pat, nil
}
-// GetAccountFromToken returns an account associated with this token
-func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, *User, error) {
+// GetAccountByID returns an account associated with this account ID.
+func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error) {
+ user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
+ if err != nil {
+ return nil, err
+ }
+
+ if user.AccountID != accountID || (!user.HasAdminPower() && !user.IsServiceUser) {
+ return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data")
+ }
+
+ return am.Store.GetAccount(ctx, accountID)
+}
+
+// GetAccountIDFromToken returns an account ID associated with this token.
+func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
if claims.UserId == "" {
- return nil, nil, fmt.Errorf("user ID is empty")
+ return "", "", fmt.Errorf("user ID is empty")
}
if am.singleAccountMode && am.singleAccountModeDomain != "" {
// This section is mostly related to self-hosted installations.
@@ -1739,110 +1783,111 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims
log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled")
}
- newAcc, err := am.getAccountWithAuthorizationClaims(ctx, claims)
+ accountID, err := am.getAccountIDWithAuthorizationClaims(ctx, claims)
if err != nil {
- return nil, nil, err
- }
- unlock := am.Store.AcquireWriteLockByUID(ctx, newAcc.Id)
- alreadyUnlocked := false
- defer func() {
- if !alreadyUnlocked {
- unlock()
- }
- }()
-
- account, err := am.Store.GetAccount(ctx, newAcc.Id)
- if err != nil {
- return nil, nil, err
+ return "", "", err
}
- user := account.Users[claims.UserId]
- if user == nil {
+ user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId)
+ if err != nil {
// this is not really possible because we got an account by user ID
- return nil, nil, status.Errorf(status.NotFound, "user %s not found", claims.UserId)
+ return "", "", status.Errorf(status.NotFound, "user %s not found", claims.UserId)
}
if !user.IsServiceUser && claims.Invited {
- err = am.redeemInvite(ctx, account, claims.UserId)
+ err = am.redeemInvite(ctx, accountID, user.Id)
if err != nil {
- return nil, nil, err
+ return "", "", err
}
}
- if account.Settings.JWTGroupsEnabled {
- if account.Settings.JWTGroupsClaimName == "" {
- log.WithContext(ctx).Errorf("JWT groups are enabled but no claim name is set")
- return account, user, nil
- }
- if claim, ok := claims.Raw[account.Settings.JWTGroupsClaimName]; ok {
- if slice, ok := claim.([]interface{}); ok {
- var groupsNames []string
- for _, item := range slice {
- if g, ok := item.(string); ok {
- groupsNames = append(groupsNames, g)
- } else {
- log.WithContext(ctx).Errorf("JWT claim %q is not a string: %v", account.Settings.JWTGroupsClaimName, item)
- }
- }
-
- oldGroups := make([]string, len(user.AutoGroups))
- copy(oldGroups, user.AutoGroups)
- // if groups were added or modified, save the account
- if account.SetJWTGroups(claims.UserId, groupsNames) {
- if account.Settings.GroupsPropagationEnabled {
- if user, err := account.FindUser(claims.UserId); err == nil {
- addNewGroups := difference(user.AutoGroups, oldGroups)
- removeOldGroups := difference(oldGroups, user.AutoGroups)
- account.UserGroupsAddToPeers(claims.UserId, addNewGroups...)
- account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...)
- account.Network.IncSerial()
- if err := am.Store.SaveAccount(ctx, account); err != nil {
- log.WithContext(ctx).Errorf("failed to save account: %v", err)
- } else {
- log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
- am.updateAccountPeers(ctx, account)
- unlock()
- alreadyUnlocked = true
- for _, g := range addNewGroups {
- if group := account.GetGroup(g); group != nil {
- am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupAddedToUser,
- map[string]any{
- "group": group.Name,
- "group_id": group.ID,
- "is_service_user": user.IsServiceUser,
- "user_name": user.ServiceUserName})
- }
- }
- for _, g := range removeOldGroups {
- if group := account.GetGroup(g); group != nil {
- am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupRemovedFromUser,
- map[string]any{
- "group": group.Name,
- "group_id": group.ID,
- "is_service_user": user.IsServiceUser,
- "user_name": user.ServiceUserName})
- }
- }
- }
- }
- } else {
- if err := am.Store.SaveAccount(ctx, account); err != nil {
- log.WithContext(ctx).Errorf("failed to save account: %v", err)
- }
- }
- }
- } else {
- log.WithContext(ctx).Debugf("JWT claim %q is not a string array", account.Settings.JWTGroupsClaimName)
- }
- } else {
- log.WithContext(ctx).Debugf("JWT claim %q not found", account.Settings.JWTGroupsClaimName)
- }
+ if err = am.syncJWTGroups(ctx, accountID, user, claims); err != nil {
+ return "", "", err
}
- return account, user, nil
+ return accountID, user.Id, nil
}
-// getAccountWithAuthorizationClaims retrievs an account using JWT Claims.
+// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
+// and propagates changes to peers if group propagation is enabled.
+func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, user *User, claims jwtclaims.AuthorizationClaims) error {
+ settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
+ if err != nil {
+ return err
+ }
+
+ if settings == nil || !settings.JWTGroupsEnabled {
+ return nil
+ }
+
+ if settings.JWTGroupsClaimName == "" {
+ log.WithContext(ctx).Errorf("JWT groups are enabled but no claim name is set")
+ return nil
+ }
+
+ // TODO: Remove GetAccount after refactoring account peer's update
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
+ defer unlock()
+
+ account, err := am.Store.GetAccount(ctx, accountID)
+ if err != nil {
+ return err
+ }
+
+ jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
+
+ oldGroups := make([]string, len(user.AutoGroups))
+ copy(oldGroups, user.AutoGroups)
+
+ // Update the account if group membership changes
+ if account.SetJWTGroups(claims.UserId, jwtGroupsNames) {
+ addNewGroups := difference(user.AutoGroups, oldGroups)
+ removeOldGroups := difference(oldGroups, user.AutoGroups)
+
+ if settings.GroupsPropagationEnabled {
+ account.UserGroupsAddToPeers(claims.UserId, addNewGroups...)
+ account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...)
+ account.Network.IncSerial()
+ }
+
+ if err := am.Store.SaveAccount(ctx, account); err != nil {
+ log.WithContext(ctx).Errorf("failed to save account: %v", err)
+ return nil
+ }
+
+ // Propagate changes to peers if group propagation is enabled
+ if settings.GroupsPropagationEnabled {
+ log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
+ am.updateAccountPeers(ctx, account)
+ }
+
+ for _, g := range addNewGroups {
+ if group := account.GetGroup(g); group != nil {
+ am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupAddedToUser,
+ map[string]any{
+ "group": group.Name,
+ "group_id": group.ID,
+ "is_service_user": user.IsServiceUser,
+ "user_name": user.ServiceUserName})
+ }
+ }
+
+ for _, g := range removeOldGroups {
+ if group := account.GetGroup(g); group != nil {
+ am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupRemovedFromUser,
+ map[string]any{
+ "group": group.Name,
+ "group_id": group.ID,
+ "is_service_user": user.IsServiceUser,
+ "user_name": user.ServiceUserName})
+ }
+ }
+ }
+
+ return nil
+}
+
+// getAccountIDWithAuthorizationClaims retrieves an account ID using JWT Claims.
// if domain is of the PrivateCategory category, it will evaluate
// if account is new, existing or if there is another account with the same domain
//
@@ -1859,26 +1904,34 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims
// Existing user + Existing account + Existing Indexed Domain -> Nothing changes
//
// Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain)
-func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, error) {
+func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) {
log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"",
claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory)
if claims.UserId == "" {
- return nil, fmt.Errorf("user ID is empty")
+ return "", fmt.Errorf("user ID is empty")
}
+
// if Account ID is part of the claims
// it means that we've already classified the domain and user has an account
if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
- return am.GetAccountByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain)
+ return am.GetAccountIDByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain)
} else if claims.AccountId != "" {
- accountFromID, err := am.Store.GetAccount(ctx, claims.AccountId)
+ userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
if err != nil {
- return nil, err
+ return "", err
}
- if _, ok := accountFromID.Users[claims.UserId]; !ok {
- return nil, fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
+
+ if userAccountID != claims.AccountId {
+ return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
}
- if accountFromID.DomainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || accountFromID.Domain != claims.Domain {
- return accountFromID, nil
+
+ domain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, claims.AccountId)
+ if err != nil {
+ return "", err
+ }
+
+ if domainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || domain != claims.Domain {
+ return userAccountID, nil
}
}
@@ -1888,48 +1941,53 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C
log.WithContext(ctx).Debugf("Acquired global lock in %s for user %s", time.Since(start), claims.UserId)
// We checked if the domain has a primary account already
- domainAccount, err := am.Store.GetAccountByPrivateDomain(ctx, claims.Domain)
+ domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, claims.Domain)
if err != nil {
// if NotFound we are good to continue, otherwise return error
e, ok := status.FromError(err)
if !ok || e.Type() != status.NotFound {
- return nil, err
+ return "", err
}
}
- account, err := am.Store.GetAccountByUser(ctx, claims.UserId)
+ userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
if err == nil {
- unlockAccount := am.Store.AcquireWriteLockByUID(ctx, account.Id)
+ unlockAccount := am.Store.AcquireWriteLockByUID(ctx, userAccountID)
defer unlockAccount()
- account, err = am.Store.GetAccountByUser(ctx, claims.UserId)
+ account, err := am.Store.GetAccountByUser(ctx, claims.UserId)
if err != nil {
- return nil, err
+ return "", err
}
// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise,
// we compare the account's ID with the domain account ID, and if they don't match, we set the account as
// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain
// was previously unclassified or classified as public so N users that logged int that time, has they own account
// and peers that shouldn't be lost.
- primaryDomain := domainAccount == nil || account.Id == domainAccount.Id
-
- err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims)
- if err != nil {
- return nil, err
+ primaryDomain := domainAccountID == "" || account.Id == domainAccountID
+ if err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims); err != nil {
+ return "", err
}
- return account, nil
+
+ return account.Id, nil
} else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
- if domainAccount != nil {
- unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccount.Id)
+ var domainAccount *Account
+ if domainAccountID != "" {
+ unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID)
defer unlockAccount()
domainAccount, err = am.Store.GetAccountByPrivateDomain(ctx, claims.Domain)
if err != nil {
- return nil, err
+ return "", err
}
}
- return am.handleNewUserAccount(ctx, domainAccount, claims)
+
+ account, err := am.handleNewUserAccount(ctx, domainAccount, claims)
+ if err != nil {
+ return "", err
+ }
+ return account.Id, nil
} else {
// other error
- return nil, err
+ return "", err
}
}
@@ -2022,26 +2080,21 @@ func (am *DefaultAccountManager) GetDNSDomain() string {
// CheckUserAccessByJWTGroups checks if the user has access, particularly in cases where the admin enabled JWT
// group propagation and set the list of groups with access permissions.
func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error {
- account, _, err := am.GetAccountFromToken(ctx, claims)
+ accountID, _, err := am.GetAccountIDFromToken(ctx, claims)
+ if err != nil {
+ return err
+ }
+
+ settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
// Ensures JWT group synchronization to the management is enabled before,
// filtering access based on the allowed groups.
- if account.Settings != nil && account.Settings.JWTGroupsEnabled {
- if allowedGroups := account.Settings.JWTAllowGroups; len(allowedGroups) > 0 {
- userJWTGroups := make([]string, 0)
-
- if claim, ok := claims.Raw[account.Settings.JWTGroupsClaimName]; ok {
- if claimGroups, ok := claim.([]interface{}); ok {
- for _, g := range claimGroups {
- if group, ok := g.(string); ok {
- userJWTGroups = append(userJWTGroups, group)
- }
- }
- }
- }
+ if settings != nil && settings.JWTGroupsEnabled {
+ if allowedGroups := settings.JWTAllowGroups; len(allowedGroups) > 0 {
+ userJWTGroups := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
if !userHasAllowedGroup(allowedGroups, userJWTGroups) {
return fmt.Errorf("user does not belong to any of the allowed JWT groups")
@@ -2111,6 +2164,19 @@ func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, store Stor
return newLabel, nil
}
+func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*Settings, error) {
+ user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
+ if err != nil {
+ return nil, err
+ }
+
+ if user.AccountID != accountID || (!user.HasAdminPower() && !user.IsServiceUser) {
+ return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data")
+ }
+
+ return am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
+}
+
// addAllGroup to account object if it doesn't exist
func addAllGroup(account *Account) error {
if len(account.Groups) == 0 {
@@ -2193,6 +2259,27 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Ac
return acc
}
+// extractJWTGroups extracts the group names from a JWT token's claims.
+func extractJWTGroups(ctx context.Context, claimName string, claims jwtclaims.AuthorizationClaims) []string {
+ userJWTGroups := make([]string, 0)
+
+ if claim, ok := claims.Raw[claimName]; ok {
+ if claimGroups, ok := claim.([]interface{}); ok {
+ for _, g := range claimGroups {
+ if group, ok := g.(string); ok {
+ userJWTGroups = append(userJWTGroups, group)
+ } else {
+ log.WithContext(ctx).Debugf("JWT claim %q contains a non-string group (type: %T): %v", claimName, g, g)
+ }
+ }
+ }
+ } else {
+ log.WithContext(ctx).Debugf("JWT claim %q is not a string array", claimName)
+ }
+
+ return userJWTGroups
+}
+
// userHasAllowedGroup checks if a user belongs to any of the allowed groups.
func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool {
for _, userGroup := range userGroups {
diff --git a/management/server/account_test.go b/management/server/account_test.go
index 03b5fa83e..303261bea 100644
--- a/management/server/account_test.go
+++ b/management/server/account_test.go
@@ -462,7 +462,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
assert.Equal(t, account.Id, ev.TargetID)
}
-func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
+func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
type initUserParams jwtclaims.AuthorizationClaims
type test struct {
@@ -633,9 +633,12 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
- initAccount, err := manager.GetAccountByUserOrAccountID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain)
+ accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain)
require.NoError(t, err, "create init user failed")
+ initAccount, err := manager.Store.GetAccount(context.Background(), accountID)
+ require.NoError(t, err, "get init account failed")
+
if testCase.inputUpdateAttrs {
err = manager.updateAccountDomainAttributes(context.Background(), initAccount, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true)
require.NoError(t, err, "update init user failed")
@@ -645,8 +648,12 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
testCase.inputClaims.AccountId = initAccount.Id
}
- account, _, err := manager.GetAccountFromToken(context.Background(), testCase.inputClaims)
+ accountID, _, err = manager.GetAccountIDFromToken(context.Background(), testCase.inputClaims)
require.NoError(t, err, "support function failed")
+
+ account, err := manager.Store.GetAccount(context.Background(), accountID)
+ require.NoError(t, err, "get account failed")
+
verifyNewAccountHasDefaultFields(t, account, testCase.expectedCreatedBy, testCase.inputClaims.Domain, testCase.expectedUsers)
verifyCanAddPeerToAccount(t, manager, account, testCase.expectedCreatedBy)
@@ -669,12 +676,13 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
require.NoError(t, err, "unable to create account manager")
accountID := initAccount.Id
- acc, err := manager.GetAccountByUserOrAccountID(context.Background(), userId, accountID, domain)
+ accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userId, accountID, domain)
require.NoError(t, err, "create init user failed")
// as initAccount was created without account id we have to take the id after account initialization
- // that happens inside the GetAccountByUserOrAccountID where the id is getting generated
+ // that happens inside the GetAccountIDByUserOrAccountID where the id is getting generated
// it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it
- initAccount = acc
+ initAccount, err = manager.Store.GetAccount(context.Background(), accountID)
+ require.NoError(t, err, "get init account failed")
claims := jwtclaims.AuthorizationClaims{
AccountId: accountID, // is empty as it is based on accountID right after initialization of initAccount
@@ -685,8 +693,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
}
t.Run("JWT groups disabled", func(t *testing.T) {
- account, _, err := manager.GetAccountFromToken(context.Background(), claims)
+ accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
require.NoError(t, err, "get account by token failed")
+
+ account, err := manager.Store.GetAccount(context.Background(), accountID)
+ require.NoError(t, err, "get account failed")
+
require.Len(t, account.Groups, 1, "only ALL group should exists")
})
@@ -696,8 +708,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
require.NoError(t, err, "save account failed")
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
- account, _, err := manager.GetAccountFromToken(context.Background(), claims)
+ accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
require.NoError(t, err, "get account by token failed")
+
+ account, err := manager.Store.GetAccount(context.Background(), accountID)
+ require.NoError(t, err, "get account failed")
+
require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT")
})
@@ -708,8 +724,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
require.NoError(t, err, "save account failed")
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
- account, _, err := manager.GetAccountFromToken(context.Background(), claims)
+ accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
require.NoError(t, err, "get account by token failed")
+
+ account, err := manager.Store.GetAccount(context.Background(), accountID)
+ require.NoError(t, err, "get account failed")
+
require.Len(t, account.Groups, 3, "groups should be added to the account")
groupsByNames := map[string]*group.Group{}
@@ -874,21 +894,21 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
userId := "test_user"
- account, err := manager.GetAccountByUserOrAccountID(context.Background(), userId, "", "")
+ accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userId, "", "")
if err != nil {
t.Fatal(err)
}
- if account == nil {
+ if accountID == "" {
t.Fatalf("expected to create an account for a user %s", userId)
return
}
- _, err = manager.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "")
+ _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "")
if err != nil {
- t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", account.Id)
+ t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", accountID)
}
- _, err = manager.GetAccountByUserOrAccountID(context.Background(), "", "", "")
+ _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", "", "")
if err == nil {
t.Errorf("expected an error when user and account IDs are empty")
}
@@ -1240,7 +1260,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
}
}()
- if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy); err != nil {
+ if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil {
t.Errorf("delete default rule: %v", err)
return
}
@@ -1648,19 +1668,22 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
- account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
+ accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to create an account")
- assert.NotNil(t, account.Settings)
- assert.Equal(t, account.Settings.PeerLoginExpirationEnabled, true)
- assert.Equal(t, account.Settings.PeerLoginExpiration, 24*time.Hour)
+ settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
+ require.NoError(t, err, "unable to get account settings")
+
+ assert.NotNil(t, settings)
+ assert.Equal(t, settings.PeerLoginExpirationEnabled, true)
+ assert.Equal(t, settings.PeerLoginExpiration, 24*time.Hour)
}
func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
- _, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
+ _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey()
@@ -1672,11 +1695,16 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
})
require.NoError(t, err, "unable to add peer")
- account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
+ accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to get the account")
+
+ account, err := manager.Store.GetAccount(context.Background(), accountID)
+ require.NoError(t, err, "unable to get the account")
+
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
require.NoError(t, err, "unable to mark peer connected")
- account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
+
+ account, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true,
})
@@ -1713,7 +1741,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
- account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
+ accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey()
@@ -1724,7 +1752,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
LoginExpirationEnabled: true,
})
require.NoError(t, err, "unable to add peer")
- _, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
+ _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true,
})
@@ -1741,8 +1769,12 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
},
}
- account, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
+ accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to get the account")
+
+ account, err := manager.Store.GetAccount(context.Background(), accountID)
+ require.NoError(t, err, "unable to get the account")
+
// when we mark peer as connected, the peer login expiration routine should trigger
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
require.NoError(t, err, "unable to mark peer connected")
@@ -1757,7 +1789,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
- _, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
+ _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey()
@@ -1769,8 +1801,12 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
})
require.NoError(t, err, "unable to add peer")
- account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
+ accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to get the account")
+
+ account, err := manager.Store.GetAccount(context.Background(), accountID)
+ require.NoError(t, err, "unable to get the account")
+
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
require.NoError(t, err, "unable to mark peer connected")
@@ -1813,10 +1849,10 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
- account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
+ accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to create an account")
- updated, err := manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
+ updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: false,
})
@@ -1824,19 +1860,22 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
assert.False(t, updated.Settings.PeerLoginExpirationEnabled)
assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour)
- account, err = manager.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "")
+ accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "")
require.NoError(t, err, "unable to get account by ID")
- assert.False(t, account.Settings.PeerLoginExpirationEnabled)
- assert.Equal(t, account.Settings.PeerLoginExpiration, time.Hour)
+ settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
+ require.NoError(t, err, "unable to get account settings")
- _, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
+ assert.False(t, settings.PeerLoginExpirationEnabled)
+ assert.Equal(t, settings.PeerLoginExpiration, time.Hour)
+
+ _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
PeerLoginExpiration: time.Second,
PeerLoginExpirationEnabled: false,
})
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour")
- _, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
+ _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
PeerLoginExpiration: time.Hour * 24 * 181,
PeerLoginExpirationEnabled: false,
})
diff --git a/management/server/dns.go b/management/server/dns.go
index 1d156c90a..7410aaa15 100644
--- a/management/server/dns.go
+++ b/management/server/dns.go
@@ -80,24 +80,16 @@ func (d DNSSettings) Copy() DNSSettings {
// GetDNSSettings validates a user role and returns the DNS settings for the provided account ID
func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error) {
- unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
- defer unlock()
-
- account, err := am.Store.GetAccount(ctx, accountID)
+ user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
- user, err := account.FindUser(userID)
- if err != nil {
- return nil, err
- }
-
- if !(user.HasAdminPower() || user.IsServiceUser) {
+ if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view DNS settings")
}
- dnsSettings := account.DNSSettings.Copy()
- return &dnsSettings, nil
+
+ return am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID)
}
// SaveDNSSettings validates a user role and updates the account's DNS settings
diff --git a/management/server/file_store.go b/management/server/file_store.go
index 95d5b4e6e..994a4b1ee 100644
--- a/management/server/file_store.go
+++ b/management/server/file_store.go
@@ -10,14 +10,15 @@ import (
"sync"
"time"
- "github.com/rs/xid"
- log "github.com/sirupsen/logrus"
-
+ "github.com/netbirdio/netbird/dns"
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/telemetry"
+ "github.com/netbirdio/netbird/route"
+ "github.com/rs/xid"
+ log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/util"
)
@@ -634,10 +635,19 @@ func (s *FileStore) GetUserByUserID(_ context.Context, _ LockingStrength, userID
return nil, err
}
- return account.Users[userID].Copy(), nil
+ user := account.Users[userID].Copy()
+ pat := make([]PersonalAccessToken, 0, len(user.PATs))
+ for _, token := range user.PATs {
+ if token != nil {
+ pat = append(pat, *token)
+ }
+ }
+ user.PATsG = pat
+
+ return user, nil
}
-func (s *FileStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
+func (s *FileStore) GetAccountGroups(_ context.Context, accountID string) ([]*nbgroup.Group, error) {
account, err := s.getAccount(accountID)
if err != nil {
return nil, err
@@ -931,7 +941,7 @@ func (s *FileStore) SaveUserLastLogin(_ context.Context, accountID, userID strin
return nil
}
-func (s *FileStore) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
+func (s *FileStore) GetPostureCheckByChecksDefinition(_ string, _ *posture.ChecksDefinition) (*posture.Checks, error) {
return nil, status.Errorf(status.Internal, "GetPostureCheckByChecksDefinition is not implemented")
}
@@ -950,10 +960,85 @@ func (s *FileStore) GetStoreEngine() StoreEngine {
return FileStoreEngine
}
-func (s *FileStore) SaveUsers(accountID string, users map[string]*User) error {
+func (s *FileStore) SaveUsers(_ string, _ map[string]*User) error {
return status.Errorf(status.Internal, "SaveUsers is not implemented")
}
-func (s *FileStore) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error {
+func (s *FileStore) SaveGroups(_ string, _ map[string]*nbgroup.Group) error {
return status.Errorf(status.Internal, "SaveGroups is not implemented")
}
+
+func (s *FileStore) GetAccountIDByPrivateDomain(_ context.Context, _ LockingStrength, _ string) (string, error) {
+ return "", status.Errorf(status.Internal, "GetAccountIDByPrivateDomain is not implemented")
+}
+
+func (s *FileStore) GetAccountDomainAndCategory(_ context.Context, _ LockingStrength, accountID string) (string, string, error) {
+ s.mux.Lock()
+ defer s.mux.Unlock()
+
+ account, err := s.getAccount(accountID)
+ if err != nil {
+ return "", "", err
+ }
+
+ return account.Domain, account.DomainCategory, nil
+}
+
+// AccountExists checks whether an account exists by the given ID.
+func (s *FileStore) AccountExists(_ context.Context, _ LockingStrength, id string) (bool, error) {
+ _, exists := s.Accounts[id]
+ return exists, nil
+}
+
+func (s *FileStore) GetAccountDNSSettings(_ context.Context, _ LockingStrength, _ string) (*DNSSettings, error) {
+ return nil, status.Errorf(status.Internal, "GetAccountDNSSettings is not implemented")
+}
+
+func (s *FileStore) GetGroupByID(_ context.Context, _ LockingStrength, _, _ string) (*nbgroup.Group, error) {
+ return nil, status.Errorf(status.Internal, "GetGroupByID is not implemented")
+}
+
+func (s *FileStore) GetGroupByName(_ context.Context, _ LockingStrength, _, _ string) (*nbgroup.Group, error) {
+ return nil, status.Errorf(status.Internal, "GetGroupByName is not implemented")
+}
+
+func (s *FileStore) GetAccountPolicies(_ context.Context, _ LockingStrength, _ string) ([]*Policy, error) {
+ return nil, status.Errorf(status.Internal, "GetPolicyByID is not implemented")
+}
+
+func (s *FileStore) GetPolicyByID(_ context.Context, _ LockingStrength, _ string, _ string) (*Policy, error) {
+ return nil, status.Errorf(status.Internal, "GetPolicyByID is not implemented")
+
+}
+
+func (s *FileStore) GetAccountPostureChecks(_ context.Context, _ LockingStrength, _ string) ([]*posture.Checks, error) {
+ return nil, status.Errorf(status.Internal, "GetAccountPostureChecks is not implemented")
+}
+
+func (s *FileStore) GetPostureChecksByID(_ context.Context, _ LockingStrength, _ string, _ string) (*posture.Checks, error) {
+ return nil, status.Errorf(status.Internal, "GetPostureChecksByID is not implemented")
+}
+
+func (s *FileStore) GetAccountRoutes(_ context.Context, _ LockingStrength, _ string) ([]*route.Route, error) {
+ return nil, status.Errorf(status.Internal, "GetAccountRoutes is not implemented")
+}
+
+func (s *FileStore) GetRouteByID(_ context.Context, _ LockingStrength, _ string, _ string) (*route.Route, error) {
+ return nil, status.Errorf(status.Internal, "GetRouteByID is not implemented")
+}
+
+func (s *FileStore) GetAccountSetupKeys(_ context.Context, _ LockingStrength, _ string) ([]*SetupKey, error) {
+ return nil, status.Errorf(status.Internal, "GetAccountSetupKeys is not implemented")
+}
+
+func (s *FileStore) GetSetupKeyByID(_ context.Context, _ LockingStrength, _ string, _ string) (*SetupKey, error) {
+ return nil, status.Errorf(status.Internal, "GetSetupKeyByID is not implemented")
+}
+
+func (s *FileStore) GetAccountNameServerGroups(_ context.Context, _ LockingStrength, _ string) ([]*dns.NameServerGroup, error) {
+ return nil, status.Errorf(status.Internal, "GetAccountNameServerGroups is not implemented")
+}
+
+func (s *FileStore) GetNameServerGroupByID(_ context.Context, _ LockingStrength, _ string, _ string) (*dns.NameServerGroup, error) {
+ return nil, status.Errorf(status.Internal, "GetNameServerGroupByID is not implemented")
+}
diff --git a/management/server/group.go b/management/server/group.go
index 49720f347..aa387c058 100644
--- a/management/server/group.go
+++ b/management/server/group.go
@@ -25,91 +25,46 @@ func (e *GroupLinkError) Error() string {
return fmt.Sprintf("group has been linked to %s: %s", e.Resource, e.Name)
}
-// GetGroup object of the peers
+// CheckGroupPermissions validates if a user has the necessary permissions to view groups
+func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error {
+ settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
+ if err != nil {
+ return err
+ }
+
+ user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
+ if err != nil {
+ return err
+ }
+
+ if (!user.IsAdminOrServiceUser() && settings.RegularUsersViewBlocked) || user.AccountID != accountID {
+ return status.Errorf(status.PermissionDenied, "groups are blocked for users")
+ }
+
+ return nil
+}
+
+// GetGroup returns a specific group by groupID in an account
func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupID, userID string) (*nbgroup.Group, error) {
- unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
- defer unlock()
-
- account, err := am.Store.GetAccount(ctx, accountID)
- if err != nil {
+ if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
return nil, err
}
- user, err := account.FindUser(userID)
- if err != nil {
- return nil, err
- }
-
- if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked {
- return nil, status.Errorf(status.PermissionDenied, "groups are blocked for users")
- }
-
- group, ok := account.Groups[groupID]
- if ok {
- return group, nil
- }
-
- return nil, status.Errorf(status.NotFound, "group with ID %s not found", groupID)
+ return am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID)
}
// GetAllGroups returns all groups in an account
-func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID string, userID string) ([]*nbgroup.Group, error) {
- unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
- defer unlock()
-
- account, err := am.Store.GetAccount(ctx, accountID)
- if err != nil {
+func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) {
+ if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
return nil, err
}
- user, err := account.FindUser(userID)
- if err != nil {
- return nil, err
- }
-
- if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked {
- return nil, status.Errorf(status.PermissionDenied, "groups are blocked for users")
- }
-
- groups := make([]*nbgroup.Group, 0, len(account.Groups))
- for _, item := range account.Groups {
- groups = append(groups, item)
- }
-
- return groups, nil
+ return am.Store.GetAccountGroups(ctx, accountID)
}
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
- unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
- defer unlock()
-
- account, err := am.Store.GetAccount(ctx, accountID)
- if err != nil {
- return nil, err
- }
-
- matchingGroups := make([]*nbgroup.Group, 0)
- for _, group := range account.Groups {
- if group.Name == groupName {
- matchingGroups = append(matchingGroups, group)
- }
- }
-
- if len(matchingGroups) == 0 {
- return nil, status.Errorf(status.NotFound, "group with name %s not found", groupName)
- }
-
- maxPeers := -1
- var groupWithMostPeers *nbgroup.Group
- for i, group := range matchingGroups {
- if len(group.Peers) > maxPeers {
- maxPeers = len(group.Peers)
- groupWithMostPeers = matchingGroups[i]
- }
- }
-
- return groupWithMostPeers, nil
+ return am.Store.GetGroupByName(ctx, LockingStrengthShare, groupName, accountID)
}
// SaveGroup object of the peers
@@ -262,6 +217,15 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use
return nil
}
+ allGroup, err := account.GetGroupAll()
+ if err != nil {
+ return err
+ }
+
+ if allGroup.ID == groupID {
+ return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
+ }
+
if err = validateDeleteGroup(account, group, userId); err != nil {
return err
}
diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go
index 5d7094b6a..cda3bc748 100644
--- a/management/server/grpcserver.go
+++ b/management/server/grpcserver.go
@@ -262,7 +262,7 @@ func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string
}
claims := s.jwtClaimsExtractor.FromToken(token)
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
- _, _, err = s.accountManager.GetAccountFromToken(ctx, claims)
+ _, _, err = s.accountManager.GetAccountIDFromToken(ctx, claims)
if err != nil {
return "", status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err)
}
diff --git a/management/server/http/accounts_handler.go b/management/server/http/accounts_handler.go
index ffa5b9a28..91caa1512 100644
--- a/management/server/http/accounts_handler.go
+++ b/management/server/http/accounts_handler.go
@@ -35,25 +35,26 @@ func NewAccountsHandler(accountManager server.AccountManager, authCfg AuthCfg) *
// GetAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account.
func (h *AccountsHandler) GetAllAccounts(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
- if !(user.HasAdminPower() || user.IsServiceUser) {
- util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "the user has no permission to access account data"), w)
+ settings, err := h.accountManager.GetAccountSettings(r.Context(), accountID, userID)
+ if err != nil {
+ util.WriteError(r.Context(), err, w)
return
}
- resp := toAccountResponse(account)
+ resp := toAccountResponse(accountID, settings)
util.WriteJSONObject(r.Context(), w, []*api.Account{resp})
}
// UpdateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings)
func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- _, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ _, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -96,24 +97,19 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request)
settings.JWTAllowGroups = *req.Settings.JwtAllowGroups
}
- updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, user.Id, settings)
+ updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
- resp := toAccountResponse(updatedAccount)
+ resp := toAccountResponse(updatedAccount.Id, updatedAccount.Settings)
util.WriteJSONObject(r.Context(), w, &resp)
}
// DeleteAccount is a HTTP DELETE handler to delete an account
func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) {
- if r.Method != http.MethodDelete {
- util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
- return
- }
-
claims := h.claimsExtractor.FromRequestContext(r)
vars := mux.Vars(r)
targetAccountID := vars["accountId"]
@@ -131,28 +127,28 @@ func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request)
util.WriteJSONObject(r.Context(), w, emptyObject{})
}
-func toAccountResponse(account *server.Account) *api.Account {
- jwtAllowGroups := account.Settings.JWTAllowGroups
+func toAccountResponse(accountID string, settings *server.Settings) *api.Account {
+ jwtAllowGroups := settings.JWTAllowGroups
if jwtAllowGroups == nil {
jwtAllowGroups = []string{}
}
- settings := api.AccountSettings{
- PeerLoginExpiration: int(account.Settings.PeerLoginExpiration.Seconds()),
- PeerLoginExpirationEnabled: account.Settings.PeerLoginExpirationEnabled,
- GroupsPropagationEnabled: &account.Settings.GroupsPropagationEnabled,
- JwtGroupsEnabled: &account.Settings.JWTGroupsEnabled,
- JwtGroupsClaimName: &account.Settings.JWTGroupsClaimName,
+ apiSettings := api.AccountSettings{
+ PeerLoginExpiration: int(settings.PeerLoginExpiration.Seconds()),
+ PeerLoginExpirationEnabled: settings.PeerLoginExpirationEnabled,
+ GroupsPropagationEnabled: &settings.GroupsPropagationEnabled,
+ JwtGroupsEnabled: &settings.JWTGroupsEnabled,
+ JwtGroupsClaimName: &settings.JWTGroupsClaimName,
JwtAllowGroups: &jwtAllowGroups,
- RegularUsersViewBlocked: account.Settings.RegularUsersViewBlocked,
+ RegularUsersViewBlocked: settings.RegularUsersViewBlocked,
}
- if account.Settings.Extra != nil {
- settings.Extra = &api.AccountExtraSettings{PeerApprovalEnabled: &account.Settings.Extra.PeerApprovalEnabled}
+ if settings.Extra != nil {
+ apiSettings.Extra = &api.AccountExtraSettings{PeerApprovalEnabled: &settings.Extra.PeerApprovalEnabled}
}
return &api.Account{
- Id: account.Id,
- Settings: settings,
+ Id: accountID,
+ Settings: apiSettings,
}
}
diff --git a/management/server/http/accounts_handler_test.go b/management/server/http/accounts_handler_test.go
index 45c7679e5..cacb3d430 100644
--- a/management/server/http/accounts_handler_test.go
+++ b/management/server/http/accounts_handler_test.go
@@ -23,8 +23,11 @@ import (
func initAccountsTestData(account *server.Account, admin *server.User) *AccountsHandler {
return &AccountsHandler{
accountManager: &mock_server.MockAccountManager{
- GetAccountFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
- return account, admin, nil
+ GetAccountIDFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
+ return account.Id, admin.Id, nil
+ },
+ GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*server.Settings, error) {
+ return account.Settings, nil
},
UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) {
halfYearLimit := 180 * 24 * time.Hour
diff --git a/management/server/http/dns_settings_handler.go b/management/server/http/dns_settings_handler.go
index 74b0e1a55..13c2101a7 100644
--- a/management/server/http/dns_settings_handler.go
+++ b/management/server/http/dns_settings_handler.go
@@ -32,14 +32,14 @@ func NewDNSSettingsHandler(accountManager server.AccountManager, authCfg AuthCfg
// GetDNSSettings returns the DNS settings for the account
func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
- dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), account.Id, user.Id)
+ dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -55,7 +55,7 @@ func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Reque
// UpdateDNSSettings handles update to DNS settings of an account
func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -72,7 +72,7 @@ func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Re
DisabledManagementGroups: req.DisabledManagementGroups,
}
- err = h.accountManager.SaveDNSSettings(r.Context(), account.Id, user.Id, updateDNSSettings)
+ err = h.accountManager.SaveDNSSettings(r.Context(), accountID, userID, updateDNSSettings)
if err != nil {
util.WriteError(r.Context(), err, w)
return
diff --git a/management/server/http/dns_settings_handler_test.go b/management/server/http/dns_settings_handler_test.go
index 897ae63dc..8baea7b15 100644
--- a/management/server/http/dns_settings_handler_test.go
+++ b/management/server/http/dns_settings_handler_test.go
@@ -52,8 +52,8 @@ func initDNSSettingsTestData() *DNSSettingsHandler {
}
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
},
- GetAccountFromTokenFunc: func(ctx context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
- return testingDNSSettingsAccount, testingDNSSettingsAccount.Users[testDNSSettingsUserID], nil
+ GetAccountIDFromTokenFunc: func(ctx context.Context, _ jwtclaims.AuthorizationClaims) (string, string, error) {
+ return testingDNSSettingsAccount.Id, testingDNSSettingsAccount.Users[testDNSSettingsUserID].Id, nil
},
},
claimsExtractor: jwtclaims.NewClaimsExtractor(
diff --git a/management/server/http/events_handler.go b/management/server/http/events_handler.go
index 428b4c164..ee0c63f28 100644
--- a/management/server/http/events_handler.go
+++ b/management/server/http/events_handler.go
@@ -34,14 +34,14 @@ func NewEventsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ev
// GetAllEvents list of the given account
func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
- accountEvents, err := h.accountManager.GetEvents(r.Context(), account.Id, user.Id)
+ accountEvents, err := h.accountManager.GetEvents(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -51,7 +51,7 @@ func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) {
events[i] = toEventResponse(e)
}
- err = h.fillEventsWithUserInfo(r.Context(), events, account.Id, user.Id)
+ err = h.fillEventsWithUserInfo(r.Context(), events, accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
diff --git a/management/server/http/events_handler_test.go b/management/server/http/events_handler_test.go
index 8bdd508bf..e525cf2ee 100644
--- a/management/server/http/events_handler_test.go
+++ b/management/server/http/events_handler_test.go
@@ -20,7 +20,7 @@ import (
"github.com/netbirdio/netbird/management/server/mock_server"
)
-func initEventsTestData(account string, user *server.User, events ...*activity.Event) *EventsHandler {
+func initEventsTestData(account string, events ...*activity.Event) *EventsHandler {
return &EventsHandler{
accountManager: &mock_server.MockAccountManager{
GetEventsFunc: func(_ context.Context, accountID, userID string) ([]*activity.Event, error) {
@@ -29,14 +29,8 @@ func initEventsTestData(account string, user *server.User, events ...*activity.E
}
return []*activity.Event{}, nil
},
- GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
- return &server.Account{
- Id: claims.AccountId,
- Domain: "hotmail.com",
- Users: map[string]*server.User{
- user.Id: user,
- },
- }, user, nil
+ GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
+ return claims.AccountId, claims.UserId, nil
},
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) {
return make([]*server.UserInfo, 0), nil
@@ -199,7 +193,7 @@ func TestEvents_GetEvents(t *testing.T) {
accountID := "test_account"
adminUser := server.NewAdminUser("test_user")
events := generateEvents(accountID, adminUser.Id)
- handler := initEventsTestData(accountID, adminUser, events...)
+ handler := initEventsTestData(accountID, events...)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
diff --git a/management/server/http/geolocation_handler_test.go b/management/server/http/geolocation_handler_test.go
index 7f4d6dc7c..19c916dd2 100644
--- a/management/server/http/geolocation_handler_test.go
+++ b/management/server/http/geolocation_handler_test.go
@@ -11,9 +11,9 @@ import (
"testing"
"github.com/gorilla/mux"
+ "github.com/netbirdio/netbird/management/server"
"github.com/stretchr/testify/assert"
- "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims"
@@ -43,14 +43,11 @@ func initGeolocationTestData(t *testing.T) *GeolocationsHandler {
return &GeolocationsHandler{
accountManager: &mock_server.MockAccountManager{
- GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
- user := server.NewAdminUser("test_user")
- return &server.Account{
- Id: claims.AccountId,
- Users: map[string]*server.User{
- "test_user": user,
- },
- }, user, nil
+ GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
+ return claims.AccountId, claims.UserId, nil
+ },
+ GetUserByIDFunc: func(ctx context.Context, id string) (*server.User, error) {
+ return server.NewAdminUser(id), nil
},
},
geolocationManager: geo,
diff --git a/management/server/http/geolocations_handler.go b/management/server/http/geolocations_handler.go
index af4d3116f..418228abf 100644
--- a/management/server/http/geolocations_handler.go
+++ b/management/server/http/geolocations_handler.go
@@ -98,7 +98,12 @@ func (l *GeolocationsHandler) GetCitiesByCountry(w http.ResponseWriter, r *http.
func (l *GeolocationsHandler) authenticateUser(r *http.Request) error {
claims := l.claimsExtractor.FromRequestContext(r)
- _, user, err := l.accountManager.GetAccountFromToken(r.Context(), claims)
+ _, userID, err := l.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ if err != nil {
+ return err
+ }
+
+ user, err := l.accountManager.GetUserByID(r.Context(), userID)
if err != nil {
return err
}
diff --git a/management/server/http/groups_handler.go b/management/server/http/groups_handler.go
index c622d873a..f369d1a00 100644
--- a/management/server/http/groups_handler.go
+++ b/management/server/http/groups_handler.go
@@ -5,6 +5,7 @@ import (
"net/http"
"github.com/gorilla/mux"
+ nbpeer "github.com/netbirdio/netbird/management/server/peer"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server"
@@ -35,14 +36,20 @@ func NewGroupsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Gr
// GetAllGroups list for the account
func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
- groups, err := h.accountManager.GetAllGroups(r.Context(), account.Id, user.Id)
+ groups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
+ if err != nil {
+ util.WriteError(r.Context(), err, w)
+ return
+ }
+
+ accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -50,7 +57,7 @@ func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) {
groupsResponse := make([]*api.Group, 0, len(groups))
for _, group := range groups {
- groupsResponse = append(groupsResponse, toGroupResponse(account, group))
+ groupsResponse = append(groupsResponse, toGroupResponse(accountPeers, group))
}
util.WriteJSONObject(r.Context(), w, groupsResponse)
@@ -59,7 +66,7 @@ func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) {
// UpdateGroup handles update to a group identified by a given ID
func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -76,17 +83,18 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
return
}
- eg, ok := account.Groups[groupID]
- if !ok {
- util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find group with ID %s", groupID), w)
- return
- }
-
- allGroup, err := account.GetGroupAll()
+ existingGroup, err := h.accountManager.GetGroup(r.Context(), accountID, groupID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+
+ allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", accountID)
+ if err != nil {
+ util.WriteError(r.Context(), err, w)
+ return
+ }
+
if allGroup.ID == groupID {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "updating group ALL is not allowed"), w)
return
@@ -114,23 +122,29 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
ID: groupID,
Name: req.Name,
Peers: peers,
- Issued: eg.Issued,
- IntegrationReference: eg.IntegrationReference,
+ Issued: existingGroup.Issued,
+ IntegrationReference: existingGroup.IntegrationReference,
}
- if err := h.accountManager.SaveGroup(r.Context(), account.Id, user.Id, &group); err != nil {
- log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, account.Id, err)
+ if err := h.accountManager.SaveGroup(r.Context(), accountID, userID, &group); err != nil {
+ log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, accountID, err)
util.WriteError(r.Context(), err, w)
return
}
- util.WriteJSONObject(r.Context(), w, toGroupResponse(account, &group))
+ accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
+ if err != nil {
+ util.WriteError(r.Context(), err, w)
+ return
+ }
+
+ util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, &group))
}
// CreateGroup handles group creation request
func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -160,24 +174,29 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
Issued: nbgroup.GroupIssuedAPI,
}
- err = h.accountManager.SaveGroup(r.Context(), account.Id, user.Id, &group)
+ err = h.accountManager.SaveGroup(r.Context(), accountID, userID, &group)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
- util.WriteJSONObject(r.Context(), w, toGroupResponse(account, &group))
+ accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
+ if err != nil {
+ util.WriteError(r.Context(), err, w)
+ return
+ }
+
+ util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, &group))
}
// DeleteGroup handles group deletion request
func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
- aID := account.Id
groupID := mux.Vars(r)["groupId"]
if len(groupID) == 0 {
@@ -185,18 +204,7 @@ func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
return
}
- allGroup, err := account.GetGroupAll()
- if err != nil {
- util.WriteError(r.Context(), err, w)
- return
- }
-
- if allGroup.ID == groupID {
- util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed"), w)
- return
- }
-
- err = h.accountManager.DeleteGroup(r.Context(), aID, user.Id, groupID)
+ err = h.accountManager.DeleteGroup(r.Context(), accountID, userID, groupID)
if err != nil {
_, ok := err.(*server.GroupLinkError)
if ok {
@@ -213,34 +221,39 @@ func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
// GetGroup returns a group
func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ if err != nil {
+ util.WriteError(r.Context(), err, w)
+ return
+ }
+ groupID := mux.Vars(r)["groupId"]
+ if len(groupID) == 0 {
+ util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w)
+ return
+ }
+
+ group, err := h.accountManager.GetGroup(r.Context(), accountID, groupID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
- switch r.Method {
- case http.MethodGet:
- groupID := mux.Vars(r)["groupId"]
- if len(groupID) == 0 {
- util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w)
- return
- }
-
- group, err := h.accountManager.GetGroup(r.Context(), account.Id, groupID, user.Id)
- if err != nil {
- util.WriteError(r.Context(), err, w)
- return
- }
-
- util.WriteJSONObject(r.Context(), w, toGroupResponse(account, group))
- default:
- util.WriteError(r.Context(), status.Errorf(status.NotFound, "HTTP method not found"), w)
+ accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
+ if err != nil {
+ util.WriteError(r.Context(), err, w)
return
}
+
+ util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, group))
+
}
-func toGroupResponse(account *server.Account, group *nbgroup.Group) *api.Group {
+func toGroupResponse(peers []*nbpeer.Peer, group *nbgroup.Group) *api.Group {
+ peersMap := make(map[string]*nbpeer.Peer, len(peers))
+ for _, peer := range peers {
+ peersMap[peer.ID] = peer
+ }
+
cache := make(map[string]api.PeerMinimum)
gr := api.Group{
Id: group.ID,
@@ -251,7 +264,7 @@ func toGroupResponse(account *server.Account, group *nbgroup.Group) *api.Group {
for _, pid := range group.Peers {
_, ok := cache[pid]
if !ok {
- peer, ok := account.Peers[pid]
+ peer, ok := peersMap[pid]
if !ok {
continue
}
diff --git a/management/server/http/groups_handler_test.go b/management/server/http/groups_handler_test.go
index d5ed07c9e..7f3c81f18 100644
--- a/management/server/http/groups_handler_test.go
+++ b/management/server/http/groups_handler_test.go
@@ -14,6 +14,7 @@ import (
"github.com/gorilla/mux"
"github.com/magiconair/properties/assert"
+ "golang.org/x/exp/maps"
"github.com/netbirdio/netbird/management/server"
nbgroup "github.com/netbirdio/netbird/management/server/group"
@@ -30,7 +31,7 @@ var TestPeers = map[string]*nbpeer.Peer{
"B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")},
}
-func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler {
+func initGroupTestData(initGroups ...*nbgroup.Group) *GroupsHandler {
return &GroupsHandler{
accountManager: &mock_server.MockAccountManager{
SaveGroupFunc: func(_ context.Context, accountID, userID string, group *nbgroup.Group) error {
@@ -40,36 +41,35 @@ func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler {
return nil
},
GetGroupFunc: func(_ context.Context, _, groupID, _ string) (*nbgroup.Group, error) {
- if groupID != "idofthegroup" {
+ groups := map[string]*nbgroup.Group{
+ "id-jwt-group": {ID: "id-jwt-group", Name: "From JWT", Issued: nbgroup.GroupIssuedJWT},
+ "id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: nbgroup.GroupIssuedAPI},
+ "id-all": {ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI},
+ }
+
+ for _, group := range initGroups {
+ groups[group.ID] = group
+ }
+
+ group, ok := groups[groupID]
+ if !ok {
return nil, status.Errorf(status.NotFound, "not found")
}
- if groupID == "id-jwt-group" {
- return &nbgroup.Group{
- ID: "id-jwt-group",
- Name: "Default Group",
- Issued: nbgroup.GroupIssuedJWT,
- }, nil
- }
- return &nbgroup.Group{
- ID: "idofthegroup",
- Name: "Group",
- Issued: nbgroup.GroupIssuedAPI,
- }, nil
+
+ return group, nil
},
- GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
- return &server.Account{
- Id: claims.AccountId,
- Domain: "hotmail.com",
- Peers: TestPeers,
- Users: map[string]*server.User{
- user.Id: user,
- },
- Groups: map[string]*nbgroup.Group{
- "id-jwt-group": {ID: "id-jwt-group", Name: "From JWT", Issued: nbgroup.GroupIssuedJWT},
- "id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: nbgroup.GroupIssuedAPI},
- "id-all": {ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI},
- },
- }, user, nil
+ GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
+ return claims.AccountId, claims.UserId, nil
+ },
+ GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*nbgroup.Group, error) {
+ if groupName == "All" {
+ return &nbgroup.Group{ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI}, nil
+ }
+
+ return nil, fmt.Errorf("unknown group name")
+ },
+ GetPeersFunc: func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
+ return maps.Values(TestPeers), nil
},
DeleteGroupFunc: func(_ context.Context, accountID, userId, groupID string) error {
if groupID == "linked-grp" {
@@ -125,8 +125,7 @@ func TestGetGroup(t *testing.T) {
Name: "Group",
}
- adminUser := server.NewAdminUser("test_user")
- p := initGroupTestData(adminUser, group)
+ p := initGroupTestData(group)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
@@ -247,8 +246,7 @@ func TestWriteGroup(t *testing.T) {
},
}
- adminUser := server.NewAdminUser("test_user")
- p := initGroupTestData(adminUser)
+ p := initGroupTestData()
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
@@ -325,8 +323,7 @@ func TestDeleteGroup(t *testing.T) {
},
}
- adminUser := server.NewAdminUser("test_user")
- p := initGroupTestData(adminUser)
+ p := initGroupTestData()
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
diff --git a/management/server/http/nameservers_handler.go b/management/server/http/nameservers_handler.go
index c6e00bb2d..e7a2bc2ae 100644
--- a/management/server/http/nameservers_handler.go
+++ b/management/server/http/nameservers_handler.go
@@ -36,14 +36,14 @@ func NewNameserversHandler(accountManager server.AccountManager, authCfg AuthCfg
// GetAllNameservers returns the list of nameserver groups for the account
func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
- nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), account.Id, user.Id)
+ nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -60,7 +60,7 @@ func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Re
// CreateNameserverGroup handles nameserver group creation request
func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -79,7 +79,7 @@ func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *htt
return
}
- nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), account.Id, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, user.Id, req.SearchDomainsEnabled)
+ nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), accountID, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, userID, req.SearchDomainsEnabled)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -93,7 +93,7 @@ func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *htt
// UpdateNameserverGroup handles update to a nameserver group identified by a given ID
func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -130,7 +130,7 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt
SearchDomainsEnabled: req.SearchDomainsEnabled,
}
- err = h.accountManager.SaveNameServerGroup(r.Context(), account.Id, user.Id, updatedNSGroup)
+ err = h.accountManager.SaveNameServerGroup(r.Context(), accountID, userID, updatedNSGroup)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -144,7 +144,7 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt
// DeleteNameserverGroup handles nameserver group deletion request
func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -156,7 +156,7 @@ func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *htt
return
}
- err = h.accountManager.DeleteNameServerGroup(r.Context(), account.Id, nsGroupID, user.Id)
+ err = h.accountManager.DeleteNameServerGroup(r.Context(), accountID, nsGroupID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -168,7 +168,7 @@ func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *htt
// GetNameserverGroup handles a nameserver group Get request identified by ID
func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
@@ -181,7 +181,7 @@ func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.R
return
}
- nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), account.Id, user.Id, nsGroupID)
+ nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), accountID, userID, nsGroupID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
diff --git a/management/server/http/nameservers_handler_test.go b/management/server/http/nameservers_handler_test.go
index 28b080571..98c2e402d 100644
--- a/management/server/http/nameservers_handler_test.go
+++ b/management/server/http/nameservers_handler_test.go
@@ -18,7 +18,6 @@ import (
"github.com/gorilla/mux"
- "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server"
)
@@ -29,14 +28,6 @@ const (
testNSGroupAccountID = "test_id"
)
-var testingNSAccount = &server.Account{
- Id: testNSGroupAccountID,
- Domain: "hotmail.com",
- Users: map[string]*server.User{
- "test_user": server.NewAdminUser("test_user"),
- },
-}
-
var baseExistingNSGroup = &nbdns.NameServerGroup{
ID: existingNSGroupID,
Name: "super",
@@ -90,8 +81,8 @@ func initNameserversTestData() *NameserversHandler {
}
return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID)
},
- GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
- return testingNSAccount, testingAccount.Users["test_user"], nil
+ GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
+ return claims.AccountId, claims.UserId, nil
},
},
claimsExtractor: jwtclaims.NewClaimsExtractor(
diff --git a/management/server/http/pat_handler.go b/management/server/http/pat_handler.go
index 9d8448d3d..dfa9563e3 100644
--- a/management/server/http/pat_handler.go
+++ b/management/server/http/pat_handler.go
@@ -34,20 +34,20 @@ func NewPATsHandler(accountManager server.AccountManager, authCfg AuthCfg) *PATH
// GetAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user
func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
- userID := vars["userId"]
+ targetUserID := vars["userId"]
if len(userID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return
}
- pats, err := h.accountManager.GetAllPATs(r.Context(), account.Id, user.Id, userID)
+ pats, err := h.accountManager.GetAllPATs(r.Context(), accountID, userID, targetUserID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -64,7 +64,7 @@ func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) {
// GetToken is HTTP GET handler that returns a personal access token for the given user
func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -83,7 +83,7 @@ func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) {
return
}
- pat, err := h.accountManager.GetPAT(r.Context(), account.Id, user.Id, targetUserID, tokenID)
+ pat, err := h.accountManager.GetPAT(r.Context(), accountID, userID, targetUserID, tokenID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -95,7 +95,7 @@ func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) {
// CreateToken is HTTP POST handler that creates a personal access token for the given user
func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -115,7 +115,7 @@ func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) {
return
}
- pat, err := h.accountManager.CreatePAT(r.Context(), account.Id, user.Id, targetUserID, req.Name, req.ExpiresIn)
+ pat, err := h.accountManager.CreatePAT(r.Context(), accountID, userID, targetUserID, req.Name, req.ExpiresIn)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -127,7 +127,7 @@ func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) {
// DeleteToken is HTTP DELETE handler that deletes a personal access token for the given user
func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -146,7 +146,7 @@ func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) {
return
}
- err = h.accountManager.DeletePAT(r.Context(), account.Id, user.Id, targetUserID, tokenID)
+ err = h.accountManager.DeletePAT(r.Context(), accountID, userID, targetUserID, tokenID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
diff --git a/management/server/http/pat_handler_test.go b/management/server/http/pat_handler_test.go
index b72f71468..c28228a50 100644
--- a/management/server/http/pat_handler_test.go
+++ b/management/server/http/pat_handler_test.go
@@ -77,8 +77,8 @@ func initPATTestData() *PATHandler {
}, nil
},
- GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
- return testAccount, testAccount.Users[existingUserID], nil
+ GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
+ return claims.AccountId, claims.UserId, nil
},
DeletePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
if accountID != existingAccountID {
@@ -119,7 +119,7 @@ func initPATTestData() *PATHandler {
return jwtclaims.AuthorizationClaims{
UserId: existingUserID,
Domain: testDomain,
- AccountId: testNSGroupAccountID,
+ AccountId: existingAccountID,
}
}),
),
diff --git a/management/server/http/peers_handler.go b/management/server/http/peers_handler.go
index 5a2190d83..4fbbc3106 100644
--- a/management/server/http/peers_handler.go
+++ b/management/server/http/peers_handler.go
@@ -74,7 +74,7 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid))
}
-func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) {
+func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, userID, peerID string, w http.ResponseWriter, r *http.Request) {
req := &api.PeerRequest{}
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
@@ -96,7 +96,7 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account,
}
}
- peer, err := h.accountManager.UpdatePeer(ctx, account.Id, user.Id, update)
+ peer, err := h.accountManager.UpdatePeer(ctx, account.Id, userID, update)
if err != nil {
util.WriteError(ctx, err, w)
return
@@ -130,7 +130,7 @@ func (h *PeersHandler) deletePeer(ctx context.Context, accountID, userID string,
// HandlePeer handles all peer requests for GET, PUT and DELETE operations
func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -144,13 +144,20 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodDelete:
- h.deletePeer(r.Context(), account.Id, user.Id, peerID, w)
+ h.deletePeer(r.Context(), accountID, userID, peerID, w)
return
- case http.MethodPut:
- h.updatePeer(r.Context(), account, user, peerID, w, r)
- return
- case http.MethodGet:
- h.getPeer(r.Context(), account, peerID, user.Id, w)
+ case http.MethodGet, http.MethodPut:
+ account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID)
+ if err != nil {
+ util.WriteError(r.Context(), err, w)
+ return
+ }
+
+ if r.Method == http.MethodGet {
+ h.getPeer(r.Context(), account, peerID, userID, w)
+ } else {
+ h.updatePeer(r.Context(), account, userID, peerID, w, r)
+ }
return
default:
util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w)
@@ -159,19 +166,14 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) {
// GetAllPeers returns a list of all peers associated with a provided account
func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
- if r.Method != http.MethodGet {
- util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w)
- return
- }
-
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
- peers, err := h.accountManager.GetPeers(r.Context(), account.Id, user.Id)
+ account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -179,8 +181,8 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
dnsDomain := h.accountManager.GetDNSDomain()
- respBody := make([]*api.PeerBatch, 0, len(peers))
- for _, peer := range peers {
+ respBody := make([]*api.PeerBatch, 0, len(account.Peers))
+ for _, peer := range account.Peers {
peerToReturn, err := h.checkPeerStatus(peer)
if err != nil {
util.WriteError(r.Context(), err, w)
@@ -214,7 +216,7 @@ func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approv
// GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network.
func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -227,6 +229,18 @@ func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request
return
}
+ account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID)
+ if err != nil {
+ util.WriteError(r.Context(), err, w)
+ return
+ }
+
+ user, err := account.FindUser(userID)
+ if err != nil {
+ util.WriteError(r.Context(), err, w)
+ return
+ }
+
// If the user is regular user and does not own the peer
// with the given peerID return an empty list
if !user.HasAdminPower() && !user.IsServiceUser {
diff --git a/management/server/http/peers_handler_test.go b/management/server/http/peers_handler_test.go
index dae264fff..f933eee14 100644
--- a/management/server/http/peers_handler_test.go
+++ b/management/server/http/peers_handler_test.go
@@ -13,16 +13,15 @@ import (
"time"
"github.com/gorilla/mux"
+ "github.com/netbirdio/netbird/management/server"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/http/api"
+ "github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"golang.org/x/exp/maps"
- "github.com/netbirdio/netbird/management/server/jwtclaims"
-
"github.com/stretchr/testify/assert"
- "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/mock_server"
)
@@ -70,7 +69,10 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
GetDNSDomainFunc: func() string {
return "netbird.selfhosted"
},
- GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
+ GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
+ return claims.AccountId, claims.UserId, nil
+ },
+ GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) {
peersMap := make(map[string]*nbpeer.Peer)
for _, peer := range peers {
peersMap[peer.ID] = peer.Copy()
@@ -78,7 +80,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
policy := &server.Policy{
ID: "policy",
- AccountID: claims.AccountId,
+ AccountID: accountID,
Name: "policy",
Enabled: true,
Rules: []*server.PolicyRule{
@@ -100,7 +102,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
srvUser.IsServiceUser = true
account := &server.Account{
- Id: claims.AccountId,
+ Id: accountID,
Domain: "hotmail.com",
Peers: peersMap,
Users: map[string]*server.User{
@@ -111,7 +113,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
Groups: map[string]*nbgroup.Group{
"group1": {
ID: "group1",
- AccountID: claims.AccountId,
+ AccountID: accountID,
Name: "group1",
Issued: "api",
Peers: maps.Keys(peersMap),
@@ -132,7 +134,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
},
}
- return account, account.Users[claims.UserId], nil
+ return account, nil
},
HasConnectedChannelFunc: func(peerID string) bool {
statuses := make(map[string]struct{})
@@ -279,9 +281,15 @@ func TestGetPeers(t *testing.T) {
// hardcode this check for now as we only have two peers in this suite
assert.Equal(t, len(respBody), 2)
- assert.Equal(t, respBody[1].Connected, false)
- got = respBody[0]
+ for _, peer := range respBody {
+ if peer.Id == testPeerID {
+ got = peer
+ } else {
+ assert.Equal(t, peer.Connected, false)
+ }
+ }
+
} else {
got = &api.Peer{}
err = json.Unmarshal(content, got)
diff --git a/management/server/http/policies_handler.go b/management/server/http/policies_handler.go
index 9622668f4..225d7e1f3 100644
--- a/management/server/http/policies_handler.go
+++ b/management/server/http/policies_handler.go
@@ -6,6 +6,7 @@ import (
"strconv"
"github.com/gorilla/mux"
+ nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/rs/xid"
"github.com/netbirdio/netbird/management/server"
@@ -35,21 +36,27 @@ func NewPoliciesHandler(accountManager server.AccountManager, authCfg AuthCfg) *
// GetAllPolicies list for the account
func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
- accountPolicies, err := h.accountManager.ListPolicies(r.Context(), account.Id, user.Id)
+ listPolicies, err := h.accountManager.ListPolicies(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
- policies := []*api.Policy{}
- for _, policy := range accountPolicies {
- resp := toPolicyResponse(account, policy)
+ allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
+ if err != nil {
+ util.WriteError(r.Context(), err, w)
+ return
+ }
+
+ policies := make([]*api.Policy, 0, len(listPolicies))
+ for _, policy := range listPolicies {
+ resp := toPolicyResponse(allGroups, policy)
if len(resp.Rules) == 0 {
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
return
@@ -63,7 +70,7 @@ func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) {
// UpdatePolicy handles update to a policy identified by a given ID
func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -76,41 +83,29 @@ func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
return
}
- policyIdx := -1
- for i, policy := range account.Policies {
- if policy.ID == policyID {
- policyIdx = i
- break
- }
- }
- if policyIdx < 0 {
- util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find policy id %s", policyID), w)
- return
- }
-
- h.savePolicy(w, r, account, user, policyID)
-}
-
-// CreatePolicy handles policy creation request
-func (h *Policies) CreatePolicy(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ _, err = h.accountManager.GetPolicy(r.Context(), accountID, policyID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
- h.savePolicy(w, r, account, user, "")
+ h.savePolicy(w, r, accountID, userID, policyID)
+}
+
+// CreatePolicy handles policy creation request
+func (h *Policies) CreatePolicy(w http.ResponseWriter, r *http.Request) {
+ claims := h.claimsExtractor.FromRequestContext(r)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ if err != nil {
+ util.WriteError(r.Context(), err, w)
+ return
+ }
+
+ h.savePolicy(w, r, accountID, userID, "")
}
// savePolicy handles policy creation and update
-func (h *Policies) savePolicy(
- w http.ResponseWriter,
- r *http.Request,
- account *server.Account,
- user *server.User,
- policyID string,
-) {
+func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID string, userID string, policyID string) {
var req api.PutApiPoliciesPolicyIdJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
@@ -127,6 +122,8 @@ func (h *Policies) savePolicy(
return
}
+ isUpdate := policyID != ""
+
if policyID == "" {
policyID = xid.New().String()
}
@@ -141,8 +138,8 @@ func (h *Policies) savePolicy(
pr := server.PolicyRule{
ID: policyID, // TODO: when policy can contain multiple rules, need refactor
Name: rule.Name,
- Destinations: groupMinimumsToStrings(account, rule.Destinations),
- Sources: groupMinimumsToStrings(account, rule.Sources),
+ Destinations: rule.Destinations,
+ Sources: rule.Sources,
Bidirectional: rule.Bidirectional,
}
@@ -207,15 +204,21 @@ func (h *Policies) savePolicy(
}
if req.SourcePostureChecks != nil {
- policy.SourcePostureChecks = sourcePostureChecksToStrings(account, *req.SourcePostureChecks)
+ policy.SourcePostureChecks = *req.SourcePostureChecks
}
- if err := h.accountManager.SavePolicy(r.Context(), account.Id, user.Id, &policy); err != nil {
+ if err := h.accountManager.SavePolicy(r.Context(), accountID, userID, &policy, isUpdate); err != nil {
util.WriteError(r.Context(), err, w)
return
}
- resp := toPolicyResponse(account, &policy)
+ allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
+ if err != nil {
+ util.WriteError(r.Context(), err, w)
+ return
+ }
+
+ resp := toPolicyResponse(allGroups, &policy)
if len(resp.Rules) == 0 {
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
return
@@ -227,12 +230,11 @@ func (h *Policies) savePolicy(
// DeletePolicy handles policy deletion request
func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
- aID := account.Id
vars := mux.Vars(r)
policyID := vars["policyId"]
@@ -241,7 +243,7 @@ func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) {
return
}
- if err = h.accountManager.DeletePolicy(r.Context(), aID, policyID, user.Id); err != nil {
+ if err = h.accountManager.DeletePolicy(r.Context(), accountID, policyID, userID); err != nil {
util.WriteError(r.Context(), err, w)
return
}
@@ -252,40 +254,46 @@ func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) {
// GetPolicy handles a group Get request identified by ID
func (h *Policies) GetPolicy(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
- switch r.Method {
- case http.MethodGet:
- vars := mux.Vars(r)
- policyID := vars["policyId"]
- if len(policyID) == 0 {
- util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
- return
- }
-
- policy, err := h.accountManager.GetPolicy(r.Context(), account.Id, policyID, user.Id)
- if err != nil {
- util.WriteError(r.Context(), err, w)
- return
- }
-
- resp := toPolicyResponse(account, policy)
- if len(resp.Rules) == 0 {
- util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
- return
- }
-
- util.WriteJSONObject(r.Context(), w, resp)
- default:
- util.WriteError(r.Context(), status.Errorf(status.NotFound, "method not found"), w)
+ vars := mux.Vars(r)
+ policyID := vars["policyId"]
+ if len(policyID) == 0 {
+ util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
+ return
}
+
+ policy, err := h.accountManager.GetPolicy(r.Context(), accountID, policyID, userID)
+ if err != nil {
+ util.WriteError(r.Context(), err, w)
+ return
+ }
+
+ allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
+ if err != nil {
+ util.WriteError(r.Context(), err, w)
+ return
+ }
+
+ resp := toPolicyResponse(allGroups, policy)
+ if len(resp.Rules) == 0 {
+ util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
+ return
+ }
+
+ util.WriteJSONObject(r.Context(), w, resp)
}
-func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Policy {
+func toPolicyResponse(groups []*nbgroup.Group, policy *server.Policy) *api.Policy {
+ groupsMap := make(map[string]*nbgroup.Group)
+ for _, group := range groups {
+ groupsMap[group.ID] = group
+ }
+
cache := make(map[string]api.GroupMinimum)
ap := &api.Policy{
Id: &policy.ID,
@@ -306,16 +314,18 @@ func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Polic
Protocol: api.PolicyRuleProtocol(r.Protocol),
Action: api.PolicyRuleAction(r.Action),
}
+
if len(r.Ports) != 0 {
portsCopy := r.Ports
rule.Ports = &portsCopy
}
+
for _, gid := range r.Sources {
_, ok := cache[gid]
if ok {
continue
}
- if group, ok := account.Groups[gid]; ok {
+ if group, ok := groupsMap[gid]; ok {
minimum := api.GroupMinimum{
Id: group.ID,
Name: group.Name,
@@ -325,13 +335,14 @@ func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Polic
cache[gid] = minimum
}
}
+
for _, gid := range r.Destinations {
cachedMinimum, ok := cache[gid]
if ok {
rule.Destinations = append(rule.Destinations, cachedMinimum)
continue
}
- if group, ok := account.Groups[gid]; ok {
+ if group, ok := groupsMap[gid]; ok {
minimum := api.GroupMinimum{
Id: group.ID,
Name: group.Name,
@@ -345,28 +356,3 @@ func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Polic
}
return ap
}
-
-func groupMinimumsToStrings(account *server.Account, gm []string) []string {
- result := make([]string, 0, len(gm))
- for _, g := range gm {
- if _, ok := account.Groups[g]; !ok {
- continue
- }
- result = append(result, g)
- }
- return result
-}
-
-func sourcePostureChecksToStrings(account *server.Account, postureChecksIds []string) []string {
- result := make([]string, 0, len(postureChecksIds))
- for _, id := range postureChecksIds {
- for _, postureCheck := range account.PostureChecks {
- if id == postureCheck.ID {
- result = append(result, id)
- continue
- }
- }
-
- }
- return result
-}
diff --git a/management/server/http/policies_handler_test.go b/management/server/http/policies_handler_test.go
index 06274fb07..228ebcbce 100644
--- a/management/server/http/policies_handler_test.go
+++ b/management/server/http/policies_handler_test.go
@@ -38,17 +38,23 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies {
}
return policy, nil
},
- SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy) error {
+ SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy, _ bool) error {
if !strings.HasPrefix(policy.ID, "id-") {
policy.ID = "id-was-set"
policy.Rules[0].ID = "id-was-set"
}
return nil
},
- GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
- user := server.NewAdminUser("test_user")
+ GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) {
+ return []*nbgroup.Group{{ID: "F"}, {ID: "G"}}, nil
+ },
+ GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
+ return claims.AccountId, claims.UserId, nil
+ },
+ GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) {
+ user := server.NewAdminUser(userID)
return &server.Account{
- Id: claims.AccountId,
+ Id: accountID,
Domain: "hotmail.com",
Policies: []*server.Policy{
{ID: "id-existed"},
@@ -60,7 +66,7 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies {
Users: map[string]*server.User{
"test_user": user,
},
- }, user, nil
+ }, nil
},
},
claimsExtractor: jwtclaims.NewClaimsExtractor(
diff --git a/management/server/http/posture_checks_handler.go b/management/server/http/posture_checks_handler.go
index 059cb3b80..1d020e9bc 100644
--- a/management/server/http/posture_checks_handler.go
+++ b/management/server/http/posture_checks_handler.go
@@ -37,20 +37,20 @@ func NewPostureChecksHandler(accountManager server.AccountManager, geolocationMa
// GetAllPostureChecks list for the account
func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r)
- account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
- accountPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), account.Id, user.Id)
+ listPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
- postureChecks := []*api.PostureCheck{}
- for _, postureCheck := range accountPostureChecks {
+ postureChecks := make([]*api.PostureCheck, 0, len(listPostureChecks))
+ for _, postureCheck := range listPostureChecks {
postureChecks = append(postureChecks, postureCheck.ToAPIResponse())
}
@@ -60,7 +60,7 @@ func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *htt
// UpdatePostureCheck handles update to a posture check identified by a given ID
func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r)
- account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -73,37 +73,31 @@ func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http
return
}
- postureChecksIdx := -1
- for i, postureCheck := range account.PostureChecks {
- if postureCheck.ID == postureChecksID {
- postureChecksIdx = i
- break
- }
- }
- if postureChecksIdx < 0 {
- util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find posture checks id %s", postureChecksID), w)
+ _, err = p.accountManager.GetPostureChecks(r.Context(), accountID, postureChecksID, userID)
+ if err != nil {
+ util.WriteError(r.Context(), err, w)
return
}
- p.savePostureChecks(w, r, account, user, postureChecksID)
+ p.savePostureChecks(w, r, accountID, userID, postureChecksID)
}
// CreatePostureCheck handles posture check creation request
func (p *PostureChecksHandler) CreatePostureCheck(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r)
- account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
- p.savePostureChecks(w, r, account, user, "")
+ p.savePostureChecks(w, r, accountID, userID, "")
}
// GetPostureCheck handles a posture check Get request identified by ID
func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r)
- account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -116,7 +110,7 @@ func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Re
return
}
- postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), account.Id, postureChecksID, user.Id)
+ postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), accountID, postureChecksID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -128,7 +122,7 @@ func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Re
// DeletePostureCheck handles posture check deletion request
func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r)
- account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -141,7 +135,7 @@ func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http
return
}
- if err = p.accountManager.DeletePostureChecks(r.Context(), account.Id, postureChecksID, user.Id); err != nil {
+ if err = p.accountManager.DeletePostureChecks(r.Context(), accountID, postureChecksID, userID); err != nil {
util.WriteError(r.Context(), err, w)
return
}
@@ -150,13 +144,7 @@ func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http
}
// savePostureChecks handles posture checks create and update
-func (p *PostureChecksHandler) savePostureChecks(
- w http.ResponseWriter,
- r *http.Request,
- account *server.Account,
- user *server.User,
- postureChecksID string,
-) {
+func (p *PostureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http.Request, accountID, userID, postureChecksID string) {
var (
err error
req api.PostureCheckUpdate
@@ -181,7 +169,7 @@ func (p *PostureChecksHandler) savePostureChecks(
return
}
- if err := p.accountManager.SavePostureChecks(r.Context(), account.Id, user.Id, postureChecks); err != nil {
+ if err := p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks); err != nil {
util.WriteError(r.Context(), err, w)
return
}
diff --git a/management/server/http/posture_checks_handler_test.go b/management/server/http/posture_checks_handler_test.go
index 974edafde..02f0f0d83 100644
--- a/management/server/http/posture_checks_handler_test.go
+++ b/management/server/http/posture_checks_handler_test.go
@@ -14,7 +14,6 @@ import (
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
- "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims"
@@ -67,15 +66,8 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
}
return accountPostureChecks, nil
},
- GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
- user := server.NewAdminUser("test_user")
- return &server.Account{
- Id: claims.AccountId,
- Users: map[string]*server.User{
- "test_user": user,
- },
- PostureChecks: postureChecks,
- }, user, nil
+ GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
+ return claims.AccountId, claims.UserId, nil
},
},
geolocationManager: &geolocation.Geolocation{},
diff --git a/management/server/http/routes_handler.go b/management/server/http/routes_handler.go
index 18c347334..0932e6445 100644
--- a/management/server/http/routes_handler.go
+++ b/management/server/http/routes_handler.go
@@ -43,13 +43,13 @@ func NewRoutesHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ro
// GetAllRoutes returns the list of routes for the account
func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
- routes, err := h.accountManager.ListRoutes(r.Context(), account.Id, user.Id)
+ routes, err := h.accountManager.ListRoutes(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -70,7 +70,7 @@ func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) {
// CreateRoute handles route creation request
func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -117,15 +117,9 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
peerGroupIds = *req.PeerGroups
}
- // Do not allow non-Linux peers
- if peer := account.GetPeer(peerId); peer != nil {
- if peer.Meta.GoOS != "linux" {
- util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes"), w)
- return
- }
- }
-
- newRoute, err := h.accountManager.CreateRoute(r.Context(), account.Id, newPrefix, networkType, domains, peerId, peerGroupIds, req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, user.Id, req.KeepRoute)
+ newRoute, err := h.accountManager.CreateRoute(r.Context(), accountID, newPrefix, networkType, domains, peerId, peerGroupIds,
+ req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, userID, req.KeepRoute,
+ )
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -168,7 +162,7 @@ func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) erro
// UpdateRoute handles update to a route identified by a given ID
func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -181,7 +175,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
return
}
- _, err = h.accountManager.GetRoute(r.Context(), account.Id, route.ID(routeID), user.Id)
+ _, err = h.accountManager.GetRoute(r.Context(), accountID, route.ID(routeID), userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -204,14 +198,6 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
peerID = *req.Peer
}
- // do not allow non Linux peers
- if peer := account.GetPeer(peerID); peer != nil {
- if peer.Meta.GoOS != "linux" {
- util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "non-linux peers are non supported as network routes"), w)
- return
- }
- }
-
newRoute := &route.Route{
ID: route.ID(routeID),
NetID: route.NetID(req.NetworkId),
@@ -247,7 +233,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
newRoute.PeerGroups = *req.PeerGroups
}
- err = h.accountManager.SaveRoute(r.Context(), account.Id, user.Id, newRoute)
+ err = h.accountManager.SaveRoute(r.Context(), accountID, userID, newRoute)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -265,7 +251,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
// DeleteRoute handles route deletion request
func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -277,7 +263,7 @@ func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) {
return
}
- err = h.accountManager.DeleteRoute(r.Context(), account.Id, route.ID(routeID), user.Id)
+ err = h.accountManager.DeleteRoute(r.Context(), accountID, route.ID(routeID), userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -289,7 +275,7 @@ func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) {
// GetRoute handles a route Get request identified by ID
func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -301,7 +287,7 @@ func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) {
return
}
- foundRoute, err := h.accountManager.GetRoute(r.Context(), account.Id, route.ID(routeID), user.Id)
+ foundRoute, err := h.accountManager.GetRoute(r.Context(), accountID, route.ID(routeID), userID)
if err != nil {
util.WriteError(r.Context(), status.Errorf(status.NotFound, "route not found"), w)
return
diff --git a/management/server/http/routes_handler_test.go b/management/server/http/routes_handler_test.go
index 40075eb9d..2c367cac3 100644
--- a/management/server/http/routes_handler_test.go
+++ b/management/server/http/routes_handler_test.go
@@ -112,6 +112,12 @@ func initRoutesTestData() *RoutesHandler {
if len(peerGroups) > 0 && peerGroups[0] == notFoundGroupID {
return nil, status.Errorf(status.InvalidArgument, "peer groups with ID %s not found", peerGroups[0])
}
+ if peerID != "" {
+ if peerID == nonLinuxExistingPeerID {
+ return nil, status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
+ }
+ }
+
return &route.Route{
ID: existingRouteID,
NetID: netID,
@@ -131,6 +137,11 @@ func initRoutesTestData() *RoutesHandler {
if r.Peer == notFoundPeerID {
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", r.Peer)
}
+
+ if r.Peer == nonLinuxExistingPeerID {
+ return status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
+ }
+
return nil
},
DeleteRouteFunc: func(_ context.Context, _ string, routeID route.ID, _ string) error {
@@ -139,8 +150,9 @@ func initRoutesTestData() *RoutesHandler {
}
return nil
},
- GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
- return testingAccount, testingAccount.Users["test_user"], nil
+ GetAccountIDFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (string, string, error) {
+ //return testingAccount, testingAccount.Users["test_user"], nil
+ return testingAccount.Id, testingAccount.Users["test_user"].Id, nil
},
},
claimsExtractor: jwtclaims.NewClaimsExtractor(
diff --git a/management/server/http/setupkeys_handler.go b/management/server/http/setupkeys_handler.go
index 8ee7dfaba..8514f0b55 100644
--- a/management/server/http/setupkeys_handler.go
+++ b/management/server/http/setupkeys_handler.go
@@ -35,7 +35,7 @@ func NewSetupKeysHandler(accountManager server.AccountManager, authCfg AuthCfg)
// CreateSetupKey is a POST requests that creates a new SetupKey
func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -76,8 +76,8 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request
if req.Ephemeral != nil {
ephemeral = *req.Ephemeral
}
- setupKey, err := h.accountManager.CreateSetupKey(r.Context(), account.Id, req.Name, server.SetupKeyType(req.Type), expiresIn,
- req.AutoGroups, req.UsageLimit, user.Id, ephemeral)
+ setupKey, err := h.accountManager.CreateSetupKey(r.Context(), accountID, req.Name, server.SetupKeyType(req.Type), expiresIn,
+ req.AutoGroups, req.UsageLimit, userID, ephemeral)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -89,7 +89,7 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request
// GetSetupKey is a GET request to get a SetupKey by ID
func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -102,7 +102,7 @@ func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) {
return
}
- key, err := h.accountManager.GetSetupKey(r.Context(), account.Id, user.Id, keyID)
+ key, err := h.accountManager.GetSetupKey(r.Context(), accountID, userID, keyID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -114,7 +114,7 @@ func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) {
// UpdateSetupKey is a PUT request to update server.SetupKey
func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -150,7 +150,7 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
newKey.Name = req.Name
newKey.Id = keyID
- newKey, err = h.accountManager.SaveSetupKey(r.Context(), account.Id, newKey, user.Id)
+ newKey, err = h.accountManager.SaveSetupKey(r.Context(), accountID, newKey, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -161,13 +161,13 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
// GetAllSetupKeys is a GET request that returns a list of SetupKey
func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
- setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), account.Id, user.Id)
+ setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
diff --git a/management/server/http/setupkeys_handler_test.go b/management/server/http/setupkeys_handler_test.go
index bfa0ec008..2d15287af 100644
--- a/management/server/http/setupkeys_handler_test.go
+++ b/management/server/http/setupkeys_handler_test.go
@@ -15,7 +15,6 @@ import (
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server"
- nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server"
@@ -34,21 +33,8 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
) *SetupKeysHandler {
return &SetupKeysHandler{
accountManager: &mock_server.MockAccountManager{
- GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
- return &server.Account{
- Id: testAccountID,
- Domain: "hotmail.com",
- Users: map[string]*server.User{
- user.Id: user,
- },
- SetupKeys: map[string]*server.SetupKey{
- defaultKey.Key: defaultKey,
- },
- Groups: map[string]*nbgroup.Group{
- "group-1": {ID: "group-1", Peers: []string{"A", "B"}},
- "id-all": {ID: "id-all", Name: "All"},
- },
- }, user, nil
+ GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
+ return claims.AccountId, claims.UserId, nil
},
CreateSetupKeyFunc: func(_ context.Context, _ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string,
_ int, _ string, ephemeral bool,
diff --git a/management/server/http/users_handler.go b/management/server/http/users_handler.go
index 2c2aed842..6e151a0da 100644
--- a/management/server/http/users_handler.go
+++ b/management/server/http/users_handler.go
@@ -41,22 +41,22 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
}
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
- userID := vars["userId"]
- if len(userID) == 0 {
+ targetUserID := vars["userId"]
+ if len(targetUserID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return
}
- existingUser, ok := account.Users[userID]
- if !ok {
- util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find user with ID %s", userID), w)
+ existingUser, err := h.accountManager.GetUserByID(r.Context(), targetUserID)
+ if err != nil {
+ util.WriteError(r.Context(), err, w)
return
}
@@ -78,8 +78,8 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
return
}
- newUser, err := h.accountManager.SaveUser(r.Context(), account.Id, user.Id, &server.User{
- Id: userID,
+ newUser, err := h.accountManager.SaveUser(r.Context(), accountID, userID, &server.User{
+ Id: targetUserID,
Role: userRole,
AutoGroups: req.AutoGroups,
Blocked: req.IsBlocked,
@@ -102,7 +102,7 @@ func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) {
}
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -115,7 +115,7 @@ func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) {
return
}
- err = h.accountManager.DeleteUser(r.Context(), account.Id, user.Id, targetUserID)
+ err = h.accountManager.DeleteUser(r.Context(), accountID, userID, targetUserID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -132,7 +132,7 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
}
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -160,7 +160,7 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
name = *req.Name
}
- newUser, err := h.accountManager.CreateUser(r.Context(), account.Id, user.Id, &server.UserInfo{
+ newUser, err := h.accountManager.CreateUser(r.Context(), accountID, userID, &server.UserInfo{
Email: email,
Name: name,
Role: req.Role,
@@ -184,13 +184,13 @@ func (h *UsersHandler) GetAllUsers(w http.ResponseWriter, r *http.Request) {
}
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
- data, err := h.accountManager.GetUsersFromAccount(r.Context(), account.Id, user.Id)
+ data, err := h.accountManager.GetUsersFromAccount(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -231,7 +231,7 @@ func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) {
}
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -244,7 +244,7 @@ func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) {
return
}
- err = h.accountManager.InviteUser(r.Context(), account.Id, user.Id, targetUserID)
+ err = h.accountManager.InviteUser(r.Context(), accountID, userID, targetUserID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
diff --git a/management/server/http/users_handler_test.go b/management/server/http/users_handler_test.go
index a78ac3a4e..f3d989da1 100644
--- a/management/server/http/users_handler_test.go
+++ b/management/server/http/users_handler_test.go
@@ -64,8 +64,11 @@ var usersTestAccount = &server.Account{
func initUsersTestData() *UsersHandler {
return &UsersHandler{
accountManager: &mock_server.MockAccountManager{
- GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
- return usersTestAccount, usersTestAccount.Users[claims.UserId], nil
+ GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
+ return usersTestAccount.Id, claims.UserId, nil
+ },
+ GetUserByIDFunc: func(ctx context.Context, id string) (*server.User, error) {
+ return usersTestAccount.Users[id], nil
},
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) {
users := make([]*server.UserInfo, 0)
diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go
index 495325252..df12ec1c4 100644
--- a/management/server/mock_server/account_mock.go
+++ b/management/server/mock_server/account_mock.go
@@ -23,10 +23,11 @@ import (
type MockAccountManager struct {
GetOrCreateAccountByUserFunc func(ctx context.Context, userId, domain string) (*server.Account, error)
+ GetAccountFunc func(ctx context.Context, accountID string) (*server.Account, error)
CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType,
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error)
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error)
- GetAccountByUserOrAccountIdFunc func(ctx context.Context, userId, accountId, domain string) (*server.Account, error)
+ GetAccountIDByUserOrAccountIdFunc func(ctx context.Context, userId, accountId, domain string) (string, error)
GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error)
ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error)
GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
@@ -48,7 +49,7 @@ type MockAccountManager struct {
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error
GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error)
- SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy) error
+ SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error
DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error
ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error)
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error)
@@ -79,7 +80,7 @@ type MockAccountManager struct {
DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error
ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error)
CreateUserFunc func(ctx context.Context, accountID, userID string, key *server.UserInfo) (*server.UserInfo, error)
- GetAccountFromTokenFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error)
+ GetAccountIDFromTokenFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
DeleteAccountFunc func(ctx context.Context, accountID, userID string) error
GetDNSDomainFunc func() string
@@ -105,6 +106,9 @@ type MockAccountManager struct {
SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error)
+ GetAccountByIDFunc func(ctx context.Context, accountID string, userID string) (*server.Account, error)
+ GetUserByIDFunc func(ctx context.Context, id string) (*server.User, error)
+ GetAccountSettingsFunc func(ctx context.Context, accountID string, userID string) (*server.Settings, error)
}
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
@@ -190,16 +194,14 @@ func (am *MockAccountManager) CreateSetupKey(
return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented")
}
-// GetAccountByUserOrAccountID mock implementation of GetAccountByUserOrAccountID from server.AccountManager interface
-func (am *MockAccountManager) GetAccountByUserOrAccountID(
- ctx context.Context, userId, accountId, domain string,
-) (*server.Account, error) {
- if am.GetAccountByUserOrAccountIdFunc != nil {
- return am.GetAccountByUserOrAccountIdFunc(ctx, userId, accountId, domain)
+// GetAccountIDByUserOrAccountID mock implementation of GetAccountIDByUserOrAccountID from server.AccountManager interface
+func (am *MockAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userId, accountId, domain string) (string, error) {
+ if am.GetAccountIDByUserOrAccountIdFunc != nil {
+ return am.GetAccountIDByUserOrAccountIdFunc(ctx, userId, accountId, domain)
}
- return nil, status.Errorf(
+ return "", status.Errorf(
codes.Unimplemented,
- "method GetAccountByUserOrAccountID is not implemented",
+ "method GetAccountIDByUserOrAccountID is not implemented",
)
}
@@ -377,9 +379,9 @@ func (am *MockAccountManager) GetPolicy(ctx context.Context, accountID, policyID
}
// SavePolicy mock implementation of SavePolicy from server.AccountManager interface
-func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy) error {
+func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error {
if am.SavePolicyFunc != nil {
- return am.SavePolicyFunc(ctx, accountID, userID, policy)
+ return am.SavePolicyFunc(ctx, accountID, userID, policy, isUpdate)
}
return status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented")
}
@@ -601,14 +603,12 @@ func (am *MockAccountManager) CreateUser(ctx context.Context, accountID, userID
return nil, status.Errorf(codes.Unimplemented, "method CreateUser is not implemented")
}
-// GetAccountFromToken mocks GetAccountFromToken of the AccountManager interface
-func (am *MockAccountManager) GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User,
- error,
-) {
- if am.GetAccountFromTokenFunc != nil {
- return am.GetAccountFromTokenFunc(ctx, claims)
+// GetAccountIDFromToken mocks GetAccountIDFromToken of the AccountManager interface
+func (am *MockAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
+ if am.GetAccountIDFromTokenFunc != nil {
+ return am.GetAccountIDFromTokenFunc(ctx, claims)
}
- return nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromToken is not implemented")
+ return "", "", status.Errorf(codes.Unimplemented, "method GetAccountIDFromToken is not implemented")
}
func (am *MockAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error {
@@ -802,3 +802,33 @@ func (am *MockAccountManager) GetAccountIDForPeerKey(ctx context.Context, peerKe
}
return "", status.Errorf(codes.Unimplemented, "method GetAccountIDForPeerKey is not implemented")
}
+
+// GetAccountByID mocks GetAccountByID of the AccountManager interface
+func (am *MockAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*server.Account, error) {
+ if am.GetAccountByIDFunc != nil {
+ return am.GetAccountByIDFunc(ctx, accountID, userID)
+ }
+ return nil, status.Errorf(codes.Unimplemented, "method GetAccountByID is not implemented")
+}
+
+// GetUserByID mocks GetUserByID of the AccountManager interface
+func (am *MockAccountManager) GetUserByID(ctx context.Context, id string) (*server.User, error) {
+ if am.GetUserByIDFunc != nil {
+ return am.GetUserByIDFunc(ctx, id)
+ }
+ return nil, status.Errorf(codes.Unimplemented, "method GetUserByID is not implemented")
+}
+
+func (am *MockAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*server.Settings, error) {
+ if am.GetAccountSettingsFunc != nil {
+ return am.GetAccountSettingsFunc(ctx, accountID, userID)
+ }
+ return nil, status.Errorf(codes.Unimplemented, "method GetAccountSettings is not implemented")
+}
+
+func (am *MockAccountManager) GetAccount(ctx context.Context, accountID string) (*server.Account, error) {
+ if am.GetAccountFunc != nil {
+ return am.GetAccountFunc(ctx, accountID)
+ }
+ return nil, status.Errorf(codes.Unimplemented, "method GetAccount is not implemented")
+}
diff --git a/management/server/nameserver.go b/management/server/nameserver.go
index 636f7cfee..0eb5d9ae4 100644
--- a/management/server/nameserver.go
+++ b/management/server/nameserver.go
@@ -19,30 +19,16 @@ const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$`
// GetNameServerGroup gets a nameserver group object from account and nameserver group IDs
func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
-
- unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
- defer unlock()
-
- account, err := am.Store.GetAccount(ctx, accountID)
+ user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
- user, err := account.FindUser(userID)
- if err != nil {
- return nil, err
+ if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
+ return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups")
}
- if !(user.HasAdminPower() || user.IsServiceUser) {
- return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view nameserver groups")
- }
-
- nsGroup, found := account.NameServerGroups[nsGroupID]
- if found {
- return nsGroup.Copy(), nil
- }
-
- return nil, status.Errorf(status.NotFound, "nameserver group with ID %s not found", nsGroupID)
+ return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, nsGroupID, accountID)
}
// CreateNameServerGroup creates and saves a new nameserver group
@@ -159,30 +145,16 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
// ListNameServerGroups returns a list of nameserver groups from account
func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) {
-
- unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
- defer unlock()
-
- account, err := am.Store.GetAccount(ctx, accountID)
+ user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
- user, err := account.FindUser(userID)
- if err != nil {
- return nil, err
- }
-
- if !(user.HasAdminPower() || user.IsServiceUser) {
+ if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups")
}
- nsGroups := make([]*nbdns.NameServerGroup, 0, len(account.NameServerGroups))
- for _, item := range account.NameServerGroups {
- nsGroups = append(nsGroups, item.Copy())
- }
-
- return nsGroups, nil
+ return am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
}
func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServerGroup, account *Account) error {
diff --git a/management/server/peer_test.go b/management/server/peer_test.go
index 4b2ec66c6..d329e04bc 100644
--- a/management/server/peer_test.go
+++ b/management/server/peer_test.go
@@ -251,7 +251,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
Action: PolicyTrafficActionAccept,
},
}
- err = manager.SavePolicy(context.Background(), account.Id, userID, &policy)
+ err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
if err != nil {
t.Errorf("expecting rule to be added, got failure %v", err)
return
@@ -299,7 +299,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
}
policy.Enabled = false
- err = manager.SavePolicy(context.Background(), account.Id, userID, &policy)
+ err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
if err != nil {
t.Errorf("expecting rule to be added, got failure %v", err)
return
diff --git a/management/server/policy.go b/management/server/policy.go
index aaf9b6e72..5d07ba8f8 100644
--- a/management/server/policy.go
+++ b/management/server/policy.go
@@ -3,6 +3,7 @@ package server
import (
"context"
_ "embed"
+ "slices"
"strconv"
"strings"
@@ -314,34 +315,20 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule,
// GetPolicy from the store
func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) {
- unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
- defer unlock()
-
- account, err := am.Store.GetAccount(ctx, accountID)
+ user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
- user, err := account.FindUser(userID)
- if err != nil {
- return nil, err
- }
-
- if !(user.HasAdminPower() || user.IsServiceUser) {
+ if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies")
}
- for _, policy := range account.Policies {
- if policy.ID == policyID {
- return policy, nil
- }
- }
-
- return nil, status.Errorf(status.NotFound, "policy with ID %s not found", policyID)
+ return am.Store.GetPolicyByID(ctx, LockingStrengthShare, policyID, accountID)
}
// SavePolicy in the store
-func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) error {
+func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
@@ -350,7 +337,9 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
return err
}
- exists := am.savePolicy(account, policy)
+ if err = am.savePolicy(account, policy, isUpdate); err != nil {
+ return err
+ }
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
@@ -358,7 +347,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
}
action := activity.PolicyAdded
- if exists {
+ if isUpdate {
action = activity.PolicyUpdated
}
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
@@ -397,24 +386,16 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
// ListPolicies from the store
func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) {
- unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
- defer unlock()
-
- account, err := am.Store.GetAccount(ctx, accountID)
+ user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
- user, err := account.FindUser(userID)
- if err != nil {
- return nil, err
+ if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
+ return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies")
}
- if !(user.HasAdminPower() || user.IsServiceUser) {
- return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view policies")
- }
-
- return account.Policies, nil
+ return am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
}
func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) (*Policy, error) {
@@ -434,18 +415,34 @@ func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string)
return policy, nil
}
-func (am *DefaultAccountManager) savePolicy(account *Account, policy *Policy) (exists bool) {
- for i, p := range account.Policies {
- if p.ID == policy.ID {
- account.Policies[i] = policy
- exists = true
- break
+// savePolicy saves or updates a policy in the given account.
+// If isUpdate is true, the function updates the existing policy; otherwise, it adds a new policy.
+func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) error {
+ for index, rule := range policyToSave.Rules {
+ rule.Sources = filterValidGroupIDs(account, rule.Sources)
+ rule.Destinations = filterValidGroupIDs(account, rule.Destinations)
+ policyToSave.Rules[index] = rule
+ }
+
+ if policyToSave.SourcePostureChecks != nil {
+ policyToSave.SourcePostureChecks = filterValidPostureChecks(account, policyToSave.SourcePostureChecks)
+ }
+
+ if isUpdate {
+ policyIdx := slices.IndexFunc(account.Policies, func(policy *Policy) bool { return policy.ID == policyToSave.ID })
+ if policyIdx < 0 {
+ return status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID)
}
+
+ // Update the existing policy
+ account.Policies[policyIdx] = policyToSave
+ return nil
}
- if !exists {
- account.Policies = append(account.Policies, policy)
- }
- return
+
+ // Add the new policy to the account
+ account.Policies = append(account.Policies, policyToSave)
+
+ return nil
}
func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule {
@@ -560,3 +557,29 @@ func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks {
}
return nil
}
+
+// filterValidPostureChecks filters and returns the posture check IDs from the given list
+// that are valid within the provided account.
+func filterValidPostureChecks(account *Account, postureChecksIds []string) []string {
+ result := make([]string, 0, len(postureChecksIds))
+ for _, id := range postureChecksIds {
+ for _, postureCheck := range account.PostureChecks {
+ if id == postureCheck.ID {
+ result = append(result, id)
+ continue
+ }
+ }
+ }
+ return result
+}
+
+// filterValidGroupIDs filters a list of group IDs and returns only the ones present in the account's group map.
+func filterValidGroupIDs(account *Account, groupIDs []string) []string {
+ result := make([]string, 0, len(groupIDs))
+ for _, groupID := range groupIDs {
+ if _, exists := account.Groups[groupID]; exists {
+ result = append(result, groupID)
+ }
+ }
+ return result
+}
diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go
index 4180550e6..9a4b679ce 100644
--- a/management/server/posture_checks.go
+++ b/management/server/posture_checks.go
@@ -15,30 +15,16 @@ const (
)
func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
- unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
- defer unlock()
-
- account, err := am.Store.GetAccount(ctx, accountID)
+ user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
- user, err := account.FindUser(userID)
- if err != nil {
- return nil, err
- }
-
- if !user.HasAdminPower() {
+ if !user.HasAdminPower() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
}
- for _, postureChecks := range account.PostureChecks {
- if postureChecks.ID == postureChecksID {
- return postureChecks, nil
- }
- }
-
- return nil, status.Errorf(status.NotFound, "posture checks with ID %s not found", postureChecksID)
+ return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecksID, accountID)
}
func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error {
@@ -121,24 +107,16 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
}
func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) {
- unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
- defer unlock()
-
- account, err := am.Store.GetAccount(ctx, accountID)
+ user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
- user, err := account.FindUser(userID)
- if err != nil {
- return nil, err
- }
-
- if !user.HasAdminPower() {
+ if !user.HasAdminPower() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
}
- return account.PostureChecks, nil
+ return am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
}
func (am *DefaultAccountManager) savePostureChecks(account *Account, postureChecks *posture.Checks) (exists, uniqName bool) {
diff --git a/management/server/route.go b/management/server/route.go
index 064f3c105..6c1c8b1b3 100644
--- a/management/server/route.go
+++ b/management/server/route.go
@@ -17,29 +17,16 @@ import (
// GetRoute gets a route object from account and route IDs
func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) {
- unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
- defer unlock()
-
- account, err := am.Store.GetAccount(ctx, accountID)
+ user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
- user, err := account.FindUser(userID)
- if err != nil {
- return nil, err
- }
-
- if !(user.HasAdminPower() || user.IsServiceUser) {
+ if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes")
}
- wantedRoute, found := account.Routes[routeID]
- if found {
- return wantedRoute, nil
- }
-
- return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID)
+ return am.Store.GetRouteByID(ctx, LockingStrengthShare, string(routeID), accountID)
}
// checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
@@ -134,6 +121,13 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
return nil, err
}
+ // Do not allow non-Linux peers
+ if peer := account.GetPeer(peerID); peer != nil {
+ if peer.Meta.GoOS != "linux" {
+ return nil, status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
+ }
+ }
+
if len(domains) > 0 && prefix.IsValid() {
return nil, status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
}
@@ -234,6 +228,13 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
return err
}
+ // Do not allow non-Linux peers
+ if peer := account.GetPeer(routeToSave.Peer); peer != nil {
+ if peer.Meta.GoOS != "linux" {
+ return status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
+ }
+ }
+
if len(routeToSave.Domains) > 0 && routeToSave.Network.IsValid() {
return status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
}
@@ -311,29 +312,16 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
// ListRoutes returns a list of routes from account
func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) {
- unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
- defer unlock()
-
- account, err := am.Store.GetAccount(ctx, accountID)
+ user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
- user, err := account.FindUser(userID)
- if err != nil {
- return nil, err
- }
-
- if !(user.HasAdminPower() || user.IsServiceUser) {
+ if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes")
}
- routes := make([]*route.Route, 0, len(account.Routes))
- for _, item := range account.Routes {
- routes = append(routes, item)
- }
-
- return routes, nil
+ return am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID)
}
func toProtocolRoute(route *route.Route) *proto.Route {
diff --git a/management/server/route_test.go b/management/server/route_test.go
index 506bfb0a8..4533c6b7e 100644
--- a/management/server/route_test.go
+++ b/management/server/route_test.go
@@ -1205,7 +1205,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
newPolicy.Rules[0].Sources = []string{newGroup.ID}
newPolicy.Rules[0].Destinations = []string{newGroup.ID}
- err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy)
+ err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy, false)
require.NoError(t, err)
err = am.DeletePolicy(context.Background(), account.Id, defaultRule.ID, userID)
diff --git a/management/server/setupkey.go b/management/server/setupkey.go
index 859f1b0b9..9521e22d3 100644
--- a/management/server/setupkey.go
+++ b/management/server/setupkey.go
@@ -330,26 +330,24 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
// ListSetupKeys returns a list of all setup keys of the account
func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error) {
- unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
- defer unlock()
- account, err := am.Store.GetAccount(ctx, accountID)
+ user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
- user, err := account.FindUser(userID)
+ if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
+ return nil, status.Errorf(status.Unauthorized, "only users with admin power can view setup keys")
+ }
+
+ setupKeys, err := am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
- if !user.HasAdminPower() && !user.IsServiceUser {
- return nil, status.Errorf(status.Unauthorized, "only users with admin power can view policies")
- }
-
- keys := make([]*SetupKey, 0, len(account.SetupKeys))
- for _, key := range account.SetupKeys {
+ keys := make([]*SetupKey, 0, len(setupKeys))
+ for _, key := range setupKeys {
var k *SetupKey
- if !(user.HasAdminPower() || user.IsServiceUser) {
+ if !user.IsAdminOrServiceUser() {
k = key.HiddenCopy(999)
} else {
k = key.Copy()
@@ -362,44 +360,30 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u
// GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found.
func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) {
- unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
- defer unlock()
-
- account, err := am.Store.GetAccount(ctx, accountID)
+ user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
- user, err := account.FindUser(userID)
+ if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
+ return nil, status.Errorf(status.Unauthorized, "only users with admin power can view setup keys")
+ }
+
+ setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID)
if err != nil {
return nil, err
}
- if !user.HasAdminPower() && !user.IsServiceUser {
- return nil, status.Errorf(status.Unauthorized, "only users with admin power can view policies")
- }
-
- var foundKey *SetupKey
- for _, key := range account.SetupKeys {
- if key.Id == keyID {
- foundKey = key.Copy()
- break
- }
- }
- if foundKey == nil {
- return nil, status.Errorf(status.NotFound, "setup key not found")
- }
-
// the UpdatedAt field was introduced later, so there might be that some keys have a Zero value (e.g, null in the store file)
- if foundKey.UpdatedAt.IsZero() {
- foundKey.UpdatedAt = foundKey.CreatedAt
+ if setupKey.UpdatedAt.IsZero() {
+ setupKey.UpdatedAt = setupKey.CreatedAt
}
- if !(user.HasAdminPower() || user.IsServiceUser) {
- foundKey = foundKey.HiddenCopy(999)
+ if !user.IsAdminOrServiceUser() {
+ setupKey = setupKey.HiddenCopy(999)
}
- return foundKey, nil
+ return setupKey, nil
}
func validateSetupKeyAutoGroups(account *Account, autoGroups []string) error {
diff --git a/management/server/sql_store.go b/management/server/sql_store.go
index 8fa5f9d05..85c68ef44 100644
--- a/management/server/sql_store.go
+++ b/management/server/sql_store.go
@@ -36,6 +36,7 @@ const (
idQueryCondition = "id = ?"
keyQueryCondition = "key = ?"
accountAndIDQueryCondition = "account_id = ? and id = ?"
+ accountIDCondition = "account_id = ?"
peerNotFoundFMT = "peer %s not found"
)
@@ -399,20 +400,30 @@ func (s *SqlStore) DeleteTokenID2UserIDIndex(tokenID string) error {
}
func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) {
- var account Account
-
- result := s.db.First(&account, "domain = ? and is_domain_primary_account = ? and domain_category = ?",
- strings.ToLower(domain), true, PrivateCategory)
- if result.Error != nil {
- if errors.Is(result.Error, gorm.ErrRecordNotFound) {
- return nil, status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
- }
- log.WithContext(ctx).Errorf("error when getting account from the store: %s", result.Error)
- return nil, status.Errorf(status.Internal, "issue getting account from store")
+ accountID, err := s.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain)
+ if err != nil {
+ return nil, err
}
// TODO: rework to not call GetAccount
- return s.GetAccount(ctx, account.Id)
+ return s.GetAccount(ctx, accountID)
+}
+
+func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) {
+ var accountID string
+ result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("id").
+ Where("domain = ? and is_domain_primary_account = ? and domain_category = ?",
+ strings.ToLower(domain), true, PrivateCategory,
+ ).First(&accountID)
+ if result.Error != nil {
+ if errors.Is(result.Error, gorm.ErrRecordNotFound) {
+ return "", status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
+ }
+ log.WithContext(ctx).Errorf("error when getting account from the store: %s", result.Error)
+ return "", status.Errorf(status.Internal, "issue getting account from store")
+ }
+
+ return accountID, nil
}
func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) {
@@ -478,7 +489,7 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User,
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) {
var user User
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
- First(&user, idQueryCondition, userID)
+ Preload(clause.Associations).First(&user, idQueryCondition, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewUserNotFoundError(userID)
@@ -491,7 +502,7 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
var groups []*nbgroup.Group
- result := s.db.Find(&groups, idQueryCondition, accountID)
+ result := s.db.Find(&groups, accountIDCondition, accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
@@ -661,9 +672,8 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string)
}
func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
- var user User
var accountID string
- result := s.db.Model(&user).Select("account_id").Where(idQueryCondition, userID).First(&accountID)
+ result := s.db.Model(&User{}).Select("account_id").Where(idQueryCondition, userID).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
@@ -1028,3 +1038,152 @@ func (s *SqlStore) withTx(tx *gorm.DB) Store {
func (s *SqlStore) GetDB() *gorm.DB {
return s.db
}
+
+func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error) {
+ var accountDNSSettings AccountDNSSettings
+
+ result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
+ First(&accountDNSSettings, idQueryCondition, accountID)
+ if result.Error != nil {
+ if errors.Is(result.Error, gorm.ErrRecordNotFound) {
+ return nil, status.Errorf(status.NotFound, "dns settings not found")
+ }
+ return nil, status.Errorf(status.Internal, "failed to get dns settings from store: %v", result.Error)
+ }
+ return &accountDNSSettings.DNSSettings, nil
+}
+
+// AccountExists checks whether an account exists by the given ID.
+func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) {
+ var accountID string
+
+ result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
+ Select("id").First(&accountID, idQueryCondition, id)
+ if result.Error != nil {
+ if errors.Is(result.Error, gorm.ErrRecordNotFound) {
+ return false, nil
+ }
+ return false, result.Error
+ }
+
+ return accountID != "", nil
+}
+
+// GetAccountDomainAndCategory retrieves the Domain and DomainCategory fields for an account based on the given accountID.
+func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) {
+ var account Account
+
+ result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("domain", "domain_category").
+ Where(idQueryCondition, accountID).First(&account)
+ if result.Error != nil {
+ if errors.Is(result.Error, gorm.ErrRecordNotFound) {
+ return "", "", status.Errorf(status.NotFound, "account not found")
+ }
+ return "", "", status.Errorf(status.Internal, "failed to get domain category from store: %v", result.Error)
+ }
+
+ return account.Domain, account.DomainCategory, nil
+}
+
+// GetGroupByID retrieves a group by ID and account ID.
+func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) {
+ return getRecordByID[nbgroup.Group](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, groupID, accountID)
+}
+
+// GetGroupByName retrieves a group by name and account ID.
+func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) {
+ var group nbgroup.Group
+
+ result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations).
+ Order("json_array_length(peers) DESC").First(&group, "name = ? and account_id = ?", groupName, accountID)
+ if err := result.Error; err != nil {
+ if errors.Is(result.Error, gorm.ErrRecordNotFound) {
+ return nil, status.Errorf(status.NotFound, "group not found")
+ }
+ return nil, status.Errorf(status.Internal, "failed to get group from store: %s", result.Error)
+ }
+ return &group, nil
+}
+
+// GetAccountPolicies retrieves policies for an account.
+func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) {
+ return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID)
+}
+
+// GetPolicyByID retrieves a policy by its ID and account ID.
+func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) {
+ return getRecordByID[Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, policyID, accountID)
+}
+
+// GetAccountPostureChecks retrieves posture checks for an account.
+func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) {
+ return getRecords[*posture.Checks](s.db.WithContext(ctx), lockStrength, accountID)
+}
+
+// GetPostureChecksByID retrieves posture checks by their ID and account ID.
+func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) {
+ return getRecordByID[posture.Checks](s.db.WithContext(ctx), lockStrength, postureCheckID, accountID)
+}
+
+// GetAccountRoutes retrieves network routes for an account.
+func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) {
+ return getRecords[*route.Route](s.db.WithContext(ctx), lockStrength, accountID)
+}
+
+// GetRouteByID retrieves a route by its ID and account ID.
+func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) {
+ return getRecordByID[route.Route](s.db.WithContext(ctx), lockStrength, routeID, accountID)
+}
+
+// GetAccountSetupKeys retrieves setup keys for an account.
+func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) {
+ return getRecords[*SetupKey](s.db.WithContext(ctx), lockStrength, accountID)
+}
+
+// GetSetupKeyByID retrieves a setup key by its ID and account ID.
+func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error) {
+ return getRecordByID[SetupKey](s.db.WithContext(ctx), lockStrength, setupKeyID, accountID)
+}
+
+// GetAccountNameServerGroups retrieves name server groups for an account.
+func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) {
+ return getRecords[*nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, accountID)
+}
+
+// GetNameServerGroupByID retrieves a name server group by its ID and account ID.
+func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nsGroupID string, accountID string) (*nbdns.NameServerGroup, error) {
+ return getRecordByID[nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, nsGroupID, accountID)
+}
+
+// getRecords retrieves records from the database based on the account ID.
+func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) {
+ var record []T
+
+ result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&record, accountIDCondition, accountID)
+ if err := result.Error; err != nil {
+ parts := strings.Split(fmt.Sprintf("%T", record), ".")
+ recordType := parts[len(parts)-1]
+
+ return nil, status.Errorf(status.Internal, "failed to get account %ss from store: %v", recordType, err)
+ }
+
+ return record, nil
+}
+
+// getRecordByID retrieves a record by its ID and account ID from the database.
+func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, accountID string) (*T, error) {
+ var record T
+
+ result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).
+ First(&record, accountAndIDQueryCondition, accountID, recordID)
+ if err := result.Error; err != nil {
+ parts := strings.Split(fmt.Sprintf("%T", record), ".")
+ recordType := parts[len(parts)-1]
+
+ if errors.Is(result.Error, gorm.ErrRecordNotFound) {
+ return nil, status.Errorf(status.NotFound, "%s not found", recordType)
+ }
+ return nil, status.Errorf(status.Internal, "failed to get %s from store: %v", recordType, err)
+ }
+ return &record, nil
+}
diff --git a/management/server/store.go b/management/server/store.go
index 84b3b140c..f34a73c2d 100644
--- a/management/server/store.go
+++ b/management/server/store.go
@@ -12,6 +12,7 @@ import (
"strings"
"time"
+ "github.com/netbirdio/netbird/dns"
log "github.com/sirupsen/logrus"
"gorm.io/gorm"
@@ -39,53 +40,81 @@ const (
type Store interface {
GetAllAccounts(ctx context.Context) []*Account
GetAccount(ctx context.Context, accountID string) (*Account, error)
- DeleteAccount(ctx context.Context, account *Account) error
+ AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error)
+ GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error)
GetAccountByUser(ctx context.Context, userID string) (*Account, error)
GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error)
GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error)
- GetAccountIDByUserID(peerKey string) (string, error)
+ GetAccountIDByUserID(userID string) (string, error)
GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error)
GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error)
GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later
GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error)
- GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
+ GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error)
+ GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error)
+ GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error)
+ SaveAccount(ctx context.Context, account *Account) error
+ DeleteAccount(ctx context.Context, account *Account) error
+
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
- GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
- GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
- SaveAccount(ctx context.Context, account *Account) error
SaveUsers(accountID string, users map[string]*User) error
- SaveGroups(accountID string, groups map[string]*nbgroup.Group) error
+ SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
+ GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
DeleteTokenID2UserIDIndex(tokenID string) error
+
+ GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
+ GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error)
+ GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error)
+ SaveGroups(accountID string, groups map[string]*nbgroup.Group) error
+
+ GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error)
+ GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error)
+
+ GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
+ GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error)
+ GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error)
+
+ GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
+ AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
+ AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
+ AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
+ GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
+ SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
+ SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
+ SavePeerLocation(accountID string, peer *nbpeer.Peer) error
+
+ GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)
+ IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
+ GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error)
+ GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error)
+
+ GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error)
+ GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error)
+
+ GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error)
+ GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)
+
+ GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
+ IncrementNetworkSerial(ctx context.Context, accountId string) error
+ GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)
+
GetInstallationID() string
SaveInstallationID(ctx context.Context, ID string) error
+
// AcquireWriteLockByUID should attempt to acquire a lock for write purposes and return a function that releases the lock
AcquireWriteLockByUID(ctx context.Context, uniqueID string) func()
// AcquireReadLockByUID should attempt to acquire lock for read purposes and return a function that releases the lock
AcquireReadLockByUID(ctx context.Context, uniqueID string) func()
// AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock
AcquireGlobalLock(ctx context.Context) func()
- SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
- SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
- SavePeerLocation(accountID string, peer *nbpeer.Peer) error
- SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
+
// Close should close the store persisting all unsaved data.
Close(ctx context.Context) error
// GetStoreEngine should return StoreEngine of the current store implementation.
// This is also a method of metrics.DataSource interface.
GetStoreEngine() StoreEngine
- GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
- GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error)
- GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)
- GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
- IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
- AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
- GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
- AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
- AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
- IncrementNetworkSerial(ctx context.Context, accountId string) error
- GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)
ExecuteInTransaction(ctx context.Context, f func(store Store) error) error
}
diff --git a/management/server/user.go b/management/server/user.go
index 9e60bb94b..6d01561c6 100644
--- a/management/server/user.go
+++ b/management/server/user.go
@@ -94,6 +94,11 @@ func (u *User) HasAdminPower() bool {
return u.Role == UserRoleAdmin || u.Role == UserRoleOwner
}
+// IsAdminOrServiceUser checks if the user has admin power or is a service user.
+func (u *User) IsAdminOrServiceUser() bool {
+ return u.HasAdminPower() || u.IsServiceUser
+}
+
// ToUserInfo converts a User object to a UserInfo object.
func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo, error) {
autoGroups := u.AutoGroups
@@ -357,39 +362,35 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
return newUser.ToUserInfo(idpUser, account.Settings)
}
+func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*User, error) {
+ return am.Store.GetUserByUserID(ctx, LockingStrengthShare, id)
+}
+
// GetUser looks up a user by provided authorization claims.
// It will also create an account if didn't exist for this user before.
func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) {
- account, _, err := am.GetAccountFromToken(ctx, claims)
+ accountID, userID, err := am.GetAccountIDFromToken(ctx, claims)
if err != nil {
return nil, fmt.Errorf("failed to get account with token claims %v", err)
}
- unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id)
- defer unlock()
-
- account, err = am.Store.GetAccount(ctx, account.Id)
+ user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
- return nil, fmt.Errorf("failed to get an account from store %v", err)
+ return nil, err
}
- user, ok := account.Users[claims.UserId]
- if !ok {
- return nil, status.Errorf(status.NotFound, "user not found")
- }
-
- // this code should be outside of the am.GetAccountFromToken(claims) because this method is called also by the gRPC
+ // this code should be outside of the am.GetAccountIDFromToken(claims) because this method is called also by the gRPC
// server when user authenticates a device. And we need to separate the Dashboard login event from the Device login event.
newLogin := user.LastDashboardLoginChanged(claims.LastLogin)
- err = am.Store.SaveUserLastLogin(ctx, account.Id, claims.UserId, claims.LastLogin)
+ err = am.Store.SaveUserLastLogin(ctx, accountID, userID, claims.LastLogin)
if err != nil {
log.WithContext(ctx).Errorf("failed saving user last login: %v", err)
}
if newLogin {
meta := map[string]any{"timestamp": claims.LastLogin}
- am.StoreEvent(ctx, claims.UserId, claims.UserId, account.Id, activity.DashboardLogin, meta)
+ am.StoreEvent(ctx, claims.UserId, claims.UserId, accountID, activity.DashboardLogin, meta)
}
return user, nil
@@ -642,63 +643,48 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string
// GetPAT returns a specific PAT from a user
func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) {
- unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
- defer unlock()
-
- account, err := am.Store.GetAccount(ctx, accountID)
+ initiatorUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID)
if err != nil {
- return nil, status.Errorf(status.NotFound, "account not found: %s", err)
+ return nil, err
}
- targetUser, ok := account.Users[targetUserID]
- if !ok {
- return nil, status.Errorf(status.NotFound, "user not found")
+ targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID)
+ if err != nil {
+ return nil, err
}
- executingUser, ok := account.Users[initiatorUserID]
- if !ok {
- return nil, status.Errorf(status.NotFound, "user not found")
+ if (initiatorUserID != targetUserID && !initiatorUser.IsAdminOrServiceUser()) || initiatorUser.AccountID != accountID {
+ return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user")
}
- if !(initiatorUserID == targetUserID || (executingUser.HasAdminPower() && targetUser.IsServiceUser)) {
- return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this userser")
+ for _, pat := range targetUser.PATsG {
+ if pat.ID == tokenID {
+ return pat.Copy(), nil
+ }
}
- pat := targetUser.PATs[tokenID]
- if pat == nil {
- return nil, status.Errorf(status.NotFound, "PAT not found")
- }
-
- return pat, nil
+ return nil, status.Errorf(status.NotFound, "PAT not found")
}
// GetAllPATs returns all PATs for a user
func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) {
- unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
- defer unlock()
-
- account, err := am.Store.GetAccount(ctx, accountID)
+ initiatorUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID)
if err != nil {
- return nil, status.Errorf(status.NotFound, "account not found: %s", err)
+ return nil, err
}
- targetUser, ok := account.Users[targetUserID]
- if !ok {
- return nil, status.Errorf(status.NotFound, "user not found")
+ targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID)
+ if err != nil {
+ return nil, err
}
- executingUser, ok := account.Users[initiatorUserID]
- if !ok {
- return nil, status.Errorf(status.NotFound, "user not found")
- }
-
- if !(initiatorUserID == targetUserID || (executingUser.HasAdminPower() && targetUser.IsServiceUser)) {
+ if (initiatorUserID != targetUserID && !initiatorUser.IsAdminOrServiceUser()) || initiatorUser.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user")
}
- var pats []*PersonalAccessToken
- for _, pat := range targetUser.PATs {
- pats = append(pats, pat)
+ pats := make([]*PersonalAccessToken, 0, len(targetUser.PATsG))
+ for _, pat := range targetUser.PATsG {
+ pats = append(pats, pat.Copy())
}
return pats, nil
diff --git a/management/server/user_test.go b/management/server/user_test.go
index 272060276..e394ef840 100644
--- a/management/server/user_test.go
+++ b/management/server/user_test.go
@@ -199,7 +199,8 @@ func TestUser_GetPAT(t *testing.T) {
defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users[mockUserID] = &User{
- Id: mockUserID,
+ Id: mockUserID,
+ AccountID: mockAccountID,
PATs: map[string]*PersonalAccessToken{
mockTokenID1: {
ID: mockTokenID1,
@@ -231,7 +232,8 @@ func TestUser_GetAllPATs(t *testing.T) {
defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users[mockUserID] = &User{
- Id: mockUserID,
+ Id: mockUserID,
+ AccountID: mockAccountID,
PATs: map[string]*PersonalAccessToken{
mockTokenID1: {
ID: mockTokenID1,
@@ -796,7 +798,10 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
assert.NoError(t, err)
}
- acc, err := am.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "")
+ accID, err := am.GetAccountIDByUserOrAccountID(context.Background(), "", account.Id, "")
+ assert.NoError(t, err)
+
+ acc, err := am.Store.GetAccount(context.Background(), accID)
assert.NoError(t, err)
for _, id := range tc.expectedDeleted {
From 58ff7ab797fcde081b3a0a802487f13a4ab4945a Mon Sep 17 00:00:00 2001
From: adasauce <60991921+adasauce@users.noreply.github.com>
Date: Fri, 27 Sep 2024 16:21:34 -0300
Subject: [PATCH 73/89] [management] improve zitadel idp error response detail
by decoding errors (#2634)
* [management] improve zitadel idp error response detail by decoding errors
* [management] extend readZitadelError to be used for requestJWTToken
more generically parse the error returned by zitadel.
* fix lint
---------
Co-authored-by: bcmmbaga
---
management/server/idp/zitadel.go | 49 +++++++++++++++++++++++++--
management/server/idp/zitadel_test.go | 10 +++---
2 files changed, 50 insertions(+), 9 deletions(-)
diff --git a/management/server/idp/zitadel.go b/management/server/idp/zitadel.go
index 729b49733..9d7626844 100644
--- a/management/server/idp/zitadel.go
+++ b/management/server/idp/zitadel.go
@@ -2,10 +2,12 @@ package idp
import (
"context"
+ "errors"
"fmt"
"io"
"net/http"
"net/url"
+ "slices"
"strings"
"sync"
"time"
@@ -97,6 +99,42 @@ type zitadelUserResponse struct {
PasswordlessRegistration zitadelPasswordlessRegistration `json:"passwordlessRegistration"`
}
+// readZitadelError parses errors returned by the zitadel APIs from a response.
+func readZitadelError(body io.ReadCloser) error {
+ bodyBytes, err := io.ReadAll(body)
+ if err != nil {
+ return fmt.Errorf("failed to read response body: %w", err)
+ }
+
+ helper := JsonParser{}
+ var target map[string]interface{}
+ err = helper.Unmarshal(bodyBytes, &target)
+ if err != nil {
+ return fmt.Errorf("error unparsable body: %s", string(bodyBytes))
+ }
+
+ // ensure keys are ordered for consistent logging behaviour.
+ errorKeys := make([]string, 0, len(target))
+ for k := range target {
+ errorKeys = append(errorKeys, k)
+ }
+ slices.Sort(errorKeys)
+
+ var errsOut []string
+ for _, k := range errorKeys {
+ if _, isEmbedded := target[k].(map[string]interface{}); isEmbedded {
+ continue
+ }
+ errsOut = append(errsOut, fmt.Sprintf("%s: %v", k, target[k]))
+ }
+
+ if len(errsOut) == 0 {
+ return errors.New("unknown error")
+ }
+
+ return errors.New(strings.Join(errsOut, " "))
+}
+
// NewZitadelManager creates a new instance of the ZitadelManager.
func NewZitadelManager(config ZitadelClientConfig, appMetrics telemetry.AppMetrics) (*ZitadelManager, error) {
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
@@ -176,7 +214,8 @@ func (zc *ZitadelCredentials) requestJWTToken(ctx context.Context) (*http.Respon
}
if resp.StatusCode != http.StatusOK {
- return nil, fmt.Errorf("unable to get zitadel token, statusCode %d", resp.StatusCode)
+ zErr := readZitadelError(resp.Body)
+ return nil, fmt.Errorf("unable to get zitadel token, statusCode %d, zitadel: %w", resp.StatusCode, zErr)
}
return resp, nil
@@ -489,7 +528,9 @@ func (zm *ZitadelManager) post(ctx context.Context, resource string, body string
zm.appMetrics.IDPMetrics().CountRequestStatusError()
}
- return nil, fmt.Errorf("unable to post %s, statusCode %d", reqURL, resp.StatusCode)
+ zErr := readZitadelError(resp.Body)
+
+ return nil, fmt.Errorf("unable to post %s, statusCode %d, zitadel: %w", reqURL, resp.StatusCode, zErr)
}
return io.ReadAll(resp.Body)
@@ -561,7 +602,9 @@ func (zm *ZitadelManager) get(ctx context.Context, resource string, q url.Values
zm.appMetrics.IDPMetrics().CountRequestStatusError()
}
- return nil, fmt.Errorf("unable to get %s, statusCode %d", reqURL, resp.StatusCode)
+ zErr := readZitadelError(resp.Body)
+
+ return nil, fmt.Errorf("unable to get %s, statusCode %d, zitadel: %w", reqURL, resp.StatusCode, zErr)
}
return io.ReadAll(resp.Body)
diff --git a/management/server/idp/zitadel_test.go b/management/server/idp/zitadel_test.go
index 6bc612e78..722f94fe0 100644
--- a/management/server/idp/zitadel_test.go
+++ b/management/server/idp/zitadel_test.go
@@ -66,7 +66,6 @@ func TestNewZitadelManager(t *testing.T) {
}
func TestZitadelRequestJWTToken(t *testing.T) {
-
type requestJWTTokenTest struct {
name string
inputCode int
@@ -88,15 +87,14 @@ func TestZitadelRequestJWTToken(t *testing.T) {
requestJWTTokenTestCase2 := requestJWTTokenTest{
name: "Request Bad Status Code",
inputCode: 400,
- inputRespBody: "{}",
+ inputRespBody: "{\"error\": \"invalid_scope\", \"error_description\":\"openid missing\"}",
helper: JsonParser{},
- expectedFuncExitErrDiff: fmt.Errorf("unable to get zitadel token, statusCode 400"),
+ expectedFuncExitErrDiff: fmt.Errorf("unable to get zitadel token, statusCode 400, zitadel: error: invalid_scope error_description: openid missing"),
expectedToken: "",
}
for _, testCase := range []requestJWTTokenTest{requestJWTTokenTesttCase1, requestJWTTokenTestCase2} {
t.Run(testCase.name, func(t *testing.T) {
-
jwtReqClient := mockHTTPClient{
resBody: testCase.inputRespBody,
code: testCase.inputCode,
@@ -156,7 +154,7 @@ func TestZitadelParseRequestJWTResponse(t *testing.T) {
}
parseRequestJWTResponseTestCase2 := parseRequestJWTResponseTest{
name: "Parse Bad json JWT Body",
- inputRespBody: "",
+ inputRespBody: "{}",
helper: JsonParser{},
expectedToken: "",
expectedExpiresIn: 0,
@@ -254,7 +252,7 @@ func TestZitadelAuthenticate(t *testing.T) {
inputCode: 400,
inputResBody: "{}",
helper: JsonParser{},
- expectedFuncExitErrDiff: fmt.Errorf("unable to get zitadel token, statusCode 400"),
+ expectedFuncExitErrDiff: fmt.Errorf("unable to get zitadel token, statusCode 400, zitadel: unknown error"),
expectedCode: 200,
expectedToken: "",
}
From 52ae693c9e5eff72082d4330eab6723251559546 Mon Sep 17 00:00:00 2001
From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com>
Date: Sun, 29 Sep 2024 00:22:47 +0200
Subject: [PATCH 74/89] [signal] add context to signal-dispatcher (#2662)
---
client/cmd/testutil_test.go | 2 +-
client/internal/engine_test.go | 2 +-
client/server/server_test.go | 2 +-
go.mod | 2 +-
go.sum | 4 ++--
signal/client/client_test.go | 2 +-
signal/cmd/run.go | 2 +-
signal/server/signal.go | 4 ++--
8 files changed, 10 insertions(+), 10 deletions(-)
diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go
index 780cc8b04..f0dc8bf21 100644
--- a/client/cmd/testutil_test.go
+++ b/client/cmd/testutil_test.go
@@ -57,7 +57,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
t.Fatal(err)
}
s := grpc.NewServer()
- srv, err := sig.NewServer(otel.Meter(""))
+ srv, err := sig.NewServer(context.Background(), otel.Meter(""))
require.NoError(t, err)
sigProto.RegisterSignalExchangeServer(s, srv)
diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go
index f30566380..95aadf141 100644
--- a/client/internal/engine_test.go
+++ b/client/internal/engine_test.go
@@ -1056,7 +1056,7 @@ func startSignal(t *testing.T) (*grpc.Server, string, error) {
log.Fatalf("failed to listen: %v", err)
}
- srv, err := signalServer.NewServer(otel.Meter(""))
+ srv, err := signalServer.NewServer(context.Background(), otel.Meter(""))
require.NoError(t, err)
proto.RegisterSignalExchangeServer(s, srv)
diff --git a/client/server/server_test.go b/client/server/server_test.go
index 795060fab..9b18df4d3 100644
--- a/client/server/server_test.go
+++ b/client/server/server_test.go
@@ -160,7 +160,7 @@ func startSignal(t *testing.T) (*grpc.Server, string, error) {
log.Fatalf("failed to listen: %v", err)
}
- srv, err := signalServer.NewServer(otel.Meter(""))
+ srv, err := signalServer.NewServer(context.Background(), otel.Meter(""))
require.NoError(t, err)
proto.RegisterSignalExchangeServer(s, srv)
diff --git a/go.mod b/go.mod
index 12709e50d..cf3b610bd 100644
--- a/go.mod
+++ b/go.mod
@@ -60,7 +60,7 @@ require (
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e
- github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240820130728-bc0683599080
+ github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240928205912-5569c4c5e086
github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0
github.com/patrickmn/go-cache v2.1.0+incompatible
diff --git a/go.sum b/go.sum
index 2355f6f0c..089629cdf 100644
--- a/go.sum
+++ b/go.sum
@@ -525,8 +525,8 @@ github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-
github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
-github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240820130728-bc0683599080 h1:mXJkoWLdqJTlkQ7DgQ536kcXHXIdUPeagkN8i4eFDdg=
-github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240820130728-bc0683599080/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
+github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240928205912-5569c4c5e086 h1:NZm4JvvjKuEh3p7daHUy3rWKhKsnUzzYpGv1qT4dYLc=
+github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240928205912-5569c4c5e086/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs=
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM=
diff --git a/signal/client/client_test.go b/signal/client/client_test.go
index 2525493b4..f7d4ebc50 100644
--- a/signal/client/client_test.go
+++ b/signal/client/client_test.go
@@ -199,7 +199,7 @@ func startSignal() (*grpc.Server, net.Listener) {
panic(err)
}
s := grpc.NewServer()
- srv, err := server.NewServer(otel.Meter(""))
+ srv, err := server.NewServer(context.Background(), otel.Meter(""))
if err != nil {
panic(err)
}
diff --git a/signal/cmd/run.go b/signal/cmd/run.go
index 0bdc62ead..1bb2f1d0c 100644
--- a/signal/cmd/run.go
+++ b/signal/cmd/run.go
@@ -102,7 +102,7 @@ var (
}
}()
- srv, err := server.NewServer(metricsServer.Meter)
+ srv, err := server.NewServer(cmd.Context(), metricsServer.Meter)
if err != nil {
return fmt.Errorf("creating signal server: %v", err)
}
diff --git a/signal/server/signal.go b/signal/server/signal.go
index b268aa3fc..c020c5604 100644
--- a/signal/server/signal.go
+++ b/signal/server/signal.go
@@ -47,13 +47,13 @@ type Server struct {
}
// NewServer creates a new Signal server
-func NewServer(meter metric.Meter) (*Server, error) {
+func NewServer(ctx context.Context, meter metric.Meter) (*Server, error) {
appMetrics, err := metrics.NewAppMetrics(meter)
if err != nil {
return nil, fmt.Errorf("creating app metrics: %v", err)
}
- dispatcher, err := dispatcher.NewDispatcher()
+ dispatcher, err := dispatcher.NewDispatcher(ctx)
if err != nil {
return nil, fmt.Errorf("creating dispatcher: %v", err)
}
From cfbcf507fb0ae039c270af48822679a754b8c530 Mon Sep 17 00:00:00 2001
From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com>
Date: Sun, 29 Sep 2024 20:23:34 +0200
Subject: [PATCH 75/89] propagate meter (#2668)
---
go.mod | 2 +-
go.sum | 4 ++--
signal/server/signal.go | 2 +-
3 files changed, 4 insertions(+), 4 deletions(-)
diff --git a/go.mod b/go.mod
index cf3b610bd..edee0ede4 100644
--- a/go.mod
+++ b/go.mod
@@ -60,7 +60,7 @@ require (
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e
- github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240928205912-5569c4c5e086
+ github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240929132730-cbef5d331757
github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0
github.com/patrickmn/go-cache v2.1.0+incompatible
diff --git a/go.sum b/go.sum
index 089629cdf..2160fa1f8 100644
--- a/go.sum
+++ b/go.sum
@@ -525,8 +525,8 @@ github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-
github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
-github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240928205912-5569c4c5e086 h1:NZm4JvvjKuEh3p7daHUy3rWKhKsnUzzYpGv1qT4dYLc=
-github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240928205912-5569c4c5e086/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
+github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240929132730-cbef5d331757 h1:6XniCzDt+1jvXWMUY4EDH0Hi5RXbUOYB0A8XEQqSlZk=
+github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240929132730-cbef5d331757/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs=
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM=
diff --git a/signal/server/signal.go b/signal/server/signal.go
index c020c5604..386ce7238 100644
--- a/signal/server/signal.go
+++ b/signal/server/signal.go
@@ -53,7 +53,7 @@ func NewServer(ctx context.Context, meter metric.Meter) (*Server, error) {
return nil, fmt.Errorf("creating app metrics: %v", err)
}
- dispatcher, err := dispatcher.NewDispatcher(ctx)
+ dispatcher, err := dispatcher.NewDispatcher(ctx, meter)
if err != nil {
return nil, fmt.Errorf("creating dispatcher: %v", err)
}
From 3dca6099d4f1a32c2e2ddbabe88a49d786fb3c41 Mon Sep 17 00:00:00 2001
From: Zoltan Papp
Date: Mon, 30 Sep 2024 10:34:57 +0200
Subject: [PATCH 76/89] Fix ebpf close function (#2672)
---
client/internal/wgproxy/ebpf/proxy.go | 9 +++++----
1 file changed, 5 insertions(+), 4 deletions(-)
diff --git a/client/internal/wgproxy/ebpf/proxy.go b/client/internal/wgproxy/ebpf/proxy.go
index 4bd4bfff6..27ede3ef1 100644
--- a/client/internal/wgproxy/ebpf/proxy.go
+++ b/client/internal/wgproxy/ebpf/proxy.go
@@ -81,8 +81,7 @@ func (p *WGEBPFProxy) Listen() error {
conn, err := nbnet.ListenUDP("udp", &addr)
if err != nil {
- cErr := p.Free()
- if cErr != nil {
+ if cErr := p.Free(); cErr != nil {
log.Errorf("Failed to close the wgproxy: %s", cErr)
}
return err
@@ -122,8 +121,10 @@ func (p *WGEBPFProxy) Free() error {
p.ctxCancel()
var result *multierror.Error
- if err := p.conn.Close(); err != nil {
- result = multierror.Append(result, err)
+ if p.conn != nil { // p.conn will be nil if we have failed to listen
+ if err := p.conn.Close(); err != nil {
+ result = multierror.Append(result, err)
+ }
}
if err := p.ebpfManager.FreeWGProxy(); err != nil {
From 2fd60b2cb46a77f16b5e1e1f72a1a09f03f0ecbe Mon Sep 17 00:00:00 2001
From: Gianluca Boiano <491117+M0Rf30@users.noreply.github.com>
Date: Mon, 30 Sep 2024 16:43:34 +0200
Subject: [PATCH 77/89] Specify goreleaser version and update to 2 (#2673)
---
.github/workflows/release.yml | 72 +++++++++++++----------------------
.goreleaser.yaml | 32 ++++++++--------
.goreleaser_ui.yaml | 9 +++--
.goreleaser_ui_darwin.yaml | 6 ++-
CONTRIBUTING.md | 2 +-
5 files changed, 52 insertions(+), 69 deletions(-)
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index 5f423f1c9..162e488c3 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -3,15 +3,14 @@ name: Release
on:
push:
tags:
- - 'v*'
+ - "v*"
branches:
- main
pull_request:
-
env:
SIGN_PIPE_VER: "v0.0.14"
- GORELEASER_VER: "v1.14.1"
+ GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"
@@ -34,19 +33,16 @@ jobs:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
- -
- name: Checkout
+ - name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
- -
- name: Set up Go
+ - name: Set up Go
uses: actions/setup-go@v5
with:
go-version: "1.23"
cache: false
- -
- name: Cache Go modules
+ - name: Cache Go modules
uses: actions/cache@v4
with:
path: |
@@ -55,20 +51,15 @@ jobs:
key: ${{ runner.os }}-go-releaser-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-releaser-
- -
- name: Install modules
+ - name: Install modules
run: go mod tidy
- -
- name: check git status
+ - name: check git status
run: git --no-pager diff --exit-code
- -
- name: Set up QEMU
+ - name: Set up QEMU
uses: docker/setup-qemu-action@v2
- -
- name: Set up Docker Buildx
+ - name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
- -
- name: Login to Docker hub
+ - name: Login to Docker hub
if: github.event_name != 'pull_request'
uses: docker/login-action@v1
with:
@@ -85,35 +76,31 @@ jobs:
uses: goreleaser/goreleaser-action@v4
with:
version: ${{ env.GORELEASER_VER }}
- args: release --rm-dist ${{ env.flags }}
+ args: release --clean ${{ env.flags }}
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
- -
- name: upload non tags for debug purposes
+ - name: upload non tags for debug purposes
uses: actions/upload-artifact@v4
with:
name: release
path: dist/
retention-days: 3
- -
- name: upload linux packages
+ - name: upload linux packages
uses: actions/upload-artifact@v4
with:
name: linux-packages
path: dist/netbird_linux**
retention-days: 3
- -
- name: upload windows packages
+ - name: upload windows packages
uses: actions/upload-artifact@v4
with:
name: windows-packages
path: dist/netbird_windows**
retention-days: 3
- -
- name: upload macos packages
+ - name: upload macos packages
uses: actions/upload-artifact@v4
with:
name: macos-packages
@@ -145,7 +132,7 @@ jobs:
- name: Cache Go modules
uses: actions/cache@v4
with:
- path: |
+ path: |
~/go/pkg/mod
~/.cache/go-build
key: ${{ runner.os }}-ui-go-releaser-${{ hashFiles('**/go.sum') }}
@@ -169,7 +156,7 @@ jobs:
uses: goreleaser/goreleaser-action@v4
with:
version: ${{ env.GORELEASER_VER }}
- args: release --config .goreleaser_ui.yaml --rm-dist ${{ env.flags }}
+ args: release --config .goreleaser_ui.yaml --clean ${{ env.flags }}
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
@@ -187,19 +174,16 @@ jobs:
steps:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
- -
- name: Checkout
+ - name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
- -
- name: Set up Go
+ - name: Set up Go
uses: actions/setup-go@v5
with:
go-version: "1.23"
cache: false
- -
- name: Cache Go modules
+ - name: Cache Go modules
uses: actions/cache@v4
with:
path: |
@@ -208,23 +192,19 @@ jobs:
key: ${{ runner.os }}-ui-go-releaser-darwin-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-ui-go-releaser-darwin-
- -
- name: Install modules
+ - name: Install modules
run: go mod tidy
- -
- name: check git status
+ - name: check git status
run: git --no-pager diff --exit-code
- -
- name: Run GoReleaser
+ - name: Run GoReleaser
id: goreleaser
uses: goreleaser/goreleaser-action@v4
with:
version: ${{ env.GORELEASER_VER }}
- args: release --config .goreleaser_ui_darwin.yaml --rm-dist ${{ env.flags }}
+ args: release --config .goreleaser_ui_darwin.yaml --clean ${{ env.flags }}
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- -
- name: upload non tags for debug purposes
+ - name: upload non tags for debug purposes
uses: actions/upload-artifact@v4
with:
name: release-ui-darwin
@@ -233,7 +213,7 @@ jobs:
trigger_signer:
runs-on: ubuntu-latest
- needs: [release,release_ui,release_ui_darwin]
+ needs: [release, release_ui, release_ui_darwin]
if: startsWith(github.ref, 'refs/tags/')
steps:
- name: Trigger binaries sign pipelines
diff --git a/.goreleaser.yaml b/.goreleaser.yaml
index 068864d6e..cf2ce4f4f 100644
--- a/.goreleaser.yaml
+++ b/.goreleaser.yaml
@@ -1,3 +1,5 @@
+version: 2
+
project_name: netbird
builds:
- id: netbird
@@ -22,7 +24,7 @@ builds:
goarch: 386
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
- mod_timestamp: '{{ .CommitTimestamp }}'
+ mod_timestamp: "{{ .CommitTimestamp }}"
tags:
- load_wgnt_from_rsrc
@@ -42,19 +44,19 @@ builds:
- softfloat
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
- mod_timestamp: '{{ .CommitTimestamp }}'
+ mod_timestamp: "{{ .CommitTimestamp }}"
tags:
- load_wgnt_from_rsrc
- id: netbird-mgmt
dir: management
env:
- - CGO_ENABLED=1
- - >-
- {{- if eq .Runtime.Goos "linux" }}
- {{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }}
- {{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }}
- {{- end }}
+ - CGO_ENABLED=1
+ - >-
+ {{- if eq .Runtime.Goos "linux" }}
+ {{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }}
+ {{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }}
+ {{- end }}
binary: netbird-mgmt
goos:
- linux
@@ -64,7 +66,7 @@ builds:
- arm
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
- mod_timestamp: '{{ .CommitTimestamp }}'
+ mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-signal
dir: signal
@@ -78,7 +80,7 @@ builds:
- arm
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
- mod_timestamp: '{{ .CommitTimestamp }}'
+ mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-relay
dir: relay
@@ -92,7 +94,7 @@ builds:
- arm
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
- mod_timestamp: '{{ .CommitTimestamp }}'
+ mod_timestamp: "{{ .CommitTimestamp }}"
archives:
- builds:
@@ -100,7 +102,6 @@ archives:
- netbird-static
nfpms:
-
- maintainer: Netbird
description: Netbird client.
homepage: https://netbird.io/
@@ -416,10 +417,9 @@ docker_manifests:
- netbirdio/management:{{ .Version }}-debug-amd64
brews:
- -
- ids:
+ - ids:
- default
- tap:
+ repository:
owner: netbirdio
name: homebrew-tap
token: "{{ .Env.HOMEBREW_TAP_GITHUB_TOKEN }}"
@@ -436,7 +436,7 @@ brews:
uploads:
- name: debian
ids:
- - netbird-deb
+ - netbird-deb
mode: archive
target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package=
username: dev@wiretrustee.com
diff --git a/.goreleaser_ui.yaml b/.goreleaser_ui.yaml
index fd92b5328..06577f4e3 100644
--- a/.goreleaser_ui.yaml
+++ b/.goreleaser_ui.yaml
@@ -1,3 +1,5 @@
+version: 2
+
project_name: netbird-ui
builds:
- id: netbird-ui
@@ -11,7 +13,7 @@ builds:
- amd64
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
- mod_timestamp: '{{ .CommitTimestamp }}'
+ mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-ui-windows
dir: client/ui
@@ -26,7 +28,7 @@ builds:
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
- -H windowsgui
- mod_timestamp: '{{ .CommitTimestamp }}'
+ mod_timestamp: "{{ .CommitTimestamp }}"
archives:
- id: linux-arch
@@ -39,7 +41,6 @@ archives:
- netbird-ui-windows
nfpms:
-
- maintainer: Netbird
description: Netbird client UI.
homepage: https://netbird.io/
@@ -77,7 +78,7 @@ nfpms:
uploads:
- name: debian
ids:
- - netbird-ui-deb
+ - netbird-ui-deb
mode: archive
target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package=
username: dev@wiretrustee.com
diff --git a/.goreleaser_ui_darwin.yaml b/.goreleaser_ui_darwin.yaml
index 2c3afa91b..bccb7f471 100644
--- a/.goreleaser_ui_darwin.yaml
+++ b/.goreleaser_ui_darwin.yaml
@@ -1,3 +1,5 @@
+version: 2
+
project_name: netbird-ui
builds:
- id: netbird-ui-darwin
@@ -17,7 +19,7 @@ builds:
- softfloat
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
- mod_timestamp: '{{ .CommitTimestamp }}'
+ mod_timestamp: "{{ .CommitTimestamp }}"
tags:
- load_wgnt_from_rsrc
@@ -28,4 +30,4 @@ archives:
checksum:
name_template: "{{ .ProjectName }}_darwin_checksums.txt"
changelog:
- skip: true
\ No newline at end of file
+ disable: true
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 492aa5c2e..c82cfc763 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -96,7 +96,7 @@ They can be executed from the repository root before every push or PR:
**Goreleaser**
```shell
-goreleaser --snapshot --rm-dist
+goreleaser build --snapshot --clean
```
**golangci-lint**
```shell
From e27f85b317a97721921933659a80c8be35c785e1 Mon Sep 17 00:00:00 2001
From: Maycon Santos
Date: Mon, 30 Sep 2024 20:07:21 +0200
Subject: [PATCH 78/89] Update docker creds (#2677)
---
.github/workflows/release.yml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index 162e488c3..7af6d3e4d 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -63,7 +63,7 @@ jobs:
if: github.event_name != 'pull_request'
uses: docker/login-action@v1
with:
- username: netbirdio
+ username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Install OS build dependencies
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
From 16179db599ef6fb42e709597bc260101dfa7cd74 Mon Sep 17 00:00:00 2001
From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com>
Date: Mon, 30 Sep 2024 22:18:10 +0200
Subject: [PATCH 79/89] [management] Propagate metrics (#2667)
---
go.mod | 2 +-
go.sum | 4 ++--
management/server/http/handler.go | 2 +-
3 files changed, 4 insertions(+), 4 deletions(-)
diff --git a/go.mod b/go.mod
index edee0ede4..c29ba0763 100644
--- a/go.mod
+++ b/go.mod
@@ -59,7 +59,7 @@ require (
github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0
- github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e
+ github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240929132730-cbef5d331757
github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0
diff --git a/go.sum b/go.sum
index 2160fa1f8..1f6cbb785 100644
--- a/go.sum
+++ b/go.sum
@@ -521,8 +521,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
-github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e h1:LYxhAmiEzSldLELHSMVoUnRPq3ztTNQImrD27frrGsI=
-github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y=
+github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd h1:phKq1S1Y/lnqEhP5Qknta733+rPX16dRDHM7hKkot9c=
+github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240929132730-cbef5d331757 h1:6XniCzDt+1jvXWMUY4EDH0Hi5RXbUOYB0A8XEQqSlZk=
diff --git a/management/server/http/handler.go b/management/server/http/handler.go
index ef94f22b9..3f8a8554d 100644
--- a/management/server/http/handler.go
+++ b/management/server/http/handler.go
@@ -82,7 +82,7 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa
AuthCfg: authCfg,
}
- if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor, integratedValidator); err != nil {
+ if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor, integratedValidator, appMetrics.GetMeter()); err != nil {
return nil, fmt.Errorf("register integrations endpoints: %w", err)
}
From 24c0aaa745bc2ac46bdcf1f855834306a886db95 Mon Sep 17 00:00:00 2001
From: Simen <97337442+simen64@users.noreply.github.com>
Date: Tue, 1 Oct 2024 13:32:58 +0200
Subject: [PATCH 80/89] Install sh alpine fixes (#2678)
* Made changes to the peer install script that makes it work on alpine linux without changes
* fix small oversight with doas fix
* use try catch approach when curling binaries
---
release_files/install.sh | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/release_files/install.sh b/release_files/install.sh
index d6aabebd8..5dd0f67bb 100755
--- a/release_files/install.sh
+++ b/release_files/install.sh
@@ -21,6 +21,8 @@ SUDO=""
if command -v sudo > /dev/null && [ "$(id -u)" -ne 0 ]; then
SUDO="sudo"
+elif command -v doas > /dev/null && [ "$(id -u)" -ne 0 ]; then
+ SUDO="doas"
fi
if [ -z ${NETBIRD_RELEASE+x} ]; then
@@ -68,7 +70,7 @@ download_release_binary() {
if [ -n "$GITHUB_TOKEN" ]; then
cd /tmp && curl -H "Authorization: token ${GITHUB_TOKEN}" -LO "$DOWNLOAD_URL"
else
- cd /tmp && curl -LO "$DOWNLOAD_URL"
+ cd /tmp && curl -LO "$DOWNLOAD_URL" || curl -LO --dns-servers 8.8.8.8 "$DOWNLOAD_URL"
fi
@@ -316,7 +318,7 @@ install_netbird() {
}
version_greater_equal() {
- printf '%s\n%s\n' "$2" "$1" | sort -V -C
+ printf '%s\n%s\n' "$2" "$1" | sort -V -c
}
is_bin_package_manager() {
From ee0ea86a0a9394b2632ed2be3149d45c04baca67 Mon Sep 17 00:00:00 2001
From: Zoltan Papp
Date: Tue, 1 Oct 2024 16:22:18 +0200
Subject: [PATCH 81/89] [relay-client] Fix Relay disconnection handling (#2680)
* Fix Relay disconnection handling
If has an active P2P connection meanwhile the Relay connection broken with the server then we removed the WireGuard peer configuration.
* Change logs
---
client/internal/peer/conn.go | 8 ++++++--
1 file changed, 6 insertions(+), 2 deletions(-)
diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go
index ea6d892b9..baff1372a 100644
--- a/client/internal/peer/conn.go
+++ b/client/internal/peer/conn.go
@@ -586,13 +586,17 @@ func (conn *Conn) onWorkerRelayStateDisconnected() {
return
}
- if conn.wgProxyRelay != nil {
- log.Debugf("relayed connection is closed, clean up WireGuard config")
+ log.Debugf("relay connection is disconnected")
+
+ if conn.currentConnPriority == connPriorityRelay {
+ log.Debugf("clean up WireGuard config")
err := conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
if err != nil {
conn.log.Errorf("failed to remove wg endpoint: %v", err)
}
+ }
+ if conn.wgProxyRelay != nil {
conn.endpointRelay = nil
_ = conn.wgProxyRelay.CloseConn()
conn.wgProxyRelay = nil
From 5932298ce03ccda417cbf954020665fdc096baaa Mon Sep 17 00:00:00 2001
From: Maycon Santos
Date: Wed, 2 Oct 2024 11:48:09 +0200
Subject: [PATCH 82/89] Add log setting to Caddy container (#2684)
This avoids full disk on busy systems
---
infrastructure_files/getting-started-with-zitadel.sh | 5 +++++
1 file changed, 5 insertions(+)
diff --git a/infrastructure_files/getting-started-with-zitadel.sh b/infrastructure_files/getting-started-with-zitadel.sh
index c0275536b..2c5c35d53 100644
--- a/infrastructure_files/getting-started-with-zitadel.sh
+++ b/infrastructure_files/getting-started-with-zitadel.sh
@@ -793,6 +793,11 @@ services:
volumes:
- netbird_caddy_data:/data
- ./Caddyfile:/etc/caddy/Caddyfile
+ logging:
+ driver: "json-file"
+ options:
+ max-size: "500m"
+ max-file: "2"
# UI dashboard
dashboard:
image: netbirdio/dashboard:latest
From a3a479429eb13dc53b9d9dd7bfb1b0710c5055c0 Mon Sep 17 00:00:00 2001
From: Maycon Santos
Date: Wed, 2 Oct 2024 11:48:42 +0200
Subject: [PATCH 83/89] Use the pkgs to get the latest version (#2682)
* Use the pkgs to get the latest version
* disable fail fast
---
.github/workflows/install-script-test.yml | 1 +
release_files/install.sh | 6 ++++--
2 files changed, 5 insertions(+), 2 deletions(-)
diff --git a/.github/workflows/install-script-test.yml b/.github/workflows/install-script-test.yml
index 04c222e87..22d002a48 100644
--- a/.github/workflows/install-script-test.yml
+++ b/.github/workflows/install-script-test.yml
@@ -13,6 +13,7 @@ concurrency:
jobs:
test-install-script:
strategy:
+ fail-fast: false
max-parallel: 2
matrix:
os: [ubuntu-latest, macos-latest]
diff --git a/release_files/install.sh b/release_files/install.sh
index 5dd0f67bb..b7a6c08f9 100755
--- a/release_files/install.sh
+++ b/release_files/install.sh
@@ -33,14 +33,16 @@ get_release() {
local RELEASE=$1
if [ "$RELEASE" = "latest" ]; then
local TAG="latest"
+ local URL="https://pkgs.netbird.io/releases/latest"
else
local TAG="tags/${RELEASE}"
+ local URL="https://api.github.com/repos/${OWNER}/${REPO}/releases/${TAG}"
fi
if [ -n "$GITHUB_TOKEN" ]; then
- curl -H "Authorization: token ${GITHUB_TOKEN}" -s "https://api.github.com/repos/${OWNER}/${REPO}/releases/${TAG}" \
+ curl -H "Authorization: token ${GITHUB_TOKEN}" -s "${URL}" \
| grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/'
else
- curl -s "https://api.github.com/repos/${OWNER}/${REPO}/releases/${TAG}" \
+ curl -s "${URL}" \
| grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/'
fi
}
From ff7863785f81c64ce0570b28950f806b75800c6a Mon Sep 17 00:00:00 2001
From: Bethuel Mmbaga
Date: Wed, 2 Oct 2024 14:41:00 +0300
Subject: [PATCH 84/89] [management, client] Add access control support to
network routes (#2100)
---
.github/workflows/golangci-lint.yml | 2 +-
client/firewall/iface.go | 4 +-
client/firewall/iptables/acl_linux.go | 174 +--
client/firewall/iptables/manager_linux.go | 64 +-
.../firewall/iptables/manager_linux_test.go | 54 +-
client/firewall/iptables/router_linux.go | 616 ++++++----
client/firewall/iptables/router_linux_test.go | 270 ++--
client/firewall/manager/firewall.go | 125 +-
client/firewall/manager/firewall_test.go | 192 +++
client/firewall/manager/routerpair.go | 16 +-
client/firewall/nftables/acl_linux.go | 549 +--------
client/firewall/nftables/manager_linux.go | 121 +-
.../firewall/nftables/manager_linux_test.go | 78 +-
client/firewall/nftables/route_linux.go | 431 -------
client/firewall/nftables/router_linux.go | 798 ++++++++++++
client/firewall/nftables/router_linux_test.go | 605 +++++++--
client/firewall/test/cases_linux.go | 20 +-
client/firewall/uspfilter/uspfilter.go | 42 +-
client/firewall/uspfilter/uspfilter_test.go | 20 +-
client/internal/acl/id/id.go | 25 +
client/internal/acl/manager.go | 255 ++--
client/internal/acl/manager_test.go | 170 +--
client/internal/engine.go | 9 +-
client/internal/routemanager/dynamic/route.go | 2 +-
client/internal/routemanager/manager.go | 6 +-
.../routemanager/refcounter/refcounter.go | 197 ++-
.../internal/routemanager/refcounter/types.go | 6 +-
.../routemanager/server_nonandroid.go | 16 +-
client/internal/routemanager/static/route.go | 2 +-
.../routemanager/systemops/systemops.go | 2 +-
.../systemops/systemops_generic.go | 4 +-
management/proto/management.pb.go | 1087 +++++++++++------
management/proto/management.proto | 84 +-
management/server/account.go | 4 +-
management/server/account_test.go | 7 +-
management/server/grpcserver.go | 4 +
management/server/http/api/openapi.yml | 30 +-
management/server/http/api/types.gen.go | 30 +-
management/server/http/policies_handler.go | 33 +-
management/server/http/routes_handler.go | 16 +-
management/server/http/routes_handler_test.go | 48 +-
management/server/mock_server/account_mock.go | 8 +-
management/server/network.go | 13 +-
management/server/peer_test.go | 7 +-
management/server/policy.go | 48 +-
management/server/route.go | 292 ++++-
management/server/route_test.go | 536 ++++++--
route/route.go | 5 +-
48 files changed, 4683 insertions(+), 2444 deletions(-)
create mode 100644 client/firewall/manager/firewall_test.go
delete mode 100644 client/firewall/nftables/route_linux.go
create mode 100644 client/firewall/nftables/router_linux.go
create mode 100644 client/internal/acl/id/id.go
diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml
index 8b7136841..2d743f790 100644
--- a/.github/workflows/golangci-lint.yml
+++ b/.github/workflows/golangci-lint.yml
@@ -19,7 +19,7 @@ jobs:
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
- ignore_words_list: erro,clienta,hastable,
+ ignore_words_list: erro,clienta,hastable,iif
skip: go.mod,go.sum
only_warn: 1
golangci:
diff --git a/client/firewall/iface.go b/client/firewall/iface.go
index 882daef75..d0b5209c0 100644
--- a/client/firewall/iface.go
+++ b/client/firewall/iface.go
@@ -1,6 +1,8 @@
package firewall
-import "github.com/netbirdio/netbird/iface"
+import (
+ "github.com/netbirdio/netbird/iface"
+)
// IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface {
diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go
index b77cc8f43..c6a96a876 100644
--- a/client/firewall/iptables/acl_linux.go
+++ b/client/firewall/iptables/acl_linux.go
@@ -19,24 +19,22 @@ const (
// rules chains contains the effective ACL rules
chainNameInputRules = "NETBIRD-ACL-INPUT"
chainNameOutputRules = "NETBIRD-ACL-OUTPUT"
-
- postRoutingMark = "0x000007e4"
)
type aclManager struct {
- iptablesClient *iptables.IPTables
- wgIface iFaceMapper
- routeingFwChainName string
+ iptablesClient *iptables.IPTables
+ wgIface iFaceMapper
+ routingFwChainName string
entries map[string][][]string
ipsetStore *ipsetStore
}
-func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routeingFwChainName string) (*aclManager, error) {
+func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) {
m := &aclManager{
- iptablesClient: iptablesClient,
- wgIface: wgIface,
- routeingFwChainName: routeingFwChainName,
+ iptablesClient: iptablesClient,
+ wgIface: wgIface,
+ routingFwChainName: routingFwChainName,
entries: make(map[string][][]string),
ipsetStore: newIpsetStore(),
@@ -61,7 +59,7 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, route
return m, nil
}
-func (m *aclManager) AddFiltering(
+func (m *aclManager) AddPeerFiltering(
ip net.IP,
protocol firewall.Protocol,
sPort *firewall.Port,
@@ -127,7 +125,7 @@ func (m *aclManager) AddFiltering(
return nil, fmt.Errorf("rule already exists")
}
- if err := m.iptablesClient.Insert("filter", chain, 1, specs...); err != nil {
+ if err := m.iptablesClient.Append("filter", chain, specs...); err != nil {
return nil, err
}
@@ -139,28 +137,16 @@ func (m *aclManager) AddFiltering(
chain: chain,
}
- if !shouldAddToPrerouting(protocol, dPort, direction) {
- return []firewall.Rule{rule}, nil
- }
-
- rulePrerouting, err := m.addPreroutingFilter(ipsetName, string(protocol), dPortVal, ip)
- if err != nil {
- return []firewall.Rule{rule}, err
- }
- return []firewall.Rule{rule, rulePrerouting}, nil
+ return []firewall.Rule{rule}, nil
}
-// DeleteRule from the firewall by rule definition
-func (m *aclManager) DeleteRule(rule firewall.Rule) error {
+// DeletePeerRule from the firewall by rule definition
+func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
r, ok := rule.(*Rule)
if !ok {
return fmt.Errorf("invalid rule type")
}
- if r.chain == "PREROUTING" {
- goto DELETERULE
- }
-
if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok {
// delete IP from ruleset IPs list and ipset
if _, ok := ipsetList.ips[r.ip]; ok {
@@ -185,14 +171,7 @@ func (m *aclManager) DeleteRule(rule firewall.Rule) error {
}
}
-DELETERULE:
- var table string
- if r.chain == "PREROUTING" {
- table = "mangle"
- } else {
- table = "filter"
- }
- err := m.iptablesClient.Delete(table, r.chain, r.specs...)
+ err := m.iptablesClient.Delete(tableName, r.chain, r.specs...)
if err != nil {
log.Debugf("failed to delete rule, %s, %v: %s", r.chain, r.specs, err)
}
@@ -203,44 +182,6 @@ func (m *aclManager) Reset() error {
return m.cleanChains()
}
-func (m *aclManager) addPreroutingFilter(ipsetName string, protocol string, port string, ip net.IP) (*Rule, error) {
- var src []string
- if ipsetName != "" {
- src = []string{"-m", "set", "--set", ipsetName, "src"}
- } else {
- src = []string{"-s", ip.String()}
- }
- specs := []string{
- "-d", m.wgIface.Address().IP.String(),
- "-p", protocol,
- "--dport", port,
- "-j", "MARK", "--set-mark", postRoutingMark,
- }
-
- specs = append(src, specs...)
-
- ok, err := m.iptablesClient.Exists("mangle", "PREROUTING", specs...)
- if err != nil {
- return nil, fmt.Errorf("failed to check rule: %w", err)
- }
- if ok {
- return nil, fmt.Errorf("rule already exists")
- }
-
- if err := m.iptablesClient.Insert("mangle", "PREROUTING", 1, specs...); err != nil {
- return nil, err
- }
-
- rule := &Rule{
- ruleID: uuid.New().String(),
- specs: specs,
- ipsetName: ipsetName,
- ip: ip.String(),
- chain: "PREROUTING",
- }
- return rule, nil
-}
-
// todo write less destructive cleanup mechanism
func (m *aclManager) cleanChains() error {
ok, err := m.iptablesClient.ChainExists(tableName, chainNameOutputRules)
@@ -291,25 +232,6 @@ func (m *aclManager) cleanChains() error {
}
}
- ok, err = m.iptablesClient.ChainExists("mangle", "PREROUTING")
- if err != nil {
- log.Debugf("failed to list chains: %s", err)
- return err
- }
- if ok {
- for _, rule := range m.entries["PREROUTING"] {
- err := m.iptablesClient.DeleteIfExists("mangle", "PREROUTING", rule...)
- if err != nil {
- log.Errorf("failed to delete rule: %v, %s", rule, err)
- }
- }
- err = m.iptablesClient.ClearChain("mangle", "PREROUTING")
- if err != nil {
- log.Debugf("failed to clear %s chain: %s", "PREROUTING", err)
- return err
- }
- }
-
for _, ipsetName := range m.ipsetStore.ipsetNames() {
if err := ipset.Flush(ipsetName); err != nil {
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
@@ -338,17 +260,9 @@ func (m *aclManager) createDefaultChains() error {
for chainName, rules := range m.entries {
for _, rule := range rules {
- if chainName == "FORWARD" {
- // position 2 because we add it after router's, jump rule
- if err := m.iptablesClient.InsertUnique(tableName, "FORWARD", 2, rule...); err != nil {
- log.Debugf("failed to create input chain jump rule: %s", err)
- return err
- }
- } else {
- if err := m.iptablesClient.AppendUnique(tableName, chainName, rule...); err != nil {
- log.Debugf("failed to create input chain jump rule: %s", err)
- return err
- }
+ if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil {
+ log.Debugf("failed to create input chain jump rule: %s", err)
+ return err
}
}
}
@@ -356,40 +270,29 @@ func (m *aclManager) createDefaultChains() error {
return nil
}
+// seedInitialEntries adds default rules to the entries map, rules are inserted on pos 1, hence the order is reversed.
+// We want to make sure our traffic is not dropped by existing rules.
+
+// The existing FORWARD rules/policies decide outbound traffic towards our interface.
+// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
+
+// The OUTPUT chain gets an extra rule to allow traffic to any set up routes, the return traffic is handled by the INPUT related/established rule.
func (m *aclManager) seedInitialEntries() {
- m.appendToEntries("INPUT",
- []string{"-i", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
- m.appendToEntries("INPUT",
- []string{"-i", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
-
- m.appendToEntries("INPUT",
- []string{"-i", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", chainNameInputRules})
+ established := getConntrackEstablished()
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
-
- m.appendToEntries("OUTPUT",
- []string{"-o", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
-
- m.appendToEntries("OUTPUT",
- []string{"-o", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
-
- m.appendToEntries("OUTPUT",
- []string{"-o", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", chainNameOutputRules})
+ m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
+ m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...))
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", "DROP"})
+ m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", chainNameOutputRules})
+ m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
+ m.appendToEntries("OUTPUT", append([]string{"-o", m.wgIface.Name()}, established...))
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
- m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
- m.appendToEntries("FORWARD",
- []string{"-o", m.wgIface.Name(), "-m", "mark", "--mark", postRoutingMark, "-j", "ACCEPT"})
- m.appendToEntries("FORWARD",
- []string{"-i", m.wgIface.Name(), "-m", "mark", "--mark", postRoutingMark, "-j", "ACCEPT"})
- m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", m.routeingFwChainName})
- m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routeingFwChainName})
-
- m.appendToEntries("PREROUTING",
- []string{"-t", "mangle", "-i", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().IP.String(), "-m", "mark", "--mark", postRoutingMark})
+ m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routingFwChainName})
+ m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...))
}
func (m *aclManager) appendToEntries(chainName string, spec []string) {
@@ -456,18 +359,3 @@ func transformIPsetName(ipsetName string, sPort, dPort string) string {
return ipsetName
}
}
-
-func shouldAddToPrerouting(proto firewall.Protocol, dPort *firewall.Port, direction firewall.RuleDirection) bool {
- if proto == "all" {
- return false
- }
-
- if direction != firewall.RuleDirectionIN {
- return false
- }
-
- if dPort == nil {
- return false
- }
- return true
-}
diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go
index 2d231ec45..fae41d9c5 100644
--- a/client/firewall/iptables/manager_linux.go
+++ b/client/firewall/iptables/manager_linux.go
@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net"
+ "net/netip"
"sync"
"github.com/coreos/go-iptables/iptables"
@@ -21,7 +22,7 @@ type Manager struct {
ipv4Client *iptables.IPTables
aclMgr *aclManager
- router *routerManager
+ router *router
}
// iFaceMapper defines subset methods of interface required for manager
@@ -43,12 +44,12 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
ipv4Client: iptablesClient,
}
- m.router, err = newRouterManager(context, iptablesClient)
+ m.router, err = newRouter(context, iptablesClient, wgIface)
if err != nil {
log.Debugf("failed to initialize route related chains: %s", err)
return nil, err
}
- m.aclMgr, err = newAclManager(iptablesClient, wgIface, m.router.RouteingFwChainName())
+ m.aclMgr, err = newAclManager(iptablesClient, wgIface, chainRTFWD)
if err != nil {
log.Debugf("failed to initialize ACL manager: %s", err)
return nil, err
@@ -57,10 +58,10 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
return m, nil
}
-// AddFiltering rule to the firewall
+// AddPeerFiltering adds a rule to the firewall
//
// Comment will be ignored because some system this feature is not supported
-func (m *Manager) AddFiltering(
+func (m *Manager) AddPeerFiltering(
ip net.IP,
protocol firewall.Protocol,
sPort *firewall.Port,
@@ -73,33 +74,62 @@ func (m *Manager) AddFiltering(
m.mutex.Lock()
defer m.mutex.Unlock()
- return m.aclMgr.AddFiltering(ip, protocol, sPort, dPort, direction, action, ipsetName)
+ return m.aclMgr.AddPeerFiltering(ip, protocol, sPort, dPort, direction, action, ipsetName)
}
-// DeleteRule from the firewall by rule definition
-func (m *Manager) DeleteRule(rule firewall.Rule) error {
+func (m *Manager) AddRouteFiltering(
+ sources [] netip.Prefix,
+ destination netip.Prefix,
+ proto firewall.Protocol,
+ sPort *firewall.Port,
+ dPort *firewall.Port,
+ action firewall.Action,
+) (firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
- return m.aclMgr.DeleteRule(rule)
+ if !destination.Addr().Is4() {
+ return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
+ }
+
+ return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
+}
+
+// DeletePeerRule from the firewall by rule definition
+func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
+ m.mutex.Lock()
+ defer m.mutex.Unlock()
+
+ return m.aclMgr.DeletePeerRule(rule)
+}
+
+func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
+ m.mutex.Lock()
+ defer m.mutex.Unlock()
+
+ return m.router.DeleteRouteRule(rule)
}
func (m *Manager) IsServerRouteSupported() bool {
return true
}
-func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
+func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
m.mutex.Lock()
defer m.mutex.Unlock()
- return m.router.InsertRoutingRules(pair)
+ return m.router.AddNatRule(pair)
}
-func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
+func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
m.mutex.Lock()
defer m.mutex.Unlock()
- return m.router.RemoveRoutingRules(pair)
+ return m.router.RemoveNatRule(pair)
+}
+
+func (m *Manager) SetLegacyManagement(isLegacy bool) error {
+ return firewall.SetLegacyManagement(m.router, isLegacy)
}
// Reset firewall to the default state
@@ -125,7 +155,7 @@ func (m *Manager) AllowNetbird() error {
return nil
}
- _, err := m.AddFiltering(
+ _, err := m.AddPeerFiltering(
net.ParseIP("0.0.0.0"),
"all",
nil,
@@ -138,7 +168,7 @@ func (m *Manager) AllowNetbird() error {
if err != nil {
return fmt.Errorf("failed to allow netbird interface traffic: %w", err)
}
- _, err = m.AddFiltering(
+ _, err = m.AddPeerFiltering(
net.ParseIP("0.0.0.0"),
"all",
nil,
@@ -153,3 +183,7 @@ func (m *Manager) AllowNetbird() error {
// Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil }
+
+func getConntrackEstablished() []string {
+ return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
+}
diff --git a/client/firewall/iptables/manager_linux_test.go b/client/firewall/iptables/manager_linux_test.go
index ceb116c62..0072aa159 100644
--- a/client/firewall/iptables/manager_linux_test.go
+++ b/client/firewall/iptables/manager_linux_test.go
@@ -14,6 +14,21 @@ import (
"github.com/netbirdio/netbird/iface"
)
+var ifaceMock = &iFaceMock{
+ NameFunc: func() string {
+ return "lo"
+ },
+ AddressFunc: func() iface.WGAddress {
+ return iface.WGAddress{
+ IP: net.ParseIP("10.20.0.1"),
+ Network: &net.IPNet{
+ IP: net.ParseIP("10.20.0.0"),
+ Mask: net.IPv4Mask(255, 255, 255, 0),
+ },
+ }
+ },
+}
+
// iFaceMapper defines subset methods of interface required for manager
type iFaceMock struct {
NameFunc func() string
@@ -40,23 +55,8 @@ func TestIptablesManager(t *testing.T) {
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err)
- mock := &iFaceMock{
- NameFunc: func() string {
- return "lo"
- },
- AddressFunc: func() iface.WGAddress {
- return iface.WGAddress{
- IP: net.ParseIP("10.20.0.1"),
- Network: &net.IPNet{
- IP: net.ParseIP("10.20.0.0"),
- Mask: net.IPv4Mask(255, 255, 255, 0),
- },
- }
- },
- }
-
// just check on the local interface
- manager, err := Create(context.Background(), mock)
+ manager, err := Create(context.Background(), ifaceMock)
require.NoError(t, err)
time.Sleep(time.Second)
@@ -72,7 +72,7 @@ func TestIptablesManager(t *testing.T) {
t.Run("add first rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.2")
port := &fw.Port{Values: []int{8080}}
- rule1, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
+ rule1, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
require.NoError(t, err, "failed to add rule")
for _, r := range rule1 {
@@ -87,7 +87,7 @@ func TestIptablesManager(t *testing.T) {
port := &fw.Port{
Values: []int{8043: 8046},
}
- rule2, err = manager.AddFiltering(
+ rule2, err = manager.AddPeerFiltering(
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
require.NoError(t, err, "failed to add rule")
@@ -99,7 +99,7 @@ func TestIptablesManager(t *testing.T) {
t.Run("delete first rule", func(t *testing.T) {
for _, r := range rule1 {
- err := manager.DeleteRule(r)
+ err := manager.DeletePeerRule(r)
require.NoError(t, err, "failed to delete rule")
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, false, r.(*Rule).specs...)
@@ -108,7 +108,7 @@ func TestIptablesManager(t *testing.T) {
t.Run("delete second rule", func(t *testing.T) {
for _, r := range rule2 {
- err := manager.DeleteRule(r)
+ err := manager.DeletePeerRule(r)
require.NoError(t, err, "failed to delete rule")
}
@@ -119,7 +119,7 @@ func TestIptablesManager(t *testing.T) {
// add second rule
ip := net.ParseIP("10.20.0.3")
port := &fw.Port{Values: []int{5353}}
- _, err = manager.AddFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic")
+ _, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic")
require.NoError(t, err, "failed to add rule")
err = manager.Reset()
@@ -170,7 +170,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
t.Run("add first rule with set", func(t *testing.T) {
ip := net.ParseIP("10.20.0.2")
port := &fw.Port{Values: []int{8080}}
- rule1, err = manager.AddFiltering(
+ rule1, err = manager.AddPeerFiltering(
ip, "tcp", nil, port, fw.RuleDirectionOUT,
fw.ActionAccept, "default", "accept HTTP traffic",
)
@@ -189,7 +189,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
port := &fw.Port{
Values: []int{443},
}
- rule2, err = manager.AddFiltering(
+ rule2, err = manager.AddPeerFiltering(
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept,
"default", "accept HTTPS traffic from ports range",
)
@@ -202,7 +202,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
t.Run("delete first rule", func(t *testing.T) {
for _, r := range rule1 {
- err := manager.DeleteRule(r)
+ err := manager.DeletePeerRule(r)
require.NoError(t, err, "failed to delete rule")
require.NotContains(t, manager.aclMgr.ipsetStore.ipsets, r.(*Rule).ruleID, "rule must be removed form the ruleset index")
@@ -211,7 +211,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
t.Run("delete second rule", func(t *testing.T) {
for _, r := range rule2 {
- err := manager.DeleteRule(r)
+ err := manager.DeletePeerRule(r)
require.NoError(t, err, "failed to delete rule")
require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty")
@@ -269,9 +269,9 @@ func TestIptablesCreatePerformance(t *testing.T) {
for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}}
if i%2 == 0 {
- _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
+ _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
} else {
- _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
+ _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
}
require.NoError(t, err, "failed to add rule")
diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go
index e8f09a106..737b20785 100644
--- a/client/firewall/iptables/router_linux.go
+++ b/client/firewall/iptables/router_linux.go
@@ -5,368 +5,478 @@ package iptables
import (
"context"
"fmt"
+ "net/netip"
+ "strconv"
"strings"
"github.com/coreos/go-iptables/iptables"
+ "github.com/hashicorp/go-multierror"
+ "github.com/nadoo/ipset"
log "github.com/sirupsen/logrus"
+ nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
+ "github.com/netbirdio/netbird/client/internal/acl/id"
+ "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
)
const (
- Ipv4Forwarding = "netbird-rt-forwarding"
- ipv4Nat = "netbird-rt-nat"
+ ipv4Nat = "netbird-rt-nat"
)
// constants needed to manage and create iptable rules
const (
tableFilter = "filter"
tableNat = "nat"
- chainFORWARD = "FORWARD"
chainPOSTROUTING = "POSTROUTING"
chainRTNAT = "NETBIRD-RT-NAT"
chainRTFWD = "NETBIRD-RT-FWD"
routingFinalForwardJump = "ACCEPT"
routingFinalNatJump = "MASQUERADE"
+
+ matchSet = "--match-set"
)
-type routerManager struct {
- ctx context.Context
- stop context.CancelFunc
- iptablesClient *iptables.IPTables
- rules map[string][]string
+type routeFilteringRuleParams struct {
+ Sources []netip.Prefix
+ Destination netip.Prefix
+ Proto firewall.Protocol
+ SPort *firewall.Port
+ DPort *firewall.Port
+ Direction firewall.RuleDirection
+ Action firewall.Action
+ SetName string
}
-func newRouterManager(parentCtx context.Context, iptablesClient *iptables.IPTables) (*routerManager, error) {
+type router struct {
+ ctx context.Context
+ stop context.CancelFunc
+ iptablesClient *iptables.IPTables
+ rules map[string][]string
+ ipsetCounter *refcounter.Counter[string, []netip.Prefix, struct{}]
+ wgIface iFaceMapper
+ legacyManagement bool
+}
+
+func newRouter(parentCtx context.Context, iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) {
ctx, cancel := context.WithCancel(parentCtx)
- m := &routerManager{
+ r := &router{
ctx: ctx,
stop: cancel,
iptablesClient: iptablesClient,
rules: make(map[string][]string),
+ wgIface: wgIface,
}
- err := m.cleanUpDefaultForwardRules()
+ r.ipsetCounter = refcounter.New(
+ r.createIpSet,
+ func(name string, _ struct{}) error {
+ return r.deleteIpSet(name)
+ },
+ )
+
+ if err := ipset.Init(); err != nil {
+ return nil, fmt.Errorf("init ipset: %w", err)
+ }
+
+ err := r.cleanUpDefaultForwardRules()
if err != nil {
- log.Errorf("failed to cleanup routing rules: %s", err)
+ log.Errorf("cleanup routing rules: %s", err)
return nil, err
}
- err = m.createContainers()
+ err = r.createContainers()
if err != nil {
- log.Errorf("failed to create containers for route: %s", err)
+ log.Errorf("create containers for route: %s", err)
}
- return m, err
+ return r, err
}
-// InsertRoutingRules inserts an iptables rule pair to the forwarding chain and if enabled, to the nat chain
-func (i *routerManager) InsertRoutingRules(pair firewall.RouterPair) error {
- err := i.insertRoutingRule(firewall.ForwardingFormat, tableFilter, chainRTFWD, routingFinalForwardJump, pair)
- if err != nil {
- return err
+func (r *router) AddRouteFiltering(
+ sources []netip.Prefix,
+ destination netip.Prefix,
+ proto firewall.Protocol,
+ sPort *firewall.Port,
+ dPort *firewall.Port,
+ action firewall.Action,
+) (firewall.Rule, error) {
+ ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
+ if _, ok := r.rules[string(ruleKey)]; ok {
+ return ruleKey, nil
}
- err = i.insertRoutingRule(firewall.InForwardingFormat, tableFilter, chainRTFWD, routingFinalForwardJump, firewall.GetInPair(pair))
- if err != nil {
- return err
+ var setName string
+ if len(sources) > 1 {
+ setName = firewall.GenerateSetName(sources)
+ if _, err := r.ipsetCounter.Increment(setName, sources); err != nil {
+ return nil, fmt.Errorf("create or get ipset: %w", err)
+ }
+ }
+
+ params := routeFilteringRuleParams{
+ Sources: sources,
+ Destination: destination,
+ Proto: proto,
+ SPort: sPort,
+ DPort: dPort,
+ Action: action,
+ SetName: setName,
+ }
+
+ rule := genRouteFilteringRuleSpec(params)
+ if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil {
+ return nil, fmt.Errorf("add route rule: %v", err)
+ }
+
+ r.rules[string(ruleKey)] = rule
+
+ return ruleKey, nil
+}
+
+func (r *router) DeleteRouteRule(rule firewall.Rule) error {
+ ruleKey := rule.GetRuleID()
+
+ if rule, exists := r.rules[ruleKey]; exists {
+ setName := r.findSetNameInRule(rule)
+
+ if err := r.iptablesClient.Delete(tableFilter, chainRTFWD, rule...); err != nil {
+ return fmt.Errorf("delete route rule: %v", err)
+ }
+ delete(r.rules, ruleKey)
+
+ if setName != "" {
+ if _, err := r.ipsetCounter.Decrement(setName); err != nil {
+ return fmt.Errorf("failed to remove ipset: %w", err)
+ }
+ }
+ } else {
+ log.Debugf("route rule %s not found", ruleKey)
+ }
+
+ return nil
+}
+
+func (r *router) findSetNameInRule(rule []string) string {
+ for i, arg := range rule {
+ if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet {
+ return rule[i+3]
+ }
+ }
+ return ""
+}
+
+func (r *router) createIpSet(setName string, sources []netip.Prefix) (struct{}, error) {
+ if err := ipset.Create(setName, ipset.OptTimeout(0)); err != nil {
+ return struct{}{}, fmt.Errorf("create set %s: %w", setName, err)
+ }
+
+ for _, prefix := range sources {
+ if err := ipset.AddPrefix(setName, prefix); err != nil {
+ return struct{}{}, fmt.Errorf("add element to set %s: %w", setName, err)
+ }
+ }
+
+ return struct{}{}, nil
+}
+
+func (r *router) deleteIpSet(setName string) error {
+ if err := ipset.Destroy(setName); err != nil {
+ return fmt.Errorf("destroy set %s: %w", setName, err)
+ }
+ return nil
+}
+
+// AddNatRule inserts an iptables rule pair into the nat chain
+func (r *router) AddNatRule(pair firewall.RouterPair) error {
+ if r.legacyManagement {
+ log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
+ if err := r.addLegacyRouteRule(pair); err != nil {
+ return fmt.Errorf("add legacy routing rule: %w", err)
+ }
}
if !pair.Masquerade {
return nil
}
- err = i.addNATRule(firewall.NatFormat, tableNat, chainRTNAT, routingFinalNatJump, pair)
- if err != nil {
- return err
+ if err := r.addNatRule(pair); err != nil {
+ return fmt.Errorf("add nat rule: %w", err)
}
- err = i.addNATRule(firewall.InNatFormat, tableNat, chainRTNAT, routingFinalNatJump, firewall.GetInPair(pair))
- if err != nil {
- return err
+ if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil {
+ return fmt.Errorf("add inverse nat rule: %w", err)
}
return nil
}
-// insertRoutingRule inserts an iptables rule
-func (i *routerManager) insertRoutingRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error {
- var err error
+// RemoveNatRule removes an iptables rule pair from forwarding and nat chains
+func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
+ if err := r.removeNatRule(pair); err != nil {
+ return fmt.Errorf("remove nat rule: %w", err)
+ }
- ruleKey := firewall.GenKey(keyFormat, pair.ID)
- rule := genRuleSpec(jump, pair.Source, pair.Destination)
- existingRule, found := i.rules[ruleKey]
- if found {
- err = i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
- if err != nil {
- return fmt.Errorf("error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
+ if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
+ return fmt.Errorf("remove inverse nat rule: %w", err)
+ }
+
+ if err := r.removeLegacyRouteRule(pair); err != nil {
+ return fmt.Errorf("remove legacy routing rule: %w", err)
+ }
+
+ return nil
+}
+
+// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
+func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
+ ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
+
+ if err := r.removeLegacyRouteRule(pair); err != nil {
+ return err
+ }
+
+ rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump}
+ if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil {
+ return fmt.Errorf("add legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
+ }
+
+ r.rules[ruleKey] = rule
+
+ return nil
+}
+
+func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
+ ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
+
+ if rule, exists := r.rules[ruleKey]; exists {
+ if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
+ return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
}
- delete(i.rules, ruleKey)
- }
-
- err = i.iptablesClient.Insert(table, chain, 1, rule...)
- if err != nil {
- return fmt.Errorf("error while adding new %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
- }
-
- i.rules[ruleKey] = rule
-
- return nil
-}
-
-// RemoveRoutingRules removes an iptables rule pair from forwarding and nat chains
-func (i *routerManager) RemoveRoutingRules(pair firewall.RouterPair) error {
- err := i.removeRoutingRule(firewall.ForwardingFormat, tableFilter, chainRTFWD, pair)
- if err != nil {
- return err
- }
-
- err = i.removeRoutingRule(firewall.InForwardingFormat, tableFilter, chainRTFWD, firewall.GetInPair(pair))
- if err != nil {
- return err
- }
-
- if !pair.Masquerade {
- return nil
- }
-
- err = i.removeRoutingRule(firewall.NatFormat, tableNat, chainRTNAT, pair)
- if err != nil {
- return err
- }
-
- err = i.removeRoutingRule(firewall.InNatFormat, tableNat, chainRTNAT, firewall.GetInPair(pair))
- if err != nil {
- return err
+ delete(r.rules, ruleKey)
+ } else {
+ log.Debugf("legacy forwarding rule %s not found", ruleKey)
}
return nil
}
-func (i *routerManager) removeRoutingRule(keyFormat, table, chain string, pair firewall.RouterPair) error {
- var err error
+// GetLegacyManagement returns the current legacy management mode
+func (r *router) GetLegacyManagement() bool {
+ return r.legacyManagement
+}
- ruleKey := firewall.GenKey(keyFormat, pair.ID)
- existingRule, found := i.rules[ruleKey]
- if found {
- err = i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
- if err != nil {
- return fmt.Errorf("error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
+// SetLegacyManagement sets the route manager to use legacy management mode
+func (r *router) SetLegacyManagement(isLegacy bool) {
+ r.legacyManagement = isLegacy
+}
+
+// RemoveAllLegacyRouteRules removes all legacy routing rules for mgmt servers pre route acls
+func (r *router) RemoveAllLegacyRouteRules() error {
+ var merr *multierror.Error
+ for k, rule := range r.rules {
+ if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) {
+ continue
+ }
+ if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
+ merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
}
}
- delete(i.rules, ruleKey)
-
- return nil
+ return nberrors.FormatErrorOrNil(merr)
}
-func (i *routerManager) RouteingFwChainName() string {
- return chainRTFWD
-}
-
-func (i *routerManager) Reset() error {
- err := i.cleanUpDefaultForwardRules()
- if err != nil {
- return err
+func (r *router) Reset() error {
+ var merr *multierror.Error
+ if err := r.cleanUpDefaultForwardRules(); err != nil {
+ merr = multierror.Append(merr, err)
}
- i.rules = make(map[string][]string)
- return nil
+ r.rules = make(map[string][]string)
+
+ if err := r.ipsetCounter.Flush(); err != nil {
+ merr = multierror.Append(merr, err)
+ }
+
+ return nberrors.FormatErrorOrNil(merr)
}
-func (i *routerManager) cleanUpDefaultForwardRules() error {
- err := i.cleanJumpRules()
+func (r *router) cleanUpDefaultForwardRules() error {
+ err := r.cleanJumpRules()
if err != nil {
return err
}
log.Debug("flushing routing related tables")
- ok, err := i.iptablesClient.ChainExists(tableFilter, chainRTFWD)
- if err != nil {
- log.Errorf("failed check chain %s,error: %v", chainRTFWD, err)
- return err
- } else if ok {
- err = i.iptablesClient.ClearAndDeleteChain(tableFilter, chainRTFWD)
+ for _, chain := range []string{chainRTFWD, chainRTNAT} {
+ table := tableFilter
+ if chain == chainRTNAT {
+ table = tableNat
+ }
+
+ ok, err := r.iptablesClient.ChainExists(table, chain)
if err != nil {
- log.Errorf("failed cleaning chain %s,error: %v", chainRTFWD, err)
+ log.Errorf("failed check chain %s, error: %v", chain, err)
return err
+ } else if ok {
+ err = r.iptablesClient.ClearAndDeleteChain(table, chain)
+ if err != nil {
+ log.Errorf("failed cleaning chain %s, error: %v", chain, err)
+ return err
+ }
}
}
- ok, err = i.iptablesClient.ChainExists(tableNat, chainRTNAT)
- if err != nil {
- log.Errorf("failed check chain %s,error: %v", chainRTNAT, err)
- return err
- } else if ok {
- err = i.iptablesClient.ClearAndDeleteChain(tableNat, chainRTNAT)
- if err != nil {
- log.Errorf("failed cleaning chain %s,error: %v", chainRTNAT, err)
- return err
- }
- }
- return nil
-}
-
-func (i *routerManager) createContainers() error {
- if i.rules[Ipv4Forwarding] != nil {
- return nil
- }
-
- errMSGFormat := "failed creating chain %s,error: %v"
- err := i.createChain(tableFilter, chainRTFWD)
- if err != nil {
- return fmt.Errorf(errMSGFormat, chainRTFWD, err)
- }
-
- err = i.createChain(tableNat, chainRTNAT)
- if err != nil {
- return fmt.Errorf(errMSGFormat, chainRTNAT, err)
- }
-
- err = i.addJumpRules()
- if err != nil {
- return fmt.Errorf("error while creating jump rules: %v", err)
- }
-
return nil
}
-// addJumpRules create jump rules to send packets to NetBird chains
-func (i *routerManager) addJumpRules() error {
- rule := []string{"-j", chainRTFWD}
- err := i.iptablesClient.Insert(tableFilter, chainFORWARD, 1, rule...)
+func (r *router) createContainers() error {
+ for _, chain := range []string{chainRTFWD, chainRTNAT} {
+ if err := r.createAndSetupChain(chain); err != nil {
+ return fmt.Errorf("create chain %s: %v", chain, err)
+ }
+ }
+
+ if err := r.insertEstablishedRule(chainRTFWD); err != nil {
+ return fmt.Errorf("insert established rule: %v", err)
+ }
+
+ return r.addJumpRules()
+}
+
+func (r *router) createAndSetupChain(chain string) error {
+ table := r.getTableForChain(chain)
+
+ if err := r.iptablesClient.NewChain(table, chain); err != nil {
+ return fmt.Errorf("failed creating chain %s, error: %v", chain, err)
+ }
+
+ return nil
+}
+
+func (r *router) getTableForChain(chain string) string {
+ if chain == chainRTNAT {
+ return tableNat
+ }
+ return tableFilter
+}
+
+func (r *router) insertEstablishedRule(chain string) error {
+ establishedRule := getConntrackEstablished()
+
+ err := r.iptablesClient.Insert(tableFilter, chain, 1, establishedRule...)
+ if err != nil {
+ return fmt.Errorf("failed to insert established rule: %v", err)
+ }
+
+ ruleKey := "established-" + chain
+ r.rules[ruleKey] = establishedRule
+
+ return nil
+}
+
+func (r *router) addJumpRules() error {
+ rule := []string{"-j", chainRTNAT}
+ err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, rule...)
if err != nil {
return err
}
- i.rules[Ipv4Forwarding] = rule
-
- rule = []string{"-j", chainRTNAT}
- err = i.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, rule...)
- if err != nil {
- return err
- }
- i.rules[ipv4Nat] = rule
+ r.rules[ipv4Nat] = rule
return nil
}
-// cleanJumpRules cleans jump rules that was sending packets to NetBird chains
-func (i *routerManager) cleanJumpRules() error {
- var err error
- errMSGFormat := "failed cleaning rule from chain %s,err: %v"
- rule, found := i.rules[Ipv4Forwarding]
+func (r *router) cleanJumpRules() error {
+ rule, found := r.rules[ipv4Nat]
if found {
- err = i.iptablesClient.DeleteIfExists(tableFilter, chainFORWARD, rule...)
+ err := r.iptablesClient.DeleteIfExists(tableNat, chainPOSTROUTING, rule...)
if err != nil {
- return fmt.Errorf(errMSGFormat, chainFORWARD, err)
- }
- }
- rule, found = i.rules[ipv4Nat]
- if found {
- err = i.iptablesClient.DeleteIfExists(tableNat, chainPOSTROUTING, rule...)
- if err != nil {
- return fmt.Errorf(errMSGFormat, chainPOSTROUTING, err)
+ return fmt.Errorf("failed cleaning rule from chain %s, err: %v", chainPOSTROUTING, err)
}
}
- rules, err := i.iptablesClient.List("nat", "POSTROUTING")
- if err != nil {
- return fmt.Errorf("failed to list rules: %s", err)
- }
-
- for _, ruleString := range rules {
- if !strings.Contains(ruleString, "NETBIRD") {
- continue
- }
- rule := strings.Fields(ruleString)
- err := i.iptablesClient.DeleteIfExists("nat", "POSTROUTING", rule[2:]...)
- if err != nil {
- return fmt.Errorf("failed to delete postrouting jump rule: %s", err)
- }
- }
-
- rules, err = i.iptablesClient.List(tableFilter, "FORWARD")
- if err != nil {
- return fmt.Errorf("failed to list rules in FORWARD chain: %s", err)
- }
-
- for _, ruleString := range rules {
- if !strings.Contains(ruleString, "NETBIRD") {
- continue
- }
- rule := strings.Fields(ruleString)
- err := i.iptablesClient.DeleteIfExists(tableFilter, "FORWARD", rule[2:]...)
- if err != nil {
- return fmt.Errorf("failed to delete FORWARD jump rule: %s", err)
- }
- }
return nil
}
-func (i *routerManager) createChain(table, newChain string) error {
- chains, err := i.iptablesClient.ListChains(table)
- if err != nil {
- return fmt.Errorf("couldn't get %s table chains, error: %v", table, err)
- }
+func (r *router) addNatRule(pair firewall.RouterPair) error {
+ ruleKey := firewall.GenKey(firewall.NatFormat, pair)
- shouldCreateChain := true
- for _, chain := range chains {
- if chain == newChain {
- shouldCreateChain = false
- }
- }
-
- if shouldCreateChain {
- err = i.iptablesClient.NewChain(table, newChain)
- if err != nil {
- return fmt.Errorf("couldn't create chain %s in %s table, error: %v", newChain, table, err)
- }
-
- // Add the loopback return rule to the NAT chain
- loopbackRule := []string{"-o", "lo", "-j", "RETURN"}
- err = i.iptablesClient.Insert(table, newChain, 1, loopbackRule...)
- if err != nil {
- return fmt.Errorf("failed to add loopback return rule to %s: %v", chainRTNAT, err)
- }
-
- err = i.iptablesClient.Append(table, newChain, "-j", "RETURN")
- if err != nil {
- return fmt.Errorf("couldn't create chain %s default rule, error: %v", newChain, err)
- }
-
- }
- return nil
-}
-
-// addNATRule appends an iptables rule pair to the nat chain
-func (i *routerManager) addNATRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error {
- ruleKey := firewall.GenKey(keyFormat, pair.ID)
- rule := genRuleSpec(jump, pair.Source, pair.Destination)
- existingRule, found := i.rules[ruleKey]
- if found {
- err := i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
- if err != nil {
+ if rule, exists := r.rules[ruleKey]; exists {
+ if err := r.iptablesClient.DeleteIfExists(tableNat, chainRTNAT, rule...); err != nil {
return fmt.Errorf("error while removing existing NAT rule for %s: %v", pair.Destination, err)
}
- delete(i.rules, ruleKey)
+ delete(r.rules, ruleKey)
}
- // inserting after loopback ignore rule
- err := i.iptablesClient.Insert(table, chain, 2, rule...)
- if err != nil {
+ rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination, r.wgIface.Name(), pair.Inverse)
+ if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule...); err != nil {
return fmt.Errorf("error while appending new NAT rule for %s: %v", pair.Destination, err)
}
- i.rules[ruleKey] = rule
+ r.rules[ruleKey] = rule
return nil
}
-// genRuleSpec generates rule specification
-func genRuleSpec(jump, source, destination string) []string {
- return []string{"-s", source, "-d", destination, "-j", jump}
+func (r *router) removeNatRule(pair firewall.RouterPair) error {
+ ruleKey := firewall.GenKey(firewall.NatFormat, pair)
+
+ if rule, exists := r.rules[ruleKey]; exists {
+ if err := r.iptablesClient.DeleteIfExists(tableNat, chainRTNAT, rule...); err != nil {
+ return fmt.Errorf("error while removing existing nat rule for %s: %v", pair.Destination, err)
+ }
+
+ delete(r.rules, ruleKey)
+ } else {
+ log.Debugf("nat rule %s not found", ruleKey)
+ }
+
+ return nil
}
-func getIptablesRuleType(table string) string {
- ruleType := "forwarding"
- if table == tableNat {
- ruleType = "nat"
+func genRuleSpec(jump string, source, destination netip.Prefix, intf string, inverse bool) []string {
+ intdir := "-i"
+ if inverse {
+ intdir = "-o"
}
- return ruleType
+ return []string{intdir, intf, "-s", source.String(), "-d", destination.String(), "-j", jump}
+}
+
+func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
+ var rule []string
+
+ if params.SetName != "" {
+ rule = append(rule, "-m", "set", matchSet, params.SetName, "src")
+ } else if len(params.Sources) > 0 {
+ source := params.Sources[0]
+ rule = append(rule, "-s", source.String())
+ }
+
+ rule = append(rule, "-d", params.Destination.String())
+
+ if params.Proto != firewall.ProtocolALL {
+ rule = append(rule, "-p", strings.ToLower(string(params.Proto)))
+ rule = append(rule, applyPort("--sport", params.SPort)...)
+ rule = append(rule, applyPort("--dport", params.DPort)...)
+ }
+
+ rule = append(rule, "-j", actionToStr(params.Action))
+
+ return rule
+}
+
+func applyPort(flag string, port *firewall.Port) []string {
+ if port == nil {
+ return nil
+ }
+
+ if port.IsRange && len(port.Values) == 2 {
+ return []string{flag, fmt.Sprintf("%d:%d", port.Values[0], port.Values[1])}
+ }
+
+ if len(port.Values) > 1 {
+ portList := make([]string, len(port.Values))
+ for i, p := range port.Values {
+ portList[i] = strconv.Itoa(p)
+ }
+ return []string{"-m", "multiport", flag, strings.Join(portList, ",")}
+ }
+
+ return []string{flag, strconv.Itoa(port.Values[0])}
}
diff --git a/client/firewall/iptables/router_linux_test.go b/client/firewall/iptables/router_linux_test.go
index 79b970c36..6cede09e2 100644
--- a/client/firewall/iptables/router_linux_test.go
+++ b/client/firewall/iptables/router_linux_test.go
@@ -4,11 +4,13 @@ package iptables
import (
"context"
+ "net/netip"
"os/exec"
"testing"
"github.com/coreos/go-iptables/iptables"
log "github.com/sirupsen/logrus"
+ "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
@@ -28,7 +30,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "failed to init iptables client")
- manager, err := newRouterManager(context.TODO(), iptablesClient)
+ manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock)
require.NoError(t, err, "should return a valid iptables manager")
defer func() {
@@ -37,26 +39,22 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
require.Len(t, manager.rules, 2, "should have created rules map")
- exists, err := manager.iptablesClient.Exists(tableFilter, chainFORWARD, manager.rules[Ipv4Forwarding]...)
- require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainFORWARD)
- require.True(t, exists, "forwarding rule should exist")
-
- exists, err = manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, manager.rules[ipv4Nat]...)
+ exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, manager.rules[ipv4Nat]...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
require.True(t, exists, "postrouting rule should exist")
pair := firewall.RouterPair{
ID: "abc",
- Source: "100.100.100.1/32",
- Destination: "100.100.100.0/24",
+ Source: netip.MustParsePrefix("100.100.100.1/32"),
+ Destination: netip.MustParsePrefix("100.100.100.0/24"),
Masquerade: true,
}
- forward4Rule := genRuleSpec(routingFinalForwardJump, pair.Source, pair.Destination)
+ forward4Rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump}
err = manager.iptablesClient.Insert(tableFilter, chainRTFWD, 1, forward4Rule...)
require.NoError(t, err, "inserting rule should not return error")
- nat4Rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination)
+ nat4Rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination, ifaceMock.Name(), false)
err = manager.iptablesClient.Insert(tableNat, chainRTNAT, 1, nat4Rule...)
require.NoError(t, err, "inserting rule should not return error")
@@ -65,7 +63,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
require.NoError(t, err, "shouldn't return error")
}
-func TestIptablesManager_InsertRoutingRules(t *testing.T) {
+func TestIptablesManager_AddNatRule(t *testing.T) {
if !isIptablesSupported() {
t.SkipNow()
@@ -76,7 +74,7 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "failed to init iptables client")
- manager, err := newRouterManager(context.TODO(), iptablesClient)
+ manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock)
require.NoError(t, err, "shouldn't return error")
defer func() {
@@ -86,35 +84,13 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) {
}
}()
- err = manager.InsertRoutingRules(testCase.InputPair)
+ err = manager.AddNatRule(testCase.InputPair)
require.NoError(t, err, "forwarding pair should be inserted")
- forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
- forwardRule := genRuleSpec(routingFinalForwardJump, testCase.InputPair.Source, testCase.InputPair.Destination)
+ natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
+ natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination, ifaceMock.Name(), false)
- exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...)
- require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
- require.True(t, exists, "forwarding rule should exist")
-
- foundRule, found := manager.rules[forwardRuleKey]
- require.True(t, found, "forwarding rule should exist in the manager map")
- require.Equal(t, forwardRule[:4], foundRule[:4], "stored forwarding rule should match")
-
- inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
- inForwardRule := genRuleSpec(routingFinalForwardJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
-
- exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...)
- require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
- require.True(t, exists, "income forwarding rule should exist")
-
- foundRule, found = manager.rules[inForwardRuleKey]
- require.True(t, found, "income forwarding rule should exist in the manager map")
- require.Equal(t, inForwardRule[:4], foundRule[:4], "stored income forwarding rule should match")
-
- natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
- natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination)
-
- exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
+ exists, err := iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
if testCase.InputPair.Masquerade {
require.True(t, exists, "nat rule should be created")
@@ -127,8 +103,8 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) {
require.False(t, foundNat, "nat rule should not exist in the map")
}
- inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
- inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
+ inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
+ inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInversePair(testCase.InputPair).Source, firewall.GetInversePair(testCase.InputPair).Destination, ifaceMock.Name(), true)
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
@@ -146,7 +122,7 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) {
}
}
-func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
+func TestIptablesManager_RemoveNatRule(t *testing.T) {
if !isIptablesSupported() {
t.SkipNow()
@@ -156,7 +132,7 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
t.Run(testCase.Name, func(t *testing.T) {
iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
- manager, err := newRouterManager(context.TODO(), iptablesClient)
+ manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock)
require.NoError(t, err, "shouldn't return error")
defer func() {
_ = manager.Reset()
@@ -164,26 +140,14 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
require.NoError(t, err, "shouldn't return error")
- forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
- forwardRule := genRuleSpec(routingFinalForwardJump, testCase.InputPair.Source, testCase.InputPair.Destination)
-
- err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, forwardRule...)
- require.NoError(t, err, "inserting rule should not return error")
-
- inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
- inForwardRule := genRuleSpec(routingFinalForwardJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
-
- err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, inForwardRule...)
- require.NoError(t, err, "inserting rule should not return error")
-
- natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
- natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination)
+ natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
+ natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination, ifaceMock.Name(), false)
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, natRule...)
require.NoError(t, err, "inserting rule should not return error")
- inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
- inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
+ inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
+ inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInversePair(testCase.InputPair).Source, firewall.GetInversePair(testCase.InputPair).Destination, ifaceMock.Name(), true)
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, inNatRule...)
require.NoError(t, err, "inserting rule should not return error")
@@ -191,28 +155,14 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
err = manager.Reset()
require.NoError(t, err, "shouldn't return error")
- err = manager.RemoveRoutingRules(testCase.InputPair)
+ err = manager.RemoveNatRule(testCase.InputPair)
require.NoError(t, err, "shouldn't return error")
- exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...)
- require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
- require.False(t, exists, "forwarding rule should not exist")
-
- _, found := manager.rules[forwardRuleKey]
- require.False(t, found, "forwarding rule should exist in the manager map")
-
- exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...)
- require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
- require.False(t, exists, "income forwarding rule should not exist")
-
- _, found = manager.rules[inForwardRuleKey]
- require.False(t, found, "income forwarding rule should exist in the manager map")
-
- exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
+ exists, err := iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
require.False(t, exists, "nat rule should not exist")
- _, found = manager.rules[natRuleKey]
+ _, found := manager.rules[natRuleKey]
require.False(t, found, "nat rule should exist in the manager map")
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
@@ -221,7 +171,175 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
_, found = manager.rules[inNatRuleKey]
require.False(t, found, "income nat rule should exist in the manager map")
-
+ })
+ }
+}
+
+func TestRouter_AddRouteFiltering(t *testing.T) {
+ if !isIptablesSupported() {
+ t.Skip("iptables not supported on this system")
+ }
+
+ iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
+ require.NoError(t, err, "Failed to create iptables client")
+
+ r, err := newRouter(context.Background(), iptablesClient, ifaceMock)
+ require.NoError(t, err, "Failed to create router manager")
+
+ defer func() {
+ err := r.Reset()
+ require.NoError(t, err, "Failed to reset router")
+ }()
+
+ tests := []struct {
+ name string
+ sources []netip.Prefix
+ destination netip.Prefix
+ proto firewall.Protocol
+ sPort *firewall.Port
+ dPort *firewall.Port
+ direction firewall.RuleDirection
+ action firewall.Action
+ expectSet bool
+ }{
+ {
+ name: "Basic TCP rule with single source",
+ sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
+ destination: netip.MustParsePrefix("10.0.0.0/24"),
+ proto: firewall.ProtocolTCP,
+ sPort: nil,
+ dPort: &firewall.Port{Values: []int{80}},
+ direction: firewall.RuleDirectionIN,
+ action: firewall.ActionAccept,
+ expectSet: false,
+ },
+ {
+ name: "UDP rule with multiple sources",
+ sources: []netip.Prefix{
+ netip.MustParsePrefix("172.16.0.0/16"),
+ netip.MustParsePrefix("192.168.0.0/16"),
+ },
+ destination: netip.MustParsePrefix("10.0.0.0/8"),
+ proto: firewall.ProtocolUDP,
+ sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true},
+ dPort: nil,
+ direction: firewall.RuleDirectionOUT,
+ action: firewall.ActionDrop,
+ expectSet: true,
+ },
+ {
+ name: "All protocols rule",
+ sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
+ destination: netip.MustParsePrefix("0.0.0.0/0"),
+ proto: firewall.ProtocolALL,
+ sPort: nil,
+ dPort: nil,
+ direction: firewall.RuleDirectionIN,
+ action: firewall.ActionAccept,
+ expectSet: false,
+ },
+ {
+ name: "ICMP rule",
+ sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")},
+ destination: netip.MustParsePrefix("10.0.0.0/8"),
+ proto: firewall.ProtocolICMP,
+ sPort: nil,
+ dPort: nil,
+ direction: firewall.RuleDirectionIN,
+ action: firewall.ActionAccept,
+ expectSet: false,
+ },
+ {
+ name: "TCP rule with multiple source ports",
+ sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
+ destination: netip.MustParsePrefix("192.168.0.0/16"),
+ proto: firewall.ProtocolTCP,
+ sPort: &firewall.Port{Values: []int{80, 443, 8080}},
+ dPort: nil,
+ direction: firewall.RuleDirectionOUT,
+ action: firewall.ActionAccept,
+ expectSet: false,
+ },
+ {
+ name: "UDP rule with single IP and port range",
+ sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")},
+ destination: netip.MustParsePrefix("10.0.0.0/24"),
+ proto: firewall.ProtocolUDP,
+ sPort: nil,
+ dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true},
+ direction: firewall.RuleDirectionIN,
+ action: firewall.ActionDrop,
+ expectSet: false,
+ },
+ {
+ name: "TCP rule with source and destination ports",
+ sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
+ destination: netip.MustParsePrefix("172.16.0.0/16"),
+ proto: firewall.ProtocolTCP,
+ sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true},
+ dPort: &firewall.Port{Values: []int{22}},
+ direction: firewall.RuleDirectionOUT,
+ action: firewall.ActionAccept,
+ expectSet: false,
+ },
+ {
+ name: "Drop all incoming traffic",
+ sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
+ destination: netip.MustParsePrefix("192.168.0.0/24"),
+ proto: firewall.ProtocolALL,
+ sPort: nil,
+ dPort: nil,
+ direction: firewall.RuleDirectionIN,
+ action: firewall.ActionDrop,
+ expectSet: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
+ require.NoError(t, err, "AddRouteFiltering failed")
+
+ // Check if the rule is in the internal map
+ rule, ok := r.rules[ruleKey.GetRuleID()]
+ assert.True(t, ok, "Rule not found in internal map")
+
+ // Log the internal rule
+ t.Logf("Internal rule: %v", rule)
+
+ // Check if the rule exists in iptables
+ exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, rule...)
+ assert.NoError(t, err, "Failed to check rule existence")
+ assert.True(t, exists, "Rule not found in iptables")
+
+ // Verify rule content
+ params := routeFilteringRuleParams{
+ Sources: tt.sources,
+ Destination: tt.destination,
+ Proto: tt.proto,
+ SPort: tt.sPort,
+ DPort: tt.dPort,
+ Action: tt.action,
+ SetName: "",
+ }
+
+ expectedRule := genRouteFilteringRuleSpec(params)
+
+ if tt.expectSet {
+ setName := firewall.GenerateSetName(tt.sources)
+ params.SetName = setName
+ expectedRule = genRouteFilteringRuleSpec(params)
+
+ // Check if the set was created
+ _, exists := r.ipsetCounter.Get(setName)
+ assert.True(t, exists, "IPSet not created")
+ }
+
+ assert.Equal(t, expectedRule, rule, "Rule content mismatch")
+
+ // Clean up
+ err = r.DeleteRouteRule(ruleKey)
+ require.NoError(t, err, "Failed to delete rule")
})
}
}
diff --git a/client/firewall/manager/firewall.go b/client/firewall/manager/firewall.go
index 6e4edb63e..a6185d370 100644
--- a/client/firewall/manager/firewall.go
+++ b/client/firewall/manager/firewall.go
@@ -1,15 +1,21 @@
package manager
import (
+ "crypto/sha256"
+ "encoding/hex"
"fmt"
"net"
+ "net/netip"
+ "sort"
+ "strings"
+
+ log "github.com/sirupsen/logrus"
)
const (
- NatFormat = "netbird-nat-%s"
- ForwardingFormat = "netbird-fwd-%s"
- InNatFormat = "netbird-nat-in-%s"
- InForwardingFormat = "netbird-fwd-in-%s"
+ ForwardingFormatPrefix = "netbird-fwd-"
+ ForwardingFormat = "netbird-fwd-%s-%t"
+ NatFormat = "netbird-nat-%s-%t"
)
// Rule abstraction should be implemented by each firewall manager
@@ -49,11 +55,11 @@ type Manager interface {
// AllowNetbird allows netbird interface traffic
AllowNetbird() error
- // AddFiltering rule to the firewall
+ // AddPeerFiltering adds a rule to the firewall
//
// If comment argument is empty firewall manager should set
// rule ID as comment for the rule
- AddFiltering(
+ AddPeerFiltering(
ip net.IP,
proto Protocol,
sPort *Port,
@@ -64,17 +70,25 @@ type Manager interface {
comment string,
) ([]Rule, error)
- // DeleteRule from the firewall by rule definition
- DeleteRule(rule Rule) error
+ // DeletePeerRule from the firewall by rule definition
+ DeletePeerRule(rule Rule) error
// IsServerRouteSupported returns true if the firewall supports server side routing operations
IsServerRouteSupported() bool
- // InsertRoutingRules inserts a routing firewall rule
- InsertRoutingRules(pair RouterPair) error
+ AddRouteFiltering(source []netip.Prefix, destination netip.Prefix, proto Protocol, sPort *Port, dPort *Port, action Action) (Rule, error)
- // RemoveRoutingRules removes a routing firewall rule
- RemoveRoutingRules(pair RouterPair) error
+ // DeleteRouteRule deletes a routing rule
+ DeleteRouteRule(rule Rule) error
+
+ // AddNatRule inserts a routing NAT rule
+ AddNatRule(pair RouterPair) error
+
+ // RemoveNatRule removes a routing NAT rule
+ RemoveNatRule(pair RouterPair) error
+
+ // SetLegacyManagement sets the legacy management mode
+ SetLegacyManagement(legacy bool) error
// Reset firewall to the default state
Reset() error
@@ -83,6 +97,89 @@ type Manager interface {
Flush() error
}
-func GenKey(format string, input string) string {
- return fmt.Sprintf(format, input)
+func GenKey(format string, pair RouterPair) string {
+ return fmt.Sprintf(format, pair.ID, pair.Inverse)
+}
+
+// LegacyManager defines the interface for legacy management operations
+type LegacyManager interface {
+ RemoveAllLegacyRouteRules() error
+ GetLegacyManagement() bool
+ SetLegacyManagement(bool)
+}
+
+// SetLegacyManagement sets the route manager to use legacy management
+func SetLegacyManagement(router LegacyManager, isLegacy bool) error {
+ oldLegacy := router.GetLegacyManagement()
+
+ if oldLegacy != isLegacy {
+ router.SetLegacyManagement(isLegacy)
+ log.Debugf("Set legacy management to %v", isLegacy)
+ }
+
+ // client reconnected to a newer mgmt, we need to clean up the legacy rules
+ if !isLegacy && oldLegacy {
+ if err := router.RemoveAllLegacyRouteRules(); err != nil {
+ return fmt.Errorf("remove legacy routing rules: %v", err)
+ }
+
+ log.Debugf("Legacy routing rules removed")
+ }
+
+ return nil
+}
+
+// GenerateSetName generates a unique name for an ipset based on the given sources.
+func GenerateSetName(sources []netip.Prefix) string {
+ // sort for consistent naming
+ sortPrefixes(sources)
+
+ var sourcesStr strings.Builder
+ for _, src := range sources {
+ sourcesStr.WriteString(src.String())
+ }
+
+ hash := sha256.Sum256([]byte(sourcesStr.String()))
+ shortHash := hex.EncodeToString(hash[:])[:8]
+
+ return fmt.Sprintf("nb-%s", shortHash)
+}
+
+// MergeIPRanges merges overlapping IP ranges and returns a slice of non-overlapping netip.Prefix
+func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix {
+ if len(prefixes) == 0 {
+ return prefixes
+ }
+
+ merged := []netip.Prefix{prefixes[0]}
+ for _, prefix := range prefixes[1:] {
+ last := merged[len(merged)-1]
+ if last.Contains(prefix.Addr()) {
+ // If the current prefix is contained within the last merged prefix, skip it
+ continue
+ }
+ if prefix.Contains(last.Addr()) {
+ // If the current prefix contains the last merged prefix, replace it
+ merged[len(merged)-1] = prefix
+ } else {
+ // Otherwise, add the current prefix to the merged list
+ merged = append(merged, prefix)
+ }
+ }
+
+ return merged
+}
+
+// sortPrefixes sorts the given slice of netip.Prefix in place.
+// It sorts first by IP address, then by prefix length (most specific to least specific).
+func sortPrefixes(prefixes []netip.Prefix) {
+ sort.Slice(prefixes, func(i, j int) bool {
+ addrCmp := prefixes[i].Addr().Compare(prefixes[j].Addr())
+ if addrCmp != 0 {
+ return addrCmp < 0
+ }
+
+ // If IP addresses are the same, compare prefix lengths (longer prefixes first)
+ return prefixes[i].Bits() > prefixes[j].Bits()
+ })
}
diff --git a/client/firewall/manager/firewall_test.go b/client/firewall/manager/firewall_test.go
new file mode 100644
index 000000000..3f47d6679
--- /dev/null
+++ b/client/firewall/manager/firewall_test.go
@@ -0,0 +1,192 @@
+package manager_test
+
+import (
+ "net/netip"
+ "reflect"
+ "regexp"
+ "testing"
+
+ "github.com/netbirdio/netbird/client/firewall/manager"
+)
+
+func TestGenerateSetName(t *testing.T) {
+ t.Run("Different orders result in same hash", func(t *testing.T) {
+ prefixes1 := []netip.Prefix{
+ netip.MustParsePrefix("192.168.1.0/24"),
+ netip.MustParsePrefix("10.0.0.0/8"),
+ }
+ prefixes2 := []netip.Prefix{
+ netip.MustParsePrefix("10.0.0.0/8"),
+ netip.MustParsePrefix("192.168.1.0/24"),
+ }
+
+ result1 := manager.GenerateSetName(prefixes1)
+ result2 := manager.GenerateSetName(prefixes2)
+
+ if result1 != result2 {
+ t.Errorf("Different orders produced different hashes: %s != %s", result1, result2)
+ }
+ })
+
+ t.Run("Result format is correct", func(t *testing.T) {
+ prefixes := []netip.Prefix{
+ netip.MustParsePrefix("192.168.1.0/24"),
+ netip.MustParsePrefix("10.0.0.0/8"),
+ }
+
+ result := manager.GenerateSetName(prefixes)
+
+ matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result)
+ if err != nil {
+ t.Fatalf("Error matching regex: %v", err)
+ }
+ if !matched {
+ t.Errorf("Result format is incorrect: %s", result)
+ }
+ })
+
+ t.Run("Empty input produces consistent result", func(t *testing.T) {
+ result1 := manager.GenerateSetName([]netip.Prefix{})
+ result2 := manager.GenerateSetName([]netip.Prefix{})
+
+ if result1 != result2 {
+ t.Errorf("Empty input produced inconsistent results: %s != %s", result1, result2)
+ }
+ })
+
+ t.Run("IPv4 and IPv6 mixing", func(t *testing.T) {
+ prefixes1 := []netip.Prefix{
+ netip.MustParsePrefix("192.168.1.0/24"),
+ netip.MustParsePrefix("2001:db8::/32"),
+ }
+ prefixes2 := []netip.Prefix{
+ netip.MustParsePrefix("2001:db8::/32"),
+ netip.MustParsePrefix("192.168.1.0/24"),
+ }
+
+ result1 := manager.GenerateSetName(prefixes1)
+ result2 := manager.GenerateSetName(prefixes2)
+
+ if result1 != result2 {
+ t.Errorf("Different orders of IPv4 and IPv6 produced different hashes: %s != %s", result1, result2)
+ }
+ })
+}
+
+func TestMergeIPRanges(t *testing.T) {
+ tests := []struct {
+ name string
+ input []netip.Prefix
+ expected []netip.Prefix
+ }{
+ {
+ name: "Empty input",
+ input: []netip.Prefix{},
+ expected: []netip.Prefix{},
+ },
+ {
+ name: "Single range",
+ input: []netip.Prefix{
+ netip.MustParsePrefix("192.168.1.0/24"),
+ },
+ expected: []netip.Prefix{
+ netip.MustParsePrefix("192.168.1.0/24"),
+ },
+ },
+ {
+ name: "Two non-overlapping ranges",
+ input: []netip.Prefix{
+ netip.MustParsePrefix("192.168.1.0/24"),
+ netip.MustParsePrefix("10.0.0.0/8"),
+ },
+ expected: []netip.Prefix{
+ netip.MustParsePrefix("192.168.1.0/24"),
+ netip.MustParsePrefix("10.0.0.0/8"),
+ },
+ },
+ {
+ name: "One range containing another",
+ input: []netip.Prefix{
+ netip.MustParsePrefix("192.168.0.0/16"),
+ netip.MustParsePrefix("192.168.1.0/24"),
+ },
+ expected: []netip.Prefix{
+ netip.MustParsePrefix("192.168.0.0/16"),
+ },
+ },
+ {
+ name: "One range containing another (different order)",
+ input: []netip.Prefix{
+ netip.MustParsePrefix("192.168.1.0/24"),
+ netip.MustParsePrefix("192.168.0.0/16"),
+ },
+ expected: []netip.Prefix{
+ netip.MustParsePrefix("192.168.0.0/16"),
+ },
+ },
+ {
+ name: "Overlapping ranges",
+ input: []netip.Prefix{
+ netip.MustParsePrefix("192.168.1.0/24"),
+ netip.MustParsePrefix("192.168.1.128/25"),
+ },
+ expected: []netip.Prefix{
+ netip.MustParsePrefix("192.168.1.0/24"),
+ },
+ },
+ {
+ name: "Overlapping ranges (different order)",
+ input: []netip.Prefix{
+ netip.MustParsePrefix("192.168.1.128/25"),
+ netip.MustParsePrefix("192.168.1.0/24"),
+ },
+ expected: []netip.Prefix{
+ netip.MustParsePrefix("192.168.1.0/24"),
+ },
+ },
+ {
+ name: "Multiple overlapping ranges",
+ input: []netip.Prefix{
+ netip.MustParsePrefix("192.168.0.0/16"),
+ netip.MustParsePrefix("192.168.1.0/24"),
+ netip.MustParsePrefix("192.168.2.0/24"),
+ netip.MustParsePrefix("192.168.1.128/25"),
+ },
+ expected: []netip.Prefix{
+ netip.MustParsePrefix("192.168.0.0/16"),
+ },
+ },
+ {
+ name: "Partially overlapping ranges",
+ input: []netip.Prefix{
+ netip.MustParsePrefix("192.168.0.0/23"),
+ netip.MustParsePrefix("192.168.1.0/24"),
+ netip.MustParsePrefix("192.168.2.0/25"),
+ },
+ expected: []netip.Prefix{
+ netip.MustParsePrefix("192.168.0.0/23"),
+ netip.MustParsePrefix("192.168.2.0/25"),
+ },
+ },
+ {
+ name: "IPv6 ranges",
+ input: []netip.Prefix{
+ netip.MustParsePrefix("2001:db8::/32"),
+ netip.MustParsePrefix("2001:db8:1::/48"),
+ netip.MustParsePrefix("2001:db8:2::/48"),
+ },
+ expected: []netip.Prefix{
+ netip.MustParsePrefix("2001:db8::/32"),
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := manager.MergeIPRanges(tt.input)
+ if !reflect.DeepEqual(result, tt.expected) {
+ t.Errorf("MergeIPRanges() = %v, want %v", result, tt.expected)
+ }
+ })
+ }
+}
diff --git a/client/firewall/manager/routerpair.go b/client/firewall/manager/routerpair.go
index b63a9f104..8c94b7dd4 100644
--- a/client/firewall/manager/routerpair.go
+++ b/client/firewall/manager/routerpair.go
@@ -1,18 +1,26 @@
package manager
+import (
+ "net/netip"
+
+ "github.com/netbirdio/netbird/route"
+)
+
type RouterPair struct {
- ID string
- Source string
- Destination string
+ ID route.ID
+ Source netip.Prefix
+ Destination netip.Prefix
Masquerade bool
+ Inverse bool
}
-func GetInPair(pair RouterPair) RouterPair {
+func GetInversePair(pair RouterPair) RouterPair {
return RouterPair{
ID: pair.ID,
// invert Source/Destination
Source: pair.Destination,
Destination: pair.Source,
Masquerade: pair.Masquerade,
+ Inverse: true,
}
}
diff --git a/client/firewall/nftables/acl_linux.go b/client/firewall/nftables/acl_linux.go
index 1fa41b63a..85cba9e1c 100644
--- a/client/firewall/nftables/acl_linux.go
+++ b/client/firewall/nftables/acl_linux.go
@@ -33,9 +33,10 @@ const (
allowNetbirdInputRuleID = "allow Netbird incoming traffic"
)
+const flushError = "flush: %w"
+
var (
- anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
- postroutingMark = []byte{0xe4, 0x7, 0x0, 0x00}
+ anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
)
type AclManager struct {
@@ -48,7 +49,6 @@ type AclManager struct {
chainInputRules *nftables.Chain
chainOutputRules *nftables.Chain
chainFwFilter *nftables.Chain
- chainPrerouting *nftables.Chain
ipsetStore *ipsetStore
rules map[string]*Rule
@@ -64,7 +64,7 @@ type iFaceMapper interface {
func newAclManager(table *nftables.Table, wgIface iFaceMapper, routeingFwChainName string) (*AclManager, error) {
// sConn is used for creating sets and adding/removing elements from them
// it's differ then rConn (which does create new conn for each flush operation)
- // and is permanent. Using same connection for booth type of operations
+ // and is permanent. Using same connection for both type of operations
// overloads netlink with high amount of rules ( > 10000)
sConn, err := nftables.New(nftables.AsLasting())
if err != nil {
@@ -90,11 +90,11 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routeingFwChainNa
return m, nil
}
-// AddFiltering rule to the firewall
+// AddPeerFiltering rule to the firewall
//
// If comment argument is empty firewall manager should set
// rule ID as comment for the rule
-func (m *AclManager) AddFiltering(
+func (m *AclManager) AddPeerFiltering(
ip net.IP,
proto firewall.Protocol,
sPort *firewall.Port,
@@ -120,20 +120,11 @@ func (m *AclManager) AddFiltering(
}
newRules = append(newRules, ioRule)
- if !shouldAddToPrerouting(proto, dPort, direction) {
- return newRules, nil
- }
-
- preroutingRule, err := m.addPreroutingFiltering(ipset, proto, dPort, ip)
- if err != nil {
- return newRules, err
- }
- newRules = append(newRules, preroutingRule)
return newRules, nil
}
-// DeleteRule from the firewall by rule definition
-func (m *AclManager) DeleteRule(rule firewall.Rule) error {
+// DeletePeerRule from the firewall by rule definition
+func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
r, ok := rule.(*Rule)
if !ok {
return fmt.Errorf("invalid rule type")
@@ -199,8 +190,7 @@ func (m *AclManager) DeleteRule(rule firewall.Rule) error {
return nil
}
-// createDefaultAllowRules In case if the USP firewall manager can use the native firewall manager we must to create allow rules for
-// input and output chains
+// createDefaultAllowRules creates default allow rules for the input and output chains
func (m *AclManager) createDefaultAllowRules() error {
expIn := []expr.Any{
&expr.Payload{
@@ -214,13 +204,13 @@ func (m *AclManager) createDefaultAllowRules() error {
SourceRegister: 1,
DestRegister: 1,
Len: 4,
- Mask: []byte{0x00, 0x00, 0x00, 0x00},
- Xor: zeroXor,
+ Mask: []byte{0, 0, 0, 0},
+ Xor: []byte{0, 0, 0, 0},
},
// net address
&expr.Cmp{
Register: 1,
- Data: []byte{0x00, 0x00, 0x00, 0x00},
+ Data: []byte{0, 0, 0, 0},
},
&expr.Verdict{
Kind: expr.VerdictAccept,
@@ -246,13 +236,13 @@ func (m *AclManager) createDefaultAllowRules() error {
SourceRegister: 1,
DestRegister: 1,
Len: 4,
- Mask: []byte{0x00, 0x00, 0x00, 0x00},
- Xor: zeroXor,
+ Mask: []byte{0, 0, 0, 0},
+ Xor: []byte{0, 0, 0, 0},
},
// net address
&expr.Cmp{
Register: 1,
- Data: []byte{0x00, 0x00, 0x00, 0x00},
+ Data: []byte{0, 0, 0, 0},
},
&expr.Verdict{
Kind: expr.VerdictAccept,
@@ -266,10 +256,8 @@ func (m *AclManager) createDefaultAllowRules() error {
Exprs: expOut,
})
- err := m.rConn.Flush()
- if err != nil {
- log.Debugf("failed to create default allow rules: %s", err)
- return err
+ if err := m.rConn.Flush(); err != nil {
+ return fmt.Errorf(flushError, err)
}
return nil
}
@@ -290,15 +278,11 @@ func (m *AclManager) Flush() error {
log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err)
}
- if err := m.refreshRuleHandles(m.chainPrerouting); err != nil {
- log.Errorf("failed to refresh rule handles IPv4 prerouting chain: %v", err)
- }
-
return nil
}
func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, ipset *nftables.Set, comment string) (*Rule, error) {
- ruleId := generateRuleId(ip, sPort, dPort, direction, action, ipset)
+ ruleId := generatePeerRuleId(ip, sPort, dPort, direction, action, ipset)
if r, ok := m.rules[ruleId]; ok {
return &Rule{
r.nftRule,
@@ -308,18 +292,7 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
}, nil
}
- ifaceKey := expr.MetaKeyIIFNAME
- if direction == firewall.RuleDirectionOUT {
- ifaceKey = expr.MetaKeyOIFNAME
- }
- expressions := []expr.Any{
- &expr.Meta{Key: ifaceKey, Register: 1},
- &expr.Cmp{
- Op: expr.CmpOpEq,
- Register: 1,
- Data: ifname(m.wgIface.Name()),
- },
- }
+ var expressions []expr.Any
if proto != firewall.ProtocolALL {
expressions = append(expressions, &expr.Payload{
@@ -329,21 +302,15 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
Len: uint32(1),
})
- var protoData []byte
- switch proto {
- case firewall.ProtocolTCP:
- protoData = []byte{unix.IPPROTO_TCP}
- case firewall.ProtocolUDP:
- protoData = []byte{unix.IPPROTO_UDP}
- case firewall.ProtocolICMP:
- protoData = []byte{unix.IPPROTO_ICMP}
- default:
- return nil, fmt.Errorf("unsupported protocol: %s", proto)
+ protoData, err := protoToInt(proto)
+ if err != nil {
+ return nil, fmt.Errorf("convert protocol to number: %v", err)
}
+
expressions = append(expressions, &expr.Cmp{
Register: 1,
Op: expr.CmpOpEq,
- Data: protoData,
+ Data: []byte{protoData},
})
}
@@ -432,10 +399,9 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
} else {
chain = m.chainOutputRules
}
- nftRule := m.rConn.InsertRule(&nftables.Rule{
+ nftRule := m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: chain,
- Position: 0,
Exprs: expressions,
UserData: userData,
})
@@ -453,139 +419,13 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
return rule, nil
}
-func (m *AclManager) addPreroutingFiltering(ipset *nftables.Set, proto firewall.Protocol, port *firewall.Port, ip net.IP) (*Rule, error) {
- var protoData []byte
- switch proto {
- case firewall.ProtocolTCP:
- protoData = []byte{unix.IPPROTO_TCP}
- case firewall.ProtocolUDP:
- protoData = []byte{unix.IPPROTO_UDP}
- case firewall.ProtocolICMP:
- protoData = []byte{unix.IPPROTO_ICMP}
- default:
- return nil, fmt.Errorf("unsupported protocol: %s", proto)
- }
-
- ruleId := generateRuleIdForMangle(ipset, ip, proto, port)
- if r, ok := m.rules[ruleId]; ok {
- return &Rule{
- r.nftRule,
- r.nftSet,
- r.ruleID,
- ip,
- }, nil
- }
-
- var ipExpression expr.Any
- // add individual IP for match if no ipset defined
- rawIP := ip.To4()
- if ipset == nil {
- ipExpression = &expr.Cmp{
- Op: expr.CmpOpEq,
- Register: 1,
- Data: rawIP,
- }
- } else {
- ipExpression = &expr.Lookup{
- SourceRegister: 1,
- SetName: ipset.Name,
- SetID: ipset.ID,
- }
- }
-
- expressions := []expr.Any{
- &expr.Payload{
- DestRegister: 1,
- Base: expr.PayloadBaseNetworkHeader,
- Offset: 12,
- Len: 4,
- },
- ipExpression,
- &expr.Payload{
- DestRegister: 1,
- Base: expr.PayloadBaseNetworkHeader,
- Offset: 16,
- Len: 4,
- },
- &expr.Cmp{
- Op: expr.CmpOpEq,
- Register: 1,
- Data: m.wgIface.Address().IP.To4(),
- },
- &expr.Payload{
- DestRegister: 1,
- Base: expr.PayloadBaseNetworkHeader,
- Offset: uint32(9),
- Len: uint32(1),
- },
- &expr.Cmp{
- Register: 1,
- Op: expr.CmpOpEq,
- Data: protoData,
- },
- }
-
- if port != nil {
- expressions = append(expressions,
- &expr.Payload{
- DestRegister: 1,
- Base: expr.PayloadBaseTransportHeader,
- Offset: 2,
- Len: 2,
- },
- &expr.Cmp{
- Op: expr.CmpOpEq,
- Register: 1,
- Data: encodePort(*port),
- },
- )
- }
-
- expressions = append(expressions,
- &expr.Immediate{
- Register: 1,
- Data: postroutingMark,
- },
- &expr.Meta{
- Key: expr.MetaKeyMARK,
- SourceRegister: true,
- Register: 1,
- },
- )
-
- nftRule := m.rConn.InsertRule(&nftables.Rule{
- Table: m.workTable,
- Chain: m.chainPrerouting,
- Position: 0,
- Exprs: expressions,
- UserData: []byte(ruleId),
- })
-
- if err := m.rConn.Flush(); err != nil {
- return nil, fmt.Errorf("flush insert rule: %v", err)
- }
-
- rule := &Rule{
- nftRule: nftRule,
- nftSet: ipset,
- ruleID: ruleId,
- ip: ip,
- }
-
- m.rules[ruleId] = rule
- if ipset != nil {
- m.ipsetStore.AddReferenceToIpset(ipset.Name)
- }
- return rule, nil
-}
-
func (m *AclManager) createDefaultChains() (err error) {
// chainNameInputRules
chain := m.createChain(chainNameInputRules)
err = m.rConn.Flush()
if err != nil {
log.Debugf("failed to create chain (%s): %s", chain.Name, err)
- return err
+ return fmt.Errorf(flushError, err)
}
m.chainInputRules = chain
@@ -601,9 +441,6 @@ func (m *AclManager) createDefaultChains() (err error) {
// netbird-acl-input-filter
// type filter hook input priority filter; policy accept;
chain = m.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput)
- //netbird-acl-input-filter iifname "wt0" ip saddr 100.72.0.0/16 ip daddr != 100.72.0.0/16 accept
- m.addRouteAllowRule(chain, expr.MetaKeyIIFNAME)
- m.addFwdAllow(chain, expr.MetaKeyIIFNAME)
m.addJumpRule(chain, m.chainInputRules.Name, expr.MetaKeyIIFNAME) // to netbird-acl-input-rules
m.addDropExpressions(chain, expr.MetaKeyIIFNAME)
err = m.rConn.Flush()
@@ -615,7 +452,6 @@ func (m *AclManager) createDefaultChains() (err error) {
// netbird-acl-output-filter
// type filter hook output priority filter; policy accept;
chain = m.createFilterChainWithHook(chainNameOutputFilter, nftables.ChainHookOutput)
- m.addRouteAllowRule(chain, expr.MetaKeyOIFNAME)
m.addFwdAllow(chain, expr.MetaKeyOIFNAME)
m.addJumpRule(chain, m.chainOutputRules.Name, expr.MetaKeyOIFNAME) // to netbird-acl-output-rules
m.addDropExpressions(chain, expr.MetaKeyOIFNAME)
@@ -627,24 +463,15 @@ func (m *AclManager) createDefaultChains() (err error) {
// netbird-acl-forward-filter
m.chainFwFilter = m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
- m.addJumpRulesToRtForward() // to
- m.addMarkAccept()
- m.addJumpRuleToInputChain() // to netbird-acl-input-rules
+ m.addJumpRulesToRtForward() // to netbird-rt-fwd
m.addDropExpressions(m.chainFwFilter, expr.MetaKeyIIFNAME)
+
err = m.rConn.Flush()
if err != nil {
log.Debugf("failed to create chain (%s): %s", chainNameForwardFilter, err)
- return err
+ return fmt.Errorf(flushError, err)
}
- // netbird-acl-output-filter
- // type filter hook output priority filter; policy accept;
- m.chainPrerouting = m.createPreroutingMangle()
- err = m.rConn.Flush()
- if err != nil {
- log.Debugf("failed to create chain (%s): %s", m.chainPrerouting.Name, err)
- return err
- }
return nil
}
@@ -667,59 +494,6 @@ func (m *AclManager) addJumpRulesToRtForward() {
Chain: m.chainFwFilter,
Exprs: expressions,
})
-
- expressions = []expr.Any{
- &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
- &expr.Cmp{
- Op: expr.CmpOpEq,
- Register: 1,
- Data: ifname(m.wgIface.Name()),
- },
- &expr.Verdict{
- Kind: expr.VerdictJump,
- Chain: m.routeingFwChainName,
- },
- }
-
- _ = m.rConn.AddRule(&nftables.Rule{
- Table: m.workTable,
- Chain: m.chainFwFilter,
- Exprs: expressions,
- })
-}
-
-func (m *AclManager) addMarkAccept() {
- // oifname "wt0" meta mark 0x000007e4 accept
- // iifname "wt0" meta mark 0x000007e4 accept
- ifaces := []expr.MetaKey{expr.MetaKeyIIFNAME, expr.MetaKeyOIFNAME}
- for _, iface := range ifaces {
- expressions := []expr.Any{
- &expr.Meta{Key: iface, Register: 1},
- &expr.Cmp{
- Op: expr.CmpOpEq,
- Register: 1,
- Data: ifname(m.wgIface.Name()),
- },
- &expr.Meta{
- Key: expr.MetaKeyMARK,
- Register: 1,
- },
- &expr.Cmp{
- Op: expr.CmpOpEq,
- Register: 1,
- Data: postroutingMark,
- },
- &expr.Verdict{
- Kind: expr.VerdictAccept,
- },
- }
-
- _ = m.rConn.AddRule(&nftables.Rule{
- Table: m.workTable,
- Chain: m.chainFwFilter,
- Exprs: expressions,
- })
- }
}
func (m *AclManager) createChain(name string) *nftables.Chain {
@@ -729,6 +503,9 @@ func (m *AclManager) createChain(name string) *nftables.Chain {
}
chain = m.rConn.AddChain(chain)
+
+ insertReturnTrafficRule(m.rConn, m.workTable, chain)
+
return chain
}
@@ -746,74 +523,6 @@ func (m *AclManager) createFilterChainWithHook(name string, hookNum nftables.Cha
return m.rConn.AddChain(chain)
}
-func (m *AclManager) createPreroutingMangle() *nftables.Chain {
- polAccept := nftables.ChainPolicyAccept
- chain := &nftables.Chain{
- Name: "netbird-acl-prerouting-filter",
- Table: m.workTable,
- Hooknum: nftables.ChainHookPrerouting,
- Priority: nftables.ChainPriorityMangle,
- Type: nftables.ChainTypeFilter,
- Policy: &polAccept,
- }
-
- chain = m.rConn.AddChain(chain)
-
- ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
- expressions := []expr.Any{
- &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
- &expr.Cmp{
- Op: expr.CmpOpEq,
- Register: 1,
- Data: ifname(m.wgIface.Name()),
- },
- &expr.Payload{
- DestRegister: 2,
- Base: expr.PayloadBaseNetworkHeader,
- Offset: 12,
- Len: 4,
- },
- &expr.Bitwise{
- SourceRegister: 2,
- DestRegister: 2,
- Len: 4,
- Xor: []byte{0x0, 0x0, 0x0, 0x0},
- Mask: m.wgIface.Address().Network.Mask,
- },
- &expr.Cmp{
- Op: expr.CmpOpNeq,
- Register: 2,
- Data: ip.Unmap().AsSlice(),
- },
- &expr.Payload{
- DestRegister: 1,
- Base: expr.PayloadBaseNetworkHeader,
- Offset: 16,
- Len: 4,
- },
- &expr.Cmp{
- Op: expr.CmpOpEq,
- Register: 1,
- Data: m.wgIface.Address().IP.To4(),
- },
- &expr.Immediate{
- Register: 1,
- Data: postroutingMark,
- },
- &expr.Meta{
- Key: expr.MetaKeyMARK,
- SourceRegister: true,
- Register: 1,
- },
- }
- _ = m.rConn.AddRule(&nftables.Rule{
- Table: m.workTable,
- Chain: chain,
- Exprs: expressions,
- })
- return chain
-}
-
func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.MetaKey) []expr.Any {
expressions := []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1},
@@ -832,101 +541,9 @@ func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.Met
return nil
}
-func (m *AclManager) addJumpRuleToInputChain() {
- expressions := []expr.Any{
- &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
- &expr.Cmp{
- Op: expr.CmpOpEq,
- Register: 1,
- Data: ifname(m.wgIface.Name()),
- },
- &expr.Verdict{
- Kind: expr.VerdictJump,
- Chain: m.chainInputRules.Name,
- },
- }
-
- _ = m.rConn.AddRule(&nftables.Rule{
- Table: m.workTable,
- Chain: m.chainFwFilter,
- Exprs: expressions,
- })
-}
-
-func (m *AclManager) addRouteAllowRule(chain *nftables.Chain, netIfName expr.MetaKey) {
- ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
- var srcOp, dstOp expr.CmpOp
- if netIfName == expr.MetaKeyIIFNAME {
- srcOp = expr.CmpOpEq
- dstOp = expr.CmpOpNeq
- } else {
- srcOp = expr.CmpOpNeq
- dstOp = expr.CmpOpEq
- }
- expressions := []expr.Any{
- &expr.Meta{Key: netIfName, Register: 1},
- &expr.Cmp{
- Op: expr.CmpOpEq,
- Register: 1,
- Data: ifname(m.wgIface.Name()),
- },
- &expr.Payload{
- DestRegister: 2,
- Base: expr.PayloadBaseNetworkHeader,
- Offset: 12,
- Len: 4,
- },
- &expr.Bitwise{
- SourceRegister: 2,
- DestRegister: 2,
- Len: 4,
- Xor: []byte{0x0, 0x0, 0x0, 0x0},
- Mask: m.wgIface.Address().Network.Mask,
- },
- &expr.Cmp{
- Op: srcOp,
- Register: 2,
- Data: ip.Unmap().AsSlice(),
- },
- &expr.Payload{
- DestRegister: 2,
- Base: expr.PayloadBaseNetworkHeader,
- Offset: 16,
- Len: 4,
- },
- &expr.Bitwise{
- SourceRegister: 2,
- DestRegister: 2,
- Len: 4,
- Xor: []byte{0x0, 0x0, 0x0, 0x0},
- Mask: m.wgIface.Address().Network.Mask,
- },
- &expr.Cmp{
- Op: dstOp,
- Register: 2,
- Data: ip.Unmap().AsSlice(),
- },
- &expr.Verdict{
- Kind: expr.VerdictAccept,
- },
- }
- _ = m.rConn.AddRule(&nftables.Rule{
- Table: chain.Table,
- Chain: chain,
- Exprs: expressions,
- })
-}
-
func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) {
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
- var srcOp, dstOp expr.CmpOp
- if iifname == expr.MetaKeyIIFNAME {
- srcOp = expr.CmpOpNeq
- dstOp = expr.CmpOpEq
- } else {
- srcOp = expr.CmpOpEq
- dstOp = expr.CmpOpNeq
- }
+ dstOp := expr.CmpOpNeq
expressions := []expr.Any{
&expr.Meta{Key: iifname, Register: 1},
&expr.Cmp{
@@ -934,24 +551,6 @@ func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) {
Register: 1,
Data: ifname(m.wgIface.Name()),
},
- &expr.Payload{
- DestRegister: 2,
- Base: expr.PayloadBaseNetworkHeader,
- Offset: 12,
- Len: 4,
- },
- &expr.Bitwise{
- SourceRegister: 2,
- DestRegister: 2,
- Len: 4,
- Xor: []byte{0x0, 0x0, 0x0, 0x0},
- Mask: m.wgIface.Address().Network.Mask,
- },
- &expr.Cmp{
- Op: srcOp,
- Register: 2,
- Data: ip.Unmap().AsSlice(),
- },
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseNetworkHeader,
@@ -982,7 +581,6 @@ func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) {
}
func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) {
- ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
expressions := []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1},
&expr.Cmp{
@@ -990,47 +588,12 @@ func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr
Register: 1,
Data: ifname(m.wgIface.Name()),
},
- &expr.Payload{
- DestRegister: 2,
- Base: expr.PayloadBaseNetworkHeader,
- Offset: 12,
- Len: 4,
- },
- &expr.Bitwise{
- SourceRegister: 2,
- DestRegister: 2,
- Len: 4,
- Xor: []byte{0x0, 0x0, 0x0, 0x0},
- Mask: m.wgIface.Address().Network.Mask,
- },
- &expr.Cmp{
- Op: expr.CmpOpEq,
- Register: 2,
- Data: ip.Unmap().AsSlice(),
- },
- &expr.Payload{
- DestRegister: 2,
- Base: expr.PayloadBaseNetworkHeader,
- Offset: 16,
- Len: 4,
- },
- &expr.Bitwise{
- SourceRegister: 2,
- DestRegister: 2,
- Len: 4,
- Xor: []byte{0x0, 0x0, 0x0, 0x0},
- Mask: m.wgIface.Address().Network.Mask,
- },
- &expr.Cmp{
- Op: expr.CmpOpEq,
- Register: 2,
- Data: ip.Unmap().AsSlice(),
- },
&expr.Verdict{
Kind: expr.VerdictJump,
Chain: to,
},
}
+
_ = m.rConn.AddRule(&nftables.Rule{
Table: chain.Table,
Chain: chain,
@@ -1132,7 +695,7 @@ func (m *AclManager) refreshRuleHandles(chain *nftables.Chain) error {
return nil
}
-func generateRuleId(
+func generatePeerRuleId(
ip net.IP,
sPort *firewall.Port,
dPort *firewall.Port,
@@ -1155,33 +718,6 @@ func generateRuleId(
}
return "set:" + ipset.Name + rulesetID
}
-func generateRuleIdForMangle(ipset *nftables.Set, ip net.IP, proto firewall.Protocol, port *firewall.Port) string {
- // case of icmp port is empty
- var p string
- if port != nil {
- p = port.String()
- }
- if ipset != nil {
- return fmt.Sprintf("p:set:%s:%s:%v", ipset.Name, proto, p)
- } else {
- return fmt.Sprintf("p:ip:%s:%s:%v", ip.String(), proto, p)
- }
-}
-
-func shouldAddToPrerouting(proto firewall.Protocol, dPort *firewall.Port, direction firewall.RuleDirection) bool {
- if proto == "all" {
- return false
- }
-
- if direction != firewall.RuleDirectionIN {
- return false
- }
-
- if dPort == nil && proto != firewall.ProtocolICMP {
- return false
- }
- return true
-}
func encodePort(port firewall.Port) []byte {
bs := make([]byte, 2)
@@ -1191,6 +727,19 @@ func encodePort(port firewall.Port) []byte {
func ifname(n string) []byte {
b := make([]byte, 16)
- copy(b, []byte(n+"\x00"))
+ copy(b, n+"\x00")
return b
}
+
+func protoToInt(protocol firewall.Protocol) (uint8, error) {
+ switch protocol {
+ case firewall.ProtocolTCP:
+ return unix.IPPROTO_TCP, nil
+ case firewall.ProtocolUDP:
+ return unix.IPPROTO_UDP, nil
+ case firewall.ProtocolICMP:
+ return unix.IPPROTO_ICMP, nil
+ }
+
+ return 0, fmt.Errorf("unsupported protocol: %s", protocol)
+}
diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go
index a376c98c3..d2258ae08 100644
--- a/client/firewall/nftables/manager_linux.go
+++ b/client/firewall/nftables/manager_linux.go
@@ -5,9 +5,11 @@ import (
"context"
"fmt"
"net"
+ "net/netip"
"sync"
"github.com/google/nftables"
+ "github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
log "github.com/sirupsen/logrus"
@@ -15,8 +17,11 @@ import (
)
const (
- // tableName is the name of the table that is used for filtering by the Netbird client
- tableName = "netbird"
+ // tableNameNetbird is the name of the table that is used for filtering by the Netbird client
+ tableNameNetbird = "netbird"
+
+ tableNameFilter = "filter"
+ chainNameInput = "INPUT"
)
// Manager of iptables firewall
@@ -41,12 +46,12 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
return nil, err
}
- m.router, err = newRouter(context, workTable)
+ m.router, err = newRouter(context, workTable, wgIface)
if err != nil {
return nil, err
}
- m.aclManager, err = newAclManager(workTable, wgIface, m.router.RouteingFwChainName())
+ m.aclManager, err = newAclManager(workTable, wgIface, chainNameRoutingFw)
if err != nil {
return nil, err
}
@@ -54,11 +59,11 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
return m, nil
}
-// AddFiltering rule to the firewall
+// AddPeerFiltering rule to the firewall
//
// If comment argument is empty firewall manager should set
// rule ID as comment for the rule
-func (m *Manager) AddFiltering(
+func (m *Manager) AddPeerFiltering(
ip net.IP,
proto firewall.Protocol,
sPort *firewall.Port,
@@ -76,33 +81,52 @@ func (m *Manager) AddFiltering(
return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
}
- return m.aclManager.AddFiltering(ip, proto, sPort, dPort, direction, action, ipsetName, comment)
+ return m.aclManager.AddPeerFiltering(ip, proto, sPort, dPort, direction, action, ipsetName, comment)
}
-// DeleteRule from the firewall by rule definition
-func (m *Manager) DeleteRule(rule firewall.Rule) error {
+func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
- return m.aclManager.DeleteRule(rule)
+ if !destination.Addr().Is4() {
+ return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
+ }
+
+ return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
+}
+
+// DeletePeerRule from the firewall by rule definition
+func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
+ m.mutex.Lock()
+ defer m.mutex.Unlock()
+
+ return m.aclManager.DeletePeerRule(rule)
+}
+
+// DeleteRouteRule deletes a routing rule
+func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
+ m.mutex.Lock()
+ defer m.mutex.Unlock()
+
+ return m.router.DeleteRouteRule(rule)
}
func (m *Manager) IsServerRouteSupported() bool {
return true
}
-func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
+func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
m.mutex.Lock()
defer m.mutex.Unlock()
- return m.router.AddRoutingRules(pair)
+ return m.router.AddNatRule(pair)
}
-func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
+func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
m.mutex.Lock()
defer m.mutex.Unlock()
- return m.router.RemoveRoutingRules(pair)
+ return m.router.RemoveNatRule(pair)
}
// AllowNetbird allows netbird interface traffic
@@ -126,7 +150,7 @@ func (m *Manager) AllowNetbird() error {
var chain *nftables.Chain
for _, c := range chains {
- if c.Table.Name == "filter" && c.Name == "INPUT" {
+ if c.Table.Name == tableNameFilter && c.Name == chainNameForward {
chain = c
break
}
@@ -157,6 +181,27 @@ func (m *Manager) AllowNetbird() error {
return nil
}
+// SetLegacyManagement sets the route manager to use legacy management
+func (m *Manager) SetLegacyManagement(isLegacy bool) error {
+ oldLegacy := m.router.legacyManagement
+
+ if oldLegacy != isLegacy {
+ m.router.legacyManagement = isLegacy
+ log.Debugf("Set legacy management to %v", isLegacy)
+ }
+
+ // client reconnected to a newer mgmt, we need to cleanup the legacy rules
+ if !isLegacy && oldLegacy {
+ if err := m.router.RemoveAllLegacyRouteRules(); err != nil {
+ return fmt.Errorf("remove legacy routing rules: %v", err)
+ }
+
+ log.Debugf("Legacy routing rules removed")
+ }
+
+ return nil
+}
+
// Reset firewall to the default state
func (m *Manager) Reset() error {
m.mutex.Lock()
@@ -185,14 +230,16 @@ func (m *Manager) Reset() error {
}
}
- m.router.ResetForwardRules()
+ if err := m.router.Reset(); err != nil {
+ return fmt.Errorf("reset forward rules: %v", err)
+ }
tables, err := m.rConn.ListTables()
if err != nil {
return fmt.Errorf("list of tables: %w", err)
}
for _, t := range tables {
- if t.Name == tableName {
+ if t.Name == tableNameNetbird {
m.rConn.DelTable(t)
}
}
@@ -218,12 +265,12 @@ func (m *Manager) createWorkTable() (*nftables.Table, error) {
}
for _, t := range tables {
- if t.Name == tableName {
+ if t.Name == tableNameNetbird {
m.rConn.DelTable(t)
}
}
- table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
+ table := m.rConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4})
err = m.rConn.Flush()
return table, err
}
@@ -239,9 +286,7 @@ func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
Register: 1,
Data: ifname(m.wgIface.Name()),
},
- &expr.Verdict{
- Kind: expr.VerdictAccept,
- },
+ &expr.Verdict{},
},
UserData: []byte(allowNetbirdInputRuleID),
}
@@ -251,7 +296,7 @@ func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftables.Rule {
ifName := ifname(m.wgIface.Name())
for _, rule := range existedRules {
- if rule.Table.Name == "filter" && rule.Chain.Name == "INPUT" {
+ if rule.Table.Name == tableNameFilter && rule.Chain.Name == chainNameInput {
if len(rule.Exprs) < 4 {
if e, ok := rule.Exprs[0].(*expr.Meta); !ok || e.Key != expr.MetaKeyIIFNAME {
continue
@@ -265,3 +310,33 @@ func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftable
}
return nil
}
+
+func insertReturnTrafficRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain) {
+ rule := &nftables.Rule{
+ Table: table,
+ Chain: chain,
+ Exprs: []expr.Any{
+ &expr.Ct{
+ Key: expr.CtKeySTATE,
+ Register: 1,
+ },
+ &expr.Bitwise{
+ SourceRegister: 1,
+ DestRegister: 1,
+ Len: 4,
+ Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
+ Xor: binaryutil.NativeEndian.PutUint32(0),
+ },
+ &expr.Cmp{
+ Op: expr.CmpOpNeq,
+ Register: 1,
+ Data: []byte{0, 0, 0, 0},
+ },
+ &expr.Verdict{
+ Kind: expr.VerdictAccept,
+ },
+ },
+ }
+
+ conn.InsertRule(rule)
+}
diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go
index 1f226e315..7f78a9a2e 100644
--- a/client/firewall/nftables/manager_linux_test.go
+++ b/client/firewall/nftables/manager_linux_test.go
@@ -9,6 +9,7 @@ import (
"time"
"github.com/google/nftables"
+ "github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
"github.com/stretchr/testify/require"
"golang.org/x/sys/unix"
@@ -17,6 +18,21 @@ import (
"github.com/netbirdio/netbird/iface"
)
+var ifaceMock = &iFaceMock{
+ NameFunc: func() string {
+ return "lo"
+ },
+ AddressFunc: func() iface.WGAddress {
+ return iface.WGAddress{
+ IP: net.ParseIP("100.96.0.1"),
+ Network: &net.IPNet{
+ IP: net.ParseIP("100.96.0.0"),
+ Mask: net.IPv4Mask(255, 255, 255, 0),
+ },
+ }
+ },
+}
+
// iFaceMapper defines subset methods of interface required for manager
type iFaceMock struct {
NameFunc func() string
@@ -40,23 +56,9 @@ func (i *iFaceMock) Address() iface.WGAddress {
func (i *iFaceMock) IsUserspaceBind() bool { return false }
func TestNftablesManager(t *testing.T) {
- mock := &iFaceMock{
- NameFunc: func() string {
- return "lo"
- },
- AddressFunc: func() iface.WGAddress {
- return iface.WGAddress{
- IP: net.ParseIP("100.96.0.1"),
- Network: &net.IPNet{
- IP: net.ParseIP("100.96.0.0"),
- Mask: net.IPv4Mask(255, 255, 255, 0),
- },
- }
- },
- }
// just check on the local interface
- manager, err := Create(context.Background(), mock)
+ manager, err := Create(context.Background(), ifaceMock)
require.NoError(t, err)
time.Sleep(time.Second * 3)
@@ -70,7 +72,7 @@ func TestNftablesManager(t *testing.T) {
testClient := &nftables.Conn{}
- rule, err := manager.AddFiltering(
+ rule, err := manager.AddPeerFiltering(
ip,
fw.ProtocolTCP,
nil,
@@ -88,17 +90,34 @@ func TestNftablesManager(t *testing.T) {
rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
require.NoError(t, err, "failed to get rules")
- require.Len(t, rules, 1, "expected 1 rules")
+ require.Len(t, rules, 2, "expected 2 rules")
+
+ expectedExprs1 := []expr.Any{
+ &expr.Ct{
+ Key: expr.CtKeySTATE,
+ Register: 1,
+ },
+ &expr.Bitwise{
+ SourceRegister: 1,
+ DestRegister: 1,
+ Len: 4,
+ Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
+ Xor: binaryutil.NativeEndian.PutUint32(0),
+ },
+ &expr.Cmp{
+ Op: expr.CmpOpNeq,
+ Register: 1,
+ Data: []byte{0, 0, 0, 0},
+ },
+ &expr.Verdict{
+ Kind: expr.VerdictAccept,
+ },
+ }
+ require.ElementsMatch(t, rules[0].Exprs, expectedExprs1, "expected the same expressions")
ipToAdd, _ := netip.AddrFromSlice(ip)
add := ipToAdd.Unmap()
- expectedExprs := []expr.Any{
- &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
- &expr.Cmp{
- Op: expr.CmpOpEq,
- Register: 1,
- Data: ifname("lo"),
- },
+ expectedExprs2 := []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
@@ -134,10 +153,10 @@ func TestNftablesManager(t *testing.T) {
},
&expr.Verdict{Kind: expr.VerdictDrop},
}
- require.ElementsMatch(t, rules[0].Exprs, expectedExprs, "expected the same expressions")
+ require.ElementsMatch(t, rules[1].Exprs, expectedExprs2, "expected the same expressions")
for _, r := range rule {
- err = manager.DeleteRule(r)
+ err = manager.DeletePeerRule(r)
require.NoError(t, err, "failed to delete rule")
}
@@ -146,7 +165,8 @@ func TestNftablesManager(t *testing.T) {
rules, err = testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
require.NoError(t, err, "failed to get rules")
- require.Len(t, rules, 0, "expected 0 rules after deletion")
+ // established rule remains
+ require.Len(t, rules, 1, "expected 1 rules after deletion")
err = manager.Reset()
require.NoError(t, err, "failed to reset")
@@ -187,9 +207,9 @@ func TestNFtablesCreatePerformance(t *testing.T) {
for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}}
if i%2 == 0 {
- _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
+ _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
} else {
- _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
+ _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
}
require.NoError(t, err, "failed to add rule")
diff --git a/client/firewall/nftables/route_linux.go b/client/firewall/nftables/route_linux.go
deleted file mode 100644
index 71d5ac88e..000000000
--- a/client/firewall/nftables/route_linux.go
+++ /dev/null
@@ -1,431 +0,0 @@
-package nftables
-
-import (
- "bytes"
- "context"
- "errors"
- "fmt"
- "net"
- "net/netip"
-
- "github.com/google/nftables"
- "github.com/google/nftables/binaryutil"
- "github.com/google/nftables/expr"
- log "github.com/sirupsen/logrus"
-
- "github.com/netbirdio/netbird/client/firewall/manager"
-)
-
-const (
- chainNameRouteingFw = "netbird-rt-fwd"
- chainNameRoutingNat = "netbird-rt-nat"
-
- userDataAcceptForwardRuleSrc = "frwacceptsrc"
- userDataAcceptForwardRuleDst = "frwacceptdst"
-
- loopbackInterface = "lo\x00"
-)
-
-// some presets for building nftable rules
-var (
- zeroXor = binaryutil.NativeEndian.PutUint32(0)
-
- exprCounterAccept = []expr.Any{
- &expr.Counter{},
- &expr.Verdict{
- Kind: expr.VerdictAccept,
- },
- }
-
- errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found")
-)
-
-type router struct {
- ctx context.Context
- stop context.CancelFunc
- conn *nftables.Conn
- workTable *nftables.Table
- filterTable *nftables.Table
- chains map[string]*nftables.Chain
- // rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules
- rules map[string]*nftables.Rule
- isDefaultFwdRulesEnabled bool
-}
-
-func newRouter(parentCtx context.Context, workTable *nftables.Table) (*router, error) {
- ctx, cancel := context.WithCancel(parentCtx)
-
- r := &router{
- ctx: ctx,
- stop: cancel,
- conn: &nftables.Conn{},
- workTable: workTable,
- chains: make(map[string]*nftables.Chain),
- rules: make(map[string]*nftables.Rule),
- }
-
- var err error
- r.filterTable, err = r.loadFilterTable()
- if err != nil {
- if errors.Is(err, errFilterTableNotFound) {
- log.Warnf("table 'filter' not found for forward rules")
- } else {
- return nil, err
- }
- }
-
- err = r.cleanUpDefaultForwardRules()
- if err != nil {
- log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
- }
-
- err = r.createContainers()
- if err != nil {
- log.Errorf("failed to create containers for route: %s", err)
- }
- return r, err
-}
-
-func (r *router) RouteingFwChainName() string {
- return chainNameRouteingFw
-}
-
-// ResetForwardRules cleans existing nftables default forward rules from the system
-func (r *router) ResetForwardRules() {
- err := r.cleanUpDefaultForwardRules()
- if err != nil {
- log.Errorf("failed to reset forward rules: %s", err)
- }
-}
-
-func (r *router) loadFilterTable() (*nftables.Table, error) {
- tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
- if err != nil {
- return nil, fmt.Errorf("nftables: unable to list tables: %v", err)
- }
-
- for _, table := range tables {
- if table.Name == "filter" {
- return table, nil
- }
- }
-
- return nil, errFilterTableNotFound
-}
-
-func (r *router) createContainers() error {
-
- r.chains[chainNameRouteingFw] = r.conn.AddChain(&nftables.Chain{
- Name: chainNameRouteingFw,
- Table: r.workTable,
- })
-
- r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
- Name: chainNameRoutingNat,
- Table: r.workTable,
- Hooknum: nftables.ChainHookPostrouting,
- Priority: nftables.ChainPriorityNATSource - 1,
- Type: nftables.ChainTypeNAT,
- })
-
- // Add RETURN rule for loopback interface
- loRule := &nftables.Rule{
- Table: r.workTable,
- Chain: r.chains[chainNameRoutingNat],
- Exprs: []expr.Any{
- &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
- &expr.Cmp{
- Op: expr.CmpOpEq,
- Register: 1,
- Data: []byte(loopbackInterface),
- },
- &expr.Verdict{Kind: expr.VerdictReturn},
- },
- }
- r.conn.InsertRule(loRule)
-
- err := r.refreshRulesMap()
- if err != nil {
- log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
- }
-
- err = r.conn.Flush()
- if err != nil {
- return fmt.Errorf("nftables: unable to initialize table: %v", err)
- }
- return nil
-}
-
-// AddRoutingRules appends a nftable rule pair to the forwarding chain and if enabled, to the nat chain
-func (r *router) AddRoutingRules(pair manager.RouterPair) error {
- err := r.refreshRulesMap()
- if err != nil {
- return err
- }
-
- err = r.addRoutingRule(manager.ForwardingFormat, chainNameRouteingFw, pair, false)
- if err != nil {
- return err
- }
- err = r.addRoutingRule(manager.InForwardingFormat, chainNameRouteingFw, manager.GetInPair(pair), false)
- if err != nil {
- return err
- }
-
- if pair.Masquerade {
- err = r.addRoutingRule(manager.NatFormat, chainNameRoutingNat, pair, true)
- if err != nil {
- return err
- }
- err = r.addRoutingRule(manager.InNatFormat, chainNameRoutingNat, manager.GetInPair(pair), true)
- if err != nil {
- return err
- }
- }
-
- if r.filterTable != nil && !r.isDefaultFwdRulesEnabled {
- log.Debugf("add default accept forward rule")
- r.acceptForwardRule(pair.Source)
- }
-
- err = r.conn.Flush()
- if err != nil {
- return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.Destination, err)
- }
- return nil
-}
-
-// addRoutingRule inserts a nftable rule to the conn client flush queue
-func (r *router) addRoutingRule(format, chainName string, pair manager.RouterPair, isNat bool) error {
- sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
- destExp := generateCIDRMatcherExpressions(false, pair.Destination)
-
- var expression []expr.Any
- if isNat {
- expression = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) // nolint:gocritic
- } else {
- expression = append(sourceExp, append(destExp, exprCounterAccept...)...) // nolint:gocritic
- }
-
- ruleKey := manager.GenKey(format, pair.ID)
-
- _, exists := r.rules[ruleKey]
- if exists {
- err := r.removeRoutingRule(format, pair)
- if err != nil {
- return err
- }
- }
-
- r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
- Table: r.workTable,
- Chain: r.chains[chainName],
- Exprs: expression,
- UserData: []byte(ruleKey),
- })
- return nil
-}
-
-func (r *router) acceptForwardRule(sourceNetwork string) {
- src := generateCIDRMatcherExpressions(true, sourceNetwork)
- dst := generateCIDRMatcherExpressions(false, "0.0.0.0/0")
-
- var exprs []expr.Any
- exprs = append(src, append(dst, &expr.Verdict{ // nolint:gocritic
- Kind: expr.VerdictAccept,
- })...)
-
- rule := &nftables.Rule{
- Table: r.filterTable,
- Chain: &nftables.Chain{
- Name: "FORWARD",
- Table: r.filterTable,
- Type: nftables.ChainTypeFilter,
- Hooknum: nftables.ChainHookForward,
- Priority: nftables.ChainPriorityFilter,
- },
- Exprs: exprs,
- UserData: []byte(userDataAcceptForwardRuleSrc),
- }
-
- r.conn.AddRule(rule)
-
- src = generateCIDRMatcherExpressions(true, "0.0.0.0/0")
- dst = generateCIDRMatcherExpressions(false, sourceNetwork)
-
- exprs = append(src, append(dst, &expr.Verdict{ //nolint:gocritic
- Kind: expr.VerdictAccept,
- })...)
-
- rule = &nftables.Rule{
- Table: r.filterTable,
- Chain: &nftables.Chain{
- Name: "FORWARD",
- Table: r.filterTable,
- Type: nftables.ChainTypeFilter,
- Hooknum: nftables.ChainHookForward,
- Priority: nftables.ChainPriorityFilter,
- },
- Exprs: exprs,
- UserData: []byte(userDataAcceptForwardRuleDst),
- }
- r.conn.AddRule(rule)
- r.isDefaultFwdRulesEnabled = true
-}
-
-// RemoveRoutingRules removes a nftable rule pair from forwarding and nat chains
-func (r *router) RemoveRoutingRules(pair manager.RouterPair) error {
- err := r.refreshRulesMap()
- if err != nil {
- return err
- }
-
- err = r.removeRoutingRule(manager.ForwardingFormat, pair)
- if err != nil {
- return err
- }
-
- err = r.removeRoutingRule(manager.InForwardingFormat, manager.GetInPair(pair))
- if err != nil {
- return err
- }
-
- err = r.removeRoutingRule(manager.NatFormat, pair)
- if err != nil {
- return err
- }
-
- err = r.removeRoutingRule(manager.InNatFormat, manager.GetInPair(pair))
- if err != nil {
- return err
- }
-
- if len(r.rules) == 0 {
- err := r.cleanUpDefaultForwardRules()
- if err != nil {
- log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
- }
- }
-
- err = r.conn.Flush()
- if err != nil {
- return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err)
- }
- log.Debugf("nftables: removed rules for %s", pair.Destination)
- return nil
-}
-
-// removeRoutingRule add a nftable rule to the removal queue and delete from rules map
-func (r *router) removeRoutingRule(format string, pair manager.RouterPair) error {
- ruleKey := manager.GenKey(format, pair.ID)
-
- rule, found := r.rules[ruleKey]
- if found {
- ruleType := "forwarding"
- if rule.Chain.Type == nftables.ChainTypeNAT {
- ruleType = "nat"
- }
-
- err := r.conn.DelRule(rule)
- if err != nil {
- return fmt.Errorf("nftables: unable to remove %s rule for %s: %v", ruleType, pair.Destination, err)
- }
-
- log.Debugf("nftables: removing %s rule for %s", ruleType, pair.Destination)
-
- delete(r.rules, ruleKey)
- }
- return nil
-}
-
-// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
-// duplicates and to get missing attributes that we don't have when adding new rules
-func (r *router) refreshRulesMap() error {
- for _, chain := range r.chains {
- rules, err := r.conn.GetRules(chain.Table, chain)
- if err != nil {
- return fmt.Errorf("nftables: unable to list rules: %v", err)
- }
- for _, rule := range rules {
- if len(rule.UserData) > 0 {
- r.rules[string(rule.UserData)] = rule
- }
- }
- }
- return nil
-}
-
-func (r *router) cleanUpDefaultForwardRules() error {
- if r.filterTable == nil {
- r.isDefaultFwdRulesEnabled = false
- return nil
- }
-
- chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
- if err != nil {
- return err
- }
-
- var rules []*nftables.Rule
- for _, chain := range chains {
- if chain.Table.Name != r.filterTable.Name {
- continue
- }
- if chain.Name != "FORWARD" {
- continue
- }
-
- rules, err = r.conn.GetRules(r.filterTable, chain)
- if err != nil {
- return err
- }
- }
-
- for _, rule := range rules {
- if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleSrc)) || bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleDst)) {
- err := r.conn.DelRule(rule)
- if err != nil {
- return err
- }
- }
- }
- r.isDefaultFwdRulesEnabled = false
- return r.conn.Flush()
-}
-
-// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
-func generateCIDRMatcherExpressions(source bool, cidr string) []expr.Any {
- ip, network, _ := net.ParseCIDR(cidr)
- ipToAdd, _ := netip.AddrFromSlice(ip)
- add := ipToAdd.Unmap()
-
- var offSet uint32
- if source {
- offSet = 12 // src offset
- } else {
- offSet = 16 // dst offset
- }
-
- return []expr.Any{
- // fetch src add
- &expr.Payload{
- DestRegister: 1,
- Base: expr.PayloadBaseNetworkHeader,
- Offset: offSet,
- Len: 4,
- },
- // net mask
- &expr.Bitwise{
- DestRegister: 1,
- SourceRegister: 1,
- Len: 4,
- Mask: network.Mask,
- Xor: zeroXor,
- },
- // net address
- &expr.Cmp{
- Register: 1,
- Data: add.AsSlice(),
- },
- }
-}
diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go
new file mode 100644
index 000000000..aa61e1858
--- /dev/null
+++ b/client/firewall/nftables/router_linux.go
@@ -0,0 +1,798 @@
+package nftables
+
+import (
+ "bytes"
+ "context"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "net"
+ "net/netip"
+ "strings"
+
+ "github.com/google/nftables"
+ "github.com/google/nftables/binaryutil"
+ "github.com/google/nftables/expr"
+ "github.com/hashicorp/go-multierror"
+ log "github.com/sirupsen/logrus"
+
+ nberrors "github.com/netbirdio/netbird/client/errors"
+ firewall "github.com/netbirdio/netbird/client/firewall/manager"
+ "github.com/netbirdio/netbird/client/internal/acl/id"
+ "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
+)
+
+const (
+ chainNameRoutingFw = "netbird-rt-fwd"
+ chainNameRoutingNat = "netbird-rt-nat"
+ chainNameForward = "FORWARD"
+
+ userDataAcceptForwardRuleIif = "frwacceptiif"
+ userDataAcceptForwardRuleOif = "frwacceptoif"
+)
+
+const refreshRulesMapError = "refresh rules map: %w"
+
+var (
+ errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found")
+)
+
+type router struct {
+ ctx context.Context
+ stop context.CancelFunc
+ conn *nftables.Conn
+ workTable *nftables.Table
+ filterTable *nftables.Table
+ chains map[string]*nftables.Chain
+ // rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules
+ rules map[string]*nftables.Rule
+ ipsetCounter *refcounter.Counter[string, []netip.Prefix, *nftables.Set]
+
+ wgIface iFaceMapper
+ legacyManagement bool
+}
+
+func newRouter(parentCtx context.Context, workTable *nftables.Table, wgIface iFaceMapper) (*router, error) {
+ ctx, cancel := context.WithCancel(parentCtx)
+
+ r := &router{
+ ctx: ctx,
+ stop: cancel,
+ conn: &nftables.Conn{},
+ workTable: workTable,
+ chains: make(map[string]*nftables.Chain),
+ rules: make(map[string]*nftables.Rule),
+ wgIface: wgIface,
+ }
+
+ r.ipsetCounter = refcounter.New(
+ r.createIpSet,
+ r.deleteIpSet,
+ )
+
+ var err error
+ r.filterTable, err = r.loadFilterTable()
+ if err != nil {
+ if errors.Is(err, errFilterTableNotFound) {
+ log.Warnf("table 'filter' not found for forward rules")
+ } else {
+ return nil, err
+ }
+ }
+
+ err = r.cleanUpDefaultForwardRules()
+ if err != nil {
+ log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
+ }
+
+ err = r.createContainers()
+ if err != nil {
+ log.Errorf("failed to create containers for route: %s", err)
+ }
+ return r, err
+}
+
+// Reset cleans existing nftables default forward rules from the system
+func (r *router) Reset() error {
+ // clear without deleting the ipsets, the nf table will be deleted by the caller
+ r.ipsetCounter.Clear()
+
+ return r.cleanUpDefaultForwardRules()
+}
+
+func (r *router) cleanUpDefaultForwardRules() error {
+ if r.filterTable == nil {
+ return nil
+ }
+
+ chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
+ if err != nil {
+ return fmt.Errorf("list chains: %v", err)
+ }
+
+ for _, chain := range chains {
+ if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward {
+ continue
+ }
+
+ rules, err := r.conn.GetRules(r.filterTable, chain)
+ if err != nil {
+ return fmt.Errorf("get rules: %v", err)
+ }
+
+ for _, rule := range rules {
+ if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
+ bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) {
+ if err := r.conn.DelRule(rule); err != nil {
+ return fmt.Errorf("delete rule: %v", err)
+ }
+ }
+ }
+ }
+
+ return r.conn.Flush()
+}
+
+func (r *router) loadFilterTable() (*nftables.Table, error) {
+ tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
+ if err != nil {
+ return nil, fmt.Errorf("nftables: unable to list tables: %v", err)
+ }
+
+ for _, table := range tables {
+ if table.Name == "filter" {
+ return table, nil
+ }
+ }
+
+ return nil, errFilterTableNotFound
+}
+
+func (r *router) createContainers() error {
+
+ r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{
+ Name: chainNameRoutingFw,
+ Table: r.workTable,
+ })
+
+ insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
+
+ r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
+ Name: chainNameRoutingNat,
+ Table: r.workTable,
+ Hooknum: nftables.ChainHookPostrouting,
+ Priority: nftables.ChainPriorityNATSource - 1,
+ Type: nftables.ChainTypeNAT,
+ })
+
+ r.acceptForwardRules()
+
+ err := r.refreshRulesMap()
+ if err != nil {
+ log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
+ }
+
+ err = r.conn.Flush()
+ if err != nil {
+ return fmt.Errorf("nftables: unable to initialize table: %v", err)
+ }
+ return nil
+}
+
+// AddRouteFiltering appends a nftables rule to the routing chain
+func (r *router) AddRouteFiltering(
+ sources []netip.Prefix,
+ destination netip.Prefix,
+ proto firewall.Protocol,
+ sPort *firewall.Port,
+ dPort *firewall.Port,
+ action firewall.Action,
+) (firewall.Rule, error) {
+ ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
+ if _, ok := r.rules[string(ruleKey)]; ok {
+ return ruleKey, nil
+ }
+
+ chain := r.chains[chainNameRoutingFw]
+ var exprs []expr.Any
+
+ switch {
+ case len(sources) == 1 && sources[0].Bits() == 0:
+ // If it's 0.0.0.0/0, we don't need to add any source matching
+ case len(sources) == 1:
+ // If there's only one source, we can use it directly
+ exprs = append(exprs, generateCIDRMatcherExpressions(true, sources[0])...)
+ default:
+ // If there are multiple sources, create or get an ipset
+ var err error
+ exprs, err = r.getIpSetExprs(sources, exprs)
+ if err != nil {
+ return nil, fmt.Errorf("get ipset expressions: %w", err)
+ }
+ }
+
+ // Handle destination
+ exprs = append(exprs, generateCIDRMatcherExpressions(false, destination)...)
+
+ // Handle protocol
+ if proto != firewall.ProtocolALL {
+ protoNum, err := protoToInt(proto)
+ if err != nil {
+ return nil, fmt.Errorf("convert protocol to number: %w", err)
+ }
+ exprs = append(exprs, &expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1})
+ exprs = append(exprs, &expr.Cmp{
+ Op: expr.CmpOpEq,
+ Register: 1,
+ Data: []byte{protoNum},
+ })
+
+ exprs = append(exprs, applyPort(sPort, true)...)
+ exprs = append(exprs, applyPort(dPort, false)...)
+ }
+
+ exprs = append(exprs, &expr.Counter{})
+
+ var verdict expr.VerdictKind
+ if action == firewall.ActionAccept {
+ verdict = expr.VerdictAccept
+ } else {
+ verdict = expr.VerdictDrop
+ }
+ exprs = append(exprs, &expr.Verdict{Kind: verdict})
+
+ rule := &nftables.Rule{
+ Table: r.workTable,
+ Chain: chain,
+ Exprs: exprs,
+ UserData: []byte(ruleKey),
+ }
+
+ r.rules[string(ruleKey)] = r.conn.AddRule(rule)
+
+ return ruleKey, r.conn.Flush()
+}
+
+func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr.Any, error) {
+ setName := firewall.GenerateSetName(sources)
+ ref, err := r.ipsetCounter.Increment(setName, sources)
+ if err != nil {
+ return nil, fmt.Errorf("create or get ipset for sources: %w", err)
+ }
+
+ exprs = append(exprs,
+ &expr.Payload{
+ DestRegister: 1,
+ Base: expr.PayloadBaseNetworkHeader,
+ Offset: 12,
+ Len: 4,
+ },
+ &expr.Lookup{
+ SourceRegister: 1,
+ SetName: ref.Out.Name,
+ SetID: ref.Out.ID,
+ },
+ )
+ return exprs, nil
+}
+
+func (r *router) DeleteRouteRule(rule firewall.Rule) error {
+ if err := r.refreshRulesMap(); err != nil {
+ return fmt.Errorf(refreshRulesMapError, err)
+ }
+
+ ruleKey := rule.GetRuleID()
+ nftRule, exists := r.rules[ruleKey]
+ if !exists {
+ log.Debugf("route rule %s not found", ruleKey)
+ return nil
+ }
+
+ setName := r.findSetNameInRule(nftRule)
+
+ if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
+ return fmt.Errorf("delete: %w", err)
+ }
+
+ if setName != "" {
+ if _, err := r.ipsetCounter.Decrement(setName); err != nil {
+ return fmt.Errorf("decrement ipset reference: %w", err)
+ }
+ }
+
+ if err := r.conn.Flush(); err != nil {
+ return fmt.Errorf(flushError, err)
+ }
+
+ return nil
+}
+
+func (r *router) createIpSet(setName string, sources []netip.Prefix) (*nftables.Set, error) {
+ // overlapping prefixes will result in an error, so we need to merge them
+ sources = firewall.MergeIPRanges(sources)
+
+ set := &nftables.Set{
+ Name: setName,
+ Table: r.workTable,
+ // required for prefixes
+ Interval: true,
+ KeyType: nftables.TypeIPAddr,
+ }
+
+ var elements []nftables.SetElement
+ for _, prefix := range sources {
+ // TODO: Implement IPv6 support
+ if prefix.Addr().Is6() {
+ log.Printf("Skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
+ continue
+ }
+
+ // nftables needs half-open intervals [firstIP, lastIP) for prefixes
+ // e.g. 10.0.0.0/24 becomes [10.0.0.0, 10.0.1.0), 10.1.1.1/32 becomes [10.1.1.1, 10.1.1.2) etc
+ firstIP := prefix.Addr()
+ lastIP := calculateLastIP(prefix).Next()
+
+ elements = append(elements,
+ // the nft tool also adds a line like this, see https://github.com/google/nftables/issues/247
+ // nftables.SetElement{Key: []byte{0, 0, 0, 0}, IntervalEnd: true},
+ nftables.SetElement{Key: firstIP.AsSlice()},
+ nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true},
+ )
+ }
+
+ if err := r.conn.AddSet(set, elements); err != nil {
+ return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err)
+ }
+
+ if err := r.conn.Flush(); err != nil {
+ return nil, fmt.Errorf("flush error: %w", err)
+ }
+
+ log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2)
+
+ return set, nil
+}
+
+// calculateLastIP determines the last IP in a given prefix.
+func calculateLastIP(prefix netip.Prefix) netip.Addr {
+ hostMask := ^uint32(0) >> prefix.Masked().Bits()
+ lastIP := uint32FromNetipAddr(prefix.Addr()) | hostMask
+
+ return netip.AddrFrom4(uint32ToBytes(lastIP))
+}
+
+// Utility function to convert netip.Addr to uint32.
+func uint32FromNetipAddr(addr netip.Addr) uint32 {
+ b := addr.As4()
+ return binary.BigEndian.Uint32(b[:])
+}
+
+// Utility function to convert uint32 to a netip-compatible byte slice.
+func uint32ToBytes(ip uint32) [4]byte {
+ var b [4]byte
+ binary.BigEndian.PutUint32(b[:], ip)
+ return b
+}
+
+func (r *router) deleteIpSet(setName string, set *nftables.Set) error {
+ r.conn.DelSet(set)
+ if err := r.conn.Flush(); err != nil {
+ return fmt.Errorf(flushError, err)
+ }
+
+ log.Debugf("Deleted unused ipset %s", setName)
+ return nil
+}
+
+func (r *router) findSetNameInRule(rule *nftables.Rule) string {
+ for _, e := range rule.Exprs {
+ if lookup, ok := e.(*expr.Lookup); ok {
+ return lookup.SetName
+ }
+ }
+ return ""
+}
+
+func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
+ if err := r.conn.DelRule(rule); err != nil {
+ return fmt.Errorf("delete rule %s: %w", ruleKey, err)
+ }
+ delete(r.rules, ruleKey)
+
+ log.Debugf("removed route rule %s", ruleKey)
+
+ return nil
+}
+
+// AddNatRule appends a nftables rule pair to the nat chain
+func (r *router) AddNatRule(pair firewall.RouterPair) error {
+ if err := r.refreshRulesMap(); err != nil {
+ return fmt.Errorf(refreshRulesMapError, err)
+ }
+
+ if r.legacyManagement {
+ log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
+ if err := r.addLegacyRouteRule(pair); err != nil {
+ return fmt.Errorf("add legacy routing rule: %w", err)
+ }
+ }
+
+ if pair.Masquerade {
+ if err := r.addNatRule(pair); err != nil {
+ return fmt.Errorf("add nat rule: %w", err)
+ }
+
+ if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil {
+ return fmt.Errorf("add inverse nat rule: %w", err)
+ }
+ }
+
+ if err := r.conn.Flush(); err != nil {
+ return fmt.Errorf("nftables: insert rules for %s: %v", pair.Destination, err)
+ }
+
+ return nil
+}
+
+// addNatRule inserts a nftables rule to the conn client flush queue
+func (r *router) addNatRule(pair firewall.RouterPair) error {
+ sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
+ destExp := generateCIDRMatcherExpressions(false, pair.Destination)
+
+ dir := expr.MetaKeyIIFNAME
+ if pair.Inverse {
+ dir = expr.MetaKeyOIFNAME
+ }
+
+ intf := ifname(r.wgIface.Name())
+ exprs := []expr.Any{
+ &expr.Meta{
+ Key: dir,
+ Register: 1,
+ },
+ &expr.Cmp{
+ Op: expr.CmpOpEq,
+ Register: 1,
+ Data: intf,
+ },
+ }
+
+ exprs = append(exprs, sourceExp...)
+ exprs = append(exprs, destExp...)
+ exprs = append(exprs,
+ &expr.Counter{}, &expr.Masq{},
+ )
+
+ ruleKey := firewall.GenKey(firewall.NatFormat, pair)
+
+ if _, exists := r.rules[ruleKey]; exists {
+ if err := r.removeNatRule(pair); err != nil {
+ return fmt.Errorf("remove routing rule: %w", err)
+ }
+ }
+
+ r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
+ Table: r.workTable,
+ Chain: r.chains[chainNameRoutingNat],
+ Exprs: exprs,
+ UserData: []byte(ruleKey),
+ })
+ return nil
+}
+
+// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
+func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
+ sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
+ destExp := generateCIDRMatcherExpressions(false, pair.Destination)
+
+ exprs := []expr.Any{
+ &expr.Counter{},
+ &expr.Verdict{
+ Kind: expr.VerdictAccept,
+ },
+ }
+
+ expression := append(sourceExp, append(destExp, exprs...)...) // nolint:gocritic
+
+ ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
+
+ if _, exists := r.rules[ruleKey]; exists {
+ if err := r.removeLegacyRouteRule(pair); err != nil {
+ return fmt.Errorf("remove legacy routing rule: %w", err)
+ }
+ }
+
+ r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
+ Table: r.workTable,
+ Chain: r.chains[chainNameRoutingFw],
+ Exprs: expression,
+ UserData: []byte(ruleKey),
+ })
+ return nil
+}
+
+// removeLegacyRouteRule removes a legacy routing rule for mgmt servers pre route acls
+func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
+ ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
+
+ if rule, exists := r.rules[ruleKey]; exists {
+ if err := r.conn.DelRule(rule); err != nil {
+ return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
+ }
+
+ log.Debugf("nftables: removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
+
+ delete(r.rules, ruleKey)
+ } else {
+ log.Debugf("nftables: legacy forwarding rule %s not found", ruleKey)
+ }
+
+ return nil
+}
+
+// GetLegacyManagement returns the route manager's legacy management mode
+func (r *router) GetLegacyManagement() bool {
+ return r.legacyManagement
+}
+
+// SetLegacyManagement sets the route manager to use legacy management mode
+func (r *router) SetLegacyManagement(isLegacy bool) {
+ r.legacyManagement = isLegacy
+}
+
+// RemoveAllLegacyRouteRules removes all legacy routing rules for mgmt servers pre route acls
+func (r *router) RemoveAllLegacyRouteRules() error {
+ if err := r.refreshRulesMap(); err != nil {
+ return fmt.Errorf(refreshRulesMapError, err)
+ }
+
+ var merr *multierror.Error
+ for k, rule := range r.rules {
+ if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) {
+ continue
+ }
+ if err := r.conn.DelRule(rule); err != nil {
+ merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
+ }
+ }
+ return nberrors.FormatErrorOrNil(merr)
+}
+
+// acceptForwardRules adds iif/oif rules in the filter table/forward chain to make sure
+// that our traffic is not dropped by existing rules there.
+// The existing FORWARD rules/policies decide outbound traffic towards our interface.
+// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
+func (r *router) acceptForwardRules() {
+ if r.filterTable == nil {
+ log.Debugf("table 'filter' not found for forward rules, skipping accept rules")
+ return
+ }
+
+ intf := ifname(r.wgIface.Name())
+
+ // Rule for incoming interface (iif) with counter
+ iifRule := &nftables.Rule{
+ Table: r.filterTable,
+ Chain: &nftables.Chain{
+ Name: "FORWARD",
+ Table: r.filterTable,
+ Type: nftables.ChainTypeFilter,
+ Hooknum: nftables.ChainHookForward,
+ Priority: nftables.ChainPriorityFilter,
+ },
+ Exprs: []expr.Any{
+ &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
+ &expr.Cmp{
+ Op: expr.CmpOpEq,
+ Register: 1,
+ Data: intf,
+ },
+ &expr.Counter{},
+ &expr.Verdict{Kind: expr.VerdictAccept},
+ },
+ UserData: []byte(userDataAcceptForwardRuleIif),
+ }
+ r.conn.InsertRule(iifRule)
+
+ // Rule for outgoing interface (oif) with counter
+ oifRule := &nftables.Rule{
+ Table: r.filterTable,
+ Chain: &nftables.Chain{
+ Name: "FORWARD",
+ Table: r.filterTable,
+ Type: nftables.ChainTypeFilter,
+ Hooknum: nftables.ChainHookForward,
+ Priority: nftables.ChainPriorityFilter,
+ },
+ Exprs: []expr.Any{
+ &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
+ &expr.Cmp{
+ Op: expr.CmpOpEq,
+ Register: 1,
+ Data: intf,
+ },
+ &expr.Ct{
+ Key: expr.CtKeySTATE,
+ Register: 2,
+ },
+ &expr.Bitwise{
+ SourceRegister: 2,
+ DestRegister: 2,
+ Len: 4,
+ Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
+ Xor: binaryutil.NativeEndian.PutUint32(0),
+ },
+ &expr.Cmp{
+ Op: expr.CmpOpNeq,
+ Register: 2,
+ Data: []byte{0, 0, 0, 0},
+ },
+ &expr.Counter{},
+ &expr.Verdict{Kind: expr.VerdictAccept},
+ },
+ UserData: []byte(userDataAcceptForwardRuleOif),
+ }
+
+ r.conn.InsertRule(oifRule)
+}
+
+// RemoveNatRule removes a nftables rule pair from nat chains
+func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
+ if err := r.refreshRulesMap(); err != nil {
+ return fmt.Errorf(refreshRulesMapError, err)
+ }
+
+ if err := r.removeNatRule(pair); err != nil {
+ return fmt.Errorf("remove nat rule: %w", err)
+ }
+
+ if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
+ return fmt.Errorf("remove inverse nat rule: %w", err)
+ }
+
+ if err := r.removeLegacyRouteRule(pair); err != nil {
+ return fmt.Errorf("remove legacy routing rule: %w", err)
+ }
+
+ if err := r.conn.Flush(); err != nil {
+ return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err)
+ }
+
+ log.Debugf("nftables: removed rules for %s", pair.Destination)
+ return nil
+}
+
+// removeNatRule adds a nftables rule to the removal queue and deletes it from the rules map
+func (r *router) removeNatRule(pair firewall.RouterPair) error {
+ ruleKey := firewall.GenKey(firewall.NatFormat, pair)
+
+ if rule, exists := r.rules[ruleKey]; exists {
+ err := r.conn.DelRule(rule)
+ if err != nil {
+ return fmt.Errorf("remove nat rule %s -> %s: %v", pair.Source, pair.Destination, err)
+ }
+
+ log.Debugf("nftables: removed nat rule %s -> %s", pair.Source, pair.Destination)
+
+ delete(r.rules, ruleKey)
+ } else {
+ log.Debugf("nftables: nat rule %s not found", ruleKey)
+ }
+
+ return nil
+}
+
+// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
+// duplicates and to get missing attributes that we don't have when adding new rules
+func (r *router) refreshRulesMap() error {
+ for _, chain := range r.chains {
+ rules, err := r.conn.GetRules(chain.Table, chain)
+ if err != nil {
+ return fmt.Errorf("nftables: unable to list rules: %v", err)
+ }
+ for _, rule := range rules {
+ if len(rule.UserData) > 0 {
+ r.rules[string(rule.UserData)] = rule
+ }
+ }
+ }
+ return nil
+}
+
+// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
+func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any {
+ var offset uint32
+ if source {
+ offset = 12 // src offset
+ } else {
+ offset = 16 // dst offset
+ }
+
+ ones := prefix.Bits()
+ // 0.0.0.0/0 doesn't need extra expressions
+ if ones == 0 {
+ return nil
+ }
+
+ mask := net.CIDRMask(ones, 32)
+
+ return []expr.Any{
+ &expr.Payload{
+ DestRegister: 1,
+ Base: expr.PayloadBaseNetworkHeader,
+ Offset: offset,
+ Len: 4,
+ },
+ // netmask
+ &expr.Bitwise{
+ DestRegister: 1,
+ SourceRegister: 1,
+ Len: 4,
+ Mask: mask,
+ Xor: []byte{0, 0, 0, 0},
+ },
+ // net address
+ &expr.Cmp{
+ Op: expr.CmpOpEq,
+ Register: 1,
+ Data: prefix.Masked().Addr().AsSlice(),
+ },
+ }
+}
+
+func applyPort(port *firewall.Port, isSource bool) []expr.Any {
+ if port == nil {
+ return nil
+ }
+
+ var exprs []expr.Any
+
+ offset := uint32(2) // Default offset for destination port
+ if isSource {
+ offset = 0 // Offset for source port
+ }
+
+ exprs = append(exprs, &expr.Payload{
+ DestRegister: 1,
+ Base: expr.PayloadBaseTransportHeader,
+ Offset: offset,
+ Len: 2,
+ })
+
+ if port.IsRange && len(port.Values) == 2 {
+ // Handle port range
+ exprs = append(exprs,
+ &expr.Cmp{
+ Op: expr.CmpOpGte,
+ Register: 1,
+ Data: binaryutil.BigEndian.PutUint16(uint16(port.Values[0])),
+ },
+ &expr.Cmp{
+ Op: expr.CmpOpLte,
+ Register: 1,
+ Data: binaryutil.BigEndian.PutUint16(uint16(port.Values[1])),
+ },
+ )
+ } else {
+ // Handle single port or multiple ports
+ for i, p := range port.Values {
+ if i > 0 {
+ // Add a bitwise OR operation between port checks
+ exprs = append(exprs, &expr.Bitwise{
+ SourceRegister: 1,
+ DestRegister: 1,
+ Len: 4,
+ Mask: []byte{0x00, 0x00, 0xff, 0xff},
+ Xor: []byte{0x00, 0x00, 0x00, 0x00},
+ })
+ }
+ exprs = append(exprs, &expr.Cmp{
+ Op: expr.CmpOpEq,
+ Register: 1,
+ Data: binaryutil.BigEndian.PutUint16(uint16(p)),
+ })
+ }
+ }
+
+ return exprs
+}
diff --git a/client/firewall/nftables/router_linux_test.go b/client/firewall/nftables/router_linux_test.go
index 913fbd5d2..bbf92f3be 100644
--- a/client/firewall/nftables/router_linux_test.go
+++ b/client/firewall/nftables/router_linux_test.go
@@ -4,11 +4,15 @@ package nftables
import (
"context"
+ "encoding/binary"
+ "net/netip"
+ "os/exec"
"testing"
"github.com/coreos/go-iptables/iptables"
"github.com/google/nftables"
"github.com/google/nftables/expr"
+ "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
@@ -24,56 +28,50 @@ const (
NFTABLES
)
-func TestNftablesManager_InsertRoutingRules(t *testing.T) {
+func TestNftablesManager_AddNatRule(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this OS")
}
table, err := createWorkTable()
- if err != nil {
- t.Fatal(err)
- }
+ require.NoError(t, err, "Failed to create work table")
defer deleteWorkTable()
for _, testCase := range test.InsertRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) {
- manager, err := newRouter(context.TODO(), table)
+ manager, err := newRouter(context.TODO(), table, ifaceMock)
require.NoError(t, err, "failed to create router")
nftablesTestingClient := &nftables.Conn{}
- defer manager.ResetForwardRules()
+ defer func(manager *router) {
+ require.NoError(t, manager.Reset(), "failed to reset rules")
+ }(manager)
require.NoError(t, err, "shouldn't return error")
- err = manager.AddRoutingRules(testCase.InputPair)
- defer func() {
- _ = manager.RemoveRoutingRules(testCase.InputPair)
- }()
- require.NoError(t, err, "forwarding pair should be inserted")
+ err = manager.AddNatRule(testCase.InputPair)
+ require.NoError(t, err, "pair should be inserted")
- sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
- destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
- testingExpression := append(sourceExp, destExp...) //nolint:gocritic
- fwdRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
-
- found := 0
- for _, chain := range manager.chains {
- rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
- require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
- for _, rule := range rules {
- if len(rule.UserData) > 0 && string(rule.UserData) == fwdRuleKey {
- require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "forwarding rule elements should match")
- found = 1
- }
- }
- }
-
- require.Equal(t, 1, found, "should find at least 1 rule to test")
+ defer func(manager *router, pair firewall.RouterPair) {
+ require.NoError(t, manager.RemoveNatRule(pair), "failed to remove rule")
+ }(manager, testCase.InputPair)
if testCase.InputPair.Masquerade {
- natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
+ sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
+ destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
+ testingExpression := append(sourceExp, destExp...) //nolint:gocritic
+ testingExpression = append(testingExpression,
+ &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
+ &expr.Cmp{
+ Op: expr.CmpOpEq,
+ Register: 1,
+ Data: ifname(ifaceMock.Name()),
+ },
+ )
+
+ natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
found := 0
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
@@ -88,27 +86,20 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) {
require.Equal(t, 1, found, "should find at least 1 rule to test")
}
- sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInPair(testCase.InputPair).Source)
- destExp = generateCIDRMatcherExpressions(false, firewall.GetInPair(testCase.InputPair).Destination)
- testingExpression = append(sourceExp, destExp...) //nolint:gocritic
- inFwdRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
-
- found = 0
- for _, chain := range manager.chains {
- rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
- require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
- for _, rule := range rules {
- if len(rule.UserData) > 0 && string(rule.UserData) == inFwdRuleKey {
- require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income forwarding rule elements should match")
- found = 1
- }
- }
- }
-
- require.Equal(t, 1, found, "should find at least 1 rule to test")
-
if testCase.InputPair.Masquerade {
- inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
+ sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
+ destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
+ testingExpression := append(sourceExp, destExp...) //nolint:gocritic
+ testingExpression = append(testingExpression,
+ &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
+ &expr.Cmp{
+ Op: expr.CmpOpEq,
+ Register: 1,
+ Data: ifname(ifaceMock.Name()),
+ },
+ )
+
+ inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
found := 0
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
@@ -122,45 +113,37 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) {
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
}
+
})
}
}
-func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
+func TestNftablesManager_RemoveNatRule(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this OS")
}
table, err := createWorkTable()
- if err != nil {
- t.Fatal(err)
- }
+ require.NoError(t, err, "Failed to create work table")
defer deleteWorkTable()
for _, testCase := range test.RemoveRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) {
- manager, err := newRouter(context.TODO(), table)
+ manager, err := newRouter(context.TODO(), table, ifaceMock)
require.NoError(t, err, "failed to create router")
nftablesTestingClient := &nftables.Conn{}
- defer manager.ResetForwardRules()
+ defer func(manager *router) {
+ require.NoError(t, manager.Reset(), "failed to reset rules")
+ }(manager)
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
- forwardExp := append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
- forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
- insertedForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
- Table: manager.workTable,
- Chain: manager.chains[chainNameRouteingFw],
- Exprs: forwardExp,
- UserData: []byte(forwardRuleKey),
- })
-
natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
- natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
+ natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
insertedNat := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.workTable,
@@ -169,20 +152,11 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
UserData: []byte(natRuleKey),
})
- sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInPair(testCase.InputPair).Source)
- destExp = generateCIDRMatcherExpressions(false, firewall.GetInPair(testCase.InputPair).Destination)
-
- forwardExp = append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
- inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
- insertedInForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
- Table: manager.workTable,
- Chain: manager.chains[chainNameRouteingFw],
- Exprs: forwardExp,
- UserData: []byte(inForwardRuleKey),
- })
+ sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInversePair(testCase.InputPair).Source)
+ destExp = generateCIDRMatcherExpressions(false, firewall.GetInversePair(testCase.InputPair).Destination)
natExp = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
- inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
+ inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
insertedInNat := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.workTable,
@@ -194,9 +168,10 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
err = nftablesTestingClient.Flush()
require.NoError(t, err, "shouldn't return error")
- manager.ResetForwardRules()
+ err = manager.Reset()
+ require.NoError(t, err, "shouldn't return error")
- err = manager.RemoveRoutingRules(testCase.InputPair)
+ err = manager.RemoveNatRule(testCase.InputPair)
require.NoError(t, err, "shouldn't return error")
for _, chain := range manager.chains {
@@ -204,9 +179,7 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 {
- require.NotEqual(t, insertedForwarding.UserData, rule.UserData, "forwarding rule should not exist")
require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should not exist")
- require.NotEqual(t, insertedInForwarding.UserData, rule.UserData, "income forwarding rule should not exist")
require.NotEqual(t, insertedInNat.UserData, rule.UserData, "income nat rule should not exist")
}
}
@@ -215,6 +188,468 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
}
}
+func TestRouter_AddRouteFiltering(t *testing.T) {
+ if check() != NFTABLES {
+ t.Skip("nftables not supported on this system")
+ }
+
+ workTable, err := createWorkTable()
+ require.NoError(t, err, "Failed to create work table")
+
+ defer deleteWorkTable()
+
+ r, err := newRouter(context.Background(), workTable, ifaceMock)
+ require.NoError(t, err, "Failed to create router")
+
+ defer func(r *router) {
+ require.NoError(t, r.Reset(), "Failed to reset rules")
+ }(r)
+
+ tests := []struct {
+ name string
+ sources []netip.Prefix
+ destination netip.Prefix
+ proto firewall.Protocol
+ sPort *firewall.Port
+ dPort *firewall.Port
+ direction firewall.RuleDirection
+ action firewall.Action
+ expectSet bool
+ }{
+ {
+ name: "Basic TCP rule with single source",
+ sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
+ destination: netip.MustParsePrefix("10.0.0.0/24"),
+ proto: firewall.ProtocolTCP,
+ sPort: nil,
+ dPort: &firewall.Port{Values: []int{80}},
+ direction: firewall.RuleDirectionIN,
+ action: firewall.ActionAccept,
+ expectSet: false,
+ },
+ {
+ name: "UDP rule with multiple sources",
+ sources: []netip.Prefix{
+ netip.MustParsePrefix("172.16.0.0/16"),
+ netip.MustParsePrefix("192.168.0.0/16"),
+ },
+ destination: netip.MustParsePrefix("10.0.0.0/8"),
+ proto: firewall.ProtocolUDP,
+ sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true},
+ dPort: nil,
+ direction: firewall.RuleDirectionOUT,
+ action: firewall.ActionDrop,
+ expectSet: true,
+ },
+ {
+ name: "All protocols rule",
+ sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
+ destination: netip.MustParsePrefix("0.0.0.0/0"),
+ proto: firewall.ProtocolALL,
+ sPort: nil,
+ dPort: nil,
+ direction: firewall.RuleDirectionIN,
+ action: firewall.ActionAccept,
+ expectSet: false,
+ },
+ {
+ name: "ICMP rule",
+ sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")},
+ destination: netip.MustParsePrefix("10.0.0.0/8"),
+ proto: firewall.ProtocolICMP,
+ sPort: nil,
+ dPort: nil,
+ direction: firewall.RuleDirectionIN,
+ action: firewall.ActionAccept,
+ expectSet: false,
+ },
+ {
+ name: "TCP rule with multiple source ports",
+ sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
+ destination: netip.MustParsePrefix("192.168.0.0/16"),
+ proto: firewall.ProtocolTCP,
+ sPort: &firewall.Port{Values: []int{80, 443, 8080}},
+ dPort: nil,
+ direction: firewall.RuleDirectionOUT,
+ action: firewall.ActionAccept,
+ expectSet: false,
+ },
+ {
+ name: "UDP rule with single IP and port range",
+ sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")},
+ destination: netip.MustParsePrefix("10.0.0.0/24"),
+ proto: firewall.ProtocolUDP,
+ sPort: nil,
+ dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true},
+ direction: firewall.RuleDirectionIN,
+ action: firewall.ActionDrop,
+ expectSet: false,
+ },
+ {
+ name: "TCP rule with source and destination ports",
+ sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
+ destination: netip.MustParsePrefix("172.16.0.0/16"),
+ proto: firewall.ProtocolTCP,
+ sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true},
+ dPort: &firewall.Port{Values: []int{22}},
+ direction: firewall.RuleDirectionOUT,
+ action: firewall.ActionAccept,
+ expectSet: false,
+ },
+ {
+ name: "Drop all incoming traffic",
+ sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
+ destination: netip.MustParsePrefix("192.168.0.0/24"),
+ proto: firewall.ProtocolALL,
+ sPort: nil,
+ dPort: nil,
+ direction: firewall.RuleDirectionIN,
+ action: firewall.ActionDrop,
+ expectSet: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
+ require.NoError(t, err, "AddRouteFiltering failed")
+
+ // Check if the rule is in the internal map
+ rule, ok := r.rules[ruleKey.GetRuleID()]
+ assert.True(t, ok, "Rule not found in internal map")
+
+ t.Log("Internal rule expressions:")
+ for i, expr := range rule.Exprs {
+ t.Logf(" [%d] %T: %+v", i, expr, expr)
+ }
+
+ // Verify internal rule content
+ verifyRule(t, rule, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.direction, tt.action, tt.expectSet)
+
+ // Check if the rule exists in nftables and verify its content
+ rules, err := r.conn.GetRules(r.workTable, r.chains[chainNameRoutingFw])
+ require.NoError(t, err, "Failed to get rules from nftables")
+
+ var nftRule *nftables.Rule
+ for _, rule := range rules {
+ if string(rule.UserData) == ruleKey.GetRuleID() {
+ nftRule = rule
+ break
+ }
+ }
+
+ require.NotNil(t, nftRule, "Rule not found in nftables")
+ t.Log("Actual nftables rule expressions:")
+ for i, expr := range nftRule.Exprs {
+ t.Logf(" [%d] %T: %+v", i, expr, expr)
+ }
+
+ // Verify actual nftables rule content
+ verifyRule(t, nftRule, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.direction, tt.action, tt.expectSet)
+
+ // Clean up
+ err = r.DeleteRouteRule(ruleKey)
+ require.NoError(t, err, "Failed to delete rule")
+ })
+ }
+}
+
+func TestNftablesCreateIpSet(t *testing.T) {
+ if check() != NFTABLES {
+ t.Skip("nftables not supported on this system")
+ }
+
+ workTable, err := createWorkTable()
+ require.NoError(t, err, "Failed to create work table")
+
+ defer deleteWorkTable()
+
+ r, err := newRouter(context.Background(), workTable, ifaceMock)
+ require.NoError(t, err, "Failed to create router")
+
+ defer func() {
+ require.NoError(t, r.Reset(), "Failed to reset router")
+ }()
+
+ tests := []struct {
+ name string
+ sources []netip.Prefix
+ expected []netip.Prefix
+ }{
+ {
+ name: "Single IP",
+ sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")},
+ },
+ {
+ name: "Multiple IPs",
+ sources: []netip.Prefix{
+ netip.MustParsePrefix("192.168.1.1/32"),
+ netip.MustParsePrefix("10.0.0.1/32"),
+ netip.MustParsePrefix("172.16.0.1/32"),
+ },
+ },
+ {
+ name: "Single Subnet",
+ sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")},
+ },
+ {
+ name: "Multiple Subnets with Various Prefix Lengths",
+ sources: []netip.Prefix{
+ netip.MustParsePrefix("10.0.0.0/8"),
+ netip.MustParsePrefix("172.16.0.0/16"),
+ netip.MustParsePrefix("192.168.1.0/24"),
+ netip.MustParsePrefix("203.0.113.0/26"),
+ },
+ },
+ {
+ name: "Mix of Single IPs and Subnets in Different Positions",
+ sources: []netip.Prefix{
+ netip.MustParsePrefix("192.168.1.1/32"),
+ netip.MustParsePrefix("10.0.0.0/16"),
+ netip.MustParsePrefix("172.16.0.1/32"),
+ netip.MustParsePrefix("203.0.113.0/24"),
+ },
+ },
+ {
+ name: "Overlapping IPs/Subnets",
+ sources: []netip.Prefix{
+ netip.MustParsePrefix("10.0.0.0/8"),
+ netip.MustParsePrefix("10.0.0.0/16"),
+ netip.MustParsePrefix("10.0.0.1/32"),
+ netip.MustParsePrefix("192.168.0.0/16"),
+ netip.MustParsePrefix("192.168.1.0/24"),
+ netip.MustParsePrefix("192.168.1.1/32"),
+ },
+ expected: []netip.Prefix{
+ netip.MustParsePrefix("10.0.0.0/8"),
+ netip.MustParsePrefix("192.168.0.0/16"),
+ },
+ },
+ }
+
+ // Add this helper function inside TestNftablesCreateIpSet
+ printNftSets := func() {
+ cmd := exec.Command("nft", "list", "sets")
+ output, err := cmd.CombinedOutput()
+ if err != nil {
+ t.Logf("Failed to run 'nft list sets': %v", err)
+ } else {
+ t.Logf("Current nft sets:\n%s", output)
+ }
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ setName := firewall.GenerateSetName(tt.sources)
+ set, err := r.createIpSet(setName, tt.sources)
+ if err != nil {
+ t.Logf("Failed to create IP set: %v", err)
+ printNftSets()
+ require.NoError(t, err, "Failed to create IP set")
+ }
+ require.NotNil(t, set, "Created set is nil")
+
+ // Verify set properties
+ assert.Equal(t, setName, set.Name, "Set name mismatch")
+ assert.Equal(t, r.workTable, set.Table, "Set table mismatch")
+ assert.True(t, set.Interval, "Set interval property should be true")
+ assert.Equal(t, nftables.TypeIPAddr, set.KeyType, "Set key type mismatch")
+
+ // Fetch the created set from nftables
+ fetchedSet, err := r.conn.GetSetByName(r.workTable, setName)
+ require.NoError(t, err, "Failed to fetch created set")
+ require.NotNil(t, fetchedSet, "Fetched set is nil")
+
+ // Verify set elements
+ elements, err := r.conn.GetSetElements(fetchedSet)
+ require.NoError(t, err, "Failed to get set elements")
+
+ // Count the number of unique prefixes (excluding interval end markers)
+ uniquePrefixes := make(map[string]bool)
+ for _, elem := range elements {
+ if !elem.IntervalEnd {
+ ip := netip.AddrFrom4(*(*[4]byte)(elem.Key))
+ uniquePrefixes[ip.String()] = true
+ }
+ }
+
+ // Check against expected merged prefixes
+ expectedCount := len(tt.expected)
+ if expectedCount == 0 {
+ expectedCount = len(tt.sources)
+ }
+ assert.Equal(t, expectedCount, len(uniquePrefixes), "Number of unique prefixes in set doesn't match expected")
+
+ // Verify each expected prefix is in the set
+ for _, expected := range tt.expected {
+ found := false
+ for _, elem := range elements {
+ if !elem.IntervalEnd {
+ ip := netip.AddrFrom4(*(*[4]byte)(elem.Key))
+ if expected.Contains(ip) {
+ found = true
+ break
+ }
+ }
+ }
+ assert.True(t, found, "Expected prefix %s not found in set", expected)
+ }
+
+ r.conn.DelSet(set)
+ if err := r.conn.Flush(); err != nil {
+ t.Logf("Failed to delete set: %v", err)
+ printNftSets()
+ }
+ require.NoError(t, err, "Failed to delete set")
+ })
+ }
+}
+
+func verifyRule(t *testing.T, rule *nftables.Rule, sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, expectSet bool) {
+ t.Helper()
+
+ assert.NotNil(t, rule, "Rule should not be nil")
+
+ // Verify sources and destination
+ if expectSet {
+ assert.True(t, containsSetLookup(rule.Exprs), "Rule should contain set lookup for multiple sources")
+ } else if len(sources) == 1 && sources[0].Bits() != 0 {
+ if direction == firewall.RuleDirectionIN {
+ assert.True(t, containsCIDRMatcher(rule.Exprs, sources[0], true), "Rule should contain source CIDR matcher for %s", sources[0])
+ } else {
+ assert.True(t, containsCIDRMatcher(rule.Exprs, sources[0], false), "Rule should contain destination CIDR matcher for %s", sources[0])
+ }
+ }
+
+ if direction == firewall.RuleDirectionIN {
+ assert.True(t, containsCIDRMatcher(rule.Exprs, destination, false), "Rule should contain destination CIDR matcher for %s", destination)
+ } else {
+ assert.True(t, containsCIDRMatcher(rule.Exprs, destination, true), "Rule should contain source CIDR matcher for %s", destination)
+ }
+
+ // Verify protocol
+ if proto != firewall.ProtocolALL {
+ assert.True(t, containsProtocol(rule.Exprs, proto), "Rule should contain protocol matcher for %s", proto)
+ }
+
+ // Verify ports
+ if sPort != nil {
+ assert.True(t, containsPort(rule.Exprs, sPort, true), "Rule should contain source port matcher for %v", sPort)
+ }
+ if dPort != nil {
+ assert.True(t, containsPort(rule.Exprs, dPort, false), "Rule should contain destination port matcher for %v", dPort)
+ }
+
+ // Verify action
+ assert.True(t, containsAction(rule.Exprs, action), "Rule should contain correct action: %s", action)
+}
+
+func containsSetLookup(exprs []expr.Any) bool {
+ for _, e := range exprs {
+ if _, ok := e.(*expr.Lookup); ok {
+ return true
+ }
+ }
+ return false
+}
+
+func containsCIDRMatcher(exprs []expr.Any, prefix netip.Prefix, isSource bool) bool {
+ var offset uint32
+ if isSource {
+ offset = 12 // src offset
+ } else {
+ offset = 16 // dst offset
+ }
+
+ var payloadFound, bitwiseFound, cmpFound bool
+ for _, e := range exprs {
+ switch ex := e.(type) {
+ case *expr.Payload:
+ if ex.Base == expr.PayloadBaseNetworkHeader && ex.Offset == offset && ex.Len == 4 {
+ payloadFound = true
+ }
+ case *expr.Bitwise:
+ if ex.Len == 4 && len(ex.Mask) == 4 && len(ex.Xor) == 4 {
+ bitwiseFound = true
+ }
+ case *expr.Cmp:
+ if ex.Op == expr.CmpOpEq && len(ex.Data) == 4 {
+ cmpFound = true
+ }
+ }
+ }
+ return (payloadFound && bitwiseFound && cmpFound) || prefix.Bits() == 0
+}
+
+func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool {
+ var offset uint32 = 2 // Default offset for destination port
+ if isSource {
+ offset = 0 // Offset for source port
+ }
+
+ var payloadFound, portMatchFound bool
+ for _, e := range exprs {
+ switch ex := e.(type) {
+ case *expr.Payload:
+ if ex.Base == expr.PayloadBaseTransportHeader && ex.Offset == offset && ex.Len == 2 {
+ payloadFound = true
+ }
+ case *expr.Cmp:
+ if port.IsRange {
+ if ex.Op == expr.CmpOpGte || ex.Op == expr.CmpOpLte {
+ portMatchFound = true
+ }
+ } else {
+ if ex.Op == expr.CmpOpEq && len(ex.Data) == 2 {
+ portValue := binary.BigEndian.Uint16(ex.Data)
+ for _, p := range port.Values {
+ if uint16(p) == portValue {
+ portMatchFound = true
+ break
+ }
+ }
+ }
+ }
+ }
+ if payloadFound && portMatchFound {
+ return true
+ }
+ }
+ return false
+}
+
+func containsProtocol(exprs []expr.Any, proto firewall.Protocol) bool {
+ var metaFound, cmpFound bool
+ expectedProto, _ := protoToInt(proto)
+ for _, e := range exprs {
+ switch ex := e.(type) {
+ case *expr.Meta:
+ if ex.Key == expr.MetaKeyL4PROTO {
+ metaFound = true
+ }
+ case *expr.Cmp:
+ if ex.Op == expr.CmpOpEq && len(ex.Data) == 1 && ex.Data[0] == expectedProto {
+ cmpFound = true
+ }
+ }
+ }
+ return metaFound && cmpFound
+}
+
+func containsAction(exprs []expr.Any, action firewall.Action) bool {
+ for _, e := range exprs {
+ if verdict, ok := e.(*expr.Verdict); ok {
+ switch action {
+ case firewall.ActionAccept:
+ return verdict.Kind == expr.VerdictAccept
+ case firewall.ActionDrop:
+ return verdict.Kind == expr.VerdictDrop
+ }
+ }
+ }
+ return false
+}
+
// check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
func check() int {
nf := nftables.Conn{}
@@ -250,12 +685,12 @@ func createWorkTable() (*nftables.Table, error) {
}
for _, t := range tables {
- if t.Name == tableName {
+ if t.Name == tableNameNetbird {
sConn.DelTable(t)
}
}
- table := sConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
+ table := sConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4})
err = sConn.Flush()
return table, err
@@ -273,7 +708,7 @@ func deleteWorkTable() {
}
for _, t := range tables {
- if t.Name == tableName {
+ if t.Name == tableNameNetbird {
sConn.DelTable(t)
}
}
diff --git a/client/firewall/test/cases_linux.go b/client/firewall/test/cases_linux.go
index 432d113dd..267e93efd 100644
--- a/client/firewall/test/cases_linux.go
+++ b/client/firewall/test/cases_linux.go
@@ -1,8 +1,10 @@
-//go:build !android
-
package test
-import firewall "github.com/netbirdio/netbird/client/firewall/manager"
+import (
+ "net/netip"
+
+ firewall "github.com/netbirdio/netbird/client/firewall/manager"
+)
var (
InsertRuleTestCases = []struct {
@@ -13,8 +15,8 @@ var (
Name: "Insert Forwarding IPV4 Rule",
InputPair: firewall.RouterPair{
ID: "zxa",
- Source: "100.100.100.1/32",
- Destination: "100.100.200.0/24",
+ Source: netip.MustParsePrefix("100.100.100.1/32"),
+ Destination: netip.MustParsePrefix("100.100.200.0/24"),
Masquerade: false,
},
},
@@ -22,8 +24,8 @@ var (
Name: "Insert Forwarding And Nat IPV4 Rules",
InputPair: firewall.RouterPair{
ID: "zxa",
- Source: "100.100.100.1/32",
- Destination: "100.100.200.0/24",
+ Source: netip.MustParsePrefix("100.100.100.1/32"),
+ Destination: netip.MustParsePrefix("100.100.200.0/24"),
Masquerade: true,
},
},
@@ -38,8 +40,8 @@ var (
Name: "Remove Forwarding And Nat IPV4 Rules",
InputPair: firewall.RouterPair{
ID: "zxa",
- Source: "100.100.100.1/32",
- Destination: "100.100.200.0/24",
+ Source: netip.MustParsePrefix("100.100.100.1/32"),
+ Destination: netip.MustParsePrefix("100.100.200.0/24"),
Masquerade: true,
},
},
diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go
index 75792e9c0..681058ea9 100644
--- a/client/firewall/uspfilter/uspfilter.go
+++ b/client/firewall/uspfilter/uspfilter.go
@@ -3,6 +3,7 @@ package uspfilter
import (
"fmt"
"net"
+ "net/netip"
"sync"
"github.com/google/gopacket"
@@ -103,26 +104,26 @@ func (m *Manager) IsServerRouteSupported() bool {
}
}
-func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
+func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
if m.nativeFirewall == nil {
return errRouteNotSupported
}
- return m.nativeFirewall.InsertRoutingRules(pair)
+ return m.nativeFirewall.AddNatRule(pair)
}
-// RemoveRoutingRules removes a routing firewall rule
-func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
+// RemoveNatRule removes a routing firewall rule
+func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
if m.nativeFirewall == nil {
return errRouteNotSupported
}
- return m.nativeFirewall.RemoveRoutingRules(pair)
+ return m.nativeFirewall.RemoveNatRule(pair)
}
-// AddFiltering rule to the firewall
+// AddPeerFiltering rule to the firewall
//
// If comment argument is empty firewall manager should set
// rule ID as comment for the rule
-func (m *Manager) AddFiltering(
+func (m *Manager) AddPeerFiltering(
ip net.IP,
proto firewall.Protocol,
sPort *firewall.Port,
@@ -188,8 +189,22 @@ func (m *Manager) AddFiltering(
return []firewall.Rule{&r}, nil
}
-// DeleteRule from the firewall by rule definition
-func (m *Manager) DeleteRule(rule firewall.Rule) error {
+func (m *Manager) AddRouteFiltering(sources [] netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action ) (firewall.Rule, error) {
+ if m.nativeFirewall == nil {
+ return nil, errRouteNotSupported
+ }
+ return m.nativeFirewall.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
+}
+
+func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
+ if m.nativeFirewall == nil {
+ return errRouteNotSupported
+ }
+ return m.nativeFirewall.DeleteRouteRule(rule)
+}
+
+// DeletePeerRule from the firewall by rule definition
+func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -215,6 +230,11 @@ func (m *Manager) DeleteRule(rule firewall.Rule) error {
return nil
}
+// SetLegacyManagement doesn't need to be implemented for this manager
+func (m *Manager) SetLegacyManagement(_ bool) error {
+ return nil
+}
+
// Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil }
@@ -395,7 +415,7 @@ func (m *Manager) RemovePacketHook(hookID string) error {
for _, r := range arr {
if r.id == hookID {
rule := r
- return m.DeleteRule(&rule)
+ return m.DeletePeerRule(&rule)
}
}
}
@@ -403,7 +423,7 @@ func (m *Manager) RemovePacketHook(hookID string) error {
for _, r := range arr {
if r.id == hookID {
rule := r
- return m.DeleteRule(&rule)
+ return m.DeletePeerRule(&rule)
}
}
}
diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go
index 514a90539..dd7366fe9 100644
--- a/client/firewall/uspfilter/uspfilter_test.go
+++ b/client/firewall/uspfilter/uspfilter_test.go
@@ -49,7 +49,7 @@ func TestManagerCreate(t *testing.T) {
}
}
-func TestManagerAddFiltering(t *testing.T) {
+func TestManagerAddPeerFiltering(t *testing.T) {
isSetFilterCalled := false
ifaceMock := &IFaceMock{
SetFilterFunc: func(iface.PacketFilter) error {
@@ -71,7 +71,7 @@ func TestManagerAddFiltering(t *testing.T) {
action := fw.ActionDrop
comment := "Test rule"
- rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
+ rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
@@ -106,7 +106,7 @@ func TestManagerDeleteRule(t *testing.T) {
action := fw.ActionDrop
comment := "Test rule"
- rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
+ rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
@@ -119,14 +119,14 @@ func TestManagerDeleteRule(t *testing.T) {
action = fw.ActionDrop
comment = "Test rule 2"
- rule2, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
+ rule2, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
}
for _, r := range rule {
- err = m.DeleteRule(r)
+ err = m.DeletePeerRule(r)
if err != nil {
t.Errorf("failed to delete rule: %v", err)
return
@@ -140,7 +140,7 @@ func TestManagerDeleteRule(t *testing.T) {
}
for _, r := range rule2 {
- err = m.DeleteRule(r)
+ err = m.DeletePeerRule(r)
if err != nil {
t.Errorf("failed to delete rule: %v", err)
return
@@ -252,7 +252,7 @@ func TestManagerReset(t *testing.T) {
action := fw.ActionDrop
comment := "Test rule"
- _, err = m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
+ _, err = m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
@@ -290,7 +290,7 @@ func TestNotMatchByIP(t *testing.T) {
action := fw.ActionAccept
comment := "Test rule"
- _, err = m.AddFiltering(ip, proto, nil, nil, direction, action, "", comment)
+ _, err = m.AddPeerFiltering(ip, proto, nil, nil, direction, action, "", comment)
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
@@ -406,9 +406,9 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}}
if i%2 == 0 {
- _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
+ _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
} else {
- _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
+ _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
}
require.NoError(t, err, "failed to add rule")
diff --git a/client/internal/acl/id/id.go b/client/internal/acl/id/id.go
new file mode 100644
index 000000000..e27fce439
--- /dev/null
+++ b/client/internal/acl/id/id.go
@@ -0,0 +1,25 @@
+package id
+
+import (
+ "fmt"
+ "net/netip"
+
+ "github.com/netbirdio/netbird/client/firewall/manager"
+)
+
+type RuleID string
+
+func (r RuleID) GetRuleID() string {
+ return string(r)
+}
+
+func GenerateRouteRuleKey(
+ sources []netip.Prefix,
+ destination netip.Prefix,
+ proto manager.Protocol,
+ sPort *manager.Port,
+ dPort *manager.Port,
+ action manager.Action,
+) RuleID {
+ return RuleID(fmt.Sprintf("%s-%s-%s-%s-%s-%d", sources, destination, proto, sPort, dPort, action))
+}
diff --git a/client/internal/acl/manager.go b/client/internal/acl/manager.go
index fd2c2c875..ce2a12af1 100644
--- a/client/internal/acl/manager.go
+++ b/client/internal/acl/manager.go
@@ -5,6 +5,7 @@ import (
"encoding/hex"
"fmt"
"net"
+ "net/netip"
"strconv"
"sync"
"time"
@@ -12,6 +13,7 @@ import (
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
+ "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/ssh"
mgmProto "github.com/netbirdio/netbird/management/proto"
)
@@ -23,16 +25,18 @@ type Manager interface {
// DefaultManager uses firewall manager to handle
type DefaultManager struct {
- firewall firewall.Manager
- ipsetCounter int
- rulesPairs map[string][]firewall.Rule
- mutex sync.Mutex
+ firewall firewall.Manager
+ ipsetCounter int
+ peerRulesPairs map[id.RuleID][]firewall.Rule
+ routeRules map[id.RuleID]struct{}
+ mutex sync.Mutex
}
func NewDefaultManager(fm firewall.Manager) *DefaultManager {
return &DefaultManager{
- firewall: fm,
- rulesPairs: make(map[string][]firewall.Rule),
+ firewall: fm,
+ peerRulesPairs: make(map[id.RuleID][]firewall.Rule),
+ routeRules: make(map[id.RuleID]struct{}),
}
}
@@ -46,7 +50,7 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
start := time.Now()
defer func() {
total := 0
- for _, pairs := range d.rulesPairs {
+ for _, pairs := range d.peerRulesPairs {
total += len(pairs)
}
log.Infof(
@@ -59,21 +63,34 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
return
}
- defer func() {
- if err := d.firewall.Flush(); err != nil {
- log.Error("failed to flush firewall rules: ", err)
- }
- }()
+ d.applyPeerACLs(networkMap)
+ // If we got empty rules list but management did not set the networkMap.FirewallRulesIsEmpty flag,
+ // then the mgmt server is older than the client, and we need to allow all traffic for routes
+ isLegacy := len(networkMap.RoutesFirewallRules) == 0 && !networkMap.RoutesFirewallRulesIsEmpty
+ if err := d.firewall.SetLegacyManagement(isLegacy); err != nil {
+ log.Errorf("failed to set legacy management flag: %v", err)
+ }
+
+ if err := d.applyRouteACLs(networkMap.RoutesFirewallRules); err != nil {
+ log.Errorf("Failed to apply route ACLs: %v", err)
+ }
+
+ if err := d.firewall.Flush(); err != nil {
+ log.Error("failed to flush firewall rules: ", err)
+ }
+}
+
+func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
rules, squashedProtocols := d.squashAcceptRules(networkMap)
- enableSSH := (networkMap.PeerConfig != nil &&
+ enableSSH := networkMap.PeerConfig != nil &&
networkMap.PeerConfig.SshConfig != nil &&
- networkMap.PeerConfig.SshConfig.SshEnabled)
- if _, ok := squashedProtocols[mgmProto.FirewallRule_ALL]; ok {
+ networkMap.PeerConfig.SshConfig.SshEnabled
+ if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok {
enableSSH = enableSSH && !ok
}
- if _, ok := squashedProtocols[mgmProto.FirewallRule_TCP]; ok {
+ if _, ok := squashedProtocols[mgmProto.RuleProtocol_TCP]; ok {
enableSSH = enableSSH && !ok
}
@@ -83,9 +100,9 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
if enableSSH {
rules = append(rules, &mgmProto.FirewallRule{
PeerIP: "0.0.0.0",
- Direction: mgmProto.FirewallRule_IN,
- Action: mgmProto.FirewallRule_ACCEPT,
- Protocol: mgmProto.FirewallRule_TCP,
+ Direction: mgmProto.RuleDirection_IN,
+ Action: mgmProto.RuleAction_ACCEPT,
+ Protocol: mgmProto.RuleProtocol_TCP,
Port: strconv.Itoa(ssh.DefaultSSHPort),
})
}
@@ -97,20 +114,20 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
rules = append(rules,
&mgmProto.FirewallRule{
PeerIP: "0.0.0.0",
- Direction: mgmProto.FirewallRule_IN,
- Action: mgmProto.FirewallRule_ACCEPT,
- Protocol: mgmProto.FirewallRule_ALL,
+ Direction: mgmProto.RuleDirection_IN,
+ Action: mgmProto.RuleAction_ACCEPT,
+ Protocol: mgmProto.RuleProtocol_ALL,
},
&mgmProto.FirewallRule{
PeerIP: "0.0.0.0",
- Direction: mgmProto.FirewallRule_OUT,
- Action: mgmProto.FirewallRule_ACCEPT,
- Protocol: mgmProto.FirewallRule_ALL,
+ Direction: mgmProto.RuleDirection_OUT,
+ Action: mgmProto.RuleAction_ACCEPT,
+ Protocol: mgmProto.RuleProtocol_ALL,
},
)
}
- newRulePairs := make(map[string][]firewall.Rule)
+ newRulePairs := make(map[id.RuleID][]firewall.Rule)
ipsetByRuleSelectors := make(map[string]string)
for _, r := range rules {
@@ -130,29 +147,97 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
break
}
if len(rules) > 0 {
- d.rulesPairs[pairID] = rulePair
+ d.peerRulesPairs[pairID] = rulePair
newRulePairs[pairID] = rulePair
}
}
- for pairID, rules := range d.rulesPairs {
+ for pairID, rules := range d.peerRulesPairs {
if _, ok := newRulePairs[pairID]; !ok {
for _, rule := range rules {
- if err := d.firewall.DeleteRule(rule); err != nil {
- log.Errorf("failed to delete firewall rule: %v", err)
+ if err := d.firewall.DeletePeerRule(rule); err != nil {
+ log.Errorf("failed to delete peer firewall rule: %v", err)
continue
}
}
- delete(d.rulesPairs, pairID)
+ delete(d.peerRulesPairs, pairID)
}
}
- d.rulesPairs = newRulePairs
+ d.peerRulesPairs = newRulePairs
+}
+
+func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) error {
+ var newRouteRules = make(map[id.RuleID]struct{})
+ for _, rule := range rules {
+ id, err := d.applyRouteACL(rule)
+ if err != nil {
+ return fmt.Errorf("apply route ACL: %w", err)
+ }
+ newRouteRules[id] = struct{}{}
+ }
+
+ for id := range d.routeRules {
+ if _, ok := newRouteRules[id]; !ok {
+ if err := d.firewall.DeleteRouteRule(id); err != nil {
+ log.Errorf("failed to delete route firewall rule: %v", err)
+ continue
+ }
+ delete(d.routeRules, id)
+ }
+ }
+ d.routeRules = newRouteRules
+ return nil
+}
+
+func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.RuleID, error) {
+ if len(rule.SourceRanges) == 0 {
+ return "", fmt.Errorf("source ranges is empty")
+ }
+
+ var sources []netip.Prefix
+ for _, sourceRange := range rule.SourceRanges {
+ source, err := netip.ParsePrefix(sourceRange)
+ if err != nil {
+ return "", fmt.Errorf("parse source range: %w", err)
+ }
+ sources = append(sources, source)
+ }
+
+ var destination netip.Prefix
+ if rule.IsDynamic {
+ destination = getDefault(sources[0])
+ } else {
+ var err error
+ destination, err = netip.ParsePrefix(rule.Destination)
+ if err != nil {
+ return "", fmt.Errorf("parse destination: %w", err)
+ }
+ }
+
+ protocol, err := convertToFirewallProtocol(rule.Protocol)
+ if err != nil {
+ return "", fmt.Errorf("invalid protocol: %w", err)
+ }
+
+ action, err := convertFirewallAction(rule.Action)
+ if err != nil {
+ return "", fmt.Errorf("invalid action: %w", err)
+ }
+
+ dPorts := convertPortInfo(rule.PortInfo)
+
+ addedRule, err := d.firewall.AddRouteFiltering(sources, destination, protocol, nil, dPorts, action)
+ if err != nil {
+ return "", fmt.Errorf("add route rule: %w", err)
+ }
+
+ return id.RuleID(addedRule.GetRuleID()), nil
}
func (d *DefaultManager) protoRuleToFirewallRule(
r *mgmProto.FirewallRule,
ipsetName string,
-) (string, []firewall.Rule, error) {
+) (id.RuleID, []firewall.Rule, error) {
ip := net.ParseIP(r.PeerIP)
if ip == nil {
return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule")
@@ -179,16 +264,16 @@ func (d *DefaultManager) protoRuleToFirewallRule(
}
}
- ruleID := d.getRuleID(ip, protocol, int(r.Direction), port, action, "")
- if rulesPair, ok := d.rulesPairs[ruleID]; ok {
+ ruleID := d.getPeerRuleID(ip, protocol, int(r.Direction), port, action, "")
+ if rulesPair, ok := d.peerRulesPairs[ruleID]; ok {
return ruleID, rulesPair, nil
}
var rules []firewall.Rule
switch r.Direction {
- case mgmProto.FirewallRule_IN:
+ case mgmProto.RuleDirection_IN:
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
- case mgmProto.FirewallRule_OUT:
+ case mgmProto.RuleDirection_OUT:
rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "")
default:
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
@@ -210,7 +295,7 @@ func (d *DefaultManager) addInRules(
comment string,
) ([]firewall.Rule, error) {
var rules []firewall.Rule
- rule, err := d.firewall.AddFiltering(
+ rule, err := d.firewall.AddPeerFiltering(
ip, protocol, nil, port, firewall.RuleDirectionIN, action, ipsetName, comment)
if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
@@ -221,7 +306,7 @@ func (d *DefaultManager) addInRules(
return rules, nil
}
- rule, err = d.firewall.AddFiltering(
+ rule, err = d.firewall.AddPeerFiltering(
ip, protocol, port, nil, firewall.RuleDirectionOUT, action, ipsetName, comment)
if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
@@ -239,7 +324,7 @@ func (d *DefaultManager) addOutRules(
comment string,
) ([]firewall.Rule, error) {
var rules []firewall.Rule
- rule, err := d.firewall.AddFiltering(
+ rule, err := d.firewall.AddPeerFiltering(
ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment)
if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
@@ -250,7 +335,7 @@ func (d *DefaultManager) addOutRules(
return rules, nil
}
- rule, err = d.firewall.AddFiltering(
+ rule, err = d.firewall.AddPeerFiltering(
ip, protocol, port, nil, firewall.RuleDirectionIN, action, ipsetName, comment)
if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
@@ -259,21 +344,21 @@ func (d *DefaultManager) addOutRules(
return append(rules, rule...), nil
}
-// getRuleID() returns unique ID for the rule based on its parameters.
-func (d *DefaultManager) getRuleID(
+// getPeerRuleID() returns unique ID for the rule based on its parameters.
+func (d *DefaultManager) getPeerRuleID(
ip net.IP,
proto firewall.Protocol,
direction int,
port *firewall.Port,
action firewall.Action,
comment string,
-) string {
+) id.RuleID {
idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action)) + comment
if port != nil {
idStr += port.String()
}
- return hex.EncodeToString(md5.New().Sum([]byte(idStr)))
+ return id.RuleID(hex.EncodeToString(md5.New().Sum([]byte(idStr))))
}
// squashAcceptRules does complex logic to convert many rules which allows connection by traffic type
@@ -283,7 +368,7 @@ func (d *DefaultManager) getRuleID(
// but other has port definitions or has drop policy.
func (d *DefaultManager) squashAcceptRules(
networkMap *mgmProto.NetworkMap,
-) ([]*mgmProto.FirewallRule, map[mgmProto.FirewallRuleProtocol]struct{}) {
+) ([]*mgmProto.FirewallRule, map[mgmProto.RuleProtocol]struct{}) {
totalIPs := 0
for _, p := range append(networkMap.RemotePeers, networkMap.OfflinePeers...) {
for range p.AllowedIps {
@@ -291,14 +376,14 @@ func (d *DefaultManager) squashAcceptRules(
}
}
- type protoMatch map[mgmProto.FirewallRuleProtocol]map[string]int
+ type protoMatch map[mgmProto.RuleProtocol]map[string]int
in := protoMatch{}
out := protoMatch{}
// trace which type of protocols was squashed
squashedRules := []*mgmProto.FirewallRule{}
- squashedProtocols := map[mgmProto.FirewallRuleProtocol]struct{}{}
+ squashedProtocols := map[mgmProto.RuleProtocol]struct{}{}
// this function we use to do calculation, can we squash the rules by protocol or not.
// We summ amount of Peers IP for given protocol we found in original rules list.
@@ -308,7 +393,7 @@ func (d *DefaultManager) squashAcceptRules(
//
// We zeroed this to notify squash function that this protocol can't be squashed.
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols protoMatch) {
- drop := r.Action == mgmProto.FirewallRule_DROP || r.Port != ""
+ drop := r.Action == mgmProto.RuleAction_DROP || r.Port != ""
if drop {
protocols[r.Protocol] = map[string]int{}
return
@@ -336,7 +421,7 @@ func (d *DefaultManager) squashAcceptRules(
for i, r := range networkMap.FirewallRules {
// calculate squash for different directions
- if r.Direction == mgmProto.FirewallRule_IN {
+ if r.Direction == mgmProto.RuleDirection_IN {
addRuleToCalculationMap(i, r, in)
} else {
addRuleToCalculationMap(i, r, out)
@@ -345,14 +430,14 @@ func (d *DefaultManager) squashAcceptRules(
// order of squashing by protocol is important
// only for their first element ALL, it must be done first
- protocolOrders := []mgmProto.FirewallRuleProtocol{
- mgmProto.FirewallRule_ALL,
- mgmProto.FirewallRule_ICMP,
- mgmProto.FirewallRule_TCP,
- mgmProto.FirewallRule_UDP,
+ protocolOrders := []mgmProto.RuleProtocol{
+ mgmProto.RuleProtocol_ALL,
+ mgmProto.RuleProtocol_ICMP,
+ mgmProto.RuleProtocol_TCP,
+ mgmProto.RuleProtocol_UDP,
}
- squash := func(matches protoMatch, direction mgmProto.FirewallRuleDirection) {
+ squash := func(matches protoMatch, direction mgmProto.RuleDirection) {
for _, protocol := range protocolOrders {
if ipset, ok := matches[protocol]; !ok || len(ipset) != totalIPs || len(ipset) < 2 {
// don't squash if :
@@ -365,12 +450,12 @@ func (d *DefaultManager) squashAcceptRules(
squashedRules = append(squashedRules, &mgmProto.FirewallRule{
PeerIP: "0.0.0.0",
Direction: direction,
- Action: mgmProto.FirewallRule_ACCEPT,
+ Action: mgmProto.RuleAction_ACCEPT,
Protocol: protocol,
})
squashedProtocols[protocol] = struct{}{}
- if protocol == mgmProto.FirewallRule_ALL {
+ if protocol == mgmProto.RuleProtocol_ALL {
// if we have ALL traffic type squashed rule
// it allows all other type of traffic, so we can stop processing
break
@@ -378,11 +463,11 @@ func (d *DefaultManager) squashAcceptRules(
}
}
- squash(in, mgmProto.FirewallRule_IN)
- squash(out, mgmProto.FirewallRule_OUT)
+ squash(in, mgmProto.RuleDirection_IN)
+ squash(out, mgmProto.RuleDirection_OUT)
// if all protocol was squashed everything is allow and we can ignore all other rules
- if _, ok := squashedProtocols[mgmProto.FirewallRule_ALL]; ok {
+ if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok {
return squashedRules, squashedProtocols
}
@@ -412,26 +497,26 @@ func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) st
return fmt.Sprintf("%v:%v:%v:%s", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port)
}
-func (d *DefaultManager) rollBack(newRulePairs map[string][]firewall.Rule) {
+func (d *DefaultManager) rollBack(newRulePairs map[id.RuleID][]firewall.Rule) {
log.Debugf("rollback ACL to previous state")
for _, rules := range newRulePairs {
for _, rule := range rules {
- if err := d.firewall.DeleteRule(rule); err != nil {
+ if err := d.firewall.DeletePeerRule(rule); err != nil {
log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.GetRuleID(), err)
}
}
}
}
-func convertToFirewallProtocol(protocol mgmProto.FirewallRuleProtocol) (firewall.Protocol, error) {
+func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewall.Protocol, error) {
switch protocol {
- case mgmProto.FirewallRule_TCP:
+ case mgmProto.RuleProtocol_TCP:
return firewall.ProtocolTCP, nil
- case mgmProto.FirewallRule_UDP:
+ case mgmProto.RuleProtocol_UDP:
return firewall.ProtocolUDP, nil
- case mgmProto.FirewallRule_ICMP:
+ case mgmProto.RuleProtocol_ICMP:
return firewall.ProtocolICMP, nil
- case mgmProto.FirewallRule_ALL:
+ case mgmProto.RuleProtocol_ALL:
return firewall.ProtocolALL, nil
default:
return firewall.ProtocolALL, fmt.Errorf("invalid protocol type: %s", protocol.String())
@@ -442,13 +527,41 @@ func shouldSkipInvertedRule(protocol firewall.Protocol, port *firewall.Port) boo
return protocol == firewall.ProtocolALL || protocol == firewall.ProtocolICMP || port == nil
}
-func convertFirewallAction(action mgmProto.FirewallRuleAction) (firewall.Action, error) {
+func convertFirewallAction(action mgmProto.RuleAction) (firewall.Action, error) {
switch action {
- case mgmProto.FirewallRule_ACCEPT:
+ case mgmProto.RuleAction_ACCEPT:
return firewall.ActionAccept, nil
- case mgmProto.FirewallRule_DROP:
+ case mgmProto.RuleAction_DROP:
return firewall.ActionDrop, nil
default:
return firewall.ActionDrop, fmt.Errorf("invalid action type: %d", action)
}
}
+
+func convertPortInfo(portInfo *mgmProto.PortInfo) *firewall.Port {
+ if portInfo == nil {
+ return nil
+ }
+
+ if portInfo.GetPort() != 0 {
+ return &firewall.Port{
+ Values: []int{int(portInfo.GetPort())},
+ }
+ }
+
+ if portInfo.GetRange() != nil {
+ return &firewall.Port{
+ IsRange: true,
+ Values: []int{int(portInfo.GetRange().Start), int(portInfo.GetRange().End)},
+ }
+ }
+
+ return nil
+}
+
+func getDefault(prefix netip.Prefix) netip.Prefix {
+ if prefix.Addr().Is6() {
+ return netip.PrefixFrom(netip.IPv6Unspecified(), 0)
+ }
+ return netip.PrefixFrom(netip.IPv4Unspecified(), 0)
+}
diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go
index 494d54bf2..eec3d3b8c 100644
--- a/client/internal/acl/manager_test.go
+++ b/client/internal/acl/manager_test.go
@@ -19,16 +19,16 @@ func TestDefaultManager(t *testing.T) {
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
- Direction: mgmProto.FirewallRule_OUT,
- Action: mgmProto.FirewallRule_ACCEPT,
- Protocol: mgmProto.FirewallRule_TCP,
+ Direction: mgmProto.RuleDirection_OUT,
+ Action: mgmProto.RuleAction_ACCEPT,
+ Protocol: mgmProto.RuleProtocol_TCP,
Port: "80",
},
{
PeerIP: "10.93.0.2",
- Direction: mgmProto.FirewallRule_OUT,
- Action: mgmProto.FirewallRule_DROP,
- Protocol: mgmProto.FirewallRule_UDP,
+ Direction: mgmProto.RuleDirection_OUT,
+ Action: mgmProto.RuleAction_DROP,
+ Protocol: mgmProto.RuleProtocol_UDP,
Port: "53",
},
},
@@ -65,16 +65,16 @@ func TestDefaultManager(t *testing.T) {
t.Run("apply firewall rules", func(t *testing.T) {
acl.ApplyFiltering(networkMap)
- if len(acl.rulesPairs) != 2 {
- t.Errorf("firewall rules not applied: %v", acl.rulesPairs)
+ if len(acl.peerRulesPairs) != 2 {
+ t.Errorf("firewall rules not applied: %v", acl.peerRulesPairs)
return
}
})
t.Run("add extra rules", func(t *testing.T) {
existedPairs := map[string]struct{}{}
- for id := range acl.rulesPairs {
- existedPairs[id] = struct{}{}
+ for id := range acl.peerRulesPairs {
+ existedPairs[id.GetRuleID()] = struct{}{}
}
// remove first rule
@@ -83,24 +83,24 @@ func TestDefaultManager(t *testing.T) {
networkMap.FirewallRules,
&mgmProto.FirewallRule{
PeerIP: "10.93.0.3",
- Direction: mgmProto.FirewallRule_IN,
- Action: mgmProto.FirewallRule_DROP,
- Protocol: mgmProto.FirewallRule_ICMP,
+ Direction: mgmProto.RuleDirection_IN,
+ Action: mgmProto.RuleAction_DROP,
+ Protocol: mgmProto.RuleProtocol_ICMP,
},
)
acl.ApplyFiltering(networkMap)
// we should have one old and one new rule in the existed rules
- if len(acl.rulesPairs) != 2 {
+ if len(acl.peerRulesPairs) != 2 {
t.Errorf("firewall rules not applied")
return
}
// check that old rule was removed
previousCount := 0
- for id := range acl.rulesPairs {
- if _, ok := existedPairs[id]; ok {
+ for id := range acl.peerRulesPairs {
+ if _, ok := existedPairs[id.GetRuleID()]; ok {
previousCount++
}
}
@@ -113,15 +113,15 @@ func TestDefaultManager(t *testing.T) {
networkMap.FirewallRules = networkMap.FirewallRules[:0]
networkMap.FirewallRulesIsEmpty = true
- if acl.ApplyFiltering(networkMap); len(acl.rulesPairs) != 0 {
- t.Errorf("rules should be empty if FirewallRulesIsEmpty is set, got: %v", len(acl.rulesPairs))
+ if acl.ApplyFiltering(networkMap); len(acl.peerRulesPairs) != 0 {
+ t.Errorf("rules should be empty if FirewallRulesIsEmpty is set, got: %v", len(acl.peerRulesPairs))
return
}
networkMap.FirewallRulesIsEmpty = false
acl.ApplyFiltering(networkMap)
- if len(acl.rulesPairs) != 2 {
- t.Errorf("rules should contain 2 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.rulesPairs))
+ if len(acl.peerRulesPairs) != 2 {
+ t.Errorf("rules should contain 2 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs))
return
}
})
@@ -138,51 +138,51 @@ func TestDefaultManagerSquashRules(t *testing.T) {
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
- Direction: mgmProto.FirewallRule_IN,
- Action: mgmProto.FirewallRule_ACCEPT,
- Protocol: mgmProto.FirewallRule_ALL,
+ Direction: mgmProto.RuleDirection_IN,
+ Action: mgmProto.RuleAction_ACCEPT,
+ Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.2",
- Direction: mgmProto.FirewallRule_IN,
- Action: mgmProto.FirewallRule_ACCEPT,
- Protocol: mgmProto.FirewallRule_ALL,
+ Direction: mgmProto.RuleDirection_IN,
+ Action: mgmProto.RuleAction_ACCEPT,
+ Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.3",
- Direction: mgmProto.FirewallRule_IN,
- Action: mgmProto.FirewallRule_ACCEPT,
- Protocol: mgmProto.FirewallRule_ALL,
+ Direction: mgmProto.RuleDirection_IN,
+ Action: mgmProto.RuleAction_ACCEPT,
+ Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.4",
- Direction: mgmProto.FirewallRule_IN,
- Action: mgmProto.FirewallRule_ACCEPT,
- Protocol: mgmProto.FirewallRule_ALL,
+ Direction: mgmProto.RuleDirection_IN,
+ Action: mgmProto.RuleAction_ACCEPT,
+ Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.1",
- Direction: mgmProto.FirewallRule_OUT,
- Action: mgmProto.FirewallRule_ACCEPT,
- Protocol: mgmProto.FirewallRule_ALL,
+ Direction: mgmProto.RuleDirection_OUT,
+ Action: mgmProto.RuleAction_ACCEPT,
+ Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.2",
- Direction: mgmProto.FirewallRule_OUT,
- Action: mgmProto.FirewallRule_ACCEPT,
- Protocol: mgmProto.FirewallRule_ALL,
+ Direction: mgmProto.RuleDirection_OUT,
+ Action: mgmProto.RuleAction_ACCEPT,
+ Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.3",
- Direction: mgmProto.FirewallRule_OUT,
- Action: mgmProto.FirewallRule_ACCEPT,
- Protocol: mgmProto.FirewallRule_ALL,
+ Direction: mgmProto.RuleDirection_OUT,
+ Action: mgmProto.RuleAction_ACCEPT,
+ Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.4",
- Direction: mgmProto.FirewallRule_OUT,
- Action: mgmProto.FirewallRule_ACCEPT,
- Protocol: mgmProto.FirewallRule_ALL,
+ Direction: mgmProto.RuleDirection_OUT,
+ Action: mgmProto.RuleAction_ACCEPT,
+ Protocol: mgmProto.RuleProtocol_ALL,
},
},
}
@@ -199,13 +199,13 @@ func TestDefaultManagerSquashRules(t *testing.T) {
case r.PeerIP != "0.0.0.0":
t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP)
return
- case r.Direction != mgmProto.FirewallRule_IN:
+ case r.Direction != mgmProto.RuleDirection_IN:
t.Errorf("direction should be IN, got: %v", r.Direction)
return
- case r.Protocol != mgmProto.FirewallRule_ALL:
+ case r.Protocol != mgmProto.RuleProtocol_ALL:
t.Errorf("protocol should be ALL, got: %v", r.Protocol)
return
- case r.Action != mgmProto.FirewallRule_ACCEPT:
+ case r.Action != mgmProto.RuleAction_ACCEPT:
t.Errorf("action should be ACCEPT, got: %v", r.Action)
return
}
@@ -215,13 +215,13 @@ func TestDefaultManagerSquashRules(t *testing.T) {
case r.PeerIP != "0.0.0.0":
t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP)
return
- case r.Direction != mgmProto.FirewallRule_OUT:
+ case r.Direction != mgmProto.RuleDirection_OUT:
t.Errorf("direction should be OUT, got: %v", r.Direction)
return
- case r.Protocol != mgmProto.FirewallRule_ALL:
+ case r.Protocol != mgmProto.RuleProtocol_ALL:
t.Errorf("protocol should be ALL, got: %v", r.Protocol)
return
- case r.Action != mgmProto.FirewallRule_ACCEPT:
+ case r.Action != mgmProto.RuleAction_ACCEPT:
t.Errorf("action should be ACCEPT, got: %v", r.Action)
return
}
@@ -238,51 +238,51 @@ func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
- Direction: mgmProto.FirewallRule_IN,
- Action: mgmProto.FirewallRule_ACCEPT,
- Protocol: mgmProto.FirewallRule_ALL,
+ Direction: mgmProto.RuleDirection_IN,
+ Action: mgmProto.RuleAction_ACCEPT,
+ Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.2",
- Direction: mgmProto.FirewallRule_IN,
- Action: mgmProto.FirewallRule_ACCEPT,
- Protocol: mgmProto.FirewallRule_ALL,
+ Direction: mgmProto.RuleDirection_IN,
+ Action: mgmProto.RuleAction_ACCEPT,
+ Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.3",
- Direction: mgmProto.FirewallRule_IN,
- Action: mgmProto.FirewallRule_ACCEPT,
- Protocol: mgmProto.FirewallRule_ALL,
+ Direction: mgmProto.RuleDirection_IN,
+ Action: mgmProto.RuleAction_ACCEPT,
+ Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.4",
- Direction: mgmProto.FirewallRule_IN,
- Action: mgmProto.FirewallRule_ACCEPT,
- Protocol: mgmProto.FirewallRule_TCP,
+ Direction: mgmProto.RuleDirection_IN,
+ Action: mgmProto.RuleAction_ACCEPT,
+ Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.1",
- Direction: mgmProto.FirewallRule_OUT,
- Action: mgmProto.FirewallRule_ACCEPT,
- Protocol: mgmProto.FirewallRule_ALL,
+ Direction: mgmProto.RuleDirection_OUT,
+ Action: mgmProto.RuleAction_ACCEPT,
+ Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.2",
- Direction: mgmProto.FirewallRule_OUT,
- Action: mgmProto.FirewallRule_ACCEPT,
- Protocol: mgmProto.FirewallRule_ALL,
+ Direction: mgmProto.RuleDirection_OUT,
+ Action: mgmProto.RuleAction_ACCEPT,
+ Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.3",
- Direction: mgmProto.FirewallRule_OUT,
- Action: mgmProto.FirewallRule_ACCEPT,
- Protocol: mgmProto.FirewallRule_ALL,
+ Direction: mgmProto.RuleDirection_OUT,
+ Action: mgmProto.RuleAction_ACCEPT,
+ Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.4",
- Direction: mgmProto.FirewallRule_OUT,
- Action: mgmProto.FirewallRule_ACCEPT,
- Protocol: mgmProto.FirewallRule_UDP,
+ Direction: mgmProto.RuleDirection_OUT,
+ Action: mgmProto.RuleAction_ACCEPT,
+ Protocol: mgmProto.RuleProtocol_UDP,
},
},
}
@@ -308,21 +308,21 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
- Direction: mgmProto.FirewallRule_IN,
- Action: mgmProto.FirewallRule_ACCEPT,
- Protocol: mgmProto.FirewallRule_TCP,
+ Direction: mgmProto.RuleDirection_IN,
+ Action: mgmProto.RuleAction_ACCEPT,
+ Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.2",
- Direction: mgmProto.FirewallRule_IN,
- Action: mgmProto.FirewallRule_ACCEPT,
- Protocol: mgmProto.FirewallRule_TCP,
+ Direction: mgmProto.RuleDirection_IN,
+ Action: mgmProto.RuleAction_ACCEPT,
+ Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.3",
- Direction: mgmProto.FirewallRule_OUT,
- Action: mgmProto.FirewallRule_ACCEPT,
- Protocol: mgmProto.FirewallRule_UDP,
+ Direction: mgmProto.RuleDirection_OUT,
+ Action: mgmProto.RuleAction_ACCEPT,
+ Protocol: mgmProto.RuleProtocol_UDP,
},
},
}
@@ -357,8 +357,8 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
acl.ApplyFiltering(networkMap)
- if len(acl.rulesPairs) != 4 {
- t.Errorf("expect 4 rules (last must be SSH), got: %d", len(acl.rulesPairs))
+ if len(acl.peerRulesPairs) != 4 {
+ t.Errorf("expect 4 rules (last must be SSH), got: %d", len(acl.peerRulesPairs))
return
}
}
diff --git a/client/internal/engine.go b/client/internal/engine.go
index 463507ad8..998cbce2d 100644
--- a/client/internal/engine.go
+++ b/client/internal/engine.go
@@ -704,6 +704,11 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
return nil
}
+ // Apply ACLs in the beginning to avoid security leaks
+ if e.acl != nil {
+ e.acl.ApplyFiltering(networkMap)
+ }
+
protoRoutes := networkMap.GetRoutes()
if protoRoutes == nil {
protoRoutes = []*mgmProto.Route{}
@@ -770,10 +775,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
log.Errorf("failed to update dns server, err: %v", err)
}
- if e.acl != nil {
- e.acl.ApplyFiltering(networkMap)
- }
-
e.networkSerial = serial
// Test received (upstream) servers for availability right away instead of upon usage.
diff --git a/client/internal/routemanager/dynamic/route.go b/client/internal/routemanager/dynamic/route.go
index 5897031e7..e86a52810 100644
--- a/client/internal/routemanager/dynamic/route.go
+++ b/client/internal/routemanager/dynamic/route.go
@@ -303,7 +303,7 @@ func (r *Route) addRoutes(domain domain.Domain, prefixes []netip.Prefix) ([]neti
var merr *multierror.Error
for _, prefix := range prefixes {
- if _, err := r.routeRefCounter.Increment(prefix, nil); err != nil {
+ if _, err := r.routeRefCounter.Increment(prefix, struct{}{}); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add dynamic route for IP %s: %w", prefix, err))
continue
}
diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go
index cdfd322bd..d97fe631f 100644
--- a/client/internal/routemanager/manager.go
+++ b/client/internal/routemanager/manager.go
@@ -87,10 +87,10 @@ func NewManager(
}
dm.routeRefCounter = refcounter.New(
- func(prefix netip.Prefix, _ any) (any, error) {
- return nil, sysOps.AddVPNRoute(prefix, wgInterface.ToInterface())
+ func(prefix netip.Prefix, _ struct{}) (struct{}, error) {
+ return struct{}{}, sysOps.AddVPNRoute(prefix, wgInterface.ToInterface())
},
- func(prefix netip.Prefix, _ any) error {
+ func(prefix netip.Prefix, _ struct{}) error {
return sysOps.RemoveVPNRoute(prefix, wgInterface.ToInterface())
},
)
diff --git a/client/internal/routemanager/refcounter/refcounter.go b/client/internal/routemanager/refcounter/refcounter.go
index f1d696ad9..65ea0f708 100644
--- a/client/internal/routemanager/refcounter/refcounter.go
+++ b/client/internal/routemanager/refcounter/refcounter.go
@@ -3,7 +3,8 @@ package refcounter
import (
"errors"
"fmt"
- "net/netip"
+ "runtime"
+ "strings"
"sync"
"github.com/hashicorp/go-multierror"
@@ -12,118 +13,153 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors"
)
-// ErrIgnore can be returned by AddFunc to indicate that the counter not be incremented for the given prefix.
+const logLevel = log.TraceLevel
+
+// ErrIgnore can be returned by AddFunc to indicate that the counter should not be incremented for the given key.
var ErrIgnore = errors.New("ignore")
+// Ref holds the reference count and associated data for a key.
type Ref[O any] struct {
Count int
Out O
}
-type AddFunc[I, O any] func(prefix netip.Prefix, in I) (out O, err error)
-type RemoveFunc[I, O any] func(prefix netip.Prefix, out O) error
+// AddFunc is the function type for adding a new key.
+// Key is the type of the key (e.g., netip.Prefix).
+type AddFunc[Key, I, O any] func(key Key, in I) (out O, err error)
-type Counter[I, O any] struct {
- // refCountMap keeps track of the reference Ref for prefixes
- refCountMap map[netip.Prefix]Ref[O]
+// RemoveFunc is the function type for removing a key.
+type RemoveFunc[Key, O any] func(key Key, out O) error
+
+// Counter is a generic reference counter for managing keys and their associated data.
+// Key: The type of the key (e.g., netip.Prefix, string).
+//
+// I: The input type for the AddFunc. It is the input type for additional data needed
+// when adding a key, it is passed as the second argument to AddFunc.
+//
+// O: The output type for the AddFunc and RemoveFunc. This is the output returned by AddFunc.
+// It is stored and passed to RemoveFunc when the reference count reaches 0.
+//
+// The types can be aliased to a specific type using the following syntax:
+//
+// type RouteRefCounter = Counter[netip.Prefix, any, any]
+type Counter[Key comparable, I, O any] struct {
+ // refCountMap keeps track of the reference Ref for keys
+ refCountMap map[Key]Ref[O]
refCountMu sync.Mutex
- // idMap keeps track of the prefixes associated with an ID for removal
- idMap map[string][]netip.Prefix
+ // idMap keeps track of the keys associated with an ID for removal
+ idMap map[string][]Key
idMu sync.Mutex
- add AddFunc[I, O]
- remove RemoveFunc[I, O]
+ add AddFunc[Key, I, O]
+ remove RemoveFunc[Key, O]
}
-// New creates a new Counter instance
-func New[I, O any](add AddFunc[I, O], remove RemoveFunc[I, O]) *Counter[I, O] {
- return &Counter[I, O]{
- refCountMap: map[netip.Prefix]Ref[O]{},
- idMap: map[string][]netip.Prefix{},
+// New creates a new Counter instance.
+// Usage example:
+//
+// counter := New[netip.Prefix, string, string](
+// func(key netip.Prefix, in string) (out string, err error) { ... },
+// func(key netip.Prefix, out string) error { ... },`
+// )
+func New[Key comparable, I, O any](add AddFunc[Key, I, O], remove RemoveFunc[Key, O]) *Counter[Key, I, O] {
+ return &Counter[Key, I, O]{
+ refCountMap: map[Key]Ref[O]{},
+ idMap: map[string][]Key{},
add: add,
remove: remove,
}
}
-// Increment increments the reference count for the given prefix.
-// If this is the first reference to the prefix, the AddFunc is called.
-func (rm *Counter[I, O]) Increment(prefix netip.Prefix, in I) (Ref[O], error) {
+// Get retrieves the current reference count and associated data for a key.
+// If the key doesn't exist, it returns a zero value Ref and false.
+func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
- ref := rm.refCountMap[prefix]
- log.Tracef("Increasing ref count %d for prefix %s with [%v]", ref.Count, prefix, ref.Out)
+ ref, ok := rm.refCountMap[key]
+ return ref, ok
+}
- // Call AddFunc only if it's a new prefix
+// Increment increments the reference count for the given key.
+// If this is the first reference to the key, the AddFunc is called.
+func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) {
+ rm.refCountMu.Lock()
+ defer rm.refCountMu.Unlock()
+
+ ref := rm.refCountMap[key]
+ logCallerF("Increasing ref count [%d -> %d] for key %v with In [%v] Out [%v]", ref.Count, ref.Count+1, key, in, ref.Out)
+
+ // Call AddFunc only if it's a new key
if ref.Count == 0 {
- log.Tracef("Adding for prefix %s with [%v]", prefix, ref.Out)
- out, err := rm.add(prefix, in)
+ logCallerF("Calling add for key %v", key)
+ out, err := rm.add(key, in)
if errors.Is(err, ErrIgnore) {
return ref, nil
}
if err != nil {
- return ref, fmt.Errorf("failed to add for prefix %s: %w", prefix, err)
+ return ref, fmt.Errorf("failed to add for key %v: %w", key, err)
}
ref.Out = out
}
ref.Count++
- rm.refCountMap[prefix] = ref
+ rm.refCountMap[key] = ref
return ref, nil
}
-// IncrementWithID increments the reference count for the given prefix and groups it under the given ID.
-// If this is the first reference to the prefix, the AddFunc is called.
-func (rm *Counter[I, O]) IncrementWithID(id string, prefix netip.Prefix, in I) (Ref[O], error) {
+// IncrementWithID increments the reference count for the given key and groups it under the given ID.
+// If this is the first reference to the key, the AddFunc is called.
+func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O], error) {
rm.idMu.Lock()
defer rm.idMu.Unlock()
- ref, err := rm.Increment(prefix, in)
+ ref, err := rm.Increment(key, in)
if err != nil {
return ref, fmt.Errorf("with ID: %w", err)
}
- rm.idMap[id] = append(rm.idMap[id], prefix)
+ rm.idMap[id] = append(rm.idMap[id], key)
return ref, nil
}
-// Decrement decrements the reference count for the given prefix.
+// Decrement decrements the reference count for the given key.
// If the reference count reaches 0, the RemoveFunc is called.
-func (rm *Counter[I, O]) Decrement(prefix netip.Prefix) (Ref[O], error) {
+func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
- ref, ok := rm.refCountMap[prefix]
+ ref, ok := rm.refCountMap[key]
if !ok {
- log.Tracef("No reference found for prefix %s", prefix)
+ logCallerF("No reference found for key %v", key)
return ref, nil
}
- log.Tracef("Decreasing ref count %d for prefix %s with [%v]", ref.Count, prefix, ref.Out)
+ logCallerF("Decreasing ref count [%d -> %d] for key %v with Out [%v]", ref.Count, ref.Count-1, key, ref.Out)
if ref.Count == 1 {
- log.Tracef("Removing for prefix %s with [%v]", prefix, ref.Out)
- if err := rm.remove(prefix, ref.Out); err != nil {
- return ref, fmt.Errorf("remove for prefix %s: %w", prefix, err)
+ logCallerF("Calling remove for key %v", key)
+ if err := rm.remove(key, ref.Out); err != nil {
+ return ref, fmt.Errorf("remove for key %v: %w", key, err)
}
- delete(rm.refCountMap, prefix)
+ delete(rm.refCountMap, key)
} else {
ref.Count--
- rm.refCountMap[prefix] = ref
+ rm.refCountMap[key] = ref
}
return ref, nil
}
-// DecrementWithID decrements the reference count for all prefixes associated with the given ID.
+// DecrementWithID decrements the reference count for all keys associated with the given ID.
// If the reference count reaches 0, the RemoveFunc is called.
-func (rm *Counter[I, O]) DecrementWithID(id string) error {
+func (rm *Counter[Key, I, O]) DecrementWithID(id string) error {
rm.idMu.Lock()
defer rm.idMu.Unlock()
var merr *multierror.Error
- for _, prefix := range rm.idMap[id] {
- if _, err := rm.Decrement(prefix); err != nil {
+ for _, key := range rm.idMap[id] {
+ if _, err := rm.Decrement(key); err != nil {
merr = multierror.Append(merr, err)
}
}
@@ -132,24 +168,77 @@ func (rm *Counter[I, O]) DecrementWithID(id string) error {
return nberrors.FormatErrorOrNil(merr)
}
-// Flush removes all references and calls RemoveFunc for each prefix.
-func (rm *Counter[I, O]) Flush() error {
+// Flush removes all references and calls RemoveFunc for each key.
+func (rm *Counter[Key, I, O]) Flush() error {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
rm.idMu.Lock()
defer rm.idMu.Unlock()
var merr *multierror.Error
- for prefix := range rm.refCountMap {
- log.Tracef("Removing for prefix %s", prefix)
- ref := rm.refCountMap[prefix]
- if err := rm.remove(prefix, ref.Out); err != nil {
- merr = multierror.Append(merr, fmt.Errorf("remove for prefix %s: %w", prefix, err))
+ for key := range rm.refCountMap {
+ logCallerF("Calling remove for key %v", key)
+ ref := rm.refCountMap[key]
+ if err := rm.remove(key, ref.Out); err != nil {
+ merr = multierror.Append(merr, fmt.Errorf("remove for key %v: %w", key, err))
}
}
- rm.refCountMap = map[netip.Prefix]Ref[O]{}
- rm.idMap = map[string][]netip.Prefix{}
+ clear(rm.refCountMap)
+ clear(rm.idMap)
return nberrors.FormatErrorOrNil(merr)
}
+
+// Clear removes all references without calling RemoveFunc.
+func (rm *Counter[Key, I, O]) Clear() {
+ rm.refCountMu.Lock()
+ defer rm.refCountMu.Unlock()
+ rm.idMu.Lock()
+ defer rm.idMu.Unlock()
+
+ clear(rm.refCountMap)
+ clear(rm.idMap)
+}
+
+func getCallerInfo(depth int, maxDepth int) (string, bool) {
+ if depth >= maxDepth {
+ return "", false
+ }
+
+ pc, _, _, ok := runtime.Caller(depth)
+ if !ok {
+ return "", false
+ }
+
+ if details := runtime.FuncForPC(pc); details != nil {
+ name := details.Name()
+
+ lastDotIndex := strings.LastIndex(name, "/")
+ if lastDotIndex != -1 {
+ name = name[lastDotIndex+1:]
+ }
+
+ if strings.HasPrefix(name, "refcounter.") {
+ // +2 to account for recursion
+ return getCallerInfo(depth+2, maxDepth)
+ }
+
+ return name, true
+ }
+
+ return "", false
+}
+
+// logCaller logs a message with the package name and method of the function that called the current function.
+func logCallerF(format string, args ...interface{}) {
+ if log.GetLevel() < logLevel {
+ return
+ }
+
+ if callerName, ok := getCallerInfo(3, 18); ok {
+ format = fmt.Sprintf("[%s] %s", callerName, format)
+ }
+
+ log.StandardLogger().Logf(logLevel, format, args...)
+}
diff --git a/client/internal/routemanager/refcounter/types.go b/client/internal/routemanager/refcounter/types.go
index 6753b64ef..aadac3e25 100644
--- a/client/internal/routemanager/refcounter/types.go
+++ b/client/internal/routemanager/refcounter/types.go
@@ -1,7 +1,9 @@
package refcounter
+import "net/netip"
+
// RouteRefCounter is a Counter for Route, it doesn't take any input on Increment and doesn't use any output on Decrement
-type RouteRefCounter = Counter[any, any]
+type RouteRefCounter = Counter[netip.Prefix, struct{}, struct{}]
// AllowedIPsRefCounter is a Counter for AllowedIPs, it takes a peer key on Increment and passes it back to Decrement
-type AllowedIPsRefCounter = Counter[string, string]
+type AllowedIPsRefCounter = Counter[netip.Prefix, string, string]
diff --git a/client/internal/routemanager/server_nonandroid.go b/client/internal/routemanager/server_nonandroid.go
index 43a266cd2..1d1a4b063 100644
--- a/client/internal/routemanager/server_nonandroid.go
+++ b/client/internal/routemanager/server_nonandroid.go
@@ -94,7 +94,7 @@ func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error
return fmt.Errorf("parse prefix: %w", err)
}
- err = m.firewall.RemoveRoutingRules(routerPair)
+ err = m.firewall.RemoveNatRule(routerPair)
if err != nil {
return fmt.Errorf("remove routing rules: %w", err)
}
@@ -123,7 +123,7 @@ func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error {
return fmt.Errorf("parse prefix: %w", err)
}
- err = m.firewall.InsertRoutingRules(routerPair)
+ err = m.firewall.AddNatRule(routerPair)
if err != nil {
return fmt.Errorf("insert routing rules: %w", err)
}
@@ -157,7 +157,7 @@ func (m *defaultServerRouter) cleanUp() {
continue
}
- err = m.firewall.RemoveRoutingRules(routerPair)
+ err = m.firewall.RemoveNatRule(routerPair)
if err != nil {
log.Errorf("Failed to remove cleanup route: %v", err)
}
@@ -173,15 +173,15 @@ func routeToRouterPair(route *route.Route) (firewall.RouterPair, error) {
// TODO: add ipv6
source := getDefaultPrefix(route.Network)
- destination := route.Network.Masked().String()
+ destination := route.Network.Masked()
if route.IsDynamic() {
- // TODO: add ipv6
- destination = "0.0.0.0/0"
+ // TODO: add ipv6 additionally
+ destination = getDefaultPrefix(destination)
}
return firewall.RouterPair{
- ID: string(route.ID),
- Source: source.String(),
+ ID: route.ID,
+ Source: source,
Destination: destination,
Masquerade: route.Masquerade,
}, nil
diff --git a/client/internal/routemanager/static/route.go b/client/internal/routemanager/static/route.go
index 88cca522a..98c34dbee 100644
--- a/client/internal/routemanager/static/route.go
+++ b/client/internal/routemanager/static/route.go
@@ -30,7 +30,7 @@ func (r *Route) String() string {
}
func (r *Route) AddRoute(context.Context) error {
- _, err := r.routeRefCounter.Increment(r.route.Network, nil)
+ _, err := r.routeRefCounter.Increment(r.route.Network, struct{}{})
return err
}
diff --git a/client/internal/routemanager/systemops/systemops.go b/client/internal/routemanager/systemops/systemops.go
index ae27b0123..10944c1e2 100644
--- a/client/internal/routemanager/systemops/systemops.go
+++ b/client/internal/routemanager/systemops/systemops.go
@@ -15,7 +15,7 @@ type Nexthop struct {
Intf *net.Interface
}
-type ExclusionCounter = refcounter.Counter[any, Nexthop]
+type ExclusionCounter = refcounter.Counter[netip.Prefix, struct{}, Nexthop]
type SysOps struct {
refCounter *ExclusionCounter
diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go
index d76824c10..90f06ba78 100644
--- a/client/internal/routemanager/systemops/systemops_generic.go
+++ b/client/internal/routemanager/systemops/systemops_generic.go
@@ -41,7 +41,7 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP) (nbnet.AddHookFunc, nbn
}
refCounter := refcounter.New(
- func(prefix netip.Prefix, _ any) (Nexthop, error) {
+ func(prefix netip.Prefix, _ struct{}) (Nexthop, error) {
initialNexthop := initialNextHopV4
if prefix.Addr().Is6() {
initialNexthop = initialNextHopV6
@@ -317,7 +317,7 @@ func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.Re
return fmt.Errorf("convert ip to prefix: %w", err)
}
- if _, err := r.refCounter.IncrementWithID(string(connID), prefix, nil); err != nil {
+ if _, err := r.refCounter.IncrementWithID(string(connID), prefix, struct{}{}); err != nil {
return fmt.Errorf("adding route reference: %v", err)
}
diff --git a/management/proto/management.pb.go b/management/proto/management.pb.go
index 48f048c4c..672b2a102 100644
--- a/management/proto/management.pb.go
+++ b/management/proto/management.pb.go
@@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.26.0
-// protoc v3.21.12
+// protoc v4.23.4
// source: management.proto
package proto
@@ -21,6 +21,153 @@ const (
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
+type RuleProtocol int32
+
+const (
+ RuleProtocol_UNKNOWN RuleProtocol = 0
+ RuleProtocol_ALL RuleProtocol = 1
+ RuleProtocol_TCP RuleProtocol = 2
+ RuleProtocol_UDP RuleProtocol = 3
+ RuleProtocol_ICMP RuleProtocol = 4
+)
+
+// Enum value maps for RuleProtocol.
+var (
+ RuleProtocol_name = map[int32]string{
+ 0: "UNKNOWN",
+ 1: "ALL",
+ 2: "TCP",
+ 3: "UDP",
+ 4: "ICMP",
+ }
+ RuleProtocol_value = map[string]int32{
+ "UNKNOWN": 0,
+ "ALL": 1,
+ "TCP": 2,
+ "UDP": 3,
+ "ICMP": 4,
+ }
+)
+
+func (x RuleProtocol) Enum() *RuleProtocol {
+ p := new(RuleProtocol)
+ *p = x
+ return p
+}
+
+func (x RuleProtocol) String() string {
+ return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x))
+}
+
+func (RuleProtocol) Descriptor() protoreflect.EnumDescriptor {
+ return file_management_proto_enumTypes[0].Descriptor()
+}
+
+func (RuleProtocol) Type() protoreflect.EnumType {
+ return &file_management_proto_enumTypes[0]
+}
+
+func (x RuleProtocol) Number() protoreflect.EnumNumber {
+ return protoreflect.EnumNumber(x)
+}
+
+// Deprecated: Use RuleProtocol.Descriptor instead.
+func (RuleProtocol) EnumDescriptor() ([]byte, []int) {
+ return file_management_proto_rawDescGZIP(), []int{0}
+}
+
+type RuleDirection int32
+
+const (
+ RuleDirection_IN RuleDirection = 0
+ RuleDirection_OUT RuleDirection = 1
+)
+
+// Enum value maps for RuleDirection.
+var (
+ RuleDirection_name = map[int32]string{
+ 0: "IN",
+ 1: "OUT",
+ }
+ RuleDirection_value = map[string]int32{
+ "IN": 0,
+ "OUT": 1,
+ }
+)
+
+func (x RuleDirection) Enum() *RuleDirection {
+ p := new(RuleDirection)
+ *p = x
+ return p
+}
+
+func (x RuleDirection) String() string {
+ return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x))
+}
+
+func (RuleDirection) Descriptor() protoreflect.EnumDescriptor {
+ return file_management_proto_enumTypes[1].Descriptor()
+}
+
+func (RuleDirection) Type() protoreflect.EnumType {
+ return &file_management_proto_enumTypes[1]
+}
+
+func (x RuleDirection) Number() protoreflect.EnumNumber {
+ return protoreflect.EnumNumber(x)
+}
+
+// Deprecated: Use RuleDirection.Descriptor instead.
+func (RuleDirection) EnumDescriptor() ([]byte, []int) {
+ return file_management_proto_rawDescGZIP(), []int{1}
+}
+
+type RuleAction int32
+
+const (
+ RuleAction_ACCEPT RuleAction = 0
+ RuleAction_DROP RuleAction = 1
+)
+
+// Enum value maps for RuleAction.
+var (
+ RuleAction_name = map[int32]string{
+ 0: "ACCEPT",
+ 1: "DROP",
+ }
+ RuleAction_value = map[string]int32{
+ "ACCEPT": 0,
+ "DROP": 1,
+ }
+)
+
+func (x RuleAction) Enum() *RuleAction {
+ p := new(RuleAction)
+ *p = x
+ return p
+}
+
+func (x RuleAction) String() string {
+ return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x))
+}
+
+func (RuleAction) Descriptor() protoreflect.EnumDescriptor {
+ return file_management_proto_enumTypes[2].Descriptor()
+}
+
+func (RuleAction) Type() protoreflect.EnumType {
+ return &file_management_proto_enumTypes[2]
+}
+
+func (x RuleAction) Number() protoreflect.EnumNumber {
+ return protoreflect.EnumNumber(x)
+}
+
+// Deprecated: Use RuleAction.Descriptor instead.
+func (RuleAction) EnumDescriptor() ([]byte, []int) {
+ return file_management_proto_rawDescGZIP(), []int{2}
+}
+
type HostConfig_Protocol int32
const (
@@ -60,11 +207,11 @@ func (x HostConfig_Protocol) String() string {
}
func (HostConfig_Protocol) Descriptor() protoreflect.EnumDescriptor {
- return file_management_proto_enumTypes[0].Descriptor()
+ return file_management_proto_enumTypes[3].Descriptor()
}
func (HostConfig_Protocol) Type() protoreflect.EnumType {
- return &file_management_proto_enumTypes[0]
+ return &file_management_proto_enumTypes[3]
}
func (x HostConfig_Protocol) Number() protoreflect.EnumNumber {
@@ -103,11 +250,11 @@ func (x DeviceAuthorizationFlowProvider) String() string {
}
func (DeviceAuthorizationFlowProvider) Descriptor() protoreflect.EnumDescriptor {
- return file_management_proto_enumTypes[1].Descriptor()
+ return file_management_proto_enumTypes[4].Descriptor()
}
func (DeviceAuthorizationFlowProvider) Type() protoreflect.EnumType {
- return &file_management_proto_enumTypes[1]
+ return &file_management_proto_enumTypes[4]
}
func (x DeviceAuthorizationFlowProvider) Number() protoreflect.EnumNumber {
@@ -119,153 +266,6 @@ func (DeviceAuthorizationFlowProvider) EnumDescriptor() ([]byte, []int) {
return file_management_proto_rawDescGZIP(), []int{21, 0}
}
-type FirewallRuleDirection int32
-
-const (
- FirewallRule_IN FirewallRuleDirection = 0
- FirewallRule_OUT FirewallRuleDirection = 1
-)
-
-// Enum value maps for FirewallRuleDirection.
-var (
- FirewallRuleDirection_name = map[int32]string{
- 0: "IN",
- 1: "OUT",
- }
- FirewallRuleDirection_value = map[string]int32{
- "IN": 0,
- "OUT": 1,
- }
-)
-
-func (x FirewallRuleDirection) Enum() *FirewallRuleDirection {
- p := new(FirewallRuleDirection)
- *p = x
- return p
-}
-
-func (x FirewallRuleDirection) String() string {
- return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x))
-}
-
-func (FirewallRuleDirection) Descriptor() protoreflect.EnumDescriptor {
- return file_management_proto_enumTypes[2].Descriptor()
-}
-
-func (FirewallRuleDirection) Type() protoreflect.EnumType {
- return &file_management_proto_enumTypes[2]
-}
-
-func (x FirewallRuleDirection) Number() protoreflect.EnumNumber {
- return protoreflect.EnumNumber(x)
-}
-
-// Deprecated: Use FirewallRuleDirection.Descriptor instead.
-func (FirewallRuleDirection) EnumDescriptor() ([]byte, []int) {
- return file_management_proto_rawDescGZIP(), []int{31, 0}
-}
-
-type FirewallRuleAction int32
-
-const (
- FirewallRule_ACCEPT FirewallRuleAction = 0
- FirewallRule_DROP FirewallRuleAction = 1
-)
-
-// Enum value maps for FirewallRuleAction.
-var (
- FirewallRuleAction_name = map[int32]string{
- 0: "ACCEPT",
- 1: "DROP",
- }
- FirewallRuleAction_value = map[string]int32{
- "ACCEPT": 0,
- "DROP": 1,
- }
-)
-
-func (x FirewallRuleAction) Enum() *FirewallRuleAction {
- p := new(FirewallRuleAction)
- *p = x
- return p
-}
-
-func (x FirewallRuleAction) String() string {
- return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x))
-}
-
-func (FirewallRuleAction) Descriptor() protoreflect.EnumDescriptor {
- return file_management_proto_enumTypes[3].Descriptor()
-}
-
-func (FirewallRuleAction) Type() protoreflect.EnumType {
- return &file_management_proto_enumTypes[3]
-}
-
-func (x FirewallRuleAction) Number() protoreflect.EnumNumber {
- return protoreflect.EnumNumber(x)
-}
-
-// Deprecated: Use FirewallRuleAction.Descriptor instead.
-func (FirewallRuleAction) EnumDescriptor() ([]byte, []int) {
- return file_management_proto_rawDescGZIP(), []int{31, 1}
-}
-
-type FirewallRuleProtocol int32
-
-const (
- FirewallRule_UNKNOWN FirewallRuleProtocol = 0
- FirewallRule_ALL FirewallRuleProtocol = 1
- FirewallRule_TCP FirewallRuleProtocol = 2
- FirewallRule_UDP FirewallRuleProtocol = 3
- FirewallRule_ICMP FirewallRuleProtocol = 4
-)
-
-// Enum value maps for FirewallRuleProtocol.
-var (
- FirewallRuleProtocol_name = map[int32]string{
- 0: "UNKNOWN",
- 1: "ALL",
- 2: "TCP",
- 3: "UDP",
- 4: "ICMP",
- }
- FirewallRuleProtocol_value = map[string]int32{
- "UNKNOWN": 0,
- "ALL": 1,
- "TCP": 2,
- "UDP": 3,
- "ICMP": 4,
- }
-)
-
-func (x FirewallRuleProtocol) Enum() *FirewallRuleProtocol {
- p := new(FirewallRuleProtocol)
- *p = x
- return p
-}
-
-func (x FirewallRuleProtocol) String() string {
- return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x))
-}
-
-func (FirewallRuleProtocol) Descriptor() protoreflect.EnumDescriptor {
- return file_management_proto_enumTypes[4].Descriptor()
-}
-
-func (FirewallRuleProtocol) Type() protoreflect.EnumType {
- return &file_management_proto_enumTypes[4]
-}
-
-func (x FirewallRuleProtocol) Number() protoreflect.EnumNumber {
- return protoreflect.EnumNumber(x)
-}
-
-// Deprecated: Use FirewallRuleProtocol.Descriptor instead.
-func (FirewallRuleProtocol) EnumDescriptor() ([]byte, []int) {
- return file_management_proto_rawDescGZIP(), []int{31, 2}
-}
-
type EncryptedMessage struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
@@ -1482,6 +1482,10 @@ type NetworkMap struct {
FirewallRules []*FirewallRule `protobuf:"bytes,8,rep,name=FirewallRules,proto3" json:"FirewallRules,omitempty"`
// firewallRulesIsEmpty indicates whether FirewallRule array is empty or not to bypass protobuf null and empty array equality.
FirewallRulesIsEmpty bool `protobuf:"varint,9,opt,name=firewallRulesIsEmpty,proto3" json:"firewallRulesIsEmpty,omitempty"`
+ // RoutesFirewallRules represents a list of routes firewall rules to be applied to peer
+ RoutesFirewallRules []*RouteFirewallRule `protobuf:"bytes,10,rep,name=routesFirewallRules,proto3" json:"routesFirewallRules,omitempty"`
+ // RoutesFirewallRulesIsEmpty indicates whether RouteFirewallRule array is empty or not to bypass protobuf null and empty array equality.
+ RoutesFirewallRulesIsEmpty bool `protobuf:"varint,11,opt,name=routesFirewallRulesIsEmpty,proto3" json:"routesFirewallRulesIsEmpty,omitempty"`
}
func (x *NetworkMap) Reset() {
@@ -1579,6 +1583,20 @@ func (x *NetworkMap) GetFirewallRulesIsEmpty() bool {
return false
}
+func (x *NetworkMap) GetRoutesFirewallRules() []*RouteFirewallRule {
+ if x != nil {
+ return x.RoutesFirewallRules
+ }
+ return nil
+}
+
+func (x *NetworkMap) GetRoutesFirewallRulesIsEmpty() bool {
+ if x != nil {
+ return x.RoutesFirewallRulesIsEmpty
+ }
+ return false
+}
+
// RemotePeerConfig represents a configuration of a remote peer.
// The properties are used to configure WireGuard Peers sections
type RemotePeerConfig struct {
@@ -2487,11 +2505,11 @@ type FirewallRule struct {
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
- PeerIP string `protobuf:"bytes,1,opt,name=PeerIP,proto3" json:"PeerIP,omitempty"`
- Direction FirewallRuleDirection `protobuf:"varint,2,opt,name=Direction,proto3,enum=management.FirewallRuleDirection" json:"Direction,omitempty"`
- Action FirewallRuleAction `protobuf:"varint,3,opt,name=Action,proto3,enum=management.FirewallRuleAction" json:"Action,omitempty"`
- Protocol FirewallRuleProtocol `protobuf:"varint,4,opt,name=Protocol,proto3,enum=management.FirewallRuleProtocol" json:"Protocol,omitempty"`
- Port string `protobuf:"bytes,5,opt,name=Port,proto3" json:"Port,omitempty"`
+ PeerIP string `protobuf:"bytes,1,opt,name=PeerIP,proto3" json:"PeerIP,omitempty"`
+ Direction RuleDirection `protobuf:"varint,2,opt,name=Direction,proto3,enum=management.RuleDirection" json:"Direction,omitempty"`
+ Action RuleAction `protobuf:"varint,3,opt,name=Action,proto3,enum=management.RuleAction" json:"Action,omitempty"`
+ Protocol RuleProtocol `protobuf:"varint,4,opt,name=Protocol,proto3,enum=management.RuleProtocol" json:"Protocol,omitempty"`
+ Port string `protobuf:"bytes,5,opt,name=Port,proto3" json:"Port,omitempty"`
}
func (x *FirewallRule) Reset() {
@@ -2533,25 +2551,25 @@ func (x *FirewallRule) GetPeerIP() string {
return ""
}
-func (x *FirewallRule) GetDirection() FirewallRuleDirection {
+func (x *FirewallRule) GetDirection() RuleDirection {
if x != nil {
return x.Direction
}
- return FirewallRule_IN
+ return RuleDirection_IN
}
-func (x *FirewallRule) GetAction() FirewallRuleAction {
+func (x *FirewallRule) GetAction() RuleAction {
if x != nil {
return x.Action
}
- return FirewallRule_ACCEPT
+ return RuleAction_ACCEPT
}
-func (x *FirewallRule) GetProtocol() FirewallRuleProtocol {
+func (x *FirewallRule) GetProtocol() RuleProtocol {
if x != nil {
return x.Protocol
}
- return FirewallRule_UNKNOWN
+ return RuleProtocol_UNKNOWN
}
func (x *FirewallRule) GetPort() string {
@@ -2663,6 +2681,236 @@ func (x *Checks) GetFiles() []string {
return nil
}
+type PortInfo struct {
+ state protoimpl.MessageState
+ sizeCache protoimpl.SizeCache
+ unknownFields protoimpl.UnknownFields
+
+ // Types that are assignable to PortSelection:
+ //
+ // *PortInfo_Port
+ // *PortInfo_Range_
+ PortSelection isPortInfo_PortSelection `protobuf_oneof:"portSelection"`
+}
+
+func (x *PortInfo) Reset() {
+ *x = PortInfo{}
+ if protoimpl.UnsafeEnabled {
+ mi := &file_management_proto_msgTypes[34]
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ ms.StoreMessageInfo(mi)
+ }
+}
+
+func (x *PortInfo) String() string {
+ return protoimpl.X.MessageStringOf(x)
+}
+
+func (*PortInfo) ProtoMessage() {}
+
+func (x *PortInfo) ProtoReflect() protoreflect.Message {
+ mi := &file_management_proto_msgTypes[34]
+ if protoimpl.UnsafeEnabled && x != nil {
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ if ms.LoadMessageInfo() == nil {
+ ms.StoreMessageInfo(mi)
+ }
+ return ms
+ }
+ return mi.MessageOf(x)
+}
+
+// Deprecated: Use PortInfo.ProtoReflect.Descriptor instead.
+func (*PortInfo) Descriptor() ([]byte, []int) {
+ return file_management_proto_rawDescGZIP(), []int{34}
+}
+
+func (m *PortInfo) GetPortSelection() isPortInfo_PortSelection {
+ if m != nil {
+ return m.PortSelection
+ }
+ return nil
+}
+
+func (x *PortInfo) GetPort() uint32 {
+ if x, ok := x.GetPortSelection().(*PortInfo_Port); ok {
+ return x.Port
+ }
+ return 0
+}
+
+func (x *PortInfo) GetRange() *PortInfo_Range {
+ if x, ok := x.GetPortSelection().(*PortInfo_Range_); ok {
+ return x.Range
+ }
+ return nil
+}
+
+type isPortInfo_PortSelection interface {
+ isPortInfo_PortSelection()
+}
+
+type PortInfo_Port struct {
+ Port uint32 `protobuf:"varint,1,opt,name=port,proto3,oneof"`
+}
+
+type PortInfo_Range_ struct {
+ Range *PortInfo_Range `protobuf:"bytes,2,opt,name=range,proto3,oneof"`
+}
+
+func (*PortInfo_Port) isPortInfo_PortSelection() {}
+
+func (*PortInfo_Range_) isPortInfo_PortSelection() {}
+
+// RouteFirewallRule signifies a firewall rule applicable for a routed network.
+type RouteFirewallRule struct {
+ state protoimpl.MessageState
+ sizeCache protoimpl.SizeCache
+ unknownFields protoimpl.UnknownFields
+
+ // sourceRanges IP ranges of the routing peers.
+ SourceRanges []string `protobuf:"bytes,1,rep,name=sourceRanges,proto3" json:"sourceRanges,omitempty"`
+ // Action to be taken by the firewall when the rule is applicable.
+ Action RuleAction `protobuf:"varint,2,opt,name=action,proto3,enum=management.RuleAction" json:"action,omitempty"`
+ // Network prefix for the routed network.
+ Destination string `protobuf:"bytes,3,opt,name=destination,proto3" json:"destination,omitempty"`
+ // Protocol of the routed network.
+ Protocol RuleProtocol `protobuf:"varint,4,opt,name=protocol,proto3,enum=management.RuleProtocol" json:"protocol,omitempty"`
+ // Details about the port.
+ PortInfo *PortInfo `protobuf:"bytes,5,opt,name=portInfo,proto3" json:"portInfo,omitempty"`
+ // IsDynamic indicates if the route is a DNS route.
+ IsDynamic bool `protobuf:"varint,6,opt,name=isDynamic,proto3" json:"isDynamic,omitempty"`
+}
+
+func (x *RouteFirewallRule) Reset() {
+ *x = RouteFirewallRule{}
+ if protoimpl.UnsafeEnabled {
+ mi := &file_management_proto_msgTypes[35]
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ ms.StoreMessageInfo(mi)
+ }
+}
+
+func (x *RouteFirewallRule) String() string {
+ return protoimpl.X.MessageStringOf(x)
+}
+
+func (*RouteFirewallRule) ProtoMessage() {}
+
+func (x *RouteFirewallRule) ProtoReflect() protoreflect.Message {
+ mi := &file_management_proto_msgTypes[35]
+ if protoimpl.UnsafeEnabled && x != nil {
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ if ms.LoadMessageInfo() == nil {
+ ms.StoreMessageInfo(mi)
+ }
+ return ms
+ }
+ return mi.MessageOf(x)
+}
+
+// Deprecated: Use RouteFirewallRule.ProtoReflect.Descriptor instead.
+func (*RouteFirewallRule) Descriptor() ([]byte, []int) {
+ return file_management_proto_rawDescGZIP(), []int{35}
+}
+
+func (x *RouteFirewallRule) GetSourceRanges() []string {
+ if x != nil {
+ return x.SourceRanges
+ }
+ return nil
+}
+
+func (x *RouteFirewallRule) GetAction() RuleAction {
+ if x != nil {
+ return x.Action
+ }
+ return RuleAction_ACCEPT
+}
+
+func (x *RouteFirewallRule) GetDestination() string {
+ if x != nil {
+ return x.Destination
+ }
+ return ""
+}
+
+func (x *RouteFirewallRule) GetProtocol() RuleProtocol {
+ if x != nil {
+ return x.Protocol
+ }
+ return RuleProtocol_UNKNOWN
+}
+
+func (x *RouteFirewallRule) GetPortInfo() *PortInfo {
+ if x != nil {
+ return x.PortInfo
+ }
+ return nil
+}
+
+func (x *RouteFirewallRule) GetIsDynamic() bool {
+ if x != nil {
+ return x.IsDynamic
+ }
+ return false
+}
+
+type PortInfo_Range struct {
+ state protoimpl.MessageState
+ sizeCache protoimpl.SizeCache
+ unknownFields protoimpl.UnknownFields
+
+ Start uint32 `protobuf:"varint,1,opt,name=start,proto3" json:"start,omitempty"`
+ End uint32 `protobuf:"varint,2,opt,name=end,proto3" json:"end,omitempty"`
+}
+
+func (x *PortInfo_Range) Reset() {
+ *x = PortInfo_Range{}
+ if protoimpl.UnsafeEnabled {
+ mi := &file_management_proto_msgTypes[36]
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ ms.StoreMessageInfo(mi)
+ }
+}
+
+func (x *PortInfo_Range) String() string {
+ return protoimpl.X.MessageStringOf(x)
+}
+
+func (*PortInfo_Range) ProtoMessage() {}
+
+func (x *PortInfo_Range) ProtoReflect() protoreflect.Message {
+ mi := &file_management_proto_msgTypes[36]
+ if protoimpl.UnsafeEnabled && x != nil {
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ if ms.LoadMessageInfo() == nil {
+ ms.StoreMessageInfo(mi)
+ }
+ return ms
+ }
+ return mi.MessageOf(x)
+}
+
+// Deprecated: Use PortInfo_Range.ProtoReflect.Descriptor instead.
+func (*PortInfo_Range) Descriptor() ([]byte, []int) {
+ return file_management_proto_rawDescGZIP(), []int{34, 0}
+}
+
+func (x *PortInfo_Range) GetStart() uint32 {
+ if x != nil {
+ return x.Start
+ }
+ return 0
+}
+
+func (x *PortInfo_Range) GetEnd() uint32 {
+ if x != nil {
+ return x.End
+ }
+ return 0
+}
+
var File_management_proto protoreflect.FileDescriptor
var file_management_proto_rawDesc = []byte{
@@ -2835,7 +3083,7 @@ var file_management_proto_rawDesc = []byte{
0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65,
0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, 0x73,
0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18,
- 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x22, 0xe2, 0x03, 0x0a, 0x0a,
+ 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x22, 0xf3, 0x04, 0x0a, 0x0a,
0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x12, 0x16, 0x0a, 0x06, 0x53, 0x65,
0x72, 0x69, 0x61, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, 0x06, 0x53, 0x65, 0x72, 0x69,
0x61, 0x6c, 0x12, 0x36, 0x0a, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67,
@@ -2866,184 +3114,219 @@ var file_management_proto_rawDesc = []byte{
0x66, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45,
0x6d, 0x70, 0x74, 0x79, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x66, 0x69, 0x72, 0x65,
0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79,
- 0x22, 0x97, 0x01, 0x0a, 0x10, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43,
- 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65,
- 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65,
- 0x79, 0x12, 0x1e, 0x0a, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, 0x18,
- 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70,
- 0x73, 0x12, 0x33, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x03,
- 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e,
- 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, 0x73, 0x68,
- 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x04,
- 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x22, 0x49, 0x0a, 0x09, 0x53, 0x53,
- 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1e, 0x0a, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e,
- 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x73, 0x68,
- 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1c, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75,
- 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x73, 0x73, 0x68, 0x50,
- 0x75, 0x62, 0x4b, 0x65, 0x79, 0x22, 0x20, 0x0a, 0x1e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41,
- 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77,
- 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0xbf, 0x01, 0x0a, 0x17, 0x44, 0x65, 0x76, 0x69,
- 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46,
- 0x6c, 0x6f, 0x77, 0x12, 0x48, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x18,
- 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x2c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65,
- 0x6e, 0x74, 0x2e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69,
- 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x2e, 0x70, 0x72, 0x6f, 0x76, 0x69,
- 0x64, 0x65, 0x72, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x42, 0x0a,
- 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18,
- 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65,
- 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69,
- 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69,
- 0x67, 0x22, 0x16, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x0a, 0x0a,
- 0x06, 0x48, 0x4f, 0x53, 0x54, 0x45, 0x44, 0x10, 0x00, 0x22, 0x1e, 0x0a, 0x1c, 0x50, 0x4b, 0x43,
- 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c,
- 0x6f, 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x5b, 0x0a, 0x15, 0x50, 0x4b, 0x43,
- 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c,
- 0x6f, 0x77, 0x12, 0x42, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f,
- 0x6e, 0x66, 0x69, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e,
- 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72,
- 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72,
- 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0xea, 0x02, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69,
- 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x43, 0x6c, 0x69,
- 0x65, 0x6e, 0x74, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x43, 0x6c, 0x69,
- 0x65, 0x6e, 0x74, 0x49, 0x44, 0x12, 0x22, 0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53,
- 0x65, 0x63, 0x72, 0x65, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x43, 0x6c, 0x69,
- 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d,
- 0x61, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69,
- 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x18, 0x04, 0x20,
- 0x01, 0x28, 0x09, 0x52, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x2e, 0x0a,
- 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f,
- 0x69, 0x6e, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63,
- 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x24, 0x0a,
- 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x06,
- 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f,
- 0x69, 0x6e, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x18, 0x07, 0x20, 0x01,
- 0x28, 0x09, 0x52, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x12, 0x1e, 0x0a, 0x0a, 0x55, 0x73, 0x65,
- 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x55,
- 0x73, 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x34, 0x0a, 0x15, 0x41, 0x75, 0x74,
- 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69,
- 0x6e, 0x74, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x15, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72,
- 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12,
- 0x22, 0x0a, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73, 0x18,
- 0x0a, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55,
- 0x52, 0x4c, 0x73, 0x22, 0xed, 0x01, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, 0x0a,
- 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, 0x0a,
- 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07,
- 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f,
- 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x4e, 0x65,
- 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x65, 0x65,
- 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, 0x16, 0x0a,
- 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4d,
- 0x65, 0x74, 0x72, 0x69, 0x63, 0x12, 0x1e, 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72,
- 0x61, 0x64, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75,
- 0x65, 0x72, 0x61, 0x64, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x18, 0x07,
- 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x44,
- 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f,
- 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, 0x75,
- 0x74, 0x65, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f,
- 0x75, 0x74, 0x65, 0x22, 0xb4, 0x01, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69,
- 0x67, 0x12, 0x24, 0x0a, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62,
- 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63,
- 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x47, 0x0a, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53,
- 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28,
- 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e,
- 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x52, 0x10,
- 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73,
- 0x12, 0x38, 0x0a, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x18,
- 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65,
- 0x6e, 0x74, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x52, 0x0b, 0x43,
- 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x22, 0x58, 0x0a, 0x0a, 0x43, 0x75,
- 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61,
- 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e,
- 0x12, 0x32, 0x0a, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28,
- 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53,
- 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, 0x65, 0x63,
- 0x6f, 0x72, 0x64, 0x73, 0x22, 0x74, 0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65,
- 0x63, 0x6f, 0x72, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01,
- 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65,
- 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05,
- 0x43, 0x6c, 0x61, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, 0x61,
- 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x54, 0x54, 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52,
- 0x03, 0x54, 0x54, 0x4c, 0x12, 0x14, 0x0a, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x18, 0x05, 0x20,
- 0x01, 0x28, 0x09, 0x52, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x22, 0xb3, 0x01, 0x0a, 0x0f, 0x4e,
- 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x38,
- 0x0a, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20,
- 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74,
- 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61, 0x6d,
- 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69, 0x6d,
- 0x61, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61,
- 0x72, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03, 0x20,
- 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x32, 0x0a, 0x14,
- 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61,
- 0x62, 0x6c, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x53, 0x65, 0x61, 0x72,
- 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64,
- 0x22, 0x48, 0x0a, 0x0a, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x0e,
- 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16,
- 0x0a, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06,
- 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03,
- 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x22, 0xf0, 0x02, 0x0a, 0x0c, 0x46,
- 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x50,
- 0x65, 0x65, 0x72, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x65, 0x65,
- 0x72, 0x49, 0x50, 0x12, 0x40, 0x0a, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e,
- 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x22, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d,
- 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65,
- 0x2e, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x09, 0x44, 0x69, 0x72, 0x65,
- 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x37, 0x0a, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18,
- 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1f, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65,
- 0x6e, 0x74, 0x2e, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x2e,
- 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x3d,
- 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e,
- 0x32, 0x21, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x69,
- 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f,
- 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x12, 0x0a,
- 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x6f, 0x72,
- 0x74, 0x22, 0x1c, 0x0a, 0x09, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x06,
- 0x0a, 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01, 0x22,
- 0x1e, 0x0a, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x43, 0x43,
- 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, 0x4f, 0x50, 0x10, 0x01, 0x22,
- 0x3c, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55,
- 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x41, 0x4c, 0x4c, 0x10,
- 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, 0x07, 0x0a, 0x03, 0x55, 0x44,
- 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x43, 0x4d, 0x50, 0x10, 0x04, 0x22, 0x38, 0x0a,
- 0x0e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12,
- 0x14, 0x0a, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05,
- 0x6e, 0x65, 0x74, 0x49, 0x50, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x61, 0x63, 0x18, 0x02, 0x20, 0x01,
- 0x28, 0x09, 0x52, 0x03, 0x6d, 0x61, 0x63, 0x22, 0x1e, 0x0a, 0x06, 0x43, 0x68, 0x65, 0x63, 0x6b,
- 0x73, 0x12, 0x14, 0x0a, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09,
- 0x52, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x32, 0x90, 0x04, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61,
- 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a,
- 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d,
- 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73,
- 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e,
- 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61,
- 0x67, 0x65, 0x22, 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d,
- 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70,
- 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e,
- 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65,
- 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c,
- 0x47, 0x65, 0x74, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d,
- 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a,
- 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72,
- 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00,
- 0x12, 0x33, 0x0a, 0x09, 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e,
- 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79,
- 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d,
- 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69,
- 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46,
- 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74,
+ 0x12, 0x4f, 0x0a, 0x13, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61,
+ 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x0a, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1d, 0x2e,
+ 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65,
+ 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x13, 0x72, 0x6f,
+ 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65,
+ 0x73, 0x12, 0x3e, 0x0a, 0x1a, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77,
+ 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18,
+ 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, 0x1a, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72,
+ 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74,
+ 0x79, 0x22, 0x97, 0x01, 0x0a, 0x10, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72,
+ 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b,
+ 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b,
+ 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73,
+ 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49,
+ 0x70, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18,
+ 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65,
+ 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, 0x73,
+ 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18,
+ 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x22, 0x49, 0x0a, 0x09, 0x53,
+ 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1e, 0x0a, 0x0a, 0x73, 0x73, 0x68, 0x45,
+ 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x73,
+ 0x68, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1c, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x50,
+ 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x73, 0x73, 0x68,
+ 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x22, 0x20, 0x0a, 0x1e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65,
+ 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f,
+ 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0xbf, 0x01, 0x0a, 0x17, 0x44, 0x65, 0x76,
+ 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e,
+ 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x48, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72,
+ 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x2c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d,
+ 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72,
+ 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x2e, 0x70, 0x72, 0x6f, 0x76,
+ 0x69, 0x64, 0x65, 0x72, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x42,
+ 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67,
+ 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d,
+ 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66,
+ 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66,
+ 0x69, 0x67, 0x22, 0x16, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x0a,
+ 0x0a, 0x06, 0x48, 0x4f, 0x53, 0x54, 0x45, 0x44, 0x10, 0x00, 0x22, 0x1e, 0x0a, 0x1c, 0x50, 0x4b,
+ 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46,
+ 0x6c, 0x6f, 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x5b, 0x0a, 0x15, 0x50, 0x4b,
+ 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46,
+ 0x6c, 0x6f, 0x77, 0x12, 0x42, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43,
+ 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61,
+ 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65,
+ 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65,
+ 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0xea, 0x02, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76,
+ 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x43, 0x6c,
+ 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x43, 0x6c,
+ 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, 0x12, 0x22, 0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74,
+ 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x43, 0x6c,
+ 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f,
+ 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61,
+ 0x69, 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x18, 0x04,
+ 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x2e,
+ 0x0a, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70,
+ 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x44, 0x65, 0x76, 0x69,
+ 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x24,
+ 0x0a, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18,
+ 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70,
+ 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x18, 0x07, 0x20,
+ 0x01, 0x28, 0x09, 0x52, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x12, 0x1e, 0x0a, 0x0a, 0x55, 0x73,
+ 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a,
+ 0x55, 0x73, 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x34, 0x0a, 0x15, 0x41, 0x75,
+ 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f,
+ 0x69, 0x6e, 0x74, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x15, 0x41, 0x75, 0x74, 0x68, 0x6f,
+ 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74,
+ 0x12, 0x22, 0x0a, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73,
+ 0x18, 0x0a, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74,
+ 0x55, 0x52, 0x4c, 0x73, 0x22, 0xed, 0x01, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e,
+ 0x0a, 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18,
+ 0x0a, 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52,
+ 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, 0x77,
+ 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x4e,
+ 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x65,
+ 0x65, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, 0x16,
+ 0x0a, 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06,
+ 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x12, 0x1e, 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65,
+ 0x72, 0x61, 0x64, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, 0x71,
+ 0x75, 0x65, 0x72, 0x61, 0x64, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x18,
+ 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07,
+ 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44,
+ 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f,
+ 0x75, 0x74, 0x65, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52,
+ 0x6f, 0x75, 0x74, 0x65, 0x22, 0xb4, 0x01, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66,
+ 0x69, 0x67, 0x12, 0x24, 0x0a, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61,
+ 0x62, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69,
+ 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x47, 0x0a, 0x10, 0x4e, 0x61, 0x6d, 0x65,
+ 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03,
+ 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e,
+ 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x52,
+ 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70,
+ 0x73, 0x12, 0x38, 0x0a, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73,
+ 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d,
+ 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x52, 0x0b,
+ 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x22, 0x58, 0x0a, 0x0a, 0x43,
+ 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d,
+ 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69,
+ 0x6e, 0x12, 0x32, 0x0a, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, 0x03,
+ 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e,
+ 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, 0x65,
+ 0x63, 0x6f, 0x72, 0x64, 0x73, 0x22, 0x74, 0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52,
+ 0x65, 0x63, 0x6f, 0x72, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20,
+ 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x79, 0x70,
+ 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a,
+ 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c,
+ 0x61, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x54, 0x54, 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03,
+ 0x52, 0x03, 0x54, 0x54, 0x4c, 0x12, 0x14, 0x0a, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x18, 0x05,
+ 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x22, 0xb3, 0x01, 0x0a, 0x0f,
+ 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12,
+ 0x38, 0x0a, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01,
+ 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e,
+ 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61,
+ 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69,
+ 0x6d, 0x61, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d,
+ 0x61, 0x72, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03,
+ 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x32, 0x0a,
+ 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e,
+ 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x53, 0x65, 0x61,
+ 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65,
+ 0x64, 0x22, 0x48, 0x0a, 0x0a, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12,
+ 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12,
+ 0x16, 0x0a, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52,
+ 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18,
+ 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x22, 0xd9, 0x01, 0x0a, 0x0c,
+ 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x16, 0x0a, 0x06,
+ 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x65,
+ 0x65, 0x72, 0x49, 0x50, 0x12, 0x37, 0x0a, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f,
+ 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65,
+ 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69,
+ 0x6f, 0x6e, 0x52, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x2e, 0x0a,
+ 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e,
+ 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41,
+ 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a,
+ 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32,
+ 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c,
+ 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f,
+ 0x63, 0x6f, 0x6c, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28,
+ 0x09, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x22, 0x38, 0x0a, 0x0e, 0x4e, 0x65, 0x74, 0x77, 0x6f,
+ 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x6e, 0x65, 0x74,
+ 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x12,
+ 0x10, 0x0a, 0x03, 0x6d, 0x61, 0x63, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6d, 0x61,
+ 0x63, 0x22, 0x1e, 0x0a, 0x06, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x46,
+ 0x69, 0x6c, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x05, 0x46, 0x69, 0x6c, 0x65,
+ 0x73, 0x22, 0x96, 0x01, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x14,
+ 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x48, 0x00, 0x52, 0x04,
+ 0x70, 0x6f, 0x72, 0x74, 0x12, 0x32, 0x0a, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x18, 0x02, 0x20,
+ 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74,
+ 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x2e, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x48,
+ 0x00, 0x52, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x1a, 0x2f, 0x0a, 0x05, 0x52, 0x61, 0x6e, 0x67,
+ 0x65, 0x12, 0x14, 0x0a, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d,
+ 0x52, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x65, 0x6e, 0x64, 0x18, 0x02,
+ 0x20, 0x01, 0x28, 0x0d, 0x52, 0x03, 0x65, 0x6e, 0x64, 0x42, 0x0f, 0x0a, 0x0d, 0x70, 0x6f, 0x72,
+ 0x74, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x8f, 0x02, 0x0a, 0x11, 0x52,
+ 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65,
+ 0x12, 0x22, 0x0a, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73,
+ 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61,
+ 0x6e, 0x67, 0x65, 0x73, 0x12, 0x2e, 0x0a, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02,
+ 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e,
+ 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x61, 0x63,
+ 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x20, 0x0a, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74,
+ 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69,
+ 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63,
+ 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67,
+ 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63,
+ 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x30, 0x0a, 0x08,
+ 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14,
+ 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74,
+ 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1c,
+ 0x0a, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x18, 0x06, 0x20, 0x01, 0x28,
+ 0x08, 0x52, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x2a, 0x40, 0x0a, 0x0c,
+ 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0b, 0x0a, 0x07,
+ 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x41, 0x4c, 0x4c,
+ 0x10, 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, 0x07, 0x0a, 0x03, 0x55,
+ 0x44, 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x43, 0x4d, 0x50, 0x10, 0x04, 0x2a, 0x20,
+ 0x0a, 0x0d, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12,
+ 0x06, 0x0a, 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01,
+ 0x2a, 0x22, 0x0a, 0x0a, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a,
+ 0x0a, 0x06, 0x41, 0x43, 0x43, 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52,
+ 0x4f, 0x50, 0x10, 0x01, 0x32, 0x90, 0x04, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d,
+ 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f,
+ 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74,
0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67,
0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45,
0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22,
- 0x00, 0x12, 0x58, 0x0a, 0x18, 0x47, 0x65, 0x74, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68,
- 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e,
- 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79,
- 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61,
- 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74,
- 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x08, 0x53,
- 0x79, 0x6e, 0x63, 0x4d, 0x65, 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65,
+ 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61,
+ 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64,
+ 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65,
0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65,
- 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65,
- 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70,
- 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
+ 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x47, 0x65, 0x74,
+ 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61,
+ 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1d, 0x2e, 0x6d,
+ 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72,
+ 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a,
+ 0x09, 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e,
+ 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e,
+ 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79,
+ 0x22, 0x00, 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41,
+ 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77,
+ 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e,
+ 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c,
+ 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72,
+ 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x58,
+ 0x0a, 0x18, 0x47, 0x65, 0x74, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69,
+ 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e,
+ 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65,
+ 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67,
+ 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d,
+ 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x08, 0x53, 0x79, 0x6e, 0x63,
+ 0x4d, 0x65, 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e,
+ 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61,
+ 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e,
+ 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74,
+ 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
@@ -3059,13 +3342,13 @@ func file_management_proto_rawDescGZIP() []byte {
}
var file_management_proto_enumTypes = make([]protoimpl.EnumInfo, 5)
-var file_management_proto_msgTypes = make([]protoimpl.MessageInfo, 34)
+var file_management_proto_msgTypes = make([]protoimpl.MessageInfo, 37)
var file_management_proto_goTypes = []interface{}{
- (HostConfig_Protocol)(0), // 0: management.HostConfig.Protocol
- (DeviceAuthorizationFlowProvider)(0), // 1: management.DeviceAuthorizationFlow.provider
- (FirewallRuleDirection)(0), // 2: management.FirewallRule.direction
- (FirewallRuleAction)(0), // 3: management.FirewallRule.action
- (FirewallRuleProtocol)(0), // 4: management.FirewallRule.protocol
+ (RuleProtocol)(0), // 0: management.RuleProtocol
+ (RuleDirection)(0), // 1: management.RuleDirection
+ (RuleAction)(0), // 2: management.RuleAction
+ (HostConfig_Protocol)(0), // 3: management.HostConfig.Protocol
+ (DeviceAuthorizationFlowProvider)(0), // 4: management.DeviceAuthorizationFlow.provider
(*EncryptedMessage)(nil), // 5: management.EncryptedMessage
(*SyncRequest)(nil), // 6: management.SyncRequest
(*SyncResponse)(nil), // 7: management.SyncResponse
@@ -3100,7 +3383,10 @@ var file_management_proto_goTypes = []interface{}{
(*FirewallRule)(nil), // 36: management.FirewallRule
(*NetworkAddress)(nil), // 37: management.NetworkAddress
(*Checks)(nil), // 38: management.Checks
- (*timestamppb.Timestamp)(nil), // 39: google.protobuf.Timestamp
+ (*PortInfo)(nil), // 39: management.PortInfo
+ (*RouteFirewallRule)(nil), // 40: management.RouteFirewallRule
+ (*PortInfo_Range)(nil), // 41: management.PortInfo.Range
+ (*timestamppb.Timestamp)(nil), // 42: google.protobuf.Timestamp
}
var file_management_proto_depIdxs = []int32{
13, // 0: management.SyncRequest.meta:type_name -> management.PeerSystemMeta
@@ -3118,12 +3404,12 @@ var file_management_proto_depIdxs = []int32{
17, // 12: management.LoginResponse.wiretrusteeConfig:type_name -> management.WiretrusteeConfig
21, // 13: management.LoginResponse.peerConfig:type_name -> management.PeerConfig
38, // 14: management.LoginResponse.Checks:type_name -> management.Checks
- 39, // 15: management.ServerKeyResponse.expiresAt:type_name -> google.protobuf.Timestamp
+ 42, // 15: management.ServerKeyResponse.expiresAt:type_name -> google.protobuf.Timestamp
18, // 16: management.WiretrusteeConfig.stuns:type_name -> management.HostConfig
20, // 17: management.WiretrusteeConfig.turns:type_name -> management.ProtectedHostConfig
18, // 18: management.WiretrusteeConfig.signal:type_name -> management.HostConfig
19, // 19: management.WiretrusteeConfig.relay:type_name -> management.RelayConfig
- 0, // 20: management.HostConfig.protocol:type_name -> management.HostConfig.Protocol
+ 3, // 20: management.HostConfig.protocol:type_name -> management.HostConfig.Protocol
18, // 21: management.ProtectedHostConfig.hostConfig:type_name -> management.HostConfig
24, // 22: management.PeerConfig.sshConfig:type_name -> management.SSHConfig
21, // 23: management.NetworkMap.peerConfig:type_name -> management.PeerConfig
@@ -3132,36 +3418,41 @@ var file_management_proto_depIdxs = []int32{
31, // 26: management.NetworkMap.DNSConfig:type_name -> management.DNSConfig
23, // 27: management.NetworkMap.offlinePeers:type_name -> management.RemotePeerConfig
36, // 28: management.NetworkMap.FirewallRules:type_name -> management.FirewallRule
- 24, // 29: management.RemotePeerConfig.sshConfig:type_name -> management.SSHConfig
- 1, // 30: management.DeviceAuthorizationFlow.Provider:type_name -> management.DeviceAuthorizationFlow.provider
- 29, // 31: management.DeviceAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig
- 29, // 32: management.PKCEAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig
- 34, // 33: management.DNSConfig.NameServerGroups:type_name -> management.NameServerGroup
- 32, // 34: management.DNSConfig.CustomZones:type_name -> management.CustomZone
- 33, // 35: management.CustomZone.Records:type_name -> management.SimpleRecord
- 35, // 36: management.NameServerGroup.NameServers:type_name -> management.NameServer
- 2, // 37: management.FirewallRule.Direction:type_name -> management.FirewallRule.direction
- 3, // 38: management.FirewallRule.Action:type_name -> management.FirewallRule.action
- 4, // 39: management.FirewallRule.Protocol:type_name -> management.FirewallRule.protocol
- 5, // 40: management.ManagementService.Login:input_type -> management.EncryptedMessage
- 5, // 41: management.ManagementService.Sync:input_type -> management.EncryptedMessage
- 16, // 42: management.ManagementService.GetServerKey:input_type -> management.Empty
- 16, // 43: management.ManagementService.isHealthy:input_type -> management.Empty
- 5, // 44: management.ManagementService.GetDeviceAuthorizationFlow:input_type -> management.EncryptedMessage
- 5, // 45: management.ManagementService.GetPKCEAuthorizationFlow:input_type -> management.EncryptedMessage
- 5, // 46: management.ManagementService.SyncMeta:input_type -> management.EncryptedMessage
- 5, // 47: management.ManagementService.Login:output_type -> management.EncryptedMessage
- 5, // 48: management.ManagementService.Sync:output_type -> management.EncryptedMessage
- 15, // 49: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse
- 16, // 50: management.ManagementService.isHealthy:output_type -> management.Empty
- 5, // 51: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage
- 5, // 52: management.ManagementService.GetPKCEAuthorizationFlow:output_type -> management.EncryptedMessage
- 16, // 53: management.ManagementService.SyncMeta:output_type -> management.Empty
- 47, // [47:54] is the sub-list for method output_type
- 40, // [40:47] is the sub-list for method input_type
- 40, // [40:40] is the sub-list for extension type_name
- 40, // [40:40] is the sub-list for extension extendee
- 0, // [0:40] is the sub-list for field type_name
+ 40, // 29: management.NetworkMap.routesFirewallRules:type_name -> management.RouteFirewallRule
+ 24, // 30: management.RemotePeerConfig.sshConfig:type_name -> management.SSHConfig
+ 4, // 31: management.DeviceAuthorizationFlow.Provider:type_name -> management.DeviceAuthorizationFlow.provider
+ 29, // 32: management.DeviceAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig
+ 29, // 33: management.PKCEAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig
+ 34, // 34: management.DNSConfig.NameServerGroups:type_name -> management.NameServerGroup
+ 32, // 35: management.DNSConfig.CustomZones:type_name -> management.CustomZone
+ 33, // 36: management.CustomZone.Records:type_name -> management.SimpleRecord
+ 35, // 37: management.NameServerGroup.NameServers:type_name -> management.NameServer
+ 1, // 38: management.FirewallRule.Direction:type_name -> management.RuleDirection
+ 2, // 39: management.FirewallRule.Action:type_name -> management.RuleAction
+ 0, // 40: management.FirewallRule.Protocol:type_name -> management.RuleProtocol
+ 41, // 41: management.PortInfo.range:type_name -> management.PortInfo.Range
+ 2, // 42: management.RouteFirewallRule.action:type_name -> management.RuleAction
+ 0, // 43: management.RouteFirewallRule.protocol:type_name -> management.RuleProtocol
+ 39, // 44: management.RouteFirewallRule.portInfo:type_name -> management.PortInfo
+ 5, // 45: management.ManagementService.Login:input_type -> management.EncryptedMessage
+ 5, // 46: management.ManagementService.Sync:input_type -> management.EncryptedMessage
+ 16, // 47: management.ManagementService.GetServerKey:input_type -> management.Empty
+ 16, // 48: management.ManagementService.isHealthy:input_type -> management.Empty
+ 5, // 49: management.ManagementService.GetDeviceAuthorizationFlow:input_type -> management.EncryptedMessage
+ 5, // 50: management.ManagementService.GetPKCEAuthorizationFlow:input_type -> management.EncryptedMessage
+ 5, // 51: management.ManagementService.SyncMeta:input_type -> management.EncryptedMessage
+ 5, // 52: management.ManagementService.Login:output_type -> management.EncryptedMessage
+ 5, // 53: management.ManagementService.Sync:output_type -> management.EncryptedMessage
+ 15, // 54: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse
+ 16, // 55: management.ManagementService.isHealthy:output_type -> management.Empty
+ 5, // 56: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage
+ 5, // 57: management.ManagementService.GetPKCEAuthorizationFlow:output_type -> management.EncryptedMessage
+ 16, // 58: management.ManagementService.SyncMeta:output_type -> management.Empty
+ 52, // [52:59] is the sub-list for method output_type
+ 45, // [45:52] is the sub-list for method input_type
+ 45, // [45:45] is the sub-list for extension type_name
+ 45, // [45:45] is the sub-list for extension extendee
+ 0, // [0:45] is the sub-list for field type_name
}
func init() { file_management_proto_init() }
@@ -3578,6 +3869,46 @@ func file_management_proto_init() {
return nil
}
}
+ file_management_proto_msgTypes[34].Exporter = func(v interface{}, i int) interface{} {
+ switch v := v.(*PortInfo); i {
+ case 0:
+ return &v.state
+ case 1:
+ return &v.sizeCache
+ case 2:
+ return &v.unknownFields
+ default:
+ return nil
+ }
+ }
+ file_management_proto_msgTypes[35].Exporter = func(v interface{}, i int) interface{} {
+ switch v := v.(*RouteFirewallRule); i {
+ case 0:
+ return &v.state
+ case 1:
+ return &v.sizeCache
+ case 2:
+ return &v.unknownFields
+ default:
+ return nil
+ }
+ }
+ file_management_proto_msgTypes[36].Exporter = func(v interface{}, i int) interface{} {
+ switch v := v.(*PortInfo_Range); i {
+ case 0:
+ return &v.state
+ case 1:
+ return &v.sizeCache
+ case 2:
+ return &v.unknownFields
+ default:
+ return nil
+ }
+ }
+ }
+ file_management_proto_msgTypes[34].OneofWrappers = []interface{}{
+ (*PortInfo_Port)(nil),
+ (*PortInfo_Range_)(nil),
}
type x struct{}
out := protoimpl.TypeBuilder{
@@ -3585,7 +3916,7 @@ func file_management_proto_init() {
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_management_proto_rawDesc,
NumEnums: 5,
- NumMessages: 34,
+ NumMessages: 37,
NumExtensions: 0,
NumServices: 1,
},
diff --git a/management/proto/management.proto b/management/proto/management.proto
index c5646820f..fe6a828b1 100644
--- a/management/proto/management.proto
+++ b/management/proto/management.proto
@@ -254,6 +254,12 @@ message NetworkMap {
// firewallRulesIsEmpty indicates whether FirewallRule array is empty or not to bypass protobuf null and empty array equality.
bool firewallRulesIsEmpty = 9;
+
+ // RoutesFirewallRules represents a list of routes firewall rules to be applied to peer
+ repeated RouteFirewallRule routesFirewallRules = 10;
+
+ // RoutesFirewallRulesIsEmpty indicates whether RouteFirewallRule array is empty or not to bypass protobuf null and empty array equality.
+ bool routesFirewallRulesIsEmpty = 11;
}
// RemotePeerConfig represents a configuration of a remote peer.
@@ -384,29 +390,32 @@ message NameServer {
int64 Port = 3;
}
+enum RuleProtocol {
+ UNKNOWN = 0;
+ ALL = 1;
+ TCP = 2;
+ UDP = 3;
+ ICMP = 4;
+}
+
+enum RuleDirection {
+ IN = 0;
+ OUT = 1;
+}
+
+enum RuleAction {
+ ACCEPT = 0;
+ DROP = 1;
+}
+
+
// FirewallRule represents a firewall rule
message FirewallRule {
string PeerIP = 1;
- direction Direction = 2;
- action Action = 3;
- protocol Protocol = 4;
+ RuleDirection Direction = 2;
+ RuleAction Action = 3;
+ RuleProtocol Protocol = 4;
string Port = 5;
-
- enum direction {
- IN = 0;
- OUT = 1;
- }
- enum action {
- ACCEPT = 0;
- DROP = 1;
- }
- enum protocol {
- UNKNOWN = 0;
- ALL = 1;
- TCP = 2;
- UDP = 3;
- ICMP = 4;
- }
}
message NetworkAddress {
@@ -415,5 +424,40 @@ message NetworkAddress {
}
message Checks {
- repeated string Files= 1;
+ repeated string Files = 1;
}
+
+
+message PortInfo {
+ oneof portSelection {
+ uint32 port = 1;
+ Range range = 2;
+ }
+
+ message Range {
+ uint32 start = 1;
+ uint32 end = 2;
+ }
+}
+
+// RouteFirewallRule signifies a firewall rule applicable for a routed network.
+message RouteFirewallRule {
+ // sourceRanges IP ranges of the routing peers.
+ repeated string sourceRanges = 1;
+
+ // Action to be taken by the firewall when the rule is applicable.
+ RuleAction action = 2;
+
+ // Network prefix for the routed network.
+ string destination = 3;
+
+ // Protocol of the routed network.
+ RuleProtocol protocol = 4;
+
+ // Details about the port.
+ PortInfo portInfo = 5;
+
+ // IsDynamic indicates if the route is a DNS route.
+ bool isDynamic = 6;
+}
+
diff --git a/management/server/account.go b/management/server/account.go
index 710b6f62f..d5e8c8cf8 100644
--- a/management/server/account.go
+++ b/management/server/account.go
@@ -113,7 +113,7 @@ type AccountManager interface {
DeletePolicy(ctx context.Context, accountID, policyID, userID string) error
ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error)
GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
- CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
+ CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
SaveRoute(ctx context.Context, accountID, userID string, route *route.Route) error
DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error
ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error)
@@ -460,6 +460,7 @@ func (a *Account) GetPeerNetworkMap(
}
routesUpdate := a.getRoutesToSync(ctx, peerID, peersToConnect)
+ routesFirewallRules := a.getPeerRoutesFirewallRules(ctx, peerID, validatedPeersMap)
dnsManagementStatus := a.getPeerDNSManagementStatus(peerID)
dnsUpdate := nbdns.Config{
@@ -483,6 +484,7 @@ func (a *Account) GetPeerNetworkMap(
DNSConfig: dnsUpdate,
OfflinePeers: expiredPeers,
FirewallRules: firewallRules,
+ RoutesFirewallRules: routesFirewallRules,
}
if metrics != nil {
diff --git a/management/server/account_test.go b/management/server/account_test.go
index 303261bea..e554ae493 100644
--- a/management/server/account_test.go
+++ b/management/server/account_test.go
@@ -1599,9 +1599,10 @@ func TestAccount_Copy(t *testing.T) {
},
Routes: map[route.ID]*route.Route{
"route1": {
- ID: "route1",
- PeerGroups: []string{},
- Groups: []string{"group1"},
+ ID: "route1",
+ PeerGroups: []string{},
+ Groups: []string{"group1"},
+ AccessControlGroups: []string{},
},
},
NameServerGroups: map[string]*nbdns.NameServerGroup{
diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go
index cda3bc748..4c4ef6c3c 100644
--- a/management/server/grpcserver.go
+++ b/management/server/grpcserver.go
@@ -596,6 +596,10 @@ func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turn
response.NetworkMap.FirewallRules = firewallRules
response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0
+ routesFirewallRules := toProtocolRoutesFirewallRules(networkMap.RoutesFirewallRules)
+ response.NetworkMap.RoutesFirewallRules = routesFirewallRules
+ response.NetworkMap.RoutesFirewallRulesIsEmpty = len(routesFirewallRules) == 0
+
return response
}
diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml
index 2463f830e..fd0343e97 100644
--- a/management/server/http/api/openapi.yml
+++ b/management/server/http/api/openapi.yml
@@ -727,17 +727,39 @@ components:
enum: ["all", "tcp", "udp", "icmp"]
example: "tcp"
ports:
- description: Policy rule affected ports or it ranges list
+ description: Policy rule affected ports
type: array
items:
type: string
example: "80"
+ port_ranges:
+ description: Policy rule affected ports ranges list
+ type: array
+ items:
+ $ref: '#/components/schemas/RulePortRange'
required:
- name
- enabled
- bidirectional
- protocol
- action
+
+ RulePortRange:
+ description: Policy rule affected ports range
+ type: object
+ properties:
+ start:
+ description: The starting port of the range
+ type: integer
+ example: 80
+ end:
+ description: The ending port of the range
+ type: integer
+ example: 320
+ required:
+ - start
+ - end
+
PolicyRuleUpdate:
allOf:
- $ref: '#/components/schemas/PolicyRuleMinimum'
@@ -1106,6 +1128,12 @@ components:
description: Indicate if the route should be kept after a domain doesn't resolve that IP anymore
type: boolean
example: true
+ access_control_groups:
+ description: Access control group identifier associated with route.
+ type: array
+ items:
+ type: string
+ example: "chacbco6lnnbn6cg5s91"
required:
- id
- description
diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go
index b219d38fd..570ec03c5 100644
--- a/management/server/http/api/types.gen.go
+++ b/management/server/http/api/types.gen.go
@@ -780,7 +780,10 @@ type PolicyRule struct {
// Name Policy rule name identifier
Name string `json:"name"`
- // Ports Policy rule affected ports or it ranges list
+ // PortRanges Policy rule affected ports ranges list
+ PortRanges *[]RulePortRange `json:"port_ranges,omitempty"`
+
+ // Ports Policy rule affected ports
Ports *[]string `json:"ports,omitempty"`
// Protocol Policy rule type of the traffic
@@ -816,7 +819,10 @@ type PolicyRuleMinimum struct {
// Name Policy rule name identifier
Name string `json:"name"`
- // Ports Policy rule affected ports or it ranges list
+ // PortRanges Policy rule affected ports ranges list
+ PortRanges *[]RulePortRange `json:"port_ranges,omitempty"`
+
+ // Ports Policy rule affected ports
Ports *[]string `json:"ports,omitempty"`
// Protocol Policy rule type of the traffic
@@ -852,7 +858,10 @@ type PolicyRuleUpdate struct {
// Name Policy rule name identifier
Name string `json:"name"`
- // Ports Policy rule affected ports or it ranges list
+ // PortRanges Policy rule affected ports ranges list
+ PortRanges *[]RulePortRange `json:"port_ranges,omitempty"`
+
+ // Ports Policy rule affected ports
Ports *[]string `json:"ports,omitempty"`
// Protocol Policy rule type of the traffic
@@ -935,6 +944,9 @@ type ProcessCheck struct {
// Route defines model for Route.
type Route struct {
+ // AccessControlGroups Access control group identifier associated with route.
+ AccessControlGroups *[]string `json:"access_control_groups,omitempty"`
+
// Description Route description
Description string `json:"description"`
@@ -977,6 +989,9 @@ type Route struct {
// RouteRequest defines model for RouteRequest.
type RouteRequest struct {
+ // AccessControlGroups Access control group identifier associated with route.
+ AccessControlGroups *[]string `json:"access_control_groups,omitempty"`
+
// Description Route description
Description string `json:"description"`
@@ -1011,6 +1026,15 @@ type RouteRequest struct {
PeerGroups *[]string `json:"peer_groups,omitempty"`
}
+// RulePortRange Policy rule affected ports range
+type RulePortRange struct {
+ // End The ending port of the range
+ End int `json:"end"`
+
+ // Start The starting port of the range
+ Start int `json:"start"`
+}
+
// SetupKey defines model for SetupKey.
type SetupKey struct {
// AutoGroups List of group IDs to auto-assign to peers registered with this key
diff --git a/management/server/http/policies_handler.go b/management/server/http/policies_handler.go
index 225d7e1f3..73f3803b5 100644
--- a/management/server/http/policies_handler.go
+++ b/management/server/http/policies_handler.go
@@ -172,6 +172,11 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
return
}
+ if (rule.Ports != nil && len(*rule.Ports) != 0) && (rule.PortRanges != nil && len(*rule.PortRanges) != 0) {
+ util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "specify either individual ports or port ranges, not both"), w)
+ return
+ }
+
if rule.Ports != nil && len(*rule.Ports) != 0 {
for _, v := range *rule.Ports {
if port, err := strconv.Atoi(v); err != nil || port < 1 || port > 65535 {
@@ -182,10 +187,23 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
}
}
+ if rule.PortRanges != nil && len(*rule.PortRanges) != 0 {
+ for _, portRange := range *rule.PortRanges {
+ if portRange.Start < 1 || portRange.End > 65535 {
+ util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "valid port value is in 1..65535 range"), w)
+ return
+ }
+ pr.PortRanges = append(pr.PortRanges, server.RulePortRange{
+ Start: uint16(portRange.Start),
+ End: uint16(portRange.End),
+ })
+ }
+ }
+
// validate policy object
switch pr.Protocol {
case server.PolicyRuleProtocolALL, server.PolicyRuleProtocolICMP:
- if len(pr.Ports) != 0 {
+ if len(pr.Ports) != 0 || len(pr.PortRanges) != 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol ports is not allowed"), w)
return
}
@@ -194,7 +212,7 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
return
}
case server.PolicyRuleProtocolTCP, server.PolicyRuleProtocolUDP:
- if !pr.Bidirectional && len(pr.Ports) == 0 {
+ if !pr.Bidirectional && (len(pr.Ports) == 0 || len(pr.PortRanges) != 0) {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w)
return
}
@@ -320,6 +338,17 @@ func toPolicyResponse(groups []*nbgroup.Group, policy *server.Policy) *api.Polic
rule.Ports = &portsCopy
}
+ if len(r.PortRanges) != 0 {
+ portRanges := make([]api.RulePortRange, 0, len(r.PortRanges))
+ for _, portRange := range r.PortRanges {
+ portRanges = append(portRanges, api.RulePortRange{
+ End: int(portRange.End),
+ Start: int(portRange.Start),
+ })
+ }
+ rule.PortRanges = &portRanges
+ }
+
for _, gid := range r.Sources {
_, ok := cache[gid]
if ok {
diff --git a/management/server/http/routes_handler.go b/management/server/http/routes_handler.go
index 0932e6445..ce4edee4f 100644
--- a/management/server/http/routes_handler.go
+++ b/management/server/http/routes_handler.go
@@ -117,9 +117,14 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
peerGroupIds = *req.PeerGroups
}
+ var accessControlGroupIds []string
+ if req.AccessControlGroups != nil {
+ accessControlGroupIds = *req.AccessControlGroups
+ }
+
newRoute, err := h.accountManager.CreateRoute(r.Context(), accountID, newPrefix, networkType, domains, peerId, peerGroupIds,
- req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, userID, req.KeepRoute,
- )
+ req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, accessControlGroupIds, req.Enabled, userID, req.KeepRoute)
+
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -233,6 +238,10 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
newRoute.PeerGroups = *req.PeerGroups
}
+ if req.AccessControlGroups != nil {
+ newRoute.AccessControlGroups = *req.AccessControlGroups
+ }
+
err = h.accountManager.SaveRoute(r.Context(), accountID, userID, newRoute)
if err != nil {
util.WriteError(r.Context(), err, w)
@@ -326,6 +335,9 @@ func toRouteResponse(serverRoute *route.Route) (*api.Route, error) {
if len(serverRoute.PeerGroups) > 0 {
route.PeerGroups = &serverRoute.PeerGroups
}
+ if len(serverRoute.AccessControlGroups) > 0 {
+ route.AccessControlGroups = &serverRoute.AccessControlGroups
+ }
return route, nil
}
diff --git a/management/server/http/routes_handler_test.go b/management/server/http/routes_handler_test.go
index 2c367cac3..83bd7004d 100644
--- a/management/server/http/routes_handler_test.go
+++ b/management/server/http/routes_handler_test.go
@@ -105,7 +105,7 @@ func initRoutesTestData() *RoutesHandler {
}
return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID)
},
- CreateRouteFunc: func(_ context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, _ string, keepRoute bool) (*route.Route, error) {
+ CreateRouteFunc: func(_ context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroups []string, enabled bool, _ string, keepRoute bool) (*route.Route, error) {
if peerID == notFoundPeerID {
return nil, status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
}
@@ -119,18 +119,19 @@ func initRoutesTestData() *RoutesHandler {
}
return &route.Route{
- ID: existingRouteID,
- NetID: netID,
- Peer: peerID,
- PeerGroups: peerGroups,
- Network: prefix,
- Domains: domains,
- NetworkType: networkType,
- Description: description,
- Masquerade: masquerade,
- Enabled: enabled,
- Groups: groups,
- KeepRoute: keepRoute,
+ ID: existingRouteID,
+ NetID: netID,
+ Peer: peerID,
+ PeerGroups: peerGroups,
+ Network: prefix,
+ Domains: domains,
+ NetworkType: networkType,
+ Description: description,
+ Masquerade: masquerade,
+ Enabled: enabled,
+ Groups: groups,
+ KeepRoute: keepRoute,
+ AccessControlGroups: accessControlGroups,
}, nil
},
SaveRouteFunc: func(_ context.Context, _, _ string, r *route.Route) error {
@@ -268,6 +269,27 @@ func TestRoutesHandlers(t *testing.T) {
Groups: []string{existingGroupID},
},
},
+ {
+ name: "POST OK With Access Control Groups",
+ requestType: http.MethodPost,
+ requestPath: "/api/routes",
+ requestBody: bytes.NewBuffer(
+ []byte(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\",\"groups\":[\"%s\"],\"access_control_groups\":[\"%s\"]}", existingPeerID, existingGroupID, existingGroupID))),
+ expectedStatus: http.StatusOK,
+ expectedBody: true,
+ expectedRoute: &api.Route{
+ Id: existingRouteID,
+ Description: "Post",
+ NetworkId: "awesomeNet",
+ Network: toPtr("192.168.0.0/16"),
+ Peer: &existingPeerID,
+ NetworkType: route.IPv4NetworkString,
+ Masquerade: false,
+ Enabled: false,
+ Groups: []string{existingGroupID},
+ AccessControlGroups: &[]string{existingGroupID},
+ },
+ },
{
name: "POST Non Linux Peer",
requestType: http.MethodPost,
diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go
index df12ec1c4..b399be822 100644
--- a/management/server/mock_server/account_mock.go
+++ b/management/server/mock_server/account_mock.go
@@ -58,7 +58,7 @@ type MockAccountManager struct {
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
UpdatePeerSSHKeyFunc func(ctx context.Context, peerID string, sshKey string) error
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
- CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
+ CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups,accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error
DeleteRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) error
@@ -367,7 +367,7 @@ func (am *MockAccountManager) DeleteRule(ctx context.Context, accountID, ruleID,
if am.DeleteRuleFunc != nil {
return am.DeleteRuleFunc(ctx, accountID, ruleID, userID)
}
- return status.Errorf(codes.Unimplemented, "method DeleteRule is not implemented")
+ return status.Errorf(codes.Unimplemented, "method DeletePeerRule is not implemented")
}
// GetPolicy mock implementation of GetPolicy from server.AccountManager interface
@@ -442,9 +442,9 @@ func (am *MockAccountManager) UpdatePeer(ctx context.Context, accountID, userID
}
// CreateRoute mock implementation of CreateRoute from server.AccountManager interface
-func (am *MockAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) {
+func (am *MockAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupID []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) {
if am.CreateRouteFunc != nil {
- return am.CreateRouteFunc(ctx, accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups, enabled, userID, keepRoute)
+ return am.CreateRouteFunc(ctx, accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups,accessControlGroupID, enabled, userID, keepRoute)
}
return nil, status.Errorf(codes.Unimplemented, "method CreateRoute is not implemented")
}
diff --git a/management/server/network.go b/management/server/network.go
index 0e7d753a7..a5b188b46 100644
--- a/management/server/network.go
+++ b/management/server/network.go
@@ -26,12 +26,13 @@ const (
)
type NetworkMap struct {
- Peers []*nbpeer.Peer
- Network *Network
- Routes []*route.Route
- DNSConfig nbdns.Config
- OfflinePeers []*nbpeer.Peer
- FirewallRules []*FirewallRule
+ Peers []*nbpeer.Peer
+ Network *Network
+ Routes []*route.Route
+ DNSConfig nbdns.Config
+ OfflinePeers []*nbpeer.Peer
+ FirewallRules []*FirewallRule
+ RoutesFirewallRules []*RouteFirewallRule
}
type Network struct {
diff --git a/management/server/peer_test.go b/management/server/peer_test.go
index d329e04bc..387adb91d 100644
--- a/management/server/peer_test.go
+++ b/management/server/peer_test.go
@@ -646,7 +646,6 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
})
}
-
}
func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccountManager, string, string, error) {
@@ -991,9 +990,9 @@ func TestToSyncResponse(t *testing.T) {
// assert network map Firewall
assert.Equal(t, 1, len(response.NetworkMap.FirewallRules))
assert.Equal(t, "192.168.1.2", response.NetworkMap.FirewallRules[0].PeerIP)
- assert.Equal(t, proto.FirewallRule_IN, response.NetworkMap.FirewallRules[0].Direction)
- assert.Equal(t, proto.FirewallRule_ACCEPT, response.NetworkMap.FirewallRules[0].Action)
- assert.Equal(t, proto.FirewallRule_TCP, response.NetworkMap.FirewallRules[0].Protocol)
+ assert.Equal(t, proto.RuleDirection_IN, response.NetworkMap.FirewallRules[0].Direction)
+ assert.Equal(t, proto.RuleAction_ACCEPT, response.NetworkMap.FirewallRules[0].Action)
+ assert.Equal(t, proto.RuleProtocol_TCP, response.NetworkMap.FirewallRules[0].Protocol)
assert.Equal(t, "80", response.NetworkMap.FirewallRules[0].Port)
// assert posture checks
assert.Equal(t, 1, len(response.Checks))
diff --git a/management/server/policy.go b/management/server/policy.go
index 5d07ba8f8..75647de44 100644
--- a/management/server/policy.go
+++ b/management/server/policy.go
@@ -76,6 +76,12 @@ type PolicyUpdateOperation struct {
Values []string
}
+// RulePortRange represents a range of ports for a firewall rule.
+type RulePortRange struct {
+ Start uint16
+ End uint16
+}
+
// PolicyRule is the metadata of the policy
type PolicyRule struct {
// ID of the policy rule
@@ -110,6 +116,9 @@ type PolicyRule struct {
// Ports or it ranges list
Ports []string `gorm:"serializer:json"`
+
+ // PortRanges a list of port ranges.
+ PortRanges []RulePortRange `gorm:"serializer:json"`
}
// Copy returns a copy of a policy rule
@@ -125,10 +134,12 @@ func (pm *PolicyRule) Copy() *PolicyRule {
Bidirectional: pm.Bidirectional,
Protocol: pm.Protocol,
Ports: make([]string, len(pm.Ports)),
+ PortRanges: make([]RulePortRange, len(pm.PortRanges)),
}
copy(rule.Destinations, pm.Destinations)
copy(rule.Sources, pm.Sources)
copy(rule.Ports, pm.Ports)
+ copy(rule.PortRanges, pm.PortRanges)
return rule
}
@@ -445,36 +456,17 @@ func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Poli
return nil
}
-func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule {
- result := make([]*proto.FirewallRule, len(update))
- for i := range update {
- direction := proto.FirewallRule_IN
- if update[i].Direction == firewallRuleDirectionOUT {
- direction = proto.FirewallRule_OUT
- }
- action := proto.FirewallRule_ACCEPT
- if update[i].Action == string(PolicyTrafficActionDrop) {
- action = proto.FirewallRule_DROP
- }
-
- protocol := proto.FirewallRule_UNKNOWN
- switch PolicyRuleProtocolType(update[i].Protocol) {
- case PolicyRuleProtocolALL:
- protocol = proto.FirewallRule_ALL
- case PolicyRuleProtocolTCP:
- protocol = proto.FirewallRule_TCP
- case PolicyRuleProtocolUDP:
- protocol = proto.FirewallRule_UDP
- case PolicyRuleProtocolICMP:
- protocol = proto.FirewallRule_ICMP
- }
+func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule {
+ result := make([]*proto.FirewallRule, len(rules))
+ for i := range rules {
+ rule := rules[i]
result[i] = &proto.FirewallRule{
- PeerIP: update[i].PeerIP,
- Direction: direction,
- Action: action,
- Protocol: protocol,
- Port: update[i].Port,
+ PeerIP: rule.PeerIP,
+ Direction: getProtoDirection(rule.Direction),
+ Action: getProtoAction(rule.Action),
+ Protocol: getProtoProtocol(rule.Protocol),
+ Port: rule.Port,
}
}
return result
diff --git a/management/server/route.go b/management/server/route.go
index 6c1c8b1b3..39ee6170c 100644
--- a/management/server/route.go
+++ b/management/server/route.go
@@ -4,9 +4,15 @@ import (
"context"
"fmt"
"net/netip"
+ "slices"
+ "strconv"
+ "strings"
"unicode/utf8"
"github.com/rs/xid"
+ log "github.com/sirupsen/logrus"
+
+ nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/proto"
@@ -15,6 +21,30 @@ import (
"github.com/netbirdio/netbird/route"
)
+// RouteFirewallRule a firewall rule applicable for a routed network.
+type RouteFirewallRule struct {
+ // SourceRanges IP ranges of the routing peers.
+ SourceRanges []string
+
+ // Action of the traffic when the rule is applicable
+ Action string
+
+ // Destination a network prefix for the routed traffic
+ Destination string
+
+ // Protocol of the traffic
+ Protocol string
+
+ // Port of the traffic
+ Port uint16
+
+ // PortRange represents the range of ports for a firewall rule
+ PortRange RulePortRange
+
+ // isDynamic indicates whether the rule is for DNS routing
+ IsDynamic bool
+}
+
// GetRoute gets a route object from account and route IDs
func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
@@ -112,7 +142,7 @@ func getRouteDescriptor(prefix netip.Prefix, domains domain.List) string {
}
// CreateRoute creates and saves a new route
-func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) {
+func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
@@ -157,6 +187,13 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
}
}
+ if len(accessControlGroupIDs) > 0 {
+ err = validateGroups(accessControlGroupIDs, account.Groups)
+ if err != nil {
+ return nil, err
+ }
+ }
+
err = am.checkRoutePrefixOrDomainsExistForPeers(account, peerID, newRoute.ID, peerGroupIDs, prefix, domains)
if err != nil {
return nil, err
@@ -187,6 +224,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
newRoute.Enabled = enabled
newRoute.Groups = groups
newRoute.KeepRoute = keepRoute
+ newRoute.AccessControlGroups = accessControlGroupIDs
if account.Routes == nil {
account.Routes = make(map[route.ID]*route.Route)
@@ -258,6 +296,13 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
}
}
+ if len(routeToSave.AccessControlGroups) > 0 {
+ err = validateGroups(routeToSave.AccessControlGroups, account.Groups)
+ if err != nil {
+ return err
+ }
+ }
+
err = am.checkRoutePrefixOrDomainsExistForPeers(account, routeToSave.Peer, routeToSave.ID, routeToSave.Copy().PeerGroups, routeToSave.Network, routeToSave.Domains)
if err != nil {
return err
@@ -351,3 +396,248 @@ func getPlaceholderIP() netip.Prefix {
// Using an IP from the documentation range to minimize impact in case older clients try to set a route
return netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 0, 2, 0}), 32)
}
+
+// getPeerRoutesFirewallRules gets the routes firewall rules associated with a routing peer ID for the account.
+func (a *Account) getPeerRoutesFirewallRules(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) []*RouteFirewallRule {
+ routesFirewallRules := make([]*RouteFirewallRule, 0, len(a.Routes))
+
+ enabledRoutes, _ := a.getRoutingPeerRoutes(ctx, peerID)
+ for _, route := range enabledRoutes {
+ // If no access control groups are specified, accept all traffic.
+ if len(route.AccessControlGroups) == 0 {
+ defaultPermit := getDefaultPermit(route)
+ routesFirewallRules = append(routesFirewallRules, defaultPermit...)
+ continue
+ }
+
+ policies := getAllRoutePoliciesFromGroups(a, route.AccessControlGroups)
+ for _, policy := range policies {
+ if !policy.Enabled {
+ continue
+ }
+
+ for _, rule := range policy.Rules {
+ if !rule.Enabled {
+ continue
+ }
+
+ distributionGroupPeers, _ := a.getAllPeersFromGroups(ctx, route.Groups, peerID, nil, validatedPeersMap)
+ rules := generateRouteFirewallRules(ctx, route, rule, distributionGroupPeers, firewallRuleDirectionIN)
+ routesFirewallRules = append(routesFirewallRules, rules...)
+ }
+ }
+ }
+
+ return routesFirewallRules
+}
+
+func getDefaultPermit(route *route.Route) []*RouteFirewallRule {
+ var rules []*RouteFirewallRule
+
+ sources := []string{"0.0.0.0/0"}
+ if route.Network.Addr().Is6() {
+ sources = []string{"::/0"}
+ }
+ rule := RouteFirewallRule{
+ SourceRanges: sources,
+ Action: string(PolicyTrafficActionAccept),
+ Destination: route.Network.String(),
+ Protocol: string(PolicyRuleProtocolALL),
+ IsDynamic: route.IsDynamic(),
+ }
+
+ rules = append(rules, &rule)
+
+ // dynamic routes always contain an IPv4 placeholder as destination, hence we must add IPv6 rules additionally
+ if route.IsDynamic() {
+ ruleV6 := rule
+ ruleV6.SourceRanges = []string{"::/0"}
+ rules = append(rules, &ruleV6)
+ }
+
+ return rules
+}
+
+// getAllRoutePoliciesFromGroups retrieves route policies associated with the specified access control groups
+// and returns a list of policies that have rules with destinations matching the specified groups.
+func getAllRoutePoliciesFromGroups(account *Account, accessControlGroups []string) []*Policy {
+ routePolicies := make([]*Policy, 0)
+ for _, groupID := range accessControlGroups {
+ group, ok := account.Groups[groupID]
+ if !ok {
+ continue
+ }
+
+ for _, policy := range account.Policies {
+ for _, rule := range policy.Rules {
+ exist := slices.ContainsFunc(rule.Destinations, func(groupID string) bool {
+ return groupID == group.ID
+ })
+ if exist {
+ routePolicies = append(routePolicies, policy)
+ continue
+ }
+ }
+ }
+ }
+
+ return routePolicies
+}
+
+// generateRouteFirewallRules generates a list of firewall rules for a given route.
+func generateRouteFirewallRules(ctx context.Context, route *route.Route, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) []*RouteFirewallRule {
+ rulesExists := make(map[string]struct{})
+ rules := make([]*RouteFirewallRule, 0)
+
+ sourceRanges := make([]string, 0, len(groupPeers))
+ for _, peer := range groupPeers {
+ if peer == nil {
+ continue
+ }
+ sourceRanges = append(sourceRanges, fmt.Sprintf(AllowedIPsFormat, peer.IP))
+ }
+
+ baseRule := RouteFirewallRule{
+ SourceRanges: sourceRanges,
+ Action: string(rule.Action),
+ Destination: route.Network.String(),
+ Protocol: string(rule.Protocol),
+ IsDynamic: route.IsDynamic(),
+ }
+
+ // generate rule for port range
+ if len(rule.Ports) == 0 {
+ rules = append(rules, generateRulesWithPortRanges(baseRule, rule, rulesExists)...)
+ } else {
+ rules = append(rules, generateRulesWithPorts(ctx, baseRule, rule, rulesExists)...)
+
+ }
+
+ // TODO: generate IPv6 rules for dynamic routes
+
+ return rules
+}
+
+// generateRuleIDBase generates the base rule ID for checking duplicates.
+func generateRuleIDBase(rule *PolicyRule, baseRule RouteFirewallRule) string {
+ return rule.ID + strings.Join(baseRule.SourceRanges, ",") + strconv.Itoa(firewallRuleDirectionIN) + baseRule.Protocol + baseRule.Action
+}
+
+// generateRulesForPeer generates rules for a given peer based on ports and port ranges.
+func generateRulesWithPortRanges(baseRule RouteFirewallRule, rule *PolicyRule, rulesExists map[string]struct{}) []*RouteFirewallRule {
+ rules := make([]*RouteFirewallRule, 0)
+
+ ruleIDBase := generateRuleIDBase(rule, baseRule)
+ if len(rule.Ports) == 0 {
+ if len(rule.PortRanges) == 0 {
+ if _, ok := rulesExists[ruleIDBase]; !ok {
+ rulesExists[ruleIDBase] = struct{}{}
+ rules = append(rules, &baseRule)
+ }
+ } else {
+ for _, portRange := range rule.PortRanges {
+ ruleID := fmt.Sprintf("%s%d-%d", ruleIDBase, portRange.Start, portRange.End)
+ if _, ok := rulesExists[ruleID]; !ok {
+ rulesExists[ruleID] = struct{}{}
+ pr := baseRule
+ pr.PortRange = portRange
+ rules = append(rules, &pr)
+ }
+ }
+ }
+ return rules
+ }
+
+ return rules
+}
+
+// generateRulesWithPorts generates rules when specific ports are provided.
+func generateRulesWithPorts(ctx context.Context, baseRule RouteFirewallRule, rule *PolicyRule, rulesExists map[string]struct{}) []*RouteFirewallRule {
+ rules := make([]*RouteFirewallRule, 0)
+ ruleIDBase := generateRuleIDBase(rule, baseRule)
+
+ for _, port := range rule.Ports {
+ ruleID := ruleIDBase + port
+ if _, ok := rulesExists[ruleID]; ok {
+ continue
+ }
+ rulesExists[ruleID] = struct{}{}
+
+ pr := baseRule
+ p, err := strconv.ParseUint(port, 10, 16)
+ if err != nil {
+ log.WithContext(ctx).Errorf("failed to parse port %s for rule: %s", port, rule.ID)
+ continue
+ }
+
+ pr.Port = uint16(p)
+ rules = append(rules, &pr)
+ }
+
+ return rules
+}
+
+func toProtocolRoutesFirewallRules(rules []*RouteFirewallRule) []*proto.RouteFirewallRule {
+ result := make([]*proto.RouteFirewallRule, len(rules))
+ for i := range rules {
+ rule := rules[i]
+ result[i] = &proto.RouteFirewallRule{
+ SourceRanges: rule.SourceRanges,
+ Action: getProtoAction(rule.Action),
+ Destination: rule.Destination,
+ Protocol: getProtoProtocol(rule.Protocol),
+ PortInfo: getProtoPortInfo(rule),
+ IsDynamic: rule.IsDynamic,
+ }
+ }
+
+ return result
+}
+
+// getProtoDirection converts the direction to proto.RuleDirection.
+func getProtoDirection(direction int) proto.RuleDirection {
+ if direction == firewallRuleDirectionOUT {
+ return proto.RuleDirection_OUT
+ }
+ return proto.RuleDirection_IN
+}
+
+// getProtoAction converts the action to proto.RuleAction.
+func getProtoAction(action string) proto.RuleAction {
+ if action == string(PolicyTrafficActionDrop) {
+ return proto.RuleAction_DROP
+ }
+ return proto.RuleAction_ACCEPT
+}
+
+// getProtoProtocol converts the protocol to proto.RuleProtocol.
+func getProtoProtocol(protocol string) proto.RuleProtocol {
+ switch PolicyRuleProtocolType(protocol) {
+ case PolicyRuleProtocolALL:
+ return proto.RuleProtocol_ALL
+ case PolicyRuleProtocolTCP:
+ return proto.RuleProtocol_TCP
+ case PolicyRuleProtocolUDP:
+ return proto.RuleProtocol_UDP
+ case PolicyRuleProtocolICMP:
+ return proto.RuleProtocol_ICMP
+ default:
+ return proto.RuleProtocol_UNKNOWN
+ }
+}
+
+// getProtoPortInfo converts the port info to proto.PortInfo.
+func getProtoPortInfo(rule *RouteFirewallRule) *proto.PortInfo {
+ var portInfo proto.PortInfo
+ if rule.Port != 0 {
+ portInfo.PortSelection = &proto.PortInfo_Port{Port: uint32(rule.Port)}
+ } else if portRange := rule.PortRange; portRange.Start != 0 && portRange.End != 0 {
+ portInfo.PortSelection = &proto.PortInfo_Range_{
+ Range: &proto.PortInfo_Range{
+ Start: uint32(portRange.Start),
+ End: uint32(portRange.End),
+ },
+ }
+ }
+ return &portInfo
+}
diff --git a/management/server/route_test.go b/management/server/route_test.go
index 4533c6b7e..b556816be 100644
--- a/management/server/route_test.go
+++ b/management/server/route_test.go
@@ -2,6 +2,8 @@ package server
import (
"context"
+ "fmt"
+ "net"
"net/netip"
"testing"
@@ -44,18 +46,19 @@ var existingDomains = domain.List{"example.com"}
func TestCreateRoute(t *testing.T) {
type input struct {
- network netip.Prefix
- domains domain.List
- keepRoute bool
- networkType route.NetworkType
- netID route.NetID
- peerKey string
- peerGroupIDs []string
- description string
- masquerade bool
- metric int
- enabled bool
- groups []string
+ network netip.Prefix
+ domains domain.List
+ keepRoute bool
+ networkType route.NetworkType
+ netID route.NetID
+ peerKey string
+ peerGroupIDs []string
+ description string
+ masquerade bool
+ metric int
+ enabled bool
+ groups []string
+ accessControlGroups []string
}
testCases := []struct {
@@ -69,100 +72,107 @@ func TestCreateRoute(t *testing.T) {
{
name: "Happy Path Network",
inputArgs: input{
- network: netip.MustParsePrefix("192.168.0.0/16"),
- networkType: route.IPv4Network,
- netID: "happy",
- peerKey: peer1ID,
- description: "super",
- masquerade: false,
- metric: 9999,
- enabled: true,
- groups: []string{routeGroup1},
+ network: netip.MustParsePrefix("192.168.0.0/16"),
+ networkType: route.IPv4Network,
+ netID: "happy",
+ peerKey: peer1ID,
+ description: "super",
+ masquerade: false,
+ metric: 9999,
+ enabled: true,
+ groups: []string{routeGroup1},
+ accessControlGroups: []string{routeGroup1},
},
errFunc: require.NoError,
shouldCreate: true,
expectedRoute: &route.Route{
- Network: netip.MustParsePrefix("192.168.0.0/16"),
- NetworkType: route.IPv4Network,
- NetID: "happy",
- Peer: peer1ID,
- Description: "super",
- Masquerade: false,
- Metric: 9999,
- Enabled: true,
- Groups: []string{routeGroup1},
+ Network: netip.MustParsePrefix("192.168.0.0/16"),
+ NetworkType: route.IPv4Network,
+ NetID: "happy",
+ Peer: peer1ID,
+ Description: "super",
+ Masquerade: false,
+ Metric: 9999,
+ Enabled: true,
+ Groups: []string{routeGroup1},
+ AccessControlGroups: []string{routeGroup1},
},
},
{
name: "Happy Path Domains",
inputArgs: input{
- domains: domain.List{"domain1", "domain2"},
- keepRoute: true,
- networkType: route.DomainNetwork,
- netID: "happy",
- peerKey: peer1ID,
- description: "super",
- masquerade: false,
- metric: 9999,
- enabled: true,
- groups: []string{routeGroup1},
+ domains: domain.List{"domain1", "domain2"},
+ keepRoute: true,
+ networkType: route.DomainNetwork,
+ netID: "happy",
+ peerKey: peer1ID,
+ description: "super",
+ masquerade: false,
+ metric: 9999,
+ enabled: true,
+ groups: []string{routeGroup1},
+ accessControlGroups: []string{routeGroup1},
},
errFunc: require.NoError,
shouldCreate: true,
expectedRoute: &route.Route{
- Network: netip.MustParsePrefix("192.0.2.0/32"),
- Domains: domain.List{"domain1", "domain2"},
- NetworkType: route.DomainNetwork,
- NetID: "happy",
- Peer: peer1ID,
- Description: "super",
- Masquerade: false,
- Metric: 9999,
- Enabled: true,
- Groups: []string{routeGroup1},
- KeepRoute: true,
+ Network: netip.MustParsePrefix("192.0.2.0/32"),
+ Domains: domain.List{"domain1", "domain2"},
+ NetworkType: route.DomainNetwork,
+ NetID: "happy",
+ Peer: peer1ID,
+ Description: "super",
+ Masquerade: false,
+ Metric: 9999,
+ Enabled: true,
+ Groups: []string{routeGroup1},
+ KeepRoute: true,
+ AccessControlGroups: []string{routeGroup1},
},
},
{
name: "Happy Path Peer Groups",
inputArgs: input{
- network: netip.MustParsePrefix("192.168.0.0/16"),
- networkType: route.IPv4Network,
- netID: "happy",
- peerGroupIDs: []string{routeGroupHA1, routeGroupHA2},
- description: "super",
- masquerade: false,
- metric: 9999,
- enabled: true,
- groups: []string{routeGroup1, routeGroup2},
+ network: netip.MustParsePrefix("192.168.0.0/16"),
+ networkType: route.IPv4Network,
+ netID: "happy",
+ peerGroupIDs: []string{routeGroupHA1, routeGroupHA2},
+ description: "super",
+ masquerade: false,
+ metric: 9999,
+ enabled: true,
+ groups: []string{routeGroup1, routeGroup2},
+ accessControlGroups: []string{routeGroup1, routeGroup2},
},
errFunc: require.NoError,
shouldCreate: true,
expectedRoute: &route.Route{
- Network: netip.MustParsePrefix("192.168.0.0/16"),
- NetworkType: route.IPv4Network,
- NetID: "happy",
- PeerGroups: []string{routeGroupHA1, routeGroupHA2},
- Description: "super",
- Masquerade: false,
- Metric: 9999,
- Enabled: true,
- Groups: []string{routeGroup1, routeGroup2},
+ Network: netip.MustParsePrefix("192.168.0.0/16"),
+ NetworkType: route.IPv4Network,
+ NetID: "happy",
+ PeerGroups: []string{routeGroupHA1, routeGroupHA2},
+ Description: "super",
+ Masquerade: false,
+ Metric: 9999,
+ Enabled: true,
+ Groups: []string{routeGroup1, routeGroup2},
+ AccessControlGroups: []string{routeGroup1, routeGroup2},
},
},
{
name: "Both network and domains provided should fail",
inputArgs: input{
- network: netip.MustParsePrefix("192.168.0.0/16"),
- domains: domain.List{"domain1", "domain2"},
- netID: "happy",
- peerKey: peer1ID,
- peerGroupIDs: []string{routeGroupHA1},
- description: "super",
- masquerade: false,
- metric: 9999,
- enabled: true,
- groups: []string{routeGroup1},
+ network: netip.MustParsePrefix("192.168.0.0/16"),
+ domains: domain.List{"domain1", "domain2"},
+ netID: "happy",
+ peerKey: peer1ID,
+ peerGroupIDs: []string{routeGroupHA1},
+ description: "super",
+ masquerade: false,
+ metric: 9999,
+ enabled: true,
+ groups: []string{routeGroup1},
+ accessControlGroups: []string{routeGroup2},
},
errFunc: require.Error,
shouldCreate: false,
@@ -170,16 +180,17 @@ func TestCreateRoute(t *testing.T) {
{
name: "Both peer and peer_groups Provided Should Fail",
inputArgs: input{
- network: netip.MustParsePrefix("192.168.0.0/16"),
- networkType: route.IPv4Network,
- netID: "happy",
- peerKey: peer1ID,
- peerGroupIDs: []string{routeGroupHA1},
- description: "super",
- masquerade: false,
- metric: 9999,
- enabled: true,
- groups: []string{routeGroup1},
+ network: netip.MustParsePrefix("192.168.0.0/16"),
+ networkType: route.IPv4Network,
+ netID: "happy",
+ peerKey: peer1ID,
+ peerGroupIDs: []string{routeGroupHA1},
+ description: "super",
+ masquerade: false,
+ metric: 9999,
+ enabled: true,
+ groups: []string{routeGroup1},
+ accessControlGroups: []string{routeGroup2},
},
errFunc: require.Error,
shouldCreate: false,
@@ -423,13 +434,13 @@ func TestCreateRoute(t *testing.T) {
if testCase.createInitRoute {
groupAll, errInit := account.GetGroupAll()
require.NoError(t, errInit)
- _, errInit = am.CreateRoute(context.Background(), account.Id, existingNetwork, 1, nil, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, true, userID, false)
+ _, errInit = am.CreateRoute(context.Background(), account.Id, existingNetwork, 1, nil, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{}, true, userID, false)
require.NoError(t, errInit)
- _, errInit = am.CreateRoute(context.Background(), account.Id, netip.Prefix{}, 3, existingDomains, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, true, userID, false)
+ _, errInit = am.CreateRoute(context.Background(), account.Id, netip.Prefix{}, 3, existingDomains, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{groupAll.ID}, true, userID, false)
require.NoError(t, errInit)
}
- outRoute, err := am.CreateRoute(context.Background(), account.Id, testCase.inputArgs.network, testCase.inputArgs.networkType, testCase.inputArgs.domains, testCase.inputArgs.peerKey, testCase.inputArgs.peerGroupIDs, testCase.inputArgs.description, testCase.inputArgs.netID, testCase.inputArgs.masquerade, testCase.inputArgs.metric, testCase.inputArgs.groups, testCase.inputArgs.enabled, userID, testCase.inputArgs.keepRoute)
+ outRoute, err := am.CreateRoute(context.Background(), account.Id, testCase.inputArgs.network, testCase.inputArgs.networkType, testCase.inputArgs.domains, testCase.inputArgs.peerKey, testCase.inputArgs.peerGroupIDs, testCase.inputArgs.description, testCase.inputArgs.netID, testCase.inputArgs.masquerade, testCase.inputArgs.metric, testCase.inputArgs.groups, testCase.inputArgs.accessControlGroups, testCase.inputArgs.enabled, userID, testCase.inputArgs.keepRoute)
testCase.errFunc(t, err)
@@ -1037,15 +1048,16 @@ func TestDeleteRoute(t *testing.T) {
func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
baseRoute := &route.Route{
- Network: netip.MustParsePrefix("192.168.0.0/16"),
- NetID: "superNet",
- NetworkType: route.IPv4Network,
- PeerGroups: []string{routeGroupHA1, routeGroupHA2},
- Description: "ha route",
- Masquerade: false,
- Metric: 9999,
- Enabled: true,
- Groups: []string{routeGroup1, routeGroup2},
+ Network: netip.MustParsePrefix("192.168.0.0/16"),
+ NetID: "superNet",
+ NetworkType: route.IPv4Network,
+ PeerGroups: []string{routeGroupHA1, routeGroupHA2},
+ Description: "ha route",
+ Masquerade: false,
+ Metric: 9999,
+ Enabled: true,
+ Groups: []string{routeGroup1, routeGroup2},
+ AccessControlGroups: []string{routeGroup1},
}
am, err := createRouterManager(t)
@@ -1062,7 +1074,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
require.NoError(t, err)
require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes")
- newRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.Enabled, userID, baseRoute.KeepRoute)
+ newRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, baseRoute.Enabled, userID, baseRoute.KeepRoute)
require.NoError(t, err)
require.Equal(t, newRoute.Enabled, true)
@@ -1127,16 +1139,17 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
// no routes for peer in different groups
// no routes when route is deleted
baseRoute := &route.Route{
- ID: "testingRoute",
- Network: netip.MustParsePrefix("192.168.0.0/16"),
- NetID: "superNet",
- NetworkType: route.IPv4Network,
- Peer: peer1ID,
- Description: "super",
- Masquerade: false,
- Metric: 9999,
- Enabled: true,
- Groups: []string{routeGroup1},
+ ID: "testingRoute",
+ Network: netip.MustParsePrefix("192.168.0.0/16"),
+ NetID: "superNet",
+ NetworkType: route.IPv4Network,
+ Peer: peer1ID,
+ Description: "super",
+ Masquerade: false,
+ Metric: 9999,
+ Enabled: true,
+ Groups: []string{routeGroup1},
+ AccessControlGroups: []string{routeGroup1},
}
am, err := createRouterManager(t)
@@ -1153,7 +1166,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
require.NoError(t, err)
require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes")
- createdRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, peer1ID, []string{}, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, false, userID, baseRoute.KeepRoute)
+ createdRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, peer1ID, []string{}, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, false, userID, baseRoute.KeepRoute)
require.NoError(t, err)
noDisabledRoutes, err := am.GetNetworkMap(context.Background(), peer1ID)
@@ -1467,3 +1480,300 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
return am.Store.GetAccount(context.Background(), account.Id)
}
+
+func TestAccount_getPeersRoutesFirewall(t *testing.T) {
+ var (
+ peerBIp = "100.65.80.39"
+ peerCIp = "100.65.254.139"
+ peerHIp = "100.65.29.55"
+ )
+
+ account := &Account{
+ Peers: map[string]*nbpeer.Peer{
+ "peerA": {
+ ID: "peerA",
+ IP: net.ParseIP("100.65.14.88"),
+ Status: &nbpeer.PeerStatus{},
+ Meta: nbpeer.PeerSystemMeta{
+ GoOS: "linux",
+ },
+ },
+ "peerB": {
+ ID: "peerB",
+ IP: net.ParseIP(peerBIp),
+ Status: &nbpeer.PeerStatus{},
+ Meta: nbpeer.PeerSystemMeta{},
+ },
+ "peerC": {
+ ID: "peerC",
+ IP: net.ParseIP(peerCIp),
+ Status: &nbpeer.PeerStatus{},
+ },
+ "peerD": {
+ ID: "peerD",
+ IP: net.ParseIP("100.65.62.5"),
+ Status: &nbpeer.PeerStatus{},
+ Meta: nbpeer.PeerSystemMeta{
+ GoOS: "linux",
+ },
+ },
+ "peerE": {
+ ID: "peerE",
+ IP: net.ParseIP("100.65.32.206"),
+ Key: peer1Key,
+ Status: &nbpeer.PeerStatus{},
+ Meta: nbpeer.PeerSystemMeta{
+ GoOS: "linux",
+ },
+ },
+ "peerF": {
+ ID: "peerF",
+ IP: net.ParseIP("100.65.250.202"),
+ Status: &nbpeer.PeerStatus{},
+ },
+ "peerG": {
+ ID: "peerG",
+ IP: net.ParseIP("100.65.13.186"),
+ Status: &nbpeer.PeerStatus{},
+ },
+ "peerH": {
+ ID: "peerH",
+ IP: net.ParseIP(peerHIp),
+ Status: &nbpeer.PeerStatus{},
+ },
+ },
+ Groups: map[string]*nbgroup.Group{
+ "routingPeer1": {
+ ID: "routingPeer1",
+ Name: "RoutingPeer1",
+ Peers: []string{
+ "peerA",
+ },
+ },
+ "routingPeer2": {
+ ID: "routingPeer2",
+ Name: "RoutingPeer2",
+ Peers: []string{
+ "peerD",
+ },
+ },
+ "route1": {
+ ID: "route1",
+ Name: "Route1",
+ Peers: []string{},
+ },
+ "route2": {
+ ID: "route2",
+ Name: "Route2",
+ Peers: []string{},
+ },
+ "finance": {
+ ID: "finance",
+ Name: "Finance",
+ Peers: []string{
+ "peerF",
+ "peerG",
+ },
+ },
+ "dev": {
+ ID: "dev",
+ Name: "Dev",
+ Peers: []string{
+ "peerC",
+ "peerH",
+ "peerB",
+ },
+ },
+ "contractors": {
+ ID: "contractors",
+ Name: "Contractors",
+ Peers: []string{},
+ },
+ },
+ Routes: map[route.ID]*route.Route{
+ "route1": {
+ ID: "route1",
+ Network: netip.MustParsePrefix("192.168.0.0/16"),
+ NetID: "route1",
+ NetworkType: route.IPv4Network,
+ PeerGroups: []string{"routingPeer1", "routingPeer2"},
+ Description: "Route1 ha route",
+ Masquerade: false,
+ Metric: 9999,
+ Enabled: true,
+ Groups: []string{"dev"},
+ AccessControlGroups: []string{"route1"},
+ },
+ "route2": {
+ ID: "route2",
+ Network: existingNetwork,
+ NetID: "route2",
+ NetworkType: route.IPv4Network,
+ Peer: "peerE",
+ Description: "Allow",
+ Masquerade: false,
+ Metric: 9999,
+ Enabled: true,
+ Groups: []string{"finance"},
+ AccessControlGroups: []string{"route2"},
+ },
+ "route3": {
+ ID: "route3",
+ Network: netip.MustParsePrefix("192.0.2.0/32"),
+ Domains: domain.List{"example.com"},
+ NetID: "route3",
+ NetworkType: route.DomainNetwork,
+ Peer: "peerE",
+ Description: "Allow all traffic to routed DNS network",
+ Masquerade: false,
+ Metric: 9999,
+ Enabled: true,
+ Groups: []string{"contractors"},
+ AccessControlGroups: []string{},
+ },
+ },
+ Policies: []*Policy{
+ {
+ ID: "RuleRoute1",
+ Name: "Route1",
+ Enabled: true,
+ Rules: []*PolicyRule{
+ {
+ ID: "RuleRoute1",
+ Name: "ruleRoute1",
+ Bidirectional: true,
+ Enabled: true,
+ Protocol: PolicyRuleProtocolALL,
+ Action: PolicyTrafficActionAccept,
+ Ports: []string{"80", "320"},
+ Sources: []string{
+ "dev",
+ },
+ Destinations: []string{
+ "route1",
+ },
+ },
+ },
+ },
+ {
+ ID: "RuleRoute2",
+ Name: "Route2",
+ Enabled: true,
+ Rules: []*PolicyRule{
+ {
+ ID: "RuleRoute2",
+ Name: "ruleRoute2",
+ Bidirectional: true,
+ Enabled: true,
+ Protocol: PolicyRuleProtocolTCP,
+ Action: PolicyTrafficActionAccept,
+ PortRanges: []RulePortRange{
+ {
+ Start: 80,
+ End: 350,
+ }, {
+ Start: 80,
+ End: 350,
+ },
+ },
+ Sources: []string{
+ "finance",
+ },
+ Destinations: []string{
+ "route2",
+ },
+ },
+ },
+ },
+ },
+ }
+
+ validatedPeers := make(map[string]struct{})
+ for p := range account.Peers {
+ validatedPeers[p] = struct{}{}
+ }
+
+ t.Run("check applied policies for the route", func(t *testing.T) {
+ route1 := account.Routes["route1"]
+ policies := getAllRoutePoliciesFromGroups(account, route1.AccessControlGroups)
+ assert.Len(t, policies, 1)
+
+ route2 := account.Routes["route2"]
+ policies = getAllRoutePoliciesFromGroups(account, route2.AccessControlGroups)
+ assert.Len(t, policies, 1)
+
+ route3 := account.Routes["route3"]
+ policies = getAllRoutePoliciesFromGroups(account, route3.AccessControlGroups)
+ assert.Len(t, policies, 0)
+ })
+
+ t.Run("check peer routes firewall rules", func(t *testing.T) {
+ routesFirewallRules := account.getPeerRoutesFirewallRules(context.Background(), "peerA", validatedPeers)
+ assert.Len(t, routesFirewallRules, 2)
+
+ expectedRoutesFirewallRules := []*RouteFirewallRule{
+ {
+ SourceRanges: []string{
+ fmt.Sprintf(AllowedIPsFormat, peerCIp),
+ fmt.Sprintf(AllowedIPsFormat, peerHIp),
+ fmt.Sprintf(AllowedIPsFormat, peerBIp),
+ },
+ Action: "accept",
+ Destination: "192.168.0.0/16",
+ Protocol: "all",
+ Port: 80,
+ },
+ {
+ SourceRanges: []string{
+ fmt.Sprintf(AllowedIPsFormat, peerCIp),
+ fmt.Sprintf(AllowedIPsFormat, peerHIp),
+ fmt.Sprintf(AllowedIPsFormat, peerBIp),
+ },
+ Action: "accept",
+ Destination: "192.168.0.0/16",
+ Protocol: "all",
+ Port: 320,
+ },
+ }
+ assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules)
+
+ //peerD is also the routing peer for route1, should contain same routes firewall rules as peerA
+ routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerD", validatedPeers)
+ assert.Len(t, routesFirewallRules, 2)
+ assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules)
+
+ // peerE is a single routing peer for route 2 and route 3
+ routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerE", validatedPeers)
+ assert.Len(t, routesFirewallRules, 3)
+
+ expectedRoutesFirewallRules = []*RouteFirewallRule{
+ {
+ SourceRanges: []string{"100.65.250.202/32", "100.65.13.186/32"},
+ Action: "accept",
+ Destination: existingNetwork.String(),
+ Protocol: "tcp",
+ PortRange: RulePortRange{Start: 80, End: 350},
+ },
+ {
+ SourceRanges: []string{"0.0.0.0/0"},
+ Action: "accept",
+ Destination: "192.0.2.0/32",
+ Protocol: "all",
+ IsDynamic: true,
+ },
+ {
+ SourceRanges: []string{"::/0"},
+ Action: "accept",
+ Destination: "192.0.2.0/32",
+ Protocol: "all",
+ IsDynamic: true,
+ },
+ }
+ assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules)
+
+ // peerC is part of route1 distribution groups but should not receive the routes firewall rules
+ routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerC", validatedPeers)
+ assert.Len(t, routesFirewallRules, 0)
+ })
+
+}
diff --git a/route/route.go b/route/route.go
index eb6c36bd8..e23801e6e 100644
--- a/route/route.go
+++ b/route/route.go
@@ -100,6 +100,7 @@ type Route struct {
Metric int
Enabled bool
Groups []string `gorm:"serializer:json"`
+ AccessControlGroups []string `gorm:"serializer:json"`
}
// EventMeta returns activity event meta related to the route
@@ -123,6 +124,7 @@ func (r *Route) Copy() *Route {
Masquerade: r.Masquerade,
Enabled: r.Enabled,
Groups: slices.Clone(r.Groups),
+ AccessControlGroups: slices.Clone(r.AccessControlGroups),
}
return route
}
@@ -147,7 +149,8 @@ func (r *Route) IsEqual(other *Route) bool {
other.Masquerade == r.Masquerade &&
other.Enabled == r.Enabled &&
slices.Equal(r.Groups, other.Groups) &&
- slices.Equal(r.PeerGroups, other.PeerGroups)
+ slices.Equal(r.PeerGroups, other.PeerGroups)&&
+ slices.Equal(r.AccessControlGroups, other.AccessControlGroups)
}
// IsDynamic returns if the route is dynamic, i.e. has domains
From b7b08281336676f356c8e1032b1907c617b6439c Mon Sep 17 00:00:00 2001
From: Maycon Santos
Date: Wed, 2 Oct 2024 15:14:09 +0200
Subject: [PATCH 85/89] [client] Adjust relay worker log level and message
(#2683)
---
client/internal/peer/worker_relay.go | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/client/internal/peer/worker_relay.go b/client/internal/peer/worker_relay.go
index 6bb385d3e..c02fccebc 100644
--- a/client/internal/peer/worker_relay.go
+++ b/client/internal/peer/worker_relay.go
@@ -74,7 +74,7 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
relayedConn, err := w.relayManager.OpenConn(srv, w.config.Key)
if err != nil {
if errors.Is(err, relayClient.ErrConnAlreadyExists) {
- w.log.Infof("do not need to reopen relay connection")
+ w.log.Debugf("handled offer by reusing existing relay connection")
return
}
w.log.Errorf("failed to open connection via Relay: %s", err)
From 7e5d3bdfe2306f69ef5daab3c742c4d206c69406 Mon Sep 17 00:00:00 2001
From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com>
Date: Wed, 2 Oct 2024 15:33:38 +0200
Subject: [PATCH 86/89] [signal] Move dummy signal message handling into
dispatcher (#2686)
---
go.mod | 2 +-
go.sum | 4 ++--
signal/server/signal.go | 5 -----
3 files changed, 3 insertions(+), 8 deletions(-)
diff --git a/go.mod b/go.mod
index c29ba0763..e7137ce5b 100644
--- a/go.mod
+++ b/go.mod
@@ -60,7 +60,7 @@ require (
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd
- github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240929132730-cbef5d331757
+ github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241002125159-0e132af8c51f
github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0
github.com/patrickmn/go-cache v2.1.0+incompatible
diff --git a/go.sum b/go.sum
index 1f6cbb785..4563dc933 100644
--- a/go.sum
+++ b/go.sum
@@ -525,8 +525,8 @@ github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-
github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
-github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240929132730-cbef5d331757 h1:6XniCzDt+1jvXWMUY4EDH0Hi5RXbUOYB0A8XEQqSlZk=
-github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240929132730-cbef5d331757/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
+github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241002125159-0e132af8c51f h1:Rl23OSc2xKFyxiuBXtWDMzhZBV4gOM7lhFxvYoCmBZg=
+github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241002125159-0e132af8c51f/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs=
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM=
diff --git a/signal/server/signal.go b/signal/server/signal.go
index 386ce7238..63cc43bd7 100644
--- a/signal/server/signal.go
+++ b/signal/server/signal.go
@@ -71,11 +71,6 @@ func NewServer(ctx context.Context, meter metric.Meter) (*Server, error) {
func (s *Server) Send(ctx context.Context, msg *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
log.Debugf("received a new message to send from peer [%s] to peer [%s]", msg.Key, msg.RemoteKey)
- if msg.RemoteKey == "dummy" {
- // Test message send during netbird status
- return &proto.EncryptedMessage{}, nil
- }
-
if _, found := s.registry.Get(msg.RemoteKey); found {
s.forwardMessageToPeer(ctx, msg)
return &proto.EncryptedMessage{}, nil
From fd67892cb4fa0c4e4c23c0511796cc7ce9fe296c Mon Sep 17 00:00:00 2001
From: Zoltan Papp
Date: Wed, 2 Oct 2024 18:24:22 +0200
Subject: [PATCH 87/89] [client] Refactor/iface pkg (#2646)
Refactor the flat code structure
---
.github/workflows/golang-test-freebsd.yml | 2 +-
.github/workflows/golang-test-linux.yml | 2 +-
client/android/client.go | 6 +--
client/cmd/login_test.go | 2 +-
client/cmd/root_test.go | 2 +-
client/cmd/up.go | 2 +-
client/firewall/iface.go | 6 +--
client/firewall/iptables/manager_linux.go | 2 +-
.../firewall/iptables/manager_linux_test.go | 2 +-
client/firewall/nftables/acl_linux.go | 2 +-
.../firewall/nftables/manager_linux_test.go | 2 +-
client/firewall/uspfilter/uspfilter.go | 5 +-
client/firewall/uspfilter/uspfilter_test.go | 21 ++++----
{iface => client/iface}/bind/bind.go | 0
{iface => client/iface}/bind/udp_mux.go | 0
.../iface}/bind/udp_mux_universal.go | 0
.../iface}/bind/udp_muxed_conn.go | 0
client/iface/configurer/err.go | 5 ++
.../iface/configurer/kernel_unix.go | 27 +++++-----
{iface => client/iface/configurer}/name.go | 2 +-
.../iface/configurer}/name_darwin.go | 2 +-
{iface => client/iface/configurer}/uapi.go | 2 +-
.../iface/configurer}/uapi_windows.go | 2 +-
.../iface/configurer/usp.go | 24 ++++-----
.../iface/configurer/usp_test.go | 2 +-
client/iface/configurer/wgstats.go | 9 ++++
client/iface/device.go | 18 +++++++
.../iface/device/adapter.go | 2 +-
{iface => client/iface/device}/address.go | 8 +--
.../iface/device/args.go | 2 +-
.../iface/device/device_android.go | 54 +++++++++----------
.../iface/device/device_darwin.go | 49 ++++++++---------
.../iface/device/device_filter.go | 19 +++----
.../iface/device/device_filter_test.go | 13 ++---
.../iface/device/device_ios.go | 49 ++++++++---------
.../iface/device/device_kernel_unix.go | 31 +++++------
.../iface/device/device_netstack.go | 49 ++++++++---------
.../iface/device/device_usp_unix.go | 52 +++++++++---------
.../iface/device/device_windows.go | 47 ++++++++--------
client/iface/device/interface.go | 20 +++++++
.../iface/device/kernel_module.go | 2 +-
.../iface/device/kernel_module_freebsd.go | 6 +--
.../iface/device/kernel_module_linux.go | 6 +--
.../iface/device/kernel_module_linux_test.go | 8 +--
.../iface/device/wg_link_freebsd.go | 5 +-
.../iface/device/wg_link_linux.go | 2 +-
{iface => client/iface/device}/wg_log.go | 2 +-
client/iface/device/windows_guid.go | 4 ++
client/iface/device_android.go | 16 ++++++
{iface => client/iface}/freebsd/errors.go | 0
{iface => client/iface}/freebsd/iface.go | 0
.../iface}/freebsd/iface_internal_test.go | 0
{iface => client/iface}/freebsd/link.go | 0
{iface => client/iface}/iface.go | 53 +++++++++---------
{iface => client/iface}/iface_android.go | 9 ++--
{iface => client/iface}/iface_create.go | 0
{iface => client/iface}/iface_darwin.go | 13 ++---
{iface => client/iface}/iface_destroy_bsd.go | 0
.../iface}/iface_destroy_linux.go | 0
.../iface}/iface_destroy_mobile.go | 0
.../iface}/iface_destroy_windows.go | 0
{iface => client/iface}/iface_ios.go | 9 ++--
{iface => client/iface}/iface_moc.go | 24 +++++----
{iface => client/iface}/iface_test.go | 6 ++-
{iface => client/iface}/iface_unix.go | 19 +++----
{iface => client/iface}/iface_windows.go | 15 +++---
{iface => client/iface}/iwginterface.go | 14 ++---
.../iface}/iwginterface_windows.go | 14 ++---
{iface => client/iface}/mocks/README.md | 0
{iface => client/iface}/mocks/filter.go | 2 +-
.../iface}/mocks/iface/mocks/filter.go | 2 +-
{iface => client/iface}/mocks/tun.go | 0
{iface => client/iface}/netstack/dialer.go | 0
{iface => client/iface}/netstack/env.go | 0
{iface => client/iface}/netstack/proxy.go | 0
{iface => client/iface}/netstack/tun.go | 0
client/internal/acl/manager_test.go | 2 +-
client/internal/acl/mocks/iface_mapper.go | 5 +-
client/internal/config.go | 2 +-
client/internal/connect.go | 7 +--
client/internal/dns/response_writer_test.go | 2 +-
client/internal/dns/server_test.go | 18 ++++---
client/internal/dns/wgiface.go | 10 ++--
client/internal/dns/wgiface_windows.go | 12 +++--
client/internal/engine.go | 13 ++---
client/internal/engine_test.go | 7 +--
client/internal/mobile_dependency.go | 4 +-
client/internal/peer/conn.go | 5 +-
client/internal/peer/conn_test.go | 2 +-
client/internal/peer/status.go | 6 +--
client/internal/peer/worker_ice.go | 4 +-
client/internal/routemanager/client.go | 2 +-
client/internal/routemanager/dynamic/route.go | 2 +-
client/internal/routemanager/manager.go | 5 +-
client/internal/routemanager/manager_test.go | 2 +-
client/internal/routemanager/mock.go | 2 +-
.../internal/routemanager/server_android.go | 2 +-
.../routemanager/server_nonandroid.go | 2 +-
.../routemanager/sysctl/sysctl_linux.go | 2 +-
.../routemanager/systemops/systemops.go | 2 +-
.../systemops/systemops_generic.go | 2 +-
.../systemops/systemops_generic_test.go | 2 +-
iface/tun.go | 21 --------
iface/wg_configurer.go | 21 --------
util/net/net.go | 2 +-
105 files changed, 505 insertions(+), 438 deletions(-)
rename {iface => client/iface}/bind/bind.go (100%)
rename {iface => client/iface}/bind/udp_mux.go (100%)
rename {iface => client/iface}/bind/udp_mux_universal.go (100%)
rename {iface => client/iface}/bind/udp_muxed_conn.go (100%)
create mode 100644 client/iface/configurer/err.go
rename iface/wg_configurer_kernel_unix.go => client/iface/configurer/kernel_unix.go (83%)
rename {iface => client/iface/configurer}/name.go (87%)
rename {iface => client/iface/configurer}/name_darwin.go (86%)
rename {iface => client/iface/configurer}/uapi.go (96%)
rename {iface => client/iface/configurer}/uapi_windows.go (88%)
rename iface/wg_configurer_usp.go => client/iface/configurer/usp.go (93%)
rename iface/wg_configurer_usp_test.go => client/iface/configurer/usp_test.go (99%)
create mode 100644 client/iface/configurer/wgstats.go
create mode 100644 client/iface/device.go
rename iface/tun_adapter.go => client/iface/device/adapter.go (94%)
rename {iface => client/iface/device}/address.go (69%)
rename iface/tun_args.go => client/iface/device/args.go (88%)
rename iface/tun_android.go => client/iface/device/device_android.go (61%)
rename iface/tun_darwin.go => client/iface/device/device_darwin.go (69%)
rename iface/device_wrapper.go => client/iface/device/device_filter.go (81%)
rename iface/device_wrapper_test.go => client/iface/device/device_filter_test.go (95%)
rename iface/tun_ios.go => client/iface/device/device_ios.go (63%)
rename iface/tun_kernel_unix.go => client/iface/device/device_kernel_unix.go (75%)
rename iface/tun_netstack.go => client/iface/device/device_netstack.go (56%)
rename iface/tun_usp_unix.go => client/iface/device/device_usp_unix.go (63%)
rename iface/tun_windows.go => client/iface/device/device_windows.go (75%)
create mode 100644 client/iface/device/interface.go
rename iface/module.go => client/iface/device/kernel_module.go (92%)
rename iface/module_freebsd.go => client/iface/device/kernel_module_freebsd.go (84%)
rename iface/module_linux.go => client/iface/device/kernel_module_linux.go (98%)
rename iface/module_linux_test.go => client/iface/device/kernel_module_linux_test.go (98%)
rename iface/tun_link_freebsd.go => client/iface/device/wg_link_freebsd.go (95%)
rename iface/tun_link_linux.go => client/iface/device/wg_link_linux.go (99%)
rename {iface => client/iface/device}/wg_log.go (93%)
create mode 100644 client/iface/device/windows_guid.go
create mode 100644 client/iface/device_android.go
rename {iface => client/iface}/freebsd/errors.go (100%)
rename {iface => client/iface}/freebsd/iface.go (100%)
rename {iface => client/iface}/freebsd/iface_internal_test.go (100%)
rename {iface => client/iface}/freebsd/link.go (100%)
rename {iface => client/iface}/iface.go (79%)
rename {iface => client/iface}/iface_android.go (67%)
rename {iface => client/iface}/iface_create.go (100%)
rename {iface => client/iface}/iface_darwin.go (68%)
rename {iface => client/iface}/iface_destroy_bsd.go (100%)
rename {iface => client/iface}/iface_destroy_linux.go (100%)
rename {iface => client/iface}/iface_destroy_mobile.go (100%)
rename {iface => client/iface}/iface_destroy_windows.go (100%)
rename {iface => client/iface}/iface_ios.go (59%)
rename {iface => client/iface}/iface_moc.go (76%)
rename {iface => client/iface}/iface_test.go (98%)
rename {iface => client/iface}/iface_unix.go (53%)
rename {iface => client/iface}/iface_windows.go (52%)
rename {iface => client/iface}/iwginterface.go (65%)
rename {iface => client/iface}/iwginterface_windows.go (65%)
rename {iface => client/iface}/mocks/README.md (100%)
rename {iface => client/iface}/mocks/filter.go (97%)
rename {iface => client/iface}/mocks/iface/mocks/filter.go (97%)
rename {iface => client/iface}/mocks/tun.go (100%)
rename {iface => client/iface}/netstack/dialer.go (100%)
rename {iface => client/iface}/netstack/env.go (100%)
rename {iface => client/iface}/netstack/proxy.go (100%)
rename {iface => client/iface}/netstack/tun.go (100%)
delete mode 100644 iface/tun.go
delete mode 100644 iface/wg_configurer.go
diff --git a/.github/workflows/golang-test-freebsd.yml b/.github/workflows/golang-test-freebsd.yml
index 4f13ee30e..a2d743715 100644
--- a/.github/workflows/golang-test-freebsd.yml
+++ b/.github/workflows/golang-test-freebsd.yml
@@ -38,7 +38,7 @@ jobs:
time go test -timeout 1m -failfast ./dns/...
time go test -timeout 1m -failfast ./encryption/...
time go test -timeout 1m -failfast ./formatter/...
- time go test -timeout 1m -failfast ./iface/...
+ time go test -timeout 1m -failfast ./client/iface/...
time go test -timeout 1m -failfast ./route/...
time go test -timeout 1m -failfast ./sharedsock/...
time go test -timeout 1m -failfast ./signal/...
diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml
index 2d5cf2856..524f35f6f 100644
--- a/.github/workflows/golang-test-linux.yml
+++ b/.github/workflows/golang-test-linux.yml
@@ -80,7 +80,7 @@ jobs:
run: git --no-pager diff --exit-code
- name: Generate Iface Test bin
- run: CGO_ENABLED=0 go test -c -o iface-testing.bin ./iface/
+ run: CGO_ENABLED=0 go test -c -o iface-testing.bin ./client/iface/
- name: Generate Shared Sock Test bin
run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock
diff --git a/client/android/client.go b/client/android/client.go
index d937e132e..229bcd974 100644
--- a/client/android/client.go
+++ b/client/android/client.go
@@ -8,6 +8,7 @@ import (
log "github.com/sirupsen/logrus"
+ "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
@@ -15,7 +16,6 @@ import (
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/formatter"
- "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/util/net"
)
@@ -26,7 +26,7 @@ type ConnectionListener interface {
// TunAdapter export internal TunAdapter for mobile
type TunAdapter interface {
- iface.TunAdapter
+ device.TunAdapter
}
// IFaceDiscover export internal IFaceDiscover for mobile
@@ -51,7 +51,7 @@ func init() {
// Client struct manage the life circle of background service
type Client struct {
cfgFile string
- tunAdapter iface.TunAdapter
+ tunAdapter device.TunAdapter
iFaceDiscover IFaceDiscover
recorder *peer.Status
ctxCancel context.CancelFunc
diff --git a/client/cmd/login_test.go b/client/cmd/login_test.go
index 6bb7eff4f..fa20435ea 100644
--- a/client/cmd/login_test.go
+++ b/client/cmd/login_test.go
@@ -5,8 +5,8 @@ import (
"strings"
"testing"
+ "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal"
- "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/util"
)
diff --git a/client/cmd/root_test.go b/client/cmd/root_test.go
index f2805cf35..4cbbe8783 100644
--- a/client/cmd/root_test.go
+++ b/client/cmd/root_test.go
@@ -7,7 +7,7 @@ import (
"github.com/spf13/cobra"
- "github.com/netbirdio/netbird/iface"
+ "github.com/netbirdio/netbird/client/iface"
)
func TestInitCommands(t *testing.T) {
diff --git a/client/cmd/up.go b/client/cmd/up.go
index b447f7141..05ecce9e0 100644
--- a/client/cmd/up.go
+++ b/client/cmd/up.go
@@ -15,11 +15,11 @@ import (
gstatus "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/durationpb"
+ "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/system"
- "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/util"
)
diff --git a/client/firewall/iface.go b/client/firewall/iface.go
index d0b5209c0..f349f9210 100644
--- a/client/firewall/iface.go
+++ b/client/firewall/iface.go
@@ -1,13 +1,13 @@
package firewall
import (
- "github.com/netbirdio/netbird/iface"
+ "github.com/netbirdio/netbird/client/iface/device"
)
// IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface {
Name() string
- Address() iface.WGAddress
+ Address() device.WGAddress
IsUserspaceBind() bool
- SetFilter(iface.PacketFilter) error
+ SetFilter(device.PacketFilter) error
}
diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go
index fae41d9c5..6fefd58e6 100644
--- a/client/firewall/iptables/manager_linux.go
+++ b/client/firewall/iptables/manager_linux.go
@@ -11,7 +11,7 @@ import (
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
- "github.com/netbirdio/netbird/iface"
+ "github.com/netbirdio/netbird/client/iface"
)
// Manager of iptables firewall
diff --git a/client/firewall/iptables/manager_linux_test.go b/client/firewall/iptables/manager_linux_test.go
index 0072aa159..498d8f58b 100644
--- a/client/firewall/iptables/manager_linux_test.go
+++ b/client/firewall/iptables/manager_linux_test.go
@@ -11,7 +11,7 @@ import (
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager"
- "github.com/netbirdio/netbird/iface"
+ "github.com/netbirdio/netbird/client/iface"
)
var ifaceMock = &iFaceMock{
diff --git a/client/firewall/nftables/acl_linux.go b/client/firewall/nftables/acl_linux.go
index 85cba9e1c..eaf7fb6a0 100644
--- a/client/firewall/nftables/acl_linux.go
+++ b/client/firewall/nftables/acl_linux.go
@@ -16,7 +16,7 @@ import (
"golang.org/x/sys/unix"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
- "github.com/netbirdio/netbird/iface"
+ "github.com/netbirdio/netbird/client/iface"
)
const (
diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go
index 7f78a9a2e..904050a51 100644
--- a/client/firewall/nftables/manager_linux_test.go
+++ b/client/firewall/nftables/manager_linux_test.go
@@ -15,7 +15,7 @@ import (
"golang.org/x/sys/unix"
fw "github.com/netbirdio/netbird/client/firewall/manager"
- "github.com/netbirdio/netbird/iface"
+ "github.com/netbirdio/netbird/client/iface"
)
var ifaceMock = &iFaceMock{
diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go
index 681058ea9..0e3ee9799 100644
--- a/client/firewall/uspfilter/uspfilter.go
+++ b/client/firewall/uspfilter/uspfilter.go
@@ -12,7 +12,8 @@ import (
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
- "github.com/netbirdio/netbird/iface"
+ "github.com/netbirdio/netbird/client/iface"
+ "github.com/netbirdio/netbird/client/iface/device"
)
const layerTypeAll = 0
@@ -23,7 +24,7 @@ var (
// IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface {
- SetFilter(iface.PacketFilter) error
+ SetFilter(device.PacketFilter) error
Address() iface.WGAddress
}
diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go
index dd7366fe9..c188deea4 100644
--- a/client/firewall/uspfilter/uspfilter_test.go
+++ b/client/firewall/uspfilter/uspfilter_test.go
@@ -11,15 +11,16 @@ import (
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager"
- "github.com/netbirdio/netbird/iface"
+ "github.com/netbirdio/netbird/client/iface"
+ "github.com/netbirdio/netbird/client/iface/device"
)
type IFaceMock struct {
- SetFilterFunc func(iface.PacketFilter) error
+ SetFilterFunc func(device.PacketFilter) error
AddressFunc func() iface.WGAddress
}
-func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error {
+func (i *IFaceMock) SetFilter(iface device.PacketFilter) error {
if i.SetFilterFunc == nil {
return fmt.Errorf("not implemented")
}
@@ -35,7 +36,7 @@ func (i *IFaceMock) Address() iface.WGAddress {
func TestManagerCreate(t *testing.T) {
ifaceMock := &IFaceMock{
- SetFilterFunc: func(iface.PacketFilter) error { return nil },
+ SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock)
@@ -52,7 +53,7 @@ func TestManagerCreate(t *testing.T) {
func TestManagerAddPeerFiltering(t *testing.T) {
isSetFilterCalled := false
ifaceMock := &IFaceMock{
- SetFilterFunc: func(iface.PacketFilter) error {
+ SetFilterFunc: func(device.PacketFilter) error {
isSetFilterCalled = true
return nil
},
@@ -90,7 +91,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
func TestManagerDeleteRule(t *testing.T) {
ifaceMock := &IFaceMock{
- SetFilterFunc: func(iface.PacketFilter) error { return nil },
+ SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock)
@@ -236,7 +237,7 @@ func TestAddUDPPacketHook(t *testing.T) {
func TestManagerReset(t *testing.T) {
ifaceMock := &IFaceMock{
- SetFilterFunc: func(iface.PacketFilter) error { return nil },
+ SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock)
@@ -271,7 +272,7 @@ func TestManagerReset(t *testing.T) {
func TestNotMatchByIP(t *testing.T) {
ifaceMock := &IFaceMock{
- SetFilterFunc: func(iface.PacketFilter) error { return nil },
+ SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock)
@@ -339,7 +340,7 @@ func TestNotMatchByIP(t *testing.T) {
func TestRemovePacketHook(t *testing.T) {
// creating mock iface
iface := &IFaceMock{
- SetFilterFunc: func(iface.PacketFilter) error { return nil },
+ SetFilterFunc: func(device.PacketFilter) error { return nil },
}
// creating manager instance
@@ -388,7 +389,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
// just check on the local interface
ifaceMock := &IFaceMock{
- SetFilterFunc: func(iface.PacketFilter) error { return nil },
+ SetFilterFunc: func(device.PacketFilter) error { return nil },
}
manager, err := Create(ifaceMock)
require.NoError(t, err)
diff --git a/iface/bind/bind.go b/client/iface/bind/bind.go
similarity index 100%
rename from iface/bind/bind.go
rename to client/iface/bind/bind.go
diff --git a/iface/bind/udp_mux.go b/client/iface/bind/udp_mux.go
similarity index 100%
rename from iface/bind/udp_mux.go
rename to client/iface/bind/udp_mux.go
diff --git a/iface/bind/udp_mux_universal.go b/client/iface/bind/udp_mux_universal.go
similarity index 100%
rename from iface/bind/udp_mux_universal.go
rename to client/iface/bind/udp_mux_universal.go
diff --git a/iface/bind/udp_muxed_conn.go b/client/iface/bind/udp_muxed_conn.go
similarity index 100%
rename from iface/bind/udp_muxed_conn.go
rename to client/iface/bind/udp_muxed_conn.go
diff --git a/client/iface/configurer/err.go b/client/iface/configurer/err.go
new file mode 100644
index 000000000..a64bba2dd
--- /dev/null
+++ b/client/iface/configurer/err.go
@@ -0,0 +1,5 @@
+package configurer
+
+import "errors"
+
+var ErrPeerNotFound = errors.New("peer not found")
diff --git a/iface/wg_configurer_kernel_unix.go b/client/iface/configurer/kernel_unix.go
similarity index 83%
rename from iface/wg_configurer_kernel_unix.go
rename to client/iface/configurer/kernel_unix.go
index 8b47082da..7c1c41669 100644
--- a/iface/wg_configurer_kernel_unix.go
+++ b/client/iface/configurer/kernel_unix.go
@@ -1,6 +1,6 @@
//go:build (linux && !android) || freebsd
-package iface
+package configurer
import (
"fmt"
@@ -12,18 +12,17 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
-type wgKernelConfigurer struct {
+type KernelConfigurer struct {
deviceName string
}
-func newWGConfigurer(deviceName string) wgConfigurer {
- wgc := &wgKernelConfigurer{
+func NewKernelConfigurer(deviceName string) *KernelConfigurer {
+ return &KernelConfigurer{
deviceName: deviceName,
}
- return wgc
}
-func (c *wgKernelConfigurer) configureInterface(privateKey string, port int) error {
+func (c *KernelConfigurer) ConfigureInterface(privateKey string, port int) error {
log.Debugf("adding Wireguard private key")
key, err := wgtypes.ParseKey(privateKey)
if err != nil {
@@ -44,7 +43,7 @@ func (c *wgKernelConfigurer) configureInterface(privateKey string, port int) err
return nil
}
-func (c *wgKernelConfigurer) updatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
+func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
// parse allowed ips
_, ipNet, err := net.ParseCIDR(allowedIps)
if err != nil {
@@ -75,7 +74,7 @@ func (c *wgKernelConfigurer) updatePeer(peerKey string, allowedIps string, keepA
return nil
}
-func (c *wgKernelConfigurer) removePeer(peerKey string) error {
+func (c *KernelConfigurer) RemovePeer(peerKey string) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
@@ -96,7 +95,7 @@ func (c *wgKernelConfigurer) removePeer(peerKey string) error {
return nil
}
-func (c *wgKernelConfigurer) addAllowedIP(peerKey string, allowedIP string) error {
+func (c *KernelConfigurer) AddAllowedIP(peerKey string, allowedIP string) error {
_, ipNet, err := net.ParseCIDR(allowedIP)
if err != nil {
return err
@@ -123,7 +122,7 @@ func (c *wgKernelConfigurer) addAllowedIP(peerKey string, allowedIP string) erro
return nil
}
-func (c *wgKernelConfigurer) removeAllowedIP(peerKey string, allowedIP string) error {
+func (c *KernelConfigurer) RemoveAllowedIP(peerKey string, allowedIP string) error {
_, ipNet, err := net.ParseCIDR(allowedIP)
if err != nil {
return fmt.Errorf("parse allowed IP: %w", err)
@@ -165,7 +164,7 @@ func (c *wgKernelConfigurer) removeAllowedIP(peerKey string, allowedIP string) e
return nil
}
-func (c *wgKernelConfigurer) getPeer(ifaceName, peerPubKey string) (wgtypes.Peer, error) {
+func (c *KernelConfigurer) getPeer(ifaceName, peerPubKey string) (wgtypes.Peer, error) {
wg, err := wgctrl.New()
if err != nil {
return wgtypes.Peer{}, fmt.Errorf("wgctl: %w", err)
@@ -189,7 +188,7 @@ func (c *wgKernelConfigurer) getPeer(ifaceName, peerPubKey string) (wgtypes.Peer
return wgtypes.Peer{}, ErrPeerNotFound
}
-func (c *wgKernelConfigurer) configure(config wgtypes.Config) error {
+func (c *KernelConfigurer) configure(config wgtypes.Config) error {
wg, err := wgctrl.New()
if err != nil {
return err
@@ -205,10 +204,10 @@ func (c *wgKernelConfigurer) configure(config wgtypes.Config) error {
return wg.ConfigureDevice(c.deviceName, config)
}
-func (c *wgKernelConfigurer) close() {
+func (c *KernelConfigurer) Close() {
}
-func (c *wgKernelConfigurer) getStats(peerKey string) (WGStats, error) {
+func (c *KernelConfigurer) GetStats(peerKey string) (WGStats, error) {
peer, err := c.getPeer(c.deviceName, peerKey)
if err != nil {
return WGStats{}, fmt.Errorf("get wireguard stats: %w", err)
diff --git a/iface/name.go b/client/iface/configurer/name.go
similarity index 87%
rename from iface/name.go
rename to client/iface/configurer/name.go
index 706cb65ad..e2133d0ea 100644
--- a/iface/name.go
+++ b/client/iface/configurer/name.go
@@ -1,6 +1,6 @@
//go:build linux || windows || freebsd
-package iface
+package configurer
// WgInterfaceDefault is a default interface name of Wiretrustee
const WgInterfaceDefault = "wt0"
diff --git a/iface/name_darwin.go b/client/iface/configurer/name_darwin.go
similarity index 86%
rename from iface/name_darwin.go
rename to client/iface/configurer/name_darwin.go
index a4016ce15..034ce388d 100644
--- a/iface/name_darwin.go
+++ b/client/iface/configurer/name_darwin.go
@@ -1,6 +1,6 @@
//go:build darwin
-package iface
+package configurer
// WgInterfaceDefault is a default interface name of Wiretrustee
const WgInterfaceDefault = "utun100"
diff --git a/iface/uapi.go b/client/iface/configurer/uapi.go
similarity index 96%
rename from iface/uapi.go
rename to client/iface/configurer/uapi.go
index d7ff52e7b..4801841de 100644
--- a/iface/uapi.go
+++ b/client/iface/configurer/uapi.go
@@ -1,6 +1,6 @@
//go:build !windows
-package iface
+package configurer
import (
"net"
diff --git a/iface/uapi_windows.go b/client/iface/configurer/uapi_windows.go
similarity index 88%
rename from iface/uapi_windows.go
rename to client/iface/configurer/uapi_windows.go
index e1f466364..46fa90c2e 100644
--- a/iface/uapi_windows.go
+++ b/client/iface/configurer/uapi_windows.go
@@ -1,4 +1,4 @@
-package iface
+package configurer
import (
"net"
diff --git a/iface/wg_configurer_usp.go b/client/iface/configurer/usp.go
similarity index 93%
rename from iface/wg_configurer_usp.go
rename to client/iface/configurer/usp.go
index cd1d9d0b6..21d65ab2a 100644
--- a/iface/wg_configurer_usp.go
+++ b/client/iface/configurer/usp.go
@@ -1,4 +1,4 @@
-package iface
+package configurer
import (
"encoding/hex"
@@ -19,15 +19,15 @@ import (
var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found")
-type wgUSPConfigurer struct {
+type WGUSPConfigurer struct {
device *device.Device
deviceName string
uapiListener net.Listener
}
-func newWGUSPConfigurer(device *device.Device, deviceName string) wgConfigurer {
- wgCfg := &wgUSPConfigurer{
+func NewUSPConfigurer(device *device.Device, deviceName string) *WGUSPConfigurer {
+ wgCfg := &WGUSPConfigurer{
device: device,
deviceName: deviceName,
}
@@ -35,7 +35,7 @@ func newWGUSPConfigurer(device *device.Device, deviceName string) wgConfigurer {
return wgCfg
}
-func (c *wgUSPConfigurer) configureInterface(privateKey string, port int) error {
+func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error {
log.Debugf("adding Wireguard private key")
key, err := wgtypes.ParseKey(privateKey)
if err != nil {
@@ -52,7 +52,7 @@ func (c *wgUSPConfigurer) configureInterface(privateKey string, port int) error
return c.device.IpcSet(toWgUserspaceString(config))
}
-func (c *wgUSPConfigurer) updatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
+func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
// parse allowed ips
_, ipNet, err := net.ParseCIDR(allowedIps)
if err != nil {
@@ -80,7 +80,7 @@ func (c *wgUSPConfigurer) updatePeer(peerKey string, allowedIps string, keepAliv
return c.device.IpcSet(toWgUserspaceString(config))
}
-func (c *wgUSPConfigurer) removePeer(peerKey string) error {
+func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
@@ -97,7 +97,7 @@ func (c *wgUSPConfigurer) removePeer(peerKey string) error {
return c.device.IpcSet(toWgUserspaceString(config))
}
-func (c *wgUSPConfigurer) addAllowedIP(peerKey string, allowedIP string) error {
+func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP string) error {
_, ipNet, err := net.ParseCIDR(allowedIP)
if err != nil {
return err
@@ -121,7 +121,7 @@ func (c *wgUSPConfigurer) addAllowedIP(peerKey string, allowedIP string) error {
return c.device.IpcSet(toWgUserspaceString(config))
}
-func (c *wgUSPConfigurer) removeAllowedIP(peerKey string, ip string) error {
+func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, ip string) error {
ipc, err := c.device.IpcGet()
if err != nil {
return err
@@ -185,7 +185,7 @@ func (c *wgUSPConfigurer) removeAllowedIP(peerKey string, ip string) error {
}
// startUAPI starts the UAPI listener for managing the WireGuard interface via external tool
-func (t *wgUSPConfigurer) startUAPI() {
+func (t *WGUSPConfigurer) startUAPI() {
var err error
t.uapiListener, err = openUAPI(t.deviceName)
if err != nil {
@@ -207,7 +207,7 @@ func (t *wgUSPConfigurer) startUAPI() {
}(t.uapiListener)
}
-func (t *wgUSPConfigurer) close() {
+func (t *WGUSPConfigurer) Close() {
if t.uapiListener != nil {
err := t.uapiListener.Close()
if err != nil {
@@ -223,7 +223,7 @@ func (t *wgUSPConfigurer) close() {
}
}
-func (t *wgUSPConfigurer) getStats(peerKey string) (WGStats, error) {
+func (t *WGUSPConfigurer) GetStats(peerKey string) (WGStats, error) {
ipc, err := t.device.IpcGet()
if err != nil {
return WGStats{}, fmt.Errorf("ipc get: %w", err)
diff --git a/iface/wg_configurer_usp_test.go b/client/iface/configurer/usp_test.go
similarity index 99%
rename from iface/wg_configurer_usp_test.go
rename to client/iface/configurer/usp_test.go
index ac0fc6130..775339f24 100644
--- a/iface/wg_configurer_usp_test.go
+++ b/client/iface/configurer/usp_test.go
@@ -1,4 +1,4 @@
-package iface
+package configurer
import (
"encoding/hex"
diff --git a/client/iface/configurer/wgstats.go b/client/iface/configurer/wgstats.go
new file mode 100644
index 000000000..56d0d7310
--- /dev/null
+++ b/client/iface/configurer/wgstats.go
@@ -0,0 +1,9 @@
+package configurer
+
+import "time"
+
+type WGStats struct {
+ LastHandshake time.Time
+ TxBytes int64
+ RxBytes int64
+}
diff --git a/client/iface/device.go b/client/iface/device.go
new file mode 100644
index 000000000..0d4e69145
--- /dev/null
+++ b/client/iface/device.go
@@ -0,0 +1,18 @@
+//go:build !android
+
+package iface
+
+import (
+ "github.com/netbirdio/netbird/client/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/device"
+)
+
+type WGTunDevice interface {
+ Create() (device.WGConfigurer, error)
+ Up() (*bind.UniversalUDPMuxDefault, error)
+ UpdateAddr(address WGAddress) error
+ WgAddress() WGAddress
+ DeviceName() string
+ Close() error
+ FilteredDevice() *device.FilteredDevice
+}
diff --git a/iface/tun_adapter.go b/client/iface/device/adapter.go
similarity index 94%
rename from iface/tun_adapter.go
rename to client/iface/device/adapter.go
index adec93ed1..6ebc05390 100644
--- a/iface/tun_adapter.go
+++ b/client/iface/device/adapter.go
@@ -1,4 +1,4 @@
-package iface
+package device
// TunAdapter is an interface for create tun device from external service
type TunAdapter interface {
diff --git a/iface/address.go b/client/iface/device/address.go
similarity index 69%
rename from iface/address.go
rename to client/iface/device/address.go
index 5ff4fbc06..15de301da 100644
--- a/iface/address.go
+++ b/client/iface/device/address.go
@@ -1,18 +1,18 @@
-package iface
+package device
import (
"fmt"
"net"
)
-// WGAddress Wireguard parsed address
+// WGAddress WireGuard parsed address
type WGAddress struct {
IP net.IP
Network *net.IPNet
}
-// parseWGAddress parse a string ("1.2.3.4/24") address to WG Address
-func parseWGAddress(address string) (WGAddress, error) {
+// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address
+func ParseWGAddress(address string) (WGAddress, error) {
ip, network, err := net.ParseCIDR(address)
if err != nil {
return WGAddress{}, err
diff --git a/iface/tun_args.go b/client/iface/device/args.go
similarity index 88%
rename from iface/tun_args.go
rename to client/iface/device/args.go
index 0eac2c4c0..d7b86b335 100644
--- a/iface/tun_args.go
+++ b/client/iface/device/args.go
@@ -1,4 +1,4 @@
-package iface
+package device
type MobileIFaceArguments struct {
TunAdapter TunAdapter // only for Android
diff --git a/iface/tun_android.go b/client/iface/device/device_android.go
similarity index 61%
rename from iface/tun_android.go
rename to client/iface/device/device_android.go
index 504993094..29e3f409d 100644
--- a/iface/tun_android.go
+++ b/client/iface/device/device_android.go
@@ -1,7 +1,6 @@
//go:build android
-// +build android
-package iface
+package device
import (
"strings"
@@ -12,11 +11,12 @@ import (
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
- "github.com/netbirdio/netbird/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/configurer"
)
-// ignore the wgTunDevice interface on Android because the creation of the tun device is different on this platform
-type wgTunDevice struct {
+// WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform
+type WGTunDevice struct {
address WGAddress
port int
key string
@@ -24,15 +24,15 @@ type wgTunDevice struct {
iceBind *bind.ICEBind
tunAdapter TunAdapter
- name string
- device *device.Device
- wrapper *DeviceWrapper
- udpMux *bind.UniversalUDPMuxDefault
- configurer wgConfigurer
+ name string
+ device *device.Device
+ filteredDevice *FilteredDevice
+ udpMux *bind.UniversalUDPMuxDefault
+ configurer WGConfigurer
}
-func newTunDevice(address WGAddress, port int, key string, mtu int, transportNet transport.Net, tunAdapter TunAdapter, filterFn bind.FilterFn) wgTunDevice {
- return wgTunDevice{
+func NewTunDevice(address WGAddress, port int, key string, mtu int, transportNet transport.Net, tunAdapter TunAdapter, filterFn bind.FilterFn) *WGTunDevice {
+ return &WGTunDevice{
address: address,
port: port,
key: key,
@@ -42,7 +42,7 @@ func newTunDevice(address WGAddress, port int, key string, mtu int, transportNet
}
}
-func (t *wgTunDevice) Create(routes []string, dns string, searchDomains []string) (wgConfigurer, error) {
+func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string) (WGConfigurer, error) {
log.Info("create tun interface")
routesString := routesToString(routes)
@@ -61,24 +61,24 @@ func (t *wgTunDevice) Create(routes []string, dns string, searchDomains []string
return nil, err
}
t.name = name
- t.wrapper = newDeviceWrapper(tunDevice)
+ t.filteredDevice = newDeviceFilter(tunDevice)
log.Debugf("attaching to interface %v", name)
- t.device = device.NewDevice(t.wrapper, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] "))
+ t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] "))
// without this property mobile devices can discover remote endpoints if the configured one was wrong.
// this helps with support for the older NetBird clients that had a hardcoded direct mode
// t.device.DisableSomeRoamingForBrokenMobileSemantics()
- t.configurer = newWGUSPConfigurer(t.device, t.name)
- err = t.configurer.configureInterface(t.key, t.port)
+ t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
+ err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
t.device.Close()
- t.configurer.close()
+ t.configurer.Close()
return nil, err
}
return t.configurer, nil
}
-func (t *wgTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
+func (t *WGTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
err := t.device.Up()
if err != nil {
return nil, err
@@ -93,14 +93,14 @@ func (t *wgTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil
}
-func (t *wgTunDevice) UpdateAddr(addr WGAddress) error {
+func (t *WGTunDevice) UpdateAddr(addr WGAddress) error {
// todo implement
return nil
}
-func (t *wgTunDevice) Close() error {
+func (t *WGTunDevice) Close() error {
if t.configurer != nil {
- t.configurer.close()
+ t.configurer.Close()
}
if t.device != nil {
@@ -115,20 +115,20 @@ func (t *wgTunDevice) Close() error {
return nil
}
-func (t *wgTunDevice) Device() *device.Device {
+func (t *WGTunDevice) Device() *device.Device {
return t.device
}
-func (t *wgTunDevice) DeviceName() string {
+func (t *WGTunDevice) DeviceName() string {
return t.name
}
-func (t *wgTunDevice) WgAddress() WGAddress {
+func (t *WGTunDevice) WgAddress() WGAddress {
return t.address
}
-func (t *wgTunDevice) Wrapper() *DeviceWrapper {
- return t.wrapper
+func (t *WGTunDevice) FilteredDevice() *FilteredDevice {
+ return t.filteredDevice
}
func routesToString(routes []string) string {
diff --git a/iface/tun_darwin.go b/client/iface/device/device_darwin.go
similarity index 69%
rename from iface/tun_darwin.go
rename to client/iface/device/device_darwin.go
index fcf9f8ba0..03e85a7f1 100644
--- a/iface/tun_darwin.go
+++ b/client/iface/device/device_darwin.go
@@ -1,6 +1,6 @@
//go:build !ios
-package iface
+package device
import (
"fmt"
@@ -11,10 +11,11 @@ import (
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
- "github.com/netbirdio/netbird/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/configurer"
)
-type tunDevice struct {
+type TunDevice struct {
name string
address WGAddress
port int
@@ -22,14 +23,14 @@ type tunDevice struct {
mtu int
iceBind *bind.ICEBind
- device *device.Device
- wrapper *DeviceWrapper
- udpMux *bind.UniversalUDPMuxDefault
- configurer wgConfigurer
+ device *device.Device
+ filteredDevice *FilteredDevice
+ udpMux *bind.UniversalUDPMuxDefault
+ configurer WGConfigurer
}
-func newTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) wgTunDevice {
- return &tunDevice{
+func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) *TunDevice {
+ return &TunDevice{
name: name,
address: address,
port: port,
@@ -39,16 +40,16 @@ func newTunDevice(name string, address WGAddress, port int, key string, mtu int,
}
}
-func (t *tunDevice) Create() (wgConfigurer, error) {
+func (t *TunDevice) Create() (WGConfigurer, error) {
tunDevice, err := tun.CreateTUN(t.name, t.mtu)
if err != nil {
return nil, fmt.Errorf("error creating tun device: %s", err)
}
- t.wrapper = newDeviceWrapper(tunDevice)
+ t.filteredDevice = newDeviceFilter(tunDevice)
// We need to create a wireguard-go device and listen to configuration requests
t.device = device.NewDevice(
- t.wrapper,
+ t.filteredDevice,
t.iceBind,
device.NewLogger(wgLogLevel(), "[netbird] "),
)
@@ -59,17 +60,17 @@ func (t *tunDevice) Create() (wgConfigurer, error) {
return nil, fmt.Errorf("error assigning ip: %s", err)
}
- t.configurer = newWGUSPConfigurer(t.device, t.name)
- err = t.configurer.configureInterface(t.key, t.port)
+ t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
+ err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
t.device.Close()
- t.configurer.close()
+ t.configurer.Close()
return nil, fmt.Errorf("error configuring interface: %s", err)
}
return t.configurer, nil
}
-func (t *tunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
+func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
err := t.device.Up()
if err != nil {
return nil, err
@@ -84,14 +85,14 @@ func (t *tunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil
}
-func (t *tunDevice) UpdateAddr(address WGAddress) error {
+func (t *TunDevice) UpdateAddr(address WGAddress) error {
t.address = address
return t.assignAddr()
}
-func (t *tunDevice) Close() error {
+func (t *TunDevice) Close() error {
if t.configurer != nil {
- t.configurer.close()
+ t.configurer.Close()
}
if t.device != nil {
@@ -105,20 +106,20 @@ func (t *tunDevice) Close() error {
return nil
}
-func (t *tunDevice) WgAddress() WGAddress {
+func (t *TunDevice) WgAddress() WGAddress {
return t.address
}
-func (t *tunDevice) DeviceName() string {
+func (t *TunDevice) DeviceName() string {
return t.name
}
-func (t *tunDevice) Wrapper() *DeviceWrapper {
- return t.wrapper
+func (t *TunDevice) FilteredDevice() *FilteredDevice {
+ return t.filteredDevice
}
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided
-func (t *tunDevice) assignAddr() error {
+func (t *TunDevice) assignAddr() error {
cmd := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String())
if out, err := cmd.CombinedOutput(); err != nil {
log.Errorf("adding address command '%v' failed with output: %s", cmd.String(), out)
diff --git a/iface/device_wrapper.go b/client/iface/device/device_filter.go
similarity index 81%
rename from iface/device_wrapper.go
rename to client/iface/device/device_filter.go
index 2fa219395..f87f10429 100644
--- a/iface/device_wrapper.go
+++ b/client/iface/device/device_filter.go
@@ -1,4 +1,4 @@
-package iface
+package device
import (
"net"
@@ -28,22 +28,23 @@ type PacketFilter interface {
SetNetwork(*net.IPNet)
}
-// DeviceWrapper to override Read or Write of packets
-type DeviceWrapper struct {
+// FilteredDevice to override Read or Write of packets
+type FilteredDevice struct {
tun.Device
+
filter PacketFilter
mutex sync.RWMutex
}
-// newDeviceWrapper constructor function
-func newDeviceWrapper(device tun.Device) *DeviceWrapper {
- return &DeviceWrapper{
+// newDeviceFilter constructor function
+func newDeviceFilter(device tun.Device) *FilteredDevice {
+ return &FilteredDevice{
Device: device,
}
}
// Read wraps read method with filtering feature
-func (d *DeviceWrapper) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
+func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
if n, err = d.Device.Read(bufs, sizes, offset); err != nil {
return 0, err
}
@@ -68,7 +69,7 @@ func (d *DeviceWrapper) Read(bufs [][]byte, sizes []int, offset int) (n int, err
}
// Write wraps write method with filtering feature
-func (d *DeviceWrapper) Write(bufs [][]byte, offset int) (int, error) {
+func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
d.mutex.RLock()
filter := d.filter
d.mutex.RUnlock()
@@ -92,7 +93,7 @@ func (d *DeviceWrapper) Write(bufs [][]byte, offset int) (int, error) {
}
// SetFilter sets packet filter to device
-func (d *DeviceWrapper) SetFilter(filter PacketFilter) {
+func (d *FilteredDevice) SetFilter(filter PacketFilter) {
d.mutex.Lock()
d.filter = filter
d.mutex.Unlock()
diff --git a/iface/device_wrapper_test.go b/client/iface/device/device_filter_test.go
similarity index 95%
rename from iface/device_wrapper_test.go
rename to client/iface/device/device_filter_test.go
index 2d3725ea4..d3278b918 100644
--- a/iface/device_wrapper_test.go
+++ b/client/iface/device/device_filter_test.go
@@ -1,4 +1,4 @@
-package iface
+package device
import (
"net"
@@ -7,7 +7,8 @@ import (
"github.com/golang/mock/gomock"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
- mocks "github.com/netbirdio/netbird/iface/mocks"
+
+ mocks "github.com/netbirdio/netbird/client/iface/mocks"
)
func TestDeviceWrapperRead(t *testing.T) {
@@ -51,7 +52,7 @@ func TestDeviceWrapperRead(t *testing.T) {
return 1, nil
})
- wrapped := newDeviceWrapper(tun)
+ wrapped := newDeviceFilter(tun)
bufs := [][]byte{{}}
sizes := []int{0}
@@ -99,7 +100,7 @@ func TestDeviceWrapperRead(t *testing.T) {
tun := mocks.NewMockDevice(ctrl)
tun.EXPECT().Write(mockBufs, 0).Return(1, nil)
- wrapped := newDeviceWrapper(tun)
+ wrapped := newDeviceFilter(tun)
bufs := [][]byte{buffer.Bytes()}
@@ -147,7 +148,7 @@ func TestDeviceWrapperRead(t *testing.T) {
filter := mocks.NewMockPacketFilter(ctrl)
filter.EXPECT().DropIncoming(gomock.Any()).Return(true)
- wrapped := newDeviceWrapper(tun)
+ wrapped := newDeviceFilter(tun)
wrapped.filter = filter
bufs := [][]byte{buffer.Bytes()}
@@ -202,7 +203,7 @@ func TestDeviceWrapperRead(t *testing.T) {
filter := mocks.NewMockPacketFilter(ctrl)
filter.EXPECT().DropOutgoing(gomock.Any()).Return(true)
- wrapped := newDeviceWrapper(tun)
+ wrapped := newDeviceFilter(tun)
wrapped.filter = filter
bufs := [][]byte{{}}
diff --git a/iface/tun_ios.go b/client/iface/device/device_ios.go
similarity index 63%
rename from iface/tun_ios.go
rename to client/iface/device/device_ios.go
index 6d53cc333..226e8a2e0 100644
--- a/iface/tun_ios.go
+++ b/client/iface/device/device_ios.go
@@ -1,7 +1,7 @@
//go:build ios
// +build ios
-package iface
+package device
import (
"os"
@@ -12,10 +12,11 @@ import (
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
- "github.com/netbirdio/netbird/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/configurer"
)
-type tunDevice struct {
+type TunDevice struct {
name string
address WGAddress
port int
@@ -23,14 +24,14 @@ type tunDevice struct {
iceBind *bind.ICEBind
tunFd int
- device *device.Device
- wrapper *DeviceWrapper
- udpMux *bind.UniversalUDPMuxDefault
- configurer wgConfigurer
+ device *device.Device
+ filteredDevice *FilteredDevice
+ udpMux *bind.UniversalUDPMuxDefault
+ configurer WGConfigurer
}
-func newTunDevice(name string, address WGAddress, port int, key string, transportNet transport.Net, tunFd int, filterFn bind.FilterFn) *tunDevice {
- return &tunDevice{
+func NewTunDevice(name string, address WGAddress, port int, key string, transportNet transport.Net, tunFd int, filterFn bind.FilterFn) *TunDevice {
+ return &TunDevice{
name: name,
address: address,
port: port,
@@ -40,7 +41,7 @@ func newTunDevice(name string, address WGAddress, port int, key string, transpor
}
}
-func (t *tunDevice) Create() (wgConfigurer, error) {
+func (t *TunDevice) Create() (WGConfigurer, error) {
log.Infof("create tun interface")
dupTunFd, err := unix.Dup(t.tunFd)
@@ -62,24 +63,24 @@ func (t *tunDevice) Create() (wgConfigurer, error) {
return nil, err
}
- t.wrapper = newDeviceWrapper(tunDevice)
+ t.filteredDevice = newDeviceFilter(tunDevice)
log.Debug("Attaching to interface")
- t.device = device.NewDevice(t.wrapper, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] "))
+ t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] "))
// without this property mobile devices can discover remote endpoints if the configured one was wrong.
// this helps with support for the older NetBird clients that had a hardcoded direct mode
// t.device.DisableSomeRoamingForBrokenMobileSemantics()
- t.configurer = newWGUSPConfigurer(t.device, t.name)
- err = t.configurer.configureInterface(t.key, t.port)
+ t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
+ err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
t.device.Close()
- t.configurer.close()
+ t.configurer.Close()
return nil, err
}
return t.configurer, nil
}
-func (t *tunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
+func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
err := t.device.Up()
if err != nil {
return nil, err
@@ -94,17 +95,17 @@ func (t *tunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil
}
-func (t *tunDevice) Device() *device.Device {
+func (t *TunDevice) Device() *device.Device {
return t.device
}
-func (t *tunDevice) DeviceName() string {
+func (t *TunDevice) DeviceName() string {
return t.name
}
-func (t *tunDevice) Close() error {
+func (t *TunDevice) Close() error {
if t.configurer != nil {
- t.configurer.close()
+ t.configurer.Close()
}
if t.device != nil {
@@ -119,15 +120,15 @@ func (t *tunDevice) Close() error {
return nil
}
-func (t *tunDevice) WgAddress() WGAddress {
+func (t *TunDevice) WgAddress() WGAddress {
return t.address
}
-func (t *tunDevice) UpdateAddr(addr WGAddress) error {
+func (t *TunDevice) UpdateAddr(addr WGAddress) error {
// todo implement
return nil
}
-func (t *tunDevice) Wrapper() *DeviceWrapper {
- return t.wrapper
+func (t *TunDevice) FilteredDevice() *FilteredDevice {
+ return t.filteredDevice
}
diff --git a/iface/tun_kernel_unix.go b/client/iface/device/device_kernel_unix.go
similarity index 75%
rename from iface/tun_kernel_unix.go
rename to client/iface/device/device_kernel_unix.go
index 220c07888..f355d2cf7 100644
--- a/iface/tun_kernel_unix.go
+++ b/client/iface/device/device_kernel_unix.go
@@ -1,6 +1,6 @@
//go:build (linux && !android) || freebsd
-package iface
+package device
import (
"context"
@@ -10,11 +10,12 @@ import (
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
- "github.com/netbirdio/netbird/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/sharedsock"
)
-type tunKernelDevice struct {
+type TunKernelDevice struct {
name string
address WGAddress
wgPort int
@@ -31,11 +32,11 @@ type tunKernelDevice struct {
filterFn bind.FilterFn
}
-func newTunDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) wgTunDevice {
+func NewKernelDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice {
checkUser()
ctx, cancel := context.WithCancel(context.Background())
- return &tunKernelDevice{
+ return &TunKernelDevice{
ctx: ctx,
ctxCancel: cancel,
name: name,
@@ -47,7 +48,7 @@ func newTunDevice(name string, address WGAddress, wgPort int, key string, mtu in
}
}
-func (t *tunKernelDevice) Create() (wgConfigurer, error) {
+func (t *TunKernelDevice) Create() (WGConfigurer, error) {
link := newWGLink(t.name)
if err := link.recreate(); err != nil {
@@ -67,16 +68,16 @@ func (t *tunKernelDevice) Create() (wgConfigurer, error) {
return nil, fmt.Errorf("set mtu: %w", err)
}
- configurer := newWGConfigurer(t.name)
+ configurer := configurer.NewKernelConfigurer(t.name)
- if err := configurer.configureInterface(t.key, t.wgPort); err != nil {
+ if err := configurer.ConfigureInterface(t.key, t.wgPort); err != nil {
return nil, fmt.Errorf("error configuring interface: %s", err)
}
return configurer, nil
}
-func (t *tunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
+func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
if t.udpMux != nil {
return t.udpMux, nil
}
@@ -111,12 +112,12 @@ func (t *tunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return t.udpMux, nil
}
-func (t *tunKernelDevice) UpdateAddr(address WGAddress) error {
+func (t *TunKernelDevice) UpdateAddr(address WGAddress) error {
t.address = address
return t.assignAddr()
}
-func (t *tunKernelDevice) Close() error {
+func (t *TunKernelDevice) Close() error {
if t.link == nil {
return nil
}
@@ -144,19 +145,19 @@ func (t *tunKernelDevice) Close() error {
return closErr
}
-func (t *tunKernelDevice) WgAddress() WGAddress {
+func (t *TunKernelDevice) WgAddress() WGAddress {
return t.address
}
-func (t *tunKernelDevice) DeviceName() string {
+func (t *TunKernelDevice) DeviceName() string {
return t.name
}
-func (t *tunKernelDevice) Wrapper() *DeviceWrapper {
+func (t *TunKernelDevice) FilteredDevice() *FilteredDevice {
return nil
}
// assignAddr Adds IP address to the tunnel interface
-func (t *tunKernelDevice) assignAddr() error {
+func (t *TunKernelDevice) assignAddr() error {
return t.link.assignAddr(t.address)
}
diff --git a/iface/tun_netstack.go b/client/iface/device/device_netstack.go
similarity index 56%
rename from iface/tun_netstack.go
rename to client/iface/device/device_netstack.go
index de1ff6654..440a1ca19 100644
--- a/iface/tun_netstack.go
+++ b/client/iface/device/device_netstack.go
@@ -1,7 +1,7 @@
//go:build !android
// +build !android
-package iface
+package device
import (
"fmt"
@@ -10,11 +10,12 @@ import (
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device"
- "github.com/netbirdio/netbird/iface/bind"
- "github.com/netbirdio/netbird/iface/netstack"
+ "github.com/netbirdio/netbird/client/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/configurer"
+ "github.com/netbirdio/netbird/client/iface/netstack"
)
-type tunNetstackDevice struct {
+type TunNetstackDevice struct {
name string
address WGAddress
port int
@@ -23,15 +24,15 @@ type tunNetstackDevice struct {
listenAddress string
iceBind *bind.ICEBind
- device *device.Device
- wrapper *DeviceWrapper
- nsTun *netstack.NetStackTun
- udpMux *bind.UniversalUDPMuxDefault
- configurer wgConfigurer
+ device *device.Device
+ filteredDevice *FilteredDevice
+ nsTun *netstack.NetStackTun
+ udpMux *bind.UniversalUDPMuxDefault
+ configurer WGConfigurer
}
-func newTunNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net, listenAddress string, filterFn bind.FilterFn) wgTunDevice {
- return &tunNetstackDevice{
+func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net, listenAddress string, filterFn bind.FilterFn) *TunNetstackDevice {
+ return &TunNetstackDevice{
name: name,
address: address,
port: wgPort,
@@ -42,23 +43,23 @@ func newTunNetstackDevice(name string, address WGAddress, wgPort int, key string
}
}
-func (t *tunNetstackDevice) Create() (wgConfigurer, error) {
+func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
log.Info("create netstack tun interface")
t.nsTun = netstack.NewNetStackTun(t.listenAddress, t.address.IP.String(), t.mtu)
tunIface, err := t.nsTun.Create()
if err != nil {
return nil, fmt.Errorf("error creating tun device: %s", err)
}
- t.wrapper = newDeviceWrapper(tunIface)
+ t.filteredDevice = newDeviceFilter(tunIface)
t.device = device.NewDevice(
- t.wrapper,
+ t.filteredDevice,
t.iceBind,
device.NewLogger(wgLogLevel(), "[netbird] "),
)
- t.configurer = newWGUSPConfigurer(t.device, t.name)
- err = t.configurer.configureInterface(t.key, t.port)
+ t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
+ err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
_ = tunIface.Close()
return nil, fmt.Errorf("error configuring interface: %s", err)
@@ -68,7 +69,7 @@ func (t *tunNetstackDevice) Create() (wgConfigurer, error) {
return t.configurer, nil
}
-func (t *tunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
+func (t *TunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
if t.device == nil {
return nil, fmt.Errorf("device is not ready yet")
}
@@ -87,13 +88,13 @@ func (t *tunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil
}
-func (t *tunNetstackDevice) UpdateAddr(WGAddress) error {
+func (t *TunNetstackDevice) UpdateAddr(WGAddress) error {
return nil
}
-func (t *tunNetstackDevice) Close() error {
+func (t *TunNetstackDevice) Close() error {
if t.configurer != nil {
- t.configurer.close()
+ t.configurer.Close()
}
if t.device != nil {
@@ -106,14 +107,14 @@ func (t *tunNetstackDevice) Close() error {
return nil
}
-func (t *tunNetstackDevice) WgAddress() WGAddress {
+func (t *TunNetstackDevice) WgAddress() WGAddress {
return t.address
}
-func (t *tunNetstackDevice) DeviceName() string {
+func (t *TunNetstackDevice) DeviceName() string {
return t.name
}
-func (t *tunNetstackDevice) Wrapper() *DeviceWrapper {
- return t.wrapper
+func (t *TunNetstackDevice) FilteredDevice() *FilteredDevice {
+ return t.filteredDevice
}
diff --git a/iface/tun_usp_unix.go b/client/iface/device/device_usp_unix.go
similarity index 63%
rename from iface/tun_usp_unix.go
rename to client/iface/device/device_usp_unix.go
index 1c1d3ac89..4175f6556 100644
--- a/iface/tun_usp_unix.go
+++ b/client/iface/device/device_usp_unix.go
@@ -1,6 +1,6 @@
//go:build (linux && !android) || freebsd
-package iface
+package device
import (
"fmt"
@@ -12,10 +12,11 @@ import (
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
- "github.com/netbirdio/netbird/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/configurer"
)
-type tunUSPDevice struct {
+type USPDevice struct {
name string
address WGAddress
port int
@@ -23,39 +24,38 @@ type tunUSPDevice struct {
mtu int
iceBind *bind.ICEBind
- device *device.Device
- wrapper *DeviceWrapper
- udpMux *bind.UniversalUDPMuxDefault
- configurer wgConfigurer
+ device *device.Device
+ filteredDevice *FilteredDevice
+ udpMux *bind.UniversalUDPMuxDefault
+ configurer WGConfigurer
}
-func newTunUSPDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) wgTunDevice {
+func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) *USPDevice {
log.Infof("using userspace bind mode")
checkUser()
- return &tunUSPDevice{
+ return &USPDevice{
name: name,
address: address,
port: port,
key: key,
mtu: mtu,
- iceBind: bind.NewICEBind(transportNet, filterFn),
- }
+ iceBind: bind.NewICEBind(transportNet, filterFn)}
}
-func (t *tunUSPDevice) Create() (wgConfigurer, error) {
+func (t *USPDevice) Create() (WGConfigurer, error) {
log.Info("create tun interface")
tunIface, err := tun.CreateTUN(t.name, t.mtu)
if err != nil {
log.Debugf("failed to create tun interface (%s, %d): %s", t.name, t.mtu, err)
return nil, fmt.Errorf("error creating tun device: %s", err)
}
- t.wrapper = newDeviceWrapper(tunIface)
+ t.filteredDevice = newDeviceFilter(tunIface)
// We need to create a wireguard-go device and listen to configuration requests
t.device = device.NewDevice(
- t.wrapper,
+ t.filteredDevice,
t.iceBind,
device.NewLogger(wgLogLevel(), "[netbird] "),
)
@@ -66,17 +66,17 @@ func (t *tunUSPDevice) Create() (wgConfigurer, error) {
return nil, fmt.Errorf("error assigning ip: %s", err)
}
- t.configurer = newWGUSPConfigurer(t.device, t.name)
- err = t.configurer.configureInterface(t.key, t.port)
+ t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
+ err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
t.device.Close()
- t.configurer.close()
+ t.configurer.Close()
return nil, fmt.Errorf("error configuring interface: %s", err)
}
return t.configurer, nil
}
-func (t *tunUSPDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
+func (t *USPDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
if t.device == nil {
return nil, fmt.Errorf("device is not ready yet")
}
@@ -96,14 +96,14 @@ func (t *tunUSPDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil
}
-func (t *tunUSPDevice) UpdateAddr(address WGAddress) error {
+func (t *USPDevice) UpdateAddr(address WGAddress) error {
t.address = address
return t.assignAddr()
}
-func (t *tunUSPDevice) Close() error {
+func (t *USPDevice) Close() error {
if t.configurer != nil {
- t.configurer.close()
+ t.configurer.Close()
}
if t.device != nil {
@@ -116,20 +116,20 @@ func (t *tunUSPDevice) Close() error {
return nil
}
-func (t *tunUSPDevice) WgAddress() WGAddress {
+func (t *USPDevice) WgAddress() WGAddress {
return t.address
}
-func (t *tunUSPDevice) DeviceName() string {
+func (t *USPDevice) DeviceName() string {
return t.name
}
-func (t *tunUSPDevice) Wrapper() *DeviceWrapper {
- return t.wrapper
+func (t *USPDevice) FilteredDevice() *FilteredDevice {
+ return t.filteredDevice
}
// assignAddr Adds IP address to the tunnel interface
-func (t *tunUSPDevice) assignAddr() error {
+func (t *USPDevice) assignAddr() error {
link := newWGLink(t.name)
return link.assignAddr(t.address)
diff --git a/iface/tun_windows.go b/client/iface/device/device_windows.go
similarity index 75%
rename from iface/tun_windows.go
rename to client/iface/device/device_windows.go
index afb67bcc0..f3e216ccd 100644
--- a/iface/tun_windows.go
+++ b/client/iface/device/device_windows.go
@@ -1,4 +1,4 @@
-package iface
+package device
import (
"fmt"
@@ -11,12 +11,13 @@ import (
"golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
- "github.com/netbirdio/netbird/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/configurer"
)
const defaultWindowsGUIDSTring = "{f2f29e61-d91f-4d76-8151-119b20c4bdeb}"
-type tunDevice struct {
+type TunDevice struct {
name string
address WGAddress
port int
@@ -26,13 +27,13 @@ type tunDevice struct {
device *device.Device
nativeTunDevice *tun.NativeTun
- wrapper *DeviceWrapper
+ filteredDevice *FilteredDevice
udpMux *bind.UniversalUDPMuxDefault
- configurer wgConfigurer
+ configurer WGConfigurer
}
-func newTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) wgTunDevice {
- return &tunDevice{
+func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) *TunDevice {
+ return &TunDevice{
name: name,
address: address,
port: port,
@@ -50,7 +51,7 @@ func getGUID() (windows.GUID, error) {
return windows.GUIDFromString(guidString)
}
-func (t *tunDevice) Create() (wgConfigurer, error) {
+func (t *TunDevice) Create() (WGConfigurer, error) {
guid, err := getGUID()
if err != nil {
log.Errorf("failed to get GUID: %s", err)
@@ -62,11 +63,11 @@ func (t *tunDevice) Create() (wgConfigurer, error) {
return nil, fmt.Errorf("error creating tun device: %s", err)
}
t.nativeTunDevice = tunDevice.(*tun.NativeTun)
- t.wrapper = newDeviceWrapper(tunDevice)
+ t.filteredDevice = newDeviceFilter(tunDevice)
// We need to create a wireguard-go device and listen to configuration requests
t.device = device.NewDevice(
- t.wrapper,
+ t.filteredDevice,
t.iceBind,
device.NewLogger(wgLogLevel(), "[netbird] "),
)
@@ -92,17 +93,17 @@ func (t *tunDevice) Create() (wgConfigurer, error) {
return nil, fmt.Errorf("error assigning ip: %s", err)
}
- t.configurer = newWGUSPConfigurer(t.device, t.name)
- err = t.configurer.configureInterface(t.key, t.port)
+ t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
+ err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
t.device.Close()
- t.configurer.close()
+ t.configurer.Close()
return nil, fmt.Errorf("error configuring interface: %s", err)
}
return t.configurer, nil
}
-func (t *tunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
+func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
err := t.device.Up()
if err != nil {
return nil, err
@@ -117,14 +118,14 @@ func (t *tunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil
}
-func (t *tunDevice) UpdateAddr(address WGAddress) error {
+func (t *TunDevice) UpdateAddr(address WGAddress) error {
t.address = address
return t.assignAddr()
}
-func (t *tunDevice) Close() error {
+func (t *TunDevice) Close() error {
if t.configurer != nil {
- t.configurer.close()
+ t.configurer.Close()
}
if t.device != nil {
@@ -138,19 +139,19 @@ func (t *tunDevice) Close() error {
}
return nil
}
-func (t *tunDevice) WgAddress() WGAddress {
+func (t *TunDevice) WgAddress() WGAddress {
return t.address
}
-func (t *tunDevice) DeviceName() string {
+func (t *TunDevice) DeviceName() string {
return t.name
}
-func (t *tunDevice) Wrapper() *DeviceWrapper {
- return t.wrapper
+func (t *TunDevice) FilteredDevice() *FilteredDevice {
+ return t.filteredDevice
}
-func (t *tunDevice) getInterfaceGUIDString() (string, error) {
+func (t *TunDevice) GetInterfaceGUIDString() (string, error) {
if t.nativeTunDevice == nil {
return "", fmt.Errorf("interface has not been initialized yet")
}
@@ -164,7 +165,7 @@ func (t *tunDevice) getInterfaceGUIDString() (string, error) {
}
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided
-func (t *tunDevice) assignAddr() error {
+func (t *TunDevice) assignAddr() error {
luid := winipcfg.LUID(t.nativeTunDevice.LUID())
log.Debugf("adding address %s to interface: %s", t.address.IP, t.name)
return luid.SetIPAddresses([]netip.Prefix{netip.MustParsePrefix(t.address.String())})
diff --git a/client/iface/device/interface.go b/client/iface/device/interface.go
new file mode 100644
index 000000000..0196b0085
--- /dev/null
+++ b/client/iface/device/interface.go
@@ -0,0 +1,20 @@
+package device
+
+import (
+ "net"
+ "time"
+
+ "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
+
+ "github.com/netbirdio/netbird/client/iface/configurer"
+)
+
+type WGConfigurer interface {
+ ConfigureInterface(privateKey string, port int) error
+ UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
+ RemovePeer(peerKey string) error
+ AddAllowedIP(peerKey string, allowedIP string) error
+ RemoveAllowedIP(peerKey string, allowedIP string) error
+ Close()
+ GetStats(peerKey string) (configurer.WGStats, error)
+}
diff --git a/iface/module.go b/client/iface/device/kernel_module.go
similarity index 92%
rename from iface/module.go
rename to client/iface/device/kernel_module.go
index ca70cf3c7..1bdd6f7c6 100644
--- a/iface/module.go
+++ b/client/iface/device/kernel_module.go
@@ -1,6 +1,6 @@
//go:build (!linux && !freebsd) || android
-package iface
+package device
// WireGuardModuleIsLoaded check if we can load WireGuard mod (linux only)
func WireGuardModuleIsLoaded() bool {
diff --git a/iface/module_freebsd.go b/client/iface/device/kernel_module_freebsd.go
similarity index 84%
rename from iface/module_freebsd.go
rename to client/iface/device/kernel_module_freebsd.go
index 00ad882c2..dd6c8b408 100644
--- a/iface/module_freebsd.go
+++ b/client/iface/device/kernel_module_freebsd.go
@@ -1,4 +1,4 @@
-package iface
+package device
// WireGuardModuleIsLoaded check if kernel support wireguard
func WireGuardModuleIsLoaded() bool {
@@ -10,8 +10,8 @@ func WireGuardModuleIsLoaded() bool {
return false
}
-// tunModuleIsLoaded check if tun module exist, if is not attempt to load it
-func tunModuleIsLoaded() bool {
+// ModuleTunIsLoaded check if tun module exist, if is not attempt to load it
+func ModuleTunIsLoaded() bool {
// Assume tun supported by freebsd kernel by default
// TODO: implement check for module loaded in kernel or build-it
return true
diff --git a/iface/module_linux.go b/client/iface/device/kernel_module_linux.go
similarity index 98%
rename from iface/module_linux.go
rename to client/iface/device/kernel_module_linux.go
index 11c0482d5..0d195779d 100644
--- a/iface/module_linux.go
+++ b/client/iface/device/kernel_module_linux.go
@@ -1,7 +1,7 @@
//go:build linux && !android
// Package iface provides wireguard network interface creation and management
-package iface
+package device
import (
"bufio"
@@ -66,8 +66,8 @@ func getModuleRoot() string {
return filepath.Join(moduleLibDir, string(uname.Release[:i]))
}
-// tunModuleIsLoaded check if tun module exist, if is not attempt to load it
-func tunModuleIsLoaded() bool {
+// ModuleTunIsLoaded check if tun module exist, if is not attempt to load it
+func ModuleTunIsLoaded() bool {
_, err := os.Stat("/dev/net/tun")
if err == nil {
return true
diff --git a/iface/module_linux_test.go b/client/iface/device/kernel_module_linux_test.go
similarity index 98%
rename from iface/module_linux_test.go
rename to client/iface/device/kernel_module_linux_test.go
index 97e9b1f78..de9656e47 100644
--- a/iface/module_linux_test.go
+++ b/client/iface/device/kernel_module_linux_test.go
@@ -1,4 +1,6 @@
-package iface
+//go:build linux && !android
+
+package device
import (
"bufio"
@@ -132,7 +134,7 @@ func resetGlobals() {
}
func createFiles(t *testing.T) (string, []module) {
- t.Helper()
+ t.Helper()
writeFile := func(path, text string) {
if err := os.WriteFile(path, []byte(text), 0644); err != nil {
t.Fatal(err)
@@ -168,7 +170,7 @@ func createFiles(t *testing.T) (string, []module) {
}
func getRandomLoadedModule(t *testing.T) (string, error) {
- t.Helper()
+ t.Helper()
f, err := os.Open("/proc/modules")
if err != nil {
return "", err
diff --git a/iface/tun_link_freebsd.go b/client/iface/device/wg_link_freebsd.go
similarity index 95%
rename from iface/tun_link_freebsd.go
rename to client/iface/device/wg_link_freebsd.go
index be7921fdb..104010f47 100644
--- a/iface/tun_link_freebsd.go
+++ b/client/iface/device/wg_link_freebsd.go
@@ -1,10 +1,11 @@
-package iface
+package device
import (
"fmt"
- "github.com/netbirdio/netbird/iface/freebsd"
log "github.com/sirupsen/logrus"
+
+ "github.com/netbirdio/netbird/client/iface/freebsd"
)
type wgLink struct {
diff --git a/iface/tun_link_linux.go b/client/iface/device/wg_link_linux.go
similarity index 99%
rename from iface/tun_link_linux.go
rename to client/iface/device/wg_link_linux.go
index 3ce644e84..a15cffe48 100644
--- a/iface/tun_link_linux.go
+++ b/client/iface/device/wg_link_linux.go
@@ -1,6 +1,6 @@
//go:build linux && !android
-package iface
+package device
import (
"fmt"
diff --git a/iface/wg_log.go b/client/iface/device/wg_log.go
similarity index 93%
rename from iface/wg_log.go
rename to client/iface/device/wg_log.go
index b44f6fc0b..db2f3111f 100644
--- a/iface/wg_log.go
+++ b/client/iface/device/wg_log.go
@@ -1,4 +1,4 @@
-package iface
+package device
import (
"os"
diff --git a/client/iface/device/windows_guid.go b/client/iface/device/windows_guid.go
new file mode 100644
index 000000000..1c7d40d13
--- /dev/null
+++ b/client/iface/device/windows_guid.go
@@ -0,0 +1,4 @@
+package device
+
+// CustomWindowsGUIDString is a custom GUID string for the interface
+var CustomWindowsGUIDString string
diff --git a/client/iface/device_android.go b/client/iface/device_android.go
new file mode 100644
index 000000000..3d15080ff
--- /dev/null
+++ b/client/iface/device_android.go
@@ -0,0 +1,16 @@
+package iface
+
+import (
+ "github.com/netbirdio/netbird/client/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/device"
+)
+
+type WGTunDevice interface {
+ Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error)
+ Up() (*bind.UniversalUDPMuxDefault, error)
+ UpdateAddr(address WGAddress) error
+ WgAddress() WGAddress
+ DeviceName() string
+ Close() error
+ FilteredDevice() *device.FilteredDevice
+}
diff --git a/iface/freebsd/errors.go b/client/iface/freebsd/errors.go
similarity index 100%
rename from iface/freebsd/errors.go
rename to client/iface/freebsd/errors.go
diff --git a/iface/freebsd/iface.go b/client/iface/freebsd/iface.go
similarity index 100%
rename from iface/freebsd/iface.go
rename to client/iface/freebsd/iface.go
diff --git a/iface/freebsd/iface_internal_test.go b/client/iface/freebsd/iface_internal_test.go
similarity index 100%
rename from iface/freebsd/iface_internal_test.go
rename to client/iface/freebsd/iface_internal_test.go
diff --git a/iface/freebsd/link.go b/client/iface/freebsd/link.go
similarity index 100%
rename from iface/freebsd/link.go
rename to client/iface/freebsd/link.go
diff --git a/iface/iface.go b/client/iface/iface.go
similarity index 79%
rename from iface/iface.go
rename to client/iface/iface.go
index 545feffcf..accf5ce0a 100644
--- a/iface/iface.go
+++ b/client/iface/iface.go
@@ -9,28 +9,27 @@ import (
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
- "github.com/netbirdio/netbird/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/configurer"
+ "github.com/netbirdio/netbird/client/iface/device"
)
const (
- DefaultMTU = 1280
- DefaultWgPort = 51820
+ DefaultMTU = 1280
+ DefaultWgPort = 51820
+ WgInterfaceDefault = configurer.WgInterfaceDefault
)
-// WGIface represents a interface instance
+type WGAddress = device.WGAddress
+
+// WGIface represents an interface instance
type WGIface struct {
- tun wgTunDevice
+ tun WGTunDevice
userspaceBind bool
mu sync.Mutex
- configurer wgConfigurer
- filter PacketFilter
-}
-
-type WGStats struct {
- LastHandshake time.Time
- TxBytes int64
- RxBytes int64
+ configurer device.WGConfigurer
+ filter device.PacketFilter
}
// IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind
@@ -44,7 +43,7 @@ func (w *WGIface) Name() string {
}
// Address returns the interface address
-func (w *WGIface) Address() WGAddress {
+func (w *WGIface) Address() device.WGAddress {
return w.tun.WgAddress()
}
@@ -75,7 +74,7 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
w.mu.Lock()
defer w.mu.Unlock()
- addr, err := parseWGAddress(newAddr)
+ addr, err := device.ParseWGAddress(newAddr)
if err != nil {
return err
}
@@ -90,7 +89,7 @@ func (w *WGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.D
defer w.mu.Unlock()
log.Debugf("updating interface %s peer %s, endpoint %s", w.tun.DeviceName(), peerKey, endpoint)
- return w.configurer.updatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
+ return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
}
// RemovePeer removes a Wireguard Peer from the interface iface
@@ -99,7 +98,7 @@ func (w *WGIface) RemovePeer(peerKey string) error {
defer w.mu.Unlock()
log.Debugf("Removing peer %s from interface %s ", peerKey, w.tun.DeviceName())
- return w.configurer.removePeer(peerKey)
+ return w.configurer.RemovePeer(peerKey)
}
// AddAllowedIP adds a prefix to the allowed IPs list of peer
@@ -108,7 +107,7 @@ func (w *WGIface) AddAllowedIP(peerKey string, allowedIP string) error {
defer w.mu.Unlock()
log.Debugf("Adding allowed IP to interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP)
- return w.configurer.addAllowedIP(peerKey, allowedIP)
+ return w.configurer.AddAllowedIP(peerKey, allowedIP)
}
// RemoveAllowedIP removes a prefix from the allowed IPs list of peer
@@ -117,7 +116,7 @@ func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP string) error {
defer w.mu.Unlock()
log.Debugf("Removing allowed IP from interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP)
- return w.configurer.removeAllowedIP(peerKey, allowedIP)
+ return w.configurer.RemoveAllowedIP(peerKey, allowedIP)
}
// Close closes the tunnel interface
@@ -144,23 +143,23 @@ func (w *WGIface) Close() error {
}
// SetFilter sets packet filters for the userspace implementation
-func (w *WGIface) SetFilter(filter PacketFilter) error {
+func (w *WGIface) SetFilter(filter device.PacketFilter) error {
w.mu.Lock()
defer w.mu.Unlock()
- if w.tun.Wrapper() == nil {
+ if w.tun.FilteredDevice() == nil {
return fmt.Errorf("userspace packet filtering not handled on this device")
}
w.filter = filter
w.filter.SetNetwork(w.tun.WgAddress().Network)
- w.tun.Wrapper().SetFilter(filter)
+ w.tun.FilteredDevice().SetFilter(filter)
return nil
}
// GetFilter returns packet filter used by interface if it uses userspace device implementation
-func (w *WGIface) GetFilter() PacketFilter {
+func (w *WGIface) GetFilter() device.PacketFilter {
w.mu.Lock()
defer w.mu.Unlock()
@@ -168,16 +167,16 @@ func (w *WGIface) GetFilter() PacketFilter {
}
// GetDevice to interact with raw device (with filtering)
-func (w *WGIface) GetDevice() *DeviceWrapper {
+func (w *WGIface) GetDevice() *device.FilteredDevice {
w.mu.Lock()
defer w.mu.Unlock()
- return w.tun.Wrapper()
+ return w.tun.FilteredDevice()
}
// GetStats returns the last handshake time, rx and tx bytes for the given peer
-func (w *WGIface) GetStats(peerKey string) (WGStats, error) {
- return w.configurer.getStats(peerKey)
+func (w *WGIface) GetStats(peerKey string) (configurer.WGStats, error) {
+ return w.configurer.GetStats(peerKey)
}
func (w *WGIface) waitUntilRemoved() error {
diff --git a/iface/iface_android.go b/client/iface/iface_android.go
similarity index 67%
rename from iface/iface_android.go
rename to client/iface/iface_android.go
index 99f6885a5..5ed476e70 100644
--- a/iface/iface_android.go
+++ b/client/iface/iface_android.go
@@ -5,18 +5,19 @@ import (
"github.com/pion/transport/v3"
- "github.com/netbirdio/netbird/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/device"
)
// NewWGIFace Creates a new WireGuard interface instance
-func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
- wgAddress, err := parseWGAddress(address)
+func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
+ wgAddress, err := device.ParseWGAddress(address)
if err != nil {
return nil, err
}
wgIFace := &WGIface{
- tun: newTunDevice(wgAddress, wgPort, wgPrivKey, mtu, transportNet, args.TunAdapter, filterFn),
+ tun: device.NewTunDevice(wgAddress, wgPort, wgPrivKey, mtu, transportNet, args.TunAdapter, filterFn),
userspaceBind: true,
}
return wgIFace, nil
diff --git a/iface/iface_create.go b/client/iface/iface_create.go
similarity index 100%
rename from iface/iface_create.go
rename to client/iface/iface_create.go
diff --git a/iface/iface_darwin.go b/client/iface/iface_darwin.go
similarity index 68%
rename from iface/iface_darwin.go
rename to client/iface/iface_darwin.go
index f48f324c3..b46ea0f80 100644
--- a/iface/iface_darwin.go
+++ b/client/iface/iface_darwin.go
@@ -9,13 +9,14 @@ import (
"github.com/cenkalti/backoff/v4"
"github.com/pion/transport/v3"
- "github.com/netbirdio/netbird/iface/bind"
- "github.com/netbirdio/netbird/iface/netstack"
+ "github.com/netbirdio/netbird/client/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/device"
+ "github.com/netbirdio/netbird/client/iface/netstack"
)
// NewWGIFace Creates a new WireGuard interface instance
-func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, _ *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
- wgAddress, err := parseWGAddress(address)
+func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, _ *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
+ wgAddress, err := device.ParseWGAddress(address)
if err != nil {
return nil, err
}
@@ -25,11 +26,11 @@ func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string,
}
if netstack.IsEnabled() {
- wgIFace.tun = newTunNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn)
+ wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn)
return wgIFace, nil
}
- wgIFace.tun = newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn)
+ wgIFace.tun = device.NewTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn)
return wgIFace, nil
}
diff --git a/iface/iface_destroy_bsd.go b/client/iface/iface_destroy_bsd.go
similarity index 100%
rename from iface/iface_destroy_bsd.go
rename to client/iface/iface_destroy_bsd.go
diff --git a/iface/iface_destroy_linux.go b/client/iface/iface_destroy_linux.go
similarity index 100%
rename from iface/iface_destroy_linux.go
rename to client/iface/iface_destroy_linux.go
diff --git a/iface/iface_destroy_mobile.go b/client/iface/iface_destroy_mobile.go
similarity index 100%
rename from iface/iface_destroy_mobile.go
rename to client/iface/iface_destroy_mobile.go
diff --git a/iface/iface_destroy_windows.go b/client/iface/iface_destroy_windows.go
similarity index 100%
rename from iface/iface_destroy_windows.go
rename to client/iface/iface_destroy_windows.go
diff --git a/iface/iface_ios.go b/client/iface/iface_ios.go
similarity index 59%
rename from iface/iface_ios.go
rename to client/iface/iface_ios.go
index 6babe5964..fc0214748 100644
--- a/iface/iface_ios.go
+++ b/client/iface/iface_ios.go
@@ -7,17 +7,18 @@ import (
"github.com/pion/transport/v3"
- "github.com/netbirdio/netbird/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/device"
)
// NewWGIFace Creates a new WireGuard interface instance
-func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
- wgAddress, err := parseWGAddress(address)
+func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
+ wgAddress, err := device.ParseWGAddress(address)
if err != nil {
return nil, err
}
wgIFace := &WGIface{
- tun: newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, transportNet, args.TunFd, filterFn),
+ tun: device.NewTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, transportNet, args.TunFd, filterFn),
userspaceBind: true,
}
return wgIFace, nil
diff --git a/iface/iface_moc.go b/client/iface/iface_moc.go
similarity index 76%
rename from iface/iface_moc.go
rename to client/iface/iface_moc.go
index fab3054a0..703da9ce0 100644
--- a/iface/iface_moc.go
+++ b/client/iface/iface_moc.go
@@ -6,7 +6,9 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
- "github.com/netbirdio/netbird/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/configurer"
+ "github.com/netbirdio/netbird/client/iface/device"
)
type MockWGIface struct {
@@ -14,7 +16,7 @@ type MockWGIface struct {
CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error
IsUserspaceBindFunc func() bool
NameFunc func() string
- AddressFunc func() WGAddress
+ AddressFunc func() device.WGAddress
ToInterfaceFunc func() *net.Interface
UpFunc func() (*bind.UniversalUDPMuxDefault, error)
UpdateAddrFunc func(newAddr string) error
@@ -23,10 +25,10 @@ type MockWGIface struct {
AddAllowedIPFunc func(peerKey string, allowedIP string) error
RemoveAllowedIPFunc func(peerKey string, allowedIP string) error
CloseFunc func() error
- SetFilterFunc func(filter PacketFilter) error
- GetFilterFunc func() PacketFilter
- GetDeviceFunc func() *DeviceWrapper
- GetStatsFunc func(peerKey string) (WGStats, error)
+ SetFilterFunc func(filter device.PacketFilter) error
+ GetFilterFunc func() device.PacketFilter
+ GetDeviceFunc func() *device.FilteredDevice
+ GetStatsFunc func(peerKey string) (configurer.WGStats, error)
GetInterfaceGUIDStringFunc func() (string, error)
}
@@ -50,7 +52,7 @@ func (m *MockWGIface) Name() string {
return m.NameFunc()
}
-func (m *MockWGIface) Address() WGAddress {
+func (m *MockWGIface) Address() device.WGAddress {
return m.AddressFunc()
}
@@ -86,18 +88,18 @@ func (m *MockWGIface) Close() error {
return m.CloseFunc()
}
-func (m *MockWGIface) SetFilter(filter PacketFilter) error {
+func (m *MockWGIface) SetFilter(filter device.PacketFilter) error {
return m.SetFilterFunc(filter)
}
-func (m *MockWGIface) GetFilter() PacketFilter {
+func (m *MockWGIface) GetFilter() device.PacketFilter {
return m.GetFilterFunc()
}
-func (m *MockWGIface) GetDevice() *DeviceWrapper {
+func (m *MockWGIface) GetDevice() *device.FilteredDevice {
return m.GetDeviceFunc()
}
-func (m *MockWGIface) GetStats(peerKey string) (WGStats, error) {
+func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) {
return m.GetStatsFunc(peerKey)
}
diff --git a/iface/iface_test.go b/client/iface/iface_test.go
similarity index 98%
rename from iface/iface_test.go
rename to client/iface/iface_test.go
index 8de9f647e..87a68addb 100644
--- a/iface/iface_test.go
+++ b/client/iface/iface_test.go
@@ -14,6 +14,8 @@ import (
"github.com/stretchr/testify/assert"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
+
+ "github.com/netbirdio/netbird/client/iface/device"
)
// keep darwin compatibility
@@ -414,7 +416,7 @@ func Test_ConnectPeers(t *testing.T) {
}
guid := fmt.Sprintf("{%s}", uuid.New().String())
- CustomWindowsGUIDString = strings.ToLower(guid)
+ device.CustomWindowsGUIDString = strings.ToLower(guid)
iface1, err := NewWGIFace(peer1ifaceName, peer1wgIP, peer1wgPort, peer1Key.String(), DefaultMTU, newNet, nil, nil)
if err != nil {
@@ -436,7 +438,7 @@ func Test_ConnectPeers(t *testing.T) {
}
guid = fmt.Sprintf("{%s}", uuid.New().String())
- CustomWindowsGUIDString = strings.ToLower(guid)
+ device.CustomWindowsGUIDString = strings.ToLower(guid)
newNet, err = stdnet.NewNet()
if err != nil {
diff --git a/iface/iface_unix.go b/client/iface/iface_unix.go
similarity index 53%
rename from iface/iface_unix.go
rename to client/iface/iface_unix.go
index 9608df1ad..09dbb2c1f 100644
--- a/iface/iface_unix.go
+++ b/client/iface/iface_unix.go
@@ -8,13 +8,14 @@ import (
"github.com/pion/transport/v3"
- "github.com/netbirdio/netbird/iface/bind"
- "github.com/netbirdio/netbird/iface/netstack"
+ "github.com/netbirdio/netbird/client/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/device"
+ "github.com/netbirdio/netbird/client/iface/netstack"
)
// NewWGIFace Creates a new WireGuard interface instance
-func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
- wgAddress, err := parseWGAddress(address)
+func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
+ wgAddress, err := device.ParseWGAddress(address)
if err != nil {
return nil, err
}
@@ -23,21 +24,21 @@ func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string,
// move the kernel/usp/netstack preference evaluation to upper layer
if netstack.IsEnabled() {
- wgIFace.tun = newTunNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn)
+ wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn)
wgIFace.userspaceBind = true
return wgIFace, nil
}
- if WireGuardModuleIsLoaded() {
- wgIFace.tun = newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet)
+ if device.WireGuardModuleIsLoaded() {
+ wgIFace.tun = device.NewKernelDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet)
wgIFace.userspaceBind = false
return wgIFace, nil
}
- if !tunModuleIsLoaded() {
+ if !device.ModuleTunIsLoaded() {
return nil, fmt.Errorf("couldn't check or load tun module")
}
- wgIFace.tun = newTunUSPDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, nil)
+ wgIFace.tun = device.NewUSPDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, nil)
wgIFace.userspaceBind = true
return wgIFace, nil
}
diff --git a/iface/iface_windows.go b/client/iface/iface_windows.go
similarity index 52%
rename from iface/iface_windows.go
rename to client/iface/iface_windows.go
index c5edd27a9..6845ef3dd 100644
--- a/iface/iface_windows.go
+++ b/client/iface/iface_windows.go
@@ -5,13 +5,14 @@ import (
"github.com/pion/transport/v3"
- "github.com/netbirdio/netbird/iface/bind"
- "github.com/netbirdio/netbird/iface/netstack"
+ "github.com/netbirdio/netbird/client/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/device"
+ "github.com/netbirdio/netbird/client/iface/netstack"
)
// NewWGIFace Creates a new WireGuard interface instance
-func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
- wgAddress, err := parseWGAddress(address)
+func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
+ wgAddress, err := device.ParseWGAddress(address)
if err != nil {
return nil, err
}
@@ -21,11 +22,11 @@ func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string,
}
if netstack.IsEnabled() {
- wgIFace.tun = newTunNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn)
+ wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn)
return wgIFace, nil
}
- wgIFace.tun = newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn)
+ wgIFace.tun = device.NewTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn)
return wgIFace, nil
}
@@ -36,5 +37,5 @@ func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
// GetInterfaceGUIDString returns an interface GUID. This is useful on Windows only
func (w *WGIface) GetInterfaceGUIDString() (string, error) {
- return w.tun.(*tunDevice).getInterfaceGUIDString()
+ return w.tun.(*device.TunDevice).GetInterfaceGUIDString()
}
diff --git a/iface/iwginterface.go b/client/iface/iwginterface.go
similarity index 65%
rename from iface/iwginterface.go
rename to client/iface/iwginterface.go
index 501f51d2b..cb6d7ccd9 100644
--- a/iface/iwginterface.go
+++ b/client/iface/iwginterface.go
@@ -8,7 +8,9 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
- "github.com/netbirdio/netbird/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/configurer"
+ "github.com/netbirdio/netbird/client/iface/device"
)
type IWGIface interface {
@@ -16,7 +18,7 @@ type IWGIface interface {
CreateOnAndroid(routeRange []string, ip string, domains []string) error
IsUserspaceBind() bool
Name() string
- Address() WGAddress
+ Address() device.WGAddress
ToInterface() *net.Interface
Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(newAddr string) error
@@ -25,8 +27,8 @@ type IWGIface interface {
AddAllowedIP(peerKey string, allowedIP string) error
RemoveAllowedIP(peerKey string, allowedIP string) error
Close() error
- SetFilter(filter PacketFilter) error
- GetFilter() PacketFilter
- GetDevice() *DeviceWrapper
- GetStats(peerKey string) (WGStats, error)
+ SetFilter(filter device.PacketFilter) error
+ GetFilter() device.PacketFilter
+ GetDevice() *device.FilteredDevice
+ GetStats(peerKey string) (configurer.WGStats, error)
}
diff --git a/iface/iwginterface_windows.go b/client/iface/iwginterface_windows.go
similarity index 65%
rename from iface/iwginterface_windows.go
rename to client/iface/iwginterface_windows.go
index b5053474e..6baeb66ae 100644
--- a/iface/iwginterface_windows.go
+++ b/client/iface/iwginterface_windows.go
@@ -6,7 +6,9 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
- "github.com/netbirdio/netbird/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/configurer"
+ "github.com/netbirdio/netbird/client/iface/device"
)
type IWGIface interface {
@@ -14,7 +16,7 @@ type IWGIface interface {
CreateOnAndroid(routeRange []string, ip string, domains []string) error
IsUserspaceBind() bool
Name() string
- Address() WGAddress
+ Address() device.WGAddress
ToInterface() *net.Interface
Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(newAddr string) error
@@ -23,9 +25,9 @@ type IWGIface interface {
AddAllowedIP(peerKey string, allowedIP string) error
RemoveAllowedIP(peerKey string, allowedIP string) error
Close() error
- SetFilter(filter PacketFilter) error
- GetFilter() PacketFilter
- GetDevice() *DeviceWrapper
- GetStats(peerKey string) (WGStats, error)
+ SetFilter(filter device.PacketFilter) error
+ GetFilter() device.PacketFilter
+ GetDevice() *device.FilteredDevice
+ GetStats(peerKey string) (configurer.WGStats, error)
GetInterfaceGUIDString() (string, error)
}
diff --git a/iface/mocks/README.md b/client/iface/mocks/README.md
similarity index 100%
rename from iface/mocks/README.md
rename to client/iface/mocks/README.md
diff --git a/iface/mocks/filter.go b/client/iface/mocks/filter.go
similarity index 97%
rename from iface/mocks/filter.go
rename to client/iface/mocks/filter.go
index 2d80d69f1..6348e0e77 100644
--- a/iface/mocks/filter.go
+++ b/client/iface/mocks/filter.go
@@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT.
-// Source: github.com/netbirdio/netbird/iface (interfaces: PacketFilter)
+// Source: github.com/netbirdio/netbird/client/iface (interfaces: PacketFilter)
// Package mocks is a generated GoMock package.
package mocks
diff --git a/iface/mocks/iface/mocks/filter.go b/client/iface/mocks/iface/mocks/filter.go
similarity index 97%
rename from iface/mocks/iface/mocks/filter.go
rename to client/iface/mocks/iface/mocks/filter.go
index 059a2b9a0..17e123abb 100644
--- a/iface/mocks/iface/mocks/filter.go
+++ b/client/iface/mocks/iface/mocks/filter.go
@@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT.
-// Source: github.com/netbirdio/netbird/iface (interfaces: PacketFilter)
+// Source: github.com/netbirdio/netbird/client/iface (interfaces: PacketFilter)
// Package mocks is a generated GoMock package.
package mocks
diff --git a/iface/mocks/tun.go b/client/iface/mocks/tun.go
similarity index 100%
rename from iface/mocks/tun.go
rename to client/iface/mocks/tun.go
diff --git a/iface/netstack/dialer.go b/client/iface/netstack/dialer.go
similarity index 100%
rename from iface/netstack/dialer.go
rename to client/iface/netstack/dialer.go
diff --git a/iface/netstack/env.go b/client/iface/netstack/env.go
similarity index 100%
rename from iface/netstack/env.go
rename to client/iface/netstack/env.go
diff --git a/iface/netstack/proxy.go b/client/iface/netstack/proxy.go
similarity index 100%
rename from iface/netstack/proxy.go
rename to client/iface/netstack/proxy.go
diff --git a/iface/netstack/tun.go b/client/iface/netstack/tun.go
similarity index 100%
rename from iface/netstack/tun.go
rename to client/iface/netstack/tun.go
diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go
index eec3d3b8c..7d999669a 100644
--- a/client/internal/acl/manager_test.go
+++ b/client/internal/acl/manager_test.go
@@ -9,8 +9,8 @@ import (
"github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/firewall/manager"
+ "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/acl/mocks"
- "github.com/netbirdio/netbird/iface"
mgmProto "github.com/netbirdio/netbird/management/proto"
)
diff --git a/client/internal/acl/mocks/iface_mapper.go b/client/internal/acl/mocks/iface_mapper.go
index 621b29513..3ed12b6dd 100644
--- a/client/internal/acl/mocks/iface_mapper.go
+++ b/client/internal/acl/mocks/iface_mapper.go
@@ -8,7 +8,8 @@ import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
- iface "github.com/netbirdio/netbird/iface"
+ iface "github.com/netbirdio/netbird/client/iface"
+ "github.com/netbirdio/netbird/client/iface/device"
)
// MockIFaceMapper is a mock of IFaceMapper interface.
@@ -77,7 +78,7 @@ func (mr *MockIFaceMapperMockRecorder) Name() *gomock.Call {
}
// SetFilter mocks base method.
-func (m *MockIFaceMapper) SetFilter(arg0 iface.PacketFilter) error {
+func (m *MockIFaceMapper) SetFilter(arg0 device.PacketFilter) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SetFilter", arg0)
ret0, _ := ret[0].(error)
diff --git a/client/internal/config.go b/client/internal/config.go
index 1df1e0547..ee54c6380 100644
--- a/client/internal/config.go
+++ b/client/internal/config.go
@@ -16,9 +16,9 @@ import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
+ "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
"github.com/netbirdio/netbird/client/ssh"
- "github.com/netbirdio/netbird/iface"
mgm "github.com/netbirdio/netbird/management/client"
"github.com/netbirdio/netbird/util"
)
diff --git a/client/internal/connect.go b/client/internal/connect.go
index 36b340cfb..c77f95603 100644
--- a/client/internal/connect.go
+++ b/client/internal/connect.go
@@ -17,13 +17,14 @@ import (
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
+ "github.com/netbirdio/netbird/client/iface"
+ "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
- "github.com/netbirdio/netbird/iface"
mgm "github.com/netbirdio/netbird/management/client"
mgmProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/relay/auth/hmac"
@@ -70,7 +71,7 @@ func (c *ConnectClient) RunWithProbes(
// RunOnAndroid with main logic on mobile system
func (c *ConnectClient) RunOnAndroid(
- tunAdapter iface.TunAdapter,
+ tunAdapter device.TunAdapter,
iFaceDiscover stdnet.ExternalIFaceDiscover,
networkChangeListener listener.NetworkChangeListener,
dnsAddresses []string,
@@ -205,7 +206,7 @@ func (c *ConnectClient) run(
localPeerState := peer.LocalPeerState{
IP: loginResp.GetPeerConfig().GetAddress(),
PubKey: myPrivateKey.PublicKey().String(),
- KernelInterface: iface.WireGuardModuleIsLoaded(),
+ KernelInterface: device.WireGuardModuleIsLoaded(),
FQDN: loginResp.GetPeerConfig().GetFqdn(),
}
c.statusRecorder.UpdateLocalPeerState(localPeerState)
diff --git a/client/internal/dns/response_writer_test.go b/client/internal/dns/response_writer_test.go
index 5a0047700..857964406 100644
--- a/client/internal/dns/response_writer_test.go
+++ b/client/internal/dns/response_writer_test.go
@@ -9,7 +9,7 @@ import (
"github.com/google/gopacket/layers"
"github.com/miekg/dns"
- "github.com/netbirdio/netbird/iface/mocks"
+ "github.com/netbirdio/netbird/client/iface/mocks"
)
func TestResponseWriterLocalAddr(t *testing.T) {
diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go
index b9552bc17..53d18a678 100644
--- a/client/internal/dns/server_test.go
+++ b/client/internal/dns/server_test.go
@@ -15,16 +15,18 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
+ "github.com/netbirdio/netbird/client/iface"
+ "github.com/netbirdio/netbird/client/iface/configurer"
+ "github.com/netbirdio/netbird/client/iface/device"
+ pfmock "github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/stdnet"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/formatter"
- "github.com/netbirdio/netbird/iface"
- pfmock "github.com/netbirdio/netbird/iface/mocks"
)
type mocWGIface struct {
- filter iface.PacketFilter
+ filter device.PacketFilter
}
func (w *mocWGIface) Name() string {
@@ -43,11 +45,11 @@ func (w *mocWGIface) ToInterface() *net.Interface {
panic("implement me")
}
-func (w *mocWGIface) GetFilter() iface.PacketFilter {
+func (w *mocWGIface) GetFilter() device.PacketFilter {
return w.filter
}
-func (w *mocWGIface) GetDevice() *iface.DeviceWrapper {
+func (w *mocWGIface) GetDevice() *device.FilteredDevice {
panic("implement me")
}
@@ -59,13 +61,13 @@ func (w *mocWGIface) IsUserspaceBind() bool {
return false
}
-func (w *mocWGIface) SetFilter(filter iface.PacketFilter) error {
+func (w *mocWGIface) SetFilter(filter device.PacketFilter) error {
w.filter = filter
return nil
}
-func (w *mocWGIface) GetStats(_ string) (iface.WGStats, error) {
- return iface.WGStats{}, nil
+func (w *mocWGIface) GetStats(_ string) (configurer.WGStats, error) {
+ return configurer.WGStats{}, nil
}
var zoneRecords = []nbdns.SimpleRecord{
diff --git a/client/internal/dns/wgiface.go b/client/internal/dns/wgiface.go
index 2f08e8d52..69bc83659 100644
--- a/client/internal/dns/wgiface.go
+++ b/client/internal/dns/wgiface.go
@@ -5,7 +5,9 @@ package dns
import (
"net"
- "github.com/netbirdio/netbird/iface"
+ "github.com/netbirdio/netbird/client/iface"
+ "github.com/netbirdio/netbird/client/iface/configurer"
+ "github.com/netbirdio/netbird/client/iface/device"
)
// WGIface defines subset methods of interface required for manager
@@ -14,7 +16,7 @@ type WGIface interface {
Address() iface.WGAddress
ToInterface() *net.Interface
IsUserspaceBind() bool
- GetFilter() iface.PacketFilter
- GetDevice() *iface.DeviceWrapper
- GetStats(peerKey string) (iface.WGStats, error)
+ GetFilter() device.PacketFilter
+ GetDevice() *device.FilteredDevice
+ GetStats(peerKey string) (configurer.WGStats, error)
}
diff --git a/client/internal/dns/wgiface_windows.go b/client/internal/dns/wgiface_windows.go
index f8bb80fb9..765132fdb 100644
--- a/client/internal/dns/wgiface_windows.go
+++ b/client/internal/dns/wgiface_windows.go
@@ -1,14 +1,18 @@
package dns
-import "github.com/netbirdio/netbird/iface"
+import (
+ "github.com/netbirdio/netbird/client/iface"
+ "github.com/netbirdio/netbird/client/iface/configurer"
+ "github.com/netbirdio/netbird/client/iface/device"
+)
// WGIface defines subset methods of interface required for manager
type WGIface interface {
Name() string
Address() iface.WGAddress
IsUserspaceBind() bool
- GetFilter() iface.PacketFilter
- GetDevice() *iface.DeviceWrapper
- GetStats(peerKey string) (iface.WGStats, error)
+ GetFilter() device.PacketFilter
+ GetDevice() *device.FilteredDevice
+ GetStats(peerKey string) (configurer.WGStats, error)
GetInterfaceGUIDString() (string, error)
}
diff --git a/client/internal/engine.go b/client/internal/engine.go
index 998cbce2d..c51901a22 100644
--- a/client/internal/engine.go
+++ b/client/internal/engine.go
@@ -23,9 +23,12 @@ import (
"github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/firewall/manager"
+ "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/dns"
+ "github.com/netbirdio/netbird/client/iface"
+ "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/internal/networkmonitor"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/relay"
@@ -36,8 +39,6 @@ import (
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns"
- "github.com/netbirdio/netbird/iface"
- "github.com/netbirdio/netbird/iface/bind"
mgm "github.com/netbirdio/netbird/management/client"
"github.com/netbirdio/netbird/management/domain"
mgmProto "github.com/netbirdio/netbird/management/proto"
@@ -619,7 +620,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
e.statusRecorder.UpdateLocalPeerState(peer.LocalPeerState{
IP: e.config.WgAddr,
PubKey: e.config.WgPrivateKey.PublicKey().String(),
- KernelInterface: iface.WireGuardModuleIsLoaded(),
+ KernelInterface: device.WireGuardModuleIsLoaded(),
FQDN: conf.GetFqdn(),
})
@@ -1165,15 +1166,15 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) {
log.Errorf("failed to create pion's stdnet: %s", err)
}
- var mArgs *iface.MobileIFaceArguments
+ var mArgs *device.MobileIFaceArguments
switch runtime.GOOS {
case "android":
- mArgs = &iface.MobileIFaceArguments{
+ mArgs = &device.MobileIFaceArguments{
TunAdapter: e.mobileDep.TunAdapter,
TunFd: int(e.mobileDep.FileDescriptor),
}
case "ios":
- mArgs = &iface.MobileIFaceArguments{
+ mArgs = &device.MobileIFaceArguments{
TunFd: int(e.mobileDep.FileDescriptor),
}
default:
diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go
index 95aadf141..29a8439a2 100644
--- a/client/internal/engine_test.go
+++ b/client/internal/engine_test.go
@@ -25,14 +25,15 @@ import (
"github.com/netbirdio/management-integrations/integrations"
+ "github.com/netbirdio/netbird/client/iface"
+ "github.com/netbirdio/netbird/client/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns"
- "github.com/netbirdio/netbird/iface"
- "github.com/netbirdio/netbird/iface/bind"
mgmt "github.com/netbirdio/netbird/management/client"
mgmtProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server"
@@ -874,7 +875,7 @@ func TestEngine_MultiplePeers(t *testing.T) {
mu.Lock()
defer mu.Unlock()
guid := fmt.Sprintf("{%s}", uuid.New().String())
- iface.CustomWindowsGUIDString = strings.ToLower(guid)
+ device.CustomWindowsGUIDString = strings.ToLower(guid)
err = engine.Start()
if err != nil {
t.Errorf("unable to start engine for peer %d with error %v", j, err)
diff --git a/client/internal/mobile_dependency.go b/client/internal/mobile_dependency.go
index 2355c67c3..2b0c92cc6 100644
--- a/client/internal/mobile_dependency.go
+++ b/client/internal/mobile_dependency.go
@@ -1,16 +1,16 @@
package internal
import (
+ "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/stdnet"
- "github.com/netbirdio/netbird/iface"
)
// MobileDependency collect all dependencies for mobile platform
type MobileDependency struct {
// Android only
- TunAdapter iface.TunAdapter
+ TunAdapter device.TunAdapter
IFaceDiscover stdnet.ExternalIFaceDiscover
NetworkChangeListener listener.NetworkChangeListener
HostDNSAddresses []string
diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go
index baff1372a..ad84bd700 100644
--- a/client/internal/peer/conn.go
+++ b/client/internal/peer/conn.go
@@ -15,9 +15,10 @@ import (
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
+ "github.com/netbirdio/netbird/client/iface"
+ "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/internal/wgproxy"
- "github.com/netbirdio/netbird/iface"
relayClient "github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/route"
nbnet "github.com/netbirdio/netbird/util/net"
@@ -684,7 +685,7 @@ func (conn *Conn) setStatusToDisconnected() {
// todo rethink status updates
conn.log.Debugf("error while updating peer's state, err: %v", err)
}
- if err := conn.statusRecorder.UpdateWireGuardPeerState(conn.config.Key, iface.WGStats{}); err != nil {
+ if err := conn.statusRecorder.UpdateWireGuardPeerState(conn.config.Key, configurer.WGStats{}); err != nil {
conn.log.Debugf("failed to reset wireguard stats for peer: %s", err)
}
}
diff --git a/client/internal/peer/conn_test.go b/client/internal/peer/conn_test.go
index 22e5409f8..b4926a9d2 100644
--- a/client/internal/peer/conn_test.go
+++ b/client/internal/peer/conn_test.go
@@ -9,9 +9,9 @@ import (
"github.com/magiconair/properties/assert"
+ "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/internal/wgproxy"
- "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/util"
)
diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go
index 915fa63f0..a28992fac 100644
--- a/client/internal/peer/status.go
+++ b/client/internal/peer/status.go
@@ -11,8 +11,8 @@ import (
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
+ "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/internal/relay"
- "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/management/domain"
relayClient "github.com/netbirdio/netbird/relay/client"
)
@@ -203,7 +203,7 @@ func (d *Status) GetPeer(peerPubKey string) (State, error) {
state, ok := d.peers[peerPubKey]
if !ok {
- return State{}, iface.ErrPeerNotFound
+ return State{}, configurer.ErrPeerNotFound
}
return state, nil
}
@@ -412,7 +412,7 @@ func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
}
// UpdateWireGuardPeerState updates the WireGuard bits of the peer state
-func (d *Status) UpdateWireGuardPeerState(pubKey string, wgStats iface.WGStats) error {
+func (d *Status) UpdateWireGuardPeerState(pubKey string, wgStats configurer.WGStats) error {
d.mux.Lock()
defer d.mux.Unlock()
diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go
index 8bf1b7568..c4e9d1950 100644
--- a/client/internal/peer/worker_ice.go
+++ b/client/internal/peer/worker_ice.go
@@ -15,9 +15,9 @@ import (
"github.com/pion/stun/v2"
log "github.com/sirupsen/logrus"
+ "github.com/netbirdio/netbird/client/iface"
+ "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/internal/stdnet"
- "github.com/netbirdio/netbird/iface"
- "github.com/netbirdio/netbird/iface/bind"
"github.com/netbirdio/netbird/route"
)
diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go
index db2caea7f..eaa232151 100644
--- a/client/internal/routemanager/client.go
+++ b/client/internal/routemanager/client.go
@@ -10,12 +10,12 @@ import (
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
+ "github.com/netbirdio/netbird/client/iface"
nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/static"
- "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
)
diff --git a/client/internal/routemanager/dynamic/route.go b/client/internal/routemanager/dynamic/route.go
index e86a52810..ac94d4a5c 100644
--- a/client/internal/routemanager/dynamic/route.go
+++ b/client/internal/routemanager/dynamic/route.go
@@ -13,10 +13,10 @@ import (
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
+ "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/util"
- "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route"
)
diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go
index d97fe631f..d7ddf7ae8 100644
--- a/client/internal/routemanager/manager.go
+++ b/client/internal/routemanager/manager.go
@@ -14,6 +14,8 @@ import (
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
+ "github.com/netbirdio/netbird/client/iface"
+ "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
@@ -21,7 +23,6 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
"github.com/netbirdio/netbird/client/internal/routeselector"
- "github.com/netbirdio/netbird/iface"
relayClient "github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/route"
nbnet "github.com/netbirdio/netbird/util/net"
@@ -102,7 +103,7 @@ func NewManager(
},
func(prefix netip.Prefix, peerKey string) error {
if err := wgInterface.RemoveAllowedIP(peerKey, prefix.String()); err != nil {
- if !errors.Is(err, iface.ErrPeerNotFound) && !errors.Is(err, iface.ErrAllowedIPNotFound) {
+ if !errors.Is(err, configurer.ErrPeerNotFound) && !errors.Is(err, configurer.ErrAllowedIPNotFound) {
return err
}
log.Tracef("Remove allowed IPs %s for %s: %v", prefix, peerKey, err)
diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go
index 2995e2740..2f26f7a5e 100644
--- a/client/internal/routemanager/manager_test.go
+++ b/client/internal/routemanager/manager_test.go
@@ -12,8 +12,8 @@ import (
"github.com/stretchr/testify/require"
+ "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/peer"
- "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
)
diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go
index 58a66715c..908279c88 100644
--- a/client/internal/routemanager/mock.go
+++ b/client/internal/routemanager/mock.go
@@ -5,9 +5,9 @@ import (
"fmt"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
+ "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/routeselector"
- "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/util/net"
)
diff --git a/client/internal/routemanager/server_android.go b/client/internal/routemanager/server_android.go
index 2057b9cc8..c75a0a7f2 100644
--- a/client/internal/routemanager/server_android.go
+++ b/client/internal/routemanager/server_android.go
@@ -7,8 +7,8 @@ import (
"fmt"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
+ "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/peer"
- "github.com/netbirdio/netbird/iface"
)
func newServerRouter(context.Context, iface.IWGIface, firewall.Manager, *peer.Status) (serverRouter, error) {
diff --git a/client/internal/routemanager/server_nonandroid.go b/client/internal/routemanager/server_nonandroid.go
index 1d1a4b063..ef38d5707 100644
--- a/client/internal/routemanager/server_nonandroid.go
+++ b/client/internal/routemanager/server_nonandroid.go
@@ -11,9 +11,9 @@ import (
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
+ "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
- "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
)
diff --git a/client/internal/routemanager/sysctl/sysctl_linux.go b/client/internal/routemanager/sysctl/sysctl_linux.go
index 13e1229f8..bb620ee68 100644
--- a/client/internal/routemanager/sysctl/sysctl_linux.go
+++ b/client/internal/routemanager/sysctl/sysctl_linux.go
@@ -13,7 +13,7 @@ import (
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
- "github.com/netbirdio/netbird/iface"
+ "github.com/netbirdio/netbird/client/iface"
)
const (
diff --git a/client/internal/routemanager/systemops/systemops.go b/client/internal/routemanager/systemops/systemops.go
index 10944c1e2..d1cb83bfb 100644
--- a/client/internal/routemanager/systemops/systemops.go
+++ b/client/internal/routemanager/systemops/systemops.go
@@ -5,9 +5,9 @@ import (
"net/netip"
"sync"
+ "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
- "github.com/netbirdio/netbird/iface"
)
type Nexthop struct {
diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go
index 90f06ba78..9258f4a4e 100644
--- a/client/internal/routemanager/systemops/systemops_generic.go
+++ b/client/internal/routemanager/systemops/systemops_generic.go
@@ -16,10 +16,10 @@ import (
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
+ "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/util"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
- "github.com/netbirdio/netbird/iface"
nbnet "github.com/netbirdio/netbird/util/net"
)
diff --git a/client/internal/routemanager/systemops/systemops_generic_test.go b/client/internal/routemanager/systemops/systemops_generic_test.go
index 94965c119..238225807 100644
--- a/client/internal/routemanager/systemops/systemops_generic_test.go
+++ b/client/internal/routemanager/systemops/systemops_generic_test.go
@@ -19,7 +19,7 @@ import (
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
- "github.com/netbirdio/netbird/iface"
+ "github.com/netbirdio/netbird/client/iface"
)
type dialer interface {
diff --git a/iface/tun.go b/iface/tun.go
deleted file mode 100644
index 7d0a57ed6..000000000
--- a/iface/tun.go
+++ /dev/null
@@ -1,21 +0,0 @@
-//go:build !android
-// +build !android
-
-package iface
-
-import (
- "github.com/netbirdio/netbird/iface/bind"
-)
-
-// CustomWindowsGUIDString is a custom GUID string for the interface
-var CustomWindowsGUIDString string
-
-type wgTunDevice interface {
- Create() (wgConfigurer, error)
- Up() (*bind.UniversalUDPMuxDefault, error)
- UpdateAddr(address WGAddress) error
- WgAddress() WGAddress
- DeviceName() string
- Close() error
- Wrapper() *DeviceWrapper // todo eliminate this function
-}
diff --git a/iface/wg_configurer.go b/iface/wg_configurer.go
deleted file mode 100644
index dd38ba075..000000000
--- a/iface/wg_configurer.go
+++ /dev/null
@@ -1,21 +0,0 @@
-package iface
-
-import (
- "errors"
- "net"
- "time"
-
- "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
-)
-
-var ErrPeerNotFound = errors.New("peer not found")
-
-type wgConfigurer interface {
- configureInterface(privateKey string, port int) error
- updatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
- removePeer(peerKey string) error
- addAllowedIP(peerKey string, allowedIP string) error
- removeAllowedIP(peerKey string, allowedIP string) error
- close()
- getStats(peerKey string) (WGStats, error)
-}
diff --git a/util/net/net.go b/util/net/net.go
index 8d1fcebd0..61b47dbe7 100644
--- a/util/net/net.go
+++ b/util/net/net.go
@@ -4,7 +4,7 @@ import (
"net"
"os"
- "github.com/netbirdio/netbird/iface/netstack"
+ "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/google/uuid"
)
From 8934453b30309e508df09236ac102c0537259291 Mon Sep 17 00:00:00 2001
From: Maycon Santos
Date: Wed, 2 Oct 2024 18:29:51 +0200
Subject: [PATCH 88/89] Update management base docker image (#2687)
---
management/Dockerfile | 4 ++--
management/Dockerfile.debug | 2 +-
2 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/management/Dockerfile b/management/Dockerfile
index cac640bf4..3b2df2623 100644
--- a/management/Dockerfile
+++ b/management/Dockerfile
@@ -1,5 +1,5 @@
-FROM ubuntu:22.04
+FROM ubuntu:24.04
RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt
ENTRYPOINT [ "/go/bin/netbird-mgmt","management"]
CMD ["--log-file", "console"]
-COPY netbird-mgmt /go/bin/netbird-mgmt
\ No newline at end of file
+COPY netbird-mgmt /go/bin/netbird-mgmt
diff --git a/management/Dockerfile.debug b/management/Dockerfile.debug
index f4be366a8..4d9730bd7 100644
--- a/management/Dockerfile.debug
+++ b/management/Dockerfile.debug
@@ -1,4 +1,4 @@
-FROM ubuntu:22.04
+FROM ubuntu:24.04
RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt
ENTRYPOINT [ "/go/bin/netbird-mgmt","management","--log-level","debug"]
CMD ["--log-file", "console"]
From a915707d13e46512ac6da01af97874d843152d4e Mon Sep 17 00:00:00 2001
From: bcmmbaga
Date: Thu, 3 Oct 2024 14:12:53 +0300
Subject: [PATCH 89/89] fix merge
Signed-off-by: bcmmbaga
---
management/server/peer.go | 2 --
1 file changed, 2 deletions(-)
diff --git a/management/server/peer.go b/management/server/peer.go
index 0311c867b..97e11c08a 100644
--- a/management/server/peer.go
+++ b/management/server/peer.go
@@ -996,8 +996,6 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
postureChecks := am.getPeerPostureChecks(account, p)
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache)
- //am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update})
- //update := toSyncResponse(ctx, nil, p, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache)
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap, Checks: postureChecks})
}(peer)
}