mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-19 03:16:58 +02:00
[management, client] Add API to change the network range (#4177)
This commit is contained in:
@@ -171,7 +171,7 @@ func fakeAddress(peerAddress *net.UDPAddr) (*netip.AddrPort, error) {
|
|||||||
|
|
||||||
fakeIP, err := netip.ParseAddr(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3]))
|
fakeIP, err := netip.ParseAddr(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3]))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to parse new IP: %w", err)
|
return nil, fmt.Errorf("parse new IP: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port))
|
netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port))
|
||||||
|
@@ -95,7 +95,7 @@ func (e *ProxyWrapper) CloseConn() error {
|
|||||||
e.closeListener.SetCloseListener(nil)
|
e.closeListener.SetCloseListener(nil)
|
||||||
|
|
||||||
if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||||
return fmt.Errorf("failed to close remote conn: %w", err)
|
return fmt.Errorf("close remote conn: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@@ -861,15 +861,10 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
|||||||
return errors.New("wireguard interface is not initialized")
|
return errors.New("wireguard interface is not initialized")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Cannot update the IP address without restarting the engine because
|
||||||
|
// the firewall, route manager, and other components cache the old address
|
||||||
if e.wgInterface.Address().String() != conf.Address {
|
if e.wgInterface.Address().String() != conf.Address {
|
||||||
oldAddr := e.wgInterface.Address().String()
|
log.Infof("peer IP address has changed from %s to %s", e.wgInterface.Address().String(), conf.Address)
|
||||||
log.Debugf("updating peer address from %s to %s", oldAddr, conf.Address)
|
|
||||||
err := e.wgInterface.UpdateAddr(conf.Address)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
e.config.WgAddr = conf.Address
|
|
||||||
log.Infof("updated peer address from %s to %s", oldAddr, conf.Address)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if conf.GetSshConfig() != nil {
|
if conf.GetSshConfig() != nil {
|
||||||
@@ -880,7 +875,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
state := e.statusRecorder.GetLocalPeerState()
|
state := e.statusRecorder.GetLocalPeerState()
|
||||||
state.IP = e.config.WgAddr
|
state.IP = e.wgInterface.Address().String()
|
||||||
state.PubKey = e.config.WgPrivateKey.PublicKey().String()
|
state.PubKey = e.config.WgPrivateKey.PublicKey().String()
|
||||||
state.KernelInterface = device.WireGuardModuleIsLoaded()
|
state.KernelInterface = device.WireGuardModuleIsLoaded()
|
||||||
state.FQDN = conf.GetFqdn()
|
state.FQDN = conf.GetFqdn()
|
||||||
|
@@ -142,7 +142,7 @@ var (
|
|||||||
|
|
||||||
err := handleRebrand(cmd)
|
err := handleRebrand(cmd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to migrate files %v", err)
|
return fmt.Errorf("migrate files %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err = os.Stat(config.Datadir); os.IsNotExist(err) {
|
if _, err = os.Stat(config.Datadir); os.IsNotExist(err) {
|
||||||
@@ -184,7 +184,7 @@ var (
|
|||||||
}
|
}
|
||||||
eventStore, key, err := integrations.InitEventStore(ctx, config.Datadir, config.DataStoreEncryptionKey, integrationMetrics)
|
eventStore, key, err := integrations.InitEventStore(ctx, config.Datadir, config.DataStoreEncryptionKey, integrationMetrics)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to initialize database: %s", err)
|
return fmt.Errorf("initialize database: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.DataStoreEncryptionKey != key {
|
if config.DataStoreEncryptionKey != key {
|
||||||
@@ -192,7 +192,7 @@ var (
|
|||||||
config.DataStoreEncryptionKey = key
|
config.DataStoreEncryptionKey = key
|
||||||
err := updateMgmtConfig(ctx, types.MgmtConfigPath, config)
|
err := updateMgmtConfig(ctx, types.MgmtConfigPath, config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to write out store encryption key: %s", err)
|
return fmt.Errorf("write out store encryption key: %s", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -205,7 +205,7 @@ var (
|
|||||||
|
|
||||||
integratedPeerValidator, err := integrations.NewIntegratedValidator(ctx, eventStore)
|
integratedPeerValidator, err := integrations.NewIntegratedValidator(ctx, eventStore)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to initialize integrated peer validator: %v", err)
|
return fmt.Errorf("initialize integrated peer validator: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
permissionsManager := integrations.InitPermissionsManager(store)
|
permissionsManager := integrations.InitPermissionsManager(store)
|
||||||
@@ -217,7 +217,7 @@ var (
|
|||||||
accountManager, err := server.BuildManager(ctx, store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain,
|
accountManager, err := server.BuildManager(ctx, store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain,
|
||||||
dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator, appMetrics, proxyController, settingsManager, permissionsManager, config.DisableDefaultPolicy)
|
dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator, appMetrics, proxyController, settingsManager, permissionsManager, config.DisableDefaultPolicy)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to build default manager: %v", err)
|
return fmt.Errorf("build default manager: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsManager)
|
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsManager)
|
||||||
|
@@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
@@ -324,6 +325,13 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if oldSettings.NetworkRange != newSettings.NetworkRange {
|
||||||
|
if err = am.reallocateAccountPeerIPs(ctx, transaction, accountID, newSettings.NetworkRange); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
updateAccountPeers = true
|
||||||
|
}
|
||||||
|
|
||||||
if oldSettings.RoutingPeerDNSResolutionEnabled != newSettings.RoutingPeerDNSResolutionEnabled ||
|
if oldSettings.RoutingPeerDNSResolutionEnabled != newSettings.RoutingPeerDNSResolutionEnabled ||
|
||||||
oldSettings.LazyConnectionEnabled != newSettings.LazyConnectionEnabled ||
|
oldSettings.LazyConnectionEnabled != newSettings.LazyConnectionEnabled ||
|
||||||
oldSettings.DNSDomain != newSettings.DNSDomain {
|
oldSettings.DNSDomain != newSettings.DNSDomain {
|
||||||
@@ -362,7 +370,18 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if oldSettings.DNSDomain != newSettings.DNSDomain {
|
if oldSettings.DNSDomain != newSettings.DNSDomain {
|
||||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountDNSDomainUpdated, nil)
|
eventMeta := map[string]any{
|
||||||
|
"old_dns_domain": oldSettings.DNSDomain,
|
||||||
|
"new_dns_domain": newSettings.DNSDomain,
|
||||||
|
}
|
||||||
|
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountDNSDomainUpdated, eventMeta)
|
||||||
|
}
|
||||||
|
if oldSettings.NetworkRange != newSettings.NetworkRange {
|
||||||
|
eventMeta := map[string]any{
|
||||||
|
"old_network_range": oldSettings.NetworkRange.String(),
|
||||||
|
"new_network_range": newSettings.NetworkRange.String(),
|
||||||
|
}
|
||||||
|
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountNetworkRangeUpdated, eventMeta)
|
||||||
}
|
}
|
||||||
|
|
||||||
if updateAccountPeers || extraSettingsChanged || groupChangesAffectPeers {
|
if updateAccountPeers || extraSettingsChanged || groupChangesAffectPeers {
|
||||||
@@ -2018,3 +2037,154 @@ func propagateUserGroupMemberships(ctx context.Context, transaction store.Store,
|
|||||||
|
|
||||||
return len(updatedGroups) > 0, peersAffected, nil
|
return len(updatedGroups) > 0, peersAffected, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// reallocateAccountPeerIPs re-allocates all peer IPs when the network range changes
|
||||||
|
func (am *DefaultAccountManager) reallocateAccountPeerIPs(ctx context.Context, transaction store.Store, accountID string, newNetworkRange netip.Prefix) error {
|
||||||
|
if !newNetworkRange.IsValid() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
newIPNet := net.IPNet{
|
||||||
|
IP: newNetworkRange.Masked().Addr().AsSlice(),
|
||||||
|
Mask: net.CIDRMask(newNetworkRange.Bits(), newNetworkRange.Addr().BitLen()),
|
||||||
|
}
|
||||||
|
|
||||||
|
account, err := transaction.GetAccount(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
account.Network.Net = newIPNet
|
||||||
|
|
||||||
|
peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var takenIPs []net.IP
|
||||||
|
|
||||||
|
for _, peer := range peers {
|
||||||
|
newIP, err := types.AllocatePeerIP(newIPNet, takenIPs)
|
||||||
|
if err != nil {
|
||||||
|
return status.Errorf(status.Internal, "allocate IP for peer %s: %v", peer.ID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithContext(ctx).Infof("reallocating peer %s IP from %s to %s due to network range change",
|
||||||
|
peer.ID, peer.IP.String(), newIP.String())
|
||||||
|
|
||||||
|
peer.IP = newIP
|
||||||
|
takenIPs = append(takenIPs, newIP)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = transaction.SaveAccount(ctx, account); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, peer := range peers {
|
||||||
|
if err = transaction.SavePeer(ctx, store.LockingStrengthUpdate, accountID, peer); err != nil {
|
||||||
|
return status.Errorf(status.Internal, "save updated peer %s: %v", peer.ID, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithContext(ctx).Infof("successfully re-allocated IPs for %d peers in account %s to network range %s",
|
||||||
|
len(peers), accountID, newNetworkRange.String())
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (am *DefaultAccountManager) validateIPForUpdate(account *types.Account, peers []*nbpeer.Peer, peerID string, newIP netip.Addr) error {
|
||||||
|
if !account.Network.Net.Contains(newIP.AsSlice()) {
|
||||||
|
return status.Errorf(status.InvalidArgument, "IP %s is not within the account network range %s", newIP.String(), account.Network.Net.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, peer := range peers {
|
||||||
|
if peer.ID != peerID && peer.IP.Equal(newIP.AsSlice()) {
|
||||||
|
return status.Errorf(status.InvalidArgument, "IP %s is already assigned to peer %s", newIP.String(), peer.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error {
|
||||||
|
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||||
|
defer unlock()
|
||||||
|
|
||||||
|
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Update)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("validate user permissions: %w", err)
|
||||||
|
}
|
||||||
|
if !allowed {
|
||||||
|
return status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
updateNetworkMap, err := am.updatePeerIPInTransaction(ctx, accountID, userID, peerID, newIP)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("update peer IP transaction: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if updateNetworkMap {
|
||||||
|
am.BufferUpdateAccountPeers(ctx, accountID)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (am *DefaultAccountManager) updatePeerIPInTransaction(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) (bool, error) {
|
||||||
|
var updateNetworkMap bool
|
||||||
|
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
|
account, err := transaction.GetAccount(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get account: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
existingPeer, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, peerID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get peer: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if existingPeer.IP.Equal(newIP.AsSlice()) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get account peers: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := am.validateIPForUpdate(account, peers, peerID, newIP); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := am.savePeerIPUpdate(ctx, transaction, accountID, userID, existingPeer, newIP); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
updateNetworkMap = true
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
return updateNetworkMap, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (am *DefaultAccountManager) savePeerIPUpdate(ctx context.Context, transaction store.Store, accountID, userID string, peer *nbpeer.Peer, newIP netip.Addr) error {
|
||||||
|
log.WithContext(ctx).Infof("updating peer %s IP from %s to %s", peer.ID, peer.IP, newIP)
|
||||||
|
|
||||||
|
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get account settings: %w", err)
|
||||||
|
}
|
||||||
|
dnsDomain := am.GetDNSDomain(settings)
|
||||||
|
|
||||||
|
eventMeta := peer.EventMeta(dnsDomain)
|
||||||
|
oldIP := peer.IP.String()
|
||||||
|
|
||||||
|
peer.IP = newIP.AsSlice()
|
||||||
|
err = transaction.SavePeer(ctx, store.LockingStrengthUpdate, accountID, peer)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("save peer: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
eventMeta["old_ip"] = oldIP
|
||||||
|
eventMeta["ip"] = newIP.String()
|
||||||
|
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerIPUpdated, eventMeta)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@@ -51,6 +51,7 @@ type Manager interface {
|
|||||||
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error
|
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error
|
||||||
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
|
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
|
||||||
UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||||
|
UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error
|
||||||
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
|
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
|
||||||
GetPeerNetwork(ctx context.Context, peerID string) (*types.Network, error)
|
GetPeerNetwork(ctx context.Context, peerID string) (*types.Network, error)
|
||||||
AddPeer(ctx context.Context, setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
AddPeer(ctx context.Context, setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||||
|
@@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -3522,3 +3523,70 @@ func TestDefaultAccountManager_UpdateAccountOnboarding(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) {
|
||||||
|
manager, err := createManager(t)
|
||||||
|
require.NoError(t, err, "unable to create account manager")
|
||||||
|
|
||||||
|
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||||
|
require.NoError(t, err, "unable to create an account")
|
||||||
|
|
||||||
|
key1, err := wgtypes.GenerateKey()
|
||||||
|
require.NoError(t, err, "unable to generate WireGuard key")
|
||||||
|
key2, err := wgtypes.GenerateKey()
|
||||||
|
require.NoError(t, err, "unable to generate WireGuard key")
|
||||||
|
|
||||||
|
peer1, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{
|
||||||
|
Key: key1.PublicKey().String(),
|
||||||
|
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"},
|
||||||
|
})
|
||||||
|
require.NoError(t, err, "unable to add peer1")
|
||||||
|
|
||||||
|
peer2, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{
|
||||||
|
Key: key2.PublicKey().String(),
|
||||||
|
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
|
||||||
|
})
|
||||||
|
require.NoError(t, err, "unable to add peer2")
|
||||||
|
|
||||||
|
t.Run("update peer IP successfully", func(t *testing.T) {
|
||||||
|
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||||
|
require.NoError(t, err, "unable to get account")
|
||||||
|
|
||||||
|
newIP, err := types.AllocatePeerIP(account.Network.Net, []net.IP{peer1.IP, peer2.IP})
|
||||||
|
require.NoError(t, err, "unable to allocate new IP")
|
||||||
|
|
||||||
|
newAddr := netip.MustParseAddr(newIP.String())
|
||||||
|
err = manager.UpdatePeerIP(context.Background(), accountID, userID, peer1.ID, newAddr)
|
||||||
|
require.NoError(t, err, "unable to update peer IP")
|
||||||
|
|
||||||
|
updatedPeer, err := manager.GetPeer(context.Background(), accountID, peer1.ID, userID)
|
||||||
|
require.NoError(t, err, "unable to get updated peer")
|
||||||
|
assert.Equal(t, newIP.String(), updatedPeer.IP.String(), "peer IP should be updated")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("update peer IP with same IP should be no-op", func(t *testing.T) {
|
||||||
|
currentAddr := netip.MustParseAddr(peer1.IP.String())
|
||||||
|
err := manager.UpdatePeerIP(context.Background(), accountID, userID, peer1.ID, currentAddr)
|
||||||
|
require.NoError(t, err, "updating with same IP should not error")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("update peer IP with collision should fail", func(t *testing.T) {
|
||||||
|
peer2Addr := netip.MustParseAddr(peer2.IP.String())
|
||||||
|
err := manager.UpdatePeerIP(context.Background(), accountID, userID, peer1.ID, peer2Addr)
|
||||||
|
require.Error(t, err, "should fail when IP is already assigned")
|
||||||
|
assert.Contains(t, err.Error(), "already assigned", "error should mention IP collision")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("update peer IP outside network range should fail", func(t *testing.T) {
|
||||||
|
invalidAddr := netip.MustParseAddr("192.168.1.100")
|
||||||
|
err := manager.UpdatePeerIP(context.Background(), accountID, userID, peer1.ID, invalidAddr)
|
||||||
|
require.Error(t, err, "should fail when IP is outside network range")
|
||||||
|
assert.Contains(t, err.Error(), "not within the account network range", "error should mention network range")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("update peer IP with invalid peer ID should fail", func(t *testing.T) {
|
||||||
|
newAddr := netip.MustParseAddr("100.64.0.101")
|
||||||
|
err := manager.UpdatePeerIP(context.Background(), accountID, userID, "invalid-peer-id", newAddr)
|
||||||
|
require.Error(t, err, "should fail with invalid peer ID")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
@@ -175,6 +175,9 @@ const (
|
|||||||
AccountLazyConnectionEnabled Activity = 85
|
AccountLazyConnectionEnabled Activity = 85
|
||||||
AccountLazyConnectionDisabled Activity = 86
|
AccountLazyConnectionDisabled Activity = 86
|
||||||
|
|
||||||
|
AccountNetworkRangeUpdated Activity = 87
|
||||||
|
PeerIPUpdated Activity = 88
|
||||||
|
|
||||||
AccountDeleted Activity = 99999
|
AccountDeleted Activity = 99999
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -277,6 +280,10 @@ var activityMap = map[Activity]Code{
|
|||||||
|
|
||||||
AccountLazyConnectionEnabled: {"Account lazy connection enabled", "account.setting.lazy.connection.enable"},
|
AccountLazyConnectionEnabled: {"Account lazy connection enabled", "account.setting.lazy.connection.enable"},
|
||||||
AccountLazyConnectionDisabled: {"Account lazy connection disabled", "account.setting.lazy.connection.disable"},
|
AccountLazyConnectionDisabled: {"Account lazy connection disabled", "account.setting.lazy.connection.disable"},
|
||||||
|
|
||||||
|
AccountNetworkRangeUpdated: {"Account network range updated", "account.network.range.update"},
|
||||||
|
|
||||||
|
PeerIPUpdated: {"Peer IP updated", "peer.ip.update"},
|
||||||
}
|
}
|
||||||
|
|
||||||
// StringCode returns a string code of the activity
|
// StringCode returns a string code of the activity
|
||||||
|
@@ -133,6 +133,11 @@ components:
|
|||||||
description: Allows to define a custom dns domain for the account
|
description: Allows to define a custom dns domain for the account
|
||||||
type: string
|
type: string
|
||||||
example: my-organization.org
|
example: my-organization.org
|
||||||
|
network_range:
|
||||||
|
description: Allows to define a custom network range for the account in CIDR format
|
||||||
|
type: string
|
||||||
|
format: cidr
|
||||||
|
example: 100.64.0.0/16
|
||||||
extra:
|
extra:
|
||||||
$ref: '#/components/schemas/AccountExtraSettings'
|
$ref: '#/components/schemas/AccountExtraSettings'
|
||||||
lazy_connection_enabled:
|
lazy_connection_enabled:
|
||||||
@@ -342,6 +347,11 @@ components:
|
|||||||
description: (Cloud only) Indicates whether peer needs approval
|
description: (Cloud only) Indicates whether peer needs approval
|
||||||
type: boolean
|
type: boolean
|
||||||
example: true
|
example: true
|
||||||
|
ip:
|
||||||
|
description: Peer's IP address
|
||||||
|
type: string
|
||||||
|
format: ipv4
|
||||||
|
example: 100.64.0.15
|
||||||
required:
|
required:
|
||||||
- name
|
- name
|
||||||
- ssh_enabled
|
- ssh_enabled
|
||||||
|
@@ -303,6 +303,9 @@ type AccountSettings struct {
|
|||||||
// LazyConnectionEnabled Enables or disables experimental lazy connection
|
// LazyConnectionEnabled Enables or disables experimental lazy connection
|
||||||
LazyConnectionEnabled *bool `json:"lazy_connection_enabled,omitempty"`
|
LazyConnectionEnabled *bool `json:"lazy_connection_enabled,omitempty"`
|
||||||
|
|
||||||
|
// NetworkRange Allows to define a custom network range for the account in CIDR format
|
||||||
|
NetworkRange *string `json:"network_range,omitempty"`
|
||||||
|
|
||||||
// PeerInactivityExpiration Period of time of inactivity after which peer session expires (seconds).
|
// PeerInactivityExpiration Period of time of inactivity after which peer session expires (seconds).
|
||||||
PeerInactivityExpiration int `json:"peer_inactivity_expiration"`
|
PeerInactivityExpiration int `json:"peer_inactivity_expiration"`
|
||||||
|
|
||||||
@@ -1198,6 +1201,9 @@ type PeerRequest struct {
|
|||||||
// ApprovalRequired (Cloud only) Indicates whether peer needs approval
|
// ApprovalRequired (Cloud only) Indicates whether peer needs approval
|
||||||
ApprovalRequired *bool `json:"approval_required,omitempty"`
|
ApprovalRequired *bool `json:"approval_required,omitempty"`
|
||||||
InactivityExpirationEnabled bool `json:"inactivity_expiration_enabled"`
|
InactivityExpirationEnabled bool `json:"inactivity_expiration_enabled"`
|
||||||
|
|
||||||
|
// Ip Peer's IP address
|
||||||
|
Ip *string `json:"ip,omitempty"`
|
||||||
LoginExpirationEnabled bool `json:"login_expiration_enabled"`
|
LoginExpirationEnabled bool `json:"login_expiration_enabled"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
SshEnabled bool `json:"ssh_enabled"`
|
SshEnabled bool `json:"ssh_enabled"`
|
||||||
|
@@ -1,8 +1,10 @@
|
|||||||
package accounts
|
package accounts
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
@@ -16,6 +18,17 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// PeerBufferPercentage is the percentage of peers to add as buffer for network range calculations
|
||||||
|
PeerBufferPercentage = 0.5
|
||||||
|
// MinRequiredAddresses is the minimum number of addresses required in a network range
|
||||||
|
MinRequiredAddresses = 10
|
||||||
|
// MinNetworkBits is the minimum prefix length for IPv4 network ranges (e.g., /29 gives 8 addresses, /28 gives 16)
|
||||||
|
MinNetworkBitsIPv4 = 28
|
||||||
|
// MinNetworkBitsIPv6 is the minimum prefix length for IPv6 network ranges
|
||||||
|
MinNetworkBitsIPv6 = 120
|
||||||
|
)
|
||||||
|
|
||||||
// handler is a handler that handles the server.Account HTTP endpoints
|
// handler is a handler that handles the server.Account HTTP endpoints
|
||||||
type handler struct {
|
type handler struct {
|
||||||
accountManager account.Manager
|
accountManager account.Manager
|
||||||
@@ -37,6 +50,86 @@ func newHandler(accountManager account.Manager, settingsManager settings.Manager
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func validateIPAddress(addr netip.Addr) error {
|
||||||
|
if addr.IsLoopback() {
|
||||||
|
return status.Errorf(status.InvalidArgument, "loopback address range not allowed")
|
||||||
|
}
|
||||||
|
|
||||||
|
if addr.IsMulticast() {
|
||||||
|
return status.Errorf(status.InvalidArgument, "multicast address range not allowed")
|
||||||
|
}
|
||||||
|
|
||||||
|
if addr.IsLinkLocalUnicast() || addr.IsLinkLocalMulticast() {
|
||||||
|
return status.Errorf(status.InvalidArgument, "link-local address range not allowed")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateMinimumSize(prefix netip.Prefix) error {
|
||||||
|
addr := prefix.Addr()
|
||||||
|
if addr.Is4() && prefix.Bits() > MinNetworkBitsIPv4 {
|
||||||
|
return status.Errorf(status.InvalidArgument, "network range too small: minimum size is /%d for IPv4", MinNetworkBitsIPv4)
|
||||||
|
}
|
||||||
|
if addr.Is6() && prefix.Bits() > MinNetworkBitsIPv6 {
|
||||||
|
return status.Errorf(status.InvalidArgument, "network range too small: minimum size is /%d for IPv6", MinNetworkBitsIPv6)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) validateNetworkRange(ctx context.Context, accountID, userID string, networkRange netip.Prefix) error {
|
||||||
|
if !networkRange.IsValid() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := validateIPAddress(networkRange.Addr()); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := validateMinimumSize(networkRange); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return h.validateCapacity(ctx, accountID, userID, networkRange)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) validateCapacity(ctx context.Context, accountID, userID string, prefix netip.Prefix) error {
|
||||||
|
peers, err := h.accountManager.GetPeers(ctx, accountID, userID, "", "")
|
||||||
|
if err != nil {
|
||||||
|
return status.Errorf(status.Internal, "get peer count: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
maxHosts := calculateMaxHosts(prefix)
|
||||||
|
requiredAddresses := calculateRequiredAddresses(len(peers))
|
||||||
|
|
||||||
|
if maxHosts < requiredAddresses {
|
||||||
|
return status.Errorf(status.InvalidArgument,
|
||||||
|
"network range too small: need at least %d addresses for %d peers + buffer, but range provides %d",
|
||||||
|
requiredAddresses, len(peers), maxHosts)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func calculateMaxHosts(prefix netip.Prefix) int64 {
|
||||||
|
availableAddresses := prefix.Addr().BitLen() - prefix.Bits()
|
||||||
|
maxHosts := int64(1) << availableAddresses
|
||||||
|
|
||||||
|
if prefix.Addr().Is4() {
|
||||||
|
maxHosts -= 2 // network and broadcast addresses
|
||||||
|
}
|
||||||
|
|
||||||
|
return maxHosts
|
||||||
|
}
|
||||||
|
|
||||||
|
func calculateRequiredAddresses(peerCount int) int64 {
|
||||||
|
requiredAddresses := int64(peerCount) + int64(float64(peerCount)*PeerBufferPercentage)
|
||||||
|
if requiredAddresses < MinRequiredAddresses {
|
||||||
|
requiredAddresses = MinRequiredAddresses
|
||||||
|
}
|
||||||
|
return requiredAddresses
|
||||||
|
}
|
||||||
|
|
||||||
// getAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account.
|
// getAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account.
|
||||||
func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) {
|
func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) {
|
||||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
@@ -131,6 +224,18 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
|
|||||||
if req.Settings.LazyConnectionEnabled != nil {
|
if req.Settings.LazyConnectionEnabled != nil {
|
||||||
settings.LazyConnectionEnabled = *req.Settings.LazyConnectionEnabled
|
settings.LazyConnectionEnabled = *req.Settings.LazyConnectionEnabled
|
||||||
}
|
}
|
||||||
|
if req.Settings.NetworkRange != nil && *req.Settings.NetworkRange != "" {
|
||||||
|
prefix, err := netip.ParsePrefix(*req.Settings.NetworkRange)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid CIDR format: %v", err), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.validateNetworkRange(r.Context(), accountID, userID, prefix); err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
settings.NetworkRange = prefix
|
||||||
|
}
|
||||||
|
|
||||||
var onboarding *types.AccountOnboarding
|
var onboarding *types.AccountOnboarding
|
||||||
if req.Onboarding != nil {
|
if req.Onboarding != nil {
|
||||||
@@ -208,6 +313,11 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A
|
|||||||
DnsDomain: &settings.DNSDomain,
|
DnsDomain: &settings.DNSDomain,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if settings.NetworkRange.IsValid() {
|
||||||
|
networkRangeStr := settings.NetworkRange.String()
|
||||||
|
apiSettings.NetworkRange = &networkRangeStr
|
||||||
|
}
|
||||||
|
|
||||||
apiOnboarding := api.AccountOnboarding{
|
apiOnboarding := api.AccountOnboarding{
|
||||||
OnboardingFlowPending: onboarding.OnboardingFlowPending,
|
OnboardingFlowPending: onboarding.OnboardingFlowPending,
|
||||||
SignupFormPending: onboarding.SignupFormPending,
|
SignupFormPending: onboarding.SignupFormPending,
|
||||||
|
@@ -5,6 +5,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -111,6 +112,19 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if req.Ip != nil {
|
||||||
|
addr, err := netip.ParseAddr(*req.Ip)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(ctx, status.Errorf(status.InvalidArgument, "invalid IP address %s: %v", *req.Ip, err), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = h.accountManager.UpdatePeerIP(ctx, accountID, userID, peerID, addr); err != nil {
|
||||||
|
util.WriteError(ctx, err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
peer, err := h.accountManager.UpdatePeer(ctx, accountID, userID, update)
|
peer, err := h.accountManager.UpdatePeer(ctx, accountID, userID, update)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(ctx, err, w)
|
util.WriteError(ctx, err, w)
|
||||||
|
@@ -9,6 +9,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -21,6 +22,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||||
)
|
)
|
||||||
@@ -112,6 +114,15 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler {
|
|||||||
p.Name = update.Name
|
p.Name = update.Name
|
||||||
return p, nil
|
return p, nil
|
||||||
},
|
},
|
||||||
|
UpdatePeerIPFunc: func(_ context.Context, accountID, userID, peerID string, newIP netip.Addr) error {
|
||||||
|
for _, peer := range peers {
|
||||||
|
if peer.ID == peerID {
|
||||||
|
peer.IP = net.IP(newIP.AsSlice())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fmt.Errorf("peer not found")
|
||||||
|
},
|
||||||
GetPeerFunc: func(_ context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) {
|
GetPeerFunc: func(_ context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) {
|
||||||
var p *nbpeer.Peer
|
var p *nbpeer.Peer
|
||||||
for _, peer := range peers {
|
for _, peer := range peers {
|
||||||
@@ -450,3 +461,73 @@ func TestGetAccessiblePeers(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPeersHandlerUpdatePeerIP(t *testing.T) {
|
||||||
|
testPeer := &nbpeer.Peer{
|
||||||
|
ID: testPeerID,
|
||||||
|
Key: "key",
|
||||||
|
IP: net.ParseIP("100.64.0.1"),
|
||||||
|
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
|
||||||
|
Name: "test-host@netbird.io",
|
||||||
|
LoginExpirationEnabled: false,
|
||||||
|
UserID: regularUser,
|
||||||
|
Meta: nbpeer.PeerSystemMeta{
|
||||||
|
Hostname: "test-host@netbird.io",
|
||||||
|
Core: "22.04",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
p := initTestMetaData(testPeer)
|
||||||
|
|
||||||
|
tt := []struct {
|
||||||
|
name string
|
||||||
|
peerID string
|
||||||
|
requestBody string
|
||||||
|
callerUserID string
|
||||||
|
expectedStatus int
|
||||||
|
expectedIP string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "update peer IP successfully",
|
||||||
|
peerID: testPeerID,
|
||||||
|
requestBody: `{"ip": "100.64.0.100"}`,
|
||||||
|
callerUserID: adminUser,
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedIP: "100.64.0.100",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "update peer IP with invalid IP",
|
||||||
|
peerID: testPeerID,
|
||||||
|
requestBody: `{"ip": "invalid-ip"}`,
|
||||||
|
callerUserID: adminUser,
|
||||||
|
expectedStatus: http.StatusUnprocessableEntity,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tt {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodPut, fmt.Sprintf("/peers/%s", tc.peerID), bytes.NewBuffer([]byte(tc.requestBody)))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||||
|
UserId: tc.callerUserID,
|
||||||
|
Domain: "hotmail.com",
|
||||||
|
AccountId: "test_id",
|
||||||
|
})
|
||||||
|
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
router := mux.NewRouter()
|
||||||
|
router.HandleFunc("/peers/{peerId}", p.HandlePeer).Methods("PUT")
|
||||||
|
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
assert.Equal(t, tc.expectedStatus, rr.Code)
|
||||||
|
|
||||||
|
if tc.expectedStatus == http.StatusOK && tc.expectedIP != "" {
|
||||||
|
var updatedPeer api.Peer
|
||||||
|
err := json.Unmarshal(rr.Body.Bytes(), &updatedPeer)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, tc.expectedIP, updatedPeer.Ip)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@@ -60,6 +60,7 @@ type MockAccountManager struct {
|
|||||||
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error)
|
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error)
|
||||||
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
|
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
|
||||||
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||||
|
UpdatePeerIPFunc func(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error
|
||||||
CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
|
CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
|
||||||
GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
|
GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
|
||||||
SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error
|
SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error
|
||||||
@@ -483,6 +484,13 @@ func (am *MockAccountManager) UpdatePeer(ctx context.Context, accountID, userID
|
|||||||
return nil, status.Errorf(codes.Unimplemented, "method UpdatePeer is not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method UpdatePeer is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (am *MockAccountManager) UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error {
|
||||||
|
if am.UpdatePeerIPFunc != nil {
|
||||||
|
return am.UpdatePeerIPFunc(ctx, accountID, userID, peerID, newIP)
|
||||||
|
}
|
||||||
|
return status.Errorf(codes.Unimplemented, "method UpdatePeerIP is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
// CreateRoute mock implementation of CreateRoute from server.AccountManager interface
|
// CreateRoute mock implementation of CreateRoute from server.AccountManager interface
|
||||||
func (am *MockAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupID []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) {
|
func (am *MockAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupID []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) {
|
||||||
if am.CreateRouteFunc != nil {
|
if am.CreateRouteFunc != nil {
|
||||||
|
@@ -163,7 +163,10 @@ func (n *Network) Copy() *Network {
|
|||||||
// E.g. if ipNet=100.30.0.0/16 and takenIps=[100.30.0.1, 100.30.0.4] then the result would be 100.30.0.2 or 100.30.0.3
|
// E.g. if ipNet=100.30.0.0/16 and takenIps=[100.30.0.1, 100.30.0.4] then the result would be 100.30.0.2 or 100.30.0.3
|
||||||
func AllocatePeerIP(ipNet net.IPNet, takenIps []net.IP) (net.IP, error) {
|
func AllocatePeerIP(ipNet net.IPNet, takenIps []net.IP) (net.IP, error) {
|
||||||
baseIP := ipToUint32(ipNet.IP.Mask(ipNet.Mask))
|
baseIP := ipToUint32(ipNet.IP.Mask(ipNet.Mask))
|
||||||
totalIPs := uint32(1 << SubnetSize)
|
|
||||||
|
ones, bits := ipNet.Mask.Size()
|
||||||
|
hostBits := bits - ones
|
||||||
|
totalIPs := uint32(1 << hostBits)
|
||||||
|
|
||||||
taken := make(map[uint32]struct{}, len(takenIps)+1)
|
taken := make(map[uint32]struct{}, len(takenIps)+1)
|
||||||
taken[baseIP] = struct{}{} // reserve network IP
|
taken[baseIP] = struct{}{} // reserve network IP
|
||||||
|
@@ -5,6 +5,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewNetwork(t *testing.T) {
|
func TestNewNetwork(t *testing.T) {
|
||||||
@@ -38,6 +39,107 @@ func TestAllocatePeerIP(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAllocatePeerIPSmallSubnet(t *testing.T) {
|
||||||
|
// Test /27 network (10.0.0.0/27) - should only have 30 usable IPs (10.0.0.1 to 10.0.0.30)
|
||||||
|
ipNet := net.IPNet{IP: net.ParseIP("10.0.0.0"), Mask: net.IPMask{255, 255, 255, 224}}
|
||||||
|
var ips []net.IP
|
||||||
|
|
||||||
|
// Allocate all available IPs in the /27 network
|
||||||
|
for i := 0; i < 30; i++ {
|
||||||
|
ip, err := AllocatePeerIP(ipNet, ips)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify IP is within the correct range
|
||||||
|
if !ipNet.Contains(ip) {
|
||||||
|
t.Errorf("allocated IP %s is not within network %s", ip.String(), ipNet.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
ips = append(ips, ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Len(t, ips, 30)
|
||||||
|
|
||||||
|
// Verify all IPs are unique
|
||||||
|
uniq := make(map[string]struct{})
|
||||||
|
for _, ip := range ips {
|
||||||
|
if _, ok := uniq[ip.String()]; !ok {
|
||||||
|
uniq[ip.String()] = struct{}{}
|
||||||
|
} else {
|
||||||
|
t.Errorf("found duplicate IP %s", ip.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to allocate one more IP - should fail as network is full
|
||||||
|
_, err := AllocatePeerIP(ipNet, ips)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error when network is full, but got none")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAllocatePeerIPVariousCIDRs(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
cidr string
|
||||||
|
expectedUsable int
|
||||||
|
}{
|
||||||
|
{"/30 network", "192.168.1.0/30", 2}, // 4 total - 2 reserved = 2 usable
|
||||||
|
{"/29 network", "192.168.1.0/29", 6}, // 8 total - 2 reserved = 6 usable
|
||||||
|
{"/28 network", "192.168.1.0/28", 14}, // 16 total - 2 reserved = 14 usable
|
||||||
|
{"/27 network", "192.168.1.0/27", 30}, // 32 total - 2 reserved = 30 usable
|
||||||
|
{"/26 network", "192.168.1.0/26", 62}, // 64 total - 2 reserved = 62 usable
|
||||||
|
{"/25 network", "192.168.1.0/25", 126}, // 128 total - 2 reserved = 126 usable
|
||||||
|
{"/16 network", "10.0.0.0/16", 65534}, // 65536 total - 2 reserved = 65534 usable
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
_, ipNet, err := net.ParseCIDR(tc.cidr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var ips []net.IP
|
||||||
|
|
||||||
|
// For larger networks, test only a subset to avoid long test runs
|
||||||
|
testCount := tc.expectedUsable
|
||||||
|
if testCount > 1000 {
|
||||||
|
testCount = 1000
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allocate IPs and verify they're within the correct range
|
||||||
|
for i := 0; i < testCount; i++ {
|
||||||
|
ip, err := AllocatePeerIP(*ipNet, ips)
|
||||||
|
require.NoError(t, err, "failed to allocate IP %d", i)
|
||||||
|
|
||||||
|
// Verify IP is within the correct range
|
||||||
|
assert.True(t, ipNet.Contains(ip), "allocated IP %s is not within network %s", ip.String(), ipNet.String())
|
||||||
|
|
||||||
|
// Verify IP is not network or broadcast address
|
||||||
|
networkIP := ipNet.IP.Mask(ipNet.Mask)
|
||||||
|
ones, bits := ipNet.Mask.Size()
|
||||||
|
hostBits := bits - ones
|
||||||
|
broadcastInt := uint32(ipToUint32(networkIP)) + (1 << hostBits) - 1
|
||||||
|
broadcastIP := uint32ToIP(broadcastInt)
|
||||||
|
|
||||||
|
assert.False(t, ip.Equal(networkIP), "allocated network address %s", ip.String())
|
||||||
|
assert.False(t, ip.Equal(broadcastIP), "allocated broadcast address %s", ip.String())
|
||||||
|
|
||||||
|
ips = append(ips, ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Len(t, ips, testCount)
|
||||||
|
|
||||||
|
// Verify all IPs are unique
|
||||||
|
uniq := make(map[string]struct{})
|
||||||
|
for _, ip := range ips {
|
||||||
|
ipStr := ip.String()
|
||||||
|
assert.NotContains(t, uniq, ipStr, "found duplicate IP %s", ipStr)
|
||||||
|
uniq[ipStr] = struct{}{}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestGenerateIPs(t *testing.T) {
|
func TestGenerateIPs(t *testing.T) {
|
||||||
ipNet := net.IPNet{IP: net.ParseIP("100.64.0.0"), Mask: net.IPMask{255, 255, 255, 0}}
|
ipNet := net.IPNet{IP: net.ParseIP("100.64.0.0"), Mask: net.IPMask{255, 255, 255, 0}}
|
||||||
ips, ipsLen := generateIPs(&ipNet, map[string]struct{}{"100.64.0.0": {}})
|
ips, ipsLen := generateIPs(&ipNet, map[string]struct{}{"100.64.0.0": {}})
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
package types
|
package types
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -42,6 +43,9 @@ type Settings struct {
|
|||||||
// DNSDomain is the custom domain for that account
|
// DNSDomain is the custom domain for that account
|
||||||
DNSDomain string
|
DNSDomain string
|
||||||
|
|
||||||
|
// NetworkRange is the custom network range for that account
|
||||||
|
NetworkRange netip.Prefix `gorm:"serializer:json"`
|
||||||
|
|
||||||
// Extra is a dictionary of Account settings
|
// Extra is a dictionary of Account settings
|
||||||
Extra *ExtraSettings `gorm:"embedded;embeddedPrefix:extra_"`
|
Extra *ExtraSettings `gorm:"embedded;embeddedPrefix:extra_"`
|
||||||
|
|
||||||
@@ -66,6 +70,7 @@ func (s *Settings) Copy() *Settings {
|
|||||||
RoutingPeerDNSResolutionEnabled: s.RoutingPeerDNSResolutionEnabled,
|
RoutingPeerDNSResolutionEnabled: s.RoutingPeerDNSResolutionEnabled,
|
||||||
LazyConnectionEnabled: s.LazyConnectionEnabled,
|
LazyConnectionEnabled: s.LazyConnectionEnabled,
|
||||||
DNSDomain: s.DNSDomain,
|
DNSDomain: s.DNSDomain,
|
||||||
|
NetworkRange: s.NetworkRange,
|
||||||
}
|
}
|
||||||
if s.Extra != nil {
|
if s.Extra != nil {
|
||||||
settings.Extra = s.Extra.Copy()
|
settings.Extra = s.Extra.Copy()
|
||||||
|
@@ -23,7 +23,6 @@ func FileExists(path string) bool {
|
|||||||
return err == nil
|
return err == nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/// Bool helpers
|
/// Bool helpers
|
||||||
|
|
||||||
// True returns a *bool whose underlying value is true.
|
// True returns a *bool whose underlying value is true.
|
||||||
|
@@ -6,7 +6,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
//Duration is used strictly for JSON requests/responses due to duration marshalling issues
|
// Duration is used strictly for JSON requests/responses due to duration marshalling issues
|
||||||
type Duration struct {
|
type Duration struct {
|
||||||
time.Duration
|
time.Duration
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user