[management, client] Add API to change the network range (#4177)

This commit is contained in:
Viktor Liu
2025-08-04 16:45:49 +02:00
committed by GitHub
parent 58eb3c8cc2
commit beb66208a0
20 changed files with 606 additions and 27 deletions

View File

@@ -171,7 +171,7 @@ func fakeAddress(peerAddress *net.UDPAddr) (*netip.AddrPort, error) {
fakeIP, err := netip.ParseAddr(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3]))
if err != nil {
return nil, fmt.Errorf("failed to parse new IP: %w", err)
return nil, fmt.Errorf("parse new IP: %w", err)
}
netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port))

View File

@@ -95,7 +95,7 @@ func (e *ProxyWrapper) CloseConn() error {
e.closeListener.SetCloseListener(nil)
if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
return fmt.Errorf("failed to close remote conn: %w", err)
return fmt.Errorf("close remote conn: %w", err)
}
return nil
}

View File

@@ -861,15 +861,10 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
return errors.New("wireguard interface is not initialized")
}
// Cannot update the IP address without restarting the engine because
// the firewall, route manager, and other components cache the old address
if e.wgInterface.Address().String() != conf.Address {
oldAddr := e.wgInterface.Address().String()
log.Debugf("updating peer address from %s to %s", oldAddr, conf.Address)
err := e.wgInterface.UpdateAddr(conf.Address)
if err != nil {
return err
}
e.config.WgAddr = conf.Address
log.Infof("updated peer address from %s to %s", oldAddr, conf.Address)
log.Infof("peer IP address has changed from %s to %s", e.wgInterface.Address().String(), conf.Address)
}
if conf.GetSshConfig() != nil {
@@ -880,7 +875,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
}
state := e.statusRecorder.GetLocalPeerState()
state.IP = e.config.WgAddr
state.IP = e.wgInterface.Address().String()
state.PubKey = e.config.WgPrivateKey.PublicKey().String()
state.KernelInterface = device.WireGuardModuleIsLoaded()
state.FQDN = conf.GetFqdn()

View File

