[management, client] Add API to change the network range (#4177)

This commit is contained in:
Viktor Liu
2025-08-04 16:45:49 +02:00
committed by GitHub
parent 58eb3c8cc2
commit beb66208a0
20 changed files with 606 additions and 27 deletions

View File

@@ -171,7 +171,7 @@ func fakeAddress(peerAddress *net.UDPAddr) (*netip.AddrPort, error) {
fakeIP, err := netip.ParseAddr(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3])) fakeIP, err := netip.ParseAddr(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3]))
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to parse new IP: %w", err) return nil, fmt.Errorf("parse new IP: %w", err)
} }
netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port)) netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port))

View File

@@ -95,7 +95,7 @@ func (e *ProxyWrapper) CloseConn() error {
e.closeListener.SetCloseListener(nil) e.closeListener.SetCloseListener(nil)
if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
return fmt.Errorf("failed to close remote conn: %w", err) return fmt.Errorf("close remote conn: %w", err)
} }
return nil return nil
} }

View File

@@ -861,15 +861,10 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
return errors.New("wireguard interface is not initialized") return errors.New("wireguard interface is not initialized")
} }
// Cannot update the IP address without restarting the engine because
// the firewall, route manager, and other components cache the old address
if e.wgInterface.Address().String() != conf.Address { if e.wgInterface.Address().String() != conf.Address {
oldAddr := e.wgInterface.Address().String() log.Infof("peer IP address has changed from %s to %s", e.wgInterface.Address().String(), conf.Address)
log.Debugf("updating peer address from %s to %s", oldAddr, conf.Address)
err := e.wgInterface.UpdateAddr(conf.Address)
if err != nil {
return err
}
e.config.WgAddr = conf.Address
log.Infof("updated peer address from %s to %s", oldAddr, conf.Address)
} }
if conf.GetSshConfig() != nil { if conf.GetSshConfig() != nil {
@@ -880,7 +875,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
} }
state := e.statusRecorder.GetLocalPeerState() state := e.statusRecorder.GetLocalPeerState()
state.IP = e.config.WgAddr state.IP = e.wgInterface.Address().String()
state.PubKey = e.config.WgPrivateKey.PublicKey().String() state.PubKey = e.config.WgPrivateKey.PublicKey().String()
state.KernelInterface = device.WireGuardModuleIsLoaded() state.KernelInterface = device.WireGuardModuleIsLoaded()
state.FQDN = conf.GetFqdn() state.FQDN = conf.GetFqdn()

View File