@@ -142,7 +142,7 @@ var (
err := handleRebrand(cmd)
if err != nil {
return fmt.Errorf("failed to migrate files %v", err)
return fmt.Errorf("migrate files %v", err)
}
if _, err = os.Stat(config.Datadir); os.IsNotExist(err) {
@@ -184,7 +184,7 @@ var (
}
eventStore, key, err := integrations.InitEventStore(ctx, config.Datadir, config.DataStoreEncryptionKey, integrationMetrics)
if err != nil {
return fmt.Errorf("failed to initialize database: %s", err)
return fmt.Errorf("initialize database: %s", err)
}
if config.DataStoreEncryptionKey != key {
@@ -192,7 +192,7 @@ var (
config.DataStoreEncryptionKey = key
err := updateMgmtConfig(ctx, types.MgmtConfigPath, config)
if err != nil {
return fmt.Errorf("failed to write out store encryption key: %s", err)
return fmt.Errorf("write out store encryption key: %s", err)
}
}
@@ -205,7 +205,7 @@ var (
integratedPeerValidator, err := integrations.NewIntegratedValidator(ctx, eventStore)
if err != nil {
return fmt.Errorf("failed to initialize integrated peer validator: %v", err)
return fmt.Errorf("initialize integrated peer validator: %v", err)
}
permissionsManager := integrations.InitPermissionsManager(store)
@@ -217,7 +217,7 @@ var (
accountManager, err := server.BuildManager(ctx, store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain,
dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator, appMetrics, proxyController, settingsManager, permissionsManager, config.DisableDefaultPolicy)
if err != nil {
return fmt.Errorf("failed to build default manager: %v", err)
return fmt.Errorf("build default manager: %v", err)
}
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsManager)

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"math/rand"
"net"
"net/netip"
"os"
"reflect"
"regexp"
@@ -324,6 +325,13 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return err
}
if oldSettings.NetworkRange != newSettings.NetworkRange {
if err = am.reallocateAccountPeerIPs(ctx, transaction, accountID, newSettings.NetworkRange); err != nil {
return err
}
updateAccountPeers = true
}
if oldSettings.RoutingPeerDNSResolutionEnabled != newSettings.RoutingPeerDNSResolutionEnabled ||
oldSettings.LazyConnectionEnabled != newSettings.LazyConnectionEnabled ||
oldSettings.DNSDomain != newSettings.DNSDomain {
@@ -362,7 +370,18 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return nil, err
}
if oldSettings.DNSDomain != newSettings.DNSDomain {
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountDNSDomainUpdated, nil)
eventMeta := map[string]any{
"old_dns_domain": oldSettings.DNSDomain,
"new_dns_domain": newSettings.DNSDomain,
}
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountDNSDomainUpdated, eventMeta)
}
if oldSettings.NetworkRange != newSettings.NetworkRange {
eventMeta := map[string]any{
"old_network_range": oldSettings.NetworkRange.String(),
"new_network_range": newSettings.NetworkRange.String(),
}
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountNetworkRangeUpdated, eventMeta)
}
if updateAccountPeers || extraSettingsChanged || groupChangesAffectPeers {
@@ -2018,3 +2037,154 @@ func propagateUserGroupMemberships(ctx context.Context, transaction store.Store,
return len(updatedGroups) > 0, peersAffected, nil
}
// reallocateAccountPeerIPs re-allocates all peer IPs when the network range changes
func (am *DefaultAccountManager) reallocateAccountPeerIPs(ctx context.Context, transaction store.Store, accountID string, newNetworkRange netip.Prefix) error {
if !newNetworkRange.IsValid() {
return nil
}
newIPNet := net.IPNet{
IP: newNetworkRange.Masked().Addr().AsSlice(),
Mask: net.CIDRMask(newNetworkRange.Bits(), newNetworkRange.Addr().BitLen()),
}
account, err := transaction.GetAccount(ctx, accountID)
if err != nil {
return err
}
account.Network.Net = newIPNet
peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "")
if err != nil {
return err
}
var takenIPs []net.IP
for _, peer := range peers {
newIP, err := types.AllocatePeerIP(newIPNet, takenIPs)
if err != nil {
return status.Errorf(status.Internal, "allocate IP for peer %s: %v", peer.ID, err)
}
log.WithContext(ctx).Infof("reallocating peer %s IP from %s to %s due to network range change",
peer.ID, peer.IP.String(), newIP.String())
peer.IP = newIP
takenIPs = append(takenIPs, newIP)
}
if err = transaction.SaveAccount(ctx, account); err != nil {
return err
}
for _, peer := range peers {
if err = transaction.SavePeer(ctx, store.LockingStrengthUpdate, accountID, peer); err != nil {
return status.Errorf(status.Internal, "save updated peer %s: %v", peer.ID, err)
}
}
log.WithContext(ctx).Infof("successfully re-allocated IPs for %d peers in account %s to network range %s",
len(peers), accountID, newNetworkRange.String())
return nil
}
func (am *DefaultAccountManager) validateIPForUpdate(account *types.Account, peers []*nbpeer.Peer, peerID string, newIP netip.Addr) error {
if !account.Network.Net.Contains(newIP.AsSlice()) {
return status.Errorf(status.InvalidArgument, "IP %s is not within the account network range %s", newIP.String(), account.Network.Net.String())
}
for _, peer := range peers {
if peer.ID != peerID && peer.IP.Equal(newIP.AsSlice()) {
return status.Errorf(status.InvalidArgument, "IP %s is already assigned to peer %s", newIP.String(), peer.ID)
}
}
return nil
}
func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Update)
if err != nil {
return fmt.Errorf("validate user permissions: %w", err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
updateNetworkMap, err := am.updatePeerIPInTransaction(ctx, accountID, userID, peerID, newIP)
if err != nil {
return fmt.Errorf("update peer IP transaction: %w", err)
}
if updateNetworkMap {
am.BufferUpdateAccountPeers(ctx, accountID)
}
return nil
}
func (am *DefaultAccountManager) updatePeerIPInTransaction(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) (bool, error) {
var updateNetworkMap bool
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
account, err := transaction.GetAccount(ctx, accountID)
if err != nil {
return fmt.Errorf("get account: %w", err)
}
existingPeer, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, peerID)
if err != nil {
return fmt.Errorf("get peer: %w", err)
}
if existingPeer.IP.Equal(newIP.AsSlice()) {
return nil
}
peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "")
if err != nil {
return fmt.Errorf("get account peers: %w", err)
}
if err := am.validateIPForUpdate(account, peers, peerID, newIP); err != nil {
return err
}
if err := am.savePeerIPUpdate(ctx, transaction, accountID, userID, existingPeer, newIP); err != nil {
return err
}
updateNetworkMap = true
return nil
})
return updateNetworkMap, err
}
func (am *DefaultAccountManager) savePeerIPUpdate(ctx context.Context, transaction store.Store, accountID, userID string, peer *nbpeer.Peer, newIP netip.Addr) error {
log.WithContext(ctx).Infof("updating peer %s IP from %s to %s", peer.ID, peer.IP, newIP)
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return fmt.Errorf("get account settings: %w", err)
}
dnsDomain := am.GetDNSDomain(settings)
eventMeta := peer.EventMeta(dnsDomain)
oldIP := peer.IP.String()
peer.IP = newIP.AsSlice()
err = transaction.SavePeer(ctx, store.LockingStrengthUpdate, accountID, peer)
if err != nil {
return fmt.Errorf("save peer: %w", err)
}
eventMeta["old_ip"] = oldIP
eventMeta["ip"] = newIP.String()
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerIPUpdated, eventMeta)
return nil
}