@@ -142,7 +142,7 @@ var (
err := handleRebrand(cmd) err := handleRebrand(cmd)
if err != nil { if err != nil {
return fmt.Errorf("failed to migrate files %v", err) return fmt.Errorf("migrate files %v", err)
} }
if _, err = os.Stat(config.Datadir); os.IsNotExist(err) { if _, err = os.Stat(config.Datadir); os.IsNotExist(err) {
@@ -184,7 +184,7 @@ var (
} }
eventStore, key, err := integrations.InitEventStore(ctx, config.Datadir, config.DataStoreEncryptionKey, integrationMetrics) eventStore, key, err := integrations.InitEventStore(ctx, config.Datadir, config.DataStoreEncryptionKey, integrationMetrics)
if err != nil { if err != nil {
return fmt.Errorf("failed to initialize database: %s", err) return fmt.Errorf("initialize database: %s", err)
} }
if config.DataStoreEncryptionKey != key { if config.DataStoreEncryptionKey != key {
@@ -192,7 +192,7 @@ var (
config.DataStoreEncryptionKey = key config.DataStoreEncryptionKey = key
err := updateMgmtConfig(ctx, types.MgmtConfigPath, config) err := updateMgmtConfig(ctx, types.MgmtConfigPath, config)
if err != nil { if err != nil {
return fmt.Errorf("failed to write out store encryption key: %s", err) return fmt.Errorf("write out store encryption key: %s", err)
} }
} }
@@ -205,7 +205,7 @@ var (
integratedPeerValidator, err := integrations.NewIntegratedValidator(ctx, eventStore) integratedPeerValidator, err := integrations.NewIntegratedValidator(ctx, eventStore)
if err != nil { if err != nil {
return fmt.Errorf("failed to initialize integrated peer validator: %v", err) return fmt.Errorf("initialize integrated peer validator: %v", err)
} }
permissionsManager := integrations.InitPermissionsManager(store) permissionsManager := integrations.InitPermissionsManager(store)
@@ -217,7 +217,7 @@ var (
accountManager, err := server.BuildManager(ctx, store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain, accountManager, err := server.BuildManager(ctx, store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain,
dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator, appMetrics, proxyController, settingsManager, permissionsManager, config.DisableDefaultPolicy) dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator, appMetrics, proxyController, settingsManager, permissionsManager, config.DisableDefaultPolicy)
if err != nil { if err != nil {
return fmt.Errorf("failed to build default manager: %v", err) return fmt.Errorf("build default manager: %v", err)
} }
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsManager) secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsManager)

View File

@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"math/rand" "math/rand"
"net" "net"
"net/netip"
"os" "os"
"reflect" "reflect"
"regexp" "regexp"
@@ -324,6 +325,13 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return err return err
} }
if oldSettings.NetworkRange != newSettings.NetworkRange {
if err = am.reallocateAccountPeerIPs(ctx, transaction, accountID, newSettings.NetworkRange); err != nil {
return err
}
updateAccountPeers = true
}
if oldSettings.RoutingPeerDNSResolutionEnabled != newSettings.RoutingPeerDNSResolutionEnabled || if oldSettings.RoutingPeerDNSResolutionEnabled != newSettings.RoutingPeerDNSResolutionEnabled ||
oldSettings.LazyConnectionEnabled != newSettings.LazyConnectionEnabled || oldSettings.LazyConnectionEnabled != newSettings.LazyConnectionEnabled ||
oldSettings.DNSDomain != newSettings.DNSDomain { oldSettings.DNSDomain != newSettings.DNSDomain {
@@ -362,7 +370,18 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return nil, err return nil, err
} }
if oldSettings.DNSDomain != newSettings.DNSDomain { if oldSettings.DNSDomain != newSettings.DNSDomain {
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountDNSDomainUpdated, nil) eventMeta := map[string]any{
"old_dns_domain": oldSettings.DNSDomain,
"new_dns_domain": newSettings.DNSDomain,
}
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountDNSDomainUpdated, eventMeta)
}
if oldSettings.NetworkRange != newSettings.NetworkRange {
eventMeta := map[string]any{
"old_network_range": oldSettings.NetworkRange.String(),
"new_network_range": newSettings.NetworkRange.String(),
}
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountNetworkRangeUpdated, eventMeta)
} }
if updateAccountPeers || extraSettingsChanged || groupChangesAffectPeers { if updateAccountPeers || extraSettingsChanged || groupChangesAffectPeers {
@@ -2018,3 +2037,154 @@ func propagateUserGroupMemberships(ctx context.Context, transaction store.Store,
return len(updatedGroups) > 0, peersAffected, nil return len(updatedGroups) > 0, peersAffected, nil
} }
// reallocateAccountPeerIPs re-allocates all peer IPs when the network range changes
func (am *DefaultAccountManager) reallocateAccountPeerIPs(ctx context.Context, transaction store.Store, accountID string, newNetworkRange netip.Prefix) error {
if !newNetworkRange.IsValid() {
return nil
}
newIPNet := net.IPNet{
IP: newNetworkRange.Masked().Addr().AsSlice(),
Mask: net.CIDRMask(newNetworkRange.Bits(), newNetworkRange.Addr().BitLen()),
}
account, err := transaction.GetAccount(ctx, accountID)
if err != nil {
return err
}
account.Network.Net = newIPNet
peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "")
if err != nil {
return err
}
var takenIPs []net.IP
for _, peer := range peers {
newIP, err := types.AllocatePeerIP(newIPNet, takenIPs)
if err != nil {
return status.Errorf(status.Internal, "allocate IP for peer %s: %v", peer.ID, err)
}
log.WithContext(ctx).Infof("reallocating peer %s IP from %s to %s due to network range change",
peer.ID, peer.IP.String(), newIP.String())
peer.IP = newIP
takenIPs = append(takenIPs, newIP)
}
if err = transaction.SaveAccount(ctx, account); err != nil {
return err
}
for _, peer := range peers {
if err = transaction.SavePeer(ctx, store.LockingStrengthUpdate, accountID, peer); err != nil {
return status.Errorf(status.Internal, "save updated peer %s: %v", peer.ID, err)
}
}
log.WithContext(ctx).Infof("successfully re-allocated IPs for %d peers in account %s to network range %s",
len(peers), accountID, newNetworkRange.String())
return nil
}
func (am *DefaultAccountManager) validateIPForUpdate(account *types.Account, peers []*nbpeer.Peer, peerID string, newIP netip.Addr) error {
if !account.Network.Net.Contains(newIP.AsSlice()) {
return status.Errorf(status.InvalidArgument, "IP %s is not within the account network range %s", newIP.String(), account.Network.Net.String())
}
for _, peer := range peers {
if peer.ID != peerID && peer.IP.Equal(newIP.AsSlice()) {
return status.Errorf(status.InvalidArgument, "IP %s is already assigned to peer %s", newIP.String(), peer.ID)
}
}
return nil
}
func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Update)
if err != nil {
return fmt.Errorf("validate user permissions: %w", err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
updateNetworkMap, err := am.updatePeerIPInTransaction(ctx, accountID, userID, peerID, newIP)
if err != nil {
return fmt.Errorf("update peer IP transaction: %w", err)
}
if updateNetworkMap {
am.BufferUpdateAccountPeers(ctx, accountID)
}
return nil
}
func (am *DefaultAccountManager) updatePeerIPInTransaction(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) (bool, error) {
var updateNetworkMap bool
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
account, err := transaction.GetAccount(ctx, accountID)
if err != nil {
return fmt.Errorf("get account: %w", err)
}
existingPeer, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, peerID)
if err != nil {
return fmt.Errorf("get peer: %w", err)
}
if existingPeer.IP.Equal(newIP.AsSlice()) {
return nil
}
peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "")
if err != nil {
return fmt.Errorf("get account peers: %w", err)
}
if err := am.validateIPForUpdate(account, peers, peerID, newIP); err != nil {
return err
}
if err := am.savePeerIPUpdate(ctx, transaction, accountID, userID, existingPeer, newIP); err != nil {
return err
}
updateNetworkMap = true
return nil
})
return updateNetworkMap, err
}
func (am *DefaultAccountManager) savePeerIPUpdate(ctx context.Context, transaction store.Store, accountID, userID string, peer *nbpeer.Peer, newIP netip.Addr) error {
log.WithContext(ctx).Infof("updating peer %s IP from %s to %s", peer.ID, peer.IP, newIP)
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return fmt.Errorf("get account settings: %w", err)
}
dnsDomain := am.GetDNSDomain(settings)
eventMeta := peer.EventMeta(dnsDomain)
oldIP := peer.IP.String()
peer.IP = newIP.AsSlice()
err = transaction.SavePeer(ctx, store.LockingStrengthUpdate, accountID, peer)
if err != nil {
return fmt.Errorf("save peer: %w", err)
}
eventMeta["old_ip"] = oldIP
eventMeta["ip"] = newIP.String()
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerIPUpdated, eventMeta)
return nil
}

View File

@@ -51,6 +51,7 @@ type Manager interface {
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error
DeletePeer(ctx context.Context, accountID, peerID, userID string) error DeletePeer(ctx context.Context, accountID, peerID, userID string) error
UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
GetPeerNetwork(ctx context.Context, peerID string) (*types.Network, error) GetPeerNetwork(ctx context.Context, peerID string) (*types.Network, error)
AddPeer(ctx context.Context, setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) AddPeer(ctx context.Context, setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)

View File

@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"net/netip"
"os" "os"
"reflect" "reflect"
"strconv" "strconv"
@@ -3522,3 +3523,70 @@ func TestDefaultAccountManager_UpdateAccountOnboarding(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
}) })
} }
func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to create an account")
key1, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
key2, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
peer1, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{
Key: key1.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"},
})
require.NoError(t, err, "unable to add peer1")
peer2, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{
Key: key2.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
})
require.NoError(t, err, "unable to add peer2")
t.Run("update peer IP successfully", func(t *testing.T) {
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "unable to get account")
newIP, err := types.AllocatePeerIP(account.Network.Net, []net.IP{peer1.IP, peer2.IP})
require.NoError(t, err, "unable to allocate new IP")
newAddr := netip.MustParseAddr(newIP.String())
err = manager.UpdatePeerIP(context.Background(), accountID, userID, peer1.ID, newAddr)
require.NoError(t, err, "unable to update peer IP")
updatedPeer, err := manager.GetPeer(context.Background(), accountID, peer1.ID, userID)
require.NoError(t, err, "unable to get updated peer")
assert.Equal(t, newIP.String(), updatedPeer.IP.String(), "peer IP should be updated")
})
t.Run("update peer IP with same IP should be no-op", func(t *testing.T) {
currentAddr := netip.MustParseAddr(peer1.IP.String())
err := manager.UpdatePeerIP(context.Background(), accountID, userID, peer1.ID, currentAddr)
require.NoError(t, err, "updating with same IP should not error")
})
t.Run("update peer IP with collision should fail", func(t *testing.T) {
peer2Addr := netip.MustParseAddr(peer2.IP.String())
err := manager.UpdatePeerIP(context.Background(), accountID, userID, peer1.ID, peer2Addr)
require.Error(t, err, "should fail when IP is already assigned")
assert.Contains(t, err.Error(), "already assigned", "error should mention IP collision")
})
t.Run("update peer IP outside network range should fail", func(t *testing.T) {
invalidAddr := netip.MustParseAddr("192.168.1.100")
err := manager.UpdatePeerIP(context.Background(), accountID, userID, peer1.ID, invalidAddr)
require.Error(t, err, "should fail when IP is outside network range")
assert.Contains(t, err.Error(), "not within the account network range", "error should mention network range")
})
t.Run("update peer IP with invalid peer ID should fail", func(t *testing.T) {
newAddr := netip.MustParseAddr("100.64.0.101")
err := manager.UpdatePeerIP(context.Background(), accountID, userID, "invalid-peer-id", newAddr)
require.Error(t, err, "should fail with invalid peer ID")
})
}