View File

@@ -51,6 +51,7 @@ type Manager interface {
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
GetPeerNetwork(ctx context.Context, peerID string) (*types.Network, error)
AddPeer(ctx context.Context, setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"io"
"net"
"net/netip"
"os"
"reflect"
"strconv"
@@ -3522,3 +3523,70 @@ func TestDefaultAccountManager_UpdateAccountOnboarding(t *testing.T) {
require.NoError(t, err)
})
}
func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to create an account")
key1, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
key2, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
peer1, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{
Key: key1.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"},
})
require.NoError(t, err, "unable to add peer1")
peer2, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{
Key: key2.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
})
require.NoError(t, err, "unable to add peer2")
t.Run("update peer IP successfully", func(t *testing.T) {
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "unable to get account")
newIP, err := types.AllocatePeerIP(account.Network.Net, []net.IP{peer1.IP, peer2.IP})
require.NoError(t, err, "unable to allocate new IP")
newAddr := netip.MustParseAddr(newIP.String())
err = manager.UpdatePeerIP(context.Background(), accountID, userID, peer1.ID, newAddr)
require.NoError(t, err, "unable to update peer IP")
updatedPeer, err := manager.GetPeer(context.Background(), accountID, peer1.ID, userID)
require.NoError(t, err, "unable to get updated peer")
assert.Equal(t, newIP.String(), updatedPeer.IP.String(), "peer IP should be updated")
})
t.Run("update peer IP with same IP should be no-op", func(t *testing.T) {
currentAddr := netip.MustParseAddr(peer1.IP.String())
err := manager.UpdatePeerIP(context.Background(), accountID, userID, peer1.ID, currentAddr)
require.NoError(t, err, "updating with same IP should not error")
})
t.Run("update peer IP with collision should fail", func(t *testing.T) {
peer2Addr := netip.MustParseAddr(peer2.IP.String())
err := manager.UpdatePeerIP(context.Background(), accountID, userID, peer1.ID, peer2Addr)
require.Error(t, err, "should fail when IP is already assigned")
assert.Contains(t, err.Error(), "already assigned", "error should mention IP collision")
})
t.Run("update peer IP outside network range should fail", func(t *testing.T) {
invalidAddr := netip.MustParseAddr("192.168.1.100")
err := manager.UpdatePeerIP(context.Background(), accountID, userID, peer1.ID, invalidAddr)
require.Error(t, err, "should fail when IP is outside network range")
assert.Contains(t, err.Error(), "not within the account network range", "error should mention network range")
})
t.Run("update peer IP with invalid peer ID should fail", func(t *testing.T) {
newAddr := netip.MustParseAddr("100.64.0.101")
err := manager.UpdatePeerIP(context.Background(), accountID, userID, "invalid-peer-id", newAddr)
require.Error(t, err, "should fail with invalid peer ID")
})
}

View File

@@ -175,6 +175,9 @@ const (
AccountLazyConnectionEnabled Activity = 85
AccountLazyConnectionDisabled Activity = 86
AccountNetworkRangeUpdated Activity = 87
PeerIPUpdated Activity = 88
AccountDeleted Activity = 99999
)
@@ -277,6 +280,10 @@ var activityMap = map[Activity]Code{
AccountLazyConnectionEnabled: {"Account lazy connection enabled", "account.setting.lazy.connection.enable"},
AccountLazyConnectionDisabled: {"Account lazy connection disabled", "account.setting.lazy.connection.disable"},
AccountNetworkRangeUpdated: {"Account network range updated", "account.network.range.update"},
PeerIPUpdated: {"Peer IP updated", "peer.ip.update"},
}
// StringCode returns a string code of the activity