View File

@@ -175,6 +175,9 @@ const (
AccountLazyConnectionEnabled Activity = 85 AccountLazyConnectionEnabled Activity = 85
AccountLazyConnectionDisabled Activity = 86 AccountLazyConnectionDisabled Activity = 86
AccountNetworkRangeUpdated Activity = 87
PeerIPUpdated Activity = 88
AccountDeleted Activity = 99999 AccountDeleted Activity = 99999
) )
@@ -277,6 +280,10 @@ var activityMap = map[Activity]Code{
AccountLazyConnectionEnabled: {"Account lazy connection enabled", "account.setting.lazy.connection.enable"}, AccountLazyConnectionEnabled: {"Account lazy connection enabled", "account.setting.lazy.connection.enable"},
AccountLazyConnectionDisabled: {"Account lazy connection disabled", "account.setting.lazy.connection.disable"}, AccountLazyConnectionDisabled: {"Account lazy connection disabled", "account.setting.lazy.connection.disable"},
AccountNetworkRangeUpdated: {"Account network range updated", "account.network.range.update"},
PeerIPUpdated: {"Peer IP updated", "peer.ip.update"},
} }
// StringCode returns a string code of the activity // StringCode returns a string code of the activity

View File

@@ -133,6 +133,11 @@ components:
description: Allows to define a custom dns domain for the account description: Allows to define a custom dns domain for the account
type: string type: string
example: my-organization.org example: my-organization.org
network_range:
description: Allows to define a custom network range for the account in CIDR format
type: string
format: cidr
example: 100.64.0.0/16
extra: extra:
$ref: '#/components/schemas/AccountExtraSettings' $ref: '#/components/schemas/AccountExtraSettings'
lazy_connection_enabled: lazy_connection_enabled:
@@ -342,6 +347,11 @@ components:
description: (Cloud only) Indicates whether peer needs approval description: (Cloud only) Indicates whether peer needs approval
type: boolean type: boolean
example: true example: true
ip:
description: Peer's IP address
type: string
format: ipv4
example: 100.64.0.15
required: required:
- name - name
- ssh_enabled - ssh_enabled

View File

@@ -303,6 +303,9 @@ type AccountSettings struct {
// LazyConnectionEnabled Enables or disables experimental lazy connection // LazyConnectionEnabled Enables or disables experimental lazy connection
LazyConnectionEnabled *bool `json:"lazy_connection_enabled,omitempty"` LazyConnectionEnabled *bool `json:"lazy_connection_enabled,omitempty"`
// NetworkRange Allows to define a custom network range for the account in CIDR format
NetworkRange *string `json:"network_range,omitempty"`
// PeerInactivityExpiration Period of time of inactivity after which peer session expires (seconds). // PeerInactivityExpiration Period of time of inactivity after which peer session expires (seconds).
PeerInactivityExpiration int `json:"peer_inactivity_expiration"` PeerInactivityExpiration int `json:"peer_inactivity_expiration"`
@@ -1196,11 +1199,14 @@ type PeerNetworkRangeCheckAction string
// PeerRequest defines model for PeerRequest. // PeerRequest defines model for PeerRequest.
type PeerRequest struct { type PeerRequest struct {
// ApprovalRequired (Cloud only) Indicates whether peer needs approval // ApprovalRequired (Cloud only) Indicates whether peer needs approval
ApprovalRequired *bool `json:"approval_required,omitempty"` ApprovalRequired *bool `json:"approval_required,omitempty"`
InactivityExpirationEnabled bool `json:"inactivity_expiration_enabled"` InactivityExpirationEnabled bool `json:"inactivity_expiration_enabled"`
LoginExpirationEnabled bool `json:"login_expiration_enabled"`
Name string `json:"name"` // Ip Peer's IP address
SshEnabled bool `json:"ssh_enabled"` Ip *string `json:"ip,omitempty"`
LoginExpirationEnabled bool `json:"login_expiration_enabled"`
Name string `json:"name"`
SshEnabled bool `json:"ssh_enabled"`
} }
// PersonalAccessToken defines model for PersonalAccessToken. // PersonalAccessToken defines model for PersonalAccessToken.

View File

@@ -1,8 +1,10 @@
package accounts package accounts
import ( import (
"context"
"encoding/json" "encoding/json"
"net/http" "net/http"
"net/netip"
"time" "time"
"github.com/gorilla/mux" "github.com/gorilla/mux"
@@ -16,6 +18,17 @@ import (
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
) )
const (
// PeerBufferPercentage is the percentage of peers to add as buffer for network range calculations
PeerBufferPercentage = 0.5
// MinRequiredAddresses is the minimum number of addresses required in a network range
MinRequiredAddresses = 10
// MinNetworkBits is the minimum prefix length for IPv4 network ranges (e.g., /29 gives 8 addresses, /28 gives 16)
MinNetworkBitsIPv4 = 28
// MinNetworkBitsIPv6 is the minimum prefix length for IPv6 network ranges
MinNetworkBitsIPv6 = 120
)
// handler is a handler that handles the server.Account HTTP endpoints // handler is a handler that handles the server.Account HTTP endpoints
type handler struct { type handler struct {
accountManager account.Manager accountManager account.Manager
@@ -37,6 +50,86 @@ func newHandler(accountManager account.Manager, settingsManager settings.Manager
} }
} }
func validateIPAddress(addr netip.Addr) error {
if addr.IsLoopback() {
return status.Errorf(status.InvalidArgument, "loopback address range not allowed")
}
if addr.IsMulticast() {
return status.Errorf(status.InvalidArgument, "multicast address range not allowed")
}
if addr.IsLinkLocalUnicast() || addr.IsLinkLocalMulticast() {
return status.Errorf(status.InvalidArgument, "link-local address range not allowed")
}
return nil
}
func validateMinimumSize(prefix netip.Prefix) error {
addr := prefix.Addr()
if addr.Is4() && prefix.Bits() > MinNetworkBitsIPv4 {
return status.Errorf(status.InvalidArgument, "network range too small: minimum size is /%d for IPv4", MinNetworkBitsIPv4)
}
if addr.Is6() && prefix.Bits() > MinNetworkBitsIPv6 {
return status.Errorf(status.InvalidArgument, "network range too small: minimum size is /%d for IPv6", MinNetworkBitsIPv6)
}
return nil
}
func (h *handler) validateNetworkRange(ctx context.Context, accountID, userID string, networkRange netip.Prefix) error {
if !networkRange.IsValid() {
return nil
}
if err := validateIPAddress(networkRange.Addr()); err != nil {
return err
}
if err := validateMinimumSize(networkRange); err != nil {
return err
}
return h.validateCapacity(ctx, accountID, userID, networkRange)
}
func (h *handler) validateCapacity(ctx context.Context, accountID, userID string, prefix netip.Prefix) error {
peers, err := h.accountManager.GetPeers(ctx, accountID, userID, "", "")
if err != nil {
return status.Errorf(status.Internal, "get peer count: %v", err)
}
maxHosts := calculateMaxHosts(prefix)
requiredAddresses := calculateRequiredAddresses(len(peers))
if maxHosts < requiredAddresses {
return status.Errorf(status.InvalidArgument,
"network range too small: need at least %d addresses for %d peers + buffer, but range provides %d",
requiredAddresses, len(peers), maxHosts)
}
return nil
}
func calculateMaxHosts(prefix netip.Prefix) int64 {
availableAddresses := prefix.Addr().BitLen() - prefix.Bits()
maxHosts := int64(1) << availableAddresses
if prefix.Addr().Is4() {
maxHosts -= 2 // network and broadcast addresses
}
return maxHosts
}
func calculateRequiredAddresses(peerCount int) int64 {
requiredAddresses := int64(peerCount) + int64(float64(peerCount)*PeerBufferPercentage)
if requiredAddresses < MinRequiredAddresses {
requiredAddresses = MinRequiredAddresses
}
return requiredAddresses
}
// getAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account. // getAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account.
func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) { func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
@@ -131,6 +224,18 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
if req.Settings.LazyConnectionEnabled != nil { if req.Settings.LazyConnectionEnabled != nil {
settings.LazyConnectionEnabled = *req.Settings.LazyConnectionEnabled settings.LazyConnectionEnabled = *req.Settings.LazyConnectionEnabled
} }
if req.Settings.NetworkRange != nil && *req.Settings.NetworkRange != "" {
prefix, err := netip.ParsePrefix(*req.Settings.NetworkRange)
if err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid CIDR format: %v", err), w)
return
}
if err := h.validateNetworkRange(r.Context(), accountID, userID, prefix); err != nil {
util.WriteError(r.Context(), err, w)
return
}
settings.NetworkRange = prefix
}
var onboarding *types.AccountOnboarding var onboarding *types.AccountOnboarding
if req.Onboarding != nil { if req.Onboarding != nil {
@@ -208,6 +313,11 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A
DnsDomain: &settings.DNSDomain, DnsDomain: &settings.DNSDomain,
} }
if settings.NetworkRange.IsValid() {
networkRangeStr := settings.NetworkRange.String()
apiSettings.NetworkRange = &networkRangeStr
}
apiOnboarding := api.AccountOnboarding{ apiOnboarding := api.AccountOnboarding{
OnboardingFlowPending: onboarding.OnboardingFlowPending, OnboardingFlowPending: onboarding.OnboardingFlowPending,
SignupFormPending: onboarding.SignupFormPending, SignupFormPending: onboarding.SignupFormPending,

View File

@@ -5,6 +5,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"net/netip"
"github.com/gorilla/mux" "github.com/gorilla/mux"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -111,6 +112,19 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri
} }
} }
if req.Ip != nil {
addr, err := netip.ParseAddr(*req.Ip)
if err != nil {
util.WriteError(ctx, status.Errorf(status.InvalidArgument, "invalid IP address %s: %v", *req.Ip, err), w)
return
}
if err = h.accountManager.UpdatePeerIP(ctx, accountID, userID, peerID, addr); err != nil {
util.WriteError(ctx, err, w)
return
}
}
peer, err := h.accountManager.UpdatePeer(ctx, accountID, userID, update) peer, err := h.accountManager.UpdatePeer(ctx, accountID, userID, update)
if err != nil { if err != nil {
util.WriteError(ctx, err, w) util.WriteError(ctx, err, w)

View File

@@ -9,6 +9,7 @@ import (
"net" "net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/netip"
"testing" "testing"
"time" "time"
@@ -21,6 +22,7 @@ import (
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
) )
@@ -112,6 +114,15 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler {
p.Name = update.Name p.Name = update.Name
return p, nil return p, nil
}, },
UpdatePeerIPFunc: func(_ context.Context, accountID, userID, peerID string, newIP netip.Addr) error {
for _, peer := range peers {
if peer.ID == peerID {
peer.IP = net.IP(newIP.AsSlice())
return nil
}
}
return fmt.Errorf("peer not found")
},
GetPeerFunc: func(_ context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) { GetPeerFunc: func(_ context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) {
var p *nbpeer.Peer var p *nbpeer.Peer
for _, peer := range peers { for _, peer := range peers {
@@ -450,3 +461,73 @@ func TestGetAccessiblePeers(t *testing.T) {
}) })
} }
} }
func TestPeersHandlerUpdatePeerIP(t *testing.T) {
testPeer := &nbpeer.Peer{
ID: testPeerID,
Key: "key",
IP: net.ParseIP("100.64.0.1"),
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
Name: "test-host@netbird.io",
LoginExpirationEnabled: false,
UserID: regularUser,
Meta: nbpeer.PeerSystemMeta{
Hostname: "test-host@netbird.io",
Core: "22.04",
},
}
p := initTestMetaData(testPeer)
tt := []struct {
name string
peerID string
requestBody string
callerUserID string
expectedStatus int
expectedIP string
}{
{
name: "update peer IP successfully",
peerID: testPeerID,
requestBody: `{"ip": "100.64.0.100"}`,
callerUserID: adminUser,
expectedStatus: http.StatusOK,
expectedIP: "100.64.0.100",
},
{
name: "update peer IP with invalid IP",
peerID: testPeerID,
requestBody: `{"ip": "invalid-ip"}`,
callerUserID: adminUser,
expectedStatus: http.StatusUnprocessableEntity,
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodPut, fmt.Sprintf("/peers/%s", tc.peerID), bytes.NewBuffer([]byte(tc.requestBody)))
req.Header.Set("Content-Type", "application/json")
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
UserId: tc.callerUserID,
Domain: "hotmail.com",
AccountId: "test_id",
})
rr := httptest.NewRecorder()
router := mux.NewRouter()
router.HandleFunc("/peers/{peerId}", p.HandlePeer).Methods("PUT")
router.ServeHTTP(rr, req)
assert.Equal(t, tc.expectedStatus, rr.Code)
if tc.expectedStatus == http.StatusOK && tc.expectedIP != "" {
var updatedPeer api.Peer
err := json.Unmarshal(rr.Body.Bytes(), &updatedPeer)
require.NoError(t, err)
assert.Equal(t, tc.expectedIP, updatedPeer.Ip)
}
})
}
}