View File

@@ -133,6 +133,11 @@ components:
description: Allows to define a custom dns domain for the account
type: string
example: my-organization.org
network_range:
description: Allows to define a custom network range for the account in CIDR format
type: string
format: cidr
example: 100.64.0.0/16
extra:
$ref: '#/components/schemas/AccountExtraSettings'
lazy_connection_enabled:
@@ -342,6 +347,11 @@ components:
description: (Cloud only) Indicates whether peer needs approval
type: boolean
example: true
ip:
description: Peer's IP address
type: string
format: ipv4
example: 100.64.0.15
required:
- name
- ssh_enabled

View File

@@ -303,6 +303,9 @@ type AccountSettings struct {
// LazyConnectionEnabled Enables or disables experimental lazy connection
LazyConnectionEnabled *bool `json:"lazy_connection_enabled,omitempty"`
// NetworkRange Allows to define a custom network range for the account in CIDR format
NetworkRange *string `json:"network_range,omitempty"`
// PeerInactivityExpiration Period of time of inactivity after which peer session expires (seconds).
PeerInactivityExpiration int `json:"peer_inactivity_expiration"`
@@ -1196,11 +1199,14 @@ type PeerNetworkRangeCheckAction string
// PeerRequest defines model for PeerRequest.
type PeerRequest struct {
// ApprovalRequired (Cloud only) Indicates whether peer needs approval
ApprovalRequired *bool `json:"approval_required,omitempty"`
InactivityExpirationEnabled bool `json:"inactivity_expiration_enabled"`
LoginExpirationEnabled bool `json:"login_expiration_enabled"`
Name string `json:"name"`
SshEnabled bool `json:"ssh_enabled"`
ApprovalRequired *bool `json:"approval_required,omitempty"`
InactivityExpirationEnabled bool `json:"inactivity_expiration_enabled"`
// Ip Peer's IP address
Ip *string `json:"ip,omitempty"`
LoginExpirationEnabled bool `json:"login_expiration_enabled"`
Name string `json:"name"`
SshEnabled bool `json:"ssh_enabled"`
}
// PersonalAccessToken defines model for PersonalAccessToken.

View File

@@ -1,8 +1,10 @@
package accounts
import (
"context"
"encoding/json"
"net/http"
"net/netip"
"time"
"github.com/gorilla/mux"
@@ -16,6 +18,17 @@ import (
"github.com/netbirdio/netbird/management/server/types"
)
const (
// PeerBufferPercentage is the percentage of peers to add as buffer for network range calculations
PeerBufferPercentage = 0.5
// MinRequiredAddresses is the minimum number of addresses required in a network range
MinRequiredAddresses = 10
// MinNetworkBits is the minimum prefix length for IPv4 network ranges (e.g., /29 gives 8 addresses, /28 gives 16)
MinNetworkBitsIPv4 = 28
// MinNetworkBitsIPv6 is the minimum prefix length for IPv6 network ranges
MinNetworkBitsIPv6 = 120
)
// handler is a handler that handles the server.Account HTTP endpoints
type handler struct {
accountManager account.Manager
@@ -37,6 +50,86 @@ func newHandler(accountManager account.Manager, settingsManager settings.Manager
}
}
func validateIPAddress(addr netip.Addr) error {
if addr.IsLoopback() {
return status.Errorf(status.InvalidArgument, "loopback address range not allowed")
}
if addr.IsMulticast() {
return status.Errorf(status.InvalidArgument, "multicast address range not allowed")
}
if addr.IsLinkLocalUnicast() || addr.IsLinkLocalMulticast() {
return status.Errorf(status.InvalidArgument, "link-local address range not allowed")
}
return nil
}
func validateMinimumSize(prefix netip.Prefix) error {
addr := prefix.Addr()
if addr.Is4() && prefix.Bits() > MinNetworkBitsIPv4 {
return status.Errorf(status.InvalidArgument, "network range too small: minimum size is /%d for IPv4", MinNetworkBitsIPv4)
}
if addr.Is6() && prefix.Bits() > MinNetworkBitsIPv6 {
return status.Errorf(status.InvalidArgument, "network range too small: minimum size is /%d for IPv6", MinNetworkBitsIPv6)
}
return nil
}
func (h *handler) validateNetworkRange(ctx context.Context, accountID, userID string, networkRange netip.Prefix) error {
if !networkRange.IsValid() {
return nil
}
if err := validateIPAddress(networkRange.Addr()); err != nil {
return err
}
if err := validateMinimumSize(networkRange); err != nil {
return err
}
return h.validateCapacity(ctx, accountID, userID, networkRange)
}
func (h *handler) validateCapacity(ctx context.Context, accountID, userID string, prefix netip.Prefix) error {
peers, err := h.accountManager.GetPeers(ctx, accountID, userID, "", "")
if err != nil {
return status.Errorf(status.Internal, "get peer count: %v", err)
}
maxHosts := calculateMaxHosts(prefix)
requiredAddresses := calculateRequiredAddresses(len(peers))
if maxHosts < requiredAddresses {
return status.Errorf(status.InvalidArgument,
"network range too small: need at least %d addresses for %d peers + buffer, but range provides %d",
requiredAddresses, len(peers), maxHosts)
}
return nil
}
func calculateMaxHosts(prefix netip.Prefix) int64 {
availableAddresses := prefix.Addr().BitLen() - prefix.Bits()
maxHosts := int64(1) << availableAddresses
if prefix.Addr().Is4() {
maxHosts -= 2 // network and broadcast addresses
}
return maxHosts
}
func calculateRequiredAddresses(peerCount int) int64 {
requiredAddresses := int64(peerCount) + int64(float64(peerCount)*PeerBufferPercentage)
if requiredAddresses < MinRequiredAddresses {
requiredAddresses = MinRequiredAddresses
}
return requiredAddresses
}
// getAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account.
func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
@@ -131,6 +224,18 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
if req.Settings.LazyConnectionEnabled != nil {
settings.LazyConnectionEnabled = *req.Settings.LazyConnectionEnabled
}
if req.Settings.NetworkRange != nil && *req.Settings.NetworkRange != "" {
prefix, err := netip.ParsePrefix(*req.Settings.NetworkRange)
if err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid CIDR format: %v", err), w)
return
}
if err := h.validateNetworkRange(r.Context(), accountID, userID, prefix); err != nil {
util.WriteError(r.Context(), err, w)
return
}
settings.NetworkRange = prefix
}
var onboarding *types.AccountOnboarding
if req.Onboarding != nil {
@@ -208,6 +313,11 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A
DnsDomain: &settings.DNSDomain,
}
if settings.NetworkRange.IsValid() {
networkRangeStr := settings.NetworkRange.String()
apiSettings.NetworkRange = &networkRangeStr
}
apiOnboarding := api.AccountOnboarding{
OnboardingFlowPending: onboarding.OnboardingFlowPending,
SignupFormPending: onboarding.SignupFormPending,

View File

@@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"net/http"
"net/netip"
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
@@ -111,6 +112,19 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri
}
}
if req.Ip != nil {
addr, err := netip.ParseAddr(*req.Ip)
if err != nil {
util.WriteError(ctx, status.Errorf(status.InvalidArgument, "invalid IP address %s: %v", *req.Ip, err), w)
return
}
if err = h.accountManager.UpdatePeerIP(ctx, accountID, userID, peerID, addr); err != nil {
util.WriteError(ctx, err, w)
return
}
}
peer, err := h.accountManager.UpdatePeer(ctx, accountID, userID, update)
if err != nil {
util.WriteError(ctx, err, w)

View File

@@ -9,6 +9,7 @@ import (
"net"
"net/http"
"net/http/httptest"
"net/netip"
"testing"
"time"
@@ -21,6 +22,7 @@ import (
"github.com/netbirdio/netbird/management/server/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/mock_server"
)
@@ -112,6 +114,15 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler {
p.Name = update.Name
return p, nil
},
UpdatePeerIPFunc: func(_ context.Context, accountID, userID, peerID string, newIP netip.Addr) error {
for _, peer := range peers {
if peer.ID == peerID {
peer.IP = net.IP(newIP.AsSlice())
return nil
}
}
return fmt.Errorf("peer not found")
},
GetPeerFunc: func(_ context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) {
var p *nbpeer.Peer
for _, peer := range peers {
@@ -450,3 +461,73 @@ func TestGetAccessiblePeers(t *testing.T) {
})
}
}
func TestPeersHandlerUpdatePeerIP(t *testing.T) {
testPeer := &nbpeer.Peer{
ID: testPeerID,
Key: "key",
IP: net.ParseIP("100.64.0.1"),
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
Name: "test-host@netbird.io",
LoginExpirationEnabled: false,
UserID: regularUser,
Meta: nbpeer.PeerSystemMeta{
Hostname: "test-host@netbird.io",
Core: "22.04",
},
}
p := initTestMetaData(testPeer)
tt := []struct {
name string
peerID string
requestBody string
callerUserID string
expectedStatus int
expectedIP string
}{
{
name: "update peer IP successfully",
peerID: testPeerID,
requestBody: `{"ip": "100.64.0.100"}`,
callerUserID: adminUser,
expectedStatus: http.StatusOK,
expectedIP: "100.64.0.100",
},
{
name: "update peer IP with invalid IP",
peerID: testPeerID,
requestBody: `{"ip": "invalid-ip"}`,
callerUserID: adminUser,
expectedStatus: http.StatusUnprocessableEntity,
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodPut, fmt.Sprintf("/peers/%s", tc.peerID), bytes.NewBuffer([]byte(tc.requestBody)))
req.Header.Set("Content-Type", "application/json")
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
UserId: tc.callerUserID,
Domain: "hotmail.com",
AccountId: "test_id",
})
rr := httptest.NewRecorder()
router := mux.NewRouter()
router.HandleFunc("/peers/{peerId}", p.HandlePeer).Methods("PUT")
router.ServeHTTP(rr, req)
assert.Equal(t, tc.expectedStatus, rr.Code)
if tc.expectedStatus == http.StatusOK && tc.expectedIP != "" {
var updatedPeer api.Peer
err := json.Unmarshal(rr.Body.Bytes(), &updatedPeer)
require.NoError(t, err)
assert.Equal(t, tc.expectedIP, updatedPeer.Ip)
}
})
}
}