View File

@@ -60,6 +60,7 @@ type MockAccountManager struct {
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error) GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error)
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
UpdatePeerIPFunc func(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error
CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error
@@ -483,6 +484,13 @@ func (am *MockAccountManager) UpdatePeer(ctx context.Context, accountID, userID
return nil, status.Errorf(codes.Unimplemented, "method UpdatePeer is not implemented") return nil, status.Errorf(codes.Unimplemented, "method UpdatePeer is not implemented")
} }
func (am *MockAccountManager) UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error {
if am.UpdatePeerIPFunc != nil {
return am.UpdatePeerIPFunc(ctx, accountID, userID, peerID, newIP)
}
return status.Errorf(codes.Unimplemented, "method UpdatePeerIP is not implemented")
}
// CreateRoute mock implementation of CreateRoute from server.AccountManager interface // CreateRoute mock implementation of CreateRoute from server.AccountManager interface
func (am *MockAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupID []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) { func (am *MockAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupID []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) {
if am.CreateRouteFunc != nil { if am.CreateRouteFunc != nil {

View File

@@ -53,7 +53,7 @@ type Config struct {
StoreConfig StoreConfig StoreConfig StoreConfig
ReverseProxy ReverseProxy ReverseProxy ReverseProxy
// disable default all-to-all policy // disable default all-to-all policy
DisableDefaultPolicy bool DisableDefaultPolicy bool
} }

View File

@@ -163,7 +163,10 @@ func (n *Network) Copy() *Network {
// E.g. if ipNet=100.30.0.0/16 and takenIps=[100.30.0.1, 100.30.0.4] then the result would be 100.30.0.2 or 100.30.0.3 // E.g. if ipNet=100.30.0.0/16 and takenIps=[100.30.0.1, 100.30.0.4] then the result would be 100.30.0.2 or 100.30.0.3
func AllocatePeerIP(ipNet net.IPNet, takenIps []net.IP) (net.IP, error) { func AllocatePeerIP(ipNet net.IPNet, takenIps []net.IP) (net.IP, error) {
baseIP := ipToUint32(ipNet.IP.Mask(ipNet.Mask)) baseIP := ipToUint32(ipNet.IP.Mask(ipNet.Mask))
totalIPs := uint32(1 << SubnetSize)
ones, bits := ipNet.Mask.Size()
hostBits := bits - ones
totalIPs := uint32(1 << hostBits)
taken := make(map[uint32]struct{}, len(takenIps)+1) taken := make(map[uint32]struct{}, len(takenIps)+1)
taken[baseIP] = struct{}{} // reserve network IP taken[baseIP] = struct{}{} // reserve network IP

View File

@@ -5,6 +5,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestNewNetwork(t *testing.T) { func TestNewNetwork(t *testing.T) {
@@ -38,6 +39,107 @@ func TestAllocatePeerIP(t *testing.T) {
} }
} }
func TestAllocatePeerIPSmallSubnet(t *testing.T) {
// Test /27 network (10.0.0.0/27) - should only have 30 usable IPs (10.0.0.1 to 10.0.0.30)
ipNet := net.IPNet{IP: net.ParseIP("10.0.0.0"), Mask: net.IPMask{255, 255, 255, 224}}
var ips []net.IP
// Allocate all available IPs in the /27 network
for i := 0; i < 30; i++ {
ip, err := AllocatePeerIP(ipNet, ips)
if err != nil {
t.Fatal(err)
}
// Verify IP is within the correct range
if !ipNet.Contains(ip) {
t.Errorf("allocated IP %s is not within network %s", ip.String(), ipNet.String())
}
ips = append(ips, ip)
}
assert.Len(t, ips, 30)
// Verify all IPs are unique
uniq := make(map[string]struct{})
for _, ip := range ips {
if _, ok := uniq[ip.String()]; !ok {
uniq[ip.String()] = struct{}{}
} else {
t.Errorf("found duplicate IP %s", ip.String())
}
}
// Try to allocate one more IP - should fail as network is full
_, err := AllocatePeerIP(ipNet, ips)
if err == nil {
t.Error("expected error when network is full, but got none")
}
}
func TestAllocatePeerIPVariousCIDRs(t *testing.T) {
testCases := []struct {
name string
cidr string
expectedUsable int
}{
{"/30 network", "192.168.1.0/30", 2}, // 4 total - 2 reserved = 2 usable
{"/29 network", "192.168.1.0/29", 6}, // 8 total - 2 reserved = 6 usable
{"/28 network", "192.168.1.0/28", 14}, // 16 total - 2 reserved = 14 usable
{"/27 network", "192.168.1.0/27", 30}, // 32 total - 2 reserved = 30 usable
{"/26 network", "192.168.1.0/26", 62}, // 64 total - 2 reserved = 62 usable
{"/25 network", "192.168.1.0/25", 126}, // 128 total - 2 reserved = 126 usable
{"/16 network", "10.0.0.0/16", 65534}, // 65536 total - 2 reserved = 65534 usable
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
_, ipNet, err := net.ParseCIDR(tc.cidr)
require.NoError(t, err)
var ips []net.IP
// For larger networks, test only a subset to avoid long test runs
testCount := tc.expectedUsable
if testCount > 1000 {
testCount = 1000
}
// Allocate IPs and verify they're within the correct range
for i := 0; i < testCount; i++ {
ip, err := AllocatePeerIP(*ipNet, ips)
require.NoError(t, err, "failed to allocate IP %d", i)
// Verify IP is within the correct range
assert.True(t, ipNet.Contains(ip), "allocated IP %s is not within network %s", ip.String(), ipNet.String())
// Verify IP is not network or broadcast address
networkIP := ipNet.IP.Mask(ipNet.Mask)
ones, bits := ipNet.Mask.Size()
hostBits := bits - ones
broadcastInt := uint32(ipToUint32(networkIP)) + (1 << hostBits) - 1
broadcastIP := uint32ToIP(broadcastInt)
assert.False(t, ip.Equal(networkIP), "allocated network address %s", ip.String())
assert.False(t, ip.Equal(broadcastIP), "allocated broadcast address %s", ip.String())
ips = append(ips, ip)
}
assert.Len(t, ips, testCount)
// Verify all IPs are unique
uniq := make(map[string]struct{})
for _, ip := range ips {
ipStr := ip.String()
assert.NotContains(t, uniq, ipStr, "found duplicate IP %s", ipStr)
uniq[ipStr] = struct{}{}
}
})
}
}
func TestGenerateIPs(t *testing.T) { func TestGenerateIPs(t *testing.T) {
ipNet := net.IPNet{IP: net.ParseIP("100.64.0.0"), Mask: net.IPMask{255, 255, 255, 0}} ipNet := net.IPNet{IP: net.ParseIP("100.64.0.0"), Mask: net.IPMask{255, 255, 255, 0}}
ips, ipsLen := generateIPs(&ipNet, map[string]struct{}{"100.64.0.0": {}}) ips, ipsLen := generateIPs(&ipNet, map[string]struct{}{"100.64.0.0": {}})

View File

@@ -1,6 +1,7 @@
package types package types
import ( import (
"net/netip"
"time" "time"
) )
@@ -42,6 +43,9 @@ type Settings struct {
// DNSDomain is the custom domain for that account // DNSDomain is the custom domain for that account
DNSDomain string DNSDomain string
// NetworkRange is the custom network range for that account
NetworkRange netip.Prefix `gorm:"serializer:json"`
// Extra is a dictionary of Account settings // Extra is a dictionary of Account settings
Extra *ExtraSettings `gorm:"embedded;embeddedPrefix:extra_"` Extra *ExtraSettings `gorm:"embedded;embeddedPrefix:extra_"`
@@ -66,6 +70,7 @@ func (s *Settings) Copy() *Settings {
RoutingPeerDNSResolutionEnabled: s.RoutingPeerDNSResolutionEnabled, RoutingPeerDNSResolutionEnabled: s.RoutingPeerDNSResolutionEnabled,
LazyConnectionEnabled: s.LazyConnectionEnabled, LazyConnectionEnabled: s.LazyConnectionEnabled,
DNSDomain: s.DNSDomain, DNSDomain: s.DNSDomain,
NetworkRange: s.NetworkRange,
} }
if s.Extra != nil { if s.Extra != nil {
settings.Extra = s.Extra.Copy() settings.Extra = s.Extra.Copy()

View File

@@ -23,7 +23,6 @@ func FileExists(path string) bool {
return err == nil return err == nil
} }
/// Bool helpers /// Bool helpers
// True returns a *bool whose underlying value is true. // True returns a *bool whose underlying value is true.
@@ -56,4 +55,4 @@ func ReturnBoolWithDefaultTrue(b *bool) bool {
return true return true
} }
} }

View File

@@ -6,7 +6,7 @@ import (
"time" "time"
) )
//Duration is used strictly for JSON requests/responses due to duration marshalling issues // Duration is used strictly for JSON requests/responses due to duration marshalling issues
type Duration struct { type Duration struct {
time.Duration time.Duration
} }