View File

@@ -60,6 +60,7 @@ type MockAccountManager struct {
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error)
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
UpdatePeerIPFunc func(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error
CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error
@@ -483,6 +484,13 @@ func (am *MockAccountManager) UpdatePeer(ctx context.Context, accountID, userID
return nil, status.Errorf(codes.Unimplemented, "method UpdatePeer is not implemented")
}
func (am *MockAccountManager) UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error {
if am.UpdatePeerIPFunc != nil {
return am.UpdatePeerIPFunc(ctx, accountID, userID, peerID, newIP)
}
return status.Errorf(codes.Unimplemented, "method UpdatePeerIP is not implemented")
}
// CreateRoute mock implementation of CreateRoute from server.AccountManager interface
func (am *MockAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupID []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) {
if am.CreateRouteFunc != nil {

View File

@@ -163,7 +163,10 @@ func (n *Network) Copy() *Network {
// E.g. if ipNet=100.30.0.0/16 and takenIps=[100.30.0.1, 100.30.0.4] then the result would be 100.30.0.2 or 100.30.0.3
func AllocatePeerIP(ipNet net.IPNet, takenIps []net.IP) (net.IP, error) {
baseIP := ipToUint32(ipNet.IP.Mask(ipNet.Mask))
totalIPs := uint32(1 << SubnetSize)
ones, bits := ipNet.Mask.Size()
hostBits := bits - ones
totalIPs := uint32(1 << hostBits)
taken := make(map[uint32]struct{}, len(takenIps)+1)
taken[baseIP] = struct{}{} // reserve network IP

View File

@@ -5,6 +5,7 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewNetwork(t *testing.T) {
@@ -38,6 +39,107 @@ func TestAllocatePeerIP(t *testing.T) {
}
}
func TestAllocatePeerIPSmallSubnet(t *testing.T) {
// Test /27 network (10.0.0.0/27) - should only have 30 usable IPs (10.0.0.1 to 10.0.0.30)
ipNet := net.IPNet{IP: net.ParseIP("10.0.0.0"), Mask: net.IPMask{255, 255, 255, 224}}
var ips []net.IP
// Allocate all available IPs in the /27 network
for i := 0; i < 30; i++ {
ip, err := AllocatePeerIP(ipNet, ips)
if err != nil {
t.Fatal(err)
}
// Verify IP is within the correct range
if !ipNet.Contains(ip) {
t.Errorf("allocated IP %s is not within network %s", ip.String(), ipNet.String())
}
ips = append(ips, ip)
}
assert.Len(t, ips, 30)
// Verify all IPs are unique
uniq := make(map[string]struct{})
for _, ip := range ips {
if _, ok := uniq[ip.String()]; !ok {
uniq[ip.String()] = struct{}{}
} else {
t.Errorf("found duplicate IP %s", ip.String())
}
}
// Try to allocate one more IP - should fail as network is full
_, err := AllocatePeerIP(ipNet, ips)
if err == nil {
t.Error("expected error when network is full, but got none")
}
}
func TestAllocatePeerIPVariousCIDRs(t *testing.T) {
testCases := []struct {
name string
cidr string
expectedUsable int
}{
{"/30 network", "192.168.1.0/30", 2}, // 4 total - 2 reserved = 2 usable
{"/29 network", "192.168.1.0/29", 6}, // 8 total - 2 reserved = 6 usable
{"/28 network", "192.168.1.0/28", 14}, // 16 total - 2 reserved = 14 usable
{"/27 network", "192.168.1.0/27", 30}, // 32 total - 2 reserved = 30 usable
{"/26 network", "192.168.1.0/26", 62}, // 64 total - 2 reserved = 62 usable
{"/25 network", "192.168.1.0/25", 126}, // 128 total - 2 reserved = 126 usable
{"/16 network", "10.0.0.0/16", 65534}, // 65536 total - 2 reserved = 65534 usable
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
_, ipNet, err := net.ParseCIDR(tc.cidr)
require.NoError(t, err)
var ips []net.IP
// For larger networks, test only a subset to avoid long test runs
testCount := tc.expectedUsable
if testCount > 1000 {
testCount = 1000
}
// Allocate IPs and verify they're within the correct range
for i := 0; i < testCount; i++ {
ip, err := AllocatePeerIP(*ipNet, ips)
require.NoError(t, err, "failed to allocate IP %d", i)
// Verify IP is within the correct range
assert.True(t, ipNet.Contains(ip), "allocated IP %s is not within network %s", ip.String(), ipNet.String())
// Verify IP is not network or broadcast address
networkIP := ipNet.IP.Mask(ipNet.Mask)
ones, bits := ipNet.Mask.Size()
hostBits := bits - ones
broadcastInt := uint32(ipToUint32(networkIP)) + (1 << hostBits) - 1
broadcastIP := uint32ToIP(broadcastInt)
assert.False(t, ip.Equal(networkIP), "allocated network address %s", ip.String())
assert.False(t, ip.Equal(broadcastIP), "allocated broadcast address %s", ip.String())
ips = append(ips, ip)
}
assert.Len(t, ips, testCount)
// Verify all IPs are unique
uniq := make(map[string]struct{})
for _, ip := range ips {
ipStr := ip.String()
assert.NotContains(t, uniq, ipStr, "found duplicate IP %s", ipStr)
uniq[ipStr] = struct{}{}
}
})
}
}
func TestGenerateIPs(t *testing.T) {
ipNet := net.IPNet{IP: net.ParseIP("100.64.0.0"), Mask: net.IPMask{255, 255, 255, 0}}
ips, ipsLen := generateIPs(&ipNet, map[string]struct{}{"100.64.0.0": {}})

View File

@@ -1,6 +1,7 @@
package types
import (
"net/netip"
"time"
)
@@ -42,6 +43,9 @@ type Settings struct {
// DNSDomain is the custom domain for that account
DNSDomain string
// NetworkRange is the custom network range for that account
NetworkRange netip.Prefix `gorm:"serializer:json"`
// Extra is a dictionary of Account settings
Extra *ExtraSettings `gorm:"embedded;embeddedPrefix:extra_"`
@@ -66,6 +70,7 @@ func (s *Settings) Copy() *Settings {
RoutingPeerDNSResolutionEnabled: s.RoutingPeerDNSResolutionEnabled,
LazyConnectionEnabled: s.LazyConnectionEnabled,
DNSDomain: s.DNSDomain,
NetworkRange: s.NetworkRange,
}
if s.Extra != nil {
settings.Extra = s.Extra.Copy()

View File

@@ -23,7 +23,6 @@ func FileExists(path string) bool {
return err == nil
}
/// Bool helpers
// True returns a *bool whose underlying value is true.

View File

@@ -6,7 +6,7 @@ import (
"time"
)
//Duration is used strictly for JSON requests/responses due to duration marshalling issues
// Duration is used strictly for JSON requests/responses due to duration marshalling issues
type Duration struct {
time.Duration
}